代码拉取完成,页面将自动刷新
同步操作将从 src-openEuler/tensorflow 强制同步,此操作会覆盖自 Fork 仓库以来所做的任何修改,且无法恢复!!!
确定后同步将在后台操作,完成时将刷新页面,请耐心等待。
From aab9998916c2ffbd8f0592059fad352622f89cda Mon Sep 17 00:00:00 2001
From: Reed Wanderman-Milne <reedwm@google.com>
Date: Wed, 29 Sep 2021 13:00:50 -0700
Subject: [PATCH] Add shape checks to FusedBatchNorm kernels.
---
.../core/kernels/fused_batch_norm_op.cc | 38 +++++-
.../python/ops/nn_fused_batchnorm_test.py | 122 ++++++++++++++++++
2 files changed, 153 insertions(+), 7 deletions(-)
diff --git a/tensorflow/core/kernels/fused_batch_norm_op.cc b/tensorflow/core/kernels/fused_batch_norm_op.cc
index bd5dab36..b19323f0 100644
--- a/tensorflow/core/kernels/fused_batch_norm_op.cc
+++ b/tensorflow/core/kernels/fused_batch_norm_op.cc
@@ -1279,18 +1279,20 @@ class FusedBatchNormOpBase : public OpKernel {
errors::InvalidArgument("offset must have the same number of elements "
"as the channels of x, got ",
offset.NumElements(), " and ", num_channels));
- if (estimated_mean.NumElements() != 0) {
+ if (!is_training_ || exponential_avg_factor_ != 1.) {
+ std::string prefix_msg = is_training_ ? "When exponential_avg_factor != 1"
+ : "When is_training=false";
OP_REQUIRES(context, estimated_mean.NumElements() == num_channels,
errors::InvalidArgument(
- "mean must be empty or have the same number of "
- "elements as the channels of x, got ",
+ prefix_msg,
+ ", mean must have the same number "
+ "of elements as the channels of x, got ",
estimated_mean.NumElements(), " and ",num_channels));
- }
- if (estimated_variance.NumElements() != 0) {
OP_REQUIRES(context, estimated_variance.NumElements() == num_channels,
errors::InvalidArgument(
- "variance must be empty or have the same number of "
- "elements as the channels of x, got ",
+ prefix_msg,
+ ", variance must have the same "
+ "number of elements as the channels of x, got ",
estimated_variance.NumElements(), " and ", num_channels));
}
@@ -1434,6 +1436,28 @@ class FusedBatchNormGradOpBase : public OpKernel {
errors::InvalidArgument(
"saved variance must be 1-dimensional",
saved_maybe_inv_var_or_pop_var.shape().DebugString()));
+ OP_REQUIRES(
+ context, x.shape() == y_backprop.shape(),
+ errors::InvalidArgument(
+ "x and y_backprop must have same shape, but x has shape ",
+ x.shape(), " and y_backprop has shape ", y_backprop.shape()));
+
+ const auto num_channels = GetTensorDim(x, tensor_format_, 'C');
+ OP_REQUIRES(
+ context, scale.NumElements() == num_channels,
+ errors::InvalidArgument("scale must have the same number of elements "
+ "as the channels of x, got ",
+ scale.NumElements(), " and ", num_channels));
+ OP_REQUIRES(
+ context, saved_mean_or_pop_mean.NumElements() == num_channels,
+ errors::InvalidArgument("reserve_space_1 must have the same number of "
+ "elements as the channels of x, got ",
+ scale.NumElements(), " and ", num_channels));
+ OP_REQUIRES(
+ context, saved_maybe_inv_var_or_pop_var.NumElements() == num_channels,
+ errors::InvalidArgument("reserve_space_2 must have the same number of "
+ "elements as the channels of x, got ",
+ scale.NumElements(), " and ", num_channels));
Tensor* x_backprop = nullptr;
OP_REQUIRES_OK(context,
diff --git a/tensorflow/python/ops/nn_fused_batchnorm_test.py b/tensorflow/python/ops/nn_fused_batchnorm_test.py
index 1742a919..8fecd1c7 100644
--- a/tensorflow/python/ops/nn_fused_batchnorm_test.py
+++ b/tensorflow/python/ops/nn_fused_batchnorm_test.py
@@ -20,10 +20,13 @@ from __future__ import print_function
import numpy as np
+from tensorflow.python.eager import context
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import errors_impl
from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import gen_nn_ops
from tensorflow.python.ops import gradient_checker
from tensorflow.python.ops import gradients_impl
from tensorflow.python.ops import math_ops
@@ -610,6 +613,125 @@ class BatchNormalizationTest(test.TestCase):
}
self._testBatchNormGradGrad(config)
+ def testEagerShapeErrors(self):
+ with context.eager_mode():
+ x = array_ops.ones((2, 2, 2, 2))
+ scale = array_ops.ones((3,))
+ offset = array_ops.ones((2,))
+ with self.assertRaisesRegex(
+ errors_impl.InvalidArgumentError,
+ 'scale must have the same number of elements'):
+ nn_impl.fused_batch_norm(x, scale, offset)
+
+ x = array_ops.ones((2, 2, 2, 2))
+ scale = array_ops.ones((2,))
+ offset = array_ops.ones((3,))
+ with self.assertRaisesRegex(
+ errors_impl.InvalidArgumentError,
+ 'offset must have the same number of elements'):
+ nn_impl.fused_batch_norm(x, scale, offset)
+
+ x = array_ops.ones((2, 2, 2, 2))
+ scale = array_ops.ones((2,))
+ offset = array_ops.ones((2,))
+ mean = array_ops.ones((0,))
+ variance = array_ops.ones((2,))
+ with self.assertRaisesRegex(
+ errors_impl.InvalidArgumentError,
+ 'When is_training=false, mean must have the same number of elements'):
+ nn_impl.fused_batch_norm(
+ x, scale, offset, mean=mean, variance=variance, is_training=False)
+
+ x = array_ops.ones((2, 2, 2, 2))
+ scale = array_ops.ones((2,))
+ offset = array_ops.ones((2,))
+ mean = array_ops.ones((2,))
+ variance = array_ops.ones((0,))
+ with self.assertRaisesRegex(
+ errors_impl.InvalidArgumentError,
+ 'When is_training=false, variance must have the same number of '
+ nn_impl.fused_batch_norm(
+ x, scale, offset, mean=mean, variance=variance, is_training=False)
+
+ x = array_ops.ones((2, 2, 2, 2))
+ scale = array_ops.ones((2,))
+ offset = array_ops.ones((2,))
+ mean = array_ops.ones((0,))
+ variance = array_ops.ones((2,))
+ with self.assertRaisesRegex(
+ errors_impl.InvalidArgumentError,
+ 'When exponential_avg_factor != 1, mean must have the same number of '
+ 'elements'):
+ nn_impl.fused_batch_norm(
+ x,
+ scale,
+ offset,
+ mean=mean,
+ variance=variance,
+ exponential_avg_factor=0.5)
+
+ x = array_ops.ones((2, 2, 2, 2))
+ scale = array_ops.ones((2,))
+ offset = array_ops.ones((2,))
+ mean = array_ops.ones((2,))
+ variance = array_ops.ones((0,))
+ with self.assertRaisesRegex(
+ errors_impl.InvalidArgumentError,
+ 'When exponential_avg_factor != 1, variance must have the same '
+ 'number of elements'):
+ nn_impl.fused_batch_norm(
+ x,
+ scale,
+ offset,
+ mean=mean,
+ variance=variance,
+ exponential_avg_factor=0.5)
+
+ def testEagerShapeGradErrors(self):
+ with context.eager_mode():
+ y_backprop = array_ops.ones((2, 2, 2, 3))
+ x = array_ops.ones((2, 2, 2, 2))
+ scale = array_ops.ones((2,))
+ reserve_space_1 = array_ops.ones((2,))
+ reserve_space_2 = array_ops.ones((2,))
+ with self.assertRaisesRegex(errors_impl.InvalidArgumentError,
+ 'x and y_backprop must have same shape,'):
+ gen_nn_ops.fused_batch_norm_grad_v2(y_backprop, x, scale,
+ reserve_space_1, reserve_space_2)
+
+ y_backprop = array_ops.ones((2, 2, 2, 2))
+ x = array_ops.ones((2, 2, 2, 2))
+ scale = array_ops.ones((3,))
+ reserve_space_1 = array_ops.ones((2,))
+ reserve_space_2 = array_ops.ones((2,))
+ with self.assertRaisesRegex(
+ errors_impl.InvalidArgumentError,
+ 'scale must have the same number of elements'):
+ gen_nn_ops.fused_batch_norm_grad_v2(y_backprop, x, scale,
+ reserve_space_1, reserve_space_2)
+
+ y_backprop = array_ops.ones((2, 2, 2, 2))
+ x = array_ops.ones((2, 2, 2, 2))
+ scale = array_ops.ones((2,))
+ reserve_space_1 = array_ops.ones((3,))
+ reserve_space_2 = array_ops.ones((2,))
+ with self.assertRaisesRegex(
+ errors_impl.InvalidArgumentError,
+ 'reserve_space_1 must have the same number of elements'):
+ gen_nn_ops.fused_batch_norm_grad_v2(y_backprop, x, scale,
+ reserve_space_1, reserve_space_2)
+
+ y_backprop = array_ops.ones((2, 2, 2, 2))
+ x = array_ops.ones((2, 2, 2, 2))
+ scale = array_ops.ones((2,))
+ reserve_space_1 = array_ops.ones((2,))
+ reserve_space_2 = array_ops.ones((3,))
+ with self.assertRaisesRegex(
+ errors_impl.InvalidArgumentError,
+ 'reserve_space_2 must have the same number of elements'):
+ gen_nn_ops.fused_batch_norm_grad_v2(y_backprop, x, scale,
+ reserve_space_1, reserve_space_2)
+
if __name__ == '__main__':
test.main()
--
2.23.0
此处可能存在不合适展示的内容,页面不予展示。您可通过相关编辑功能自查并修改。
如您确认内容无涉及 不当用语 / 纯广告导流 / 暴力 / 低俗色情 / 侵权 / 盗版 / 虚假 / 无价值内容或违法国家有关法律法规的内容,可点击提交进行申诉,我们将尽快为您处理。