代码拉取完成,页面将自动刷新
import torch
from transformers import BertTokenizer, BertModel
import matplotlib.pyplot as plt
import json
import random
import jsonlines
from util.file_utils import *
from util.duee_utils import *
def use_roberta():
config_path = r'config/duee_config_roberta.json'
with open(config_path, 'r') as fp:
config_json = json.load(fp)
tokenizer = BertTokenizer.from_pretrained(pretrained_model_name_or_path=config_json['model_name_or_path'])
tokens = tokenizer.tokenize('消失的“外企光环”,5月份在华裁员900余人,香饽饽变“臭”了')
print(tokens)
print(tokenizer.convert_tokens_to_ids(tokens))
model = BertModel.from_pretrained(pretrained_model_name_or_path=config_json['model_name_or_path'])
# Encode text
input_ids = torch.tensor([tokenizer.encode('消失的“外企光环”,5月份在华裁员900余人,香饽饽变“臭”了',
add_special_tokens=True)]) # Add special tokens takes care of adding [CLS], [SEP], <s>... tokens in the right way for each model.
print(type(input_ids))
print(input_ids)
input_ids_a = tokenizer.convert_tokens_to_ids(tokens)
print(type(input_ids_a))
print(input_ids_a)
# assert len(set(input_ids.tolist())-set(input_ids_a.tolist()))==0
with torch.no_grad():
last_hidden_states = model(input_ids)[0] # Models outputs are now tuples
print(last_hidden_states.size())
def schema_event_type_process(schema_path, save_path):
"""schema_process"""
if not schema_path or not save_path:
raise Exception("set schema_path and save_path first")
index = 0
event_types = set()
for line in read_by_lines(schema_path):
d_json = json.loads(line)
event_types.add(d_json["event_type"])
outputs = []
for et in list(event_types):
outputs.append(u"B-{}\t{}".format(et, index))
index += 1
outputs.append(u"I-{}\t{}".format(et, index))
index += 1
outputs.append(u"O\t{}".format(index))
print(u"include event type {}, create label {}".format(
len(event_types), len(outputs)))
write_by_lines(save_path, outputs)
def schema_role_process(schema_path, save_path):
"""schema_role_process"""
if not schema_path or not save_path:
raise Exception("set schema_path and save_path first")
index = 0
roles = set()
for line in read_by_lines(schema_path):
d_json = json.loads(line)
for role in d_json["role_list"]:
roles.add(role["role"])
outputs = []
for r in list(roles):
outputs.append(u"B-{}\t{}".format(r, index))
index += 1
outputs.append(u"I-{}\t{}".format(r, index))
index += 1
outputs.append(u"O\t{}".format(index))
print(u"include roles {},create label {}".format(len(roles), len(outputs)))
write_by_lines(save_path, outputs)
def origin_events_process(origin_events_dir, save_dir, split='train'):
"""origin_events_process"""
origin_events_path = origin_events_dir + ('/%s.json' % split)
if not origin_events_dir or not save_dir:
raise Exception("set origin_events_dir and save_dir first")
output = []
lines = read_by_lines(origin_events_path)
for line in lines:
d_json = json.loads(line)
for event in d_json["event_list"]:
event["event_id"] = u"{}_{}".format(d_json["id"], event["trigger"])
event["text"] = d_json["text"]
event["id"] = d_json["id"]
output.append(json.dumps(event, ensure_ascii=False))
print(
u"include sentences {}, events {}, {} datas {}"
.format(
len(lines),
len(output), split, len(output)))
write_by_lines(u"{}/{}.json".format(save_dir, split), output)
def data_explore():
config_path = r'config/duee_config_roberta.json'
config_dict = read_json(config_path)
print(config_dict)
data_dir = config_dict['data_dir']
train_data_path = data_dir + '/train.json'
# train_data = read_json(train_data_path)
# print(type(train_data))
# print(len(train_data))
# print(json.dumps(train_data[:1],ensure_ascii=False,indent=1))
jl_reader = read_jsonlines(train_data_path)
for i in jl_reader:
print(json.dumps(i, ensure_ascii=False, indent=1))
break
def data_preprocess():
config_path = r'config/duee_config_roberta.json'
config_dict = read_json(config_path)
# print(config_dict)
data_dir = config_dict['origin_data_dir']
save_dir = data_dir + '/processed'
if not os.path.exists(save_dir):
os.mkdir(save_dir)
# 获取trigger的label
schema_path = data_dir + '/event_schema.json'
save_path = save_dir + '/vocab_trigger_label_map.txt'
if not os.path.exists(save_path):
schema_event_type_process(schema_path, save_path)
# 获取role的label
save_path = save_dir + '/vocab_role_label_map.txt'
if not os.path.exists(save_path):
schema_role_process(schema_path, save_path)
# 处理数据集
for split in ['train', 'dev']:
if not os.path.exists(save_dir + ('/%s.json' % split)):
origin_events_process(data_dir, save_dir, split=split)
def show_labels():
config_path = r'config/duee_trigger_config_roberta.json'
config_dict = read_json(config_path)
labels = get_labels(config_dict.get('labels'))
print(labels)
from util.duee_utils import *
def use_read_examples_from_file():
config_path = r'config/duee_trigger_config_roberta.json'
config_dict = read_json(config_path)
tokenizer = BertTokenizer.from_pretrained(pretrained_model_name_or_path=config_dict['model_name_or_path'])
read_examples_from_file(config_dict['data_dir'], Split.train, tokenizer, trigger_label=True)
if __name__ == '__main__':
# use_read_examples_from_file()
# show_labels()
data_preprocess()
# data_explore()
# use_roberta()
pass
此处可能存在不合适展示的内容,页面不予展示。您可通过相关编辑功能自查并修改。
如您确认内容无涉及 不当用语 / 纯广告导流 / 暴力 / 低俗色情 / 侵权 / 盗版 / 虚假 / 无价值内容或违法国家有关法律法规的内容,可点击提交进行申诉,我们将尽快为您处理。