1 Star 0 Fork 1

Xzh/ChinaVis-后端

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
该仓库未声明开源许可证文件(LICENSE),使用请关注具体项目描述及其代码上游依赖。
克隆/下载
similarity_analyse.py 1.22 KB
一键复制 编辑 原始数据 按行查看 历史
Xzh 提交于 2023-05-24 16:57 . x
from transformers import BertForMaskedLM, BertTokenizer
import torch
# Loading models
tokenizer = BertTokenizer.from_pretrained("PLMs/Erlangshen-TCBert-110M-Sentence-Embedding-Chinese")
model = BertForMaskedLM.from_pretrained("PLMs/Erlangshen-TCBert-110M-Sentence-Embedding-Chinese")
# Cosine similarity function
cos = torch.nn.CosineSimilarity(dim=0, eps=1e-8)
with torch.no_grad():
# To extract sentence representations for training data
training_input = tokenizer("""沙漠真人本至尊,青蛇罢祀出梧垣。
孝陵松柏犹樵牧,元庙何妨有泪痕。""", return_tensors="pt")
training_output = model(**training_input, output_hidden_states=True)
training_representation = torch.mean(training_output.hidden_states[-1].squeeze(), dim=0)
# To extract sentence representations for test data
test_input = tokenizer("""想像承平乐事留,履綦陈迹也风流。
轻烟翠柳今何处,十六门如十六楼。""", return_tensors="pt")
test_output = model(**test_input, output_hidden_states=True)
test_representation = torch.mean(test_output.hidden_states[-1].squeeze(), dim=0)
# Calculate similarity scores
similarity_score = cos(training_representation, test_representation)
print(similarity_score)
Loading...
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
1
https://gitee.com/Xu_xiaoxu/vis-sys-backend.git
git@gitee.com:Xu_xiaoxu/vis-sys-backend.git
Xu_xiaoxu
vis-sys-backend
ChinaVis-后端
master

搜索帮助