1 Star 0 Fork 2

dongjiaceo/知识图谱推理

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
该仓库未声明开源许可证文件(LICENSE),使用请关注具体项目描述及其代码上游依赖。
克隆/下载
main.py 2.60 KB
一键复制 编辑 原始数据 按行查看 历史
young-rich 提交于 2022-07-27 16:27 . 添加了多分类的模型代码
import pandas as pd
import numpy as np
import json
import jsonlines
from pip import main
from sentence_transformers import SentenceTransformer, InputExample, losses, models, CrossEncoder
from sentence_transformers.evaluation import EmbeddingSimilarityEvaluator, BinaryClassificationEvaluator
from torch.utils.data import DataLoader
from torch import nn
from sentence_transformers.cross_encoder.evaluation import CEBinaryClassificationEvaluator
import logging
from datetime import datetime
import math
from sklearn.model_selection import train_test_split
# labels = ['品类_适用_场景', '品类_搭配_品类', '品类_适用_人物', '人物_蕴含_场景']
def train(pre_model):
df = pd.read_pickle(
'/home/yx/project/P_prediction/ccks_1_sbert/data/data.pkl')
for name in ['品类_适用_场景', '品类_搭配_品类', '品类_适用_人物', '人物_蕴含_场景']:
df_label = df.loc[(df['predicate'] == name)]
sentence1 = df_label.subject.values
sentence2 = df_label.object.values
labels = df_label.salience.astype(np.float32)
train_examples = []
dev_examples = []
sentence1_train, sentence1_test, sentence2_train, sentence2_test, labels_teain, label_test = train_test_split(
sentence1, sentence2, labels, train_size=0.9, random_state=725)
for s1, s2, label in zip(sentence1_train, sentence2_train, labels_teain):
train_examples.append(InputExample(texts=[s1, s2], label=label))
for s1, s2, label in zip(sentence1_test, sentence2_test, label_test):
dev_examples.append(InputExample(texts=[s1, s2], label=label))
train_dataloader = DataLoader(
train_examples, shuffle=True, batch_size=128)
model = CrossEncoder(pre_model, num_labels=1)
evaluator = CEBinaryClassificationEvaluator.from_input_examples(
dev_examples, name='sts-dev')
warmup_steps = math.ceil(len(train_dataloader) * 10 * 0.1)
model.fit(train_dataloader=train_dataloader,
evaluator=evaluator,
epochs=30,
warmup_steps=warmup_steps,
evaluation_steps=200,
output_path="/home/yx/project/P_prediction/ccks_1_sbert/output/model_save_path_" +
pre_model.replace("/", "-") + '_' + name,
# use_amp = True,
)
if "__main__" == __name__:
# train('peterchou/nezha-chinese-base')
train('hfl/chinese-macbert-base')
train('hfl/chinese-roberta-wwm-ext')
train('hfl/chinese-bert-wwm-ext')
train('hfl/chinese-electra-180g-base-discriminator')
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
1
https://gitee.com/ecpbid/knowledge-map-reasoning.git
git@gitee.com:ecpbid/knowledge-map-reasoning.git
ecpbid
knowledge-map-reasoning
知识图谱推理
master

搜索帮助

0d507c66 1850385 C8b1a773 1850385