From 913eb5bdebfed41021b507fd9874c9d1c8f7a63e Mon Sep 17 00:00:00 2001 From: l00916585 <1410167690@qq.com> Date: Fri, 29 Nov 2024 16:32:05 +0800 Subject: [PATCH] merge convolution --- .../kdnn_jit_sve_256_1x1_convolution_f16.cpp | 381 ++++++++++++++++++ .../kdnn_jit_sve_256_1x1_convolution_f16.hpp | 234 +++++++++++ .../src/cpu/aarch64/kdnn/kdnn_utils_conv.cpp | 104 +++++ .../cpu/aarch64/kdnn/kdnn_utils_deconv.cpp | 102 +++++ 4 files changed, 821 insertions(+) create mode 100644 oneDNN-3.4/src/cpu/aarch64/kdnn/jit/kdnn_jit_sve_256_1x1_convolution_f16.cpp create mode 100644 oneDNN-3.4/src/cpu/aarch64/kdnn/jit/kdnn_jit_sve_256_1x1_convolution_f16.hpp create mode 100644 oneDNN-3.4/src/cpu/aarch64/kdnn/kdnn_utils_conv.cpp create mode 100644 oneDNN-3.4/src/cpu/aarch64/kdnn/kdnn_utils_deconv.cpp diff --git a/oneDNN-3.4/src/cpu/aarch64/kdnn/jit/kdnn_jit_sve_256_1x1_convolution_f16.cpp b/oneDNN-3.4/src/cpu/aarch64/kdnn/jit/kdnn_jit_sve_256_1x1_convolution_f16.cpp new file mode 100644 index 0000000..13b8686 --- /dev/null +++ b/oneDNN-3.4/src/cpu/aarch64/kdnn/jit/kdnn_jit_sve_256_1x1_convolution_f16.cpp @@ -0,0 +1,381 @@ +/******************************************************************************* +* Copyright 2021-2023 Intel Corporation +* Copyright 2021-2023 FUJITSU LIMITED +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include "common/c_types_map.hpp" +#include "common/dnnl_thread.hpp" +#include "common/type_helpers.hpp" +#include "common/utils.hpp" +#include "cpu/aarch64/kdnn/jit/kdnn_jit_generator.hpp" + +#include "cpu/aarch64/kdnn/jit/kdnn_jit_sve_256_1x1_convolution_f16.hpp" + + + +namespace dnnl { +namespace impl { +namespace cpu { +namespace aarch64 { + +using namespace dnnl::impl::status; +using namespace dnnl::impl::memory_tracking::names; +using namespace dnnl::impl::utils; + +#define data_blk_off(f, n, c, d, h, w) \ + ((ndims == 3) ? (f).blk_off(n, c, w) \ + : ((ndims == 4) ? (f).blk_off(n, c, h, w) \ + : (f).blk_off(n, c, d, h, w))) +/* convolution forward */ + +template +void kdnn_jit_sve_256_1x1_convolution_fwd_f16_t::execute_forward(const exec_ctx_t &ctx) const { + const auto &jcp = kernel_->jcp; + auto src = CTX_IN_MEM(const src_data_t *, DNNL_ARG_SRC); + auto weights = CTX_IN_MEM(const wei_data_t *, DNNL_ARG_WEIGHTS); + auto bias = CTX_IN_MEM(const dst_data_t *, DNNL_ARG_BIAS); + auto dst = CTX_OUT_MEM(dst_data_t *, DNNL_ARG_DST); + auto weights_dw = CTX_IN_MEM( + const wei_data_t *, DNNL_ARG_ATTR_POST_OP_DW | DNNL_ARG_WEIGHTS); + auto bias_dw = CTX_IN_MEM( + const dst_data_t *, DNNL_ARG_ATTR_POST_OP_DW | DNNL_ARG_BIAS); + const auto post_ops_binary_rhs_arg_vec + = binary_injector::prepare_binary_args(pd()->jcp_.post_ops, ctx); + + auto scratchpad = ctx.get_scratchpad_grantor(); + + if (pd()->wants_padded_bias()) { + auto padded_bias + = scratchpad.template get(key_conv_padded_bias); + utils::array_copy(padded_bias, bias, jcp.oc_without_padding); + utils::array_set(padded_bias + jcp.oc_without_padding, 0.f, + jcp.oc - jcp.oc_without_padding); + bias = padded_bias; + } + + auto dat_tag_nxc = jcp.ndims == 3 ? memory::format_tag::nwc : memory::format_tag::nhwc; + auto dat_tag_ncx = jcp.ndims == 3 ? memory::format_tag::ncw : memory::format_tag::nchw; + const int N = jcp.mb, IH = jcp.ih * pd()->src_stride_, IW = jcp.iw * pd()->src_stride_, IC = jcp.ic, OC = jcp.oc, OH = jcp.oh, OW = jcp.ow; + const memory::dims conv_src_2d_sizes = {N, IC, IW}, conv_src_3d_sizes = {N, IC, IH, IW}; + const memory::dims conv_dst_2d_sizes = {N, OC, OW}, conv_dst_3d_sizes = {N, OC, OH, OW}; + engine eng(engine::kind::cpu, 0); + stream s(eng); + + if(is_src_need_change_layout_) {// if src tag is ncx, need to change layout to nxc + void* p_src_data = nullptr; + auto reorder_src_func = [&](const memory::dims& conv_src_sizes){ // do reoder for src + auto src_mem = memory({conv_src_sizes, memory::data_type::f16, dat_tag_ncx}, eng); + auto src_mem_changed = memory({conv_src_sizes, memory::data_type::f16, dat_tag_nxc}, eng); + src_mem.set_data_handle((void*)const_cast(src)); + + auto reorder_src = reorder(src_mem, src_mem_changed); + reorder_src.execute( + s, {{DNNL_ARG_FROM, src_mem}, {DNNL_ARG_TO, src_mem_changed}}); + s.wait(); // wait for the reorder to complete + p_src_data = src_mem_changed.get_data_handle(); + assert(p_src_data != nullptr); + if(is_dst_need_change_layout_) {//for input = ncx & output = ncx + auto conv_dst_sizes = jcp.ndims == 3 ? conv_dst_2d_sizes : conv_dst_3d_sizes; + auto dst_mem_changed = memory({conv_dst_sizes, memory::data_type::f16, dat_tag_nxc}, eng); + void* p_dst_data = dst_mem_changed.get_data_handle(); + assert(p_dst_data != nullptr); + parallel(jcp.nthr, [&](const int ithr, const int nthr) { + execute_forward_thr(ithr, nthr, (const src_data_t *)p_src_data, weights, bias, weights_dw, bias_dw, + (src_data_t*)p_dst_data, scratchpad, post_ops_binary_rhs_arg_vec.data(), + nullptr/*post_ops_binary_rhs_arg_vec_dw.data()*/); + }); + auto dst_mem = memory({conv_dst_sizes, memory::data_type::f16, dat_tag_ncx}, eng); + dst_mem.set_data_handle((void*)dst); + auto reorder_dst = reorder(dst_mem_changed, dst_mem); + reorder_dst.execute( + s, {{DNNL_ARG_FROM, dst_mem_changed}, {DNNL_ARG_TO, dst_mem}}); + s.wait(); // wait for the reorder to complete + } + else {//for input = ncx, but output != ncx + parallel(jcp.nthr, [&](const int ithr, const int nthr) { + execute_forward_thr(ithr, nthr, (const src_data_t *)p_src_data, weights, bias, weights_dw, bias_dw, + dst, scratchpad, post_ops_binary_rhs_arg_vec.data(), + nullptr/*post_ops_binary_rhs_arg_vec_dw.data()*/); + }); + } + }; + jcp.ndims == 3 ? reorder_src_func(conv_src_2d_sizes) : reorder_src_func(conv_src_3d_sizes); + } + else { + parallel(jcp.nthr, [&](const int ithr, const int nthr) { + execute_forward_thr(ithr, nthr, src, weights, bias, weights_dw, bias_dw, + dst, scratchpad, post_ops_binary_rhs_arg_vec.data(), + nullptr/*post_ops_binary_rhs_arg_vec_dw.data()*/); + }); + } + + if (pd()->wants_zero_pad_dst()) ctx.zero_pad_output(DNNL_ARG_DST); +} + +template +void kdnn_jit_sve_256_1x1_convolution_fwd_f16_t::execute_forward_thr(const int ithr, const int nthr, + const src_data_t *src, const wei_data_t *weights, + const dst_data_t *bias, const wei_data_t *weights_dw, + const dst_data_t *bias_dw, dst_data_t *dst, + const memory_tracking::grantor_t &scratchpad, + const void *post_ops_binary_rhs_arg_vec, + const void *post_ops_binary_rhs_arg_vec_dw) const { + const memory_desc_wrapper weights_d(pd()->weights_md(0)); + const memory_desc_wrapper dw_weights_d( + pd()->arg_md(DNNL_ARG_ATTR_POST_OP_DW | DNNL_ARG_WEIGHTS)); + const memory_desc_wrapper dw_bias_d( + pd()->arg_md(DNNL_ARG_ATTR_POST_OP_DW | DNNL_ARG_BIAS)); + + const auto &jcp = kernel_->jcp; + auto rtus_space = pd()->rtus_.reduce_src_ + ? scratchpad.get(key_conv_rtus_space) + : nullptr; + dnnl_memory_desc_t conv_src_md, conv_dst_md; + if(is_src_need_change_layout_) { + auto src_dnnl_tag = pd()->src_md()->ndims == 4 ? dnnl_nhwc : dnnl_nwc; + dnnl_memory_desc_create_with_tag(&conv_src_md, pd()->src_md()->ndims, + pd()->src_md()->dims, dnnl_f16, src_dnnl_tag); //only surpport 2d or 3d + } + else { + conv_src_md = const_cast(pd()->src_md()); + } + + if(is_dst_need_change_layout_) { + auto dst_dnnl_tag = pd()->src_md()->ndims == 4 ? dnnl_nhwc : dnnl_nwc; + dnnl_memory_desc_create_with_tag(&conv_dst_md, pd()->dst_md()->ndims, + pd()->dst_md()->dims, dnnl_f16, dst_dnnl_tag); //only surpport 2d or 3d + } + else { + conv_dst_md = const_cast(pd()->dst_md()); + } + + const memory_desc_wrapper src_d(conv_src_md); + const memory_desc_wrapper dst_d(conv_dst_md); + const int ndims = conv_src_md->ndims; + const int stride_d = (ndims == 5) ? pd()->desc()->strides[0] : 1; + const int stride_h = (ndims == 3) ? 1 : pd()->desc()->strides[ndims - 4]; + const int stride_w = pd()->desc()->strides[ndims - 3]; + + auto step = [](int default_step, int remaining, int tail_step) { + assert(default_step <= tail_step); + return remaining < tail_step ? remaining : default_step; + }; + auto p = kdnn_jit_1x1_conv_call_s(); + + auto rp = kdnn_rtus_driver_t::call_params_t(); + const int nb_oc = jcp.nb_load; + const int nb_ic = jcp.nb_reduce; + const int nb_ic_blocking = jcp.nb_reduce_blocking; + + // override some constants for fused dw_conv + const int os_block = jcp.with_dw_conv ? jcp.ow : jcp.bcast_block; + const int nb_bcast = jcp.with_dw_conv ? jcp.oh : jcp.nb_bcast; + const int nb_bcast_blocking = jcp.with_dw_conv ? 1 : jcp.nb_bcast_blocking; + const int nb_bcast_blocking_max + = jcp.with_dw_conv ? 1 : jcp.nb_bcast_blocking_max; + const int nb_load_blocking = jcp.nb_load_blocking; + const int nb_load_blocking_max = jcp.with_dw_conv + ? jcp.nb_load_blocking + : jcp.nb_load_blocking_max; + const bool is_dst_layout_nxc = utils::one_of( + jcp.dst_tag, format_tag::nwc, format_tag::nhwc, format_tag::ndhwc); + const bool is_src_layout_nxc = utils::one_of( + jcp.src_tag, format_tag::nwc, format_tag::nhwc, format_tag::ndhwc); + + auto init_bcast = [&](int iwork, int bcast_end, int &n, int &g, + int &bcast_step, int &od, int &oh, int &ow, + int &id, int &ih, int &iw) { + int osb {0}; + nd_iterator_init(iwork, n, jcp.mb, g, jcp.ngroups, osb, nb_bcast); + bcast_step = step( + nb_bcast_blocking, nb_bcast - osb, nb_bcast_blocking_max); + bcast_step = nstl::min(bcast_step, bcast_end - iwork); + + const int os = osb * os_block; + od = os / (jcp.oh * jcp.ow); + int os_2d = os % (jcp.oh * jcp.ow); + oh = os_2d / jcp.ow; + ow = os_2d % jcp.ow; + + id = od * stride_d; + ih = oh * stride_h; + iw = ow * stride_w; + rp.iw_start = iw; + + p.bcast_dim = this_block_size(os, jcp.os, bcast_step * os_block); + rp.os = p.bcast_dim; + }; + + auto init_load = [&](int ocb, int ocb_end, int &load_step) { + load_step = step(nb_load_blocking, ocb_end - ocb, nb_load_blocking_max); + const auto max_oc + = nstl::min(ocb_end * jcp.oc_block, jcp.oc_without_padding); + p.load_dim = this_block_size( + ocb * jcp.oc_block, max_oc, load_step * jcp.oc_block); + }; + + auto init_reduce = [&](int icb) { + const int nb_ic_blocking_step + = nstl::min(icb + nb_ic_blocking, nb_ic) - icb; + p.first_last_flag = 0 | (icb == 0 ? kdnn_FLAG_REDUCE_FIRST : 0) + | (icb + nb_ic_blocking_step >= nb_ic ? kdnn_FLAG_REDUCE_LAST : 0); + + p.reduce_dim = this_block_size( + icb * jcp.ic_block, jcp.ic, nb_ic_blocking_step * jcp.ic_block); + rp.icb = p.reduce_dim; + }; + + auto ker_1x1 = [&](int ocb, int ocb_start, int icb, int n, int g, int od, + int oh, int ow, int id, int ih, int iw) { + const int oc_off_idx = is_dst_layout_nxc + ? g * jcp.oc + ocb * jcp.oc_block + : g * nb_oc + ocb; + const size_t dst_off = data_blk_off(dst_d, n, oc_off_idx, od, oh, ow); + + p.output_data = &dst[dst_off]; + + p.bias_data = bias + ? &bias[oc_off_idx * (is_dst_layout_nxc ? 1 : jcp.oc_block)] + : nullptr; + + p.load_data + = &weights[pd()->with_groups() ? weights_d.blk_off(g, ocb, icb) + : weights_d.blk_off(ocb, icb)]; + + const int ic_off_idx = is_src_layout_nxc + ? g * jcp.ic + icb * jcp.ic_block + : g * nb_ic + icb; + if (pd()->rtus_.reduce_src_) { + rp.ws = rtus_space + ithr * pd()->rtus_.space_per_thread_ + + (is_src_layout_nxc ? ic_off_idx + : jcp.is * ic_off_idx * jcp.ic_block); + if (ocb == ocb_start) { + rp.src = src + data_blk_off(src_d, n, ic_off_idx, id, ih, iw); + (*rtus_driver_)(&rp); + } + p.bcast_data = rp.ws; + } else + p.bcast_data = src + data_blk_off(src_d, n, ic_off_idx, id, ih, iw); + + p.oc_l_off = oc_off_idx * (is_dst_layout_nxc ? 1 : jcp.oc_block); + p.post_ops_binary_rhs_arg_vec = post_ops_binary_rhs_arg_vec; + p.dst_orig = dst; + + (*kernel_)(&p); + }; + auto conv_1x1 = [&](int bcast_start, int bcast_end, int ocb_start, + int ocb_end) { + if (bcast_start >= bcast_end || ocb_start >= ocb_end) return; + + if (jcp.loop_order == kdnn_loop_rlb) { + for (int icb = 0; icb < nb_ic; icb += nb_ic_blocking) { + init_reduce(icb); + int ocb = ocb_start; + while (ocb < ocb_end) { + int load_step; + init_load(ocb, ocb_end, load_step); + int iwork = bcast_start; + while (iwork < bcast_end) { + int n {0}, g {0}, bcast_step {0}, od {0}, oh {0}, + ow {0}, id {0}, ih {0}, iw {0}; + init_bcast(iwork, bcast_end, n, g, bcast_step, od, oh, + ow, id, ih, iw); + ker_1x1(ocb, ocb_start, icb, n, g, od, oh, ow, id, ih, + iw); + iwork += bcast_step; + } + ocb += load_step; + } + } + } else if (jcp.loop_order == kdnn_loop_lbr) { + int ocb = ocb_start; + while (ocb < ocb_end) { + int load_step; + init_load(ocb, ocb_end, load_step); + int iwork = bcast_start; + while (iwork < bcast_end) { + int n {0}, g {0}, bcast_step {0}, od {0}, oh {0}, ow {0}, + id {0}, ih {0}, iw {0}; + init_bcast(iwork, bcast_end, n, g, bcast_step, od, oh, ow, + id, ih, iw); + for (int icb = 0; icb < nb_ic; icb += nb_ic_blocking) { + init_reduce(icb); + ker_1x1(ocb, ocb_start, icb, n, g, od, oh, ow, id, ih, + iw); + } + iwork += bcast_step; + } + ocb += load_step; + } + } else if (jcp.loop_order == kdnn_loop_rbl) { + for (int icb = 0; icb < nb_ic; icb += nb_ic_blocking) { + init_reduce(icb); + int iwork = bcast_start; + while (iwork < bcast_end) { + int n {0}, g {0}, bcast_step {0}, od {0}, oh {0}, ow {0}, + id {0}, ih {0}, iw {0}; + init_bcast(iwork, bcast_end, n, g, bcast_step, od, oh, ow, + id, ih, iw); + int ocb = ocb_start; + while (ocb < ocb_end) { + int load_step; + init_load(ocb, ocb_end, load_step); + ker_1x1(ocb, ocb_start, icb, n, g, od, oh, ow, id, ih, + iw); + ocb += load_step; + } + iwork += bcast_step; + } + } + } else if (jcp.loop_order == kdnn_loop_blr) { + int iwork = bcast_start; + while (iwork < bcast_end) { + int n {0}, g {0}, bcast_step {0}, od {0}, oh {0}, ow {0}, + id {0}, ih {0}, iw {0}; + init_bcast(iwork, bcast_end, n, g, bcast_step, od, oh, ow, id, + ih, iw); + int ocb = ocb_start; + while (ocb < ocb_end) { + int load_step; + init_load(ocb, ocb_end, load_step); + for (int icb = 0; icb < nb_ic; icb += nb_ic_blocking) { + init_reduce(icb); + ker_1x1(ocb, ocb_start, icb, n, g, od, oh, ow, id, ih, + iw); + } + ocb += load_step; + } + iwork += bcast_step; + } + } else { + assert(!"unsupported loop order"); + } + }; + + const int work_amount = jcp.mb * jcp.ngroups * jcp.nb_bcast; + int bcast_start {0}, bcast_end {0}, ocb_start {0}, ocb_end {0}; + balance2D(nthr, ithr, work_amount, bcast_start, bcast_end, jcp.nb_load, + ocb_start, ocb_end, jcp.load_grp_count); + + conv_1x1(bcast_start, bcast_end, ocb_start, ocb_end); +} +template struct kdnn_jit_sve_256_1x1_convolution_fwd_f16_t; + +} // namespace aarch64 +} // namespace cpu +} // namespace impl +} // namespace dnnl diff --git a/oneDNN-3.4/src/cpu/aarch64/kdnn/jit/kdnn_jit_sve_256_1x1_convolution_f16.hpp b/oneDNN-3.4/src/cpu/aarch64/kdnn/jit/kdnn_jit_sve_256_1x1_convolution_f16.hpp new file mode 100644 index 0000000..35e49bf --- /dev/null +++ b/oneDNN-3.4/src/cpu/aarch64/kdnn/jit/kdnn_jit_sve_256_1x1_convolution_f16.hpp @@ -0,0 +1,234 @@ +/******************************************************************************* +* Copyright 2021-2023 Intel Corporation +* Copyright 2021-2023 FUJITSU LIMITED +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef CPU_AARCH64_JIT_SVE_256_1X1_CONVOLUTION_F16_HPP +#define CPU_AARCH64_JIT_SVE_256_1X1_CONVOLUTION_F16_HPP + +#include "common/c_types_map.hpp" +#include "common/dnnl_thread.hpp" +#include "common/memory_tracking.hpp" +#include "common/primitive.hpp" +#include "common/primitive_hashing.hpp" +#include "common/utils.hpp" + +#include "cpu/cpu_convolution_pd.hpp" +#include "cpu/platform.hpp" + +#include "cpu/aarch64/cpu_reducer.hpp" +#include "cpu/aarch64/kdnn/jit/kdnn_jit_sve_256_1x1_conv_kernel_f16.hpp" +#include "cpu/aarch64/kdnn/jit/kdnn_jit_uni_1x1_conv_utils.hpp" + +#include "oneapi/dnnl/dnnl.hpp" +#include "oneapi/dnnl/dnnl_types.h" + +namespace dnnl { +namespace impl { +namespace cpu { +namespace aarch64 { + +template +struct kdnn_jit_sve_256_1x1_convolution_fwd_f16_t : public primitive_t { + struct pd_t : public cpu_convolution_fwd_pd_t { + pd_t(const convolution_desc_t *adesc, const primitive_attr_t *attr, + const typename pd_t::base_class *hint_fwd_pd) + : cpu_convolution_fwd_pd_t(adesc, attr, hint_fwd_pd) + , jcp_() + , rtus_() + , src_stride_(1) {} + pd_t(const pd_t &other) : cpu_convolution_fwd_pd_t(other) { + if (copy(other) != status::success) is_initialized_ = false; + } + + DECLARE_COMMON_PD_T("kdnn_jit_1x1:", + kdnn_jit_sve_256_1x1_convolution_fwd_f16_t); + + status_t init(engine_t *engine) { + using namespace utils; + using namespace format_tag; + + bool ok = true && is_fwd() + && set_default_alg_kind(alg_kind::convolution_direct) + && expect_data_types(src_type, wei_type, dst_type, dst_type, + data_type::undef) + && attr()->has_default_values( + primitive_attr_t::skip_mask_t::post_ops, dst_type) + && !has_zero_dim_memory() && set_default_formats() + && attr_.set_default_formats(dst_md(0)) == status::success; + if (!ok) return status::unimplemented; + + const convolution_desc_t *conv_d = desc(); + const memory_desc_t *src_d = src_md(); + assert(ndims() >= 3); + src_stride_ = conv_d->strides[ndims() - 3]; //stride w, for h=w, need to get w stride + const memory_desc_wrapper src_desc_w(src_d); + const memory_desc_wrapper dst_desc_w(dst_md()); + + const auto dat_tag_ncx = utils::pick(ndims() - 3, ncw, nchw, ncdhw); + if(src_desc_w.matches_one_of_tag(dat_tag_ncx) == dat_tag_ncx && + dst_desc_w.matches_one_of_tag(dat_tag_ncx) == dat_tag_ncx) {//for layout is ncx, need change layout to nxc + dnnl_memory_desc_t conv_src_md, conv_dst_md; + auto dnnl_tag = src_md()->ndims == 4 ? dnnl_nhwc : dnnl_nwc; + dnnl_memory_desc_create_with_tag(&conv_src_md, src_md()->ndims, + src_md()->dims, dnnl_f16, dnnl_tag); + dnnl_memory_desc_create_with_tag(&conv_dst_md, dst_md()->ndims, + dst_md()->dims, dnnl_f16, dnnl_tag); + const memory_desc_t* p_changed_src = conv_src_md; + const memory_desc_t* p_changed_dst = conv_dst_md; + + kdnn_rtus_prepare(this, conv_d, p_changed_src, p_changed_dst); + const memory_desc_t* weights_d = weights_md(); + + CHECK(kdnn_jit_sve_256_1x1_conv_kernel_f16::init_conf(jcp_, *conv_d, const_cast(*p_changed_src), + const_cast(*weights_d), const_cast(*p_changed_dst), *attr(), dnnl_get_max_threads(), + rtus_.reduce_src_)); + } + else { + kdnn_rtus_prepare(this, conv_d, src_d, dst_md()); + const memory_desc_t* dst_d = dst_md(); + const memory_desc_t* weights_d = weights_md(); + + CHECK(kdnn_jit_sve_256_1x1_conv_kernel_f16::init_conf(jcp_, *conv_d, const_cast(*src_d), + const_cast(*weights_d), const_cast(*dst_d), *attr(), dnnl_get_max_threads(), + rtus_.reduce_src_)); + } + auto scratchpad = scratchpad_registry().registrar(); + kdnn_jit_sve_256_1x1_conv_kernel_f16::init_scratchpad(scratchpad, (const kdnn_jit_1x1_conv_conf_t &)jcp_); + + kdnn_rtus_prepare_space_info(this, scratchpad, jcp_.nthr); + return status::success; + } + + const memory_desc_t *dst_md( + int index = 0, bool user_input = false) const override { + return cpu_convolution_fwd_pd_t::dst_md(index, user_input); + } + + const memory_desc_t *arg_md( + int arg, bool user_input = false) const override { + return convolution_fwd_pd_t::arg_md(arg, user_input); + } + + arg_usage_t arg_usage(int arg) const override { + if (arg == (DNNL_ARG_ATTR_POST_OP_DW | DNNL_ARG_WEIGHTS)) + return arg_usage_t::input; + + if (arg == (DNNL_ARG_ATTR_POST_OP_DW | DNNL_ARG_BIAS) + && attr_post_op_dw_inputs() > 1) + return arg_usage_t::input; + + return convolution_fwd_pd_t::arg_usage(arg); + } + + kdnn_jit_1x1_conv_conf_t jcp_; + kdnn_reduce_to_unit_stride_t rtus_; + int src_stride_; + + protected: + bool set_default_formats() { + using namespace format_tag; + + const memory_desc_wrapper src_d(&src_md_); + const memory_desc_wrapper dst_d(&dst_md_); + + const auto dat_tag_nxc = utils::pick(ndims() - 3, nwc, nhwc, ndhwc); + const auto dat_tag_nCx16c + = utils::pick(ndims() - 3, nCw16c, nChw16c, nCdhw16c); + const auto curr_src_tag + = src_d.matches_one_of_tag(dat_tag_nxc, dat_tag_nCx16c); + const auto curr_dst_tag + = dst_d.matches_one_of_tag(dat_tag_nxc, dat_tag_nCx16c); + const auto is_data_layout_nxc + = IMPLICATION(curr_src_tag != dat_tag_nxc, + src_d.format_kind() == format_kind::any) + && IMPLICATION(curr_dst_tag != dat_tag_nxc, + dst_d.format_kind() == format_kind::any) + && utils::one_of(dat_tag_nxc, curr_src_tag, curr_dst_tag); + auto dat_tag = is_data_layout_nxc ? dat_tag_nxc : dat_tag_nCx16c; + auto wei_tag = with_groups() + ? utils::pick(ndims() - 3, gOIw16i16o, gOIhw16i16o, gOIdhw16i16o) + : utils::pick(ndims() - 3, OIw16i16o, OIhw16i16o, OIdhw16i16o); + + return set_default_formats_common(dat_tag, wei_tag, dat_tag); + } + status_t copy(const pd_t &other) { + jcp_ = other.jcp_; + rtus_ = other.rtus_; + src_stride_ = other.src_stride_; + return status::success; + } + }; + + template + friend status_t kdnn_init_rtus_driver(conv_t *self); + + kdnn_jit_sve_256_1x1_convolution_fwd_f16_t(const pd_t *apd) + : primitive_t(apd) + , is_src_need_change_layout_(false) + , is_dst_need_change_layout_(false) {} + + typedef typename prec_traits::type src_data_t; + typedef typename prec_traits::type wei_data_t; + typedef typename prec_traits::type dst_data_t; + + status_t init(engine_t *engine) override { + using namespace format_tag; + CHECK(safe_ptr_assign(kernel_, + new kdnn_jit_sve_256_1x1_conv_kernel_f16( + pd()->jcp_, *pd()->attr(), *pd()->dst_md(0)))); + CHECK(kernel_->create_kernel()); + + CHECK(kdnn_init_rtus_driver(this)); + + const memory_desc_wrapper src_d(pd()->src_md()), dst_d(pd()->dst_md()); + const auto dat_tag_ncx = utils::pick(pd()->ndims() -3, ncw, nchw, ncdhw); + if(src_d.matches_one_of_tag(dat_tag_ncx) == dat_tag_ncx) is_src_need_change_layout_ = true; + if(dst_d.matches_one_of_tag(dat_tag_ncx) == dat_tag_ncx) is_dst_need_change_layout_ = true; + + return status::success; + } + + status_t execute(const exec_ctx_t &ctx) const override { + execute_forward(ctx); + return status::success; + } + +private: + void execute_forward(const exec_ctx_t &ctx) const; + void execute_forward_thr(const int ithr, const int nthr, + const src_data_t *src, const wei_data_t *weights, + const dst_data_t *bias, const wei_data_t *weights_dw, + const dst_data_t *bias_dw, dst_data_t *dst, + const memory_tracking::grantor_t &scratchpad, + const void *post_ops_binary_rhs_arg_vec, + const void *post_ops_binary_rhs_arg_vec_dw) const; + const pd_t *pd() const { return (const pd_t *)primitive_t::pd().get(); } + + std::unique_ptr kernel_; + std::unique_ptr> rtus_driver_; + std::atomic is_src_need_change_layout_, is_dst_need_change_layout_; +}; + +using kdnn_jit_sve_256_1x1_convolution_fwd_f16 + = kdnn_jit_sve_256_1x1_convolution_fwd_f16_t; + +} // namespace aarch64 +} // namespace cpu +} // namespace impl +} // namespace dnnl + +#endif diff --git a/oneDNN-3.4/src/cpu/aarch64/kdnn/kdnn_utils_conv.cpp b/oneDNN-3.4/src/cpu/aarch64/kdnn/kdnn_utils_conv.cpp new file mode 100644 index 0000000..d8efff3 --- /dev/null +++ b/oneDNN-3.4/src/cpu/aarch64/kdnn/kdnn_utils_conv.cpp @@ -0,0 +1,104 @@ +#include "kdnn.hpp" + +#include "cpu/aarch64/kdnn/kdnn_utils.hpp" +#include "common/type_helpers.hpp" + +namespace dnnl { +namespace impl { +namespace cpu { +namespace aarch64 { + +namespace kdnn_utils { + +using namespace dnnl::impl::alg_kind; +using namespace data_type; + +std::pair convert_to_kdnn_conv_fwd(const memory_desc_wrapper& mem_desc_src, + const memory_desc_wrapper& mem_desc_wei, const memory_desc_wrapper& mem_desc_dst, + const memory_desc_wrapper& mem_desc_bia, const convolution_desc_t &cd, const alg_kind_t &alg) noexcept(false) { + if (!common_tensor_checks(mem_desc_src, mem_desc_wei, mem_desc_dst)) { + return {false, nullptr}; + } + + KDNN::TensorInfo src = get_kdnn_tensor_info(mem_desc_src); + KDNN::TensorInfo wei = get_kdnn_tensor_info(mem_desc_wei); + KDNN::TensorInfo dst = get_kdnn_tensor_info(mem_desc_dst); + KDNN::TensorInfo bias = KDNN::TensorInfo({0}, wei.GetType(), KDNN::Layout::A); + if (mem_desc_bia != &glob_zero_md) { + if (!common_tensor_checks(mem_desc_bia)) { + return {false, nullptr}; + } + bias = get_kdnn_tensor_info(mem_desc_bia); + } + KDNN::Shape strides = get_kdnn_shape(cd.strides, mem_desc_src.ndims() - 2); + KDNN::Shape paddingL = get_kdnn_shape(cd.padding[0], mem_desc_src.ndims() - 2); + KDNN::Shape paddingR = get_kdnn_shape(cd.padding[1], mem_desc_src.ndims() - 2); + KDNN::Shape dilates = get_kdnn_shape(cd.dilates, mem_desc_src.ndims() - 2); + if (KDNN::Status::SUCCESS != KDNN::ConvolutionLayerFWD::ValidateInput(src, wei, dst, bias, + strides, dilates, paddingL, paddingR, get_kdnn_conv_alg(alg))) { + return {false, nullptr}; + } else { + return {true, new KDNN::ConvolutionLayerFWD{src, wei, dst, bias, + strides, dilates, paddingL, paddingR, get_kdnn_conv_alg(alg)}}; + } +} + +std::pair convert_to_kdnn_conv_bwd_data(const memory_desc_wrapper& mem_desc_diff_dst, + const memory_desc_wrapper& mem_desc_wei, const memory_desc_wrapper& mem_desc_diff_src, + const convolution_desc_t &cd, const alg_kind_t &alg) noexcept(false) { + if (!common_tensor_checks(mem_desc_diff_dst, mem_desc_wei, mem_desc_diff_src)) { + return {false, nullptr}; + } + KDNN::TensorInfo diff_dst = get_kdnn_tensor_info(mem_desc_diff_dst); + KDNN::TensorInfo wei = get_kdnn_tensor_info(mem_desc_wei); + KDNN::TensorInfo diff_src = get_kdnn_tensor_info(mem_desc_diff_src); + KDNN::Shape strides = get_kdnn_shape(cd.strides, mem_desc_diff_dst.ndims() - 2); + KDNN::Shape paddingL = get_kdnn_shape(cd.padding[0], mem_desc_diff_dst.ndims() - 2); + KDNN::Shape paddingR = get_kdnn_shape(cd.padding[1], mem_desc_diff_dst.ndims() - 2); + KDNN::Shape dilates = get_kdnn_shape(cd.dilates, mem_desc_diff_dst.ndims() - 2); + if (KDNN::Status::SUCCESS != KDNN::ConvolutionLayerBWDData::ValidateInput(diff_dst, wei, diff_src, + strides, dilates, paddingL, paddingR, get_kdnn_conv_alg(alg))) { + return {false, nullptr}; + } else { + return {true, new KDNN::ConvolutionLayerBWDData{diff_dst, wei, diff_src, + strides, dilates, paddingL, paddingR, get_kdnn_conv_alg(alg)}}; + } +} + +std::pair convert_to_kdnn_conv_bwd_weights(const memory_desc_wrapper& mem_desc_diff_dst, + const memory_desc_wrapper& mem_desc_src, const memory_desc_wrapper& mem_desc_diff_wei, + const memory_desc_wrapper& mem_desc_diff_bia, const convolution_desc_t &cd, const alg_kind_t &alg) noexcept(false) +{ + if (!common_tensor_checks(mem_desc_diff_dst, mem_desc_src, mem_desc_diff_wei)) { + return {false, nullptr}; + } + + KDNN::TensorInfo diff_dst = get_kdnn_tensor_info(mem_desc_diff_dst); + KDNN::TensorInfo src = get_kdnn_tensor_info(mem_desc_src); + KDNN::TensorInfo diff_wei = get_kdnn_tensor_info(mem_desc_diff_wei); + KDNN::TensorInfo diff_bias = KDNN::TensorInfo({0}, diff_wei.GetType(), KDNN::Layout::A); + if (mem_desc_diff_bia != &glob_zero_md) { + if (!common_tensor_checks(mem_desc_diff_bia)) { + return {false, nullptr}; + } + diff_bias = get_kdnn_tensor_info(mem_desc_diff_bia); + } + KDNN::Shape strides = get_kdnn_shape(cd.strides, mem_desc_src.ndims() - 2); + KDNN::Shape paddingL = get_kdnn_shape(cd.padding[0], mem_desc_src.ndims() - 2); + KDNN::Shape paddingR = get_kdnn_shape(cd.padding[1], mem_desc_src.ndims() - 2); + KDNN::Shape dilates = get_kdnn_shape(cd.dilates, mem_desc_src.ndims() - 2); + if (KDNN::Status::SUCCESS != KDNN::ConvolutionLayerBWDWeights::ValidateInput(diff_dst, src, diff_wei, diff_bias, + strides, dilates, paddingL, paddingR, get_kdnn_conv_alg(alg))) { + return {false, nullptr}; + } else { + return {true, new KDNN::ConvolutionLayerBWDWeights{diff_dst, src, diff_wei, diff_bias, + strides, dilates, paddingL, paddingR, get_kdnn_conv_alg(alg)}}; + } +} + +} // namespace kdnn_utils + +} // namespace aarch64 +} // namespace cpu +} // namespace impl +} // namespace dnnl diff --git a/oneDNN-3.4/src/cpu/aarch64/kdnn/kdnn_utils_deconv.cpp b/oneDNN-3.4/src/cpu/aarch64/kdnn/kdnn_utils_deconv.cpp new file mode 100644 index 0000000..c58dd94 --- /dev/null +++ b/oneDNN-3.4/src/cpu/aarch64/kdnn/kdnn_utils_deconv.cpp @@ -0,0 +1,102 @@ +#include "kdnn.hpp" + +#include "cpu/aarch64/kdnn/kdnn_utils.hpp" +#include "common/type_helpers.hpp" + +namespace dnnl { +namespace impl { +namespace cpu { +namespace aarch64 { + +namespace kdnn_utils { + +using namespace dnnl::impl::alg_kind; +using namespace data_type; + +std::pair convert_to_kdnn_deconv_fwd(const memory_desc_wrapper& mem_desc_src, + const memory_desc_wrapper& mem_desc_wei, const memory_desc_wrapper& mem_desc_dst, + const memory_desc_wrapper& mem_desc_bia, const deconvolution_desc_t &dd, const alg_kind_t &alg) noexcept(false) { + if (!common_tensor_checks(mem_desc_src, mem_desc_wei, mem_desc_dst)) { + return {false, nullptr}; + } + KDNN::TensorInfo src = get_kdnn_tensor_info(mem_desc_src); + KDNN::TensorInfo wei = get_kdnn_tensor_info(mem_desc_wei); + KDNN::TensorInfo dst = get_kdnn_tensor_info(mem_desc_dst); + KDNN::TensorInfo bias = KDNN::TensorInfo({0}, wei.GetType(), KDNN::Layout::A); + if (mem_desc_bia != &glob_zero_md) { + if (!common_tensor_checks(mem_desc_bia)) { + return {false, nullptr}; + } + bias = get_kdnn_tensor_info(mem_desc_bia); + } + KDNN::Shape strides = get_kdnn_shape(dd.strides, mem_desc_src.ndims() - 2); + KDNN::Shape paddingL = get_kdnn_shape(dd.padding[0], mem_desc_src.ndims() - 2); + KDNN::Shape paddingR = get_kdnn_shape(dd.padding[1], mem_desc_src.ndims() - 2); + KDNN::Shape dilates = get_kdnn_shape(dd.dilates, mem_desc_src.ndims() - 2); + if (KDNN::Status::SUCCESS != KDNN::DeconvolutionLayerFWD::ValidateInput(src, wei, dst, bias, + strides, dilates, paddingL, paddingR, get_kdnn_deconv_alg(alg))) { + return {false, nullptr}; + } else { + return {true, new KDNN::DeconvolutionLayerFWD{src, wei, dst, bias, + strides, dilates, paddingL, paddingR, get_kdnn_deconv_alg(alg)}}; + } +} + +std::pair convert_to_kdnn_deconv_bwd_data(const memory_desc_wrapper& mem_desc_diff_dst, + const memory_desc_wrapper& mem_desc_wei, const memory_desc_wrapper& mem_desc_diff_src, + const deconvolution_desc_t &dd, const alg_kind_t &alg) noexcept(false) { + if (!common_tensor_checks(mem_desc_diff_dst, mem_desc_wei, mem_desc_diff_src)) { + return {false, nullptr}; + } + KDNN::TensorInfo diff_dst = get_kdnn_tensor_info(mem_desc_diff_dst); + KDNN::TensorInfo wei = get_kdnn_tensor_info(mem_desc_wei); + KDNN::TensorInfo diff_src = get_kdnn_tensor_info(mem_desc_diff_src); + KDNN::Shape strides = get_kdnn_shape(dd.strides, mem_desc_diff_src.ndims() - 2); + KDNN::Shape paddingL = get_kdnn_shape(dd.padding[0], mem_desc_diff_src.ndims() - 2); + KDNN::Shape paddingR = get_kdnn_shape(dd.padding[1], mem_desc_diff_src.ndims() - 2); + KDNN::Shape dilates = get_kdnn_shape(dd.dilates, mem_desc_diff_src.ndims() - 2); + if (KDNN::Status::SUCCESS != KDNN::DeconvolutionLayerBWDData::ValidateInput(diff_dst, + wei, diff_src, strides, dilates, paddingL, paddingR, get_kdnn_deconv_alg(alg))) { + return {false, nullptr}; + } else { + return {true, new KDNN::DeconvolutionLayerBWDData{diff_dst, wei, diff_src, + strides, dilates, paddingL, paddingR, get_kdnn_deconv_alg(alg)}}; + } +} + +std::pair convert_to_kdnn_deconv_bwd_weights(const memory_desc_wrapper& mem_desc_diff_dst, + const memory_desc_wrapper& mem_desc_src, const memory_desc_wrapper& mem_desc_diff_wei, + const memory_desc_wrapper& mem_desc_diff_bia, const deconvolution_desc_t &cd, const alg_kind_t &alg) noexcept(false) +{ + if (!common_tensor_checks(mem_desc_diff_dst, mem_desc_src, mem_desc_diff_wei)) { + return {false, nullptr}; + } + KDNN::TensorInfo diff_dst = get_kdnn_tensor_info(mem_desc_diff_dst); + KDNN::TensorInfo src = get_kdnn_tensor_info(mem_desc_src); + KDNN::TensorInfo diff_wei = get_kdnn_tensor_info(mem_desc_diff_wei); + KDNN::TensorInfo diff_bias = KDNN::TensorInfo({0}, diff_wei.GetType(), KDNN::Layout::A); + if (mem_desc_diff_bia != &glob_zero_md) { + if (!common_tensor_checks(mem_desc_diff_bia)) { + return {false, nullptr}; + } + diff_bias = get_kdnn_tensor_info(mem_desc_diff_bia); + } + KDNN::Shape strides = get_kdnn_shape(cd.strides, mem_desc_src.ndims() - 2); + KDNN::Shape paddingL = get_kdnn_shape(cd.padding[0], mem_desc_src.ndims() - 2); + KDNN::Shape paddingR = get_kdnn_shape(cd.padding[1], mem_desc_src.ndims() - 2); + KDNN::Shape dilates = get_kdnn_shape(cd.dilates, mem_desc_src.ndims() - 2); + if (KDNN::Status::SUCCESS != KDNN::DeconvolutionLayerBWDWeights::ValidateInput(diff_dst, + src, diff_wei, diff_bias, strides, dilates, paddingL, paddingR, get_kdnn_deconv_alg(alg))) { + return {false, nullptr}; + } else { + return {true, new KDNN::DeconvolutionLayerBWDWeights{diff_dst, src, diff_wei, diff_bias, + strides, dilates, paddingL, paddingR, get_kdnn_deconv_alg(alg)}}; + } +} + +} // namespace kdnn_utils + +} // namespace aarch64 +} // namespace cpu +} // namespace impl +} // namespace dnnl -- Gitee