代码拉取完成,页面将自动刷新
import pandas as pd
import argparse
import numpy as np
import torch
import os
import datetime
import time
from matplotlib import pyplot as plt
from data.datautils import Dataset_ETT_hour,batch_x_ffts
from utils.util import EarlyStopping,_logger
from torch.utils.data import DataLoader
from model.encoder import Time_Frequence_Mul
from model.decoder import linear_Decoder,Attention_Decoder
from trainer import Trainer
from Config import Configs
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--data_path',default="d:\\data\\etth",type = str,help="数据的根路径")
parser.add_argument('--data_name',default='ETTh1.csv',type = str, help="数据集名字")
parser.add_argument('--save_path',default='./resModel',type=str,help="模型存储的地方")
parser.add_argument('--experiment_description',default='pretrain',type= str,help='帮助记住这是干啥的')
parser.add_argument('--modelname',default='linear_5_100_50.pth',type=str,help=' name of saved model.')
parser.add_argument('--seed',default=3678,type = int,help="random seed")
parser.add_argument('--lamubda',default=1,type=int,help='regulational size. ')
parser.add_argument('--patience',default=30,type=int,help='early stopping.')
parser.add_argument('--logs_save_dir', default='../experiments_logs', type=str,help='saving directory')
parser.add_argument('--epoches',default=400,type=int,help='the epoches of learning. ')
parser.add_argument('--training_mode', default='pre_train', type=str, help='pre_train, training')
parser.add_argument('--run_description', default='run1', type=str,help='Experiment Description')
parser.add_argument('--size',default=[96,24,24],help='size for learning , training, testing')
# parser.add_argument('--device',default=0,type=int,help ='training device ,cpu or gpu. ')
args = parser.parse_args()
configs = Configs()
#
SEED = args.seed
experiment_log_dir = os.path.join("chec",args.logs_save_dir,args.experiment_description, args.training_mode + f"_seed_{SEED}")
log_file_name = os.path.join(experiment_log_dir, f"logs_{datetime.datetime.now().strftime('%d_%m_%Y_%H_%M_%S')}.log")
logger = _logger(log_file_name)
logger.debug(f'Data_name: {args.data_name}')
logger.debug(f'Mode: {args.training_mode}')
logger.debug("Data loaded ...")
# 设置相关的随机数种子
torch.manual_seed(SEED)
torch.backends.cudnn.deterministic = False
torch.backends.cudnn.benchmark = False
np.random.seed(SEED)
# 准备数据
Data = Dataset_ETT_hour
train_data_set = Data(args.data_path ,flag = 'train', size = configs.SIZE)
test_data_set = Data(args.data_path ,flag = 'test', size = configs.SIZE)
valid_data_set = Data(args.data_path,flag = 'val', size = configs.SIZE)
train_dl = DataLoader(train_data_set,batch_size = configs.BATCH_SIZE,shuffle = configs.SHUFFLE_FLAG,drop_last = configs.DROP_LAST)
test_dl = DataLoader(test_data_set,batch_size = configs.BATCH_SIZE,shuffle = configs.SHUFFLE_FLAG,drop_last = configs.DROP_LAST)
valid_dl = DataLoader(valid_data_set,batch_size = configs.BATCH_SIZE,shuffle = configs.SHUFFLE_FLAG,drop_last = configs.DROP_LAST)
#earlyStopping:
early_stopping = EarlyStopping(patience=3, verbose=True)
# Load Model
model = Time_Frequence_Mul(configs.INPUT_DIMS,configs.OUTPUT_DIMS,configs.HIDDEN_DIMS,configs.lr,configs.BATCH_SIZE,configs.DEVICE,configs.VARS,configs).to(configs.DEVICE)
optimizer = torch.optim.Adam(model.parameters(),lr = configs.lr,weight_decay=3e-4)
decoder = linear_Decoder(configs.SEQ_LEN,configs.LABEL_LEN + configs.PRED_LEN,configs.HIDDEN_DIMS,configs.lr,configs.BATCH_SIZE,configs.DEVICE,configs.VARS).to(configs.DEVICE)
time_now = time.time()
Trainer(model,optimizer,train_dl,test_dl,valid_dl,configs,args.training_mode,logger,decoder,experiment_log_dir)
logger.debug(f"Training time is : {datetime.datetime.now()-time_now}")
此处可能存在不合适展示的内容,页面不予展示。您可通过相关编辑功能自查并修改。
如您确认内容无涉及 不当用语 / 纯广告导流 / 暴力 / 低俗色情 / 侵权 / 盗版 / 虚假 / 无价值内容或违法国家有关法律法规的内容,可点击提交进行申诉,我们将尽快为您处理。