1 Star 0 Fork 0

gisleung/py_pytorch_learn

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
该仓库未声明开源许可证文件(LICENSE),使用请关注具体项目描述及其代码上游依赖。
克隆/下载
014 nn_sequential.py 2.76 KB
一键复制 编辑 原始数据 按行查看 历史
gisleung 提交于 2022-04-01 20:26 . 复习完毕
"""
使用 Sequential 完成简单网络搭建
"""
import torch
import torchvision
from torch import nn
from torch.nn import Conv2d, MaxPool2d, ReLU, Sigmoid, Linear, Flatten, Sequential
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
writer = SummaryWriter("logs/014")
dataset = torchvision.datasets.CIFAR10(root="./visionData", train=False, transform=torchvision.transforms.ToTensor(),
download=True)
dataloader = DataLoader(dataset, batch_size=64)
# 注:该网络结构没有使用激活函数
'''
Conv2d:
# 通过 输入分辨率、输出分辨率、卷积个尺寸 计算 -> padding
# 计算公式 https://pytorch.org/docs/stable/generated/torch.nn.Conv2d.html?highlight=conv2d#torch.nn.Conv2d
Linear:
# 将一个线性特征经过全连接的方式,转换成另一个长度的线性特征
'''
class MyNn(nn.Module):
def __init__(self):
super(MyNn, self).__init__()
self.conv1 = Conv2d(3, 32, 5, padding=2) # 第1层卷积 ↓
self.maxpool1 = MaxPool2d(2) # 第1层池化 ↓
self.conv2 = Conv2d(32, 32, 5, padding=2) # 第2层卷积 ↓
self.maxpool2 = MaxPool2d(2) # 第2层池化 ↓
self.conv3 = Conv2d(32, 64, 5, padding=2) # 第3层卷积 ↓
self.maxpool3 = MaxPool2d(2) # 第3层池化 ↓
self.flatten = Flatten() # 展平 (全连接 ↓ )
self.linear1 = Linear(1024, 64) # 线性层1 (全连接 ↓ )
self.linear2 = Linear(64, 10) # 线性层2( 输出 )
def forward(self, x):
x = self.conv1(x)
x = self.maxpool1(x)
x = self.conv2(x)
x = self.maxpool2(x)
x = self.conv3(x)
x = self.maxpool3(x)
x = self.flatten(x)
x = self.linear1(x)
x = self.linear2(x)
return x
# 使用Sequential重构网络。使代码更加简洁
class MyNnSe(nn.Module):
def __init__(self):
super(MyNnSe, self).__init__()
self.model1 = Sequential(
Conv2d(3, 32, 5, padding=2),
MaxPool2d(2),
Conv2d(32, 32, 5, padding=2),
MaxPool2d(2),
Conv2d(32, 64, 5, padding=2),
MaxPool2d(2),
Flatten(),
Linear(1024, 64),
Linear(64, 10)
)
def forward(self, x):
x = self.model1(x)
return x
# mynn = MyNn()
# print(mynn)
# input = torch.ones((64, 3, 32, 32))
# output = mynn(input)
# print(output.shape)
mySeNn = MyNnSe()
print(mySeNn)
input = torch.ones((64, 3, 32, 32)) # (1组)64张图片
output = mySeNn(input)
print(output.shape)
writer.add_graph(mySeNn, input)
writer.close()
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
1
https://gitee.com/gisleung/py_pytorch_learn.git
git@gitee.com:gisleung/py_pytorch_learn.git
gisleung
py_pytorch_learn
py_pytorch_learn
master

搜索帮助

0d507c66 1850385 C8b1a773 1850385