2 Star 10 Fork 1

陈泽艇/手写汉字识别

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
该仓库未声明开源许可证文件(LICENSE),使用请关注具体项目描述及其代码上游依赖。
克隆/下载
test.py 1.24 KB
一键复制 编辑 原始数据 按行查看 历史
陈泽艇 提交于 2022-03-19 18:05 . 设置cuda
import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision import transforms
from config import args
from MyDataset import MyDataset
from model import Cnn1, Cnn2
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
# 获取测试集数据
transform = transforms.ToTensor()
test_set = MyDataset(args.root + '/HWDB1.1tst_gnt', transforms=transform)
test_loader = DataLoader(test_set, batch_size=args.batch_size)
# 加载 Train 模型
device = torch.device('cuda' if args.cuda else 'cpu')
model = eval(args.model + '().to(device)') # Cnn().to(device)
model.load_state_dict(torch.load(args.root + '/param', map_location=device))
model.eval()
criterion = nn.CrossEntropyLoss()
eval_acc = 0
eval_loss = 0
# 测试
for data in test_loader:
image, label = data
image, label = image.to(device), label.to(device)
out = model(image) # 前行算法
loss = criterion(out, label) # 计算loss
eval_loss += loss.item() * label.size(0) # 计算总的loss
_, pred = torch.max(out, 1) # 预测结果
num_correct = (pred == label).sum() # 正确结果
eval_acc += num_correct.item() # 正确结果总数
print('Test Loss: {:.6f}, Acc: {:.6f}'.format(eval_loss / (len(test_set)), eval_acc / (len(test_set))))
Loading...
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
Python
1
https://gitee.com/zeting-chen/wordCNN.git
git@gitee.com:zeting-chen/wordCNN.git
zeting-chen
wordCNN
手写汉字识别
master

搜索帮助