代码拉取完成,页面将自动刷新
同步操作将从 young-rich/知识图谱推理 强制同步,此操作会覆盖自 Fork 仓库以来所做的任何修改,且无法恢复!!!
确定后同步将在后台操作,完成时将刷新页面,请耐心等待。
import pandas as pd
import numpy as np
import json
import jsonlines
from pip import main
from sentence_transformers import SentenceTransformer, InputExample, losses, models, CrossEncoder
from sentence_transformers.evaluation import EmbeddingSimilarityEvaluator, BinaryClassificationEvaluator
from torch.utils.data import DataLoader
from torch import nn
from sentence_transformers.cross_encoder.evaluation import CEBinaryClassificationEvaluator
import logging
from datetime import datetime
import math
from sklearn.model_selection import train_test_split
# labels = ['品类_适用_场景', '品类_搭配_品类', '品类_适用_人物', '人物_蕴含_场景']
def train(pre_model):
df = pd.read_pickle(
'/home/yx/project/P_prediction/ccks_1_sbert/data/data.pkl')
for name in ['品类_适用_场景', '品类_搭配_品类', '品类_适用_人物', '人物_蕴含_场景']:
df_label = df.loc[(df['predicate'] == name)]
sentence1 = df_label.subject.values
sentence2 = df_label.object.values
labels = df_label.salience.astype(np.float32)
train_examples = []
dev_examples = []
sentence1_train, sentence1_test, sentence2_train, sentence2_test, labels_teain, label_test = train_test_split(
sentence1, sentence2, labels, train_size=0.9, random_state=725)
for s1, s2, label in zip(sentence1_train, sentence2_train, labels_teain):
train_examples.append(InputExample(texts=[s1, s2], label=label))
for s1, s2, label in zip(sentence1_test, sentence2_test, label_test):
dev_examples.append(InputExample(texts=[s1, s2], label=label))
train_dataloader = DataLoader(
train_examples, shuffle=True, batch_size=128)
model = CrossEncoder(pre_model, num_labels=1)
evaluator = CEBinaryClassificationEvaluator.from_input_examples(
dev_examples, name='sts-dev')
warmup_steps = math.ceil(len(train_dataloader) * 10 * 0.1)
model.fit(train_dataloader=train_dataloader,
evaluator=evaluator,
epochs=30,
warmup_steps=warmup_steps,
evaluation_steps=200,
output_path="/home/yx/project/P_prediction/ccks_1_sbert/output/model_save_path_" +
pre_model.replace("/", "-") + '_' + name,
# use_amp = True,
)
if "__main__" == __name__:
# train('peterchou/nezha-chinese-base')
train('hfl/chinese-macbert-base')
train('hfl/chinese-roberta-wwm-ext')
train('hfl/chinese-bert-wwm-ext')
train('hfl/chinese-electra-180g-base-discriminator')
此处可能存在不合适展示的内容,页面不予展示。您可通过相关编辑功能自查并修改。
如您确认内容无涉及 不当用语 / 纯广告导流 / 暴力 / 低俗色情 / 侵权 / 盗版 / 虚假 / 无价值内容或违法国家有关法律法规的内容,可点击提交进行申诉,我们将尽快为您处理。