1 Star 0 Fork 45

小松鼠/tensorflow

forked from src-openEuler/tensorflow 
加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
该仓库未声明开源许可证文件(LICENSE),使用请关注具体项目描述及其代码上游依赖。
克隆/下载
CVE-2021-29617.patch 2.45 KB
一键复制 编辑 原始数据 按行查看 历史
From 7edb8c9b83ad583616406af61e0de61393996a3e Mon Sep 17 00:00:00 2001
From: Yong Tang <yong.tang.github@outlook.com>
Date: Sat, 6 Feb 2021 20:24:54 +0000
Subject: [PATCH] Fix crash of tf.strings.substr when pos and len have
different shapes
This PR tries to address the issue raised in 46900 where
tf.strings.substr will crash when pos and len have different shapes.
According to the documentation of tf.strings.substr, ValueError
should be raised instead when pos and len does not have the same shape.
This PR add shape check in kernel to allows grace error throw (instead of crash).
This PR fixes 46900.
Signed-off-by: Yong Tang <yong.tang.github@outlook.com>
---
tensorflow/core/kernels/substr_op.cc | 6 ++++++
tensorflow/python/kernel_tests/substr_op_test.py | 10 ++++++++++
2 files changed, 16 insertions(+)
diff --git a/tensorflow/core/kernels/substr_op.cc b/tensorflow/core/kernels/substr_op.cc
index e382381e12232..0c94ba35b249a 100644
--- a/tensorflow/core/kernels/substr_op.cc
+++ b/tensorflow/core/kernels/substr_op.cc
@@ -51,6 +51,12 @@ class SubstrOp : public OpKernel {
const Tensor& len_tensor = context->input(2);
const TensorShape& input_shape = input_tensor.shape();
const TensorShape& pos_shape = pos_tensor.shape();
+ const TensorShape& len_shape = len_tensor.shape();
+ OP_REQUIRES(
+ context, (pos_shape == len_shape),
+ errors::InvalidArgument("pos and len should have the same shape, got: ",
+ pos_shape.DebugString(), " vs. ",
+ len_shape.DebugString()));
bool is_scalar = TensorShapeUtils::IsScalar(pos_shape);
diff --git a/tensorflow/python/kernel_tests/substr_op_test.py b/tensorflow/python/kernel_tests/substr_op_test.py
index 9302152e82bfa..ad7b6050c2901 100644
--- a/tensorflow/python/kernel_tests/substr_op_test.py
+++ b/tensorflow/python/kernel_tests/substr_op_test.py
@@ -492,6 +492,16 @@ def testInvalidUnit(self):
with self.assertRaises(ValueError):
string_ops.substr(b"test", 3, 1, unit="UTF8")
+ def testInvalidPos(self):
+ # Test case for GitHub issue 46900.
+ with self.assertRaises((ValueError, errors_impl.InvalidArgumentError)):
+ x = string_ops.substr(b"abc", len=1, pos=[1, -1])
+ self.evaluate(x)
+
+ with self.assertRaises((ValueError, errors_impl.InvalidArgumentError)):
+ x = string_ops.substr(b"abc", len=1, pos=[1, 2])
+ self.evaluate(x)
+
if __name__ == "__main__":
test.main()
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
1
https://gitee.com/wang_songsong/tensorflow.git
git@gitee.com:wang_songsong/tensorflow.git
wang_songsong
tensorflow
tensorflow
master

搜索帮助