代码拉取完成,页面将自动刷新
from torchinfo import summary
import src.model
import torch
import importlib
import models.TCN
import models.DLinear
import src.models2.TimesNet
import src.models2.PatchTST
import src.transformer
importlib.reload(src.model)
importlib.reload(src.models.TCN)
importlib.reload(src.models.DLinear)
importlib.reload(src.models2.TimesNet)
importlib.reload(src.models2.PatchTST)
importlib.reload(src.transformer)
device = 'cuda'
minute_len = 240
minute_size = 67
daily_len = 20
daily_size = 6
batch_size = 128
hidden_size = 256
class Conf:
seq_len = minute_len
label_len = 1
pred_len = 1
d_model = minute_size
enc_in = minute_size
embed = 128
c_out = 1
factor = 5
e_layers = 1
d_ff = 1
n_heads = 1
num_kernels = 1
dropout = 0.1
num_class = 1
activation = 'relu'
output_attention = False
model = src.models2.PatchTST.Model(Conf())
minute_data = torch.rand((batch_size, minute_len, minute_size))
daily_data = torch.rand((batch_size, daily_len, daily_size))
o = model(minute_data, daily_data)
print(o.shape)
summary(model, input_data=[minute_data, daily_data])
此处可能存在不合适展示的内容,页面不予展示。您可通过相关编辑功能自查并修改。
如您确认内容无涉及 不当用语 / 纯广告导流 / 暴力 / 低俗色情 / 侵权 / 盗版 / 虚假 / 无价值内容或违法国家有关法律法规的内容,可点击提交进行申诉,我们将尽快为您处理。