代码拉取完成,页面将自动刷新
#!/usr/bin/env python3
# -*- coding:utf-8 -*-
# Author: kerlomz <kerlomz@gmail.com>
import tensorflow as tf
from config import ModelConfig
class Loss(object):
"""损失函数生成器"""
@staticmethod
def cross_entropy(labels, logits):
"""交叉熵损失函数"""
# return tf.nn.softmax_cross_entropy_with_logits_v2(logits=logits, labels=labels)
# return tf.nn.sparse_softmax_cross_entropy_with_logits(labels=labels, logits=logits)
# return tf.nn.sparse_softmax_cross_entropy_with_logits(logits=logits, labels=labels)
# return tf.nn.sigmoid_cross_entropy_with_logits(logits=logits, labels=labels)
target = tf.sparse.to_dense(labels)
# target = labels
print('logits', logits.shape)
print('target', target.shape)
# logits = tf.reshape(tensor=logits, shape=[tf.shape(labels)[0], None])
return tf.keras.backend.sparse_categorical_crossentropy(
target=target,
output=logits,
from_logits=True,
)
@staticmethod
def ctc(labels, logits, sequence_length):
"""CTC 损失函数"""
return tf.nn.ctc_loss_v2(
labels=labels,
logits=logits,
logit_length=sequence_length,
label_length=sequence_length,
blank_index=-1,
logits_time_major=True
)
此处可能存在不合适展示的内容,页面不予展示。您可通过相关编辑功能自查并修改。
如您确认内容无涉及 不当用语 / 纯广告导流 / 暴力 / 低俗色情 / 侵权 / 盗版 / 虚假 / 无价值内容或违法国家有关法律法规的内容,可点击提交进行申诉,我们将尽快为您处理。