代码拉取完成,页面将自动刷新
同步操作将从 元原子/强化学习之超级马里奥 强制同步,此操作会覆盖自 Fork 仓库以来所做的任何修改,且无法恢复!!!
确定后同步将在后台操作,完成时将刷新页面,请耐心等待。
import os.path
import torch
import torch.nn.functional as 火炬函数
from 工具屋.参数室 import 参数
from 工具屋.环境室 import 创建训练环境
from 模型构建屋.模型室 import 行动者和评论家类
def 测试():
torch.cuda.manual_seed(123)
# 临时路径拼接 = os.path.join(参数.视频输出的路径, "sp_{}_{}.mp4".format(参数.世界号, 参数.舞台号))
# 环境, 舞台数量, 动作数量 = 创建训练环境(参数.世界号, 参数.舞台号, 参数.操作模式, 临时路径拼接)
环境, 舞台数量, 动作数量 = 创建训练环境(参数.世界号, 参数.舞台号, 参数.操作模式, 渲染模式="human")
模型 = 行动者和评论家类(舞台数量, 动作数量)
临时路径拼接 = os.path.join(参数.保存的路径, "超级马里奥_{}_{}_已完成".format(参数.世界号, 参数.舞台号))
if 参数.是否使用图像处理单元:
模型.load_state_dict(torch.load(临时路径拼接))
模型.cuda()
else:
模型.load_state_dict(torch.load(临时路径拼接))
模型.eval()
状态 = torch.from_numpy(环境.reset())
完毕 = True
while True:
if 完毕:
隐藏的状态_零张量 = torch.zeros((1, 512), dtype=torch.float)
单元的状态_零张量 = torch.zeros((1, 512), dtype=torch.float)
环境.reset()
else:
隐藏的状态_零张量 = 隐藏的状态_零张量.detach()
单元的状态_零张量 = 单元的状态_零张量.detach()
if 参数.是否使用图像处理单元:
状态 = 状态.cuda()
隐藏的状态_零张量 = 隐藏的状态_零张量.cuda()
单元的状态_零张量 = 单元的状态_零张量.cuda()
动作_策略, 预期值, 隐藏的状态_零张量, 单元的状态_零张量 = 模型(状态, 隐藏的状态_零张量, 单元的状态_零张量)
策略 = 火炬函数.softmax(动作_策略, dim=1)
# for i in 策略[0]:
# print(round(i.item(),4))
动作 = torch.argmax(策略).item()
动作 = int(动作)
状态, 奖励, 完毕, _, 信息 = 环境.step(动作)
状态 = torch.from_numpy(状态)
环境.render()
if 信息["flag_get"]:
print("世界 {},舞台 {},执行完毕".format(参数.世界号, 参数.舞台号))
break
if __name__ == '__main__':
测试()
pass
此处可能存在不合适展示的内容,页面不予展示。您可通过相关编辑功能自查并修改。
如您确认内容无涉及 不当用语 / 纯广告导流 / 暴力 / 低俗色情 / 侵权 / 盗版 / 虚假 / 无价值内容或违法国家有关法律法规的内容,可点击提交进行申诉,我们将尽快为您处理。