1 Star 0 Fork 0

善若水/uie_pytorch

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
克隆/下载
utils.py 45.33 KB
一键复制 编辑 原始数据 按行查看 历史
1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297
# Copyright (c) 2022 Heiheiyoyo. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import contextlib
import functools
import json
import logging
import math
import random
import re
import shutil
import threading
import time
from functools import partial
import colorlog
import numpy as np
import torch
from colorama import Back, Fore
from torch.utils.data import Dataset
from tqdm import tqdm
from tqdm.contrib.logging import logging_redirect_tqdm
loggers = {}
log_config = {
'DEBUG': {
'level': 10,
'color': 'purple'
},
'INFO': {
'level': 20,
'color': 'green'
},
'TRAIN': {
'level': 21,
'color': 'cyan'
},
'EVAL': {
'level': 22,
'color': 'blue'
},
'WARNING': {
'level': 30,
'color': 'yellow'
},
'ERROR': {
'level': 40,
'color': 'red'
},
'CRITICAL': {
'level': 50,
'color': 'bold_red'
}
}
def set_seed(seed):
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
np.random.seed(seed)
random.seed(seed)
np.random.seed(seed)
def get_span(start_ids, end_ids, with_prob=False):
"""
Get span set from position start and end list.
Args:
start_ids (List[int]/List[tuple]): The start index list.
end_ids (List[int]/List[tuple]): The end index list.
with_prob (bool): If True, each element for start_ids and end_ids is a tuple aslike: (index, probability).
Returns:
set: The span set without overlapping, every id can only be used once .
"""
if with_prob:
start_ids = sorted(start_ids, key=lambda x: x[0])
end_ids = sorted(end_ids, key=lambda x: x[0])
else:
start_ids = sorted(start_ids)
end_ids = sorted(end_ids)
start_pointer = 0
end_pointer = 0
len_start = len(start_ids)
len_end = len(end_ids)
couple_dict = {}
while start_pointer < len_start and end_pointer < len_end:
if with_prob:
start_id = start_ids[start_pointer][0]
end_id = end_ids[end_pointer][0]
else:
start_id = start_ids[start_pointer]
end_id = end_ids[end_pointer]
if start_id == end_id:
couple_dict[end_ids[end_pointer]] = start_ids[start_pointer]
start_pointer += 1
end_pointer += 1
continue
if start_id < end_id:
couple_dict[end_ids[end_pointer]] = start_ids[start_pointer]
start_pointer += 1
continue
if start_id > end_id:
end_pointer += 1
continue
result = [(couple_dict[end], end) for end in couple_dict]
result = set(result)
return result
def get_bool_ids_greater_than(probs, limit=0.5, return_prob=False):
"""
Get idx of the last dimension in probability arrays, which is greater than a limitation.
Args:
probs (List[List[float]]): The input probability arrays.
limit (float): The limitation for probability.
return_prob (bool): Whether to return the probability
Returns:
List[List[int]]: The index of the last dimension meet the conditions.
"""
probs = np.array(probs)
dim_len = len(probs.shape)
if dim_len > 1:
result = []
for p in probs:
result.append(get_bool_ids_greater_than(p, limit, return_prob))
return result
else:
result = []
for i, p in enumerate(probs):
if p > limit:
if return_prob:
result.append((i, p))
else:
result.append(i)
return result
class SpanEvaluator:
"""
SpanEvaluator computes the precision, recall and F1-score for span detection.
"""
def __init__(self):
super(SpanEvaluator, self).__init__()
self.num_infer_spans = 0
self.num_label_spans = 0
self.num_correct_spans = 0
def compute(self, start_probs, end_probs, gold_start_ids, gold_end_ids):
"""
Computes the precision, recall and F1-score for span detection.
"""
pred_start_ids = get_bool_ids_greater_than(start_probs)
pred_end_ids = get_bool_ids_greater_than(end_probs)
gold_start_ids = get_bool_ids_greater_than(gold_start_ids.tolist())
gold_end_ids = get_bool_ids_greater_than(gold_end_ids.tolist())
num_correct_spans = 0
num_infer_spans = 0
num_label_spans = 0
for predict_start_ids, predict_end_ids, label_start_ids, label_end_ids in zip(
pred_start_ids, pred_end_ids, gold_start_ids, gold_end_ids):
[_correct, _infer, _label] = self.eval_span(
predict_start_ids, predict_end_ids, label_start_ids,
label_end_ids)
num_correct_spans += _correct
num_infer_spans += _infer
num_label_spans += _label
return num_correct_spans, num_infer_spans, num_label_spans
def update(self, num_correct_spans, num_infer_spans, num_label_spans):
"""
This function takes (num_infer_spans, num_label_spans, num_correct_spans) as input,
to accumulate and update the corresponding status of the SpanEvaluator object.
"""
self.num_infer_spans += num_infer_spans
self.num_label_spans += num_label_spans
self.num_correct_spans += num_correct_spans
def eval_span(self, predict_start_ids, predict_end_ids, label_start_ids,
label_end_ids):
"""
evaluate position extraction (start, end)
return num_correct, num_infer, num_label
input: [1, 2, 10] [4, 12] [2, 10] [4, 11]
output: (1, 2, 2)
"""
pred_set = get_span(predict_start_ids, predict_end_ids)
label_set = get_span(label_start_ids, label_end_ids)
num_correct = len(pred_set & label_set)
num_infer = len(pred_set)
num_label = len(label_set)
return (num_correct, num_infer, num_label)
def accumulate(self):
"""
This function returns the mean precision, recall and f1 score for all accumulated minibatches.
Returns:
tuple: Returns tuple (`precision, recall, f1 score`).
"""
precision = float(self.num_correct_spans /
self.num_infer_spans) if self.num_infer_spans else 0.
recall = float(self.num_correct_spans /
self.num_label_spans) if self.num_label_spans else 0.
f1_score = float(2 * precision * recall /
(precision + recall)) if self.num_correct_spans else 0.
return precision, recall, f1_score
def reset(self):
"""
Reset function empties the evaluation memory for previous mini-batches.
"""
self.num_infer_spans = 0
self.num_label_spans = 0
self.num_correct_spans = 0
def name(self):
"""
Return name of metric instance.
"""
return "precision", "recall", "f1"
class IEDataset(Dataset):
"""
Dataset for Information Extraction fron jsonl file.
The line type is
{
content
result_list
prompt
}
"""
def __init__(self, file_path, tokenizer, max_seq_len) -> None:
super().__init__()
self.file_path = file_path
self.dataset = list(reader(file_path))
self.tokenizer = tokenizer
self.max_seq_len = max_seq_len
def __len__(self):
return len(self.dataset)
def __getitem__(self, index):
return convert_example(self.dataset[index], tokenizer=self.tokenizer, max_seq_len=self.max_seq_len)
class IEMapDataset(Dataset):
"""
Dataset for Information Extraction fron jsonl file.
The line type is
{
content
result_list
prompt
}
"""
def __init__(self, data, tokenizer, max_seq_len) -> None:
super().__init__()
self.dataset = data
self.tokenizer = tokenizer
self.max_seq_len = max_seq_len
def __len__(self):
return len(self.dataset)
def __getitem__(self, index):
return convert_example(self.dataset[index], tokenizer=self.tokenizer, max_seq_len=self.max_seq_len)
def convert_example(example, tokenizer, max_seq_len):
"""
example: {
title
prompt
content
result_list
}
"""
encoded_inputs = tokenizer(
text=[example["prompt"]],
text_pair=[example["content"]],
truncation=True,
max_length=max_seq_len,
add_special_tokens=True,
return_offsets_mapping=True)
# encoded_inputs = encoded_inputs[0]
offset_mapping = [list(x) for x in encoded_inputs["offset_mapping"][0]]
bias = 0
for index in range(1, len(offset_mapping)):
mapping = offset_mapping[index]
if mapping[0] == 0 and mapping[1] == 0 and bias == 0:
bias = offset_mapping[index - 1][1] + 1 # Includes [SEP] token
if mapping[0] == 0 and mapping[1] == 0:
continue
offset_mapping[index][0] += bias
offset_mapping[index][1] += bias
start_ids = [0 for x in range(max_seq_len)]
end_ids = [0 for x in range(max_seq_len)]
for item in example["result_list"]:
start = map_offset(item["start"] + bias, offset_mapping)
end = map_offset(item["end"] - 1 + bias, offset_mapping)
start_ids[start] = 1.0
end_ids[end] = 1.0
tokenized_output = [
encoded_inputs["input_ids"][0], encoded_inputs["token_type_ids"][0],
encoded_inputs["attention_mask"][0],
start_ids, end_ids
]
tokenized_output = [np.array(x, dtype="int64") for x in tokenized_output]
tokenized_output = [
np.pad(x, (0, max_seq_len-x.shape[-1]), 'constant') for x in tokenized_output]
return tuple(tokenized_output)
def map_offset(ori_offset, offset_mapping):
"""
map ori offset to token offset
"""
for index, span in enumerate(offset_mapping):
if span[0] <= ori_offset < span[1]:
return index
return -1
def reader(data_path, max_seq_len=512):
"""
read json
"""
with open(data_path, 'r', encoding='utf-8') as f:
for line in f:
json_line = json.loads(line)
content = json_line['content']
prompt = json_line['prompt']
# Model Input is aslike: [CLS] Prompt [SEP] Content [SEP]
# It include three summary tokens.
if max_seq_len <= len(prompt) + 3:
raise ValueError(
"The value of max_seq_len is too small, please set a larger value"
)
max_content_len = max_seq_len - len(prompt) - 3
if len(content) <= max_content_len:
yield json_line
else:
result_list = json_line['result_list']
json_lines = []
accumulate = 0
while True:
cur_result_list = []
for result in result_list:
if result['start'] + 1 <= max_content_len < result[
'end']:
max_content_len = result['start']
break
cur_content = content[:max_content_len]
res_content = content[max_content_len:]
while True:
if len(result_list) == 0:
break
elif result_list[0]['end'] <= max_content_len:
if result_list[0]['end'] > 0:
cur_result = result_list.pop(0)
cur_result_list.append(cur_result)
else:
cur_result_list = [
result for result in result_list
]
break
else:
break
json_line = {
'content': cur_content,
'result_list': cur_result_list,
'prompt': prompt
}
json_lines.append(json_line)
for result in result_list:
if result['end'] <= 0:
break
result['start'] -= max_content_len
result['end'] -= max_content_len
accumulate += max_content_len
max_content_len = max_seq_len - len(prompt) - 3
if len(res_content) == 0:
break
elif len(res_content) < max_content_len:
json_line = {
'content': res_content,
'result_list': result_list,
'prompt': prompt
}
json_lines.append(json_line)
break
else:
content = res_content
for json_line in json_lines:
yield json_line
def unify_prompt_name(prompt):
# The classification labels are shuffled during finetuning, so they need
# to be unified during evaluation.
if re.search(r'\[.*?\]$', prompt):
prompt_prefix = prompt[:prompt.find("[", 1)]
cls_options = re.search(r'\[.*?\]$', prompt).group()[1:-1].split(",")
cls_options = sorted(list(set(cls_options)))
cls_options = ",".join(cls_options)
prompt = prompt_prefix + "[" + cls_options + "]"
return prompt
return prompt
class Logger(object):
'''
Deafult logger in UIE
Args:
name(str) : Logger name, default is 'UIE'
'''
def __init__(self, name: str = None):
name = 'UIE' if not name else name
self.logger = logging.getLogger(name)
for key, conf in log_config.items():
logging.addLevelName(conf['level'], key)
self.__dict__[key] = functools.partial(
self.__call__, conf['level'])
self.__dict__[key.lower()] = functools.partial(
self.__call__, conf['level'])
self.format = colorlog.ColoredFormatter(
'%(log_color)s[%(asctime)-15s] [%(levelname)8s]%(reset)s - %(message)s',
log_colors={key: conf['color']
for key, conf in log_config.items()})
self.handler = logging.StreamHandler()
self.handler.setFormatter(self.format)
self.logger.addHandler(self.handler)
self.logLevel = 'DEBUG'
self.logger.setLevel(logging.DEBUG)
self.logger.propagate = False
self._is_enable = True
def disable(self):
self._is_enable = False
def enable(self):
self._is_enable = True
@property
def is_enable(self) -> bool:
return self._is_enable
def __call__(self, log_level: str, msg: str):
if not self.is_enable:
return
self.logger.log(log_level, msg)
@contextlib.contextmanager
def use_terminator(self, terminator: str):
old_terminator = self.handler.terminator
self.handler.terminator = terminator
yield
self.handler.terminator = old_terminator
@contextlib.contextmanager
def processing(self, msg: str, interval: float = 0.1):
'''
Continuously print a progress bar with rotating special effects.
Args:
msg(str): Message to be printed.
interval(float): Rotation interval. Default to 0.1.
'''
end = False
def _printer():
index = 0
flags = ['\\', '|', '/', '-']
while not end:
flag = flags[index % len(flags)]
with self.use_terminator('\r'):
self.info('{}: {}'.format(msg, flag))
time.sleep(interval)
index += 1
t = threading.Thread(target=_printer)
t.start()
yield
end = True
logger = Logger()
BAR_FORMAT = f'{{desc}}: {Fore.GREEN}{{percentage:3.0f}}%{Fore.RESET} {Fore.BLUE}{{bar}}{Fore.RESET} {Fore.GREEN}{{n_fmt}}/{{total_fmt}} {Fore.RED}{{rate_fmt}}{{postfix}}{Fore.RESET} eta {Fore.CYAN}{{remaining}}{Fore.RESET}'
BAR_FORMAT_NO_TIME = f'{{desc}}: {Fore.GREEN}{{percentage:3.0f}}%{Fore.RESET} {Fore.BLUE}{{bar}}{Fore.RESET} {Fore.GREEN}{{n_fmt}}/{{total_fmt}}{Fore.RESET}'
BAR_TYPE = [
"░▝▗▖▘▚▞▛▙█",
"░▖▘▝▗▚▞█",
" ▖▘▝▗▚▞█",
"░▒█",
" >=",
" ▏▎▍▌▋▊▉█"
"░▏▎▍▌▋▊▉█"
]
tqdm = partial(tqdm, bar_format=BAR_FORMAT, ascii=BAR_TYPE[0], leave=False)
def get_id_and_prob(spans, offset_map):
prompt_length = 0
for i in range(1, len(offset_map)):
if offset_map[i] != [0, 0]:
prompt_length += 1
else:
break
for i in range(1, prompt_length + 1):
offset_map[i][0] -= (prompt_length + 1)
offset_map[i][1] -= (prompt_length + 1)
sentence_id = []
prob = []
for start, end in spans:
prob.append(start[1] * end[1])
sentence_id.append(
(offset_map[start[0]][0], offset_map[end[0]][1]))
return sentence_id, prob
def cut_chinese_sent(para):
"""
Cut the Chinese sentences more precisely, reference to
"https://blog.csdn.net/blmoistawinde/article/details/82379256".
"""
para = re.sub(r'([。!?\?])([^”’])', r'\1\n\2', para)
para = re.sub(r'(\.{6})([^”’])', r'\1\n\2', para)
para = re.sub(r'(\…{2})([^”’])', r'\1\n\2', para)
para = re.sub(r'([。!?\?][”’])([^,。!?\?])', r'\1\n\2', para)
para = para.rstrip()
return para.split("\n")
def dbc2sbc(s):
rs = ""
for char in s:
code = ord(char)
if code == 0x3000:
code = 0x0020
else:
code -= 0xfee0
if not (0x0021 <= code and code <= 0x7e):
rs += char
continue
rs += chr(code)
return rs
class EarlyStopping:
"""Early stops the training if validation loss doesn't improve after a given patience."""
def __init__(self, patience=7, verbose=False, delta=0, save_dir='checkpoint/early_stopping', trace_func=print):
"""
Args:
patience (int): How long to wait after last time validation loss improved.
Default: 7
verbose (bool): If True, prints a message for each validation loss improvement.
Default: False
delta (float): Minimum change in the monitored quantity to qualify as an improvement.
Default: 0
path (str): Path for the checkpoint to be saved to.
Default: 'checkpoint/early_stopping'
trace_func (function): trace print function.
Default: print
"""
self.patience = patience
self.verbose = verbose
self.counter = 0
self.best_score = None
self.early_stop = False
self.val_loss_min = np.Inf
self.delta = delta
self.save_dir = save_dir
self.trace_func = trace_func
def __call__(self, val_loss, model):
score = -val_loss
if self.best_score is None:
self.best_score = score
self.save_checkpoint(val_loss, model)
elif score < self.best_score + self.delta:
self.counter += 1
self.trace_func(
f'EarlyStopping counter: {self.counter} out of {self.patience}')
if self.counter >= self.patience:
self.early_stop = True
else:
self.best_score = score
self.save_checkpoint(val_loss, model)
self.counter = 0
def save_checkpoint(self, val_loss, model):
'''Saves model when validation loss decrease.'''
if self.verbose:
self.trace_func(
f'Validation loss decreased ({self.val_loss_min:.6f} --> {val_loss:.6f}). Saving model ...')
model.save_pretrained(self.save_dir)
self.val_loss_min = val_loss
def get_relation_type_dict(relation_data):
def compare(a, b):
a = a[::-1]
b = b[::-1]
res = ''
for i in range(min(len(a), len(b))):
if a[i] == b[i]:
res += a[i]
else:
break
if res == "":
return res
elif res[::-1][0] == "的":
return res[::-1][1:]
return ""
relation_type_dict = {}
added_list = []
for i in range(len(relation_data)):
added = False
if relation_data[i][0] not in added_list:
for j in range(i + 1, len(relation_data)):
match = compare(relation_data[i][0], relation_data[j][0])
if match != "":
match = unify_prompt_name(match)
if relation_data[i][0] not in added_list:
added_list.append(relation_data[i][0])
relation_type_dict.setdefault(match, []).append(
relation_data[i][1])
added_list.append(relation_data[j][0])
relation_type_dict.setdefault(match, []).append(
relation_data[j][1])
added = True
if not added:
added_list.append(relation_data[i][0])
suffix = relation_data[i][0].rsplit("的", 1)[1]
suffix = unify_prompt_name(suffix)
relation_type_dict[suffix] = relation_data[i][1]
return relation_type_dict
def add_entity_negative_example(examples, texts, prompts, label_set,
negative_ratio):
negative_examples = []
positive_examples = []
with tqdm(total=len(prompts)) as pbar:
for i, prompt in enumerate(prompts):
redundants = list(set(label_set) ^ set(prompt))
redundants.sort()
num_positive = len(examples[i])
if num_positive != 0:
actual_ratio = math.ceil(len(redundants) / num_positive)
else:
# Set num_positive to 1 for text without positive example
num_positive, actual_ratio = 1, 0
if actual_ratio <= negative_ratio or negative_ratio == -1:
idxs = [k for k in range(len(redundants))]
else:
idxs = random.sample(range(0, len(redundants)),
negative_ratio * num_positive)
for idx in idxs:
negative_result = {
"content": texts[i],
"result_list": [],
"prompt": redundants[idx]
}
negative_examples.append(negative_result)
positive_examples.extend(examples[i])
pbar.update(1)
return positive_examples, negative_examples
def add_relation_negative_example(redundants, text, num_positive, ratio):
added_example = []
rest_example = []
if num_positive != 0:
actual_ratio = math.ceil(len(redundants) / num_positive)
else:
# Set num_positive to 1 for text without positive example
num_positive, actual_ratio = 1, 0
all_idxs = [k for k in range(len(redundants))]
if actual_ratio <= ratio or ratio == -1:
idxs = all_idxs
rest_idxs = []
else:
idxs = random.sample(range(0, len(redundants)), ratio * num_positive)
rest_idxs = list(set(all_idxs) ^ set(idxs))
for idx in idxs:
negative_result = {
"content": text,
"result_list": [],
"prompt": redundants[idx]
}
added_example.append(negative_result)
for rest_idx in rest_idxs:
negative_result = {
"content": text,
"result_list": [],
"prompt": redundants[rest_idx]
}
rest_example.append(negative_result)
return added_example, rest_example
def add_full_negative_example(examples, texts, relation_prompts, predicate_set,
subject_goldens):
with tqdm(total=len(relation_prompts)) as pbar:
for i, relation_prompt in enumerate(relation_prompts):
negative_sample = []
for subject in subject_goldens[i]:
for predicate in predicate_set:
# The relation prompt is constructed as follows:
# subject + "的" + predicate
prompt = subject + "的" + predicate
if prompt not in relation_prompt:
negative_result = {
"content": texts[i],
"result_list": [],
"prompt": prompt
}
negative_sample.append(negative_result)
examples[i].extend(negative_sample)
pbar.update(1)
return examples
def generate_cls_example(text, labels, prompt_prefix, options):
random.shuffle(options)
cls_options = ",".join(options)
prompt = prompt_prefix + "[" + cls_options + "]"
result_list = []
example = {"content": text, "result_list": result_list, "prompt": prompt}
for label in labels:
start = prompt.rfind(label) - len(prompt) - 1
end = start + len(label)
result = {"text": label, "start": start, "end": end}
example["result_list"].append(result)
return example
def convert_cls_examples(raw_examples,
prompt_prefix="情感倾向",
options=["正向", "负向"]):
"""
Convert labeled data export from doccano for classification task.
"""
examples = []
logger.info(f"Converting doccano data...")
with tqdm(total=len(raw_examples)) as pbar:
for line in raw_examples:
items = json.loads(line)
# Compatible with doccano >= 1.6.2
if "data" in items.keys():
text, labels = items["data"], items["label"]
else:
text, labels = items["text"], items["label"]
example = generate_cls_example(
text, labels, prompt_prefix, options)
examples.append(example)
return examples
def convert_ext_examples(raw_examples,
negative_ratio,
prompt_prefix="情感倾向",
options=["正向", "负向"],
separator="##",
is_train=True):
"""
Convert labeled data export from doccano for extraction and aspect-level classification task.
"""
def _sep_cls_label(label, separator):
label_list = label.split(separator)
if len(label_list) == 1:
return label_list[0], None
return label_list[0], label_list[1:]
texts = []
entity_examples = []
relation_examples = []
entity_cls_examples = []
entity_prompts = []
relation_prompts = []
entity_label_set = []
entity_name_set = []
predicate_set = []
subject_goldens = []
inverse_relation_list = []
predicate_list = []
logger.info(f"Converting doccano data...")
with tqdm(total=len(raw_examples)) as pbar:
for line in raw_examples:
items = json.loads(line)
entity_id = 0
if "data" in items.keys():
relation_mode = False
if isinstance(items["label"],
dict) and "entities" in items["label"].keys():
relation_mode = True
text = items["data"]
entities = []
relations = []
if not relation_mode:
# Export file in JSONL format which doccano < 1.7.0
# e.g. {"data": "", "label": [ [0, 2, "ORG"], ... ]}
for item in items["label"]:
entity = {
"id": entity_id,
"start_offset": item[0],
"end_offset": item[1],
"label": item[2]
}
entities.append(entity)
entity_id += 1
else:
# Export file in JSONL format for relation labeling task which doccano < 1.7.0
# e.g. {"data": "", "label": {"relations": [ {"id": 0, "start_offset": 0, "end_offset": 6, "label": "ORG"}, ... ], "entities": [ {"id": 0, "from_id": 0, "to_id": 1, "type": "foundedAt"}, ... ]}}
entities.extend(
[entity for entity in items["label"]["entities"]])
if "relations" in items["label"].keys():
relations.extend([
relation for relation in items["label"]["relations"]
])
else:
# Export file in JSONL format which doccano >= 1.7.0
# e.g. {"text": "", "label": [ [0, 2, "ORG"], ... ]}
if "label" in items.keys():
text = items["text"]
entities = []
for item in items["label"]:
entity = {
"id": entity_id,
"start_offset": item[0],
"end_offset": item[1],
"label": item[2]
}
entities.append(entity)
entity_id += 1
relations = []
else:
# Export file in JSONL (relation) format
# e.g. {"text": "", "relations": [ {"id": 0, "start_offset": 0, "end_offset": 6, "label": "ORG"}, ... ], "entities": [ {"id": 0, "from_id": 0, "to_id": 1, "type": "foundedAt"}, ... ]}
text, relations, entities = items["text"], items[
"relations"], items["entities"]
texts.append(text)
entity_example = []
entity_prompt = []
entity_example_map = {}
entity_map = {} # id to entity name
for entity in entities:
entity_name = text[entity["start_offset"]:entity["end_offset"]]
entity_map[entity["id"]] = {
"name": entity_name,
"start": entity["start_offset"],
"end": entity["end_offset"]
}
entity_label, entity_cls_label = _sep_cls_label(
entity["label"], separator)
# Define the prompt prefix for entity-level classification
entity_cls_prompt_prefix = entity_name + "的" + prompt_prefix
if entity_cls_label is not None:
entity_cls_example = generate_cls_example(
text, entity_cls_label, entity_cls_prompt_prefix,
options)
entity_cls_examples.append(entity_cls_example)
result = {
"text": entity_name,
"start": entity["start_offset"],
"end": entity["end_offset"]
}
if entity_label not in entity_example_map.keys():
entity_example_map[entity_label] = {
"content": text,
"result_list": [result],
"prompt": entity_label
}
else:
entity_example_map[entity_label]["result_list"].append(
result)
if entity_label not in entity_label_set:
entity_label_set.append(entity_label)
if entity_name not in entity_name_set:
entity_name_set.append(entity_name)
entity_prompt.append(entity_label)
for v in entity_example_map.values():
entity_example.append(v)
entity_examples.append(entity_example)
entity_prompts.append(entity_prompt)
subject_golden = [] # Golden entity inputs
relation_example = []
relation_prompt = []
relation_example_map = {}
inverse_relation = []
predicates = []
for relation in relations:
predicate = relation["type"]
subject_id = relation["from_id"]
object_id = relation["to_id"]
# The relation prompt is constructed as follows:
# subject + "的" + predicate
prompt = entity_map[subject_id]["name"] + "的" + predicate
if entity_map[subject_id]["name"] not in subject_golden:
subject_golden.append(entity_map[subject_id]["name"])
result = {
"text": entity_map[object_id]["name"],
"start": entity_map[object_id]["start"],
"end": entity_map[object_id]["end"]
}
inverse_negative = entity_map[object_id][
"name"] + "的" + predicate
inverse_relation.append(inverse_negative)
predicates.append(predicate)
if prompt not in relation_example_map.keys():
relation_example_map[prompt] = {
"content": text,
"result_list": [result],
"prompt": prompt
}
else:
relation_example_map[prompt]["result_list"].append(result)
if predicate not in predicate_set:
predicate_set.append(predicate)
relation_prompt.append(prompt)
for v in relation_example_map.values():
relation_example.append(v)
relation_examples.append(relation_example)
relation_prompts.append(relation_prompt)
subject_goldens.append(subject_golden)
inverse_relation_list.append(inverse_relation)
predicate_list.append(predicates)
pbar.update(1)
logger.info(f"Adding negative samples for first stage prompt...")
positive_examples, negative_examples = add_entity_negative_example(
entity_examples, texts, entity_prompts, entity_label_set,
negative_ratio)
if len(positive_examples) == 0:
all_entity_examples = []
else:
all_entity_examples = positive_examples + negative_examples
all_relation_examples = []
if len(predicate_set) != 0:
logger.info(f"Adding negative samples for second stage prompt...")
if is_train:
positive_examples = []
negative_examples = []
per_n_ratio = negative_ratio // 3
with tqdm(total=len(texts)) as pbar:
for i, text in enumerate(texts):
negative_example = []
collects = []
num_positive = len(relation_examples[i])
# 1. inverse_relation_list
redundants1 = inverse_relation_list[i]
# 2. entity_name_set ^ subject_goldens[i]
redundants2 = []
if len(predicate_list[i]) != 0:
nonentity_list = list(
set(entity_name_set) ^ set(subject_goldens[i]))
nonentity_list.sort()
redundants2 = [
nonentity + "的" +
predicate_list[i][random.randrange(
len(predicate_list[i]))]
for nonentity in nonentity_list
]
# 3. entity_label_set ^ entity_prompts[i]
redundants3 = []
if len(subject_goldens[i]) != 0:
non_ent_label_list = list(
set(entity_label_set) ^ set(entity_prompts[i]))
non_ent_label_list.sort()
redundants3 = [
subject_goldens[i][random.randrange(
len(subject_goldens[i]))] + "的" + non_ent_label
for non_ent_label in non_ent_label_list
]
redundants_list = [redundants1, redundants2, redundants3]
for redundants in redundants_list:
added, rest = add_relation_negative_example(
redundants,
texts[i],
num_positive,
per_n_ratio,
)
negative_example.extend(added)
collects.extend(rest)
num_sup = num_positive * negative_ratio - len(
negative_example)
if num_sup > 0 and collects:
if num_sup > len(collects):
idxs = [k for k in range(len(collects))]
else:
idxs = random.sample(range(0, len(collects)),
num_sup)
for idx in idxs:
negative_example.append(collects[idx])
positive_examples.extend(relation_examples[i])
negative_examples.extend(negative_example)
pbar.update(1)
all_relation_examples = positive_examples + negative_examples
else:
relation_examples = add_full_negative_example(
relation_examples, texts, relation_prompts, predicate_set,
subject_goldens)
all_relation_examples = [
r for relation_example in relation_examples
for r in relation_example
]
return all_entity_examples, all_relation_examples, entity_cls_examples
def get_path_from_url(url,
root_dir,
check_exist=True,
decompress=True):
""" Download from given url to root_dir.
if file or directory specified by url is exists under
root_dir, return the path directly, otherwise download
from url and decompress it, return the path.
Args:
url (str): download url
root_dir (str): root dir for downloading, it should be
WEIGHTS_HOME or DATASET_HOME
decompress (bool): decompress zip or tar file. Default is `True`
Returns:
str: a local path to save downloaded models & weights & datasets.
"""
import os.path
import os
import tarfile
import zipfile
def is_url(path):
"""
Whether path is URL.
Args:
path (string): URL string or not.
"""
return path.startswith('http://') or path.startswith('https://')
def _map_path(url, root_dir):
# parse path after download under root_dir
fname = os.path.split(url)[-1]
fpath = fname
return os.path.join(root_dir, fpath)
def _get_download(url, fullname):
import requests
# using requests.get method
fname = os.path.basename(fullname)
try:
req = requests.get(url, stream=True)
except Exception as e: # requests.exceptions.ConnectionError
logger.info("Downloading {} from {} failed with exception {}".format(
fname, url, str(e)))
return False
if req.status_code != 200:
raise RuntimeError("Downloading from {} failed with code "
"{}!".format(url, req.status_code))
# For protecting download interupted, download to
# tmp_fullname firstly, move tmp_fullname to fullname
# after download finished
tmp_fullname = fullname + "_tmp"
total_size = req.headers.get('content-length')
with open(tmp_fullname, 'wb') as f:
if total_size:
with tqdm(total=(int(total_size) + 1023) // 1024, unit='KB') as pbar:
for chunk in req.iter_content(chunk_size=1024):
f.write(chunk)
pbar.update(1)
else:
for chunk in req.iter_content(chunk_size=1024):
if chunk:
f.write(chunk)
shutil.move(tmp_fullname, fullname)
return fullname
def _download(url, path):
"""
Download from url, save to path.
url (str): download url
path (str): download to given path
"""
if not os.path.exists(path):
os.makedirs(path)
fname = os.path.split(url)[-1]
fullname = os.path.join(path, fname)
retry_cnt = 0
logger.info("Downloading {} from {}".format(fname, url))
DOWNLOAD_RETRY_LIMIT = 3
while not os.path.exists(fullname):
if retry_cnt < DOWNLOAD_RETRY_LIMIT:
retry_cnt += 1
else:
raise RuntimeError("Download from {} failed. "
"Retry limit reached".format(url))
if not _get_download(url, fullname):
time.sleep(1)
continue
return fullname
def _uncompress_file_zip(filepath):
with zipfile.ZipFile(filepath, 'r') as files:
file_list = files.namelist()
file_dir = os.path.dirname(filepath)
if _is_a_single_file(file_list):
rootpath = file_list[0]
uncompressed_path = os.path.join(file_dir, rootpath)
files.extractall(file_dir)
elif _is_a_single_dir(file_list):
# `strip(os.sep)` to remove `os.sep` in the tail of path
rootpath = os.path.splitext(file_list[0].strip(os.sep))[0].split(
os.sep)[-1]
uncompressed_path = os.path.join(file_dir, rootpath)
files.extractall(file_dir)
else:
rootpath = os.path.splitext(filepath)[0].split(os.sep)[-1]
uncompressed_path = os.path.join(file_dir, rootpath)
if not os.path.exists(uncompressed_path):
os.makedirs(uncompressed_path)
files.extractall(os.path.join(file_dir, rootpath))
return uncompressed_path
def _is_a_single_file(file_list):
if len(file_list) == 1 and file_list[0].find(os.sep) < 0:
return True
return False
def _is_a_single_dir(file_list):
new_file_list = []
for file_path in file_list:
if '/' in file_path:
file_path = file_path.replace('/', os.sep)
elif '\\' in file_path:
file_path = file_path.replace('\\', os.sep)
new_file_list.append(file_path)
file_name = new_file_list[0].split(os.sep)[0]
for i in range(1, len(new_file_list)):
if file_name != new_file_list[i].split(os.sep)[0]:
return False
return True
def _uncompress_file_tar(filepath, mode="r:*"):
with tarfile.open(filepath, mode) as files:
file_list = files.getnames()
file_dir = os.path.dirname(filepath)
if _is_a_single_file(file_list):
rootpath = file_list[0]
uncompressed_path = os.path.join(file_dir, rootpath)
files.extractall(file_dir)
elif _is_a_single_dir(file_list):
rootpath = os.path.splitext(file_list[0].strip(os.sep))[0].split(
os.sep)[-1]
uncompressed_path = os.path.join(file_dir, rootpath)
files.extractall(file_dir)
else:
rootpath = os.path.splitext(filepath)[0].split(os.sep)[-1]
uncompressed_path = os.path.join(file_dir, rootpath)
if not os.path.exists(uncompressed_path):
os.makedirs(uncompressed_path)
files.extractall(os.path.join(file_dir, rootpath))
return uncompressed_path
def _decompress(fname):
"""
Decompress for zip and tar file
"""
logger.info("Decompressing {}...".format(fname))
# For protecting decompressing interupted,
# decompress to fpath_tmp directory firstly, if decompress
# successed, move decompress files to fpath and delete
# fpath_tmp and remove download compress file.
if tarfile.is_tarfile(fname):
uncompressed_path = _uncompress_file_tar(fname)
elif zipfile.is_zipfile(fname):
uncompressed_path = _uncompress_file_zip(fname)
else:
raise TypeError("Unsupport compress file type {}".format(fname))
return uncompressed_path
assert is_url(url), "downloading from {} not a url".format(url)
fullpath = _map_path(url, root_dir)
if os.path.exists(fullpath) and check_exist:
logger.info("Found {}".format(fullpath))
else:
fullpath = _download(url, root_dir)
if decompress and (tarfile.is_tarfile(fullpath) or
zipfile.is_zipfile(fullpath)):
fullpath = _decompress(fullpath)
return fullpath
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
1
https://gitee.com/good-as-water/uie_pytorch.git
git@gitee.com:good-as-water/uie_pytorch.git
good-as-water
uie_pytorch
uie_pytorch
main

搜索帮助

0d507c66 1850385 C8b1a773 1850385