1 Star 0 Fork 2

dongjiaceo/知识图谱推理

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
该仓库未声明开源许可证文件(LICENSE),使用请关注具体项目描述及其代码上游依赖。
克隆/下载
predict.py 2.88 KB
一键复制 编辑 原始数据 按行查看 历史
young-rich 提交于 2022-07-27 16:27 . 添加了多分类的模型代码
import pandas as pd
import numpy as np
import json
import jsonlines
from sentence_transformers import SentenceTransformer, InputExample, losses, CrossEncoder
from torch.utils.data import DataLoader
from sentence_transformers.cross_encoder.evaluation import CEBinaryClassificationEvaluator
# labels = ['品类_适用_场景', '品类_搭配_品类', '品类_适用_人物', '人物_蕴含_场景']
# pre_name = 'hfl-chinese-bert-wwm-ext'
for pre_name in ['hfl-chinese-bert-wwm-ext', 'hfl-chinese-macbert-base', 'hfl-chinese-roberta-wwm-ext', 'hfl-chinese-electra-180g-base-discriminator']:
for label in ['品类_适用_场景', '品类_搭配_品类', '品类_适用_人物', '人物_蕴含_场景']:
# label = '品类_适用_场景'
df_result = pd.read_csv('/home/yx/project/P_prediction/ccks_1_sbert/output/model_save_path_' +
pre_name + '_' + label + '/CEBinaryClassificationEvaluator_sts-dev_results.csv')
threshold = df_result['F1_Threshold'].values[-1]
df_dev = pd.read_pickle(
'/home/yx/project/P_prediction/ccks_1_sbert/data/dev_data.pkl')
test_samples = []
df_label_dev = df_dev.loc[(df_dev['predicate'] == label)]
sentence1 = df_label_dev.subject.values
sentence2 = df_label_dev.object.values
for s1, s2 in zip(sentence1, sentence2):
test_samples.append([s1, s2])
model = CrossEncoder(
'/home/yx/project/P_prediction/ccks_1_sbert/output/model_save_path_'+pre_name+'_' + label)
results = model.predict(test_samples)
result_i = []
for i in results:
if i > threshold:
result_i.append(1)
else:
result_i.append(0)
df_label_dev['salience'] = result_i
df_label_dev.to_pickle(
'/home/yx/project/P_prediction/ccks_1_sbert/output/model_save_path_'+pre_name+'_' + label + '/result.pkl')
df_result1 = pd.read_pickle(
'/home/yx/project/P_prediction/ccks_1_sbert/output/model_save_path_'+pre_name+'_品类_搭配_品类/result.pkl')
df_result2 = pd.read_pickle(
'/home/yx/project/P_prediction/ccks_1_sbert/output/model_save_path_'+pre_name+'_品类_适用_场景/result.pkl')
df_result3 = pd.read_pickle(
'/home/yx/project/P_prediction/ccks_1_sbert/output/model_save_path_'+pre_name+'_品类_适用_人物/result.pkl')
df_result4 = pd.read_pickle(
'/home/yx/project/P_prediction/ccks_1_sbert/output/model_save_path_'+pre_name+'_人物_蕴含_场景/result.pkl')
df_result = pd.concat([df_result1, df_result2, df_result3, df_result4])
# df_result[['salience', 'triple_id']].to_json('/home/yx/project/P_prediction/ccks_1_sbert/data/result_'+pre_name +'.jsonl', orient='records', lines= True)
df_result[['salience', 'triple_id']].to_pickle(
'/home/yx/project/P_prediction/ccks_1_sbert/data/result_'+pre_name + '.pkl')
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
1
https://gitee.com/ecpbid/knowledge-map-reasoning.git
git@gitee.com:ecpbid/knowledge-map-reasoning.git
ecpbid
knowledge-map-reasoning
知识图谱推理
master

搜索帮助

0d507c66 1850385 C8b1a773 1850385