Skip to content

Commit

Permalink
Break the dependency between qd8-f32-qc8w gemm & igemms and enable th…
Browse files Browse the repository at this point in the history
…e optimal gemm assembly microkernel on vnni machines

PiperOrigin-RevId: 727847132
  • Loading branch information
alankelly authored and xnnpack-bot committed Feb 17, 2025
1 parent e27a4b0 commit b33f8b5
Show file tree
Hide file tree
Showing 9 changed files with 96 additions and 26 deletions.
6 changes: 3 additions & 3 deletions cmake/gen/amd64_microkernels.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,9 @@ SET(PROD_AMD64_ASM_MICROKERNEL_SRCS
src/bf16-f32-gemm/gen/bf16-f32-gemm-1x32c2-minmax-asm-amd64-avx512bf16-broadcast.S
src/bf16-f32-gemm/gen/bf16-f32-gemm-11x32c2-minmax-asm-amd64-avx512bf16-broadcast.S
src/f32-gemm/gen/f32-gemm-1x32c2-minmax-asm-amd64-avx512f-broadcast.S
src/f32-gemm/gen/f32-gemm-5x32c2-minmax-asm-amd64-avx512f-broadcast.S)
src/f32-gemm/gen/f32-gemm-5x32c2-minmax-asm-amd64-avx512f-broadcast.S
src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-1x64-minmax-asm-amd64-avx512vnni.S
src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-5x64-minmax-asm-amd64-avx512vnni.S)

