代码拉取完成,页面将自动刷新
import time
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from model import LeNet
transform = transforms.Compose(
[transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
# 导入50000张训练图片
train_set = torchvision.datasets.CIFAR10(root='./data', # 数据集存放目录
train=True, # 表示是数据集中的训练集
download=False, # 第一次运行时为True,下载数据集,下载完成后改为False
transform=transform) # 预处理过程
# 加载训练集,实际过程需要分批次(batch)训练
train_loader = torch.utils.data.DataLoader(train_set, # 导入的训练集
batch_size=36, # 每批训练的样本数
shuffle=True, # 是否打乱训练集
num_workers=0) # 使用线程数,在windows下设置为0
# 导入10000张测试图片
test_set = torchvision.datasets.CIFAR10(root='./data',
train=False, # 表示是数据集中的测试集
download=False, transform=transform)
# 加载测试集
test_loader = torch.utils.data.DataLoader(test_set,
batch_size=10000, # 每批用于验证的样本数
shuffle=False, num_workers=0)
# 获取测试集中的图像和标签,用于accuracy计算
test_data_iter = iter(test_loader)
test_image, test_label = test_data_iter.next()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)
net = LeNet() # 定义训练的网络模型
net.to(device)
loss_function = nn.CrossEntropyLoss() # 定义损失函数为交叉熵损失函数
optimizer = optim.Adam(net.parameters(), lr=0.001) # 定义优化器(训练参数,学习率)
for epoch in range(100): # 一个epoch即对整个训练集进行一次训练
running_loss = 0.0
time_start = time.perf_counter()
for step, data in enumerate(train_loader, start=0): # 遍历训练集,step从0开始计算
inputs, labels = data # 获取训练集的图像和标签
optimizer.zero_grad() # 清除历史梯度
# forward + backward + optimize
outputs = net(inputs.to(device)) # 正向传播
loss = loss_function(outputs, labels.to(device)) # 计算损失
loss.backward() # 反向传播
optimizer.step() # 优化器更新参数
# 打印耗时、损失、准确率等数据
running_loss += loss.item()
if step % 1000 == 999: # print every 1000 mini-batches,每1000步打印一次
with torch.no_grad(): # 在以下步骤中(验证过程中)不用计算每个节点的损失梯度,防止内存占用
outputs = net(test_image.to(device)) # 测试集传入网络(test_batch_size=10000),output维度为[10000,10]
predict_y = torch.max(outputs, dim=1)[1] # 以output中值最大位置对应的索引(标签)作为预测输出
accuracy = (predict_y == test_label.to(device)).sum().item() / test_label.size(0)
print('[%d, %5d] train_loss: %.3f test_accuracy: %.3f' % # 打印epoch,step,loss,accuracy
(epoch + 1, step + 1, running_loss / 500, accuracy))
print('%f s' % (time.perf_counter() - time_start)) # 打印耗时
running_loss = 0.0
print('Finished Training')
# 保存训练得到的参数
save_path = './Lenet.pth'
torch.save(net.state_dict(), save_path)
此处可能存在不合适展示的内容,页面不予展示。您可通过相关编辑功能自查并修改。
如您确认内容无涉及 不当用语 / 纯广告导流 / 暴力 / 低俗色情 / 侵权 / 盗版 / 虚假 / 无价值内容或违法国家有关法律法规的内容,可点击提交进行申诉,我们将尽快为您处理。