1 Star 0 Fork 0

陈志豪/Pytorch实战-物体分类

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
该仓库未声明开源许可证文件(LICENSE),使用请关注具体项目描述及其代码上游依赖。
克隆/下载
物体分类.py 5.81 KB
一键复制 编辑 原始数据 按行查看 历史
陈志豪 提交于 2024-03-20 14:14 . 123
import torch
import torchvision
import torchvision.transforms as transforms
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt
import numpy as np
import os
#定义transform对象
batch_size = 16
transform = transforms.Compose(
[transforms.ToTensor(), # 将图片转为Tensor类型
# 对图片进行正则化。第一个参数为mean(均值),第二个为std(方差)。每个参数之所以有三个0.5,是因为有RGB三个通道。
# 综上,这句就是把图片的RGB三个通道都正则化到均值为0.5,方差为0.5的分布上。
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
trainset = torchvision.datasets.CIFAR10(root='./data1', train=True,
download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size,
shuffle=True)
testset = torchvision.datasets.CIFAR10(root='./data1', train=False,
download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size,
shuffle=False)
# CIFAR10总共10个类别
classes = ('plane', 'car', 'bird', 'cat',
'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
#定义CNN模型
class Net(nn.Module):
def __init__(self):
super().__init__()
"""
定义卷积层
nn.Conv2d包含三个重要参数:
in_channels: 输入的通道数
out_channels: 输出的通道数
kernel_size: 卷积核的大小
stride: 步长,默认为1
padding: 填充,默认为0,即不进行填充
补充:这里卷积层2d的意思是数据是“2维的”,
例如图片就是2维数据(长×宽)。同理也有Conv1d,
是针对于文本、信号等1维数据,也有Conv3d
是针对视频等这种3维数据。
"""
self.classifier = nn.Sequential(
nn.Conv2d(3, 6, 5),
# 激活函数
nn.ReLU(),
# 使用MaxPool进行下采样。
nn.MaxPool2d(2, 2),
nn.Conv2d(6, 16, 5),
nn.ReLU(),
nn.MaxPool2d(2, 2),
# 当完成卷积后,使用flatten将数据展开
# 即将tensor的shape从(batch_size, c, h, w)变成(batch_size, c*h*w),这样才能送给全连接层
nn.Flatten(),
# 最后接全连接层。
# 计算方式可以参考:https://blog.csdn.net/zhaohongfei_358/article/details/123269313
nn.Linear(16 * 5 * 5, 120),
nn.ReLU(),
nn.Linear(120, 84),
nn.ReLU(),
nn.Linear(84, 10)
# 注意这里并没有调用Softmax,也不能调Softmax
# 这是因为Softmax被包含在了CrossEntropyLoss损失函数中
# 如果这里调用的话,就会调用两遍,最后网络啥都学不着
)
#前向传播
def forward(self, x):
return self.classifier(x)
net = Net()
# 使用简单的CrossEntorpyLoss作为损失函数,一般多分类问题都用这个
criterion = nn.CrossEntropyLoss()
# 使用简单的SGD作为优化器
optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)
# 训练20次
epochs = 20
for epoch in range(epochs):
# 记录一下损失
running_loss = 0.0
for i, data in enumerate(trainloader):
# trainloader返回的是tuple,第一个是图片数,第二个对应的labels
inputs, labels = data
# 清除之前的梯度
optimizer.zero_grad()
# 进行前向传播
outputs = net(inputs)
# 计算损失
loss = criterion(outputs, labels)
# 反向传播
loss.backward()
# 更新参数
optimizer.step()
# 记录损失,每2000次打印一次损失
running_loss += loss.item()
if i % 2000 == 1999:
print(f'[{epoch + 1}, {i + 1:5d}] loss: {running_loss / 2000:.3f}')
running_loss = 0.0
print('Finished Training')
correct = 0 # 记录正确的数量
total = 0 # 记录总数
#计算模型精度
with torch.no_grad():
for data in testloader:
images, labels = data
# 前向传播
outputs = net(images)
"""
outputs.shape为(16, 10),batch_size为16, 10为类别
output这16张图片的各个类别的可能性(未经Softmax处理)
所以通过torch.max找到最大的那个。
torch.max接受两个参数,第一个是tensor,第二个是dim(维度)
这里传1,意思是在类别这个维度上取最大的
torch.max有两个输出,values和indexes,
values就是最大的数是什么,
indexes是这些最大的数的index是什么
这里我们只需要index即可,所以忽略第一个参数
"""
_, predicted = torch.max(outputs, 1)
# 记录总数量
total += labels.size(0)
# 计算正确数量
correct += (predicted == labels).sum().item()
print(f'Accuracy of the network on the 10000 test images: {100 * correct // total} %')
# 统计每个类别的正确数量和总数量
correct_pred = {classname: 0 for classname in classes}
total_pred = {classname: 0 for classname in classes}
with torch.no_grad():
for data in testloader:
images, labels = data
outputs = net(images)
_, predictions = torch.max(outputs, 1)
# collect the correct predictions for each class
for label, prediction in zip(labels, predictions):
if label == prediction:
correct_pred[classes[label]] += 1
total_pred[classes[label]] += 1
#打印各类别精度数据
for classname, correct_count in correct_pred.items():
accuracy = 100 * float(correct_count) / total_pred[classname]
print(f'Accuracy for class: {classname:5s} is {accuracy:.1f} %')
Loading...
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
Python
1
https://gitee.com/chenzhi_hao/object-classification.git
git@gitee.com:chenzhi_hao/object-classification.git
chenzhi_hao
object-classification
Pytorch实战-物体分类
master

搜索帮助