代码拉取完成,页面将自动刷新
import os,sys
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
import torch
import torch.nn as nn
import torch.nn.functional as F
from collections import OrderedDict
from base import Memory, env0, env1
import numpy as np
import time
class DQN(nn.Module):
def __init__(self,n_state,n_action):
super(DQN, self).__init__()
self.seq = nn.Sequential(
nn.Linear(n_state,20),
nn.ReLU(),
nn.Linear(20,n_action)
)
for m in self.seq.parameters():
if isinstance(m,nn.Linear):
nn.init.normal_(m.weight,mean=0,std=0.1)
def forward(self,state):
return self.seq(state)
class Model(Memory):
def __init__(self,n_state,n_action,
lr=1e-3,epsilon=0.5,epsilon_decay=0.9996,
capacity=10000,logs='./logs'):
super(Model, self).__init__(capacity,logs)
self.n_state = n_state
self.n_action = n_action
self.eval_net = DQN(n_state,n_action).to(self.device)
self.target_net = DQN(n_state,n_action).to(self.device)
self.opt = torch.optim.Adam(self.eval_net.parameters(),lr=lr)
self.epsilon = epsilon
self.epsilon_decay = epsilon_decay
def optimize(self,batch):
s = torch.tensor(batch.s,dtype=torch.float).to(self.device)
a = torch.tensor(batch.a,dtype=torch.long).to(self.device)
r = torch.tensor(batch.r,dtype=torch.float).to(self.device)
s_ = torch.tensor(batch.s_,dtype=torch.float).to(self.device)
done = torch.tensor(batch.done,dtype=torch.float).to(self.device)
q_predicted = self.eval_net(s).gather(dim=1,index=a)
q_expected,_ = torch.max(self.target_net(s_),dim=1,keepdim=True)
q_expected = r + (1-done)*q_expected
loss = F.mse_loss(q_predicted,q_expected).mean()
self.opt.zero_grad()
loss.backward()
self.opt.step()
self.writer.add_scalar('loss',loss.item(),self.step)
return
def choose_action(self,state):
self.epsilon *= self.epsilon_decay
if np.random.uniform(0,1)<self.epsilon:
action = np.random.randint(0,self.n_action)
else:
state = torch.tensor(state,dtype=torch.float).unsqueeze(0)
action = torch.argmax(self.eval_net(state),dim=1).item()
self.writer.add_scalar('action',action,self.step)
return action
def hard_update(self):
self.target_net.load_state_dict(OrderedDict(self.eval_net.state_dict()))
def save_model(self,ckpt_dir='./logs'):
torch.save(self.eval_net.parameters(),os.path.join(ckpt_dir,'model.pth'))
def load_model(self,ckpt_dir='./logs'):
if os.path.exists(os.path.join(ckpt_dir,'model.pth')):
self.eval_net.load_state_dict(torch.load(os.path.join(ckpt_dir,'model.pth')))
self.hard_update()
if __name__=='__main__':
epoch = 1000
max_step_per_epoch =1000
update_epoch = 10
batch_size = 64
env = env0
env.seed(int(time.time()))
model = Model(env.observation_space.shape[0],env.action_space.n)
model.load_model()
print("Sampling...")
count = 0
state = env.reset()
while count < model.capacity:
action = model.choose_action(state)
state_, reward, done, info = env.step(action)
model.put_transition(state, [action], [reward], state_, [done])
state = state_
count += 1
if done:
state = env.reset()
count = model.capacity
for ep in range(epoch):
state = env.reset()
for st in range(max_step_per_epoch):
action = model.choose_action(state)
state_,reward,done,info = env.step(action)
model.put_transition(state,[action],[reward],state_,[done])
state = state_
batch = model.get_transition(batch_size=batch_size)
model.optimize(batch)
env.render()
print('Epoch:[{}/{}],step:[{}/{}]'.format(ep + 1, epoch, st + 1, max_step_per_epoch))
count += 1
if done:
model.writer.add_scalar('ep_r',st+1,count)
break
if (ep+1) % update_epoch == 0:
model.hard_update()
此处可能存在不合适展示的内容,页面不予展示。您可通过相关编辑功能自查并修改。
如您确认内容无涉及 不当用语 / 纯广告导流 / 暴力 / 低俗色情 / 侵权 / 盗版 / 虚假 / 无价值内容或违法国家有关法律法规的内容,可点击提交进行申诉,我们将尽快为您处理。