代码拉取完成,页面将自动刷新
from collections import defaultdict
from transformers import BertTokenizer, BertModel
import torch
import os
import numpy as np
torch.cuda.set_device(1)
class Bert:
""" load plm bert model and encode txts
for example:
bert = Bert()
embeddings = bert.embedding("content")
"""
def __init__(self, plm_dir="/home/public/projects/emotion_dan/dataset/chinese_wwm_pytorch", using_gpu=True):
self.plm_dir = plm_dir
self.embed_size = 768
device = "cpu"
if using_gpu:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
self.device = torch.device(device)
self.tokenizer, self.model = self.load_bert_plm()
def load_bert_plm(self):
"""
load pre-training model
chinese bert pre-train : https://www.cnblogs.com/think90/p/13091705.html
if data dir is name: such as hfl/chinese-roberta-wwm-ext
will download from the hugging-face models
you can download local and set local dir : **/PLM/chinese_wwm_pytorch
**/PLM/chinese_roberta_wwm_ext_pytorch
:return: tokenizer, model
"""
tokenizer = BertTokenizer.from_pretrained(self.plm_dir)
model = BertModel.from_pretrained(self.plm_dir,
output_hidden_states=True,
output_attentions=True).to(self.device)
return tokenizer, model
def embedding(self, words):
emb = torch.tensor([self.tokenizer.encode(words)]).to(self.device)[:, :64]
all_hidden_states, all_attentions = self.model(emb)[-2:]
rep = (all_hidden_states[-2][0] * all_attentions[-2][0].
mean(dim=0).mean(dim=0).view(-1, 1)).sum(dim=0)
return rep.cpu().detach().numpy()
if __name__ == '__main__':
bert = Bert(using_gpu=False)
embeddings = bert.embedding("北京是个好地方")
print(embeddings)
print(type(embeddings))
此处可能存在不合适展示的内容,页面不予展示。您可通过相关编辑功能自查并修改。
如您确认内容无涉及 不当用语 / 纯广告导流 / 暴力 / 低俗色情 / 侵权 / 盗版 / 虚假 / 无价值内容或违法国家有关法律法规的内容,可点击提交进行申诉,我们将尽快为您处理。