代码拉取完成,页面将自动刷新
同步操作将从 Lijuce/named_entity_recognition 强制同步,此操作会覆盖自 Fork 仓库以来所做的任何修改,且无法恢复!!!
确定后同步将在后台操作,完成时将刷新页面,请耐心等待。
import pickle
def merge_maps(dict1, dict2):
"""用于合并两个word2id或者两个tag2id"""
for key in dict2.keys():
if key not in dict1:
dict1[key] = len(dict1)
return dict1
def save_model(model, file_name):
"""用于保存模型"""
with open(file_name, "wb") as f:
pickle.dump(model, f)
def load_model(file_name):
"""用于加载模型"""
with open(file_name, "rb") as f:
model = pickle.load(f)
return model
# LSTM模型训练的时候需要在word2id和tag2id加入PAD和UNK
# 如果是加了CRF的lstm还要加入<start>和<end> (解码的时候需要用到)
def extend_maps(word2id, tag2id, for_crf=True):
word2id['<unk>'] = len(word2id)
word2id['<pad>'] = len(word2id)
tag2id['<unk>'] = len(tag2id)
tag2id['<pad>'] = len(tag2id)
# 如果是加了CRF的bilstm 那么还要加入<start> 和 <end>token
if for_crf:
word2id['<start>'] = len(word2id)
word2id['<end>'] = len(word2id)
tag2id['<start>'] = len(tag2id)
tag2id['<end>'] = len(tag2id)
return word2id, tag2id
def prepocess_data_for_lstmcrf(word_lists, tag_lists, test=False):
assert len(word_lists) == len(tag_lists)
for i in range(len(word_lists)):
word_lists[i].append("<end>")
if not test: # 如果是测试数据,就不需要加end token了
tag_lists[i].append("<end>")
return word_lists, tag_lists
def flatten_lists(lists):
flatten_list = []
for l in lists:
if type(l) == list:
flatten_list += l
else:
flatten_list.append(l)
return flatten_list
此处可能存在不合适展示的内容,页面不予展示。您可通过相关编辑功能自查并修改。
如您确认内容无涉及 不当用语 / 纯广告导流 / 暴力 / 低俗色情 / 侵权 / 盗版 / 虚假 / 无价值内容或违法国家有关法律法规的内容,可点击提交进行申诉,我们将尽快为您处理。