1 Star 0 Fork 0

vjk0909/TD3-BipedalWalkerHardcore-v2

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
克隆/下载
ReplayBuffer.py 1.22 KB
一键复制 编辑 原始数据 按行查看 历史
XinJingHao 提交于 2020-12-12 01:15 . Add files via upload
import numpy as np
import torch
class ReplayBuffer(object):
def __init__(self, state_dim, action_dim, max_size=int(1e6)):
self.max_size = max_size
self.ptr = 0
self.size = 0
self.state = np.zeros((max_size, state_dim))
self.action = np.zeros((max_size, action_dim))
self.reward = np.zeros((max_size, 1))
self.next_state = np.zeros((max_size, state_dim))
self.dead = np.zeros((max_size, 1))
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
def add(self, state, action, reward, next_state, dead):
self.state[self.ptr] = state
self.action[self.ptr] = action
self.reward[self.ptr] = reward
self.next_state[self.ptr] = next_state
self.dead[self.ptr] = dead #0,0,0,...,1
self.ptr = (self.ptr + 1) % self.max_size
self.size = min(self.size + 1, self.max_size)
def sample(self, batch_size):
ind = np.random.randint(0, self.size, size=batch_size)
return (
torch.FloatTensor(self.state[ind]).to(self.device),
torch.FloatTensor(self.action[ind]).to(self.device),
torch.FloatTensor(self.reward[ind]).to(self.device),
torch.FloatTensor(self.next_state[ind]).to(self.device),
torch.FloatTensor(self.dead[ind]).to(self.device)
)
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
1
https://gitee.com/vjk0909/TD3-BipedalWalkerHardcore-v2.git
git@gitee.com:vjk0909/TD3-BipedalWalkerHardcore-v2.git
vjk0909
TD3-BipedalWalkerHardcore-v2
TD3-BipedalWalkerHardcore-v2
main

搜索帮助

0d507c66 1850385 C8b1a773 1850385