1 Star 0 Fork 12

rizheng213/stock_robot

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
克隆/下载
learn_ppo.py 2.21 KB
一键复制 编辑 原始数据 按行查看 历史
邹吉华 提交于 2023-04-12 16:27 . 1.6
from stable_baselines3.common.monitor import Monitor
from stable_baselines3.common.vec_env import SubprocVecEnv
from stable_baselines3.common.utils import set_random_seed
from stable_baselines3.common.evaluation import evaluate_policy
from stable_baselines3 import PPO
from stock_env import StockEnv
from save_model import SaveModelCallback
import torch as th
TB_LOG_PATH = "../tb_log"
MODEL_PATH = "./model/ppo"
def make_env(rank, seed=0):
"""
Utility function for multiprocessed env.
:param env_id: (str) the environment ID
:param num_env: (int) the number of environments you wish to have in subprocesses
:param seed: (int) the inital seed for RNG
:param rank: (int) index of the subprocess
"""
def _init():
env = Monitor(StockEnv(range(2011,2021)), MODEL_PATH+'/'+str(rank))
env.seed(seed + rank)
return env
set_random_seed(seed)
return _init
'''
mean_reward 193.5548455
[I 2022-08-19 17:59:21,513] Trial 16 finished with value: 193.5548455 and parameters: {'na_num': 6, '0': 32, '1': 128, '2': 128, '3': 768, '4': 32, '5': 32, 'n_steps': 2048, 'gamma': 0.8590208698070629, 'learning_rate': 4.044243543027443e-05, 'clip_range': 0.23308626459396495, 'gae_lambda': 0.8641983426984212}. Best is trial 16 with value: 193.5548455.
'''
def optimize_param():
policy = dict(
activation_fn=th.nn.ReLU,
net_arch=[32, 128, 128, 768, 32, 32]
)
return {
'n_steps': 2048, 'gamma': 0.8590208698070629, 'learning_rate': 4.044243543027443e-05, 'clip_range': 0.23308626459396495, 'gae_lambda': 0.8641983426984212,
'policy_kwargs':policy
}
if __name__ == '__main__':
num_cpu = 128 # Number of processes to use
# Create the vectorized environment
env = SubprocVecEnv([make_env(i) for i in range(num_cpu)])
#env = DummyVecEnv([make_env(i) for i in range(num_cpu)])
model_params = optimize_param()
model = PPO('MlpPolicy', env, verbose=1,tensorboard_log=TB_LOG_PATH,**model_params)
model.learn(total_timesteps=10000000,callback=SaveModelCallback(1024,"./model/ppo"))
mean_reward, std_reward = evaluate_policy(model, env)
model.save("./ppo_model")
print("learn finish",mean_reward,std_reward)
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
Python
1
https://gitee.com/rizheng213/stock_robot.git
git@gitee.com:rizheng213/stock_robot.git
rizheng213
stock_robot
stock_robot
master

搜索帮助