SET(NON_PROD_AMD64_ASM_MICROKERNEL_SRCS
src/bf16-f32-gemm/gen/bf16-f32-gemm-1x16c2-minmax-asm-amd64-avx512bf16-broadcast.S
Expand Down Expand Up @@ -84,7 +86,6 @@ SET(NON_PROD_AMD64_ASM_MICROKERNEL_SRCS
src/f32-gemm/gen/f32-gemm-11x32-minmax-asm-amd64-avx512f-broadcast.S
src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-1x16-minmax-asm-amd64-avx512vnni.S
src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-1x32-minmax-asm-amd64-avx512vnni.S
src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-1x64-minmax-asm-amd64-avx512vnni.S
src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-2x16-minmax-asm-amd64-avx512vnni.S
src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-2x32-minmax-asm-amd64-avx512vnni.S
src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-2x64-minmax-asm-amd64-avx512vnni.S
Expand All @@ -96,7 +97,6 @@ SET(NON_PROD_AMD64_ASM_MICROKERNEL_SRCS
src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-4x64-minmax-asm-amd64-avx512vnni.S
src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-5x16-minmax-asm-amd64-avx512vnni.S
src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-5x32-minmax-asm-amd64-avx512vnni.S
src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-5x64-minmax-asm-amd64-avx512vnni.S
src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-6x16-minmax-asm-amd64-avx512vnni.S
src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-6x32-minmax-asm-amd64-avx512vnni.S
src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-7x16-minmax-asm-amd64-avx512vnni.S
Expand Down
4 changes: 2 additions & 2 deletions cmake/gen/avx512vnni_microkernels.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,6 @@ SET(PROD_AVX512VNNI_MICROKERNEL_SRCS
src/qd8-f32-qb4w-gemm/gen/qd8-f32-qb4w-gemm-8x16c8-minmax-avx512vnni-prfm.c
src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-1x16c8-minmax-avx512vnni-prfm.c
src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-8x16c8-minmax-avx512vnni-prfm.c
src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-1x16c8-minmax-avx512vnni-prfm.c
src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-10x16c8-minmax-avx512vnni-prfm.c
src/qd8-f32-qc8w-igemm/gen/qd8-f32-qc8w-igemm-1x16c8-minmax-avx512vnni-prfm.c
src/qd8-f32-qc8w-igemm/gen/qd8-f32-qc8w-igemm-10x16c8-minmax-avx512vnni-prfm.c
src/qs8-qc8w-gemm/gen/qs8-qc8w-gemm-1x16c8-minmax-fp32-avx512vnni-prfm.c
Expand Down Expand Up @@ -73,6 +71,7 @@ SET(NON_PROD_AVX512VNNI_MICROKERNEL_SRCS
src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-14x16c8-minmax-avx512vnni.c
src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-1x16c4-minmax-avx512vnni-prfm.c
src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-1x16c4-minmax-avx512vnni.c
src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-1x16c8-minmax-avx512vnni-prfm.c
src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-1x16c8-minmax-avx512vnni.c
src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-4x16c4-minmax-avx512vnni-prfm.c
src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-4x16c4-minmax-avx512vnni.c
Expand All @@ -94,6 +93,7 @@ SET(NON_PROD_AVX512VNNI_MICROKERNEL_SRCS
src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-9x16c8-minmax-avx512vnni.c
src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-10x16c4-minmax-avx512vnni-prfm.c
src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-10x16c4-minmax-avx512vnni.c
src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-10x16c8-minmax-avx512vnni-prfm.c
src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-10x16c8-minmax-avx512vnni.c
src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-12x16c4-minmax-avx512vnni-prfm.c
src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-12x16c4-minmax-avx512vnni.c
Expand Down
4 changes: 2 additions & 2 deletions gen/amd64_microkernels.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@ PROD_AMD64_ASM_MICROKERNEL_SRCS = [
"src/bf16-f32-gemm/gen/bf16-f32-gemm-11x32c2-minmax-asm-amd64-avx512bf16-broadcast.S",
"src/f32-gemm/gen/f32-gemm-1x32c2-minmax-asm-amd64-avx512f-broadcast.S",
"src/f32-gemm/gen/f32-gemm-5x32c2-minmax-asm-amd64-avx512f-broadcast.S",
"src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-1x64-minmax-asm-amd64-avx512vnni.S",
"src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-5x64-minmax-asm-amd64-avx512vnni.S",
]

NON_PROD_AMD64_ASM_MICROKERNEL_SRCS = [
Expand Down Expand Up @@ -81,7 +83,6 @@ NON_PROD_AMD64_ASM_MICROKERNEL_SRCS = [
"src/f32-gemm/gen/f32-gemm-11x32-minmax-asm-amd64-avx512f-broadcast.S",
"src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-1x16-minmax-asm-amd64-avx512vnni.S",
"src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-1x32-minmax-asm-amd64-avx512vnni.S",
"src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-1x64-minmax-asm-amd64-avx512vnni.S",
"src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-2x16-minmax-asm-amd64-avx512vnni.S",
"src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-2x32-minmax-asm-amd64-avx512vnni.S",
"src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-2x64-minmax-asm-amd64-avx512vnni.S",
Expand All @@ -93,7 +94,6 @@ NON_PROD_AMD64_ASM_MICROKERNEL_SRCS = [
"src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-4x64-minmax-asm-amd64-avx512vnni.S",
"src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-5x16-minmax-asm-amd64-avx512vnni.S",
"src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-5x32-minmax-asm-amd64-avx512vnni.S",
"src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-5x64-minmax-asm-amd64-avx512vnni.S",
"src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-6x16-minmax-asm-amd64-avx512vnni.S",
"src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-6x32-minmax-asm-amd64-avx512vnni.S",
"src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-7x16-minmax-asm-amd64-avx512vnni.S",
Expand Down
4 changes: 2 additions & 2 deletions gen/avx512vnni_microkernels.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,6 @@ PROD_AVX512VNNI_MICROKERNEL_SRCS = [
"src/qd8-f32-qb4w-gemm/gen/qd8-f32-qb4w-gemm-8x16c8-minmax-avx512vnni-prfm.c",
"src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-1x16c8-minmax-avx512vnni-prfm.c",
"src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-8x16c8-minmax-avx512vnni-prfm.c",
"src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-1x16c8-minmax-avx512vnni-prfm.c",
"src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-10x16c8-minmax-avx512vnni-prfm.c",
"src/qd8-f32-qc8w-igemm/gen/qd8-f32-qc8w-igemm-1x16c8-minmax-avx512vnni-prfm.c",
"src/qd8-f32-qc8w-igemm/gen/qd8-f32-qc8w-igemm-10x16c8-minmax-avx512vnni-prfm.c",
"src/qs8-qc8w-gemm/gen/qs8-qc8w-gemm-1x16c8-minmax-fp32-avx512vnni-prfm.c",
Expand Down Expand Up @@ -70,6 +68,7 @@ NON_PROD_AVX512VNNI_MICROKERNEL_SRCS = [
"src/qd8-f32-qc4w-gemm/gen/qd8-f32-qc4w-gemm-14x16c8-minmax-avx512vnni.c",
"src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-1x16c4-minmax-avx512vnni-prfm.c",
"src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-1x16c4-minmax-avx512vnni.c",
"src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-1x16c8-minmax-avx512vnni-prfm.c",
"src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-1x16c8-minmax-avx512vnni.c",
"src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-4x16c4-minmax-avx512vnni-prfm.c",
"src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-4x16c4-minmax-avx512vnni.c",
Expand All @@ -91,6 +90,7 @@ NON_PROD_AVX512VNNI_MICROKERNEL_SRCS = [
"src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-9x16c8-minmax-avx512vnni.c",
"src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-10x16c4-minmax-avx512vnni-prfm.c",
"src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-10x16c4-minmax-avx512vnni.c",
"src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-10x16c8-minmax-avx512vnni-prfm.c",
"src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-10x16c8-minmax-avx512vnni.c",
"src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-12x16c4-minmax-avx512vnni-prfm.c",
"src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-12x16c4-minmax-avx512vnni.c",
Expand Down
86 changes: 72 additions & 14 deletions src/configs/gemm-config.c
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ static struct xnn_gemm_config qp8_f32_qc8w_gemm_config = {0};
static struct xnn_gemm_config qp8_f32_qb4w_gemm_config = {0};
static struct xnn_gemm_config qdu8_f32_qc4w_gemm_config = {0};
static struct xnn_gemm_config qdu8_f16_qc8w_gemm_config = {0};
static struct xnn_gemm_config qdu8_f32_qc8w_igemm_config = {0};
static struct xnn_gemm_config qdu8_f32_qc8w_gemm_config = {0};
static struct xnn_gemm_config qdu8_f32_qb4w_gemm_config = {0};
static struct xnn_gemm_config qdu8_f16_qc4w_gemm_config = {0};
Expand Down Expand Up @@ -78,6 +79,7 @@ XNN_INIT_ONCE_GUARD(qp8_f32_qb4w_gemm);
XNN_INIT_ONCE_GUARD(qdu8_f32_qc4w_gemm);
XNN_INIT_ONCE_GUARD(qdu8_f16_qc8w_gemm);
XNN_INIT_ONCE_GUARD(qdu8_f32_qc8w_gemm);
XNN_INIT_ONCE_GUARD(qdu8_f32_qc8w_igemm);
XNN_INIT_ONCE_GUARD(qdu8_f32_qb4w_gemm);
XNN_INIT_ONCE_GUARD(qdu8_f16_qc4w_gemm);
XNN_INIT_ONCE_GUARD(qs8_qc8w_gemm);
Expand Down Expand Up @@ -2907,31 +2909,23 @@ static void init_qdu8_f32_qc8w_gemm_config(void) {
#if XNN_ENABLE_AVX512VNNI
if (!XNN_PLATFORM_MOBILE && hardware_config->use_x86_avx512vnni) {
qdu8_f32_qc8w_gemm_config.arch = xnn_arch_x86_avx512vnni;
qdu8_f32_qc8w_gemm_config.minmax.dqgemm[XNN_MR_TO_INDEX(1)] = xnn_init_hmp_dqgemm_ukernel((xnn_dqgemm_ukernel_fn) xnn_qd8_f32_qc8w_gemm_minmax_ukernel_1x16c8__avx512vnni_prfm);
qdu8_f32_qc8w_gemm_config.minmax.dqgemm[XNN_MR_TO_INDEX(10)] = xnn_init_hmp_dqgemm_ukernel((xnn_dqgemm_ukernel_fn) xnn_qd8_f32_qc8w_gemm_minmax_ukernel_10x16c8__avx512vnni_prfm);
qdu8_f32_qc8w_gemm_config.minmax.dqigemm[XNN_MR_TO_INDEX(1)] = xnn_init_hmp_dqigemm_ukernel((xnn_dqigemm_ukernel_fn) xnn_qd8_f32_qc8w_igemm_minmax_ukernel_1x16c8__avx512vnni_prfm);
qdu8_f32_qc8w_gemm_config.minmax.dqigemm[XNN_MR_TO_INDEX(10)] = xnn_init_hmp_dqigemm_ukernel((xnn_dqigemm_ukernel_fn) xnn_qd8_f32_qc8w_igemm_minmax_ukernel_10x16c8__avx512vnni_prfm);
qdu8_f32_qc8w_gemm_config.minmax.dqgemm[XNN_MR_TO_INDEX(1)] = xnn_init_hmp_dqgemm_ukernel((xnn_dqgemm_ukernel_fn) xnn_qd8_f32_qc8w_gemm_minmax_ukernel_1x64c4__asm_amd64_avx512vnni);
qdu8_f32_qc8w_gemm_config.minmax.dqgemm[XNN_MR_TO_INDEX(5)] = xnn_init_hmp_dqgemm_ukernel((xnn_dqgemm_ukernel_fn) xnn_qd8_f32_qc8w_gemm_minmax_ukernel_5x64c4__asm_amd64_avx512vnni);
qdu8_f32_qc8w_gemm_config.init.f32 = xnn_init_f32_minmax_scalar_params;
qdu8_f32_qc8w_gemm_config.pack_weights_and_biases = NULL; // Override the default packing function.
qdu8_f32_qc8w_gemm_config.packed_stride_weights_and_biases = NULL; // Override the default packing function.
qdu8_f32_qc8w_gemm_config.pack_gemm_gio = (xnn_packw_gemm_gio_ukernel_fn) xnn_pack_qs8_gemm_gio_w;
#if XNN_ENABLE_AVX256VNNI
qdu8_f32_qc8w_gemm_config.pack_gemm_goi = (xnn_packw_gemm_goi_ukernel_fn) xnn_qs8_packw_gemm_goi_ukernel_x16c8__avx256vnni_prfm;
#else
qdu8_f32_qc8w_gemm_config.pack_gemm_goi = (xnn_packw_gemm_goi_ukernel_fn) xnn_qs8_packw_gemm_goi_ukernel_x16c8__scalar;
#endif
qdu8_f32_qc8w_gemm_config.mr = 10;
qdu8_f32_qc8w_gemm_config.nr = 16;
qdu8_f32_qc8w_gemm_config.log2_kr = 3;
qdu8_f32_qc8w_gemm_config.pack_gemm_goi = (xnn_packw_gemm_goi_ukernel_fn) xnn_pack_qs8_gemm_goi_w;
qdu8_f32_qc8w_gemm_config.mr = 5;
qdu8_f32_qc8w_gemm_config.nr = 64;
qdu8_f32_qc8w_gemm_config.log2_kr = 2;
} else
#endif
#if XNN_ENABLE_AVXVNNI
if (!XNN_PLATFORM_MOBILE && hardware_config->use_x86_avxvnni) {
qdu8_f32_qc8w_gemm_config.arch = xnn_arch_x86_avxvnni;
qdu8_f32_qc8w_gemm_config.minmax.dqgemm[XNN_MR_TO_INDEX(1)] = xnn_init_hmp_dqgemm_ukernel((xnn_dqgemm_ukernel_fn) xnn_qd8_f32_qc8w_gemm_minmax_ukernel_1x8c8__avxvnni_prfm);
qdu8_f32_qc8w_gemm_config.minmax.dqgemm[XNN_MR_TO_INDEX(5)] = xnn_init_hmp_dqgemm_ukernel((xnn_dqgemm_ukernel_fn) xnn_qd8_f32_qc8w_gemm_minmax_ukernel_5x8c8__avxvnni_prfm);
qdu8_f32_qc8w_gemm_config.minmax.dqigemm[XNN_MR_TO_INDEX(1)] = xnn_init_hmp_dqigemm_ukernel((xnn_dqigemm_ukernel_fn) xnn_qd8_f32_qc8w_igemm_minmax_ukernel_1x8c8__avxvnni_prfm);
qdu8_f32_qc8w_gemm_config.minmax.dqigemm[XNN_MR_TO_INDEX(5)] = xnn_init_hmp_dqigemm_ukernel((xnn_dqigemm_ukernel_fn) xnn_qd8_f32_qc8w_igemm_minmax_ukernel_5x8c8__avxvnni_prfm);
qdu8_f32_qc8w_gemm_config.init.f32 = xnn_init_f32_minmax_scalar_params;
qdu8_f32_qc8w_gemm_config.pack_weights_and_biases = NULL; // Override the default packing function.
qdu8_f32_qc8w_gemm_config.packed_stride_weights_and_biases = NULL; // Override the default packing function.
Expand All @@ -2951,6 +2945,62 @@ static void init_qdu8_f32_qc8w_gemm_config(void) {
#endif //XNN_ARCH_X86 || XNN_ARCH_X86_64
}

static void init_qdu8_f32_qc8w_igemm_config(void) {
// Use the same packing function throughout.
qdu8_f32_qc8w_igemm_config.pack_weights_and_biases =
(xnn_pack_weights_and_biases_fn)xnn_pack_qs8_weights_and_biases;
qdu8_f32_qc8w_igemm_config.packed_stride_weights_and_biases =
(xnn_packed_stride_weights_and_biases_fn)
xnn_packed_stride_qs8_weights_and_biases;
qdu8_f32_qc8w_igemm_config.pack_gemm_gio = (xnn_packw_gemm_gio_ukernel_fn) xnn_pack_qs8_gemm_gio_w;
qdu8_f32_qc8w_igemm_config.pack_gemm_goi = (xnn_packw_gemm_goi_ukernel_fn) xnn_pack_qs8_gemm_goi_w;
#if XNN_ARCH_X86 || XNN_ARCH_X86_64
const struct xnn_hardware_config* hardware_config = xnn_init_hardware_config();
assert(hardware_config != NULL);
(void) hardware_config; // May be unused.
#if XNN_ENABLE_AVX512VNNI
if (!XNN_PLATFORM_MOBILE && hardware_config->use_x86_avx512vnni) {
qdu8_f32_qc8w_igemm_config.arch = xnn_arch_x86_avx512vnni;
qdu8_f32_qc8w_igemm_config.minmax.dqigemm[XNN_MR_TO_INDEX(1)] = xnn_init_hmp_dqigemm_ukernel((xnn_dqigemm_ukernel_fn) xnn_qd8_f32_qc8w_igemm_minmax_ukernel_1x16c8__avx512vnni_prfm);
qdu8_f32_qc8w_igemm_config.minmax.dqigemm[XNN_MR_TO_INDEX(10)] = xnn_init_hmp_dqigemm_ukernel((xnn_dqigemm_ukernel_fn) xnn_qd8_f32_qc8w_igemm_minmax_ukernel_10x16c8__avx512vnni_prfm);
qdu8_f32_qc8w_igemm_config.init.f32 = xnn_init_f32_minmax_scalar_params;
qdu8_f32_qc8w_igemm_config.pack_weights_and_biases = NULL; // Override the default packing function.
qdu8_f32_qc8w_igemm_config.packed_stride_weights_and_biases = NULL; // Override the default packing function.
qdu8_f32_qc8w_igemm_config.pack_gemm_gio = (xnn_packw_gemm_gio_ukernel_fn) xnn_pack_qs8_gemm_gio_w;
#if XNN_ENABLE_AVX256VNNI
qdu8_f32_qc8w_igemm_config.pack_gemm_goi = (xnn_packw_gemm_goi_ukernel_fn) xnn_qs8_packw_gemm_goi_ukernel_x16c8__avx256vnni_prfm;
#else
qdu8_f32_qc8w_igemm_config.pack_gemm_goi = (xnn_packw_gemm_goi_ukernel_fn) xnn_qs8_packw_gemm_goi_ukernel_x16c8__scalar;
#endif
qdu8_f32_qc8w_igemm_config.mr = 10;
qdu8_f32_qc8w_igemm_config.nr = 16;
qdu8_f32_qc8w_igemm_config.log2_kr = 3;
} else
#endif
#if XNN_ENABLE_AVXVNNI
if (!XNN_PLATFORM_MOBILE && hardware_config->use_x86_avxvnni) {
qdu8_f32_qc8w_igemm_config.arch = xnn_arch_x86_avxvnni;
qdu8_f32_qc8w_igemm_config.minmax.dqigemm[XNN_MR_TO_INDEX(1)] = xnn_init_hmp_dqigemm_ukernel((xnn_dqigemm_ukernel_fn) xnn_qd8_f32_qc8w_igemm_minmax_ukernel_1x8c8__avxvnni_prfm);
qdu8_f32_qc8w_igemm_config.minmax.dqigemm[XNN_MR_TO_INDEX(5)] = xnn_init_hmp_dqigemm_ukernel((xnn_dqigemm_ukernel_fn) xnn_qd8_f32_qc8w_igemm_minmax_ukernel_5x8c8__avxvnni_prfm);
qdu8_f32_qc8w_igemm_config.init.f32 = xnn_init_f32_minmax_scalar_params;
qdu8_f32_qc8w_igemm_config.pack_weights_and_biases = NULL; // Override the default packing function.
qdu8_f32_qc8w_igemm_config.packed_stride_weights_and_biases = NULL; // Override the default packing function.
qdu8_f32_qc8w_igemm_config.pack_gemm_gio = (xnn_packw_gemm_gio_ukernel_fn) xnn_pack_qs8_gemm_gio_w;
qdu8_f32_qc8w_igemm_config.pack_gemm_goi = (xnn_packw_gemm_goi_ukernel_fn) xnn_qs8_packw_gemm_goi_ukernel_x8c8__avxvnni_prfm;
qdu8_f32_qc8w_igemm_config.mr = 5;
qdu8_f32_qc8w_igemm_config.nr = 8;
qdu8_f32_qc8w_igemm_config.log2_kr = 3;
}
#else
{
;
}
#endif
assert(qdu8_f32_qc8w_igemm_config.mr <= XNN_MAX_MR);
assert(qdu8_f32_qc8w_igemm_config.mr <= (XNN_EXTRA_QUANTIZATION_PARAMS + 1));
#endif //XNN_ARCH_X86 || XNN_ARCH_X86_64
}

static void init_qdu8_f32_qc4w_gemm_config(void) {
// Use the same packing function throughout.
qdu8_f32_qc4w_gemm_config.pack_weights_and_biases = (xnn_pack_weights_and_biases_fn) xnn_pack_qs4_weights_and_biases;
Expand Down Expand Up @@ -4866,6 +4916,14 @@ const struct xnn_gemm_config* xnn_init_qdu8_f16_qc8w_gemm_config() {
return &qdu8_f16_qc8w_gemm_config;
}

const struct xnn_gemm_config* xnn_init_qdu8_f32_qc8w_igemm_config() {
if (xnn_init_hardware_config() == NULL) {
return NULL;
}
XNN_INIT_ONCE(qdu8_f32_qc8w_igemm);
return &qdu8_f32_qc8w_igemm_config;
}

const struct xnn_gemm_config* xnn_init_qdu8_f32_qc8w_gemm_config() {
if (xnn_init_hardware_config() == NULL) {
return NULL;
Expand Down
2 changes: 1 addition & 1 deletion src/operators/convolution-nhwc.c
Original file line number Diff line number Diff line change
Expand Up @@ -1059,7 +1059,7 @@ enum xnn_status xnn_create_convolution2d_nhwc_qdu8_f32_qc8w(
xnn_weights_cache_t weights_cache,
xnn_operator_t* convolution_op_out)
{
const struct xnn_gemm_config* gemm_config = xnn_init_qdu8_f32_qc8w_gemm_config();
const struct xnn_gemm_config* gemm_config = xnn_init_qdu8_f32_qc8w_igemm_config();
return create_convolution2d_nhwc_qx8_f32_qc8w(input_padding_top,
input_padding_right,
input_padding_bottom,
Expand Down
2 changes: 1 addition & 1 deletion src/operators/deconvolution-nhwc.c
Original file line number Diff line number Diff line change
Expand Up @@ -1047,7 +1047,7 @@ enum xnn_status xnn_create_deconvolution2d_nhwc_qdu8_f32_qc8w(
xnn_weights_cache_t weights_cache,
xnn_operator_t* deconvolution_op_out)
{
const struct xnn_gemm_config* gemm_config = xnn_init_qdu8_f32_qc8w_gemm_config();
const struct xnn_gemm_config* gemm_config = xnn_init_qdu8_f32_qc8w_igemm_config();
return create_deconvolution2d_nhwc_qx8_f32_qc8w(output_padding_top, output_padding_right,
output_padding_bottom, output_padding_left,
kernel_height, kernel_width,
Expand Down
13 changes: 12 additions & 1 deletion src/subgraph.c
Original file line number Diff line number Diff line change
Expand Up @@ -1605,7 +1605,18 @@ void xnn_subgraph_optimize_dynamic_quantization_ops(xnn_subgraph_t subgraph) {
unsigned_config = xnn_init_qdu8_f32_qc4w_gemm_config();
} else if (weights_type == xnn_weights_type_qc8w) {
original_config = xnn_init_qd8_f32_qc8w_gemm_config();
unsigned_config = xnn_init_qdu8_f32_qc8w_gemm_config();
switch (consumer_type) {
case xnn_consumer_type_batch_mat_mul:
case xnn_consumer_type_fully_connected:
unsigned_config = xnn_init_qdu8_f32_qc8w_gemm_config();
break;
case xnn_consumer_type_convolution_2d:
case xnn_consumer_type_deconvolution:
unsigned_config = xnn_init_qdu8_f32_qc8w_igemm_config();
break;
default:
XNN_UNREACHABLE;
}
} else if (weights_type == xnn_weights_type_qb4w) {
original_config = xnn_init_qd8_f32_qb4w_gemm_config();
unsigned_config = xnn_init_qdu8_f32_qb4w_gemm_config();
Expand Down
1 change: 1 addition & 0 deletions src/xnnpack/config.h
Original file line number Diff line number Diff line change
Expand Up @@ -262,6 +262,7 @@ XNN_INTERNAL const struct xnn_gemm_config* xnn_init_qdu8_f16_qc8w_gemm_config();
XNN_INTERNAL const struct xnn_gemm_config* xnn_init_qdu8_f32_qc8w_gemm_config();
XNN_INTERNAL const struct xnn_gemm_config* xnn_init_qdu8_f32_qb4w_gemm_config();
XNN_INTERNAL const struct xnn_gemm_config* xnn_init_qdu8_f16_qc4w_gemm_config();
XNN_INTERNAL const struct xnn_gemm_config* xnn_init_qdu8_f32_qc8w_igemm_config();
XNN_INTERNAL const struct xnn_gemm_config* xnn_init_qd8_f16_qc8w_gemm_config();
XNN_INTERNAL const struct xnn_gemm_config* xnn_init_qs8_qc8w_gemm_config();
XNN_INTERNAL const struct xnn_gemm_config* xnn_init_qu8_gemm_config();
Expand Down

0 comments on commit b33f8b5

Please sign in to comment.