代码拉取完成,页面将自动刷新
import torch
import torch.nn as nn
import torch.optim as optim
from model import ActionClassificationModel
from dataset import CustomImageFolder
from torchvision.transforms import transforms
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print("Using device:", device)
# 1. 定义数据预处理
train_transforms = transforms.Compose([
transforms.Resize((112, 112)),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
# 2. 加载数据集
train_dataset = CustomImageFolder("dataset/val", transform=train_transforms)
train_loader = torch.utils.data.DataLoader(
train_dataset, batch_size=4, shuffle=False, num_workers=0
)
"""
num_classes 视频类别数
input_channels 输入图片通道数,默认为3
img_height 输入帧的高
img_width 输入帧的宽
time_steps 每个视频的抽取帧数(将一个视频变为time_steps帧的图片)
hidden_size = 256 LSTM的隐藏层
"""
num_classes = 5
input_channels = 3
img_height = 112
img_width = 112
time_steps = 20
hidden_size = 256
# 3. 定义模型
model = ActionClassificationModel(num_classes=num_classes, input_channels=input_channels, img_height=img_height,
img_width=img_width, time_steps=time_steps, hidden_size=hidden_size).to(device)
# 4. 定义损失函数和优化器
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
# 5. 训练模型
for epoch in range(10):
running_loss = 0.0
print('Training-', epoch)
for i, data in enumerate(train_loader, 0):
inputs, labels = data
inputs, labels = inputs.to(device), labels.to(device)
optimizer.zero_grad()
outputs = model(inputs)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
running_loss += loss.item()
if i % 2 == 0: # 每 100 个 mini-batch 输出一次损失
print('[%d, %5d] loss: %.3f' % (epoch + 1, i + 1, running_loss / 100))
running_loss = 0.0
# 保存模型文件
if (epoch + 1) % 10 == 0:
checkpoint_path = "model_epoch_{}.pth".format(epoch + 1)
torch.save(model.state_dict(), checkpoint_path)
print('Finished Training')
此处可能存在不合适展示的内容,页面不予展示。您可通过相关编辑功能自查并修改。
如您确认内容无涉及 不当用语 / 纯广告导流 / 暴力 / 低俗色情 / 侵权 / 盗版 / 虚假 / 无价值内容或违法国家有关法律法规的内容,可点击提交进行申诉,我们将尽快为您处理。