1 Star 2 Fork 1

learning-limitless/DuEE-transformers

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
该仓库未声明开源许可证文件(LICENSE),使用请关注具体项目描述及其代码上游依赖。
克隆/下载
duee_process.py 5.91 KB
一键复制 编辑 原始数据 按行查看 历史
fx 提交于 2020-08-24 11:31 . 预处理自动创建保存目录
import torch
from transformers import BertTokenizer, BertModel
import matplotlib.pyplot as plt
import json
import random
import jsonlines
from util.file_utils import *
from util.duee_utils import *
def use_roberta():
config_path = r'config/duee_config_roberta.json'
with open(config_path, 'r') as fp:
config_json = json.load(fp)
tokenizer = BertTokenizer.from_pretrained(pretrained_model_name_or_path=config_json['model_name_or_path'])
tokens = tokenizer.tokenize('消失的“外企光环”,5月份在华裁员900余人,香饽饽变“臭”了')
print(tokens)
print(tokenizer.convert_tokens_to_ids(tokens))
model = BertModel.from_pretrained(pretrained_model_name_or_path=config_json['model_name_or_path'])
# Encode text
input_ids = torch.tensor([tokenizer.encode('消失的“外企光环”,5月份在华裁员900余人,香饽饽变“臭”了',
add_special_tokens=True)]) # Add special tokens takes care of adding [CLS], [SEP], <s>... tokens in the right way for each model.
print(type(input_ids))
print(input_ids)
input_ids_a = tokenizer.convert_tokens_to_ids(tokens)
print(type(input_ids_a))
print(input_ids_a)
# assert len(set(input_ids.tolist())-set(input_ids_a.tolist()))==0
with torch.no_grad():
last_hidden_states = model(input_ids)[0] # Models outputs are now tuples
print(last_hidden_states.size())
def schema_event_type_process(schema_path, save_path):
"""schema_process"""
if not schema_path or not save_path:
raise Exception("set schema_path and save_path first")
index = 0
event_types = set()
for line in read_by_lines(schema_path):
d_json = json.loads(line)
event_types.add(d_json["event_type"])
outputs = []
for et in list(event_types):
outputs.append(u"B-{}\t{}".format(et, index))
index += 1
outputs.append(u"I-{}\t{}".format(et, index))
index += 1
outputs.append(u"O\t{}".format(index))
print(u"include event type {}, create label {}".format(
len(event_types), len(outputs)))
write_by_lines(save_path, outputs)
def schema_role_process(schema_path, save_path):
"""schema_role_process"""
if not schema_path or not save_path:
raise Exception("set schema_path and save_path first")
index = 0
roles = set()
for line in read_by_lines(schema_path):
d_json = json.loads(line)
for role in d_json["role_list"]:
roles.add(role["role"])
outputs = []
for r in list(roles):
outputs.append(u"B-{}\t{}".format(r, index))
index += 1
outputs.append(u"I-{}\t{}".format(r, index))
index += 1
outputs.append(u"O\t{}".format(index))
print(u"include roles {},create label {}".format(len(roles), len(outputs)))
write_by_lines(save_path, outputs)
def origin_events_process(origin_events_dir, save_dir, split='train'):
"""origin_events_process"""
origin_events_path = origin_events_dir + ('/%s.json' % split)
if not origin_events_dir or not save_dir:
raise Exception("set origin_events_dir and save_dir first")
output = []
lines = read_by_lines(origin_events_path)
for line in lines:
d_json = json.loads(line)
for event in d_json["event_list"]:
event["event_id"] = u"{}_{}".format(d_json["id"], event["trigger"])
event["text"] = d_json["text"]
event["id"] = d_json["id"]
output.append(json.dumps(event, ensure_ascii=False))
print(
u"include sentences {}, events {}, {} datas {}"
.format(
len(lines),
len(output), split, len(output)))
write_by_lines(u"{}/{}.json".format(save_dir, split), output)
def data_explore():
config_path = r'config/duee_config_roberta.json'
config_dict = read_json(config_path)
print(config_dict)
data_dir = config_dict['data_dir']
train_data_path = data_dir + '/train.json'
# train_data = read_json(train_data_path)
# print(type(train_data))
# print(len(train_data))
# print(json.dumps(train_data[:1],ensure_ascii=False,indent=1))
jl_reader = read_jsonlines(train_data_path)
for i in jl_reader:
print(json.dumps(i, ensure_ascii=False, indent=1))
break
def data_preprocess():
config_path = r'config/duee_config_roberta.json'
config_dict = read_json(config_path)
# print(config_dict)
data_dir = config_dict['origin_data_dir']
save_dir = data_dir + '/processed'
if not os.path.exists(save_dir):
os.mkdir(save_dir)
# 获取trigger的label
schema_path = data_dir + '/event_schema.json'
save_path = save_dir + '/vocab_trigger_label_map.txt'
if not os.path.exists(save_path):
schema_event_type_process(schema_path, save_path)
# 获取role的label
save_path = save_dir + '/vocab_role_label_map.txt'
if not os.path.exists(save_path):
schema_role_process(schema_path, save_path)
# 处理数据集
for split in ['train', 'dev']:
if not os.path.exists(save_dir + ('/%s.json' % split)):
origin_events_process(data_dir, save_dir, split=split)
def show_labels():
config_path = r'config/duee_trigger_config_roberta.json'
config_dict = read_json(config_path)
labels = get_labels(config_dict.get('labels'))
print(labels)
from util.duee_utils import *
def use_read_examples_from_file():
config_path = r'config/duee_trigger_config_roberta.json'
config_dict = read_json(config_path)
tokenizer = BertTokenizer.from_pretrained(pretrained_model_name_or_path=config_dict['model_name_or_path'])
read_examples_from_file(config_dict['data_dir'], Split.train, tokenizer, trigger_label=True)
if __name__ == '__main__':
# use_read_examples_from_file()
# show_labels()
data_preprocess()
# data_explore()
# use_roberta()
pass
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
Python
1
https://gitee.com/leaning-limitless/DuEE-transformers.git
git@gitee.com:leaning-limitless/DuEE-transformers.git
leaning-limitless
DuEE-transformers
DuEE-transformers
master

搜索帮助