代码拉取完成,页面将自动刷新
import torch
import torchvision
import torchvision.transforms as transforms
from torch import nn
from torch.utils.data import DataLoader
#数据预处理
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
#加载数据集
train_dataset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=64)
test_dataset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=64)
#定义网络结构
class re_cifar(nn.Module):
def __init__(self):
super(re_cifar, self).__init__()
self.model = nn.Sequential(
nn.Conv2d(3, 32, kernel_size=5, padding=2, stride=1),
nn.MaxPool2d(2),
nn.Conv2d(32, 32, kernel_size=5, padding=2, stride=1),
nn.MaxPool2d(2),
nn.Conv2d(32, 64, kernel_size=5, padding=2, stride=1),
nn.MaxPool2d(2),
nn.Flatten()
)
def forward(self, input):
return self.model(input)
input = torch.ones(64, 1, 32, 32)
input = torch.tensor(input, dtype=torch.float32)
input = torch.reshape(input, [-1, 3, 32, 32])
model = re_cifar()
output = model(input)
print(output)
此处可能存在不合适展示的内容,页面不予展示。您可通过相关编辑功能自查并修改。
如您确认内容无涉及 不当用语 / 纯广告导流 / 暴力 / 低俗色情 / 侵权 / 盗版 / 虚假 / 无价值内容或违法国家有关法律法规的内容,可点击提交进行申诉,我们将尽快为您处理。