代码拉取完成,页面将自动刷新
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import os,sys
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
from base import env1, Memory
import time
class Actor(nn.Module):
def __init__(self,n_state,n_action):
super(Actor, self).__init__()
self.seq = nn.Sequential(
nn.Linear(n_state,30),
nn.ReLU(),
nn.Linear(30,n_action)
)
for m in self.seq.parameters():
if isinstance(m,nn.Linear):
nn.init.normal_(m.weight,0,0.1)
def forward(self,state):
x = self.seq(state)
x = torch.tanh(x)*2.
return x
class Critic(nn.Module):
def __init__(self,n_state,n_action):
super(Critic, self).__init__()
self.fc1 = nn.Linear(n_state,30)
self.fc1.weight.data.normal_(0,0.1)
self.fc2 = nn.Linear(n_action,30)
self.fc2.weight.data.normal_(0,0.1)
self.fc3 = nn.Linear(60,60)
self.fc3.weight.data.normal_(0,0.1)
self.out = nn.Linear(60,n_action)
self.out.weight.data.normal_(0,0.1)
def forward(self,state,action):
x1 = self.fc1(state)
x2 = self.fc2(action)
x = torch.cat([x1,x2],dim=1)
x = self.fc3(F.relu(x))
x = self.out(F.relu(x))
return x
class DDPG(Memory):
def __init__(self,n_state,n_action,
explore_var=3.,explore_var_decay=0.9995,
gamma=0.9,TAU=0.01,actor_lr=1e-3,critic_lr=0.002,
capacity=10000,logs='./logs'):
super(DDPG,self).__init__(capacity,logs)
self.actor = Actor(n_state,n_action).to(self.device)
self.actor_target = Actor(n_state,n_action).to(self.device)
self.actor_opt = torch.optim.Adam(self.actor.parameters(),lr=actor_lr)
self.critic = Critic(n_state,n_action).to(self.device)
self.critic_target = Critic(n_state,n_action).to(self.device)
self.critic_opt = torch.optim.Adam(self.critic.parameters(),lr=critic_lr)
self.explore_var = explore_var
self.explore_var_decay = explore_var_decay
self.gamma = gamma
self.TAU = TAU
def choose_action(self,state):
self.explore_var *= self.explore_var_decay
state = torch.tensor(state,dtype=torch.float).unsqueeze(0).to(self.device)
action = self.actor(state).item()
action = np.random.normal(action,self.explore_var)
action = np.clip(action,env1.action_space.low,env1.action_space.high)
return action
def optimize(self,batch):
s = torch.tensor(batch.s,dtype=torch.float).to(self.device)
a = torch.tensor(batch.a,dtype=torch.float).to(self.device)
r = torch.tensor(batch.r,dtype=torch.float).to(self.device)
s_ = torch.tensor(batch.s_, dtype=torch.float).to(self.device)
done = torch.tensor(batch.done, dtype=torch.float).to(self.device)
q_predicted = self.critic(s,a)
q_expected = self.critic_target(s_,self.actor_target(s_))
q_expected = r + self.gamma*(1-done)*q_expected
critic_loss = F.mse_loss(q_predicted,q_expected).mean()
self.writer.add_scalar('loss_critic',critic_loss.item(),self.step)
self.critic_opt.zero_grad()
critic_loss.backward()
self.critic_opt.step()
# 计算损失和反向传播过程要挨在一起,不然会覆盖梯度
actor_loss = -1.*self.critic(s,self.actor(s)).mean()
self.writer.add_scalar('loss_actor',actor_loss.item(),self.step)
self.actor_opt.zero_grad()
actor_loss.backward()
self.actor_opt.step()
for p,tp in zip(self.critic.parameters(),self.critic_target.parameters()):
tp.data.copy_(self.TAU*p.data + (1-self.TAU)*tp.data)
for p,tp in zip(self.actor.parameters(),self.actor_target.parameters()):
tp.data.copy_(self.TAU*p.data + (1-self.TAU)*tp.data)
if __name__=='__main__':
epoch = 200
max_step_per_epoch =200
# update_epoch = 10
batch_size = 64
env = env1
env.seed(int(time.time()))
model = DDPG(env.observation_space.shape[0],env.action_space.shape[0])
print("Sampling...")
count = 0
while count < model.capacity:
state = env.reset()
for _ in range(max_step_per_epoch):
action = model.choose_action(state)
state_, reward, done, info = env.step(action)
model.put_transition(state, action, [reward], state_, [done])
state = state_
count += 1
count = 0
for ep in range(epoch):
epoch_r = 0
state = env.reset()
for st in range(max_step_per_epoch):
action = model.choose_action(state)
state_,reward,done,info = env.step(action)
model.put_transition(state,action,[reward],state_,[done])
state = state_
batch = model.get_transition(batch_size=batch_size)
model.optimize(batch)
env.render()
count += 1
epoch_r += reward
print('Epoch:[{}/{}],step:[{}/{}],ep_r:{:.4f}'.format(
ep + 1, epoch, st + 1, max_step_per_epoch,epoch_r))
model.writer.add_scalar('ep_r',epoch_r,count)
此处可能存在不合适展示的内容,页面不予展示。您可通过相关编辑功能自查并修改。
如您确认内容无涉及 不当用语 / 纯广告导流 / 暴力 / 低俗色情 / 侵权 / 盗版 / 虚假 / 无价值内容或违法国家有关法律法规的内容,可点击提交进行申诉,我们将尽快为您处理。