代码拉取完成,页面将自动刷新
import collections
import numpy as np
Transition = collections.namedtuple('Experience', field_names=['state', 'action', 'next_state', 'reward', 'is_game_on'])
class ReplayBuffer:
def __init__(self, capacity):
"""回放缓冲区初始化"""
self.buffer = collections.deque(maxlen=capacity)
def push(self, transition):
"""存储经验"""
self.buffer.append(transition)
def sample(self, batch_size):
"""采样经验"""
indices = np.random.choice(len(self.buffer), batch_size, replace=False)
states, actions, next_states, rewards, dones = zip(*[self.buffer[idx] for idx in indices])
return np.array(states), np.array(actions), np.array(next_states), np.array(rewards), np.array(dones)
def __len__(self):
"""缓冲区长度"""
return len(self.buffer)
此处可能存在不合适展示的内容,页面不予展示。您可通过相关编辑功能自查并修改。
如您确认内容无涉及 不当用语 / 纯广告导流 / 暴力 / 低俗色情 / 侵权 / 盗版 / 虚假 / 无价值内容或违法国家有关法律法规的内容,可点击提交进行申诉,我们将尽快为您处理。