From 704811fba229ed1aacb4922ea86d7c78a8839edf Mon Sep 17 00:00:00 2001 From: Tomasz Czeszun Date: Sat, 14 Dec 2024 08:10:35 -0800 Subject: [PATCH 1/2] tests: benchdnn: add conv per_dim_1 src zp coverage --- tests/benchdnn/inputs/conv/test_conv_ci | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/benchdnn/inputs/conv/test_conv_ci b/tests/benchdnn/inputs/conv/test_conv_ci index 25cdf0d5476..ab9999a7544 100644 --- a/tests/benchdnn/inputs/conv/test_conv_ci +++ b/tests/benchdnn/inputs/conv/test_conv_ci @@ -63,7 +63,7 @@ --attr-zero-points= --batch=shapes_basic --attr-post-ops= ---attr-zero-points=,src:common:2+dst:common:1,src:per_dim_1+dst:per_dim_1 +--attr-zero-points=,src:common:2+dst:common:1,src:per_dim_1+dst:per_dim_1,src:per_dim_1+dst:common:1 --batch=shapes_basic ### Signed input --dt=s8:s8:s8 @@ -77,7 +77,7 @@ --attr-zero-points= --batch=shapes_basic --attr-post-ops= ---attr-zero-points=,src:common:2+dst:common:1,src:per_dim_1+dst:per_dim_1 +--attr-zero-points=,src:common:2+dst:common:1,src:per_dim_1+dst:per_dim_1,src:per_dim_1+dst:common:1 --batch=shapes_basic # BF32 --reset From e1edbbaaf7c54a46cd9e83196dbd0c3d23cc2973 Mon Sep 17 00:00:00 2001 From: Tomasz Czeszun Date: Sat, 14 Dec 2024 07:51:20 -0800 Subject: [PATCH 2/2] x64: conv: brdgmm: enable zps per group --- src/common/primitive_attr.hpp | 1 + src/cpu/x64/brgemm/brgemm.cpp | 26 +++-- src/cpu/x64/brgemm/brgemm_types.hpp | 8 +- src/cpu/x64/brgemm/jit_brdgmm_kernel.cpp | 120 +++++++++++++++++------ src/cpu/x64/brgemm/jit_brdgmm_kernel.hpp | 5 +- src/cpu/x64/jit_brdgmm_dw_conv.cpp | 12 ++- 6 files changed, 128 insertions(+), 44 deletions(-) diff --git a/src/common/primitive_attr.hpp b/src/common/primitive_attr.hpp index 5e1496978ed..c187d6e6cde 100644 --- a/src/common/primitive_attr.hpp +++ b/src/common/primitive_attr.hpp @@ -387,6 +387,7 @@ struct zero_points_t : public c_compatible { // arg-specific checks bool common(int arg) const { return get_mask(arg) == 0; } + bool per_dim_1(int arg) const { return get_mask(arg) == 2; } bool defined(int arg) const { return has_default_values(arg); } bool has_default_values(int arg) const { return is_set(arg) == false && has_default_data_type(arg); diff --git a/src/cpu/x64/brgemm/brgemm.cpp b/src/cpu/x64/brgemm/brgemm.cpp index 56d5c83a47c..347dcb6b5fa 100644 --- a/src/cpu/x64/brgemm/brgemm.cpp +++ b/src/cpu/x64/brgemm/brgemm.cpp @@ -205,6 +205,7 @@ void brgemm_kernel_execute_postops(const brgemm_kernel_t *brg_kernel, int bs, brgemm_p.first_mb_matrix_addr_off = post_ops_data.first_mb_matrix_addr_off; brgemm_p.a_zp_compensations = post_ops_data.a_zp_compensations; brgemm_p.b_zp_compensations = post_ops_data.b_zp_compensations; + brgemm_p.a_zp_values = post_ops_data.a_zp_values; brgemm_p.c_zp_values = post_ops_data.c_zp_values; brgemm_p.ptr_dst_scales = post_ops_data.dst_scales; if (dynamic_values) { @@ -457,19 +458,30 @@ status_t brgemm_desc_set_postops(brgemm_desc_t *brg, auto zero_points = attr->zero_points_; // common zero point type is supported for now - if (!zero_points.common(mem_arg)) return status::unimplemented; + const bool is_per_dim_1_bcast = zero_points.per_dim_1(mem_arg); + const bool is_common_bcast = zero_points.common(mem_arg); + if (!is_common_bcast && !is_per_dim_1_bcast) + return status::unimplemented; const bool skip_zero_point = mem_arg == DNNL_ARG_WEIGHTS && brg->skip_zp_b_compensation; - zp_type = zero_points.has_default_values(mem_arg) || skip_zero_point - ? brgemm_broadcast_t::none - : brgemm_broadcast_t::per_tensor; + + zp_type = brgemm_broadcast_t::none; + const bool is_any_bcast + = !(zero_points.has_default_values(mem_arg) || skip_zero_point); + if (is_any_bcast) { + if (is_common_bcast) + zp_type = brgemm_broadcast_t::per_tensor; + else if (is_per_dim_1_bcast) + zp_type = brgemm_broadcast_t::per_n; + } + return status::success; }; - init_zp_type(brg->zp_type_a, DNNL_ARG_SRC); - init_zp_type(brg->zp_type_b, DNNL_ARG_WEIGHTS); - init_zp_type(brg->zp_type_c, DNNL_ARG_DST); + CHECK(init_zp_type(brg->zp_type_a, DNNL_ARG_SRC)); + CHECK(init_zp_type(brg->zp_type_b, DNNL_ARG_WEIGHTS)); + CHECK(init_zp_type(brg->zp_type_c, DNNL_ARG_DST)); // Post-ops may use vector registers so brgemm/brdgmm blocking may need to // be updated diff --git a/src/cpu/x64/brgemm/brgemm_types.hpp b/src/cpu/x64/brgemm/brgemm_types.hpp index b22309d4331..02904815cb0 100644 --- a/src/cpu/x64/brgemm/brgemm_types.hpp +++ b/src/cpu/x64/brgemm/brgemm_types.hpp @@ -472,6 +472,7 @@ struct brgemm_kernel_params_t { const void *a_zp_compensations = nullptr; const void *b_zp_compensations = nullptr; + const void *a_zp_values = nullptr; const void *c_zp_values = nullptr; size_t skip_accm = 0; int32_t zp_a_val = 1; @@ -582,7 +583,8 @@ struct brgemm_post_ops_data_t { const void *b_zp_compensations = nullptr, const void *c_zp_values = nullptr, bool skip_accumulation = false, int32_t zp_a_val = 1, bool do_only_comp = false, - bool do_only_zp_a_val = false, const float *dst_scales = nullptr) + bool do_only_zp_a_val = false, const float *dst_scales = nullptr, + const void *a_zp_values = nullptr) : bias(bias) , scales(scales) , binary_post_ops_rhs(binary_post_ops_rhs) @@ -597,7 +599,8 @@ struct brgemm_post_ops_data_t { , zp_a_val {zp_a_val} , do_only_comp {do_only_comp} , do_only_zp_a_val {do_only_zp_a_val} - , dst_scales(dst_scales) {} + , dst_scales(dst_scales) + , a_zp_values(a_zp_values) {} const void *bias = nullptr; const float *scales = nullptr; @@ -614,6 +617,7 @@ struct brgemm_post_ops_data_t { const bool do_only_comp = false; const bool do_only_zp_a_val = false; const float *dst_scales = nullptr; + const void *a_zp_values = nullptr; }; } // namespace x64 diff --git a/src/cpu/x64/brgemm/jit_brdgmm_kernel.cpp b/src/cpu/x64/brgemm/jit_brdgmm_kernel.cpp index 71a1c32492c..e9ccb5dae1c 100644 --- a/src/cpu/x64/brgemm/jit_brdgmm_kernel.cpp +++ b/src/cpu/x64/brgemm/jit_brdgmm_kernel.cpp @@ -45,6 +45,7 @@ jit_brdgmm_kernel_base_t::jit_brdgmm_kernel_base_t( , max_vmms_(isa_num_vregs(brg.isa_impl)) , compute_dst_zp_(brg.zp_type_c != brgemm_broadcast_t::none) , compute_src_zp_(brg.zp_type_a != brgemm_broadcast_t::none) + , is_src_zp_bcast_(brg.zp_type_a == brgemm_broadcast_t::per_tensor) , compute_compensation_(compute_src_zp_ || brg.req_s8s8_compensation) , has_vpad_(brg.brgattr.max_top_vpad > 0 || brg.brgattr.max_bottom_vpad > 0) , has_bpad_(brg.brgattr.max_top_bpad > 0 || brg.brgattr.max_bottom_bpad > 0) @@ -147,7 +148,7 @@ void jit_brdgmm_kernel_base_t::read_params() { } if (compute_src_zp_) { - mov(reg_tmp, ptr[param1 + GET_OFF(zp_a_val)]); + mov(reg_tmp, ptr[param1 + GET_OFF(a_zp_values)]); mov(ptr[rsp + src_zp_value_], reg_tmp); mov(reg_tmp, ptr[param1 + GET_OFF(a_zp_compensations)]); @@ -609,6 +610,17 @@ void jit_brdgmm_kernel_base_t::maybe_transpose_interleaved_vnni_to_plain( } } +template +void jit_brdgmm_kernel_base_t::load_src_zp() { + mov(reg_src_zero_point, ptr[rsp + src_zp_value_]); + lea(reg_src_zero_point, + is_src_zp_bcast_ + ? ptr_b[reg_src_zero_point] + : ptr[reg_src_zero_point + reg_aux_N * sizeof(int32_t)]); + if (!is_superset(brg.isa_impl, avx512_core) && is_src_zp_bcast_) + uni_vpbroadcastd(vmm_bcast(), ptr[reg_src_zero_point]); +} + template void jit_brdgmm_kernel_base_t::compute_int8_compensation( int m_blocks, int n_blocks, bool has_n_tail) { @@ -620,12 +632,10 @@ void jit_brdgmm_kernel_base_t::compute_int8_compensation( lea(reg_s8s8_comp, ptr[reg_s8s8_comp + reg_aux_N * sizeof(int32_t)]); } if (compute_src_zp_) { - lea(reg_src_zero_point, ptr[rsp + src_zp_value_]); + load_src_zp(); mov(reg_zp_compensation, ptr[rsp + zp_compensation_]); lea(reg_zp_compensation, ptr[reg_zp_compensation + reg_aux_N * sizeof(int32_t)]); - if (!is_superset(brg.isa_impl, avx512_core)) - uni_vpbroadcastd(vmm_bcast(), ptr[reg_src_zero_point]); } for_(int v_i = 0; v_i < v_substep; ++v_i) @@ -640,16 +650,35 @@ void jit_brdgmm_kernel_base_t::compute_int8_compensation( } if (compute_src_zp_) { // zero_point: conv(src_x8, wei_s8) - src_shift_s32 * compensation_s32 - const Vmm vmm_zp = vmm_zp_comp(); - vmovups(vmm_zp, - maybe_EVEX_compress_addr(reg_zp_compensation, offset)); - if (is_superset(brg.isa_impl, avx512_core)) { - const bool src_zp_is_common = true; - vpmulld(vmm_zp, vmm_zp, - maybe_EVEX_compress_addr( - reg_src_zero_point, 0, src_zp_is_common)); + const bool is_tail + = n + 1 == n_blocks && has_n_tail && substep_simd < simd_w_; + const Vmm vmm_zp = isa_has_masks(brg.isa_impl) + ? maybe_mask(vmm_zp_comp(), is_tail, false) + : vmm_zp_comp(); + if (IMPLICATION(is_tail, isa_has_masks(brg.isa_impl))) { + vmovups(vmm_zp, + maybe_EVEX_compress_addr(reg_zp_compensation, offset)); + if (is_src_zp_bcast_) { + if (is_superset(brg.isa_impl, avx512_core)) + vpmulld(vmm_zp, vmm_zp, + maybe_EVEX_compress_addr( + reg_src_zero_point, 0, true)); + else + vpmulld(vmm_zp, vmm_zp, vmm_bcast()); + } else + vpmulld(vmm_zp, vmm_zp, + maybe_EVEX_compress_addr( + reg_src_zero_point, offset)); } else { - vpmulld(vmm_zp, vmm_zp, vmm_bcast()); + const int tail_size = tail_length(); + const Vmm ymm_tmp + = vmm_bcast(); // used for bcast or tail processing in avx2 + load_data(data_type::s32, vmm_zp, + ptr[reg_zp_compensation + offset], tail_size); + if (!is_src_zp_bcast_) + load_data(data_type::s32, ymm_tmp, + ptr[reg_src_zero_point + offset], tail_size); + vpmulld(vmm_zp, vmm_zp, ymm_tmp); } } for (int m = 0; m < m_blocks; m++) { @@ -795,7 +824,8 @@ void jit_brdgmm_kernel_base_t::load_b( template void jit_brdgmm_kernel_base_t::comp_dot_product( - compute_pad_kernel_t kernel_type, Vmm vmm_acc, Vmm vmmb) { + compute_pad_kernel_t kernel_type, Vmm vmm_acc, Vmm vmmb, int n, + bool is_tail_block) { switch (kernel_type) { case compute_pad_kernel_t::s8s8_kernel: vpdpbusd(vmm_acc, vmm_shift(), vmmb, @@ -803,16 +833,39 @@ void jit_brdgmm_kernel_base_t::comp_dot_product( ? Xbyak::EvexEncoding : Xbyak::VexEncoding); break; - case compute_pad_kernel_t::zero_point_kernel: - if (is_superset(brg.isa_impl, avx512_core)) { - vpmulld(vmm_zp_comp(), vmmb, - maybe_EVEX_compress_addr(reg_src_zero_point, 0, true)); + case compute_pad_kernel_t::zero_point_kernel: { + const Vmm vmm_zp = isa_has_masks(brg.isa_impl) + ? maybe_mask(vmm_zp_comp(), is_tail_block, false) + : vmm_zp_comp(); + const size_t offset = comp_offset(n); + if (IMPLICATION(is_tail_block, isa_has_masks(brg.isa_impl))) { + if (is_src_zp_bcast_) { + if (is_superset(brg.isa_impl, avx512_core)) + vpmulld(vmm_zp, vmmb, + maybe_EVEX_compress_addr( + reg_src_zero_point, 0, true)); + else + vpmulld(vmm_zp, vmmb, vmm_bcast()); + } else { + const Xbyak::Address src_zp_addr = maybe_EVEX_compress_addr( + reg_src_zero_point, offset); + if (is_fast_vnni_int8()) { + vmovups(vmm_zp, src_zp_addr); + vpermd(vmm_zp, vmm_permute(), vmm_zp); + vpmulld(vmm_zp, vmmb, vmm_zp); + } else + vpmulld(vmm_zp, vmmb, src_zp_addr); + } } else { - uni_vpbroadcastd(vmm_bcast(), ptr[reg_src_zero_point]); - vpmulld(vmm_zp_comp(), vmmb, vmm_bcast()); + const Vmm ymm_tmp + = vmm_bcast(); // used for bcast or tail processing in avx2 + if (!is_src_zp_bcast_) + load_data(data_type::s32, ymm_tmp, + ptr[reg_src_zero_point + offset], tail_length()); + vpmulld(vmm_zp, vmmb, ymm_tmp); } vpaddd(vmm_acc, vmm_acc, vmm_zp_comp()); - break; + } break; default: assert(!"unsupported comp_kernel type"); } } @@ -853,21 +906,25 @@ void jit_brdgmm_kernel_base_t::pad_comp_kernel( for (int pad_i = max_m_unroll; pad_i > 0; --pad_i) { L(jmp_table_labels[pad_i]); - if (is_zero_point_kernel) - lea(reg_src_zero_point, ptr[rsp + src_zp_value_]); + if (is_zero_point_kernel) load_src_zp(); if (pad_i > m_blocks) continue; const int m_i = get_mi(pad_i); int p_b_i = 0; for (int n_i = 0; n_i < n_blocks; ++n_i, ++p_b_i) { - if (get_substep_simd(n_i, 0, has_tail) <= 0) continue; + const int substep_simd = get_substep_simd(n_i, 0, has_tail); + if (substep_simd <= 0) continue; const Vmm vmm_acc = accm(m_blocks, n_blocks, m_i, n_i, 0); + const bool is_tail_block + = n_i + 1 == n_blocks && has_tail && substep_simd < simd_w_; if (p_b_i < n_preload_b_vmms) { - comp_dot_product(kernel_type, vmm_acc, vmm_b(p_b_i)); + comp_dot_product( + kernel_type, vmm_acc, vmm_b(p_b_i), n_i, is_tail_block); } else { // preloaded vmm_b not available const Vmm vmm_wei = vmm_b(max_bvmms - 1); load_b(vmm_wei, n_i, 0, has_tail, load_broadcast_wei); - comp_dot_product(kernel_type, vmm_acc, vmm_wei); + comp_dot_product( + kernel_type, vmm_acc, vmm_wei, n_i, is_tail_block); } } } @@ -885,8 +942,7 @@ void jit_brdgmm_kernel_base_t::batch_pad_kernel( auto kernel_body = [&](compute_pad_kernel_t kernel_type) { const bool is_zero_point_kernel = kernel_type == compute_pad_kernel_t::zero_point_kernel; - if (is_zero_point_kernel) - lea(reg_src_zero_point, ptr[rsp + src_zp_value_]); + if (is_zero_point_kernel) load_src_zp(); for (int nb_i = 0; nb_i < n_blocks; nb_i += max_bvmms) { const int n_e = nstl::min(nb_i + max_bvmms, n_blocks) - nb_i; for (int i = 0; i < n_e; ++i) { @@ -898,9 +954,13 @@ void jit_brdgmm_kernel_base_t::batch_pad_kernel( for_(int m_i = 0; m_i < m_blocks; ++m_i) for (int i = 0; i < n_e; ++i) { const int n_i = nb_i + i; - if (get_substep_simd(n_i, 0, has_tail) <= 0) continue; + const int substep_simd = get_substep_simd(n_i, 0, has_tail); + if (substep_simd <= 0) continue; const Vmm vmm_acc = accm(m_blocks, n_blocks, m_i, n_i, 0); - comp_dot_product(kernel_type, vmm_acc, vmm_b(i)); + const bool is_tail_block + = n_i + 1 == n_e && has_tail && substep_simd < simd_w_; + comp_dot_product( + kernel_type, vmm_acc, vmm_b(i), n_i, is_tail_block); } } }; diff --git a/src/cpu/x64/brgemm/jit_brdgmm_kernel.hpp b/src/cpu/x64/brgemm/jit_brdgmm_kernel.hpp index e3d6138dd5e..49d4771a9b8 100644 --- a/src/cpu/x64/brgemm/jit_brdgmm_kernel.hpp +++ b/src/cpu/x64/brgemm/jit_brdgmm_kernel.hpp @@ -230,6 +230,7 @@ struct jit_brdgmm_kernel_base_t : public jit_generator { const int simd_w_; const int max_vmms_; const bool compute_dst_zp_, compute_src_zp_; + const bool is_src_zp_bcast_; const bool compute_compensation_; // code-path for either s8s8 or src_zp const bool has_vpad_; // vertical padding w.r.t. M dimension const bool has_bpad_; // batch pad is computed for the overlap between the @@ -341,7 +342,8 @@ struct jit_brdgmm_kernel_base_t : public jit_generator { void load_b( Vmm vmmb, int n_i, int v_i, bool has_n_tail, bool wei_zp = false); void comp_dot_product(compute_pad_kernel_t kernel_type, Vmm vmm_acc, - Vmm vmmb); // int8 compensation dot_product (zp and s8s8) + Vmm vmmb, int n, + bool is_tail_block); // int8 compensation dot_product (zp and s8s8) void pad_comp_kernel(compute_pad_kernel_t kernel_type, int m_blocks, int n_blocks, int padding, const Xbyak::Reg64 reg_pad, const std::function &get_mi, bool has_tail = false); @@ -360,6 +362,7 @@ struct jit_brdgmm_kernel_base_t : public jit_generator { void apply_post_ops(int m_blocks, int n_blocks, bool has_n_tail); void maybe_transpose_interleaved_vnni_to_plain( int m_blocks, int n_blocks, bool has_n_tail); + void load_src_zp(); void compute_int8_compensation(int m_blocks, int n_blocks, bool has_n_tail); void store_accumulators(int m_blocks, int n_blocks, bool has_n_tail); void store_accumulators_without_post_ops( diff --git a/src/cpu/x64/jit_brdgmm_dw_conv.cpp b/src/cpu/x64/jit_brdgmm_dw_conv.cpp index 75a9481b2a8..e391fcaadf7 100644 --- a/src/cpu/x64/jit_brdgmm_dw_conv.cpp +++ b/src/cpu/x64/jit_brdgmm_dw_conv.cpp @@ -255,7 +255,8 @@ status_t brdgmm_dw_convolution_fwd_t::pd_t::init(engine_t *engine) { const bool params_ok = IMPLICATION(has_zero_points, utils::one_of(jcp.src_dt, u8, s8)) && IMPLICATION(jcp.src_zero_point, - attr()->zero_points_.common(DNNL_ARG_SRC)) + attr()->zero_points_.common(DNNL_ARG_SRC) + || attr()->zero_points_.per_dim_1(DNNL_ARG_SRC)) && IMPLICATION(jcp.dst_zero_point, attr()->zero_points_.common(DNNL_ARG_DST)); VDISPATCH_CONV(params_ok, VERBOSE_UNSUPPORTED_ZP_CFG); @@ -583,7 +584,7 @@ status_t brdgmm_dw_convolution_fwd_t::execute(const exec_ctx_t &ctx) const { DEFINE_ARG_SCALES_BUFFER(wei_scales, DNNL_ARG_WEIGHTS); DEFINE_ARG_SCALES_BUFFER(dst_scales, DNNL_ARG_DST); - DEFINE_ZERO_POINT_VALUE(src_zero_point, DNNL_ARG_SRC); + DEFINE_ZERO_POINTS_BUFFER(src_zero_point, DNNL_ARG_SRC); DEFINE_ZERO_POINTS_BUFFER(dst_zero_point, DNNL_ARG_DST); const int wei_scale_mask @@ -753,8 +754,11 @@ status_t brdgmm_dw_convolution_fwd_t::execute(const exec_ctx_t &ctx) const { post_ops_data.scales = &oscales[jcp.is_oc_scale * ch]; post_ops_data.oc_logical_off = ch; post_ops_data.dst_scales = dst_scales; - post_ops_data.zp_a_val - = jcp.src_zero_point ? src_zero_point : 1; + const bool is_bcast_zp + = pd()->attr()->zero_points_.common(DNNL_ARG_SRC); + post_ops_data.a_zp_values = jcp.src_zero_point + ? src_zero_point + ch * !is_bcast_zp + : nullptr; post_ops_data.c_zp_values = jcp.dst_zero_point ? dst_zero_point : nullptr; post_ops_data.a_zp_compensations