代码拉取完成,页面将自动刷新
同步操作将从 lightning-trader/stock_robot 强制同步,此操作会覆盖自 Fork 仓库以来所做的任何修改,且无法恢复!!!
确定后同步将在后台操作,完成时将刷新页面,请耐心等待。
from stable_baselines3.common.monitor import Monitor
from stable_baselines3.common.vec_env import SubprocVecEnv
from stable_baselines3.common.utils import set_random_seed
from stable_baselines3.common.evaluation import evaluate_policy
from stable_baselines3 import PPO
from stock_env import StockEnv
from save_model import SaveModelCallback
import torch as th
TB_LOG_PATH = "../tb_log"
MODEL_PATH = "./model/ppo"
def make_env(rank, seed=0):
"""
Utility function for multiprocessed env.
:param env_id: (str) the environment ID
:param num_env: (int) the number of environments you wish to have in subprocesses
:param seed: (int) the inital seed for RNG
:param rank: (int) index of the subprocess
"""
def _init():
env = Monitor(StockEnv(range(2011,2021)), MODEL_PATH+'/'+str(rank))
env.seed(seed + rank)
return env
set_random_seed(seed)
return _init
'''
mean_reward 193.5548455
[I 2022-08-19 17:59:21,513] Trial 16 finished with value: 193.5548455 and parameters: {'na_num': 6, '0': 32, '1': 128, '2': 128, '3': 768, '4': 32, '5': 32, 'n_steps': 2048, 'gamma': 0.8590208698070629, 'learning_rate': 4.044243543027443e-05, 'clip_range': 0.23308626459396495, 'gae_lambda': 0.8641983426984212}. Best is trial 16 with value: 193.5548455.
'''
def optimize_param():
policy = dict(
activation_fn=th.nn.ReLU,
net_arch=[32, 128, 128, 768, 32, 32]
)
return {
'n_steps': 2048, 'gamma': 0.8590208698070629, 'learning_rate': 4.044243543027443e-05, 'clip_range': 0.23308626459396495, 'gae_lambda': 0.8641983426984212,
'policy_kwargs':policy
}
if __name__ == '__main__':
num_cpu = 128 # Number of processes to use
# Create the vectorized environment
env = SubprocVecEnv([make_env(i) for i in range(num_cpu)])
#env = DummyVecEnv([make_env(i) for i in range(num_cpu)])
model_params = optimize_param()
model = PPO('MlpPolicy', env, verbose=1,tensorboard_log=TB_LOG_PATH,**model_params)
model.learn(total_timesteps=10000000,callback=SaveModelCallback(1024,"./model/ppo"))
mean_reward, std_reward = evaluate_policy(model, env)
model.save("./ppo_model")
print("learn finish",mean_reward,std_reward)
此处可能存在不合适展示的内容,页面不予展示。您可通过相关编辑功能自查并修改。
如您确认内容无涉及 不当用语 / 纯广告导流 / 暴力 / 低俗色情 / 侵权 / 盗版 / 虚假 / 无价值内容或违法国家有关法律法规的内容,可点击提交进行申诉,我们将尽快为您处理。