1 Star 1 Fork 0

coshpr/emotion_classification

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
克隆/下载
bert.py 1.95 KB
一键复制 编辑 原始数据 按行查看 历史
coshpr 提交于 2022-05-29 21:25 . 结果保存以及混淆矩阵展示
from collections import defaultdict
from transformers import BertTokenizer, BertModel
import torch
import os
import numpy as np
torch.cuda.set_device(1)
class Bert:
""" load plm bert model and encode txts
for example:
bert = Bert()
embeddings = bert.embedding("content")
"""
def __init__(self, plm_dir="/home/public/projects/emotion_dan/dataset/chinese_wwm_pytorch", using_gpu=True):
self.plm_dir = plm_dir
self.embed_size = 768
device = "cpu"
if using_gpu:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
self.device = torch.device(device)
self.tokenizer, self.model = self.load_bert_plm()
def load_bert_plm(self):
"""
load pre-training model
chinese bert pre-train : https://www.cnblogs.com/think90/p/13091705.html
if data dir is name: such as hfl/chinese-roberta-wwm-ext
will download from the hugging-face models
you can download local and set local dir : **/PLM/chinese_wwm_pytorch
**/PLM/chinese_roberta_wwm_ext_pytorch
:return: tokenizer, model
"""
tokenizer = BertTokenizer.from_pretrained(self.plm_dir)
model = BertModel.from_pretrained(self.plm_dir,
output_hidden_states=True,
output_attentions=True).to(self.device)
return tokenizer, model
def embedding(self, words):
emb = torch.tensor([self.tokenizer.encode(words)]).to(self.device)[:, :64]
all_hidden_states, all_attentions = self.model(emb)[-2:]
rep = (all_hidden_states[-2][0] * all_attentions[-2][0].
mean(dim=0).mean(dim=0).view(-1, 1)).sum(dim=0)
return rep.cpu().detach().numpy()
if __name__ == '__main__':
bert = Bert(using_gpu=False)
embeddings = bert.embedding("北京是个好地方")
print(embeddings)
print(type(embeddings))
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
Python
1
https://gitee.com/coshpr/emotion_classification.git
git@gitee.com:coshpr/emotion_classification.git
coshpr
emotion_classification
emotion_classification
master

搜索帮助

0d507c66 1850385 C8b1a773 1850385