代码拉取完成,页面将自动刷新
from __future__ import absolute_import, division, print_function
import sys
import os
import argparse
from math import log
from datetime import datetime
from tqdm import tqdm
from collections import namedtuple
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.autograd import Variable
from torch.distributions import Categorical
import threading
from functools import reduce
from random import randint, sample
from KgEnv import BatchKGEnvironment
from TrainAgent import TestActorCritic,ActorCritic,PGPRActorCritic
from utils import *
import math
import warnings
warnings.filterwarnings('ignore')
#全局变量:
data_root='/mnt/ssd/zjyang/KAPR/OnlyProduct/AAAIData'
def save_dic(filename,dataset_obj):
dataset_file=os.path.join(filename)
with open(dataset_file,'wb') as f:
pickle.dump(dataset_obj,f)
def load_pkl(file):
with open(file,'rb') as f:
return pickle.load(f)
def evaluate_ijcai(topk_matches,test_product_rproducts,product):
productidList=[]
for i in range(len(product)):
productidList.append(i)
invalid_r_products=[]
num=0
#计算指标
test_product_idxs=list(test_product_rproducts.keys())
test_num=0
hit_num_10=0
hit_num_30=0
hit_num_50=0
S=[]
R=[]
FailedRank={}
TrueRank={}
# pid 是目标头实体商品 item1
for pid in test_product_idxs:
num=num+1
#if num%100==0 and test_num!=0:
#print('Hit10:{} Hit30:{} Hit50:{} TestNum:{} ###{} {} {}'.format(hit_num_10/test_num,hit_num_30/test_num,hit_num_50/test_num,test_num,hit_num_10,hit_num_30,hit_num_50))
#是否有没有搜索到的商品
if pid not in topk_matches:
invalid_r_products.append(pid)
continue
pred_list, rel_set = topk_matches[pid][::-1], test_product_rproducts[pid]
#预测序列pred_list
#ground_truth:rel_set
# r 是目标商品 item2
for r in rel_set:
productCopy=list(productidList)
productCopy.remove(pid)
if r in productCopy:
productCopy.remove(r)
random_product=sample(productCopy,500)
test_num+=1
#如果r在预测列表中
if r in pred_list:
r_index=pred_list.index(r)
R.append(r_index)
if pid not in TrueRank.keys():
TrueRank[pid]=[]
TrueRank[pid].append(r)
# 记录在预测值前面的商品
ForwardProductList=pred_list[:r_index]
SetLen=len(set(ForwardProductList)&set(random_product))
S.append(SetLen)
#Hit 10
if SetLen<10:
hit_num_10+=1
#Hit 30
if SetLen<30:
hit_num_30+=1
#Hit 50
if SetLen<50:
hit_num_50+=1
#如果r在不在预测列表中
else:
if pid not in FailedRank.keys():
FailedRank[pid]=[]
FailedRank[pid].append(r)
continue
#print('invalid_r_products:',len(invalid_r_products))
Hit10=hit_num_10/test_num
Hit30=hit_num_30/test_num
Hit50=hit_num_50/test_num
#print('Hit10:{} Hit30:{} Hit50:{} TestNum:{}'.format(Hit10,Hit30,Hit50,test_num))
#print(len(S))
S.sort(reverse=True)
#print(S)
R.sort(reverse=True)
#print(R)
#FailPath='/mnt/ssd/zjyang/KAPR/OnlyProduct/FailAna/baby_comp.pkl'
#save_dic(FailPath,FailedRank)
TruePath='/mnt/ssd/zjyang/KAPR/OnlyProduct/FailAna/baby_sub_true.pkl'
save_dic(TruePath,TrueRank)
#obj=load_pkl(FailPath)
#for i in obj:
#print(i,obj[i],len(obj[i]))
#input()
return Hit10,Hit30,Hit50,test_num
def evaluate(topk_matches,test_product_rproducts,best_paths):
predict_path_file=args.pre_dir
pickle.dump(best_paths,open(predict_path_file,'wb'))
#print('预测新商品路径 predict_path_file len',len(best_paths))
invalid_r_products=[]
#计算指标
#一共有多少命中
FailPid=[]
H=0
HowManyProductsHaveTheirOwn=0
targeted_path=[]
hitss=[]
hit_add=0
precisions,recalls,ndcgs,hits=[],[],[],[]
test_product_idxs=list(test_product_rproducts.keys())
for pid in test_product_idxs:
#是否有没有搜索到的商品
if pid not in topk_matches or len(topk_matches[pid])<10:
invalid_r_products.append(pid)
continue
pred_list, rel_set = topk_matches[pid][::-1], test_product_rproducts[pid]
if len(pred_list)==0:
continue
dcg = 0.0
hit_num = 0.0
KKK=10
#for i in range(len(pred_list)):
for i in range(KKK):
if pred_list[i] == pid:
HowManyProductsHaveTheirOwn=HowManyProductsHaveTheirOwn+1
if pred_list[i] in rel_set:
dcg += 1. / (log(i + 2) / log(2))
hit_num += 1
H += 1
#如果商品不够10个,后面根据TransE添加尽来的商品是没有对应的路径的
#这里要排除那部分的商品
if pid in best_paths.keys() and pred_list[i] in best_paths[pid].keys():
targeted_path.append(best_paths[pid][pred_list[i]][0])
#if len(best_paths[pid][pred_list[i]][0][2])==1:
#print(best_paths[pid][pred_list[i]][0],len(best_paths[pid][pred_list[i]][0][2]),pid,pred_list[i])
#if len(best_paths[pid][pred_list[i]][0])==1:
#print(best_paths[pid][pred_list[i]][0],pid,pred_list[i])
# idcg
idcg = 0.0
for i in range(min(len(rel_set), len(pred_list))):
idcg += 1. / (log(i + 2) / log(2))
ndcg = dcg / idcg
recall = hit_num / len(rel_set)
precision = hit_num / len(pred_list)
hit = 1.0 if hit_num > 0.0 else 0.0
#记录没能找到的商品
if hit==0:
FailPid.append(pid)
#hitss.append(hit)
hit_add=hit_add+hit
ndcgs.append(ndcg)
recalls.append(recall)
precisions.append(precision)
hits.append(hit)
#保存命中的路径
#print(len(FailPid))
#print(FailPid)
avg_precision = np.mean(precisions) * 100
avg_recall = np.mean(recalls) * 100
avg_ndcg = np.mean(ndcgs) * 100
avg_hit = np.mean(hits) * 100
#print('NDCG={:.3f} | Recall={:.3f} | HR={:.3f} | Precision={:.3f} | Invalid users={}'.format(
#avg_ndcg, avg_recall, avg_hit, avg_precision, len(invalid_users)))
print('NDCG={:.3f} | Recall={:.3f} | HR={:.3f} | Precision={:.3f} '.format(avg_ndcg, avg_recall, avg_hit, avg_precision))
#print('H',H)
#print('length_pred_list',len(pred_list))
#print('hitss',hit_add,len(hitss))
#print(len(targeted_path))
#print('invalid_r_products length',len(invalid_r_products))
#print('test_product_idxs length',len(test_product_idxs))
#print('HowManyProductsHaveTheirOwn',HowManyProductsHaveTheirOwn)
targeted_path_file=args.target_dir
pickle.dump(targeted_path,open(targeted_path_file,'wb'))
return avg_ndcg,avg_recall,avg_hit,avg_precision
def load_dic(file):
try:
with open(file,'rb') as f:
return pickle.load(f)
except EOFError: #捕获异常EOFError 后返回None
return None
#函数作用:返回
def batch_beam_search(env,model,pids,device,topk=[25,5,1],train_time=1):
def _batch_acts_to_masks(batch_acts):
batch_mask=[]
for acts in batch_acts:
num_acts=len(acts)
#构建一个act_mask用来记录每个动作的概率
act_mask=np.zeros(model.act_dim,dtype=np.uint8)
#可以行动的动作act 标注为1
act_mask[:num_acts]=1
batch_mask.append(act_mask)
return np.vstack(batch_mask)
state_pool,cur_node_type,cur_node_id=env.reset(pids)
path_pool=env._batch_path
probs_pool=[[] for _ in pids]
#保存探索到的路径长度为3/4的路径
all_path_pool=[]
all_probs_pool=[]
model.eval()
#探索路径长度为3
for hop in range(3):
#acts_pool是经过剪纸后的长度为250的动作空间
acts_pool=env._batch_get_actions(path_pool,False)
actmask_pool=_batch_acts_to_masks(acts_pool)
action_emb=env.batch_action_embedding(cur_node_type,cur_node_id,acts_pool)
state_tensor=torch.FloatTensor(state_pool).to(device) #state_pool (32, 400)
actmask_tensor=torch.ByteTensor(actmask_pool).to(device) #actmask_pool (32, 251)
action_emb_tensor=torch.FloatTensor(action_emb).to(device) #action_emb (32, 251, 200)
#print(len(state_tensor),len(state_tensor[0])) #32 400
#直接得出每个batch返回的动作ID
if args.train_time==13 or args.train_time==77:
'''
使用PGPR模型原本的policy network
'''
probs,_=model((state_tensor,actmask_tensor))
#print(probs.shape) #torch.Size([32, 251])
probs=probs+actmask_tensor.float()
#原来是根据policy network每次在空间中选出25,5,1对应的top路径,进行游走
topk_probs, topk_idxs = torch.topk(probs, topk[hop], dim=1) # LongTensor of [bs, k]
topk_idxs = topk_idxs.detach().cpu().numpy() #32*25
topk_probs = topk_probs.detach().cpu().numpy() #32*25
else:
topk_idxs,topk_probs=model((state_tensor,action_emb_tensor,actmask_tensor,topk[hop]))
'''
#输入state和动作空间actmask_tensor得到每个动作的probs
probs,_=model((state_tensor,action_emb_tensor,actmask_tensor))
#print(probs.shape) #torch.Size([32, 251])
probs=probs+actmask_tensor.float()
#原来是根据policy network每次在空间中选出25,5,1对应的top路径,进行游走
topk_probs, topk_idxs = torch.topk(probs, topk[hop], dim=1) # LongTensor of [bs, k]
topk_idxs = topk_idxs.detach().cpu().numpy() #32*25
topk_probs = topk_probs.detach().cpu().numpy() #32*25
'''
#每一轮的new_path都要清空,然后跟着上一轮的路径继续探索
new_path_pool, new_probs_pool = [], []
#遍历每一个batch的数据
for row in range(topk_idxs.shape[0]):
path = path_pool[row]
probs = probs_pool[row]
#便利每个batch中的每一个动作
if type(topk_idxs[row]) is np.ndarray:
for idx, p in zip(topk_idxs[row], topk_probs[row]):
if idx >= len(acts_pool[row]): # act idx is invalid
continue
relation, next_node_id = acts_pool[row][idx] # (relation, next_node_id)
if relation == SELF_LOOP:
next_node_type = path[-1][1]
else:
next_node_type = KG_RELATION[path[-1][1]][relation]
new_path = path + [(relation, next_node_type, next_node_id)]
new_path_pool.append(new_path)
new_probs_pool.append(probs + [p])
elif type(topk_idxs[row]) is np.int64:
idx=topk_idxs[row]
p=topk_probs[row]
if idx >= len(acts_pool[row]): # act idx is invalid
continue
relation, next_node_id = acts_pool[row][idx] # (relation, next_node_id)
if relation == SELF_LOOP:
next_node_type = path[-1][1]
else:
next_node_type = KG_RELATION[path[-1][1]][relation]
new_path = path + [(relation, next_node_type, next_node_id)]
new_path_pool.append(new_path)
new_probs_pool.append(probs + [p])
path_pool = new_path_pool
probs_pool = new_probs_pool
if hop < 2:
state_pool= env._batch_get_state(path_pool)
cur_node_type=env._batch_get_cur_node_type(path_pool)
cur_node_id=env._batch_get_cur_node_id(path_pool)
all_path_pool.extend(new_path_pool)
all_probs_pool.extend(new_probs_pool)
if args.train_time==77:
return path_pool, probs_pool
else:
return all_path_pool, all_probs_pool
def load_pkl(file):
with open(file,'rb') as f:
return pickle.load(f)
def predict_paths(policy_file,path_file,test_labels,args):
#print('Predicting paths...')
'''
if args.relation[0]=='COMP':
train_path=os.path.join(data_root,args.dataset,'comp_p_p_train.pkl')
if args.relation[0]=='SUB':
train_path=os.path.join(data_root,args.dataset,'sub_p_p_train.pkl')
train_labels=load_pkl(train_path)
'''
#reward_path=os.path.join(data_root,args.dataset,'KEIM',R,'score_numpy.npy')
#RewardNumpy=np.load(reward_path)
RewardNumpy=''
env=BatchKGEnvironment(args.dataset,max_acts=args.max_acts,max_path_len=args.max_path_len,state_history=args.state_history,relation=args.relation,type=args.norel,train_time=args.train_time,delrel=args.delrel)
pretrain_sd=torch.load(policy_file)
#model=TestActorCritic(env.state_dim,env.act_dim,gamma=args.gamma,hidden_sizes=args.hidden).to(args.device)
if args.train_time==13 or args.train_time==77:
model=PGPRActorCritic(env.state_dim,env.act_dim,gamma=args.gamma,hidden_sizes=args.hidden).to(args.device)
else:
model=TestActorCritic(env.state_dim,env.act_dim,gamma=args.gamma,hidden_sizes=args.hidden).to(args.device)
#model=ActorCritic(env.state_dim,env.act_dim,gamma=args.gamma,hidden_sizes=args.hidden).to(args.device)
model_sd=model.state_dict()
model_sd.update(pretrain_sd)
model.load_state_dict(model_sd)
test_pids=list(test_labels.keys())
batch_size=4
start_idx=0
all_paths,all_probs=[],[]
pbar=tqdm(total=len(test_pids))
while start_idx<len(test_pids):
end_idx=min(start_idx+batch_size,len(test_pids))
batch_pids=test_pids[start_idx:end_idx]
if len(batch_pids)==0:
break
paths,probs=batch_beam_search(env=env,model=model,pids=batch_pids,device=args.device,topk=args.topk,train_time=args.train_time)
all_paths.extend(paths)
all_probs.extend(probs)
start_idx=end_idx
pbar.update(batch_size)
length=[]
for path,prob in zip(all_paths,all_probs):
length.append(len(path))
#print(length.count(3))
#print(length.count(4))
#input()
predicts={'paths':all_paths,'probs':all_probs}
pickle.dump(predicts,open(path_file,'wb'))
def evaluate_paths(path_file,train_labels,test_labels,reward_score,Anotherreward_score,product_list,args):
#读取related
comp_trainPath=os.path.join(data_root,args.dataset,'comp_p_p_train.pkl')
sub_trainPath=os.path.join(data_root,args.dataset,'sub_p_p_train.pkl')
comp_train=load_dic(comp_trainPath)
sub_train=load_dic(sub_trainPath)
comp_related=[]
sub_related=[]
for i in comp_train:
comp_related.extend(comp_train[i])
comp_related=set(comp_related)
for i in sub_train:
sub_related.extend(sub_train[i])
sub_related=set(sub_related)
embeds=load_embed(args.dataset,type=args.norel)
Patterns=[]
product_embeds=embeds[PRODUCT]
if args.relation[0]=='COMP':
try:
relation_embeds=embeds[COMP][0]
except:
relation_embeds=embeds[SUB][0]
related=comp_related
for pattern_id in PATH_PATTERN_COMP.keys():
pattern=PATH_PATTERN_COMP[pattern_id]
pattern=[SELF_LOOP]+[v[0] for v in pattern[1:]]
Patterns.append(tuple(pattern))
if args.relation[0]=='SUB':
try:
relation_embeds=embeds[SUB][0]
except:
relation_embeds=embeds[COMP][0]
related=sub_related
for pattern_id in PATH_PATTERN_SUB.keys():
pattern=PATH_PATTERN_SUB[pattern_id]
pattern=[SELF_LOOP]+[v[0] for v in pattern[1:]]
Patterns.append(tuple(pattern))
results=pickle.load(open(path_file,'rb'))
pred_paths={pid:{} for pid in test_labels}
i=0
pbar=tqdm(total=len(results['paths']))
relation_list=[]
ProductNum=0
CleanNum=0
Score=[]
for path,probs in zip(results['paths'],results['probs']):
i=i+1
if i%1000==0:
pbar.update(1000)
pattern=tuple([v[0] for v in path])
if path[-1][1] != PRODUCT:
continue
#这个related的条件一定要加上
if path[-1][2] not in related:
continue
#判断路径是否符合模板
#似乎不按照模板推会好一些
#if pattern not in Patterns:
#continue
#其实在test阶段的时候,可以根据score numpy排除一部分商品的
#明天在ele上试一下效果,不同的排除阈值,对实验结果的影响
#print(path)
#print(probs)
#input()
#考虑商品的id
pid=path[0][2]
#如果是product
#如果product不在related里面,那就不用考虑了,绝对不是。是的话,就把pid转化为rid,放进路径里面,因为最后测评的时候是按照rid来测的
#我还可以把路径合并起来
r_pid=path[-1][2]
#通过reward_score的数值排除一部分不可能的商品
score1=np.dot(embeds['product'][pid]+relation_embeds,embeds['product'][r_pid].T)
score2=reward_score[pid,r_pid]
score3=Anotherreward_score[pid,r_pid]
#if score1<0 and score2<0.5:
#continue
ProductNum+=1
if r_pid not in pred_paths[pid]:
pred_paths[pid][r_pid]=[]
#用TransE的值:两个商品对于特定关系的数值作为路径的得分
#path_score=score1
path_score=score1
#if reward_score[pid,r_pid]<args.threshold:
#path_score=path_score+(-1)
#用判别器的值作为路径的打分
#path_score=reward_score[pid,r_pid]
#这里进行累乘肯定不合适呀,3 4肯定有区别
path_prob=reduce(lambda x,y:x*y,probs)
'''
score=1
shuaijian=0
for s in probs:
score=score*(s*(1-0.1*shuaijian))
shuaijian=shuaijian+1
path_prob=score
'''
#print(path_prob)
#path_prob=reduce(lambda x,y:x+y,probs)
#path_prob=path_prob/len(probs)
pred_paths[pid][r_pid].append((path_score,path_prob,path))
#记录了搜索商品pid和相关商品r_pid的路径关系,包括路径分数,路径概率,路径
#print(path_score,path_prob,path)
#input()
#检查一下,加上scoreNumpy的限制后,排除了多少正确的商品对
'''
HitNum=0
AllNum=0
for i in pred_paths.keys():
pid=i
AllNum=AllNum+len(pred_paths[pid])
for j in pred_paths[pid]:
rid=j
if rid in test_labels[pid]:
HitNum+=1
print(HitNum,HitNum/AllNum)
input()
'''
#为每个用户-产品对选择最好的路径,如果它在训练集,则应删除
#print('Patterns',len(Patterns))
#print('ProductNum',ProductNum)
#print('CleanNum',CleanNum)
relation_list=list(set(relation_list))
best_pred_paths={}
best_paths={}
#print('Select the best path for product pair')
for pid in pred_paths:
#如果商品在train集合中有id
if pid in train_labels.keys():
train_r_pids=set(train_labels[pid])
else:
train_r_pids=set()
best_pred_paths[pid]=[]
for r_pid in pred_paths[pid]:
if r_pid in train_r_pids:
continue
#排除score numpy小于0.5的商品
if args.train_time!=77 and args.train_time!=11:
if reward_score[pid,r_pid]<0.5 and Anotherreward_score[pid,r_pid]<0.5:
continue
#路径排序,选择一条对于商品pid和r_pid最好的路径
#解释路径的选择,依据prob,可以牵扯更多相关的商品,解释信息更加丰富
#检查这两个商品之间的路径构成
#路径概率
sorted_path=sorted(pred_paths[pid][r_pid],key=lambda x:x[1],reverse=True)
#if sorted_path[0][1]<0.26:
#continue
#根据score得到2个商品间的最优路径
best_pred_paths[pid].append(sorted_path[0])
#如果一个商品找到的全部商品都在训练集中,那best_paths里面就没有它,而在pred_paths,keys()里面却存在
if pid not in best_paths.keys():
best_paths[pid]={}
if r_pid not in best_paths[pid].keys():
best_paths[pid][r_pid]=[]
best_paths[pid][r_pid].append(sorted_path[0])
#print(best_paths[pid][r_pid][0])
for pid in test_labels:
if pid not in best_pred_paths.keys():
best_paths[pid]={}
#input()
#检查找到的路径中,商品对的最大score numpy得分
#best_pred_paths 得到一对商品的单一推理路径
#选择top 10对于每一个商品
#测评的分数,依据score
sort_by='score'
pred_labels={}
AppendNum=0
#print('Get the top ten product')
print('best_pred_paths lenth',len(best_pred_paths))
ProNum=0
PPbar=tqdm(total=len(best_pred_paths))
for pid in best_pred_paths:
ProNum=ProNum+1
if ProNum%100==0:
PPbar.update(100)
if args.sortby=='score':
sorted_path=sorted(best_pred_paths[pid],key=lambda x:(x[0],x[1]),reverse=True)
elif sort_by=='prob':
sorted_path=sorted(best_pred_paths[pid],key=lambda x:(x[1],x[0]),reverse=True)
top10_r_pids=[p[-1][2] for _,_,p in sorted_path[:args.K]] #根据分数从大到小进行排序
#如果不够10个的话则添加
if args.add_products and len(top10_r_pids)<args.K:
if pid not in train_labels.keys():
train_r_pids=set()
train_r_pids.add(pid)
else:
train_r_pids=set(train_labels[pid])
train_r_pids.add(pid)
#使用TransE作为指标添加其他商品
scores=np.dot(embeds[PRODUCT][pid]+relation_embeds,embeds[PRODUCT].T)
#使用reward score作为指标添加其他商品
#scores_MFI=reward_score[pid,r_pid]
#使用KEIM作为指标添加其他商品
cand_r_pids=np.argsort(scores)
#根据transe的值添加额外的商品
for cand_r_pids in cand_r_pids[::-1]:
#if cand_r_pids in train_r_pids or cand_r_pids in top10_r_pids or (reward_score[pid,cand_r_pids]<0.5 and Anotherreward_score[pid,cand_r_pids]<0.5):
if cand_r_pids in train_r_pids or cand_r_pids in top10_r_pids:
continue
top10_r_pids.append(cand_r_pids)
AppendNum+=1
if len(top10_r_pids)>=args.K:
break
#添加停止
pred_labels[pid]=top10_r_pids[::-1]
print('AppendNum:',AppendNum)
avg_ndcg,avg_recall,avg_hit,avg_precision=evaluate(pred_labels,test_labels,best_paths)
print('NDCG={:.3f} | Recall={:.3f} | HR={:.3f} | Precision={:.3f} '.format(avg_ndcg,avg_recall, avg_hit,avg_precision))
#hit_10,hit_30,hit_50,test_num=evaluate_ijcai(pred_labels,test_labels,product_list)
#print('Hit@10={} Hit@30={} Hit@50={} test_num={}'.format(hit_10,hit_30,hit_50,test_num))
'''
正常计算
'''
'''
if args.eval=='normal':
avg_ndcg,avg_recall,avg_hit,avg_precision=evaluate(pred_labels,test_labels,best_paths)
return avg_ndcg,avg_recall,avg_hit,avg_precision
if args.eval=='ijcai':
Hit10,Hit30,Hit50,testNum=evaluate_ijcai(pred_labels,test_labels,product_list)
return Hit10,Hit30,Hit50,testNum
'''
'''
Hit @K计算
predict_K=10
'''
#predict_K=10
#hit_num,test_num,HitK=evaluate_ijcai(pred_labels,test_labels,best_paths,product_list,predict_K)
#return hit_num,test_num,HitK
def load_txt(file):
with open(file, 'r') as f:
return [line.strip() for line in f]
def test(args):
product_file=os.path.join(data_root,args.dataset,'product.txt')
product_list=load_txt(product_file)
#policy_file='/mnt/ssd/zjyang/KAPR/OnlyProduct/AAAITmp/ele/TrainAgent/SUB/policy_model_epoch_whole_5.ckpt'
#train_time为5的时候,使用的是通用模型,保存在sub中
policy_file=args.log_dir+'/policy_model_epoch_whole_{}_{}_{}.ckpt'.format(args.train_time,args.norel,str(args.max_acts))
#policy_file='/mnt/ssd/zjyang/KAPR/OnlyProduct/AAAITmp/baby/TrainAgent/SUB/policy_model_epoch_whole_allrelation.ckpt'
if args.train_time==77:
path_file=args.log_dir+'/pgpr_policy_paths_epoch_whole_{}_{}_{}.pkl'.format(args.train_time,args.norel,str(args.max_acts))
else:
path_file=args.log_dir+'/policy_paths_epoch_whole_{}_{}_{}.pkl'.format(args.train_time,args.norel,str(args.max_acts))
#reward_path='/mnt/ssd/amazon/littleele/Data/comp/score_numpy_0.5+0.5.npy'
#reward_score=np.load(reward_path)
print('policy_file:',policy_file)
#print('reward_score_file:',reward_path)
train_labels=load_labels(args.dataset,args.relation[0],'train',type='whole')
test_labels=load_labels(args.dataset,args.relation[0],'test',type='whole')
print(len(test_labels))
'''
d={}
for i in test_labels:
d[i]=test_labels[i]
if len(d)>400:
break
test_labels=d
'''
''
print(len(test_labels))
#print(len(test_labels))
result={}
TestNum=args.test_num
for i in range(TestNum):
#print('Epoch',i)
#policy_file=args.log_dir+'/policy_model_epoch_file_{}_4.ckpt'.format(i)
#print(policy_file)
if args.run_path:
predict_paths(policy_file,path_file,test_labels,args)
if args.run_eval:
evaluate_paths(path_file,train_labels,test_labels,reward_score,Anotherreward_score,product_list,args)
'''
if args.eval=='normal':
avg_ndcg,avg_recall,avg_hit,avg_precision=evaluate_paths(path_file,train_labels,test_labels,reward_score,Anotherreward_score,product_list,args)
print('NDCG={:.3f} | Recall={:.3f} | HR={:.3f} | Precision={:.3f} '.format(avg_ndcg,avg_recall, avg_hit,avg_precision))
if args.eval=='ijcai':
hit_10,hit_30,hit_50,test_num=evaluate_paths(path_file,train_labels,test_labels,reward_score,Anotherreward_score,product_list,args)
print('Hit@10={} Hit@30={} Hit@50={} test_num={}'.format(hit_10,hit_30,hit_50,test_num))
'''
'''
R.append(avg_ndcg)
R.append(avg_recall)
R.append(avg_hit)
R.append(avg_precision)
result[i]=R
#输出5次test最好的结果
max_hr=0
best_result=0
for i in result.keys():
hr=result[i][2]
if hr>max_hr:
best_result=i
max_hr=hr
print('Final result:')
print('NDCG={:.3f} | Recall={:.3f} | HR={:.3f} | Precision={:.3f} '.format(result[best_result][0], result[best_result][1], result[best_result][2], result[best_result][3]))
'''
if __name__ == '__main__':
boolean=lambda x:(str(x).lower()=='true')
parser=argparse.ArgumentParser()
parser.add_argument('--dataset',type=str,default='ele',help='One of {cloth,beauty,cell,cd}')
parser.add_argument('--name',type=str,default='TrainAgent',help='directory name')
parser.add_argument('--seed',type=int,default=123,help='random seed')
parser.add_argument('--gpu',type=str,default='0',help='gpu device')
parser.add_argument('--epochs',type=int,default=100,help='num of epoches')
parser.add_argument('--max_acts',type=int,default=250,help='Max number of actions')
parser.add_argument('--max_path_len',type=int,default=3,help='Max path length')
parser.add_argument('--gamma',type=float,default=0.99,help='reward discount factor')
parser.add_argument('--state_history',type=int,default=1,help='state history length')
parser.add_argument('--hidden',type=int,nargs='*',default=[512,256],help='number of samples')
parser.add_argument('--add_products',type=boolean,default=True,help='Add predicted products up to 10')
#parser.add_argument('--topk',type=int,nargs='*',default=[25,5,1],help='number of samples')
parser.add_argument('--topk',type=int,nargs='*',default=[25,10,2],help='number of samples')
parser.add_argument('--run_path',type=boolean,default=True,help='Generate predicted path?')
parser.add_argument('--run_eval',type=boolean,default=True,help='Run evaluation?')
parser.add_argument('--relation',type=str,nargs='*',default=['COMP'],help='relation type')
parser.add_argument('--K',type=int,default=1000,help='hit@K')
parser.add_argument('--goal',type=str,default='None',help='the goal of the experience')
parser.add_argument('--whole_data',type=boolean,default=True,help='Generate predicted path?')
parser.add_argument('--test_num',type=int,default=1,help='the number of test')
parser.add_argument('--train_time',type=int,default=7,help='train_time')
parser.add_argument('--threshold',type=float,default=-1,help='reward score threshold')
parser.add_argument('--sortby',type=str,default='score',help='score or probs')
parser.add_argument('--eval',type=str,default='normal',help='score or probs')
parser.add_argument('--norel',type=str,default='0',help='choose one from {beauty,cell}')
parser.add_argument('--delrel',type=str,default='None',help='choose one from {beauty,cell}')
args=parser.parse_args()
print(args)
#print('********************')
#print(args.goal)
print('********************')
print(args.dataset,args.relation,args.train_time)
print('********************')
os.environ['CUDA_VISIBLE_DEVICES']=args.gpu
args.device = torch.device('cuda:0') if torch.cuda.is_available() else 'cpu'
args.log_dir=TMP_DIR[args.dataset]+'/'+args.name+'/'+args.relation[0]
if not os.path.isdir(args.log_dir):
os.makedirs(args.log_dir)
#命中的可解释路径
args.target_dir=TMP_DIR[args.dataset]+'/'+args.name+'/'+args.relation[0]+'/targeted_path_{}_{}_{}.pkl'.format(args.train_time,args.norel,str(args.max_acts))
#用于预测新商品的可解释路径
args.pre_dir=TMP_DIR[args.dataset]+'/'+args.name+'/'+args.relation[0]+ '/predict1_path_{}_{}_{}.pkl'.format(args.train_time,args.norel,str(args.max_acts))
#if args.train_time==3 or args.train_time==4:
scorePath=os.path.join('/mnt/ssd/zjyang/KAPR/OnlyProduct/AAAITmp',args.dataset,'KEIM',args.relation[0],'score_numpy.npy')
reward_score=np.load(scorePath)
if args.relation[0]=='COMP':
AnotherscorePath=os.path.join('/mnt/ssd/zjyang/KAPR/OnlyProduct/AAAITmp',args.dataset,'KEIM','SUB','score_numpy.npy')
Anotherreward_score=np.load(AnotherscorePath)
elif args.relation[0]=='SUB':
AnotherscorePath=os.path.join('/mnt/ssd/zjyang/KAPR/OnlyProduct/AAAITmp',args.dataset,'KEIM','COMP','score_numpy.npy')
Anotherreward_score=np.load(AnotherscorePath)
#将reward数值映射到(-30,30)
'''
if args.train_time==4:
print('range change',len(reward_score),len(reward_score[0]))
Nmax=1
Nmin=-1
for i in range(len(reward_score)):
for j in range(len(reward_score[0])):
reward_score[i,j]=(Nmax-Nmin)*reward_score[i,j]+Nmin
'''
#else:
#reward_score=''
test(args)
此处可能存在不合适展示的内容,页面不予展示。您可通过相关编辑功能自查并修改。
如您确认内容无涉及 不当用语 / 纯广告导流 / 暴力 / 低俗色情 / 侵权 / 盗版 / 虚假 / 无价值内容或违法国家有关法律法规的内容,可点击提交进行申诉,我们将尽快为您处理。