代码拉取完成,页面将自动刷新
import os
import copy
import json
import logging
import torch
from torch.utils.data import TensorDataset
from bert_finetune_ner.utils import get_slot_labels
logger = logging.getLogger(__name__)
class InputExample(object):
"""
定义一个样本类
一个样本完全可以用一个dict来表示,但是使用 InputExample类,有很多方便之处
Args:
guid: Unique id for the example.
words: list. The words of the sequence.
slot_labels: (Optional) list. The slot labels of the example.
"""
def __init__(self, guid, words, slot_labels=None):
# 每个样本的独特的序号
self.guid = guid
# 样本的输入序列
self.words = words
# 样本的命名实体标签
self.slot_labels = slot_labels
def __repr__(self):
"""
这个魔法方法默认为: “类名+object at+内存地址”这样的信息表示这个实例;
这里我们重写成了想要输出的信息,当print(input_example) 时候显示这些信息
"""
return str(self.to_json_string())
def to_dict(self):
"""
将此实例序列化到Python字典中
__dict__:类的静态函数、类函数、普通函数、全局变量以及一些内置的属性都是放在类__dict__里的
对象实例的__dict__中存储了一些self.xxx的一些东西
"""
output = copy.deepcopy(self.__dict__)
return output
def to_json_string(self):
"""
类的属性等信息(字典格式)dump进入json string
json.dumps()函数将python对象编码成JSON字符串
indent=2.文件格式中加入了换行与缩进
"""
return json.dumps(self.to_dict(), indent=2, sort_keys=True) + "\n"
class InputFeatures(object):
"""
定义一个样本类,构造模型特征
"""
def __init__(self, input_ids, attention_mask, token_type_ids, slot_labels_ids):
# 输入样本序列在bert词表里的索引,可以直接喂给nn.embedding
self.input_ids = input_ids
# 注意力mask,padding的部分为0,其他为1
self.attention_mask = attention_mask
# 表示每个token属于句子1还是句子2
self.token_type_ids = token_type_ids
# 命名实体标签索引
self.slot_labels_ids = slot_labels_ids
def __repr__(self):
"""
这个魔法方法默认为: “类名+object at+内存地址”这样的信息表示这个实例;
这里我们重写成了想要输出的信息,当print(input_example) 时候显示这些信息
"""
return str(self.to_json_string())
def to_dict(self):
"""
将此实例序列化到Python字典中
__dict__:类的静态函数、类函数、普通函数、全局变量以及一些内置的属性都是放在类__dict__里的
对象实例的__dict__中存储了一些self.xxx的一些东西
"""
output = copy.deepcopy(self.__dict__)
return output
def to_json_string(self):
"""
类的属性等信息(字典格式)dump进入json string
json.dumps()函数将python对象编码成JSON字符串
indent=2.文件格式中加入了换行与缩进
"""
return json.dumps(self.to_dict(), indent=2, sort_keys=True) + "\n"
class NerProcessor(object):
"""NER项目的数据处理器 """
def __init__(self, args):
# 参数
self.args = args
# 获取命名实体标签
self.slot_labels = get_slot_labels(args)
# 输入样本文件名
self.input_text_file = 'seq.in'
# 输入的命名实体标签文件名
self.slot_labels_file = 'seq.out'
@classmethod # 类方法
def _read_file(cls, input_file, quotechar=None):
"""
按行读取文件
:param input_file: 输入文件名
:param quotechar:
:return: 返回句子列表
"""
# 按行读取文件并添加到lines文件中
with open(input_file, "r", encoding="utf-8") as f:
lines = []
for line in f:
lines.append(line.strip())
return lines
def _create_examples(self, texts, slots, set_type):
"""
为训练集与验证集创建示例
:param texts:需要处理的文本组成的句子列表
:param slots:需要处理的文本组成的命名实体列表
:param set_type:区分训练/开发/测试集
:return:
"""
examples = []
for i, (text, slot) in enumerate(zip(texts, slots)):
# 给每个样本一个编号
guid = "%s-%s" % (set_type, i)
# 1. 每一句话依据空格与空白字符进行分割
words = text.split()
# 2. slot
slot_labels = []
# 对每一句话的命名实体标签进行循环,如果在self.slot_labels中,则添加其索引,否则添加"UNK"的索引
for s in slot.split():
slot_labels.append(self.slot_labels.index(s) if s in self.slot_labels else self.slot_labels.index("UNK"))
# 断言,确保单词与命名实体个数相同
assert len(words) == len(slot_labels)
# 将InputExample类中添加到examples列表中
examples.append(InputExample(guid=guid, words=words, slot_labels=slot_labels))
return examples
def get_examples(self, mode):
"""
获得examples数据
Args:
mode: train, dev, test
"""
# 文件路径拼接
data_path = os.path.join(self.args.data_dir, self.args.task, mode)
# 日志文件
logger.info("LOOKING AT {}".format(data_path))
return self._create_examples(texts=self._read_file(os.path.join(data_path, self.input_text_file)),
slots=self._read_file(os.path.join(data_path, self.slot_labels_file)),
set_type=mode)
# 如果有多个数据集,则数据集的processor可以通过映射得到
processors = {
"atis": NerProcessor,
"snips": NerProcessor
}
def convert_examples_to_features(examples, max_seq_len, tokenizer,
pad_token_label_id=-100,
cls_token_segment_id=0,
pad_token_segment_id=0,
sequence_a_segment_id=0,
mask_padding_with_zero=True,
slot_label_lst=None,
):
"""
将输入样本转化为bert能够读取的features
:param examples: 输入样本
:param max_seq_len: 最大长度
:param tokenizer:文本处理
:param pad_token_label_id:-100
:param cls_token_segment_id:0
:param pad_token_segment_id:0
:param sequence_a_segment_id:0
:param mask_padding_with_zero:0
:param slot_label_lst:0
:return:
"""
# 基于当前模型的设置
# [CLS]
cls_token = tokenizer.cls_token
# [SEP]
sep_token = tokenizer.sep_token
# [UNK]
unk_token = tokenizer.unk_token
# [PAD]
pad_token_id = tokenizer.pad_token_id
features = []
for (ex_index, example) in enumerate(examples):
# 每5000条数据,写入一条日志
if ex_index % 5000 == 0:
logger.info("Writing example %d of %d" % (ex_index, len(examples)))
# Tokenize word by word (for NER)
tokens = []
slot_labels_ids = []
# 循环遍历每个词与label
for word, slot_label in zip(example.words, example.slot_labels):
# 使用 tokenize() 函数对文本进行 tokenization之后,返回的分词的 token 词
word_tokens = tokenizer.tokenize(word)
# 处理错误编码的单词
if not word_tokens:
word_tokens = [unk_token]
# 用word_tokens列表拓展tokens列表
tokens.extend(word_tokens)
# A——word : B-PER;
# A-word --> BERT tokenize --> A_1, A_2, A_3: B-PER, PAD, PAD;
# A-word --> BERT tokenize --> A_1, A_2, A_3: B-PER, I-PER, I-PER;
# 对单词的第一个token使用真实标签id,对其余token使用填充id
slot_labels_ids.extend([int(slot_label)] + [pad_token_label_id] * (len(word_tokens) - 1))
# pad_token_label_id: -100, loss function 忽略的label编号
# Account for [CLS] and [SEP]
special_tokens_count = 2
# 如果句子长了就截断,命名实体标签索引列表也进行截断
if len(tokens) > max_seq_len - special_tokens_count:
tokens = tokens[:(max_seq_len - special_tokens_count)]
slot_labels_ids = slot_labels_ids[:(max_seq_len - special_tokens_count)]
# Add [SEP] token
tokens += [sep_token]
# [SEP] label id: pad_token_label_id
slot_labels_ids += [pad_token_label_id]
token_type_ids = [sequence_a_segment_id] * len(tokens)
# Add [CLS] token
tokens = [cls_token] + tokens
# [CLS] label id: pad_token_label_id
slot_labels_ids = [pad_token_label_id] + slot_labels_ids
token_type_ids = [cls_token_segment_id] + token_type_ids
# 把tokens转化为bert词表中的id
input_ids = tokenizer.convert_tokens_to_ids(tokens)
# The mask has 1 for real tokens and 0 for padding tokens. Only real
# tokens are attended to.
attention_mask = [1 if mask_padding_with_zero else 0] * len(input_ids)
# Zero-pad up to the sequence length.
# 长度补齐,保证长度满足最大序列长度
# 需要填充序列的长度
padding_length = max_seq_len - len(input_ids)
# 输入样本序列在bert词表里的索引
input_ids = input_ids + ([pad_token_id] * padding_length)
# 注意力mask,padding的部分为0,其他为1
attention_mask = attention_mask + ([0 if mask_padding_with_zero else 1] * padding_length)
# token_type_ids表示每个token属于句子1还是句子2
token_type_ids = token_type_ids + ([pad_token_segment_id] * padding_length)
# slot_labels_ids:命名实体标签索引,剩余部分用pad_token_label_id来填充
slot_labels_ids = slot_labels_ids + ([pad_token_label_id] * padding_length)
# 验证长度是否填充至最长序列
assert len(input_ids) == max_seq_len, "Error with input length {} vs {}".format(len(input_ids), max_seq_len)
assert len(attention_mask) == max_seq_len, "Error with attention mask length {} vs {}".format(len(attention_mask), max_seq_len)
assert len(token_type_ids) == max_seq_len, "Error with token type length {} vs {}".format(len(token_type_ids), max_seq_len)
assert len(slot_labels_ids) == max_seq_len, "Error with slot labels length {} vs {}".format(len(slot_labels_ids), max_seq_len)
# 第1199句话日志的写法
if 1198 < ex_index < 1200:
logger.info("*** Example ***")
logger.info("guid: %s" % example.guid)
logger.info("original words: %s" % " ".join([str(x) for x in example.words]))
logger.info("original slot_labels: %s" % " ".join([slot_label_lst[int(x)] for x in example.slot_labels]))
logger.info("tokens: %s" % " ".join([str(x) for x in tokens]))
logger.info("input_ids: %s" % " ".join([str(x) for x in input_ids]))
logger.info("attention_mask: %s" % " ".join([str(x) for x in attention_mask]))
logger.info("token_type_ids: %s" % " ".join([str(x) for x in token_type_ids]))
logger.info("slot_labels: %s" % " ".join([str(x) for x in slot_labels_ids]))
# 每句话以InputFeatures类数据添加到features列表中
features.append(
InputFeatures(input_ids=input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
slot_labels_ids=slot_labels_ids
))
return features
def load_and_cache_examples(args, tokenizer, mode):
"""
将数据转化为cache文件,方便下一次快速加载
:param args:参数
:param tokenizer:分词
:param mode:区分训练验证测试集
:return:数据集
"""
# 通过映射得到数据集的处理器
processor = processors[args.task](args)
# 从cache 或 dataset 文件加载数据
# cach路径拼接
cached_features_file = os.path.join(
args.data_dir,
'cached_{}_{}_{}_{}'.format(
mode,
args.task,
list(filter(None, args.model_name_or_path.split("/"))).pop(),
args.max_seq_len
)
)
# 如果cach文件路径存在
if os.path.exists(cached_features_file):
logger.info("Loading features from cached file %s", cached_features_file)
# 加载cach文件
features = torch.load(cached_features_file)
else:
# 读取命名实体标签
slot_label_lst = get_slot_labels(args)
# Load data features from dataset file
logger.info("Creating features from dataset file at %s", args.data_dir)
# 通过数据处理器获得examples数据
if mode == "train":
examples = processor.get_examples("train")
elif mode == "dev":
examples = processor.get_examples("dev")
elif mode == "test":
examples = processor.get_examples("test")
else:
raise Exception("For mode, Only train, dev, test is available")
# Use cross entropy ignore index as padding label id so that only real label ids contribute to the loss later
pad_token_label_id = args.ignore_index
# 将输入样本转化为bert能够读取的features
# 将之前读取的数据进行添加[CLS],[SEP]标记,padding等操作
features = convert_examples_to_features(examples, args.max_seq_len, tokenizer,
pad_token_label_id=pad_token_label_id,
slot_label_lst=slot_label_lst)
logger.info("Saving features into cached file %s", cached_features_file)
# 将处理后的数据保存至cach文件中
torch.save(features, cached_features_file)
# 将InputFeatures类数据转化为tensor
all_input_ids = torch.tensor([f.input_ids for f in features], dtype=torch.long)
all_attention_mask = torch.tensor([f.attention_mask for f in features], dtype=torch.long)
all_token_type_ids = torch.tensor([f.token_type_ids for f in features], dtype=torch.long)
all_slot_labels_ids = torch.tensor([f.slot_labels_ids for f in features], dtype=torch.long)
# 构造数据集
dataset = TensorDataset(all_input_ids, all_attention_mask,
all_token_type_ids, all_slot_labels_ids
)
# 返回数据集
return dataset
此处可能存在不合适展示的内容,页面不予展示。您可通过相关编辑功能自查并修改。
如您确认内容无涉及 不当用语 / 纯广告导流 / 暴力 / 低俗色情 / 侵权 / 盗版 / 虚假 / 无价值内容或违法国家有关法律法规的内容,可点击提交进行申诉,我们将尽快为您处理。