代码拉取完成,页面将自动刷新
import numpy as np
import torch
class ReplayBuffer(object):
def __init__(self, state_dim, action_dim, max_size=int(1e6)):
self.max_size = max_size
self.ptr = 0
self.size = 0
self.state = np.zeros((max_size, state_dim))
self.action = np.zeros((max_size, action_dim))
self.reward = np.zeros((max_size, 1))
self.next_state = np.zeros((max_size, state_dim))
self.dead = np.zeros((max_size, 1))
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
def add(self, state, action, reward, next_state, dead):
self.state[self.ptr] = state
self.action[self.ptr] = action
self.reward[self.ptr] = reward
self.next_state[self.ptr] = next_state
self.dead[self.ptr] = dead #0,0,0,...,1
self.ptr = (self.ptr + 1) % self.max_size
self.size = min(self.size + 1, self.max_size)
def sample(self, batch_size):
ind = np.random.randint(0, self.size, size=batch_size)
return (
torch.FloatTensor(self.state[ind]).to(self.device),
torch.FloatTensor(self.action[ind]).to(self.device),
torch.FloatTensor(self.reward[ind]).to(self.device),
torch.FloatTensor(self.next_state[ind]).to(self.device),
torch.FloatTensor(self.dead[ind]).to(self.device)
)
此处可能存在不合适展示的内容,页面不予展示。您可通过相关编辑功能自查并修改。
如您确认内容无涉及 不当用语 / 纯广告导流 / 暴力 / 低俗色情 / 侵权 / 盗版 / 虚假 / 无价值内容或违法国家有关法律法规的内容,可点击提交进行申诉,我们将尽快为您处理。