1 Star 3 Fork 4

Feel_Again/transformer-quant

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
克隆/下载
main.py 9.99 KB
一键复制 编辑 原始数据 按行查看 历史
烦却 提交于 2023-05-17 20:25 . 任务交接
import json
import logging
import os
from argparse import ArgumentParser
from datetime import datetime
from typing import List, Tuple
import lightgbm as lgb
import xgboost as xgb
from torch import nn
from torch.utils.data import DataLoader
from torchinfo import summary
from src.cache import cache_factor
from src.data import MinuteFactorDataset, get_stock
from src.loss import CCCLoss, ICLoss
from src.model import *
from src.models.DLinear import DLinear
from src.models.TCN import TCN
from src.models2.PatchTST import Model as PatchTST
from src.test import Tester, test_gb
from src.train import train, train_gb
from src.util import *
from src.transformer import MyTransformer, MyModel, FinalModel
if __name__ == '__main__':
parser = ArgumentParser(description='基于Transformer架构的量化预测研究')
parser.add_argument('--stock', nargs='+', default=['zz500'], help='股票池范围,hs300,zz500,zz1000')
parser.add_argument('--start-date', type=str, default='2020-02-20', help='开始日期')
parser.add_argument('--end-date', type=str, default='2022-02-20', help='结束日期')
parser.add_argument('--seq-len', type=int, default=240, help='使用多少条分钟频数据预测,最好是240的整数倍且不超1200')
parser.add_argument('--daily-len', type=int, default=20, help='使用多少条日频数据预测')
parser.add_argument('--pred-num', type=int, default=240, help='预测多少条数据后的变化率,60,180,240,1200,4800')
parser.add_argument('--pred-time-points', nargs='+', default=['15:15:00'], help='进行预测的时间点')
parser.add_argument('--lr', type=float, default=1e-4, help='学习率')
parser.add_argument('--batch-size', type=int, default=128, help='batch大小')
parser.add_argument('--epoch-num', type=int, default=15, help='epoch数')
parser.add_argument('--model', type=str, default='Final', help='模型名称')
parser.add_argument('--loss-fn', type=str, default='MSE', help='损失函数')
parser.add_argument('--early-stop', type=int, default=10, help='早停轮数')
parser.add_argument('--save-path', type=str, default='', help='模型、log的保存路径')
parser.add_argument('--continue-train', action='store_true', help='是否继续训练,按照save_path的数据')
parser.add_argument('--cache', action='store_true', help='缓存因子数据 e.g `--cache --stock "zz500"`')
parser.add_argument('--prepare', action='store_true', help='准备数据集 e.g `--prepare --pred-time-points "10:00:00" "10:30:00" "11:00:00" "11:45:00" "13:30:00" "14:00:00" "14:30:00" "15:15:00"`')
parser.add_argument('--test', type=str, default='', help='生成测试')
parser.add_argument('--hidden-size', type=int, default=256, help='隐藏层大小')
parser.add_argument('--layer-num', type=int, default=2, help='隐藏层层数')
parser.add_argument('--dropout', type=float, default=0.0, help='dropout')
parser.add_argument('--worker-num', type=int, default=0, help='num_workers大小')
parser.add_argument('--attention', type=str, default='origin', help='注意力层类型')
ns = parser.parse_args()
seq_len: int = ns.seq_len
daily_len: int = ns.daily_len
minute_size = 67 # 分钟频数据维度
# feat_len = 73 # 含有4800聚合周期的分钟频数据维度
daily_size = 6 # 日频数据维度
pred_num: int = ns.pred_num
pred_time_points: List[str] = ns.pred_time_points
# pred_time_points = ['10:00:00', '10:30:00', '11:00:00', '11:45:00', '13:30:00', '14:00:00', '14:30:00', '15:15:00']
# pred_time_points = ['11:45:00', '15:15:00']
batch_size: int = ns.batch_size
lr: float = ns.lr
epoch_num: int = ns.epoch_num
model_name: str = ns.model
# early_stop: int = ns.early_stop
save_path: str = ns.save_path
continue_train: bool = ns.continue_train
worker_num: int = ns.worker_num
now = datetime.now()
if save_path == '':
if continue_train:
raise RuntimeError('continue_train is True but no save_path specified')
save_path = f'{Env.output_path}{model_name}_{now.month:02d}_{now.day:02d}_{now.hour:02d}_{now.minute:02d}_{now.second:02d}/'
else:
save_path = f'{Env.output_path}{save_path}/'
ns.save_path = save_path
os.makedirs(os.path.dirname(save_path), exist_ok=True)
logging.basicConfig(filename=f'{save_path}/log.txt', level=logging.DEBUG, filemode='w', format='%(message)s')
logging.debug(f'save_path: {save_path}')
logging_banner('args')
args = {k: v for k, v in ns._get_kwargs()}
logging.debug(f'args:{args}')
with open(f'{save_path}/args.txt', 'w') as fp:
json.dump(args, fp)
stock_df, stocks, dates = get_stock(ns.stock, ns.start_date, ns.end_date)
if ns.cache:
logging_banner('cache')
for code in stocks:
logging.debug(code)
cache_factor(code)
exit(0)
l = len(dates)
n1, n2 = int(l*0.7), int(l*0.8)
train_dates, val_dates, test_dates = dates[:n1], dates[n1:n2], dates[n2:]
logging_banner('info')
logging.debug(f'feat_len:{minute_size}, feat_len2:{daily_size}')
logging.debug(f'stock+date:{stock_df.shape[0]}')
logging.debug(f'dates0:{dates[0]}, dates1:{dates[n1]}, dates2:{dates[n2]}, dates-1:{dates[-1]}')
if ns.prepare:
dataset = MinuteFactorDataset(stock_df, seq_len, pred_num, pred_time_points, prepare=True)
# dataloader: DataLoader = accelerator.prepare(DataLoader(dataset, batch_size, True))
for idx in range(len(dataset)):
(inputs1, inputs2, targets, info) = dataset[idx]
print(info)
exit(0)
train_dataset = MinuteFactorDataset(stock_df[(stock_df[DATE] >= dates[0]) & (stock_df[DATE] < dates[n1])], seq_len, pred_num, pred_time_points)
val_dataset = MinuteFactorDataset(stock_df[(stock_df[DATE] >= dates[n1]) & (stock_df[DATE] < dates[n2])], seq_len, pred_num, pred_time_points)
test_dataset = MinuteFactorDataset(stock_df[stock_df[DATE] >= dates[n2]], seq_len, pred_num, pred_time_points)
logging.debug(f'train_dataset:{len(train_dataset)}, val_dataset:{len(val_dataset)}, test_dataset:{len(test_dataset)}')
train_dataloader: DataLoader = DataLoader(train_dataset, batch_size, True, num_workers=worker_num, pin_memory=True)
val_dataloader: DataLoader = DataLoader(val_dataset, batch_size, True, num_workers=worker_num, pin_memory=True)
test_dataloader: DataLoader = DataLoader(test_dataset, batch_size, False, num_workers=worker_num, pin_memory=True)
logging.debug(f'{model_name}')
hidden_size: int = ns.hidden_size
layer_num: int = ns.layer_num
dropout: float = ns.dropout
output_size = 1
head_num = 1
if model_name == 'LSTM':
model = LSTM(minute_size, hidden_size)
# logging.debug(f'hidden_size:{hidden_size}')
elif model_name == 'LSTM+daily':
model = LSTMPlus(minute_size, daily_size, hidden_size)
elif model_name == 'Transformer':
model = Transformer(minute_size, hidden_size, output_size, layer_num, dropout)
elif model_name == 'TimeTransformer':
model = TimeVecTransformer(minute_size, hidden_size, output_size, layer_num, dropout, t2v_dim=3)
elif model_name == 'TimeTransformer+daily':
model = TimeVecTransformerPlus(minute_size, daily_size, hidden_size, output_size, layer_num, dropout, t2v_dim=3)
elif model_name == 'LSTM+Attention':
model = LSTMAttention(minute_size, hidden_size, output_size, layer_num, dropout)
elif model_name == 'TCN':
model = TCN(minute_size, 1, [128]*1, 2, dropout)
elif model_name == 'DLinear':
model = DLinear(seq_len, minute_size, pred_len=1, individual=False, enc_in=1)
elif model_name == 'NN':
model = NN(minute_size, seq_len, hidden_size)
elif model_name == 'LGBM' or model_name == 'XGB':
if model_name == 'LGBM':
model = lgb.LGBMRegressor(device='gpu', learning_rate=lr)
elif model_name == 'XGB':
model = xgb.XGBRegressor(verbosity=1, gpu_id=0, tree_method='gpu_hist', learning_rate=lr)
logging.debug(f'{model}')
logging_banner('train')
model = train_gb(model, epoch_num, train_dataloader, val_dataloader, save_path)
logging_banner('test')
test_gb(model, test_dataloader)
exit(0)
elif model_name == 'Final':
model = FinalModel(minute_size, daily_size, hidden_size, output_size, layer_num, dropout)
elif model_name == 'MyModel':
# model = MyTransformer(layer_num, minute_size, hidden_size, head_num, dropout, t2v_dim=3, out_dim=output_size)
model = MyModel(minute_size, daily_size, hidden_size, output_size, layer_num, dropout, ns.attention)
elif model_name == 'PatchTST':
class Conf:
seq_len = seq_len
pred_len = output_size
d_model = minute_size
enc_in = minute_size
factor = 5
e_layers = layer_num
d_ff = hidden_size
n_heads = head_num
dropout = dropout
num_class = 1
activation = 'relu'
output_attention = False
model = PatchTST(Conf())
logging.debug(summary(model, input_data=[torch.rand((batch_size, seq_len, minute_size)), torch.rand((batch_size, daily_len, daily_size))]))
if ns.loss_fn == 'MSE':
loss_fn = nn.MSELoss()
elif ns.loss_fn == 'IC':
loss_fn = ICLoss()
elif ns.loss_fn == 'CCC':
loss_fn = CCCLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
model, optimizer = accelerator.prepare(model, optimizer)
logging_banner('train')
train_dataloader, val_dataloader, test_dataloader = accelerator.prepare(train_dataloader, val_dataloader, test_dataloader)
best_val_model = train(model, loss_fn, optimizer, epoch_num, train_dataloader, val_dataloader, save_path, ns.early_stop)
logging_banner('test')
Tester.eval_model(best_val_model, test_dataloader, loss_fn)
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
Python
1
https://gitee.com/Feel_Again/transformer-quant.git
git@gitee.com:Feel_Again/transformer-quant.git
Feel_Again
transformer-quant
transformer-quant
master

搜索帮助

D67c1975 1850385 1daf7b77 1850385