代码拉取完成,页面将自动刷新
同步操作将从 src-openEuler/tensorflow 强制同步,此操作会覆盖自 Fork 仓库以来所做的任何修改,且无法恢复!!!
确定后同步将在后台操作,完成时将刷新页面,请耐心等待。
From b761c9b652af2107cfbc33efd19be0ce41daa33e Mon Sep 17 00:00:00 2001
From: Amit Patankar <amitpatankar@google.com>
Date: Thu, 15 Apr 2021 13:28:49 -0700
Subject: [PATCH] Fix `tf.raw_ops.RaggedTensorToTensor` failing CHECK.
PiperOrigin-RevId: 368706628
Change-Id: I5c9ea4833f38835ee183ca50d63251dc89c9f3bc
---
.../kernels/ragged_tensor_to_tensor_op.cc | 20 ++++++++++---------
1 file changed, 11 insertions(+), 9 deletions(-)
diff --git a/tensorflow/core/kernels/ragged_tensor_to_tensor_op.cc b/tensorflow/core/kernels/ragged_tensor_to_tensor_op.cc
index 433d910f6090c..434c853b63daa 100644
--- a/tensorflow/core/kernels/ragged_tensor_to_tensor_op.cc
+++ b/tensorflow/core/kernels/ragged_tensor_to_tensor_op.cc
@@ -208,7 +208,7 @@ class RaggedTensorToTensorBaseOp : public OpKernel {
}
void CalculateOutputIndexRowSplit(
- const RowPartitionTensor& row_split,
+ OpKernelContext* context, const RowPartitionTensor& row_split,
const vector<INDEX_TYPE>& parent_output_index,
INDEX_TYPE output_index_multiplier, INDEX_TYPE output_size,
vector<INDEX_TYPE>* result) {
@@ -233,7 +233,8 @@ class RaggedTensorToTensorBaseOp : public OpKernel {
}
}
if (row_split_size > 0) {
- DCHECK_EQ(result->size(), row_split(row_split_size - 1));
+ OP_REQUIRES(context, result->size() == row_split(row_split_size - 1),
+ errors::InvalidArgument("Invalid row split size."));
}
}
@@ -259,7 +260,7 @@ class RaggedTensorToTensorBaseOp : public OpKernel {
// result[7] = -1 because parent_output_index[value_rowids[6]] == -1
// result[8] = parent_output_index[value_rowids[7]]
void CalculateOutputIndexValueRowID(
- const RowPartitionTensor& value_rowids,
+ OpKernelContext* context, const RowPartitionTensor& value_rowids,
const vector<INDEX_TYPE>& parent_output_index,
INDEX_TYPE output_index_multiplier, INDEX_TYPE output_size,
vector<INDEX_TYPE>* result) {
@@ -293,7 +294,8 @@ class RaggedTensorToTensorBaseOp : public OpKernel {
}
result->push_back(current_output_index);
}
- DCHECK_EQ(result->size(), value_rowids.size());
+ OP_REQUIRES(context, result->size() == value_rowids.size(),
+ errors::InvalidArgument("Invalid row ids."));
}
Status CalculateOutputIndex(OpKernelContext* context, int dimension,
@@ -307,13 +309,13 @@ class RaggedTensorToTensorBaseOp : public OpKernel {
switch (partition_type) {
case RowPartitionType::VALUE_ROWIDS:
CalculateOutputIndexValueRowID(
- row_partition_tensor, parent_output_index, output_index_multiplier,
- output_size, result);
+ context, row_partition_tensor, parent_output_index,
+ output_index_multiplier, output_size, result);
return tensorflow::Status::OK();
case RowPartitionType::ROW_SPLITS:
- CalculateOutputIndexRowSplit(row_partition_tensor, parent_output_index,
- output_index_multiplier, output_size,
- result);
+ CalculateOutputIndexRowSplit(
+ context, row_partition_tensor, parent_output_index,
+ output_index_multiplier, output_size, result);
return tensorflow::Status::OK();
default:
return errors::InvalidArgument(
此处可能存在不合适展示的内容,页面不予展示。您可通过相关编辑功能自查并修改。
如您确认内容无涉及 不当用语 / 纯广告导流 / 暴力 / 低俗色情 / 侵权 / 盗版 / 虚假 / 无价值内容或违法国家有关法律法规的内容,可点击提交进行申诉,我们将尽快为您处理。