1 Star 0 Fork 0

laiyijun2023/MazeCodeRepo

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
该仓库未声明开源许可证文件(LICENSE),使用请关注具体项目描述及其代码上游依赖。
克隆/下载
replay_buffer.py 847 Bytes
一键复制 编辑 原始数据 按行查看 历史
Yijun Lai 提交于 2024-05-14 19:40 . add all files
import collections
import numpy as np
Transition = collections.namedtuple('Experience', field_names=['state', 'action', 'next_state', 'reward', 'is_game_on'])
class ReplayBuffer:
def __init__(self, capacity):
"""回放缓冲区初始化"""
self.buffer = collections.deque(maxlen=capacity)
def push(self, transition):
"""存储经验"""
self.buffer.append(transition)
def sample(self, batch_size):
"""采样经验"""
indices = np.random.choice(len(self.buffer), batch_size, replace=False)
states, actions, next_states, rewards, dones = zip(*[self.buffer[idx] for idx in indices])
return np.array(states), np.array(actions), np.array(next_states), np.array(rewards), np.array(dones)
def __len__(self):
"""缓冲区长度"""
return len(self.buffer)
Loading...
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
1
https://gitee.com/laiyijun2023/maze-code-repo.git
git@gitee.com:laiyijun2023/maze-code-repo.git
laiyijun2023
maze-code-repo
MazeCodeRepo
master

搜索帮助