代码拉取完成,页面将自动刷新
import json
import logging
import os
from argparse import ArgumentParser
from datetime import datetime
from typing import List, Tuple
import lightgbm as lgb
import xgboost as xgb
from torch import nn
from torch.utils.data import DataLoader
from torchinfo import summary
from src.cache import cache_factor
from src.data import MinuteFactorDataset, get_stock
from src.loss import CCCLoss, ICLoss
from src.model import *
from src.models.DLinear import DLinear
from src.models.TCN import TCN
from src.models2.PatchTST import Model as PatchTST
from src.test import Tester, test_gb
from src.train import train, train_gb
from src.util import *
from src.transformer import MyTransformer, MyModel, FinalModel
if __name__ == '__main__':
parser = ArgumentParser(description='基于Transformer架构的量化预测研究')
parser.add_argument('--stock', nargs='+', default=['zz500'], help='股票池范围,hs300,zz500,zz1000')
parser.add_argument('--start-date', type=str, default='2020-02-20', help='开始日期')
parser.add_argument('--end-date', type=str, default='2022-02-20', help='结束日期')
parser.add_argument('--seq-len', type=int, default=240, help='使用多少条分钟频数据预测,最好是240的整数倍且不超1200')
parser.add_argument('--daily-len', type=int, default=20, help='使用多少条日频数据预测')
parser.add_argument('--pred-num', type=int, default=240, help='预测多少条数据后的变化率,60,180,240,1200,4800')
parser.add_argument('--pred-time-points', nargs='+', default=['15:15:00'], help='进行预测的时间点')
parser.add_argument('--lr', type=float, default=1e-4, help='学习率')
parser.add_argument('--batch-size', type=int, default=128, help='batch大小')
parser.add_argument('--epoch-num', type=int, default=15, help='epoch数')
parser.add_argument('--model', type=str, default='Final', help='模型名称')
parser.add_argument('--loss-fn', type=str, default='MSE', help='损失函数')
parser.add_argument('--early-stop', type=int, default=10, help='早停轮数')
parser.add_argument('--save-path', type=str, default='', help='模型、log的保存路径')
parser.add_argument('--continue-train', action='store_true', help='是否继续训练,按照save_path的数据')
parser.add_argument('--cache', action='store_true', help='缓存因子数据 e.g `--cache --stock "zz500"`')
parser.add_argument('--prepare', action='store_true', help='准备数据集 e.g `--prepare --pred-time-points "10:00:00" "10:30:00" "11:00:00" "11:45:00" "13:30:00" "14:00:00" "14:30:00" "15:15:00"`')
parser.add_argument('--test', type=str, default='', help='生成测试')
parser.add_argument('--hidden-size', type=int, default=256, help='隐藏层大小')
parser.add_argument('--layer-num', type=int, default=2, help='隐藏层层数')
parser.add_argument('--dropout', type=float, default=0.0, help='dropout')
parser.add_argument('--worker-num', type=int, default=0, help='num_workers大小')
parser.add_argument('--attention', type=str, default='origin', help='注意力层类型')
ns = parser.parse_args()
seq_len: int = ns.seq_len
daily_len: int = ns.daily_len
minute_size = 67 # 分钟频数据维度
# feat_len = 73 # 含有4800聚合周期的分钟频数据维度
daily_size = 6 # 日频数据维度
pred_num: int = ns.pred_num
pred_time_points: List[str] = ns.pred_time_points
# pred_time_points = ['10:00:00', '10:30:00', '11:00:00', '11:45:00', '13:30:00', '14:00:00', '14:30:00', '15:15:00']
# pred_time_points = ['11:45:00', '15:15:00']
batch_size: int = ns.batch_size
lr: float = ns.lr
epoch_num: int = ns.epoch_num
model_name: str = ns.model
# early_stop: int = ns.early_stop
save_path: str = ns.save_path
continue_train: bool = ns.continue_train
worker_num: int = ns.worker_num
now = datetime.now()
if save_path == '':
if continue_train:
raise RuntimeError('continue_train is True but no save_path specified')
save_path = f'{Env.output_path}{model_name}_{now.month:02d}_{now.day:02d}_{now.hour:02d}_{now.minute:02d}_{now.second:02d}/'
else:
save_path = f'{Env.output_path}{save_path}/'
ns.save_path = save_path
os.makedirs(os.path.dirname(save_path), exist_ok=True)
logging.basicConfig(filename=f'{save_path}/log.txt', level=logging.DEBUG, filemode='w', format='%(message)s')
logging.debug(f'save_path: {save_path}')
logging_banner('args')
args = {k: v for k, v in ns._get_kwargs()}
logging.debug(f'args:{args}')
with open(f'{save_path}/args.txt', 'w') as fp:
json.dump(args, fp)
stock_df, stocks, dates = get_stock(ns.stock, ns.start_date, ns.end_date)
if ns.cache:
logging_banner('cache')
for code in stocks:
logging.debug(code)
cache_factor(code)
exit(0)
l = len(dates)
n1, n2 = int(l*0.7), int(l*0.8)
train_dates, val_dates, test_dates = dates[:n1], dates[n1:n2], dates[n2:]
logging_banner('info')
logging.debug(f'feat_len:{minute_size}, feat_len2:{daily_size}')
logging.debug(f'stock+date:{stock_df.shape[0]}')
logging.debug(f'dates0:{dates[0]}, dates1:{dates[n1]}, dates2:{dates[n2]}, dates-1:{dates[-1]}')
if ns.prepare:
dataset = MinuteFactorDataset(stock_df, seq_len, pred_num, pred_time_points, prepare=True)
# dataloader: DataLoader = accelerator.prepare(DataLoader(dataset, batch_size, True))
for idx in range(len(dataset)):
(inputs1, inputs2, targets, info) = dataset[idx]
print(info)
exit(0)
train_dataset = MinuteFactorDataset(stock_df[(stock_df[DATE] >= dates[0]) & (stock_df[DATE] < dates[n1])], seq_len, pred_num, pred_time_points)
val_dataset = MinuteFactorDataset(stock_df[(stock_df[DATE] >= dates[n1]) & (stock_df[DATE] < dates[n2])], seq_len, pred_num, pred_time_points)
test_dataset = MinuteFactorDataset(stock_df[stock_df[DATE] >= dates[n2]], seq_len, pred_num, pred_time_points)
logging.debug(f'train_dataset:{len(train_dataset)}, val_dataset:{len(val_dataset)}, test_dataset:{len(test_dataset)}')
train_dataloader: DataLoader = DataLoader(train_dataset, batch_size, True, num_workers=worker_num, pin_memory=True)
val_dataloader: DataLoader = DataLoader(val_dataset, batch_size, True, num_workers=worker_num, pin_memory=True)
test_dataloader: DataLoader = DataLoader(test_dataset, batch_size, False, num_workers=worker_num, pin_memory=True)
logging.debug(f'{model_name}')
hidden_size: int = ns.hidden_size
layer_num: int = ns.layer_num
dropout: float = ns.dropout
output_size = 1
head_num = 1
if model_name == 'LSTM':
model = LSTM(minute_size, hidden_size)
# logging.debug(f'hidden_size:{hidden_size}')
elif model_name == 'LSTM+daily':
model = LSTMPlus(minute_size, daily_size, hidden_size)
elif model_name == 'Transformer':
model = Transformer(minute_size, hidden_size, output_size, layer_num, dropout)
elif model_name == 'TimeTransformer':
model = TimeVecTransformer(minute_size, hidden_size, output_size, layer_num, dropout, t2v_dim=3)
elif model_name == 'TimeTransformer+daily':
model = TimeVecTransformerPlus(minute_size, daily_size, hidden_size, output_size, layer_num, dropout, t2v_dim=3)
elif model_name == 'LSTM+Attention':
model = LSTMAttention(minute_size, hidden_size, output_size, layer_num, dropout)
elif model_name == 'TCN':
model = TCN(minute_size, 1, [128]*1, 2, dropout)
elif model_name == 'DLinear':
model = DLinear(seq_len, minute_size, pred_len=1, individual=False, enc_in=1)
elif model_name == 'NN':
model = NN(minute_size, seq_len, hidden_size)
elif model_name == 'LGBM' or model_name == 'XGB':
if model_name == 'LGBM':
model = lgb.LGBMRegressor(device='gpu', learning_rate=lr)
elif model_name == 'XGB':
model = xgb.XGBRegressor(verbosity=1, gpu_id=0, tree_method='gpu_hist', learning_rate=lr)
logging.debug(f'{model}')
logging_banner('train')
model = train_gb(model, epoch_num, train_dataloader, val_dataloader, save_path)
logging_banner('test')
test_gb(model, test_dataloader)
exit(0)
elif model_name == 'Final':
model = FinalModel(minute_size, daily_size, hidden_size, output_size, layer_num, dropout)
elif model_name == 'MyModel':
# model = MyTransformer(layer_num, minute_size, hidden_size, head_num, dropout, t2v_dim=3, out_dim=output_size)
model = MyModel(minute_size, daily_size, hidden_size, output_size, layer_num, dropout, ns.attention)
elif model_name == 'PatchTST':
class Conf:
seq_len = seq_len
pred_len = output_size
d_model = minute_size
enc_in = minute_size
factor = 5
e_layers = layer_num
d_ff = hidden_size
n_heads = head_num
dropout = dropout
num_class = 1
activation = 'relu'
output_attention = False
model = PatchTST(Conf())
logging.debug(summary(model, input_data=[torch.rand((batch_size, seq_len, minute_size)), torch.rand((batch_size, daily_len, daily_size))]))
if ns.loss_fn == 'MSE':
loss_fn = nn.MSELoss()
elif ns.loss_fn == 'IC':
loss_fn = ICLoss()
elif ns.loss_fn == 'CCC':
loss_fn = CCCLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
model, optimizer = accelerator.prepare(model, optimizer)
logging_banner('train')
train_dataloader, val_dataloader, test_dataloader = accelerator.prepare(train_dataloader, val_dataloader, test_dataloader)
best_val_model = train(model, loss_fn, optimizer, epoch_num, train_dataloader, val_dataloader, save_path, ns.early_stop)
logging_banner('test')
Tester.eval_model(best_val_model, test_dataloader, loss_fn)
此处可能存在不合适展示的内容,页面不予展示。您可通过相关编辑功能自查并修改。
如您确认内容无涉及 不当用语 / 纯广告导流 / 暴力 / 低俗色情 / 侵权 / 盗版 / 虚假 / 无价值内容或违法国家有关法律法规的内容,可点击提交进行申诉,我们将尽快为您处理。