代码拉取完成,页面将自动刷新
import torch
import torch.nn as nn
import torch.optim as optim
import math
class TransformerModel(nn.Module):
def __init__(self, d_model, nhead, num_layers, num_classes, num_nodes=21):
super(TransformerModel, self).__init__()
# Generate position encoding
position = torch.arange(num_nodes).unsqueeze(1).float()
div_term = torch.exp(torch.arange(0, d_model, 2).float() * -(math.log(10000.0) / d_model))
self.position_enc = torch.zeros(num_nodes, d_model)
self.position_enc[:, 0::2] = torch.sin(position * div_term)
self.position_enc[:, 1::2] = torch.cos(position * div_term)
# Add learnable group embedding
self.embedding = nn.Linear(2, d_model)
self.encoder_layer = nn.TransformerEncoderLayer(d_model, nhead)
self.encoder = nn.TransformerEncoder(self.encoder_layer, num_layers=num_layers)
self.classifier = nn.Linear(d_model, num_classes)
def forward(self, x):
# Reshape input
x = x.view(-1, 21, 2) # x shape becomes [batch_size, 21, 2]
# Embedding the input
x = self.embedding(x) # x shape becomes [batch_size, 21, d_model]
# Add position encoding
x += self.position_enc
x = x.permute(1, 0, 2) # Change shape to [seq_len, batch_size, d_model] as expected by Transformer
x = self.encoder(x)
x = self.classifier(x[0]) # Use the first token for classification
return x
# Create a model
model = TransformerModel(d_model=12, nhead=4, num_layers=2, num_classes=13)
# Test input
input_data = torch.rand((8, 42)) # Batch size is 8
output = model(input_data)
print(output.shape) # Should print torch.Size([8, 10])
此处可能存在不合适展示的内容,页面不予展示。您可通过相关编辑功能自查并修改。
如您确认内容无涉及 不当用语 / 纯广告导流 / 暴力 / 低俗色情 / 侵权 / 盗版 / 虚假 / 无价值内容或违法国家有关法律法规的内容,可点击提交进行申诉,我们将尽快为您处理。