代码拉取完成,页面将自动刷新
"""
@Description : 模型测试
@Author : python_assignment_group
@Time : 2022/10/30 17:34:21
"""
import time
import warnings
import torch
from torch.utils.data import DataLoader, SequentialSampler
from configs import *
from nets import *
from tools.data_process import *
from tools.utils import *
warnings.filterwarnings("ignore")
nets = [Word2VecCNNNet, FastTextNet, BertNet] # 所有的网络
get_datasets = [Word2VecDataset, FastTextDataset, BertDataset] # Dataset获取
data_split = DataSplit(
train_configs[0]["raw_data_path"], test_percent=train_configs[0]["test_percent"], data_split_num=train_configs[0]["data_split_num"], resplit_data=False)
data_split()
# 数据集的测试结果
test_results = []
for i in range(len(nets)):
test_results.append([])
# 测试所有网络
for net_i in range(len(nets)):
# 选择网络对应参数
config = train_configs[net_i]
# 每个网络要测试data_split_num次
for i in range(1, 1+config["data_split_num"]):
# 加载数据集
_, valid_data = data_split.load_data(data_num=i)
# data_split.data_preprocess(data_num=1, data=valid_data)
valid_dataset = get_datasets[net_i](valid_data)
# 构建dataloader
valid_dataloader = DataLoader(
valid_dataset,
sampler=SequentialSampler(valid_dataset), # 按顺序测试
batch_size=config["batch_size"],
)
# 构建网络
net = nets[net_i]()
# 使用GPU
net.to(device)
# 加载模型的参数
cache_path = os.path.join(config["cache_path"], "data"+str(i))
load_net_stats(
cache_path, config["test_buffer_name"], net, None, mode="eval")
net.eval()
# 测试参数
t0 = time.time()
total_eval_accuracy = 0
print("Net:"+nets_names[net_i]+" Data:"+str(i)+'正在测试中...')
for batch in valid_dataloader:
# 不用反向传播
with torch.no_grad():
if net_i == 2:
outputs = net.forward(batch)
logits = outputs
b_labels = batch[2].to(device)
logits = logits.detach().cpu().numpy()
label_ids = b_labels.to('cpu').numpy()
else:
outputs = net.forward(batch)
logits = outputs
b_labels = batch[1].to(device)
logits = logits.detach().cpu().numpy()
label_ids = b_labels.to('cpu').numpy()
# 计算总测试准确率
total_eval_accuracy += flat_accuracy(logits, label_ids)
# 打印测试结果
validation_time = format_time(time.time() - t0)
avg_val_accuracy = total_eval_accuracy / len(valid_dataloader)
test_results[net_i].append(avg_val_accuracy)
print("测试完成!")
print("测试用时: {:}".format(validation_time))
print("准确率是{0:.4f}".format(avg_val_accuracy))
for i, net_name in enumerate(nets_names):
print(
net_name+"的平均测试准确率是:{}%".format(round(100*sum(test_results[i])/len(test_results[i]), 2)))
# 画图
fig_path = os.path.join("figs", "test_result_figs")
plot_for_test(fig_path, test_results)
np.save(os.path.join("figs", "info_for_figs", "test_results.npy"),
np.array(test_results))
此处可能存在不合适展示的内容,页面不予展示。您可通过相关编辑功能自查并修改。
如您确认内容无涉及 不当用语 / 纯广告导流 / 暴力 / 低俗色情 / 侵权 / 盗版 / 虚假 / 无价值内容或违法国家有关法律法规的内容,可点击提交进行申诉,我们将尽快为您处理。