代码拉取完成,页面将自动刷新
同步操作将从 lightning-trader/stock_robot 强制同步,此操作会覆盖自 Fork 仓库以来所做的任何修改,且无法恢复!!!
确定后同步将在后台操作,完成时将刷新页面,请耐心等待。
from stable_baselines3.common.monitor import Monitor
from stable_baselines3.common.utils import set_random_seed
from stable_baselines3 import SAC
from stock_env import StockEnv
from save_model import SaveModelCallback
from stable_baselines3.common.vec_env import SubprocVecEnv
from stable_baselines3.common.noise import NormalActionNoise
import torch as th
import numpy as np
from stable_baselines3.common.evaluation import evaluate_policy
TB_LOG_PATH = "../tb_log"
MODEL_PATH = "./model/sac"
def make_env(rank, seed=0):
"""
Utility function for multiprocessed env.
:param env_id: (str) the environment ID
:param num_env: (int) the number of environments you wish to have in subprocesses
:param seed: (int) the inital seed for RNG
:param rank: (int) index of the subprocess
"""
def _init():
env = Monitor(StockEnv(range(2011,2021)), MODEL_PATH+'/'+str(rank))
env.seed(seed + rank)
return env
set_random_seed(seed)
return _init
#mean_reward 18.8189427 0.5833992945014679
def optimize_params(actions):
policy = dict(
activation_fn=th.nn.ReLU6,
net_arch=[
512, 128, 1280, 256, 1536, 64, 2048, 128
]
)
return {
'gamma': 0.849893018754744, 'batch_size': 16, 'buffer_size': 2000000, 'learning_starts': 1, 'learning_rate': 1.955844338087666e-05, 'tau': 0.01, 'train_freq': 16,
'policy_kwargs':policy
}
if __name__ == '__main__':
num_cpu = 64 # Number of processes to use
# Create the vectorized environment
#env = DummyVecEnv([lambda: Monitor(TrainingEnv(TRAINING_BEGIN_TIME), MODEL_PATH)])
env = SubprocVecEnv([make_env(i) for i in range(num_cpu)])
model_params = optimize_params(env.action_space.shape[-1])
#model = SAC.load("model/sac_stock",env)
model = SAC('MlpPolicy', env,verbose=1,tensorboard_log=TB_LOG_PATH,**model_params)
model.learn(total_timesteps=10000000,callback=SaveModelCallback(check_freq=4096, path=MODEL_PATH))
mean_reward, std_reward = evaluate_policy(model, env)
print(f"{mean_reward} {std_reward}")
model.save("sac_stock")
此处可能存在不合适展示的内容,页面不予展示。您可通过相关编辑功能自查并修改。
如您确认内容无涉及 不当用语 / 纯广告导流 / 暴力 / 低俗色情 / 侵权 / 盗版 / 虚假 / 无价值内容或违法国家有关法律法规的内容,可点击提交进行申诉,我们将尽快为您处理。