代码拉取完成,页面将自动刷新
同步操作将从 src-openEuler/tensorflow 强制同步,此操作会覆盖自 Fork 仓库以来所做的任何修改,且无法恢复!!!
确定后同步将在后台操作,完成时将刷新页面,请耐心等待。
From 7edb8c9b83ad583616406af61e0de61393996a3e Mon Sep 17 00:00:00 2001
From: Yong Tang <yong.tang.github@outlook.com>
Date: Sat, 6 Feb 2021 20:24:54 +0000
Subject: [PATCH] Fix crash of tf.strings.substr when pos and len have
different shapes
This PR tries to address the issue raised in 46900 where
tf.strings.substr will crash when pos and len have different shapes.
According to the documentation of tf.strings.substr, ValueError
should be raised instead when pos and len does not have the same shape.
This PR add shape check in kernel to allows grace error throw (instead of crash).
This PR fixes 46900.
Signed-off-by: Yong Tang <yong.tang.github@outlook.com>
---
tensorflow/core/kernels/substr_op.cc | 6 ++++++
tensorflow/python/kernel_tests/substr_op_test.py | 10 ++++++++++
2 files changed, 16 insertions(+)
diff --git a/tensorflow/core/kernels/substr_op.cc b/tensorflow/core/kernels/substr_op.cc
index e382381e12232..0c94ba35b249a 100644
--- a/tensorflow/core/kernels/substr_op.cc
+++ b/tensorflow/core/kernels/substr_op.cc
@@ -51,6 +51,12 @@ class SubstrOp : public OpKernel {
const Tensor& len_tensor = context->input(2);
const TensorShape& input_shape = input_tensor.shape();
const TensorShape& pos_shape = pos_tensor.shape();
+ const TensorShape& len_shape = len_tensor.shape();
+ OP_REQUIRES(
+ context, (pos_shape == len_shape),
+ errors::InvalidArgument("pos and len should have the same shape, got: ",
+ pos_shape.DebugString(), " vs. ",
+ len_shape.DebugString()));
bool is_scalar = TensorShapeUtils::IsScalar(pos_shape);
diff --git a/tensorflow/python/kernel_tests/substr_op_test.py b/tensorflow/python/kernel_tests/substr_op_test.py
index 9302152e82bfa..ad7b6050c2901 100644
--- a/tensorflow/python/kernel_tests/substr_op_test.py
+++ b/tensorflow/python/kernel_tests/substr_op_test.py
@@ -492,6 +492,16 @@ def testInvalidUnit(self):
with self.assertRaises(ValueError):
string_ops.substr(b"test", 3, 1, unit="UTF8")
+ def testInvalidPos(self):
+ # Test case for GitHub issue 46900.
+ with self.assertRaises((ValueError, errors_impl.InvalidArgumentError)):
+ x = string_ops.substr(b"abc", len=1, pos=[1, -1])
+ self.evaluate(x)
+
+ with self.assertRaises((ValueError, errors_impl.InvalidArgumentError)):
+ x = string_ops.substr(b"abc", len=1, pos=[1, 2])
+ self.evaluate(x)
+
if __name__ == "__main__":
test.main()
此处可能存在不合适展示的内容,页面不予展示。您可通过相关编辑功能自查并修改。
如您确认内容无涉及 不当用语 / 纯广告导流 / 暴力 / 低俗色情 / 侵权 / 盗版 / 虚假 / 无价值内容或违法国家有关法律法规的内容,可点击提交进行申诉,我们将尽快为您处理。