1 Star 0 Fork 45

小松鼠/tensorflow

forked from src-openEuler/tensorflow 
加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
该仓库未声明开源许可证文件(LICENSE),使用请关注具体项目描述及其代码上游依赖。
克隆/下载
CVE-2021-29615.patch 3.19 KB
一键复制 编辑 原始数据 按行查看 历史
From e07e1c3d26492c06f078c7e5bf2d138043e199c1 Mon Sep 17 00:00:00 2001
From: Laura Pak <lpak@google.com>
Date: Fri, 23 Apr 2021 10:33:00 -0700
Subject: [PATCH] Prevent memory overflow in ParseAttrValue from nested
tensors.
PiperOrigin-RevId: 370108442
Change-Id: I84d64a5e8895a6aeffbf4749841b4c54d51b5889
---
tensorflow/core/framework/attr_value_util.cc | 58 +++++++++++++++++++-
1 file changed, 57 insertions(+), 1 deletion(-)
diff --git a/tensorflow/core/framework/attr_value_util.cc b/tensorflow/core/framework/attr_value_util.cc
index 712e205c587c8..76fe36e7f1e2a 100644
--- a/tensorflow/core/framework/attr_value_util.cc
+++ b/tensorflow/core/framework/attr_value_util.cc
@@ -38,6 +38,9 @@ namespace {
// Do not construct large tensors to compute their hash or compare for equality.
constexpr int kMaxAttrValueTensorByteSize = 32 * 1024 * 1024; // 32mb
+// Limit nesting of tensors to 100 deep to prevent memory overflow.
+constexpr int kMaxTensorNestDepth = 100;
+
// Return the size of the tensor represented by this TensorProto. If shape is
// not fully defined return -1.
int64 TensorByteSize(const TensorProto& t) {
@@ -224,6 +227,54 @@ string SummarizeFunc(const NameAttrList& func) {
return strings::StrCat(func.name(), "[", absl::StrJoin(entries, ", "), "]");
}
+bool ParseAttrValueHelper_TensorNestsUnderLimit(int limit, string to_parse) {
+ int nests = 0;
+ int maxed_out = to_parse.length();
+ int open_curly = to_parse.find('{');
+ int open_bracket = to_parse.find('<');
+ int close_curly = to_parse.find('}');
+ int close_bracket = to_parse.find('>');
+ if (open_curly == -1) {
+ open_curly = maxed_out;
+ }
+ if (open_bracket == -1) {
+ open_bracket = maxed_out;
+ }
+ int min = std::min(open_curly, open_bracket);
+ do {
+ if (open_curly == maxed_out && open_bracket == maxed_out) {
+ return true;
+ }
+ if (min == open_curly) {
+ nests += 1;
+ open_curly = to_parse.find('{', open_curly + 1);
+ if (open_curly == -1) {
+ open_curly = maxed_out;
+ }
+ } else if (min == open_bracket) {
+ nests += 1;
+ open_bracket = to_parse.find('<', open_bracket + 1);
+ if (open_bracket == -1) {
+ open_bracket = maxed_out;
+ }
+ } else if (min == close_curly) {
+ nests -= 1;
+ close_curly = to_parse.find('}', close_curly + 1);
+ if (close_curly == -1) {
+ close_curly = maxed_out;
+ }
+ } else if (min == close_bracket) {
+ nests -= 1;
+ close_bracket = to_parse.find('>', close_bracket + 1);
+ if (close_bracket == -1) {
+ close_bracket = maxed_out;
+ }
+ }
+ min = std::min({open_curly, open_bracket, close_curly, close_bracket});
+ } while (nests < 100);
+ return false;
+}
+
} // namespace
string SummarizeAttrValue(const AttrValue& attr_value) {
@@ -448,7 +499,12 @@ bool ParseAttrValue(StringPiece type, StringPiece text, AttrValue* out) {
} else {
to_parse = strings::StrCat(field_name, ": ", text);
}
-
+ if (field_name == "tensor") {
+ if (!ParseAttrValueHelper_TensorNestsUnderLimit(kMaxTensorNestDepth,
+ to_parse)) {
+ return false;
+ }
+ }
return ProtoParseFromString(to_parse, out);
}
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
1
https://gitee.com/wang_songsong/tensorflow.git
git@gitee.com:wang_songsong/tensorflow.git
wang_songsong
tensorflow
tensorflow
master

搜索帮助