3 Star 1 Fork 0

tqychy/HUST_Python大作业

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
该仓库未声明开源许可证文件(LICENSE),使用请关注具体项目描述及其代码上游依赖。
克隆/下载
test.py 3.36 KB
一键复制 编辑 原始数据 按行查看 历史
tqychy 提交于 2023-01-01 14:40 . 更改测试图的画法。
"""
@Description : 模型测试
@Author : python_assignment_group
@Time : 2022/10/30 17:34:21
"""
import time
import warnings
import torch
from torch.utils.data import DataLoader, SequentialSampler
from configs import *
from nets import *
from tools.data_process import *
from tools.utils import *
warnings.filterwarnings("ignore")
nets = [Word2VecCNNNet, FastTextNet, BertNet] # 所有的网络
get_datasets = [Word2VecDataset, FastTextDataset, BertDataset] # Dataset获取
data_split = DataSplit(
train_configs[0]["raw_data_path"], test_percent=train_configs[0]["test_percent"], data_split_num=train_configs[0]["data_split_num"], resplit_data=False)
data_split()
# 数据集的测试结果
test_results = []
for i in range(len(nets)):
test_results.append([])
# 测试所有网络
for net_i in range(len(nets)):
# 选择网络对应参数
config = train_configs[net_i]
# 每个网络要测试data_split_num次
for i in range(1, 1+config["data_split_num"]):
# 加载数据集
_, valid_data = data_split.load_data(data_num=i)
# data_split.data_preprocess(data_num=1, data=valid_data)
valid_dataset = get_datasets[net_i](valid_data)
# 构建dataloader
valid_dataloader = DataLoader(
valid_dataset,
sampler=SequentialSampler(valid_dataset), # 按顺序测试
batch_size=config["batch_size"],
)
# 构建网络
net = nets[net_i]()
# 使用GPU
net.to(device)
# 加载模型的参数
cache_path = os.path.join(config["cache_path"], "data"+str(i))
load_net_stats(
cache_path, config["test_buffer_name"], net, None, mode="eval")
net.eval()
# 测试参数
t0 = time.time()
total_eval_accuracy = 0
print("Net:"+nets_names[net_i]+" Data:"+str(i)+'正在测试中...')
for batch in valid_dataloader:
# 不用反向传播
with torch.no_grad():
if net_i == 2:
outputs = net.forward(batch)
logits = outputs
b_labels = batch[2].to(device)
logits = logits.detach().cpu().numpy()
label_ids = b_labels.to('cpu').numpy()
else:
outputs = net.forward(batch)
logits = outputs
b_labels = batch[1].to(device)
logits = logits.detach().cpu().numpy()
label_ids = b_labels.to('cpu').numpy()
# 计算总测试准确率
total_eval_accuracy += flat_accuracy(logits, label_ids)
# 打印测试结果
validation_time = format_time(time.time() - t0)
avg_val_accuracy = total_eval_accuracy / len(valid_dataloader)
test_results[net_i].append(avg_val_accuracy)
print("测试完成!")
print("测试用时: {:}".format(validation_time))
print("准确率是{0:.4f}".format(avg_val_accuracy))
for i, net_name in enumerate(nets_names):
print(
net_name+"的平均测试准确率是:{}%".format(round(100*sum(test_results[i])/len(test_results[i]), 2)))
# 画图
fig_path = os.path.join("figs", "test_result_figs")
plot_for_test(fig_path, test_results)
np.save(os.path.join("figs", "info_for_figs", "test_results.npy"),
np.array(test_results))
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
Python
1
https://gitee.com/tqychy/python-homework.git
git@gitee.com:tqychy/python-homework.git
tqychy
python-homework
HUST_Python大作业
master

搜索帮助