1 Star 5 Fork 1

zhanbiao2023/Summary

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
该仓库未声明开源许可证文件(LICENSE),使用请关注具体项目描述及其代码上游依赖。
克隆/下载
Tokenize.py 5.72 KB
一键复制 编辑 原始数据 按行查看 历史
biao242626 提交于 2021-01-12 15:44 . i
import unicodedata
def load_vocab(dict_path, encoding='utf-8', simplified=False, startswith=None):
"""从bert的词典文件中读取词典
"""
token_dict = {}
with open(dict_path, encoding=encoding) as reader:
for line in reader:
token = line.split()
token = token[0] if token else line.strip()
token_dict[token] = len(token_dict)
return token_dict
class Tokenize:
def __init__(
self,
token_dict=None,
normalize=True,
token_start='[CLS]',
token_end='[SEP]',
pre_tokenize=None,
token_translate=None
):
"""参数说明:
pre_tokenize:外部传入的分词函数,用作对文本进行预分词。如果传入
pre_tokenize,则先执行pre_tokenize(text),然后在它
的基础上执行原本的tokenize函数;
token_translate:映射字典,主要用在tokenize之后,将某些特殊的token
替换为对应的token。
"""
token_dict = load_vocab(token_dict)
self._normalize = normalize
self._token_pad = '[PAD]'
self._token_unk = '[UNK]'
self._token_mask = '[MASK]'
self._token_start = token_start
self._token_end = token_end
self._pre_tokenize = pre_tokenize
self._token_dict = token_dict
self._vocab_size = len(token_dict)
self._token_dict_inv = {v: k for k, v in token_dict.items()}
self._token_translate = token_translate or {}
self._token_translate_inv = {
v: k
for k, v in self._token_translate.items()
}
def tokenize(self, text, pre_tokenize=True):
"""分词函数
"""
if self._normalize:
text = str(text)
text = text.lower()
text = unicodedata.normalize('NFD', text)
text = ''.join([
ch for ch in text if unicodedata.category(ch) != 'Mn'
])
E_pun = u',.!?[]()<>"\'"\'.'
C_pun = u',。!?【】()《》“‘”’…'
table = {ord(f): ord(t) for f, t in zip(C_pun, E_pun)}
text = text.translate(table)
if pre_tokenize and self._pre_tokenize is not None:
tokens = []
for token in self._pre_tokenize(text):
if token in self._token_dict:
tokens.append(token)
else:
tokens.extend(self.tokenize(token, False))
return tokens
spaced = ''
for ch in text:
if self._is_punctuation(ch) or self._is_cjk_character(ch):
spaced += ' ' + ch + ' '
elif self._is_space(ch):
spaced += ' '
elif ord(ch) == 0 or ord(ch) == 0xfffd or self._is_control(ch):
continue
else:
spaced += ch
tokens = []
for word in spaced.strip().split():
tokens.extend(self._word_piece_tokenize(word))
return tokens
def _word_piece_tokenize(self, word):
"""word内分成subword
"""
if word in self._token_dict:
return [word]
tokens = []
start, stop = 0, 0
while start < len(word):
stop = len(word)
while stop > start:
sub = word[start:stop]
if start > 0:
sub = '##' + sub
if sub in self._token_dict:
break
stop -= 1
# if start == stop and start != 0 and stop!=0:
# stop += 1
if stop == 0:
break
tokens.append(sub)
start = stop
return tokens
def token_to_id(self, token):
return self._token_dict.get(token, self._token_unk)
def tokens_to_ids(self, tokens):
# 首先分词
tokens = self.tokenize(tokens)
return [self.token_to_id(token) for token in tokens]
def id_to_token(self, id):
return self._token_dict_inv[id]
def ids_to_tokens(self, ids):
return [self.id_to_token(i) for i in ids]
@staticmethod
def _is_punctuation(ch):
"""标点符号类字符判断(全/半角均在此内)
提醒:unicodedata.category这个函数在py2和py3下的
表现可能不一样,比如u'§'字符,在py2下的结果为'So',
在py3下的结果是'Po'。
"""
code = ord(ch)
return 33 <= code <= 47 or \
58 <= code <= 64 or \
91 <= code <= 96 or \
123 <= code <= 126 or \
unicodedata.category(ch).startswith('P')
@staticmethod
def _is_cjk_character(ch):
"""CJK类字符判断(包括中文字符也在此列)
参考:https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block)
"""
code = ord(ch)
return 0x4E00 <= code <= 0x9FFF or \
0x3400 <= code <= 0x4DBF or \
0x20000 <= code <= 0x2A6DF or \
0x2A700 <= code <= 0x2B73F or \
0x2B740 <= code <= 0x2B81F or \
0x2B820 <= code <= 0x2CEAF or \
0xF900 <= code <= 0xFAFF or \
0x2F800 <= code <= 0x2FA1F
@staticmethod
def _is_control(ch):
"""控制类字符判断
"""
return unicodedata.category(ch) in ('Cc', 'Cf')
@staticmethod
def _is_special(ch):
"""判断是不是有特殊含义的符号
"""
return bool(ch) and (ch[0] == '[') and (ch[-1] == ']')
@staticmethod
def _is_space(ch):
"""空格类字符判断
"""
return ch == ' ' or ch == '\n' or ch == '\r' or ch == '\t' or \
unicodedata.category(ch) == 'Zs'
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
Python
1
https://gitee.com/zhanbiao2023/summary.git
git@gitee.com:zhanbiao2023/summary.git
zhanbiao2023
summary
Summary
master

搜索帮助