代码拉取完成,页面将自动刷新
# Name: valid
# Author: Reacubeth
# Time: 2021/8/25 10:30
# Mail: noverfitting@gmail.com
# Site: www.omegaxyz.com
# *_*coding:utf-8 *_*
import argparse
import numpy as np
import torch
import pickle
import time
import datetime
import os
import random
import utils
from cenet_model import CENET
def execute_valid(args, total_data, model,
data,
s_history, o_history,
s_label, o_label,
s_frequency, o_frequency):
s_ranks2 = []
o_ranks2 = []
all_ranks2 = []
s_ranks3 = []
o_ranks3 = []
all_ranks3 = []
total_data = utils.to_device(torch.from_numpy(total_data))
for batch_data in utils.make_batch(data,
s_history,
o_history,
s_label,
o_label,
s_frequency,
o_frequency,
args.batch_size):
batch_data[0] = utils.to_device(torch.from_numpy(batch_data[0]))
batch_data[3] = utils.to_device(torch.from_numpy(batch_data[3])).float()
batch_data[4] = utils.to_device(torch.from_numpy(batch_data[4])).float()
batch_data[5] = utils.to_device(torch.from_numpy(batch_data[5])).float()
batch_data[6] = utils.to_device(torch.from_numpy(batch_data[6])).float()
with torch.no_grad():
_, _, _, \
sub_rank2, obj_rank2, cur_loss2, \
sub_rank3, obj_rank3, cur_loss3, ce_all_acc = model(batch_data, 'Valid', total_data)
s_ranks2 += sub_rank2
o_ranks2 += obj_rank2
tmp2 = sub_rank2 + obj_rank2
all_ranks2 += tmp2
s_ranks3 += sub_rank3
o_ranks3 += obj_rank3
tmp3 = sub_rank3 + obj_rank3
all_ranks3 += tmp3
return s_ranks2, o_ranks2, all_ranks2, s_ranks3, o_ranks3, all_ranks3
此处可能存在不合适展示的内容,页面不予展示。您可通过相关编辑功能自查并修改。
如您确认内容无涉及 不当用语 / 纯广告导流 / 暴力 / 低俗色情 / 侵权 / 盗版 / 虚假 / 无价值内容或违法国家有关法律法规的内容,可点击提交进行申诉,我们将尽快为您处理。