2 Star 0 Fork 3

杨雪元/hand_gester

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
克隆/下载
train2.py 1.67 KB
一键复制 编辑 原始数据 按行查看 历史
杨雪元 提交于 2023-09-08 12:07 . first commit
import torch
import torch.nn as nn
import torch.optim as optim
import math
class TransformerModel(nn.Module):
def __init__(self, d_model, nhead, num_layers, num_classes, num_nodes=21):
super(TransformerModel, self).__init__()
# Generate position encoding
position = torch.arange(num_nodes).unsqueeze(1).float()
div_term = torch.exp(torch.arange(0, d_model, 2).float() * -(math.log(10000.0) / d_model))
self.position_enc = torch.zeros(num_nodes, d_model)
self.position_enc[:, 0::2] = torch.sin(position * div_term)
self.position_enc[:, 1::2] = torch.cos(position * div_term)
# Add learnable group embedding
self.embedding = nn.Linear(2, d_model)
self.encoder_layer = nn.TransformerEncoderLayer(d_model, nhead)
self.encoder = nn.TransformerEncoder(self.encoder_layer, num_layers=num_layers)
self.classifier = nn.Linear(d_model, num_classes)
def forward(self, x):
# Reshape input
x = x.view(-1, 21, 2) # x shape becomes [batch_size, 21, 2]
# Embedding the input
x = self.embedding(x) # x shape becomes [batch_size, 21, d_model]
# Add position encoding
x += self.position_enc
x = x.permute(1, 0, 2) # Change shape to [seq_len, batch_size, d_model] as expected by Transformer
x = self.encoder(x)
x = self.classifier(x[0]) # Use the first token for classification
return x
# Create a model
model = TransformerModel(d_model=12, nhead=4, num_layers=2, num_classes=13)
# Test input
input_data = torch.rand((8, 42)) # Batch size is 8
output = model(input_data)
print(output.shape) # Should print torch.Size([8, 10])
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
1
https://gitee.com/yang_xueyuan/hand_gester.git
git@gitee.com:yang_xueyuan/hand_gester.git
yang_xueyuan
hand_gester
hand_gester
master

搜索帮助

D67c1975 1850385 1daf7b77 1850385