1 Star 0 Fork 0

Lu/CENET

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
克隆/下载
valid.py 1.97 KB
一键复制 编辑 原始数据 按行查看 历史
omegaxyz 提交于 2023-04-06 22:12 . add validation process
# Name: valid
# Author: Reacubeth
# Time: 2021/8/25 10:30
# Mail: noverfitting@gmail.com
# Site: www.omegaxyz.com
# *_*coding:utf-8 *_*
import argparse
import numpy as np
import torch
import pickle
import time
import datetime
import os
import random
import utils
from cenet_model import CENET
def execute_valid(args, total_data, model,
data,
s_history, o_history,
s_label, o_label,
s_frequency, o_frequency):
s_ranks2 = []
o_ranks2 = []
all_ranks2 = []
s_ranks3 = []
o_ranks3 = []
all_ranks3 = []
total_data = utils.to_device(torch.from_numpy(total_data))
for batch_data in utils.make_batch(data,
s_history,
o_history,
s_label,
o_label,
s_frequency,
o_frequency,
args.batch_size):
batch_data[0] = utils.to_device(torch.from_numpy(batch_data[0]))
batch_data[3] = utils.to_device(torch.from_numpy(batch_data[3])).float()
batch_data[4] = utils.to_device(torch.from_numpy(batch_data[4])).float()
batch_data[5] = utils.to_device(torch.from_numpy(batch_data[5])).float()
batch_data[6] = utils.to_device(torch.from_numpy(batch_data[6])).float()
with torch.no_grad():
_, _, _, \
sub_rank2, obj_rank2, cur_loss2, \
sub_rank3, obj_rank3, cur_loss3, ce_all_acc = model(batch_data, 'Valid', total_data)
s_ranks2 += sub_rank2
o_ranks2 += obj_rank2
tmp2 = sub_rank2 + obj_rank2
all_ranks2 += tmp2
s_ranks3 += sub_rank3
o_ranks3 += obj_rank3
tmp3 = sub_rank3 + obj_rank3
all_ranks3 += tmp3
return s_ranks2, o_ranks2, all_ranks2, s_ranks3, o_ranks3, all_ranks3
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
1
https://gitee.com/yiweilu/CENET.git
git@gitee.com:yiweilu/CENET.git
yiweilu
CENET
CENET
master

搜索帮助