代码拉取完成,页面将自动刷新
import torch
import torch.nn as nn
from model import ActionClassificationModel
from dataset import CustomImageFolder
from torchvision.transforms import transforms
# 1. Define data preprocessing
test_transforms = transforms.Compose([
transforms.Resize((112, 112)),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
# 2. Load test dataset
test_dataset = CustomImageFolder("dataset/val", transform=test_transforms)
test_loader = torch.utils.data.DataLoader(
test_dataset, batch_size=4, shuffle=False, num_workers=0
)
# 3. Load the saved model
model = ActionClassificationModel()
model.load_state_dict(torch.load("model_epoch_10.pth")) # Replace with your desired checkpoint file
# 4. Test the model on the test dataset
correct = 0
total = 0
with torch.no_grad():
for data in test_loader:
inputs, labels = data
outputs = model(inputs)
_, predicted = torch.max(outputs.data, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
accuracy = 100 * correct / total
print('Accuracy of the network on the test images: %d %%' % accuracy)
此处可能存在不合适展示的内容,页面不予展示。您可通过相关编辑功能自查并修改。
如您确认内容无涉及 不当用语 / 纯广告导流 / 暴力 / 低俗色情 / 侵权 / 盗版 / 虚假 / 无价值内容或违法国家有关法律法规的内容,可点击提交进行申诉,我们将尽快为您处理。