代码拉取完成,页面将自动刷新
# 导入需要的包
import torch.nn as nn
class SimpleNet(nn.Module):
def __init__(self, num_classes=10):
super(SimpleNet, self).__init__()
self.conv1 = nn.Conv2d(in_channels=3, out_channels=12, kernel_size=3, stride=1, padding=1)
self.relu1 = nn.ReLU()
self.conv2 = nn.Conv2d(in_channels=12, out_channels=12, kernel_size=3, stride=1, padding=1)
self.relu2 = nn.ReLU()
self.pool = nn.MaxPool2d(kernel_size=2)
self.conv3 = nn.Conv2d(in_channels=12, out_channels=24, kernel_size=3, stride=1, padding=1)
self.relu3 = nn.ReLU()
self.conv4 = nn.Conv2d(in_channels=24, out_channels=24, kernel_size=3, stride=1, padding=1)
self.relu4 = nn.ReLU()
self.fc = nn.Linear(in_features=16 * 16 * 24, out_features=num_classes)
def forward(self, input):
output = self.conv1(input)
output = self.relu1(output)
output = self.conv2(output)
output = self.relu2(output)
output = self.pool(output)
output = self.conv3(output)
output = self.relu3(output)
output = self.conv4(output)
output = self.relu4(output)
output = output.view(-1, 16 * 16 * 24)
output = self.fc(output)
return output
class S1(nn.Module):
def __init__(self):
super(S1, self).__init__()
self.conv1 = nn.Conv2d(in_channels=3, out_channels=12, kernel_size=3, stride=1, padding=1)
def forward(self, x):
return self.conv1(x)
此处可能存在不合适展示的内容,页面不予展示。您可通过相关编辑功能自查并修改。
如您确认内容无涉及 不当用语 / 纯广告导流 / 暴力 / 低俗色情 / 侵权 / 盗版 / 虚假 / 无价值内容或违法国家有关法律法规的内容,可点击提交进行申诉,我们将尽快为您处理。