1 Star 0 Fork 0

Jill/pytorch-maddpg

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
该仓库未声明开源许可证文件(LICENSE),使用请关注具体项目描述及其代码上游依赖。
克隆/下载
model.py 1.25 KB
一键复制 编辑 原始数据 按行查看 历史
xuehy 提交于 2017-07-25 09:48 . first commit
import torch as th
import torch.nn as nn
import torch.nn.functional as F
class Critic(nn.Module):
def __init__(self, n_agent, dim_observation, dim_action):
super(Critic, self).__init__()
self.n_agent = n_agent
self.dim_observation = dim_observation
self.dim_action = dim_action
obs_dim = dim_observation * n_agent
act_dim = self.dim_action * n_agent
self.FC1 = nn.Linear(obs_dim, 1024)
self.FC2 = nn.Linear(1024+act_dim, 512)
self.FC3 = nn.Linear(512, 300)
self.FC4 = nn.Linear(300, 1)
# obs: batch_size * obs_dim
def forward(self, obs, acts):
result = F.relu(self.FC1(obs))
combined = th.cat([result, acts], 1)
result = F.relu(self.FC2(combined))
return self.FC4(F.relu(self.FC3(result)))
class Actor(nn.Module):
def __init__(self, dim_observation, dim_action):
super(Actor, self).__init__()
self.FC1 = nn.Linear(dim_observation, 500)
self.FC2 = nn.Linear(500, 128)
self.FC3 = nn.Linear(128, dim_action)
# action output between -2 and 2
def forward(self, obs):
result = F.relu(self.FC1(obs))
result = F.relu(self.FC2(result))
result = F.tanh(self.FC3(result))
return result
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
1
https://gitee.com/lxqbupt/pytorch-maddpg.git
git@gitee.com:lxqbupt/pytorch-maddpg.git
lxqbupt
pytorch-maddpg
pytorch-maddpg
master

搜索帮助

0d507c66 1850385 C8b1a773 1850385