1 Star 0 Fork 1

ray7jq/ShortVideo Classification

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
该仓库未声明开源许可证文件(LICENSE),使用请关注具体项目描述及其代码上游依赖。
克隆/下载
train.py 2.22 KB
一键复制 编辑 原始数据 按行查看 历史
ray7jq 提交于 2023-05-28 16:26 . V1.0
import torch
import torch.nn as nn
import torch.optim as optim
from model import ActionClassificationModel
from dataset import CustomImageFolder
from torchvision.transforms import transforms
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print("Using device:", device)
# 1. 定义数据预处理
train_transforms = transforms.Compose([
transforms.Resize((112, 112)),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
# 2. 加载数据集
train_dataset = CustomImageFolder("dataset/val", transform=train_transforms)
train_loader = torch.utils.data.DataLoader(
train_dataset, batch_size=4, shuffle=False, num_workers=0
)
"""
num_classes 视频类别数
input_channels 输入图片通道数,默认为3
img_height 输入帧的高
img_width 输入帧的宽
time_steps 每个视频的抽取帧数(将一个视频变为time_steps帧的图片)
hidden_size = 256 LSTM的隐藏层
"""
num_classes = 5
input_channels = 3
img_height = 112
img_width = 112
time_steps = 20
hidden_size = 256
# 3. 定义模型
model = ActionClassificationModel(num_classes=num_classes, input_channels=input_channels, img_height=img_height,
img_width=img_width, time_steps=time_steps, hidden_size=hidden_size).to(device)
# 4. 定义损失函数和优化器
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
# 5. 训练模型
for epoch in range(10):
running_loss = 0.0
print('Training-', epoch)
for i, data in enumerate(train_loader, 0):
inputs, labels = data
inputs, labels = inputs.to(device), labels.to(device)
optimizer.zero_grad()
outputs = model(inputs)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
running_loss += loss.item()
if i % 2 == 0: # 每 100 个 mini-batch 输出一次损失
print('[%d, %5d] loss: %.3f' % (epoch + 1, i + 1, running_loss / 100))
running_loss = 0.0
# 保存模型文件
if (epoch + 1) % 10 == 0:
checkpoint_path = "model_epoch_{}.pth".format(epoch + 1)
torch.save(model.state_dict(), checkpoint_path)
print('Finished Training')
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
1
https://gitee.com/ray7jq/short-video-classification.git
git@gitee.com:ray7jq/short-video-classification.git
ray7jq
short-video-classification
ShortVideo Classification
master

搜索帮助

D67c1975 1850385 1daf7b77 1850385