代码拉取完成,页面将自动刷新
同步操作将从 src-openEuler/tensorflow 强制同步,此操作会覆盖自 Fork 仓库以来所做的任何修改,且无法恢复!!!
确定后同步将在后台操作,完成时将刷新页面,请耐心等待。
From 7b8db6083b34520688dbc71f341f7aeaf156bf17 Mon Sep 17 00:00:00 2001
From: Eugene Zhulenev <ezhulenev@google.com>
Date: Fri, 19 Mar 2021 16:16:41 -0700
Subject: [PATCH] Implement grouped convolution on CPU
To get better compute resources utilization group-compute loop has to be parallelized, but it involves a lot of changes in Conv2D primitives. Will address that later if it will be critical for some of the users.
Fix for: https://github.com/tensorflow/tensorflow/issues/29005
PiperOrigin-RevId: 363991782
Change-Id: I97f375b1133833c4de5181199316be7cbf4ebee0
---
tensorflow/core/kernels/BUILD | 1 +
tensorflow/core/kernels/conv_2d.h | 54 +++++++
tensorflow/core/kernels/conv_ops.cc | 133 ++++++++++++++++--
.../python/kernel_tests/conv_ops_test.py | 20 +--
4 files changed, 189 insertions(+), 19 deletions(-)
diff --git a/tensorflow/core/kernels/BUILD b/tensorflow/core/kernels/BUILD
index 8e49f1e0a5caf..bc455626f4322 100644
--- a/tensorflow/core/kernels/BUILD
+++ b/tensorflow/core/kernels/BUILD
@@ -3818,6 +3818,7 @@ tf_kernel_library(
":ops_util",
"@com_google_absl//absl/base:dynamic_annotations",
"@com_google_absl//absl/strings",
+ "@com_google_absl//absl/synchronization",
"//third_party/eigen3",
"//tensorflow/core:core_cpu",
"//tensorflow/core:framework",
diff --git a/tensorflow/core/kernels/conv_2d.h b/tensorflow/core/kernels/conv_2d.h
index b9a8c977e11ee..87df4a848dd56 100644
--- a/tensorflow/core/kernels/conv_2d.h
+++ b/tensorflow/core/kernels/conv_2d.h
@@ -43,6 +43,9 @@ void SpatialConvolutionFunc(const Device& d, Output output, Input input,
padding_bottom);
}
+// TODO(ezhulenev): Non-templated `operator()` are required by explicit template
+// instantiations for the GPU device. However they are almost certainly not used
+// in any of the kernel implementation. Check if they can be removed.
template <typename Device, typename T,
typename OutputKernel = const Eigen::NoOpOutputKernel>
struct SpatialConvolution {
@@ -55,6 +58,16 @@ struct SpatialConvolution {
SpatialConvolutionFunc(d, output, input, filter, row_stride, col_stride,
row_dilation, col_dilation, padding, output_kernel);
}
+
+ template <typename Input, typename Filter, typename Output>
+ void operator()(const Device& d, Output output, Input input, Filter filter,
+ int row_stride, int col_stride, int row_dilation,
+ int col_dilation, const Eigen::PaddingType& padding,
+ const OutputKernel& output_kernel = OutputKernel()) {
+ SpatialConvolutionFunc(d, output, input, filter, row_stride, col_stride,
+ row_dilation, col_dilation, padding, output_kernel);
+ }
+
void operator()(const Device& d, typename TTypes<T, 4>::Tensor output,
typename TTypes<T, 4>::ConstTensor input,
typename TTypes<T, 4>::ConstTensor filter, int row_stride,
@@ -67,6 +80,18 @@ struct SpatialConvolution {
col_dilation, Eigen::PaddingType::PADDING_VALID, output_kernel,
padding_top, padding_bottom, padding_left, padding_right);
}
+
+ template <typename Input, typename Filter, typename Output>
+ void operator()(const Device& d, Output output, Input input, Filter filter,
+ int row_stride, int col_stride, int row_dilation,
+ int col_dilation, int padding_top, int padding_bottom,
+ int padding_left, int padding_right,
+ const OutputKernel& output_kernel = OutputKernel()) {
+ SpatialConvolutionFunc(
+ d, output, input, filter, row_stride, col_stride, row_dilation,
+ col_dilation, Eigen::PaddingType::PADDING_VALID, output_kernel,
+ padding_top, padding_bottom, padding_left, padding_right);
+ }
};
template <typename Device, typename OutputKernel>
@@ -84,6 +109,20 @@ struct SpatialConvolution<Device, Eigen::half, OutputKernel> {
row_dilation, output_kernel)
.template cast<Eigen::half>();
}
+
+ template <typename Input, typename Filter, typename Output>
+ void operator()(const Device& d, Output output, Input input, Filter filter,
+ int row_stride, int col_stride, int row_dilation,
+ int col_dilation, const Eigen::PaddingType& padding,
+ const OutputKernel& output_kernel = OutputKernel()) {
+ output.device(d) =
+ Eigen::SpatialConvolution(input.template cast<float>(),
+ filter.template cast<float>(), col_stride,
+ row_stride, padding, col_dilation,
+ row_dilation, output_kernel)
+ .template cast<Eigen::half>();
+ }
+
void operator()(const Device& d,
typename TTypes<Eigen::half, 4>::Tensor output,
typename TTypes<Eigen::half, 4>::ConstTensor input,
@@ -100,6 +139,21 @@ struct SpatialConvolution<Device, Eigen::half, OutputKernel> {
padding_bottom)
.template cast<Eigen::half>();
}
+
+ template <typename Input, typename Filter, typename Output>
+ void operator()(const Device& d, Output output, Input input, Filter filter,
+ int row_stride, int col_stride, int row_dilation,
+ int col_dilation, int padding_top, int padding_bottom,
+ int padding_left, int padding_right,
+ const OutputKernel& output_kernel = OutputKernel()) {
+ output.device(d) =
+ Eigen::SpatialConvolution(
+ input.template cast<float>(), filter.template cast<float>(),
+ col_stride, row_stride, Eigen::PaddingType::PADDING_VALID,
+ col_dilation, row_dilation, output_kernel, padding_left,
+ padding_right, padding_top, padding_bottom)
+ .template cast<Eigen::half>();
+ }
};
template <typename Device, typename T>
diff --git a/tensorflow/core/kernels/conv_ops.cc b/tensorflow/core/kernels/conv_ops.cc
index 025a8e37a94e9..8fdfe04bd1c67 100644
--- a/tensorflow/core/kernels/conv_ops.cc
+++ b/tensorflow/core/kernels/conv_ops.cc
@@ -30,6 +30,7 @@ limitations under the License.
#include <map>
#include <vector>
+#include "absl/synchronization/blocking_counter.h"
#include "tensorflow/core/framework/allocator.h"
#include "tensorflow/core/framework/bounds_check.h"
#include "tensorflow/core/framework/kernel_shape_util.h"
@@ -138,6 +139,98 @@ struct LaunchGeneric {
}
}
};
+
+// Compute grouped 2D convolutions on CPU. Unlike grouped convolution
+// implementation in cuDNN this is faaaaaar from optimal and needs more work
+// to deliver competitive performance. Currently it exists to close the feature
+// parity gap between convolution operations on different devices.
+template <typename T>
+struct LaunchGrouped {
+ void operator()(OpKernelContext* ctx, const Tensor& input,
+ const Tensor& filter, int row_stride, int col_stride,
+ int row_dilation, int col_dilation, const Padding& padding,
+ const std::vector<int64>& explicit_paddings, Tensor* output,
+ TensorFormat data_format) {
+ DCHECK(data_format == FORMAT_NHWC)
+ << "Grouped conv implementation only "
+ "supports NHWC tensor format for now.";
+
+ const int64 in_depth = input.dim_size(3);
+ const int64 patch_depth = filter.dim_size(2);
+ const int64 num_groups = in_depth / patch_depth;
+
+ // Shuffle input/filter tensors to have group as a leading dimension.
+ std::array<int64, 5> shuffle({3, 0, 1, 2, 4});
+
+ // Compute pre shuffle dimemnsions.
+ auto pre_shuffle = [&](const Tensor& tensor) -> std::array<int64, 5> {
+ return {tensor.dim_size(0), tensor.dim_size(1), tensor.dim_size(2),
+ num_groups, tensor.dim_size(3) / num_groups};
+ };
+
+ // Compute post shuffle dimemnsions.
+ auto post_shuffle = [&](const Tensor& tensor) -> std::array<int64, 5> {
+ return {num_groups, tensor.dim_size(0), tensor.dim_size(1),
+ tensor.dim_size(2), tensor.dim_size(3) / num_groups};
+ };
+
+ auto& device = ctx->eigen_device<CPUDevice>();
+
+ absl::BlockingCounter shuffles_completed(2);
+ auto on_shuffled = [&]() { shuffles_completed.DecrementCount(); };
+
+ // Shuffle input into temporary tensor.
+ Tensor input_shuffled(input.dtype(), TensorShape(post_shuffle(input)));
+ input_shuffled.tensor<T, 5>().device(device, on_shuffled) =
+ input.shaped<T, 5>(pre_shuffle(input)).shuffle(shuffle);
+
+ // Shuffle filter into temporary tensor.
+ Tensor filter_shuffled(filter.dtype(), TensorShape(post_shuffle(filter)));
+ filter_shuffled.tensor<T, 5>().device(device, on_shuffled) =
+ filter.shaped<T, 5>(pre_shuffle(filter)).shuffle(shuffle);
+
+ // Wait for the completion of input/filter shuffles.
+ shuffles_completed.Wait();
+
+ // Write group convolution results into temporary output tensor.
+ Tensor output_shuffled(output->dtype(), TensorShape(post_shuffle(*output)));
+
+ for (int64 i = 0; i < num_groups; ++i) {
+ // TODO(ezhulenev): Run this loop using `parallelFor` (regular parallelFor
+ // will lead to deadlock, SpatialConvolution has to use async Eigen
+ // assignment). This requires small changes to Eigen to support async
+ // exeuction for tensor chipping operation.
+
+ // TODO(ezhulenev): Grouped convolution should also support 1x1 filter
+ // optimization.
+
+ auto input_slice = input_shuffled.tensor<T, 5>().template chip<0>(i);
+ auto filter_slice = filter_shuffled.tensor<T, 5>().template chip<0>(i);
+ auto output_slice = output_shuffled.tensor<T, 5>().template chip<0>(i);
+
+ if (padding == EXPLICIT) {
+ functor::SpatialConvolution<CPUDevice, T>()(
+ ctx->eigen_device<CPUDevice>(), output_slice, input_slice,
+ filter_slice, row_stride, col_stride, row_dilation, col_dilation,
+ static_cast<int>(explicit_paddings[2]),
+ static_cast<int>(explicit_paddings[3]),
+ static_cast<int>(explicit_paddings[4]),
+ static_cast<int>(explicit_paddings[5]));
+ } else {
+ functor::SpatialConvolution<CPUDevice, T>()(
+ ctx->eigen_device<CPUDevice>(), output_slice, input_slice,
+ filter_slice, row_stride, col_stride, row_dilation, col_dilation,
+ BrainPadding2EigenPadding(padding));
+ }
+ }
+
+ // Shuffle temporary output back into pre-shuffled shape.
+ std::array<int64, 5> rev_shuffle({1, 2, 3, 0, 4});
+ output->shaped<T, 5>(pre_shuffle(*output)).device(device) =
+ output_shuffled.tensor<T, 5>().shuffle(rev_shuffle);
+ }
+};
+
} // namespace
template <typename T>
@@ -155,14 +248,6 @@ struct LaunchConv2DOp<CPUDevice, T> {
ToString(data_format)));
return;
}
- const int64 in_depth = GetTensorDim(input, data_format, 'C');
- OP_REQUIRES(ctx, in_depth == filter.dim_size(2),
- errors::Unimplemented(
- "The Conv2D op currently does not support grouped "
- "convolutions on the CPU. A grouped convolution was "
- "attempted to be run because the input depth of ",
- in_depth, " does not match the filter input depth of ",
- filter.dim_size(2)));
for (int64 explicit_padding : explicit_paddings) {
if (!FastBoundsCheck(explicit_padding, std::numeric_limits<int>::max())) {
@@ -170,9 +255,35 @@ struct LaunchConv2DOp<CPUDevice, T> {
return;
}
}
- LaunchGeneric<CPUDevice, T>()(ctx, input, filter, row_stride, col_stride,
- row_dilation, col_dilation, padding,
- explicit_paddings, output, data_format);
+
+ const int64 in_depth = input.dim_size(3);
+ const int64 out_depth = output->dim_size(3);
+ const int64 patch_depth = filter.dim_size(2);
+
+ if (in_depth % patch_depth != 0) {
+ ctx->SetStatus(errors::InvalidArgument(
+ "input depth must be evenly divisible by filter depth: ", in_depth,
+ " vs ", patch_depth));
+ return;
+ }
+
+ const int64 num_groups = in_depth / patch_depth;
+ if (out_depth % num_groups != 0 || out_depth < num_groups) {
+ ctx->SetStatus(errors::InvalidArgument(
+ "output depth must be evenly divisible by number of groups: ",
+ out_depth, " vs ", num_groups));
+ return;
+ }
+
+ if (in_depth != patch_depth) {
+ LaunchGrouped<T>()(ctx, input, filter, row_stride, col_stride,
+ row_dilation, col_dilation, padding, explicit_paddings,
+ output, data_format);
+ } else {
+ LaunchGeneric<CPUDevice, T>()(ctx, input, filter, row_stride, col_stride,
+ row_dilation, col_dilation, padding,
+ explicit_paddings, output, data_format);
+ }
}
};
diff --git a/tensorflow/python/kernel_tests/conv_ops_test.py b/tensorflow/python/kernel_tests/conv_ops_test.py
index 44a67ccc55f0a..92af04359caa9 100644
--- a/tensorflow/python/kernel_tests/conv_ops_test.py
+++ b/tensorflow/python/kernel_tests/conv_ops_test.py
@@ -834,17 +834,21 @@ def MakeConv2d(inputs, filters):
results[0], results[1], atol=tol_to_use, rtol=tol_to_use)
@test_util.run_in_graph_and_eager_modes
- @test_util.run_cuda_only
def testConv2DGroupConvFwd(self):
- for data_format in ["NHWC", "NCHW"]:
+ if test.is_gpu_available(cuda_only=True):
+ data_formats = ["NHWC", "NCHW"]
+ else:
+ data_formats = ["NHWC"]
+ for data_format in data_formats:
for dilation in [1, 2]:
for stride in [1, 2]:
- self._VerifyGroupConvFwd([10, 32, 32, 16], [3, 3, 4, 8],
- dilations=[dilation, dilation],
- strides=[stride, stride],
- padding="SAME",
- data_format=data_format,
- dtype=dtypes.float32)
+ for filter_dims in [[3, 3, 4, 8], [1, 1, 2, 16]]:
+ self._VerifyGroupConvFwd([10, 32, 32, 16], filter_dims,
+ dilations=[dilation, dilation],
+ strides=[stride, stride],
+ padding="SAME",
+ data_format=data_format,
+ dtype=dtypes.float32)
@test_util.deprecated_graph_mode_only
@test_util.run_cuda_only
此处可能存在不合适展示的内容,页面不予展示。您可通过相关编辑功能自查并修改。
如您确认内容无涉及 不当用语 / 纯广告导流 / 暴力 / 低俗色情 / 侵权 / 盗版 / 虚假 / 无价值内容或违法国家有关法律法规的内容,可点击提交进行申诉,我们将尽快为您处理。