1 Star 0 Fork 2

qinyukun/named_entity_recognition

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
该仓库未声明开源许可证文件(LICENSE),使用请关注具体项目描述及其代码上游依赖。
克隆/下载
utils.py 1.61 KB
一键复制 编辑 原始数据 按行查看 历史
luopeixiang 提交于 2019-04-06 21:40 . add evaluating.py
import pickle
def merge_maps(dict1, dict2):
"""用于合并两个word2id或者两个tag2id"""
for key in dict2.keys():
if key not in dict1:
dict1[key] = len(dict1)
return dict1
def save_model(model, file_name):
"""用于保存模型"""
with open(file_name, "wb") as f:
pickle.dump(model, f)
def load_model(file_name):
"""用于加载模型"""
with open(file_name, "rb") as f:
model = pickle.load(f)
return model
# LSTM模型训练的时候需要在word2id和tag2id加入PAD和UNK
# 如果是加了CRF的lstm还要加入<start>和<end> (解码的时候需要用到)
def extend_maps(word2id, tag2id, for_crf=True):
word2id['<unk>'] = len(word2id)
word2id['<pad>'] = len(word2id)
tag2id['<unk>'] = len(tag2id)
tag2id['<pad>'] = len(tag2id)
# 如果是加了CRF的bilstm 那么还要加入<start> 和 <end>token
if for_crf:
word2id['<start>'] = len(word2id)
word2id['<end>'] = len(word2id)
tag2id['<start>'] = len(tag2id)
tag2id['<end>'] = len(tag2id)
return word2id, tag2id
def prepocess_data_for_lstmcrf(word_lists, tag_lists, test=False):
assert len(word_lists) == len(tag_lists)
for i in range(len(word_lists)):
word_lists[i].append("<end>")
if not test: # 如果是测试数据,就不需要加end token了
tag_lists[i].append("<end>")
return word_lists, tag_lists
def flatten_lists(lists):
flatten_list = []
for l in lists:
if type(l) == list:
flatten_list += l
else:
flatten_list.append(l)
return flatten_list
Loading...
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
1
https://gitee.com/qinyukun/named_entity_recognition.git
git@gitee.com:qinyukun/named_entity_recognition.git
qinyukun
named_entity_recognition
named_entity_recognition
master

搜索帮助