1 Star 0 Fork 0

许满坤/deep learning

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
克隆/下载
train.py 1.40 KB
一键复制 编辑 原始数据 按行查看 历史
许满坤 提交于 2024-10-09 17:16 . partition training code
import torch
import torchvision
import torchvision.transforms as transforms
from torch import nn
from torch.utils.data import DataLoader
#数据预处理
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
#加载数据集
train_dataset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=64)
test_dataset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=64)
#定义网络结构
class re_cifar(nn.Module):
def __init__(self):
super(re_cifar, self).__init__()
self.model = nn.Sequential(
nn.Conv2d(3, 32, kernel_size=5, padding=2, stride=1),
nn.MaxPool2d(2),
nn.Conv2d(32, 32, kernel_size=5, padding=2, stride=1),
nn.MaxPool2d(2),
nn.Conv2d(32, 64, kernel_size=5, padding=2, stride=1),
nn.MaxPool2d(2),
nn.Flatten()
)
def forward(self, input):
return self.model(input)
input = torch.ones(64, 1, 32, 32)
input = torch.tensor(input, dtype=torch.float32)
input = torch.reshape(input, [-1, 3, 32, 32])
model = re_cifar()
output = model(input)
print(output)
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
Python
1
https://gitee.com/xu-mankun/deep-learning.git
git@gitee.com:xu-mankun/deep-learning.git
xu-mankun
deep-learning
deep learning
master

搜索帮助