0 Star 1 Fork 0

gisfanmachel/pdcner

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
克隆/下载
wcbert_modeling_nky.py 26.33 KB
一键复制 编辑 原始数据 按行查看 历史
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636
# -*- coding: utf-8 -*-
"""
implement of PDCNER
"""
import math
import os
import warnings
from dataclasses import dataclass
from typing import Optional, Tuple
import torch
import torch.utils.checkpoint
from torch import nn
from torch.nn import CrossEntropyLoss, MSELoss
from transformers.activations import gelu, gelu_new, ACT2FN
from transformers.configuration_bert import BertConfig
from module.crf import CRF
from module.bilstm import BiLSTM
from transformers.modeling_utils import (
PreTrainedModel,
apply_chunking_to_forward,
find_pruneable_heads_and_indices,
prune_linear_layer,
)
from transformers.modeling_bert import BertAttention, BertIntermediate, BertOutput, load_tf_weights_in_bert, BertModel
BertLayerNorm = torch.nn.LayerNorm
from function.utils import gather_indexes
#定义了PDCNER模型的主要结构,包括BERT嵌入层、编码器、池化层和用于命名实体识别的CRF层。
# 用于构建词嵌入,包括词(word)、位置(position)、和边界(boundary)嵌入。
# 词嵌入是使用nn.Embedding创建的,并且具有一个特殊的边界嵌入层,用于处理边界信息。
class BertEmbeddings(nn.Module):
"""
Construct the embeddingns fron word, position and token_type, boundary embeddings
"""
def __init__(self, config):
super().__init__()
# word_embeddings: 将词汇映射到连续向量空间。
# position_embeddings: 给定序列中每个词的位置编码。
# token_type_embeddings: 用于区分不同类型的输入(如句子A和句子B)。
# boundary_embeddings: 用于表示词边界信息。
self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size)
# self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
# any TensorFlow checkpoint file
self.LayerNorm = BertLayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
# position_ids (1, len position emb) is contiguous in memory and exported when serialized
self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))
def forward(self, input_ids=None, token_type_ids=None, position_ids=None, boundary_ids=None, inputs_embeds=None):
"""
here we add a boundary information
boundary_ids: [batch_size, seq_length, boundary_size]
boundary_mask: filter some boubdary information
"""
if input_ids is not None:
input_shape = input_ids.size()
else:
input_shape = inputs_embeds.size()[:-1]
seq_length = input_shape[1]
if position_ids is None:
position_ids = self.position_ids[:, :seq_length]
if token_type_ids is None:
token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device)
if inputs_embeds is None:
inputs_embeds = self.word_embeddings(input_ids)
position_embeddings = self.position_embeddings(position_ids)
token_type_embeddings = self.token_type_embeddings(token_type_ids)
embeddings = inputs_embeds + position_embeddings + token_type_embeddings
embeddings = self.LayerNorm(embeddings)
embeddings = self.dropout(embeddings)
return embeddings
# 扩展了BERT的层,增加了对匹配词的注意力机制。
# 包含一个特殊的变换层word_transform,用于将词嵌入转换为与BERT模型隐藏层维度相匹配的表示。
class BertLayer(nn.Module):
"""
we modify the module to add word embedding information into the transformer
"""
def __init__(self, config, has_word_attn=False):
super().__init__()
self.chunk_size_feed_forward = config.chunk_size_feed_forward
self.seq_len_dim = 1
self.attention = BertAttention(config)
self.is_decoder = config.is_decoder
self.add_cross_attention = config.add_cross_attention
if self.add_cross_attention:
assert self.is_decoder, f"{self} should be used as a decoder model if cross attention is added"
self.crossattention = BertAttention(config)
## here we add a attention for matched word
self.has_word_attn = has_word_attn
if self.has_word_attn:
self.dropout = nn.Dropout(config.hidden_dropout_prob)
self.act = nn.Tanh()
self.word_transform = nn.Linear(config.word_embed_dim, config.hidden_size)
self.word_word_weight = nn.Linear(config.hidden_size, config.hidden_size)
attn_W = torch.zeros(config.hidden_size, config.hidden_size)
self.attn_W = nn.Parameter(attn_W)
self.attn_W.data.normal_(mean=0.0, std=config.initializer_range)
self.fuse_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.intermediate = BertIntermediate(config)
self.output = BertOutput(config)
def forward(
self,
hidden_states,
attention_mask=None,
input_word_embeddings=None,
input_word_mask=None,
head_mask=None,
encoder_hidden_states=None,
encoder_attention_mask=None,
output_attentions=False
):
"""
code refer to: https://github.com/huggingface/transformers/blob/master/src/transformers/modeling_bert.py
N: batch_size
L: seq length
W: word size
D: word_embedding dim
Args:
input_word_embedding: [N, L, W, D]
input_word_mask: [N, L, W]
"""
## 1.character contextual representation
self_attention_outputs = self.attention(
hidden_states,
attention_mask,
head_mask,
output_attentions=output_attentions,
)
attention_output = self_attention_outputs[0] # this is the contextual representation
outputs = self_attention_outputs[1:] # add self attentions if we output attention weights
# decode need join attention from the outputs
if self.is_decoder and encoder_hidden_states is not None:
assert hasattr(
self, "crossattention"
), f"If `encoder_hidden_states` are passed, {self} has to be instantiated with cross-attention layers by setting `config.add_cross_attention=True`"
cross_attention_outputs = self.crossattention(
attention_output,
attention_mask,
head_mask,
encoder_hidden_states,
encoder_attention_mask,
output_attentions,
)
attention_output = cross_attention_outputs[0]
outputs = outputs + cross_attention_outputs[1:] # add cross attentions if we output attention weights
layer_output = apply_chunking_to_forward(
self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output
)
if self.has_word_attn:
assert input_word_mask is not None
# transform
word_outputs = self.word_transform(input_word_embeddings) # [N, L, W, D]
word_outputs = self.act(word_outputs)
word_outputs = self.word_word_weight(word_outputs)
word_outputs = self.dropout(word_outputs)
# attention_output = attention_output.unsqueeze(2) # [N, L, D] -> [N, L, 1, D]
alpha = torch.matmul(layer_output.unsqueeze(2), self.attn_W) # [N, L, 1, D]
alpha = torch.matmul(alpha, torch.transpose(word_outputs, 2, 3)) # [N, L, 1, W]
alpha = alpha.squeeze() # [N, L, W]
alpha = alpha + (1 - input_word_mask.float()) * (-10000.0)
alpha = torch.nn.Softmax(dim=-1)(alpha) # [N, L, W]
alpha = alpha.unsqueeze(-1) # [N, L, W, 1]
weighted_word_embedding = torch.sum(word_outputs * alpha, dim=2) # [N, L, D]
layer_output = layer_output + weighted_word_embedding
layer_output = self.dropout(layer_output)
layer_output = self.fuse_layernorm(layer_output)
outputs = (layer_output,) + outputs
return outputs
def feed_forward_chunk(self, attention_output):
intermediate_output = self.intermediate(attention_output)
layer_output = self.output(intermediate_output, attention_output)
return layer_output
# 构建BERT模型的编码器部分,包含多个BertLayer。
# 可以处理额外的词嵌入输入,并将其融合到BERT的隐藏状态中。
# 由多个BertLayer组成,每个层都包含自注意力机制和前馈网络。
# 可以通过input_word_embeddings和input_word_mask来融合额外的词嵌入信息。
class BertEncoder(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
self.add_layers = config.add_layers
total_layers = []
for i in range(config.num_hidden_layers):
if i in self.add_layers:
total_layers.append(BertLayer(config, True))
else:
total_layers.append(BertLayer(config, False))
self.layer = nn.ModuleList(total_layers)
def forward(
self,
hidden_states,
attention_mask=None,
input_word_embeddings=None,
input_word_mask=None,
head_mask=None,
encoder_hidden_states=None,
encoder_attention_mask=None,
output_attentions=False,
output_hidden_states=False,
return_dict=False,
):
all_hidden_states = () if output_hidden_states else None
all_attentions = () if output_attentions else None
# print("Layer 0: \n")
# print(hidden_states)
for i, layer_module in enumerate(self.layer):
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
layer_head_mask = head_mask[i] if head_mask is not None else None
if getattr(self.config, "gradient_checkpointing", False):
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs, output_attentions)
return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(layer_module),
hidden_states,
attention_mask,
input_word_embeddings,
input_word_mask,
layer_head_mask,
encoder_hidden_states,
encoder_attention_mask,
)
else:
layer_outputs = layer_module(
hidden_states,
attention_mask,
input_word_embeddings,
input_word_mask,
layer_head_mask,
encoder_hidden_states,
encoder_attention_mask,
output_attentions,
)
hidden_states = layer_outputs[0]
# print("Layer %d: \n"%(i+1))
# print(hidden_states)
if output_attentions:
all_attentions = all_attentions + (layer_outputs[1],)
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
return tuple(v for v in [hidden_states, all_hidden_states, all_attentions] if v is not None)
# 一个池化层,用于从BERT模型的序列输出中提取特征向量。
# 使用一个线性层和激活函数来生成池化输出。
# 使用一个线性层和Tanh激活函数来提取整个序列的代表性特征。
class BertPooler(nn.Module):
def __init__(self, config):
super().__init__()
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
self.activation = nn.Tanh()
def forward(self, hidden_states):
# We "pool" the model by simply taking the hidden state corresponding
# to the first token.
first_token_tensor = hidden_states[:, 0]
pooled_output = self.dense(first_token_tensor)
pooled_output = self.activation(pooled_output)
return pooled_output
# 一个抽象基类,提供了权重初始化和加载预训练模型的接口。
# # 提供了初始化权重和加载预训练BERT模型的方法。
class BertPreTrainedModel(PreTrainedModel):
"""An abstract class to handle weights initialization and
a simple interface for downloading and loading pretrained models.
"""
config_class = BertConfig
load_tf_weights = load_tf_weights_in_bert
base_model_prefix = "bert"
authorized_missing_keys = [r"position_ids"]
def _init_weights(self, module):
""" Initialize the weights """
if isinstance(module, (nn.Linear, nn.Embedding, nn.Parameter)):
# Slightly different from the TF version which uses truncated_normal for initialization
# cf https://github.com/pytorch/pytorch/pull/5617
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
elif isinstance(module, nn.LayerNorm):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
if isinstance(module, nn.Linear) and module.bias is not None:
module.bias.data.zero_()
# 定义了完整的BERT模型,包括嵌入层、编码器、池化层。
# 可以加载预训练的BERT权重,并添加额外的池化层。
# 结合了嵌入层、编码器和池化层,形成了完整的BERT模型。
# 可以设置是否添加池化层。
class WCBertModel(BertPreTrainedModel):
def __init__(self, config, add_pooling_layer=True):
super(WCBertModel, self).__init__(config)
self.embeddings = BertEmbeddings(config)
self.encoder = BertEncoder(config)
self.pooler = BertPooler(config) if add_pooling_layer else None
self.init_weights()
def get_input_embeddings(self):
return self.embeddings.word_embeddings
def set_input_embeddings(self, value):
self.embeddings.word_embeddings = value
def _prune_heads(self, heads_to_prune):
"""Prunes heads of the model.
heads_to_prune: dict of {layer_num: list of heads to prune in this layer}
See base class PreTrainedModel
"""
for layer, heads in heads_to_prune.items():
self.encoder.layer[layer].attention.prune_heads(heads)
def forward(
self,
input_ids=None,
attention_mask=None,
token_type_ids=None,
matched_word_embeddings=None,
matched_word_mask=None,
boundary_ids=None,
position_ids=None,
head_mask=None,
inputs_embeds=None,
encoder_hidden_states=None,
encoder_attention_mask=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
):
r"""
encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention
if the model is configured as a decoder.
encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
Mask to avoid performing attention on the padding token indices of the encoder input. This mask
is used in the cross-attention if the model is configured as a decoder.
Mask values selected in ``[0, 1]``:
``1`` for tokens that are NOT MASKED, ``0`` for MASKED tokens.
batch_size: N
seq_length: L
dim: D
word_num: W
boundary_num: B
Args:
input_ids: [N, L]
attention_mask: [N, L]
boundary_ids: [N, L, B]
boundary_mask: [N, L, B]
matched_word_embeddings: [B, L, W, D]
matched_word_mask: [B, L, W]
"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
elif input_ids is not None:
input_shape = input_ids.size()
elif inputs_embeds is not None:
input_shape = inputs_embeds.size()[:-1]
else:
raise ValueError("You have to specify either input_ids or inputs_embeds")
device = input_ids.device if input_ids is not None else inputs_embeds.device
if attention_mask is None:
attention_mask = torch.ones(input_shape, device=device)
if token_type_ids is None:
token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
# ourselves in which case we just need to make it broadcastable to all heads.
extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape, device)
# If a 2D or 3D attention mask is provided for the cross-attention
# we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
if self.config.is_decoder and encoder_hidden_states is not None:
encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
if encoder_attention_mask is None:
encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
else:
encoder_extended_attention_mask = None
# Prepare head mask if needed
# 1.0 in head_mask indicate we keep the head
# attention_probs has shape bsz x n_heads x N x N
# input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
# and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
embedding_output = self.embeddings(
input_ids=input_ids, position_ids=position_ids, token_type_ids=token_type_ids,
boundary_ids=boundary_ids, inputs_embeds=inputs_embeds,
)
encoder_outputs = self.encoder(
embedding_output,
attention_mask=extended_attention_mask,
input_word_embeddings=matched_word_embeddings,
input_word_mask=matched_word_mask,
head_mask=head_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_extended_attention_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
sequence_output = encoder_outputs[0]
pooled_output = self.pooler(sequence_output)
if not return_dict:
return (sequence_output, pooled_output) + encoder_outputs[1:]
return BaseModelOutputWithPooling(
last_hidden_state=sequence_output,
pooler_output=pooled_output,
hidden_states=encoder_outputs.hidden_states,
attentions=encoder_outputs.attentions,
)
# 继承自BertPreTrainedModel,用于命名实体识别任务。
# 结合了BERT模型和CRF层,用于序列标注任务。
# 在BERT模型的基础上增加了dropout层和CRF层。
# forward方法同时支持训练和预测模式。
class WCBertCRFForTokenClassification(BertPreTrainedModel):
def __init__(self, config, pretrained_embeddings, num_labels):
super().__init__(config)
word_vocab_size = pretrained_embeddings.shape[0]
embed_dim = pretrained_embeddings.shape[1]
self.word_embeddings = nn.Embedding(word_vocab_size, embed_dim)
self.bert = WCBertModel(config)
self.dropout = nn.Dropout(config.HP_dropout)
self.num_labels = num_labels
self.hidden2tag = nn.Linear(config.hidden_size, num_labels + 2)
self.crf = CRF(num_labels, torch.cuda.is_available())
self.init_weights()
## init the embedding
self.word_embeddings.weight.data.copy_(torch.from_numpy(pretrained_embeddings))
print("Load pretrained embedding from file.........")
def forward(
self,
input_ids=None,
attention_mask=None,
token_type_ids=None,
matched_word_ids=None,
matched_word_mask=None,
boundary_ids=None,
labels=None,
flag="Train"
):
matched_word_embeddings = self.word_embeddings(matched_word_ids)
outputs = self.bert(
input_ids=input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
matched_word_embeddings=matched_word_embeddings,
matched_word_mask=matched_word_mask,
boundary_ids=boundary_ids,
#增加这个,是为了让输出的embedding再多12层,修改,为了embeding可视化
output_hidden_states=True
)
sequence_output = outputs[0]
# 前向传播给定output_hidden_states=True参数时,BertLayer嵌入层和12个BertLayer层中的每一层都可以返回它们的输出(也称为hidden_states )
# 模型的维数(13, number_of_data_points, max_sequence_length, embeddings_dimension)
# 模型所有层的隐藏状态的列表。
hidden_states = outputs[2]
sequence_output = self.dropout(sequence_output)
logits = self.hidden2tag(sequence_output)
if flag == 'Train':
assert labels is not None
loss = self.crf.neg_log_likelihood_loss(logits, attention_mask, labels)
_, preds = self.crf._viterbi_decode(logits, attention_mask)
# return (loss, preds)
# 修改部分
return (loss, preds, sequence_output,hidden_states)
elif flag == 'Predict':
_, preds = self.crf._viterbi_decode(logits, attention_mask)
return (preds,)
# 另一个用于命名实体识别的模型,结合了BERT、LSTM和CRF层。
# 与WCBertCRFForTokenClassification类似,但增加了BiLSTM层来进一步处理BERT的输出。
class BertWordLSTMCRFForTokenClassification(BertPreTrainedModel):
"""
model-level fusion baseline
concat bert vector with attention weighted sum word embedding
and then input to LSTM-CRF
"""
# CRF(条件随机场): 用于命名实体识别的序列标注任务,可以学习标签之间的依赖关系。
# BiLSTM(双向长短期记忆网络): 用于学习序列数据的双向表示。
def __init__(self, config, pretrained_embeddings, num_labels):
super().__init__(config)
word_vocab_size = pretrained_embeddings.shape[0]
embed_dim = pretrained_embeddings.shape[1]
self.word_embeddings = nn.Embedding(word_vocab_size, embed_dim)
self.bert = BertModel(config)
self.dropout = nn.Dropout(config.HP_dropout)
self.act = nn.Tanh()
self.word_transform = nn.Linear(config.word_embed_dim, config.hidden_size)
self.word_word_weight = nn.Linear(config.hidden_size, config.hidden_size)
self.bilstm = BiLSTM(config.hidden_size * 2, config.lstm_size, config.HP_dropout)
attn_W = torch.zeros(config.hidden_size, config.hidden_size)
self.attn_W = nn.Parameter(attn_W)
self.attn_W.data.normal_(mean=0.0, std=config.initializer_range)
self.num_labels = num_labels
self.hidden2tag = nn.Linear(config.lstm_size * 2, num_labels + 2)
self.crf = CRF(num_labels, torch.cuda.is_available())
self.init_weights()
## init the embedding
self.word_embeddings.weight.data.copy_(torch.from_numpy(pretrained_embeddings))
print("Load pretrained embedding from file.........")
def forward(
self,
input_ids=None,
attention_mask=None,
token_type_ids=None,
matched_word_ids=None,
matched_word_mask=None,
boundary_ids=None,
labels=None,
flag="Train"
):
matched_word_embeddings = self.word_embeddings(matched_word_ids)
outputs = self.bert(
input_ids=input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids
)
sequence_output = outputs[0]
matched_word_embeddings = self.word_transform(matched_word_embeddings)
matched_word_embeddings = self.act(matched_word_embeddings)
matched_word_embeddings = self.word_word_weight(matched_word_embeddings)
matched_word_embeddings = self.dropout(matched_word_embeddings)
alpha = torch.matmul(sequence_output.unsqueeze(2), self.attn_W) # [N, L, 1, D]
alpha = torch.matmul(alpha, torch.transpose(matched_word_embeddings, 2, 3)) # [N, L, 1, W]
alpha = alpha.squeeze() # [N, L, W]
alpha = alpha + (1 - matched_word_mask.float()) * (-2 ** 31 + 1)
alpha = torch.nn.Softmax(dim=-1)(alpha) # [N, L, W]
alpha = alpha.unsqueeze(-1) # [N, L, W, 1]
matched_word_embeddings = torch.sum(matched_word_embeddings * alpha, dim=2) # [N, L, D]
## concat the embedding [B, L, N, D], [B, L, N]
sequence_output = torch.cat((sequence_output, matched_word_embeddings), dim=-1)
sequence_output = self.dropout(sequence_output)
lstm_output = self.bilstm(sequence_output, attention_mask)
logits = self.hidden2tag(lstm_output)
if flag == 'Train':
assert labels is not None
loss = self.crf.neg_log_likelihood_loss(logits, attention_mask, labels)
_, preds = self.crf._viterbi_decode(logits, attention_mask)
return (loss, preds)
elif flag == 'Predict':
_, preds = self.crf._viterbi_decode(logits, attention_mask)
return (preds,)
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
1
https://gitee.com/gisfanmachel/pdcner.git
git@gitee.com:gisfanmachel/pdcner.git
gisfanmachel
pdcner
pdcner
master

搜索帮助

0d507c66 1850385 C8b1a773 1850385