From 4574e42e0e04eb712ce4226dc618518f5db665e7 Mon Sep 17 00:00:00 2001 From: Codersheepchen Date: Thu, 13 Jun 2024 07:52:09 +0000 Subject: [PATCH 1/2] cpp-2 Signed-off-by: Codersheepchen --- .../aarch64/kdnn/kdnn_jit_sve_conv_kernel.cpp | 678 +++++++++++++++++- 1 file changed, 677 insertions(+), 1 deletion(-) diff --git a/oneDNN-3.4/src/cpu/aarch64/kdnn/kdnn_jit_sve_conv_kernel.cpp b/oneDNN-3.4/src/cpu/aarch64/kdnn/kdnn_jit_sve_conv_kernel.cpp index b5946f4..d5babb5 100644 --- a/oneDNN-3.4/src/cpu/aarch64/kdnn/kdnn_jit_sve_conv_kernel.cpp +++ b/oneDNN-3.4/src/cpu/aarch64/kdnn/kdnn_jit_sve_conv_kernel.cpp @@ -519,4 +519,680 @@ void kdnn_jit_sve_conv_fwd_kernel::compute_loop_fma_core( if (icb_loop_in_compute_function) mov(aux_reg_ker_d, aux_reg_ker_d_org); mov(reg_out, reg_out_org); } -} \ No newline at end of file +} + +template +void kdnn_jit_sve_conv_fwd_kernel::compute_loop( + int ur_w, int pad_l, int pad_r) { + + if (jcp.ndims == 5) mov(reg_oi_org, reg_oi); + + prepare_output(ur_w); + + Label skip_compute_loop; + if (jcp.ndims == 5) { + if ((jcp.dilate_d >= jcp.id) + || (jcp.kd - 1) * (jcp.dilate_d + 1) + < nstl::max(jcp.f_pad, jcp.back_pad)) { + ldr(reg_kj, ptr(kdnn_abi_param1, KDNN_GET_OFF(kd_padding))); + cmp(reg_kj, 0); + b(LE, skip_compute_loop); + } + } + if ((jcp.dilate_h >= jcp.ih) + || (jcp.kh - 1) * (jcp.dilate_h + 1) + < nstl::max(jcp.t_pad, jcp.b_pad)) { + ldr(reg_kj, ptr(kdnn_abi_param1, KDNN_GET_OFF(kh_padding))); + cmp(reg_kj, 0); + b(LE, skip_compute_loop); + } + + Label ic_loop; + const bool generate_icb_loop = jcp.nb_ic > 1 && is_src_layout_nxc(); + if (generate_icb_loop) { + mov(reg_inp_org, reg_inp); + mov(reg_ker_org, reg_ker); + + ldr(reg_channel, ptr(param, KDNN_GET_OFF(reduce_work))); + L(ic_loop); + } + + if (jcp.ver == kdnn_ver_fma) + if (jcp.is_1stconv && jcp.kernel_kind != kdnn_expl_bcast) + assert(!"STOP:jcp.is_1stconv && jcp.kernel_kind != expl_bcast"); + else if (jcp.kernel_kind == kdnn_embd_bcast && jcp.nb_oc_blocking == 1) + assert(!"STOP:jcp.kernel_kind == embd_bcast && jcp.nb_oc_blocking " + "== 1"); + else { + compute_loop_fma_core(ur_w, pad_l, pad_r); + } + else + assert(!"unknown convolution version"); + + if (generate_icb_loop) { + assert(is_src_layout_nxc()); + const int inp_shift = jcp.ic_block * jcp.typesize_in; + add_imm(reg_inp, reg_inp, inp_shift, reg_tmp_imm); + const int ker_shift = jcp.kd * jcp.kh * jcp.kw * jcp.ic_block + * jcp.oc_block * jcp.typesize_in; + add_imm(reg_ker, reg_ker, ker_shift, reg_tmp_imm); + sub_imm(reg_channel, reg_channel, jcp.ic_block, reg_tmp_imm); + b(GT, ic_loop); + mov(reg_ker, reg_ker_org); + mov(reg_inp, reg_inp_org); + } + + L(skip_compute_loop); + store_output(ur_w); + if (jcp.ndims == 5) mov(reg_oi, reg_oi_org); +} + +template +void kdnn_jit_sve_conv_fwd_kernel::generate() { + int iw = jcp.iw; + int ow = jcp.ow; + int ow_block = jcp.ow_block; + int nb_ow = jcp.nb_ow; + int kw = jcp.kw; + int l_pad = jcp.l_pad; + int ur_w = jcp.ur_w; + int ur_w_tail = jcp.ur_w_tail; + int stride_w = jcp.stride_w; + + int inp_mult = is_src_layout_nxc() ? jcp.ngroups * jcp.ic + : (jcp.is_1stconv ? 1 : jcp.ic_block); + int inp_shift_pad = jcp.typesize_in * (ur_w * stride_w - l_pad) * inp_mult; + int inp_shift = jcp.typesize_in * ur_w * stride_w * inp_mult; + int inp_shift_pad_second_block = -1 * jcp.typesize_in * l_pad * inp_mult; + int out_shift = jcp.typesize_out * ur_w + * (is_dst_layout_nxc() ? jcp.ngroups * jcp.oc : jcp.oc_block); + + const int simd_w_ = cpu_isa_traits::vlen / sizeof(float); + + preamble(); + + //TO DO : renaming predicate register (P_ALL_ONE) + if (simd_w_ != cpu_sveLen / sizeof(float)) + set_preg(P_ALL_ONE.s, simd_w_, X_TMP_0, X_TMP_1); + + ldr(reg_inp, ptr(kdnn_abi_param1, KDNN_GET_OFF(src))); + ldr(reg_out, ptr(kdnn_abi_param1, KDNN_GET_OFF(dst))); + ldr(reg_ker, ptr(kdnn_abi_param1, KDNN_GET_OFF(filt))); + ldr(reg_kh, ptr(kdnn_abi_param1, KDNN_GET_OFF(kh_padding))); + if (jcp.ndims == 5) mov(aux_reg_ker_d_org, reg_ker); + + int r_pad = nstl::max(0, jcp.r_pad); + int n_oi = ow / ur_w; + int r_pad1 = calculate_end_padding(l_pad, ur_w * n_oi, iw, stride_w, + calculate_extended_filter_size(kw, jcp.dilate_w)); + + if (!is_ow_threading_on(jcp)) { // nb_ow <= 1 + // nb_ow is # of output width blocks ?? + + // ow is being processed as a whole - with left and right paddings + // n_oi is # of output width blocks ?? + if (r_pad1 > 0) n_oi--; + + if (ow == ur_w) { + ldr(reg_out_prf, ptr(kdnn_abi_param1, KDNN_GET_OFF(dst_prf))); + compute_loop(ur_w, l_pad, r_pad); + } else { + mov(reg_out_prf, reg_out); + if (n_oi == 0) { + add_imm(reg_out_prf, reg_out_prf, out_shift, reg_tmp_imm); + compute_loop(ur_w, l_pad, r_pad1); + add_imm(reg_inp, reg_inp, inp_shift_pad, reg_tmp_imm); + add_imm(reg_out, reg_out, out_shift, reg_tmp_imm); + if (ur_w_tail != 0) { + add_imm(reg_out_prf, reg_out_prf, out_shift, reg_tmp_imm); + compute_loop(ur_w_tail, 0, r_pad); + } + } else { + mov(reg_oi, 0); + if (l_pad > 0) { + add_imm(reg_out_prf, reg_out_prf, out_shift, reg_tmp_imm); + compute_loop(ur_w, l_pad, 0); + add_imm(reg_inp, reg_inp, inp_shift_pad, reg_tmp_imm); + add_imm(reg_out, reg_out, out_shift, reg_tmp_imm); + add_imm(reg_oi, reg_oi, 1, reg_tmp_imm); // increment + } + if ((l_pad <= 0 && n_oi > 0) || (l_pad > 0 && n_oi > 1)) { + Label ow_loop_label; + L(ow_loop_label); + { + add_imm(reg_out_prf, reg_out_prf, out_shift, + reg_tmp_imm); + compute_loop(ur_w, 0, 0); + + add_imm(reg_inp, reg_inp, inp_shift, reg_tmp_imm); + add_imm(reg_out, reg_out, out_shift, reg_tmp_imm); + add_imm(reg_oi, reg_oi, 1, reg_tmp_imm); //inc(reg_oi); + cmp_imm(reg_oi, n_oi, reg_tmp_imm); + + b(LT, ow_loop_label); + } + } + if (r_pad1 > 0) { + add_imm(reg_out_prf, reg_out_prf, out_shift, reg_tmp_imm); + compute_loop(ur_w, 0, r_pad1); + add_imm(reg_inp, reg_inp, inp_shift, reg_tmp_imm); + add_imm(reg_out, reg_out, out_shift, reg_tmp_imm); + } + if (ur_w_tail != 0) { + add_imm(reg_out_prf, reg_out_prf, out_shift, reg_tmp_imm); + compute_loop(ur_w_tail, 0, r_pad); + } + } + } + } else { + // ow block is only processed. + // Number of block is passed as parameter owb, + // and padding processing depends on this number. + + Label end_label, last_oi_label, middle_ow_blocks_label, tail_label; + Label oi_loop_label, oi_loop_start_label, oi_loop_end_label; + + assert(ow_block % ur_w == 0); + int n_oi_not_last_ow_block = ow_block / ur_w; + // to simplify code (and general regs usage), + // size of ow block must be >= 2 * ur_w + assert(n_oi_not_last_ow_block > 1); + int n_oi_next_last_ow_block = n_oi_not_last_ow_block; + int n_oi_first_ow_block = n_oi_not_last_ow_block; + + int n_oi_last_ow_block = (ow - ow_block * (nb_ow - 1)) / ur_w; + + // prepare right padding + bool next_last_ow_block_padded = r_pad1 > 0 && n_oi_last_ow_block == 0; + bool first_ow_block_padded + = next_last_ow_block_padded && jcp.nb_ow == 2; + bool last_ow_block_padded = r_pad1 > 0 && n_oi_last_ow_block > 0; + + if (last_ow_block_padded) + n_oi_last_ow_block--; + else if (first_ow_block_padded) + n_oi_first_ow_block--; + else if (next_last_ow_block_padded) + n_oi_next_last_ow_block--; + + ldr(reg_owb, ptr(kdnn_abi_param1, KDNN_GET_OFF(owb))); + cmp(reg_owb, 0); // is that the first ow-block ? + b(GT, middle_ow_blocks_label); + + // the first ow block, compute left padding + + mov(reg_oi, n_oi_first_ow_block); + mov(reg_out_prf, reg_out); + + if (l_pad > 0) { + add_imm(reg_out_prf, reg_out_prf, out_shift, reg_tmp_imm); + compute_loop(ur_w, l_pad, 0); + add_imm(reg_inp, reg_inp, inp_shift_pad, reg_tmp_imm); + add_imm(reg_out, reg_out, out_shift, reg_tmp_imm); + sub(reg_oi, reg_oi, 1); // decrement + cmp(reg_oi, 0); + } + b(oi_loop_label); + + // middle or last ow block entry + + L(middle_ow_blocks_label); + + if (l_pad > 0) { + // just to consider left padding, not compute + add_imm(reg_inp, reg_inp, inp_shift_pad_second_block, reg_tmp_imm); + } + + // set number of iteration for oi-loop + cmp_imm(reg_owb, jcp.nb_ow - 1, reg_tmp_imm); // last ow-block ? + mov(reg_oi, n_oi_last_ow_block); + b(EQ, oi_loop_label); + cmp_imm(reg_owb, jcp.nb_ow - 2, reg_tmp_imm); // next to last ow-block ? + mov(reg_oi, n_oi_next_last_ow_block); + b(EQ, oi_loop_label); + mov(reg_oi, n_oi_not_last_ow_block); // other middle ow-blocks + + // oi loop w/o padding + L(oi_loop_label); + L(oi_loop_start_label); + cmp(reg_oi, 0); + b(LE, oi_loop_end_label); + + add_imm(reg_out_prf, reg_out_prf, out_shift, reg_tmp_imm); + compute_loop(ur_w, 0, 0); + add_imm(reg_inp, reg_inp, inp_shift, reg_tmp_imm); + add_imm(reg_out, reg_out, out_shift, reg_tmp_imm); + sub(reg_oi, reg_oi, 1); // dec(reg_oi); + cmp(reg_oi, 0); + b(oi_loop_start_label); + L(oi_loop_end_label); + + ldr(reg_owb, ptr(kdnn_abi_param1, KDNN_GET_OFF(owb))); + + cmp(reg_owb, 0); // first ow-block ? + if (first_ow_block_padded) { + b(EQ, last_oi_label); + } else { + b(EQ, end_label); + } + cmp_imm(reg_owb, jcp.nb_ow - 2, reg_tmp_imm); // next to last ow-block ? + b(LT, end_label); + if (next_last_ow_block_padded) { + b(EQ, last_oi_label); + } else { + b(EQ, end_label); + } + // that is last block + if (!last_ow_block_padded) { b(tail_label); } + + // last oi block with right padding + L(last_oi_label); + add_imm(reg_out_prf, reg_out_prf, out_shift, reg_tmp_imm); + compute_loop(ur_w, 0, r_pad1); + add_imm(reg_inp, reg_inp, inp_shift, reg_tmp_imm); + add_imm(reg_out, reg_out, out_shift, reg_tmp_imm); + + ldr(reg_owb, ptr(kdnn_abi_param1, KDNN_GET_OFF(owb))); + cmp_imm(reg_owb, jcp.nb_ow - 1, reg_tmp_imm); // last ow_block? + b(LT, end_label); + + L(tail_label); + if (ur_w_tail != 0) { + add_imm(reg_out_prf, reg_out_prf, out_shift, reg_tmp_imm); + compute_loop(ur_w_tail, 0, r_pad); + } + L(end_label); + } + postamble(); + +} + +template +bool kdnn_jit_sve_conv_fwd_kernel::post_ops_ok( + kdnn_jit_conv_conf_t &jcp, const primitive_attr_t &attr) { + const auto &p = attr.post_ops_; + + auto is_eltwise = [&](int idx) { return p.entry_[idx].is_eltwise(); }; + auto is_sum = [&](int idx) { return p.entry_[idx].is_sum(); }; + + switch (p.len()) { + case 0: return true; // no post_ops + case 1: return is_eltwise(0) || is_sum(0); // sum OR eltwise + case 2: return is_sum(0) && is_eltwise(1); // sum -> eltwise + default: return false; + } + + return false; +} + +template +status_t kdnn_jit_sve_conv_fwd_kernel::init_conf(kdnn_jit_conv_conf_t &jcp, + const convolution_desc_t &cd, memory_desc_t &src_md, + memory_desc_t &weights_md, memory_desc_t &dst_md, + memory_desc_t &bias_md, const primitive_attr_t &attr, int nthreads) { + using namespace prop_kind; + + if (!mayiuse(isa)) { return status::unimplemented; } + + const memory_desc_wrapper src_d(&src_md); + const memory_desc_wrapper weights_d(&weights_md); + const memory_desc_wrapper dst_d(&dst_md); + const memory_desc_wrapper bias_d(&bias_md); + + const int regs = 28; + const bool with_groups = weights_d.ndims() == src_d.ndims() + 1; + int ndims = src_d.ndims(); + + jcp = zero(); + jcp.nthr = jcp.aligned_threads = nthreads; + jcp.ndims = ndims; + jcp.prop_kind = cd.prop_kind; + jcp.ngroups = with_groups ? weights_d.dims()[0] : 1; + jcp.mb = src_d.dims()[0]; + jcp.oc = dst_d.dims()[1] / jcp.ngroups; + jcp.oc_without_padding = jcp.oc; + jcp.ic = src_d.dims()[1] / jcp.ngroups; + jcp.ic_without_padding = jcp.ic; + jcp.id = (ndims == 5) ? src_d.dims()[2] : 1; + jcp.ih = (ndims == 3) ? 1 : src_d.dims()[ndims - 2]; + jcp.iw = src_d.dims()[ndims - 1]; + jcp.od = (ndims == 5) ? dst_d.dims()[2] : 1; + jcp.oh = (ndims == 3) ? 1 : dst_d.dims()[ndims - 2]; + jcp.ow = dst_d.dims()[ndims - 1]; + jcp.kd = (ndims == 5) ? weights_d.dims()[with_groups + 2] : 1; + jcp.kh = (ndims == 3) ? 1 : weights_d.dims()[with_groups + ndims - 2]; + jcp.kw = weights_d.dims()[with_groups + ndims - 1]; + jcp.f_pad = (ndims == 5) ? cd.padding[0][0] : 0; + jcp.t_pad = (ndims == 3) ? 0 : cd.padding[0][ndims - 4]; + jcp.l_pad = cd.padding[0][ndims - 3]; + jcp.stride_d = (ndims == 5) ? cd.strides[0] : 1; + jcp.stride_h = (ndims == 3) ? 1 : cd.strides[ndims - 4]; + jcp.stride_w = cd.strides[ndims - 3]; + + jcp.dilate_d = (ndims == 5) ? cd.dilates[0] : 0; + jcp.dilate_h = (ndims == 3) ? 0 : cd.dilates[ndims - 4]; + jcp.dilate_w = cd.dilates[ndims - 3]; + + int ext_kw = calculate_extended_filter_size(jcp.kw, jcp.dilate_w); + int ext_kh = calculate_extended_filter_size(jcp.kh, jcp.dilate_h); + int ext_kd = calculate_extended_filter_size(jcp.kd, jcp.dilate_d); + jcp.r_pad = calculate_end_padding( + jcp.l_pad, jcp.ow, jcp.iw, jcp.stride_w, ext_kw); + jcp.b_pad = calculate_end_padding( + jcp.t_pad, jcp.oh, jcp.ih, jcp.stride_h, ext_kh); + jcp.back_pad = calculate_end_padding( + jcp.f_pad, jcp.od, jcp.id, jcp.stride_d, ext_kd); + bool kernel_outside_src = false || ext_kw <= jcp.l_pad + || ext_kw <= jcp.r_pad || ext_kh <= jcp.t_pad || ext_kh <= jcp.b_pad + || ext_kd <= jcp.f_pad || ext_kd <= jcp.back_pad; + if (kernel_outside_src) { return status::unimplemented; } + + const auto dat_tag_nxc = pick(ndims - 3, nwc, nhwc, ndhwc); + const auto dat_tag_ncx = pick(ndims - 3, ncw, nchw, ncdhw); + const auto dat_tag_nCx8c = pick(ndims - 3, nCw8c, nChw8c, nCdhw8c); + const auto dat_tag_nCx16c = pick(ndims - 3, nCw16c, nChw16c, nCdhw16c); + auto curr_src_tag = src_d.matches_one_of_tag( + dat_tag_nxc, dat_tag_nCx16c, dat_tag_nCx8c, dat_tag_ncx); + auto curr_dst_tag = dst_d.matches_one_of_tag( + dat_tag_nxc, dat_tag_nCx16c, dat_tag_nCx8c); + bool is_data_layout_nxc + = utils::everyone_is(dat_tag_nxc, curr_src_tag, curr_dst_tag); + + /* 1st convolution check */ + jcp.is_1stconv = is_1stconv(jcp); + + /* Padding check (Channel) */ + bool ok_to_pad_channels + = true && jcp.ngroups == 1 && src_d.data_type() == data_type::f32; + + int full_simd_w = cpu_isa_traits::vlen / typesize; + jcp.simd_w = full_simd_w; + jcp.oc_block = jcp.simd_w; + jcp.ic_block = jcp.is_1stconv ? jcp.ic : jcp.simd_w; + + /* Channel padding */ + if (ok_to_pad_channels) { + jcp.oc = rnd_up(jcp.oc, jcp.oc_block); + jcp.ic = rnd_up(jcp.ic, jcp.ic_block); + } + + /* Input and output channels must be multiples of simd_w */ + if (!(jcp.oc % jcp.oc_block == 0 && jcp.ic % jcp.ic_block == 0)) { + return status::unimplemented; + } + jcp.ic_tail = 0; + jcp.oc_tail = 0; + + /* Post operation check */ + if (!post_ops_ok(jcp, attr)) { return status::unimplemented; } + + /* Eltwise operation check */ + //const auto &p = attr.post_ops_; + //jcp.with_sum = p.find(primitive_kind::sum) != -1; + //const int eltwise_ind = p.find(primitive_kind::eltwise); + //jcp.with_eltwise = eltwise_ind != -1; + //if (jcp.with_eltwise) { + // jcp.eltwise = p.entry_[eltwise_ind].eltwise; + // if (!eltwise_injector::is_supported(isa, jcp.eltwise.alg)) + // return status::unimplemented; + // if (dst_d.data_type() == data_type::s32) return status::unimplemented; + //} + + format_tag_t src_tag, dst_tag, wei_tag; + + switch (isa) { + case sve_512: + dst_tag = dat_tag_nCx16c; + src_tag = jcp.is_1stconv ? dat_tag_ncx : dat_tag_nCx16c; + wei_tag = pick(2 * ndims - 6 + with_groups, OIw16i16o, gOIw16i16o, + OIhw16i16o, gOIhw16i16o, OIdhw16i16o, gOIdhw16i16o); + break; + case sve_256: + dst_tag = dat_tag_nCx8c; + src_tag = jcp.is_1stconv ? dat_tag_ncx : dat_tag_nCx8c; + wei_tag = pick(2 * ndims - 6 + with_groups, OIw8i8o, gOIw8i8o, + OIhw8i8o, gOIhw8i8o, OIdhw8i8o, gOIdhw8i8o); + break; + default: break; + } + + if (src_md.format_kind == format_kind::any) + CHECK(memory_desc_init_by_tag(src_md, src_tag)); + else if (curr_src_tag != src_tag) + return status::unimplemented; + jcp.src_tag = src_tag; + + if (dst_md.format_kind == format_kind::any) + CHECK(memory_desc_init_by_tag(dst_md, dst_tag)); + else if (curr_dst_tag != dst_tag) + return status::unimplemented; + jcp.dst_tag = dst_tag; + + jcp.with_bias = cd.bias_desc.format_kind != format_kind::undef; + if (jcp.with_bias) { + if (bias_d.format_kind() == format_kind::any) + CHECK(memory_desc_init_by_tag(bias_md, x)); + } + + if (mayiuse(isa) && src_d.data_type() == data_type::f32 + && weights_d.data_type() == data_type::f32 + && dst_d.data_type() == data_type::f32) { + jcp.ver = kdnn_ver_fma; + jcp.typesize_in = typesize; + jcp.typesize_out = typesize; + + if (jcp.is_1stconv) { + switch (isa) { + case sve_512: + wei_tag = with_groups + ? pick(ndims - 3, gOwi16o, gOhwi16o, gOdhwi16o) + : pick(ndims - 3, Owi16o, Ohwi16o, Odhwi16o); + break; + case sve_256: + wei_tag = with_groups + ? pick(ndims - 3, gOwi8o, gOhwi8o, gOdhwi8o) + : pick(ndims - 3, Owi8o, Ohwi8o, Odhwi8o); + break; + default: break; + } + } + } else { + return status::unimplemented; + } + + if (init_tag(jcp.wei_tag, weights_md, weights_d, wei_tag) + != status::success) + return status::unimplemented; + + jcp.ur_w = nstl::min(jcp.ow, regs); // ur_w is min(output width, regs=28) + + int n_oi = (jcp.ow / jcp.ur_w); + int r_pad = calculate_end_padding( + jcp.l_pad, jcp.ur_w * n_oi, jcp.iw, jcp.stride_w, ext_kw); + if (jcp.l_pad > 0 && r_pad > 0) n_oi--; + + /* Grouped channel offset to support 'non-blocked data' format for + * convolution sizes with '(input_channel / ngroups) < simd' */ + jcp.nonblk_group_off + = (jcp.ngroups > 1 && one_of(jcp.src_tag, ncw, nchw, ncdhw)) + ? jcp.ic + : 1; + + jcp.nb_ic = div_up(jcp.ic, jcp.ic_block); + jcp.nb_oc = div_up(jcp.oc, jcp.oc_block); + jcp.nb_ic_blocking = jcp.nb_oc_blocking = 1; + + jcp.ow_block = jcp.ow; + + auto get_thr_eff = [=](int nb_oc_blocking, int ow_block) { + int nb_ow = div_up(jcp.ow, ow_block); + int nb_oc_chunks = div_up(jcp.nb_oc, nb_oc_blocking); + int work_amount = jcp.mb * jcp.oh * nb_oc_chunks * nb_ow; + float disbalance = (float)jcp.ow / rnd_up(jcp.ow, ow_block); + float thr_eff = disbalance * (float)work_amount + / rnd_up(work_amount, jcp.nthr); + return thr_eff; + }; + + auto get_ow_block = [=](int nb_oc_blocking, int ur_w, float &eff) { + int res_ow_block = jcp.ow; + eff = get_thr_eff(nb_oc_blocking, res_ow_block); + + return res_ow_block; + }; + + if (jcp.ver == kdnn_ver_fma && mayiuse(isa)) { + // These conditions define a set of shapes with 'ow = 1' which + // have a very limited optimization space for performance. Try + // to optimize by using a larger 'nb_oc_blocking' size. + bool expl_bcast_condition + = everyone_is(1, jcp.ngroups, jcp.mb, jcp.stride_h, jcp.ow, + jcp.stride_w, jcp.id, jcp.od, jcp.kd, jcp.stride_d) + && jcp.iw == jcp.kw && jcp.nb_oc > 1 + && everyone_is(0, jcp.l_pad, jcp.r_pad, jcp.dilate_w, jcp.f_pad, + jcp.back_pad, jcp.dilate_d) + && jcp.oh >= 60 && jcp.kh >= 3; + + if (jcp.mb == 1) { + unsigned int inp_size = jcp.mb * div_up(jcp.ih, jcp.stride_h) + * div_up(jcp.iw, jcp.stride_w) * jcp.ic; + unsigned int wei_size = jcp.ic * jcp.oc * jcp.kh * jcp.kw; + + // Estimate whether we need to limit the number of threads + // and calculate this number. Includes some heuristic. + int oc_chunks = jcp.nb_oc / jcp.nb_oc_blocking; + int work_amount = jcp.mb * jcp.ngroups * oc_chunks * jcp.oh; + int job_size_min = work_amount / nthreads; + int job_size_max = div_up(work_amount, nthreads); + int ch_max = rnd_up(jcp.oh, job_size_max); + int ch_min = (job_size_min == 0) ? jcp.oh + : rnd_up(jcp.oh, job_size_min); + bool not_aligned_max = ch_max % jcp.oh != 0 && ch_max / jcp.oh < 2 + && (jcp.oh != 8 || ch_max / jcp.oh > 1); + bool not_aligned_min = ch_min % jcp.oh != 0 && ch_min / jcp.oh < 2 + && (jcp.oh != 8 || ch_min / jcp.oh > 1); + bool eligible_case = (jcp.stride_h == 1 && jcp.stride_w == 1) + || nthreads > oc_chunks; + if (jcp.loop_order == kdnn_loop_cgn && oc_chunks > 1 && nthreads > 1 + && wei_size / inp_size > 24 + && (not_aligned_max || not_aligned_min) && eligible_case) { + // Try to find number of threads > nthreads / 2 such that + // oc_chunks is a multiple of nthreads, or nthreads is a + // multiple of oc_chunks. Otherwise, keep default value. + // TODO: implement a task-based alternative without throttling. + jcp.aligned_threads = jcp.nthr; + for (int i = jcp.nthr; i > jcp.nthr / 2; i--) { + if (oc_chunks % i == 0 || i % oc_chunks == 0) { + jcp.aligned_threads = i; + break; + } + } + } + } + + const int max_nb_oc = 2; + { + jcp.kernel_kind = kdnn_expl_bcast; + jcp.nb_ic_blocking = 1; + if (IMPLICATION(jcp.is_1stconv, jcp.mb >= 1) + || expl_bcast_condition) { + float best_thr_eff = 0.f; + int best_nb_oc_blocking = 1; + for (int i = nstl::min(jcp.nb_oc, max_nb_oc); i > 0; i--) { + if (jcp.nb_oc % i == 0) { + if (expl_bcast_condition) { + best_nb_oc_blocking = i; + break; + } else { + float thr_eff; + int ur_w = nstl::min(jcp.ow, 31 / (i + 1)); + get_ow_block(i, ur_w, thr_eff); + if (thr_eff > 1.05f * best_thr_eff) { + best_nb_oc_blocking = i; + best_thr_eff = thr_eff; + } + } + } + } + jcp.nb_oc_blocking = best_nb_oc_blocking; + jcp.ur_w = nstl::min(jcp.ow, 31 / (jcp.nb_oc_blocking + 1)); + if (jcp.l_pad > jcp.ur_w) { + jcp.nb_oc_blocking = 1; + jcp.ur_w = nstl::min(jcp.ow, 31 / (jcp.nb_oc_blocking + 1)); + } + if (jcp.l_pad >= 16) { jcp.ur_w = nstl::min(jcp.l_pad, 29); } + } + } + } + + jcp.ur_w_tail = jcp.ow % jcp.ur_w; + + bool args_ok = true && jcp.l_pad <= jcp.ur_w + && jcp.ic <= src_d.padded_dims()[1] + && jcp.oc <= dst_d.padded_dims()[1] + && jcp.ic <= weights_d.padded_dims()[with_groups + 1] + && jcp.oc <= weights_d.padded_dims()[with_groups + 0]; + if (!args_ok) return status::unimplemented; + + int r_pad_no_tail = nstl::max(0, + calculate_end_padding(jcp.l_pad, jcp.ow - jcp.ur_w_tail, jcp.iw, + jcp.stride_w, ext_kw)); + if (r_pad_no_tail > jcp.ur_w) return status::unimplemented; + + pick_loop_order(jcp); + + jcp.nb_ic_L2 = jcp.nb_ic; + + float thr_eff; + jcp.ow_block = get_ow_block(jcp.nb_oc_blocking, jcp.ur_w, thr_eff); + jcp.nb_ow = div_up(jcp.ow, jcp.ow_block); + + const int L2_size = platform::get_per_core_cache_size(2) / sizeof(float); + + // Source and output data needs to fit in L2, + // leaving some space for weights and prefetching. + int h_L2 = int(((0.6f * L2_size) / jcp.simd_w + - nstl::min(0, jcp.kh - jcp.stride_h) * jcp.iw) + / (jcp.stride_h * jcp.iw + jcp.ow)); + jcp.h_blocking = nstl::max(1, nstl::min(jcp.oh, h_L2)); + + if (is_data_layout_nxc) { + // TODO: improve L2 blocking for large IC + const int nb_ic_theshold_L2 = 32; + if (jcp.nb_ic > nb_ic_theshold_L2 && jcp.nb_ic < 2 * nb_ic_theshold_L2) + jcp.nb_ic_L2 = div_up(jcp.nb_ic, 2); + else + jcp.nb_ic_L2 = nstl::min(nb_ic_theshold_L2, jcp.nb_ic); + } + + // A rough check on code size + // TODO: come up with a tighter bound + { + const int max_code_size = 256 * 1024; // default size of jit generator + int mult = 1 + (jcp.l_pad > 0) + (r_pad > 0); + const float max_instruction_size = 15; + float ur_fac + = (float)jcp.kw * jcp.ic_block * jcp.nb_oc_blocking * jcp.ur_w; + float code_size = mult * ur_fac * max_instruction_size; + if (code_size > max_code_size) return status::unimplemented; + } + + return status::success; +} + +template +void kdnn_jit_sve_conv_fwd_kernel::init_scratchpad( + memory_tracking::registrar_t &scratchpad, const kdnn_jit_conv_conf_t &jcp) { + + if (jcp.with_bias && jcp.oc != jcp.oc_without_padding) + scratchpad.book(key_conv_padded_bias, jcp.oc, jcp.typesize_out); +} + +/*struct instantiation*/ +template struct kdnn_jit_sve_conv_fwd_kernel; + +} // namespace aarch64 +} // namespace cpu +} // namespace impl +} // namespace dnnl + -- Gitee From 3fe999333d6974410eaa7e2027f191fe98f3f71e Mon Sep 17 00:00:00 2001 From: Codersheepchen Date: Thu, 13 Jun 2024 07:52:14 +0000 Subject: [PATCH 2/2] hpp Signed-off-by: Codersheepchen --- .../aarch64/kdnn/kdnn_jit_sve_conv_kernel.hpp | 200 ++++++++++++++++++ 1 file changed, 200 insertions(+) create mode 100644 oneDNN-3.4/src/cpu/aarch64/kdnn/kdnn_jit_sve_conv_kernel.hpp diff --git a/oneDNN-3.4/src/cpu/aarch64/kdnn/kdnn_jit_sve_conv_kernel.hpp b/oneDNN-3.4/src/cpu/aarch64/kdnn/kdnn_jit_sve_conv_kernel.hpp new file mode 100644 index 0000000..8d6568b --- /dev/null +++ b/oneDNN-3.4/src/cpu/aarch64/kdnn/kdnn_jit_sve_conv_kernel.hpp @@ -0,0 +1,200 @@ +#ifndef KDNN_JIT_SVE_CONV_KERNEL_HPP +#define KDNN_JIT_SVE_CONV_KERNEL_HPP + +#include "common/c_types_map.hpp" +#include "common/memory_tracking.hpp" + +#include "cpu/aarch64/kdnn/kdnn_jit_generator.hpp" +#include "cpu/aarch64/kdnn/kdnn_jit_primitive_conf.hpp" +#include "cpu/aarch64/kdnn/kdnn_jit_op_imm_check.hpp" + +#define kdnn_VL_OFS(ofs, isa) (ofs >> cpu_isa_traits::vlen_shift) + +using namespace Xbyak_aarch64; + +namespace dnnl { +namespace impl { +namespace cpu { +namespace aarch64 { + +template +struct kdnn_jit_sve_conv_fwd_kernel : public kdnn_jit_generator { + + kdnn_jit_sve_conv_fwd_kernel( + const kdnn_jit_conv_conf_t &ajcp, const primitive_attr_t &attr) + : jcp(ajcp), attr_(attr) {} + + DECLARE_KDNN_JIT_AUX_FUNCTIONS(kdnn_jit_sve_conv_fwd_kernel) + + kdnn_jit_conv_conf_t jcp; + const primitive_attr_t &attr_; + + static bool post_ops_ok(kdnn_jit_conv_conf_t &jcp, const primitive_attr_t &attr); + static status_t init_conf(kdnn_jit_conv_conf_t &jcp, + const convolution_desc_t &cd, memory_desc_t &src_pd, + memory_desc_t &weights_pd, memory_desc_t &dst_pd, + memory_desc_t &bias_pd, const primitive_attr_t &attr, int nthreads); + static void init_scratchpad(memory_tracking::registrar_t &scratchpad, + const kdnn_jit_conv_conf_t &jcp); + +private: + using reg64_t = const XReg; + enum { + typesize = sizeof(float), + ker_reg_base_idx = 28, + }; + + reg64_t param = kdnn_abi_param1; + reg64_t reg_inp = x1; // src base addr (2d) + reg64_t reg_ker = x2; // ker base addr (2d) + reg64_t aux_reg_ker_d = x2; // ker addr (3d) + reg64_t reg_out = x3; // dst base addr (2d) + reg64_t reg_ki = x3; // d-dim loop var? (3d) + reg64_t reg_owb = x5; // num of ow-block + reg64_t reg_out_prf = x6; // addr for prefetch + + reg64_t aux_reg_inp = x7; // src addr (main loop) + reg64_t aux_reg_inp2 = x24; // src addr (main loop) + reg64_t aux_reg_inp3 = x25; // src addr (main loop) + reg64_t reg_out_ofs = x7; // dst addr (store_output) + reg64_t aux_reg_ker = x8; // ker addr (main loop) + reg64_t reg_channel = x9; // reduce workload + reg64_t reg_bias = x10; // bias addr (prepare_out) + + reg64_t aux_reg_inp_d = x11; // src addr (3d) + reg64_t reg_oi = x11; + + reg64_t reg_kh = x12; // ker h size + reg64_t reg_kj = x13; // ker h workload + + /* Temporary registers for ARM insts */ + reg64_t reg_tmp_addr = x14; + reg64_t reg_prev_bcast_addr = x15; + reg64_t reg_prev_wei_addr = x16; + reg64_t reg_tmp_imm = x17; + + reg64_t reg_out_org = x27; // dst base addr (3d) + reg64_t reg_oi_org = x19; // base oi (3d) + reg64_t aux_reg_ker_d_org = x20; + reg64_t reg_ker_org = x21; // ker base addr (3d) + reg64_t reg_inp_org = x29; // src base addr (3d) + + void prefetch( + const std::string prfop, int level, reg64_t in, long long int ofs) { + bool for_load = false; + if (prfop == "LD") { + for_load = true; + } else if (prfop == "ST") { + for_load = false; + } else { + assert(!"invalid prfop"); + } + + bool cacheline_aligned = ((ofs & 0xFF) == 0) ? true : false; + if (cacheline_aligned == true) { + Prfop op = PLDL1KEEP; + switch (level) { + case 1: op = (for_load == true) ? PLDL1KEEP : PSTL1KEEP; break; + case 2: op = (for_load == true) ? PLDL2KEEP : PSTL2KEEP; break; + case 3: op = (for_load == true) ? PLDL3KEEP : PSTL3KEEP; break; + default: assert(!"invalid prfop"); break; + } + + if ((ofs <= PRFMMAX) && (ofs >= 0)) { + prfm(op, ptr(in, static_cast(ofs))); + } else { + add_imm(reg_tmp_addr, in, ofs, reg_tmp_imm); + prfm(op, ptr(reg_tmp_addr)); + } + } else { + PrfopSve op_sve = PLDL1KEEP_SVE; + switch (level) { + case 1: + op_sve = (for_load == true) ? PLDL1KEEP_SVE : PSTL1KEEP_SVE; + break; + case 2: + op_sve = (for_load == true) ? PLDL2KEEP_SVE : PSTL2KEEP_SVE; + break; + case 3: + op_sve = (for_load == true) ? PLDL3KEEP_SVE : PSTL3KEEP_SVE; + break; + default: assert(!"invalid level"); break; + } + + if ((kdnn_VL_OFS(ofs, isa) < PRFWMAX) + && (kdnn_VL_OFS(ofs, isa) >= (-1 * PRFWMAX))) { + prfw(op_sve, P_ALL_ONE, + ptr(in, static_cast(kdnn_VL_OFS(ofs, isa)))); + } else { + add_imm(reg_tmp_addr, in, ofs, reg_tmp_imm); + prfw(op_sve, P_ALL_ONE, ptr(reg_tmp_addr)); + } + } + } + + inline void prepare_output(int ur_w); + inline void store_output(int ur_w); + inline void compute_loop_fma_core(int ur_w, int pad_l, int pad_r); + inline void compute_loop(int ur_w, int pad_l, int pad_r); + + void generate() override; + + inline size_t get_output_offset(int oi, int n_oc_block) { + const bool is_nxc_layout = is_dst_layout_nxc(); + size_t ow_str = is_nxc_layout ? jcp.ngroups * jcp.oc : jcp.oc_block; + size_t ocb_str = is_nxc_layout + ? jcp.oc_block + : (size_t)jcp.od * jcp.oh * jcp.ow * jcp.oc_block; + + return jcp.typesize_out * (n_oc_block * ocb_str + oi * ow_str); + } + + inline size_t get_input_offset(int ki, int ic, int oi, int pad_l) { + const bool is_nxc_layout = is_src_layout_nxc(); + size_t iw_str = is_nxc_layout ? jcp.ngroups * jcp.ic + : (!jcp.is_1stconv ? jcp.ic_block : 1); + size_t ic_str = !jcp.is_1stconv || is_nxc_layout + ? 1 + : (size_t)jcp.iw * jcp.ih * jcp.id; + size_t iw_idx = ki * (jcp.dilate_w + 1) + oi * jcp.stride_w - pad_l; + + return jcp.typesize_in * (iw_idx * iw_str + ic * ic_str); + } + + inline int get_kernel_offset( + int ki, int ic, int n_oc_block, int ker_number) { + return jcp.typesize_in * jcp.oc_block + * (n_oc_block * jcp.nb_ic * jcp.ic_block * jcp.kh * jcp.kw + * jcp.kd + + (ic + ker_number) + ki * jcp.ic_block); + } + + inline int get_ow_start(int ki, int pad_l) { + return nstl::max(0, + utils::div_up(pad_l - ki * (jcp.dilate_w + 1), jcp.stride_w)); + } + + inline int get_ow_end(int ur_w, int ki, int pad_r) { + return ur_w + - nstl::max(0, + utils::div_up( + pad_r - (jcp.kw - 1 - ki) * (jcp.dilate_w + 1), + jcp.stride_w)); + } + inline bool is_src_layout_nxc() { + return utils::one_of(jcp.src_tag, format_tag::ndhwc, format_tag::nhwc, + format_tag::nwc); + } + inline bool is_dst_layout_nxc() { + return utils::one_of(jcp.dst_tag, format_tag::ndhwc, format_tag::nhwc, + format_tag::nwc); + } +}; + +} // namespace aarch64 +} // namespace cpu +} // namespace impl +} // namespace dnnl + +#endif + -- Gitee