1 Star 3 Fork 4

Feel_Again/transformer-quant

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
克隆/下载
tmp.py 1.09 KB
一键复制 编辑 原始数据 按行查看 历史
烦却 提交于 2023-05-13 15:37 . 去除MyTransformer的daily部分
from torchinfo import summary
import src.model
import torch
import importlib
import models.TCN
import models.DLinear
import src.models2.TimesNet
import src.models2.PatchTST
import src.transformer
importlib.reload(src.model)
importlib.reload(src.models.TCN)
importlib.reload(src.models.DLinear)
importlib.reload(src.models2.TimesNet)
importlib.reload(src.models2.PatchTST)
importlib.reload(src.transformer)
device = 'cuda'
minute_len = 240
minute_size = 67
daily_len = 20
daily_size = 6
batch_size = 128
hidden_size = 256
class Conf:
seq_len = minute_len
label_len = 1
pred_len = 1
d_model = minute_size
enc_in = minute_size
embed = 128
c_out = 1
factor = 5
e_layers = 1
d_ff = 1
n_heads = 1
num_kernels = 1
dropout = 0.1
num_class = 1
activation = 'relu'
output_attention = False
model = src.models2.PatchTST.Model(Conf())
minute_data = torch.rand((batch_size, minute_len, minute_size))
daily_data = torch.rand((batch_size, daily_len, daily_size))
o = model(minute_data, daily_data)
print(o.shape)
summary(model, input_data=[minute_data, daily_data])
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
Python
1
https://gitee.com/Feel_Again/transformer-quant.git
git@gitee.com:Feel_Again/transformer-quant.git
Feel_Again
transformer-quant
transformer-quant
master

搜索帮助

D67c1975 1850385 1daf7b77 1850385