diff --git a/bench/f32-gemm-minmax.cc b/bench/f32-gemm-minmax.cc index 1ea1fd46474..ee57530ed44 100644 --- a/bench/f32-gemm-minmax.cc +++ b/bench/f32-gemm-minmax.cc @@ -1583,6 +1583,182 @@ } BENCHMARK_GEMM(f32_gemm_minmax_ukernel_5x64__asm_amd64_avx512f_broadcast) + + static void f32_gemm_minmax_ukernel_1x16c2__asm_amd64_avx512f_broadcast(benchmark::State& state, const char* net) { + GEMMBenchmark(state, + xnn_f32_gemm_minmax_ukernel_1x16c2__asm_amd64_avx512f_broadcast, + xnn_init_f32_minmax_scalar_params, + xnn_pack_f32_gemm_goi_w, + /*mr=*/1, /*nr=*/16, /*kr=*/2, /*sr=*/1, + benchmark::utils::CheckAVX512F); + } + + BENCHMARK_GEMM(f32_gemm_minmax_ukernel_1x16c2__asm_amd64_avx512f_broadcast) + + static void f32_gemm_minmax_ukernel_2x16c2__asm_amd64_avx512f_broadcast(benchmark::State& state, const char* net) { + GEMMBenchmark(state, + xnn_f32_gemm_minmax_ukernel_2x16c2__asm_amd64_avx512f_broadcast, + xnn_init_f32_minmax_scalar_params, + xnn_pack_f32_gemm_goi_w, + /*mr=*/2, /*nr=*/16, /*kr=*/2, /*sr=*/1, + benchmark::utils::CheckAVX512F); + } + + BENCHMARK_GEMM(f32_gemm_minmax_ukernel_2x16c2__asm_amd64_avx512f_broadcast) + + static void f32_gemm_minmax_ukernel_3x16c2__asm_amd64_avx512f_broadcast(benchmark::State& state, const char* net) { + GEMMBenchmark(state, + xnn_f32_gemm_minmax_ukernel_3x16c2__asm_amd64_avx512f_broadcast, + xnn_init_f32_minmax_scalar_params, + xnn_pack_f32_gemm_goi_w, + /*mr=*/3, /*nr=*/16, /*kr=*/2, /*sr=*/1, + benchmark::utils::CheckAVX512F); + } + + BENCHMARK_GEMM(f32_gemm_minmax_ukernel_3x16c2__asm_amd64_avx512f_broadcast) + + static void f32_gemm_minmax_ukernel_4x16c2__asm_amd64_avx512f_broadcast(benchmark::State& state, const char* net) { + GEMMBenchmark(state, + xnn_f32_gemm_minmax_ukernel_4x16c2__asm_amd64_avx512f_broadcast, + xnn_init_f32_minmax_scalar_params, + xnn_pack_f32_gemm_goi_w, + /*mr=*/4, /*nr=*/16, /*kr=*/2, /*sr=*/1, + benchmark::utils::CheckAVX512F); + } + + BENCHMARK_GEMM(f32_gemm_minmax_ukernel_4x16c2__asm_amd64_avx512f_broadcast) + + static void f32_gemm_minmax_ukernel_5x16c2__asm_amd64_avx512f_broadcast(benchmark::State& state, const char* net) { + GEMMBenchmark(state, + xnn_f32_gemm_minmax_ukernel_5x16c2__asm_amd64_avx512f_broadcast, + xnn_init_f32_minmax_scalar_params, + xnn_pack_f32_gemm_goi_w, + /*mr=*/5, /*nr=*/16, /*kr=*/2, /*sr=*/1, + benchmark::utils::CheckAVX512F); + } + + BENCHMARK_GEMM(f32_gemm_minmax_ukernel_5x16c2__asm_amd64_avx512f_broadcast) + + static void f32_gemm_minmax_ukernel_6x16c2__asm_amd64_avx512f_broadcast(benchmark::State& state, const char* net) { + GEMMBenchmark(state, + xnn_f32_gemm_minmax_ukernel_6x16c2__asm_amd64_avx512f_broadcast, + xnn_init_f32_minmax_scalar_params, + xnn_pack_f32_gemm_goi_w, + /*mr=*/6, /*nr=*/16, /*kr=*/2, /*sr=*/1, + benchmark::utils::CheckAVX512F); + } + + BENCHMARK_GEMM(f32_gemm_minmax_ukernel_6x16c2__asm_amd64_avx512f_broadcast) + + static void f32_gemm_minmax_ukernel_7x16c2__asm_amd64_avx512f_broadcast(benchmark::State& state, const char* net) { + GEMMBenchmark(state, + xnn_f32_gemm_minmax_ukernel_7x16c2__asm_amd64_avx512f_broadcast, + xnn_init_f32_minmax_scalar_params, + xnn_pack_f32_gemm_goi_w, + /*mr=*/7, /*nr=*/16, /*kr=*/2, /*sr=*/1, + benchmark::utils::CheckAVX512F); + } + + BENCHMARK_GEMM(f32_gemm_minmax_ukernel_7x16c2__asm_amd64_avx512f_broadcast) + + static void f32_gemm_minmax_ukernel_8x16c2__asm_amd64_avx512f_broadcast(benchmark::State& state, const char* net) { + GEMMBenchmark(state, + xnn_f32_gemm_minmax_ukernel_8x16c2__asm_amd64_avx512f_broadcast, + xnn_init_f32_minmax_scalar_params, + xnn_pack_f32_gemm_goi_w, + /*mr=*/8, /*nr=*/16, /*kr=*/2, /*sr=*/1, + benchmark::utils::CheckAVX512F); + } + + BENCHMARK_GEMM(f32_gemm_minmax_ukernel_8x16c2__asm_amd64_avx512f_broadcast) + + static void f32_gemm_minmax_ukernel_9x16c2__asm_amd64_avx512f_broadcast(benchmark::State& state, const char* net) { + GEMMBenchmark(state, + xnn_f32_gemm_minmax_ukernel_9x16c2__asm_amd64_avx512f_broadcast, + xnn_init_f32_minmax_scalar_params, + xnn_pack_f32_gemm_goi_w, + /*mr=*/9, /*nr=*/16, /*kr=*/2, /*sr=*/1, + benchmark::utils::CheckAVX512F); + } + + BENCHMARK_GEMM(f32_gemm_minmax_ukernel_9x16c2__asm_amd64_avx512f_broadcast) + + static void f32_gemm_minmax_ukernel_10x16c2__asm_amd64_avx512f_broadcast(benchmark::State& state, const char* net) { + GEMMBenchmark(state, + xnn_f32_gemm_minmax_ukernel_10x16c2__asm_amd64_avx512f_broadcast, + xnn_init_f32_minmax_scalar_params, + xnn_pack_f32_gemm_goi_w, + /*mr=*/10, /*nr=*/16, /*kr=*/2, /*sr=*/1, + benchmark::utils::CheckAVX512F); + } + + BENCHMARK_GEMM(f32_gemm_minmax_ukernel_10x16c2__asm_amd64_avx512f_broadcast) + + static void f32_gemm_minmax_ukernel_11x16c2__asm_amd64_avx512f_broadcast(benchmark::State& state, const char* net) { + GEMMBenchmark(state, + xnn_f32_gemm_minmax_ukernel_11x16c2__asm_amd64_avx512f_broadcast, + xnn_init_f32_minmax_scalar_params, + xnn_pack_f32_gemm_goi_w, + /*mr=*/11, /*nr=*/16, /*kr=*/2, /*sr=*/1, + benchmark::utils::CheckAVX512F); + } + + BENCHMARK_GEMM(f32_gemm_minmax_ukernel_11x16c2__asm_amd64_avx512f_broadcast) + + static void f32_gemm_minmax_ukernel_1x32c2__asm_amd64_avx512f_broadcast(benchmark::State& state, const char* net) { + GEMMBenchmark(state, + xnn_f32_gemm_minmax_ukernel_1x32c2__asm_amd64_avx512f_broadcast, + xnn_init_f32_minmax_scalar_params, + xnn_pack_f32_gemm_goi_w, + /*mr=*/1, /*nr=*/32, /*kr=*/2, /*sr=*/1, + benchmark::utils::CheckAVX512F); + } + + BENCHMARK_GEMM(f32_gemm_minmax_ukernel_1x32c2__asm_amd64_avx512f_broadcast) + + static void f32_gemm_minmax_ukernel_2x32c2__asm_amd64_avx512f_broadcast(benchmark::State& state, const char* net) { + GEMMBenchmark(state, + xnn_f32_gemm_minmax_ukernel_2x32c2__asm_amd64_avx512f_broadcast, + xnn_init_f32_minmax_scalar_params, + xnn_pack_f32_gemm_goi_w, + /*mr=*/2, /*nr=*/32, /*kr=*/2, /*sr=*/1, + benchmark::utils::CheckAVX512F); + } + + BENCHMARK_GEMM(f32_gemm_minmax_ukernel_2x32c2__asm_amd64_avx512f_broadcast) + + static void f32_gemm_minmax_ukernel_3x32c2__asm_amd64_avx512f_broadcast(benchmark::State& state, const char* net) { + GEMMBenchmark(state, + xnn_f32_gemm_minmax_ukernel_3x32c2__asm_amd64_avx512f_broadcast, + xnn_init_f32_minmax_scalar_params, + xnn_pack_f32_gemm_goi_w, + /*mr=*/3, /*nr=*/32, /*kr=*/2, /*sr=*/1, + benchmark::utils::CheckAVX512F); + } + + BENCHMARK_GEMM(f32_gemm_minmax_ukernel_3x32c2__asm_amd64_avx512f_broadcast) + + static void f32_gemm_minmax_ukernel_4x32c2__asm_amd64_avx512f_broadcast(benchmark::State& state, const char* net) { + GEMMBenchmark(state, + xnn_f32_gemm_minmax_ukernel_4x32c2__asm_amd64_avx512f_broadcast, + xnn_init_f32_minmax_scalar_params, + xnn_pack_f32_gemm_goi_w, + /*mr=*/4, /*nr=*/32, /*kr=*/2, /*sr=*/1, + benchmark::utils::CheckAVX512F); + } + + BENCHMARK_GEMM(f32_gemm_minmax_ukernel_4x32c2__asm_amd64_avx512f_broadcast) + + static void f32_gemm_minmax_ukernel_5x32c2__asm_amd64_avx512f_broadcast(benchmark::State& state, const char* net) { + GEMMBenchmark(state, + xnn_f32_gemm_minmax_ukernel_5x32c2__asm_amd64_avx512f_broadcast, + xnn_init_f32_minmax_scalar_params, + xnn_pack_f32_gemm_goi_w, + /*mr=*/5, /*nr=*/32, /*kr=*/2, /*sr=*/1, + benchmark::utils::CheckAVX512F); + } + + BENCHMARK_GEMM(f32_gemm_minmax_ukernel_5x32c2__asm_amd64_avx512f_broadcast) #endif // XNN_ENABLE_AVX512F && XNN_ARCH_X86_64 && XNN_ENABLE_ASSEMBLY diff --git a/cmake/gen/amd64_microkernels.cmake b/cmake/gen/amd64_microkernels.cmake index 61d20574ef8..909b2fbf932 100644 --- a/cmake/gen/amd64_microkernels.cmake +++ b/cmake/gen/amd64_microkernels.cmake @@ -11,7 +11,9 @@ SET(PROD_AMD64_ASM_MICROKERNEL_SRCS src/bf16-f32-gemm/gen/bf16-f32-gemm-1x32c2-minmax-asm-amd64-avx512bf16-broadcast.S - src/bf16-f32-gemm/gen/bf16-f32-gemm-7x32c2-minmax-asm-amd64-avx512bf16-broadcast.S) + src/bf16-f32-gemm/gen/bf16-f32-gemm-7x32c2-minmax-asm-amd64-avx512bf16-broadcast.S + src/f32-gemm/gen/f32-gemm-1x32c2-minmax-asm-amd64-avx512f-broadcast.S + src/f32-gemm/gen/f32-gemm-5x32c2-minmax-asm-amd64-avx512f-broadcast.S) SET(NON_PROD_AMD64_ASM_MICROKERNEL_SRCS src/bf16-f32-gemm/gen/bf16-f32-gemm-1x16c2-minmax-asm-amd64-avx512bf16-broadcast.S @@ -40,31 +42,45 @@ SET(NON_PROD_AMD64_ASM_MICROKERNEL_SRCS src/bf16-f32-gemm/gen/bf16-f32-gemm-11x16c2-minmax-asm-amd64-avx512bf16-broadcast.S src/bf16-f32-gemm/gen/bf16-f32-gemm-11x32c2-minmax-asm-amd64-avx512bf16-broadcast.S src/f32-gemm/gen/f32-gemm-1x16-minmax-asm-amd64-avx512f-broadcast.S + src/f32-gemm/gen/f32-gemm-1x16c2-minmax-asm-amd64-avx512f-broadcast.S src/f32-gemm/gen/f32-gemm-1x32-minmax-asm-amd64-avx512f-broadcast.S src/f32-gemm/gen/f32-gemm-1x64-minmax-asm-amd64-avx512f-broadcast.S src/f32-gemm/gen/f32-gemm-2x16-minmax-asm-amd64-avx512f-broadcast.S + src/f32-gemm/gen/f32-gemm-2x16c2-minmax-asm-amd64-avx512f-broadcast.S src/f32-gemm/gen/f32-gemm-2x32-minmax-asm-amd64-avx512f-broadcast.S + src/f32-gemm/gen/f32-gemm-2x32c2-minmax-asm-amd64-avx512f-broadcast.S src/f32-gemm/gen/f32-gemm-2x64-minmax-asm-amd64-avx512f-broadcast.S src/f32-gemm/gen/f32-gemm-3x16-minmax-asm-amd64-avx512f-broadcast.S + src/f32-gemm/gen/f32-gemm-3x16c2-minmax-asm-amd64-avx512f-broadcast.S src/f32-gemm/gen/f32-gemm-3x32-minmax-asm-amd64-avx512f-broadcast.S + src/f32-gemm/gen/f32-gemm-3x32c2-minmax-asm-amd64-avx512f-broadcast.S src/f32-gemm/gen/f32-gemm-3x64-minmax-asm-amd64-avx512f-broadcast.S src/f32-gemm/gen/f32-gemm-4x16-minmax-asm-amd64-avx512f-broadcast.S + src/f32-gemm/gen/f32-gemm-4x16c2-minmax-asm-amd64-avx512f-broadcast.S src/f32-gemm/gen/f32-gemm-4x32-minmax-asm-amd64-avx512f-broadcast.S + src/f32-gemm/gen/f32-gemm-4x32c2-minmax-asm-amd64-avx512f-broadcast.S src/f32-gemm/gen/f32-gemm-4x64-minmax-asm-amd64-avx512f-broadcast.S src/f32-gemm/gen/f32-gemm-5x16-minmax-asm-amd64-avx512f-broadcast.S + src/f32-gemm/gen/f32-gemm-5x16c2-minmax-asm-amd64-avx512f-broadcast.S src/f32-gemm/gen/f32-gemm-5x32-minmax-asm-amd64-avx512f-broadcast.S src/f32-gemm/gen/f32-gemm-5x64-minmax-asm-amd64-avx512f-broadcast.S src/f32-gemm/gen/f32-gemm-6x16-minmax-asm-amd64-avx512f-broadcast.S + src/f32-gemm/gen/f32-gemm-6x16c2-minmax-asm-amd64-avx512f-broadcast.S src/f32-gemm/gen/f32-gemm-6x32-minmax-asm-amd64-avx512f-broadcast.S src/f32-gemm/gen/f32-gemm-7x16-minmax-asm-amd64-avx512f-broadcast.S + src/f32-gemm/gen/f32-gemm-7x16c2-minmax-asm-amd64-avx512f-broadcast.S src/f32-gemm/gen/f32-gemm-7x32-minmax-asm-amd64-avx512f-broadcast.S src/f32-gemm/gen/f32-gemm-8x16-minmax-asm-amd64-avx512f-broadcast.S + src/f32-gemm/gen/f32-gemm-8x16c2-minmax-asm-amd64-avx512f-broadcast.S src/f32-gemm/gen/f32-gemm-8x32-minmax-asm-amd64-avx512f-broadcast.S src/f32-gemm/gen/f32-gemm-9x16-minmax-asm-amd64-avx512f-broadcast.S + src/f32-gemm/gen/f32-gemm-9x16c2-minmax-asm-amd64-avx512f-broadcast.S src/f32-gemm/gen/f32-gemm-9x32-minmax-asm-amd64-avx512f-broadcast.S src/f32-gemm/gen/f32-gemm-10x16-minmax-asm-amd64-avx512f-broadcast.S + src/f32-gemm/gen/f32-gemm-10x16c2-minmax-asm-amd64-avx512f-broadcast.S src/f32-gemm/gen/f32-gemm-10x32-minmax-asm-amd64-avx512f-broadcast.S src/f32-gemm/gen/f32-gemm-11x16-minmax-asm-amd64-avx512f-broadcast.S + src/f32-gemm/gen/f32-gemm-11x16c2-minmax-asm-amd64-avx512f-broadcast.S src/f32-gemm/gen/f32-gemm-11x32-minmax-asm-amd64-avx512f-broadcast.S src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-1x16-minmax-asm-amd64-avx512vnni.S src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-1x32-minmax-asm-amd64-avx512vnni.S diff --git a/gemm_compiler/avx512bf16_template.py b/gemm_compiler/avx512bf16_template.py index b5d7415d10e..fb787573ae8 100644 --- a/gemm_compiler/avx512bf16_template.py +++ b/gemm_compiler/avx512bf16_template.py @@ -12,7 +12,7 @@ class Avx512Bf16(isa.Avx512F): def __init__(self): - pass # Empty constructor + self.c = 2 def isa(self): return 'avx512bf16' @@ -33,27 +33,28 @@ def compute_asm(self): def function_name(self, M, N, isa): return f'xnn_bf16_f32_gemm_minmax_ukernel_{M}x{N}c2__asm_amd64_{isa}_broadcast' + def init_accumulators(self, M, N): + asm_string = super().init_accumulators(M, N) + asm_string += """ + # Are there at least 4 bytes? + cmp rdx, 4 + js inner_loop_tail\n""" + + return asm_string + def outer_loop_prepare(self, M, N): k_register = self.k_register() kc_register = self.kc_register() offset = M * 16 + self.c_ptr_stack_offset() + kmask = self.k_mask() asm_string = f""" # Copy k and flip bit. mov {k_register}, rdx and {k_register}, 0x2 - and {kc_register}, 0xFFFFFFFFFFFFFFFD + and {kc_register}, {kmask} mov [rsp + {offset}], {k_register}\n""" return asm_string - def init_accumulators(self, M, N): - asm_string = super().init_accumulators(M, N) - asm_string += """ - # Are there at least 4 bytes? - cmp rdx, 4 - js inner_loop_tail\n""" - - return asm_string - def inner_loop_tail(self, M, N): k_register = self.k_register() nc_register = self.nc_register() @@ -75,3 +76,9 @@ def inner_loop_tail(self, M, N): else: asm_string += self.inner_loop_small_M_N(M=M, N=N, tail=True) return asm_string + + def element_size(self): + return 2 + + def k_mask(self): + return "0xFFFFFFFFFFFFFFFD" diff --git a/gemm_compiler/avx512f_template.py b/gemm_compiler/avx512f_template.py index 33e88a5b1b7..f3fa818f8b3 100644 --- a/gemm_compiler/avx512f_template.py +++ b/gemm_compiler/avx512f_template.py @@ -14,7 +14,7 @@ class Avx512F(isa.Fma3): def __init__(self): - pass # Empty constructor + self.c = 1 def isa(self): return 'avx512f' @@ -31,7 +31,7 @@ def a_registers(self, idx): return registers[idx] def w_registers(self): - return ['zmm10', 'zmm11', 'zmm12', 'zmm13'] + return ['zmm7', 'zmm8', 'zmm9', 'zmm10'] def n_step(self): return 16 @@ -48,9 +48,6 @@ def compute_asm(self): } return c_asm - def outer_loop_prepare(self, M, N): - return '' - def inner_loop_spill_gp(self, M, N, tail=False): N_COUNT = N // self.n_step() # weights @@ -91,6 +88,7 @@ def inner_loop_spill_gp(self, M, N, tail=False): W=self.w_registers()[nr], A=self.a_registers(0), ACC=self.acc_registers()[M * nr + mr], + mask=self.mask() ) return asm_string @@ -134,6 +132,7 @@ def inner_loop_small_M_N(self, M, N, tail=False): W=self.w_registers()[nr], A=self.a_registers(mr), ACC=self.acc_registers()[M * nr + mr], + mask=self.mask() ) return asm_string @@ -301,3 +300,180 @@ def stack_size(self, M): size = M * 16 + 64 # round up to multiple of 64. return math.ceil(size / 64) * 64 + +class Avx512FC(Avx512F): + def __init__(self, c): + self.c = c + + def input_asm(self): + in_asm = { + 'loop': [ + 'vbroadcastsd {AM}, QWORD PTR [{AM_ptr} + {a_offset}]\n', + ] + } + return in_asm + + def pre_header(self): + return '''.PERMUTATION: + .long 0 + .long 2 + .long 4 + .long 6 + .long 8 + .long 10 + .long 12 + .long 14 + .long 16 + .long 18 + .long 20 + .long 22 + .long 24 + .long 26 + .long 28 + .long 30 + ''' + + def function_name(self, M, N, isa): + c = self.c + return f'xnn_f32_gemm_minmax_ukernel_{M}x{N}c{c}__asm_amd64_{isa}_broadcast' + + def dequantize(self, M, N, W): + shift_add = '''vpsrlq {tmp}, z{acc}, 32 + vaddps z{acc}, z{acc}, {tmp}\n''' + asm_string = '' + accumulators = self.acc_registers() + for nr in range(0, N * 2): + for mr in range(0, M): + asm_string += shift_add.format(acc=accumulators[M * nr + mr], tmp=self.w_registers()[0]) + perm_reg = self.w_registers()[0] + asm_string += f'vmovups {perm_reg}, zmmword ptr [rip + .PERMUTATION]\n' + perm = 'vpermt2ps z{acc0}, {perm_reg}, z{acc1}\n' + for nr in range(0, N): + for mr in range(0, M): + asm_string += perm.format(perm_reg=perm_reg, acc0=accumulators[2 * M * nr + mr], acc1=accumulators[2 * M * nr + M + mr]) + return asm_string + + def inner_loop_increment(self): + return self.c * self.element_size() + + def k_mask(self): + return "0xFFFFFFFFFFFFFFFB" + + def init_accumulators(self, M, N): + W = self.w_ptr_register() + accumulators = self.acc_registers() + bias_registers = self.bias_registers() + bias = 'vmovaps {ACC}, [{W} + {offset}]\n' + asm_string = '' + for nr in range(0, N): + asm_string += bias.format( + W=W, ACC=bias_registers[nr], offset=self.register_bytes() * nr + ) + + c = self.c * self.element_size() + asm_string += '# Interleave with zeros.\n' + unpack_lo = 'vpmovzxdq z{acc0}, y{acc1}\n' + unpack_hi = '''vextracti64x4 y{acc1}, z{acc1}, 1 + vpmovzxdq z{acc0}, y{acc1} + ''' + for nr in range(0, N): + asm_string += unpack_lo.format( + acc0=accumulators[2 * M * nr], + acc1=bias_registers[nr][1:], + ) + asm_string += unpack_hi.format( + acc0=accumulators[2 * M * nr + M], + acc1=bias_registers[nr][1:], + ) + for nr in range(0, N * 2): + for mr in range(1, M): + asm_string += self.copy_simd_register( + prefix=self.prefix(), + src=accumulators[M * nr], + dst=accumulators[M * nr + mr], + ) + asm_string += self.increment_ptr(ptr=W, step=self.register_bytes() * N) + asm_string += f""" + # Are there at least {c} bytes? + cmp rdx, {c} + js inner_loop_tail\n""" + + return asm_string + + def inner_loop(self, M, N): + return super().inner_loop(M, N * 2) + + def inner_loop_tail(self, M, N): + k_register = self.k_register() + nc_register = self.nc_register() + offset = M * 16 + self.c_ptr_stack_offset() + nc_offset = offset + 8 + asm_string = f""" + # Store nc_register. + mov [rsp + {nc_offset}], {nc_register} + # Load odd k bit. + mov {nc_register}, [rsp + {offset}] + # Check if channels are odd. + test {nc_register}, {nc_register} + mov {nc_register}, [rsp + {nc_offset}] + jz inner_loop_end + + inner_loop_tail:\n""" + if M > self.max_M_before_spilling(): + asm_string += self.inner_loop_spill_gp(M=M, N=N, tail=True) + else: + asm_string += self.inner_loop_small_M_N(M=M, N=N, tail=True) + return asm_string + + def compute_asm(self): + c_asm = { + 'loop': ['vfmadd231ps z{ACC}, {A}, {W}\n'], + 'loop_tail': ['vfmadd231ps z{ACC}{{{mask}}}, {A}, {W}\n'], + } + return c_asm + + def outer_loop_prepare(self, M, N): + k_register = self.k_register() + kc_register = self.kc_register() + offset = M * 16 + self.c_ptr_stack_offset() + element_size = self.element_size() + k_mask = self.k_mask() + mask = self.mask() + asm_string = f""" + # Copy k and flip bit. + mov {k_register}, rdx + and {k_register}, 0x{element_size} + and {kc_register}, {k_mask} + mov [rsp + {offset}], {k_register} + mov r11, 0x5555 + kmovw {mask}, r11d\n""" + return asm_string + + def mask(self): + return 'k3' + + def bias_registers(self): + return self.w_registers() + + def clamp_min(self, reg, prefix, other_reg): + min_reg = self.max_register() + return f'vminps {prefix}{reg}, {prefix}{min_reg}, {prefix}{other_reg}\n' + + def clamp(self, M, N): + ''' + Clamp output registers while handling rotation to match standard registers. + ''' + num_horizontal_registers = int(N / self.n_step()) + acc_registers = self.acc_registers() + asm_string = '' + for nr in range(0, num_horizontal_registers): + for mr in range(0, M): + asm_string += self.clamp_min( + reg=acc_registers[M * nr + mr], prefix=self.prefix(), other_reg=acc_registers[M * nr + mr + nr * M], + ) + for nr in range(0, num_horizontal_registers): + for mr in range(0, M): + asm_string += self.clamp_max( + reg=acc_registers[M * nr + mr], prefix=self.prefix() + ) + return asm_string diff --git a/gemm_compiler/base_architecture.py b/gemm_compiler/base_architecture.py index 9133ef58eb5..8f56164ff6f 100644 --- a/gemm_compiler/base_architecture.py +++ b/gemm_compiler/base_architecture.py @@ -161,3 +161,19 @@ def epilogue(self, M, N, isa): def inner_loop(self, M, N): """Returns the assemebly for the microkernel's inner loop.""" raise NotImplementedError + + def clamp(self, M, N): + num_horizontal_registers = int(N / self.n_step()) + acc_registers = self.acc_registers() + asm_string = '' + for nr in range(0, num_horizontal_registers): + for mr in range(0, M): + asm_string += self.clamp_min( + reg=acc_registers[M * nr + mr], prefix=self.prefix() + ) + for nr in range(0, num_horizontal_registers): + for mr in range(0, M): + asm_string += self.clamp_max( + reg=acc_registers[M * nr + mr], prefix=self.prefix() + ) + return asm_string diff --git a/gemm_compiler/generate.py b/gemm_compiler/generate.py index 6151464b316..7183b13baa9 100644 --- a/gemm_compiler/generate.py +++ b/gemm_compiler/generate.py @@ -59,16 +59,7 @@ def generate_gemm_microkernel( ## min/max clamping asm_string += '# Min/max clamping.\n' - for nr in range(0, num_horizontal_registers): - for mr in range(0, M): - asm_string += isa.clamp_min( - reg=acc_registers[M * nr + mr], prefix=isa.prefix() - ) - for nr in range(0, num_horizontal_registers): - for mr in range(0, M): - asm_string += isa.clamp_max( - reg=acc_registers[M * nr + mr], prefix=isa.prefix() - ) + asm_string += isa.clamp(M, N) ## store asm_string += isa.store( diff --git a/gemm_compiler/generate_f32_gemm_microkernels.py b/gemm_compiler/generate_f32_gemm_microkernels.py index 65e82dbad2d..dbb8f862951 100644 --- a/gemm_compiler/generate_f32_gemm_microkernels.py +++ b/gemm_compiler/generate_f32_gemm_microkernels.py @@ -71,3 +71,28 @@ def generate_f32_gemm_microkernels(): f'f32-gemm-{mr}x8-minmax-asm-aarch64-neonfma-ld{decrement}-2.S', ), ) + + # Generate C2 variants. + for mr in range(1, 12): + generate.generate_gemm_microkernel( + M=mr, + N=16, + isa=avx512f_template.Avx512FC(c=2), + output_file=os.path.join( + output_base, + f'f32-gemm-{mr}x16c2-minmax-asm-amd64-avx512f-broadcast.S', + ), + ) + + # not enough SIMD registers to go above 5x32 + for mr in range(1, 6): + generate.generate_gemm_microkernel( + M=mr, + N=32, + isa=avx512f_template.Avx512FC(c=2), + output_file=os.path.join( + output_base, + f'f32-gemm-{mr}x32c2-minmax-asm-amd64-avx512f-broadcast.S', + ), + ) + diff --git a/gemm_compiler/x64_template.py b/gemm_compiler/x64_template.py index 9f4c1db00dc..06d5be059d1 100644 --- a/gemm_compiler/x64_template.py +++ b/gemm_compiler/x64_template.py @@ -11,6 +11,18 @@ class X64(base_architecture.BaseArchitecture): + def __init__(self): + self.c = 1 + + def element_size(self): + return 4 + + def mask(self): + return '' + + def outer_loop_prepare(self, M, N): + return '' + def astride_register(self): return 'r8' @@ -48,9 +60,9 @@ def cm_registers(self): def acc_registers(self): return [ - 'mm7', - 'mm8', - 'mm9', + 'mm11', + 'mm12', + 'mm13', 'mm14', 'mm15', 'mm16', @@ -68,10 +80,13 @@ def acc_registers(self): 'mm28', 'mm29', 'mm30', - 'mm12', - 'mm13', + 'mm9', + 'mm10', ] + def bias_registers(self): + return self.acc_registers() + def w_ptr_register(self): return 'r9' @@ -143,14 +158,22 @@ def function_name(self, M, N, isa): def params_offset(self): return 96 - def header(self, M, N, prefix, isa): - HEADER = """// Copyright 2025 Google LLC + def pre_header(self): + return '' + + def copyright(self): + return '''// Copyright 2025 Google LLC // // This source code is licensed under the BSD-style license found in the // LICENSE file in the root directory of this source tree. #include "xnnpack/assembly.h" +''' + def header(self, M, N, prefix, isa): + HEADER = self.copyright() + HEADER += self.pre_header() + HEADER += """ BEGIN_FUNCTION {function_name} .intel_syntax noprefix @@ -302,7 +325,8 @@ def initialize_k_register(self, reg): return f'mov {reg}, 0\n' def inner_loop_increment(self): - return 4 + return self.c * self.element_size() + def cmp_k_and_jump_if_less(self, label): kc_register = self.kc_register() diff --git a/gen/amd64_microkernels.bzl b/gen/amd64_microkernels.bzl index c92421c9b60..3cdd52cd1b7 100644 --- a/gen/amd64_microkernels.bzl +++ b/gen/amd64_microkernels.bzl @@ -8,6 +8,8 @@ Auto-generated file. Do not edit! PROD_AMD64_ASM_MICROKERNEL_SRCS = [ "src/bf16-f32-gemm/gen/bf16-f32-gemm-1x32c2-minmax-asm-amd64-avx512bf16-broadcast.S", "src/bf16-f32-gemm/gen/bf16-f32-gemm-7x32c2-minmax-asm-amd64-avx512bf16-broadcast.S", + "src/f32-gemm/gen/f32-gemm-1x32c2-minmax-asm-amd64-avx512f-broadcast.S", + "src/f32-gemm/gen/f32-gemm-5x32c2-minmax-asm-amd64-avx512f-broadcast.S", ] NON_PROD_AMD64_ASM_MICROKERNEL_SRCS = [ @@ -37,31 +39,45 @@ NON_PROD_AMD64_ASM_MICROKERNEL_SRCS = [ "src/bf16-f32-gemm/gen/bf16-f32-gemm-11x16c2-minmax-asm-amd64-avx512bf16-broadcast.S", "src/bf16-f32-gemm/gen/bf16-f32-gemm-11x32c2-minmax-asm-amd64-avx512bf16-broadcast.S", "src/f32-gemm/gen/f32-gemm-1x16-minmax-asm-amd64-avx512f-broadcast.S", + "src/f32-gemm/gen/f32-gemm-1x16c2-minmax-asm-amd64-avx512f-broadcast.S", "src/f32-gemm/gen/f32-gemm-1x32-minmax-asm-amd64-avx512f-broadcast.S", "src/f32-gemm/gen/f32-gemm-1x64-minmax-asm-amd64-avx512f-broadcast.S", "src/f32-gemm/gen/f32-gemm-2x16-minmax-asm-amd64-avx512f-broadcast.S", + "src/f32-gemm/gen/f32-gemm-2x16c2-minmax-asm-amd64-avx512f-broadcast.S", "src/f32-gemm/gen/f32-gemm-2x32-minmax-asm-amd64-avx512f-broadcast.S", + "src/f32-gemm/gen/f32-gemm-2x32c2-minmax-asm-amd64-avx512f-broadcast.S", "src/f32-gemm/gen/f32-gemm-2x64-minmax-asm-amd64-avx512f-broadcast.S", "src/f32-gemm/gen/f32-gemm-3x16-minmax-asm-amd64-avx512f-broadcast.S", + "src/f32-gemm/gen/f32-gemm-3x16c2-minmax-asm-amd64-avx512f-broadcast.S", "src/f32-gemm/gen/f32-gemm-3x32-minmax-asm-amd64-avx512f-broadcast.S", + "src/f32-gemm/gen/f32-gemm-3x32c2-minmax-asm-amd64-avx512f-broadcast.S", "src/f32-gemm/gen/f32-gemm-3x64-minmax-asm-amd64-avx512f-broadcast.S", "src/f32-gemm/gen/f32-gemm-4x16-minmax-asm-amd64-avx512f-broadcast.S", + "src/f32-gemm/gen/f32-gemm-4x16c2-minmax-asm-amd64-avx512f-broadcast.S", "src/f32-gemm/gen/f32-gemm-4x32-minmax-asm-amd64-avx512f-broadcast.S", + "src/f32-gemm/gen/f32-gemm-4x32c2-minmax-asm-amd64-avx512f-broadcast.S", "src/f32-gemm/gen/f32-gemm-4x64-minmax-asm-amd64-avx512f-broadcast.S", "src/f32-gemm/gen/f32-gemm-5x16-minmax-asm-amd64-avx512f-broadcast.S", + "src/f32-gemm/gen/f32-gemm-5x16c2-minmax-asm-amd64-avx512f-broadcast.S", "src/f32-gemm/gen/f32-gemm-5x32-minmax-asm-amd64-avx512f-broadcast.S", "src/f32-gemm/gen/f32-gemm-5x64-minmax-asm-amd64-avx512f-broadcast.S", "src/f32-gemm/gen/f32-gemm-6x16-minmax-asm-amd64-avx512f-broadcast.S", + "src/f32-gemm/gen/f32-gemm-6x16c2-minmax-asm-amd64-avx512f-broadcast.S", "src/f32-gemm/gen/f32-gemm-6x32-minmax-asm-amd64-avx512f-broadcast.S", "src/f32-gemm/gen/f32-gemm-7x16-minmax-asm-amd64-avx512f-broadcast.S", + "src/f32-gemm/gen/f32-gemm-7x16c2-minmax-asm-amd64-avx512f-broadcast.S", "src/f32-gemm/gen/f32-gemm-7x32-minmax-asm-amd64-avx512f-broadcast.S", "src/f32-gemm/gen/f32-gemm-8x16-minmax-asm-amd64-avx512f-broadcast.S", + "src/f32-gemm/gen/f32-gemm-8x16c2-minmax-asm-amd64-avx512f-broadcast.S", "src/f32-gemm/gen/f32-gemm-8x32-minmax-asm-amd64-avx512f-broadcast.S", "src/f32-gemm/gen/f32-gemm-9x16-minmax-asm-amd64-avx512f-broadcast.S", + "src/f32-gemm/gen/f32-gemm-9x16c2-minmax-asm-amd64-avx512f-broadcast.S", "src/f32-gemm/gen/f32-gemm-9x32-minmax-asm-amd64-avx512f-broadcast.S", "src/f32-gemm/gen/f32-gemm-10x16-minmax-asm-amd64-avx512f-broadcast.S", + "src/f32-gemm/gen/f32-gemm-10x16c2-minmax-asm-amd64-avx512f-broadcast.S", "src/f32-gemm/gen/f32-gemm-10x32-minmax-asm-amd64-avx512f-broadcast.S", "src/f32-gemm/gen/f32-gemm-11x16-minmax-asm-amd64-avx512f-broadcast.S", + "src/f32-gemm/gen/f32-gemm-11x16c2-minmax-asm-amd64-avx512f-broadcast.S", "src/f32-gemm/gen/f32-gemm-11x32-minmax-asm-amd64-avx512f-broadcast.S", "src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-1x16-minmax-asm-amd64-avx512vnni.S", "src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-1x32-minmax-asm-amd64-avx512vnni.S", diff --git a/src/bf16-f32-gemm/gen/bf16-f32-gemm-10x16c2-minmax-asm-amd64-avx512bf16-broadcast.S b/src/bf16-f32-gemm/gen/bf16-f32-gemm-10x16c2-minmax-asm-amd64-avx512bf16-broadcast.S index c3757784ddf..0df08405ce9 100644 --- a/src/bf16-f32-gemm/gen/bf16-f32-gemm-10x16c2-minmax-asm-amd64-avx512bf16-broadcast.S +++ b/src/bf16-f32-gemm/gen/bf16-f32-gemm-10x16c2-minmax-asm-amd64-avx512bf16-broadcast.S @@ -173,16 +173,16 @@ outer_loop: mov r8, [rsp + 160] # Initialize accumulators with the biases. - vmovaps zmm7, [r9 + 0] - vmovaps zmm8, zmm7 - vmovaps zmm9, zmm7 - vmovaps zmm14, zmm7 - vmovaps zmm15, zmm7 - vmovaps zmm16, zmm7 - vmovaps zmm17, zmm7 - vmovaps zmm18, zmm7 - vmovaps zmm19, zmm7 - vmovaps zmm20, zmm7 + vmovaps zmm11, [r9 + 0] + vmovaps zmm12, zmm11 + vmovaps zmm13, zmm11 + vmovaps zmm14, zmm11 + vmovaps zmm15, zmm11 + vmovaps zmm16, zmm11 + vmovaps zmm17, zmm11 + vmovaps zmm18, zmm11 + vmovaps zmm19, zmm11 + vmovaps zmm20, zmm11 add r9, 64 # Are there at least 4 bytes? @@ -190,28 +190,28 @@ outer_loop: js inner_loop_tail inner_loop: - vmovaps zmm10, [r9 + 0] + vmovaps zmm7, [r9 + 0] add r9, 64 vbroadcastss zmm2, DWORD PTR [rcx + r11] - vdpbf16ps zmm7, zmm2, zmm10 + vdpbf16ps zmm11, zmm2, zmm7 vbroadcastss zmm2, DWORD PTR [rax + r11] - vdpbf16ps zmm8, zmm2, zmm10 + vdpbf16ps zmm12, zmm2, zmm7 vbroadcastss zmm2, DWORD PTR [r15 + r11] - vdpbf16ps zmm9, zmm2, zmm10 + vdpbf16ps zmm13, zmm2, zmm7 vbroadcastss zmm2, DWORD PTR [r14 + r11] - vdpbf16ps zmm14, zmm2, zmm10 + vdpbf16ps zmm14, zmm2, zmm7 vbroadcastss zmm2, DWORD PTR [r12 + r11] - vdpbf16ps zmm15, zmm2, zmm10 + vdpbf16ps zmm15, zmm2, zmm7 vbroadcastss zmm2, DWORD PTR [r10 + r11] - vdpbf16ps zmm16, zmm2, zmm10 + vdpbf16ps zmm16, zmm2, zmm7 vbroadcastss zmm2, DWORD PTR [r13 + r11] - vdpbf16ps zmm17, zmm2, zmm10 + vdpbf16ps zmm17, zmm2, zmm7 vbroadcastss zmm2, DWORD PTR [rbx + r11] - vdpbf16ps zmm18, zmm2, zmm10 + vdpbf16ps zmm18, zmm2, zmm7 vbroadcastss zmm2, DWORD PTR [rbp + r11] - vdpbf16ps zmm19, zmm2, zmm10 + vdpbf16ps zmm19, zmm2, zmm7 vbroadcastss zmm2, DWORD PTR [r8 + r11] - vdpbf16ps zmm20, zmm2, zmm10 + vdpbf16ps zmm20, zmm2, zmm7 add r11, 4 cmp rdx, r11 @@ -227,63 +227,63 @@ inner_loop: jz inner_loop_end inner_loop_tail: - vmovaps zmm10, [r9 + 0] + vmovaps zmm7, [r9 + 0] add r9, 64 vbroadcastss zmm2, DWORD PTR [rcx + r11] vpslld zmm2, zmm2, 16 vpsrad zmm2, zmm2, 16 - vdpbf16ps zmm7, zmm2, zmm10 + vdpbf16ps zmm11, zmm2, zmm7 vbroadcastss zmm2, DWORD PTR [rax + r11] vpslld zmm2, zmm2, 16 vpsrad zmm2, zmm2, 16 - vdpbf16ps zmm8, zmm2, zmm10 + vdpbf16ps zmm12, zmm2, zmm7 vbroadcastss zmm2, DWORD PTR [r15 + r11] vpslld zmm2, zmm2, 16 vpsrad zmm2, zmm2, 16 - vdpbf16ps zmm9, zmm2, zmm10 + vdpbf16ps zmm13, zmm2, zmm7 vbroadcastss zmm2, DWORD PTR [r14 + r11] vpslld zmm2, zmm2, 16 vpsrad zmm2, zmm2, 16 - vdpbf16ps zmm14, zmm2, zmm10 + vdpbf16ps zmm14, zmm2, zmm7 vbroadcastss zmm2, DWORD PTR [r12 + r11] vpslld zmm2, zmm2, 16 vpsrad zmm2, zmm2, 16 - vdpbf16ps zmm15, zmm2, zmm10 + vdpbf16ps zmm15, zmm2, zmm7 vbroadcastss zmm2, DWORD PTR [r10 + r11] vpslld zmm2, zmm2, 16 vpsrad zmm2, zmm2, 16 - vdpbf16ps zmm16, zmm2, zmm10 + vdpbf16ps zmm16, zmm2, zmm7 vbroadcastss zmm2, DWORD PTR [r13 + r11] vpslld zmm2, zmm2, 16 vpsrad zmm2, zmm2, 16 - vdpbf16ps zmm17, zmm2, zmm10 + vdpbf16ps zmm17, zmm2, zmm7 vbroadcastss zmm2, DWORD PTR [rbx + r11] vpslld zmm2, zmm2, 16 vpsrad zmm2, zmm2, 16 - vdpbf16ps zmm18, zmm2, zmm10 + vdpbf16ps zmm18, zmm2, zmm7 vbroadcastss zmm2, DWORD PTR [rbp + r11] vpslld zmm2, zmm2, 16 vpsrad zmm2, zmm2, 16 - vdpbf16ps zmm19, zmm2, zmm10 + vdpbf16ps zmm19, zmm2, zmm7 vbroadcastss zmm2, DWORD PTR [r8 + r11] vpslld zmm2, zmm2, 16 vpsrad zmm2, zmm2, 16 - vdpbf16ps zmm20, zmm2, zmm10 + vdpbf16ps zmm20, zmm2, zmm7 inner_loop_end: # Min/max clamping. - vminps zmm7, zmm1, zmm7 - vminps zmm8, zmm1, zmm8 - vminps zmm9, zmm1, zmm9 + vminps zmm11, zmm1, zmm11 + vminps zmm12, zmm1, zmm12 + vminps zmm13, zmm1, zmm13 vminps zmm14, zmm1, zmm14 vminps zmm15, zmm1, zmm15 vminps zmm16, zmm1, zmm16 @@ -291,9 +291,9 @@ inner_loop_end: vminps zmm18, zmm1, zmm18 vminps zmm19, zmm1, zmm19 vminps zmm20, zmm1, zmm20 - vmaxps zmm7, zmm0, zmm7 - vmaxps zmm8, zmm0, zmm8 - vmaxps zmm9, zmm0, zmm9 + vmaxps zmm11, zmm0, zmm11 + vmaxps zmm12, zmm0, zmm12 + vmaxps zmm13, zmm0, zmm13 vmaxps zmm14, zmm0, zmm14 vmaxps zmm15, zmm0, zmm15 vmaxps zmm16, zmm0, zmm16 @@ -318,9 +318,9 @@ inner_loop_end: cmp rsi, 16 jl tail - vmovups [rcx], zmm7 - vmovups [rax], zmm8 - vmovups [r15], zmm9 + vmovups [rcx], zmm11 + vmovups [rax], zmm12 + vmovups [r15], zmm13 vmovups [r14], zmm14 vmovups [r12], zmm15 vmovups [r10], zmm16 @@ -360,9 +360,9 @@ tail: shlx r11, r11, rsi not r11 kmovw k1, r11d - vmovups ZMMWORD PTR [rcx]{k1}, zmm7 - vmovups ZMMWORD PTR [rax]{k1}, zmm8 - vmovups ZMMWORD PTR [r15]{k1}, zmm9 + vmovups ZMMWORD PTR [rcx]{k1}, zmm11 + vmovups ZMMWORD PTR [rax]{k1}, zmm12 + vmovups ZMMWORD PTR [r15]{k1}, zmm13 vmovups ZMMWORD PTR [r14]{k1}, zmm14 vmovups ZMMWORD PTR [r12]{k1}, zmm15 vmovups ZMMWORD PTR [r10]{k1}, zmm16 diff --git a/src/bf16-f32-gemm/gen/bf16-f32-gemm-10x32c2-minmax-asm-amd64-avx512bf16-broadcast.S b/src/bf16-f32-gemm/gen/bf16-f32-gemm-10x32c2-minmax-asm-amd64-avx512bf16-broadcast.S index 96395d0b3fc..5752f568de9 100644 --- a/src/bf16-f32-gemm/gen/bf16-f32-gemm-10x32c2-minmax-asm-amd64-avx512bf16-broadcast.S +++ b/src/bf16-f32-gemm/gen/bf16-f32-gemm-10x32c2-minmax-asm-amd64-avx512bf16-broadcast.S @@ -173,17 +173,17 @@ outer_loop: mov r8, [rsp + 160] # Initialize accumulators with the biases. - vmovaps zmm7, [r9 + 0] + vmovaps zmm11, [r9 + 0] vmovaps zmm21, [r9 + 64] - vmovaps zmm8, zmm7 - vmovaps zmm9, zmm7 - vmovaps zmm14, zmm7 - vmovaps zmm15, zmm7 - vmovaps zmm16, zmm7 - vmovaps zmm17, zmm7 - vmovaps zmm18, zmm7 - vmovaps zmm19, zmm7 - vmovaps zmm20, zmm7 + vmovaps zmm12, zmm11 + vmovaps zmm13, zmm11 + vmovaps zmm14, zmm11 + vmovaps zmm15, zmm11 + vmovaps zmm16, zmm11 + vmovaps zmm17, zmm11 + vmovaps zmm18, zmm11 + vmovaps zmm19, zmm11 + vmovaps zmm20, zmm11 vmovaps zmm22, zmm21 vmovaps zmm23, zmm21 vmovaps zmm24, zmm21 @@ -200,39 +200,39 @@ outer_loop: js inner_loop_tail inner_loop: - vmovaps zmm10, [r9 + 0] - vmovaps zmm11, [r9 + 64] + vmovaps zmm7, [r9 + 0] + vmovaps zmm8, [r9 + 64] add r9, 128 vbroadcastss zmm2, DWORD PTR [rcx + r11] - vdpbf16ps zmm7, zmm2, zmm10 - vdpbf16ps zmm21, zmm2, zmm11 + vdpbf16ps zmm11, zmm2, zmm7 + vdpbf16ps zmm21, zmm2, zmm8 vbroadcastss zmm2, DWORD PTR [rax + r11] - vdpbf16ps zmm8, zmm2, zmm10 - vdpbf16ps zmm22, zmm2, zmm11 + vdpbf16ps zmm12, zmm2, zmm7 + vdpbf16ps zmm22, zmm2, zmm8 vbroadcastss zmm2, DWORD PTR [r15 + r11] - vdpbf16ps zmm9, zmm2, zmm10 - vdpbf16ps zmm23, zmm2, zmm11 + vdpbf16ps zmm13, zmm2, zmm7 + vdpbf16ps zmm23, zmm2, zmm8 vbroadcastss zmm2, DWORD PTR [r14 + r11] - vdpbf16ps zmm14, zmm2, zmm10 - vdpbf16ps zmm24, zmm2, zmm11 + vdpbf16ps zmm14, zmm2, zmm7 + vdpbf16ps zmm24, zmm2, zmm8 vbroadcastss zmm2, DWORD PTR [r12 + r11] - vdpbf16ps zmm15, zmm2, zmm10 - vdpbf16ps zmm25, zmm2, zmm11 + vdpbf16ps zmm15, zmm2, zmm7 + vdpbf16ps zmm25, zmm2, zmm8 vbroadcastss zmm2, DWORD PTR [r10 + r11] - vdpbf16ps zmm16, zmm2, zmm10 - vdpbf16ps zmm26, zmm2, zmm11 + vdpbf16ps zmm16, zmm2, zmm7 + vdpbf16ps zmm26, zmm2, zmm8 vbroadcastss zmm2, DWORD PTR [r13 + r11] - vdpbf16ps zmm17, zmm2, zmm10 - vdpbf16ps zmm27, zmm2, zmm11 + vdpbf16ps zmm17, zmm2, zmm7 + vdpbf16ps zmm27, zmm2, zmm8 vbroadcastss zmm2, DWORD PTR [rbx + r11] - vdpbf16ps zmm18, zmm2, zmm10 - vdpbf16ps zmm28, zmm2, zmm11 + vdpbf16ps zmm18, zmm2, zmm7 + vdpbf16ps zmm28, zmm2, zmm8 vbroadcastss zmm2, DWORD PTR [rbp + r11] - vdpbf16ps zmm19, zmm2, zmm10 - vdpbf16ps zmm29, zmm2, zmm11 + vdpbf16ps zmm19, zmm2, zmm7 + vdpbf16ps zmm29, zmm2, zmm8 vbroadcastss zmm2, DWORD PTR [r8 + r11] - vdpbf16ps zmm20, zmm2, zmm10 - vdpbf16ps zmm30, zmm2, zmm11 + vdpbf16ps zmm20, zmm2, zmm7 + vdpbf16ps zmm30, zmm2, zmm8 add r11, 4 cmp rdx, r11 @@ -248,104 +248,104 @@ inner_loop: jz inner_loop_end inner_loop_tail: - vmovaps zmm10, [r9 + 0] - vmovaps zmm11, [r9 + 64] + vmovaps zmm7, [r9 + 0] + vmovaps zmm8, [r9 + 64] add r9, 128 vbroadcastss zmm2, DWORD PTR [rcx + r11] vpslld zmm2, zmm2, 16 vpsrad zmm2, zmm2, 16 - vdpbf16ps zmm7, zmm2, zmm10 + vdpbf16ps zmm11, zmm2, zmm7 vpslld zmm2, zmm2, 16 vpsrad zmm2, zmm2, 16 - vdpbf16ps zmm21, zmm2, zmm11 + vdpbf16ps zmm21, zmm2, zmm8 vbroadcastss zmm2, DWORD PTR [rax + r11] vpslld zmm2, zmm2, 16 vpsrad zmm2, zmm2, 16 - vdpbf16ps zmm8, zmm2, zmm10 + vdpbf16ps zmm12, zmm2, zmm7 vpslld zmm2, zmm2, 16 vpsrad zmm2, zmm2, 16 - vdpbf16ps zmm22, zmm2, zmm11 + vdpbf16ps zmm22, zmm2, zmm8 vbroadcastss zmm2, DWORD PTR [r15 + r11] vpslld zmm2, zmm2, 16 vpsrad zmm2, zmm2, 16 - vdpbf16ps zmm9, zmm2, zmm10 + vdpbf16ps zmm13, zmm2, zmm7 vpslld zmm2, zmm2, 16 vpsrad zmm2, zmm2, 16 - vdpbf16ps zmm23, zmm2, zmm11 + vdpbf16ps zmm23, zmm2, zmm8 vbroadcastss zmm2, DWORD PTR [r14 + r11] vpslld zmm2, zmm2, 16 vpsrad zmm2, zmm2, 16 - vdpbf16ps zmm14, zmm2, zmm10 + vdpbf16ps zmm14, zmm2, zmm7 vpslld zmm2, zmm2, 16 vpsrad zmm2, zmm2, 16 - vdpbf16ps zmm24, zmm2, zmm11 + vdpbf16ps zmm24, zmm2, zmm8 vbroadcastss zmm2, DWORD PTR [r12 + r11] vpslld zmm2, zmm2, 16 vpsrad zmm2, zmm2, 16 - vdpbf16ps zmm15, zmm2, zmm10 + vdpbf16ps zmm15, zmm2, zmm7 vpslld zmm2, zmm2, 16 vpsrad zmm2, zmm2, 16 - vdpbf16ps zmm25, zmm2, zmm11 + vdpbf16ps zmm25, zmm2, zmm8 vbroadcastss zmm2, DWORD PTR [r10 + r11] vpslld zmm2, zmm2, 16 vpsrad zmm2, zmm2, 16 - vdpbf16ps zmm16, zmm2, zmm10 + vdpbf16ps zmm16, zmm2, zmm7 vpslld zmm2, zmm2, 16 vpsrad zmm2, zmm2, 16 - vdpbf16ps zmm26, zmm2, zmm11 + vdpbf16ps zmm26, zmm2, zmm8 vbroadcastss zmm2, DWORD PTR [r13 + r11] vpslld zmm2, zmm2, 16 vpsrad zmm2, zmm2, 16 - vdpbf16ps zmm17, zmm2, zmm10 + vdpbf16ps zmm17, zmm2, zmm7 vpslld zmm2, zmm2, 16 vpsrad zmm2, zmm2, 16 - vdpbf16ps zmm27, zmm2, zmm11 + vdpbf16ps zmm27, zmm2, zmm8 vbroadcastss zmm2, DWORD PTR [rbx + r11] vpslld zmm2, zmm2, 16 vpsrad zmm2, zmm2, 16 - vdpbf16ps zmm18, zmm2, zmm10 + vdpbf16ps zmm18, zmm2, zmm7 vpslld zmm2, zmm2, 16 vpsrad zmm2, zmm2, 16 - vdpbf16ps zmm28, zmm2, zmm11 + vdpbf16ps zmm28, zmm2, zmm8 vbroadcastss zmm2, DWORD PTR [rbp + r11] vpslld zmm2, zmm2, 16 vpsrad zmm2, zmm2, 16 - vdpbf16ps zmm19, zmm2, zmm10 + vdpbf16ps zmm19, zmm2, zmm7 vpslld zmm2, zmm2, 16 vpsrad zmm2, zmm2, 16 - vdpbf16ps zmm29, zmm2, zmm11 + vdpbf16ps zmm29, zmm2, zmm8 vbroadcastss zmm2, DWORD PTR [r8 + r11] vpslld zmm2, zmm2, 16 vpsrad zmm2, zmm2, 16 - vdpbf16ps zmm20, zmm2, zmm10 + vdpbf16ps zmm20, zmm2, zmm7 vpslld zmm2, zmm2, 16 vpsrad zmm2, zmm2, 16 - vdpbf16ps zmm30, zmm2, zmm11 + vdpbf16ps zmm30, zmm2, zmm8 inner_loop_end: # Min/max clamping. - vminps zmm7, zmm1, zmm7 - vminps zmm8, zmm1, zmm8 - vminps zmm9, zmm1, zmm9 + vminps zmm11, zmm1, zmm11 + vminps zmm12, zmm1, zmm12 + vminps zmm13, zmm1, zmm13 vminps zmm14, zmm1, zmm14 vminps zmm15, zmm1, zmm15 vminps zmm16, zmm1, zmm16 @@ -363,9 +363,9 @@ inner_loop_end: vminps zmm28, zmm1, zmm28 vminps zmm29, zmm1, zmm29 vminps zmm30, zmm1, zmm30 - vmaxps zmm7, zmm0, zmm7 - vmaxps zmm8, zmm0, zmm8 - vmaxps zmm9, zmm0, zmm9 + vmaxps zmm11, zmm0, zmm11 + vmaxps zmm12, zmm0, zmm12 + vmaxps zmm13, zmm0, zmm13 vmaxps zmm14, zmm0, zmm14 vmaxps zmm15, zmm0, zmm15 vmaxps zmm16, zmm0, zmm16 @@ -400,11 +400,11 @@ inner_loop_end: cmp rsi, 32 jl tail - vmovups [rcx], zmm7 + vmovups [rcx], zmm11 vmovups [rcx + 64], zmm21 - vmovups [rax], zmm8 + vmovups [rax], zmm12 vmovups [rax + 64], zmm22 - vmovups [r15], zmm9 + vmovups [r15], zmm13 vmovups [r15 + 64], zmm23 vmovups [r14], zmm14 vmovups [r14 + 64], zmm24 @@ -454,11 +454,11 @@ tail: kmovw k1, r11d shr r11d, 16 kmovw k2, r11d - vmovups ZMMWORD PTR [rcx]{k1}, zmm7 + vmovups ZMMWORD PTR [rcx]{k1}, zmm11 vmovups ZMMWORD PTR [rcx + 64]{k2}, zmm21 - vmovups ZMMWORD PTR [rax]{k1}, zmm8 + vmovups ZMMWORD PTR [rax]{k1}, zmm12 vmovups ZMMWORD PTR [rax + 64]{k2}, zmm22 - vmovups ZMMWORD PTR [r15]{k1}, zmm9 + vmovups ZMMWORD PTR [r15]{k1}, zmm13 vmovups ZMMWORD PTR [r15 + 64]{k2}, zmm23 vmovups ZMMWORD PTR [r14]{k1}, zmm14 vmovups ZMMWORD PTR [r14 + 64]{k2}, zmm24 diff --git a/src/bf16-f32-gemm/gen/bf16-f32-gemm-11x16c2-minmax-asm-amd64-avx512bf16-broadcast.S b/src/bf16-f32-gemm/gen/bf16-f32-gemm-11x16c2-minmax-asm-amd64-avx512bf16-broadcast.S index fa76aec54fc..447657c63fd 100644 --- a/src/bf16-f32-gemm/gen/bf16-f32-gemm-11x16c2-minmax-asm-amd64-avx512bf16-broadcast.S +++ b/src/bf16-f32-gemm/gen/bf16-f32-gemm-11x16c2-minmax-asm-amd64-avx512bf16-broadcast.S @@ -186,17 +186,17 @@ outer_loop: mov rdi, [rsp + 176] # Initialize accumulators with the biases. - vmovaps zmm7, [r9 + 0] - vmovaps zmm8, zmm7 - vmovaps zmm9, zmm7 - vmovaps zmm14, zmm7 - vmovaps zmm15, zmm7 - vmovaps zmm16, zmm7 - vmovaps zmm17, zmm7 - vmovaps zmm18, zmm7 - vmovaps zmm19, zmm7 - vmovaps zmm20, zmm7 - vmovaps zmm21, zmm7 + vmovaps zmm11, [r9 + 0] + vmovaps zmm12, zmm11 + vmovaps zmm13, zmm11 + vmovaps zmm14, zmm11 + vmovaps zmm15, zmm11 + vmovaps zmm16, zmm11 + vmovaps zmm17, zmm11 + vmovaps zmm18, zmm11 + vmovaps zmm19, zmm11 + vmovaps zmm20, zmm11 + vmovaps zmm21, zmm11 add r9, 64 # Are there at least 4 bytes? @@ -204,30 +204,30 @@ outer_loop: js inner_loop_tail inner_loop: - vmovaps zmm10, [r9 + 0] + vmovaps zmm7, [r9 + 0] add r9, 64 vbroadcastss zmm2, DWORD PTR [rcx + r11] - vdpbf16ps zmm7, zmm2, zmm10 + vdpbf16ps zmm11, zmm2, zmm7 vbroadcastss zmm2, DWORD PTR [rax + r11] - vdpbf16ps zmm8, zmm2, zmm10 + vdpbf16ps zmm12, zmm2, zmm7 vbroadcastss zmm2, DWORD PTR [r15 + r11] - vdpbf16ps zmm9, zmm2, zmm10 + vdpbf16ps zmm13, zmm2, zmm7 vbroadcastss zmm2, DWORD PTR [r14 + r11] - vdpbf16ps zmm14, zmm2, zmm10 + vdpbf16ps zmm14, zmm2, zmm7 vbroadcastss zmm2, DWORD PTR [r12 + r11] - vdpbf16ps zmm15, zmm2, zmm10 + vdpbf16ps zmm15, zmm2, zmm7 vbroadcastss zmm2, DWORD PTR [r10 + r11] - vdpbf16ps zmm16, zmm2, zmm10 + vdpbf16ps zmm16, zmm2, zmm7 vbroadcastss zmm2, DWORD PTR [r13 + r11] - vdpbf16ps zmm17, zmm2, zmm10 + vdpbf16ps zmm17, zmm2, zmm7 vbroadcastss zmm2, DWORD PTR [rbx + r11] - vdpbf16ps zmm18, zmm2, zmm10 + vdpbf16ps zmm18, zmm2, zmm7 vbroadcastss zmm2, DWORD PTR [rbp + r11] - vdpbf16ps zmm19, zmm2, zmm10 + vdpbf16ps zmm19, zmm2, zmm7 vbroadcastss zmm2, DWORD PTR [r8 + r11] - vdpbf16ps zmm20, zmm2, zmm10 + vdpbf16ps zmm20, zmm2, zmm7 vbroadcastss zmm2, DWORD PTR [rdi + r11] - vdpbf16ps zmm21, zmm2, zmm10 + vdpbf16ps zmm21, zmm2, zmm7 add r11, 4 cmp rdx, r11 @@ -243,68 +243,68 @@ inner_loop: jz inner_loop_end inner_loop_tail: - vmovaps zmm10, [r9 + 0] + vmovaps zmm7, [r9 + 0] add r9, 64 vbroadcastss zmm2, DWORD PTR [rcx + r11] vpslld zmm2, zmm2, 16 vpsrad zmm2, zmm2, 16 - vdpbf16ps zmm7, zmm2, zmm10 + vdpbf16ps zmm11, zmm2, zmm7 vbroadcastss zmm2, DWORD PTR [rax + r11] vpslld zmm2, zmm2, 16 vpsrad zmm2, zmm2, 16 - vdpbf16ps zmm8, zmm2, zmm10 + vdpbf16ps zmm12, zmm2, zmm7 vbroadcastss zmm2, DWORD PTR [r15 + r11] vpslld zmm2, zmm2, 16 vpsrad zmm2, zmm2, 16 - vdpbf16ps zmm9, zmm2, zmm10 + vdpbf16ps zmm13, zmm2, zmm7 vbroadcastss zmm2, DWORD PTR [r14 + r11] vpslld zmm2, zmm2, 16 vpsrad zmm2, zmm2, 16 - vdpbf16ps zmm14, zmm2, zmm10 + vdpbf16ps zmm14, zmm2, zmm7 vbroadcastss zmm2, DWORD PTR [r12 + r11] vpslld zmm2, zmm2, 16 vpsrad zmm2, zmm2, 16 - vdpbf16ps zmm15, zmm2, zmm10 + vdpbf16ps zmm15, zmm2, zmm7 vbroadcastss zmm2, DWORD PTR [r10 + r11] vpslld zmm2, zmm2, 16 vpsrad zmm2, zmm2, 16 - vdpbf16ps zmm16, zmm2, zmm10 + vdpbf16ps zmm16, zmm2, zmm7 vbroadcastss zmm2, DWORD PTR [r13 + r11] vpslld zmm2, zmm2, 16 vpsrad zmm2, zmm2, 16 - vdpbf16ps zmm17, zmm2, zmm10 + vdpbf16ps zmm17, zmm2, zmm7 vbroadcastss zmm2, DWORD PTR [rbx + r11] vpslld zmm2, zmm2, 16 vpsrad zmm2, zmm2, 16 - vdpbf16ps zmm18, zmm2, zmm10 + vdpbf16ps zmm18, zmm2, zmm7 vbroadcastss zmm2, DWORD PTR [rbp + r11] vpslld zmm2, zmm2, 16 vpsrad zmm2, zmm2, 16 - vdpbf16ps zmm19, zmm2, zmm10 + vdpbf16ps zmm19, zmm2, zmm7 vbroadcastss zmm2, DWORD PTR [r8 + r11] vpslld zmm2, zmm2, 16 vpsrad zmm2, zmm2, 16 - vdpbf16ps zmm20, zmm2, zmm10 + vdpbf16ps zmm20, zmm2, zmm7 vbroadcastss zmm2, DWORD PTR [rdi + r11] vpslld zmm2, zmm2, 16 vpsrad zmm2, zmm2, 16 - vdpbf16ps zmm21, zmm2, zmm10 + vdpbf16ps zmm21, zmm2, zmm7 inner_loop_end: # Min/max clamping. - vminps zmm7, zmm1, zmm7 - vminps zmm8, zmm1, zmm8 - vminps zmm9, zmm1, zmm9 + vminps zmm11, zmm1, zmm11 + vminps zmm12, zmm1, zmm12 + vminps zmm13, zmm1, zmm13 vminps zmm14, zmm1, zmm14 vminps zmm15, zmm1, zmm15 vminps zmm16, zmm1, zmm16 @@ -313,9 +313,9 @@ inner_loop_end: vminps zmm19, zmm1, zmm19 vminps zmm20, zmm1, zmm20 vminps zmm21, zmm1, zmm21 - vmaxps zmm7, zmm0, zmm7 - vmaxps zmm8, zmm0, zmm8 - vmaxps zmm9, zmm0, zmm9 + vmaxps zmm11, zmm0, zmm11 + vmaxps zmm12, zmm0, zmm12 + vmaxps zmm13, zmm0, zmm13 vmaxps zmm14, zmm0, zmm14 vmaxps zmm15, zmm0, zmm15 vmaxps zmm16, zmm0, zmm16 @@ -342,9 +342,9 @@ inner_loop_end: cmp rsi, 16 jl tail - vmovups [rcx], zmm7 - vmovups [rax], zmm8 - vmovups [r15], zmm9 + vmovups [rcx], zmm11 + vmovups [rax], zmm12 + vmovups [r15], zmm13 vmovups [r14], zmm14 vmovups [r12], zmm15 vmovups [r10], zmm16 @@ -387,9 +387,9 @@ tail: shlx r11, r11, rsi not r11 kmovw k1, r11d - vmovups ZMMWORD PTR [rcx]{k1}, zmm7 - vmovups ZMMWORD PTR [rax]{k1}, zmm8 - vmovups ZMMWORD PTR [r15]{k1}, zmm9 + vmovups ZMMWORD PTR [rcx]{k1}, zmm11 + vmovups ZMMWORD PTR [rax]{k1}, zmm12 + vmovups ZMMWORD PTR [r15]{k1}, zmm13 vmovups ZMMWORD PTR [r14]{k1}, zmm14 vmovups ZMMWORD PTR [r12]{k1}, zmm15 vmovups ZMMWORD PTR [r10]{k1}, zmm16 diff --git a/src/bf16-f32-gemm/gen/bf16-f32-gemm-11x32c2-minmax-asm-amd64-avx512bf16-broadcast.S b/src/bf16-f32-gemm/gen/bf16-f32-gemm-11x32c2-minmax-asm-amd64-avx512bf16-broadcast.S index 0e3ae1360ea..2e3aa88dc7f 100644 --- a/src/bf16-f32-gemm/gen/bf16-f32-gemm-11x32c2-minmax-asm-amd64-avx512bf16-broadcast.S +++ b/src/bf16-f32-gemm/gen/bf16-f32-gemm-11x32c2-minmax-asm-amd64-avx512bf16-broadcast.S @@ -186,18 +186,18 @@ outer_loop: mov rdi, [rsp + 176] # Initialize accumulators with the biases. - vmovaps zmm7, [r9 + 0] + vmovaps zmm11, [r9 + 0] vmovaps zmm22, [r9 + 64] - vmovaps zmm8, zmm7 - vmovaps zmm9, zmm7 - vmovaps zmm14, zmm7 - vmovaps zmm15, zmm7 - vmovaps zmm16, zmm7 - vmovaps zmm17, zmm7 - vmovaps zmm18, zmm7 - vmovaps zmm19, zmm7 - vmovaps zmm20, zmm7 - vmovaps zmm21, zmm7 + vmovaps zmm12, zmm11 + vmovaps zmm13, zmm11 + vmovaps zmm14, zmm11 + vmovaps zmm15, zmm11 + vmovaps zmm16, zmm11 + vmovaps zmm17, zmm11 + vmovaps zmm18, zmm11 + vmovaps zmm19, zmm11 + vmovaps zmm20, zmm11 + vmovaps zmm21, zmm11 vmovaps zmm23, zmm22 vmovaps zmm24, zmm22 vmovaps zmm25, zmm22 @@ -206,8 +206,8 @@ outer_loop: vmovaps zmm28, zmm22 vmovaps zmm29, zmm22 vmovaps zmm30, zmm22 - vmovaps zmm12, zmm22 - vmovaps zmm13, zmm22 + vmovaps zmm9, zmm22 + vmovaps zmm10, zmm22 add r9, 128 # Are there at least 4 bytes? @@ -215,42 +215,42 @@ outer_loop: js inner_loop_tail inner_loop: - vmovaps zmm10, [r9 + 0] - vmovaps zmm11, [r9 + 64] + vmovaps zmm7, [r9 + 0] + vmovaps zmm8, [r9 + 64] add r9, 128 vbroadcastss zmm2, DWORD PTR [rcx + r11] - vdpbf16ps zmm7, zmm2, zmm10 - vdpbf16ps zmm22, zmm2, zmm11 + vdpbf16ps zmm11, zmm2, zmm7 + vdpbf16ps zmm22, zmm2, zmm8 vbroadcastss zmm2, DWORD PTR [rax + r11] - vdpbf16ps zmm8, zmm2, zmm10 - vdpbf16ps zmm23, zmm2, zmm11 + vdpbf16ps zmm12, zmm2, zmm7 + vdpbf16ps zmm23, zmm2, zmm8 vbroadcastss zmm2, DWORD PTR [r15 + r11] - vdpbf16ps zmm9, zmm2, zmm10 - vdpbf16ps zmm24, zmm2, zmm11 + vdpbf16ps zmm13, zmm2, zmm7 + vdpbf16ps zmm24, zmm2, zmm8 vbroadcastss zmm2, DWORD PTR [r14 + r11] - vdpbf16ps zmm14, zmm2, zmm10 - vdpbf16ps zmm25, zmm2, zmm11 + vdpbf16ps zmm14, zmm2, zmm7 + vdpbf16ps zmm25, zmm2, zmm8 vbroadcastss zmm2, DWORD PTR [r12 + r11] - vdpbf16ps zmm15, zmm2, zmm10 - vdpbf16ps zmm26, zmm2, zmm11 + vdpbf16ps zmm15, zmm2, zmm7 + vdpbf16ps zmm26, zmm2, zmm8 vbroadcastss zmm2, DWORD PTR [r10 + r11] - vdpbf16ps zmm16, zmm2, zmm10 - vdpbf16ps zmm27, zmm2, zmm11 + vdpbf16ps zmm16, zmm2, zmm7 + vdpbf16ps zmm27, zmm2, zmm8 vbroadcastss zmm2, DWORD PTR [r13 + r11] - vdpbf16ps zmm17, zmm2, zmm10 - vdpbf16ps zmm28, zmm2, zmm11 + vdpbf16ps zmm17, zmm2, zmm7 + vdpbf16ps zmm28, zmm2, zmm8 vbroadcastss zmm2, DWORD PTR [rbx + r11] - vdpbf16ps zmm18, zmm2, zmm10 - vdpbf16ps zmm29, zmm2, zmm11 + vdpbf16ps zmm18, zmm2, zmm7 + vdpbf16ps zmm29, zmm2, zmm8 vbroadcastss zmm2, DWORD PTR [rbp + r11] - vdpbf16ps zmm19, zmm2, zmm10 - vdpbf16ps zmm30, zmm2, zmm11 + vdpbf16ps zmm19, zmm2, zmm7 + vdpbf16ps zmm30, zmm2, zmm8 vbroadcastss zmm2, DWORD PTR [r8 + r11] - vdpbf16ps zmm20, zmm2, zmm10 - vdpbf16ps zmm12, zmm2, zmm11 + vdpbf16ps zmm20, zmm2, zmm7 + vdpbf16ps zmm9, zmm2, zmm8 vbroadcastss zmm2, DWORD PTR [rdi + r11] - vdpbf16ps zmm21, zmm2, zmm10 - vdpbf16ps zmm13, zmm2, zmm11 + vdpbf16ps zmm21, zmm2, zmm7 + vdpbf16ps zmm10, zmm2, zmm8 add r11, 4 cmp rdx, r11 @@ -266,113 +266,113 @@ inner_loop: jz inner_loop_end inner_loop_tail: - vmovaps zmm10, [r9 + 0] - vmovaps zmm11, [r9 + 64] + vmovaps zmm7, [r9 + 0] + vmovaps zmm8, [r9 + 64] add r9, 128 vbroadcastss zmm2, DWORD PTR [rcx + r11] vpslld zmm2, zmm2, 16 vpsrad zmm2, zmm2, 16 - vdpbf16ps zmm7, zmm2, zmm10 + vdpbf16ps zmm11, zmm2, zmm7 vpslld zmm2, zmm2, 16 vpsrad zmm2, zmm2, 16 - vdpbf16ps zmm22, zmm2, zmm11 + vdpbf16ps zmm22, zmm2, zmm8 vbroadcastss zmm2, DWORD PTR [rax + r11] vpslld zmm2, zmm2, 16 vpsrad zmm2, zmm2, 16 - vdpbf16ps zmm8, zmm2, zmm10 + vdpbf16ps zmm12, zmm2, zmm7 vpslld zmm2, zmm2, 16 vpsrad zmm2, zmm2, 16 - vdpbf16ps zmm23, zmm2, zmm11 + vdpbf16ps zmm23, zmm2, zmm8 vbroadcastss zmm2, DWORD PTR [r15 + r11] vpslld zmm2, zmm2, 16 vpsrad zmm2, zmm2, 16 - vdpbf16ps zmm9, zmm2, zmm10 + vdpbf16ps zmm13, zmm2, zmm7 vpslld zmm2, zmm2, 16 vpsrad zmm2, zmm2, 16 - vdpbf16ps zmm24, zmm2, zmm11 + vdpbf16ps zmm24, zmm2, zmm8 vbroadcastss zmm2, DWORD PTR [r14 + r11] vpslld zmm2, zmm2, 16 vpsrad zmm2, zmm2, 16 - vdpbf16ps zmm14, zmm2, zmm10 + vdpbf16ps zmm14, zmm2, zmm7 vpslld zmm2, zmm2, 16 vpsrad zmm2, zmm2, 16 - vdpbf16ps zmm25, zmm2, zmm11 + vdpbf16ps zmm25, zmm2, zmm8 vbroadcastss zmm2, DWORD PTR [r12 + r11] vpslld zmm2, zmm2, 16 vpsrad zmm2, zmm2, 16 - vdpbf16ps zmm15, zmm2, zmm10 + vdpbf16ps zmm15, zmm2, zmm7 vpslld zmm2, zmm2, 16 vpsrad zmm2, zmm2, 16 - vdpbf16ps zmm26, zmm2, zmm11 + vdpbf16ps zmm26, zmm2, zmm8 vbroadcastss zmm2, DWORD PTR [r10 + r11] vpslld zmm2, zmm2, 16 vpsrad zmm2, zmm2, 16 - vdpbf16ps zmm16, zmm2, zmm10 + vdpbf16ps zmm16, zmm2, zmm7 vpslld zmm2, zmm2, 16 vpsrad zmm2, zmm2, 16 - vdpbf16ps zmm27, zmm2, zmm11 + vdpbf16ps zmm27, zmm2, zmm8 vbroadcastss zmm2, DWORD PTR [r13 + r11] vpslld zmm2, zmm2, 16 vpsrad zmm2, zmm2, 16 - vdpbf16ps zmm17, zmm2, zmm10 + vdpbf16ps zmm17, zmm2, zmm7 vpslld zmm2, zmm2, 16 vpsrad zmm2, zmm2, 16 - vdpbf16ps zmm28, zmm2, zmm11 + vdpbf16ps zmm28, zmm2, zmm8 vbroadcastss zmm2, DWORD PTR [rbx + r11] vpslld zmm2, zmm2, 16 vpsrad zmm2, zmm2, 16 - vdpbf16ps zmm18, zmm2, zmm10 + vdpbf16ps zmm18, zmm2, zmm7 vpslld zmm2, zmm2, 16 vpsrad zmm2, zmm2, 16 - vdpbf16ps zmm29, zmm2, zmm11 + vdpbf16ps zmm29, zmm2, zmm8 vbroadcastss zmm2, DWORD PTR [rbp + r11] vpslld zmm2, zmm2, 16 vpsrad zmm2, zmm2, 16 - vdpbf16ps zmm19, zmm2, zmm10 + vdpbf16ps zmm19, zmm2, zmm7 vpslld zmm2, zmm2, 16 vpsrad zmm2, zmm2, 16 - vdpbf16ps zmm30, zmm2, zmm11 + vdpbf16ps zmm30, zmm2, zmm8 vbroadcastss zmm2, DWORD PTR [r8 + r11] vpslld zmm2, zmm2, 16 vpsrad zmm2, zmm2, 16 - vdpbf16ps zmm20, zmm2, zmm10 + vdpbf16ps zmm20, zmm2, zmm7 vpslld zmm2, zmm2, 16 vpsrad zmm2, zmm2, 16 - vdpbf16ps zmm12, zmm2, zmm11 + vdpbf16ps zmm9, zmm2, zmm8 vbroadcastss zmm2, DWORD PTR [rdi + r11] vpslld zmm2, zmm2, 16 vpsrad zmm2, zmm2, 16 - vdpbf16ps zmm21, zmm2, zmm10 + vdpbf16ps zmm21, zmm2, zmm7 vpslld zmm2, zmm2, 16 vpsrad zmm2, zmm2, 16 - vdpbf16ps zmm13, zmm2, zmm11 + vdpbf16ps zmm10, zmm2, zmm8 inner_loop_end: # Min/max clamping. - vminps zmm7, zmm1, zmm7 - vminps zmm8, zmm1, zmm8 - vminps zmm9, zmm1, zmm9 + vminps zmm11, zmm1, zmm11 + vminps zmm12, zmm1, zmm12 + vminps zmm13, zmm1, zmm13 vminps zmm14, zmm1, zmm14 vminps zmm15, zmm1, zmm15 vminps zmm16, zmm1, zmm16 @@ -390,11 +390,11 @@ inner_loop_end: vminps zmm28, zmm1, zmm28 vminps zmm29, zmm1, zmm29 vminps zmm30, zmm1, zmm30 - vminps zmm12, zmm1, zmm12 - vminps zmm13, zmm1, zmm13 - vmaxps zmm7, zmm0, zmm7 - vmaxps zmm8, zmm0, zmm8 - vmaxps zmm9, zmm0, zmm9 + vminps zmm9, zmm1, zmm9 + vminps zmm10, zmm1, zmm10 + vmaxps zmm11, zmm0, zmm11 + vmaxps zmm12, zmm0, zmm12 + vmaxps zmm13, zmm0, zmm13 vmaxps zmm14, zmm0, zmm14 vmaxps zmm15, zmm0, zmm15 vmaxps zmm16, zmm0, zmm16 @@ -412,8 +412,8 @@ inner_loop_end: vmaxps zmm28, zmm0, zmm28 vmaxps zmm29, zmm0, zmm29 vmaxps zmm30, zmm0, zmm30 - vmaxps zmm12, zmm0, zmm12 - vmaxps zmm13, zmm0, zmm13 + vmaxps zmm9, zmm0, zmm9 + vmaxps zmm10, zmm0, zmm10 # Pop output pointers from the stack. mov rcx, [rsp + 24] @@ -432,11 +432,11 @@ inner_loop_end: cmp rsi, 32 jl tail - vmovups [rcx], zmm7 + vmovups [rcx], zmm11 vmovups [rcx + 64], zmm22 - vmovups [rax], zmm8 + vmovups [rax], zmm12 vmovups [rax + 64], zmm23 - vmovups [r15], zmm9 + vmovups [r15], zmm13 vmovups [r15 + 64], zmm24 vmovups [r14], zmm14 vmovups [r14 + 64], zmm25 @@ -451,9 +451,9 @@ inner_loop_end: vmovups [rbp], zmm19 vmovups [rbp + 64], zmm30 vmovups [r8], zmm20 - vmovups [r8 + 64], zmm12 + vmovups [r8 + 64], zmm9 vmovups [rdi], zmm21 - vmovups [rdi + 64], zmm13 + vmovups [rdi + 64], zmm10 add rcx, 128 add rax, 128 add r15, 128 @@ -490,11 +490,11 @@ tail: kmovw k1, r11d shr r11d, 16 kmovw k2, r11d - vmovups ZMMWORD PTR [rcx]{k1}, zmm7 + vmovups ZMMWORD PTR [rcx]{k1}, zmm11 vmovups ZMMWORD PTR [rcx + 64]{k2}, zmm22 - vmovups ZMMWORD PTR [rax]{k1}, zmm8 + vmovups ZMMWORD PTR [rax]{k1}, zmm12 vmovups ZMMWORD PTR [rax + 64]{k2}, zmm23 - vmovups ZMMWORD PTR [r15]{k1}, zmm9 + vmovups ZMMWORD PTR [r15]{k1}, zmm13 vmovups ZMMWORD PTR [r15 + 64]{k2}, zmm24 vmovups ZMMWORD PTR [r14]{k1}, zmm14 vmovups ZMMWORD PTR [r14 + 64]{k2}, zmm25 @@ -509,9 +509,9 @@ tail: vmovups ZMMWORD PTR [rbp]{k1}, zmm19 vmovups ZMMWORD PTR [rbp + 64]{k2}, zmm30 vmovups ZMMWORD PTR [r8]{k1}, zmm20 - vmovups ZMMWORD PTR [r8 + 64]{k2}, zmm12 + vmovups ZMMWORD PTR [r8 + 64]{k2}, zmm9 vmovups ZMMWORD PTR [rdi]{k1}, zmm21 - vmovups ZMMWORD PTR [rdi + 64]{k2}, zmm13 + vmovups ZMMWORD PTR [rdi + 64]{k2}, zmm10 return: add rsp, 256 diff --git a/src/bf16-f32-gemm/gen/bf16-f32-gemm-1x16c2-minmax-asm-amd64-avx512bf16-broadcast.S b/src/bf16-f32-gemm/gen/bf16-f32-gemm-1x16c2-minmax-asm-amd64-avx512bf16-broadcast.S index 748977c0815..c872c46d1ca 100644 --- a/src/bf16-f32-gemm/gen/bf16-f32-gemm-1x16c2-minmax-asm-amd64-avx512bf16-broadcast.S +++ b/src/bf16-f32-gemm/gen/bf16-f32-gemm-1x16c2-minmax-asm-amd64-avx512bf16-broadcast.S @@ -49,7 +49,7 @@ outer_loop: # Initialize k counter. mov r11, 0 # Initialize accumulators with the biases. - vmovaps zmm7, [r9 + 0] + vmovaps zmm11, [r9 + 0] add r9, 64 # Are there at least 4 bytes? @@ -57,10 +57,10 @@ outer_loop: js inner_loop_tail inner_loop: - vmovaps zmm10, [r9 + 0] + vmovaps zmm7, [r9 + 0] add r9, 64 vbroadcastss zmm2, DWORD PTR [rcx + r11] - vdpbf16ps zmm7, zmm2, zmm10 + vdpbf16ps zmm11, zmm2, zmm7 add r11, 4 cmp rdx, r11 @@ -76,23 +76,23 @@ inner_loop: jz inner_loop_end inner_loop_tail: - vmovaps zmm10, [r9 + 0] + vmovaps zmm7, [r9 + 0] add r9, 64 vbroadcastss zmm2, DWORD PTR [rcx + r11] vpslld zmm2, zmm2, 16 vpsrad zmm2, zmm2, 16 - vdpbf16ps zmm7, zmm2, zmm10 + vdpbf16ps zmm11, zmm2, zmm7 inner_loop_end: # Min/max clamping. - vminps zmm7, zmm1, zmm7 - vmaxps zmm7, zmm0, zmm7 + vminps zmm11, zmm1, zmm11 + vmaxps zmm11, zmm0, zmm11 # Check whether full or partial store. cmp rsi, 16 jl tail - vmovups [r10], zmm7 + vmovups [r10], zmm11 add r10, 64 sub rsi, 16 @@ -104,7 +104,7 @@ tail: shlx r11, r11, rsi not r11 kmovw k1, r11d - vmovups ZMMWORD PTR [r10]{k1}, zmm7 + vmovups ZMMWORD PTR [r10]{k1}, zmm11 return: add rsp, 128 diff --git a/src/bf16-f32-gemm/gen/bf16-f32-gemm-1x32c2-minmax-asm-amd64-avx512bf16-broadcast.S b/src/bf16-f32-gemm/gen/bf16-f32-gemm-1x32c2-minmax-asm-amd64-avx512bf16-broadcast.S index c77925b6605..d551d5bc0ea 100644 --- a/src/bf16-f32-gemm/gen/bf16-f32-gemm-1x32c2-minmax-asm-amd64-avx512bf16-broadcast.S +++ b/src/bf16-f32-gemm/gen/bf16-f32-gemm-1x32c2-minmax-asm-amd64-avx512bf16-broadcast.S @@ -49,8 +49,8 @@ outer_loop: # Initialize k counter. mov r11, 0 # Initialize accumulators with the biases. - vmovaps zmm7, [r9 + 0] - vmovaps zmm8, [r9 + 64] + vmovaps zmm11, [r9 + 0] + vmovaps zmm12, [r9 + 64] add r9, 128 # Are there at least 4 bytes? @@ -58,12 +58,12 @@ outer_loop: js inner_loop_tail inner_loop: - vmovaps zmm10, [r9 + 0] - vmovaps zmm11, [r9 + 64] + vmovaps zmm7, [r9 + 0] + vmovaps zmm8, [r9 + 64] add r9, 128 vbroadcastss zmm2, DWORD PTR [rcx + r11] - vdpbf16ps zmm7, zmm2, zmm10 - vdpbf16ps zmm8, zmm2, zmm11 + vdpbf16ps zmm11, zmm2, zmm7 + vdpbf16ps zmm12, zmm2, zmm8 add r11, 4 cmp rdx, r11 @@ -79,31 +79,31 @@ inner_loop: jz inner_loop_end inner_loop_tail: - vmovaps zmm10, [r9 + 0] - vmovaps zmm11, [r9 + 64] + vmovaps zmm7, [r9 + 0] + vmovaps zmm8, [r9 + 64] add r9, 128 vbroadcastss zmm2, DWORD PTR [rcx + r11] vpslld zmm2, zmm2, 16 vpsrad zmm2, zmm2, 16 - vdpbf16ps zmm7, zmm2, zmm10 + vdpbf16ps zmm11, zmm2, zmm7 vpslld zmm2, zmm2, 16 vpsrad zmm2, zmm2, 16 - vdpbf16ps zmm8, zmm2, zmm11 + vdpbf16ps zmm12, zmm2, zmm8 inner_loop_end: # Min/max clamping. - vminps zmm7, zmm1, zmm7 - vminps zmm8, zmm1, zmm8 - vmaxps zmm7, zmm0, zmm7 - vmaxps zmm8, zmm0, zmm8 + vminps zmm11, zmm1, zmm11 + vminps zmm12, zmm1, zmm12 + vmaxps zmm11, zmm0, zmm11 + vmaxps zmm12, zmm0, zmm12 # Check whether full or partial store. cmp rsi, 32 jl tail - vmovups [r10], zmm7 - vmovups [r10 + 64], zmm8 + vmovups [r10], zmm11 + vmovups [r10 + 64], zmm12 add r10, 128 sub rsi, 32 @@ -117,8 +117,8 @@ tail: kmovw k1, r11d shr r11d, 16 kmovw k2, r11d - vmovups ZMMWORD PTR [r10]{k1}, zmm7 - vmovups ZMMWORD PTR [r10 + 64]{k2}, zmm8 + vmovups ZMMWORD PTR [r10]{k1}, zmm11 + vmovups ZMMWORD PTR [r10 + 64]{k2}, zmm12 return: add rsp, 128 diff --git a/src/bf16-f32-gemm/gen/bf16-f32-gemm-1x64c2-minmax-asm-amd64-avx512bf16-broadcast.S b/src/bf16-f32-gemm/gen/bf16-f32-gemm-1x64c2-minmax-asm-amd64-avx512bf16-broadcast.S index bb8dc420bc9..224adc86156 100644 --- a/src/bf16-f32-gemm/gen/bf16-f32-gemm-1x64c2-minmax-asm-amd64-avx512bf16-broadcast.S +++ b/src/bf16-f32-gemm/gen/bf16-f32-gemm-1x64c2-minmax-asm-amd64-avx512bf16-broadcast.S @@ -49,9 +49,9 @@ outer_loop: # Initialize k counter. mov r11, 0 # Initialize accumulators with the biases. - vmovaps zmm7, [r9 + 0] - vmovaps zmm8, [r9 + 64] - vmovaps zmm9, [r9 + 128] + vmovaps zmm11, [r9 + 0] + vmovaps zmm12, [r9 + 64] + vmovaps zmm13, [r9 + 128] vmovaps zmm14, [r9 + 192] add r9, 256 @@ -60,16 +60,16 @@ outer_loop: js inner_loop_tail inner_loop: - vmovaps zmm10, [r9 + 0] - vmovaps zmm11, [r9 + 64] - vmovaps zmm12, [r9 + 128] - vmovaps zmm13, [r9 + 192] + vmovaps zmm7, [r9 + 0] + vmovaps zmm8, [r9 + 64] + vmovaps zmm9, [r9 + 128] + vmovaps zmm10, [r9 + 192] add r9, 256 vbroadcastss zmm2, DWORD PTR [rcx + r11] - vdpbf16ps zmm7, zmm2, zmm10 - vdpbf16ps zmm8, zmm2, zmm11 - vdpbf16ps zmm9, zmm2, zmm12 - vdpbf16ps zmm14, zmm2, zmm13 + vdpbf16ps zmm11, zmm2, zmm7 + vdpbf16ps zmm12, zmm2, zmm8 + vdpbf16ps zmm13, zmm2, zmm9 + vdpbf16ps zmm14, zmm2, zmm10 add r11, 4 cmp rdx, r11 @@ -85,46 +85,46 @@ inner_loop: jz inner_loop_end inner_loop_tail: - vmovaps zmm10, [r9 + 0] - vmovaps zmm11, [r9 + 64] - vmovaps zmm12, [r9 + 128] - vmovaps zmm13, [r9 + 192] + vmovaps zmm7, [r9 + 0] + vmovaps zmm8, [r9 + 64] + vmovaps zmm9, [r9 + 128] + vmovaps zmm10, [r9 + 192] add r9, 256 vbroadcastss zmm2, DWORD PTR [rcx + r11] vpslld zmm2, zmm2, 16 vpsrad zmm2, zmm2, 16 - vdpbf16ps zmm7, zmm2, zmm10 + vdpbf16ps zmm11, zmm2, zmm7 vpslld zmm2, zmm2, 16 vpsrad zmm2, zmm2, 16 - vdpbf16ps zmm8, zmm2, zmm11 + vdpbf16ps zmm12, zmm2, zmm8 vpslld zmm2, zmm2, 16 vpsrad zmm2, zmm2, 16 - vdpbf16ps zmm9, zmm2, zmm12 + vdpbf16ps zmm13, zmm2, zmm9 vpslld zmm2, zmm2, 16 vpsrad zmm2, zmm2, 16 - vdpbf16ps zmm14, zmm2, zmm13 + vdpbf16ps zmm14, zmm2, zmm10 inner_loop_end: # Min/max clamping. - vminps zmm7, zmm1, zmm7 - vminps zmm8, zmm1, zmm8 - vminps zmm9, zmm1, zmm9 + vminps zmm11, zmm1, zmm11 + vminps zmm12, zmm1, zmm12 + vminps zmm13, zmm1, zmm13 vminps zmm14, zmm1, zmm14 - vmaxps zmm7, zmm0, zmm7 - vmaxps zmm8, zmm0, zmm8 - vmaxps zmm9, zmm0, zmm9 + vmaxps zmm11, zmm0, zmm11 + vmaxps zmm12, zmm0, zmm12 + vmaxps zmm13, zmm0, zmm13 vmaxps zmm14, zmm0, zmm14 # Check whether full or partial store. cmp rsi, 64 jl tail - vmovups [r10], zmm7 - vmovups [r10 + 64], zmm8 - vmovups [r10 + 128], zmm9 + vmovups [r10], zmm11 + vmovups [r10 + 64], zmm12 + vmovups [r10 + 128], zmm13 vmovups [r10 + 192], zmm14 add r10, 256 @@ -144,9 +144,9 @@ tail: shr r11, 16 kmovw k4, r11d - vmovups ZMMWORD PTR [r10]{k1}, zmm7 - vmovups ZMMWORD PTR [r10 + 64]{k2}, zmm8 - vmovups ZMMWORD PTR [r10 + 128]{k3}, zmm9 + vmovups ZMMWORD PTR [r10]{k1}, zmm11 + vmovups ZMMWORD PTR [r10 + 64]{k2}, zmm12 + vmovups ZMMWORD PTR [r10 + 128]{k3}, zmm13 vmovups ZMMWORD PTR [r10 + 192]{k4}, zmm14 return: diff --git a/src/bf16-f32-gemm/gen/bf16-f32-gemm-2x16c2-minmax-asm-amd64-avx512bf16-broadcast.S b/src/bf16-f32-gemm/gen/bf16-f32-gemm-2x16c2-minmax-asm-amd64-avx512bf16-broadcast.S index c6ea3e4011c..af98665483f 100644 --- a/src/bf16-f32-gemm/gen/bf16-f32-gemm-2x16c2-minmax-asm-amd64-avx512bf16-broadcast.S +++ b/src/bf16-f32-gemm/gen/bf16-f32-gemm-2x16c2-minmax-asm-amd64-avx512bf16-broadcast.S @@ -58,8 +58,8 @@ outer_loop: # Initialize k counter. mov r11, 0 # Initialize accumulators with the biases. - vmovaps zmm7, [r9 + 0] - vmovaps zmm8, zmm7 + vmovaps zmm11, [r9 + 0] + vmovaps zmm12, zmm11 add r9, 64 # Are there at least 4 bytes? @@ -67,12 +67,12 @@ outer_loop: js inner_loop_tail inner_loop: - vmovaps zmm10, [r9 + 0] + vmovaps zmm7, [r9 + 0] add r9, 64 vbroadcastss zmm2, DWORD PTR [rcx + r11] - vdpbf16ps zmm7, zmm2, zmm10 + vdpbf16ps zmm11, zmm2, zmm7 vbroadcastss zmm3, DWORD PTR [rax + r11] - vdpbf16ps zmm8, zmm3, zmm10 + vdpbf16ps zmm12, zmm3, zmm7 add r11, 4 cmp rdx, r11 @@ -88,31 +88,31 @@ inner_loop: jz inner_loop_end inner_loop_tail: - vmovaps zmm10, [r9 + 0] + vmovaps zmm7, [r9 + 0] add r9, 64 vbroadcastss zmm2, DWORD PTR [rcx + r11] vpslld zmm2, zmm2, 16 vpsrad zmm2, zmm2, 16 - vdpbf16ps zmm7, zmm2, zmm10 + vdpbf16ps zmm11, zmm2, zmm7 vbroadcastss zmm3, DWORD PTR [rax + r11] vpslld zmm3, zmm3, 16 vpsrad zmm3, zmm3, 16 - vdpbf16ps zmm8, zmm3, zmm10 + vdpbf16ps zmm12, zmm3, zmm7 inner_loop_end: # Min/max clamping. - vminps zmm7, zmm1, zmm7 - vminps zmm8, zmm1, zmm8 - vmaxps zmm7, zmm0, zmm7 - vmaxps zmm8, zmm0, zmm8 + vminps zmm11, zmm1, zmm11 + vminps zmm12, zmm1, zmm12 + vmaxps zmm11, zmm0, zmm11 + vmaxps zmm12, zmm0, zmm12 # Check whether full or partial store. cmp rsi, 16 jl tail - vmovups [r10], zmm7 - vmovups [r13], zmm8 + vmovups [r10], zmm11 + vmovups [r13], zmm12 add r10, 64 add r13, 64 @@ -125,8 +125,8 @@ tail: shlx r11, r11, rsi not r11 kmovw k1, r11d - vmovups ZMMWORD PTR [r10]{k1}, zmm7 - vmovups ZMMWORD PTR [r13]{k1}, zmm8 + vmovups ZMMWORD PTR [r10]{k1}, zmm11 + vmovups ZMMWORD PTR [r13]{k1}, zmm12 return: add rsp, 128 diff --git a/src/bf16-f32-gemm/gen/bf16-f32-gemm-2x32c2-minmax-asm-amd64-avx512bf16-broadcast.S b/src/bf16-f32-gemm/gen/bf16-f32-gemm-2x32c2-minmax-asm-amd64-avx512bf16-broadcast.S index 0a78cf58971..74545fd6490 100644 --- a/src/bf16-f32-gemm/gen/bf16-f32-gemm-2x32c2-minmax-asm-amd64-avx512bf16-broadcast.S +++ b/src/bf16-f32-gemm/gen/bf16-f32-gemm-2x32c2-minmax-asm-amd64-avx512bf16-broadcast.S @@ -58,10 +58,10 @@ outer_loop: # Initialize k counter. mov r11, 0 # Initialize accumulators with the biases. - vmovaps zmm7, [r9 + 0] - vmovaps zmm9, [r9 + 64] - vmovaps zmm8, zmm7 - vmovaps zmm14, zmm9 + vmovaps zmm11, [r9 + 0] + vmovaps zmm13, [r9 + 64] + vmovaps zmm12, zmm11 + vmovaps zmm14, zmm13 add r9, 128 # Are there at least 4 bytes? @@ -69,15 +69,15 @@ outer_loop: js inner_loop_tail inner_loop: - vmovaps zmm10, [r9 + 0] - vmovaps zmm11, [r9 + 64] + vmovaps zmm7, [r9 + 0] + vmovaps zmm8, [r9 + 64] add r9, 128 vbroadcastss zmm2, DWORD PTR [rcx + r11] - vdpbf16ps zmm7, zmm2, zmm10 - vdpbf16ps zmm9, zmm2, zmm11 + vdpbf16ps zmm11, zmm2, zmm7 + vdpbf16ps zmm13, zmm2, zmm8 vbroadcastss zmm3, DWORD PTR [rax + r11] - vdpbf16ps zmm8, zmm3, zmm10 - vdpbf16ps zmm14, zmm3, zmm11 + vdpbf16ps zmm12, zmm3, zmm7 + vdpbf16ps zmm14, zmm3, zmm8 add r11, 4 cmp rdx, r11 @@ -93,45 +93,45 @@ inner_loop: jz inner_loop_end inner_loop_tail: - vmovaps zmm10, [r9 + 0] - vmovaps zmm11, [r9 + 64] + vmovaps zmm7, [r9 + 0] + vmovaps zmm8, [r9 + 64] add r9, 128 vbroadcastss zmm2, DWORD PTR [rcx + r11] vpslld zmm2, zmm2, 16 vpsrad zmm2, zmm2, 16 - vdpbf16ps zmm7, zmm2, zmm10 + vdpbf16ps zmm11, zmm2, zmm7 vpslld zmm2, zmm2, 16 vpsrad zmm2, zmm2, 16 - vdpbf16ps zmm9, zmm2, zmm11 + vdpbf16ps zmm13, zmm2, zmm8 vbroadcastss zmm3, DWORD PTR [rax + r11] vpslld zmm3, zmm3, 16 vpsrad zmm3, zmm3, 16 - vdpbf16ps zmm8, zmm3, zmm10 + vdpbf16ps zmm12, zmm3, zmm7 vpslld zmm3, zmm3, 16 vpsrad zmm3, zmm3, 16 - vdpbf16ps zmm14, zmm3, zmm11 + vdpbf16ps zmm14, zmm3, zmm8 inner_loop_end: # Min/max clamping. - vminps zmm7, zmm1, zmm7 - vminps zmm8, zmm1, zmm8 - vminps zmm9, zmm1, zmm9 + vminps zmm11, zmm1, zmm11 + vminps zmm12, zmm1, zmm12 + vminps zmm13, zmm1, zmm13 vminps zmm14, zmm1, zmm14 - vmaxps zmm7, zmm0, zmm7 - vmaxps zmm8, zmm0, zmm8 - vmaxps zmm9, zmm0, zmm9 + vmaxps zmm11, zmm0, zmm11 + vmaxps zmm12, zmm0, zmm12 + vmaxps zmm13, zmm0, zmm13 vmaxps zmm14, zmm0, zmm14 # Check whether full or partial store. cmp rsi, 32 jl tail - vmovups [r10], zmm7 - vmovups [r10 + 64], zmm9 - vmovups [r13], zmm8 + vmovups [r10], zmm11 + vmovups [r10 + 64], zmm13 + vmovups [r13], zmm12 vmovups [r13 + 64], zmm14 add r10, 128 add r13, 128 @@ -147,9 +147,9 @@ tail: kmovw k1, r11d shr r11d, 16 kmovw k2, r11d - vmovups ZMMWORD PTR [r10]{k1}, zmm7 - vmovups ZMMWORD PTR [r10 + 64]{k2}, zmm9 - vmovups ZMMWORD PTR [r13]{k1}, zmm8 + vmovups ZMMWORD PTR [r10]{k1}, zmm11 + vmovups ZMMWORD PTR [r10 + 64]{k2}, zmm13 + vmovups ZMMWORD PTR [r13]{k1}, zmm12 vmovups ZMMWORD PTR [r13 + 64]{k2}, zmm14 return: diff --git a/src/bf16-f32-gemm/gen/bf16-f32-gemm-2x64c2-minmax-asm-amd64-avx512bf16-broadcast.S b/src/bf16-f32-gemm/gen/bf16-f32-gemm-2x64c2-minmax-asm-amd64-avx512bf16-broadcast.S index ac7eb82528a..d0a503ee1ae 100644 --- a/src/bf16-f32-gemm/gen/bf16-f32-gemm-2x64c2-minmax-asm-amd64-avx512bf16-broadcast.S +++ b/src/bf16-f32-gemm/gen/bf16-f32-gemm-2x64c2-minmax-asm-amd64-avx512bf16-broadcast.S @@ -58,12 +58,12 @@ outer_loop: # Initialize k counter. mov r11, 0 # Initialize accumulators with the biases. - vmovaps zmm7, [r9 + 0] - vmovaps zmm9, [r9 + 64] + vmovaps zmm11, [r9 + 0] + vmovaps zmm13, [r9 + 64] vmovaps zmm15, [r9 + 128] vmovaps zmm17, [r9 + 192] - vmovaps zmm8, zmm7 - vmovaps zmm14, zmm9 + vmovaps zmm12, zmm11 + vmovaps zmm14, zmm13 vmovaps zmm16, zmm15 vmovaps zmm18, zmm17 add r9, 256 @@ -73,21 +73,21 @@ outer_loop: js inner_loop_tail inner_loop: - vmovaps zmm10, [r9 + 0] - vmovaps zmm11, [r9 + 64] - vmovaps zmm12, [r9 + 128] - vmovaps zmm13, [r9 + 192] + vmovaps zmm7, [r9 + 0] + vmovaps zmm8, [r9 + 64] + vmovaps zmm9, [r9 + 128] + vmovaps zmm10, [r9 + 192] add r9, 256 vbroadcastss zmm2, DWORD PTR [rcx + r11] - vdpbf16ps zmm7, zmm2, zmm10 - vdpbf16ps zmm9, zmm2, zmm11 - vdpbf16ps zmm15, zmm2, zmm12 - vdpbf16ps zmm17, zmm2, zmm13 + vdpbf16ps zmm11, zmm2, zmm7 + vdpbf16ps zmm13, zmm2, zmm8 + vdpbf16ps zmm15, zmm2, zmm9 + vdpbf16ps zmm17, zmm2, zmm10 vbroadcastss zmm3, DWORD PTR [rax + r11] - vdpbf16ps zmm8, zmm3, zmm10 - vdpbf16ps zmm14, zmm3, zmm11 - vdpbf16ps zmm16, zmm3, zmm12 - vdpbf16ps zmm18, zmm3, zmm13 + vdpbf16ps zmm12, zmm3, zmm7 + vdpbf16ps zmm14, zmm3, zmm8 + vdpbf16ps zmm16, zmm3, zmm9 + vdpbf16ps zmm18, zmm3, zmm10 add r11, 4 cmp rdx, r11 @@ -103,58 +103,58 @@ inner_loop: jz inner_loop_end inner_loop_tail: - vmovaps zmm10, [r9 + 0] - vmovaps zmm11, [r9 + 64] - vmovaps zmm12, [r9 + 128] - vmovaps zmm13, [r9 + 192] + vmovaps zmm7, [r9 + 0] + vmovaps zmm8, [r9 + 64] + vmovaps zmm9, [r9 + 128] + vmovaps zmm10, [r9 + 192] add r9, 256 vbroadcastss zmm2, DWORD PTR [rcx + r11] vpslld zmm2, zmm2, 16 vpsrad zmm2, zmm2, 16 - vdpbf16ps zmm7, zmm2, zmm10 + vdpbf16ps zmm11, zmm2, zmm7 vpslld zmm2, zmm2, 16 vpsrad zmm2, zmm2, 16 - vdpbf16ps zmm9, zmm2, zmm11 + vdpbf16ps zmm13, zmm2, zmm8 vpslld zmm2, zmm2, 16 vpsrad zmm2, zmm2, 16 - vdpbf16ps zmm15, zmm2, zmm12 + vdpbf16ps zmm15, zmm2, zmm9 vpslld zmm2, zmm2, 16 vpsrad zmm2, zmm2, 16 - vdpbf16ps zmm17, zmm2, zmm13 + vdpbf16ps zmm17, zmm2, zmm10 vbroadcastss zmm3, DWORD PTR [rax + r11] vpslld zmm3, zmm3, 16 vpsrad zmm3, zmm3, 16 - vdpbf16ps zmm8, zmm3, zmm10 + vdpbf16ps zmm12, zmm3, zmm7 vpslld zmm3, zmm3, 16 vpsrad zmm3, zmm3, 16 - vdpbf16ps zmm14, zmm3, zmm11 + vdpbf16ps zmm14, zmm3, zmm8 vpslld zmm3, zmm3, 16 vpsrad zmm3, zmm3, 16 - vdpbf16ps zmm16, zmm3, zmm12 + vdpbf16ps zmm16, zmm3, zmm9 vpslld zmm3, zmm3, 16 vpsrad zmm3, zmm3, 16 - vdpbf16ps zmm18, zmm3, zmm13 + vdpbf16ps zmm18, zmm3, zmm10 inner_loop_end: # Min/max clamping. - vminps zmm7, zmm1, zmm7 - vminps zmm8, zmm1, zmm8 - vminps zmm9, zmm1, zmm9 + vminps zmm11, zmm1, zmm11 + vminps zmm12, zmm1, zmm12 + vminps zmm13, zmm1, zmm13 vminps zmm14, zmm1, zmm14 vminps zmm15, zmm1, zmm15 vminps zmm16, zmm1, zmm16 vminps zmm17, zmm1, zmm17 vminps zmm18, zmm1, zmm18 - vmaxps zmm7, zmm0, zmm7 - vmaxps zmm8, zmm0, zmm8 - vmaxps zmm9, zmm0, zmm9 + vmaxps zmm11, zmm0, zmm11 + vmaxps zmm12, zmm0, zmm12 + vmaxps zmm13, zmm0, zmm13 vmaxps zmm14, zmm0, zmm14 vmaxps zmm15, zmm0, zmm15 vmaxps zmm16, zmm0, zmm16 @@ -165,11 +165,11 @@ inner_loop_end: cmp rsi, 64 jl tail - vmovups [r10], zmm7 - vmovups [r10 + 64], zmm9 + vmovups [r10], zmm11 + vmovups [r10 + 64], zmm13 vmovups [r10 + 128], zmm15 vmovups [r10 + 192], zmm17 - vmovups [r13], zmm8 + vmovups [r13], zmm12 vmovups [r13 + 64], zmm14 vmovups [r13 + 128], zmm16 vmovups [r13 + 192], zmm18 @@ -192,11 +192,11 @@ tail: shr r11, 16 kmovw k4, r11d - vmovups ZMMWORD PTR [r10]{k1}, zmm7 - vmovups ZMMWORD PTR [r10 + 64]{k2}, zmm9 + vmovups ZMMWORD PTR [r10]{k1}, zmm11 + vmovups ZMMWORD PTR [r10 + 64]{k2}, zmm13 vmovups ZMMWORD PTR [r10 + 128]{k3}, zmm15 vmovups ZMMWORD PTR [r10 + 192]{k4}, zmm17 - vmovups ZMMWORD PTR [r13]{k1}, zmm8 + vmovups ZMMWORD PTR [r13]{k1}, zmm12 vmovups ZMMWORD PTR [r13 + 64]{k2}, zmm14 vmovups ZMMWORD PTR [r13 + 128]{k3}, zmm16 vmovups ZMMWORD PTR [r13 + 192]{k4}, zmm18 diff --git a/src/bf16-f32-gemm/gen/bf16-f32-gemm-3x16c2-minmax-asm-amd64-avx512bf16-broadcast.S b/src/bf16-f32-gemm/gen/bf16-f32-gemm-3x16c2-minmax-asm-amd64-avx512bf16-broadcast.S index 32a3a59aba7..f28b2cc45ae 100644 --- a/src/bf16-f32-gemm/gen/bf16-f32-gemm-3x16c2-minmax-asm-amd64-avx512bf16-broadcast.S +++ b/src/bf16-f32-gemm/gen/bf16-f32-gemm-3x16c2-minmax-asm-amd64-avx512bf16-broadcast.S @@ -67,9 +67,9 @@ outer_loop: # Initialize k counter. mov r11, 0 # Initialize accumulators with the biases. - vmovaps zmm7, [r9 + 0] - vmovaps zmm8, zmm7 - vmovaps zmm9, zmm7 + vmovaps zmm11, [r9 + 0] + vmovaps zmm12, zmm11 + vmovaps zmm13, zmm11 add r9, 64 # Are there at least 4 bytes? @@ -77,14 +77,14 @@ outer_loop: js inner_loop_tail inner_loop: - vmovaps zmm10, [r9 + 0] + vmovaps zmm7, [r9 + 0] add r9, 64 vbroadcastss zmm2, DWORD PTR [rcx + r11] - vdpbf16ps zmm7, zmm2, zmm10 + vdpbf16ps zmm11, zmm2, zmm7 vbroadcastss zmm3, DWORD PTR [rax + r11] - vdpbf16ps zmm8, zmm3, zmm10 + vdpbf16ps zmm12, zmm3, zmm7 vbroadcastss zmm4, DWORD PTR [r15 + r11] - vdpbf16ps zmm9, zmm4, zmm10 + vdpbf16ps zmm13, zmm4, zmm7 add r11, 4 cmp rdx, r11 @@ -100,39 +100,39 @@ inner_loop: jz inner_loop_end inner_loop_tail: - vmovaps zmm10, [r9 + 0] + vmovaps zmm7, [r9 + 0] add r9, 64 vbroadcastss zmm2, DWORD PTR [rcx + r11] vpslld zmm2, zmm2, 16 vpsrad zmm2, zmm2, 16 - vdpbf16ps zmm7, zmm2, zmm10 + vdpbf16ps zmm11, zmm2, zmm7 vbroadcastss zmm3, DWORD PTR [rax + r11] vpslld zmm3, zmm3, 16 vpsrad zmm3, zmm3, 16 - vdpbf16ps zmm8, zmm3, zmm10 + vdpbf16ps zmm12, zmm3, zmm7 vbroadcastss zmm4, DWORD PTR [r15 + r11] vpslld zmm4, zmm4, 16 vpsrad zmm4, zmm4, 16 - vdpbf16ps zmm9, zmm4, zmm10 + vdpbf16ps zmm13, zmm4, zmm7 inner_loop_end: # Min/max clamping. - vminps zmm7, zmm1, zmm7 - vminps zmm8, zmm1, zmm8 - vminps zmm9, zmm1, zmm9 - vmaxps zmm7, zmm0, zmm7 - vmaxps zmm8, zmm0, zmm8 - vmaxps zmm9, zmm0, zmm9 + vminps zmm11, zmm1, zmm11 + vminps zmm12, zmm1, zmm12 + vminps zmm13, zmm1, zmm13 + vmaxps zmm11, zmm0, zmm11 + vmaxps zmm12, zmm0, zmm12 + vmaxps zmm13, zmm0, zmm13 # Check whether full or partial store. cmp rsi, 16 jl tail - vmovups [r10], zmm7 - vmovups [r13], zmm8 - vmovups [rbx], zmm9 + vmovups [r10], zmm11 + vmovups [r13], zmm12 + vmovups [rbx], zmm13 add r10, 64 add r13, 64 add rbx, 64 @@ -146,9 +146,9 @@ tail: shlx r11, r11, rsi not r11 kmovw k1, r11d - vmovups ZMMWORD PTR [r10]{k1}, zmm7 - vmovups ZMMWORD PTR [r13]{k1}, zmm8 - vmovups ZMMWORD PTR [rbx]{k1}, zmm9 + vmovups ZMMWORD PTR [r10]{k1}, zmm11 + vmovups ZMMWORD PTR [r13]{k1}, zmm12 + vmovups ZMMWORD PTR [rbx]{k1}, zmm13 return: add rsp, 128 diff --git a/src/bf16-f32-gemm/gen/bf16-f32-gemm-3x32c2-minmax-asm-amd64-avx512bf16-broadcast.S b/src/bf16-f32-gemm/gen/bf16-f32-gemm-3x32c2-minmax-asm-amd64-avx512bf16-broadcast.S index 14c22c0bb07..998551a6a3f 100644 --- a/src/bf16-f32-gemm/gen/bf16-f32-gemm-3x32c2-minmax-asm-amd64-avx512bf16-broadcast.S +++ b/src/bf16-f32-gemm/gen/bf16-f32-gemm-3x32c2-minmax-asm-amd64-avx512bf16-broadcast.S @@ -67,10 +67,10 @@ outer_loop: # Initialize k counter. mov r11, 0 # Initialize accumulators with the biases. - vmovaps zmm7, [r9 + 0] + vmovaps zmm11, [r9 + 0] vmovaps zmm14, [r9 + 64] - vmovaps zmm8, zmm7 - vmovaps zmm9, zmm7 + vmovaps zmm12, zmm11 + vmovaps zmm13, zmm11 vmovaps zmm15, zmm14 vmovaps zmm16, zmm14 add r9, 128 @@ -80,18 +80,18 @@ outer_loop: js inner_loop_tail inner_loop: - vmovaps zmm10, [r9 + 0] - vmovaps zmm11, [r9 + 64] + vmovaps zmm7, [r9 + 0] + vmovaps zmm8, [r9 + 64] add r9, 128 vbroadcastss zmm2, DWORD PTR [rcx + r11] - vdpbf16ps zmm7, zmm2, zmm10 - vdpbf16ps zmm14, zmm2, zmm11 + vdpbf16ps zmm11, zmm2, zmm7 + vdpbf16ps zmm14, zmm2, zmm8 vbroadcastss zmm3, DWORD PTR [rax + r11] - vdpbf16ps zmm8, zmm3, zmm10 - vdpbf16ps zmm15, zmm3, zmm11 + vdpbf16ps zmm12, zmm3, zmm7 + vdpbf16ps zmm15, zmm3, zmm8 vbroadcastss zmm4, DWORD PTR [r15 + r11] - vdpbf16ps zmm9, zmm4, zmm10 - vdpbf16ps zmm16, zmm4, zmm11 + vdpbf16ps zmm13, zmm4, zmm7 + vdpbf16ps zmm16, zmm4, zmm8 add r11, 4 cmp rdx, r11 @@ -107,47 +107,47 @@ inner_loop: jz inner_loop_end inner_loop_tail: - vmovaps zmm10, [r9 + 0] - vmovaps zmm11, [r9 + 64] + vmovaps zmm7, [r9 + 0] + vmovaps zmm8, [r9 + 64] add r9, 128 vbroadcastss zmm2, DWORD PTR [rcx + r11] vpslld zmm2, zmm2, 16 vpsrad zmm2, zmm2, 16 - vdpbf16ps zmm7, zmm2, zmm10 + vdpbf16ps zmm11, zmm2, zmm7 vpslld zmm2, zmm2, 16 vpsrad zmm2, zmm2, 16 - vdpbf16ps zmm14, zmm2, zmm11 + vdpbf16ps zmm14, zmm2, zmm8 vbroadcastss zmm3, DWORD PTR [rax + r11] vpslld zmm3, zmm3, 16 vpsrad zmm3, zmm3, 16 - vdpbf16ps zmm8, zmm3, zmm10 + vdpbf16ps zmm12, zmm3, zmm7 vpslld zmm3, zmm3, 16 vpsrad zmm3, zmm3, 16 - vdpbf16ps zmm15, zmm3, zmm11 + vdpbf16ps zmm15, zmm3, zmm8 vbroadcastss zmm4, DWORD PTR [r15 + r11] vpslld zmm4, zmm4, 16 vpsrad zmm4, zmm4, 16 - vdpbf16ps zmm9, zmm4, zmm10 + vdpbf16ps zmm13, zmm4, zmm7 vpslld zmm4, zmm4, 16 vpsrad zmm4, zmm4, 16 - vdpbf16ps zmm16, zmm4, zmm11 + vdpbf16ps zmm16, zmm4, zmm8 inner_loop_end: # Min/max clamping. - vminps zmm7, zmm1, zmm7 - vminps zmm8, zmm1, zmm8 - vminps zmm9, zmm1, zmm9 + vminps zmm11, zmm1, zmm11 + vminps zmm12, zmm1, zmm12 + vminps zmm13, zmm1, zmm13 vminps zmm14, zmm1, zmm14 vminps zmm15, zmm1, zmm15 vminps zmm16, zmm1, zmm16 - vmaxps zmm7, zmm0, zmm7 - vmaxps zmm8, zmm0, zmm8 - vmaxps zmm9, zmm0, zmm9 + vmaxps zmm11, zmm0, zmm11 + vmaxps zmm12, zmm0, zmm12 + vmaxps zmm13, zmm0, zmm13 vmaxps zmm14, zmm0, zmm14 vmaxps zmm15, zmm0, zmm15 vmaxps zmm16, zmm0, zmm16 @@ -156,11 +156,11 @@ inner_loop_end: cmp rsi, 32 jl tail - vmovups [r10], zmm7 + vmovups [r10], zmm11 vmovups [r10 + 64], zmm14 - vmovups [r13], zmm8 + vmovups [r13], zmm12 vmovups [r13 + 64], zmm15 - vmovups [rbx], zmm9 + vmovups [rbx], zmm13 vmovups [rbx + 64], zmm16 add r10, 128 add r13, 128 @@ -177,11 +177,11 @@ tail: kmovw k1, r11d shr r11d, 16 kmovw k2, r11d - vmovups ZMMWORD PTR [r10]{k1}, zmm7 + vmovups ZMMWORD PTR [r10]{k1}, zmm11 vmovups ZMMWORD PTR [r10 + 64]{k2}, zmm14 - vmovups ZMMWORD PTR [r13]{k1}, zmm8 + vmovups ZMMWORD PTR [r13]{k1}, zmm12 vmovups ZMMWORD PTR [r13 + 64]{k2}, zmm15 - vmovups ZMMWORD PTR [rbx]{k1}, zmm9 + vmovups ZMMWORD PTR [rbx]{k1}, zmm13 vmovups ZMMWORD PTR [rbx + 64]{k2}, zmm16 return: diff --git a/src/bf16-f32-gemm/gen/bf16-f32-gemm-3x64c2-minmax-asm-amd64-avx512bf16-broadcast.S b/src/bf16-f32-gemm/gen/bf16-f32-gemm-3x64c2-minmax-asm-amd64-avx512bf16-broadcast.S index 2edf90b11ba..808204f0c7c 100644 --- a/src/bf16-f32-gemm/gen/bf16-f32-gemm-3x64c2-minmax-asm-amd64-avx512bf16-broadcast.S +++ b/src/bf16-f32-gemm/gen/bf16-f32-gemm-3x64c2-minmax-asm-amd64-avx512bf16-broadcast.S @@ -67,12 +67,12 @@ outer_loop: # Initialize k counter. mov r11, 0 # Initialize accumulators with the biases. - vmovaps zmm7, [r9 + 0] + vmovaps zmm11, [r9 + 0] vmovaps zmm14, [r9 + 64] vmovaps zmm17, [r9 + 128] vmovaps zmm20, [r9 + 192] - vmovaps zmm8, zmm7 - vmovaps zmm9, zmm7 + vmovaps zmm12, zmm11 + vmovaps zmm13, zmm11 vmovaps zmm15, zmm14 vmovaps zmm16, zmm14 vmovaps zmm18, zmm17 @@ -86,26 +86,26 @@ outer_loop: js inner_loop_tail inner_loop: - vmovaps zmm10, [r9 + 0] - vmovaps zmm11, [r9 + 64] - vmovaps zmm12, [r9 + 128] - vmovaps zmm13, [r9 + 192] + vmovaps zmm7, [r9 + 0] + vmovaps zmm8, [r9 + 64] + vmovaps zmm9, [r9 + 128] + vmovaps zmm10, [r9 + 192] add r9, 256 vbroadcastss zmm2, DWORD PTR [rcx + r11] - vdpbf16ps zmm7, zmm2, zmm10 - vdpbf16ps zmm14, zmm2, zmm11 - vdpbf16ps zmm17, zmm2, zmm12 - vdpbf16ps zmm20, zmm2, zmm13 + vdpbf16ps zmm11, zmm2, zmm7 + vdpbf16ps zmm14, zmm2, zmm8 + vdpbf16ps zmm17, zmm2, zmm9 + vdpbf16ps zmm20, zmm2, zmm10 vbroadcastss zmm3, DWORD PTR [rax + r11] - vdpbf16ps zmm8, zmm3, zmm10 - vdpbf16ps zmm15, zmm3, zmm11 - vdpbf16ps zmm18, zmm3, zmm12 - vdpbf16ps zmm21, zmm3, zmm13 + vdpbf16ps zmm12, zmm3, zmm7 + vdpbf16ps zmm15, zmm3, zmm8 + vdpbf16ps zmm18, zmm3, zmm9 + vdpbf16ps zmm21, zmm3, zmm10 vbroadcastss zmm4, DWORD PTR [r15 + r11] - vdpbf16ps zmm9, zmm4, zmm10 - vdpbf16ps zmm16, zmm4, zmm11 - vdpbf16ps zmm19, zmm4, zmm12 - vdpbf16ps zmm22, zmm4, zmm13 + vdpbf16ps zmm13, zmm4, zmm7 + vdpbf16ps zmm16, zmm4, zmm8 + vdpbf16ps zmm19, zmm4, zmm9 + vdpbf16ps zmm22, zmm4, zmm10 add r11, 4 cmp rdx, r11 @@ -121,67 +121,67 @@ inner_loop: jz inner_loop_end inner_loop_tail: - vmovaps zmm10, [r9 + 0] - vmovaps zmm11, [r9 + 64] - vmovaps zmm12, [r9 + 128] - vmovaps zmm13, [r9 + 192] + vmovaps zmm7, [r9 + 0] + vmovaps zmm8, [r9 + 64] + vmovaps zmm9, [r9 + 128] + vmovaps zmm10, [r9 + 192] add r9, 256 vbroadcastss zmm2, DWORD PTR [rcx + r11] vpslld zmm2, zmm2, 16 vpsrad zmm2, zmm2, 16 - vdpbf16ps zmm7, zmm2, zmm10 + vdpbf16ps zmm11, zmm2, zmm7 vpslld zmm2, zmm2, 16 vpsrad zmm2, zmm2, 16 - vdpbf16ps zmm14, zmm2, zmm11 + vdpbf16ps zmm14, zmm2, zmm8 vpslld zmm2, zmm2, 16 vpsrad zmm2, zmm2, 16 - vdpbf16ps zmm17, zmm2, zmm12 + vdpbf16ps zmm17, zmm2, zmm9 vpslld zmm2, zmm2, 16 vpsrad zmm2, zmm2, 16 - vdpbf16ps zmm20, zmm2, zmm13 + vdpbf16ps zmm20, zmm2, zmm10 vbroadcastss zmm3, DWORD PTR [rax + r11] vpslld zmm3, zmm3, 16 vpsrad zmm3, zmm3, 16 - vdpbf16ps zmm8, zmm3, zmm10 + vdpbf16ps zmm12, zmm3, zmm7 vpslld zmm3, zmm3, 16 vpsrad zmm3, zmm3, 16 - vdpbf16ps zmm15, zmm3, zmm11 + vdpbf16ps zmm15, zmm3, zmm8 vpslld zmm3, zmm3, 16 vpsrad zmm3, zmm3, 16 - vdpbf16ps zmm18, zmm3, zmm12 + vdpbf16ps zmm18, zmm3, zmm9 vpslld zmm3, zmm3, 16 vpsrad zmm3, zmm3, 16 - vdpbf16ps zmm21, zmm3, zmm13 + vdpbf16ps zmm21, zmm3, zmm10 vbroadcastss zmm4, DWORD PTR [r15 + r11] vpslld zmm4, zmm4, 16 vpsrad zmm4, zmm4, 16 - vdpbf16ps zmm9, zmm4, zmm10 + vdpbf16ps zmm13, zmm4, zmm7 vpslld zmm4, zmm4, 16 vpsrad zmm4, zmm4, 16 - vdpbf16ps zmm16, zmm4, zmm11 + vdpbf16ps zmm16, zmm4, zmm8 vpslld zmm4, zmm4, 16 vpsrad zmm4, zmm4, 16 - vdpbf16ps zmm19, zmm4, zmm12 + vdpbf16ps zmm19, zmm4, zmm9 vpslld zmm4, zmm4, 16 vpsrad zmm4, zmm4, 16 - vdpbf16ps zmm22, zmm4, zmm13 + vdpbf16ps zmm22, zmm4, zmm10 inner_loop_end: # Min/max clamping. - vminps zmm7, zmm1, zmm7 - vminps zmm8, zmm1, zmm8 - vminps zmm9, zmm1, zmm9 + vminps zmm11, zmm1, zmm11 + vminps zmm12, zmm1, zmm12 + vminps zmm13, zmm1, zmm13 vminps zmm14, zmm1, zmm14 vminps zmm15, zmm1, zmm15 vminps zmm16, zmm1, zmm16 @@ -191,9 +191,9 @@ inner_loop_end: vminps zmm20, zmm1, zmm20 vminps zmm21, zmm1, zmm21 vminps zmm22, zmm1, zmm22 - vmaxps zmm7, zmm0, zmm7 - vmaxps zmm8, zmm0, zmm8 - vmaxps zmm9, zmm0, zmm9 + vmaxps zmm11, zmm0, zmm11 + vmaxps zmm12, zmm0, zmm12 + vmaxps zmm13, zmm0, zmm13 vmaxps zmm14, zmm0, zmm14 vmaxps zmm15, zmm0, zmm15 vmaxps zmm16, zmm0, zmm16 @@ -208,15 +208,15 @@ inner_loop_end: cmp rsi, 64 jl tail - vmovups [r10], zmm7 + vmovups [r10], zmm11 vmovups [r10 + 64], zmm14 vmovups [r10 + 128], zmm17 vmovups [r10 + 192], zmm20 - vmovups [r13], zmm8 + vmovups [r13], zmm12 vmovups [r13 + 64], zmm15 vmovups [r13 + 128], zmm18 vmovups [r13 + 192], zmm21 - vmovups [rbx], zmm9 + vmovups [rbx], zmm13 vmovups [rbx + 64], zmm16 vmovups [rbx + 128], zmm19 vmovups [rbx + 192], zmm22 @@ -240,15 +240,15 @@ tail: shr r11, 16 kmovw k4, r11d - vmovups ZMMWORD PTR [r10]{k1}, zmm7 + vmovups ZMMWORD PTR [r10]{k1}, zmm11 vmovups ZMMWORD PTR [r10 + 64]{k2}, zmm14 vmovups ZMMWORD PTR [r10 + 128]{k3}, zmm17 vmovups ZMMWORD PTR [r10 + 192]{k4}, zmm20 - vmovups ZMMWORD PTR [r13]{k1}, zmm8 + vmovups ZMMWORD PTR [r13]{k1}, zmm12 vmovups ZMMWORD PTR [r13 + 64]{k2}, zmm15 vmovups ZMMWORD PTR [r13 + 128]{k3}, zmm18 vmovups ZMMWORD PTR [r13 + 192]{k4}, zmm21 - vmovups ZMMWORD PTR [rbx]{k1}, zmm9 + vmovups ZMMWORD PTR [rbx]{k1}, zmm13 vmovups ZMMWORD PTR [rbx + 64]{k2}, zmm16 vmovups ZMMWORD PTR [rbx + 128]{k3}, zmm19 vmovups ZMMWORD PTR [rbx + 192]{k4}, zmm22 diff --git a/src/bf16-f32-gemm/gen/bf16-f32-gemm-4x16c2-minmax-asm-amd64-avx512bf16-broadcast.S b/src/bf16-f32-gemm/gen/bf16-f32-gemm-4x16c2-minmax-asm-amd64-avx512bf16-broadcast.S index 093923931f5..1062ebaf61e 100644 --- a/src/bf16-f32-gemm/gen/bf16-f32-gemm-4x16c2-minmax-asm-amd64-avx512bf16-broadcast.S +++ b/src/bf16-f32-gemm/gen/bf16-f32-gemm-4x16c2-minmax-asm-amd64-avx512bf16-broadcast.S @@ -76,10 +76,10 @@ outer_loop: # Initialize k counter. mov r11, 0 # Initialize accumulators with the biases. - vmovaps zmm7, [r9 + 0] - vmovaps zmm8, zmm7 - vmovaps zmm9, zmm7 - vmovaps zmm14, zmm7 + vmovaps zmm11, [r9 + 0] + vmovaps zmm12, zmm11 + vmovaps zmm13, zmm11 + vmovaps zmm14, zmm11 add r9, 64 # Are there at least 4 bytes? @@ -87,16 +87,16 @@ outer_loop: js inner_loop_tail inner_loop: - vmovaps zmm10, [r9 + 0] + vmovaps zmm7, [r9 + 0] add r9, 64 vbroadcastss zmm2, DWORD PTR [rcx + r11] - vdpbf16ps zmm7, zmm2, zmm10 + vdpbf16ps zmm11, zmm2, zmm7 vbroadcastss zmm3, DWORD PTR [rax + r11] - vdpbf16ps zmm8, zmm3, zmm10 + vdpbf16ps zmm12, zmm3, zmm7 vbroadcastss zmm4, DWORD PTR [r15 + r11] - vdpbf16ps zmm9, zmm4, zmm10 + vdpbf16ps zmm13, zmm4, zmm7 vbroadcastss zmm5, DWORD PTR [r14 + r11] - vdpbf16ps zmm14, zmm5, zmm10 + vdpbf16ps zmm14, zmm5, zmm7 add r11, 4 cmp rdx, r11 @@ -112,46 +112,46 @@ inner_loop: jz inner_loop_end inner_loop_tail: - vmovaps zmm10, [r9 + 0] + vmovaps zmm7, [r9 + 0] add r9, 64 vbroadcastss zmm2, DWORD PTR [rcx + r11] vpslld zmm2, zmm2, 16 vpsrad zmm2, zmm2, 16 - vdpbf16ps zmm7, zmm2, zmm10 + vdpbf16ps zmm11, zmm2, zmm7 vbroadcastss zmm3, DWORD PTR [rax + r11] vpslld zmm3, zmm3, 16 vpsrad zmm3, zmm3, 16 - vdpbf16ps zmm8, zmm3, zmm10 + vdpbf16ps zmm12, zmm3, zmm7 vbroadcastss zmm4, DWORD PTR [r15 + r11] vpslld zmm4, zmm4, 16 vpsrad zmm4, zmm4, 16 - vdpbf16ps zmm9, zmm4, zmm10 + vdpbf16ps zmm13, zmm4, zmm7 vbroadcastss zmm5, DWORD PTR [r14 + r11] vpslld zmm5, zmm5, 16 vpsrad zmm5, zmm5, 16 - vdpbf16ps zmm14, zmm5, zmm10 + vdpbf16ps zmm14, zmm5, zmm7 inner_loop_end: # Min/max clamping. - vminps zmm7, zmm1, zmm7 - vminps zmm8, zmm1, zmm8 - vminps zmm9, zmm1, zmm9 + vminps zmm11, zmm1, zmm11 + vminps zmm12, zmm1, zmm12 + vminps zmm13, zmm1, zmm13 vminps zmm14, zmm1, zmm14 - vmaxps zmm7, zmm0, zmm7 - vmaxps zmm8, zmm0, zmm8 - vmaxps zmm9, zmm0, zmm9 + vmaxps zmm11, zmm0, zmm11 + vmaxps zmm12, zmm0, zmm12 + vmaxps zmm13, zmm0, zmm13 vmaxps zmm14, zmm0, zmm14 # Check whether full or partial store. cmp rsi, 16 jl tail - vmovups [r10], zmm7 - vmovups [r13], zmm8 - vmovups [rbx], zmm9 + vmovups [r10], zmm11 + vmovups [r13], zmm12 + vmovups [rbx], zmm13 vmovups [rbp], zmm14 add r10, 64 add r13, 64 @@ -167,9 +167,9 @@ tail: shlx r11, r11, rsi not r11 kmovw k1, r11d - vmovups ZMMWORD PTR [r10]{k1}, zmm7 - vmovups ZMMWORD PTR [r13]{k1}, zmm8 - vmovups ZMMWORD PTR [rbx]{k1}, zmm9 + vmovups ZMMWORD PTR [r10]{k1}, zmm11 + vmovups ZMMWORD PTR [r13]{k1}, zmm12 + vmovups ZMMWORD PTR [rbx]{k1}, zmm13 vmovups ZMMWORD PTR [rbp]{k1}, zmm14 return: diff --git a/src/bf16-f32-gemm/gen/bf16-f32-gemm-4x32c2-minmax-asm-amd64-avx512bf16-broadcast.S b/src/bf16-f32-gemm/gen/bf16-f32-gemm-4x32c2-minmax-asm-amd64-avx512bf16-broadcast.S index 576774596fc..c343e42dfb4 100644 --- a/src/bf16-f32-gemm/gen/bf16-f32-gemm-4x32c2-minmax-asm-amd64-avx512bf16-broadcast.S +++ b/src/bf16-f32-gemm/gen/bf16-f32-gemm-4x32c2-minmax-asm-amd64-avx512bf16-broadcast.S @@ -76,11 +76,11 @@ outer_loop: # Initialize k counter. mov r11, 0 # Initialize accumulators with the biases. - vmovaps zmm7, [r9 + 0] + vmovaps zmm11, [r9 + 0] vmovaps zmm15, [r9 + 64] - vmovaps zmm8, zmm7 - vmovaps zmm9, zmm7 - vmovaps zmm14, zmm7 + vmovaps zmm12, zmm11 + vmovaps zmm13, zmm11 + vmovaps zmm14, zmm11 vmovaps zmm16, zmm15 vmovaps zmm17, zmm15 vmovaps zmm18, zmm15 @@ -91,21 +91,21 @@ outer_loop: js inner_loop_tail inner_loop: - vmovaps zmm10, [r9 + 0] - vmovaps zmm11, [r9 + 64] + vmovaps zmm7, [r9 + 0] + vmovaps zmm8, [r9 + 64] add r9, 128 vbroadcastss zmm2, DWORD PTR [rcx + r11] - vdpbf16ps zmm7, zmm2, zmm10 - vdpbf16ps zmm15, zmm2, zmm11 + vdpbf16ps zmm11, zmm2, zmm7 + vdpbf16ps zmm15, zmm2, zmm8 vbroadcastss zmm3, DWORD PTR [rax + r11] - vdpbf16ps zmm8, zmm3, zmm10 - vdpbf16ps zmm16, zmm3, zmm11 + vdpbf16ps zmm12, zmm3, zmm7 + vdpbf16ps zmm16, zmm3, zmm8 vbroadcastss zmm4, DWORD PTR [r15 + r11] - vdpbf16ps zmm9, zmm4, zmm10 - vdpbf16ps zmm17, zmm4, zmm11 + vdpbf16ps zmm13, zmm4, zmm7 + vdpbf16ps zmm17, zmm4, zmm8 vbroadcastss zmm5, DWORD PTR [r14 + r11] - vdpbf16ps zmm14, zmm5, zmm10 - vdpbf16ps zmm18, zmm5, zmm11 + vdpbf16ps zmm14, zmm5, zmm7 + vdpbf16ps zmm18, zmm5, zmm8 add r11, 4 cmp rdx, r11 @@ -121,58 +121,58 @@ inner_loop: jz inner_loop_end inner_loop_tail: - vmovaps zmm10, [r9 + 0] - vmovaps zmm11, [r9 + 64] + vmovaps zmm7, [r9 + 0] + vmovaps zmm8, [r9 + 64] add r9, 128 vbroadcastss zmm2, DWORD PTR [rcx + r11] vpslld zmm2, zmm2, 16 vpsrad zmm2, zmm2, 16 - vdpbf16ps zmm7, zmm2, zmm10 + vdpbf16ps zmm11, zmm2, zmm7 vpslld zmm2, zmm2, 16 vpsrad zmm2, zmm2, 16 - vdpbf16ps zmm15, zmm2, zmm11 + vdpbf16ps zmm15, zmm2, zmm8 vbroadcastss zmm3, DWORD PTR [rax + r11] vpslld zmm3, zmm3, 16 vpsrad zmm3, zmm3, 16 - vdpbf16ps zmm8, zmm3, zmm10 + vdpbf16ps zmm12, zmm3, zmm7 vpslld zmm3, zmm3, 16 vpsrad zmm3, zmm3, 16 - vdpbf16ps zmm16, zmm3, zmm11 + vdpbf16ps zmm16, zmm3, zmm8 vbroadcastss zmm4, DWORD PTR [r15 + r11] vpslld zmm4, zmm4, 16 vpsrad zmm4, zmm4, 16 - vdpbf16ps zmm9, zmm4, zmm10 + vdpbf16ps zmm13, zmm4, zmm7 vpslld zmm4, zmm4, 16 vpsrad zmm4, zmm4, 16 - vdpbf16ps zmm17, zmm4, zmm11 + vdpbf16ps zmm17, zmm4, zmm8 vbroadcastss zmm5, DWORD PTR [r14 + r11] vpslld zmm5, zmm5, 16 vpsrad zmm5, zmm5, 16 - vdpbf16ps zmm14, zmm5, zmm10 + vdpbf16ps zmm14, zmm5, zmm7 vpslld zmm5, zmm5, 16 vpsrad zmm5, zmm5, 16 - vdpbf16ps zmm18, zmm5, zmm11 + vdpbf16ps zmm18, zmm5, zmm8 inner_loop_end: # Min/max clamping. - vminps zmm7, zmm1, zmm7 - vminps zmm8, zmm1, zmm8 - vminps zmm9, zmm1, zmm9 + vminps zmm11, zmm1, zmm11 + vminps zmm12, zmm1, zmm12 + vminps zmm13, zmm1, zmm13 vminps zmm14, zmm1, zmm14 vminps zmm15, zmm1, zmm15 vminps zmm16, zmm1, zmm16 vminps zmm17, zmm1, zmm17 vminps zmm18, zmm1, zmm18 - vmaxps zmm7, zmm0, zmm7 - vmaxps zmm8, zmm0, zmm8 - vmaxps zmm9, zmm0, zmm9 + vmaxps zmm11, zmm0, zmm11 + vmaxps zmm12, zmm0, zmm12 + vmaxps zmm13, zmm0, zmm13 vmaxps zmm14, zmm0, zmm14 vmaxps zmm15, zmm0, zmm15 vmaxps zmm16, zmm0, zmm16 @@ -183,11 +183,11 @@ inner_loop_end: cmp rsi, 32 jl tail - vmovups [r10], zmm7 + vmovups [r10], zmm11 vmovups [r10 + 64], zmm15 - vmovups [r13], zmm8 + vmovups [r13], zmm12 vmovups [r13 + 64], zmm16 - vmovups [rbx], zmm9 + vmovups [rbx], zmm13 vmovups [rbx + 64], zmm17 vmovups [rbp], zmm14 vmovups [rbp + 64], zmm18 @@ -207,11 +207,11 @@ tail: kmovw k1, r11d shr r11d, 16 kmovw k2, r11d - vmovups ZMMWORD PTR [r10]{k1}, zmm7 + vmovups ZMMWORD PTR [r10]{k1}, zmm11 vmovups ZMMWORD PTR [r10 + 64]{k2}, zmm15 - vmovups ZMMWORD PTR [r13]{k1}, zmm8 + vmovups ZMMWORD PTR [r13]{k1}, zmm12 vmovups ZMMWORD PTR [r13 + 64]{k2}, zmm16 - vmovups ZMMWORD PTR [rbx]{k1}, zmm9 + vmovups ZMMWORD PTR [rbx]{k1}, zmm13 vmovups ZMMWORD PTR [rbx + 64]{k2}, zmm17 vmovups ZMMWORD PTR [rbp]{k1}, zmm14 vmovups ZMMWORD PTR [rbp + 64]{k2}, zmm18 diff --git a/src/bf16-f32-gemm/gen/bf16-f32-gemm-4x64c2-minmax-asm-amd64-avx512bf16-broadcast.S b/src/bf16-f32-gemm/gen/bf16-f32-gemm-4x64c2-minmax-asm-amd64-avx512bf16-broadcast.S index edbfe156564..27ac260f213 100644 --- a/src/bf16-f32-gemm/gen/bf16-f32-gemm-4x64c2-minmax-asm-amd64-avx512bf16-broadcast.S +++ b/src/bf16-f32-gemm/gen/bf16-f32-gemm-4x64c2-minmax-asm-amd64-avx512bf16-broadcast.S @@ -76,13 +76,13 @@ outer_loop: # Initialize k counter. mov r11, 0 # Initialize accumulators with the biases. - vmovaps zmm7, [r9 + 0] + vmovaps zmm11, [r9 + 0] vmovaps zmm15, [r9 + 64] vmovaps zmm19, [r9 + 128] vmovaps zmm23, [r9 + 192] - vmovaps zmm8, zmm7 - vmovaps zmm9, zmm7 - vmovaps zmm14, zmm7 + vmovaps zmm12, zmm11 + vmovaps zmm13, zmm11 + vmovaps zmm14, zmm11 vmovaps zmm16, zmm15 vmovaps zmm17, zmm15 vmovaps zmm18, zmm15 @@ -99,31 +99,31 @@ outer_loop: js inner_loop_tail inner_loop: - vmovaps zmm10, [r9 + 0] - vmovaps zmm11, [r9 + 64] - vmovaps zmm12, [r9 + 128] - vmovaps zmm13, [r9 + 192] + vmovaps zmm7, [r9 + 0] + vmovaps zmm8, [r9 + 64] + vmovaps zmm9, [r9 + 128] + vmovaps zmm10, [r9 + 192] add r9, 256 vbroadcastss zmm2, DWORD PTR [rcx + r11] - vdpbf16ps zmm7, zmm2, zmm10 - vdpbf16ps zmm15, zmm2, zmm11 - vdpbf16ps zmm19, zmm2, zmm12 - vdpbf16ps zmm23, zmm2, zmm13 + vdpbf16ps zmm11, zmm2, zmm7 + vdpbf16ps zmm15, zmm2, zmm8 + vdpbf16ps zmm19, zmm2, zmm9 + vdpbf16ps zmm23, zmm2, zmm10 vbroadcastss zmm3, DWORD PTR [rax + r11] - vdpbf16ps zmm8, zmm3, zmm10 - vdpbf16ps zmm16, zmm3, zmm11 - vdpbf16ps zmm20, zmm3, zmm12 - vdpbf16ps zmm24, zmm3, zmm13 + vdpbf16ps zmm12, zmm3, zmm7 + vdpbf16ps zmm16, zmm3, zmm8 + vdpbf16ps zmm20, zmm3, zmm9 + vdpbf16ps zmm24, zmm3, zmm10 vbroadcastss zmm4, DWORD PTR [r15 + r11] - vdpbf16ps zmm9, zmm4, zmm10 - vdpbf16ps zmm17, zmm4, zmm11 - vdpbf16ps zmm21, zmm4, zmm12 - vdpbf16ps zmm25, zmm4, zmm13 + vdpbf16ps zmm13, zmm4, zmm7 + vdpbf16ps zmm17, zmm4, zmm8 + vdpbf16ps zmm21, zmm4, zmm9 + vdpbf16ps zmm25, zmm4, zmm10 vbroadcastss zmm5, DWORD PTR [r14 + r11] - vdpbf16ps zmm14, zmm5, zmm10 - vdpbf16ps zmm18, zmm5, zmm11 - vdpbf16ps zmm22, zmm5, zmm12 - vdpbf16ps zmm26, zmm5, zmm13 + vdpbf16ps zmm14, zmm5, zmm7 + vdpbf16ps zmm18, zmm5, zmm8 + vdpbf16ps zmm22, zmm5, zmm9 + vdpbf16ps zmm26, zmm5, zmm10 add r11, 4 cmp rdx, r11 @@ -139,84 +139,84 @@ inner_loop: jz inner_loop_end inner_loop_tail: - vmovaps zmm10, [r9 + 0] - vmovaps zmm11, [r9 + 64] - vmovaps zmm12, [r9 + 128] - vmovaps zmm13, [r9 + 192] + vmovaps zmm7, [r9 + 0] + vmovaps zmm8, [r9 + 64] + vmovaps zmm9, [r9 + 128] + vmovaps zmm10, [r9 + 192] add r9, 256 vbroadcastss zmm2, DWORD PTR [rcx + r11] vpslld zmm2, zmm2, 16 vpsrad zmm2, zmm2, 16 - vdpbf16ps zmm7, zmm2, zmm10 + vdpbf16ps zmm11, zmm2, zmm7 vpslld zmm2, zmm2, 16 vpsrad zmm2, zmm2, 16 - vdpbf16ps zmm15, zmm2, zmm11 + vdpbf16ps zmm15, zmm2, zmm8 vpslld zmm2, zmm2, 16 vpsrad zmm2, zmm2, 16 - vdpbf16ps zmm19, zmm2, zmm12 + vdpbf16ps zmm19, zmm2, zmm9 vpslld zmm2, zmm2, 16 vpsrad zmm2, zmm2, 16 - vdpbf16ps zmm23, zmm2, zmm13 + vdpbf16ps zmm23, zmm2, zmm10 vbroadcastss zmm3, DWORD PTR [rax + r11] vpslld zmm3, zmm3, 16 vpsrad zmm3, zmm3, 16 - vdpbf16ps zmm8, zmm3, zmm10 + vdpbf16ps zmm12, zmm3, zmm7 vpslld zmm3, zmm3, 16 vpsrad zmm3, zmm3, 16 - vdpbf16ps zmm16, zmm3, zmm11 + vdpbf16ps zmm16, zmm3, zmm8 vpslld zmm3, zmm3, 16 vpsrad zmm3, zmm3, 16 - vdpbf16ps zmm20, zmm3, zmm12 + vdpbf16ps zmm20, zmm3, zmm9 vpslld zmm3, zmm3, 16 vpsrad zmm3, zmm3, 16 - vdpbf16ps zmm24, zmm3, zmm13 + vdpbf16ps zmm24, zmm3, zmm10 vbroadcastss zmm4, DWORD PTR [r15 + r11] vpslld zmm4, zmm4, 16 vpsrad zmm4, zmm4, 16 - vdpbf16ps zmm9, zmm4, zmm10 + vdpbf16ps zmm13, zmm4, zmm7 vpslld zmm4, zmm4, 16 vpsrad zmm4, zmm4, 16 - vdpbf16ps zmm17, zmm4, zmm11 + vdpbf16ps zmm17, zmm4, zmm8 vpslld zmm4, zmm4, 16 vpsrad zmm4, zmm4, 16 - vdpbf16ps zmm21, zmm4, zmm12 + vdpbf16ps zmm21, zmm4, zmm9 vpslld zmm4, zmm4, 16 vpsrad zmm4, zmm4, 16 - vdpbf16ps zmm25, zmm4, zmm13 + vdpbf16ps zmm25, zmm4, zmm10 vbroadcastss zmm5, DWORD PTR [r14 + r11] vpslld zmm5, zmm5, 16 vpsrad zmm5, zmm5, 16 - vdpbf16ps zmm14, zmm5, zmm10 + vdpbf16ps zmm14, zmm5, zmm7 vpslld zmm5, zmm5, 16 vpsrad zmm5, zmm5, 16 - vdpbf16ps zmm18, zmm5, zmm11 + vdpbf16ps zmm18, zmm5, zmm8 vpslld zmm5, zmm5, 16 vpsrad zmm5, zmm5, 16 - vdpbf16ps zmm22, zmm5, zmm12 + vdpbf16ps zmm22, zmm5, zmm9 vpslld zmm5, zmm5, 16 vpsrad zmm5, zmm5, 16 - vdpbf16ps zmm26, zmm5, zmm13 + vdpbf16ps zmm26, zmm5, zmm10 inner_loop_end: # Min/max clamping. - vminps zmm7, zmm1, zmm7 - vminps zmm8, zmm1, zmm8 - vminps zmm9, zmm1, zmm9 + vminps zmm11, zmm1, zmm11 + vminps zmm12, zmm1, zmm12 + vminps zmm13, zmm1, zmm13 vminps zmm14, zmm1, zmm14 vminps zmm15, zmm1, zmm15 vminps zmm16, zmm1, zmm16 @@ -230,9 +230,9 @@ inner_loop_end: vminps zmm24, zmm1, zmm24 vminps zmm25, zmm1, zmm25 vminps zmm26, zmm1, zmm26 - vmaxps zmm7, zmm0, zmm7 - vmaxps zmm8, zmm0, zmm8 - vmaxps zmm9, zmm0, zmm9 + vmaxps zmm11, zmm0, zmm11 + vmaxps zmm12, zmm0, zmm12 + vmaxps zmm13, zmm0, zmm13 vmaxps zmm14, zmm0, zmm14 vmaxps zmm15, zmm0, zmm15 vmaxps zmm16, zmm0, zmm16 @@ -251,15 +251,15 @@ inner_loop_end: cmp rsi, 64 jl tail - vmovups [r10], zmm7 + vmovups [r10], zmm11 vmovups [r10 + 64], zmm15 vmovups [r10 + 128], zmm19 vmovups [r10 + 192], zmm23 - vmovups [r13], zmm8 + vmovups [r13], zmm12 vmovups [r13 + 64], zmm16 vmovups [r13 + 128], zmm20 vmovups [r13 + 192], zmm24 - vmovups [rbx], zmm9 + vmovups [rbx], zmm13 vmovups [rbx + 64], zmm17 vmovups [rbx + 128], zmm21 vmovups [rbx + 192], zmm25 @@ -288,15 +288,15 @@ tail: shr r11, 16 kmovw k4, r11d - vmovups ZMMWORD PTR [r10]{k1}, zmm7 + vmovups ZMMWORD PTR [r10]{k1}, zmm11 vmovups ZMMWORD PTR [r10 + 64]{k2}, zmm15 vmovups ZMMWORD PTR [r10 + 128]{k3}, zmm19 vmovups ZMMWORD PTR [r10 + 192]{k4}, zmm23 - vmovups ZMMWORD PTR [r13]{k1}, zmm8 + vmovups ZMMWORD PTR [r13]{k1}, zmm12 vmovups ZMMWORD PTR [r13 + 64]{k2}, zmm16 vmovups ZMMWORD PTR [r13 + 128]{k3}, zmm20 vmovups ZMMWORD PTR [r13 + 192]{k4}, zmm24 - vmovups ZMMWORD PTR [rbx]{k1}, zmm9 + vmovups ZMMWORD PTR [rbx]{k1}, zmm13 vmovups ZMMWORD PTR [rbx + 64]{k2}, zmm17 vmovups ZMMWORD PTR [rbx + 128]{k3}, zmm21 vmovups ZMMWORD PTR [rbx + 192]{k4}, zmm25 diff --git a/src/bf16-f32-gemm/gen/bf16-f32-gemm-5x16c2-minmax-asm-amd64-avx512bf16-broadcast.S b/src/bf16-f32-gemm/gen/bf16-f32-gemm-5x16c2-minmax-asm-amd64-avx512bf16-broadcast.S index 25e48b57ba5..2dbafab097c 100644 --- a/src/bf16-f32-gemm/gen/bf16-f32-gemm-5x16c2-minmax-asm-amd64-avx512bf16-broadcast.S +++ b/src/bf16-f32-gemm/gen/bf16-f32-gemm-5x16c2-minmax-asm-amd64-avx512bf16-broadcast.S @@ -85,11 +85,11 @@ outer_loop: # Initialize k counter. mov r11, 0 # Initialize accumulators with the biases. - vmovaps zmm7, [r9 + 0] - vmovaps zmm8, zmm7 - vmovaps zmm9, zmm7 - vmovaps zmm14, zmm7 - vmovaps zmm15, zmm7 + vmovaps zmm11, [r9 + 0] + vmovaps zmm12, zmm11 + vmovaps zmm13, zmm11 + vmovaps zmm14, zmm11 + vmovaps zmm15, zmm11 add r9, 64 # Are there at least 4 bytes? @@ -97,18 +97,18 @@ outer_loop: js inner_loop_tail inner_loop: - vmovaps zmm10, [r9 + 0] + vmovaps zmm7, [r9 + 0] add r9, 64 vbroadcastss zmm2, DWORD PTR [rcx + r11] - vdpbf16ps zmm7, zmm2, zmm10 + vdpbf16ps zmm11, zmm2, zmm7 vbroadcastss zmm3, DWORD PTR [rax + r11] - vdpbf16ps zmm8, zmm3, zmm10 + vdpbf16ps zmm12, zmm3, zmm7 vbroadcastss zmm4, DWORD PTR [r15 + r11] - vdpbf16ps zmm9, zmm4, zmm10 + vdpbf16ps zmm13, zmm4, zmm7 vbroadcastss zmm5, DWORD PTR [r14 + r11] - vdpbf16ps zmm14, zmm5, zmm10 + vdpbf16ps zmm14, zmm5, zmm7 vbroadcastss zmm6, DWORD PTR [r12 + r11] - vdpbf16ps zmm15, zmm6, zmm10 + vdpbf16ps zmm15, zmm6, zmm7 add r11, 4 cmp rdx, r11 @@ -124,43 +124,43 @@ inner_loop: jz inner_loop_end inner_loop_tail: - vmovaps zmm10, [r9 + 0] + vmovaps zmm7, [r9 + 0] add r9, 64 vbroadcastss zmm2, DWORD PTR [rcx + r11] vpslld zmm2, zmm2, 16 vpsrad zmm2, zmm2, 16 - vdpbf16ps zmm7, zmm2, zmm10 + vdpbf16ps zmm11, zmm2, zmm7 vbroadcastss zmm3, DWORD PTR [rax + r11] vpslld zmm3, zmm3, 16 vpsrad zmm3, zmm3, 16 - vdpbf16ps zmm8, zmm3, zmm10 + vdpbf16ps zmm12, zmm3, zmm7 vbroadcastss zmm4, DWORD PTR [r15 + r11] vpslld zmm4, zmm4, 16 vpsrad zmm4, zmm4, 16 - vdpbf16ps zmm9, zmm4, zmm10 + vdpbf16ps zmm13, zmm4, zmm7 vbroadcastss zmm5, DWORD PTR [r14 + r11] vpslld zmm5, zmm5, 16 vpsrad zmm5, zmm5, 16 - vdpbf16ps zmm14, zmm5, zmm10 + vdpbf16ps zmm14, zmm5, zmm7 vbroadcastss zmm6, DWORD PTR [r12 + r11] vpslld zmm6, zmm6, 16 vpsrad zmm6, zmm6, 16 - vdpbf16ps zmm15, zmm6, zmm10 + vdpbf16ps zmm15, zmm6, zmm7 inner_loop_end: # Min/max clamping. - vminps zmm7, zmm1, zmm7 - vminps zmm8, zmm1, zmm8 - vminps zmm9, zmm1, zmm9 + vminps zmm11, zmm1, zmm11 + vminps zmm12, zmm1, zmm12 + vminps zmm13, zmm1, zmm13 vminps zmm14, zmm1, zmm14 vminps zmm15, zmm1, zmm15 - vmaxps zmm7, zmm0, zmm7 - vmaxps zmm8, zmm0, zmm8 - vmaxps zmm9, zmm0, zmm9 + vmaxps zmm11, zmm0, zmm11 + vmaxps zmm12, zmm0, zmm12 + vmaxps zmm13, zmm0, zmm13 vmaxps zmm14, zmm0, zmm14 vmaxps zmm15, zmm0, zmm15 @@ -168,9 +168,9 @@ inner_loop_end: cmp rsi, 16 jl tail - vmovups [r10], zmm7 - vmovups [r13], zmm8 - vmovups [rbx], zmm9 + vmovups [r10], zmm11 + vmovups [r13], zmm12 + vmovups [rbx], zmm13 vmovups [rbp], zmm14 vmovups [r8], zmm15 add r10, 64 @@ -188,9 +188,9 @@ tail: shlx r11, r11, rsi not r11 kmovw k1, r11d - vmovups ZMMWORD PTR [r10]{k1}, zmm7 - vmovups ZMMWORD PTR [r13]{k1}, zmm8 - vmovups ZMMWORD PTR [rbx]{k1}, zmm9 + vmovups ZMMWORD PTR [r10]{k1}, zmm11 + vmovups ZMMWORD PTR [r13]{k1}, zmm12 + vmovups ZMMWORD PTR [rbx]{k1}, zmm13 vmovups ZMMWORD PTR [rbp]{k1}, zmm14 vmovups ZMMWORD PTR [r8]{k1}, zmm15 diff --git a/src/bf16-f32-gemm/gen/bf16-f32-gemm-5x32c2-minmax-asm-amd64-avx512bf16-broadcast.S b/src/bf16-f32-gemm/gen/bf16-f32-gemm-5x32c2-minmax-asm-amd64-avx512bf16-broadcast.S index 20963e3dcca..8b564ee49f5 100644 --- a/src/bf16-f32-gemm/gen/bf16-f32-gemm-5x32c2-minmax-asm-amd64-avx512bf16-broadcast.S +++ b/src/bf16-f32-gemm/gen/bf16-f32-gemm-5x32c2-minmax-asm-amd64-avx512bf16-broadcast.S @@ -85,12 +85,12 @@ outer_loop: # Initialize k counter. mov r11, 0 # Initialize accumulators with the biases. - vmovaps zmm7, [r9 + 0] + vmovaps zmm11, [r9 + 0] vmovaps zmm16, [r9 + 64] - vmovaps zmm8, zmm7 - vmovaps zmm9, zmm7 - vmovaps zmm14, zmm7 - vmovaps zmm15, zmm7 + vmovaps zmm12, zmm11 + vmovaps zmm13, zmm11 + vmovaps zmm14, zmm11 + vmovaps zmm15, zmm11 vmovaps zmm17, zmm16 vmovaps zmm18, zmm16 vmovaps zmm19, zmm16 @@ -102,24 +102,24 @@ outer_loop: js inner_loop_tail inner_loop: - vmovaps zmm10, [r9 + 0] - vmovaps zmm11, [r9 + 64] + vmovaps zmm7, [r9 + 0] + vmovaps zmm8, [r9 + 64] add r9, 128 vbroadcastss zmm2, DWORD PTR [rcx + r11] - vdpbf16ps zmm7, zmm2, zmm10 - vdpbf16ps zmm16, zmm2, zmm11 + vdpbf16ps zmm11, zmm2, zmm7 + vdpbf16ps zmm16, zmm2, zmm8 vbroadcastss zmm3, DWORD PTR [rax + r11] - vdpbf16ps zmm8, zmm3, zmm10 - vdpbf16ps zmm17, zmm3, zmm11 + vdpbf16ps zmm12, zmm3, zmm7 + vdpbf16ps zmm17, zmm3, zmm8 vbroadcastss zmm4, DWORD PTR [r15 + r11] - vdpbf16ps zmm9, zmm4, zmm10 - vdpbf16ps zmm18, zmm4, zmm11 + vdpbf16ps zmm13, zmm4, zmm7 + vdpbf16ps zmm18, zmm4, zmm8 vbroadcastss zmm5, DWORD PTR [r14 + r11] - vdpbf16ps zmm14, zmm5, zmm10 - vdpbf16ps zmm19, zmm5, zmm11 + vdpbf16ps zmm14, zmm5, zmm7 + vdpbf16ps zmm19, zmm5, zmm8 vbroadcastss zmm6, DWORD PTR [r12 + r11] - vdpbf16ps zmm15, zmm6, zmm10 - vdpbf16ps zmm20, zmm6, zmm11 + vdpbf16ps zmm15, zmm6, zmm7 + vdpbf16ps zmm20, zmm6, zmm8 add r11, 4 cmp rdx, r11 @@ -135,59 +135,59 @@ inner_loop: jz inner_loop_end inner_loop_tail: - vmovaps zmm10, [r9 + 0] - vmovaps zmm11, [r9 + 64] + vmovaps zmm7, [r9 + 0] + vmovaps zmm8, [r9 + 64] add r9, 128 vbroadcastss zmm2, DWORD PTR [rcx + r11] vpslld zmm2, zmm2, 16 vpsrad zmm2, zmm2, 16 - vdpbf16ps zmm7, zmm2, zmm10 + vdpbf16ps zmm11, zmm2, zmm7 vpslld zmm2, zmm2, 16 vpsrad zmm2, zmm2, 16 - vdpbf16ps zmm16, zmm2, zmm11 + vdpbf16ps zmm16, zmm2, zmm8 vbroadcastss zmm3, DWORD PTR [rax + r11] vpslld zmm3, zmm3, 16 vpsrad zmm3, zmm3, 16 - vdpbf16ps zmm8, zmm3, zmm10 + vdpbf16ps zmm12, zmm3, zmm7 vpslld zmm3, zmm3, 16 vpsrad zmm3, zmm3, 16 - vdpbf16ps zmm17, zmm3, zmm11 + vdpbf16ps zmm17, zmm3, zmm8 vbroadcastss zmm4, DWORD PTR [r15 + r11] vpslld zmm4, zmm4, 16 vpsrad zmm4, zmm4, 16 - vdpbf16ps zmm9, zmm4, zmm10 + vdpbf16ps zmm13, zmm4, zmm7 vpslld zmm4, zmm4, 16 vpsrad zmm4, zmm4, 16 - vdpbf16ps zmm18, zmm4, zmm11 + vdpbf16ps zmm18, zmm4, zmm8 vbroadcastss zmm5, DWORD PTR [r14 + r11] vpslld zmm5, zmm5, 16 vpsrad zmm5, zmm5, 16 - vdpbf16ps zmm14, zmm5, zmm10 + vdpbf16ps zmm14, zmm5, zmm7 vpslld zmm5, zmm5, 16 vpsrad zmm5, zmm5, 16 - vdpbf16ps zmm19, zmm5, zmm11 + vdpbf16ps zmm19, zmm5, zmm8 vbroadcastss zmm6, DWORD PTR [r12 + r11] vpslld zmm6, zmm6, 16 vpsrad zmm6, zmm6, 16 - vdpbf16ps zmm15, zmm6, zmm10 + vdpbf16ps zmm15, zmm6, zmm7 vpslld zmm6, zmm6, 16 vpsrad zmm6, zmm6, 16 - vdpbf16ps zmm20, zmm6, zmm11 + vdpbf16ps zmm20, zmm6, zmm8 inner_loop_end: # Min/max clamping. - vminps zmm7, zmm1, zmm7 - vminps zmm8, zmm1, zmm8 - vminps zmm9, zmm1, zmm9 + vminps zmm11, zmm1, zmm11 + vminps zmm12, zmm1, zmm12 + vminps zmm13, zmm1, zmm13 vminps zmm14, zmm1, zmm14 vminps zmm15, zmm1, zmm15 vminps zmm16, zmm1, zmm16 @@ -195,9 +195,9 @@ inner_loop_end: vminps zmm18, zmm1, zmm18 vminps zmm19, zmm1, zmm19 vminps zmm20, zmm1, zmm20 - vmaxps zmm7, zmm0, zmm7 - vmaxps zmm8, zmm0, zmm8 - vmaxps zmm9, zmm0, zmm9 + vmaxps zmm11, zmm0, zmm11 + vmaxps zmm12, zmm0, zmm12 + vmaxps zmm13, zmm0, zmm13 vmaxps zmm14, zmm0, zmm14 vmaxps zmm15, zmm0, zmm15 vmaxps zmm16, zmm0, zmm16 @@ -210,11 +210,11 @@ inner_loop_end: cmp rsi, 32 jl tail - vmovups [r10], zmm7 + vmovups [r10], zmm11 vmovups [r10 + 64], zmm16 - vmovups [r13], zmm8 + vmovups [r13], zmm12 vmovups [r13 + 64], zmm17 - vmovups [rbx], zmm9 + vmovups [rbx], zmm13 vmovups [rbx + 64], zmm18 vmovups [rbp], zmm14 vmovups [rbp + 64], zmm19 @@ -237,11 +237,11 @@ tail: kmovw k1, r11d shr r11d, 16 kmovw k2, r11d - vmovups ZMMWORD PTR [r10]{k1}, zmm7 + vmovups ZMMWORD PTR [r10]{k1}, zmm11 vmovups ZMMWORD PTR [r10 + 64]{k2}, zmm16 - vmovups ZMMWORD PTR [r13]{k1}, zmm8 + vmovups ZMMWORD PTR [r13]{k1}, zmm12 vmovups ZMMWORD PTR [r13 + 64]{k2}, zmm17 - vmovups ZMMWORD PTR [rbx]{k1}, zmm9 + vmovups ZMMWORD PTR [rbx]{k1}, zmm13 vmovups ZMMWORD PTR [rbx + 64]{k2}, zmm18 vmovups ZMMWORD PTR [rbp]{k1}, zmm14 vmovups ZMMWORD PTR [rbp + 64]{k2}, zmm19 diff --git a/src/bf16-f32-gemm/gen/bf16-f32-gemm-5x64c2-minmax-asm-amd64-avx512bf16-broadcast.S b/src/bf16-f32-gemm/gen/bf16-f32-gemm-5x64c2-minmax-asm-amd64-avx512bf16-broadcast.S index afe88c74b0b..990cb269449 100644 --- a/src/bf16-f32-gemm/gen/bf16-f32-gemm-5x64c2-minmax-asm-amd64-avx512bf16-broadcast.S +++ b/src/bf16-f32-gemm/gen/bf16-f32-gemm-5x64c2-minmax-asm-amd64-avx512bf16-broadcast.S @@ -85,14 +85,14 @@ outer_loop: # Initialize k counter. mov r11, 0 # Initialize accumulators with the biases. - vmovaps zmm7, [r9 + 0] + vmovaps zmm11, [r9 + 0] vmovaps zmm16, [r9 + 64] vmovaps zmm21, [r9 + 128] vmovaps zmm26, [r9 + 192] - vmovaps zmm8, zmm7 - vmovaps zmm9, zmm7 - vmovaps zmm14, zmm7 - vmovaps zmm15, zmm7 + vmovaps zmm12, zmm11 + vmovaps zmm13, zmm11 + vmovaps zmm14, zmm11 + vmovaps zmm15, zmm11 vmovaps zmm17, zmm16 vmovaps zmm18, zmm16 vmovaps zmm19, zmm16 @@ -112,36 +112,36 @@ outer_loop: js inner_loop_tail inner_loop: - vmovaps zmm10, [r9 + 0] - vmovaps zmm11, [r9 + 64] - vmovaps zmm12, [r9 + 128] - vmovaps zmm13, [r9 + 192] + vmovaps zmm7, [r9 + 0] + vmovaps zmm8, [r9 + 64] + vmovaps zmm9, [r9 + 128] + vmovaps zmm10, [r9 + 192] add r9, 256 vbroadcastss zmm2, DWORD PTR [rcx + r11] - vdpbf16ps zmm7, zmm2, zmm10 - vdpbf16ps zmm16, zmm2, zmm11 - vdpbf16ps zmm21, zmm2, zmm12 - vdpbf16ps zmm26, zmm2, zmm13 + vdpbf16ps zmm11, zmm2, zmm7 + vdpbf16ps zmm16, zmm2, zmm8 + vdpbf16ps zmm21, zmm2, zmm9 + vdpbf16ps zmm26, zmm2, zmm10 vbroadcastss zmm3, DWORD PTR [rax + r11] - vdpbf16ps zmm8, zmm3, zmm10 - vdpbf16ps zmm17, zmm3, zmm11 - vdpbf16ps zmm22, zmm3, zmm12 - vdpbf16ps zmm27, zmm3, zmm13 + vdpbf16ps zmm12, zmm3, zmm7 + vdpbf16ps zmm17, zmm3, zmm8 + vdpbf16ps zmm22, zmm3, zmm9 + vdpbf16ps zmm27, zmm3, zmm10 vbroadcastss zmm4, DWORD PTR [r15 + r11] - vdpbf16ps zmm9, zmm4, zmm10 - vdpbf16ps zmm18, zmm4, zmm11 - vdpbf16ps zmm23, zmm4, zmm12 - vdpbf16ps zmm28, zmm4, zmm13 + vdpbf16ps zmm13, zmm4, zmm7 + vdpbf16ps zmm18, zmm4, zmm8 + vdpbf16ps zmm23, zmm4, zmm9 + vdpbf16ps zmm28, zmm4, zmm10 vbroadcastss zmm5, DWORD PTR [r14 + r11] - vdpbf16ps zmm14, zmm5, zmm10 - vdpbf16ps zmm19, zmm5, zmm11 - vdpbf16ps zmm24, zmm5, zmm12 - vdpbf16ps zmm29, zmm5, zmm13 + vdpbf16ps zmm14, zmm5, zmm7 + vdpbf16ps zmm19, zmm5, zmm8 + vdpbf16ps zmm24, zmm5, zmm9 + vdpbf16ps zmm29, zmm5, zmm10 vbroadcastss zmm6, DWORD PTR [r12 + r11] - vdpbf16ps zmm15, zmm6, zmm10 - vdpbf16ps zmm20, zmm6, zmm11 - vdpbf16ps zmm25, zmm6, zmm12 - vdpbf16ps zmm30, zmm6, zmm13 + vdpbf16ps zmm15, zmm6, zmm7 + vdpbf16ps zmm20, zmm6, zmm8 + vdpbf16ps zmm25, zmm6, zmm9 + vdpbf16ps zmm30, zmm6, zmm10 add r11, 4 cmp rdx, r11 @@ -157,101 +157,101 @@ inner_loop: jz inner_loop_end inner_loop_tail: - vmovaps zmm10, [r9 + 0] - vmovaps zmm11, [r9 + 64] - vmovaps zmm12, [r9 + 128] - vmovaps zmm13, [r9 + 192] + vmovaps zmm7, [r9 + 0] + vmovaps zmm8, [r9 + 64] + vmovaps zmm9, [r9 + 128] + vmovaps zmm10, [r9 + 192] add r9, 256 vbroadcastss zmm2, DWORD PTR [rcx + r11] vpslld zmm2, zmm2, 16 vpsrad zmm2, zmm2, 16 - vdpbf16ps zmm7, zmm2, zmm10 + vdpbf16ps zmm11, zmm2, zmm7 vpslld zmm2, zmm2, 16 vpsrad zmm2, zmm2, 16 - vdpbf16ps zmm16, zmm2, zmm11 + vdpbf16ps zmm16, zmm2, zmm8 vpslld zmm2, zmm2, 16 vpsrad zmm2, zmm2, 16 - vdpbf16ps zmm21, zmm2, zmm12 + vdpbf16ps zmm21, zmm2, zmm9 vpslld zmm2, zmm2, 16 vpsrad zmm2, zmm2, 16 - vdpbf16ps zmm26, zmm2, zmm13 + vdpbf16ps zmm26, zmm2, zmm10 vbroadcastss zmm3, DWORD PTR [rax + r11] vpslld zmm3, zmm3, 16 vpsrad zmm3, zmm3, 16 - vdpbf16ps zmm8, zmm3, zmm10 + vdpbf16ps zmm12, zmm3, zmm7 vpslld zmm3, zmm3, 16 vpsrad zmm3, zmm3, 16 - vdpbf16ps zmm17, zmm3, zmm11 + vdpbf16ps zmm17, zmm3, zmm8 vpslld zmm3, zmm3, 16 vpsrad zmm3, zmm3, 16 - vdpbf16ps zmm22, zmm3, zmm12 + vdpbf16ps zmm22, zmm3, zmm9 vpslld zmm3, zmm3, 16 vpsrad zmm3, zmm3, 16 - vdpbf16ps zmm27, zmm3, zmm13 + vdpbf16ps zmm27, zmm3, zmm10 vbroadcastss zmm4, DWORD PTR [r15 + r11] vpslld zmm4, zmm4, 16 vpsrad zmm4, zmm4, 16 - vdpbf16ps zmm9, zmm4, zmm10 + vdpbf16ps zmm13, zmm4, zmm7 vpslld zmm4, zmm4, 16 vpsrad zmm4, zmm4, 16 - vdpbf16ps zmm18, zmm4, zmm11 + vdpbf16ps zmm18, zmm4, zmm8 vpslld zmm4, zmm4, 16 vpsrad zmm4, zmm4, 16 - vdpbf16ps zmm23, zmm4, zmm12 + vdpbf16ps zmm23, zmm4, zmm9 vpslld zmm4, zmm4, 16 vpsrad zmm4, zmm4, 16 - vdpbf16ps zmm28, zmm4, zmm13 + vdpbf16ps zmm28, zmm4, zmm10 vbroadcastss zmm5, DWORD PTR [r14 + r11] vpslld zmm5, zmm5, 16 vpsrad zmm5, zmm5, 16 - vdpbf16ps zmm14, zmm5, zmm10 + vdpbf16ps zmm14, zmm5, zmm7 vpslld zmm5, zmm5, 16 vpsrad zmm5, zmm5, 16 - vdpbf16ps zmm19, zmm5, zmm11 + vdpbf16ps zmm19, zmm5, zmm8 vpslld zmm5, zmm5, 16 vpsrad zmm5, zmm5, 16 - vdpbf16ps zmm24, zmm5, zmm12 + vdpbf16ps zmm24, zmm5, zmm9 vpslld zmm5, zmm5, 16 vpsrad zmm5, zmm5, 16 - vdpbf16ps zmm29, zmm5, zmm13 + vdpbf16ps zmm29, zmm5, zmm10 vbroadcastss zmm6, DWORD PTR [r12 + r11] vpslld zmm6, zmm6, 16 vpsrad zmm6, zmm6, 16 - vdpbf16ps zmm15, zmm6, zmm10 + vdpbf16ps zmm15, zmm6, zmm7 vpslld zmm6, zmm6, 16 vpsrad zmm6, zmm6, 16 - vdpbf16ps zmm20, zmm6, zmm11 + vdpbf16ps zmm20, zmm6, zmm8 vpslld zmm6, zmm6, 16 vpsrad zmm6, zmm6, 16 - vdpbf16ps zmm25, zmm6, zmm12 + vdpbf16ps zmm25, zmm6, zmm9 vpslld zmm6, zmm6, 16 vpsrad zmm6, zmm6, 16 - vdpbf16ps zmm30, zmm6, zmm13 + vdpbf16ps zmm30, zmm6, zmm10 inner_loop_end: # Min/max clamping. - vminps zmm7, zmm1, zmm7 - vminps zmm8, zmm1, zmm8 - vminps zmm9, zmm1, zmm9 + vminps zmm11, zmm1, zmm11 + vminps zmm12, zmm1, zmm12 + vminps zmm13, zmm1, zmm13 vminps zmm14, zmm1, zmm14 vminps zmm15, zmm1, zmm15 vminps zmm16, zmm1, zmm16 @@ -269,9 +269,9 @@ inner_loop_end: vminps zmm28, zmm1, zmm28 vminps zmm29, zmm1, zmm29 vminps zmm30, zmm1, zmm30 - vmaxps zmm7, zmm0, zmm7 - vmaxps zmm8, zmm0, zmm8 - vmaxps zmm9, zmm0, zmm9 + vmaxps zmm11, zmm0, zmm11 + vmaxps zmm12, zmm0, zmm12 + vmaxps zmm13, zmm0, zmm13 vmaxps zmm14, zmm0, zmm14 vmaxps zmm15, zmm0, zmm15 vmaxps zmm16, zmm0, zmm16 @@ -294,15 +294,15 @@ inner_loop_end: cmp rsi, 64 jl tail - vmovups [r10], zmm7 + vmovups [r10], zmm11 vmovups [r10 + 64], zmm16 vmovups [r10 + 128], zmm21 vmovups [r10 + 192], zmm26 - vmovups [r13], zmm8 + vmovups [r13], zmm12 vmovups [r13 + 64], zmm17 vmovups [r13 + 128], zmm22 vmovups [r13 + 192], zmm27 - vmovups [rbx], zmm9 + vmovups [rbx], zmm13 vmovups [rbx + 64], zmm18 vmovups [rbx + 128], zmm23 vmovups [rbx + 192], zmm28 @@ -336,15 +336,15 @@ tail: shr r11, 16 kmovw k4, r11d - vmovups ZMMWORD PTR [r10]{k1}, zmm7 + vmovups ZMMWORD PTR [r10]{k1}, zmm11 vmovups ZMMWORD PTR [r10 + 64]{k2}, zmm16 vmovups ZMMWORD PTR [r10 + 128]{k3}, zmm21 vmovups ZMMWORD PTR [r10 + 192]{k4}, zmm26 - vmovups ZMMWORD PTR [r13]{k1}, zmm8 + vmovups ZMMWORD PTR [r13]{k1}, zmm12 vmovups ZMMWORD PTR [r13 + 64]{k2}, zmm17 vmovups ZMMWORD PTR [r13 + 128]{k3}, zmm22 vmovups ZMMWORD PTR [r13 + 192]{k4}, zmm27 - vmovups ZMMWORD PTR [rbx]{k1}, zmm9 + vmovups ZMMWORD PTR [rbx]{k1}, zmm13 vmovups ZMMWORD PTR [rbx + 64]{k2}, zmm18 vmovups ZMMWORD PTR [rbx + 128]{k3}, zmm23 vmovups ZMMWORD PTR [rbx + 192]{k4}, zmm28 diff --git a/src/bf16-f32-gemm/gen/bf16-f32-gemm-6x16c2-minmax-asm-amd64-avx512bf16-broadcast.S b/src/bf16-f32-gemm/gen/bf16-f32-gemm-6x16c2-minmax-asm-amd64-avx512bf16-broadcast.S index 651449d9306..5a49aa7e343 100644 --- a/src/bf16-f32-gemm/gen/bf16-f32-gemm-6x16c2-minmax-asm-amd64-avx512bf16-broadcast.S +++ b/src/bf16-f32-gemm/gen/bf16-f32-gemm-6x16c2-minmax-asm-amd64-avx512bf16-broadcast.S @@ -121,12 +121,12 @@ outer_loop: mov r10, [rsp + 96] # Initialize accumulators with the biases. - vmovaps zmm7, [r9 + 0] - vmovaps zmm8, zmm7 - vmovaps zmm9, zmm7 - vmovaps zmm14, zmm7 - vmovaps zmm15, zmm7 - vmovaps zmm16, zmm7 + vmovaps zmm11, [r9 + 0] + vmovaps zmm12, zmm11 + vmovaps zmm13, zmm11 + vmovaps zmm14, zmm11 + vmovaps zmm15, zmm11 + vmovaps zmm16, zmm11 add r9, 64 # Are there at least 4 bytes? @@ -134,20 +134,20 @@ outer_loop: js inner_loop_tail inner_loop: - vmovaps zmm10, [r9 + 0] + vmovaps zmm7, [r9 + 0] add r9, 64 vbroadcastss zmm2, DWORD PTR [rcx + r11] - vdpbf16ps zmm7, zmm2, zmm10 + vdpbf16ps zmm11, zmm2, zmm7 vbroadcastss zmm2, DWORD PTR [rax + r11] - vdpbf16ps zmm8, zmm2, zmm10 + vdpbf16ps zmm12, zmm2, zmm7 vbroadcastss zmm2, DWORD PTR [r15 + r11] - vdpbf16ps zmm9, zmm2, zmm10 + vdpbf16ps zmm13, zmm2, zmm7 vbroadcastss zmm2, DWORD PTR [r14 + r11] - vdpbf16ps zmm14, zmm2, zmm10 + vdpbf16ps zmm14, zmm2, zmm7 vbroadcastss zmm2, DWORD PTR [r12 + r11] - vdpbf16ps zmm15, zmm2, zmm10 + vdpbf16ps zmm15, zmm2, zmm7 vbroadcastss zmm2, DWORD PTR [r10 + r11] - vdpbf16ps zmm16, zmm2, zmm10 + vdpbf16ps zmm16, zmm2, zmm7 add r11, 4 cmp rdx, r11 @@ -163,49 +163,49 @@ inner_loop: jz inner_loop_end inner_loop_tail: - vmovaps zmm10, [r9 + 0] + vmovaps zmm7, [r9 + 0] add r9, 64 vbroadcastss zmm2, DWORD PTR [rcx + r11] vpslld zmm2, zmm2, 16 vpsrad zmm2, zmm2, 16 - vdpbf16ps zmm7, zmm2, zmm10 + vdpbf16ps zmm11, zmm2, zmm7 vbroadcastss zmm2, DWORD PTR [rax + r11] vpslld zmm2, zmm2, 16 vpsrad zmm2, zmm2, 16 - vdpbf16ps zmm8, zmm2, zmm10 + vdpbf16ps zmm12, zmm2, zmm7 vbroadcastss zmm2, DWORD PTR [r15 + r11] vpslld zmm2, zmm2, 16 vpsrad zmm2, zmm2, 16 - vdpbf16ps zmm9, zmm2, zmm10 + vdpbf16ps zmm13, zmm2, zmm7 vbroadcastss zmm2, DWORD PTR [r14 + r11] vpslld zmm2, zmm2, 16 vpsrad zmm2, zmm2, 16 - vdpbf16ps zmm14, zmm2, zmm10 + vdpbf16ps zmm14, zmm2, zmm7 vbroadcastss zmm2, DWORD PTR [r12 + r11] vpslld zmm2, zmm2, 16 vpsrad zmm2, zmm2, 16 - vdpbf16ps zmm15, zmm2, zmm10 + vdpbf16ps zmm15, zmm2, zmm7 vbroadcastss zmm2, DWORD PTR [r10 + r11] vpslld zmm2, zmm2, 16 vpsrad zmm2, zmm2, 16 - vdpbf16ps zmm16, zmm2, zmm10 + vdpbf16ps zmm16, zmm2, zmm7 inner_loop_end: # Min/max clamping. - vminps zmm7, zmm1, zmm7 - vminps zmm8, zmm1, zmm8 - vminps zmm9, zmm1, zmm9 + vminps zmm11, zmm1, zmm11 + vminps zmm12, zmm1, zmm12 + vminps zmm13, zmm1, zmm13 vminps zmm14, zmm1, zmm14 vminps zmm15, zmm1, zmm15 vminps zmm16, zmm1, zmm16 - vmaxps zmm7, zmm0, zmm7 - vmaxps zmm8, zmm0, zmm8 - vmaxps zmm9, zmm0, zmm9 + vmaxps zmm11, zmm0, zmm11 + vmaxps zmm12, zmm0, zmm12 + vmaxps zmm13, zmm0, zmm13 vmaxps zmm14, zmm0, zmm14 vmaxps zmm15, zmm0, zmm15 vmaxps zmm16, zmm0, zmm16 @@ -222,9 +222,9 @@ inner_loop_end: cmp rsi, 16 jl tail - vmovups [rcx], zmm7 - vmovups [rax], zmm8 - vmovups [r15], zmm9 + vmovups [rcx], zmm11 + vmovups [rax], zmm12 + vmovups [r15], zmm13 vmovups [r14], zmm14 vmovups [r12], zmm15 vmovups [r10], zmm16 @@ -252,9 +252,9 @@ tail: shlx r11, r11, rsi not r11 kmovw k1, r11d - vmovups ZMMWORD PTR [rcx]{k1}, zmm7 - vmovups ZMMWORD PTR [rax]{k1}, zmm8 - vmovups ZMMWORD PTR [r15]{k1}, zmm9 + vmovups ZMMWORD PTR [rcx]{k1}, zmm11 + vmovups ZMMWORD PTR [rax]{k1}, zmm12 + vmovups ZMMWORD PTR [r15]{k1}, zmm13 vmovups ZMMWORD PTR [r14]{k1}, zmm14 vmovups ZMMWORD PTR [r12]{k1}, zmm15 vmovups ZMMWORD PTR [r10]{k1}, zmm16 diff --git a/src/bf16-f32-gemm/gen/bf16-f32-gemm-6x32c2-minmax-asm-amd64-avx512bf16-broadcast.S b/src/bf16-f32-gemm/gen/bf16-f32-gemm-6x32c2-minmax-asm-amd64-avx512bf16-broadcast.S index 2ffe0cdddec..a7e4d616bef 100644 --- a/src/bf16-f32-gemm/gen/bf16-f32-gemm-6x32c2-minmax-asm-amd64-avx512bf16-broadcast.S +++ b/src/bf16-f32-gemm/gen/bf16-f32-gemm-6x32c2-minmax-asm-amd64-avx512bf16-broadcast.S @@ -121,13 +121,13 @@ outer_loop: mov r10, [rsp + 96] # Initialize accumulators with the biases. - vmovaps zmm7, [r9 + 0] + vmovaps zmm11, [r9 + 0] vmovaps zmm17, [r9 + 64] - vmovaps zmm8, zmm7 - vmovaps zmm9, zmm7 - vmovaps zmm14, zmm7 - vmovaps zmm15, zmm7 - vmovaps zmm16, zmm7 + vmovaps zmm12, zmm11 + vmovaps zmm13, zmm11 + vmovaps zmm14, zmm11 + vmovaps zmm15, zmm11 + vmovaps zmm16, zmm11 vmovaps zmm18, zmm17 vmovaps zmm19, zmm17 vmovaps zmm20, zmm17 @@ -140,27 +140,27 @@ outer_loop: js inner_loop_tail inner_loop: - vmovaps zmm10, [r9 + 0] - vmovaps zmm11, [r9 + 64] + vmovaps zmm7, [r9 + 0] + vmovaps zmm8, [r9 + 64] add r9, 128 vbroadcastss zmm2, DWORD PTR [rcx + r11] - vdpbf16ps zmm7, zmm2, zmm10 - vdpbf16ps zmm17, zmm2, zmm11 + vdpbf16ps zmm11, zmm2, zmm7 + vdpbf16ps zmm17, zmm2, zmm8 vbroadcastss zmm2, DWORD PTR [rax + r11] - vdpbf16ps zmm8, zmm2, zmm10 - vdpbf16ps zmm18, zmm2, zmm11 + vdpbf16ps zmm12, zmm2, zmm7 + vdpbf16ps zmm18, zmm2, zmm8 vbroadcastss zmm2, DWORD PTR [r15 + r11] - vdpbf16ps zmm9, zmm2, zmm10 - vdpbf16ps zmm19, zmm2, zmm11 + vdpbf16ps zmm13, zmm2, zmm7 + vdpbf16ps zmm19, zmm2, zmm8 vbroadcastss zmm2, DWORD PTR [r14 + r11] - vdpbf16ps zmm14, zmm2, zmm10 - vdpbf16ps zmm20, zmm2, zmm11 + vdpbf16ps zmm14, zmm2, zmm7 + vdpbf16ps zmm20, zmm2, zmm8 vbroadcastss zmm2, DWORD PTR [r12 + r11] - vdpbf16ps zmm15, zmm2, zmm10 - vdpbf16ps zmm21, zmm2, zmm11 + vdpbf16ps zmm15, zmm2, zmm7 + vdpbf16ps zmm21, zmm2, zmm8 vbroadcastss zmm2, DWORD PTR [r10 + r11] - vdpbf16ps zmm16, zmm2, zmm10 - vdpbf16ps zmm22, zmm2, zmm11 + vdpbf16ps zmm16, zmm2, zmm7 + vdpbf16ps zmm22, zmm2, zmm8 add r11, 4 cmp rdx, r11 @@ -176,68 +176,68 @@ inner_loop: jz inner_loop_end inner_loop_tail: - vmovaps zmm10, [r9 + 0] - vmovaps zmm11, [r9 + 64] + vmovaps zmm7, [r9 + 0] + vmovaps zmm8, [r9 + 64] add r9, 128 vbroadcastss zmm2, DWORD PTR [rcx + r11] vpslld zmm2, zmm2, 16 vpsrad zmm2, zmm2, 16 - vdpbf16ps zmm7, zmm2, zmm10 + vdpbf16ps zmm11, zmm2, zmm7 vpslld zmm2, zmm2, 16 vpsrad zmm2, zmm2, 16 - vdpbf16ps zmm17, zmm2, zmm11 + vdpbf16ps zmm17, zmm2, zmm8 vbroadcastss zmm2, DWORD PTR [rax + r11] vpslld zmm2, zmm2, 16 vpsrad zmm2, zmm2, 16 - vdpbf16ps zmm8, zmm2, zmm10 + vdpbf16ps zmm12, zmm2, zmm7 vpslld zmm2, zmm2, 16 vpsrad zmm2, zmm2, 16 - vdpbf16ps zmm18, zmm2, zmm11 + vdpbf16ps zmm18, zmm2, zmm8 vbroadcastss zmm2, DWORD PTR [r15 + r11] vpslld zmm2, zmm2, 16 vpsrad zmm2, zmm2, 16 - vdpbf16ps zmm9, zmm2, zmm10 + vdpbf16ps zmm13, zmm2, zmm7 vpslld zmm2, zmm2, 16 vpsrad zmm2, zmm2, 16 - vdpbf16ps zmm19, zmm2, zmm11 + vdpbf16ps zmm19, zmm2, zmm8 vbroadcastss zmm2, DWORD PTR [r14 + r11] vpslld zmm2, zmm2, 16 vpsrad zmm2, zmm2, 16 - vdpbf16ps zmm14, zmm2, zmm10 + vdpbf16ps zmm14, zmm2, zmm7 vpslld zmm2, zmm2, 16 vpsrad zmm2, zmm2, 16 - vdpbf16ps zmm20, zmm2, zmm11 + vdpbf16ps zmm20, zmm2, zmm8 vbroadcastss zmm2, DWORD PTR [r12 + r11] vpslld zmm2, zmm2, 16 vpsrad zmm2, zmm2, 16 - vdpbf16ps zmm15, zmm2, zmm10 + vdpbf16ps zmm15, zmm2, zmm7 vpslld zmm2, zmm2, 16 vpsrad zmm2, zmm2, 16 - vdpbf16ps zmm21, zmm2, zmm11 + vdpbf16ps zmm21, zmm2, zmm8 vbroadcastss zmm2, DWORD PTR [r10 + r11] vpslld zmm2, zmm2, 16 vpsrad zmm2, zmm2, 16 - vdpbf16ps zmm16, zmm2, zmm10 + vdpbf16ps zmm16, zmm2, zmm7 vpslld zmm2, zmm2, 16 vpsrad zmm2, zmm2, 16 - vdpbf16ps zmm22, zmm2, zmm11 + vdpbf16ps zmm22, zmm2, zmm8 inner_loop_end: # Min/max clamping. - vminps zmm7, zmm1, zmm7 - vminps zmm8, zmm1, zmm8 - vminps zmm9, zmm1, zmm9 + vminps zmm11, zmm1, zmm11 + vminps zmm12, zmm1, zmm12 + vminps zmm13, zmm1, zmm13 vminps zmm14, zmm1, zmm14 vminps zmm15, zmm1, zmm15 vminps zmm16, zmm1, zmm16 @@ -247,9 +247,9 @@ inner_loop_end: vminps zmm20, zmm1, zmm20 vminps zmm21, zmm1, zmm21 vminps zmm22, zmm1, zmm22 - vmaxps zmm7, zmm0, zmm7 - vmaxps zmm8, zmm0, zmm8 - vmaxps zmm9, zmm0, zmm9 + vmaxps zmm11, zmm0, zmm11 + vmaxps zmm12, zmm0, zmm12 + vmaxps zmm13, zmm0, zmm13 vmaxps zmm14, zmm0, zmm14 vmaxps zmm15, zmm0, zmm15 vmaxps zmm16, zmm0, zmm16 @@ -272,11 +272,11 @@ inner_loop_end: cmp rsi, 32 jl tail - vmovups [rcx], zmm7 + vmovups [rcx], zmm11 vmovups [rcx + 64], zmm17 - vmovups [rax], zmm8 + vmovups [rax], zmm12 vmovups [rax + 64], zmm18 - vmovups [r15], zmm9 + vmovups [r15], zmm13 vmovups [r15 + 64], zmm19 vmovups [r14], zmm14 vmovups [r14 + 64], zmm20 @@ -310,11 +310,11 @@ tail: kmovw k1, r11d shr r11d, 16 kmovw k2, r11d - vmovups ZMMWORD PTR [rcx]{k1}, zmm7 + vmovups ZMMWORD PTR [rcx]{k1}, zmm11 vmovups ZMMWORD PTR [rcx + 64]{k2}, zmm17 - vmovups ZMMWORD PTR [rax]{k1}, zmm8 + vmovups ZMMWORD PTR [rax]{k1}, zmm12 vmovups ZMMWORD PTR [rax + 64]{k2}, zmm18 - vmovups ZMMWORD PTR [r15]{k1}, zmm9 + vmovups ZMMWORD PTR [r15]{k1}, zmm13 vmovups ZMMWORD PTR [r15 + 64]{k2}, zmm19 vmovups ZMMWORD PTR [r14]{k1}, zmm14 vmovups ZMMWORD PTR [r14 + 64]{k2}, zmm20 diff --git a/src/bf16-f32-gemm/gen/bf16-f32-gemm-7x16c2-minmax-asm-amd64-avx512bf16-broadcast.S b/src/bf16-f32-gemm/gen/bf16-f32-gemm-7x16c2-minmax-asm-amd64-avx512bf16-broadcast.S index c24eed24e11..babd9c5251e 100644 --- a/src/bf16-f32-gemm/gen/bf16-f32-gemm-7x16c2-minmax-asm-amd64-avx512bf16-broadcast.S +++ b/src/bf16-f32-gemm/gen/bf16-f32-gemm-7x16c2-minmax-asm-amd64-avx512bf16-broadcast.S @@ -134,13 +134,13 @@ outer_loop: mov r13, [rsp + 112] # Initialize accumulators with the biases. - vmovaps zmm7, [r9 + 0] - vmovaps zmm8, zmm7 - vmovaps zmm9, zmm7 - vmovaps zmm14, zmm7 - vmovaps zmm15, zmm7 - vmovaps zmm16, zmm7 - vmovaps zmm17, zmm7 + vmovaps zmm11, [r9 + 0] + vmovaps zmm12, zmm11 + vmovaps zmm13, zmm11 + vmovaps zmm14, zmm11 + vmovaps zmm15, zmm11 + vmovaps zmm16, zmm11 + vmovaps zmm17, zmm11 add r9, 64 # Are there at least 4 bytes? @@ -148,22 +148,22 @@ outer_loop: js inner_loop_tail inner_loop: - vmovaps zmm10, [r9 + 0] + vmovaps zmm7, [r9 + 0] add r9, 64 vbroadcastss zmm2, DWORD PTR [rcx + r11] - vdpbf16ps zmm7, zmm2, zmm10 + vdpbf16ps zmm11, zmm2, zmm7 vbroadcastss zmm2, DWORD PTR [rax + r11] - vdpbf16ps zmm8, zmm2, zmm10 + vdpbf16ps zmm12, zmm2, zmm7 vbroadcastss zmm2, DWORD PTR [r15 + r11] - vdpbf16ps zmm9, zmm2, zmm10 + vdpbf16ps zmm13, zmm2, zmm7 vbroadcastss zmm2, DWORD PTR [r14 + r11] - vdpbf16ps zmm14, zmm2, zmm10 + vdpbf16ps zmm14, zmm2, zmm7 vbroadcastss zmm2, DWORD PTR [r12 + r11] - vdpbf16ps zmm15, zmm2, zmm10 + vdpbf16ps zmm15, zmm2, zmm7 vbroadcastss zmm2, DWORD PTR [r10 + r11] - vdpbf16ps zmm16, zmm2, zmm10 + vdpbf16ps zmm16, zmm2, zmm7 vbroadcastss zmm2, DWORD PTR [r13 + r11] - vdpbf16ps zmm17, zmm2, zmm10 + vdpbf16ps zmm17, zmm2, zmm7 add r11, 4 cmp rdx, r11 @@ -179,55 +179,55 @@ inner_loop: jz inner_loop_end inner_loop_tail: - vmovaps zmm10, [r9 + 0] + vmovaps zmm7, [r9 + 0] add r9, 64 vbroadcastss zmm2, DWORD PTR [rcx + r11] vpslld zmm2, zmm2, 16 vpsrad zmm2, zmm2, 16 - vdpbf16ps zmm7, zmm2, zmm10 + vdpbf16ps zmm11, zmm2, zmm7 vbroadcastss zmm2, DWORD PTR [rax + r11] vpslld zmm2, zmm2, 16 vpsrad zmm2, zmm2, 16 - vdpbf16ps zmm8, zmm2, zmm10 + vdpbf16ps zmm12, zmm2, zmm7 vbroadcastss zmm2, DWORD PTR [r15 + r11] vpslld zmm2, zmm2, 16 vpsrad zmm2, zmm2, 16 - vdpbf16ps zmm9, zmm2, zmm10 + vdpbf16ps zmm13, zmm2, zmm7 vbroadcastss zmm2, DWORD PTR [r14 + r11] vpslld zmm2, zmm2, 16 vpsrad zmm2, zmm2, 16 - vdpbf16ps zmm14, zmm2, zmm10 + vdpbf16ps zmm14, zmm2, zmm7 vbroadcastss zmm2, DWORD PTR [r12 + r11] vpslld zmm2, zmm2, 16 vpsrad zmm2, zmm2, 16 - vdpbf16ps zmm15, zmm2, zmm10 + vdpbf16ps zmm15, zmm2, zmm7 vbroadcastss zmm2, DWORD PTR [r10 + r11] vpslld zmm2, zmm2, 16 vpsrad zmm2, zmm2, 16 - vdpbf16ps zmm16, zmm2, zmm10 + vdpbf16ps zmm16, zmm2, zmm7 vbroadcastss zmm2, DWORD PTR [r13 + r11] vpslld zmm2, zmm2, 16 vpsrad zmm2, zmm2, 16 - vdpbf16ps zmm17, zmm2, zmm10 + vdpbf16ps zmm17, zmm2, zmm7 inner_loop_end: # Min/max clamping. - vminps zmm7, zmm1, zmm7 - vminps zmm8, zmm1, zmm8 - vminps zmm9, zmm1, zmm9 + vminps zmm11, zmm1, zmm11 + vminps zmm12, zmm1, zmm12 + vminps zmm13, zmm1, zmm13 vminps zmm14, zmm1, zmm14 vminps zmm15, zmm1, zmm15 vminps zmm16, zmm1, zmm16 vminps zmm17, zmm1, zmm17 - vmaxps zmm7, zmm0, zmm7 - vmaxps zmm8, zmm0, zmm8 - vmaxps zmm9, zmm0, zmm9 + vmaxps zmm11, zmm0, zmm11 + vmaxps zmm12, zmm0, zmm12 + vmaxps zmm13, zmm0, zmm13 vmaxps zmm14, zmm0, zmm14 vmaxps zmm15, zmm0, zmm15 vmaxps zmm16, zmm0, zmm16 @@ -246,9 +246,9 @@ inner_loop_end: cmp rsi, 16 jl tail - vmovups [rcx], zmm7 - vmovups [rax], zmm8 - vmovups [r15], zmm9 + vmovups [rcx], zmm11 + vmovups [rax], zmm12 + vmovups [r15], zmm13 vmovups [r14], zmm14 vmovups [r12], zmm15 vmovups [r10], zmm16 @@ -279,9 +279,9 @@ tail: shlx r11, r11, rsi not r11 kmovw k1, r11d - vmovups ZMMWORD PTR [rcx]{k1}, zmm7 - vmovups ZMMWORD PTR [rax]{k1}, zmm8 - vmovups ZMMWORD PTR [r15]{k1}, zmm9 + vmovups ZMMWORD PTR [rcx]{k1}, zmm11 + vmovups ZMMWORD PTR [rax]{k1}, zmm12 + vmovups ZMMWORD PTR [r15]{k1}, zmm13 vmovups ZMMWORD PTR [r14]{k1}, zmm14 vmovups ZMMWORD PTR [r12]{k1}, zmm15 vmovups ZMMWORD PTR [r10]{k1}, zmm16 diff --git a/src/bf16-f32-gemm/gen/bf16-f32-gemm-7x32c2-minmax-asm-amd64-avx512bf16-broadcast.S b/src/bf16-f32-gemm/gen/bf16-f32-gemm-7x32c2-minmax-asm-amd64-avx512bf16-broadcast.S index 6d204b66c81..e3678339596 100644 --- a/src/bf16-f32-gemm/gen/bf16-f32-gemm-7x32c2-minmax-asm-amd64-avx512bf16-broadcast.S +++ b/src/bf16-f32-gemm/gen/bf16-f32-gemm-7x32c2-minmax-asm-amd64-avx512bf16-broadcast.S @@ -134,14 +134,14 @@ outer_loop: mov r13, [rsp + 112] # Initialize accumulators with the biases. - vmovaps zmm7, [r9 + 0] + vmovaps zmm11, [r9 + 0] vmovaps zmm18, [r9 + 64] - vmovaps zmm8, zmm7 - vmovaps zmm9, zmm7 - vmovaps zmm14, zmm7 - vmovaps zmm15, zmm7 - vmovaps zmm16, zmm7 - vmovaps zmm17, zmm7 + vmovaps zmm12, zmm11 + vmovaps zmm13, zmm11 + vmovaps zmm14, zmm11 + vmovaps zmm15, zmm11 + vmovaps zmm16, zmm11 + vmovaps zmm17, zmm11 vmovaps zmm19, zmm18 vmovaps zmm20, zmm18 vmovaps zmm21, zmm18 @@ -155,30 +155,30 @@ outer_loop: js inner_loop_tail inner_loop: - vmovaps zmm10, [r9 + 0] - vmovaps zmm11, [r9 + 64] + vmovaps zmm7, [r9 + 0] + vmovaps zmm8, [r9 + 64] add r9, 128 vbroadcastss zmm2, DWORD PTR [rcx + r11] - vdpbf16ps zmm7, zmm2, zmm10 - vdpbf16ps zmm18, zmm2, zmm11 + vdpbf16ps zmm11, zmm2, zmm7 + vdpbf16ps zmm18, zmm2, zmm8 vbroadcastss zmm2, DWORD PTR [rax + r11] - vdpbf16ps zmm8, zmm2, zmm10 - vdpbf16ps zmm19, zmm2, zmm11 + vdpbf16ps zmm12, zmm2, zmm7 + vdpbf16ps zmm19, zmm2, zmm8 vbroadcastss zmm2, DWORD PTR [r15 + r11] - vdpbf16ps zmm9, zmm2, zmm10 - vdpbf16ps zmm20, zmm2, zmm11 + vdpbf16ps zmm13, zmm2, zmm7 + vdpbf16ps zmm20, zmm2, zmm8 vbroadcastss zmm2, DWORD PTR [r14 + r11] - vdpbf16ps zmm14, zmm2, zmm10 - vdpbf16ps zmm21, zmm2, zmm11 + vdpbf16ps zmm14, zmm2, zmm7 + vdpbf16ps zmm21, zmm2, zmm8 vbroadcastss zmm2, DWORD PTR [r12 + r11] - vdpbf16ps zmm15, zmm2, zmm10 - vdpbf16ps zmm22, zmm2, zmm11 + vdpbf16ps zmm15, zmm2, zmm7 + vdpbf16ps zmm22, zmm2, zmm8 vbroadcastss zmm2, DWORD PTR [r10 + r11] - vdpbf16ps zmm16, zmm2, zmm10 - vdpbf16ps zmm23, zmm2, zmm11 + vdpbf16ps zmm16, zmm2, zmm7 + vdpbf16ps zmm23, zmm2, zmm8 vbroadcastss zmm2, DWORD PTR [r13 + r11] - vdpbf16ps zmm17, zmm2, zmm10 - vdpbf16ps zmm24, zmm2, zmm11 + vdpbf16ps zmm17, zmm2, zmm7 + vdpbf16ps zmm24, zmm2, zmm8 add r11, 4 cmp rdx, r11 @@ -194,77 +194,77 @@ inner_loop: jz inner_loop_end inner_loop_tail: - vmovaps zmm10, [r9 + 0] - vmovaps zmm11, [r9 + 64] + vmovaps zmm7, [r9 + 0] + vmovaps zmm8, [r9 + 64] add r9, 128 vbroadcastss zmm2, DWORD PTR [rcx + r11] vpslld zmm2, zmm2, 16 vpsrad zmm2, zmm2, 16 - vdpbf16ps zmm7, zmm2, zmm10 + vdpbf16ps zmm11, zmm2, zmm7 vpslld zmm2, zmm2, 16 vpsrad zmm2, zmm2, 16 - vdpbf16ps zmm18, zmm2, zmm11 + vdpbf16ps zmm18, zmm2, zmm8 vbroadcastss zmm2, DWORD PTR [rax + r11] vpslld zmm2, zmm2, 16 vpsrad zmm2, zmm2, 16 - vdpbf16ps zmm8, zmm2, zmm10 + vdpbf16ps zmm12, zmm2, zmm7 vpslld zmm2, zmm2, 16 vpsrad zmm2, zmm2, 16 - vdpbf16ps zmm19, zmm2, zmm11 + vdpbf16ps zmm19, zmm2, zmm8 vbroadcastss zmm2, DWORD PTR [r15 + r11] vpslld zmm2, zmm2, 16 vpsrad zmm2, zmm2, 16 - vdpbf16ps zmm9, zmm2, zmm10 + vdpbf16ps zmm13, zmm2, zmm7 vpslld zmm2, zmm2, 16 vpsrad zmm2, zmm2, 16 - vdpbf16ps zmm20, zmm2, zmm11 + vdpbf16ps zmm20, zmm2, zmm8 vbroadcastss zmm2, DWORD PTR [r14 + r11] vpslld zmm2, zmm2, 16 vpsrad zmm2, zmm2, 16 - vdpbf16ps zmm14, zmm2, zmm10 + vdpbf16ps zmm14, zmm2, zmm7 vpslld zmm2, zmm2, 16 vpsrad zmm2, zmm2, 16 - vdpbf16ps zmm21, zmm2, zmm11 + vdpbf16ps zmm21, zmm2, zmm8 vbroadcastss zmm2, DWORD PTR [r12 + r11] vpslld zmm2, zmm2, 16 vpsrad zmm2, zmm2, 16 - vdpbf16ps zmm15, zmm2, zmm10 + vdpbf16ps zmm15, zmm2, zmm7 vpslld zmm2, zmm2, 16 vpsrad zmm2, zmm2, 16 - vdpbf16ps zmm22, zmm2, zmm11 + vdpbf16ps zmm22, zmm2, zmm8 vbroadcastss zmm2, DWORD PTR [r10 + r11] vpslld zmm2, zmm2, 16 vpsrad zmm2, zmm2, 16 - vdpbf16ps zmm16, zmm2, zmm10 + vdpbf16ps zmm16, zmm2, zmm7 vpslld zmm2, zmm2, 16 vpsrad zmm2, zmm2, 16 - vdpbf16ps zmm23, zmm2, zmm11 + vdpbf16ps zmm23, zmm2, zmm8 vbroadcastss zmm2, DWORD PTR [r13 + r11] vpslld zmm2, zmm2, 16 vpsrad zmm2, zmm2, 16 - vdpbf16ps zmm17, zmm2, zmm10 + vdpbf16ps zmm17, zmm2, zmm7 vpslld zmm2, zmm2, 16 vpsrad zmm2, zmm2, 16 - vdpbf16ps zmm24, zmm2, zmm11 + vdpbf16ps zmm24, zmm2, zmm8 inner_loop_end: # Min/max clamping. - vminps zmm7, zmm1, zmm7 - vminps zmm8, zmm1, zmm8 - vminps zmm9, zmm1, zmm9 + vminps zmm11, zmm1, zmm11 + vminps zmm12, zmm1, zmm12 + vminps zmm13, zmm1, zmm13 vminps zmm14, zmm1, zmm14 vminps zmm15, zmm1, zmm15 vminps zmm16, zmm1, zmm16 @@ -276,9 +276,9 @@ inner_loop_end: vminps zmm22, zmm1, zmm22 vminps zmm23, zmm1, zmm23 vminps zmm24, zmm1, zmm24 - vmaxps zmm7, zmm0, zmm7 - vmaxps zmm8, zmm0, zmm8 - vmaxps zmm9, zmm0, zmm9 + vmaxps zmm11, zmm0, zmm11 + vmaxps zmm12, zmm0, zmm12 + vmaxps zmm13, zmm0, zmm13 vmaxps zmm14, zmm0, zmm14 vmaxps zmm15, zmm0, zmm15 vmaxps zmm16, zmm0, zmm16 @@ -304,11 +304,11 @@ inner_loop_end: cmp rsi, 32 jl tail - vmovups [rcx], zmm7 + vmovups [rcx], zmm11 vmovups [rcx + 64], zmm18 - vmovups [rax], zmm8 + vmovups [rax], zmm12 vmovups [rax + 64], zmm19 - vmovups [r15], zmm9 + vmovups [r15], zmm13 vmovups [r15 + 64], zmm20 vmovups [r14], zmm14 vmovups [r14 + 64], zmm21 @@ -346,11 +346,11 @@ tail: kmovw k1, r11d shr r11d, 16 kmovw k2, r11d - vmovups ZMMWORD PTR [rcx]{k1}, zmm7 + vmovups ZMMWORD PTR [rcx]{k1}, zmm11 vmovups ZMMWORD PTR [rcx + 64]{k2}, zmm18 - vmovups ZMMWORD PTR [rax]{k1}, zmm8 + vmovups ZMMWORD PTR [rax]{k1}, zmm12 vmovups ZMMWORD PTR [rax + 64]{k2}, zmm19 - vmovups ZMMWORD PTR [r15]{k1}, zmm9 + vmovups ZMMWORD PTR [r15]{k1}, zmm13 vmovups ZMMWORD PTR [r15 + 64]{k2}, zmm20 vmovups ZMMWORD PTR [r14]{k1}, zmm14 vmovups ZMMWORD PTR [r14 + 64]{k2}, zmm21 diff --git a/src/bf16-f32-gemm/gen/bf16-f32-gemm-8x16c2-minmax-asm-amd64-avx512bf16-broadcast.S b/src/bf16-f32-gemm/gen/bf16-f32-gemm-8x16c2-minmax-asm-amd64-avx512bf16-broadcast.S index 56e5f15c4b4..2a916964249 100644 --- a/src/bf16-f32-gemm/gen/bf16-f32-gemm-8x16c2-minmax-asm-amd64-avx512bf16-broadcast.S +++ b/src/bf16-f32-gemm/gen/bf16-f32-gemm-8x16c2-minmax-asm-amd64-avx512bf16-broadcast.S @@ -147,14 +147,14 @@ outer_loop: mov rbx, [rsp + 128] # Initialize accumulators with the biases. - vmovaps zmm7, [r9 + 0] - vmovaps zmm8, zmm7 - vmovaps zmm9, zmm7 - vmovaps zmm14, zmm7 - vmovaps zmm15, zmm7 - vmovaps zmm16, zmm7 - vmovaps zmm17, zmm7 - vmovaps zmm18, zmm7 + vmovaps zmm11, [r9 + 0] + vmovaps zmm12, zmm11 + vmovaps zmm13, zmm11 + vmovaps zmm14, zmm11 + vmovaps zmm15, zmm11 + vmovaps zmm16, zmm11 + vmovaps zmm17, zmm11 + vmovaps zmm18, zmm11 add r9, 64 # Are there at least 4 bytes? @@ -162,24 +162,24 @@ outer_loop: js inner_loop_tail inner_loop: - vmovaps zmm10, [r9 + 0] + vmovaps zmm7, [r9 + 0] add r9, 64 vbroadcastss zmm2, DWORD PTR [rcx + r11] - vdpbf16ps zmm7, zmm2, zmm10 + vdpbf16ps zmm11, zmm2, zmm7 vbroadcastss zmm2, DWORD PTR [rax + r11] - vdpbf16ps zmm8, zmm2, zmm10 + vdpbf16ps zmm12, zmm2, zmm7 vbroadcastss zmm2, DWORD PTR [r15 + r11] - vdpbf16ps zmm9, zmm2, zmm10 + vdpbf16ps zmm13, zmm2, zmm7 vbroadcastss zmm2, DWORD PTR [r14 + r11] - vdpbf16ps zmm14, zmm2, zmm10 + vdpbf16ps zmm14, zmm2, zmm7 vbroadcastss zmm2, DWORD PTR [r12 + r11] - vdpbf16ps zmm15, zmm2, zmm10 + vdpbf16ps zmm15, zmm2, zmm7 vbroadcastss zmm2, DWORD PTR [r10 + r11] - vdpbf16ps zmm16, zmm2, zmm10 + vdpbf16ps zmm16, zmm2, zmm7 vbroadcastss zmm2, DWORD PTR [r13 + r11] - vdpbf16ps zmm17, zmm2, zmm10 + vdpbf16ps zmm17, zmm2, zmm7 vbroadcastss zmm2, DWORD PTR [rbx + r11] - vdpbf16ps zmm18, zmm2, zmm10 + vdpbf16ps zmm18, zmm2, zmm7 add r11, 4 cmp rdx, r11 @@ -195,61 +195,61 @@ inner_loop: jz inner_loop_end inner_loop_tail: - vmovaps zmm10, [r9 + 0] + vmovaps zmm7, [r9 + 0] add r9, 64 vbroadcastss zmm2, DWORD PTR [rcx + r11] vpslld zmm2, zmm2, 16 vpsrad zmm2, zmm2, 16 - vdpbf16ps zmm7, zmm2, zmm10 + vdpbf16ps zmm11, zmm2, zmm7 vbroadcastss zmm2, DWORD PTR [rax + r11] vpslld zmm2, zmm2, 16 vpsrad zmm2, zmm2, 16 - vdpbf16ps zmm8, zmm2, zmm10 + vdpbf16ps zmm12, zmm2, zmm7 vbroadcastss zmm2, DWORD PTR [r15 + r11] vpslld zmm2, zmm2, 16 vpsrad zmm2, zmm2, 16 - vdpbf16ps zmm9, zmm2, zmm10 + vdpbf16ps zmm13, zmm2, zmm7 vbroadcastss zmm2, DWORD PTR [r14 + r11] vpslld zmm2, zmm2, 16 vpsrad zmm2, zmm2, 16 - vdpbf16ps zmm14, zmm2, zmm10 + vdpbf16ps zmm14, zmm2, zmm7 vbroadcastss zmm2, DWORD PTR [r12 + r11] vpslld zmm2, zmm2, 16 vpsrad zmm2, zmm2, 16 - vdpbf16ps zmm15, zmm2, zmm10 + vdpbf16ps zmm15, zmm2, zmm7 vbroadcastss zmm2, DWORD PTR [r10 + r11] vpslld zmm2, zmm2, 16 vpsrad zmm2, zmm2, 16 - vdpbf16ps zmm16, zmm2, zmm10 + vdpbf16ps zmm16, zmm2, zmm7 vbroadcastss zmm2, DWORD PTR [r13 + r11] vpslld zmm2, zmm2, 16 vpsrad zmm2, zmm2, 16 - vdpbf16ps zmm17, zmm2, zmm10 + vdpbf16ps zmm17, zmm2, zmm7 vbroadcastss zmm2, DWORD PTR [rbx + r11] vpslld zmm2, zmm2, 16 vpsrad zmm2, zmm2, 16 - vdpbf16ps zmm18, zmm2, zmm10 + vdpbf16ps zmm18, zmm2, zmm7 inner_loop_end: # Min/max clamping. - vminps zmm7, zmm1, zmm7 - vminps zmm8, zmm1, zmm8 - vminps zmm9, zmm1, zmm9 + vminps zmm11, zmm1, zmm11 + vminps zmm12, zmm1, zmm12 + vminps zmm13, zmm1, zmm13 vminps zmm14, zmm1, zmm14 vminps zmm15, zmm1, zmm15 vminps zmm16, zmm1, zmm16 vminps zmm17, zmm1, zmm17 vminps zmm18, zmm1, zmm18 - vmaxps zmm7, zmm0, zmm7 - vmaxps zmm8, zmm0, zmm8 - vmaxps zmm9, zmm0, zmm9 + vmaxps zmm11, zmm0, zmm11 + vmaxps zmm12, zmm0, zmm12 + vmaxps zmm13, zmm0, zmm13 vmaxps zmm14, zmm0, zmm14 vmaxps zmm15, zmm0, zmm15 vmaxps zmm16, zmm0, zmm16 @@ -270,9 +270,9 @@ inner_loop_end: cmp rsi, 16 jl tail - vmovups [rcx], zmm7 - vmovups [rax], zmm8 - vmovups [r15], zmm9 + vmovups [rcx], zmm11 + vmovups [rax], zmm12 + vmovups [r15], zmm13 vmovups [r14], zmm14 vmovups [r12], zmm15 vmovups [r10], zmm16 @@ -306,9 +306,9 @@ tail: shlx r11, r11, rsi not r11 kmovw k1, r11d - vmovups ZMMWORD PTR [rcx]{k1}, zmm7 - vmovups ZMMWORD PTR [rax]{k1}, zmm8 - vmovups ZMMWORD PTR [r15]{k1}, zmm9 + vmovups ZMMWORD PTR [rcx]{k1}, zmm11 + vmovups ZMMWORD PTR [rax]{k1}, zmm12 + vmovups ZMMWORD PTR [r15]{k1}, zmm13 vmovups ZMMWORD PTR [r14]{k1}, zmm14 vmovups ZMMWORD PTR [r12]{k1}, zmm15 vmovups ZMMWORD PTR [r10]{k1}, zmm16 diff --git a/src/bf16-f32-gemm/gen/bf16-f32-gemm-8x32c2-minmax-asm-amd64-avx512bf16-broadcast.S b/src/bf16-f32-gemm/gen/bf16-f32-gemm-8x32c2-minmax-asm-amd64-avx512bf16-broadcast.S index f8d7c77dcd8..85ba626bb14 100644 --- a/src/bf16-f32-gemm/gen/bf16-f32-gemm-8x32c2-minmax-asm-amd64-avx512bf16-broadcast.S +++ b/src/bf16-f32-gemm/gen/bf16-f32-gemm-8x32c2-minmax-asm-amd64-avx512bf16-broadcast.S @@ -147,15 +147,15 @@ outer_loop: mov rbx, [rsp + 128] # Initialize accumulators with the biases. - vmovaps zmm7, [r9 + 0] + vmovaps zmm11, [r9 + 0] vmovaps zmm19, [r9 + 64] - vmovaps zmm8, zmm7 - vmovaps zmm9, zmm7 - vmovaps zmm14, zmm7 - vmovaps zmm15, zmm7 - vmovaps zmm16, zmm7 - vmovaps zmm17, zmm7 - vmovaps zmm18, zmm7 + vmovaps zmm12, zmm11 + vmovaps zmm13, zmm11 + vmovaps zmm14, zmm11 + vmovaps zmm15, zmm11 + vmovaps zmm16, zmm11 + vmovaps zmm17, zmm11 + vmovaps zmm18, zmm11 vmovaps zmm20, zmm19 vmovaps zmm21, zmm19 vmovaps zmm22, zmm19 @@ -170,33 +170,33 @@ outer_loop: js inner_loop_tail inner_loop: - vmovaps zmm10, [r9 + 0] - vmovaps zmm11, [r9 + 64] + vmovaps zmm7, [r9 + 0] + vmovaps zmm8, [r9 + 64] add r9, 128 vbroadcastss zmm2, DWORD PTR [rcx + r11] - vdpbf16ps zmm7, zmm2, zmm10 - vdpbf16ps zmm19, zmm2, zmm11 + vdpbf16ps zmm11, zmm2, zmm7 + vdpbf16ps zmm19, zmm2, zmm8 vbroadcastss zmm2, DWORD PTR [rax + r11] - vdpbf16ps zmm8, zmm2, zmm10 - vdpbf16ps zmm20, zmm2, zmm11 + vdpbf16ps zmm12, zmm2, zmm7 + vdpbf16ps zmm20, zmm2, zmm8 vbroadcastss zmm2, DWORD PTR [r15 + r11] - vdpbf16ps zmm9, zmm2, zmm10 - vdpbf16ps zmm21, zmm2, zmm11 + vdpbf16ps zmm13, zmm2, zmm7 + vdpbf16ps zmm21, zmm2, zmm8 vbroadcastss zmm2, DWORD PTR [r14 + r11] - vdpbf16ps zmm14, zmm2, zmm10 - vdpbf16ps zmm22, zmm2, zmm11 + vdpbf16ps zmm14, zmm2, zmm7 + vdpbf16ps zmm22, zmm2, zmm8 vbroadcastss zmm2, DWORD PTR [r12 + r11] - vdpbf16ps zmm15, zmm2, zmm10 - vdpbf16ps zmm23, zmm2, zmm11 + vdpbf16ps zmm15, zmm2, zmm7 + vdpbf16ps zmm23, zmm2, zmm8 vbroadcastss zmm2, DWORD PTR [r10 + r11] - vdpbf16ps zmm16, zmm2, zmm10 - vdpbf16ps zmm24, zmm2, zmm11 + vdpbf16ps zmm16, zmm2, zmm7 + vdpbf16ps zmm24, zmm2, zmm8 vbroadcastss zmm2, DWORD PTR [r13 + r11] - vdpbf16ps zmm17, zmm2, zmm10 - vdpbf16ps zmm25, zmm2, zmm11 + vdpbf16ps zmm17, zmm2, zmm7 + vdpbf16ps zmm25, zmm2, zmm8 vbroadcastss zmm2, DWORD PTR [rbx + r11] - vdpbf16ps zmm18, zmm2, zmm10 - vdpbf16ps zmm26, zmm2, zmm11 + vdpbf16ps zmm18, zmm2, zmm7 + vdpbf16ps zmm26, zmm2, zmm8 add r11, 4 cmp rdx, r11 @@ -212,86 +212,86 @@ inner_loop: jz inner_loop_end inner_loop_tail: - vmovaps zmm10, [r9 + 0] - vmovaps zmm11, [r9 + 64] + vmovaps zmm7, [r9 + 0] + vmovaps zmm8, [r9 + 64] add r9, 128 vbroadcastss zmm2, DWORD PTR [rcx + r11] vpslld zmm2, zmm2, 16 vpsrad zmm2, zmm2, 16 - vdpbf16ps zmm7, zmm2, zmm10 + vdpbf16ps zmm11, zmm2, zmm7 vpslld zmm2, zmm2, 16 vpsrad zmm2, zmm2, 16 - vdpbf16ps zmm19, zmm2, zmm11 + vdpbf16ps zmm19, zmm2, zmm8 vbroadcastss zmm2, DWORD PTR [rax + r11] vpslld zmm2, zmm2, 16 vpsrad zmm2, zmm2, 16 - vdpbf16ps zmm8, zmm2, zmm10 + vdpbf16ps zmm12, zmm2, zmm7 vpslld zmm2, zmm2, 16 vpsrad zmm2, zmm2, 16 - vdpbf16ps zmm20, zmm2, zmm11 + vdpbf16ps zmm20, zmm2, zmm8 vbroadcastss zmm2, DWORD PTR [r15 + r11] vpslld zmm2, zmm2, 16 vpsrad zmm2, zmm2, 16 - vdpbf16ps zmm9, zmm2, zmm10 + vdpbf16ps zmm13, zmm2, zmm7 vpslld zmm2, zmm2, 16 vpsrad zmm2, zmm2, 16 - vdpbf16ps zmm21, zmm2, zmm11 + vdpbf16ps zmm21, zmm2, zmm8 vbroadcastss zmm2, DWORD PTR [r14 + r11] vpslld zmm2, zmm2, 16 vpsrad zmm2, zmm2, 16 - vdpbf16ps zmm14, zmm2, zmm10 + vdpbf16ps zmm14, zmm2, zmm7 vpslld zmm2, zmm2, 16 vpsrad zmm2, zmm2, 16 - vdpbf16ps zmm22, zmm2, zmm11 + vdpbf16ps zmm22, zmm2, zmm8 vbroadcastss zmm2, DWORD PTR [r12 + r11] vpslld zmm2, zmm2, 16 vpsrad zmm2, zmm2, 16 - vdpbf16ps zmm15, zmm2, zmm10 + vdpbf16ps zmm15, zmm2, zmm7 vpslld zmm2, zmm2, 16 vpsrad zmm2, zmm2, 16 - vdpbf16ps zmm23, zmm2, zmm11 + vdpbf16ps zmm23, zmm2, zmm8 vbroadcastss zmm2, DWORD PTR [r10 + r11] vpslld zmm2, zmm2, 16 vpsrad zmm2, zmm2, 16 - vdpbf16ps zmm16, zmm2, zmm10 + vdpbf16ps zmm16, zmm2, zmm7 vpslld zmm2, zmm2, 16 vpsrad zmm2, zmm2, 16 - vdpbf16ps zmm24, zmm2, zmm11 + vdpbf16ps zmm24, zmm2, zmm8 vbroadcastss zmm2, DWORD PTR [r13 + r11] vpslld zmm2, zmm2, 16 vpsrad zmm2, zmm2, 16 - vdpbf16ps zmm17, zmm2, zmm10 + vdpbf16ps zmm17, zmm2, zmm7 vpslld zmm2, zmm2, 16 vpsrad zmm2, zmm2, 16 - vdpbf16ps zmm25, zmm2, zmm11 + vdpbf16ps zmm25, zmm2, zmm8 vbroadcastss zmm2, DWORD PTR [rbx + r11] vpslld zmm2, zmm2, 16 vpsrad zmm2, zmm2, 16 - vdpbf16ps zmm18, zmm2, zmm10 + vdpbf16ps zmm18, zmm2, zmm7 vpslld zmm2, zmm2, 16 vpsrad zmm2, zmm2, 16 - vdpbf16ps zmm26, zmm2, zmm11 + vdpbf16ps zmm26, zmm2, zmm8 inner_loop_end: # Min/max clamping. - vminps zmm7, zmm1, zmm7 - vminps zmm8, zmm1, zmm8 - vminps zmm9, zmm1, zmm9 + vminps zmm11, zmm1, zmm11 + vminps zmm12, zmm1, zmm12 + vminps zmm13, zmm1, zmm13 vminps zmm14, zmm1, zmm14 vminps zmm15, zmm1, zmm15 vminps zmm16, zmm1, zmm16 @@ -305,9 +305,9 @@ inner_loop_end: vminps zmm24, zmm1, zmm24 vminps zmm25, zmm1, zmm25 vminps zmm26, zmm1, zmm26 - vmaxps zmm7, zmm0, zmm7 - vmaxps zmm8, zmm0, zmm8 - vmaxps zmm9, zmm0, zmm9 + vmaxps zmm11, zmm0, zmm11 + vmaxps zmm12, zmm0, zmm12 + vmaxps zmm13, zmm0, zmm13 vmaxps zmm14, zmm0, zmm14 vmaxps zmm15, zmm0, zmm15 vmaxps zmm16, zmm0, zmm16 @@ -336,11 +336,11 @@ inner_loop_end: cmp rsi, 32 jl tail - vmovups [rcx], zmm7 + vmovups [rcx], zmm11 vmovups [rcx + 64], zmm19 - vmovups [rax], zmm8 + vmovups [rax], zmm12 vmovups [rax + 64], zmm20 - vmovups [r15], zmm9 + vmovups [r15], zmm13 vmovups [r15 + 64], zmm21 vmovups [r14], zmm14 vmovups [r14 + 64], zmm22 @@ -382,11 +382,11 @@ tail: kmovw k1, r11d shr r11d, 16 kmovw k2, r11d - vmovups ZMMWORD PTR [rcx]{k1}, zmm7 + vmovups ZMMWORD PTR [rcx]{k1}, zmm11 vmovups ZMMWORD PTR [rcx + 64]{k2}, zmm19 - vmovups ZMMWORD PTR [rax]{k1}, zmm8 + vmovups ZMMWORD PTR [rax]{k1}, zmm12 vmovups ZMMWORD PTR [rax + 64]{k2}, zmm20 - vmovups ZMMWORD PTR [r15]{k1}, zmm9 + vmovups ZMMWORD PTR [r15]{k1}, zmm13 vmovups ZMMWORD PTR [r15 + 64]{k2}, zmm21 vmovups ZMMWORD PTR [r14]{k1}, zmm14 vmovups ZMMWORD PTR [r14 + 64]{k2}, zmm22 diff --git a/src/bf16-f32-gemm/gen/bf16-f32-gemm-9x16c2-minmax-asm-amd64-avx512bf16-broadcast.S b/src/bf16-f32-gemm/gen/bf16-f32-gemm-9x16c2-minmax-asm-amd64-avx512bf16-broadcast.S index c004ba2b5f6..8bad949c56e 100644 --- a/src/bf16-f32-gemm/gen/bf16-f32-gemm-9x16c2-minmax-asm-amd64-avx512bf16-broadcast.S +++ b/src/bf16-f32-gemm/gen/bf16-f32-gemm-9x16c2-minmax-asm-amd64-avx512bf16-broadcast.S @@ -160,15 +160,15 @@ outer_loop: mov rbp, [rsp + 144] # Initialize accumulators with the biases. - vmovaps zmm7, [r9 + 0] - vmovaps zmm8, zmm7 - vmovaps zmm9, zmm7 - vmovaps zmm14, zmm7 - vmovaps zmm15, zmm7 - vmovaps zmm16, zmm7 - vmovaps zmm17, zmm7 - vmovaps zmm18, zmm7 - vmovaps zmm19, zmm7 + vmovaps zmm11, [r9 + 0] + vmovaps zmm12, zmm11 + vmovaps zmm13, zmm11 + vmovaps zmm14, zmm11 + vmovaps zmm15, zmm11 + vmovaps zmm16, zmm11 + vmovaps zmm17, zmm11 + vmovaps zmm18, zmm11 + vmovaps zmm19, zmm11 add r9, 64 # Are there at least 4 bytes? @@ -176,26 +176,26 @@ outer_loop: js inner_loop_tail inner_loop: - vmovaps zmm10, [r9 + 0] + vmovaps zmm7, [r9 + 0] add r9, 64 vbroadcastss zmm2, DWORD PTR [rcx + r11] - vdpbf16ps zmm7, zmm2, zmm10 + vdpbf16ps zmm11, zmm2, zmm7 vbroadcastss zmm2, DWORD PTR [rax + r11] - vdpbf16ps zmm8, zmm2, zmm10 + vdpbf16ps zmm12, zmm2, zmm7 vbroadcastss zmm2, DWORD PTR [r15 + r11] - vdpbf16ps zmm9, zmm2, zmm10 + vdpbf16ps zmm13, zmm2, zmm7 vbroadcastss zmm2, DWORD PTR [r14 + r11] - vdpbf16ps zmm14, zmm2, zmm10 + vdpbf16ps zmm14, zmm2, zmm7 vbroadcastss zmm2, DWORD PTR [r12 + r11] - vdpbf16ps zmm15, zmm2, zmm10 + vdpbf16ps zmm15, zmm2, zmm7 vbroadcastss zmm2, DWORD PTR [r10 + r11] - vdpbf16ps zmm16, zmm2, zmm10 + vdpbf16ps zmm16, zmm2, zmm7 vbroadcastss zmm2, DWORD PTR [r13 + r11] - vdpbf16ps zmm17, zmm2, zmm10 + vdpbf16ps zmm17, zmm2, zmm7 vbroadcastss zmm2, DWORD PTR [rbx + r11] - vdpbf16ps zmm18, zmm2, zmm10 + vdpbf16ps zmm18, zmm2, zmm7 vbroadcastss zmm2, DWORD PTR [rbp + r11] - vdpbf16ps zmm19, zmm2, zmm10 + vdpbf16ps zmm19, zmm2, zmm7 add r11, 4 cmp rdx, r11 @@ -211,67 +211,67 @@ inner_loop: jz inner_loop_end inner_loop_tail: - vmovaps zmm10, [r9 + 0] + vmovaps zmm7, [r9 + 0] add r9, 64 vbroadcastss zmm2, DWORD PTR [rcx + r11] vpslld zmm2, zmm2, 16 vpsrad zmm2, zmm2, 16 - vdpbf16ps zmm7, zmm2, zmm10 + vdpbf16ps zmm11, zmm2, zmm7 vbroadcastss zmm2, DWORD PTR [rax + r11] vpslld zmm2, zmm2, 16 vpsrad zmm2, zmm2, 16 - vdpbf16ps zmm8, zmm2, zmm10 + vdpbf16ps zmm12, zmm2, zmm7 vbroadcastss zmm2, DWORD PTR [r15 + r11] vpslld zmm2, zmm2, 16 vpsrad zmm2, zmm2, 16 - vdpbf16ps zmm9, zmm2, zmm10 + vdpbf16ps zmm13, zmm2, zmm7 vbroadcastss zmm2, DWORD PTR [r14 + r11] vpslld zmm2, zmm2, 16 vpsrad zmm2, zmm2, 16 - vdpbf16ps zmm14, zmm2, zmm10 + vdpbf16ps zmm14, zmm2, zmm7 vbroadcastss zmm2, DWORD PTR [r12 + r11] vpslld zmm2, zmm2, 16 vpsrad zmm2, zmm2, 16 - vdpbf16ps zmm15, zmm2, zmm10 + vdpbf16ps zmm15, zmm2, zmm7 vbroadcastss zmm2, DWORD PTR [r10 + r11] vpslld zmm2, zmm2, 16 vpsrad zmm2, zmm2, 16 - vdpbf16ps zmm16, zmm2, zmm10 + vdpbf16ps zmm16, zmm2, zmm7 vbroadcastss zmm2, DWORD PTR [r13 + r11] vpslld zmm2, zmm2, 16 vpsrad zmm2, zmm2, 16 - vdpbf16ps zmm17, zmm2, zmm10 + vdpbf16ps zmm17, zmm2, zmm7 vbroadcastss zmm2, DWORD PTR [rbx + r11] vpslld zmm2, zmm2, 16 vpsrad zmm2, zmm2, 16 - vdpbf16ps zmm18, zmm2, zmm10 + vdpbf16ps zmm18, zmm2, zmm7 vbroadcastss zmm2, DWORD PTR [rbp + r11] vpslld zmm2, zmm2, 16 vpsrad zmm2, zmm2, 16 - vdpbf16ps zmm19, zmm2, zmm10 + vdpbf16ps zmm19, zmm2, zmm7 inner_loop_end: # Min/max clamping. - vminps zmm7, zmm1, zmm7 - vminps zmm8, zmm1, zmm8 - vminps zmm9, zmm1, zmm9 + vminps zmm11, zmm1, zmm11 + vminps zmm12, zmm1, zmm12 + vminps zmm13, zmm1, zmm13 vminps zmm14, zmm1, zmm14 vminps zmm15, zmm1, zmm15 vminps zmm16, zmm1, zmm16 vminps zmm17, zmm1, zmm17 vminps zmm18, zmm1, zmm18 vminps zmm19, zmm1, zmm19 - vmaxps zmm7, zmm0, zmm7 - vmaxps zmm8, zmm0, zmm8 - vmaxps zmm9, zmm0, zmm9 + vmaxps zmm11, zmm0, zmm11 + vmaxps zmm12, zmm0, zmm12 + vmaxps zmm13, zmm0, zmm13 vmaxps zmm14, zmm0, zmm14 vmaxps zmm15, zmm0, zmm15 vmaxps zmm16, zmm0, zmm16 @@ -294,9 +294,9 @@ inner_loop_end: cmp rsi, 16 jl tail - vmovups [rcx], zmm7 - vmovups [rax], zmm8 - vmovups [r15], zmm9 + vmovups [rcx], zmm11 + vmovups [rax], zmm12 + vmovups [r15], zmm13 vmovups [r14], zmm14 vmovups [r12], zmm15 vmovups [r10], zmm16 @@ -333,9 +333,9 @@ tail: shlx r11, r11, rsi not r11 kmovw k1, r11d - vmovups ZMMWORD PTR [rcx]{k1}, zmm7 - vmovups ZMMWORD PTR [rax]{k1}, zmm8 - vmovups ZMMWORD PTR [r15]{k1}, zmm9 + vmovups ZMMWORD PTR [rcx]{k1}, zmm11 + vmovups ZMMWORD PTR [rax]{k1}, zmm12 + vmovups ZMMWORD PTR [r15]{k1}, zmm13 vmovups ZMMWORD PTR [r14]{k1}, zmm14 vmovups ZMMWORD PTR [r12]{k1}, zmm15 vmovups ZMMWORD PTR [r10]{k1}, zmm16 diff --git a/src/bf16-f32-gemm/gen/bf16-f32-gemm-9x32c2-minmax-asm-amd64-avx512bf16-broadcast.S b/src/bf16-f32-gemm/gen/bf16-f32-gemm-9x32c2-minmax-asm-amd64-avx512bf16-broadcast.S index 43ec9506b9c..dc6d02dd3ce 100644 --- a/src/bf16-f32-gemm/gen/bf16-f32-gemm-9x32c2-minmax-asm-amd64-avx512bf16-broadcast.S +++ b/src/bf16-f32-gemm/gen/bf16-f32-gemm-9x32c2-minmax-asm-amd64-avx512bf16-broadcast.S @@ -160,16 +160,16 @@ outer_loop: mov rbp, [rsp + 144] # Initialize accumulators with the biases. - vmovaps zmm7, [r9 + 0] + vmovaps zmm11, [r9 + 0] vmovaps zmm20, [r9 + 64] - vmovaps zmm8, zmm7 - vmovaps zmm9, zmm7 - vmovaps zmm14, zmm7 - vmovaps zmm15, zmm7 - vmovaps zmm16, zmm7 - vmovaps zmm17, zmm7 - vmovaps zmm18, zmm7 - vmovaps zmm19, zmm7 + vmovaps zmm12, zmm11 + vmovaps zmm13, zmm11 + vmovaps zmm14, zmm11 + vmovaps zmm15, zmm11 + vmovaps zmm16, zmm11 + vmovaps zmm17, zmm11 + vmovaps zmm18, zmm11 + vmovaps zmm19, zmm11 vmovaps zmm21, zmm20 vmovaps zmm22, zmm20 vmovaps zmm23, zmm20 @@ -185,36 +185,36 @@ outer_loop: js inner_loop_tail inner_loop: - vmovaps zmm10, [r9 + 0] - vmovaps zmm11, [r9 + 64] + vmovaps zmm7, [r9 + 0] + vmovaps zmm8, [r9 + 64] add r9, 128 vbroadcastss zmm2, DWORD PTR [rcx + r11] - vdpbf16ps zmm7, zmm2, zmm10 - vdpbf16ps zmm20, zmm2, zmm11 + vdpbf16ps zmm11, zmm2, zmm7 + vdpbf16ps zmm20, zmm2, zmm8 vbroadcastss zmm2, DWORD PTR [rax + r11] - vdpbf16ps zmm8, zmm2, zmm10 - vdpbf16ps zmm21, zmm2, zmm11 + vdpbf16ps zmm12, zmm2, zmm7 + vdpbf16ps zmm21, zmm2, zmm8 vbroadcastss zmm2, DWORD PTR [r15 + r11] - vdpbf16ps zmm9, zmm2, zmm10 - vdpbf16ps zmm22, zmm2, zmm11 + vdpbf16ps zmm13, zmm2, zmm7 + vdpbf16ps zmm22, zmm2, zmm8 vbroadcastss zmm2, DWORD PTR [r14 + r11] - vdpbf16ps zmm14, zmm2, zmm10 - vdpbf16ps zmm23, zmm2, zmm11 + vdpbf16ps zmm14, zmm2, zmm7 + vdpbf16ps zmm23, zmm2, zmm8 vbroadcastss zmm2, DWORD PTR [r12 + r11] - vdpbf16ps zmm15, zmm2, zmm10 - vdpbf16ps zmm24, zmm2, zmm11 + vdpbf16ps zmm15, zmm2, zmm7 + vdpbf16ps zmm24, zmm2, zmm8 vbroadcastss zmm2, DWORD PTR [r10 + r11] - vdpbf16ps zmm16, zmm2, zmm10 - vdpbf16ps zmm25, zmm2, zmm11 + vdpbf16ps zmm16, zmm2, zmm7 + vdpbf16ps zmm25, zmm2, zmm8 vbroadcastss zmm2, DWORD PTR [r13 + r11] - vdpbf16ps zmm17, zmm2, zmm10 - vdpbf16ps zmm26, zmm2, zmm11 + vdpbf16ps zmm17, zmm2, zmm7 + vdpbf16ps zmm26, zmm2, zmm8 vbroadcastss zmm2, DWORD PTR [rbx + r11] - vdpbf16ps zmm18, zmm2, zmm10 - vdpbf16ps zmm27, zmm2, zmm11 + vdpbf16ps zmm18, zmm2, zmm7 + vdpbf16ps zmm27, zmm2, zmm8 vbroadcastss zmm2, DWORD PTR [rbp + r11] - vdpbf16ps zmm19, zmm2, zmm10 - vdpbf16ps zmm28, zmm2, zmm11 + vdpbf16ps zmm19, zmm2, zmm7 + vdpbf16ps zmm28, zmm2, zmm8 add r11, 4 cmp rdx, r11 @@ -230,95 +230,95 @@ inner_loop: jz inner_loop_end inner_loop_tail: - vmovaps zmm10, [r9 + 0] - vmovaps zmm11, [r9 + 64] + vmovaps zmm7, [r9 + 0] + vmovaps zmm8, [r9 + 64] add r9, 128 vbroadcastss zmm2, DWORD PTR [rcx + r11] vpslld zmm2, zmm2, 16 vpsrad zmm2, zmm2, 16 - vdpbf16ps zmm7, zmm2, zmm10 + vdpbf16ps zmm11, zmm2, zmm7 vpslld zmm2, zmm2, 16 vpsrad zmm2, zmm2, 16 - vdpbf16ps zmm20, zmm2, zmm11 + vdpbf16ps zmm20, zmm2, zmm8 vbroadcastss zmm2, DWORD PTR [rax + r11] vpslld zmm2, zmm2, 16 vpsrad zmm2, zmm2, 16 - vdpbf16ps zmm8, zmm2, zmm10 + vdpbf16ps zmm12, zmm2, zmm7 vpslld zmm2, zmm2, 16 vpsrad zmm2, zmm2, 16 - vdpbf16ps zmm21, zmm2, zmm11 + vdpbf16ps zmm21, zmm2, zmm8 vbroadcastss zmm2, DWORD PTR [r15 + r11] vpslld zmm2, zmm2, 16 vpsrad zmm2, zmm2, 16 - vdpbf16ps zmm9, zmm2, zmm10 + vdpbf16ps zmm13, zmm2, zmm7 vpslld zmm2, zmm2, 16 vpsrad zmm2, zmm2, 16 - vdpbf16ps zmm22, zmm2, zmm11 + vdpbf16ps zmm22, zmm2, zmm8 vbroadcastss zmm2, DWORD PTR [r14 + r11] vpslld zmm2, zmm2, 16 vpsrad zmm2, zmm2, 16 - vdpbf16ps zmm14, zmm2, zmm10 + vdpbf16ps zmm14, zmm2, zmm7 vpslld zmm2, zmm2, 16 vpsrad zmm2, zmm2, 16 - vdpbf16ps zmm23, zmm2, zmm11 + vdpbf16ps zmm23, zmm2, zmm8 vbroadcastss zmm2, DWORD PTR [r12 + r11] vpslld zmm2, zmm2, 16 vpsrad zmm2, zmm2, 16 - vdpbf16ps zmm15, zmm2, zmm10 + vdpbf16ps zmm15, zmm2, zmm7 vpslld zmm2, zmm2, 16 vpsrad zmm2, zmm2, 16 - vdpbf16ps zmm24, zmm2, zmm11 + vdpbf16ps zmm24, zmm2, zmm8 vbroadcastss zmm2, DWORD PTR [r10 + r11] vpslld zmm2, zmm2, 16 vpsrad zmm2, zmm2, 16 - vdpbf16ps zmm16, zmm2, zmm10 + vdpbf16ps zmm16, zmm2, zmm7 vpslld zmm2, zmm2, 16 vpsrad zmm2, zmm2, 16 - vdpbf16ps zmm25, zmm2, zmm11 + vdpbf16ps zmm25, zmm2, zmm8 vbroadcastss zmm2, DWORD PTR [r13 + r11] vpslld zmm2, zmm2, 16 vpsrad zmm2, zmm2, 16 - vdpbf16ps zmm17, zmm2, zmm10 + vdpbf16ps zmm17, zmm2, zmm7 vpslld zmm2, zmm2, 16 vpsrad zmm2, zmm2, 16 - vdpbf16ps zmm26, zmm2, zmm11 + vdpbf16ps zmm26, zmm2, zmm8 vbroadcastss zmm2, DWORD PTR [rbx + r11] vpslld zmm2, zmm2, 16 vpsrad zmm2, zmm2, 16 - vdpbf16ps zmm18, zmm2, zmm10 + vdpbf16ps zmm18, zmm2, zmm7 vpslld zmm2, zmm2, 16 vpsrad zmm2, zmm2, 16 - vdpbf16ps zmm27, zmm2, zmm11 + vdpbf16ps zmm27, zmm2, zmm8 vbroadcastss zmm2, DWORD PTR [rbp + r11] vpslld zmm2, zmm2, 16 vpsrad zmm2, zmm2, 16 - vdpbf16ps zmm19, zmm2, zmm10 + vdpbf16ps zmm19, zmm2, zmm7 vpslld zmm2, zmm2, 16 vpsrad zmm2, zmm2, 16 - vdpbf16ps zmm28, zmm2, zmm11 + vdpbf16ps zmm28, zmm2, zmm8 inner_loop_end: # Min/max clamping. - vminps zmm7, zmm1, zmm7 - vminps zmm8, zmm1, zmm8 - vminps zmm9, zmm1, zmm9 + vminps zmm11, zmm1, zmm11 + vminps zmm12, zmm1, zmm12 + vminps zmm13, zmm1, zmm13 vminps zmm14, zmm1, zmm14 vminps zmm15, zmm1, zmm15 vminps zmm16, zmm1, zmm16 @@ -334,9 +334,9 @@ inner_loop_end: vminps zmm26, zmm1, zmm26 vminps zmm27, zmm1, zmm27 vminps zmm28, zmm1, zmm28 - vmaxps zmm7, zmm0, zmm7 - vmaxps zmm8, zmm0, zmm8 - vmaxps zmm9, zmm0, zmm9 + vmaxps zmm11, zmm0, zmm11 + vmaxps zmm12, zmm0, zmm12 + vmaxps zmm13, zmm0, zmm13 vmaxps zmm14, zmm0, zmm14 vmaxps zmm15, zmm0, zmm15 vmaxps zmm16, zmm0, zmm16 @@ -368,11 +368,11 @@ inner_loop_end: cmp rsi, 32 jl tail - vmovups [rcx], zmm7 + vmovups [rcx], zmm11 vmovups [rcx + 64], zmm20 - vmovups [rax], zmm8 + vmovups [rax], zmm12 vmovups [rax + 64], zmm21 - vmovups [r15], zmm9 + vmovups [r15], zmm13 vmovups [r15 + 64], zmm22 vmovups [r14], zmm14 vmovups [r14 + 64], zmm23 @@ -418,11 +418,11 @@ tail: kmovw k1, r11d shr r11d, 16 kmovw k2, r11d - vmovups ZMMWORD PTR [rcx]{k1}, zmm7 + vmovups ZMMWORD PTR [rcx]{k1}, zmm11 vmovups ZMMWORD PTR [rcx + 64]{k2}, zmm20 - vmovups ZMMWORD PTR [rax]{k1}, zmm8 + vmovups ZMMWORD PTR [rax]{k1}, zmm12 vmovups ZMMWORD PTR [rax + 64]{k2}, zmm21 - vmovups ZMMWORD PTR [r15]{k1}, zmm9 + vmovups ZMMWORD PTR [r15]{k1}, zmm13 vmovups ZMMWORD PTR [r15 + 64]{k2}, zmm22 vmovups ZMMWORD PTR [r14]{k1}, zmm14 vmovups ZMMWORD PTR [r14 + 64]{k2}, zmm23 diff --git a/src/configs/gemm-config.c b/src/configs/gemm-config.c index 21fe175070d..c2e9b6a5b11 100644 --- a/src/configs/gemm-config.c +++ b/src/configs/gemm-config.c @@ -5,6 +5,7 @@ #include #include +#include #if XNN_ENABLE_CPUINFO #include @@ -32,6 +33,7 @@ static struct xnn_gemm_config bf16_f32_gemm_config = {0}; static struct xnn_gemm_config f16_gemm_config = {0}; static struct xnn_gemm_config f32_gemm_config = {0}; +static struct xnn_gemm_config f32_igemm_config = {0}; static struct xnn_gemm_config f32_gemm_nr2_config = {0}; static struct xnn_gemm_config f32_qc4w_gemm_config = {0}; static struct xnn_gemm_config f32_qc8w_gemm_config = {0}; @@ -57,6 +59,7 @@ static struct xnn_gemm_config qu8_gemm_config = {0}; XNN_INIT_ONCE_GUARD(bf16_f32_gemm); XNN_INIT_ONCE_GUARD(f16_gemm); +XNN_INIT_ONCE_GUARD(f32_igemm); XNN_INIT_ONCE_GUARD(f32_gemm); XNN_INIT_ONCE_GUARD(f32_gemm_nr2); XNN_INIT_ONCE_GUARD(f32_qc4w_gemm); @@ -352,7 +355,7 @@ static void init_pqs8_qc8w_gemm_config(void) { #endif // XNN_ARCH_ARM64 && XNN_ENABLE_KLEIDIAI } -static void init_f32_gemm_config(void) { +static void init_f32_igemm_config(void) { #if XNN_ARCH_ARM const struct xnn_hardware_config* hardware_config = xnn_init_hardware_config(); assert(hardware_config != NULL); @@ -364,83 +367,83 @@ static void init_f32_gemm_config(void) { case cpuinfo_uarch_cortex_a7: case cpuinfo_uarch_krait: case cpuinfo_uarch_kryo: - f32_gemm_config.minmax.gemm[XNN_MR_TO_INDEX(1)] = xnn_init_hmp_gemm_ukernel((xnn_gemm_ukernel_fn) xnn_f32_gemm_minmax_ukernel_1x8__asm_aarch32_neon_cortex_a53_prfm); - f32_gemm_config.minmax.gemm[XNN_MR_TO_INDEX(4)] = xnn_init_hmp_gemm_ukernel((xnn_gemm_ukernel_fn) xnn_f32_gemm_minmax_ukernel_4x8__asm_aarch32_neon_cortex_a7); - f32_gemm_config.minmax.igemm[XNN_MR_TO_INDEX(1)] = xnn_init_hmp_igemm_ukernel((xnn_igemm_ukernel_fn) xnn_f32_igemm_minmax_ukernel_1x8__asm_aarch32_neon_cortex_a53_prfm); - f32_gemm_config.minmax.igemm[XNN_MR_TO_INDEX(4)] = xnn_init_hmp_igemm_ukernel((xnn_igemm_ukernel_fn) xnn_f32_igemm_minmax_ukernel_4x8__asm_aarch32_neon_cortex_a7); - f32_gemm_config.init.f32 = xnn_init_f32_minmax_scalar_params; - f32_gemm_config.pack_gemm_gio = (xnn_packw_gemm_gio_ukernel_fn) xnn_pack_f32_gemm_gio_w; - f32_gemm_config.pack_gemm_goi = (xnn_packw_gemm_goi_ukernel_fn) xnn_x32_packw_gemm_goi_ukernel_x8__neon_ld4lane_u4_prfm; - f32_gemm_config.mr = 4; - f32_gemm_config.nr = 8; + f32_igemm_config.minmax.gemm[XNN_MR_TO_INDEX(1)] = xnn_init_hmp_gemm_ukernel((xnn_gemm_ukernel_fn) xnn_f32_gemm_minmax_ukernel_1x8__asm_aarch32_neon_cortex_a53_prfm); + f32_igemm_config.minmax.gemm[XNN_MR_TO_INDEX(4)] = xnn_init_hmp_gemm_ukernel((xnn_gemm_ukernel_fn) xnn_f32_gemm_minmax_ukernel_4x8__asm_aarch32_neon_cortex_a7); + f32_igemm_config.minmax.igemm[XNN_MR_TO_INDEX(1)] = xnn_init_hmp_igemm_ukernel((xnn_igemm_ukernel_fn) xnn_f32_igemm_minmax_ukernel_1x8__asm_aarch32_neon_cortex_a53_prfm); + f32_igemm_config.minmax.igemm[XNN_MR_TO_INDEX(4)] = xnn_init_hmp_igemm_ukernel((xnn_igemm_ukernel_fn) xnn_f32_igemm_minmax_ukernel_4x8__asm_aarch32_neon_cortex_a7); + f32_igemm_config.init.f32 = xnn_init_f32_minmax_scalar_params; + f32_igemm_config.pack_gemm_gio = (xnn_packw_gemm_gio_ukernel_fn) xnn_pack_f32_gemm_gio_w; + f32_igemm_config.pack_gemm_goi = (xnn_packw_gemm_goi_ukernel_fn) xnn_x32_packw_gemm_goi_ukernel_x8__neon_ld4lane_u4_prfm; + f32_igemm_config.mr = 4; + f32_igemm_config.nr = 8; break; case cpuinfo_uarch_cortex_a53: - f32_gemm_config.minmax.gemm[XNN_MR_TO_INDEX(1)] = xnn_init_hmp_gemm_ukernel((xnn_gemm_ukernel_fn) xnn_f32_gemm_minmax_ukernel_1x8__asm_aarch32_neon_cortex_a53_prfm); - f32_gemm_config.minmax.gemm[XNN_MR_TO_INDEX(4)] = xnn_init_hmp_gemm_ukernel((xnn_gemm_ukernel_fn) xnn_f32_gemm_minmax_ukernel_4x8__asm_aarch32_neon_cortex_a53_prfm); - f32_gemm_config.minmax.igemm[XNN_MR_TO_INDEX(1)] = xnn_init_hmp_igemm_ukernel((xnn_igemm_ukernel_fn) xnn_f32_igemm_minmax_ukernel_1x8__asm_aarch32_neon_cortex_a53_prfm); - f32_gemm_config.minmax.igemm[XNN_MR_TO_INDEX(4)] = xnn_init_hmp_igemm_ukernel((xnn_igemm_ukernel_fn) xnn_f32_igemm_minmax_ukernel_4x8__asm_aarch32_neon_cortex_a53_prfm); - f32_gemm_config.init.f32 = xnn_init_f32_minmax_scalar_params; - f32_gemm_config.pack_gemm_gio = (xnn_packw_gemm_gio_ukernel_fn) xnn_pack_f32_gemm_gio_w; - f32_gemm_config.pack_gemm_goi = (xnn_packw_gemm_goi_ukernel_fn) xnn_x32_packw_gemm_goi_ukernel_x8__neon_ld4lane_u4_prfm; - f32_gemm_config.mr = 4; - f32_gemm_config.nr = 8; + f32_igemm_config.minmax.gemm[XNN_MR_TO_INDEX(1)] = xnn_init_hmp_gemm_ukernel((xnn_gemm_ukernel_fn) xnn_f32_gemm_minmax_ukernel_1x8__asm_aarch32_neon_cortex_a53_prfm); + f32_igemm_config.minmax.gemm[XNN_MR_TO_INDEX(4)] = xnn_init_hmp_gemm_ukernel((xnn_gemm_ukernel_fn) xnn_f32_gemm_minmax_ukernel_4x8__asm_aarch32_neon_cortex_a53_prfm); + f32_igemm_config.minmax.igemm[XNN_MR_TO_INDEX(1)] = xnn_init_hmp_igemm_ukernel((xnn_igemm_ukernel_fn) xnn_f32_igemm_minmax_ukernel_1x8__asm_aarch32_neon_cortex_a53_prfm); + f32_igemm_config.minmax.igemm[XNN_MR_TO_INDEX(4)] = xnn_init_hmp_igemm_ukernel((xnn_igemm_ukernel_fn) xnn_f32_igemm_minmax_ukernel_4x8__asm_aarch32_neon_cortex_a53_prfm); + f32_igemm_config.init.f32 = xnn_init_f32_minmax_scalar_params; + f32_igemm_config.pack_gemm_gio = (xnn_packw_gemm_gio_ukernel_fn) xnn_pack_f32_gemm_gio_w; + f32_igemm_config.pack_gemm_goi = (xnn_packw_gemm_goi_ukernel_fn) xnn_x32_packw_gemm_goi_ukernel_x8__neon_ld4lane_u4_prfm; + f32_igemm_config.mr = 4; + f32_igemm_config.nr = 8; break; case cpuinfo_uarch_cortex_a55r0: - f32_gemm_config.minmax.gemm[XNN_MR_TO_INDEX(1)] = xnn_init_hmp_gemm_ukernel((xnn_gemm_ukernel_fn) xnn_f32_gemm_minmax_ukernel_1x8__asm_aarch32_neon_cortex_a53); - f32_gemm_config.minmax.gemm[XNN_MR_TO_INDEX(4)] = xnn_init_hmp_gemm_ukernel((xnn_gemm_ukernel_fn) xnn_f32_gemm_minmax_ukernel_4x8__asm_aarch32_neon_cortex_a53); - f32_gemm_config.minmax.igemm[XNN_MR_TO_INDEX(1)] = xnn_init_hmp_igemm_ukernel((xnn_igemm_ukernel_fn) xnn_f32_igemm_minmax_ukernel_1x8__asm_aarch32_neon_cortex_a53); - f32_gemm_config.minmax.igemm[XNN_MR_TO_INDEX(4)] = xnn_init_hmp_igemm_ukernel((xnn_igemm_ukernel_fn) xnn_f32_igemm_minmax_ukernel_4x8__asm_aarch32_neon_cortex_a53); - f32_gemm_config.init.f32 = xnn_init_f32_minmax_scalar_params; - f32_gemm_config.pack_gemm_gio = (xnn_packw_gemm_gio_ukernel_fn) xnn_pack_f32_gemm_gio_w; - f32_gemm_config.pack_gemm_goi = (xnn_packw_gemm_goi_ukernel_fn) xnn_x32_packw_gemm_goi_ukernel_x8__neon_ld4lane_u4_prfm; - f32_gemm_config.mr = 4; - f32_gemm_config.nr = 8; + f32_igemm_config.minmax.gemm[XNN_MR_TO_INDEX(1)] = xnn_init_hmp_gemm_ukernel((xnn_gemm_ukernel_fn) xnn_f32_gemm_minmax_ukernel_1x8__asm_aarch32_neon_cortex_a53); + f32_igemm_config.minmax.gemm[XNN_MR_TO_INDEX(4)] = xnn_init_hmp_gemm_ukernel((xnn_gemm_ukernel_fn) xnn_f32_gemm_minmax_ukernel_4x8__asm_aarch32_neon_cortex_a53); + f32_igemm_config.minmax.igemm[XNN_MR_TO_INDEX(1)] = xnn_init_hmp_igemm_ukernel((xnn_igemm_ukernel_fn) xnn_f32_igemm_minmax_ukernel_1x8__asm_aarch32_neon_cortex_a53); + f32_igemm_config.minmax.igemm[XNN_MR_TO_INDEX(4)] = xnn_init_hmp_igemm_ukernel((xnn_igemm_ukernel_fn) xnn_f32_igemm_minmax_ukernel_4x8__asm_aarch32_neon_cortex_a53); + f32_igemm_config.init.f32 = xnn_init_f32_minmax_scalar_params; + f32_igemm_config.pack_gemm_gio = (xnn_packw_gemm_gio_ukernel_fn) xnn_pack_f32_gemm_gio_w; + f32_igemm_config.pack_gemm_goi = (xnn_packw_gemm_goi_ukernel_fn) xnn_x32_packw_gemm_goi_ukernel_x8__neon_ld4lane_u4_prfm; + f32_igemm_config.mr = 4; + f32_igemm_config.nr = 8; break; case cpuinfo_uarch_cortex_a32: case cpuinfo_uarch_cortex_a35: case cpuinfo_uarch_cortex_a55: - f32_gemm_config.minmax.gemm[XNN_MR_TO_INDEX(1)] = xnn_init_hmp_gemm_ukernel((xnn_gemm_ukernel_fn) xnn_f32_gemm_minmax_ukernel_1x8__asm_aarch32_neon_cortex_a53); - f32_gemm_config.minmax.gemm[XNN_MR_TO_INDEX(4)] = xnn_init_hmp_gemm_ukernel((xnn_gemm_ukernel_fn) xnn_f32_gemm_minmax_ukernel_4x8__asm_aarch32_neon_cortex_a55); - f32_gemm_config.minmax.igemm[XNN_MR_TO_INDEX(1)] = xnn_init_hmp_igemm_ukernel((xnn_igemm_ukernel_fn) xnn_f32_igemm_minmax_ukernel_1x8__asm_aarch32_neon_cortex_a53); - f32_gemm_config.minmax.igemm[XNN_MR_TO_INDEX(4)] = xnn_init_hmp_igemm_ukernel((xnn_igemm_ukernel_fn) xnn_f32_igemm_minmax_ukernel_4x8__asm_aarch32_neon_cortex_a55); - f32_gemm_config.init.f32 = xnn_init_f32_minmax_scalar_params; - f32_gemm_config.pack_gemm_gio = (xnn_packw_gemm_gio_ukernel_fn) xnn_pack_f32_gemm_gio_w; - f32_gemm_config.pack_gemm_goi = (xnn_packw_gemm_goi_ukernel_fn) xnn_x32_packw_gemm_goi_ukernel_x8__neon_ld4lane_u4_prfm; - f32_gemm_config.mr = 4; - f32_gemm_config.nr = 8; + f32_igemm_config.minmax.gemm[XNN_MR_TO_INDEX(1)] = xnn_init_hmp_gemm_ukernel((xnn_gemm_ukernel_fn) xnn_f32_gemm_minmax_ukernel_1x8__asm_aarch32_neon_cortex_a53); + f32_igemm_config.minmax.gemm[XNN_MR_TO_INDEX(4)] = xnn_init_hmp_gemm_ukernel((xnn_gemm_ukernel_fn) xnn_f32_gemm_minmax_ukernel_4x8__asm_aarch32_neon_cortex_a55); + f32_igemm_config.minmax.igemm[XNN_MR_TO_INDEX(1)] = xnn_init_hmp_igemm_ukernel((xnn_igemm_ukernel_fn) xnn_f32_igemm_minmax_ukernel_1x8__asm_aarch32_neon_cortex_a53); + f32_igemm_config.minmax.igemm[XNN_MR_TO_INDEX(4)] = xnn_init_hmp_igemm_ukernel((xnn_igemm_ukernel_fn) xnn_f32_igemm_minmax_ukernel_4x8__asm_aarch32_neon_cortex_a55); + f32_igemm_config.init.f32 = xnn_init_f32_minmax_scalar_params; + f32_igemm_config.pack_gemm_gio = (xnn_packw_gemm_gio_ukernel_fn) xnn_pack_f32_gemm_gio_w; + f32_igemm_config.pack_gemm_goi = (xnn_packw_gemm_goi_ukernel_fn) xnn_x32_packw_gemm_goi_ukernel_x8__neon_ld4lane_u4_prfm; + f32_igemm_config.mr = 4; + f32_igemm_config.nr = 8; break; case cpuinfo_uarch_cortex_a57: case cpuinfo_uarch_cortex_a72: case cpuinfo_uarch_cortex_a73: - f32_gemm_config.minmax.gemm[XNN_MR_TO_INDEX(1)] = xnn_init_hmp_gemm_ukernel((xnn_gemm_ukernel_fn) xnn_f32_gemm_minmax_ukernel_1x8__asm_aarch32_neon_cortex_a53_prfm); - f32_gemm_config.minmax.gemm[XNN_MR_TO_INDEX(4)] = xnn_init_hmp_gemm_ukernel((xnn_gemm_ukernel_fn) xnn_f32_gemm_minmax_ukernel_4x8__asm_aarch32_neon_cortex_a75_prfm); - f32_gemm_config.minmax.igemm[XNN_MR_TO_INDEX(1)] = xnn_init_hmp_igemm_ukernel((xnn_igemm_ukernel_fn) xnn_f32_igemm_minmax_ukernel_1x8__asm_aarch32_neon_cortex_a53_prfm); - f32_gemm_config.minmax.igemm[XNN_MR_TO_INDEX(4)] = xnn_init_hmp_igemm_ukernel((xnn_igemm_ukernel_fn) xnn_f32_igemm_minmax_ukernel_4x8__asm_aarch32_neon_cortex_a75_prfm); - f32_gemm_config.init.f32 = xnn_init_f32_minmax_scalar_params; - f32_gemm_config.pack_gemm_gio = (xnn_packw_gemm_gio_ukernel_fn) xnn_pack_f32_gemm_gio_w; - f32_gemm_config.pack_gemm_goi = (xnn_packw_gemm_goi_ukernel_fn) xnn_x32_packw_gemm_goi_ukernel_x8__neon_ld4lane_u4_prfm; - f32_gemm_config.mr = 4; - f32_gemm_config.nr = 8; + f32_igemm_config.minmax.gemm[XNN_MR_TO_INDEX(1)] = xnn_init_hmp_gemm_ukernel((xnn_gemm_ukernel_fn) xnn_f32_gemm_minmax_ukernel_1x8__asm_aarch32_neon_cortex_a53_prfm); + f32_igemm_config.minmax.gemm[XNN_MR_TO_INDEX(4)] = xnn_init_hmp_gemm_ukernel((xnn_gemm_ukernel_fn) xnn_f32_gemm_minmax_ukernel_4x8__asm_aarch32_neon_cortex_a75_prfm); + f32_igemm_config.minmax.igemm[XNN_MR_TO_INDEX(1)] = xnn_init_hmp_igemm_ukernel((xnn_igemm_ukernel_fn) xnn_f32_igemm_minmax_ukernel_1x8__asm_aarch32_neon_cortex_a53_prfm); + f32_igemm_config.minmax.igemm[XNN_MR_TO_INDEX(4)] = xnn_init_hmp_igemm_ukernel((xnn_igemm_ukernel_fn) xnn_f32_igemm_minmax_ukernel_4x8__asm_aarch32_neon_cortex_a75_prfm); + f32_igemm_config.init.f32 = xnn_init_f32_minmax_scalar_params; + f32_igemm_config.pack_gemm_gio = (xnn_packw_gemm_gio_ukernel_fn) xnn_pack_f32_gemm_gio_w; + f32_igemm_config.pack_gemm_goi = (xnn_packw_gemm_goi_ukernel_fn) xnn_x32_packw_gemm_goi_ukernel_x8__neon_ld4lane_u4_prfm; + f32_igemm_config.mr = 4; + f32_igemm_config.nr = 8; break; default: - f32_gemm_config.minmax.gemm[XNN_MR_TO_INDEX(1)] = xnn_init_hmp_gemm_ukernel((xnn_gemm_ukernel_fn) xnn_f32_gemm_minmax_ukernel_1x8__asm_aarch32_neon_cortex_a53); - f32_gemm_config.minmax.gemm[XNN_MR_TO_INDEX(4)] = xnn_init_hmp_gemm_ukernel((xnn_gemm_ukernel_fn) xnn_f32_gemm_minmax_ukernel_4x8__asm_aarch32_neon_cortex_a75); - f32_gemm_config.minmax.igemm[XNN_MR_TO_INDEX(1)] = xnn_init_hmp_igemm_ukernel((xnn_igemm_ukernel_fn) xnn_f32_igemm_minmax_ukernel_1x8__asm_aarch32_neon_cortex_a53); - f32_gemm_config.minmax.igemm[XNN_MR_TO_INDEX(4)] = xnn_init_hmp_igemm_ukernel((xnn_igemm_ukernel_fn) xnn_f32_igemm_minmax_ukernel_4x8__asm_aarch32_neon_cortex_a75); - f32_gemm_config.init.f32 = xnn_init_f32_minmax_scalar_params; - f32_gemm_config.pack_gemm_gio = (xnn_packw_gemm_gio_ukernel_fn) xnn_pack_f32_gemm_gio_w; - f32_gemm_config.pack_gemm_goi = (xnn_packw_gemm_goi_ukernel_fn) xnn_x32_packw_gemm_goi_ukernel_x8__neon_ld4lane_u4_prfm; - f32_gemm_config.mr = 4; - f32_gemm_config.nr = 8; + f32_igemm_config.minmax.gemm[XNN_MR_TO_INDEX(1)] = xnn_init_hmp_gemm_ukernel((xnn_gemm_ukernel_fn) xnn_f32_gemm_minmax_ukernel_1x8__asm_aarch32_neon_cortex_a53); + f32_igemm_config.minmax.gemm[XNN_MR_TO_INDEX(4)] = xnn_init_hmp_gemm_ukernel((xnn_gemm_ukernel_fn) xnn_f32_gemm_minmax_ukernel_4x8__asm_aarch32_neon_cortex_a75); + f32_igemm_config.minmax.igemm[XNN_MR_TO_INDEX(1)] = xnn_init_hmp_igemm_ukernel((xnn_igemm_ukernel_fn) xnn_f32_igemm_minmax_ukernel_1x8__asm_aarch32_neon_cortex_a53); + f32_igemm_config.minmax.igemm[XNN_MR_TO_INDEX(4)] = xnn_init_hmp_igemm_ukernel((xnn_igemm_ukernel_fn) xnn_f32_igemm_minmax_ukernel_4x8__asm_aarch32_neon_cortex_a75); + f32_igemm_config.init.f32 = xnn_init_f32_minmax_scalar_params; + f32_igemm_config.pack_gemm_gio = (xnn_packw_gemm_gio_ukernel_fn) xnn_pack_f32_gemm_gio_w; + f32_igemm_config.pack_gemm_goi = (xnn_packw_gemm_goi_ukernel_fn) xnn_x32_packw_gemm_goi_ukernel_x8__neon_ld4lane_u4_prfm; + f32_igemm_config.mr = 4; + f32_igemm_config.nr = 8; break; } #if XNN_MAX_UARCH_TYPES > 1 { /* Choose micro-kernels for little cores according to micro-kernel specification for the big core */ - const uint32_t mr = f32_gemm_config.mr; - const uint32_t nr = f32_gemm_config.nr; + const uint32_t mr = f32_igemm_config.mr; + const uint32_t nr = f32_igemm_config.nr; for (size_t i = 1; i < XNN_MAX_UARCH_TYPES; i++) { const struct cpuinfo_uarch_info* uarch_info = cpuinfo_get_uarch(i); if (uarch_info == NULL) { @@ -451,26 +454,26 @@ static void init_f32_gemm_config(void) { switch (uarch_info->uarch) { case cpuinfo_uarch_cortex_a53: if (mr == 4 && nr == 8) { - f32_gemm_config.minmax.gemm[XNN_MR_TO_INDEX(1)].function[i] = (xnn_gemm_ukernel_fn) xnn_f32_gemm_minmax_ukernel_1x8__asm_aarch32_neon_cortex_a53_prfm; - f32_gemm_config.minmax.gemm[XNN_MR_TO_INDEX(4)].function[i] = (xnn_gemm_ukernel_fn) xnn_f32_gemm_minmax_ukernel_4x8__asm_aarch32_neon_cortex_a53_prfm; - f32_gemm_config.minmax.igemm[XNN_MR_TO_INDEX(1)].function[i] = (xnn_igemm_ukernel_fn) xnn_f32_igemm_minmax_ukernel_1x8__asm_aarch32_neon_cortex_a53_prfm; - f32_gemm_config.minmax.igemm[XNN_MR_TO_INDEX(4)].function[i] = (xnn_igemm_ukernel_fn) xnn_f32_igemm_minmax_ukernel_4x8__asm_aarch32_neon_cortex_a53_prfm; + f32_igemm_config.minmax.gemm[XNN_MR_TO_INDEX(1)].function[i] = (xnn_gemm_ukernel_fn) xnn_f32_gemm_minmax_ukernel_1x8__asm_aarch32_neon_cortex_a53_prfm; + f32_igemm_config.minmax.gemm[XNN_MR_TO_INDEX(4)].function[i] = (xnn_gemm_ukernel_fn) xnn_f32_gemm_minmax_ukernel_4x8__asm_aarch32_neon_cortex_a53_prfm; + f32_igemm_config.minmax.igemm[XNN_MR_TO_INDEX(1)].function[i] = (xnn_igemm_ukernel_fn) xnn_f32_igemm_minmax_ukernel_1x8__asm_aarch32_neon_cortex_a53_prfm; + f32_igemm_config.minmax.igemm[XNN_MR_TO_INDEX(4)].function[i] = (xnn_igemm_ukernel_fn) xnn_f32_igemm_minmax_ukernel_4x8__asm_aarch32_neon_cortex_a53_prfm; } break; case cpuinfo_uarch_cortex_a55r0: if (mr == 4 && nr == 8) { - f32_gemm_config.minmax.gemm[XNN_MR_TO_INDEX(1)].function[i] = (xnn_gemm_ukernel_fn) xnn_f32_gemm_minmax_ukernel_1x8__asm_aarch32_neon_cortex_a53; - f32_gemm_config.minmax.gemm[XNN_MR_TO_INDEX(4)].function[i] = (xnn_gemm_ukernel_fn) xnn_f32_gemm_minmax_ukernel_4x8__asm_aarch32_neon_cortex_a53; - f32_gemm_config.minmax.igemm[XNN_MR_TO_INDEX(1)].function[i] = (xnn_igemm_ukernel_fn) xnn_f32_igemm_minmax_ukernel_1x8__asm_aarch32_neon_cortex_a53; - f32_gemm_config.minmax.igemm[XNN_MR_TO_INDEX(4)].function[i] = (xnn_igemm_ukernel_fn) xnn_f32_igemm_minmax_ukernel_4x8__asm_aarch32_neon_cortex_a53; + f32_igemm_config.minmax.gemm[XNN_MR_TO_INDEX(1)].function[i] = (xnn_gemm_ukernel_fn) xnn_f32_gemm_minmax_ukernel_1x8__asm_aarch32_neon_cortex_a53; + f32_igemm_config.minmax.gemm[XNN_MR_TO_INDEX(4)].function[i] = (xnn_gemm_ukernel_fn) xnn_f32_gemm_minmax_ukernel_4x8__asm_aarch32_neon_cortex_a53; + f32_igemm_config.minmax.igemm[XNN_MR_TO_INDEX(1)].function[i] = (xnn_igemm_ukernel_fn) xnn_f32_igemm_minmax_ukernel_1x8__asm_aarch32_neon_cortex_a53; + f32_igemm_config.minmax.igemm[XNN_MR_TO_INDEX(4)].function[i] = (xnn_igemm_ukernel_fn) xnn_f32_igemm_minmax_ukernel_4x8__asm_aarch32_neon_cortex_a53; } break; case cpuinfo_uarch_cortex_a55: if (mr == 4 && nr == 8) { - f32_gemm_config.minmax.gemm[XNN_MR_TO_INDEX(1)].function[i] = (xnn_gemm_ukernel_fn) xnn_f32_gemm_minmax_ukernel_1x8__asm_aarch32_neon_cortex_a53; - f32_gemm_config.minmax.gemm[XNN_MR_TO_INDEX(4)].function[i] = (xnn_gemm_ukernel_fn) xnn_f32_gemm_minmax_ukernel_4x8__asm_aarch32_neon_cortex_a55; - f32_gemm_config.minmax.igemm[XNN_MR_TO_INDEX(1)].function[i] = (xnn_igemm_ukernel_fn) xnn_f32_igemm_minmax_ukernel_1x8__asm_aarch32_neon_cortex_a53; - f32_gemm_config.minmax.igemm[XNN_MR_TO_INDEX(4)].function[i] = (xnn_igemm_ukernel_fn) xnn_f32_igemm_minmax_ukernel_4x8__asm_aarch32_neon_cortex_a55; + f32_igemm_config.minmax.gemm[XNN_MR_TO_INDEX(1)].function[i] = (xnn_gemm_ukernel_fn) xnn_f32_gemm_minmax_ukernel_1x8__asm_aarch32_neon_cortex_a53; + f32_igemm_config.minmax.gemm[XNN_MR_TO_INDEX(4)].function[i] = (xnn_gemm_ukernel_fn) xnn_f32_gemm_minmax_ukernel_4x8__asm_aarch32_neon_cortex_a55; + f32_igemm_config.minmax.igemm[XNN_MR_TO_INDEX(1)].function[i] = (xnn_igemm_ukernel_fn) xnn_f32_igemm_minmax_ukernel_1x8__asm_aarch32_neon_cortex_a53; + f32_igemm_config.minmax.igemm[XNN_MR_TO_INDEX(4)].function[i] = (xnn_igemm_ukernel_fn) xnn_f32_igemm_minmax_ukernel_4x8__asm_aarch32_neon_cortex_a55; } break; default: @@ -480,48 +483,48 @@ static void init_f32_gemm_config(void) { } #endif // XNN_MAX_UARCH_TYPES > 1 #else // XNN_ENABLE_ASSEMBLY - f32_gemm_config.minmax.gemm[XNN_MR_TO_INDEX(1)] = xnn_init_hmp_gemm_ukernel((xnn_gemm_ukernel_fn) xnn_f32_gemm_minmax_ukernel_1x8__neon_lane_ld64); - f32_gemm_config.minmax.gemm[XNN_MR_TO_INDEX(4)] = xnn_init_hmp_gemm_ukernel((xnn_gemm_ukernel_fn) xnn_f32_gemm_minmax_ukernel_4x8__neon_lane_ld128); - f32_gemm_config.minmax.igemm[XNN_MR_TO_INDEX(1)] = xnn_init_hmp_igemm_ukernel((xnn_igemm_ukernel_fn) xnn_f32_igemm_minmax_ukernel_1x8__neon_lane_ld64); - f32_gemm_config.minmax.igemm[XNN_MR_TO_INDEX(4)] = xnn_init_hmp_igemm_ukernel((xnn_igemm_ukernel_fn) xnn_f32_igemm_minmax_ukernel_4x8__neon_lane_ld128); - f32_gemm_config.init.f32 = xnn_init_f32_minmax_scalar_params; - f32_gemm_config.pack_gemm_gio = (xnn_packw_gemm_gio_ukernel_fn) xnn_pack_f32_gemm_gio_w; - f32_gemm_config.pack_gemm_goi = (xnn_packw_gemm_goi_ukernel_fn) xnn_x32_packw_gemm_goi_ukernel_x8__neon_ld4lane_u4_prfm; - f32_gemm_config.mr = 4; - f32_gemm_config.nr = 8; + f32_igemm_config.minmax.gemm[XNN_MR_TO_INDEX(1)] = xnn_init_hmp_gemm_ukernel((xnn_gemm_ukernel_fn) xnn_f32_gemm_minmax_ukernel_1x8__neon_lane_ld64); + f32_igemm_config.minmax.gemm[XNN_MR_TO_INDEX(4)] = xnn_init_hmp_gemm_ukernel((xnn_gemm_ukernel_fn) xnn_f32_gemm_minmax_ukernel_4x8__neon_lane_ld128); + f32_igemm_config.minmax.igemm[XNN_MR_TO_INDEX(1)] = xnn_init_hmp_igemm_ukernel((xnn_igemm_ukernel_fn) xnn_f32_igemm_minmax_ukernel_1x8__neon_lane_ld64); + f32_igemm_config.minmax.igemm[XNN_MR_TO_INDEX(4)] = xnn_init_hmp_igemm_ukernel((xnn_igemm_ukernel_fn) xnn_f32_igemm_minmax_ukernel_4x8__neon_lane_ld128); + f32_igemm_config.init.f32 = xnn_init_f32_minmax_scalar_params; + f32_igemm_config.pack_gemm_gio = (xnn_packw_gemm_gio_ukernel_fn) xnn_pack_f32_gemm_gio_w; + f32_igemm_config.pack_gemm_goi = (xnn_packw_gemm_goi_ukernel_fn) xnn_x32_packw_gemm_goi_ukernel_x8__neon_ld4lane_u4_prfm; + f32_igemm_config.mr = 4; + f32_igemm_config.nr = 8; #endif // XNN_ENABLE_ASSEMBLY } else if (!XNN_PLATFORM_MOBILE) { - f32_gemm_config.linear.gemm[XNN_MR_TO_INDEX(1)] = xnn_init_hmp_gemm_ukernel((xnn_gemm_ukernel_fn) xnn_f32_gemm_ukernel_1x4__scalar); - f32_gemm_config.linear.gemm[XNN_MR_TO_INDEX(4)] = xnn_init_hmp_gemm_ukernel((xnn_gemm_ukernel_fn) xnn_f32_gemm_ukernel_4x4__scalar); - f32_gemm_config.linear.igemm[XNN_MR_TO_INDEX(1)] = xnn_init_hmp_igemm_ukernel((xnn_igemm_ukernel_fn) xnn_f32_igemm_ukernel_1x4__scalar); - f32_gemm_config.linear.igemm[XNN_MR_TO_INDEX(4)] = xnn_init_hmp_igemm_ukernel((xnn_igemm_ukernel_fn) xnn_f32_igemm_ukernel_4x4__scalar); - f32_gemm_config.minmax.gemm[XNN_MR_TO_INDEX(1)] = xnn_init_hmp_gemm_ukernel((xnn_gemm_ukernel_fn) xnn_f32_gemm_minmax_ukernel_1x4__scalar); - f32_gemm_config.minmax.gemm[XNN_MR_TO_INDEX(4)] = xnn_init_hmp_gemm_ukernel((xnn_gemm_ukernel_fn) xnn_f32_gemm_minmax_ukernel_4x4__scalar); - f32_gemm_config.minmax.igemm[XNN_MR_TO_INDEX(1)] = xnn_init_hmp_igemm_ukernel((xnn_igemm_ukernel_fn) xnn_f32_igemm_minmax_ukernel_1x4__scalar); - f32_gemm_config.minmax.igemm[XNN_MR_TO_INDEX(4)] = xnn_init_hmp_igemm_ukernel((xnn_igemm_ukernel_fn) xnn_f32_igemm_minmax_ukernel_4x4__scalar); - f32_gemm_config.relu.gemm[XNN_MR_TO_INDEX(1)] = xnn_init_hmp_gemm_ukernel((xnn_gemm_ukernel_fn) xnn_f32_gemm_relu_ukernel_1x4__scalar); - f32_gemm_config.relu.gemm[XNN_MR_TO_INDEX(4)] = xnn_init_hmp_gemm_ukernel((xnn_gemm_ukernel_fn) xnn_f32_gemm_relu_ukernel_4x4__scalar); - f32_gemm_config.relu.igemm[XNN_MR_TO_INDEX(1)] = xnn_init_hmp_igemm_ukernel((xnn_igemm_ukernel_fn) xnn_f32_igemm_relu_ukernel_1x4__scalar); - f32_gemm_config.relu.igemm[XNN_MR_TO_INDEX(4)] = xnn_init_hmp_igemm_ukernel((xnn_igemm_ukernel_fn) xnn_f32_igemm_relu_ukernel_4x4__scalar); - f32_gemm_config.init.f32 = xnn_init_f32_minmax_scalar_params; - f32_gemm_config.pack_gemm_gio = (xnn_packw_gemm_gio_ukernel_fn) xnn_pack_f32_gemm_gio_w; - f32_gemm_config.pack_gemm_goi = (xnn_packw_gemm_goi_ukernel_fn) xnn_x32_packw_gemm_goi_ukernel_x4__scalar_float_u4; - f32_gemm_config.mr = 4; - f32_gemm_config.nr = 4; + f32_igemm_config.linear.gemm[XNN_MR_TO_INDEX(1)] = xnn_init_hmp_gemm_ukernel((xnn_gemm_ukernel_fn) xnn_f32_gemm_ukernel_1x4__scalar); + f32_igemm_config.linear.gemm[XNN_MR_TO_INDEX(4)] = xnn_init_hmp_gemm_ukernel((xnn_gemm_ukernel_fn) xnn_f32_gemm_ukernel_4x4__scalar); + f32_igemm_config.linear.igemm[XNN_MR_TO_INDEX(1)] = xnn_init_hmp_igemm_ukernel((xnn_igemm_ukernel_fn) xnn_f32_igemm_ukernel_1x4__scalar); + f32_igemm_config.linear.igemm[XNN_MR_TO_INDEX(4)] = xnn_init_hmp_igemm_ukernel((xnn_igemm_ukernel_fn) xnn_f32_igemm_ukernel_4x4__scalar); + f32_igemm_config.minmax.gemm[XNN_MR_TO_INDEX(1)] = xnn_init_hmp_gemm_ukernel((xnn_gemm_ukernel_fn) xnn_f32_gemm_minmax_ukernel_1x4__scalar); + f32_igemm_config.minmax.gemm[XNN_MR_TO_INDEX(4)] = xnn_init_hmp_gemm_ukernel((xnn_gemm_ukernel_fn) xnn_f32_gemm_minmax_ukernel_4x4__scalar); + f32_igemm_config.minmax.igemm[XNN_MR_TO_INDEX(1)] = xnn_init_hmp_igemm_ukernel((xnn_igemm_ukernel_fn) xnn_f32_igemm_minmax_ukernel_1x4__scalar); + f32_igemm_config.minmax.igemm[XNN_MR_TO_INDEX(4)] = xnn_init_hmp_igemm_ukernel((xnn_igemm_ukernel_fn) xnn_f32_igemm_minmax_ukernel_4x4__scalar); + f32_igemm_config.relu.gemm[XNN_MR_TO_INDEX(1)] = xnn_init_hmp_gemm_ukernel((xnn_gemm_ukernel_fn) xnn_f32_gemm_relu_ukernel_1x4__scalar); + f32_igemm_config.relu.gemm[XNN_MR_TO_INDEX(4)] = xnn_init_hmp_gemm_ukernel((xnn_gemm_ukernel_fn) xnn_f32_gemm_relu_ukernel_4x4__scalar); + f32_igemm_config.relu.igemm[XNN_MR_TO_INDEX(1)] = xnn_init_hmp_igemm_ukernel((xnn_igemm_ukernel_fn) xnn_f32_igemm_relu_ukernel_1x4__scalar); + f32_igemm_config.relu.igemm[XNN_MR_TO_INDEX(4)] = xnn_init_hmp_igemm_ukernel((xnn_igemm_ukernel_fn) xnn_f32_igemm_relu_ukernel_4x4__scalar); + f32_igemm_config.init.f32 = xnn_init_f32_minmax_scalar_params; + f32_igemm_config.pack_gemm_gio = (xnn_packw_gemm_gio_ukernel_fn) xnn_pack_f32_gemm_gio_w; + f32_igemm_config.pack_gemm_goi = (xnn_packw_gemm_goi_ukernel_fn) xnn_x32_packw_gemm_goi_ukernel_x4__scalar_float_u4; + f32_igemm_config.mr = 4; + f32_igemm_config.nr = 4; } #elif XNN_ARCH_ARM64 #if XNN_ENABLE_ASSEMBLY && !XNN_PLATFORM_IOS && !XNN_PLATFORM_MAC switch (cpuinfo_get_core(0)->uarch) { case cpuinfo_uarch_cortex_a72: - f32_gemm_config.minmax.gemm[XNN_MR_TO_INDEX(1)] = xnn_init_hmp_gemm_ukernel((xnn_gemm_ukernel_fn) xnn_f32_gemm_minmax_ukernel_1x8__asm_aarch64_neonfma_cortex_a75_prfm); - f32_gemm_config.minmax.gemm[XNN_MR_TO_INDEX(4)] = xnn_init_hmp_gemm_ukernel((xnn_gemm_ukernel_fn) xnn_f32_gemm_minmax_ukernel_4x8__asm_aarch64_neonfma_cortex_a75_prfm); - f32_gemm_config.minmax.igemm[XNN_MR_TO_INDEX(1)] = xnn_init_hmp_igemm_ukernel((xnn_igemm_ukernel_fn) xnn_f32_igemm_minmax_ukernel_1x8__asm_aarch64_neonfma_cortex_a75_prfm); - f32_gemm_config.minmax.igemm[XNN_MR_TO_INDEX(4)] = xnn_init_hmp_igemm_ukernel((xnn_igemm_ukernel_fn) xnn_f32_igemm_minmax_ukernel_4x8__asm_aarch64_neonfma_cortex_a75_prfm); - f32_gemm_config.init.f32 = xnn_init_f32_minmax_scalar_params; - f32_gemm_config.pack_gemm_gio = (xnn_packw_gemm_gio_ukernel_fn) xnn_pack_f32_gemm_gio_w; - f32_gemm_config.pack_gemm_goi = (xnn_packw_gemm_goi_ukernel_fn) xnn_x32_packw_gemm_goi_ukernel_x8__neon_ld4lane_u4_prfm; - f32_gemm_config.mr = 4; - f32_gemm_config.nr = 8; + f32_igemm_config.minmax.gemm[XNN_MR_TO_INDEX(1)] = xnn_init_hmp_gemm_ukernel((xnn_gemm_ukernel_fn) xnn_f32_gemm_minmax_ukernel_1x8__asm_aarch64_neonfma_cortex_a75_prfm); + f32_igemm_config.minmax.gemm[XNN_MR_TO_INDEX(4)] = xnn_init_hmp_gemm_ukernel((xnn_gemm_ukernel_fn) xnn_f32_gemm_minmax_ukernel_4x8__asm_aarch64_neonfma_cortex_a75_prfm); + f32_igemm_config.minmax.igemm[XNN_MR_TO_INDEX(1)] = xnn_init_hmp_igemm_ukernel((xnn_igemm_ukernel_fn) xnn_f32_igemm_minmax_ukernel_1x8__asm_aarch64_neonfma_cortex_a75_prfm); + f32_igemm_config.minmax.igemm[XNN_MR_TO_INDEX(4)] = xnn_init_hmp_igemm_ukernel((xnn_igemm_ukernel_fn) xnn_f32_igemm_minmax_ukernel_4x8__asm_aarch64_neonfma_cortex_a75_prfm); + f32_igemm_config.init.f32 = xnn_init_f32_minmax_scalar_params; + f32_igemm_config.pack_gemm_gio = (xnn_packw_gemm_gio_ukernel_fn) xnn_pack_f32_gemm_gio_w; + f32_igemm_config.pack_gemm_goi = (xnn_packw_gemm_goi_ukernel_fn) xnn_x32_packw_gemm_goi_ukernel_x8__neon_ld4lane_u4_prfm; + f32_igemm_config.mr = 4; + f32_igemm_config.nr = 8; break; case cpuinfo_uarch_cortex_a57: case cpuinfo_uarch_cortex_a75: @@ -529,119 +532,119 @@ static void init_f32_gemm_config(void) { case cpuinfo_uarch_exynos_m3: case cpuinfo_uarch_exynos_m4: case cpuinfo_uarch_neoverse_n1: - f32_gemm_config.minmax.gemm[XNN_MR_TO_INDEX(1)] = xnn_init_hmp_gemm_ukernel((xnn_gemm_ukernel_fn) xnn_f32_gemm_minmax_ukernel_1x8__asm_aarch64_neonfma_cortex_a75_prfm); - f32_gemm_config.minmax.gemm[XNN_MR_TO_INDEX(6)] = xnn_init_hmp_gemm_ukernel((xnn_gemm_ukernel_fn) xnn_f32_gemm_minmax_ukernel_6x8__asm_aarch64_neonfma_cortex_a75_prfm); - f32_gemm_config.minmax.igemm[XNN_MR_TO_INDEX(1)] = xnn_init_hmp_igemm_ukernel((xnn_igemm_ukernel_fn) xnn_f32_igemm_minmax_ukernel_1x8__asm_aarch64_neonfma_cortex_a75_prfm); - f32_gemm_config.minmax.igemm[XNN_MR_TO_INDEX(6)] = xnn_init_hmp_igemm_ukernel((xnn_igemm_ukernel_fn) xnn_f32_igemm_minmax_ukernel_6x8__asm_aarch64_neonfma_cortex_a75_prfm); + f32_igemm_config.minmax.gemm[XNN_MR_TO_INDEX(1)] = xnn_init_hmp_gemm_ukernel((xnn_gemm_ukernel_fn) xnn_f32_gemm_minmax_ukernel_1x8__asm_aarch64_neonfma_cortex_a75_prfm); + f32_igemm_config.minmax.gemm[XNN_MR_TO_INDEX(6)] = xnn_init_hmp_gemm_ukernel((xnn_gemm_ukernel_fn) xnn_f32_gemm_minmax_ukernel_6x8__asm_aarch64_neonfma_cortex_a75_prfm); + f32_igemm_config.minmax.igemm[XNN_MR_TO_INDEX(1)] = xnn_init_hmp_igemm_ukernel((xnn_igemm_ukernel_fn) xnn_f32_igemm_minmax_ukernel_1x8__asm_aarch64_neonfma_cortex_a75_prfm); + f32_igemm_config.minmax.igemm[XNN_MR_TO_INDEX(6)] = xnn_init_hmp_igemm_ukernel((xnn_igemm_ukernel_fn) xnn_f32_igemm_minmax_ukernel_6x8__asm_aarch64_neonfma_cortex_a75_prfm); #if XNN_ENABLE_GEMM_M_SPECIALIZATION - f32_gemm_config.minmax.gemm[XNN_MR_TO_INDEX(4)] = xnn_init_hmp_gemm_ukernel((xnn_gemm_ukernel_fn) xnn_f32_gemm_minmax_ukernel_4x8__asm_aarch64_neonfma_cortex_a75_prfm); - f32_gemm_config.minmax.igemm[XNN_MR_TO_INDEX(4)] = xnn_init_hmp_igemm_ukernel((xnn_igemm_ukernel_fn) xnn_f32_igemm_minmax_ukernel_4x8__asm_aarch64_neonfma_cortex_a75_prfm); + f32_igemm_config.minmax.gemm[XNN_MR_TO_INDEX(4)] = xnn_init_hmp_gemm_ukernel((xnn_gemm_ukernel_fn) xnn_f32_gemm_minmax_ukernel_4x8__asm_aarch64_neonfma_cortex_a75_prfm); + f32_igemm_config.minmax.igemm[XNN_MR_TO_INDEX(4)] = xnn_init_hmp_igemm_ukernel((xnn_igemm_ukernel_fn) xnn_f32_igemm_minmax_ukernel_4x8__asm_aarch64_neonfma_cortex_a75_prfm); #endif - f32_gemm_config.init.f32 = xnn_init_f32_minmax_scalar_params; - f32_gemm_config.pack_gemm_gio = (xnn_packw_gemm_gio_ukernel_fn) xnn_pack_f32_gemm_gio_w; - f32_gemm_config.pack_gemm_goi = (xnn_packw_gemm_goi_ukernel_fn) xnn_x32_packw_gemm_goi_ukernel_x8__neon_ld4lane_u4_prfm; - f32_gemm_config.mr = 6; - f32_gemm_config.nr = 8; + f32_igemm_config.init.f32 = xnn_init_f32_minmax_scalar_params; + f32_igemm_config.pack_gemm_gio = (xnn_packw_gemm_gio_ukernel_fn) xnn_pack_f32_gemm_gio_w; + f32_igemm_config.pack_gemm_goi = (xnn_packw_gemm_goi_ukernel_fn) xnn_x32_packw_gemm_goi_ukernel_x8__neon_ld4lane_u4_prfm; + f32_igemm_config.mr = 6; + f32_igemm_config.nr = 8; break; case cpuinfo_uarch_exynos_m1: case cpuinfo_uarch_exynos_m2: - f32_gemm_config.minmax.gemm[XNN_MR_TO_INDEX(1)] = xnn_init_hmp_gemm_ukernel((xnn_gemm_ukernel_fn) xnn_f32_gemm_minmax_ukernel_1x8s4__neonfma); - f32_gemm_config.minmax.gemm[XNN_MR_TO_INDEX(6)] = xnn_init_hmp_gemm_ukernel((xnn_gemm_ukernel_fn) xnn_f32_gemm_minmax_ukernel_6x8s4__neonfma); - f32_gemm_config.minmax.igemm[XNN_MR_TO_INDEX(1)] = xnn_init_hmp_igemm_ukernel((xnn_igemm_ukernel_fn) xnn_f32_igemm_minmax_ukernel_1x8s4__neonfma); - f32_gemm_config.minmax.igemm[XNN_MR_TO_INDEX(6)] = xnn_init_hmp_igemm_ukernel((xnn_igemm_ukernel_fn) xnn_f32_igemm_minmax_ukernel_6x8s4__neonfma); + f32_igemm_config.minmax.gemm[XNN_MR_TO_INDEX(1)] = xnn_init_hmp_gemm_ukernel((xnn_gemm_ukernel_fn) xnn_f32_gemm_minmax_ukernel_1x8s4__neonfma); + f32_igemm_config.minmax.gemm[XNN_MR_TO_INDEX(6)] = xnn_init_hmp_gemm_ukernel((xnn_gemm_ukernel_fn) xnn_f32_gemm_minmax_ukernel_6x8s4__neonfma); + f32_igemm_config.minmax.igemm[XNN_MR_TO_INDEX(1)] = xnn_init_hmp_igemm_ukernel((xnn_igemm_ukernel_fn) xnn_f32_igemm_minmax_ukernel_1x8s4__neonfma); + f32_igemm_config.minmax.igemm[XNN_MR_TO_INDEX(6)] = xnn_init_hmp_igemm_ukernel((xnn_igemm_ukernel_fn) xnn_f32_igemm_minmax_ukernel_6x8s4__neonfma); #if XNN_ENABLE_GEMM_M_SPECIALIZATION - f32_gemm_config.minmax.gemm[XNN_MR_TO_INDEX(4)] = xnn_init_hmp_gemm_ukernel((xnn_gemm_ukernel_fn) xnn_f32_gemm_minmax_ukernel_4x8s4__neonfma); - f32_gemm_config.minmax.igemm[XNN_MR_TO_INDEX(4)] = xnn_init_hmp_igemm_ukernel((xnn_igemm_ukernel_fn) xnn_f32_igemm_minmax_ukernel_4x8s4__neonfma); + f32_igemm_config.minmax.gemm[XNN_MR_TO_INDEX(4)] = xnn_init_hmp_gemm_ukernel((xnn_gemm_ukernel_fn) xnn_f32_gemm_minmax_ukernel_4x8s4__neonfma); + f32_igemm_config.minmax.igemm[XNN_MR_TO_INDEX(4)] = xnn_init_hmp_igemm_ukernel((xnn_igemm_ukernel_fn) xnn_f32_igemm_minmax_ukernel_4x8s4__neonfma); #endif - f32_gemm_config.init.f32 = xnn_init_f32_minmax_scalar_params; - f32_gemm_config.pack_gemm_gio = (xnn_packw_gemm_gio_ukernel_fn) xnn_pack_f32_gemm_gio_w; - f32_gemm_config.pack_gemm_goi = (xnn_packw_gemm_goi_ukernel_fn) xnn_x32_packw_gemm_goi_ukernel_x8s4__neon_ld4lane_u4_prfm; - f32_gemm_config.mr = 6; - f32_gemm_config.nr = 8; - f32_gemm_config.log2_sr = 2; + f32_igemm_config.init.f32 = xnn_init_f32_minmax_scalar_params; + f32_igemm_config.pack_gemm_gio = (xnn_packw_gemm_gio_ukernel_fn) xnn_pack_f32_gemm_gio_w; + f32_igemm_config.pack_gemm_goi = (xnn_packw_gemm_goi_ukernel_fn) xnn_x32_packw_gemm_goi_ukernel_x8s4__neon_ld4lane_u4_prfm; + f32_igemm_config.mr = 6; + f32_igemm_config.nr = 8; + f32_igemm_config.log2_sr = 2; break; case cpuinfo_uarch_cortex_a53: - f32_gemm_config.minmax.gemm[XNN_MR_TO_INDEX(1)] = xnn_init_hmp_gemm_ukernel((xnn_gemm_ukernel_fn) xnn_f32_gemm_minmax_ukernel_1x8__asm_aarch64_neonfma_cortex_a53_prfm); - f32_gemm_config.minmax.gemm[XNN_MR_TO_INDEX(6)] = xnn_init_hmp_gemm_ukernel((xnn_gemm_ukernel_fn) xnn_f32_gemm_minmax_ukernel_6x8__asm_aarch64_neonfma_cortex_a53_prfm); - f32_gemm_config.minmax.igemm[XNN_MR_TO_INDEX(1)] = xnn_init_hmp_igemm_ukernel((xnn_igemm_ukernel_fn) xnn_f32_igemm_minmax_ukernel_1x8__asm_aarch64_neonfma_cortex_a53_prfm); - f32_gemm_config.minmax.igemm[XNN_MR_TO_INDEX(6)] = xnn_init_hmp_igemm_ukernel((xnn_igemm_ukernel_fn) xnn_f32_igemm_minmax_ukernel_6x8__asm_aarch64_neonfma_cortex_a53_prfm); + f32_igemm_config.minmax.gemm[XNN_MR_TO_INDEX(1)] = xnn_init_hmp_gemm_ukernel((xnn_gemm_ukernel_fn) xnn_f32_gemm_minmax_ukernel_1x8__asm_aarch64_neonfma_cortex_a53_prfm); + f32_igemm_config.minmax.gemm[XNN_MR_TO_INDEX(6)] = xnn_init_hmp_gemm_ukernel((xnn_gemm_ukernel_fn) xnn_f32_gemm_minmax_ukernel_6x8__asm_aarch64_neonfma_cortex_a53_prfm); + f32_igemm_config.minmax.igemm[XNN_MR_TO_INDEX(1)] = xnn_init_hmp_igemm_ukernel((xnn_igemm_ukernel_fn) xnn_f32_igemm_minmax_ukernel_1x8__asm_aarch64_neonfma_cortex_a53_prfm); + f32_igemm_config.minmax.igemm[XNN_MR_TO_INDEX(6)] = xnn_init_hmp_igemm_ukernel((xnn_igemm_ukernel_fn) xnn_f32_igemm_minmax_ukernel_6x8__asm_aarch64_neonfma_cortex_a53_prfm); #if XNN_ENABLE_GEMM_M_SPECIALIZATION - f32_gemm_config.minmax.gemm[XNN_MR_TO_INDEX(4)] = xnn_init_hmp_gemm_ukernel((xnn_gemm_ukernel_fn) xnn_f32_gemm_minmax_ukernel_4x8__asm_aarch64_neonfma_cortex_a53_prfm); - f32_gemm_config.minmax.igemm[XNN_MR_TO_INDEX(4)] = xnn_init_hmp_igemm_ukernel((xnn_igemm_ukernel_fn) xnn_f32_igemm_minmax_ukernel_4x8__asm_aarch64_neonfma_cortex_a53_prfm); + f32_igemm_config.minmax.gemm[XNN_MR_TO_INDEX(4)] = xnn_init_hmp_gemm_ukernel((xnn_gemm_ukernel_fn) xnn_f32_gemm_minmax_ukernel_4x8__asm_aarch64_neonfma_cortex_a53_prfm); + f32_igemm_config.minmax.igemm[XNN_MR_TO_INDEX(4)] = xnn_init_hmp_igemm_ukernel((xnn_igemm_ukernel_fn) xnn_f32_igemm_minmax_ukernel_4x8__asm_aarch64_neonfma_cortex_a53_prfm); #endif - f32_gemm_config.init.f32 = xnn_init_f32_minmax_scalar_params; - f32_gemm_config.pack_gemm_gio = (xnn_packw_gemm_gio_ukernel_fn) xnn_pack_f32_gemm_gio_w; - f32_gemm_config.pack_gemm_goi = (xnn_packw_gemm_goi_ukernel_fn) xnn_x32_packw_gemm_goi_ukernel_x8__neon_ld4lane_u4_prfm; - f32_gemm_config.mr = 6; - f32_gemm_config.nr = 8; + f32_igemm_config.init.f32 = xnn_init_f32_minmax_scalar_params; + f32_igemm_config.pack_gemm_gio = (xnn_packw_gemm_gio_ukernel_fn) xnn_pack_f32_gemm_gio_w; + f32_igemm_config.pack_gemm_goi = (xnn_packw_gemm_goi_ukernel_fn) xnn_x32_packw_gemm_goi_ukernel_x8__neon_ld4lane_u4_prfm; + f32_igemm_config.mr = 6; + f32_igemm_config.nr = 8; break; case cpuinfo_uarch_cortex_a55r0: - f32_gemm_config.minmax.gemm[XNN_MR_TO_INDEX(1)] = xnn_init_hmp_gemm_ukernel((xnn_gemm_ukernel_fn) xnn_f32_gemm_minmax_ukernel_1x8__asm_aarch64_neonfma_cortex_a53); - f32_gemm_config.minmax.gemm[XNN_MR_TO_INDEX(6)] = xnn_init_hmp_gemm_ukernel((xnn_gemm_ukernel_fn) xnn_f32_gemm_minmax_ukernel_6x8__asm_aarch64_neonfma_cortex_a53); - f32_gemm_config.minmax.igemm[XNN_MR_TO_INDEX(1)] = xnn_init_hmp_igemm_ukernel((xnn_igemm_ukernel_fn) xnn_f32_igemm_minmax_ukernel_1x8__asm_aarch64_neonfma_cortex_a53); - f32_gemm_config.minmax.igemm[XNN_MR_TO_INDEX(6)] = xnn_init_hmp_igemm_ukernel((xnn_igemm_ukernel_fn) xnn_f32_igemm_minmax_ukernel_6x8__asm_aarch64_neonfma_cortex_a53); + f32_igemm_config.minmax.gemm[XNN_MR_TO_INDEX(1)] = xnn_init_hmp_gemm_ukernel((xnn_gemm_ukernel_fn) xnn_f32_gemm_minmax_ukernel_1x8__asm_aarch64_neonfma_cortex_a53); + f32_igemm_config.minmax.gemm[XNN_MR_TO_INDEX(6)] = xnn_init_hmp_gemm_ukernel((xnn_gemm_ukernel_fn) xnn_f32_gemm_minmax_ukernel_6x8__asm_aarch64_neonfma_cortex_a53); + f32_igemm_config.minmax.igemm[XNN_MR_TO_INDEX(1)] = xnn_init_hmp_igemm_ukernel((xnn_igemm_ukernel_fn) xnn_f32_igemm_minmax_ukernel_1x8__asm_aarch64_neonfma_cortex_a53); + f32_igemm_config.minmax.igemm[XNN_MR_TO_INDEX(6)] = xnn_init_hmp_igemm_ukernel((xnn_igemm_ukernel_fn) xnn_f32_igemm_minmax_ukernel_6x8__asm_aarch64_neonfma_cortex_a53); #if XNN_ENABLE_GEMM_M_SPECIALIZATION - f32_gemm_config.minmax.gemm[XNN_MR_TO_INDEX(4)] = xnn_init_hmp_gemm_ukernel((xnn_gemm_ukernel_fn) xnn_f32_gemm_minmax_ukernel_4x8__asm_aarch64_neonfma_cortex_a53); - f32_gemm_config.minmax.igemm[XNN_MR_TO_INDEX(4)] = xnn_init_hmp_igemm_ukernel((xnn_igemm_ukernel_fn) xnn_f32_igemm_minmax_ukernel_4x8__asm_aarch64_neonfma_cortex_a53); + f32_igemm_config.minmax.gemm[XNN_MR_TO_INDEX(4)] = xnn_init_hmp_gemm_ukernel((xnn_gemm_ukernel_fn) xnn_f32_gemm_minmax_ukernel_4x8__asm_aarch64_neonfma_cortex_a53); + f32_igemm_config.minmax.igemm[XNN_MR_TO_INDEX(4)] = xnn_init_hmp_igemm_ukernel((xnn_igemm_ukernel_fn) xnn_f32_igemm_minmax_ukernel_4x8__asm_aarch64_neonfma_cortex_a53); #endif - f32_gemm_config.init.f32 = xnn_init_f32_minmax_scalar_params; - f32_gemm_config.pack_gemm_gio = (xnn_packw_gemm_gio_ukernel_fn) xnn_pack_f32_gemm_gio_w; - f32_gemm_config.pack_gemm_goi = (xnn_packw_gemm_goi_ukernel_fn) xnn_x32_packw_gemm_goi_ukernel_x8__neon_ld4lane_u4_prfm; - f32_gemm_config.mr = 6; - f32_gemm_config.nr = 8; + f32_igemm_config.init.f32 = xnn_init_f32_minmax_scalar_params; + f32_igemm_config.pack_gemm_gio = (xnn_packw_gemm_gio_ukernel_fn) xnn_pack_f32_gemm_gio_w; + f32_igemm_config.pack_gemm_goi = (xnn_packw_gemm_goi_ukernel_fn) xnn_x32_packw_gemm_goi_ukernel_x8__neon_ld4lane_u4_prfm; + f32_igemm_config.mr = 6; + f32_igemm_config.nr = 8; break; case cpuinfo_uarch_cortex_a35: case cpuinfo_uarch_cortex_a55: case cpuinfo_uarch_kryo: - f32_gemm_config.minmax.gemm[XNN_MR_TO_INDEX(1)] = xnn_init_hmp_gemm_ukernel((xnn_gemm_ukernel_fn) xnn_f32_gemm_minmax_ukernel_1x8__asm_aarch64_neonfma_cortex_a53); - f32_gemm_config.minmax.gemm[XNN_MR_TO_INDEX(6)] = xnn_init_hmp_gemm_ukernel((xnn_gemm_ukernel_fn) xnn_f32_gemm_minmax_ukernel_6x8__asm_aarch64_neonfma_cortex_a55); - f32_gemm_config.minmax.igemm[XNN_MR_TO_INDEX(1)] = xnn_init_hmp_igemm_ukernel((xnn_igemm_ukernel_fn) xnn_f32_igemm_minmax_ukernel_1x8__asm_aarch64_neonfma_cortex_a53); - f32_gemm_config.minmax.igemm[XNN_MR_TO_INDEX(6)] = xnn_init_hmp_igemm_ukernel((xnn_igemm_ukernel_fn) xnn_f32_igemm_minmax_ukernel_6x8__asm_aarch64_neonfma_cortex_a55); + f32_igemm_config.minmax.gemm[XNN_MR_TO_INDEX(1)] = xnn_init_hmp_gemm_ukernel((xnn_gemm_ukernel_fn) xnn_f32_gemm_minmax_ukernel_1x8__asm_aarch64_neonfma_cortex_a53); + f32_igemm_config.minmax.gemm[XNN_MR_TO_INDEX(6)] = xnn_init_hmp_gemm_ukernel((xnn_gemm_ukernel_fn) xnn_f32_gemm_minmax_ukernel_6x8__asm_aarch64_neonfma_cortex_a55); + f32_igemm_config.minmax.igemm[XNN_MR_TO_INDEX(1)] = xnn_init_hmp_igemm_ukernel((xnn_igemm_ukernel_fn) xnn_f32_igemm_minmax_ukernel_1x8__asm_aarch64_neonfma_cortex_a53); + f32_igemm_config.minmax.igemm[XNN_MR_TO_INDEX(6)] = xnn_init_hmp_igemm_ukernel((xnn_igemm_ukernel_fn) xnn_f32_igemm_minmax_ukernel_6x8__asm_aarch64_neonfma_cortex_a55); #if XNN_ENABLE_GEMM_M_SPECIALIZATION - f32_gemm_config.minmax.gemm[XNN_MR_TO_INDEX(4)] = xnn_init_hmp_gemm_ukernel((xnn_gemm_ukernel_fn) xnn_f32_gemm_minmax_ukernel_4x8__asm_aarch64_neonfma_cortex_a55); - f32_gemm_config.minmax.igemm[XNN_MR_TO_INDEX(4)] = xnn_init_hmp_igemm_ukernel((xnn_igemm_ukernel_fn) xnn_f32_igemm_minmax_ukernel_4x8__asm_aarch64_neonfma_cortex_a55); + f32_igemm_config.minmax.gemm[XNN_MR_TO_INDEX(4)] = xnn_init_hmp_gemm_ukernel((xnn_gemm_ukernel_fn) xnn_f32_gemm_minmax_ukernel_4x8__asm_aarch64_neonfma_cortex_a55); + f32_igemm_config.minmax.igemm[XNN_MR_TO_INDEX(4)] = xnn_init_hmp_igemm_ukernel((xnn_igemm_ukernel_fn) xnn_f32_igemm_minmax_ukernel_4x8__asm_aarch64_neonfma_cortex_a55); #endif - f32_gemm_config.init.f32 = xnn_init_f32_minmax_scalar_params; - f32_gemm_config.pack_gemm_gio = (xnn_packw_gemm_gio_ukernel_fn) xnn_pack_f32_gemm_gio_w; - f32_gemm_config.pack_gemm_goi = (xnn_packw_gemm_goi_ukernel_fn) xnn_x32_packw_gemm_goi_ukernel_x8__neon_ld4lane_u4_prfm; - f32_gemm_config.mr = 6; - f32_gemm_config.nr = 8; + f32_igemm_config.init.f32 = xnn_init_f32_minmax_scalar_params; + f32_igemm_config.pack_gemm_gio = (xnn_packw_gemm_gio_ukernel_fn) xnn_pack_f32_gemm_gio_w; + f32_igemm_config.pack_gemm_goi = (xnn_packw_gemm_goi_ukernel_fn) xnn_x32_packw_gemm_goi_ukernel_x8__neon_ld4lane_u4_prfm; + f32_igemm_config.mr = 6; + f32_igemm_config.nr = 8; break; case cpuinfo_uarch_cortex_a73: - f32_gemm_config.minmax.gemm[XNN_MR_TO_INDEX(1)] = xnn_init_hmp_gemm_ukernel((xnn_gemm_ukernel_fn) xnn_f32_gemm_minmax_ukernel_1x8__asm_aarch64_neonfma_cortex_a75_prfm); - f32_gemm_config.minmax.gemm[XNN_MR_TO_INDEX(6)] = xnn_init_hmp_gemm_ukernel((xnn_gemm_ukernel_fn) xnn_f32_gemm_minmax_ukernel_6x8__asm_aarch64_neonfma_cortex_a73); - f32_gemm_config.minmax.igemm[XNN_MR_TO_INDEX(1)] = xnn_init_hmp_igemm_ukernel((xnn_igemm_ukernel_fn) xnn_f32_igemm_minmax_ukernel_1x8__asm_aarch64_neonfma_cortex_a75_prfm); - f32_gemm_config.minmax.igemm[XNN_MR_TO_INDEX(6)] = xnn_init_hmp_igemm_ukernel((xnn_igemm_ukernel_fn) xnn_f32_igemm_minmax_ukernel_6x8__asm_aarch64_neonfma_cortex_a73); - f32_gemm_config.init.f32 = xnn_init_f32_minmax_scalar_params; - f32_gemm_config.pack_gemm_gio = (xnn_packw_gemm_gio_ukernel_fn) xnn_pack_f32_gemm_gio_w; - f32_gemm_config.pack_gemm_goi = (xnn_packw_gemm_goi_ukernel_fn) xnn_x32_packw_gemm_goi_ukernel_x8__neon_ld4lane_u4_prfm; - f32_gemm_config.mr = 6; - f32_gemm_config.nr = 8; + f32_igemm_config.minmax.gemm[XNN_MR_TO_INDEX(1)] = xnn_init_hmp_gemm_ukernel((xnn_gemm_ukernel_fn) xnn_f32_gemm_minmax_ukernel_1x8__asm_aarch64_neonfma_cortex_a75_prfm); + f32_igemm_config.minmax.gemm[XNN_MR_TO_INDEX(6)] = xnn_init_hmp_gemm_ukernel((xnn_gemm_ukernel_fn) xnn_f32_gemm_minmax_ukernel_6x8__asm_aarch64_neonfma_cortex_a73); + f32_igemm_config.minmax.igemm[XNN_MR_TO_INDEX(1)] = xnn_init_hmp_igemm_ukernel((xnn_igemm_ukernel_fn) xnn_f32_igemm_minmax_ukernel_1x8__asm_aarch64_neonfma_cortex_a75_prfm); + f32_igemm_config.minmax.igemm[XNN_MR_TO_INDEX(6)] = xnn_init_hmp_igemm_ukernel((xnn_igemm_ukernel_fn) xnn_f32_igemm_minmax_ukernel_6x8__asm_aarch64_neonfma_cortex_a73); + f32_igemm_config.init.f32 = xnn_init_f32_minmax_scalar_params; + f32_igemm_config.pack_gemm_gio = (xnn_packw_gemm_gio_ukernel_fn) xnn_pack_f32_gemm_gio_w; + f32_igemm_config.pack_gemm_goi = (xnn_packw_gemm_goi_ukernel_fn) xnn_x32_packw_gemm_goi_ukernel_x8__neon_ld4lane_u4_prfm; + f32_igemm_config.mr = 6; + f32_igemm_config.nr = 8; break; case cpuinfo_uarch_cortex_a77: case cpuinfo_uarch_exynos_m5: - f32_gemm_config.minmax.gemm[XNN_MR_TO_INDEX(1)] = xnn_init_hmp_gemm_ukernel((xnn_gemm_ukernel_fn) xnn_f32_gemm_minmax_ukernel_1x8__asm_aarch64_neonfma_cortex_a75); - f32_gemm_config.minmax.gemm[XNN_MR_TO_INDEX(4)] = xnn_init_hmp_gemm_ukernel((xnn_gemm_ukernel_fn) xnn_f32_gemm_minmax_ukernel_4x8__asm_aarch64_neonfma_cortex_a75); - f32_gemm_config.minmax.igemm[XNN_MR_TO_INDEX(1)] = xnn_init_hmp_igemm_ukernel((xnn_igemm_ukernel_fn) xnn_f32_igemm_minmax_ukernel_1x8__asm_aarch64_neonfma_cortex_a75); - f32_gemm_config.minmax.igemm[XNN_MR_TO_INDEX(4)] = xnn_init_hmp_igemm_ukernel((xnn_igemm_ukernel_fn) xnn_f32_igemm_minmax_ukernel_4x8__asm_aarch64_neonfma_cortex_a75); - f32_gemm_config.init.f32 = xnn_init_f32_minmax_scalar_params; - f32_gemm_config.pack_gemm_gio = (xnn_packw_gemm_gio_ukernel_fn) xnn_pack_f32_gemm_gio_w; - f32_gemm_config.pack_gemm_goi = (xnn_packw_gemm_goi_ukernel_fn) xnn_x32_packw_gemm_goi_ukernel_x8__neon_ld4lane_u4_prfm; - f32_gemm_config.mr = 4; - f32_gemm_config.nr = 8; + f32_igemm_config.minmax.gemm[XNN_MR_TO_INDEX(1)] = xnn_init_hmp_gemm_ukernel((xnn_gemm_ukernel_fn) xnn_f32_gemm_minmax_ukernel_1x8__asm_aarch64_neonfma_cortex_a75); + f32_igemm_config.minmax.gemm[XNN_MR_TO_INDEX(4)] = xnn_init_hmp_gemm_ukernel((xnn_gemm_ukernel_fn) xnn_f32_gemm_minmax_ukernel_4x8__asm_aarch64_neonfma_cortex_a75); + f32_igemm_config.minmax.igemm[XNN_MR_TO_INDEX(1)] = xnn_init_hmp_igemm_ukernel((xnn_igemm_ukernel_fn) xnn_f32_igemm_minmax_ukernel_1x8__asm_aarch64_neonfma_cortex_a75); + f32_igemm_config.minmax.igemm[XNN_MR_TO_INDEX(4)] = xnn_init_hmp_igemm_ukernel((xnn_igemm_ukernel_fn) xnn_f32_igemm_minmax_ukernel_4x8__asm_aarch64_neonfma_cortex_a75); + f32_igemm_config.init.f32 = xnn_init_f32_minmax_scalar_params; + f32_igemm_config.pack_gemm_gio = (xnn_packw_gemm_gio_ukernel_fn) xnn_pack_f32_gemm_gio_w; + f32_igemm_config.pack_gemm_goi = (xnn_packw_gemm_goi_ukernel_fn) xnn_x32_packw_gemm_goi_ukernel_x8__neon_ld4lane_u4_prfm; + f32_igemm_config.mr = 4; + f32_igemm_config.nr = 8; break; case cpuinfo_uarch_cortex_x3: case cpuinfo_uarch_neoverse_v2: // TODO(fbarchard): Implement asm with indexed inputs - f32_gemm_config.minmax.gemm[XNN_MR_TO_INDEX(1)] = xnn_init_hmp_gemm_ukernel((xnn_gemm_ukernel_fn) xnn_f32_gemm_minmax_ukernel_1x8__asm_aarch64_neonfma_ld128_acc2); - f32_gemm_config.minmax.gemm[XNN_MR_TO_INDEX(6)] = xnn_init_hmp_gemm_ukernel((xnn_gemm_ukernel_fn) xnn_f32_gemm_minmax_ukernel_6x8__aarch64_neonfma_lane_ld128); - f32_gemm_config.minmax.igemm[XNN_MR_TO_INDEX(1)] = xnn_init_hmp_igemm_ukernel((xnn_igemm_ukernel_fn) xnn_f32_igemm_minmax_ukernel_1x8__asm_aarch64_neonfma_ld64); - f32_gemm_config.minmax.igemm[XNN_MR_TO_INDEX(6)] = xnn_init_hmp_igemm_ukernel((xnn_igemm_ukernel_fn) xnn_f32_igemm_minmax_ukernel_6x8__aarch64_neonfma_lane_ld128); - f32_gemm_config.init.f32 = xnn_init_f32_minmax_scalar_params; - f32_gemm_config.pack_gemm_gio = (xnn_packw_gemm_gio_ukernel_fn) xnn_pack_f32_gemm_gio_w; - f32_gemm_config.pack_gemm_goi = (xnn_packw_gemm_goi_ukernel_fn) xnn_x32_packw_gemm_goi_ukernel_x8__neon_ld4lane_u4_prfm; - f32_gemm_config.mr = 6; - f32_gemm_config.nr = 8; + f32_igemm_config.minmax.gemm[XNN_MR_TO_INDEX(1)] = xnn_init_hmp_gemm_ukernel((xnn_gemm_ukernel_fn) xnn_f32_gemm_minmax_ukernel_1x8__asm_aarch64_neonfma_ld128_acc2); + f32_igemm_config.minmax.gemm[XNN_MR_TO_INDEX(6)] = xnn_init_hmp_gemm_ukernel((xnn_gemm_ukernel_fn) xnn_f32_gemm_minmax_ukernel_6x8__aarch64_neonfma_lane_ld128); + f32_igemm_config.minmax.igemm[XNN_MR_TO_INDEX(1)] = xnn_init_hmp_igemm_ukernel((xnn_igemm_ukernel_fn) xnn_f32_igemm_minmax_ukernel_1x8__asm_aarch64_neonfma_ld64); + f32_igemm_config.minmax.igemm[XNN_MR_TO_INDEX(6)] = xnn_init_hmp_igemm_ukernel((xnn_igemm_ukernel_fn) xnn_f32_igemm_minmax_ukernel_6x8__aarch64_neonfma_lane_ld128); + f32_igemm_config.init.f32 = xnn_init_f32_minmax_scalar_params; + f32_igemm_config.pack_gemm_gio = (xnn_packw_gemm_gio_ukernel_fn) xnn_pack_f32_gemm_gio_w; + f32_igemm_config.pack_gemm_goi = (xnn_packw_gemm_goi_ukernel_fn) xnn_x32_packw_gemm_goi_ukernel_x8__neon_ld4lane_u4_prfm; + f32_igemm_config.mr = 6; + f32_igemm_config.nr = 8; break; case cpuinfo_uarch_cortex_a78: case cpuinfo_uarch_cortex_a510: @@ -652,27 +655,27 @@ static void init_f32_gemm_config(void) { case cpuinfo_uarch_neoverse_n2: case cpuinfo_uarch_neoverse_v1: default: - f32_gemm_config.minmax.gemm[XNN_MR_TO_INDEX(1)] = xnn_init_hmp_gemm_ukernel((xnn_gemm_ukernel_fn) xnn_f32_gemm_minmax_ukernel_1x8__asm_aarch64_neonfma_ld128_acc4); - f32_gemm_config.minmax.gemm[XNN_MR_TO_INDEX(6)] = xnn_init_hmp_gemm_ukernel((xnn_gemm_ukernel_fn) xnn_f32_gemm_minmax_ukernel_6x8__asm_aarch64_neonfma_ld128); - f32_gemm_config.minmax.igemm[XNN_MR_TO_INDEX(1)] = xnn_init_hmp_igemm_ukernel((xnn_igemm_ukernel_fn) xnn_f32_igemm_minmax_ukernel_1x8__asm_aarch64_neonfma_ld64); - f32_gemm_config.minmax.igemm[XNN_MR_TO_INDEX(6)] = xnn_init_hmp_igemm_ukernel((xnn_igemm_ukernel_fn) xnn_f32_igemm_minmax_ukernel_6x8__asm_aarch64_neonfma_ld128); + f32_igemm_config.minmax.gemm[XNN_MR_TO_INDEX(1)] = xnn_init_hmp_gemm_ukernel((xnn_gemm_ukernel_fn) xnn_f32_gemm_minmax_ukernel_1x8__asm_aarch64_neonfma_ld128_acc4); + f32_igemm_config.minmax.gemm[XNN_MR_TO_INDEX(6)] = xnn_init_hmp_gemm_ukernel((xnn_gemm_ukernel_fn) xnn_f32_gemm_minmax_ukernel_6x8__asm_aarch64_neonfma_ld128); + f32_igemm_config.minmax.igemm[XNN_MR_TO_INDEX(1)] = xnn_init_hmp_igemm_ukernel((xnn_igemm_ukernel_fn) xnn_f32_igemm_minmax_ukernel_1x8__asm_aarch64_neonfma_ld64); + f32_igemm_config.minmax.igemm[XNN_MR_TO_INDEX(6)] = xnn_init_hmp_igemm_ukernel((xnn_igemm_ukernel_fn) xnn_f32_igemm_minmax_ukernel_6x8__asm_aarch64_neonfma_ld128); #if XNN_ENABLE_GEMM_M_SPECIALIZATION - f32_gemm_config.minmax.gemm[XNN_MR_TO_INDEX(4)] = xnn_init_hmp_gemm_ukernel((xnn_gemm_ukernel_fn) xnn_f32_gemm_minmax_ukernel_4x8__asm_aarch64_neonfma_ld128); - f32_gemm_config.minmax.igemm[XNN_MR_TO_INDEX(4)] = xnn_init_hmp_igemm_ukernel((xnn_igemm_ukernel_fn) xnn_f32_igemm_minmax_ukernel_4x8__asm_aarch64_neonfma_ld128); + f32_igemm_config.minmax.gemm[XNN_MR_TO_INDEX(4)] = xnn_init_hmp_gemm_ukernel((xnn_gemm_ukernel_fn) xnn_f32_gemm_minmax_ukernel_4x8__asm_aarch64_neonfma_ld128); + f32_igemm_config.minmax.igemm[XNN_MR_TO_INDEX(4)] = xnn_init_hmp_igemm_ukernel((xnn_igemm_ukernel_fn) xnn_f32_igemm_minmax_ukernel_4x8__asm_aarch64_neonfma_ld128); #endif - f32_gemm_config.init.f32 = xnn_init_f32_minmax_scalar_params; - f32_gemm_config.pack_gemm_gio = (xnn_packw_gemm_gio_ukernel_fn) xnn_pack_f32_gemm_gio_w; - f32_gemm_config.pack_gemm_goi = (xnn_packw_gemm_goi_ukernel_fn) xnn_x32_packw_gemm_goi_ukernel_x8__neon_ld4lane_u4_prfm; - f32_gemm_config.mr = 6; - f32_gemm_config.nr = 8; + f32_igemm_config.init.f32 = xnn_init_f32_minmax_scalar_params; + f32_igemm_config.pack_gemm_gio = (xnn_packw_gemm_gio_ukernel_fn) xnn_pack_f32_gemm_gio_w; + f32_igemm_config.pack_gemm_goi = (xnn_packw_gemm_goi_ukernel_fn) xnn_x32_packw_gemm_goi_ukernel_x8__neon_ld4lane_u4_prfm; + f32_igemm_config.mr = 6; + f32_igemm_config.nr = 8; break; } #if XNN_MAX_UARCH_TYPES > 1 { /* Choose micro-kernels for little cores according to micro-kernel specification for the big core */ - const uint32_t mr = f32_gemm_config.mr; - const uint32_t nr = f32_gemm_config.nr; - const uint32_t log2_sr = f32_gemm_config.log2_sr; + const uint32_t mr = f32_igemm_config.mr; + const uint32_t nr = f32_igemm_config.nr; + const uint32_t log2_sr = f32_igemm_config.log2_sr; for (size_t i = 1; i < XNN_MAX_UARCH_TYPES; i++) { const struct cpuinfo_uarch_info* uarch_info = cpuinfo_get_uarch(i); if (uarch_info == NULL) { @@ -683,53 +686,53 @@ static void init_f32_gemm_config(void) { switch (uarch_info->uarch) { case cpuinfo_uarch_cortex_a53: if (mr == 6 && nr == 8 && log2_sr == 0) { - f32_gemm_config.minmax.gemm[XNN_MR_TO_INDEX(1)].function[i] = (xnn_gemm_ukernel_fn) xnn_f32_gemm_minmax_ukernel_1x8__asm_aarch64_neonfma_cortex_a53_prfm; - f32_gemm_config.minmax.gemm[XNN_MR_TO_INDEX(6)].function[i] = (xnn_gemm_ukernel_fn) xnn_f32_gemm_minmax_ukernel_6x8__asm_aarch64_neonfma_cortex_a53_prfm; - f32_gemm_config.minmax.igemm[XNN_MR_TO_INDEX(1)].function[i] = (xnn_igemm_ukernel_fn) xnn_f32_igemm_minmax_ukernel_1x8__asm_aarch64_neonfma_cortex_a53_prfm; - f32_gemm_config.minmax.igemm[XNN_MR_TO_INDEX(6)].function[i] = (xnn_igemm_ukernel_fn) xnn_f32_igemm_minmax_ukernel_6x8__asm_aarch64_neonfma_cortex_a53_prfm; + f32_igemm_config.minmax.gemm[XNN_MR_TO_INDEX(1)].function[i] = (xnn_gemm_ukernel_fn) xnn_f32_gemm_minmax_ukernel_1x8__asm_aarch64_neonfma_cortex_a53_prfm; + f32_igemm_config.minmax.gemm[XNN_MR_TO_INDEX(6)].function[i] = (xnn_gemm_ukernel_fn) xnn_f32_gemm_minmax_ukernel_6x8__asm_aarch64_neonfma_cortex_a53_prfm; + f32_igemm_config.minmax.igemm[XNN_MR_TO_INDEX(1)].function[i] = (xnn_igemm_ukernel_fn) xnn_f32_igemm_minmax_ukernel_1x8__asm_aarch64_neonfma_cortex_a53_prfm; + f32_igemm_config.minmax.igemm[XNN_MR_TO_INDEX(6)].function[i] = (xnn_igemm_ukernel_fn) xnn_f32_igemm_minmax_ukernel_6x8__asm_aarch64_neonfma_cortex_a53_prfm; #if XNN_ENABLE_GEMM_M_SPECIALIZATION - f32_gemm_config.minmax.gemm[XNN_MR_TO_INDEX(4)].function[i] = (xnn_gemm_ukernel_fn) xnn_f32_gemm_minmax_ukernel_4x8__asm_aarch64_neonfma_cortex_a53_prfm; - f32_gemm_config.minmax.igemm[XNN_MR_TO_INDEX(4)].function[i] = (xnn_igemm_ukernel_fn) xnn_f32_igemm_minmax_ukernel_4x8__asm_aarch64_neonfma_cortex_a53_prfm; + f32_igemm_config.minmax.gemm[XNN_MR_TO_INDEX(4)].function[i] = (xnn_gemm_ukernel_fn) xnn_f32_gemm_minmax_ukernel_4x8__asm_aarch64_neonfma_cortex_a53_prfm; + f32_igemm_config.minmax.igemm[XNN_MR_TO_INDEX(4)].function[i] = (xnn_igemm_ukernel_fn) xnn_f32_igemm_minmax_ukernel_4x8__asm_aarch64_neonfma_cortex_a53_prfm; #endif } else if (mr == 4 && nr == 8 && log2_sr == 0) { - f32_gemm_config.minmax.gemm[XNN_MR_TO_INDEX(1)].function[i] = (xnn_gemm_ukernel_fn) xnn_f32_gemm_minmax_ukernel_1x8__asm_aarch64_neonfma_cortex_a53_prfm; - f32_gemm_config.minmax.gemm[XNN_MR_TO_INDEX(4)].function[i] = (xnn_gemm_ukernel_fn) xnn_f32_gemm_minmax_ukernel_4x8__asm_aarch64_neonfma_cortex_a53_prfm; - f32_gemm_config.minmax.igemm[XNN_MR_TO_INDEX(1)].function[i] = (xnn_igemm_ukernel_fn) xnn_f32_igemm_minmax_ukernel_1x8__asm_aarch64_neonfma_cortex_a53_prfm; - f32_gemm_config.minmax.igemm[XNN_MR_TO_INDEX(4)].function[i] = (xnn_igemm_ukernel_fn) xnn_f32_igemm_minmax_ukernel_4x8__asm_aarch64_neonfma_cortex_a53_prfm; + f32_igemm_config.minmax.gemm[XNN_MR_TO_INDEX(1)].function[i] = (xnn_gemm_ukernel_fn) xnn_f32_gemm_minmax_ukernel_1x8__asm_aarch64_neonfma_cortex_a53_prfm; + f32_igemm_config.minmax.gemm[XNN_MR_TO_INDEX(4)].function[i] = (xnn_gemm_ukernel_fn) xnn_f32_gemm_minmax_ukernel_4x8__asm_aarch64_neonfma_cortex_a53_prfm; + f32_igemm_config.minmax.igemm[XNN_MR_TO_INDEX(1)].function[i] = (xnn_igemm_ukernel_fn) xnn_f32_igemm_minmax_ukernel_1x8__asm_aarch64_neonfma_cortex_a53_prfm; + f32_igemm_config.minmax.igemm[XNN_MR_TO_INDEX(4)].function[i] = (xnn_igemm_ukernel_fn) xnn_f32_igemm_minmax_ukernel_4x8__asm_aarch64_neonfma_cortex_a53_prfm; } break; case cpuinfo_uarch_cortex_a55r0: if (mr == 6 && nr == 8 && log2_sr == 0) { - f32_gemm_config.minmax.gemm[XNN_MR_TO_INDEX(1)].function[i] = (xnn_gemm_ukernel_fn) xnn_f32_gemm_minmax_ukernel_1x8__asm_aarch64_neonfma_cortex_a53; - f32_gemm_config.minmax.gemm[XNN_MR_TO_INDEX(6)].function[i] = (xnn_gemm_ukernel_fn) xnn_f32_gemm_minmax_ukernel_6x8__asm_aarch64_neonfma_cortex_a53; - f32_gemm_config.minmax.igemm[XNN_MR_TO_INDEX(1)].function[i] = (xnn_igemm_ukernel_fn) xnn_f32_igemm_minmax_ukernel_1x8__asm_aarch64_neonfma_cortex_a53; - f32_gemm_config.minmax.igemm[XNN_MR_TO_INDEX(6)].function[i] = (xnn_igemm_ukernel_fn) xnn_f32_igemm_minmax_ukernel_6x8__asm_aarch64_neonfma_cortex_a53; + f32_igemm_config.minmax.gemm[XNN_MR_TO_INDEX(1)].function[i] = (xnn_gemm_ukernel_fn) xnn_f32_gemm_minmax_ukernel_1x8__asm_aarch64_neonfma_cortex_a53; + f32_igemm_config.minmax.gemm[XNN_MR_TO_INDEX(6)].function[i] = (xnn_gemm_ukernel_fn) xnn_f32_gemm_minmax_ukernel_6x8__asm_aarch64_neonfma_cortex_a53; + f32_igemm_config.minmax.igemm[XNN_MR_TO_INDEX(1)].function[i] = (xnn_igemm_ukernel_fn) xnn_f32_igemm_minmax_ukernel_1x8__asm_aarch64_neonfma_cortex_a53; + f32_igemm_config.minmax.igemm[XNN_MR_TO_INDEX(6)].function[i] = (xnn_igemm_ukernel_fn) xnn_f32_igemm_minmax_ukernel_6x8__asm_aarch64_neonfma_cortex_a53; #if XNN_ENABLE_GEMM_M_SPECIALIZATION - f32_gemm_config.minmax.gemm[XNN_MR_TO_INDEX(4)].function[i] = (xnn_gemm_ukernel_fn) xnn_f32_gemm_minmax_ukernel_4x8__asm_aarch64_neonfma_cortex_a53; - f32_gemm_config.minmax.igemm[XNN_MR_TO_INDEX(4)].function[i] = (xnn_igemm_ukernel_fn) xnn_f32_igemm_minmax_ukernel_4x8__asm_aarch64_neonfma_cortex_a53; + f32_igemm_config.minmax.gemm[XNN_MR_TO_INDEX(4)].function[i] = (xnn_gemm_ukernel_fn) xnn_f32_gemm_minmax_ukernel_4x8__asm_aarch64_neonfma_cortex_a53; + f32_igemm_config.minmax.igemm[XNN_MR_TO_INDEX(4)].function[i] = (xnn_igemm_ukernel_fn) xnn_f32_igemm_minmax_ukernel_4x8__asm_aarch64_neonfma_cortex_a53; #endif } else if (mr == 4 && nr == 8 && log2_sr == 0) { - f32_gemm_config.minmax.gemm[XNN_MR_TO_INDEX(1)].function[i] = (xnn_gemm_ukernel_fn) xnn_f32_gemm_minmax_ukernel_1x8__asm_aarch64_neonfma_cortex_a53; - f32_gemm_config.minmax.gemm[XNN_MR_TO_INDEX(4)].function[i] = (xnn_gemm_ukernel_fn) xnn_f32_gemm_minmax_ukernel_4x8__asm_aarch64_neonfma_cortex_a53; - f32_gemm_config.minmax.igemm[XNN_MR_TO_INDEX(1)].function[i] = (xnn_igemm_ukernel_fn) xnn_f32_igemm_minmax_ukernel_1x8__asm_aarch64_neonfma_cortex_a53; - f32_gemm_config.minmax.igemm[XNN_MR_TO_INDEX(4)].function[i] = (xnn_igemm_ukernel_fn) xnn_f32_igemm_minmax_ukernel_4x8__asm_aarch64_neonfma_cortex_a53; + f32_igemm_config.minmax.gemm[XNN_MR_TO_INDEX(1)].function[i] = (xnn_gemm_ukernel_fn) xnn_f32_gemm_minmax_ukernel_1x8__asm_aarch64_neonfma_cortex_a53; + f32_igemm_config.minmax.gemm[XNN_MR_TO_INDEX(4)].function[i] = (xnn_gemm_ukernel_fn) xnn_f32_gemm_minmax_ukernel_4x8__asm_aarch64_neonfma_cortex_a53; + f32_igemm_config.minmax.igemm[XNN_MR_TO_INDEX(1)].function[i] = (xnn_igemm_ukernel_fn) xnn_f32_igemm_minmax_ukernel_1x8__asm_aarch64_neonfma_cortex_a53; + f32_igemm_config.minmax.igemm[XNN_MR_TO_INDEX(4)].function[i] = (xnn_igemm_ukernel_fn) xnn_f32_igemm_minmax_ukernel_4x8__asm_aarch64_neonfma_cortex_a53; } break; case cpuinfo_uarch_cortex_a55: if (mr == 6 && nr == 8 && log2_sr == 0) { - f32_gemm_config.minmax.gemm[XNN_MR_TO_INDEX(1)].function[i] = (xnn_gemm_ukernel_fn) xnn_f32_gemm_minmax_ukernel_1x8__asm_aarch64_neonfma_cortex_a53; - f32_gemm_config.minmax.gemm[XNN_MR_TO_INDEX(6)].function[i] = (xnn_gemm_ukernel_fn) xnn_f32_gemm_minmax_ukernel_6x8__asm_aarch64_neonfma_cortex_a55; - f32_gemm_config.minmax.igemm[XNN_MR_TO_INDEX(1)].function[i] = (xnn_igemm_ukernel_fn) xnn_f32_igemm_minmax_ukernel_1x8__asm_aarch64_neonfma_cortex_a53; - f32_gemm_config.minmax.igemm[XNN_MR_TO_INDEX(6)].function[i] = (xnn_igemm_ukernel_fn) xnn_f32_igemm_minmax_ukernel_6x8__asm_aarch64_neonfma_cortex_a55; + f32_igemm_config.minmax.gemm[XNN_MR_TO_INDEX(1)].function[i] = (xnn_gemm_ukernel_fn) xnn_f32_gemm_minmax_ukernel_1x8__asm_aarch64_neonfma_cortex_a53; + f32_igemm_config.minmax.gemm[XNN_MR_TO_INDEX(6)].function[i] = (xnn_gemm_ukernel_fn) xnn_f32_gemm_minmax_ukernel_6x8__asm_aarch64_neonfma_cortex_a55; + f32_igemm_config.minmax.igemm[XNN_MR_TO_INDEX(1)].function[i] = (xnn_igemm_ukernel_fn) xnn_f32_igemm_minmax_ukernel_1x8__asm_aarch64_neonfma_cortex_a53; + f32_igemm_config.minmax.igemm[XNN_MR_TO_INDEX(6)].function[i] = (xnn_igemm_ukernel_fn) xnn_f32_igemm_minmax_ukernel_6x8__asm_aarch64_neonfma_cortex_a55; #if XNN_ENABLE_GEMM_M_SPECIALIZATION - f32_gemm_config.minmax.gemm[XNN_MR_TO_INDEX(4)].function[i] = (xnn_gemm_ukernel_fn) xnn_f32_gemm_minmax_ukernel_4x8__asm_aarch64_neonfma_cortex_a55; - f32_gemm_config.minmax.igemm[XNN_MR_TO_INDEX(4)].function[i] = (xnn_igemm_ukernel_fn) xnn_f32_igemm_minmax_ukernel_4x8__asm_aarch64_neonfma_cortex_a55; + f32_igemm_config.minmax.gemm[XNN_MR_TO_INDEX(4)].function[i] = (xnn_gemm_ukernel_fn) xnn_f32_gemm_minmax_ukernel_4x8__asm_aarch64_neonfma_cortex_a55; + f32_igemm_config.minmax.igemm[XNN_MR_TO_INDEX(4)].function[i] = (xnn_igemm_ukernel_fn) xnn_f32_igemm_minmax_ukernel_4x8__asm_aarch64_neonfma_cortex_a55; #endif } else if (mr == 4 && nr == 8 && log2_sr == 0) { - f32_gemm_config.minmax.gemm[XNN_MR_TO_INDEX(1)].function[i] = (xnn_gemm_ukernel_fn) xnn_f32_gemm_minmax_ukernel_1x8__asm_aarch64_neonfma_cortex_a53; - f32_gemm_config.minmax.gemm[XNN_MR_TO_INDEX(4)].function[i] = (xnn_gemm_ukernel_fn) xnn_f32_gemm_minmax_ukernel_4x8__asm_aarch64_neonfma_cortex_a55; - f32_gemm_config.minmax.igemm[XNN_MR_TO_INDEX(1)].function[i] = (xnn_igemm_ukernel_fn) xnn_f32_igemm_minmax_ukernel_1x8__asm_aarch64_neonfma_cortex_a53; - f32_gemm_config.minmax.igemm[XNN_MR_TO_INDEX(4)].function[i] = (xnn_igemm_ukernel_fn) xnn_f32_igemm_minmax_ukernel_4x8__asm_aarch64_neonfma_cortex_a55; + f32_igemm_config.minmax.gemm[XNN_MR_TO_INDEX(1)].function[i] = (xnn_gemm_ukernel_fn) xnn_f32_gemm_minmax_ukernel_1x8__asm_aarch64_neonfma_cortex_a53; + f32_igemm_config.minmax.gemm[XNN_MR_TO_INDEX(4)].function[i] = (xnn_gemm_ukernel_fn) xnn_f32_gemm_minmax_ukernel_4x8__asm_aarch64_neonfma_cortex_a55; + f32_igemm_config.minmax.igemm[XNN_MR_TO_INDEX(1)].function[i] = (xnn_igemm_ukernel_fn) xnn_f32_igemm_minmax_ukernel_1x8__asm_aarch64_neonfma_cortex_a53; + f32_igemm_config.minmax.igemm[XNN_MR_TO_INDEX(4)].function[i] = (xnn_igemm_ukernel_fn) xnn_f32_igemm_minmax_ukernel_4x8__asm_aarch64_neonfma_cortex_a55; } break; default: @@ -740,25 +743,25 @@ static void init_f32_gemm_config(void) { #endif // XNN_MAX_UARCH_TYPES > 1 #else // XNN_ENABLE_ASSEMBLY && !XNN_PLATFORM_IOS && !XNN_PLATFORM_MAC #if XNN_ENABLE_ASSEMBLY - f32_gemm_config.minmax.gemm[XNN_MR_TO_INDEX(1)] = xnn_init_hmp_gemm_ukernel((xnn_gemm_ukernel_fn) xnn_f32_gemm_minmax_ukernel_1x8__asm_aarch64_neonfma_ld128_acc4); - f32_gemm_config.minmax.gemm[XNN_MR_TO_INDEX(6)] = xnn_init_hmp_gemm_ukernel((xnn_gemm_ukernel_fn) xnn_f32_gemm_minmax_ukernel_6x8__asm_aarch64_neonfma_ld128); - f32_gemm_config.minmax.igemm[XNN_MR_TO_INDEX(1)] = xnn_init_hmp_igemm_ukernel((xnn_igemm_ukernel_fn) xnn_f32_igemm_minmax_ukernel_1x8__asm_aarch64_neonfma_ld64); - f32_gemm_config.minmax.igemm[XNN_MR_TO_INDEX(6)] = xnn_init_hmp_igemm_ukernel((xnn_igemm_ukernel_fn) xnn_f32_igemm_minmax_ukernel_6x8__asm_aarch64_neonfma_ld128); - f32_gemm_config.init.f32 = xnn_init_f32_minmax_scalar_params; - f32_gemm_config.pack_gemm_gio = (xnn_packw_gemm_gio_ukernel_fn) xnn_pack_f32_gemm_gio_w; - f32_gemm_config.pack_gemm_goi = (xnn_packw_gemm_goi_ukernel_fn) xnn_x32_packw_gemm_goi_ukernel_x8__neon_ld4lane_u4_prfm; - f32_gemm_config.mr = 6; - f32_gemm_config.nr = 8; + f32_igemm_config.minmax.gemm[XNN_MR_TO_INDEX(1)] = xnn_init_hmp_gemm_ukernel((xnn_gemm_ukernel_fn) xnn_f32_gemm_minmax_ukernel_1x8__asm_aarch64_neonfma_ld128_acc4); + f32_igemm_config.minmax.gemm[XNN_MR_TO_INDEX(6)] = xnn_init_hmp_gemm_ukernel((xnn_gemm_ukernel_fn) xnn_f32_gemm_minmax_ukernel_6x8__asm_aarch64_neonfma_ld128); + f32_igemm_config.minmax.igemm[XNN_MR_TO_INDEX(1)] = xnn_init_hmp_igemm_ukernel((xnn_igemm_ukernel_fn) xnn_f32_igemm_minmax_ukernel_1x8__asm_aarch64_neonfma_ld64); + f32_igemm_config.minmax.igemm[XNN_MR_TO_INDEX(6)] = xnn_init_hmp_igemm_ukernel((xnn_igemm_ukernel_fn) xnn_f32_igemm_minmax_ukernel_6x8__asm_aarch64_neonfma_ld128); + f32_igemm_config.init.f32 = xnn_init_f32_minmax_scalar_params; + f32_igemm_config.pack_gemm_gio = (xnn_packw_gemm_gio_ukernel_fn) xnn_pack_f32_gemm_gio_w; + f32_igemm_config.pack_gemm_goi = (xnn_packw_gemm_goi_ukernel_fn) xnn_x32_packw_gemm_goi_ukernel_x8__neon_ld4lane_u4_prfm; + f32_igemm_config.mr = 6; + f32_igemm_config.nr = 8; #else // !XNN_ENABLE_ASSEMBLY - f32_gemm_config.minmax.gemm[XNN_MR_TO_INDEX(1)] = xnn_init_hmp_gemm_ukernel((xnn_gemm_ukernel_fn) xnn_f32_gemm_minmax_ukernel_1x8__aarch64_neonfma_lane_ld128); - f32_gemm_config.minmax.gemm[XNN_MR_TO_INDEX(6)] = xnn_init_hmp_gemm_ukernel((xnn_gemm_ukernel_fn) xnn_f32_gemm_minmax_ukernel_6x8__aarch64_neonfma_lane_ld128); - f32_gemm_config.minmax.igemm[XNN_MR_TO_INDEX(1)] = xnn_init_hmp_igemm_ukernel((xnn_igemm_ukernel_fn) xnn_f32_igemm_minmax_ukernel_1x8__aarch64_neonfma_lane_ld128); - f32_gemm_config.minmax.igemm[XNN_MR_TO_INDEX(6)] = xnn_init_hmp_igemm_ukernel((xnn_igemm_ukernel_fn) xnn_f32_igemm_minmax_ukernel_6x8__aarch64_neonfma_lane_ld128); - f32_gemm_config.init.f32 = xnn_init_f32_minmax_scalar_params; - f32_gemm_config.pack_gemm_gio = (xnn_packw_gemm_gio_ukernel_fn) xnn_pack_f32_gemm_gio_w; - f32_gemm_config.pack_gemm_goi = (xnn_packw_gemm_goi_ukernel_fn) xnn_x32_packw_gemm_goi_ukernel_x8__neon_ld4lane_u4_prfm; - f32_gemm_config.mr = 6; - f32_gemm_config.nr = 8; + f32_igemm_config.minmax.gemm[XNN_MR_TO_INDEX(1)] = xnn_init_hmp_gemm_ukernel((xnn_gemm_ukernel_fn) xnn_f32_gemm_minmax_ukernel_1x8__aarch64_neonfma_lane_ld128); + f32_igemm_config.minmax.gemm[XNN_MR_TO_INDEX(6)] = xnn_init_hmp_gemm_ukernel((xnn_gemm_ukernel_fn) xnn_f32_gemm_minmax_ukernel_6x8__aarch64_neonfma_lane_ld128); + f32_igemm_config.minmax.igemm[XNN_MR_TO_INDEX(1)] = xnn_init_hmp_igemm_ukernel((xnn_igemm_ukernel_fn) xnn_f32_igemm_minmax_ukernel_1x8__aarch64_neonfma_lane_ld128); + f32_igemm_config.minmax.igemm[XNN_MR_TO_INDEX(6)] = xnn_init_hmp_igemm_ukernel((xnn_igemm_ukernel_fn) xnn_f32_igemm_minmax_ukernel_6x8__aarch64_neonfma_lane_ld128); + f32_igemm_config.init.f32 = xnn_init_f32_minmax_scalar_params; + f32_igemm_config.pack_gemm_gio = (xnn_packw_gemm_gio_ukernel_fn) xnn_pack_f32_gemm_gio_w; + f32_igemm_config.pack_gemm_goi = (xnn_packw_gemm_goi_ukernel_fn) xnn_x32_packw_gemm_goi_ukernel_x8__neon_ld4lane_u4_prfm; + f32_igemm_config.mr = 6; + f32_igemm_config.nr = 8; #endif // XNN_ENABLE_ASSEMBLY #endif // XNN_ENABLE_ASSEMBLY && !XNN_PLATFORM_IOS && !XNN_PLATFORM_MAC #elif XNN_ARCH_X86 || XNN_ARCH_X86_64 @@ -767,64 +770,64 @@ static void init_f32_gemm_config(void) { (void) hardware_config; // May be unused. #if XNN_ENABLE_AVX512F if (!XNN_PLATFORM_MOBILE && hardware_config->use_x86_avx512f) { - f32_gemm_config.minmax.gemm[XNN_MR_TO_INDEX(1)] = xnn_init_hmp_gemm_ukernel((xnn_gemm_ukernel_fn) xnn_f32_gemm_minmax_ukernel_1x32__avx512f_broadcast); - f32_gemm_config.minmax.gemm[XNN_MR_TO_INDEX(7)] = xnn_init_hmp_gemm_ukernel((xnn_gemm_ukernel_fn) xnn_f32_gemm_minmax_ukernel_7x32__avx512f_broadcast); - f32_gemm_config.minmax.igemm[XNN_MR_TO_INDEX(1)] = xnn_init_hmp_igemm_ukernel((xnn_igemm_ukernel_fn) xnn_f32_igemm_minmax_ukernel_1x32__avx512f_broadcast); - f32_gemm_config.minmax.igemm[XNN_MR_TO_INDEX(7)] = xnn_init_hmp_igemm_ukernel((xnn_igemm_ukernel_fn) xnn_f32_igemm_minmax_ukernel_7x32__avx512f_broadcast); - f32_gemm_config.init.f32 = xnn_init_f32_minmax_scalar_params; - f32_gemm_config.pack_gemm_gio = (xnn_packw_gemm_gio_ukernel_fn) xnn_x32_packw_gemm_gio_ukernel_x32__avx512f_u8; - f32_gemm_config.pack_gemm_goi = (xnn_packw_gemm_goi_ukernel_fn) xnn_x32_packw_gemm_goi_ukernel_x32__avx512f_u4_prfm; - f32_gemm_config.mr = 7; - f32_gemm_config.nr = 32; + f32_igemm_config.minmax.gemm[XNN_MR_TO_INDEX(1)] = xnn_init_hmp_gemm_ukernel((xnn_gemm_ukernel_fn) xnn_f32_gemm_minmax_ukernel_1x32__avx512f_broadcast); + f32_igemm_config.minmax.gemm[XNN_MR_TO_INDEX(7)] = xnn_init_hmp_gemm_ukernel((xnn_gemm_ukernel_fn) xnn_f32_gemm_minmax_ukernel_7x32__avx512f_broadcast); + f32_igemm_config.minmax.igemm[XNN_MR_TO_INDEX(1)] = xnn_init_hmp_igemm_ukernel((xnn_igemm_ukernel_fn) xnn_f32_igemm_minmax_ukernel_1x32__avx512f_broadcast); + f32_igemm_config.minmax.igemm[XNN_MR_TO_INDEX(7)] = xnn_init_hmp_igemm_ukernel((xnn_igemm_ukernel_fn) xnn_f32_igemm_minmax_ukernel_7x32__avx512f_broadcast); + f32_igemm_config.init.f32 = xnn_init_f32_minmax_scalar_params; + f32_igemm_config.pack_gemm_gio = (xnn_packw_gemm_gio_ukernel_fn) xnn_x32_packw_gemm_gio_ukernel_x32__avx512f_u8; + f32_igemm_config.pack_gemm_goi = (xnn_packw_gemm_goi_ukernel_fn) xnn_x32_packw_gemm_goi_ukernel_x32__avx512f_u4_prfm; + f32_igemm_config.mr = 7; + f32_igemm_config.nr = 32; } else #endif if (hardware_config->use_x86_fma3) { switch (cpuinfo_get_core(0)->uarch) { case cpuinfo_uarch_zen: case cpuinfo_uarch_dhyana: - f32_gemm_config.minmax.gemm[XNN_MR_TO_INDEX(1)] = xnn_init_hmp_gemm_ukernel((xnn_gemm_ukernel_fn) xnn_f32_gemm_minmax_ukernel_1x16s4__fma3_broadcast); - f32_gemm_config.minmax.gemm[XNN_MR_TO_INDEX(4)] = xnn_init_hmp_gemm_ukernel((xnn_gemm_ukernel_fn) xnn_f32_gemm_minmax_ukernel_4x16s4__fma3_broadcast); - f32_gemm_config.minmax.igemm[XNN_MR_TO_INDEX(1)] = xnn_init_hmp_igemm_ukernel((xnn_igemm_ukernel_fn) xnn_f32_igemm_minmax_ukernel_1x16s4__fma3_broadcast); - f32_gemm_config.minmax.igemm[XNN_MR_TO_INDEX(4)] = xnn_init_hmp_igemm_ukernel((xnn_igemm_ukernel_fn) xnn_f32_igemm_minmax_ukernel_4x16s4__fma3_broadcast); - f32_gemm_config.init.f32 = xnn_init_f32_minmax_scalar_params; - f32_gemm_config.pack_gemm_gio = (xnn_packw_gemm_gio_ukernel_fn) xnn_pack_f32_gemm_gio_w; - f32_gemm_config.pack_gemm_goi = (xnn_packw_gemm_goi_ukernel_fn) xnn_x32_packw_gemm_goi_ukernel_x16s4__avx_u4; - f32_gemm_config.mr = 4; - f32_gemm_config.nr = 16; - f32_gemm_config.log2_sr = 2; + f32_igemm_config.minmax.gemm[XNN_MR_TO_INDEX(1)] = xnn_init_hmp_gemm_ukernel((xnn_gemm_ukernel_fn) xnn_f32_gemm_minmax_ukernel_1x16s4__fma3_broadcast); + f32_igemm_config.minmax.gemm[XNN_MR_TO_INDEX(4)] = xnn_init_hmp_gemm_ukernel((xnn_gemm_ukernel_fn) xnn_f32_gemm_minmax_ukernel_4x16s4__fma3_broadcast); + f32_igemm_config.minmax.igemm[XNN_MR_TO_INDEX(1)] = xnn_init_hmp_igemm_ukernel((xnn_igemm_ukernel_fn) xnn_f32_igemm_minmax_ukernel_1x16s4__fma3_broadcast); + f32_igemm_config.minmax.igemm[XNN_MR_TO_INDEX(4)] = xnn_init_hmp_igemm_ukernel((xnn_igemm_ukernel_fn) xnn_f32_igemm_minmax_ukernel_4x16s4__fma3_broadcast); + f32_igemm_config.init.f32 = xnn_init_f32_minmax_scalar_params; + f32_igemm_config.pack_gemm_gio = (xnn_packw_gemm_gio_ukernel_fn) xnn_pack_f32_gemm_gio_w; + f32_igemm_config.pack_gemm_goi = (xnn_packw_gemm_goi_ukernel_fn) xnn_x32_packw_gemm_goi_ukernel_x16s4__avx_u4; + f32_igemm_config.mr = 4; + f32_igemm_config.nr = 16; + f32_igemm_config.log2_sr = 2; break; default: - f32_gemm_config.minmax.gemm[XNN_MR_TO_INDEX(1)] = xnn_init_hmp_gemm_ukernel((xnn_gemm_ukernel_fn) xnn_f32_gemm_minmax_ukernel_1x16__fma3_broadcast); - f32_gemm_config.minmax.gemm[XNN_MR_TO_INDEX(5)] = xnn_init_hmp_gemm_ukernel((xnn_gemm_ukernel_fn) xnn_f32_gemm_minmax_ukernel_5x16__fma3_broadcast); - f32_gemm_config.minmax.igemm[XNN_MR_TO_INDEX(1)] = xnn_init_hmp_igemm_ukernel((xnn_igemm_ukernel_fn) xnn_f32_igemm_minmax_ukernel_1x16__fma3_broadcast); - f32_gemm_config.minmax.igemm[XNN_MR_TO_INDEX(5)] = xnn_init_hmp_igemm_ukernel((xnn_igemm_ukernel_fn) xnn_f32_igemm_minmax_ukernel_5x16__fma3_broadcast_prfm); - f32_gemm_config.init.f32 = xnn_init_f32_minmax_scalar_params; - f32_gemm_config.pack_gemm_gio = (xnn_packw_gemm_gio_ukernel_fn) xnn_x32_packw_gemm_gio_ukernel_x16__avx_u8; - f32_gemm_config.pack_gemm_goi = (xnn_packw_gemm_goi_ukernel_fn) xnn_x32_packw_gemm_goi_ukernel_x16__avx_u4; - f32_gemm_config.mr = 5; - f32_gemm_config.nr = 16; + f32_igemm_config.minmax.gemm[XNN_MR_TO_INDEX(1)] = xnn_init_hmp_gemm_ukernel((xnn_gemm_ukernel_fn) xnn_f32_gemm_minmax_ukernel_1x16__fma3_broadcast); + f32_igemm_config.minmax.gemm[XNN_MR_TO_INDEX(5)] = xnn_init_hmp_gemm_ukernel((xnn_gemm_ukernel_fn) xnn_f32_gemm_minmax_ukernel_5x16__fma3_broadcast); + f32_igemm_config.minmax.igemm[XNN_MR_TO_INDEX(1)] = xnn_init_hmp_igemm_ukernel((xnn_igemm_ukernel_fn) xnn_f32_igemm_minmax_ukernel_1x16__fma3_broadcast); + f32_igemm_config.minmax.igemm[XNN_MR_TO_INDEX(5)] = xnn_init_hmp_igemm_ukernel((xnn_igemm_ukernel_fn) xnn_f32_igemm_minmax_ukernel_5x16__fma3_broadcast_prfm); + f32_igemm_config.init.f32 = xnn_init_f32_minmax_scalar_params; + f32_igemm_config.pack_gemm_gio = (xnn_packw_gemm_gio_ukernel_fn) xnn_x32_packw_gemm_gio_ukernel_x16__avx_u8; + f32_igemm_config.pack_gemm_goi = (xnn_packw_gemm_goi_ukernel_fn) xnn_x32_packw_gemm_goi_ukernel_x16__avx_u4; + f32_igemm_config.mr = 5; + f32_igemm_config.nr = 16; break; } } else if (hardware_config->use_x86_avx) { - f32_gemm_config.minmax.gemm[XNN_MR_TO_INDEX(1)] = xnn_init_hmp_gemm_ukernel((xnn_gemm_ukernel_fn) xnn_f32_gemm_minmax_ukernel_1x16__avx_broadcast); - f32_gemm_config.minmax.gemm[XNN_MR_TO_INDEX(5)] = xnn_init_hmp_gemm_ukernel((xnn_gemm_ukernel_fn) xnn_f32_gemm_minmax_ukernel_5x16__avx_broadcast); - f32_gemm_config.minmax.igemm[XNN_MR_TO_INDEX(1)] = xnn_init_hmp_igemm_ukernel((xnn_igemm_ukernel_fn) xnn_f32_igemm_minmax_ukernel_1x16__avx_broadcast); - f32_gemm_config.minmax.igemm[XNN_MR_TO_INDEX(5)] = xnn_init_hmp_igemm_ukernel((xnn_igemm_ukernel_fn) xnn_f32_igemm_minmax_ukernel_5x16__avx_broadcast); - f32_gemm_config.init.f32 = xnn_init_f32_minmax_scalar_params; - f32_gemm_config.pack_gemm_gio = (xnn_packw_gemm_gio_ukernel_fn) xnn_x32_packw_gemm_gio_ukernel_x16__avx_u8; - f32_gemm_config.pack_gemm_goi = (xnn_packw_gemm_goi_ukernel_fn) xnn_x32_packw_gemm_goi_ukernel_x16__avx_u4; - f32_gemm_config.mr = 5; - f32_gemm_config.nr = 16; + f32_igemm_config.minmax.gemm[XNN_MR_TO_INDEX(1)] = xnn_init_hmp_gemm_ukernel((xnn_gemm_ukernel_fn) xnn_f32_gemm_minmax_ukernel_1x16__avx_broadcast); + f32_igemm_config.minmax.gemm[XNN_MR_TO_INDEX(5)] = xnn_init_hmp_gemm_ukernel((xnn_gemm_ukernel_fn) xnn_f32_gemm_minmax_ukernel_5x16__avx_broadcast); + f32_igemm_config.minmax.igemm[XNN_MR_TO_INDEX(1)] = xnn_init_hmp_igemm_ukernel((xnn_igemm_ukernel_fn) xnn_f32_igemm_minmax_ukernel_1x16__avx_broadcast); + f32_igemm_config.minmax.igemm[XNN_MR_TO_INDEX(5)] = xnn_init_hmp_igemm_ukernel((xnn_igemm_ukernel_fn) xnn_f32_igemm_minmax_ukernel_5x16__avx_broadcast); + f32_igemm_config.init.f32 = xnn_init_f32_minmax_scalar_params; + f32_igemm_config.pack_gemm_gio = (xnn_packw_gemm_gio_ukernel_fn) xnn_x32_packw_gemm_gio_ukernel_x16__avx_u8; + f32_igemm_config.pack_gemm_goi = (xnn_packw_gemm_goi_ukernel_fn) xnn_x32_packw_gemm_goi_ukernel_x16__avx_u4; + f32_igemm_config.mr = 5; + f32_igemm_config.nr = 16; } else { - f32_gemm_config.minmax.gemm[XNN_MR_TO_INDEX(1)] = xnn_init_hmp_gemm_ukernel((xnn_gemm_ukernel_fn) xnn_f32_gemm_minmax_ukernel_1x8__sse_load1); - f32_gemm_config.minmax.gemm[XNN_MR_TO_INDEX(4)] = xnn_init_hmp_gemm_ukernel((xnn_gemm_ukernel_fn) xnn_f32_gemm_minmax_ukernel_4x8__sse_load1); - f32_gemm_config.minmax.igemm[XNN_MR_TO_INDEX(1)] = xnn_init_hmp_igemm_ukernel((xnn_igemm_ukernel_fn) xnn_f32_igemm_minmax_ukernel_1x8__sse_load1); - f32_gemm_config.minmax.igemm[XNN_MR_TO_INDEX(4)] = xnn_init_hmp_igemm_ukernel((xnn_igemm_ukernel_fn) xnn_f32_igemm_minmax_ukernel_4x8__sse_load1); - f32_gemm_config.init.f32 = xnn_init_f32_minmax_scalar_params; - f32_gemm_config.pack_gemm_gio = (xnn_packw_gemm_gio_ukernel_fn) xnn_pack_f32_gemm_gio_w; - f32_gemm_config.pack_gemm_goi = (xnn_packw_gemm_goi_ukernel_fn) xnn_x32_packw_gemm_goi_ukernel_x8__sse2_u4; - f32_gemm_config.mr = 4; - f32_gemm_config.nr = 8; + f32_igemm_config.minmax.gemm[XNN_MR_TO_INDEX(1)] = xnn_init_hmp_gemm_ukernel((xnn_gemm_ukernel_fn) xnn_f32_gemm_minmax_ukernel_1x8__sse_load1); + f32_igemm_config.minmax.gemm[XNN_MR_TO_INDEX(4)] = xnn_init_hmp_gemm_ukernel((xnn_gemm_ukernel_fn) xnn_f32_gemm_minmax_ukernel_4x8__sse_load1); + f32_igemm_config.minmax.igemm[XNN_MR_TO_INDEX(1)] = xnn_init_hmp_igemm_ukernel((xnn_igemm_ukernel_fn) xnn_f32_igemm_minmax_ukernel_1x8__sse_load1); + f32_igemm_config.minmax.igemm[XNN_MR_TO_INDEX(4)] = xnn_init_hmp_igemm_ukernel((xnn_igemm_ukernel_fn) xnn_f32_igemm_minmax_ukernel_4x8__sse_load1); + f32_igemm_config.init.f32 = xnn_init_f32_minmax_scalar_params; + f32_igemm_config.pack_gemm_gio = (xnn_packw_gemm_gio_ukernel_fn) xnn_pack_f32_gemm_gio_w; + f32_igemm_config.pack_gemm_goi = (xnn_packw_gemm_goi_ukernel_fn) xnn_x32_packw_gemm_goi_ukernel_x8__sse2_u4; + f32_igemm_config.mr = 4; + f32_igemm_config.nr = 8; } #elif XNN_ARCH_WASMSIMD || XNN_ARCH_WASMRELAXEDSIMD const struct xnn_hardware_config* hardware_config = xnn_init_hardware_config(); @@ -832,90 +835,90 @@ static void init_f32_gemm_config(void) { (void) hardware_config; // May be unused. if (hardware_config->is_x86) { #if XNN_ARCH_WASMRELAXEDSIMD - f32_gemm_config.linear.gemm[XNN_MR_TO_INDEX(1)] = xnn_init_hmp_gemm_ukernel((xnn_gemm_ukernel_fn) xnn_f32_gemm_ukernel_1x8__wasmrelaxedsimd_fma_loadsplat); - f32_gemm_config.linear.gemm[XNN_MR_TO_INDEX(4)] = xnn_init_hmp_gemm_ukernel((xnn_gemm_ukernel_fn) xnn_f32_gemm_ukernel_4x8__wasmrelaxedsimd_fma_loadsplat); - f32_gemm_config.linear.igemm[XNN_MR_TO_INDEX(1)] = xnn_init_hmp_igemm_ukernel((xnn_igemm_ukernel_fn) xnn_f32_igemm_ukernel_1x8__wasmrelaxedsimd_fma_loadsplat); - f32_gemm_config.linear.igemm[XNN_MR_TO_INDEX(4)] = xnn_init_hmp_igemm_ukernel((xnn_igemm_ukernel_fn) xnn_f32_igemm_ukernel_4x8__wasmrelaxedsimd_fma_loadsplat); - f32_gemm_config.minmax.gemm[XNN_MR_TO_INDEX(1)] = xnn_init_hmp_gemm_ukernel((xnn_gemm_ukernel_fn) xnn_f32_gemm_minmax_ukernel_1x8__wasmrelaxedsimd_fma_loadsplat); - f32_gemm_config.minmax.gemm[XNN_MR_TO_INDEX(4)] = xnn_init_hmp_gemm_ukernel((xnn_gemm_ukernel_fn) xnn_f32_gemm_minmax_ukernel_4x8__wasmrelaxedsimd_fma_loadsplat); - f32_gemm_config.minmax.igemm[XNN_MR_TO_INDEX(1)] = xnn_init_hmp_igemm_ukernel((xnn_igemm_ukernel_fn) xnn_f32_igemm_minmax_ukernel_1x8__wasmrelaxedsimd_fma_loadsplat); - f32_gemm_config.minmax.igemm[XNN_MR_TO_INDEX(4)] = xnn_init_hmp_igemm_ukernel((xnn_igemm_ukernel_fn) xnn_f32_igemm_minmax_ukernel_4x8__wasmrelaxedsimd_fma_loadsplat); - f32_gemm_config.relu.gemm[XNN_MR_TO_INDEX(1)] = xnn_init_hmp_gemm_ukernel((xnn_gemm_ukernel_fn) xnn_f32_gemm_relu_ukernel_1x8__wasmrelaxedsimd_fma_loadsplat); - f32_gemm_config.relu.gemm[XNN_MR_TO_INDEX(4)] = xnn_init_hmp_gemm_ukernel((xnn_gemm_ukernel_fn) xnn_f32_gemm_relu_ukernel_4x8__wasmrelaxedsimd_fma_loadsplat); - f32_gemm_config.relu.igemm[XNN_MR_TO_INDEX(1)] = xnn_init_hmp_igemm_ukernel((xnn_igemm_ukernel_fn) xnn_f32_igemm_relu_ukernel_1x8__wasmrelaxedsimd_fma_loadsplat); - f32_gemm_config.relu.igemm[XNN_MR_TO_INDEX(4)] = xnn_init_hmp_igemm_ukernel((xnn_igemm_ukernel_fn) xnn_f32_igemm_relu_ukernel_4x8__wasmrelaxedsimd_fma_loadsplat); + f32_igemm_config.linear.gemm[XNN_MR_TO_INDEX(1)] = xnn_init_hmp_gemm_ukernel((xnn_gemm_ukernel_fn) xnn_f32_gemm_ukernel_1x8__wasmrelaxedsimd_fma_loadsplat); + f32_igemm_config.linear.gemm[XNN_MR_TO_INDEX(4)] = xnn_init_hmp_gemm_ukernel((xnn_gemm_ukernel_fn) xnn_f32_gemm_ukernel_4x8__wasmrelaxedsimd_fma_loadsplat); + f32_igemm_config.linear.igemm[XNN_MR_TO_INDEX(1)] = xnn_init_hmp_igemm_ukernel((xnn_igemm_ukernel_fn) xnn_f32_igemm_ukernel_1x8__wasmrelaxedsimd_fma_loadsplat); + f32_igemm_config.linear.igemm[XNN_MR_TO_INDEX(4)] = xnn_init_hmp_igemm_ukernel((xnn_igemm_ukernel_fn) xnn_f32_igemm_ukernel_4x8__wasmrelaxedsimd_fma_loadsplat); + f32_igemm_config.minmax.gemm[XNN_MR_TO_INDEX(1)] = xnn_init_hmp_gemm_ukernel((xnn_gemm_ukernel_fn) xnn_f32_gemm_minmax_ukernel_1x8__wasmrelaxedsimd_fma_loadsplat); + f32_igemm_config.minmax.gemm[XNN_MR_TO_INDEX(4)] = xnn_init_hmp_gemm_ukernel((xnn_gemm_ukernel_fn) xnn_f32_gemm_minmax_ukernel_4x8__wasmrelaxedsimd_fma_loadsplat); + f32_igemm_config.minmax.igemm[XNN_MR_TO_INDEX(1)] = xnn_init_hmp_igemm_ukernel((xnn_igemm_ukernel_fn) xnn_f32_igemm_minmax_ukernel_1x8__wasmrelaxedsimd_fma_loadsplat); + f32_igemm_config.minmax.igemm[XNN_MR_TO_INDEX(4)] = xnn_init_hmp_igemm_ukernel((xnn_igemm_ukernel_fn) xnn_f32_igemm_minmax_ukernel_4x8__wasmrelaxedsimd_fma_loadsplat); + f32_igemm_config.relu.gemm[XNN_MR_TO_INDEX(1)] = xnn_init_hmp_gemm_ukernel((xnn_gemm_ukernel_fn) xnn_f32_gemm_relu_ukernel_1x8__wasmrelaxedsimd_fma_loadsplat); + f32_igemm_config.relu.gemm[XNN_MR_TO_INDEX(4)] = xnn_init_hmp_gemm_ukernel((xnn_gemm_ukernel_fn) xnn_f32_gemm_relu_ukernel_4x8__wasmrelaxedsimd_fma_loadsplat); + f32_igemm_config.relu.igemm[XNN_MR_TO_INDEX(1)] = xnn_init_hmp_igemm_ukernel((xnn_igemm_ukernel_fn) xnn_f32_igemm_relu_ukernel_1x8__wasmrelaxedsimd_fma_loadsplat); + f32_igemm_config.relu.igemm[XNN_MR_TO_INDEX(4)] = xnn_init_hmp_igemm_ukernel((xnn_igemm_ukernel_fn) xnn_f32_igemm_relu_ukernel_4x8__wasmrelaxedsimd_fma_loadsplat); #else if (hardware_concurrency() > kCoreCountThresholdForAdaptiveAvxOptimization) { - f32_gemm_config.linear.gemm[XNN_MR_TO_INDEX(1)] = xnn_init_hmp_gemm_ukernel((xnn_gemm_ukernel_fn) xnn_f32_gemm_ukernel_1x8__wasmsimd_loadsplat); - f32_gemm_config.linear.gemm[XNN_MR_TO_INDEX(4)] = xnn_init_hmp_gemm_ukernel((xnn_gemm_ukernel_fn) xnn_f32_gemm_ukernel_4x8__wasmsimd_loadsplat); - f32_gemm_config.linear.igemm[XNN_MR_TO_INDEX(1)] = xnn_init_hmp_igemm_ukernel((xnn_igemm_ukernel_fn) xnn_f32_igemm_ukernel_1x8__wasmsimd_loadsplat); - f32_gemm_config.linear.igemm[XNN_MR_TO_INDEX(4)] = xnn_init_hmp_igemm_ukernel((xnn_igemm_ukernel_fn) xnn_f32_igemm_ukernel_4x8__wasmsimd_loadsplat); - f32_gemm_config.minmax.gemm[XNN_MR_TO_INDEX(1)] = xnn_init_hmp_gemm_ukernel((xnn_gemm_ukernel_fn) xnn_f32_gemm_minmax_ukernel_1x8__wasmsimd_x86_loadsplat); - f32_gemm_config.minmax.gemm[XNN_MR_TO_INDEX(4)] = xnn_init_hmp_gemm_ukernel((xnn_gemm_ukernel_fn) xnn_f32_gemm_minmax_ukernel_4x8__wasmsimd_x86_loadsplat); - f32_gemm_config.minmax.igemm[XNN_MR_TO_INDEX(1)] = xnn_init_hmp_igemm_ukernel((xnn_igemm_ukernel_fn) xnn_f32_igemm_minmax_ukernel_1x8__wasmsimd_x86_loadsplat); - f32_gemm_config.minmax.igemm[XNN_MR_TO_INDEX(4)] = xnn_init_hmp_igemm_ukernel((xnn_igemm_ukernel_fn) xnn_f32_igemm_minmax_ukernel_4x8__wasmsimd_x86_loadsplat); - f32_gemm_config.relu.gemm[XNN_MR_TO_INDEX(1)] = xnn_init_hmp_gemm_ukernel((xnn_gemm_ukernel_fn) xnn_f32_gemm_relu_ukernel_1x8__wasmsimd_loadsplat); - f32_gemm_config.relu.gemm[XNN_MR_TO_INDEX(4)] = xnn_init_hmp_gemm_ukernel((xnn_gemm_ukernel_fn) xnn_f32_gemm_relu_ukernel_4x8__wasmsimd_loadsplat); - f32_gemm_config.relu.igemm[XNN_MR_TO_INDEX(1)] = xnn_init_hmp_igemm_ukernel((xnn_igemm_ukernel_fn) xnn_f32_igemm_relu_ukernel_1x8__wasmsimd_loadsplat); - f32_gemm_config.relu.igemm[XNN_MR_TO_INDEX(4)] = xnn_init_hmp_igemm_ukernel((xnn_igemm_ukernel_fn) xnn_f32_igemm_relu_ukernel_4x8__wasmsimd_loadsplat); + f32_igemm_config.linear.gemm[XNN_MR_TO_INDEX(1)] = xnn_init_hmp_gemm_ukernel((xnn_gemm_ukernel_fn) xnn_f32_gemm_ukernel_1x8__wasmsimd_loadsplat); + f32_igemm_config.linear.gemm[XNN_MR_TO_INDEX(4)] = xnn_init_hmp_gemm_ukernel((xnn_gemm_ukernel_fn) xnn_f32_gemm_ukernel_4x8__wasmsimd_loadsplat); + f32_igemm_config.linear.igemm[XNN_MR_TO_INDEX(1)] = xnn_init_hmp_igemm_ukernel((xnn_igemm_ukernel_fn) xnn_f32_igemm_ukernel_1x8__wasmsimd_loadsplat); + f32_igemm_config.linear.igemm[XNN_MR_TO_INDEX(4)] = xnn_init_hmp_igemm_ukernel((xnn_igemm_ukernel_fn) xnn_f32_igemm_ukernel_4x8__wasmsimd_loadsplat); + f32_igemm_config.minmax.gemm[XNN_MR_TO_INDEX(1)] = xnn_init_hmp_gemm_ukernel((xnn_gemm_ukernel_fn) xnn_f32_gemm_minmax_ukernel_1x8__wasmsimd_x86_loadsplat); + f32_igemm_config.minmax.gemm[XNN_MR_TO_INDEX(4)] = xnn_init_hmp_gemm_ukernel((xnn_gemm_ukernel_fn) xnn_f32_gemm_minmax_ukernel_4x8__wasmsimd_x86_loadsplat); + f32_igemm_config.minmax.igemm[XNN_MR_TO_INDEX(1)] = xnn_init_hmp_igemm_ukernel((xnn_igemm_ukernel_fn) xnn_f32_igemm_minmax_ukernel_1x8__wasmsimd_x86_loadsplat); + f32_igemm_config.minmax.igemm[XNN_MR_TO_INDEX(4)] = xnn_init_hmp_igemm_ukernel((xnn_igemm_ukernel_fn) xnn_f32_igemm_minmax_ukernel_4x8__wasmsimd_x86_loadsplat); + f32_igemm_config.relu.gemm[XNN_MR_TO_INDEX(1)] = xnn_init_hmp_gemm_ukernel((xnn_gemm_ukernel_fn) xnn_f32_gemm_relu_ukernel_1x8__wasmsimd_loadsplat); + f32_igemm_config.relu.gemm[XNN_MR_TO_INDEX(4)] = xnn_init_hmp_gemm_ukernel((xnn_gemm_ukernel_fn) xnn_f32_gemm_relu_ukernel_4x8__wasmsimd_loadsplat); + f32_igemm_config.relu.igemm[XNN_MR_TO_INDEX(1)] = xnn_init_hmp_igemm_ukernel((xnn_igemm_ukernel_fn) xnn_f32_igemm_relu_ukernel_1x8__wasmsimd_loadsplat); + f32_igemm_config.relu.igemm[XNN_MR_TO_INDEX(4)] = xnn_init_hmp_igemm_ukernel((xnn_igemm_ukernel_fn) xnn_f32_igemm_relu_ukernel_4x8__wasmsimd_loadsplat); } else { - f32_gemm_config.linear.gemm[XNN_MR_TO_INDEX(1)] = xnn_init_hmp_gemm_ukernel((xnn_gemm_ukernel_fn) xnn_f32_gemm_ukernel_1x8__wasmsimd_splat); - f32_gemm_config.linear.gemm[XNN_MR_TO_INDEX(4)] = xnn_init_hmp_gemm_ukernel((xnn_gemm_ukernel_fn) xnn_f32_gemm_ukernel_4x8__wasmsimd_splat); - f32_gemm_config.linear.igemm[XNN_MR_TO_INDEX(1)] = xnn_init_hmp_igemm_ukernel((xnn_igemm_ukernel_fn) xnn_f32_igemm_ukernel_1x8__wasmsimd_splat); - f32_gemm_config.linear.igemm[XNN_MR_TO_INDEX(4)] = xnn_init_hmp_igemm_ukernel((xnn_igemm_ukernel_fn) xnn_f32_igemm_ukernel_4x8__wasmsimd_splat); - f32_gemm_config.minmax.gemm[XNN_MR_TO_INDEX(1)] = xnn_init_hmp_gemm_ukernel((xnn_gemm_ukernel_fn) xnn_f32_gemm_minmax_ukernel_1x8__wasmsimd_x86_splat); - f32_gemm_config.minmax.gemm[XNN_MR_TO_INDEX(4)] = xnn_init_hmp_gemm_ukernel((xnn_gemm_ukernel_fn) xnn_f32_gemm_minmax_ukernel_4x8__wasmsimd_x86_splat); - f32_gemm_config.minmax.igemm[XNN_MR_TO_INDEX(1)] = xnn_init_hmp_igemm_ukernel((xnn_igemm_ukernel_fn) xnn_f32_igemm_minmax_ukernel_1x8__wasmsimd_x86_splat); - f32_gemm_config.minmax.igemm[XNN_MR_TO_INDEX(4)] = xnn_init_hmp_igemm_ukernel((xnn_igemm_ukernel_fn) xnn_f32_igemm_minmax_ukernel_4x8__wasmsimd_x86_splat); - f32_gemm_config.relu.gemm[XNN_MR_TO_INDEX(1)] = xnn_init_hmp_gemm_ukernel((xnn_gemm_ukernel_fn) xnn_f32_gemm_relu_ukernel_1x8__wasmsimd_splat); - f32_gemm_config.relu.gemm[XNN_MR_TO_INDEX(4)] = xnn_init_hmp_gemm_ukernel((xnn_gemm_ukernel_fn) xnn_f32_gemm_relu_ukernel_4x8__wasmsimd_splat); - f32_gemm_config.relu.igemm[XNN_MR_TO_INDEX(1)] = xnn_init_hmp_igemm_ukernel((xnn_igemm_ukernel_fn) xnn_f32_igemm_relu_ukernel_1x8__wasmsimd_splat); - f32_gemm_config.relu.igemm[XNN_MR_TO_INDEX(4)] = xnn_init_hmp_igemm_ukernel((xnn_igemm_ukernel_fn) xnn_f32_igemm_relu_ukernel_4x8__wasmsimd_splat); + f32_igemm_config.linear.gemm[XNN_MR_TO_INDEX(1)] = xnn_init_hmp_gemm_ukernel((xnn_gemm_ukernel_fn) xnn_f32_gemm_ukernel_1x8__wasmsimd_splat); + f32_igemm_config.linear.gemm[XNN_MR_TO_INDEX(4)] = xnn_init_hmp_gemm_ukernel((xnn_gemm_ukernel_fn) xnn_f32_gemm_ukernel_4x8__wasmsimd_splat); + f32_igemm_config.linear.igemm[XNN_MR_TO_INDEX(1)] = xnn_init_hmp_igemm_ukernel((xnn_igemm_ukernel_fn) xnn_f32_igemm_ukernel_1x8__wasmsimd_splat); + f32_igemm_config.linear.igemm[XNN_MR_TO_INDEX(4)] = xnn_init_hmp_igemm_ukernel((xnn_igemm_ukernel_fn) xnn_f32_igemm_ukernel_4x8__wasmsimd_splat); + f32_igemm_config.minmax.gemm[XNN_MR_TO_INDEX(1)] = xnn_init_hmp_gemm_ukernel((xnn_gemm_ukernel_fn) xnn_f32_gemm_minmax_ukernel_1x8__wasmsimd_x86_splat); + f32_igemm_config.minmax.gemm[XNN_MR_TO_INDEX(4)] = xnn_init_hmp_gemm_ukernel((xnn_gemm_ukernel_fn) xnn_f32_gemm_minmax_ukernel_4x8__wasmsimd_x86_splat); + f32_igemm_config.minmax.igemm[XNN_MR_TO_INDEX(1)] = xnn_init_hmp_igemm_ukernel((xnn_igemm_ukernel_fn) xnn_f32_igemm_minmax_ukernel_1x8__wasmsimd_x86_splat); + f32_igemm_config.minmax.igemm[XNN_MR_TO_INDEX(4)] = xnn_init_hmp_igemm_ukernel((xnn_igemm_ukernel_fn) xnn_f32_igemm_minmax_ukernel_4x8__wasmsimd_x86_splat); + f32_igemm_config.relu.gemm[XNN_MR_TO_INDEX(1)] = xnn_init_hmp_gemm_ukernel((xnn_gemm_ukernel_fn) xnn_f32_gemm_relu_ukernel_1x8__wasmsimd_splat); + f32_igemm_config.relu.gemm[XNN_MR_TO_INDEX(4)] = xnn_init_hmp_gemm_ukernel((xnn_gemm_ukernel_fn) xnn_f32_gemm_relu_ukernel_4x8__wasmsimd_splat); + f32_igemm_config.relu.igemm[XNN_MR_TO_INDEX(1)] = xnn_init_hmp_igemm_ukernel((xnn_igemm_ukernel_fn) xnn_f32_igemm_relu_ukernel_1x8__wasmsimd_splat); + f32_igemm_config.relu.igemm[XNN_MR_TO_INDEX(4)] = xnn_init_hmp_igemm_ukernel((xnn_igemm_ukernel_fn) xnn_f32_igemm_relu_ukernel_4x8__wasmsimd_splat); } #endif - f32_gemm_config.init.f32 = xnn_init_f32_minmax_scalar_params; - f32_gemm_config.pack_gemm_gio = (xnn_packw_gemm_gio_ukernel_fn) xnn_pack_f32_gemm_gio_w; - f32_gemm_config.pack_gemm_goi = (xnn_packw_gemm_goi_ukernel_fn) xnn_x32_packw_gemm_goi_ukernel_x8__wasmsimd_u4; - f32_gemm_config.mr = 4; - f32_gemm_config.nr = 8; + f32_igemm_config.init.f32 = xnn_init_f32_minmax_scalar_params; + f32_igemm_config.pack_gemm_gio = (xnn_packw_gemm_gio_ukernel_fn) xnn_pack_f32_gemm_gio_w; + f32_igemm_config.pack_gemm_goi = (xnn_packw_gemm_goi_ukernel_fn) xnn_x32_packw_gemm_goi_ukernel_x8__wasmsimd_u4; + f32_igemm_config.mr = 4; + f32_igemm_config.nr = 8; } else { #if XNN_ARCH_WASMRELAXEDSIMD - f32_gemm_config.linear.gemm[XNN_MR_TO_INDEX(1)] = xnn_init_hmp_gemm_ukernel((xnn_gemm_ukernel_fn) xnn_f32_gemm_ukernel_1x8__wasmrelaxedsimd_fma_splat); - f32_gemm_config.linear.gemm[XNN_MR_TO_INDEX(6)] = xnn_init_hmp_gemm_ukernel((xnn_gemm_ukernel_fn) xnn_f32_gemm_ukernel_6x8__wasmrelaxedsimd_fma_splat); - f32_gemm_config.linear.igemm[XNN_MR_TO_INDEX(1)] = xnn_init_hmp_igemm_ukernel((xnn_igemm_ukernel_fn) xnn_f32_igemm_ukernel_1x8__wasmrelaxedsimd_fma_splat); - f32_gemm_config.linear.igemm[XNN_MR_TO_INDEX(6)] = xnn_init_hmp_igemm_ukernel((xnn_igemm_ukernel_fn) xnn_f32_igemm_ukernel_6x8__wasmrelaxedsimd_fma_splat); - f32_gemm_config.minmax.gemm[XNN_MR_TO_INDEX(1)] = xnn_init_hmp_gemm_ukernel((xnn_gemm_ukernel_fn) xnn_f32_gemm_minmax_ukernel_1x8__wasmrelaxedsimd_fma_splat); - f32_gemm_config.minmax.gemm[XNN_MR_TO_INDEX(6)] = xnn_init_hmp_gemm_ukernel((xnn_gemm_ukernel_fn) xnn_f32_gemm_minmax_ukernel_6x8__wasmrelaxedsimd_fma_splat); - f32_gemm_config.minmax.igemm[XNN_MR_TO_INDEX(1)] = xnn_init_hmp_igemm_ukernel((xnn_igemm_ukernel_fn) xnn_f32_igemm_minmax_ukernel_1x8__wasmrelaxedsimd_fma_splat); - f32_gemm_config.minmax.igemm[XNN_MR_TO_INDEX(6)] = xnn_init_hmp_igemm_ukernel((xnn_igemm_ukernel_fn) xnn_f32_igemm_minmax_ukernel_6x8__wasmrelaxedsimd_fma_splat); - f32_gemm_config.relu.gemm[XNN_MR_TO_INDEX(1)] = xnn_init_hmp_gemm_ukernel((xnn_gemm_ukernel_fn) xnn_f32_gemm_relu_ukernel_1x8__wasmrelaxedsimd_fma_splat); - f32_gemm_config.relu.gemm[XNN_MR_TO_INDEX(6)] = xnn_init_hmp_gemm_ukernel((xnn_gemm_ukernel_fn) xnn_f32_gemm_relu_ukernel_6x8__wasmrelaxedsimd_fma_splat); - f32_gemm_config.relu.igemm[XNN_MR_TO_INDEX(1)] = xnn_init_hmp_igemm_ukernel((xnn_igemm_ukernel_fn) xnn_f32_igemm_relu_ukernel_1x8__wasmrelaxedsimd_fma_splat); - f32_gemm_config.relu.igemm[XNN_MR_TO_INDEX(6)] = xnn_init_hmp_igemm_ukernel((xnn_igemm_ukernel_fn) xnn_f32_igemm_relu_ukernel_6x8__wasmrelaxedsimd_fma_splat); - f32_gemm_config.init.f32 = xnn_init_f32_minmax_scalar_params; - f32_gemm_config.pack_gemm_gio = (xnn_packw_gemm_gio_ukernel_fn) xnn_pack_f32_gemm_gio_w; - f32_gemm_config.pack_gemm_goi = (xnn_packw_gemm_goi_ukernel_fn) xnn_x32_packw_gemm_goi_ukernel_x8__wasmsimd_u4; - f32_gemm_config.mr = 6; - f32_gemm_config.nr = 8; + f32_igemm_config.linear.gemm[XNN_MR_TO_INDEX(1)] = xnn_init_hmp_gemm_ukernel((xnn_gemm_ukernel_fn) xnn_f32_gemm_ukernel_1x8__wasmrelaxedsimd_fma_splat); + f32_igemm_config.linear.gemm[XNN_MR_TO_INDEX(6)] = xnn_init_hmp_gemm_ukernel((xnn_gemm_ukernel_fn) xnn_f32_gemm_ukernel_6x8__wasmrelaxedsimd_fma_splat); + f32_igemm_config.linear.igemm[XNN_MR_TO_INDEX(1)] = xnn_init_hmp_igemm_ukernel((xnn_igemm_ukernel_fn) xnn_f32_igemm_ukernel_1x8__wasmrelaxedsimd_fma_splat); + f32_igemm_config.linear.igemm[XNN_MR_TO_INDEX(6)] = xnn_init_hmp_igemm_ukernel((xnn_igemm_ukernel_fn) xnn_f32_igemm_ukernel_6x8__wasmrelaxedsimd_fma_splat); + f32_igemm_config.minmax.gemm[XNN_MR_TO_INDEX(1)] = xnn_init_hmp_gemm_ukernel((xnn_gemm_ukernel_fn) xnn_f32_gemm_minmax_ukernel_1x8__wasmrelaxedsimd_fma_splat); + f32_igemm_config.minmax.gemm[XNN_MR_TO_INDEX(6)] = xnn_init_hmp_gemm_ukernel((xnn_gemm_ukernel_fn) xnn_f32_gemm_minmax_ukernel_6x8__wasmrelaxedsimd_fma_splat); + f32_igemm_config.minmax.igemm[XNN_MR_TO_INDEX(1)] = xnn_init_hmp_igemm_ukernel((xnn_igemm_ukernel_fn) xnn_f32_igemm_minmax_ukernel_1x8__wasmrelaxedsimd_fma_splat); + f32_igemm_config.minmax.igemm[XNN_MR_TO_INDEX(6)] = xnn_init_hmp_igemm_ukernel((xnn_igemm_ukernel_fn) xnn_f32_igemm_minmax_ukernel_6x8__wasmrelaxedsimd_fma_splat); + f32_igemm_config.relu.gemm[XNN_MR_TO_INDEX(1)] = xnn_init_hmp_gemm_ukernel((xnn_gemm_ukernel_fn) xnn_f32_gemm_relu_ukernel_1x8__wasmrelaxedsimd_fma_splat); + f32_igemm_config.relu.gemm[XNN_MR_TO_INDEX(6)] = xnn_init_hmp_gemm_ukernel((xnn_gemm_ukernel_fn) xnn_f32_gemm_relu_ukernel_6x8__wasmrelaxedsimd_fma_splat); + f32_igemm_config.relu.igemm[XNN_MR_TO_INDEX(1)] = xnn_init_hmp_igemm_ukernel((xnn_igemm_ukernel_fn) xnn_f32_igemm_relu_ukernel_1x8__wasmrelaxedsimd_fma_splat); + f32_igemm_config.relu.igemm[XNN_MR_TO_INDEX(6)] = xnn_init_hmp_igemm_ukernel((xnn_igemm_ukernel_fn) xnn_f32_igemm_relu_ukernel_6x8__wasmrelaxedsimd_fma_splat); + f32_igemm_config.init.f32 = xnn_init_f32_minmax_scalar_params; + f32_igemm_config.pack_gemm_gio = (xnn_packw_gemm_gio_ukernel_fn) xnn_pack_f32_gemm_gio_w; + f32_igemm_config.pack_gemm_goi = (xnn_packw_gemm_goi_ukernel_fn) xnn_x32_packw_gemm_goi_ukernel_x8__wasmsimd_u4; + f32_igemm_config.mr = 6; + f32_igemm_config.nr = 8; #else - f32_gemm_config.linear.gemm[XNN_MR_TO_INDEX(1)] = xnn_init_hmp_gemm_ukernel((xnn_gemm_ukernel_fn) xnn_f32_gemm_ukernel_1x8__wasmsimd_splat); - f32_gemm_config.linear.gemm[XNN_MR_TO_INDEX(5)] = xnn_init_hmp_gemm_ukernel((xnn_gemm_ukernel_fn) xnn_f32_gemm_ukernel_5x8__wasmsimd_splat); - f32_gemm_config.linear.igemm[XNN_MR_TO_INDEX(1)] = xnn_init_hmp_igemm_ukernel((xnn_igemm_ukernel_fn) xnn_f32_igemm_ukernel_1x8__wasmsimd_splat); - f32_gemm_config.linear.igemm[XNN_MR_TO_INDEX(5)] = xnn_init_hmp_igemm_ukernel((xnn_igemm_ukernel_fn) xnn_f32_igemm_ukernel_5x8__wasmsimd_splat); - f32_gemm_config.minmax.gemm[XNN_MR_TO_INDEX(1)] = xnn_init_hmp_gemm_ukernel((xnn_gemm_ukernel_fn) xnn_f32_gemm_minmax_ukernel_1x8__wasmsimd_arm_splat); - f32_gemm_config.minmax.gemm[XNN_MR_TO_INDEX(5)] = xnn_init_hmp_gemm_ukernel((xnn_gemm_ukernel_fn) xnn_f32_gemm_minmax_ukernel_5x8__wasmsimd_arm_splat); - f32_gemm_config.minmax.igemm[XNN_MR_TO_INDEX(1)] = xnn_init_hmp_igemm_ukernel((xnn_igemm_ukernel_fn) xnn_f32_igemm_minmax_ukernel_1x8__wasmsimd_arm_splat); - f32_gemm_config.minmax.igemm[XNN_MR_TO_INDEX(5)] = xnn_init_hmp_igemm_ukernel((xnn_igemm_ukernel_fn) xnn_f32_igemm_minmax_ukernel_5x8__wasmsimd_arm_splat); - f32_gemm_config.relu.gemm[XNN_MR_TO_INDEX(1)] = xnn_init_hmp_gemm_ukernel((xnn_gemm_ukernel_fn) xnn_f32_gemm_relu_ukernel_1x8__wasmsimd_splat); - f32_gemm_config.relu.gemm[XNN_MR_TO_INDEX(5)] = xnn_init_hmp_gemm_ukernel((xnn_gemm_ukernel_fn) xnn_f32_gemm_relu_ukernel_5x8__wasmsimd_splat); - f32_gemm_config.relu.igemm[XNN_MR_TO_INDEX(1)] = xnn_init_hmp_igemm_ukernel((xnn_igemm_ukernel_fn) xnn_f32_igemm_relu_ukernel_1x8__wasmsimd_splat); - f32_gemm_config.relu.igemm[XNN_MR_TO_INDEX(5)] = xnn_init_hmp_igemm_ukernel((xnn_igemm_ukernel_fn) xnn_f32_igemm_relu_ukernel_5x8__wasmsimd_splat); - f32_gemm_config.init.f32 = xnn_init_f32_minmax_scalar_params; - f32_gemm_config.pack_gemm_gio = (xnn_packw_gemm_gio_ukernel_fn) xnn_pack_f32_gemm_gio_w; - f32_gemm_config.pack_gemm_goi = (xnn_packw_gemm_goi_ukernel_fn) xnn_x32_packw_gemm_goi_ukernel_x8__wasmsimd_u4; - f32_gemm_config.mr = 5; - f32_gemm_config.nr = 8; + f32_igemm_config.linear.gemm[XNN_MR_TO_INDEX(1)] = xnn_init_hmp_gemm_ukernel((xnn_gemm_ukernel_fn) xnn_f32_gemm_ukernel_1x8__wasmsimd_splat); + f32_igemm_config.linear.gemm[XNN_MR_TO_INDEX(5)] = xnn_init_hmp_gemm_ukernel((xnn_gemm_ukernel_fn) xnn_f32_gemm_ukernel_5x8__wasmsimd_splat); + f32_igemm_config.linear.igemm[XNN_MR_TO_INDEX(1)] = xnn_init_hmp_igemm_ukernel((xnn_igemm_ukernel_fn) xnn_f32_igemm_ukernel_1x8__wasmsimd_splat); + f32_igemm_config.linear.igemm[XNN_MR_TO_INDEX(5)] = xnn_init_hmp_igemm_ukernel((xnn_igemm_ukernel_fn) xnn_f32_igemm_ukernel_5x8__wasmsimd_splat); + f32_igemm_config.minmax.gemm[XNN_MR_TO_INDEX(1)] = xnn_init_hmp_gemm_ukernel((xnn_gemm_ukernel_fn) xnn_f32_gemm_minmax_ukernel_1x8__wasmsimd_arm_splat); + f32_igemm_config.minmax.gemm[XNN_MR_TO_INDEX(5)] = xnn_init_hmp_gemm_ukernel((xnn_gemm_ukernel_fn) xnn_f32_gemm_minmax_ukernel_5x8__wasmsimd_arm_splat); + f32_igemm_config.minmax.igemm[XNN_MR_TO_INDEX(1)] = xnn_init_hmp_igemm_ukernel((xnn_igemm_ukernel_fn) xnn_f32_igemm_minmax_ukernel_1x8__wasmsimd_arm_splat); + f32_igemm_config.minmax.igemm[XNN_MR_TO_INDEX(5)] = xnn_init_hmp_igemm_ukernel((xnn_igemm_ukernel_fn) xnn_f32_igemm_minmax_ukernel_5x8__wasmsimd_arm_splat); + f32_igemm_config.relu.gemm[XNN_MR_TO_INDEX(1)] = xnn_init_hmp_gemm_ukernel((xnn_gemm_ukernel_fn) xnn_f32_gemm_relu_ukernel_1x8__wasmsimd_splat); + f32_igemm_config.relu.gemm[XNN_MR_TO_INDEX(5)] = xnn_init_hmp_gemm_ukernel((xnn_gemm_ukernel_fn) xnn_f32_gemm_relu_ukernel_5x8__wasmsimd_splat); + f32_igemm_config.relu.igemm[XNN_MR_TO_INDEX(1)] = xnn_init_hmp_igemm_ukernel((xnn_igemm_ukernel_fn) xnn_f32_igemm_relu_ukernel_1x8__wasmsimd_splat); + f32_igemm_config.relu.igemm[XNN_MR_TO_INDEX(5)] = xnn_init_hmp_igemm_ukernel((xnn_igemm_ukernel_fn) xnn_f32_igemm_relu_ukernel_5x8__wasmsimd_splat); + f32_igemm_config.init.f32 = xnn_init_f32_minmax_scalar_params; + f32_igemm_config.pack_gemm_gio = (xnn_packw_gemm_gio_ukernel_fn) xnn_pack_f32_gemm_gio_w; + f32_igemm_config.pack_gemm_goi = (xnn_packw_gemm_goi_ukernel_fn) xnn_x32_packw_gemm_goi_ukernel_x8__wasmsimd_u4; + f32_igemm_config.mr = 5; + f32_igemm_config.nr = 8; #endif } #elif XNN_ARCH_WASM @@ -923,81 +926,104 @@ static void init_f32_gemm_config(void) { assert(hardware_config != NULL); (void) hardware_config; // May be unused. if (hardware_config->is_x86) { - f32_gemm_config.linear.gemm[XNN_MR_TO_INDEX(1)] = xnn_init_hmp_gemm_ukernel((xnn_gemm_ukernel_fn) xnn_f32_gemm_ukernel_1x4__scalar); - f32_gemm_config.linear.gemm[XNN_MR_TO_INDEX(2)] = xnn_init_hmp_gemm_ukernel((xnn_gemm_ukernel_fn) xnn_f32_gemm_ukernel_2x4__scalar); - f32_gemm_config.linear.igemm[XNN_MR_TO_INDEX(1)] = xnn_init_hmp_igemm_ukernel((xnn_igemm_ukernel_fn) xnn_f32_igemm_ukernel_1x4__scalar); - f32_gemm_config.linear.igemm[XNN_MR_TO_INDEX(2)] = xnn_init_hmp_igemm_ukernel((xnn_igemm_ukernel_fn) xnn_f32_igemm_ukernel_2x4__scalar); - f32_gemm_config.minmax.gemm[XNN_MR_TO_INDEX(1)] = xnn_init_hmp_gemm_ukernel((xnn_gemm_ukernel_fn) xnn_f32_gemm_minmax_ukernel_1x4__wasm); - f32_gemm_config.minmax.gemm[XNN_MR_TO_INDEX(2)] = xnn_init_hmp_gemm_ukernel((xnn_gemm_ukernel_fn) xnn_f32_gemm_minmax_ukernel_2x4__scalar); - f32_gemm_config.minmax.igemm[XNN_MR_TO_INDEX(1)] = xnn_init_hmp_igemm_ukernel((xnn_igemm_ukernel_fn) xnn_f32_igemm_minmax_ukernel_1x4__wasm); - f32_gemm_config.minmax.igemm[XNN_MR_TO_INDEX(2)] = xnn_init_hmp_igemm_ukernel((xnn_igemm_ukernel_fn) xnn_f32_igemm_minmax_ukernel_2x4__scalar); - f32_gemm_config.relu.gemm[XNN_MR_TO_INDEX(1)] = xnn_init_hmp_gemm_ukernel((xnn_gemm_ukernel_fn) xnn_f32_gemm_relu_ukernel_1x4__wasm); - f32_gemm_config.relu.gemm[XNN_MR_TO_INDEX(2)] = xnn_init_hmp_gemm_ukernel((xnn_gemm_ukernel_fn) xnn_f32_gemm_relu_ukernel_2x4__scalar); - f32_gemm_config.relu.igemm[XNN_MR_TO_INDEX(1)] = xnn_init_hmp_igemm_ukernel((xnn_igemm_ukernel_fn) xnn_f32_igemm_relu_ukernel_1x4__wasm); - f32_gemm_config.relu.igemm[XNN_MR_TO_INDEX(2)] = xnn_init_hmp_igemm_ukernel((xnn_igemm_ukernel_fn) xnn_f32_igemm_relu_ukernel_2x4__scalar); - f32_gemm_config.init.f32 = xnn_init_f32_minmax_scalar_params; - f32_gemm_config.pack_gemm_gio = (xnn_packw_gemm_gio_ukernel_fn) xnn_pack_f32_gemm_gio_w; - f32_gemm_config.pack_gemm_goi = (xnn_packw_gemm_goi_ukernel_fn) xnn_x32_packw_gemm_goi_ukernel_x4__scalar_float_u4; - f32_gemm_config.mr = 2; - f32_gemm_config.nr = 4; + f32_igemm_config.linear.gemm[XNN_MR_TO_INDEX(1)] = xnn_init_hmp_gemm_ukernel((xnn_gemm_ukernel_fn) xnn_f32_gemm_ukernel_1x4__scalar); + f32_igemm_config.linear.gemm[XNN_MR_TO_INDEX(2)] = xnn_init_hmp_gemm_ukernel((xnn_gemm_ukernel_fn) xnn_f32_gemm_ukernel_2x4__scalar); + f32_igemm_config.linear.igemm[XNN_MR_TO_INDEX(1)] = xnn_init_hmp_igemm_ukernel((xnn_igemm_ukernel_fn) xnn_f32_igemm_ukernel_1x4__scalar); + f32_igemm_config.linear.igemm[XNN_MR_TO_INDEX(2)] = xnn_init_hmp_igemm_ukernel((xnn_igemm_ukernel_fn) xnn_f32_igemm_ukernel_2x4__scalar); + f32_igemm_config.minmax.gemm[XNN_MR_TO_INDEX(1)] = xnn_init_hmp_gemm_ukernel((xnn_gemm_ukernel_fn) xnn_f32_gemm_minmax_ukernel_1x4__wasm); + f32_igemm_config.minmax.gemm[XNN_MR_TO_INDEX(2)] = xnn_init_hmp_gemm_ukernel((xnn_gemm_ukernel_fn) xnn_f32_gemm_minmax_ukernel_2x4__scalar); + f32_igemm_config.minmax.igemm[XNN_MR_TO_INDEX(1)] = xnn_init_hmp_igemm_ukernel((xnn_igemm_ukernel_fn) xnn_f32_igemm_minmax_ukernel_1x4__wasm); + f32_igemm_config.minmax.igemm[XNN_MR_TO_INDEX(2)] = xnn_init_hmp_igemm_ukernel((xnn_igemm_ukernel_fn) xnn_f32_igemm_minmax_ukernel_2x4__scalar); + f32_igemm_config.relu.gemm[XNN_MR_TO_INDEX(1)] = xnn_init_hmp_gemm_ukernel((xnn_gemm_ukernel_fn) xnn_f32_gemm_relu_ukernel_1x4__wasm); + f32_igemm_config.relu.gemm[XNN_MR_TO_INDEX(2)] = xnn_init_hmp_gemm_ukernel((xnn_gemm_ukernel_fn) xnn_f32_gemm_relu_ukernel_2x4__scalar); + f32_igemm_config.relu.igemm[XNN_MR_TO_INDEX(1)] = xnn_init_hmp_igemm_ukernel((xnn_igemm_ukernel_fn) xnn_f32_igemm_relu_ukernel_1x4__wasm); + f32_igemm_config.relu.igemm[XNN_MR_TO_INDEX(2)] = xnn_init_hmp_igemm_ukernel((xnn_igemm_ukernel_fn) xnn_f32_igemm_relu_ukernel_2x4__scalar); + f32_igemm_config.init.f32 = xnn_init_f32_minmax_scalar_params; + f32_igemm_config.pack_gemm_gio = (xnn_packw_gemm_gio_ukernel_fn) xnn_pack_f32_gemm_gio_w; + f32_igemm_config.pack_gemm_goi = (xnn_packw_gemm_goi_ukernel_fn) xnn_x32_packw_gemm_goi_ukernel_x4__scalar_float_u4; + f32_igemm_config.mr = 2; + f32_igemm_config.nr = 4; } else { - f32_gemm_config.linear.gemm[XNN_MR_TO_INDEX(1)] = xnn_init_hmp_gemm_ukernel((xnn_gemm_ukernel_fn) xnn_f32_gemm_ukernel_1x4__scalar); - f32_gemm_config.linear.gemm[XNN_MR_TO_INDEX(4)] = xnn_init_hmp_gemm_ukernel((xnn_gemm_ukernel_fn) xnn_f32_gemm_ukernel_4x4__scalar); - f32_gemm_config.linear.igemm[XNN_MR_TO_INDEX(1)] = xnn_init_hmp_igemm_ukernel((xnn_igemm_ukernel_fn) xnn_f32_igemm_ukernel_1x4__scalar); - f32_gemm_config.linear.igemm[XNN_MR_TO_INDEX(4)] = xnn_init_hmp_igemm_ukernel((xnn_igemm_ukernel_fn) xnn_f32_igemm_ukernel_4x4__scalar); - f32_gemm_config.minmax.gemm[XNN_MR_TO_INDEX(1)] = xnn_init_hmp_gemm_ukernel((xnn_gemm_ukernel_fn) xnn_f32_gemm_minmax_ukernel_1x4__wasm); - f32_gemm_config.minmax.gemm[XNN_MR_TO_INDEX(4)] = xnn_init_hmp_gemm_ukernel((xnn_gemm_ukernel_fn) xnn_f32_gemm_minmax_ukernel_4x4__wasm); - f32_gemm_config.minmax.igemm[XNN_MR_TO_INDEX(1)] = xnn_init_hmp_igemm_ukernel((xnn_igemm_ukernel_fn) xnn_f32_igemm_minmax_ukernel_1x4__wasm); - f32_gemm_config.minmax.igemm[XNN_MR_TO_INDEX(4)] = xnn_init_hmp_igemm_ukernel((xnn_igemm_ukernel_fn) xnn_f32_igemm_minmax_ukernel_4x4__wasm); - f32_gemm_config.relu.gemm[XNN_MR_TO_INDEX(1)] = xnn_init_hmp_gemm_ukernel((xnn_gemm_ukernel_fn) xnn_f32_gemm_relu_ukernel_1x4__wasm); - f32_gemm_config.relu.gemm[XNN_MR_TO_INDEX(4)] = xnn_init_hmp_gemm_ukernel((xnn_gemm_ukernel_fn) xnn_f32_gemm_relu_ukernel_4x4__wasm); - f32_gemm_config.relu.igemm[XNN_MR_TO_INDEX(1)] = xnn_init_hmp_igemm_ukernel((xnn_igemm_ukernel_fn) xnn_f32_igemm_relu_ukernel_1x4__wasm); - f32_gemm_config.relu.igemm[XNN_MR_TO_INDEX(4)] = xnn_init_hmp_igemm_ukernel((xnn_igemm_ukernel_fn) xnn_f32_igemm_relu_ukernel_4x4__wasm); - f32_gemm_config.init.f32 = xnn_init_f32_minmax_scalar_params; - f32_gemm_config.pack_gemm_gio = (xnn_packw_gemm_gio_ukernel_fn) xnn_pack_f32_gemm_gio_w; - f32_gemm_config.pack_gemm_goi = (xnn_packw_gemm_goi_ukernel_fn) xnn_x32_packw_gemm_goi_ukernel_x4__scalar_float_u4; - f32_gemm_config.mr = 4; - f32_gemm_config.nr = 4; + f32_igemm_config.linear.gemm[XNN_MR_TO_INDEX(1)] = xnn_init_hmp_gemm_ukernel((xnn_gemm_ukernel_fn) xnn_f32_gemm_ukernel_1x4__scalar); + f32_igemm_config.linear.gemm[XNN_MR_TO_INDEX(4)] = xnn_init_hmp_gemm_ukernel((xnn_gemm_ukernel_fn) xnn_f32_gemm_ukernel_4x4__scalar); + f32_igemm_config.linear.igemm[XNN_MR_TO_INDEX(1)] = xnn_init_hmp_igemm_ukernel((xnn_igemm_ukernel_fn) xnn_f32_igemm_ukernel_1x4__scalar); + f32_igemm_config.linear.igemm[XNN_MR_TO_INDEX(4)] = xnn_init_hmp_igemm_ukernel((xnn_igemm_ukernel_fn) xnn_f32_igemm_ukernel_4x4__scalar); + f32_igemm_config.minmax.gemm[XNN_MR_TO_INDEX(1)] = xnn_init_hmp_gemm_ukernel((xnn_gemm_ukernel_fn) xnn_f32_gemm_minmax_ukernel_1x4__wasm); + f32_igemm_config.minmax.gemm[XNN_MR_TO_INDEX(4)] = xnn_init_hmp_gemm_ukernel((xnn_gemm_ukernel_fn) xnn_f32_gemm_minmax_ukernel_4x4__wasm); + f32_igemm_config.minmax.igemm[XNN_MR_TO_INDEX(1)] = xnn_init_hmp_igemm_ukernel((xnn_igemm_ukernel_fn) xnn_f32_igemm_minmax_ukernel_1x4__wasm); + f32_igemm_config.minmax.igemm[XNN_MR_TO_INDEX(4)] = xnn_init_hmp_igemm_ukernel((xnn_igemm_ukernel_fn) xnn_f32_igemm_minmax_ukernel_4x4__wasm); + f32_igemm_config.relu.gemm[XNN_MR_TO_INDEX(1)] = xnn_init_hmp_gemm_ukernel((xnn_gemm_ukernel_fn) xnn_f32_gemm_relu_ukernel_1x4__wasm); + f32_igemm_config.relu.gemm[XNN_MR_TO_INDEX(4)] = xnn_init_hmp_gemm_ukernel((xnn_gemm_ukernel_fn) xnn_f32_gemm_relu_ukernel_4x4__wasm); + f32_igemm_config.relu.igemm[XNN_MR_TO_INDEX(1)] = xnn_init_hmp_igemm_ukernel((xnn_igemm_ukernel_fn) xnn_f32_igemm_relu_ukernel_1x4__wasm); + f32_igemm_config.relu.igemm[XNN_MR_TO_INDEX(4)] = xnn_init_hmp_igemm_ukernel((xnn_igemm_ukernel_fn) xnn_f32_igemm_relu_ukernel_4x4__wasm); + f32_igemm_config.init.f32 = xnn_init_f32_minmax_scalar_params; + f32_igemm_config.pack_gemm_gio = (xnn_packw_gemm_gio_ukernel_fn) xnn_pack_f32_gemm_gio_w; + f32_igemm_config.pack_gemm_goi = (xnn_packw_gemm_goi_ukernel_fn) xnn_x32_packw_gemm_goi_ukernel_x4__scalar_float_u4; + f32_igemm_config.mr = 4; + f32_igemm_config.nr = 4; } #elif XNN_ARCH_RISCV && XNN_ENABLE_RISCV_VECTOR const struct xnn_hardware_config* hardware_config = xnn_init_hardware_config(); assert(hardware_config != NULL); (void) hardware_config; // May be unused. if (hardware_config->use_riscv_vector) { - f32_gemm_config.minmax.gemm[XNN_MR_TO_INDEX(7)] = xnn_init_hmp_gemm_ukernel((xnn_gemm_ukernel_fn) xnn_f32_gemm_minmax_ukernel_7x4v__rvv); - f32_gemm_config.minmax.gemm[XNN_MR_TO_INDEX(1)] = xnn_init_hmp_gemm_ukernel((xnn_gemm_ukernel_fn) xnn_f32_gemm_minmax_ukernel_1x4v__rvv); - f32_gemm_config.minmax.igemm[XNN_MR_TO_INDEX(7)] = xnn_init_hmp_igemm_ukernel((xnn_igemm_ukernel_fn) xnn_f32_igemm_minmax_ukernel_7x4v__rvv); - f32_gemm_config.minmax.igemm[XNN_MR_TO_INDEX(1)] = xnn_init_hmp_igemm_ukernel((xnn_igemm_ukernel_fn) xnn_f32_igemm_minmax_ukernel_1x4v__rvv); - f32_gemm_config.init.f32 = xnn_init_f32_minmax_scalar_params; - f32_gemm_config.pack_gemm_gio = (xnn_packw_gemm_gio_ukernel_fn) xnn_pack_f32_gemm_gio_w; - f32_gemm_config.pack_gemm_goi = (xnn_packw_gemm_goi_ukernel_fn) xnn_x32_packw_gemm_goi_ukernel_x4v__rvv_u8; - f32_gemm_config.mr = 7; + f32_igemm_config.minmax.gemm[XNN_MR_TO_INDEX(7)] = xnn_init_hmp_gemm_ukernel((xnn_gemm_ukernel_fn) xnn_f32_gemm_minmax_ukernel_7x4v__rvv); + f32_igemm_config.minmax.gemm[XNN_MR_TO_INDEX(1)] = xnn_init_hmp_gemm_ukernel((xnn_gemm_ukernel_fn) xnn_f32_gemm_minmax_ukernel_1x4v__rvv); + f32_igemm_config.minmax.igemm[XNN_MR_TO_INDEX(7)] = xnn_init_hmp_igemm_ukernel((xnn_igemm_ukernel_fn) xnn_f32_igemm_minmax_ukernel_7x4v__rvv); + f32_igemm_config.minmax.igemm[XNN_MR_TO_INDEX(1)] = xnn_init_hmp_igemm_ukernel((xnn_igemm_ukernel_fn) xnn_f32_igemm_minmax_ukernel_1x4v__rvv); + f32_igemm_config.init.f32 = xnn_init_f32_minmax_scalar_params; + f32_igemm_config.pack_gemm_gio = (xnn_packw_gemm_gio_ukernel_fn) xnn_pack_f32_gemm_gio_w; + f32_igemm_config.pack_gemm_goi = (xnn_packw_gemm_goi_ukernel_fn) xnn_x32_packw_gemm_goi_ukernel_x4v__rvv_u8; + f32_igemm_config.mr = 7; // nr is set to vlen * 4 / sizeof(float) = 4 * VLENB * 8 / 32 = VLENB - f32_gemm_config.nr = hardware_config->vlenb; + f32_igemm_config.nr = hardware_config->vlenb; return; } #else - f32_gemm_config.linear.gemm[XNN_MR_TO_INDEX(1)] = xnn_init_hmp_gemm_ukernel((xnn_gemm_ukernel_fn) xnn_f32_gemm_ukernel_1x4__scalar); - f32_gemm_config.linear.gemm[XNN_MR_TO_INDEX(4)] = xnn_init_hmp_gemm_ukernel((xnn_gemm_ukernel_fn) xnn_f32_gemm_ukernel_4x4__scalar); - f32_gemm_config.linear.igemm[XNN_MR_TO_INDEX(1)] = xnn_init_hmp_igemm_ukernel((xnn_igemm_ukernel_fn) xnn_f32_igemm_ukernel_1x4__scalar); - f32_gemm_config.linear.igemm[XNN_MR_TO_INDEX(4)] = xnn_init_hmp_igemm_ukernel((xnn_igemm_ukernel_fn) xnn_f32_igemm_ukernel_4x4__scalar); - f32_gemm_config.minmax.gemm[XNN_MR_TO_INDEX(1)] = xnn_init_hmp_gemm_ukernel((xnn_gemm_ukernel_fn) xnn_f32_gemm_minmax_ukernel_1x4__scalar); - f32_gemm_config.minmax.gemm[XNN_MR_TO_INDEX(4)] = xnn_init_hmp_gemm_ukernel((xnn_gemm_ukernel_fn) xnn_f32_gemm_minmax_ukernel_4x4__scalar); - f32_gemm_config.minmax.igemm[XNN_MR_TO_INDEX(1)] = xnn_init_hmp_igemm_ukernel((xnn_igemm_ukernel_fn) xnn_f32_igemm_minmax_ukernel_1x4__scalar); - f32_gemm_config.minmax.igemm[XNN_MR_TO_INDEX(4)] = xnn_init_hmp_igemm_ukernel((xnn_igemm_ukernel_fn) xnn_f32_igemm_minmax_ukernel_4x4__scalar); - f32_gemm_config.relu.gemm[XNN_MR_TO_INDEX(1)] = xnn_init_hmp_gemm_ukernel((xnn_gemm_ukernel_fn) xnn_f32_gemm_relu_ukernel_1x4__scalar); - f32_gemm_config.relu.gemm[XNN_MR_TO_INDEX(4)] = xnn_init_hmp_gemm_ukernel((xnn_gemm_ukernel_fn) xnn_f32_gemm_relu_ukernel_4x4__scalar); - f32_gemm_config.relu.igemm[XNN_MR_TO_INDEX(1)] = xnn_init_hmp_igemm_ukernel((xnn_igemm_ukernel_fn) xnn_f32_igemm_relu_ukernel_1x4__scalar); - f32_gemm_config.relu.igemm[XNN_MR_TO_INDEX(4)] = xnn_init_hmp_igemm_ukernel((xnn_igemm_ukernel_fn) xnn_f32_igemm_relu_ukernel_4x4__scalar); - f32_gemm_config.init.f32 = xnn_init_f32_minmax_scalar_params; - f32_gemm_config.pack_gemm_gio = (xnn_packw_gemm_gio_ukernel_fn) xnn_pack_f32_gemm_gio_w; - f32_gemm_config.pack_gemm_goi = (xnn_packw_gemm_goi_ukernel_fn) xnn_x32_packw_gemm_goi_ukernel_x4__scalar_float_u4; - f32_gemm_config.mr = 4; - f32_gemm_config.nr = 4; + f32_igemm_config.linear.gemm[XNN_MR_TO_INDEX(1)] = xnn_init_hmp_gemm_ukernel((xnn_gemm_ukernel_fn) xnn_f32_gemm_ukernel_1x4__scalar); + f32_igemm_config.linear.gemm[XNN_MR_TO_INDEX(4)] = xnn_init_hmp_gemm_ukernel((xnn_gemm_ukernel_fn) xnn_f32_gemm_ukernel_4x4__scalar); + f32_igemm_config.linear.igemm[XNN_MR_TO_INDEX(1)] = xnn_init_hmp_igemm_ukernel((xnn_igemm_ukernel_fn) xnn_f32_igemm_ukernel_1x4__scalar); + f32_igemm_config.linear.igemm[XNN_MR_TO_INDEX(4)] = xnn_init_hmp_igemm_ukernel((xnn_igemm_ukernel_fn) xnn_f32_igemm_ukernel_4x4__scalar); + f32_igemm_config.minmax.gemm[XNN_MR_TO_INDEX(1)] = xnn_init_hmp_gemm_ukernel((xnn_gemm_ukernel_fn) xnn_f32_gemm_minmax_ukernel_1x4__scalar); + f32_igemm_config.minmax.gemm[XNN_MR_TO_INDEX(4)] = xnn_init_hmp_gemm_ukernel((xnn_gemm_ukernel_fn) xnn_f32_gemm_minmax_ukernel_4x4__scalar); + f32_igemm_config.minmax.igemm[XNN_MR_TO_INDEX(1)] = xnn_init_hmp_igemm_ukernel((xnn_igemm_ukernel_fn) xnn_f32_igemm_minmax_ukernel_1x4__scalar); + f32_igemm_config.minmax.igemm[XNN_MR_TO_INDEX(4)] = xnn_init_hmp_igemm_ukernel((xnn_igemm_ukernel_fn) xnn_f32_igemm_minmax_ukernel_4x4__scalar); + f32_igemm_config.relu.gemm[XNN_MR_TO_INDEX(1)] = xnn_init_hmp_gemm_ukernel((xnn_gemm_ukernel_fn) xnn_f32_gemm_relu_ukernel_1x4__scalar); + f32_igemm_config.relu.gemm[XNN_MR_TO_INDEX(4)] = xnn_init_hmp_gemm_ukernel((xnn_gemm_ukernel_fn) xnn_f32_gemm_relu_ukernel_4x4__scalar); + f32_igemm_config.relu.igemm[XNN_MR_TO_INDEX(1)] = xnn_init_hmp_igemm_ukernel((xnn_igemm_ukernel_fn) xnn_f32_igemm_relu_ukernel_1x4__scalar); + f32_igemm_config.relu.igemm[XNN_MR_TO_INDEX(4)] = xnn_init_hmp_igemm_ukernel((xnn_igemm_ukernel_fn) xnn_f32_igemm_relu_ukernel_4x4__scalar); + f32_igemm_config.init.f32 = xnn_init_f32_minmax_scalar_params; + f32_igemm_config.pack_gemm_gio = (xnn_packw_gemm_gio_ukernel_fn) xnn_pack_f32_gemm_gio_w; + f32_igemm_config.pack_gemm_goi = (xnn_packw_gemm_goi_ukernel_fn) xnn_x32_packw_gemm_goi_ukernel_x4__scalar_float_u4; + f32_igemm_config.mr = 4; + f32_igemm_config.nr = 4; #endif assert(f32_gemm_config.mr <= XNN_MAX_MR); } +static void init_f32_gemm_config(void) { + init_f32_igemm_config(); + memcpy(&f32_gemm_config, &f32_igemm_config, sizeof(f32_gemm_config)); + #if XNN_ARCH_X86 || XNN_ARCH_X86_64 + const struct xnn_hardware_config* hardware_config = xnn_init_hardware_config(); + assert(hardware_config != NULL); + (void) hardware_config; // May be unused. + #if XNN_ENABLE_AVX512F + if (!XNN_PLATFORM_MOBILE && hardware_config->use_x86_avx512f) { + printf("HERE\n"); + f32_gemm_config.minmax.gemm[XNN_MR_TO_INDEX(1)] = xnn_init_hmp_gemm_ukernel((xnn_gemm_ukernel_fn) xnn_f32_gemm_minmax_ukernel_1x32c2__asm_amd64_avx512f_broadcast); + f32_gemm_config.minmax.gemm[XNN_MR_TO_INDEX(5)] = xnn_init_hmp_gemm_ukernel((xnn_gemm_ukernel_fn) xnn_f32_gemm_minmax_ukernel_5x32c2__asm_amd64_avx512f_broadcast); + f32_gemm_config.init.f32 = xnn_init_f32_minmax_scalar_params; + f32_gemm_config.pack_gemm_gio = (xnn_packw_gemm_gio_ukernel_fn) xnn_pack_f32_gemm_gio_w; + f32_gemm_config.pack_gemm_goi = (xnn_packw_gemm_goi_ukernel_fn) xnn_pack_f32_gemm_goi_w; + f32_gemm_config.mr = 5; + f32_gemm_config.nr = 32; + f32_gemm_config.log2_kr = 1; + } + #endif + #endif +} + static void init_f32_gemm_nr2_config(void) { #if XNN_ARCH_ARM const struct xnn_hardware_config* hardware_config = xnn_init_hardware_config(); @@ -1334,9 +1360,9 @@ static void init_f32_qc8w_gemm_config(void) { } #if XNN_MAX_UARCH_TYPES > 1 /* Choose micro-kernels for little cores according to micro-kernel specification for the big core */ - const uint32_t mr = f32_gemm_config.mr; - const uint32_t nr = f32_gemm_config.nr; - const uint32_t log2_sr = f32_gemm_config.log2_sr; + const uint32_t mr = f32_igemm_config.mr; + const uint32_t nr = f32_igemm_config.nr; + const uint32_t log2_sr = f32_igemm_config.log2_sr; // TODO(fbarchard): fill in with microkernels. (void) mr; (void) nr; @@ -4353,6 +4379,14 @@ const struct xnn_gemm_config* xnn_init_f32_gemm_config() { return &f32_gemm_config; } +const struct xnn_gemm_config* xnn_init_f32_igemm_config() { + if (xnn_init_hardware_config() == NULL) { + return NULL; + } + XNN_INIT_ONCE(f32_igemm); + return &f32_igemm_config; +} + const struct xnn_gemm_config* xnn_init_f32_gemm_nr2_config() { if (xnn_init_hardware_config() == NULL) { return NULL; diff --git a/src/f32-gemm/gen/f32-gemm-10x16-minmax-asm-amd64-avx512f-broadcast.S b/src/f32-gemm/gen/f32-gemm-10x16-minmax-asm-amd64-avx512f-broadcast.S index 5aeef004b89..2ea73a895d9 100644 --- a/src/f32-gemm/gen/f32-gemm-10x16-minmax-asm-amd64-avx512f-broadcast.S +++ b/src/f32-gemm/gen/f32-gemm-10x16-minmax-asm-amd64-avx512f-broadcast.S @@ -167,50 +167,50 @@ outer_loop: mov r8, [rsp + 160] # Initialize accumulators with the biases. - vmovaps zmm7, [r9 + 0] - vmovaps zmm8, zmm7 - vmovaps zmm9, zmm7 - vmovaps zmm14, zmm7 - vmovaps zmm15, zmm7 - vmovaps zmm16, zmm7 - vmovaps zmm17, zmm7 - vmovaps zmm18, zmm7 - vmovaps zmm19, zmm7 - vmovaps zmm20, zmm7 + vmovaps zmm11, [r9 + 0] + vmovaps zmm12, zmm11 + vmovaps zmm13, zmm11 + vmovaps zmm14, zmm11 + vmovaps zmm15, zmm11 + vmovaps zmm16, zmm11 + vmovaps zmm17, zmm11 + vmovaps zmm18, zmm11 + vmovaps zmm19, zmm11 + vmovaps zmm20, zmm11 add r9, 64 inner_loop: - vmovaps zmm10, [r9 + 0] + vmovaps zmm7, [r9 + 0] add r9, 64 vbroadcastss zmm2, DWORD PTR [rcx + r11] - vfmadd231ps zmm7, zmm2, zmm10 + vfmadd231ps zmm11, zmm2, zmm7 vbroadcastss zmm2, DWORD PTR [rax + r11] - vfmadd231ps zmm8, zmm2, zmm10 + vfmadd231ps zmm12, zmm2, zmm7 vbroadcastss zmm2, DWORD PTR [r15 + r11] - vfmadd231ps zmm9, zmm2, zmm10 + vfmadd231ps zmm13, zmm2, zmm7 vbroadcastss zmm2, DWORD PTR [r14 + r11] - vfmadd231ps zmm14, zmm2, zmm10 + vfmadd231ps zmm14, zmm2, zmm7 vbroadcastss zmm2, DWORD PTR [r12 + r11] - vfmadd231ps zmm15, zmm2, zmm10 + vfmadd231ps zmm15, zmm2, zmm7 vbroadcastss zmm2, DWORD PTR [r10 + r11] - vfmadd231ps zmm16, zmm2, zmm10 + vfmadd231ps zmm16, zmm2, zmm7 vbroadcastss zmm2, DWORD PTR [r13 + r11] - vfmadd231ps zmm17, zmm2, zmm10 + vfmadd231ps zmm17, zmm2, zmm7 vbroadcastss zmm2, DWORD PTR [rbx + r11] - vfmadd231ps zmm18, zmm2, zmm10 + vfmadd231ps zmm18, zmm2, zmm7 vbroadcastss zmm2, DWORD PTR [rbp + r11] - vfmadd231ps zmm19, zmm2, zmm10 + vfmadd231ps zmm19, zmm2, zmm7 vbroadcastss zmm2, DWORD PTR [r8 + r11] - vfmadd231ps zmm20, zmm2, zmm10 + vfmadd231ps zmm20, zmm2, zmm7 add r11, 4 cmp rdx, r11 jne inner_loop inner_loop_end: # Min/max clamping. - vminps zmm7, zmm1, zmm7 - vminps zmm8, zmm1, zmm8 - vminps zmm9, zmm1, zmm9 + vminps zmm11, zmm1, zmm11 + vminps zmm12, zmm1, zmm12 + vminps zmm13, zmm1, zmm13 vminps zmm14, zmm1, zmm14 vminps zmm15, zmm1, zmm15 vminps zmm16, zmm1, zmm16 @@ -218,9 +218,9 @@ inner_loop_end: vminps zmm18, zmm1, zmm18 vminps zmm19, zmm1, zmm19 vminps zmm20, zmm1, zmm20 - vmaxps zmm7, zmm0, zmm7 - vmaxps zmm8, zmm0, zmm8 - vmaxps zmm9, zmm0, zmm9 + vmaxps zmm11, zmm0, zmm11 + vmaxps zmm12, zmm0, zmm12 + vmaxps zmm13, zmm0, zmm13 vmaxps zmm14, zmm0, zmm14 vmaxps zmm15, zmm0, zmm15 vmaxps zmm16, zmm0, zmm16 @@ -245,9 +245,9 @@ inner_loop_end: cmp rsi, 16 jl tail - vmovups [rcx], zmm7 - vmovups [rax], zmm8 - vmovups [r15], zmm9 + vmovups [rcx], zmm11 + vmovups [rax], zmm12 + vmovups [r15], zmm13 vmovups [r14], zmm14 vmovups [r12], zmm15 vmovups [r10], zmm16 @@ -287,9 +287,9 @@ tail: shlx r11, r11, rsi not r11 kmovw k1, r11d - vmovups ZMMWORD PTR [rcx]{k1}, zmm7 - vmovups ZMMWORD PTR [rax]{k1}, zmm8 - vmovups ZMMWORD PTR [r15]{k1}, zmm9 + vmovups ZMMWORD PTR [rcx]{k1}, zmm11 + vmovups ZMMWORD PTR [rax]{k1}, zmm12 + vmovups ZMMWORD PTR [r15]{k1}, zmm13 vmovups ZMMWORD PTR [r14]{k1}, zmm14 vmovups ZMMWORD PTR [r12]{k1}, zmm15 vmovups ZMMWORD PTR [r10]{k1}, zmm16 diff --git a/src/f32-gemm/gen/f32-gemm-10x16c2-minmax-asm-amd64-avx512f-broadcast.S b/src/f32-gemm/gen/f32-gemm-10x16c2-minmax-asm-amd64-avx512f-broadcast.S new file mode 100644 index 00000000000..b253a1ea931 --- /dev/null +++ b/src/f32-gemm/gen/f32-gemm-10x16c2-minmax-asm-amd64-avx512f-broadcast.S @@ -0,0 +1,470 @@ +// Copyright 2025 Google LLC +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + +#include "xnnpack/assembly.h" +.PERMUTATION: + .long 0 + .long 2 + .long 4 + .long 6 + .long 8 + .long 10 + .long 12 + .long 14 + .long 16 + .long 18 + .long 20 + .long 22 + .long 24 + .long 26 + .long 28 + .long 30 + +BEGIN_FUNCTION xnn_f32_gemm_minmax_ukernel_10x16c2__asm_amd64_avx512f_broadcast + + .intel_syntax noprefix + + # Free up GP registers. + push rbx + push rbp + push r15 + push r14 + push r13 + push r12 + + # load params to free up a GP registers + mov r13, [rsp + 80] # params + vbroadcastss zmm0, DWORD PTR [r13] + vbroadcastss zmm1, DWORD PTR [r13 + 4] + + # Load c pointer. + mov r10, [rsp + 56] + # Load cm_stride. + mov r11, [rsp + 64] + + # Align the stack pointer. + mov r13, rsp + sub rsp, 64 + and rsp, 0xFFFFFFFFFFFFFFC0 + # Store the old stack pointer containing the return address + mov [rsp], r13 + + # Allocate some space on the stack. + sub rsp, 256 + # Write rsi (a pointer) to the stack as we need the register. + mov [rsp + 16], rcx + # Write r10 (c pointer) to the stack as we need the register. + mov [rsp + 24], r10 + + # Clamp a & c pointers if mr <= 1 + mov rax, rcx + add rax, r8 + mov r13, r10 + add r13, r11 + cmp rdi, 1 + cmovle rax, rcx + cmovle r13, r10 + + mov [rsp + 32], rax + mov [rsp + 40], r13 + + # Clamp a & c pointers if mr <= 2 + mov rcx, rax + add rcx, r8 + mov r10, r13 + add r10, r11 + cmp rdi, 2 + cmovle rcx, rax + cmovle r10, r13 + + mov [rsp + 48], rcx + mov [rsp + 56], r10 + + # Clamp a & c pointers if mr <= 3 + mov rax, rcx + add rax, r8 + mov r13, r10 + add r13, r11 + cmp rdi, 3 + cmovle rax, rcx + cmovle r13, r10 + + mov [rsp + 64], rax + mov [rsp + 72], r13 + + # Clamp a & c pointers if mr <= 4 + mov rcx, rax + add rcx, r8 + mov r10, r13 + add r10, r11 + cmp rdi, 4 + cmovle rcx, rax + cmovle r10, r13 + + mov [rsp + 80], rcx + mov [rsp + 88], r10 + + # Clamp a & c pointers if mr <= 5 + mov rax, rcx + add rax, r8 + mov r13, r10 + add r13, r11 + cmp rdi, 5 + cmovle rax, rcx + cmovle r13, r10 + + mov [rsp + 96], rax + mov [rsp + 104], r13 + + # Clamp a & c pointers if mr <= 6 + mov rcx, rax + add rcx, r8 + mov r10, r13 + add r10, r11 + cmp rdi, 6 + cmovle rcx, rax + cmovle r10, r13 + + mov [rsp + 112], rcx + mov [rsp + 120], r10 + + # Clamp a & c pointers if mr <= 7 + mov rax, rcx + add rax, r8 + mov r13, r10 + add r13, r11 + cmp rdi, 7 + cmovle rax, rcx + cmovle r13, r10 + + mov [rsp + 128], rax + mov [rsp + 136], r13 + + # Clamp a & c pointers if mr <= 8 + mov rcx, rax + add rcx, r8 + mov r10, r13 + add r10, r11 + cmp rdi, 8 + cmovle rcx, rax + cmovle r10, r13 + + mov [rsp + 144], rcx + mov [rsp + 152], r10 + + # Clamp a & c pointers if mr <= 9 + mov rax, rcx + add rax, r8 + mov r13, r10 + add r13, r11 + cmp rdi, 9 + cmovle rax, rcx + cmovle r13, r10 + + mov [rsp + 160], rax + mov [rsp + 168], r13 + + # Copy k and flip bit. + mov r11, rdx + and r11, 0x4 + and rdx, 0xFFFFFFFFFFFFFFFB + mov [rsp + 184], r11 + mov r11, 0x5555 + kmovw k3, r11d + +outer_loop: + # Initialize k counter. + mov r11, 0 + # Read a pointers from stack into GP registers. + mov rcx, [rsp + 16] + mov rax, [rsp + 32] + mov r15, [rsp + 48] + mov r14, [rsp + 64] + mov r12, [rsp + 80] + mov r10, [rsp + 96] + mov r13, [rsp + 112] + mov rbx, [rsp + 128] + mov rbp, [rsp + 144] + mov r8, [rsp + 160] + + vmovaps zmm7, [r9 + 0] + # Interleave with zeros. + vpmovzxdq zmm11, ymm7 + vextracti64x4 ymm7, zmm7, 1 + vpmovzxdq zmm21, ymm7 + vmovaps zmm12, zmm11 + vmovaps zmm13, zmm11 + vmovaps zmm14, zmm11 + vmovaps zmm15, zmm11 + vmovaps zmm16, zmm11 + vmovaps zmm17, zmm11 + vmovaps zmm18, zmm11 + vmovaps zmm19, zmm11 + vmovaps zmm20, zmm11 + vmovaps zmm22, zmm21 + vmovaps zmm23, zmm21 + vmovaps zmm24, zmm21 + vmovaps zmm25, zmm21 + vmovaps zmm26, zmm21 + vmovaps zmm27, zmm21 + vmovaps zmm28, zmm21 + vmovaps zmm29, zmm21 + vmovaps zmm30, zmm21 + add r9, 64 + + # Are there at least 8 bytes? + cmp rdx, 8 + js inner_loop_tail + +inner_loop: + vmovaps zmm7, [r9 + 0] + vmovaps zmm8, [r9 + 64] + add r9, 128 + vbroadcastsd zmm2, QWORD PTR [rcx + r11] + vfmadd231ps zmm11, zmm2, zmm7 + vfmadd231ps zmm21, zmm2, zmm8 + vbroadcastsd zmm2, QWORD PTR [rax + r11] + vfmadd231ps zmm12, zmm2, zmm7 + vfmadd231ps zmm22, zmm2, zmm8 + vbroadcastsd zmm2, QWORD PTR [r15 + r11] + vfmadd231ps zmm13, zmm2, zmm7 + vfmadd231ps zmm23, zmm2, zmm8 + vbroadcastsd zmm2, QWORD PTR [r14 + r11] + vfmadd231ps zmm14, zmm2, zmm7 + vfmadd231ps zmm24, zmm2, zmm8 + vbroadcastsd zmm2, QWORD PTR [r12 + r11] + vfmadd231ps zmm15, zmm2, zmm7 + vfmadd231ps zmm25, zmm2, zmm8 + vbroadcastsd zmm2, QWORD PTR [r10 + r11] + vfmadd231ps zmm16, zmm2, zmm7 + vfmadd231ps zmm26, zmm2, zmm8 + vbroadcastsd zmm2, QWORD PTR [r13 + r11] + vfmadd231ps zmm17, zmm2, zmm7 + vfmadd231ps zmm27, zmm2, zmm8 + vbroadcastsd zmm2, QWORD PTR [rbx + r11] + vfmadd231ps zmm18, zmm2, zmm7 + vfmadd231ps zmm28, zmm2, zmm8 + vbroadcastsd zmm2, QWORD PTR [rbp + r11] + vfmadd231ps zmm19, zmm2, zmm7 + vfmadd231ps zmm29, zmm2, zmm8 + vbroadcastsd zmm2, QWORD PTR [r8 + r11] + vfmadd231ps zmm20, zmm2, zmm7 + vfmadd231ps zmm30, zmm2, zmm8 + + add r11, 8 + cmp rdx, r11 + jne inner_loop + + # Store nc_register. + mov [rsp + 192], rsi + # Load odd k bit. + mov rsi, [rsp + 184] + # Check if channels are odd. + test rsi, rsi + mov rsi, [rsp + 192] + jz inner_loop_end + +inner_loop_tail: + vmovaps zmm7, [r9 + 0] + vmovaps zmm8, [r9 + 64] + add r9, 128 + vbroadcastsd zmm2, QWORD PTR [rcx + r11] + vfmadd231ps zmm11{k3}, zmm2, zmm7 + vfmadd231ps zmm21{k3}, zmm2, zmm8 + vbroadcastsd zmm2, QWORD PTR [rax + r11] + vfmadd231ps zmm12{k3}, zmm2, zmm7 + vfmadd231ps zmm22{k3}, zmm2, zmm8 + vbroadcastsd zmm2, QWORD PTR [r15 + r11] + vfmadd231ps zmm13{k3}, zmm2, zmm7 + vfmadd231ps zmm23{k3}, zmm2, zmm8 + vbroadcastsd zmm2, QWORD PTR [r14 + r11] + vfmadd231ps zmm14{k3}, zmm2, zmm7 + vfmadd231ps zmm24{k3}, zmm2, zmm8 + vbroadcastsd zmm2, QWORD PTR [r12 + r11] + vfmadd231ps zmm15{k3}, zmm2, zmm7 + vfmadd231ps zmm25{k3}, zmm2, zmm8 + vbroadcastsd zmm2, QWORD PTR [r10 + r11] + vfmadd231ps zmm16{k3}, zmm2, zmm7 + vfmadd231ps zmm26{k3}, zmm2, zmm8 + vbroadcastsd zmm2, QWORD PTR [r13 + r11] + vfmadd231ps zmm17{k3}, zmm2, zmm7 + vfmadd231ps zmm27{k3}, zmm2, zmm8 + vbroadcastsd zmm2, QWORD PTR [rbx + r11] + vfmadd231ps zmm18{k3}, zmm2, zmm7 + vfmadd231ps zmm28{k3}, zmm2, zmm8 + vbroadcastsd zmm2, QWORD PTR [rbp + r11] + vfmadd231ps zmm19{k3}, zmm2, zmm7 + vfmadd231ps zmm29{k3}, zmm2, zmm8 + vbroadcastsd zmm2, QWORD PTR [r8 + r11] + vfmadd231ps zmm20{k3}, zmm2, zmm7 + vfmadd231ps zmm30{k3}, zmm2, zmm8 +inner_loop_end: + vpsrlq zmm7, zmm11, 32 + vaddps zmm11, zmm11, zmm7 + vpsrlq zmm7, zmm12, 32 + vaddps zmm12, zmm12, zmm7 + vpsrlq zmm7, zmm13, 32 + vaddps zmm13, zmm13, zmm7 + vpsrlq zmm7, zmm14, 32 + vaddps zmm14, zmm14, zmm7 + vpsrlq zmm7, zmm15, 32 + vaddps zmm15, zmm15, zmm7 + vpsrlq zmm7, zmm16, 32 + vaddps zmm16, zmm16, zmm7 + vpsrlq zmm7, zmm17, 32 + vaddps zmm17, zmm17, zmm7 + vpsrlq zmm7, zmm18, 32 + vaddps zmm18, zmm18, zmm7 + vpsrlq zmm7, zmm19, 32 + vaddps zmm19, zmm19, zmm7 + vpsrlq zmm7, zmm20, 32 + vaddps zmm20, zmm20, zmm7 + vpsrlq zmm7, zmm21, 32 + vaddps zmm21, zmm21, zmm7 + vpsrlq zmm7, zmm22, 32 + vaddps zmm22, zmm22, zmm7 + vpsrlq zmm7, zmm23, 32 + vaddps zmm23, zmm23, zmm7 + vpsrlq zmm7, zmm24, 32 + vaddps zmm24, zmm24, zmm7 + vpsrlq zmm7, zmm25, 32 + vaddps zmm25, zmm25, zmm7 + vpsrlq zmm7, zmm26, 32 + vaddps zmm26, zmm26, zmm7 + vpsrlq zmm7, zmm27, 32 + vaddps zmm27, zmm27, zmm7 + vpsrlq zmm7, zmm28, 32 + vaddps zmm28, zmm28, zmm7 + vpsrlq zmm7, zmm29, 32 + vaddps zmm29, zmm29, zmm7 + vpsrlq zmm7, zmm30, 32 + vaddps zmm30, zmm30, zmm7 + vmovups zmm7, zmmword ptr [rip + .PERMUTATION] + vpermt2ps zmm11, zmm7, zmm21 + vpermt2ps zmm12, zmm7, zmm22 + vpermt2ps zmm13, zmm7, zmm23 + vpermt2ps zmm14, zmm7, zmm24 + vpermt2ps zmm15, zmm7, zmm25 + vpermt2ps zmm16, zmm7, zmm26 + vpermt2ps zmm17, zmm7, zmm27 + vpermt2ps zmm18, zmm7, zmm28 + vpermt2ps zmm19, zmm7, zmm29 + vpermt2ps zmm20, zmm7, zmm30 + # Min/max clamping. + vminps zmm11, zmm1, zmm11 + vminps zmm12, zmm1, zmm12 + vminps zmm13, zmm1, zmm13 + vminps zmm14, zmm1, zmm14 + vminps zmm15, zmm1, zmm15 + vminps zmm16, zmm1, zmm16 + vminps zmm17, zmm1, zmm17 + vminps zmm18, zmm1, zmm18 + vminps zmm19, zmm1, zmm19 + vminps zmm20, zmm1, zmm20 + vmaxps zmm11, zmm0, zmm11 + vmaxps zmm12, zmm0, zmm12 + vmaxps zmm13, zmm0, zmm13 + vmaxps zmm14, zmm0, zmm14 + vmaxps zmm15, zmm0, zmm15 + vmaxps zmm16, zmm0, zmm16 + vmaxps zmm17, zmm0, zmm17 + vmaxps zmm18, zmm0, zmm18 + vmaxps zmm19, zmm0, zmm19 + vmaxps zmm20, zmm0, zmm20 + + # Pop output pointers from the stack. + mov rcx, [rsp + 24] + mov rax, [rsp + 40] + mov r15, [rsp + 56] + mov r14, [rsp + 72] + mov r12, [rsp + 88] + mov r10, [rsp + 104] + mov r13, [rsp + 120] + mov rbx, [rsp + 136] + mov rbp, [rsp + 152] + mov r8, [rsp + 168] + + # Check whether full or partial store. + cmp rsi, 16 + jl tail + + vmovups [rcx], zmm11 + vmovups [rax], zmm12 + vmovups [r15], zmm13 + vmovups [r14], zmm14 + vmovups [r12], zmm15 + vmovups [r10], zmm16 + vmovups [r13], zmm17 + vmovups [rbx], zmm18 + vmovups [rbp], zmm19 + vmovups [r8], zmm20 + add rcx, 64 + add rax, 64 + add r15, 64 + add r14, 64 + add r12, 64 + add r10, 64 + add r13, 64 + add rbx, 64 + add rbp, 64 + add r8, 64 + + # Write output pointers to the stack. + mov [rsp + 24], rcx + mov [rsp + 40], rax + mov [rsp + 56], r15 + mov [rsp + 72], r14 + mov [rsp + 88], r12 + mov [rsp + 104], r10 + mov [rsp + 120], r13 + mov [rsp + 136], rbx + mov [rsp + 152], rbp + mov [rsp + 168], r8 + + sub rsi, 16 + jne outer_loop + jmp return + +tail: + mov r11, -1 + shlx r11, r11, rsi + not r11 + kmovw k1, r11d + vmovups ZMMWORD PTR [rcx]{k1}, zmm11 + vmovups ZMMWORD PTR [rax]{k1}, zmm12 + vmovups ZMMWORD PTR [r15]{k1}, zmm13 + vmovups ZMMWORD PTR [r14]{k1}, zmm14 + vmovups ZMMWORD PTR [r12]{k1}, zmm15 + vmovups ZMMWORD PTR [r10]{k1}, zmm16 + vmovups ZMMWORD PTR [r13]{k1}, zmm17 + vmovups ZMMWORD PTR [rbx]{k1}, zmm18 + vmovups ZMMWORD PTR [rbp]{k1}, zmm19 + vmovups ZMMWORD PTR [r8]{k1}, zmm20 + +return: + add rsp, 256 + mov r13, [rsp] + mov rsp, r13 + # Restore the callee saved registers. + pop r12 + pop r13 + pop r14 + pop r15 + pop rbp + pop rbx + ret +END_FUNCTION xnn_f32_gemm_minmax_ukernel_10x16c2__asm_amd64_avx512f_broadcast + + #ifdef __has_feature + #if __has_feature(dataflow_sanitizer) +BEGIN_FUNCTION xnn_f32_gemm_minmax_ukernel_10x16c2__asm_amd64_avx512f_broadcast.dfsan + .intel_syntax noprefix + # We could implement this by calling a function that implements the dfsan instrumentation. + # For now, just break, so if someone tries to use this, they'll know where the problem is. + int 3 + ret +END_FUNCTION xnn_f32_gemm_minmax_ukernel_10x16c2__asm_amd64_avx512f_broadcast.dfsan + #endif + #endif \ No newline at end of file diff --git a/src/f32-gemm/gen/f32-gemm-10x32-minmax-asm-amd64-avx512f-broadcast.S b/src/f32-gemm/gen/f32-gemm-10x32-minmax-asm-amd64-avx512f-broadcast.S index 273754bbe56..58040d4e460 100644 --- a/src/f32-gemm/gen/f32-gemm-10x32-minmax-asm-amd64-avx512f-broadcast.S +++ b/src/f32-gemm/gen/f32-gemm-10x32-minmax-asm-amd64-avx512f-broadcast.S @@ -167,17 +167,17 @@ outer_loop: mov r8, [rsp + 160] # Initialize accumulators with the biases. - vmovaps zmm7, [r9 + 0] + vmovaps zmm11, [r9 + 0] vmovaps zmm21, [r9 + 64] - vmovaps zmm8, zmm7 - vmovaps zmm9, zmm7 - vmovaps zmm14, zmm7 - vmovaps zmm15, zmm7 - vmovaps zmm16, zmm7 - vmovaps zmm17, zmm7 - vmovaps zmm18, zmm7 - vmovaps zmm19, zmm7 - vmovaps zmm20, zmm7 + vmovaps zmm12, zmm11 + vmovaps zmm13, zmm11 + vmovaps zmm14, zmm11 + vmovaps zmm15, zmm11 + vmovaps zmm16, zmm11 + vmovaps zmm17, zmm11 + vmovaps zmm18, zmm11 + vmovaps zmm19, zmm11 + vmovaps zmm20, zmm11 vmovaps zmm22, zmm21 vmovaps zmm23, zmm21 vmovaps zmm24, zmm21 @@ -190,48 +190,48 @@ outer_loop: add r9, 128 inner_loop: - vmovaps zmm10, [r9 + 0] - vmovaps zmm11, [r9 + 64] + vmovaps zmm7, [r9 + 0] + vmovaps zmm8, [r9 + 64] add r9, 128 vbroadcastss zmm2, DWORD PTR [rcx + r11] - vfmadd231ps zmm7, zmm2, zmm10 - vfmadd231ps zmm21, zmm2, zmm11 + vfmadd231ps zmm11, zmm2, zmm7 + vfmadd231ps zmm21, zmm2, zmm8 vbroadcastss zmm2, DWORD PTR [rax + r11] - vfmadd231ps zmm8, zmm2, zmm10 - vfmadd231ps zmm22, zmm2, zmm11 + vfmadd231ps zmm12, zmm2, zmm7 + vfmadd231ps zmm22, zmm2, zmm8 vbroadcastss zmm2, DWORD PTR [r15 + r11] - vfmadd231ps zmm9, zmm2, zmm10 - vfmadd231ps zmm23, zmm2, zmm11 + vfmadd231ps zmm13, zmm2, zmm7 + vfmadd231ps zmm23, zmm2, zmm8 vbroadcastss zmm2, DWORD PTR [r14 + r11] - vfmadd231ps zmm14, zmm2, zmm10 - vfmadd231ps zmm24, zmm2, zmm11 + vfmadd231ps zmm14, zmm2, zmm7 + vfmadd231ps zmm24, zmm2, zmm8 vbroadcastss zmm2, DWORD PTR [r12 + r11] - vfmadd231ps zmm15, zmm2, zmm10 - vfmadd231ps zmm25, zmm2, zmm11 + vfmadd231ps zmm15, zmm2, zmm7 + vfmadd231ps zmm25, zmm2, zmm8 vbroadcastss zmm2, DWORD PTR [r10 + r11] - vfmadd231ps zmm16, zmm2, zmm10 - vfmadd231ps zmm26, zmm2, zmm11 + vfmadd231ps zmm16, zmm2, zmm7 + vfmadd231ps zmm26, zmm2, zmm8 vbroadcastss zmm2, DWORD PTR [r13 + r11] - vfmadd231ps zmm17, zmm2, zmm10 - vfmadd231ps zmm27, zmm2, zmm11 + vfmadd231ps zmm17, zmm2, zmm7 + vfmadd231ps zmm27, zmm2, zmm8 vbroadcastss zmm2, DWORD PTR [rbx + r11] - vfmadd231ps zmm18, zmm2, zmm10 - vfmadd231ps zmm28, zmm2, zmm11 + vfmadd231ps zmm18, zmm2, zmm7 + vfmadd231ps zmm28, zmm2, zmm8 vbroadcastss zmm2, DWORD PTR [rbp + r11] - vfmadd231ps zmm19, zmm2, zmm10 - vfmadd231ps zmm29, zmm2, zmm11 + vfmadd231ps zmm19, zmm2, zmm7 + vfmadd231ps zmm29, zmm2, zmm8 vbroadcastss zmm2, DWORD PTR [r8 + r11] - vfmadd231ps zmm20, zmm2, zmm10 - vfmadd231ps zmm30, zmm2, zmm11 + vfmadd231ps zmm20, zmm2, zmm7 + vfmadd231ps zmm30, zmm2, zmm8 add r11, 4 cmp rdx, r11 jne inner_loop inner_loop_end: # Min/max clamping. - vminps zmm7, zmm1, zmm7 - vminps zmm8, zmm1, zmm8 - vminps zmm9, zmm1, zmm9 + vminps zmm11, zmm1, zmm11 + vminps zmm12, zmm1, zmm12 + vminps zmm13, zmm1, zmm13 vminps zmm14, zmm1, zmm14 vminps zmm15, zmm1, zmm15 vminps zmm16, zmm1, zmm16 @@ -249,9 +249,9 @@ inner_loop_end: vminps zmm28, zmm1, zmm28 vminps zmm29, zmm1, zmm29 vminps zmm30, zmm1, zmm30 - vmaxps zmm7, zmm0, zmm7 - vmaxps zmm8, zmm0, zmm8 - vmaxps zmm9, zmm0, zmm9 + vmaxps zmm11, zmm0, zmm11 + vmaxps zmm12, zmm0, zmm12 + vmaxps zmm13, zmm0, zmm13 vmaxps zmm14, zmm0, zmm14 vmaxps zmm15, zmm0, zmm15 vmaxps zmm16, zmm0, zmm16 @@ -286,11 +286,11 @@ inner_loop_end: cmp rsi, 32 jl tail - vmovups [rcx], zmm7 + vmovups [rcx], zmm11 vmovups [rcx + 64], zmm21 - vmovups [rax], zmm8 + vmovups [rax], zmm12 vmovups [rax + 64], zmm22 - vmovups [r15], zmm9 + vmovups [r15], zmm13 vmovups [r15 + 64], zmm23 vmovups [r14], zmm14 vmovups [r14 + 64], zmm24 @@ -340,11 +340,11 @@ tail: kmovw k1, r11d shr r11d, 16 kmovw k2, r11d - vmovups ZMMWORD PTR [rcx]{k1}, zmm7 + vmovups ZMMWORD PTR [rcx]{k1}, zmm11 vmovups ZMMWORD PTR [rcx + 64]{k2}, zmm21 - vmovups ZMMWORD PTR [rax]{k1}, zmm8 + vmovups ZMMWORD PTR [rax]{k1}, zmm12 vmovups ZMMWORD PTR [rax + 64]{k2}, zmm22 - vmovups ZMMWORD PTR [r15]{k1}, zmm9 + vmovups ZMMWORD PTR [r15]{k1}, zmm13 vmovups ZMMWORD PTR [r15 + 64]{k2}, zmm23 vmovups ZMMWORD PTR [r14]{k1}, zmm14 vmovups ZMMWORD PTR [r14 + 64]{k2}, zmm24 diff --git a/src/f32-gemm/gen/f32-gemm-11x16-minmax-asm-amd64-avx512f-broadcast.S b/src/f32-gemm/gen/f32-gemm-11x16-minmax-asm-amd64-avx512f-broadcast.S index eac6ea58e23..3361d9a70bc 100644 --- a/src/f32-gemm/gen/f32-gemm-11x16-minmax-asm-amd64-avx512f-broadcast.S +++ b/src/f32-gemm/gen/f32-gemm-11x16-minmax-asm-amd64-avx512f-broadcast.S @@ -180,53 +180,53 @@ outer_loop: mov rdi, [rsp + 176] # Initialize accumulators with the biases. - vmovaps zmm7, [r9 + 0] - vmovaps zmm8, zmm7 - vmovaps zmm9, zmm7 - vmovaps zmm14, zmm7 - vmovaps zmm15, zmm7 - vmovaps zmm16, zmm7 - vmovaps zmm17, zmm7 - vmovaps zmm18, zmm7 - vmovaps zmm19, zmm7 - vmovaps zmm20, zmm7 - vmovaps zmm21, zmm7 + vmovaps zmm11, [r9 + 0] + vmovaps zmm12, zmm11 + vmovaps zmm13, zmm11 + vmovaps zmm14, zmm11 + vmovaps zmm15, zmm11 + vmovaps zmm16, zmm11 + vmovaps zmm17, zmm11 + vmovaps zmm18, zmm11 + vmovaps zmm19, zmm11 + vmovaps zmm20, zmm11 + vmovaps zmm21, zmm11 add r9, 64 inner_loop: - vmovaps zmm10, [r9 + 0] + vmovaps zmm7, [r9 + 0] add r9, 64 vbroadcastss zmm2, DWORD PTR [rcx + r11] - vfmadd231ps zmm7, zmm2, zmm10 + vfmadd231ps zmm11, zmm2, zmm7 vbroadcastss zmm2, DWORD PTR [rax + r11] - vfmadd231ps zmm8, zmm2, zmm10 + vfmadd231ps zmm12, zmm2, zmm7 vbroadcastss zmm2, DWORD PTR [r15 + r11] - vfmadd231ps zmm9, zmm2, zmm10 + vfmadd231ps zmm13, zmm2, zmm7 vbroadcastss zmm2, DWORD PTR [r14 + r11] - vfmadd231ps zmm14, zmm2, zmm10 + vfmadd231ps zmm14, zmm2, zmm7 vbroadcastss zmm2, DWORD PTR [r12 + r11] - vfmadd231ps zmm15, zmm2, zmm10 + vfmadd231ps zmm15, zmm2, zmm7 vbroadcastss zmm2, DWORD PTR [r10 + r11] - vfmadd231ps zmm16, zmm2, zmm10 + vfmadd231ps zmm16, zmm2, zmm7 vbroadcastss zmm2, DWORD PTR [r13 + r11] - vfmadd231ps zmm17, zmm2, zmm10 + vfmadd231ps zmm17, zmm2, zmm7 vbroadcastss zmm2, DWORD PTR [rbx + r11] - vfmadd231ps zmm18, zmm2, zmm10 + vfmadd231ps zmm18, zmm2, zmm7 vbroadcastss zmm2, DWORD PTR [rbp + r11] - vfmadd231ps zmm19, zmm2, zmm10 + vfmadd231ps zmm19, zmm2, zmm7 vbroadcastss zmm2, DWORD PTR [r8 + r11] - vfmadd231ps zmm20, zmm2, zmm10 + vfmadd231ps zmm20, zmm2, zmm7 vbroadcastss zmm2, DWORD PTR [rdi + r11] - vfmadd231ps zmm21, zmm2, zmm10 + vfmadd231ps zmm21, zmm2, zmm7 add r11, 4 cmp rdx, r11 jne inner_loop inner_loop_end: # Min/max clamping. - vminps zmm7, zmm1, zmm7 - vminps zmm8, zmm1, zmm8 - vminps zmm9, zmm1, zmm9 + vminps zmm11, zmm1, zmm11 + vminps zmm12, zmm1, zmm12 + vminps zmm13, zmm1, zmm13 vminps zmm14, zmm1, zmm14 vminps zmm15, zmm1, zmm15 vminps zmm16, zmm1, zmm16 @@ -235,9 +235,9 @@ inner_loop_end: vminps zmm19, zmm1, zmm19 vminps zmm20, zmm1, zmm20 vminps zmm21, zmm1, zmm21 - vmaxps zmm7, zmm0, zmm7 - vmaxps zmm8, zmm0, zmm8 - vmaxps zmm9, zmm0, zmm9 + vmaxps zmm11, zmm0, zmm11 + vmaxps zmm12, zmm0, zmm12 + vmaxps zmm13, zmm0, zmm13 vmaxps zmm14, zmm0, zmm14 vmaxps zmm15, zmm0, zmm15 vmaxps zmm16, zmm0, zmm16 @@ -264,9 +264,9 @@ inner_loop_end: cmp rsi, 16 jl tail - vmovups [rcx], zmm7 - vmovups [rax], zmm8 - vmovups [r15], zmm9 + vmovups [rcx], zmm11 + vmovups [rax], zmm12 + vmovups [r15], zmm13 vmovups [r14], zmm14 vmovups [r12], zmm15 vmovups [r10], zmm16 @@ -309,9 +309,9 @@ tail: shlx r11, r11, rsi not r11 kmovw k1, r11d - vmovups ZMMWORD PTR [rcx]{k1}, zmm7 - vmovups ZMMWORD PTR [rax]{k1}, zmm8 - vmovups ZMMWORD PTR [r15]{k1}, zmm9 + vmovups ZMMWORD PTR [rcx]{k1}, zmm11 + vmovups ZMMWORD PTR [rax]{k1}, zmm12 + vmovups ZMMWORD PTR [r15]{k1}, zmm13 vmovups ZMMWORD PTR [r14]{k1}, zmm14 vmovups ZMMWORD PTR [r12]{k1}, zmm15 vmovups ZMMWORD PTR [r10]{k1}, zmm16 diff --git a/src/f32-gemm/gen/f32-gemm-11x16c2-minmax-asm-amd64-avx512f-broadcast.S b/src/f32-gemm/gen/f32-gemm-11x16c2-minmax-asm-amd64-avx512f-broadcast.S new file mode 100644 index 00000000000..ca5bd294914 --- /dev/null +++ b/src/f32-gemm/gen/f32-gemm-11x16c2-minmax-asm-amd64-avx512f-broadcast.S @@ -0,0 +1,503 @@ +// Copyright 2025 Google LLC +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + +#include "xnnpack/assembly.h" +.PERMUTATION: + .long 0 + .long 2 + .long 4 + .long 6 + .long 8 + .long 10 + .long 12 + .long 14 + .long 16 + .long 18 + .long 20 + .long 22 + .long 24 + .long 26 + .long 28 + .long 30 + +BEGIN_FUNCTION xnn_f32_gemm_minmax_ukernel_11x16c2__asm_amd64_avx512f_broadcast + + .intel_syntax noprefix + + # Free up GP registers. + push rbx + push rbp + push r15 + push r14 + push r13 + push r12 + + # load params to free up a GP registers + mov r13, [rsp + 80] # params + vbroadcastss zmm0, DWORD PTR [r13] + vbroadcastss zmm1, DWORD PTR [r13 + 4] + + # Load c pointer. + mov r10, [rsp + 56] + # Load cm_stride. + mov r11, [rsp + 64] + + # Align the stack pointer. + mov r13, rsp + sub rsp, 64 + and rsp, 0xFFFFFFFFFFFFFFC0 + # Store the old stack pointer containing the return address + mov [rsp], r13 + + # Allocate some space on the stack. + sub rsp, 256 + # Write rsi (a pointer) to the stack as we need the register. + mov [rsp + 16], rcx + # Write r10 (c pointer) to the stack as we need the register. + mov [rsp + 24], r10 + + # Clamp a & c pointers if mr <= 1 + mov rax, rcx + add rax, r8 + mov r13, r10 + add r13, r11 + cmp rdi, 1 + cmovle rax, rcx + cmovle r13, r10 + + mov [rsp + 32], rax + mov [rsp + 40], r13 + + # Clamp a & c pointers if mr <= 2 + mov rcx, rax + add rcx, r8 + mov r10, r13 + add r10, r11 + cmp rdi, 2 + cmovle rcx, rax + cmovle r10, r13 + + mov [rsp + 48], rcx + mov [rsp + 56], r10 + + # Clamp a & c pointers if mr <= 3 + mov rax, rcx + add rax, r8 + mov r13, r10 + add r13, r11 + cmp rdi, 3 + cmovle rax, rcx + cmovle r13, r10 + + mov [rsp + 64], rax + mov [rsp + 72], r13 + + # Clamp a & c pointers if mr <= 4 + mov rcx, rax + add rcx, r8 + mov r10, r13 + add r10, r11 + cmp rdi, 4 + cmovle rcx, rax + cmovle r10, r13 + + mov [rsp + 80], rcx + mov [rsp + 88], r10 + + # Clamp a & c pointers if mr <= 5 + mov rax, rcx + add rax, r8 + mov r13, r10 + add r13, r11 + cmp rdi, 5 + cmovle rax, rcx + cmovle r13, r10 + + mov [rsp + 96], rax + mov [rsp + 104], r13 + + # Clamp a & c pointers if mr <= 6 + mov rcx, rax + add rcx, r8 + mov r10, r13 + add r10, r11 + cmp rdi, 6 + cmovle rcx, rax + cmovle r10, r13 + + mov [rsp + 112], rcx + mov [rsp + 120], r10 + + # Clamp a & c pointers if mr <= 7 + mov rax, rcx + add rax, r8 + mov r13, r10 + add r13, r11 + cmp rdi, 7 + cmovle rax, rcx + cmovle r13, r10 + + mov [rsp + 128], rax + mov [rsp + 136], r13 + + # Clamp a & c pointers if mr <= 8 + mov rcx, rax + add rcx, r8 + mov r10, r13 + add r10, r11 + cmp rdi, 8 + cmovle rcx, rax + cmovle r10, r13 + + mov [rsp + 144], rcx + mov [rsp + 152], r10 + + # Clamp a & c pointers if mr <= 9 + mov rax, rcx + add rax, r8 + mov r13, r10 + add r13, r11 + cmp rdi, 9 + cmovle rax, rcx + cmovle r13, r10 + + mov [rsp + 160], rax + mov [rsp + 168], r13 + + # Clamp a & c pointers if mr <= 10 + mov rcx, rax + add rcx, r8 + mov r10, r13 + add r10, r11 + cmp rdi, 10 + cmovle rcx, rax + cmovle r10, r13 + + mov [rsp + 176], rcx + mov [rsp + 184], r10 + + # Copy k and flip bit. + mov r11, rdx + and r11, 0x4 + and rdx, 0xFFFFFFFFFFFFFFFB + mov [rsp + 200], r11 + mov r11, 0x5555 + kmovw k3, r11d + +outer_loop: + # Initialize k counter. + mov r11, 0 + # Read a pointers from stack into GP registers. + mov rcx, [rsp + 16] + mov rax, [rsp + 32] + mov r15, [rsp + 48] + mov r14, [rsp + 64] + mov r12, [rsp + 80] + mov r10, [rsp + 96] + mov r13, [rsp + 112] + mov rbx, [rsp + 128] + mov rbp, [rsp + 144] + mov r8, [rsp + 160] + mov rdi, [rsp + 176] + + vmovaps zmm7, [r9 + 0] + # Interleave with zeros. + vpmovzxdq zmm11, ymm7 + vextracti64x4 ymm7, zmm7, 1 + vpmovzxdq zmm22, ymm7 + vmovaps zmm12, zmm11 + vmovaps zmm13, zmm11 + vmovaps zmm14, zmm11 + vmovaps zmm15, zmm11 + vmovaps zmm16, zmm11 + vmovaps zmm17, zmm11 + vmovaps zmm18, zmm11 + vmovaps zmm19, zmm11 + vmovaps zmm20, zmm11 + vmovaps zmm21, zmm11 + vmovaps zmm23, zmm22 + vmovaps zmm24, zmm22 + vmovaps zmm25, zmm22 + vmovaps zmm26, zmm22 + vmovaps zmm27, zmm22 + vmovaps zmm28, zmm22 + vmovaps zmm29, zmm22 + vmovaps zmm30, zmm22 + vmovaps zmm9, zmm22 + vmovaps zmm10, zmm22 + add r9, 64 + + # Are there at least 8 bytes? + cmp rdx, 8 + js inner_loop_tail + +inner_loop: + vmovaps zmm7, [r9 + 0] + vmovaps zmm8, [r9 + 64] + add r9, 128 + vbroadcastsd zmm2, QWORD PTR [rcx + r11] + vfmadd231ps zmm11, zmm2, zmm7 + vfmadd231ps zmm22, zmm2, zmm8 + vbroadcastsd zmm2, QWORD PTR [rax + r11] + vfmadd231ps zmm12, zmm2, zmm7 + vfmadd231ps zmm23, zmm2, zmm8 + vbroadcastsd zmm2, QWORD PTR [r15 + r11] + vfmadd231ps zmm13, zmm2, zmm7 + vfmadd231ps zmm24, zmm2, zmm8 + vbroadcastsd zmm2, QWORD PTR [r14 + r11] + vfmadd231ps zmm14, zmm2, zmm7 + vfmadd231ps zmm25, zmm2, zmm8 + vbroadcastsd zmm2, QWORD PTR [r12 + r11] + vfmadd231ps zmm15, zmm2, zmm7 + vfmadd231ps zmm26, zmm2, zmm8 + vbroadcastsd zmm2, QWORD PTR [r10 + r11] + vfmadd231ps zmm16, zmm2, zmm7 + vfmadd231ps zmm27, zmm2, zmm8 + vbroadcastsd zmm2, QWORD PTR [r13 + r11] + vfmadd231ps zmm17, zmm2, zmm7 + vfmadd231ps zmm28, zmm2, zmm8 + vbroadcastsd zmm2, QWORD PTR [rbx + r11] + vfmadd231ps zmm18, zmm2, zmm7 + vfmadd231ps zmm29, zmm2, zmm8 + vbroadcastsd zmm2, QWORD PTR [rbp + r11] + vfmadd231ps zmm19, zmm2, zmm7 + vfmadd231ps zmm30, zmm2, zmm8 + vbroadcastsd zmm2, QWORD PTR [r8 + r11] + vfmadd231ps zmm20, zmm2, zmm7 + vfmadd231ps zmm9, zmm2, zmm8 + vbroadcastsd zmm2, QWORD PTR [rdi + r11] + vfmadd231ps zmm21, zmm2, zmm7 + vfmadd231ps zmm10, zmm2, zmm8 + + add r11, 8 + cmp rdx, r11 + jne inner_loop + + # Store nc_register. + mov [rsp + 208], rsi + # Load odd k bit. + mov rsi, [rsp + 200] + # Check if channels are odd. + test rsi, rsi + mov rsi, [rsp + 208] + jz inner_loop_end + +inner_loop_tail: + vmovaps zmm7, [r9 + 0] + vmovaps zmm8, [r9 + 64] + add r9, 128 + vbroadcastsd zmm2, QWORD PTR [rcx + r11] + vfmadd231ps zmm11{k3}, zmm2, zmm7 + vfmadd231ps zmm22{k3}, zmm2, zmm8 + vbroadcastsd zmm2, QWORD PTR [rax + r11] + vfmadd231ps zmm12{k3}, zmm2, zmm7 + vfmadd231ps zmm23{k3}, zmm2, zmm8 + vbroadcastsd zmm2, QWORD PTR [r15 + r11] + vfmadd231ps zmm13{k3}, zmm2, zmm7 + vfmadd231ps zmm24{k3}, zmm2, zmm8 + vbroadcastsd zmm2, QWORD PTR [r14 + r11] + vfmadd231ps zmm14{k3}, zmm2, zmm7 + vfmadd231ps zmm25{k3}, zmm2, zmm8 + vbroadcastsd zmm2, QWORD PTR [r12 + r11] + vfmadd231ps zmm15{k3}, zmm2, zmm7 + vfmadd231ps zmm26{k3}, zmm2, zmm8 + vbroadcastsd zmm2, QWORD PTR [r10 + r11] + vfmadd231ps zmm16{k3}, zmm2, zmm7 + vfmadd231ps zmm27{k3}, zmm2, zmm8 + vbroadcastsd zmm2, QWORD PTR [r13 + r11] + vfmadd231ps zmm17{k3}, zmm2, zmm7 + vfmadd231ps zmm28{k3}, zmm2, zmm8 + vbroadcastsd zmm2, QWORD PTR [rbx + r11] + vfmadd231ps zmm18{k3}, zmm2, zmm7 + vfmadd231ps zmm29{k3}, zmm2, zmm8 + vbroadcastsd zmm2, QWORD PTR [rbp + r11] + vfmadd231ps zmm19{k3}, zmm2, zmm7 + vfmadd231ps zmm30{k3}, zmm2, zmm8 + vbroadcastsd zmm2, QWORD PTR [r8 + r11] + vfmadd231ps zmm20{k3}, zmm2, zmm7 + vfmadd231ps zmm9{k3}, zmm2, zmm8 + vbroadcastsd zmm2, QWORD PTR [rdi + r11] + vfmadd231ps zmm21{k3}, zmm2, zmm7 + vfmadd231ps zmm10{k3}, zmm2, zmm8 +inner_loop_end: + vpsrlq zmm7, zmm11, 32 + vaddps zmm11, zmm11, zmm7 + vpsrlq zmm7, zmm12, 32 + vaddps zmm12, zmm12, zmm7 + vpsrlq zmm7, zmm13, 32 + vaddps zmm13, zmm13, zmm7 + vpsrlq zmm7, zmm14, 32 + vaddps zmm14, zmm14, zmm7 + vpsrlq zmm7, zmm15, 32 + vaddps zmm15, zmm15, zmm7 + vpsrlq zmm7, zmm16, 32 + vaddps zmm16, zmm16, zmm7 + vpsrlq zmm7, zmm17, 32 + vaddps zmm17, zmm17, zmm7 + vpsrlq zmm7, zmm18, 32 + vaddps zmm18, zmm18, zmm7 + vpsrlq zmm7, zmm19, 32 + vaddps zmm19, zmm19, zmm7 + vpsrlq zmm7, zmm20, 32 + vaddps zmm20, zmm20, zmm7 + vpsrlq zmm7, zmm21, 32 + vaddps zmm21, zmm21, zmm7 + vpsrlq zmm7, zmm22, 32 + vaddps zmm22, zmm22, zmm7 + vpsrlq zmm7, zmm23, 32 + vaddps zmm23, zmm23, zmm7 + vpsrlq zmm7, zmm24, 32 + vaddps zmm24, zmm24, zmm7 + vpsrlq zmm7, zmm25, 32 + vaddps zmm25, zmm25, zmm7 + vpsrlq zmm7, zmm26, 32 + vaddps zmm26, zmm26, zmm7 + vpsrlq zmm7, zmm27, 32 + vaddps zmm27, zmm27, zmm7 + vpsrlq zmm7, zmm28, 32 + vaddps zmm28, zmm28, zmm7 + vpsrlq zmm7, zmm29, 32 + vaddps zmm29, zmm29, zmm7 + vpsrlq zmm7, zmm30, 32 + vaddps zmm30, zmm30, zmm7 + vpsrlq zmm7, zmm9, 32 + vaddps zmm9, zmm9, zmm7 + vpsrlq zmm7, zmm10, 32 + vaddps zmm10, zmm10, zmm7 + vmovups zmm7, zmmword ptr [rip + .PERMUTATION] + vpermt2ps zmm11, zmm7, zmm22 + vpermt2ps zmm12, zmm7, zmm23 + vpermt2ps zmm13, zmm7, zmm24 + vpermt2ps zmm14, zmm7, zmm25 + vpermt2ps zmm15, zmm7, zmm26 + vpermt2ps zmm16, zmm7, zmm27 + vpermt2ps zmm17, zmm7, zmm28 + vpermt2ps zmm18, zmm7, zmm29 + vpermt2ps zmm19, zmm7, zmm30 + vpermt2ps zmm20, zmm7, zmm9 + vpermt2ps zmm21, zmm7, zmm10 + # Min/max clamping. + vminps zmm11, zmm1, zmm11 + vminps zmm12, zmm1, zmm12 + vminps zmm13, zmm1, zmm13 + vminps zmm14, zmm1, zmm14 + vminps zmm15, zmm1, zmm15 + vminps zmm16, zmm1, zmm16 + vminps zmm17, zmm1, zmm17 + vminps zmm18, zmm1, zmm18 + vminps zmm19, zmm1, zmm19 + vminps zmm20, zmm1, zmm20 + vminps zmm21, zmm1, zmm21 + vmaxps zmm11, zmm0, zmm11 + vmaxps zmm12, zmm0, zmm12 + vmaxps zmm13, zmm0, zmm13 + vmaxps zmm14, zmm0, zmm14 + vmaxps zmm15, zmm0, zmm15 + vmaxps zmm16, zmm0, zmm16 + vmaxps zmm17, zmm0, zmm17 + vmaxps zmm18, zmm0, zmm18 + vmaxps zmm19, zmm0, zmm19 + vmaxps zmm20, zmm0, zmm20 + vmaxps zmm21, zmm0, zmm21 + + # Pop output pointers from the stack. + mov rcx, [rsp + 24] + mov rax, [rsp + 40] + mov r15, [rsp + 56] + mov r14, [rsp + 72] + mov r12, [rsp + 88] + mov r10, [rsp + 104] + mov r13, [rsp + 120] + mov rbx, [rsp + 136] + mov rbp, [rsp + 152] + mov r8, [rsp + 168] + mov rdi, [rsp + 184] + + # Check whether full or partial store. + cmp rsi, 16 + jl tail + + vmovups [rcx], zmm11 + vmovups [rax], zmm12 + vmovups [r15], zmm13 + vmovups [r14], zmm14 + vmovups [r12], zmm15 + vmovups [r10], zmm16 + vmovups [r13], zmm17 + vmovups [rbx], zmm18 + vmovups [rbp], zmm19 + vmovups [r8], zmm20 + vmovups [rdi], zmm21 + add rcx, 64 + add rax, 64 + add r15, 64 + add r14, 64 + add r12, 64 + add r10, 64 + add r13, 64 + add rbx, 64 + add rbp, 64 + add r8, 64 + add rdi, 64 + + # Write output pointers to the stack. + mov [rsp + 24], rcx + mov [rsp + 40], rax + mov [rsp + 56], r15 + mov [rsp + 72], r14 + mov [rsp + 88], r12 + mov [rsp + 104], r10 + mov [rsp + 120], r13 + mov [rsp + 136], rbx + mov [rsp + 152], rbp + mov [rsp + 168], r8 + mov [rsp + 184], rdi + + sub rsi, 16 + jne outer_loop + jmp return + +tail: + mov r11, -1 + shlx r11, r11, rsi + not r11 + kmovw k1, r11d + vmovups ZMMWORD PTR [rcx]{k1}, zmm11 + vmovups ZMMWORD PTR [rax]{k1}, zmm12 + vmovups ZMMWORD PTR [r15]{k1}, zmm13 + vmovups ZMMWORD PTR [r14]{k1}, zmm14 + vmovups ZMMWORD PTR [r12]{k1}, zmm15 + vmovups ZMMWORD PTR [r10]{k1}, zmm16 + vmovups ZMMWORD PTR [r13]{k1}, zmm17 + vmovups ZMMWORD PTR [rbx]{k1}, zmm18 + vmovups ZMMWORD PTR [rbp]{k1}, zmm19 + vmovups ZMMWORD PTR [r8]{k1}, zmm20 + vmovups ZMMWORD PTR [rdi]{k1}, zmm21 + +return: + add rsp, 256 + mov r13, [rsp] + mov rsp, r13 + # Restore the callee saved registers. + pop r12 + pop r13 + pop r14 + pop r15 + pop rbp + pop rbx + ret +END_FUNCTION xnn_f32_gemm_minmax_ukernel_11x16c2__asm_amd64_avx512f_broadcast + + #ifdef __has_feature + #if __has_feature(dataflow_sanitizer) +BEGIN_FUNCTION xnn_f32_gemm_minmax_ukernel_11x16c2__asm_amd64_avx512f_broadcast.dfsan + .intel_syntax noprefix + # We could implement this by calling a function that implements the dfsan instrumentation. + # For now, just break, so if someone tries to use this, they'll know where the problem is. + int 3 + ret +END_FUNCTION xnn_f32_gemm_minmax_ukernel_11x16c2__asm_amd64_avx512f_broadcast.dfsan + #endif + #endif \ No newline at end of file diff --git a/src/f32-gemm/gen/f32-gemm-11x32-minmax-asm-amd64-avx512f-broadcast.S b/src/f32-gemm/gen/f32-gemm-11x32-minmax-asm-amd64-avx512f-broadcast.S index e5f766bedc6..87d49db266b 100644 --- a/src/f32-gemm/gen/f32-gemm-11x32-minmax-asm-amd64-avx512f-broadcast.S +++ b/src/f32-gemm/gen/f32-gemm-11x32-minmax-asm-amd64-avx512f-broadcast.S @@ -180,18 +180,18 @@ outer_loop: mov rdi, [rsp + 176] # Initialize accumulators with the biases. - vmovaps zmm7, [r9 + 0] + vmovaps zmm11, [r9 + 0] vmovaps zmm22, [r9 + 64] - vmovaps zmm8, zmm7 - vmovaps zmm9, zmm7 - vmovaps zmm14, zmm7 - vmovaps zmm15, zmm7 - vmovaps zmm16, zmm7 - vmovaps zmm17, zmm7 - vmovaps zmm18, zmm7 - vmovaps zmm19, zmm7 - vmovaps zmm20, zmm7 - vmovaps zmm21, zmm7 + vmovaps zmm12, zmm11 + vmovaps zmm13, zmm11 + vmovaps zmm14, zmm11 + vmovaps zmm15, zmm11 + vmovaps zmm16, zmm11 + vmovaps zmm17, zmm11 + vmovaps zmm18, zmm11 + vmovaps zmm19, zmm11 + vmovaps zmm20, zmm11 + vmovaps zmm21, zmm11 vmovaps zmm23, zmm22 vmovaps zmm24, zmm22 vmovaps zmm25, zmm22 @@ -200,56 +200,56 @@ outer_loop: vmovaps zmm28, zmm22 vmovaps zmm29, zmm22 vmovaps zmm30, zmm22 - vmovaps zmm12, zmm22 - vmovaps zmm13, zmm22 + vmovaps zmm9, zmm22 + vmovaps zmm10, zmm22 add r9, 128 inner_loop: - vmovaps zmm10, [r9 + 0] - vmovaps zmm11, [r9 + 64] + vmovaps zmm7, [r9 + 0] + vmovaps zmm8, [r9 + 64] add r9, 128 vbroadcastss zmm2, DWORD PTR [rcx + r11] - vfmadd231ps zmm7, zmm2, zmm10 - vfmadd231ps zmm22, zmm2, zmm11 + vfmadd231ps zmm11, zmm2, zmm7 + vfmadd231ps zmm22, zmm2, zmm8 vbroadcastss zmm2, DWORD PTR [rax + r11] - vfmadd231ps zmm8, zmm2, zmm10 - vfmadd231ps zmm23, zmm2, zmm11 + vfmadd231ps zmm12, zmm2, zmm7 + vfmadd231ps zmm23, zmm2, zmm8 vbroadcastss zmm2, DWORD PTR [r15 + r11] - vfmadd231ps zmm9, zmm2, zmm10 - vfmadd231ps zmm24, zmm2, zmm11 + vfmadd231ps zmm13, zmm2, zmm7 + vfmadd231ps zmm24, zmm2, zmm8 vbroadcastss zmm2, DWORD PTR [r14 + r11] - vfmadd231ps zmm14, zmm2, zmm10 - vfmadd231ps zmm25, zmm2, zmm11 + vfmadd231ps zmm14, zmm2, zmm7 + vfmadd231ps zmm25, zmm2, zmm8 vbroadcastss zmm2, DWORD PTR [r12 + r11] - vfmadd231ps zmm15, zmm2, zmm10 - vfmadd231ps zmm26, zmm2, zmm11 + vfmadd231ps zmm15, zmm2, zmm7 + vfmadd231ps zmm26, zmm2, zmm8 vbroadcastss zmm2, DWORD PTR [r10 + r11] - vfmadd231ps zmm16, zmm2, zmm10 - vfmadd231ps zmm27, zmm2, zmm11 + vfmadd231ps zmm16, zmm2, zmm7 + vfmadd231ps zmm27, zmm2, zmm8 vbroadcastss zmm2, DWORD PTR [r13 + r11] - vfmadd231ps zmm17, zmm2, zmm10 - vfmadd231ps zmm28, zmm2, zmm11 + vfmadd231ps zmm17, zmm2, zmm7 + vfmadd231ps zmm28, zmm2, zmm8 vbroadcastss zmm2, DWORD PTR [rbx + r11] - vfmadd231ps zmm18, zmm2, zmm10 - vfmadd231ps zmm29, zmm2, zmm11 + vfmadd231ps zmm18, zmm2, zmm7 + vfmadd231ps zmm29, zmm2, zmm8 vbroadcastss zmm2, DWORD PTR [rbp + r11] - vfmadd231ps zmm19, zmm2, zmm10 - vfmadd231ps zmm30, zmm2, zmm11 + vfmadd231ps zmm19, zmm2, zmm7 + vfmadd231ps zmm30, zmm2, zmm8 vbroadcastss zmm2, DWORD PTR [r8 + r11] - vfmadd231ps zmm20, zmm2, zmm10 - vfmadd231ps zmm12, zmm2, zmm11 + vfmadd231ps zmm20, zmm2, zmm7 + vfmadd231ps zmm9, zmm2, zmm8 vbroadcastss zmm2, DWORD PTR [rdi + r11] - vfmadd231ps zmm21, zmm2, zmm10 - vfmadd231ps zmm13, zmm2, zmm11 + vfmadd231ps zmm21, zmm2, zmm7 + vfmadd231ps zmm10, zmm2, zmm8 add r11, 4 cmp rdx, r11 jne inner_loop inner_loop_end: # Min/max clamping. - vminps zmm7, zmm1, zmm7 - vminps zmm8, zmm1, zmm8 - vminps zmm9, zmm1, zmm9 + vminps zmm11, zmm1, zmm11 + vminps zmm12, zmm1, zmm12 + vminps zmm13, zmm1, zmm13 vminps zmm14, zmm1, zmm14 vminps zmm15, zmm1, zmm15 vminps zmm16, zmm1, zmm16 @@ -267,11 +267,11 @@ inner_loop_end: vminps zmm28, zmm1, zmm28 vminps zmm29, zmm1, zmm29 vminps zmm30, zmm1, zmm30 - vminps zmm12, zmm1, zmm12 - vminps zmm13, zmm1, zmm13 - vmaxps zmm7, zmm0, zmm7 - vmaxps zmm8, zmm0, zmm8 - vmaxps zmm9, zmm0, zmm9 + vminps zmm9, zmm1, zmm9 + vminps zmm10, zmm1, zmm10 + vmaxps zmm11, zmm0, zmm11 + vmaxps zmm12, zmm0, zmm12 + vmaxps zmm13, zmm0, zmm13 vmaxps zmm14, zmm0, zmm14 vmaxps zmm15, zmm0, zmm15 vmaxps zmm16, zmm0, zmm16 @@ -289,8 +289,8 @@ inner_loop_end: vmaxps zmm28, zmm0, zmm28 vmaxps zmm29, zmm0, zmm29 vmaxps zmm30, zmm0, zmm30 - vmaxps zmm12, zmm0, zmm12 - vmaxps zmm13, zmm0, zmm13 + vmaxps zmm9, zmm0, zmm9 + vmaxps zmm10, zmm0, zmm10 # Pop output pointers from the stack. mov rcx, [rsp + 24] @@ -309,11 +309,11 @@ inner_loop_end: cmp rsi, 32 jl tail - vmovups [rcx], zmm7 + vmovups [rcx], zmm11 vmovups [rcx + 64], zmm22 - vmovups [rax], zmm8 + vmovups [rax], zmm12 vmovups [rax + 64], zmm23 - vmovups [r15], zmm9 + vmovups [r15], zmm13 vmovups [r15 + 64], zmm24 vmovups [r14], zmm14 vmovups [r14 + 64], zmm25 @@ -328,9 +328,9 @@ inner_loop_end: vmovups [rbp], zmm19 vmovups [rbp + 64], zmm30 vmovups [r8], zmm20 - vmovups [r8 + 64], zmm12 + vmovups [r8 + 64], zmm9 vmovups [rdi], zmm21 - vmovups [rdi + 64], zmm13 + vmovups [rdi + 64], zmm10 add rcx, 128 add rax, 128 add r15, 128 @@ -367,11 +367,11 @@ tail: kmovw k1, r11d shr r11d, 16 kmovw k2, r11d - vmovups ZMMWORD PTR [rcx]{k1}, zmm7 + vmovups ZMMWORD PTR [rcx]{k1}, zmm11 vmovups ZMMWORD PTR [rcx + 64]{k2}, zmm22 - vmovups ZMMWORD PTR [rax]{k1}, zmm8 + vmovups ZMMWORD PTR [rax]{k1}, zmm12 vmovups ZMMWORD PTR [rax + 64]{k2}, zmm23 - vmovups ZMMWORD PTR [r15]{k1}, zmm9 + vmovups ZMMWORD PTR [r15]{k1}, zmm13 vmovups ZMMWORD PTR [r15 + 64]{k2}, zmm24 vmovups ZMMWORD PTR [r14]{k1}, zmm14 vmovups ZMMWORD PTR [r14 + 64]{k2}, zmm25 @@ -386,9 +386,9 @@ tail: vmovups ZMMWORD PTR [rbp]{k1}, zmm19 vmovups ZMMWORD PTR [rbp + 64]{k2}, zmm30 vmovups ZMMWORD PTR [r8]{k1}, zmm20 - vmovups ZMMWORD PTR [r8 + 64]{k2}, zmm12 + vmovups ZMMWORD PTR [r8 + 64]{k2}, zmm9 vmovups ZMMWORD PTR [rdi]{k1}, zmm21 - vmovups ZMMWORD PTR [rdi + 64]{k2}, zmm13 + vmovups ZMMWORD PTR [rdi + 64]{k2}, zmm10 return: add rsp, 256 diff --git a/src/f32-gemm/gen/f32-gemm-1x16-minmax-asm-amd64-avx512f-broadcast.S b/src/f32-gemm/gen/f32-gemm-1x16-minmax-asm-amd64-avx512f-broadcast.S index 60f05033b58..514fc9cf96c 100644 --- a/src/f32-gemm/gen/f32-gemm-1x16-minmax-asm-amd64-avx512f-broadcast.S +++ b/src/f32-gemm/gen/f32-gemm-1x16-minmax-asm-amd64-avx512f-broadcast.S @@ -43,28 +43,28 @@ outer_loop: # Initialize k counter. mov r11, 0 # Initialize accumulators with the biases. - vmovaps zmm7, [r9 + 0] + vmovaps zmm11, [r9 + 0] add r9, 64 inner_loop: - vmovaps zmm10, [r9 + 0] + vmovaps zmm7, [r9 + 0] add r9, 64 vbroadcastss zmm2, DWORD PTR [rcx + r11] - vfmadd231ps zmm7, zmm2, zmm10 + vfmadd231ps zmm11, zmm2, zmm7 add r11, 4 cmp rdx, r11 jne inner_loop inner_loop_end: # Min/max clamping. - vminps zmm7, zmm1, zmm7 - vmaxps zmm7, zmm0, zmm7 + vminps zmm11, zmm1, zmm11 + vmaxps zmm11, zmm0, zmm11 # Check whether full or partial store. cmp rsi, 16 jl tail - vmovups [r10], zmm7 + vmovups [r10], zmm11 add r10, 64 sub rsi, 16 @@ -76,7 +76,7 @@ tail: shlx r11, r11, rsi not r11 kmovw k1, r11d - vmovups ZMMWORD PTR [r10]{k1}, zmm7 + vmovups ZMMWORD PTR [r10]{k1}, zmm11 return: add rsp, 128 diff --git a/src/f32-gemm/gen/f32-gemm-1x16c2-minmax-asm-amd64-avx512f-broadcast.S b/src/f32-gemm/gen/f32-gemm-1x16c2-minmax-asm-amd64-avx512f-broadcast.S new file mode 100644 index 00000000000..342ba0e5439 --- /dev/null +++ b/src/f32-gemm/gen/f32-gemm-1x16c2-minmax-asm-amd64-avx512f-broadcast.S @@ -0,0 +1,160 @@ +// Copyright 2025 Google LLC +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + +#include "xnnpack/assembly.h" +.PERMUTATION: + .long 0 + .long 2 + .long 4 + .long 6 + .long 8 + .long 10 + .long 12 + .long 14 + .long 16 + .long 18 + .long 20 + .long 22 + .long 24 + .long 26 + .long 28 + .long 30 + +BEGIN_FUNCTION xnn_f32_gemm_minmax_ukernel_1x16c2__asm_amd64_avx512f_broadcast + + .intel_syntax noprefix + + # Free up GP registers. + push rbx + push rbp + push r15 + push r14 + push r13 + push r12 + + # load params to free up a GP registers + mov r13, [rsp + 80] # params + vbroadcastss zmm0, DWORD PTR [r13] + vbroadcastss zmm1, DWORD PTR [r13 + 4] + + # Load c pointer. + mov r10, [rsp + 56] + # Load cm_stride. + mov r11, [rsp + 64] + + # Align the stack pointer. + mov r13, rsp + sub rsp, 64 + and rsp, 0xFFFFFFFFFFFFFFC0 + # Store the old stack pointer containing the return address + mov [rsp], r13 + + # Allocate some space on the stack. + sub rsp, 64 + + # Copy k and flip bit. + mov r11, rdx + and r11, 0x4 + and rdx, 0xFFFFFFFFFFFFFFFB + mov [rsp + 40], r11 + mov r11, 0x5555 + kmovw k3, r11d + +outer_loop: + # Initialize k counter. + mov r11, 0 + vmovaps zmm7, [r9 + 0] + # Interleave with zeros. + vpmovzxdq zmm11, ymm7 + vextracti64x4 ymm7, zmm7, 1 + vpmovzxdq zmm12, ymm7 + add r9, 64 + + # Are there at least 8 bytes? + cmp rdx, 8 + js inner_loop_tail + +inner_loop: + vmovaps zmm7, [r9 + 0] + vmovaps zmm8, [r9 + 64] + add r9, 128 + vbroadcastsd zmm2, QWORD PTR [rcx + r11] + vfmadd231ps zmm11, zmm2, zmm7 + vfmadd231ps zmm12, zmm2, zmm8 + + add r11, 8 + cmp rdx, r11 + jne inner_loop + + # Store nc_register. + mov [rsp + 48], rsi + # Load odd k bit. + mov rsi, [rsp + 40] + # Check if channels are odd. + test rsi, rsi + mov rsi, [rsp + 48] + jz inner_loop_end + +inner_loop_tail: + vmovaps zmm7, [r9 + 0] + vmovaps zmm8, [r9 + 64] + add r9, 128 + vbroadcastsd zmm2, QWORD PTR [rcx + r11] + vfmadd231ps zmm11{k3}, zmm2, zmm7 + vfmadd231ps zmm12{k3}, zmm2, zmm8 +inner_loop_end: + vpsrlq zmm7, zmm11, 32 + vaddps zmm11, zmm11, zmm7 + vpsrlq zmm7, zmm12, 32 + vaddps zmm12, zmm12, zmm7 + vmovups zmm7, zmmword ptr [rip + .PERMUTATION] + vpermt2ps zmm11, zmm7, zmm12 + # Min/max clamping. + vminps zmm11, zmm1, zmm11 + vmaxps zmm11, zmm0, zmm11 + + # Check whether full or partial store. + cmp rsi, 16 + jl tail + + vmovups [r10], zmm11 + add r10, 64 + + sub rsi, 16 + jne outer_loop + jmp return + +tail: + mov r11, -1 + shlx r11, r11, rsi + not r11 + kmovw k1, r11d + vmovups ZMMWORD PTR [r10]{k1}, zmm11 + +return: + add rsp, 64 + mov r13, [rsp] + mov rsp, r13 + # Restore the callee saved registers. + pop r12 + pop r13 + pop r14 + pop r15 + pop rbp + pop rbx + ret +END_FUNCTION xnn_f32_gemm_minmax_ukernel_1x16c2__asm_amd64_avx512f_broadcast + + #ifdef __has_feature + #if __has_feature(dataflow_sanitizer) +BEGIN_FUNCTION xnn_f32_gemm_minmax_ukernel_1x16c2__asm_amd64_avx512f_broadcast.dfsan + .intel_syntax noprefix + # We could implement this by calling a function that implements the dfsan instrumentation. + # For now, just break, so if someone tries to use this, they'll know where the problem is. + int 3 + ret +END_FUNCTION xnn_f32_gemm_minmax_ukernel_1x16c2__asm_amd64_avx512f_broadcast.dfsan + #endif + #endif \ No newline at end of file diff --git a/src/f32-gemm/gen/f32-gemm-1x32-minmax-asm-amd64-avx512f-broadcast.S b/src/f32-gemm/gen/f32-gemm-1x32-minmax-asm-amd64-avx512f-broadcast.S index 7754f033105..03f5391c609 100644 --- a/src/f32-gemm/gen/f32-gemm-1x32-minmax-asm-amd64-avx512f-broadcast.S +++ b/src/f32-gemm/gen/f32-gemm-1x32-minmax-asm-amd64-avx512f-broadcast.S @@ -43,34 +43,34 @@ outer_loop: # Initialize k counter. mov r11, 0 # Initialize accumulators with the biases. - vmovaps zmm7, [r9 + 0] - vmovaps zmm8, [r9 + 64] + vmovaps zmm11, [r9 + 0] + vmovaps zmm12, [r9 + 64] add r9, 128 inner_loop: - vmovaps zmm10, [r9 + 0] - vmovaps zmm11, [r9 + 64] + vmovaps zmm7, [r9 + 0] + vmovaps zmm8, [r9 + 64] add r9, 128 vbroadcastss zmm2, DWORD PTR [rcx + r11] - vfmadd231ps zmm7, zmm2, zmm10 - vfmadd231ps zmm8, zmm2, zmm11 + vfmadd231ps zmm11, zmm2, zmm7 + vfmadd231ps zmm12, zmm2, zmm8 add r11, 4 cmp rdx, r11 jne inner_loop inner_loop_end: # Min/max clamping. - vminps zmm7, zmm1, zmm7 - vminps zmm8, zmm1, zmm8 - vmaxps zmm7, zmm0, zmm7 - vmaxps zmm8, zmm0, zmm8 + vminps zmm11, zmm1, zmm11 + vminps zmm12, zmm1, zmm12 + vmaxps zmm11, zmm0, zmm11 + vmaxps zmm12, zmm0, zmm12 # Check whether full or partial store. cmp rsi, 32 jl tail - vmovups [r10], zmm7 - vmovups [r10 + 64], zmm8 + vmovups [r10], zmm11 + vmovups [r10 + 64], zmm12 add r10, 128 sub rsi, 32 @@ -84,8 +84,8 @@ tail: kmovw k1, r11d shr r11d, 16 kmovw k2, r11d - vmovups ZMMWORD PTR [r10]{k1}, zmm7 - vmovups ZMMWORD PTR [r10 + 64]{k2}, zmm8 + vmovups ZMMWORD PTR [r10]{k1}, zmm11 + vmovups ZMMWORD PTR [r10 + 64]{k2}, zmm12 return: add rsp, 128 diff --git a/src/f32-gemm/gen/f32-gemm-1x32c2-minmax-asm-amd64-avx512f-broadcast.S b/src/f32-gemm/gen/f32-gemm-1x32c2-minmax-asm-amd64-avx512f-broadcast.S new file mode 100644 index 00000000000..d31aa78e101 --- /dev/null +++ b/src/f32-gemm/gen/f32-gemm-1x32c2-minmax-asm-amd64-avx512f-broadcast.S @@ -0,0 +1,183 @@ +// Copyright 2025 Google LLC +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + +#include "xnnpack/assembly.h" +.PERMUTATION: + .long 0 + .long 2 + .long 4 + .long 6 + .long 8 + .long 10 + .long 12 + .long 14 + .long 16 + .long 18 + .long 20 + .long 22 + .long 24 + .long 26 + .long 28 + .long 30 + +BEGIN_FUNCTION xnn_f32_gemm_minmax_ukernel_1x32c2__asm_amd64_avx512f_broadcast + + .intel_syntax noprefix + + # Free up GP registers. + push rbx + push rbp + push r15 + push r14 + push r13 + push r12 + + # load params to free up a GP registers + mov r13, [rsp + 80] # params + vbroadcastss zmm0, DWORD PTR [r13] + vbroadcastss zmm1, DWORD PTR [r13 + 4] + + # Load c pointer. + mov r10, [rsp + 56] + # Load cm_stride. + mov r11, [rsp + 64] + + # Align the stack pointer. + mov r13, rsp + sub rsp, 64 + and rsp, 0xFFFFFFFFFFFFFFC0 + # Store the old stack pointer containing the return address + mov [rsp], r13 + + # Allocate some space on the stack. + sub rsp, 64 + + # Copy k and flip bit. + mov r11, rdx + and r11, 0x4 + and rdx, 0xFFFFFFFFFFFFFFFB + mov [rsp + 40], r11 + mov r11, 0x5555 + kmovw k3, r11d + +outer_loop: + # Initialize k counter. + mov r11, 0 + vmovaps zmm7, [r9 + 0] + vmovaps zmm8, [r9 + 64] + # Interleave with zeros. + vpmovzxdq zmm11, ymm7 + vextracti64x4 ymm7, zmm7, 1 + vpmovzxdq zmm12, ymm7 + vpmovzxdq zmm13, ymm8 + vextracti64x4 ymm8, zmm8, 1 + vpmovzxdq zmm14, ymm8 + add r9, 128 + + # Are there at least 8 bytes? + cmp rdx, 8 + js inner_loop_tail + +inner_loop: + vmovaps zmm7, [r9 + 0] + vmovaps zmm8, [r9 + 64] + vmovaps zmm9, [r9 + 128] + vmovaps zmm10, [r9 + 192] + add r9, 256 + vbroadcastsd zmm2, QWORD PTR [rcx + r11] + vfmadd231ps zmm11, zmm2, zmm7 + vfmadd231ps zmm12, zmm2, zmm8 + vfmadd231ps zmm13, zmm2, zmm9 + vfmadd231ps zmm14, zmm2, zmm10 + + add r11, 8 + cmp rdx, r11 + jne inner_loop + + # Store nc_register. + mov [rsp + 48], rsi + # Load odd k bit. + mov rsi, [rsp + 40] + # Check if channels are odd. + test rsi, rsi + mov rsi, [rsp + 48] + jz inner_loop_end + +inner_loop_tail: + vmovaps zmm7, [r9 + 0] + vmovaps zmm8, [r9 + 64] + vmovaps zmm9, [r9 + 128] + vmovaps zmm10, [r9 + 192] + add r9, 256 + vbroadcastsd zmm2, QWORD PTR [rcx + r11] + vfmadd231ps zmm11{k3}, zmm2, zmm7 + vfmadd231ps zmm12{k3}, zmm2, zmm8 + vfmadd231ps zmm13{k3}, zmm2, zmm9 + vfmadd231ps zmm14{k3}, zmm2, zmm10 +inner_loop_end: + vpsrlq zmm7, zmm11, 32 + vaddps zmm11, zmm11, zmm7 + vpsrlq zmm7, zmm12, 32 + vaddps zmm12, zmm12, zmm7 + vpsrlq zmm7, zmm13, 32 + vaddps zmm13, zmm13, zmm7 + vpsrlq zmm7, zmm14, 32 + vaddps zmm14, zmm14, zmm7 + vmovups zmm7, zmmword ptr [rip + .PERMUTATION] + vpermt2ps zmm11, zmm7, zmm12 + vpermt2ps zmm13, zmm7, zmm14 + # Min/max clamping. + vminps zmm11, zmm1, zmm11 + vminps zmm12, zmm1, zmm13 + vmaxps zmm11, zmm0, zmm11 + vmaxps zmm12, zmm0, zmm12 + + # Check whether full or partial store. + cmp rsi, 32 + jl tail + + vmovups [r10], zmm11 + vmovups [r10 + 64], zmm12 + add r10, 128 + + sub rsi, 32 + jne outer_loop + jmp return + +tail: + mov r11, -1 + shlx r11, r11, rsi + not r11 + kmovw k1, r11d + shr r11d, 16 + kmovw k2, r11d + vmovups ZMMWORD PTR [r10]{k1}, zmm11 + vmovups ZMMWORD PTR [r10 + 64]{k2}, zmm12 + +return: + add rsp, 64 + mov r13, [rsp] + mov rsp, r13 + # Restore the callee saved registers. + pop r12 + pop r13 + pop r14 + pop r15 + pop rbp + pop rbx + ret +END_FUNCTION xnn_f32_gemm_minmax_ukernel_1x32c2__asm_amd64_avx512f_broadcast + + #ifdef __has_feature + #if __has_feature(dataflow_sanitizer) +BEGIN_FUNCTION xnn_f32_gemm_minmax_ukernel_1x32c2__asm_amd64_avx512f_broadcast.dfsan + .intel_syntax noprefix + # We could implement this by calling a function that implements the dfsan instrumentation. + # For now, just break, so if someone tries to use this, they'll know where the problem is. + int 3 + ret +END_FUNCTION xnn_f32_gemm_minmax_ukernel_1x32c2__asm_amd64_avx512f_broadcast.dfsan + #endif + #endif \ No newline at end of file diff --git a/src/f32-gemm/gen/f32-gemm-1x64-minmax-asm-amd64-avx512f-broadcast.S b/src/f32-gemm/gen/f32-gemm-1x64-minmax-asm-amd64-avx512f-broadcast.S index 1f18be2bed9..0c0cc97b2f6 100644 --- a/src/f32-gemm/gen/f32-gemm-1x64-minmax-asm-amd64-avx512f-broadcast.S +++ b/src/f32-gemm/gen/f32-gemm-1x64-minmax-asm-amd64-avx512f-broadcast.S @@ -43,45 +43,45 @@ outer_loop: # Initialize k counter. mov r11, 0 # Initialize accumulators with the biases. - vmovaps zmm7, [r9 + 0] - vmovaps zmm8, [r9 + 64] - vmovaps zmm9, [r9 + 128] + vmovaps zmm11, [r9 + 0] + vmovaps zmm12, [r9 + 64] + vmovaps zmm13, [r9 + 128] vmovaps zmm14, [r9 + 192] add r9, 256 inner_loop: - vmovaps zmm10, [r9 + 0] - vmovaps zmm11, [r9 + 64] - vmovaps zmm12, [r9 + 128] - vmovaps zmm13, [r9 + 192] + vmovaps zmm7, [r9 + 0] + vmovaps zmm8, [r9 + 64] + vmovaps zmm9, [r9 + 128] + vmovaps zmm10, [r9 + 192] add r9, 256 vbroadcastss zmm2, DWORD PTR [rcx + r11] - vfmadd231ps zmm7, zmm2, zmm10 - vfmadd231ps zmm8, zmm2, zmm11 - vfmadd231ps zmm9, zmm2, zmm12 - vfmadd231ps zmm14, zmm2, zmm13 + vfmadd231ps zmm11, zmm2, zmm7 + vfmadd231ps zmm12, zmm2, zmm8 + vfmadd231ps zmm13, zmm2, zmm9 + vfmadd231ps zmm14, zmm2, zmm10 add r11, 4 cmp rdx, r11 jne inner_loop inner_loop_end: # Min/max clamping. - vminps zmm7, zmm1, zmm7 - vminps zmm8, zmm1, zmm8 - vminps zmm9, zmm1, zmm9 + vminps zmm11, zmm1, zmm11 + vminps zmm12, zmm1, zmm12 + vminps zmm13, zmm1, zmm13 vminps zmm14, zmm1, zmm14 - vmaxps zmm7, zmm0, zmm7 - vmaxps zmm8, zmm0, zmm8 - vmaxps zmm9, zmm0, zmm9 + vmaxps zmm11, zmm0, zmm11 + vmaxps zmm12, zmm0, zmm12 + vmaxps zmm13, zmm0, zmm13 vmaxps zmm14, zmm0, zmm14 # Check whether full or partial store. cmp rsi, 64 jl tail - vmovups [r10], zmm7 - vmovups [r10 + 64], zmm8 - vmovups [r10 + 128], zmm9 + vmovups [r10], zmm11 + vmovups [r10 + 64], zmm12 + vmovups [r10 + 128], zmm13 vmovups [r10 + 192], zmm14 add r10, 256 @@ -101,9 +101,9 @@ tail: shr r11, 16 kmovw k4, r11d - vmovups ZMMWORD PTR [r10]{k1}, zmm7 - vmovups ZMMWORD PTR [r10 + 64]{k2}, zmm8 - vmovups ZMMWORD PTR [r10 + 128]{k3}, zmm9 + vmovups ZMMWORD PTR [r10]{k1}, zmm11 + vmovups ZMMWORD PTR [r10 + 64]{k2}, zmm12 + vmovups ZMMWORD PTR [r10 + 128]{k3}, zmm13 vmovups ZMMWORD PTR [r10 + 192]{k4}, zmm14 return: diff --git a/src/f32-gemm/gen/f32-gemm-2x16-minmax-asm-amd64-avx512f-broadcast.S b/src/f32-gemm/gen/f32-gemm-2x16-minmax-asm-amd64-avx512f-broadcast.S index e7c5cfd8204..ed9a5adabb8 100644 --- a/src/f32-gemm/gen/f32-gemm-2x16-minmax-asm-amd64-avx512f-broadcast.S +++ b/src/f32-gemm/gen/f32-gemm-2x16-minmax-asm-amd64-avx512f-broadcast.S @@ -52,34 +52,34 @@ outer_loop: # Initialize k counter. mov r11, 0 # Initialize accumulators with the biases. - vmovaps zmm7, [r9 + 0] - vmovaps zmm8, zmm7 + vmovaps zmm11, [r9 + 0] + vmovaps zmm12, zmm11 add r9, 64 inner_loop: - vmovaps zmm10, [r9 + 0] + vmovaps zmm7, [r9 + 0] add r9, 64 vbroadcastss zmm2, DWORD PTR [rcx + r11] - vfmadd231ps zmm7, zmm2, zmm10 + vfmadd231ps zmm11, zmm2, zmm7 vbroadcastss zmm3, DWORD PTR [rax + r11] - vfmadd231ps zmm8, zmm3, zmm10 + vfmadd231ps zmm12, zmm3, zmm7 add r11, 4 cmp rdx, r11 jne inner_loop inner_loop_end: # Min/max clamping. - vminps zmm7, zmm1, zmm7 - vminps zmm8, zmm1, zmm8 - vmaxps zmm7, zmm0, zmm7 - vmaxps zmm8, zmm0, zmm8 + vminps zmm11, zmm1, zmm11 + vminps zmm12, zmm1, zmm12 + vmaxps zmm11, zmm0, zmm11 + vmaxps zmm12, zmm0, zmm12 # Check whether full or partial store. cmp rsi, 16 jl tail - vmovups [r10], zmm7 - vmovups [r13], zmm8 + vmovups [r10], zmm11 + vmovups [r13], zmm12 add r10, 64 add r13, 64 @@ -92,8 +92,8 @@ tail: shlx r11, r11, rsi not r11 kmovw k1, r11d - vmovups ZMMWORD PTR [r10]{k1}, zmm7 - vmovups ZMMWORD PTR [r13]{k1}, zmm8 + vmovups ZMMWORD PTR [r10]{k1}, zmm11 + vmovups ZMMWORD PTR [r13]{k1}, zmm12 return: add rsp, 128 diff --git a/src/f32-gemm/gen/f32-gemm-2x16c2-minmax-asm-amd64-avx512f-broadcast.S b/src/f32-gemm/gen/f32-gemm-2x16c2-minmax-asm-amd64-avx512f-broadcast.S new file mode 100644 index 00000000000..d2015697f28 --- /dev/null +++ b/src/f32-gemm/gen/f32-gemm-2x16c2-minmax-asm-amd64-avx512f-broadcast.S @@ -0,0 +1,187 @@ +// Copyright 2025 Google LLC +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + +#include "xnnpack/assembly.h" +.PERMUTATION: + .long 0 + .long 2 + .long 4 + .long 6 + .long 8 + .long 10 + .long 12 + .long 14 + .long 16 + .long 18 + .long 20 + .long 22 + .long 24 + .long 26 + .long 28 + .long 30 + +BEGIN_FUNCTION xnn_f32_gemm_minmax_ukernel_2x16c2__asm_amd64_avx512f_broadcast + + .intel_syntax noprefix + + # Free up GP registers. + push rbx + push rbp + push r15 + push r14 + push r13 + push r12 + + # load params to free up a GP registers + mov r13, [rsp + 80] # params + vbroadcastss zmm0, DWORD PTR [r13] + vbroadcastss zmm1, DWORD PTR [r13 + 4] + + # Load c pointer. + mov r10, [rsp + 56] + # Load cm_stride. + mov r11, [rsp + 64] + + # Align the stack pointer. + mov r13, rsp + sub rsp, 64 + and rsp, 0xFFFFFFFFFFFFFFC0 + # Store the old stack pointer containing the return address + mov [rsp], r13 + + # Allocate some space on the stack. + sub rsp, 128 + + # Clamp a & c pointers if mr <= 1 + mov rax, rcx + add rax, r8 + mov r13, r10 + add r13, r11 + cmp rdi, 1 + cmovle rax, rcx + cmovle r13, r10 + + # Copy k and flip bit. + mov r11, rdx + and r11, 0x4 + and rdx, 0xFFFFFFFFFFFFFFFB + mov [rsp + 56], r11 + mov r11, 0x5555 + kmovw k3, r11d + +outer_loop: + # Initialize k counter. + mov r11, 0 + vmovaps zmm7, [r9 + 0] + # Interleave with zeros. + vpmovzxdq zmm11, ymm7 + vextracti64x4 ymm7, zmm7, 1 + vpmovzxdq zmm13, ymm7 + vmovaps zmm12, zmm11 + vmovaps zmm14, zmm13 + add r9, 64 + + # Are there at least 8 bytes? + cmp rdx, 8 + js inner_loop_tail + +inner_loop: + vmovaps zmm7, [r9 + 0] + vmovaps zmm8, [r9 + 64] + add r9, 128 + vbroadcastsd zmm2, QWORD PTR [rcx + r11] + vfmadd231ps zmm11, zmm2, zmm7 + vfmadd231ps zmm13, zmm2, zmm8 + vbroadcastsd zmm3, QWORD PTR [rax + r11] + vfmadd231ps zmm12, zmm3, zmm7 + vfmadd231ps zmm14, zmm3, zmm8 + + add r11, 8 + cmp rdx, r11 + jne inner_loop + + # Store nc_register. + mov [rsp + 64], rsi + # Load odd k bit. + mov rsi, [rsp + 56] + # Check if channels are odd. + test rsi, rsi + mov rsi, [rsp + 64] + jz inner_loop_end + +inner_loop_tail: + vmovaps zmm7, [r9 + 0] + vmovaps zmm8, [r9 + 64] + add r9, 128 + vbroadcastsd zmm2, QWORD PTR [rcx + r11] + vfmadd231ps zmm11{k3}, zmm2, zmm7 + vfmadd231ps zmm13{k3}, zmm2, zmm8 + vbroadcastsd zmm3, QWORD PTR [rax + r11] + vfmadd231ps zmm12{k3}, zmm3, zmm7 + vfmadd231ps zmm14{k3}, zmm3, zmm8 +inner_loop_end: + vpsrlq zmm7, zmm11, 32 + vaddps zmm11, zmm11, zmm7 + vpsrlq zmm7, zmm12, 32 + vaddps zmm12, zmm12, zmm7 + vpsrlq zmm7, zmm13, 32 + vaddps zmm13, zmm13, zmm7 + vpsrlq zmm7, zmm14, 32 + vaddps zmm14, zmm14, zmm7 + vmovups zmm7, zmmword ptr [rip + .PERMUTATION] + vpermt2ps zmm11, zmm7, zmm13 + vpermt2ps zmm12, zmm7, zmm14 + # Min/max clamping. + vminps zmm11, zmm1, zmm11 + vminps zmm12, zmm1, zmm12 + vmaxps zmm11, zmm0, zmm11 + vmaxps zmm12, zmm0, zmm12 + + # Check whether full or partial store. + cmp rsi, 16 + jl tail + + vmovups [r10], zmm11 + vmovups [r13], zmm12 + add r10, 64 + add r13, 64 + + sub rsi, 16 + jne outer_loop + jmp return + +tail: + mov r11, -1 + shlx r11, r11, rsi + not r11 + kmovw k1, r11d + vmovups ZMMWORD PTR [r10]{k1}, zmm11 + vmovups ZMMWORD PTR [r13]{k1}, zmm12 + +return: + add rsp, 128 + mov r13, [rsp] + mov rsp, r13 + # Restore the callee saved registers. + pop r12 + pop r13 + pop r14 + pop r15 + pop rbp + pop rbx + ret +END_FUNCTION xnn_f32_gemm_minmax_ukernel_2x16c2__asm_amd64_avx512f_broadcast + + #ifdef __has_feature + #if __has_feature(dataflow_sanitizer) +BEGIN_FUNCTION xnn_f32_gemm_minmax_ukernel_2x16c2__asm_amd64_avx512f_broadcast.dfsan + .intel_syntax noprefix + # We could implement this by calling a function that implements the dfsan instrumentation. + # For now, just break, so if someone tries to use this, they'll know where the problem is. + int 3 + ret +END_FUNCTION xnn_f32_gemm_minmax_ukernel_2x16c2__asm_amd64_avx512f_broadcast.dfsan + #endif + #endif \ No newline at end of file diff --git a/src/f32-gemm/gen/f32-gemm-2x32-minmax-asm-amd64-avx512f-broadcast.S b/src/f32-gemm/gen/f32-gemm-2x32-minmax-asm-amd64-avx512f-broadcast.S index 4c307a0f2a0..8b2c1920d8a 100644 --- a/src/f32-gemm/gen/f32-gemm-2x32-minmax-asm-amd64-avx512f-broadcast.S +++ b/src/f32-gemm/gen/f32-gemm-2x32-minmax-asm-amd64-avx512f-broadcast.S @@ -52,44 +52,44 @@ outer_loop: # Initialize k counter. mov r11, 0 # Initialize accumulators with the biases. - vmovaps zmm7, [r9 + 0] - vmovaps zmm9, [r9 + 64] - vmovaps zmm8, zmm7 - vmovaps zmm14, zmm9 + vmovaps zmm11, [r9 + 0] + vmovaps zmm13, [r9 + 64] + vmovaps zmm12, zmm11 + vmovaps zmm14, zmm13 add r9, 128 inner_loop: - vmovaps zmm10, [r9 + 0] - vmovaps zmm11, [r9 + 64] + vmovaps zmm7, [r9 + 0] + vmovaps zmm8, [r9 + 64] add r9, 128 vbroadcastss zmm2, DWORD PTR [rcx + r11] - vfmadd231ps zmm7, zmm2, zmm10 - vfmadd231ps zmm9, zmm2, zmm11 + vfmadd231ps zmm11, zmm2, zmm7 + vfmadd231ps zmm13, zmm2, zmm8 vbroadcastss zmm3, DWORD PTR [rax + r11] - vfmadd231ps zmm8, zmm3, zmm10 - vfmadd231ps zmm14, zmm3, zmm11 + vfmadd231ps zmm12, zmm3, zmm7 + vfmadd231ps zmm14, zmm3, zmm8 add r11, 4 cmp rdx, r11 jne inner_loop inner_loop_end: # Min/max clamping. - vminps zmm7, zmm1, zmm7 - vminps zmm8, zmm1, zmm8 - vminps zmm9, zmm1, zmm9 + vminps zmm11, zmm1, zmm11 + vminps zmm12, zmm1, zmm12 + vminps zmm13, zmm1, zmm13 vminps zmm14, zmm1, zmm14 - vmaxps zmm7, zmm0, zmm7 - vmaxps zmm8, zmm0, zmm8 - vmaxps zmm9, zmm0, zmm9 + vmaxps zmm11, zmm0, zmm11 + vmaxps zmm12, zmm0, zmm12 + vmaxps zmm13, zmm0, zmm13 vmaxps zmm14, zmm0, zmm14 # Check whether full or partial store. cmp rsi, 32 jl tail - vmovups [r10], zmm7 - vmovups [r10 + 64], zmm9 - vmovups [r13], zmm8 + vmovups [r10], zmm11 + vmovups [r10 + 64], zmm13 + vmovups [r13], zmm12 vmovups [r13 + 64], zmm14 add r10, 128 add r13, 128 @@ -105,9 +105,9 @@ tail: kmovw k1, r11d shr r11d, 16 kmovw k2, r11d - vmovups ZMMWORD PTR [r10]{k1}, zmm7 - vmovups ZMMWORD PTR [r10 + 64]{k2}, zmm9 - vmovups ZMMWORD PTR [r13]{k1}, zmm8 + vmovups ZMMWORD PTR [r10]{k1}, zmm11 + vmovups ZMMWORD PTR [r10 + 64]{k2}, zmm13 + vmovups ZMMWORD PTR [r13]{k1}, zmm12 vmovups ZMMWORD PTR [r13 + 64]{k2}, zmm14 return: diff --git a/src/f32-gemm/gen/f32-gemm-2x32c2-minmax-asm-amd64-avx512f-broadcast.S b/src/f32-gemm/gen/f32-gemm-2x32c2-minmax-asm-amd64-avx512f-broadcast.S new file mode 100644 index 00000000000..36f8a00ef01 --- /dev/null +++ b/src/f32-gemm/gen/f32-gemm-2x32c2-minmax-asm-amd64-avx512f-broadcast.S @@ -0,0 +1,225 @@ +// Copyright 2025 Google LLC +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + +#include "xnnpack/assembly.h" +.PERMUTATION: + .long 0 + .long 2 + .long 4 + .long 6 + .long 8 + .long 10 + .long 12 + .long 14 + .long 16 + .long 18 + .long 20 + .long 22 + .long 24 + .long 26 + .long 28 + .long 30 + +BEGIN_FUNCTION xnn_f32_gemm_minmax_ukernel_2x32c2__asm_amd64_avx512f_broadcast + + .intel_syntax noprefix + + # Free up GP registers. + push rbx + push rbp + push r15 + push r14 + push r13 + push r12 + + # load params to free up a GP registers + mov r13, [rsp + 80] # params + vbroadcastss zmm0, DWORD PTR [r13] + vbroadcastss zmm1, DWORD PTR [r13 + 4] + + # Load c pointer. + mov r10, [rsp + 56] + # Load cm_stride. + mov r11, [rsp + 64] + + # Align the stack pointer. + mov r13, rsp + sub rsp, 64 + and rsp, 0xFFFFFFFFFFFFFFC0 + # Store the old stack pointer containing the return address + mov [rsp], r13 + + # Allocate some space on the stack. + sub rsp, 128 + + # Clamp a & c pointers if mr <= 1 + mov rax, rcx + add rax, r8 + mov r13, r10 + add r13, r11 + cmp rdi, 1 + cmovle rax, rcx + cmovle r13, r10 + + # Copy k and flip bit. + mov r11, rdx + and r11, 0x4 + and rdx, 0xFFFFFFFFFFFFFFFB + mov [rsp + 56], r11 + mov r11, 0x5555 + kmovw k3, r11d + +outer_loop: + # Initialize k counter. + mov r11, 0 + vmovaps zmm7, [r9 + 0] + vmovaps zmm8, [r9 + 64] + # Interleave with zeros. + vpmovzxdq zmm11, ymm7 + vextracti64x4 ymm7, zmm7, 1 + vpmovzxdq zmm13, ymm7 + vpmovzxdq zmm15, ymm8 + vextracti64x4 ymm8, zmm8, 1 + vpmovzxdq zmm17, ymm8 + vmovaps zmm12, zmm11 + vmovaps zmm14, zmm13 + vmovaps zmm16, zmm15 + vmovaps zmm18, zmm17 + add r9, 128 + + # Are there at least 8 bytes? + cmp rdx, 8 + js inner_loop_tail + +inner_loop: + vmovaps zmm7, [r9 + 0] + vmovaps zmm8, [r9 + 64] + vmovaps zmm9, [r9 + 128] + vmovaps zmm10, [r9 + 192] + add r9, 256 + vbroadcastsd zmm2, QWORD PTR [rcx + r11] + vfmadd231ps zmm11, zmm2, zmm7 + vfmadd231ps zmm13, zmm2, zmm8 + vfmadd231ps zmm15, zmm2, zmm9 + vfmadd231ps zmm17, zmm2, zmm10 + vbroadcastsd zmm3, QWORD PTR [rax + r11] + vfmadd231ps zmm12, zmm3, zmm7 + vfmadd231ps zmm14, zmm3, zmm8 + vfmadd231ps zmm16, zmm3, zmm9 + vfmadd231ps zmm18, zmm3, zmm10 + + add r11, 8 + cmp rdx, r11 + jne inner_loop + + # Store nc_register. + mov [rsp + 64], rsi + # Load odd k bit. + mov rsi, [rsp + 56] + # Check if channels are odd. + test rsi, rsi + mov rsi, [rsp + 64] + jz inner_loop_end + +inner_loop_tail: + vmovaps zmm7, [r9 + 0] + vmovaps zmm8, [r9 + 64] + vmovaps zmm9, [r9 + 128] + vmovaps zmm10, [r9 + 192] + add r9, 256 + vbroadcastsd zmm2, QWORD PTR [rcx + r11] + vfmadd231ps zmm11{k3}, zmm2, zmm7 + vfmadd231ps zmm13{k3}, zmm2, zmm8 + vfmadd231ps zmm15{k3}, zmm2, zmm9 + vfmadd231ps zmm17{k3}, zmm2, zmm10 + vbroadcastsd zmm3, QWORD PTR [rax + r11] + vfmadd231ps zmm12{k3}, zmm3, zmm7 + vfmadd231ps zmm14{k3}, zmm3, zmm8 + vfmadd231ps zmm16{k3}, zmm3, zmm9 + vfmadd231ps zmm18{k3}, zmm3, zmm10 +inner_loop_end: + vpsrlq zmm7, zmm11, 32 + vaddps zmm11, zmm11, zmm7 + vpsrlq zmm7, zmm12, 32 + vaddps zmm12, zmm12, zmm7 + vpsrlq zmm7, zmm13, 32 + vaddps zmm13, zmm13, zmm7 + vpsrlq zmm7, zmm14, 32 + vaddps zmm14, zmm14, zmm7 + vpsrlq zmm7, zmm15, 32 + vaddps zmm15, zmm15, zmm7 + vpsrlq zmm7, zmm16, 32 + vaddps zmm16, zmm16, zmm7 + vpsrlq zmm7, zmm17, 32 + vaddps zmm17, zmm17, zmm7 + vpsrlq zmm7, zmm18, 32 + vaddps zmm18, zmm18, zmm7 + vmovups zmm7, zmmword ptr [rip + .PERMUTATION] + vpermt2ps zmm11, zmm7, zmm13 + vpermt2ps zmm12, zmm7, zmm14 + vpermt2ps zmm15, zmm7, zmm17 + vpermt2ps zmm16, zmm7, zmm18 + # Min/max clamping. + vminps zmm11, zmm1, zmm11 + vminps zmm12, zmm1, zmm12 + vminps zmm13, zmm1, zmm15 + vminps zmm14, zmm1, zmm16 + vmaxps zmm11, zmm0, zmm11 + vmaxps zmm12, zmm0, zmm12 + vmaxps zmm13, zmm0, zmm13 + vmaxps zmm14, zmm0, zmm14 + + # Check whether full or partial store. + cmp rsi, 32 + jl tail + + vmovups [r10], zmm11 + vmovups [r10 + 64], zmm13 + vmovups [r13], zmm12 + vmovups [r13 + 64], zmm14 + add r10, 128 + add r13, 128 + + sub rsi, 32 + jne outer_loop + jmp return + +tail: + mov r11, -1 + shlx r11, r11, rsi + not r11 + kmovw k1, r11d + shr r11d, 16 + kmovw k2, r11d + vmovups ZMMWORD PTR [r10]{k1}, zmm11 + vmovups ZMMWORD PTR [r10 + 64]{k2}, zmm13 + vmovups ZMMWORD PTR [r13]{k1}, zmm12 + vmovups ZMMWORD PTR [r13 + 64]{k2}, zmm14 + +return: + add rsp, 128 + mov r13, [rsp] + mov rsp, r13 + # Restore the callee saved registers. + pop r12 + pop r13 + pop r14 + pop r15 + pop rbp + pop rbx + ret +END_FUNCTION xnn_f32_gemm_minmax_ukernel_2x32c2__asm_amd64_avx512f_broadcast + + #ifdef __has_feature + #if __has_feature(dataflow_sanitizer) +BEGIN_FUNCTION xnn_f32_gemm_minmax_ukernel_2x32c2__asm_amd64_avx512f_broadcast.dfsan + .intel_syntax noprefix + # We could implement this by calling a function that implements the dfsan instrumentation. + # For now, just break, so if someone tries to use this, they'll know where the problem is. + int 3 + ret +END_FUNCTION xnn_f32_gemm_minmax_ukernel_2x32c2__asm_amd64_avx512f_broadcast.dfsan + #endif + #endif \ No newline at end of file diff --git a/src/f32-gemm/gen/f32-gemm-2x64-minmax-asm-amd64-avx512f-broadcast.S b/src/f32-gemm/gen/f32-gemm-2x64-minmax-asm-amd64-avx512f-broadcast.S index 2dfd9185c5d..2fa249a533b 100644 --- a/src/f32-gemm/gen/f32-gemm-2x64-minmax-asm-amd64-avx512f-broadcast.S +++ b/src/f32-gemm/gen/f32-gemm-2x64-minmax-asm-amd64-avx512f-broadcast.S @@ -52,49 +52,49 @@ outer_loop: # Initialize k counter. mov r11, 0 # Initialize accumulators with the biases. - vmovaps zmm7, [r9 + 0] - vmovaps zmm9, [r9 + 64] + vmovaps zmm11, [r9 + 0] + vmovaps zmm13, [r9 + 64] vmovaps zmm15, [r9 + 128] vmovaps zmm17, [r9 + 192] - vmovaps zmm8, zmm7 - vmovaps zmm14, zmm9 + vmovaps zmm12, zmm11 + vmovaps zmm14, zmm13 vmovaps zmm16, zmm15 vmovaps zmm18, zmm17 add r9, 256 inner_loop: - vmovaps zmm10, [r9 + 0] - vmovaps zmm11, [r9 + 64] - vmovaps zmm12, [r9 + 128] - vmovaps zmm13, [r9 + 192] + vmovaps zmm7, [r9 + 0] + vmovaps zmm8, [r9 + 64] + vmovaps zmm9, [r9 + 128] + vmovaps zmm10, [r9 + 192] add r9, 256 vbroadcastss zmm2, DWORD PTR [rcx + r11] - vfmadd231ps zmm7, zmm2, zmm10 - vfmadd231ps zmm9, zmm2, zmm11 - vfmadd231ps zmm15, zmm2, zmm12 - vfmadd231ps zmm17, zmm2, zmm13 + vfmadd231ps zmm11, zmm2, zmm7 + vfmadd231ps zmm13, zmm2, zmm8 + vfmadd231ps zmm15, zmm2, zmm9 + vfmadd231ps zmm17, zmm2, zmm10 vbroadcastss zmm3, DWORD PTR [rax + r11] - vfmadd231ps zmm8, zmm3, zmm10 - vfmadd231ps zmm14, zmm3, zmm11 - vfmadd231ps zmm16, zmm3, zmm12 - vfmadd231ps zmm18, zmm3, zmm13 + vfmadd231ps zmm12, zmm3, zmm7 + vfmadd231ps zmm14, zmm3, zmm8 + vfmadd231ps zmm16, zmm3, zmm9 + vfmadd231ps zmm18, zmm3, zmm10 add r11, 4 cmp rdx, r11 jne inner_loop inner_loop_end: # Min/max clamping. - vminps zmm7, zmm1, zmm7 - vminps zmm8, zmm1, zmm8 - vminps zmm9, zmm1, zmm9 + vminps zmm11, zmm1, zmm11 + vminps zmm12, zmm1, zmm12 + vminps zmm13, zmm1, zmm13 vminps zmm14, zmm1, zmm14 vminps zmm15, zmm1, zmm15 vminps zmm16, zmm1, zmm16 vminps zmm17, zmm1, zmm17 vminps zmm18, zmm1, zmm18 - vmaxps zmm7, zmm0, zmm7 - vmaxps zmm8, zmm0, zmm8 - vmaxps zmm9, zmm0, zmm9 + vmaxps zmm11, zmm0, zmm11 + vmaxps zmm12, zmm0, zmm12 + vmaxps zmm13, zmm0, zmm13 vmaxps zmm14, zmm0, zmm14 vmaxps zmm15, zmm0, zmm15 vmaxps zmm16, zmm0, zmm16 @@ -105,11 +105,11 @@ inner_loop_end: cmp rsi, 64 jl tail - vmovups [r10], zmm7 - vmovups [r10 + 64], zmm9 + vmovups [r10], zmm11 + vmovups [r10 + 64], zmm13 vmovups [r10 + 128], zmm15 vmovups [r10 + 192], zmm17 - vmovups [r13], zmm8 + vmovups [r13], zmm12 vmovups [r13 + 64], zmm14 vmovups [r13 + 128], zmm16 vmovups [r13 + 192], zmm18 @@ -132,11 +132,11 @@ tail: shr r11, 16 kmovw k4, r11d - vmovups ZMMWORD PTR [r10]{k1}, zmm7 - vmovups ZMMWORD PTR [r10 + 64]{k2}, zmm9 + vmovups ZMMWORD PTR [r10]{k1}, zmm11 + vmovups ZMMWORD PTR [r10 + 64]{k2}, zmm13 vmovups ZMMWORD PTR [r10 + 128]{k3}, zmm15 vmovups ZMMWORD PTR [r10 + 192]{k4}, zmm17 - vmovups ZMMWORD PTR [r13]{k1}, zmm8 + vmovups ZMMWORD PTR [r13]{k1}, zmm12 vmovups ZMMWORD PTR [r13 + 64]{k2}, zmm14 vmovups ZMMWORD PTR [r13 + 128]{k3}, zmm16 vmovups ZMMWORD PTR [r13 + 192]{k4}, zmm18 diff --git a/src/f32-gemm/gen/f32-gemm-3x16-minmax-asm-amd64-avx512f-broadcast.S b/src/f32-gemm/gen/f32-gemm-3x16-minmax-asm-amd64-avx512f-broadcast.S index 00b6de470a2..b0dace48264 100644 --- a/src/f32-gemm/gen/f32-gemm-3x16-minmax-asm-amd64-avx512f-broadcast.S +++ b/src/f32-gemm/gen/f32-gemm-3x16-minmax-asm-amd64-avx512f-broadcast.S @@ -61,40 +61,40 @@ outer_loop: # Initialize k counter. mov r11, 0 # Initialize accumulators with the biases. - vmovaps zmm7, [r9 + 0] - vmovaps zmm8, zmm7 - vmovaps zmm9, zmm7 + vmovaps zmm11, [r9 + 0] + vmovaps zmm12, zmm11 + vmovaps zmm13, zmm11 add r9, 64 inner_loop: - vmovaps zmm10, [r9 + 0] + vmovaps zmm7, [r9 + 0] add r9, 64 vbroadcastss zmm2, DWORD PTR [rcx + r11] - vfmadd231ps zmm7, zmm2, zmm10 + vfmadd231ps zmm11, zmm2, zmm7 vbroadcastss zmm3, DWORD PTR [rax + r11] - vfmadd231ps zmm8, zmm3, zmm10 + vfmadd231ps zmm12, zmm3, zmm7 vbroadcastss zmm4, DWORD PTR [r15 + r11] - vfmadd231ps zmm9, zmm4, zmm10 + vfmadd231ps zmm13, zmm4, zmm7 add r11, 4 cmp rdx, r11 jne inner_loop inner_loop_end: # Min/max clamping. - vminps zmm7, zmm1, zmm7 - vminps zmm8, zmm1, zmm8 - vminps zmm9, zmm1, zmm9 - vmaxps zmm7, zmm0, zmm7 - vmaxps zmm8, zmm0, zmm8 - vmaxps zmm9, zmm0, zmm9 + vminps zmm11, zmm1, zmm11 + vminps zmm12, zmm1, zmm12 + vminps zmm13, zmm1, zmm13 + vmaxps zmm11, zmm0, zmm11 + vmaxps zmm12, zmm0, zmm12 + vmaxps zmm13, zmm0, zmm13 # Check whether full or partial store. cmp rsi, 16 jl tail - vmovups [r10], zmm7 - vmovups [r13], zmm8 - vmovups [rbx], zmm9 + vmovups [r10], zmm11 + vmovups [r13], zmm12 + vmovups [rbx], zmm13 add r10, 64 add r13, 64 add rbx, 64 @@ -108,9 +108,9 @@ tail: shlx r11, r11, rsi not r11 kmovw k1, r11d - vmovups ZMMWORD PTR [r10]{k1}, zmm7 - vmovups ZMMWORD PTR [r13]{k1}, zmm8 - vmovups ZMMWORD PTR [rbx]{k1}, zmm9 + vmovups ZMMWORD PTR [r10]{k1}, zmm11 + vmovups ZMMWORD PTR [r13]{k1}, zmm12 + vmovups ZMMWORD PTR [rbx]{k1}, zmm13 return: add rsp, 128 diff --git a/src/f32-gemm/gen/f32-gemm-3x16c2-minmax-asm-amd64-avx512f-broadcast.S b/src/f32-gemm/gen/f32-gemm-3x16c2-minmax-asm-amd64-avx512f-broadcast.S new file mode 100644 index 00000000000..d1c56d99e5e --- /dev/null +++ b/src/f32-gemm/gen/f32-gemm-3x16c2-minmax-asm-amd64-avx512f-broadcast.S @@ -0,0 +1,214 @@ +// Copyright 2025 Google LLC +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + +#include "xnnpack/assembly.h" +.PERMUTATION: + .long 0 + .long 2 + .long 4 + .long 6 + .long 8 + .long 10 + .long 12 + .long 14 + .long 16 + .long 18 + .long 20 + .long 22 + .long 24 + .long 26 + .long 28 + .long 30 + +BEGIN_FUNCTION xnn_f32_gemm_minmax_ukernel_3x16c2__asm_amd64_avx512f_broadcast + + .intel_syntax noprefix + + # Free up GP registers. + push rbx + push rbp + push r15 + push r14 + push r13 + push r12 + + # load params to free up a GP registers + mov r13, [rsp + 80] # params + vbroadcastss zmm0, DWORD PTR [r13] + vbroadcastss zmm1, DWORD PTR [r13 + 4] + + # Load c pointer. + mov r10, [rsp + 56] + # Load cm_stride. + mov r11, [rsp + 64] + + # Align the stack pointer. + mov r13, rsp + sub rsp, 64 + and rsp, 0xFFFFFFFFFFFFFFC0 + # Store the old stack pointer containing the return address + mov [rsp], r13 + + # Allocate some space on the stack. + sub rsp, 128 + + # Clamp a & c pointers if mr <= 1 + mov rax, rcx + add rax, r8 + mov r13, r10 + add r13, r11 + cmp rdi, 1 + cmovle rax, rcx + cmovle r13, r10 + + # Clamp a & c pointers if mr <= 2 + mov r15, rax + add r15, r8 + mov rbx, r13 + add rbx, r11 + cmp rdi, 2 + cmovle r15, rax + cmovle rbx, r13 + + # Copy k and flip bit. + mov r11, rdx + and r11, 0x4 + and rdx, 0xFFFFFFFFFFFFFFFB + mov [rsp + 72], r11 + mov r11, 0x5555 + kmovw k3, r11d + +outer_loop: + # Initialize k counter. + mov r11, 0 + vmovaps zmm7, [r9 + 0] + # Interleave with zeros. + vpmovzxdq zmm11, ymm7 + vextracti64x4 ymm7, zmm7, 1 + vpmovzxdq zmm14, ymm7 + vmovaps zmm12, zmm11 + vmovaps zmm13, zmm11 + vmovaps zmm15, zmm14 + vmovaps zmm16, zmm14 + add r9, 64 + + # Are there at least 8 bytes? + cmp rdx, 8 + js inner_loop_tail + +inner_loop: + vmovaps zmm7, [r9 + 0] + vmovaps zmm8, [r9 + 64] + add r9, 128 + vbroadcastsd zmm2, QWORD PTR [rcx + r11] + vfmadd231ps zmm11, zmm2, zmm7 + vfmadd231ps zmm14, zmm2, zmm8 + vbroadcastsd zmm3, QWORD PTR [rax + r11] + vfmadd231ps zmm12, zmm3, zmm7 + vfmadd231ps zmm15, zmm3, zmm8 + vbroadcastsd zmm4, QWORD PTR [r15 + r11] + vfmadd231ps zmm13, zmm4, zmm7 + vfmadd231ps zmm16, zmm4, zmm8 + + add r11, 8 + cmp rdx, r11 + jne inner_loop + + # Store nc_register. + mov [rsp + 80], rsi + # Load odd k bit. + mov rsi, [rsp + 72] + # Check if channels are odd. + test rsi, rsi + mov rsi, [rsp + 80] + jz inner_loop_end + +inner_loop_tail: + vmovaps zmm7, [r9 + 0] + vmovaps zmm8, [r9 + 64] + add r9, 128 + vbroadcastsd zmm2, QWORD PTR [rcx + r11] + vfmadd231ps zmm11{k3}, zmm2, zmm7 + vfmadd231ps zmm14{k3}, zmm2, zmm8 + vbroadcastsd zmm3, QWORD PTR [rax + r11] + vfmadd231ps zmm12{k3}, zmm3, zmm7 + vfmadd231ps zmm15{k3}, zmm3, zmm8 + vbroadcastsd zmm4, QWORD PTR [r15 + r11] + vfmadd231ps zmm13{k3}, zmm4, zmm7 + vfmadd231ps zmm16{k3}, zmm4, zmm8 +inner_loop_end: + vpsrlq zmm7, zmm11, 32 + vaddps zmm11, zmm11, zmm7 + vpsrlq zmm7, zmm12, 32 + vaddps zmm12, zmm12, zmm7 + vpsrlq zmm7, zmm13, 32 + vaddps zmm13, zmm13, zmm7 + vpsrlq zmm7, zmm14, 32 + vaddps zmm14, zmm14, zmm7 + vpsrlq zmm7, zmm15, 32 + vaddps zmm15, zmm15, zmm7 + vpsrlq zmm7, zmm16, 32 + vaddps zmm16, zmm16, zmm7 + vmovups zmm7, zmmword ptr [rip + .PERMUTATION] + vpermt2ps zmm11, zmm7, zmm14 + vpermt2ps zmm12, zmm7, zmm15 + vpermt2ps zmm13, zmm7, zmm16 + # Min/max clamping. + vminps zmm11, zmm1, zmm11 + vminps zmm12, zmm1, zmm12 + vminps zmm13, zmm1, zmm13 + vmaxps zmm11, zmm0, zmm11 + vmaxps zmm12, zmm0, zmm12 + vmaxps zmm13, zmm0, zmm13 + + # Check whether full or partial store. + cmp rsi, 16 + jl tail + + vmovups [r10], zmm11 + vmovups [r13], zmm12 + vmovups [rbx], zmm13 + add r10, 64 + add r13, 64 + add rbx, 64 + + sub rsi, 16 + jne outer_loop + jmp return + +tail: + mov r11, -1 + shlx r11, r11, rsi + not r11 + kmovw k1, r11d + vmovups ZMMWORD PTR [r10]{k1}, zmm11 + vmovups ZMMWORD PTR [r13]{k1}, zmm12 + vmovups ZMMWORD PTR [rbx]{k1}, zmm13 + +return: + add rsp, 128 + mov r13, [rsp] + mov rsp, r13 + # Restore the callee saved registers. + pop r12 + pop r13 + pop r14 + pop r15 + pop rbp + pop rbx + ret +END_FUNCTION xnn_f32_gemm_minmax_ukernel_3x16c2__asm_amd64_avx512f_broadcast + + #ifdef __has_feature + #if __has_feature(dataflow_sanitizer) +BEGIN_FUNCTION xnn_f32_gemm_minmax_ukernel_3x16c2__asm_amd64_avx512f_broadcast.dfsan + .intel_syntax noprefix + # We could implement this by calling a function that implements the dfsan instrumentation. + # For now, just break, so if someone tries to use this, they'll know where the problem is. + int 3 + ret +END_FUNCTION xnn_f32_gemm_minmax_ukernel_3x16c2__asm_amd64_avx512f_broadcast.dfsan + #endif + #endif \ No newline at end of file diff --git a/src/f32-gemm/gen/f32-gemm-3x32-minmax-asm-amd64-avx512f-broadcast.S b/src/f32-gemm/gen/f32-gemm-3x32-minmax-asm-amd64-avx512f-broadcast.S index cb9e8750856..83ff0b29a9e 100644 --- a/src/f32-gemm/gen/f32-gemm-3x32-minmax-asm-amd64-avx512f-broadcast.S +++ b/src/f32-gemm/gen/f32-gemm-3x32-minmax-asm-amd64-avx512f-broadcast.S @@ -61,42 +61,42 @@ outer_loop: # Initialize k counter. mov r11, 0 # Initialize accumulators with the biases. - vmovaps zmm7, [r9 + 0] + vmovaps zmm11, [r9 + 0] vmovaps zmm14, [r9 + 64] - vmovaps zmm8, zmm7 - vmovaps zmm9, zmm7 + vmovaps zmm12, zmm11 + vmovaps zmm13, zmm11 vmovaps zmm15, zmm14 vmovaps zmm16, zmm14 add r9, 128 inner_loop: - vmovaps zmm10, [r9 + 0] - vmovaps zmm11, [r9 + 64] + vmovaps zmm7, [r9 + 0] + vmovaps zmm8, [r9 + 64] add r9, 128 vbroadcastss zmm2, DWORD PTR [rcx + r11] - vfmadd231ps zmm7, zmm2, zmm10 - vfmadd231ps zmm14, zmm2, zmm11 + vfmadd231ps zmm11, zmm2, zmm7 + vfmadd231ps zmm14, zmm2, zmm8 vbroadcastss zmm3, DWORD PTR [rax + r11] - vfmadd231ps zmm8, zmm3, zmm10 - vfmadd231ps zmm15, zmm3, zmm11 + vfmadd231ps zmm12, zmm3, zmm7 + vfmadd231ps zmm15, zmm3, zmm8 vbroadcastss zmm4, DWORD PTR [r15 + r11] - vfmadd231ps zmm9, zmm4, zmm10 - vfmadd231ps zmm16, zmm4, zmm11 + vfmadd231ps zmm13, zmm4, zmm7 + vfmadd231ps zmm16, zmm4, zmm8 add r11, 4 cmp rdx, r11 jne inner_loop inner_loop_end: # Min/max clamping. - vminps zmm7, zmm1, zmm7 - vminps zmm8, zmm1, zmm8 - vminps zmm9, zmm1, zmm9 + vminps zmm11, zmm1, zmm11 + vminps zmm12, zmm1, zmm12 + vminps zmm13, zmm1, zmm13 vminps zmm14, zmm1, zmm14 vminps zmm15, zmm1, zmm15 vminps zmm16, zmm1, zmm16 - vmaxps zmm7, zmm0, zmm7 - vmaxps zmm8, zmm0, zmm8 - vmaxps zmm9, zmm0, zmm9 + vmaxps zmm11, zmm0, zmm11 + vmaxps zmm12, zmm0, zmm12 + vmaxps zmm13, zmm0, zmm13 vmaxps zmm14, zmm0, zmm14 vmaxps zmm15, zmm0, zmm15 vmaxps zmm16, zmm0, zmm16 @@ -105,11 +105,11 @@ inner_loop_end: cmp rsi, 32 jl tail - vmovups [r10], zmm7 + vmovups [r10], zmm11 vmovups [r10 + 64], zmm14 - vmovups [r13], zmm8 + vmovups [r13], zmm12 vmovups [r13 + 64], zmm15 - vmovups [rbx], zmm9 + vmovups [rbx], zmm13 vmovups [rbx + 64], zmm16 add r10, 128 add r13, 128 @@ -126,11 +126,11 @@ tail: kmovw k1, r11d shr r11d, 16 kmovw k2, r11d - vmovups ZMMWORD PTR [r10]{k1}, zmm7 + vmovups ZMMWORD PTR [r10]{k1}, zmm11 vmovups ZMMWORD PTR [r10 + 64]{k2}, zmm14 - vmovups ZMMWORD PTR [r13]{k1}, zmm8 + vmovups ZMMWORD PTR [r13]{k1}, zmm12 vmovups ZMMWORD PTR [r13 + 64]{k2}, zmm15 - vmovups ZMMWORD PTR [rbx]{k1}, zmm9 + vmovups ZMMWORD PTR [rbx]{k1}, zmm13 vmovups ZMMWORD PTR [rbx + 64]{k2}, zmm16 return: diff --git a/src/f32-gemm/gen/f32-gemm-3x32c2-minmax-asm-amd64-avx512f-broadcast.S b/src/f32-gemm/gen/f32-gemm-3x32c2-minmax-asm-amd64-avx512f-broadcast.S new file mode 100644 index 00000000000..24954a34f7f --- /dev/null +++ b/src/f32-gemm/gen/f32-gemm-3x32c2-minmax-asm-amd64-avx512f-broadcast.S @@ -0,0 +1,267 @@ +// Copyright 2025 Google LLC +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + +#include "xnnpack/assembly.h" +.PERMUTATION: + .long 0 + .long 2 + .long 4 + .long 6 + .long 8 + .long 10 + .long 12 + .long 14 + .long 16 + .long 18 + .long 20 + .long 22 + .long 24 + .long 26 + .long 28 + .long 30 + +BEGIN_FUNCTION xnn_f32_gemm_minmax_ukernel_3x32c2__asm_amd64_avx512f_broadcast + + .intel_syntax noprefix + + # Free up GP registers. + push rbx + push rbp + push r15 + push r14 + push r13 + push r12 + + # load params to free up a GP registers + mov r13, [rsp + 80] # params + vbroadcastss zmm0, DWORD PTR [r13] + vbroadcastss zmm1, DWORD PTR [r13 + 4] + + # Load c pointer. + mov r10, [rsp + 56] + # Load cm_stride. + mov r11, [rsp + 64] + + # Align the stack pointer. + mov r13, rsp + sub rsp, 64 + and rsp, 0xFFFFFFFFFFFFFFC0 + # Store the old stack pointer containing the return address + mov [rsp], r13 + + # Allocate some space on the stack. + sub rsp, 128 + + # Clamp a & c pointers if mr <= 1 + mov rax, rcx + add rax, r8 + mov r13, r10 + add r13, r11 + cmp rdi, 1 + cmovle rax, rcx + cmovle r13, r10 + + # Clamp a & c pointers if mr <= 2 + mov r15, rax + add r15, r8 + mov rbx, r13 + add rbx, r11 + cmp rdi, 2 + cmovle r15, rax + cmovle rbx, r13 + + # Copy k and flip bit. + mov r11, rdx + and r11, 0x4 + and rdx, 0xFFFFFFFFFFFFFFFB + mov [rsp + 72], r11 + mov r11, 0x5555 + kmovw k3, r11d + +outer_loop: + # Initialize k counter. + mov r11, 0 + vmovaps zmm7, [r9 + 0] + vmovaps zmm8, [r9 + 64] + # Interleave with zeros. + vpmovzxdq zmm11, ymm7 + vextracti64x4 ymm7, zmm7, 1 + vpmovzxdq zmm14, ymm7 + vpmovzxdq zmm17, ymm8 + vextracti64x4 ymm8, zmm8, 1 + vpmovzxdq zmm20, ymm8 + vmovaps zmm12, zmm11 + vmovaps zmm13, zmm11 + vmovaps zmm15, zmm14 + vmovaps zmm16, zmm14 + vmovaps zmm18, zmm17 + vmovaps zmm19, zmm17 + vmovaps zmm21, zmm20 + vmovaps zmm22, zmm20 + add r9, 128 + + # Are there at least 8 bytes? + cmp rdx, 8 + js inner_loop_tail + +inner_loop: + vmovaps zmm7, [r9 + 0] + vmovaps zmm8, [r9 + 64] + vmovaps zmm9, [r9 + 128] + vmovaps zmm10, [r9 + 192] + add r9, 256 + vbroadcastsd zmm2, QWORD PTR [rcx + r11] + vfmadd231ps zmm11, zmm2, zmm7 + vfmadd231ps zmm14, zmm2, zmm8 + vfmadd231ps zmm17, zmm2, zmm9 + vfmadd231ps zmm20, zmm2, zmm10 + vbroadcastsd zmm3, QWORD PTR [rax + r11] + vfmadd231ps zmm12, zmm3, zmm7 + vfmadd231ps zmm15, zmm3, zmm8 + vfmadd231ps zmm18, zmm3, zmm9 + vfmadd231ps zmm21, zmm3, zmm10 + vbroadcastsd zmm4, QWORD PTR [r15 + r11] + vfmadd231ps zmm13, zmm4, zmm7 + vfmadd231ps zmm16, zmm4, zmm8 + vfmadd231ps zmm19, zmm4, zmm9 + vfmadd231ps zmm22, zmm4, zmm10 + + add r11, 8 + cmp rdx, r11 + jne inner_loop + + # Store nc_register. + mov [rsp + 80], rsi + # Load odd k bit. + mov rsi, [rsp + 72] + # Check if channels are odd. + test rsi, rsi + mov rsi, [rsp + 80] + jz inner_loop_end + +inner_loop_tail: + vmovaps zmm7, [r9 + 0] + vmovaps zmm8, [r9 + 64] + vmovaps zmm9, [r9 + 128] + vmovaps zmm10, [r9 + 192] + add r9, 256 + vbroadcastsd zmm2, QWORD PTR [rcx + r11] + vfmadd231ps zmm11{k3}, zmm2, zmm7 + vfmadd231ps zmm14{k3}, zmm2, zmm8 + vfmadd231ps zmm17{k3}, zmm2, zmm9 + vfmadd231ps zmm20{k3}, zmm2, zmm10 + vbroadcastsd zmm3, QWORD PTR [rax + r11] + vfmadd231ps zmm12{k3}, zmm3, zmm7 + vfmadd231ps zmm15{k3}, zmm3, zmm8 + vfmadd231ps zmm18{k3}, zmm3, zmm9 + vfmadd231ps zmm21{k3}, zmm3, zmm10 + vbroadcastsd zmm4, QWORD PTR [r15 + r11] + vfmadd231ps zmm13{k3}, zmm4, zmm7 + vfmadd231ps zmm16{k3}, zmm4, zmm8 + vfmadd231ps zmm19{k3}, zmm4, zmm9 + vfmadd231ps zmm22{k3}, zmm4, zmm10 +inner_loop_end: + vpsrlq zmm7, zmm11, 32 + vaddps zmm11, zmm11, zmm7 + vpsrlq zmm7, zmm12, 32 + vaddps zmm12, zmm12, zmm7 + vpsrlq zmm7, zmm13, 32 + vaddps zmm13, zmm13, zmm7 + vpsrlq zmm7, zmm14, 32 + vaddps zmm14, zmm14, zmm7 + vpsrlq zmm7, zmm15, 32 + vaddps zmm15, zmm15, zmm7 + vpsrlq zmm7, zmm16, 32 + vaddps zmm16, zmm16, zmm7 + vpsrlq zmm7, zmm17, 32 + vaddps zmm17, zmm17, zmm7 + vpsrlq zmm7, zmm18, 32 + vaddps zmm18, zmm18, zmm7 + vpsrlq zmm7, zmm19, 32 + vaddps zmm19, zmm19, zmm7 + vpsrlq zmm7, zmm20, 32 + vaddps zmm20, zmm20, zmm7 + vpsrlq zmm7, zmm21, 32 + vaddps zmm21, zmm21, zmm7 + vpsrlq zmm7, zmm22, 32 + vaddps zmm22, zmm22, zmm7 + vmovups zmm7, zmmword ptr [rip + .PERMUTATION] + vpermt2ps zmm11, zmm7, zmm14 + vpermt2ps zmm12, zmm7, zmm15 + vpermt2ps zmm13, zmm7, zmm16 + vpermt2ps zmm17, zmm7, zmm20 + vpermt2ps zmm18, zmm7, zmm21 + vpermt2ps zmm19, zmm7, zmm22 + # Min/max clamping. + vminps zmm11, zmm1, zmm11 + vminps zmm12, zmm1, zmm12 + vminps zmm13, zmm1, zmm13 + vminps zmm14, zmm1, zmm17 + vminps zmm15, zmm1, zmm18 + vminps zmm16, zmm1, zmm19 + vmaxps zmm11, zmm0, zmm11 + vmaxps zmm12, zmm0, zmm12 + vmaxps zmm13, zmm0, zmm13 + vmaxps zmm14, zmm0, zmm14 + vmaxps zmm15, zmm0, zmm15 + vmaxps zmm16, zmm0, zmm16 + + # Check whether full or partial store. + cmp rsi, 32 + jl tail + + vmovups [r10], zmm11 + vmovups [r10 + 64], zmm14 + vmovups [r13], zmm12 + vmovups [r13 + 64], zmm15 + vmovups [rbx], zmm13 + vmovups [rbx + 64], zmm16 + add r10, 128 + add r13, 128 + add rbx, 128 + + sub rsi, 32 + jne outer_loop + jmp return + +tail: + mov r11, -1 + shlx r11, r11, rsi + not r11 + kmovw k1, r11d + shr r11d, 16 + kmovw k2, r11d + vmovups ZMMWORD PTR [r10]{k1}, zmm11 + vmovups ZMMWORD PTR [r10 + 64]{k2}, zmm14 + vmovups ZMMWORD PTR [r13]{k1}, zmm12 + vmovups ZMMWORD PTR [r13 + 64]{k2}, zmm15 + vmovups ZMMWORD PTR [rbx]{k1}, zmm13 + vmovups ZMMWORD PTR [rbx + 64]{k2}, zmm16 + +return: + add rsp, 128 + mov r13, [rsp] + mov rsp, r13 + # Restore the callee saved registers. + pop r12 + pop r13 + pop r14 + pop r15 + pop rbp + pop rbx + ret +END_FUNCTION xnn_f32_gemm_minmax_ukernel_3x32c2__asm_amd64_avx512f_broadcast + + #ifdef __has_feature + #if __has_feature(dataflow_sanitizer) +BEGIN_FUNCTION xnn_f32_gemm_minmax_ukernel_3x32c2__asm_amd64_avx512f_broadcast.dfsan + .intel_syntax noprefix + # We could implement this by calling a function that implements the dfsan instrumentation. + # For now, just break, so if someone tries to use this, they'll know where the problem is. + int 3 + ret +END_FUNCTION xnn_f32_gemm_minmax_ukernel_3x32c2__asm_amd64_avx512f_broadcast.dfsan + #endif + #endif \ No newline at end of file diff --git a/src/f32-gemm/gen/f32-gemm-3x64-minmax-asm-amd64-avx512f-broadcast.S b/src/f32-gemm/gen/f32-gemm-3x64-minmax-asm-amd64-avx512f-broadcast.S index dd09eb5aefe..df159066664 100644 --- a/src/f32-gemm/gen/f32-gemm-3x64-minmax-asm-amd64-avx512f-broadcast.S +++ b/src/f32-gemm/gen/f32-gemm-3x64-minmax-asm-amd64-avx512f-broadcast.S @@ -61,12 +61,12 @@ outer_loop: # Initialize k counter. mov r11, 0 # Initialize accumulators with the biases. - vmovaps zmm7, [r9 + 0] + vmovaps zmm11, [r9 + 0] vmovaps zmm14, [r9 + 64] vmovaps zmm17, [r9 + 128] vmovaps zmm20, [r9 + 192] - vmovaps zmm8, zmm7 - vmovaps zmm9, zmm7 + vmovaps zmm12, zmm11 + vmovaps zmm13, zmm11 vmovaps zmm15, zmm14 vmovaps zmm16, zmm14 vmovaps zmm18, zmm17 @@ -76,35 +76,35 @@ outer_loop: add r9, 256 inner_loop: - vmovaps zmm10, [r9 + 0] - vmovaps zmm11, [r9 + 64] - vmovaps zmm12, [r9 + 128] - vmovaps zmm13, [r9 + 192] + vmovaps zmm7, [r9 + 0] + vmovaps zmm8, [r9 + 64] + vmovaps zmm9, [r9 + 128] + vmovaps zmm10, [r9 + 192] add r9, 256 vbroadcastss zmm2, DWORD PTR [rcx + r11] - vfmadd231ps zmm7, zmm2, zmm10 - vfmadd231ps zmm14, zmm2, zmm11 - vfmadd231ps zmm17, zmm2, zmm12 - vfmadd231ps zmm20, zmm2, zmm13 + vfmadd231ps zmm11, zmm2, zmm7 + vfmadd231ps zmm14, zmm2, zmm8 + vfmadd231ps zmm17, zmm2, zmm9 + vfmadd231ps zmm20, zmm2, zmm10 vbroadcastss zmm3, DWORD PTR [rax + r11] - vfmadd231ps zmm8, zmm3, zmm10 - vfmadd231ps zmm15, zmm3, zmm11 - vfmadd231ps zmm18, zmm3, zmm12 - vfmadd231ps zmm21, zmm3, zmm13 + vfmadd231ps zmm12, zmm3, zmm7 + vfmadd231ps zmm15, zmm3, zmm8 + vfmadd231ps zmm18, zmm3, zmm9 + vfmadd231ps zmm21, zmm3, zmm10 vbroadcastss zmm4, DWORD PTR [r15 + r11] - vfmadd231ps zmm9, zmm4, zmm10 - vfmadd231ps zmm16, zmm4, zmm11 - vfmadd231ps zmm19, zmm4, zmm12 - vfmadd231ps zmm22, zmm4, zmm13 + vfmadd231ps zmm13, zmm4, zmm7 + vfmadd231ps zmm16, zmm4, zmm8 + vfmadd231ps zmm19, zmm4, zmm9 + vfmadd231ps zmm22, zmm4, zmm10 add r11, 4 cmp rdx, r11 jne inner_loop inner_loop_end: # Min/max clamping. - vminps zmm7, zmm1, zmm7 - vminps zmm8, zmm1, zmm8 - vminps zmm9, zmm1, zmm9 + vminps zmm11, zmm1, zmm11 + vminps zmm12, zmm1, zmm12 + vminps zmm13, zmm1, zmm13 vminps zmm14, zmm1, zmm14 vminps zmm15, zmm1, zmm15 vminps zmm16, zmm1, zmm16 @@ -114,9 +114,9 @@ inner_loop_end: vminps zmm20, zmm1, zmm20 vminps zmm21, zmm1, zmm21 vminps zmm22, zmm1, zmm22 - vmaxps zmm7, zmm0, zmm7 - vmaxps zmm8, zmm0, zmm8 - vmaxps zmm9, zmm0, zmm9 + vmaxps zmm11, zmm0, zmm11 + vmaxps zmm12, zmm0, zmm12 + vmaxps zmm13, zmm0, zmm13 vmaxps zmm14, zmm0, zmm14 vmaxps zmm15, zmm0, zmm15 vmaxps zmm16, zmm0, zmm16 @@ -131,15 +131,15 @@ inner_loop_end: cmp rsi, 64 jl tail - vmovups [r10], zmm7 + vmovups [r10], zmm11 vmovups [r10 + 64], zmm14 vmovups [r10 + 128], zmm17 vmovups [r10 + 192], zmm20 - vmovups [r13], zmm8 + vmovups [r13], zmm12 vmovups [r13 + 64], zmm15 vmovups [r13 + 128], zmm18 vmovups [r13 + 192], zmm21 - vmovups [rbx], zmm9 + vmovups [rbx], zmm13 vmovups [rbx + 64], zmm16 vmovups [rbx + 128], zmm19 vmovups [rbx + 192], zmm22 @@ -163,15 +163,15 @@ tail: shr r11, 16 kmovw k4, r11d - vmovups ZMMWORD PTR [r10]{k1}, zmm7 + vmovups ZMMWORD PTR [r10]{k1}, zmm11 vmovups ZMMWORD PTR [r10 + 64]{k2}, zmm14 vmovups ZMMWORD PTR [r10 + 128]{k3}, zmm17 vmovups ZMMWORD PTR [r10 + 192]{k4}, zmm20 - vmovups ZMMWORD PTR [r13]{k1}, zmm8 + vmovups ZMMWORD PTR [r13]{k1}, zmm12 vmovups ZMMWORD PTR [r13 + 64]{k2}, zmm15 vmovups ZMMWORD PTR [r13 + 128]{k3}, zmm18 vmovups ZMMWORD PTR [r13 + 192]{k4}, zmm21 - vmovups ZMMWORD PTR [rbx]{k1}, zmm9 + vmovups ZMMWORD PTR [rbx]{k1}, zmm13 vmovups ZMMWORD PTR [rbx + 64]{k2}, zmm16 vmovups ZMMWORD PTR [rbx + 128]{k3}, zmm19 vmovups ZMMWORD PTR [rbx + 192]{k4}, zmm22 diff --git a/src/f32-gemm/gen/f32-gemm-4x16-minmax-asm-amd64-avx512f-broadcast.S b/src/f32-gemm/gen/f32-gemm-4x16-minmax-asm-amd64-avx512f-broadcast.S index 5b632e4bdff..106b4a4d4f5 100644 --- a/src/f32-gemm/gen/f32-gemm-4x16-minmax-asm-amd64-avx512f-broadcast.S +++ b/src/f32-gemm/gen/f32-gemm-4x16-minmax-asm-amd64-avx512f-broadcast.S @@ -70,45 +70,45 @@ outer_loop: # Initialize k counter. mov r11, 0 # Initialize accumulators with the biases. - vmovaps zmm7, [r9 + 0] - vmovaps zmm8, zmm7 - vmovaps zmm9, zmm7 - vmovaps zmm14, zmm7 + vmovaps zmm11, [r9 + 0] + vmovaps zmm12, zmm11 + vmovaps zmm13, zmm11 + vmovaps zmm14, zmm11 add r9, 64 inner_loop: - vmovaps zmm10, [r9 + 0] + vmovaps zmm7, [r9 + 0] add r9, 64 vbroadcastss zmm2, DWORD PTR [rcx + r11] - vfmadd231ps zmm7, zmm2, zmm10 + vfmadd231ps zmm11, zmm2, zmm7 vbroadcastss zmm3, DWORD PTR [rax + r11] - vfmadd231ps zmm8, zmm3, zmm10 + vfmadd231ps zmm12, zmm3, zmm7 vbroadcastss zmm4, DWORD PTR [r15 + r11] - vfmadd231ps zmm9, zmm4, zmm10 + vfmadd231ps zmm13, zmm4, zmm7 vbroadcastss zmm5, DWORD PTR [r14 + r11] - vfmadd231ps zmm14, zmm5, zmm10 + vfmadd231ps zmm14, zmm5, zmm7 add r11, 4 cmp rdx, r11 jne inner_loop inner_loop_end: # Min/max clamping. - vminps zmm7, zmm1, zmm7 - vminps zmm8, zmm1, zmm8 - vminps zmm9, zmm1, zmm9 + vminps zmm11, zmm1, zmm11 + vminps zmm12, zmm1, zmm12 + vminps zmm13, zmm1, zmm13 vminps zmm14, zmm1, zmm14 - vmaxps zmm7, zmm0, zmm7 - vmaxps zmm8, zmm0, zmm8 - vmaxps zmm9, zmm0, zmm9 + vmaxps zmm11, zmm0, zmm11 + vmaxps zmm12, zmm0, zmm12 + vmaxps zmm13, zmm0, zmm13 vmaxps zmm14, zmm0, zmm14 # Check whether full or partial store. cmp rsi, 16 jl tail - vmovups [r10], zmm7 - vmovups [r13], zmm8 - vmovups [rbx], zmm9 + vmovups [r10], zmm11 + vmovups [r13], zmm12 + vmovups [rbx], zmm13 vmovups [rbp], zmm14 add r10, 64 add r13, 64 @@ -124,9 +124,9 @@ tail: shlx r11, r11, rsi not r11 kmovw k1, r11d - vmovups ZMMWORD PTR [r10]{k1}, zmm7 - vmovups ZMMWORD PTR [r13]{k1}, zmm8 - vmovups ZMMWORD PTR [rbx]{k1}, zmm9 + vmovups ZMMWORD PTR [r10]{k1}, zmm11 + vmovups ZMMWORD PTR [r13]{k1}, zmm12 + vmovups ZMMWORD PTR [rbx]{k1}, zmm13 vmovups ZMMWORD PTR [rbp]{k1}, zmm14 return: diff --git a/src/f32-gemm/gen/f32-gemm-4x16c2-minmax-asm-amd64-avx512f-broadcast.S b/src/f32-gemm/gen/f32-gemm-4x16c2-minmax-asm-amd64-avx512f-broadcast.S new file mode 100644 index 00000000000..733cd6aa078 --- /dev/null +++ b/src/f32-gemm/gen/f32-gemm-4x16c2-minmax-asm-amd64-avx512f-broadcast.S @@ -0,0 +1,241 @@ +// Copyright 2025 Google LLC +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + +#include "xnnpack/assembly.h" +.PERMUTATION: + .long 0 + .long 2 + .long 4 + .long 6 + .long 8 + .long 10 + .long 12 + .long 14 + .long 16 + .long 18 + .long 20 + .long 22 + .long 24 + .long 26 + .long 28 + .long 30 + +BEGIN_FUNCTION xnn_f32_gemm_minmax_ukernel_4x16c2__asm_amd64_avx512f_broadcast + + .intel_syntax noprefix + + # Free up GP registers. + push rbx + push rbp + push r15 + push r14 + push r13 + push r12 + + # load params to free up a GP registers + mov r13, [rsp + 80] # params + vbroadcastss zmm0, DWORD PTR [r13] + vbroadcastss zmm1, DWORD PTR [r13 + 4] + + # Load c pointer. + mov r10, [rsp + 56] + # Load cm_stride. + mov r11, [rsp + 64] + + # Align the stack pointer. + mov r13, rsp + sub rsp, 64 + and rsp, 0xFFFFFFFFFFFFFFC0 + # Store the old stack pointer containing the return address + mov [rsp], r13 + + # Allocate some space on the stack. + sub rsp, 128 + + # Clamp a & c pointers if mr <= 1 + mov rax, rcx + add rax, r8 + mov r13, r10 + add r13, r11 + cmp rdi, 1 + cmovle rax, rcx + cmovle r13, r10 + + # Clamp a & c pointers if mr <= 2 + mov r15, rax + add r15, r8 + mov rbx, r13 + add rbx, r11 + cmp rdi, 2 + cmovle r15, rax + cmovle rbx, r13 + + # Clamp a & c pointers if mr <= 3 + mov r14, r15 + add r14, r8 + mov rbp, rbx + add rbp, r11 + cmp rdi, 3 + cmovle r14, r15 + cmovle rbp, rbx + + # Copy k and flip bit. + mov r11, rdx + and r11, 0x4 + and rdx, 0xFFFFFFFFFFFFFFFB + mov [rsp + 88], r11 + mov r11, 0x5555 + kmovw k3, r11d + +outer_loop: + # Initialize k counter. + mov r11, 0 + vmovaps zmm7, [r9 + 0] + # Interleave with zeros. + vpmovzxdq zmm11, ymm7 + vextracti64x4 ymm7, zmm7, 1 + vpmovzxdq zmm15, ymm7 + vmovaps zmm12, zmm11 + vmovaps zmm13, zmm11 + vmovaps zmm14, zmm11 + vmovaps zmm16, zmm15 + vmovaps zmm17, zmm15 + vmovaps zmm18, zmm15 + add r9, 64 + + # Are there at least 8 bytes? + cmp rdx, 8 + js inner_loop_tail + +inner_loop: + vmovaps zmm7, [r9 + 0] + vmovaps zmm8, [r9 + 64] + add r9, 128 + vbroadcastsd zmm2, QWORD PTR [rcx + r11] + vfmadd231ps zmm11, zmm2, zmm7 + vfmadd231ps zmm15, zmm2, zmm8 + vbroadcastsd zmm3, QWORD PTR [rax + r11] + vfmadd231ps zmm12, zmm3, zmm7 + vfmadd231ps zmm16, zmm3, zmm8 + vbroadcastsd zmm4, QWORD PTR [r15 + r11] + vfmadd231ps zmm13, zmm4, zmm7 + vfmadd231ps zmm17, zmm4, zmm8 + vbroadcastsd zmm5, QWORD PTR [r14 + r11] + vfmadd231ps zmm14, zmm5, zmm7 + vfmadd231ps zmm18, zmm5, zmm8 + + add r11, 8 + cmp rdx, r11 + jne inner_loop + + # Store nc_register. + mov [rsp + 96], rsi + # Load odd k bit. + mov rsi, [rsp + 88] + # Check if channels are odd. + test rsi, rsi + mov rsi, [rsp + 96] + jz inner_loop_end + +inner_loop_tail: + vmovaps zmm7, [r9 + 0] + vmovaps zmm8, [r9 + 64] + add r9, 128 + vbroadcastsd zmm2, QWORD PTR [rcx + r11] + vfmadd231ps zmm11{k3}, zmm2, zmm7 + vfmadd231ps zmm15{k3}, zmm2, zmm8 + vbroadcastsd zmm3, QWORD PTR [rax + r11] + vfmadd231ps zmm12{k3}, zmm3, zmm7 + vfmadd231ps zmm16{k3}, zmm3, zmm8 + vbroadcastsd zmm4, QWORD PTR [r15 + r11] + vfmadd231ps zmm13{k3}, zmm4, zmm7 + vfmadd231ps zmm17{k3}, zmm4, zmm8 + vbroadcastsd zmm5, QWORD PTR [r14 + r11] + vfmadd231ps zmm14{k3}, zmm5, zmm7 + vfmadd231ps zmm18{k3}, zmm5, zmm8 +inner_loop_end: + vpsrlq zmm7, zmm11, 32 + vaddps zmm11, zmm11, zmm7 + vpsrlq zmm7, zmm12, 32 + vaddps zmm12, zmm12, zmm7 + vpsrlq zmm7, zmm13, 32 + vaddps zmm13, zmm13, zmm7 + vpsrlq zmm7, zmm14, 32 + vaddps zmm14, zmm14, zmm7 + vpsrlq zmm7, zmm15, 32 + vaddps zmm15, zmm15, zmm7 + vpsrlq zmm7, zmm16, 32 + vaddps zmm16, zmm16, zmm7 + vpsrlq zmm7, zmm17, 32 + vaddps zmm17, zmm17, zmm7 + vpsrlq zmm7, zmm18, 32 + vaddps zmm18, zmm18, zmm7 + vmovups zmm7, zmmword ptr [rip + .PERMUTATION] + vpermt2ps zmm11, zmm7, zmm15 + vpermt2ps zmm12, zmm7, zmm16 + vpermt2ps zmm13, zmm7, zmm17 + vpermt2ps zmm14, zmm7, zmm18 + # Min/max clamping. + vminps zmm11, zmm1, zmm11 + vminps zmm12, zmm1, zmm12 + vminps zmm13, zmm1, zmm13 + vminps zmm14, zmm1, zmm14 + vmaxps zmm11, zmm0, zmm11 + vmaxps zmm12, zmm0, zmm12 + vmaxps zmm13, zmm0, zmm13 + vmaxps zmm14, zmm0, zmm14 + + # Check whether full or partial store. + cmp rsi, 16 + jl tail + + vmovups [r10], zmm11 + vmovups [r13], zmm12 + vmovups [rbx], zmm13 + vmovups [rbp], zmm14 + add r10, 64 + add r13, 64 + add rbx, 64 + add rbp, 64 + + sub rsi, 16 + jne outer_loop + jmp return + +tail: + mov r11, -1 + shlx r11, r11, rsi + not r11 + kmovw k1, r11d + vmovups ZMMWORD PTR [r10]{k1}, zmm11 + vmovups ZMMWORD PTR [r13]{k1}, zmm12 + vmovups ZMMWORD PTR [rbx]{k1}, zmm13 + vmovups ZMMWORD PTR [rbp]{k1}, zmm14 + +return: + add rsp, 128 + mov r13, [rsp] + mov rsp, r13 + # Restore the callee saved registers. + pop r12 + pop r13 + pop r14 + pop r15 + pop rbp + pop rbx + ret +END_FUNCTION xnn_f32_gemm_minmax_ukernel_4x16c2__asm_amd64_avx512f_broadcast + + #ifdef __has_feature + #if __has_feature(dataflow_sanitizer) +BEGIN_FUNCTION xnn_f32_gemm_minmax_ukernel_4x16c2__asm_amd64_avx512f_broadcast.dfsan + .intel_syntax noprefix + # We could implement this by calling a function that implements the dfsan instrumentation. + # For now, just break, so if someone tries to use this, they'll know where the problem is. + int 3 + ret +END_FUNCTION xnn_f32_gemm_minmax_ukernel_4x16c2__asm_amd64_avx512f_broadcast.dfsan + #endif + #endif \ No newline at end of file diff --git a/src/f32-gemm/gen/f32-gemm-4x32-minmax-asm-amd64-avx512f-broadcast.S b/src/f32-gemm/gen/f32-gemm-4x32-minmax-asm-amd64-avx512f-broadcast.S index 6e400276543..87b7f31db31 100644 --- a/src/f32-gemm/gen/f32-gemm-4x32-minmax-asm-amd64-avx512f-broadcast.S +++ b/src/f32-gemm/gen/f32-gemm-4x32-minmax-asm-amd64-avx512f-broadcast.S @@ -70,49 +70,49 @@ outer_loop: # Initialize k counter. mov r11, 0 # Initialize accumulators with the biases. - vmovaps zmm7, [r9 + 0] + vmovaps zmm11, [r9 + 0] vmovaps zmm15, [r9 + 64] - vmovaps zmm8, zmm7 - vmovaps zmm9, zmm7 - vmovaps zmm14, zmm7 + vmovaps zmm12, zmm11 + vmovaps zmm13, zmm11 + vmovaps zmm14, zmm11 vmovaps zmm16, zmm15 vmovaps zmm17, zmm15 vmovaps zmm18, zmm15 add r9, 128 inner_loop: - vmovaps zmm10, [r9 + 0] - vmovaps zmm11, [r9 + 64] + vmovaps zmm7, [r9 + 0] + vmovaps zmm8, [r9 + 64] add r9, 128 vbroadcastss zmm2, DWORD PTR [rcx + r11] - vfmadd231ps zmm7, zmm2, zmm10 - vfmadd231ps zmm15, zmm2, zmm11 + vfmadd231ps zmm11, zmm2, zmm7 + vfmadd231ps zmm15, zmm2, zmm8 vbroadcastss zmm3, DWORD PTR [rax + r11] - vfmadd231ps zmm8, zmm3, zmm10 - vfmadd231ps zmm16, zmm3, zmm11 + vfmadd231ps zmm12, zmm3, zmm7 + vfmadd231ps zmm16, zmm3, zmm8 vbroadcastss zmm4, DWORD PTR [r15 + r11] - vfmadd231ps zmm9, zmm4, zmm10 - vfmadd231ps zmm17, zmm4, zmm11 + vfmadd231ps zmm13, zmm4, zmm7 + vfmadd231ps zmm17, zmm4, zmm8 vbroadcastss zmm5, DWORD PTR [r14 + r11] - vfmadd231ps zmm14, zmm5, zmm10 - vfmadd231ps zmm18, zmm5, zmm11 + vfmadd231ps zmm14, zmm5, zmm7 + vfmadd231ps zmm18, zmm5, zmm8 add r11, 4 cmp rdx, r11 jne inner_loop inner_loop_end: # Min/max clamping. - vminps zmm7, zmm1, zmm7 - vminps zmm8, zmm1, zmm8 - vminps zmm9, zmm1, zmm9 + vminps zmm11, zmm1, zmm11 + vminps zmm12, zmm1, zmm12 + vminps zmm13, zmm1, zmm13 vminps zmm14, zmm1, zmm14 vminps zmm15, zmm1, zmm15 vminps zmm16, zmm1, zmm16 vminps zmm17, zmm1, zmm17 vminps zmm18, zmm1, zmm18 - vmaxps zmm7, zmm0, zmm7 - vmaxps zmm8, zmm0, zmm8 - vmaxps zmm9, zmm0, zmm9 + vmaxps zmm11, zmm0, zmm11 + vmaxps zmm12, zmm0, zmm12 + vmaxps zmm13, zmm0, zmm13 vmaxps zmm14, zmm0, zmm14 vmaxps zmm15, zmm0, zmm15 vmaxps zmm16, zmm0, zmm16 @@ -123,11 +123,11 @@ inner_loop_end: cmp rsi, 32 jl tail - vmovups [r10], zmm7 + vmovups [r10], zmm11 vmovups [r10 + 64], zmm15 - vmovups [r13], zmm8 + vmovups [r13], zmm12 vmovups [r13 + 64], zmm16 - vmovups [rbx], zmm9 + vmovups [rbx], zmm13 vmovups [rbx + 64], zmm17 vmovups [rbp], zmm14 vmovups [rbp + 64], zmm18 @@ -147,11 +147,11 @@ tail: kmovw k1, r11d shr r11d, 16 kmovw k2, r11d - vmovups ZMMWORD PTR [r10]{k1}, zmm7 + vmovups ZMMWORD PTR [r10]{k1}, zmm11 vmovups ZMMWORD PTR [r10 + 64]{k2}, zmm15 - vmovups ZMMWORD PTR [r13]{k1}, zmm8 + vmovups ZMMWORD PTR [r13]{k1}, zmm12 vmovups ZMMWORD PTR [r13 + 64]{k2}, zmm16 - vmovups ZMMWORD PTR [rbx]{k1}, zmm9 + vmovups ZMMWORD PTR [rbx]{k1}, zmm13 vmovups ZMMWORD PTR [rbx + 64]{k2}, zmm17 vmovups ZMMWORD PTR [rbp]{k1}, zmm14 vmovups ZMMWORD PTR [rbp + 64]{k2}, zmm18 diff --git a/src/f32-gemm/gen/f32-gemm-4x32c2-minmax-asm-amd64-avx512f-broadcast.S b/src/f32-gemm/gen/f32-gemm-4x32c2-minmax-asm-amd64-avx512f-broadcast.S new file mode 100644 index 00000000000..5f0d903785d --- /dev/null +++ b/src/f32-gemm/gen/f32-gemm-4x32c2-minmax-asm-amd64-avx512f-broadcast.S @@ -0,0 +1,309 @@ +// Copyright 2025 Google LLC +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + +#include "xnnpack/assembly.h" +.PERMUTATION: + .long 0 + .long 2 + .long 4 + .long 6 + .long 8 + .long 10 + .long 12 + .long 14 + .long 16 + .long 18 + .long 20 + .long 22 + .long 24 + .long 26 + .long 28 + .long 30 + +BEGIN_FUNCTION xnn_f32_gemm_minmax_ukernel_4x32c2__asm_amd64_avx512f_broadcast + + .intel_syntax noprefix + + # Free up GP registers. + push rbx + push rbp + push r15 + push r14 + push r13 + push r12 + + # load params to free up a GP registers + mov r13, [rsp + 80] # params + vbroadcastss zmm0, DWORD PTR [r13] + vbroadcastss zmm1, DWORD PTR [r13 + 4] + + # Load c pointer. + mov r10, [rsp + 56] + # Load cm_stride. + mov r11, [rsp + 64] + + # Align the stack pointer. + mov r13, rsp + sub rsp, 64 + and rsp, 0xFFFFFFFFFFFFFFC0 + # Store the old stack pointer containing the return address + mov [rsp], r13 + + # Allocate some space on the stack. + sub rsp, 128 + + # Clamp a & c pointers if mr <= 1 + mov rax, rcx + add rax, r8 + mov r13, r10 + add r13, r11 + cmp rdi, 1 + cmovle rax, rcx + cmovle r13, r10 + + # Clamp a & c pointers if mr <= 2 + mov r15, rax + add r15, r8 + mov rbx, r13 + add rbx, r11 + cmp rdi, 2 + cmovle r15, rax + cmovle rbx, r13 + + # Clamp a & c pointers if mr <= 3 + mov r14, r15 + add r14, r8 + mov rbp, rbx + add rbp, r11 + cmp rdi, 3 + cmovle r14, r15 + cmovle rbp, rbx + + # Copy k and flip bit. + mov r11, rdx + and r11, 0x4 + and rdx, 0xFFFFFFFFFFFFFFFB + mov [rsp + 88], r11 + mov r11, 0x5555 + kmovw k3, r11d + +outer_loop: + # Initialize k counter. + mov r11, 0 + vmovaps zmm7, [r9 + 0] + vmovaps zmm8, [r9 + 64] + # Interleave with zeros. + vpmovzxdq zmm11, ymm7 + vextracti64x4 ymm7, zmm7, 1 + vpmovzxdq zmm15, ymm7 + vpmovzxdq zmm19, ymm8 + vextracti64x4 ymm8, zmm8, 1 + vpmovzxdq zmm23, ymm8 + vmovaps zmm12, zmm11 + vmovaps zmm13, zmm11 + vmovaps zmm14, zmm11 + vmovaps zmm16, zmm15 + vmovaps zmm17, zmm15 + vmovaps zmm18, zmm15 + vmovaps zmm20, zmm19 + vmovaps zmm21, zmm19 + vmovaps zmm22, zmm19 + vmovaps zmm24, zmm23 + vmovaps zmm25, zmm23 + vmovaps zmm26, zmm23 + add r9, 128 + + # Are there at least 8 bytes? + cmp rdx, 8 + js inner_loop_tail + +inner_loop: + vmovaps zmm7, [r9 + 0] + vmovaps zmm8, [r9 + 64] + vmovaps zmm9, [r9 + 128] + vmovaps zmm10, [r9 + 192] + add r9, 256 + vbroadcastsd zmm2, QWORD PTR [rcx + r11] + vfmadd231ps zmm11, zmm2, zmm7 + vfmadd231ps zmm15, zmm2, zmm8 + vfmadd231ps zmm19, zmm2, zmm9 + vfmadd231ps zmm23, zmm2, zmm10 + vbroadcastsd zmm3, QWORD PTR [rax + r11] + vfmadd231ps zmm12, zmm3, zmm7 + vfmadd231ps zmm16, zmm3, zmm8 + vfmadd231ps zmm20, zmm3, zmm9 + vfmadd231ps zmm24, zmm3, zmm10 + vbroadcastsd zmm4, QWORD PTR [r15 + r11] + vfmadd231ps zmm13, zmm4, zmm7 + vfmadd231ps zmm17, zmm4, zmm8 + vfmadd231ps zmm21, zmm4, zmm9 + vfmadd231ps zmm25, zmm4, zmm10 + vbroadcastsd zmm5, QWORD PTR [r14 + r11] + vfmadd231ps zmm14, zmm5, zmm7 + vfmadd231ps zmm18, zmm5, zmm8 + vfmadd231ps zmm22, zmm5, zmm9 + vfmadd231ps zmm26, zmm5, zmm10 + + add r11, 8 + cmp rdx, r11 + jne inner_loop + + # Store nc_register. + mov [rsp + 96], rsi + # Load odd k bit. + mov rsi, [rsp + 88] + # Check if channels are odd. + test rsi, rsi + mov rsi, [rsp + 96] + jz inner_loop_end + +inner_loop_tail: + vmovaps zmm7, [r9 + 0] + vmovaps zmm8, [r9 + 64] + vmovaps zmm9, [r9 + 128] + vmovaps zmm10, [r9 + 192] + add r9, 256 + vbroadcastsd zmm2, QWORD PTR [rcx + r11] + vfmadd231ps zmm11{k3}, zmm2, zmm7 + vfmadd231ps zmm15{k3}, zmm2, zmm8 + vfmadd231ps zmm19{k3}, zmm2, zmm9 + vfmadd231ps zmm23{k3}, zmm2, zmm10 + vbroadcastsd zmm3, QWORD PTR [rax + r11] + vfmadd231ps zmm12{k3}, zmm3, zmm7 + vfmadd231ps zmm16{k3}, zmm3, zmm8 + vfmadd231ps zmm20{k3}, zmm3, zmm9 + vfmadd231ps zmm24{k3}, zmm3, zmm10 + vbroadcastsd zmm4, QWORD PTR [r15 + r11] + vfmadd231ps zmm13{k3}, zmm4, zmm7 + vfmadd231ps zmm17{k3}, zmm4, zmm8 + vfmadd231ps zmm21{k3}, zmm4, zmm9 + vfmadd231ps zmm25{k3}, zmm4, zmm10 + vbroadcastsd zmm5, QWORD PTR [r14 + r11] + vfmadd231ps zmm14{k3}, zmm5, zmm7 + vfmadd231ps zmm18{k3}, zmm5, zmm8 + vfmadd231ps zmm22{k3}, zmm5, zmm9 + vfmadd231ps zmm26{k3}, zmm5, zmm10 +inner_loop_end: + vpsrlq zmm7, zmm11, 32 + vaddps zmm11, zmm11, zmm7 + vpsrlq zmm7, zmm12, 32 + vaddps zmm12, zmm12, zmm7 + vpsrlq zmm7, zmm13, 32 + vaddps zmm13, zmm13, zmm7 + vpsrlq zmm7, zmm14, 32 + vaddps zmm14, zmm14, zmm7 + vpsrlq zmm7, zmm15, 32 + vaddps zmm15, zmm15, zmm7 + vpsrlq zmm7, zmm16, 32 + vaddps zmm16, zmm16, zmm7 + vpsrlq zmm7, zmm17, 32 + vaddps zmm17, zmm17, zmm7 + vpsrlq zmm7, zmm18, 32 + vaddps zmm18, zmm18, zmm7 + vpsrlq zmm7, zmm19, 32 + vaddps zmm19, zmm19, zmm7 + vpsrlq zmm7, zmm20, 32 + vaddps zmm20, zmm20, zmm7 + vpsrlq zmm7, zmm21, 32 + vaddps zmm21, zmm21, zmm7 + vpsrlq zmm7, zmm22, 32 + vaddps zmm22, zmm22, zmm7 + vpsrlq zmm7, zmm23, 32 + vaddps zmm23, zmm23, zmm7 + vpsrlq zmm7, zmm24, 32 + vaddps zmm24, zmm24, zmm7 + vpsrlq zmm7, zmm25, 32 + vaddps zmm25, zmm25, zmm7 + vpsrlq zmm7, zmm26, 32 + vaddps zmm26, zmm26, zmm7 + vmovups zmm7, zmmword ptr [rip + .PERMUTATION] + vpermt2ps zmm11, zmm7, zmm15 + vpermt2ps zmm12, zmm7, zmm16 + vpermt2ps zmm13, zmm7, zmm17 + vpermt2ps zmm14, zmm7, zmm18 + vpermt2ps zmm19, zmm7, zmm23 + vpermt2ps zmm20, zmm7, zmm24 + vpermt2ps zmm21, zmm7, zmm25 + vpermt2ps zmm22, zmm7, zmm26 + # Min/max clamping. + vminps zmm11, zmm1, zmm11 + vminps zmm12, zmm1, zmm12 + vminps zmm13, zmm1, zmm13 + vminps zmm14, zmm1, zmm14 + vminps zmm15, zmm1, zmm19 + vminps zmm16, zmm1, zmm20 + vminps zmm17, zmm1, zmm21 + vminps zmm18, zmm1, zmm22 + vmaxps zmm11, zmm0, zmm11 + vmaxps zmm12, zmm0, zmm12 + vmaxps zmm13, zmm0, zmm13 + vmaxps zmm14, zmm0, zmm14 + vmaxps zmm15, zmm0, zmm15 + vmaxps zmm16, zmm0, zmm16 + vmaxps zmm17, zmm0, zmm17 + vmaxps zmm18, zmm0, zmm18 + + # Check whether full or partial store. + cmp rsi, 32 + jl tail + + vmovups [r10], zmm11 + vmovups [r10 + 64], zmm15 + vmovups [r13], zmm12 + vmovups [r13 + 64], zmm16 + vmovups [rbx], zmm13 + vmovups [rbx + 64], zmm17 + vmovups [rbp], zmm14 + vmovups [rbp + 64], zmm18 + add r10, 128 + add r13, 128 + add rbx, 128 + add rbp, 128 + + sub rsi, 32 + jne outer_loop + jmp return + +tail: + mov r11, -1 + shlx r11, r11, rsi + not r11 + kmovw k1, r11d + shr r11d, 16 + kmovw k2, r11d + vmovups ZMMWORD PTR [r10]{k1}, zmm11 + vmovups ZMMWORD PTR [r10 + 64]{k2}, zmm15 + vmovups ZMMWORD PTR [r13]{k1}, zmm12 + vmovups ZMMWORD PTR [r13 + 64]{k2}, zmm16 + vmovups ZMMWORD PTR [rbx]{k1}, zmm13 + vmovups ZMMWORD PTR [rbx + 64]{k2}, zmm17 + vmovups ZMMWORD PTR [rbp]{k1}, zmm14 + vmovups ZMMWORD PTR [rbp + 64]{k2}, zmm18 + +return: + add rsp, 128 + mov r13, [rsp] + mov rsp, r13 + # Restore the callee saved registers. + pop r12 + pop r13 + pop r14 + pop r15 + pop rbp + pop rbx + ret +END_FUNCTION xnn_f32_gemm_minmax_ukernel_4x32c2__asm_amd64_avx512f_broadcast + + #ifdef __has_feature + #if __has_feature(dataflow_sanitizer) +BEGIN_FUNCTION xnn_f32_gemm_minmax_ukernel_4x32c2__asm_amd64_avx512f_broadcast.dfsan + .intel_syntax noprefix + # We could implement this by calling a function that implements the dfsan instrumentation. + # For now, just break, so if someone tries to use this, they'll know where the problem is. + int 3 + ret +END_FUNCTION xnn_f32_gemm_minmax_ukernel_4x32c2__asm_amd64_avx512f_broadcast.dfsan + #endif + #endif \ No newline at end of file diff --git a/src/f32-gemm/gen/f32-gemm-4x64-minmax-asm-amd64-avx512f-broadcast.S b/src/f32-gemm/gen/f32-gemm-4x64-minmax-asm-amd64-avx512f-broadcast.S index 9b6218bb392..3ba0eb14058 100644 --- a/src/f32-gemm/gen/f32-gemm-4x64-minmax-asm-amd64-avx512f-broadcast.S +++ b/src/f32-gemm/gen/f32-gemm-4x64-minmax-asm-amd64-avx512f-broadcast.S @@ -70,13 +70,13 @@ outer_loop: # Initialize k counter. mov r11, 0 # Initialize accumulators with the biases. - vmovaps zmm7, [r9 + 0] + vmovaps zmm11, [r9 + 0] vmovaps zmm15, [r9 + 64] vmovaps zmm19, [r9 + 128] vmovaps zmm23, [r9 + 192] - vmovaps zmm8, zmm7 - vmovaps zmm9, zmm7 - vmovaps zmm14, zmm7 + vmovaps zmm12, zmm11 + vmovaps zmm13, zmm11 + vmovaps zmm14, zmm11 vmovaps zmm16, zmm15 vmovaps zmm17, zmm15 vmovaps zmm18, zmm15 @@ -89,40 +89,40 @@ outer_loop: add r9, 256 inner_loop: - vmovaps zmm10, [r9 + 0] - vmovaps zmm11, [r9 + 64] - vmovaps zmm12, [r9 + 128] - vmovaps zmm13, [r9 + 192] + vmovaps zmm7, [r9 + 0] + vmovaps zmm8, [r9 + 64] + vmovaps zmm9, [r9 + 128] + vmovaps zmm10, [r9 + 192] add r9, 256 vbroadcastss zmm2, DWORD PTR [rcx + r11] - vfmadd231ps zmm7, zmm2, zmm10 - vfmadd231ps zmm15, zmm2, zmm11 - vfmadd231ps zmm19, zmm2, zmm12 - vfmadd231ps zmm23, zmm2, zmm13 + vfmadd231ps zmm11, zmm2, zmm7 + vfmadd231ps zmm15, zmm2, zmm8 + vfmadd231ps zmm19, zmm2, zmm9 + vfmadd231ps zmm23, zmm2, zmm10 vbroadcastss zmm3, DWORD PTR [rax + r11] - vfmadd231ps zmm8, zmm3, zmm10 - vfmadd231ps zmm16, zmm3, zmm11 - vfmadd231ps zmm20, zmm3, zmm12 - vfmadd231ps zmm24, zmm3, zmm13 + vfmadd231ps zmm12, zmm3, zmm7 + vfmadd231ps zmm16, zmm3, zmm8 + vfmadd231ps zmm20, zmm3, zmm9 + vfmadd231ps zmm24, zmm3, zmm10 vbroadcastss zmm4, DWORD PTR [r15 + r11] - vfmadd231ps zmm9, zmm4, zmm10 - vfmadd231ps zmm17, zmm4, zmm11 - vfmadd231ps zmm21, zmm4, zmm12 - vfmadd231ps zmm25, zmm4, zmm13 + vfmadd231ps zmm13, zmm4, zmm7 + vfmadd231ps zmm17, zmm4, zmm8 + vfmadd231ps zmm21, zmm4, zmm9 + vfmadd231ps zmm25, zmm4, zmm10 vbroadcastss zmm5, DWORD PTR [r14 + r11] - vfmadd231ps zmm14, zmm5, zmm10 - vfmadd231ps zmm18, zmm5, zmm11 - vfmadd231ps zmm22, zmm5, zmm12 - vfmadd231ps zmm26, zmm5, zmm13 + vfmadd231ps zmm14, zmm5, zmm7 + vfmadd231ps zmm18, zmm5, zmm8 + vfmadd231ps zmm22, zmm5, zmm9 + vfmadd231ps zmm26, zmm5, zmm10 add r11, 4 cmp rdx, r11 jne inner_loop inner_loop_end: # Min/max clamping. - vminps zmm7, zmm1, zmm7 - vminps zmm8, zmm1, zmm8 - vminps zmm9, zmm1, zmm9 + vminps zmm11, zmm1, zmm11 + vminps zmm12, zmm1, zmm12 + vminps zmm13, zmm1, zmm13 vminps zmm14, zmm1, zmm14 vminps zmm15, zmm1, zmm15 vminps zmm16, zmm1, zmm16 @@ -136,9 +136,9 @@ inner_loop_end: vminps zmm24, zmm1, zmm24 vminps zmm25, zmm1, zmm25 vminps zmm26, zmm1, zmm26 - vmaxps zmm7, zmm0, zmm7 - vmaxps zmm8, zmm0, zmm8 - vmaxps zmm9, zmm0, zmm9 + vmaxps zmm11, zmm0, zmm11 + vmaxps zmm12, zmm0, zmm12 + vmaxps zmm13, zmm0, zmm13 vmaxps zmm14, zmm0, zmm14 vmaxps zmm15, zmm0, zmm15 vmaxps zmm16, zmm0, zmm16 @@ -157,15 +157,15 @@ inner_loop_end: cmp rsi, 64 jl tail - vmovups [r10], zmm7 + vmovups [r10], zmm11 vmovups [r10 + 64], zmm15 vmovups [r10 + 128], zmm19 vmovups [r10 + 192], zmm23 - vmovups [r13], zmm8 + vmovups [r13], zmm12 vmovups [r13 + 64], zmm16 vmovups [r13 + 128], zmm20 vmovups [r13 + 192], zmm24 - vmovups [rbx], zmm9 + vmovups [rbx], zmm13 vmovups [rbx + 64], zmm17 vmovups [rbx + 128], zmm21 vmovups [rbx + 192], zmm25 @@ -194,15 +194,15 @@ tail: shr r11, 16 kmovw k4, r11d - vmovups ZMMWORD PTR [r10]{k1}, zmm7 + vmovups ZMMWORD PTR [r10]{k1}, zmm11 vmovups ZMMWORD PTR [r10 + 64]{k2}, zmm15 vmovups ZMMWORD PTR [r10 + 128]{k3}, zmm19 vmovups ZMMWORD PTR [r10 + 192]{k4}, zmm23 - vmovups ZMMWORD PTR [r13]{k1}, zmm8 + vmovups ZMMWORD PTR [r13]{k1}, zmm12 vmovups ZMMWORD PTR [r13 + 64]{k2}, zmm16 vmovups ZMMWORD PTR [r13 + 128]{k3}, zmm20 vmovups ZMMWORD PTR [r13 + 192]{k4}, zmm24 - vmovups ZMMWORD PTR [rbx]{k1}, zmm9 + vmovups ZMMWORD PTR [rbx]{k1}, zmm13 vmovups ZMMWORD PTR [rbx + 64]{k2}, zmm17 vmovups ZMMWORD PTR [rbx + 128]{k3}, zmm21 vmovups ZMMWORD PTR [rbx + 192]{k4}, zmm25 diff --git a/src/f32-gemm/gen/f32-gemm-5x16-minmax-asm-amd64-avx512f-broadcast.S b/src/f32-gemm/gen/f32-gemm-5x16-minmax-asm-amd64-avx512f-broadcast.S index 01e56a207d4..5bb7869ae41 100644 --- a/src/f32-gemm/gen/f32-gemm-5x16-minmax-asm-amd64-avx512f-broadcast.S +++ b/src/f32-gemm/gen/f32-gemm-5x16-minmax-asm-amd64-avx512f-broadcast.S @@ -79,40 +79,40 @@ outer_loop: # Initialize k counter. mov r11, 0 # Initialize accumulators with the biases. - vmovaps zmm7, [r9 + 0] - vmovaps zmm8, zmm7 - vmovaps zmm9, zmm7 - vmovaps zmm14, zmm7 - vmovaps zmm15, zmm7 + vmovaps zmm11, [r9 + 0] + vmovaps zmm12, zmm11 + vmovaps zmm13, zmm11 + vmovaps zmm14, zmm11 + vmovaps zmm15, zmm11 add r9, 64 inner_loop: - vmovaps zmm10, [r9 + 0] + vmovaps zmm7, [r9 + 0] add r9, 64 vbroadcastss zmm2, DWORD PTR [rcx + r11] - vfmadd231ps zmm7, zmm2, zmm10 + vfmadd231ps zmm11, zmm2, zmm7 vbroadcastss zmm3, DWORD PTR [rax + r11] - vfmadd231ps zmm8, zmm3, zmm10 + vfmadd231ps zmm12, zmm3, zmm7 vbroadcastss zmm4, DWORD PTR [r15 + r11] - vfmadd231ps zmm9, zmm4, zmm10 + vfmadd231ps zmm13, zmm4, zmm7 vbroadcastss zmm5, DWORD PTR [r14 + r11] - vfmadd231ps zmm14, zmm5, zmm10 + vfmadd231ps zmm14, zmm5, zmm7 vbroadcastss zmm6, DWORD PTR [r12 + r11] - vfmadd231ps zmm15, zmm6, zmm10 + vfmadd231ps zmm15, zmm6, zmm7 add r11, 4 cmp rdx, r11 jne inner_loop inner_loop_end: # Min/max clamping. - vminps zmm7, zmm1, zmm7 - vminps zmm8, zmm1, zmm8 - vminps zmm9, zmm1, zmm9 + vminps zmm11, zmm1, zmm11 + vminps zmm12, zmm1, zmm12 + vminps zmm13, zmm1, zmm13 vminps zmm14, zmm1, zmm14 vminps zmm15, zmm1, zmm15 - vmaxps zmm7, zmm0, zmm7 - vmaxps zmm8, zmm0, zmm8 - vmaxps zmm9, zmm0, zmm9 + vmaxps zmm11, zmm0, zmm11 + vmaxps zmm12, zmm0, zmm12 + vmaxps zmm13, zmm0, zmm13 vmaxps zmm14, zmm0, zmm14 vmaxps zmm15, zmm0, zmm15 @@ -120,9 +120,9 @@ inner_loop_end: cmp rsi, 16 jl tail - vmovups [r10], zmm7 - vmovups [r13], zmm8 - vmovups [rbx], zmm9 + vmovups [r10], zmm11 + vmovups [r13], zmm12 + vmovups [rbx], zmm13 vmovups [rbp], zmm14 vmovups [r8], zmm15 add r10, 64 @@ -140,9 +140,9 @@ tail: shlx r11, r11, rsi not r11 kmovw k1, r11d - vmovups ZMMWORD PTR [r10]{k1}, zmm7 - vmovups ZMMWORD PTR [r13]{k1}, zmm8 - vmovups ZMMWORD PTR [rbx]{k1}, zmm9 + vmovups ZMMWORD PTR [r10]{k1}, zmm11 + vmovups ZMMWORD PTR [r13]{k1}, zmm12 + vmovups ZMMWORD PTR [rbx]{k1}, zmm13 vmovups ZMMWORD PTR [rbp]{k1}, zmm14 vmovups ZMMWORD PTR [r8]{k1}, zmm15 diff --git a/src/f32-gemm/gen/f32-gemm-5x16c2-minmax-asm-amd64-avx512f-broadcast.S b/src/f32-gemm/gen/f32-gemm-5x16c2-minmax-asm-amd64-avx512f-broadcast.S new file mode 100644 index 00000000000..e0bc7920255 --- /dev/null +++ b/src/f32-gemm/gen/f32-gemm-5x16c2-minmax-asm-amd64-avx512f-broadcast.S @@ -0,0 +1,268 @@ +// Copyright 2025 Google LLC +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + +#include "xnnpack/assembly.h" +.PERMUTATION: + .long 0 + .long 2 + .long 4 + .long 6 + .long 8 + .long 10 + .long 12 + .long 14 + .long 16 + .long 18 + .long 20 + .long 22 + .long 24 + .long 26 + .long 28 + .long 30 + +BEGIN_FUNCTION xnn_f32_gemm_minmax_ukernel_5x16c2__asm_amd64_avx512f_broadcast + + .intel_syntax noprefix + + # Free up GP registers. + push rbx + push rbp + push r15 + push r14 + push r13 + push r12 + + # load params to free up a GP registers + mov r13, [rsp + 80] # params + vbroadcastss zmm0, DWORD PTR [r13] + vbroadcastss zmm1, DWORD PTR [r13 + 4] + + # Load c pointer. + mov r10, [rsp + 56] + # Load cm_stride. + mov r11, [rsp + 64] + + # Align the stack pointer. + mov r13, rsp + sub rsp, 64 + and rsp, 0xFFFFFFFFFFFFFFC0 + # Store the old stack pointer containing the return address + mov [rsp], r13 + + # Allocate some space on the stack. + sub rsp, 128 + + # Clamp a & c pointers if mr <= 1 + mov rax, rcx + add rax, r8 + mov r13, r10 + add r13, r11 + cmp rdi, 1 + cmovle rax, rcx + cmovle r13, r10 + + # Clamp a & c pointers if mr <= 2 + mov r15, rax + add r15, r8 + mov rbx, r13 + add rbx, r11 + cmp rdi, 2 + cmovle r15, rax + cmovle rbx, r13 + + # Clamp a & c pointers if mr <= 3 + mov r14, r15 + add r14, r8 + mov rbp, rbx + add rbp, r11 + cmp rdi, 3 + cmovle r14, r15 + cmovle rbp, rbx + + # Clamp a & c pointers if mr <= 4 + mov r12, r14 + add r12, r8 + mov r8, rbp + add r8, r11 + cmp rdi, 4 + cmovle r12, r14 + cmovle r8, rbp + + # Copy k and flip bit. + mov r11, rdx + and r11, 0x4 + and rdx, 0xFFFFFFFFFFFFFFFB + mov [rsp + 104], r11 + mov r11, 0x5555 + kmovw k3, r11d + +outer_loop: + # Initialize k counter. + mov r11, 0 + vmovaps zmm7, [r9 + 0] + # Interleave with zeros. + vpmovzxdq zmm11, ymm7 + vextracti64x4 ymm7, zmm7, 1 + vpmovzxdq zmm16, ymm7 + vmovaps zmm12, zmm11 + vmovaps zmm13, zmm11 + vmovaps zmm14, zmm11 + vmovaps zmm15, zmm11 + vmovaps zmm17, zmm16 + vmovaps zmm18, zmm16 + vmovaps zmm19, zmm16 + vmovaps zmm20, zmm16 + add r9, 64 + + # Are there at least 8 bytes? + cmp rdx, 8 + js inner_loop_tail + +inner_loop: + vmovaps zmm7, [r9 + 0] + vmovaps zmm8, [r9 + 64] + add r9, 128 + vbroadcastsd zmm2, QWORD PTR [rcx + r11] + vfmadd231ps zmm11, zmm2, zmm7 + vfmadd231ps zmm16, zmm2, zmm8 + vbroadcastsd zmm3, QWORD PTR [rax + r11] + vfmadd231ps zmm12, zmm3, zmm7 + vfmadd231ps zmm17, zmm3, zmm8 + vbroadcastsd zmm4, QWORD PTR [r15 + r11] + vfmadd231ps zmm13, zmm4, zmm7 + vfmadd231ps zmm18, zmm4, zmm8 + vbroadcastsd zmm5, QWORD PTR [r14 + r11] + vfmadd231ps zmm14, zmm5, zmm7 + vfmadd231ps zmm19, zmm5, zmm8 + vbroadcastsd zmm6, QWORD PTR [r12 + r11] + vfmadd231ps zmm15, zmm6, zmm7 + vfmadd231ps zmm20, zmm6, zmm8 + + add r11, 8 + cmp rdx, r11 + jne inner_loop + + # Store nc_register. + mov [rsp + 112], rsi + # Load odd k bit. + mov rsi, [rsp + 104] + # Check if channels are odd. + test rsi, rsi + mov rsi, [rsp + 112] + jz inner_loop_end + +inner_loop_tail: + vmovaps zmm7, [r9 + 0] + vmovaps zmm8, [r9 + 64] + add r9, 128 + vbroadcastsd zmm2, QWORD PTR [rcx + r11] + vfmadd231ps zmm11{k3}, zmm2, zmm7 + vfmadd231ps zmm16{k3}, zmm2, zmm8 + vbroadcastsd zmm3, QWORD PTR [rax + r11] + vfmadd231ps zmm12{k3}, zmm3, zmm7 + vfmadd231ps zmm17{k3}, zmm3, zmm8 + vbroadcastsd zmm4, QWORD PTR [r15 + r11] + vfmadd231ps zmm13{k3}, zmm4, zmm7 + vfmadd231ps zmm18{k3}, zmm4, zmm8 + vbroadcastsd zmm5, QWORD PTR [r14 + r11] + vfmadd231ps zmm14{k3}, zmm5, zmm7 + vfmadd231ps zmm19{k3}, zmm5, zmm8 + vbroadcastsd zmm6, QWORD PTR [r12 + r11] + vfmadd231ps zmm15{k3}, zmm6, zmm7 + vfmadd231ps zmm20{k3}, zmm6, zmm8 +inner_loop_end: + vpsrlq zmm7, zmm11, 32 + vaddps zmm11, zmm11, zmm7 + vpsrlq zmm7, zmm12, 32 + vaddps zmm12, zmm12, zmm7 + vpsrlq zmm7, zmm13, 32 + vaddps zmm13, zmm13, zmm7 + vpsrlq zmm7, zmm14, 32 + vaddps zmm14, zmm14, zmm7 + vpsrlq zmm7, zmm15, 32 + vaddps zmm15, zmm15, zmm7 + vpsrlq zmm7, zmm16, 32 + vaddps zmm16, zmm16, zmm7 + vpsrlq zmm7, zmm17, 32 + vaddps zmm17, zmm17, zmm7 + vpsrlq zmm7, zmm18, 32 + vaddps zmm18, zmm18, zmm7 + vpsrlq zmm7, zmm19, 32 + vaddps zmm19, zmm19, zmm7 + vpsrlq zmm7, zmm20, 32 + vaddps zmm20, zmm20, zmm7 + vmovups zmm7, zmmword ptr [rip + .PERMUTATION] + vpermt2ps zmm11, zmm7, zmm16 + vpermt2ps zmm12, zmm7, zmm17 + vpermt2ps zmm13, zmm7, zmm18 + vpermt2ps zmm14, zmm7, zmm19 + vpermt2ps zmm15, zmm7, zmm20 + # Min/max clamping. + vminps zmm11, zmm1, zmm11 + vminps zmm12, zmm1, zmm12 + vminps zmm13, zmm1, zmm13 + vminps zmm14, zmm1, zmm14 + vminps zmm15, zmm1, zmm15 + vmaxps zmm11, zmm0, zmm11 + vmaxps zmm12, zmm0, zmm12 + vmaxps zmm13, zmm0, zmm13 + vmaxps zmm14, zmm0, zmm14 + vmaxps zmm15, zmm0, zmm15 + + # Check whether full or partial store. + cmp rsi, 16 + jl tail + + vmovups [r10], zmm11 + vmovups [r13], zmm12 + vmovups [rbx], zmm13 + vmovups [rbp], zmm14 + vmovups [r8], zmm15 + add r10, 64 + add r13, 64 + add rbx, 64 + add rbp, 64 + add r8, 64 + + sub rsi, 16 + jne outer_loop + jmp return + +tail: + mov r11, -1 + shlx r11, r11, rsi + not r11 + kmovw k1, r11d + vmovups ZMMWORD PTR [r10]{k1}, zmm11 + vmovups ZMMWORD PTR [r13]{k1}, zmm12 + vmovups ZMMWORD PTR [rbx]{k1}, zmm13 + vmovups ZMMWORD PTR [rbp]{k1}, zmm14 + vmovups ZMMWORD PTR [r8]{k1}, zmm15 + +return: + add rsp, 128 + mov r13, [rsp] + mov rsp, r13 + # Restore the callee saved registers. + pop r12 + pop r13 + pop r14 + pop r15 + pop rbp + pop rbx + ret +END_FUNCTION xnn_f32_gemm_minmax_ukernel_5x16c2__asm_amd64_avx512f_broadcast + + #ifdef __has_feature + #if __has_feature(dataflow_sanitizer) +BEGIN_FUNCTION xnn_f32_gemm_minmax_ukernel_5x16c2__asm_amd64_avx512f_broadcast.dfsan + .intel_syntax noprefix + # We could implement this by calling a function that implements the dfsan instrumentation. + # For now, just break, so if someone tries to use this, they'll know where the problem is. + int 3 + ret +END_FUNCTION xnn_f32_gemm_minmax_ukernel_5x16c2__asm_amd64_avx512f_broadcast.dfsan + #endif + #endif \ No newline at end of file diff --git a/src/f32-gemm/gen/f32-gemm-5x32-minmax-asm-amd64-avx512f-broadcast.S b/src/f32-gemm/gen/f32-gemm-5x32-minmax-asm-amd64-avx512f-broadcast.S index cb1d574cd68..cccbac02b84 100644 --- a/src/f32-gemm/gen/f32-gemm-5x32-minmax-asm-amd64-avx512f-broadcast.S +++ b/src/f32-gemm/gen/f32-gemm-5x32-minmax-asm-amd64-avx512f-broadcast.S @@ -79,12 +79,12 @@ outer_loop: # Initialize k counter. mov r11, 0 # Initialize accumulators with the biases. - vmovaps zmm7, [r9 + 0] + vmovaps zmm11, [r9 + 0] vmovaps zmm16, [r9 + 64] - vmovaps zmm8, zmm7 - vmovaps zmm9, zmm7 - vmovaps zmm14, zmm7 - vmovaps zmm15, zmm7 + vmovaps zmm12, zmm11 + vmovaps zmm13, zmm11 + vmovaps zmm14, zmm11 + vmovaps zmm15, zmm11 vmovaps zmm17, zmm16 vmovaps zmm18, zmm16 vmovaps zmm19, zmm16 @@ -92,33 +92,33 @@ outer_loop: add r9, 128 inner_loop: - vmovaps zmm10, [r9 + 0] - vmovaps zmm11, [r9 + 64] + vmovaps zmm7, [r9 + 0] + vmovaps zmm8, [r9 + 64] add r9, 128 vbroadcastss zmm2, DWORD PTR [rcx + r11] - vfmadd231ps zmm7, zmm2, zmm10 - vfmadd231ps zmm16, zmm2, zmm11 + vfmadd231ps zmm11, zmm2, zmm7 + vfmadd231ps zmm16, zmm2, zmm8 vbroadcastss zmm3, DWORD PTR [rax + r11] - vfmadd231ps zmm8, zmm3, zmm10 - vfmadd231ps zmm17, zmm3, zmm11 + vfmadd231ps zmm12, zmm3, zmm7 + vfmadd231ps zmm17, zmm3, zmm8 vbroadcastss zmm4, DWORD PTR [r15 + r11] - vfmadd231ps zmm9, zmm4, zmm10 - vfmadd231ps zmm18, zmm4, zmm11 + vfmadd231ps zmm13, zmm4, zmm7 + vfmadd231ps zmm18, zmm4, zmm8 vbroadcastss zmm5, DWORD PTR [r14 + r11] - vfmadd231ps zmm14, zmm5, zmm10 - vfmadd231ps zmm19, zmm5, zmm11 + vfmadd231ps zmm14, zmm5, zmm7 + vfmadd231ps zmm19, zmm5, zmm8 vbroadcastss zmm6, DWORD PTR [r12 + r11] - vfmadd231ps zmm15, zmm6, zmm10 - vfmadd231ps zmm20, zmm6, zmm11 + vfmadd231ps zmm15, zmm6, zmm7 + vfmadd231ps zmm20, zmm6, zmm8 add r11, 4 cmp rdx, r11 jne inner_loop inner_loop_end: # Min/max clamping. - vminps zmm7, zmm1, zmm7 - vminps zmm8, zmm1, zmm8 - vminps zmm9, zmm1, zmm9 + vminps zmm11, zmm1, zmm11 + vminps zmm12, zmm1, zmm12 + vminps zmm13, zmm1, zmm13 vminps zmm14, zmm1, zmm14 vminps zmm15, zmm1, zmm15 vminps zmm16, zmm1, zmm16 @@ -126,9 +126,9 @@ inner_loop_end: vminps zmm18, zmm1, zmm18 vminps zmm19, zmm1, zmm19 vminps zmm20, zmm1, zmm20 - vmaxps zmm7, zmm0, zmm7 - vmaxps zmm8, zmm0, zmm8 - vmaxps zmm9, zmm0, zmm9 + vmaxps zmm11, zmm0, zmm11 + vmaxps zmm12, zmm0, zmm12 + vmaxps zmm13, zmm0, zmm13 vmaxps zmm14, zmm0, zmm14 vmaxps zmm15, zmm0, zmm15 vmaxps zmm16, zmm0, zmm16 @@ -141,11 +141,11 @@ inner_loop_end: cmp rsi, 32 jl tail - vmovups [r10], zmm7 + vmovups [r10], zmm11 vmovups [r10 + 64], zmm16 - vmovups [r13], zmm8 + vmovups [r13], zmm12 vmovups [r13 + 64], zmm17 - vmovups [rbx], zmm9 + vmovups [rbx], zmm13 vmovups [rbx + 64], zmm18 vmovups [rbp], zmm14 vmovups [rbp + 64], zmm19 @@ -168,11 +168,11 @@ tail: kmovw k1, r11d shr r11d, 16 kmovw k2, r11d - vmovups ZMMWORD PTR [r10]{k1}, zmm7 + vmovups ZMMWORD PTR [r10]{k1}, zmm11 vmovups ZMMWORD PTR [r10 + 64]{k2}, zmm16 - vmovups ZMMWORD PTR [r13]{k1}, zmm8 + vmovups ZMMWORD PTR [r13]{k1}, zmm12 vmovups ZMMWORD PTR [r13 + 64]{k2}, zmm17 - vmovups ZMMWORD PTR [rbx]{k1}, zmm9 + vmovups ZMMWORD PTR [rbx]{k1}, zmm13 vmovups ZMMWORD PTR [rbx + 64]{k2}, zmm18 vmovups ZMMWORD PTR [rbp]{k1}, zmm14 vmovups ZMMWORD PTR [rbp + 64]{k2}, zmm19 diff --git a/src/f32-gemm/gen/f32-gemm-5x32c2-minmax-asm-amd64-avx512f-broadcast.S b/src/f32-gemm/gen/f32-gemm-5x32c2-minmax-asm-amd64-avx512f-broadcast.S new file mode 100644 index 00000000000..a2096959051 --- /dev/null +++ b/src/f32-gemm/gen/f32-gemm-5x32c2-minmax-asm-amd64-avx512f-broadcast.S @@ -0,0 +1,351 @@ +// Copyright 2025 Google LLC +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + +#include "xnnpack/assembly.h" +.PERMUTATION: + .long 0 + .long 2 + .long 4 + .long 6 + .long 8 + .long 10 + .long 12 + .long 14 + .long 16 + .long 18 + .long 20 + .long 22 + .long 24 + .long 26 + .long 28 + .long 30 + +BEGIN_FUNCTION xnn_f32_gemm_minmax_ukernel_5x32c2__asm_amd64_avx512f_broadcast + + .intel_syntax noprefix + + # Free up GP registers. + push rbx + push rbp + push r15 + push r14 + push r13 + push r12 + + # load params to free up a GP registers + mov r13, [rsp + 80] # params + vbroadcastss zmm0, DWORD PTR [r13] + vbroadcastss zmm1, DWORD PTR [r13 + 4] + + # Load c pointer. + mov r10, [rsp + 56] + # Load cm_stride. + mov r11, [rsp + 64] + + # Align the stack pointer. + mov r13, rsp + sub rsp, 64 + and rsp, 0xFFFFFFFFFFFFFFC0 + # Store the old stack pointer containing the return address + mov [rsp], r13 + + # Allocate some space on the stack. + sub rsp, 128 + + # Clamp a & c pointers if mr <= 1 + mov rax, rcx + add rax, r8 + mov r13, r10 + add r13, r11 + cmp rdi, 1 + cmovle rax, rcx + cmovle r13, r10 + + # Clamp a & c pointers if mr <= 2 + mov r15, rax + add r15, r8 + mov rbx, r13 + add rbx, r11 + cmp rdi, 2 + cmovle r15, rax + cmovle rbx, r13 + + # Clamp a & c pointers if mr <= 3 + mov r14, r15 + add r14, r8 + mov rbp, rbx + add rbp, r11 + cmp rdi, 3 + cmovle r14, r15 + cmovle rbp, rbx + + # Clamp a & c pointers if mr <= 4 + mov r12, r14 + add r12, r8 + mov r8, rbp + add r8, r11 + cmp rdi, 4 + cmovle r12, r14 + cmovle r8, rbp + + # Copy k and flip bit. + mov r11, rdx + and r11, 0x4 + and rdx, 0xFFFFFFFFFFFFFFFB + mov [rsp + 104], r11 + mov r11, 0x5555 + kmovw k3, r11d + +outer_loop: + # Initialize k counter. + mov r11, 0 + vmovaps zmm7, [r9 + 0] + vmovaps zmm8, [r9 + 64] + # Interleave with zeros. + vpmovzxdq zmm11, ymm7 + vextracti64x4 ymm7, zmm7, 1 + vpmovzxdq zmm16, ymm7 + vpmovzxdq zmm21, ymm8 + vextracti64x4 ymm8, zmm8, 1 + vpmovzxdq zmm26, ymm8 + vmovaps zmm12, zmm11 + vmovaps zmm13, zmm11 + vmovaps zmm14, zmm11 + vmovaps zmm15, zmm11 + vmovaps zmm17, zmm16 + vmovaps zmm18, zmm16 + vmovaps zmm19, zmm16 + vmovaps zmm20, zmm16 + vmovaps zmm22, zmm21 + vmovaps zmm23, zmm21 + vmovaps zmm24, zmm21 + vmovaps zmm25, zmm21 + vmovaps zmm27, zmm26 + vmovaps zmm28, zmm26 + vmovaps zmm29, zmm26 + vmovaps zmm30, zmm26 + add r9, 128 + + # Are there at least 8 bytes? + cmp rdx, 8 + js inner_loop_tail + +inner_loop: + vmovaps zmm7, [r9 + 0] + vmovaps zmm8, [r9 + 64] + vmovaps zmm9, [r9 + 128] + vmovaps zmm10, [r9 + 192] + add r9, 256 + vbroadcastsd zmm2, QWORD PTR [rcx + r11] + vfmadd231ps zmm11, zmm2, zmm7 + vfmadd231ps zmm16, zmm2, zmm8 + vfmadd231ps zmm21, zmm2, zmm9 + vfmadd231ps zmm26, zmm2, zmm10 + vbroadcastsd zmm3, QWORD PTR [rax + r11] + vfmadd231ps zmm12, zmm3, zmm7 + vfmadd231ps zmm17, zmm3, zmm8 + vfmadd231ps zmm22, zmm3, zmm9 + vfmadd231ps zmm27, zmm3, zmm10 + vbroadcastsd zmm4, QWORD PTR [r15 + r11] + vfmadd231ps zmm13, zmm4, zmm7 + vfmadd231ps zmm18, zmm4, zmm8 + vfmadd231ps zmm23, zmm4, zmm9 + vfmadd231ps zmm28, zmm4, zmm10 + vbroadcastsd zmm5, QWORD PTR [r14 + r11] + vfmadd231ps zmm14, zmm5, zmm7 + vfmadd231ps zmm19, zmm5, zmm8 + vfmadd231ps zmm24, zmm5, zmm9 + vfmadd231ps zmm29, zmm5, zmm10 + vbroadcastsd zmm6, QWORD PTR [r12 + r11] + vfmadd231ps zmm15, zmm6, zmm7 + vfmadd231ps zmm20, zmm6, zmm8 + vfmadd231ps zmm25, zmm6, zmm9 + vfmadd231ps zmm30, zmm6, zmm10 + + add r11, 8 + cmp rdx, r11 + jne inner_loop + + # Store nc_register. + mov [rsp + 112], rsi + # Load odd k bit. + mov rsi, [rsp + 104] + # Check if channels are odd. + test rsi, rsi + mov rsi, [rsp + 112] + jz inner_loop_end + +inner_loop_tail: + vmovaps zmm7, [r9 + 0] + vmovaps zmm8, [r9 + 64] + vmovaps zmm9, [r9 + 128] + vmovaps zmm10, [r9 + 192] + add r9, 256 + vbroadcastsd zmm2, QWORD PTR [rcx + r11] + vfmadd231ps zmm11{k3}, zmm2, zmm7 + vfmadd231ps zmm16{k3}, zmm2, zmm8 + vfmadd231ps zmm21{k3}, zmm2, zmm9 + vfmadd231ps zmm26{k3}, zmm2, zmm10 + vbroadcastsd zmm3, QWORD PTR [rax + r11] + vfmadd231ps zmm12{k3}, zmm3, zmm7 + vfmadd231ps zmm17{k3}, zmm3, zmm8 + vfmadd231ps zmm22{k3}, zmm3, zmm9 + vfmadd231ps zmm27{k3}, zmm3, zmm10 + vbroadcastsd zmm4, QWORD PTR [r15 + r11] + vfmadd231ps zmm13{k3}, zmm4, zmm7 + vfmadd231ps zmm18{k3}, zmm4, zmm8 + vfmadd231ps zmm23{k3}, zmm4, zmm9 + vfmadd231ps zmm28{k3}, zmm4, zmm10 + vbroadcastsd zmm5, QWORD PTR [r14 + r11] + vfmadd231ps zmm14{k3}, zmm5, zmm7 + vfmadd231ps zmm19{k3}, zmm5, zmm8 + vfmadd231ps zmm24{k3}, zmm5, zmm9 + vfmadd231ps zmm29{k3}, zmm5, zmm10 + vbroadcastsd zmm6, QWORD PTR [r12 + r11] + vfmadd231ps zmm15{k3}, zmm6, zmm7 + vfmadd231ps zmm20{k3}, zmm6, zmm8 + vfmadd231ps zmm25{k3}, zmm6, zmm9 + vfmadd231ps zmm30{k3}, zmm6, zmm10 +inner_loop_end: + vpsrlq zmm7, zmm11, 32 + vaddps zmm11, zmm11, zmm7 + vpsrlq zmm7, zmm12, 32 + vaddps zmm12, zmm12, zmm7 + vpsrlq zmm7, zmm13, 32 + vaddps zmm13, zmm13, zmm7 + vpsrlq zmm7, zmm14, 32 + vaddps zmm14, zmm14, zmm7 + vpsrlq zmm7, zmm15, 32 + vaddps zmm15, zmm15, zmm7 + vpsrlq zmm7, zmm16, 32 + vaddps zmm16, zmm16, zmm7 + vpsrlq zmm7, zmm17, 32 + vaddps zmm17, zmm17, zmm7 + vpsrlq zmm7, zmm18, 32 + vaddps zmm18, zmm18, zmm7 + vpsrlq zmm7, zmm19, 32 + vaddps zmm19, zmm19, zmm7 + vpsrlq zmm7, zmm20, 32 + vaddps zmm20, zmm20, zmm7 + vpsrlq zmm7, zmm21, 32 + vaddps zmm21, zmm21, zmm7 + vpsrlq zmm7, zmm22, 32 + vaddps zmm22, zmm22, zmm7 + vpsrlq zmm7, zmm23, 32 + vaddps zmm23, zmm23, zmm7 + vpsrlq zmm7, zmm24, 32 + vaddps zmm24, zmm24, zmm7 + vpsrlq zmm7, zmm25, 32 + vaddps zmm25, zmm25, zmm7 + vpsrlq zmm7, zmm26, 32 + vaddps zmm26, zmm26, zmm7 + vpsrlq zmm7, zmm27, 32 + vaddps zmm27, zmm27, zmm7 + vpsrlq zmm7, zmm28, 32 + vaddps zmm28, zmm28, zmm7 + vpsrlq zmm7, zmm29, 32 + vaddps zmm29, zmm29, zmm7 + vpsrlq zmm7, zmm30, 32 + vaddps zmm30, zmm30, zmm7 + vmovups zmm7, zmmword ptr [rip + .PERMUTATION] + vpermt2ps zmm11, zmm7, zmm16 + vpermt2ps zmm12, zmm7, zmm17 + vpermt2ps zmm13, zmm7, zmm18 + vpermt2ps zmm14, zmm7, zmm19 + vpermt2ps zmm15, zmm7, zmm20 + vpermt2ps zmm21, zmm7, zmm26 + vpermt2ps zmm22, zmm7, zmm27 + vpermt2ps zmm23, zmm7, zmm28 + vpermt2ps zmm24, zmm7, zmm29 + vpermt2ps zmm25, zmm7, zmm30 + # Min/max clamping. + vminps zmm11, zmm1, zmm11 + vminps zmm12, zmm1, zmm12 + vminps zmm13, zmm1, zmm13 + vminps zmm14, zmm1, zmm14 + vminps zmm15, zmm1, zmm15 + vminps zmm16, zmm1, zmm21 + vminps zmm17, zmm1, zmm22 + vminps zmm18, zmm1, zmm23 + vminps zmm19, zmm1, zmm24 + vminps zmm20, zmm1, zmm25 + vmaxps zmm11, zmm0, zmm11 + vmaxps zmm12, zmm0, zmm12 + vmaxps zmm13, zmm0, zmm13 + vmaxps zmm14, zmm0, zmm14 + vmaxps zmm15, zmm0, zmm15 + vmaxps zmm16, zmm0, zmm16 + vmaxps zmm17, zmm0, zmm17 + vmaxps zmm18, zmm0, zmm18 + vmaxps zmm19, zmm0, zmm19 + vmaxps zmm20, zmm0, zmm20 + + # Check whether full or partial store. + cmp rsi, 32 + jl tail + + vmovups [r10], zmm11 + vmovups [r10 + 64], zmm16 + vmovups [r13], zmm12 + vmovups [r13 + 64], zmm17 + vmovups [rbx], zmm13 + vmovups [rbx + 64], zmm18 + vmovups [rbp], zmm14 + vmovups [rbp + 64], zmm19 + vmovups [r8], zmm15 + vmovups [r8 + 64], zmm20 + add r10, 128 + add r13, 128 + add rbx, 128 + add rbp, 128 + add r8, 128 + + sub rsi, 32 + jne outer_loop + jmp return + +tail: + mov r11, -1 + shlx r11, r11, rsi + not r11 + kmovw k1, r11d + shr r11d, 16 + kmovw k2, r11d + vmovups ZMMWORD PTR [r10]{k1}, zmm11 + vmovups ZMMWORD PTR [r10 + 64]{k2}, zmm16 + vmovups ZMMWORD PTR [r13]{k1}, zmm12 + vmovups ZMMWORD PTR [r13 + 64]{k2}, zmm17 + vmovups ZMMWORD PTR [rbx]{k1}, zmm13 + vmovups ZMMWORD PTR [rbx + 64]{k2}, zmm18 + vmovups ZMMWORD PTR [rbp]{k1}, zmm14 + vmovups ZMMWORD PTR [rbp + 64]{k2}, zmm19 + vmovups ZMMWORD PTR [r8]{k1}, zmm15 + vmovups ZMMWORD PTR [r8 + 64]{k2}, zmm20 + +return: + add rsp, 128 + mov r13, [rsp] + mov rsp, r13 + # Restore the callee saved registers. + pop r12 + pop r13 + pop r14 + pop r15 + pop rbp + pop rbx + ret +END_FUNCTION xnn_f32_gemm_minmax_ukernel_5x32c2__asm_amd64_avx512f_broadcast + + #ifdef __has_feature + #if __has_feature(dataflow_sanitizer) +BEGIN_FUNCTION xnn_f32_gemm_minmax_ukernel_5x32c2__asm_amd64_avx512f_broadcast.dfsan + .intel_syntax noprefix + # We could implement this by calling a function that implements the dfsan instrumentation. + # For now, just break, so if someone tries to use this, they'll know where the problem is. + int 3 + ret +END_FUNCTION xnn_f32_gemm_minmax_ukernel_5x32c2__asm_amd64_avx512f_broadcast.dfsan + #endif + #endif \ No newline at end of file diff --git a/src/f32-gemm/gen/f32-gemm-5x64-minmax-asm-amd64-avx512f-broadcast.S b/src/f32-gemm/gen/f32-gemm-5x64-minmax-asm-amd64-avx512f-broadcast.S index b5e229846a0..b6ca2852f67 100644 --- a/src/f32-gemm/gen/f32-gemm-5x64-minmax-asm-amd64-avx512f-broadcast.S +++ b/src/f32-gemm/gen/f32-gemm-5x64-minmax-asm-amd64-avx512f-broadcast.S @@ -79,14 +79,14 @@ outer_loop: # Initialize k counter. mov r11, 0 # Initialize accumulators with the biases. - vmovaps zmm7, [r9 + 0] + vmovaps zmm11, [r9 + 0] vmovaps zmm16, [r9 + 64] vmovaps zmm21, [r9 + 128] vmovaps zmm26, [r9 + 192] - vmovaps zmm8, zmm7 - vmovaps zmm9, zmm7 - vmovaps zmm14, zmm7 - vmovaps zmm15, zmm7 + vmovaps zmm12, zmm11 + vmovaps zmm13, zmm11 + vmovaps zmm14, zmm11 + vmovaps zmm15, zmm11 vmovaps zmm17, zmm16 vmovaps zmm18, zmm16 vmovaps zmm19, zmm16 @@ -102,45 +102,45 @@ outer_loop: add r9, 256 inner_loop: - vmovaps zmm10, [r9 + 0] - vmovaps zmm11, [r9 + 64] - vmovaps zmm12, [r9 + 128] - vmovaps zmm13, [r9 + 192] + vmovaps zmm7, [r9 + 0] + vmovaps zmm8, [r9 + 64] + vmovaps zmm9, [r9 + 128] + vmovaps zmm10, [r9 + 192] add r9, 256 vbroadcastss zmm2, DWORD PTR [rcx + r11] - vfmadd231ps zmm7, zmm2, zmm10 - vfmadd231ps zmm16, zmm2, zmm11 - vfmadd231ps zmm21, zmm2, zmm12 - vfmadd231ps zmm26, zmm2, zmm13 + vfmadd231ps zmm11, zmm2, zmm7 + vfmadd231ps zmm16, zmm2, zmm8 + vfmadd231ps zmm21, zmm2, zmm9 + vfmadd231ps zmm26, zmm2, zmm10 vbroadcastss zmm3, DWORD PTR [rax + r11] - vfmadd231ps zmm8, zmm3, zmm10 - vfmadd231ps zmm17, zmm3, zmm11 - vfmadd231ps zmm22, zmm3, zmm12 - vfmadd231ps zmm27, zmm3, zmm13 + vfmadd231ps zmm12, zmm3, zmm7 + vfmadd231ps zmm17, zmm3, zmm8 + vfmadd231ps zmm22, zmm3, zmm9 + vfmadd231ps zmm27, zmm3, zmm10 vbroadcastss zmm4, DWORD PTR [r15 + r11] - vfmadd231ps zmm9, zmm4, zmm10 - vfmadd231ps zmm18, zmm4, zmm11 - vfmadd231ps zmm23, zmm4, zmm12 - vfmadd231ps zmm28, zmm4, zmm13 + vfmadd231ps zmm13, zmm4, zmm7 + vfmadd231ps zmm18, zmm4, zmm8 + vfmadd231ps zmm23, zmm4, zmm9 + vfmadd231ps zmm28, zmm4, zmm10 vbroadcastss zmm5, DWORD PTR [r14 + r11] - vfmadd231ps zmm14, zmm5, zmm10 - vfmadd231ps zmm19, zmm5, zmm11 - vfmadd231ps zmm24, zmm5, zmm12 - vfmadd231ps zmm29, zmm5, zmm13 + vfmadd231ps zmm14, zmm5, zmm7 + vfmadd231ps zmm19, zmm5, zmm8 + vfmadd231ps zmm24, zmm5, zmm9 + vfmadd231ps zmm29, zmm5, zmm10 vbroadcastss zmm6, DWORD PTR [r12 + r11] - vfmadd231ps zmm15, zmm6, zmm10 - vfmadd231ps zmm20, zmm6, zmm11 - vfmadd231ps zmm25, zmm6, zmm12 - vfmadd231ps zmm30, zmm6, zmm13 + vfmadd231ps zmm15, zmm6, zmm7 + vfmadd231ps zmm20, zmm6, zmm8 + vfmadd231ps zmm25, zmm6, zmm9 + vfmadd231ps zmm30, zmm6, zmm10 add r11, 4 cmp rdx, r11 jne inner_loop inner_loop_end: # Min/max clamping. - vminps zmm7, zmm1, zmm7 - vminps zmm8, zmm1, zmm8 - vminps zmm9, zmm1, zmm9 + vminps zmm11, zmm1, zmm11 + vminps zmm12, zmm1, zmm12 + vminps zmm13, zmm1, zmm13 vminps zmm14, zmm1, zmm14 vminps zmm15, zmm1, zmm15 vminps zmm16, zmm1, zmm16 @@ -158,9 +158,9 @@ inner_loop_end: vminps zmm28, zmm1, zmm28 vminps zmm29, zmm1, zmm29 vminps zmm30, zmm1, zmm30 - vmaxps zmm7, zmm0, zmm7 - vmaxps zmm8, zmm0, zmm8 - vmaxps zmm9, zmm0, zmm9 + vmaxps zmm11, zmm0, zmm11 + vmaxps zmm12, zmm0, zmm12 + vmaxps zmm13, zmm0, zmm13 vmaxps zmm14, zmm0, zmm14 vmaxps zmm15, zmm0, zmm15 vmaxps zmm16, zmm0, zmm16 @@ -183,15 +183,15 @@ inner_loop_end: cmp rsi, 64 jl tail - vmovups [r10], zmm7 + vmovups [r10], zmm11 vmovups [r10 + 64], zmm16 vmovups [r10 + 128], zmm21 vmovups [r10 + 192], zmm26 - vmovups [r13], zmm8 + vmovups [r13], zmm12 vmovups [r13 + 64], zmm17 vmovups [r13 + 128], zmm22 vmovups [r13 + 192], zmm27 - vmovups [rbx], zmm9 + vmovups [rbx], zmm13 vmovups [rbx + 64], zmm18 vmovups [rbx + 128], zmm23 vmovups [rbx + 192], zmm28 @@ -225,15 +225,15 @@ tail: shr r11, 16 kmovw k4, r11d - vmovups ZMMWORD PTR [r10]{k1}, zmm7 + vmovups ZMMWORD PTR [r10]{k1}, zmm11 vmovups ZMMWORD PTR [r10 + 64]{k2}, zmm16 vmovups ZMMWORD PTR [r10 + 128]{k3}, zmm21 vmovups ZMMWORD PTR [r10 + 192]{k4}, zmm26 - vmovups ZMMWORD PTR [r13]{k1}, zmm8 + vmovups ZMMWORD PTR [r13]{k1}, zmm12 vmovups ZMMWORD PTR [r13 + 64]{k2}, zmm17 vmovups ZMMWORD PTR [r13 + 128]{k3}, zmm22 vmovups ZMMWORD PTR [r13 + 192]{k4}, zmm27 - vmovups ZMMWORD PTR [rbx]{k1}, zmm9 + vmovups ZMMWORD PTR [rbx]{k1}, zmm13 vmovups ZMMWORD PTR [rbx + 64]{k2}, zmm18 vmovups ZMMWORD PTR [rbx + 128]{k3}, zmm23 vmovups ZMMWORD PTR [rbx + 192]{k4}, zmm28 diff --git a/src/f32-gemm/gen/f32-gemm-6x16-minmax-asm-amd64-avx512f-broadcast.S b/src/f32-gemm/gen/f32-gemm-6x16-minmax-asm-amd64-avx512f-broadcast.S index b6b713fa7b8..8c2e881f0fd 100644 --- a/src/f32-gemm/gen/f32-gemm-6x16-minmax-asm-amd64-avx512f-broadcast.S +++ b/src/f32-gemm/gen/f32-gemm-6x16-minmax-asm-amd64-avx512f-broadcast.S @@ -115,44 +115,44 @@ outer_loop: mov r10, [rsp + 96] # Initialize accumulators with the biases. - vmovaps zmm7, [r9 + 0] - vmovaps zmm8, zmm7 - vmovaps zmm9, zmm7 - vmovaps zmm14, zmm7 - vmovaps zmm15, zmm7 - vmovaps zmm16, zmm7 + vmovaps zmm11, [r9 + 0] + vmovaps zmm12, zmm11 + vmovaps zmm13, zmm11 + vmovaps zmm14, zmm11 + vmovaps zmm15, zmm11 + vmovaps zmm16, zmm11 add r9, 64 inner_loop: - vmovaps zmm10, [r9 + 0] + vmovaps zmm7, [r9 + 0] add r9, 64 vbroadcastss zmm2, DWORD PTR [rcx + r11] - vfmadd231ps zmm7, zmm2, zmm10 + vfmadd231ps zmm11, zmm2, zmm7 vbroadcastss zmm2, DWORD PTR [rax + r11] - vfmadd231ps zmm8, zmm2, zmm10 + vfmadd231ps zmm12, zmm2, zmm7 vbroadcastss zmm2, DWORD PTR [r15 + r11] - vfmadd231ps zmm9, zmm2, zmm10 + vfmadd231ps zmm13, zmm2, zmm7 vbroadcastss zmm2, DWORD PTR [r14 + r11] - vfmadd231ps zmm14, zmm2, zmm10 + vfmadd231ps zmm14, zmm2, zmm7 vbroadcastss zmm2, DWORD PTR [r12 + r11] - vfmadd231ps zmm15, zmm2, zmm10 + vfmadd231ps zmm15, zmm2, zmm7 vbroadcastss zmm2, DWORD PTR [r10 + r11] - vfmadd231ps zmm16, zmm2, zmm10 + vfmadd231ps zmm16, zmm2, zmm7 add r11, 4 cmp rdx, r11 jne inner_loop inner_loop_end: # Min/max clamping. - vminps zmm7, zmm1, zmm7 - vminps zmm8, zmm1, zmm8 - vminps zmm9, zmm1, zmm9 + vminps zmm11, zmm1, zmm11 + vminps zmm12, zmm1, zmm12 + vminps zmm13, zmm1, zmm13 vminps zmm14, zmm1, zmm14 vminps zmm15, zmm1, zmm15 vminps zmm16, zmm1, zmm16 - vmaxps zmm7, zmm0, zmm7 - vmaxps zmm8, zmm0, zmm8 - vmaxps zmm9, zmm0, zmm9 + vmaxps zmm11, zmm0, zmm11 + vmaxps zmm12, zmm0, zmm12 + vmaxps zmm13, zmm0, zmm13 vmaxps zmm14, zmm0, zmm14 vmaxps zmm15, zmm0, zmm15 vmaxps zmm16, zmm0, zmm16 @@ -169,9 +169,9 @@ inner_loop_end: cmp rsi, 16 jl tail - vmovups [rcx], zmm7 - vmovups [rax], zmm8 - vmovups [r15], zmm9 + vmovups [rcx], zmm11 + vmovups [rax], zmm12 + vmovups [r15], zmm13 vmovups [r14], zmm14 vmovups [r12], zmm15 vmovups [r10], zmm16 @@ -199,9 +199,9 @@ tail: shlx r11, r11, rsi not r11 kmovw k1, r11d - vmovups ZMMWORD PTR [rcx]{k1}, zmm7 - vmovups ZMMWORD PTR [rax]{k1}, zmm8 - vmovups ZMMWORD PTR [r15]{k1}, zmm9 + vmovups ZMMWORD PTR [rcx]{k1}, zmm11 + vmovups ZMMWORD PTR [rax]{k1}, zmm12 + vmovups ZMMWORD PTR [r15]{k1}, zmm13 vmovups ZMMWORD PTR [r14]{k1}, zmm14 vmovups ZMMWORD PTR [r12]{k1}, zmm15 vmovups ZMMWORD PTR [r10]{k1}, zmm16 diff --git a/src/f32-gemm/gen/f32-gemm-6x16c2-minmax-asm-amd64-avx512f-broadcast.S b/src/f32-gemm/gen/f32-gemm-6x16c2-minmax-asm-amd64-avx512f-broadcast.S new file mode 100644 index 00000000000..862e64cf3a5 --- /dev/null +++ b/src/f32-gemm/gen/f32-gemm-6x16c2-minmax-asm-amd64-avx512f-broadcast.S @@ -0,0 +1,338 @@ +// Copyright 2025 Google LLC +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + +#include "xnnpack/assembly.h" +.PERMUTATION: + .long 0 + .long 2 + .long 4 + .long 6 + .long 8 + .long 10 + .long 12 + .long 14 + .long 16 + .long 18 + .long 20 + .long 22 + .long 24 + .long 26 + .long 28 + .long 30 + +BEGIN_FUNCTION xnn_f32_gemm_minmax_ukernel_6x16c2__asm_amd64_avx512f_broadcast + + .intel_syntax noprefix + + # Free up GP registers. + push rbx + push rbp + push r15 + push r14 + push r13 + push r12 + + # load params to free up a GP registers + mov r13, [rsp + 80] # params + vbroadcastss zmm0, DWORD PTR [r13] + vbroadcastss zmm1, DWORD PTR [r13 + 4] + + # Load c pointer. + mov r10, [rsp + 56] + # Load cm_stride. + mov r11, [rsp + 64] + + # Align the stack pointer. + mov r13, rsp + sub rsp, 64 + and rsp, 0xFFFFFFFFFFFFFFC0 + # Store the old stack pointer containing the return address + mov [rsp], r13 + + # Allocate some space on the stack. + sub rsp, 192 + # Write rsi (a pointer) to the stack as we need the register. + mov [rsp + 16], rcx + # Write r10 (c pointer) to the stack as we need the register. + mov [rsp + 24], r10 + + # Clamp a & c pointers if mr <= 1 + mov rax, rcx + add rax, r8 + mov r13, r10 + add r13, r11 + cmp rdi, 1 + cmovle rax, rcx + cmovle r13, r10 + + mov [rsp + 32], rax + mov [rsp + 40], r13 + + # Clamp a & c pointers if mr <= 2 + mov rcx, rax + add rcx, r8 + mov r10, r13 + add r10, r11 + cmp rdi, 2 + cmovle rcx, rax + cmovle r10, r13 + + mov [rsp + 48], rcx + mov [rsp + 56], r10 + + # Clamp a & c pointers if mr <= 3 + mov rax, rcx + add rax, r8 + mov r13, r10 + add r13, r11 + cmp rdi, 3 + cmovle rax, rcx + cmovle r13, r10 + + mov [rsp + 64], rax + mov [rsp + 72], r13 + + # Clamp a & c pointers if mr <= 4 + mov rcx, rax + add rcx, r8 + mov r10, r13 + add r10, r11 + cmp rdi, 4 + cmovle rcx, rax + cmovle r10, r13 + + mov [rsp + 80], rcx + mov [rsp + 88], r10 + + # Clamp a & c pointers if mr <= 5 + mov rax, rcx + add rax, r8 + mov r13, r10 + add r13, r11 + cmp rdi, 5 + cmovle rax, rcx + cmovle r13, r10 + + mov [rsp + 96], rax + mov [rsp + 104], r13 + + # Copy k and flip bit. + mov r11, rdx + and r11, 0x4 + and rdx, 0xFFFFFFFFFFFFFFFB + mov [rsp + 120], r11 + mov r11, 0x5555 + kmovw k3, r11d + +outer_loop: + # Initialize k counter. + mov r11, 0 + # Read a pointers from stack into GP registers. + mov rcx, [rsp + 16] + mov rax, [rsp + 32] + mov r15, [rsp + 48] + mov r14, [rsp + 64] + mov r12, [rsp + 80] + mov r10, [rsp + 96] + + vmovaps zmm7, [r9 + 0] + # Interleave with zeros. + vpmovzxdq zmm11, ymm7 + vextracti64x4 ymm7, zmm7, 1 + vpmovzxdq zmm17, ymm7 + vmovaps zmm12, zmm11 + vmovaps zmm13, zmm11 + vmovaps zmm14, zmm11 + vmovaps zmm15, zmm11 + vmovaps zmm16, zmm11 + vmovaps zmm18, zmm17 + vmovaps zmm19, zmm17 + vmovaps zmm20, zmm17 + vmovaps zmm21, zmm17 + vmovaps zmm22, zmm17 + add r9, 64 + + # Are there at least 8 bytes? + cmp rdx, 8 + js inner_loop_tail + +inner_loop: + vmovaps zmm7, [r9 + 0] + vmovaps zmm8, [r9 + 64] + add r9, 128 + vbroadcastsd zmm2, QWORD PTR [rcx + r11] + vfmadd231ps zmm11, zmm2, zmm7 + vfmadd231ps zmm17, zmm2, zmm8 + vbroadcastsd zmm2, QWORD PTR [rax + r11] + vfmadd231ps zmm12, zmm2, zmm7 + vfmadd231ps zmm18, zmm2, zmm8 + vbroadcastsd zmm2, QWORD PTR [r15 + r11] + vfmadd231ps zmm13, zmm2, zmm7 + vfmadd231ps zmm19, zmm2, zmm8 + vbroadcastsd zmm2, QWORD PTR [r14 + r11] + vfmadd231ps zmm14, zmm2, zmm7 + vfmadd231ps zmm20, zmm2, zmm8 + vbroadcastsd zmm2, QWORD PTR [r12 + r11] + vfmadd231ps zmm15, zmm2, zmm7 + vfmadd231ps zmm21, zmm2, zmm8 + vbroadcastsd zmm2, QWORD PTR [r10 + r11] + vfmadd231ps zmm16, zmm2, zmm7 + vfmadd231ps zmm22, zmm2, zmm8 + + add r11, 8 + cmp rdx, r11 + jne inner_loop + + # Store nc_register. + mov [rsp + 128], rsi + # Load odd k bit. + mov rsi, [rsp + 120] + # Check if channels are odd. + test rsi, rsi + mov rsi, [rsp + 128] + jz inner_loop_end + +inner_loop_tail: + vmovaps zmm7, [r9 + 0] + vmovaps zmm8, [r9 + 64] + add r9, 128 + vbroadcastsd zmm2, QWORD PTR [rcx + r11] + vfmadd231ps zmm11{k3}, zmm2, zmm7 + vfmadd231ps zmm17{k3}, zmm2, zmm8 + vbroadcastsd zmm2, QWORD PTR [rax + r11] + vfmadd231ps zmm12{k3}, zmm2, zmm7 + vfmadd231ps zmm18{k3}, zmm2, zmm8 + vbroadcastsd zmm2, QWORD PTR [r15 + r11] + vfmadd231ps zmm13{k3}, zmm2, zmm7 + vfmadd231ps zmm19{k3}, zmm2, zmm8 + vbroadcastsd zmm2, QWORD PTR [r14 + r11] + vfmadd231ps zmm14{k3}, zmm2, zmm7 + vfmadd231ps zmm20{k3}, zmm2, zmm8 + vbroadcastsd zmm2, QWORD PTR [r12 + r11] + vfmadd231ps zmm15{k3}, zmm2, zmm7 + vfmadd231ps zmm21{k3}, zmm2, zmm8 + vbroadcastsd zmm2, QWORD PTR [r10 + r11] + vfmadd231ps zmm16{k3}, zmm2, zmm7 + vfmadd231ps zmm22{k3}, zmm2, zmm8 +inner_loop_end: + vpsrlq zmm7, zmm11, 32 + vaddps zmm11, zmm11, zmm7 + vpsrlq zmm7, zmm12, 32 + vaddps zmm12, zmm12, zmm7 + vpsrlq zmm7, zmm13, 32 + vaddps zmm13, zmm13, zmm7 + vpsrlq zmm7, zmm14, 32 + vaddps zmm14, zmm14, zmm7 + vpsrlq zmm7, zmm15, 32 + vaddps zmm15, zmm15, zmm7 + vpsrlq zmm7, zmm16, 32 + vaddps zmm16, zmm16, zmm7 + vpsrlq zmm7, zmm17, 32 + vaddps zmm17, zmm17, zmm7 + vpsrlq zmm7, zmm18, 32 + vaddps zmm18, zmm18, zmm7 + vpsrlq zmm7, zmm19, 32 + vaddps zmm19, zmm19, zmm7 + vpsrlq zmm7, zmm20, 32 + vaddps zmm20, zmm20, zmm7 + vpsrlq zmm7, zmm21, 32 + vaddps zmm21, zmm21, zmm7 + vpsrlq zmm7, zmm22, 32 + vaddps zmm22, zmm22, zmm7 + vmovups zmm7, zmmword ptr [rip + .PERMUTATION] + vpermt2ps zmm11, zmm7, zmm17 + vpermt2ps zmm12, zmm7, zmm18 + vpermt2ps zmm13, zmm7, zmm19 + vpermt2ps zmm14, zmm7, zmm20 + vpermt2ps zmm15, zmm7, zmm21 + vpermt2ps zmm16, zmm7, zmm22 + # Min/max clamping. + vminps zmm11, zmm1, zmm11 + vminps zmm12, zmm1, zmm12 + vminps zmm13, zmm1, zmm13 + vminps zmm14, zmm1, zmm14 + vminps zmm15, zmm1, zmm15 + vminps zmm16, zmm1, zmm16 + vmaxps zmm11, zmm0, zmm11 + vmaxps zmm12, zmm0, zmm12 + vmaxps zmm13, zmm0, zmm13 + vmaxps zmm14, zmm0, zmm14 + vmaxps zmm15, zmm0, zmm15 + vmaxps zmm16, zmm0, zmm16 + + # Pop output pointers from the stack. + mov rcx, [rsp + 24] + mov rax, [rsp + 40] + mov r15, [rsp + 56] + mov r14, [rsp + 72] + mov r12, [rsp + 88] + mov r10, [rsp + 104] + + # Check whether full or partial store. + cmp rsi, 16 + jl tail + + vmovups [rcx], zmm11 + vmovups [rax], zmm12 + vmovups [r15], zmm13 + vmovups [r14], zmm14 + vmovups [r12], zmm15 + vmovups [r10], zmm16 + add rcx, 64 + add rax, 64 + add r15, 64 + add r14, 64 + add r12, 64 + add r10, 64 + + # Write output pointers to the stack. + mov [rsp + 24], rcx + mov [rsp + 40], rax + mov [rsp + 56], r15 + mov [rsp + 72], r14 + mov [rsp + 88], r12 + mov [rsp + 104], r10 + + sub rsi, 16 + jne outer_loop + jmp return + +tail: + mov r11, -1 + shlx r11, r11, rsi + not r11 + kmovw k1, r11d + vmovups ZMMWORD PTR [rcx]{k1}, zmm11 + vmovups ZMMWORD PTR [rax]{k1}, zmm12 + vmovups ZMMWORD PTR [r15]{k1}, zmm13 + vmovups ZMMWORD PTR [r14]{k1}, zmm14 + vmovups ZMMWORD PTR [r12]{k1}, zmm15 + vmovups ZMMWORD PTR [r10]{k1}, zmm16 + +return: + add rsp, 192 + mov r13, [rsp] + mov rsp, r13 + # Restore the callee saved registers. + pop r12 + pop r13 + pop r14 + pop r15 + pop rbp + pop rbx + ret +END_FUNCTION xnn_f32_gemm_minmax_ukernel_6x16c2__asm_amd64_avx512f_broadcast + + #ifdef __has_feature + #if __has_feature(dataflow_sanitizer) +BEGIN_FUNCTION xnn_f32_gemm_minmax_ukernel_6x16c2__asm_amd64_avx512f_broadcast.dfsan + .intel_syntax noprefix + # We could implement this by calling a function that implements the dfsan instrumentation. + # For now, just break, so if someone tries to use this, they'll know where the problem is. + int 3 + ret +END_FUNCTION xnn_f32_gemm_minmax_ukernel_6x16c2__asm_amd64_avx512f_broadcast.dfsan + #endif + #endif \ No newline at end of file diff --git a/src/f32-gemm/gen/f32-gemm-6x32-minmax-asm-amd64-avx512f-broadcast.S b/src/f32-gemm/gen/f32-gemm-6x32-minmax-asm-amd64-avx512f-broadcast.S index fb9fb962695..e09db90eac2 100644 --- a/src/f32-gemm/gen/f32-gemm-6x32-minmax-asm-amd64-avx512f-broadcast.S +++ b/src/f32-gemm/gen/f32-gemm-6x32-minmax-asm-amd64-avx512f-broadcast.S @@ -115,13 +115,13 @@ outer_loop: mov r10, [rsp + 96] # Initialize accumulators with the biases. - vmovaps zmm7, [r9 + 0] + vmovaps zmm11, [r9 + 0] vmovaps zmm17, [r9 + 64] - vmovaps zmm8, zmm7 - vmovaps zmm9, zmm7 - vmovaps zmm14, zmm7 - vmovaps zmm15, zmm7 - vmovaps zmm16, zmm7 + vmovaps zmm12, zmm11 + vmovaps zmm13, zmm11 + vmovaps zmm14, zmm11 + vmovaps zmm15, zmm11 + vmovaps zmm16, zmm11 vmovaps zmm18, zmm17 vmovaps zmm19, zmm17 vmovaps zmm20, zmm17 @@ -130,36 +130,36 @@ outer_loop: add r9, 128 inner_loop: - vmovaps zmm10, [r9 + 0] - vmovaps zmm11, [r9 + 64] + vmovaps zmm7, [r9 + 0] + vmovaps zmm8, [r9 + 64] add r9, 128 vbroadcastss zmm2, DWORD PTR [rcx + r11] - vfmadd231ps zmm7, zmm2, zmm10 - vfmadd231ps zmm17, zmm2, zmm11 + vfmadd231ps zmm11, zmm2, zmm7 + vfmadd231ps zmm17, zmm2, zmm8 vbroadcastss zmm2, DWORD PTR [rax + r11] - vfmadd231ps zmm8, zmm2, zmm10 - vfmadd231ps zmm18, zmm2, zmm11 + vfmadd231ps zmm12, zmm2, zmm7 + vfmadd231ps zmm18, zmm2, zmm8 vbroadcastss zmm2, DWORD PTR [r15 + r11] - vfmadd231ps zmm9, zmm2, zmm10 - vfmadd231ps zmm19, zmm2, zmm11 + vfmadd231ps zmm13, zmm2, zmm7 + vfmadd231ps zmm19, zmm2, zmm8 vbroadcastss zmm2, DWORD PTR [r14 + r11] - vfmadd231ps zmm14, zmm2, zmm10 - vfmadd231ps zmm20, zmm2, zmm11 + vfmadd231ps zmm14, zmm2, zmm7 + vfmadd231ps zmm20, zmm2, zmm8 vbroadcastss zmm2, DWORD PTR [r12 + r11] - vfmadd231ps zmm15, zmm2, zmm10 - vfmadd231ps zmm21, zmm2, zmm11 + vfmadd231ps zmm15, zmm2, zmm7 + vfmadd231ps zmm21, zmm2, zmm8 vbroadcastss zmm2, DWORD PTR [r10 + r11] - vfmadd231ps zmm16, zmm2, zmm10 - vfmadd231ps zmm22, zmm2, zmm11 + vfmadd231ps zmm16, zmm2, zmm7 + vfmadd231ps zmm22, zmm2, zmm8 add r11, 4 cmp rdx, r11 jne inner_loop inner_loop_end: # Min/max clamping. - vminps zmm7, zmm1, zmm7 - vminps zmm8, zmm1, zmm8 - vminps zmm9, zmm1, zmm9 + vminps zmm11, zmm1, zmm11 + vminps zmm12, zmm1, zmm12 + vminps zmm13, zmm1, zmm13 vminps zmm14, zmm1, zmm14 vminps zmm15, zmm1, zmm15 vminps zmm16, zmm1, zmm16 @@ -169,9 +169,9 @@ inner_loop_end: vminps zmm20, zmm1, zmm20 vminps zmm21, zmm1, zmm21 vminps zmm22, zmm1, zmm22 - vmaxps zmm7, zmm0, zmm7 - vmaxps zmm8, zmm0, zmm8 - vmaxps zmm9, zmm0, zmm9 + vmaxps zmm11, zmm0, zmm11 + vmaxps zmm12, zmm0, zmm12 + vmaxps zmm13, zmm0, zmm13 vmaxps zmm14, zmm0, zmm14 vmaxps zmm15, zmm0, zmm15 vmaxps zmm16, zmm0, zmm16 @@ -194,11 +194,11 @@ inner_loop_end: cmp rsi, 32 jl tail - vmovups [rcx], zmm7 + vmovups [rcx], zmm11 vmovups [rcx + 64], zmm17 - vmovups [rax], zmm8 + vmovups [rax], zmm12 vmovups [rax + 64], zmm18 - vmovups [r15], zmm9 + vmovups [r15], zmm13 vmovups [r15 + 64], zmm19 vmovups [r14], zmm14 vmovups [r14 + 64], zmm20 @@ -232,11 +232,11 @@ tail: kmovw k1, r11d shr r11d, 16 kmovw k2, r11d - vmovups ZMMWORD PTR [rcx]{k1}, zmm7 + vmovups ZMMWORD PTR [rcx]{k1}, zmm11 vmovups ZMMWORD PTR [rcx + 64]{k2}, zmm17 - vmovups ZMMWORD PTR [rax]{k1}, zmm8 + vmovups ZMMWORD PTR [rax]{k1}, zmm12 vmovups ZMMWORD PTR [rax + 64]{k2}, zmm18 - vmovups ZMMWORD PTR [r15]{k1}, zmm9 + vmovups ZMMWORD PTR [r15]{k1}, zmm13 vmovups ZMMWORD PTR [r15 + 64]{k2}, zmm19 vmovups ZMMWORD PTR [r14]{k1}, zmm14 vmovups ZMMWORD PTR [r14 + 64]{k2}, zmm20 diff --git a/src/f32-gemm/gen/f32-gemm-7x16-minmax-asm-amd64-avx512f-broadcast.S b/src/f32-gemm/gen/f32-gemm-7x16-minmax-asm-amd64-avx512f-broadcast.S index 123b10706b0..74d97a55d1b 100644 --- a/src/f32-gemm/gen/f32-gemm-7x16-minmax-asm-amd64-avx512f-broadcast.S +++ b/src/f32-gemm/gen/f32-gemm-7x16-minmax-asm-amd64-avx512f-broadcast.S @@ -128,48 +128,48 @@ outer_loop: mov r13, [rsp + 112] # Initialize accumulators with the biases. - vmovaps zmm7, [r9 + 0] - vmovaps zmm8, zmm7 - vmovaps zmm9, zmm7 - vmovaps zmm14, zmm7 - vmovaps zmm15, zmm7 - vmovaps zmm16, zmm7 - vmovaps zmm17, zmm7 + vmovaps zmm11, [r9 + 0] + vmovaps zmm12, zmm11 + vmovaps zmm13, zmm11 + vmovaps zmm14, zmm11 + vmovaps zmm15, zmm11 + vmovaps zmm16, zmm11 + vmovaps zmm17, zmm11 add r9, 64 inner_loop: - vmovaps zmm10, [r9 + 0] + vmovaps zmm7, [r9 + 0] add r9, 64 vbroadcastss zmm2, DWORD PTR [rcx + r11] - vfmadd231ps zmm7, zmm2, zmm10 + vfmadd231ps zmm11, zmm2, zmm7 vbroadcastss zmm2, DWORD PTR [rax + r11] - vfmadd231ps zmm8, zmm2, zmm10 + vfmadd231ps zmm12, zmm2, zmm7 vbroadcastss zmm2, DWORD PTR [r15 + r11] - vfmadd231ps zmm9, zmm2, zmm10 + vfmadd231ps zmm13, zmm2, zmm7 vbroadcastss zmm2, DWORD PTR [r14 + r11] - vfmadd231ps zmm14, zmm2, zmm10 + vfmadd231ps zmm14, zmm2, zmm7 vbroadcastss zmm2, DWORD PTR [r12 + r11] - vfmadd231ps zmm15, zmm2, zmm10 + vfmadd231ps zmm15, zmm2, zmm7 vbroadcastss zmm2, DWORD PTR [r10 + r11] - vfmadd231ps zmm16, zmm2, zmm10 + vfmadd231ps zmm16, zmm2, zmm7 vbroadcastss zmm2, DWORD PTR [r13 + r11] - vfmadd231ps zmm17, zmm2, zmm10 + vfmadd231ps zmm17, zmm2, zmm7 add r11, 4 cmp rdx, r11 jne inner_loop inner_loop_end: # Min/max clamping. - vminps zmm7, zmm1, zmm7 - vminps zmm8, zmm1, zmm8 - vminps zmm9, zmm1, zmm9 + vminps zmm11, zmm1, zmm11 + vminps zmm12, zmm1, zmm12 + vminps zmm13, zmm1, zmm13 vminps zmm14, zmm1, zmm14 vminps zmm15, zmm1, zmm15 vminps zmm16, zmm1, zmm16 vminps zmm17, zmm1, zmm17 - vmaxps zmm7, zmm0, zmm7 - vmaxps zmm8, zmm0, zmm8 - vmaxps zmm9, zmm0, zmm9 + vmaxps zmm11, zmm0, zmm11 + vmaxps zmm12, zmm0, zmm12 + vmaxps zmm13, zmm0, zmm13 vmaxps zmm14, zmm0, zmm14 vmaxps zmm15, zmm0, zmm15 vmaxps zmm16, zmm0, zmm16 @@ -188,9 +188,9 @@ inner_loop_end: cmp rsi, 16 jl tail - vmovups [rcx], zmm7 - vmovups [rax], zmm8 - vmovups [r15], zmm9 + vmovups [rcx], zmm11 + vmovups [rax], zmm12 + vmovups [r15], zmm13 vmovups [r14], zmm14 vmovups [r12], zmm15 vmovups [r10], zmm16 @@ -221,9 +221,9 @@ tail: shlx r11, r11, rsi not r11 kmovw k1, r11d - vmovups ZMMWORD PTR [rcx]{k1}, zmm7 - vmovups ZMMWORD PTR [rax]{k1}, zmm8 - vmovups ZMMWORD PTR [r15]{k1}, zmm9 + vmovups ZMMWORD PTR [rcx]{k1}, zmm11 + vmovups ZMMWORD PTR [rax]{k1}, zmm12 + vmovups ZMMWORD PTR [r15]{k1}, zmm13 vmovups ZMMWORD PTR [r14]{k1}, zmm14 vmovups ZMMWORD PTR [r12]{k1}, zmm15 vmovups ZMMWORD PTR [r10]{k1}, zmm16 diff --git a/src/f32-gemm/gen/f32-gemm-7x16c2-minmax-asm-amd64-avx512f-broadcast.S b/src/f32-gemm/gen/f32-gemm-7x16c2-minmax-asm-amd64-avx512f-broadcast.S new file mode 100644 index 00000000000..907ff8b98bd --- /dev/null +++ b/src/f32-gemm/gen/f32-gemm-7x16c2-minmax-asm-amd64-avx512f-broadcast.S @@ -0,0 +1,371 @@ +// Copyright 2025 Google LLC +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + +#include "xnnpack/assembly.h" +.PERMUTATION: + .long 0 + .long 2 + .long 4 + .long 6 + .long 8 + .long 10 + .long 12 + .long 14 + .long 16 + .long 18 + .long 20 + .long 22 + .long 24 + .long 26 + .long 28 + .long 30 + +BEGIN_FUNCTION xnn_f32_gemm_minmax_ukernel_7x16c2__asm_amd64_avx512f_broadcast + + .intel_syntax noprefix + + # Free up GP registers. + push rbx + push rbp + push r15 + push r14 + push r13 + push r12 + + # load params to free up a GP registers + mov r13, [rsp + 80] # params + vbroadcastss zmm0, DWORD PTR [r13] + vbroadcastss zmm1, DWORD PTR [r13 + 4] + + # Load c pointer. + mov r10, [rsp + 56] + # Load cm_stride. + mov r11, [rsp + 64] + + # Align the stack pointer. + mov r13, rsp + sub rsp, 64 + and rsp, 0xFFFFFFFFFFFFFFC0 + # Store the old stack pointer containing the return address + mov [rsp], r13 + + # Allocate some space on the stack. + sub rsp, 192 + # Write rsi (a pointer) to the stack as we need the register. + mov [rsp + 16], rcx + # Write r10 (c pointer) to the stack as we need the register. + mov [rsp + 24], r10 + + # Clamp a & c pointers if mr <= 1 + mov rax, rcx + add rax, r8 + mov r13, r10 + add r13, r11 + cmp rdi, 1 + cmovle rax, rcx + cmovle r13, r10 + + mov [rsp + 32], rax + mov [rsp + 40], r13 + + # Clamp a & c pointers if mr <= 2 + mov rcx, rax + add rcx, r8 + mov r10, r13 + add r10, r11 + cmp rdi, 2 + cmovle rcx, rax + cmovle r10, r13 + + mov [rsp + 48], rcx + mov [rsp + 56], r10 + + # Clamp a & c pointers if mr <= 3 + mov rax, rcx + add rax, r8 + mov r13, r10 + add r13, r11 + cmp rdi, 3 + cmovle rax, rcx + cmovle r13, r10 + + mov [rsp + 64], rax + mov [rsp + 72], r13 + + # Clamp a & c pointers if mr <= 4 + mov rcx, rax + add rcx, r8 + mov r10, r13 + add r10, r11 + cmp rdi, 4 + cmovle rcx, rax + cmovle r10, r13 + + mov [rsp + 80], rcx + mov [rsp + 88], r10 + + # Clamp a & c pointers if mr <= 5 + mov rax, rcx + add rax, r8 + mov r13, r10 + add r13, r11 + cmp rdi, 5 + cmovle rax, rcx + cmovle r13, r10 + + mov [rsp + 96], rax + mov [rsp + 104], r13 + + # Clamp a & c pointers if mr <= 6 + mov rcx, rax + add rcx, r8 + mov r10, r13 + add r10, r11 + cmp rdi, 6 + cmovle rcx, rax + cmovle r10, r13 + + mov [rsp + 112], rcx + mov [rsp + 120], r10 + + # Copy k and flip bit. + mov r11, rdx + and r11, 0x4 + and rdx, 0xFFFFFFFFFFFFFFFB + mov [rsp + 136], r11 + mov r11, 0x5555 + kmovw k3, r11d + +outer_loop: + # Initialize k counter. + mov r11, 0 + # Read a pointers from stack into GP registers. + mov rcx, [rsp + 16] + mov rax, [rsp + 32] + mov r15, [rsp + 48] + mov r14, [rsp + 64] + mov r12, [rsp + 80] + mov r10, [rsp + 96] + mov r13, [rsp + 112] + + vmovaps zmm7, [r9 + 0] + # Interleave with zeros. + vpmovzxdq zmm11, ymm7 + vextracti64x4 ymm7, zmm7, 1 + vpmovzxdq zmm18, ymm7 + vmovaps zmm12, zmm11 + vmovaps zmm13, zmm11 + vmovaps zmm14, zmm11 + vmovaps zmm15, zmm11 + vmovaps zmm16, zmm11 + vmovaps zmm17, zmm11 + vmovaps zmm19, zmm18 + vmovaps zmm20, zmm18 + vmovaps zmm21, zmm18 + vmovaps zmm22, zmm18 + vmovaps zmm23, zmm18 + vmovaps zmm24, zmm18 + add r9, 64 + + # Are there at least 8 bytes? + cmp rdx, 8 + js inner_loop_tail + +inner_loop: + vmovaps zmm7, [r9 + 0] + vmovaps zmm8, [r9 + 64] + add r9, 128 + vbroadcastsd zmm2, QWORD PTR [rcx + r11] + vfmadd231ps zmm11, zmm2, zmm7 + vfmadd231ps zmm18, zmm2, zmm8 + vbroadcastsd zmm2, QWORD PTR [rax + r11] + vfmadd231ps zmm12, zmm2, zmm7 + vfmadd231ps zmm19, zmm2, zmm8 + vbroadcastsd zmm2, QWORD PTR [r15 + r11] + vfmadd231ps zmm13, zmm2, zmm7 + vfmadd231ps zmm20, zmm2, zmm8 + vbroadcastsd zmm2, QWORD PTR [r14 + r11] + vfmadd231ps zmm14, zmm2, zmm7 + vfmadd231ps zmm21, zmm2, zmm8 + vbroadcastsd zmm2, QWORD PTR [r12 + r11] + vfmadd231ps zmm15, zmm2, zmm7 + vfmadd231ps zmm22, zmm2, zmm8 + vbroadcastsd zmm2, QWORD PTR [r10 + r11] + vfmadd231ps zmm16, zmm2, zmm7 + vfmadd231ps zmm23, zmm2, zmm8 + vbroadcastsd zmm2, QWORD PTR [r13 + r11] + vfmadd231ps zmm17, zmm2, zmm7 + vfmadd231ps zmm24, zmm2, zmm8 + + add r11, 8 + cmp rdx, r11 + jne inner_loop + + # Store nc_register. + mov [rsp + 144], rsi + # Load odd k bit. + mov rsi, [rsp + 136] + # Check if channels are odd. + test rsi, rsi + mov rsi, [rsp + 144] + jz inner_loop_end + +inner_loop_tail: + vmovaps zmm7, [r9 + 0] + vmovaps zmm8, [r9 + 64] + add r9, 128 + vbroadcastsd zmm2, QWORD PTR [rcx + r11] + vfmadd231ps zmm11{k3}, zmm2, zmm7 + vfmadd231ps zmm18{k3}, zmm2, zmm8 + vbroadcastsd zmm2, QWORD PTR [rax + r11] + vfmadd231ps zmm12{k3}, zmm2, zmm7 + vfmadd231ps zmm19{k3}, zmm2, zmm8 + vbroadcastsd zmm2, QWORD PTR [r15 + r11] + vfmadd231ps zmm13{k3}, zmm2, zmm7 + vfmadd231ps zmm20{k3}, zmm2, zmm8 + vbroadcastsd zmm2, QWORD PTR [r14 + r11] + vfmadd231ps zmm14{k3}, zmm2, zmm7 + vfmadd231ps zmm21{k3}, zmm2, zmm8 + vbroadcastsd zmm2, QWORD PTR [r12 + r11] + vfmadd231ps zmm15{k3}, zmm2, zmm7 + vfmadd231ps zmm22{k3}, zmm2, zmm8 + vbroadcastsd zmm2, QWORD PTR [r10 + r11] + vfmadd231ps zmm16{k3}, zmm2, zmm7 + vfmadd231ps zmm23{k3}, zmm2, zmm8 + vbroadcastsd zmm2, QWORD PTR [r13 + r11] + vfmadd231ps zmm17{k3}, zmm2, zmm7 + vfmadd231ps zmm24{k3}, zmm2, zmm8 +inner_loop_end: + vpsrlq zmm7, zmm11, 32 + vaddps zmm11, zmm11, zmm7 + vpsrlq zmm7, zmm12, 32 + vaddps zmm12, zmm12, zmm7 + vpsrlq zmm7, zmm13, 32 + vaddps zmm13, zmm13, zmm7 + vpsrlq zmm7, zmm14, 32 + vaddps zmm14, zmm14, zmm7 + vpsrlq zmm7, zmm15, 32 + vaddps zmm15, zmm15, zmm7 + vpsrlq zmm7, zmm16, 32 + vaddps zmm16, zmm16, zmm7 + vpsrlq zmm7, zmm17, 32 + vaddps zmm17, zmm17, zmm7 + vpsrlq zmm7, zmm18, 32 + vaddps zmm18, zmm18, zmm7 + vpsrlq zmm7, zmm19, 32 + vaddps zmm19, zmm19, zmm7 + vpsrlq zmm7, zmm20, 32 + vaddps zmm20, zmm20, zmm7 + vpsrlq zmm7, zmm21, 32 + vaddps zmm21, zmm21, zmm7 + vpsrlq zmm7, zmm22, 32 + vaddps zmm22, zmm22, zmm7 + vpsrlq zmm7, zmm23, 32 + vaddps zmm23, zmm23, zmm7 + vpsrlq zmm7, zmm24, 32 + vaddps zmm24, zmm24, zmm7 + vmovups zmm7, zmmword ptr [rip + .PERMUTATION] + vpermt2ps zmm11, zmm7, zmm18 + vpermt2ps zmm12, zmm7, zmm19 + vpermt2ps zmm13, zmm7, zmm20 + vpermt2ps zmm14, zmm7, zmm21 + vpermt2ps zmm15, zmm7, zmm22 + vpermt2ps zmm16, zmm7, zmm23 + vpermt2ps zmm17, zmm7, zmm24 + # Min/max clamping. + vminps zmm11, zmm1, zmm11 + vminps zmm12, zmm1, zmm12 + vminps zmm13, zmm1, zmm13 + vminps zmm14, zmm1, zmm14 + vminps zmm15, zmm1, zmm15 + vminps zmm16, zmm1, zmm16 + vminps zmm17, zmm1, zmm17 + vmaxps zmm11, zmm0, zmm11 + vmaxps zmm12, zmm0, zmm12 + vmaxps zmm13, zmm0, zmm13 + vmaxps zmm14, zmm0, zmm14 + vmaxps zmm15, zmm0, zmm15 + vmaxps zmm16, zmm0, zmm16 + vmaxps zmm17, zmm0, zmm17 + + # Pop output pointers from the stack. + mov rcx, [rsp + 24] + mov rax, [rsp + 40] + mov r15, [rsp + 56] + mov r14, [rsp + 72] + mov r12, [rsp + 88] + mov r10, [rsp + 104] + mov r13, [rsp + 120] + + # Check whether full or partial store. + cmp rsi, 16 + jl tail + + vmovups [rcx], zmm11 + vmovups [rax], zmm12 + vmovups [r15], zmm13 + vmovups [r14], zmm14 + vmovups [r12], zmm15 + vmovups [r10], zmm16 + vmovups [r13], zmm17 + add rcx, 64 + add rax, 64 + add r15, 64 + add r14, 64 + add r12, 64 + add r10, 64 + add r13, 64 + + # Write output pointers to the stack. + mov [rsp + 24], rcx + mov [rsp + 40], rax + mov [rsp + 56], r15 + mov [rsp + 72], r14 + mov [rsp + 88], r12 + mov [rsp + 104], r10 + mov [rsp + 120], r13 + + sub rsi, 16 + jne outer_loop + jmp return + +tail: + mov r11, -1 + shlx r11, r11, rsi + not r11 + kmovw k1, r11d + vmovups ZMMWORD PTR [rcx]{k1}, zmm11 + vmovups ZMMWORD PTR [rax]{k1}, zmm12 + vmovups ZMMWORD PTR [r15]{k1}, zmm13 + vmovups ZMMWORD PTR [r14]{k1}, zmm14 + vmovups ZMMWORD PTR [r12]{k1}, zmm15 + vmovups ZMMWORD PTR [r10]{k1}, zmm16 + vmovups ZMMWORD PTR [r13]{k1}, zmm17 + +return: + add rsp, 192 + mov r13, [rsp] + mov rsp, r13 + # Restore the callee saved registers. + pop r12 + pop r13 + pop r14 + pop r15 + pop rbp + pop rbx + ret +END_FUNCTION xnn_f32_gemm_minmax_ukernel_7x16c2__asm_amd64_avx512f_broadcast + + #ifdef __has_feature + #if __has_feature(dataflow_sanitizer) +BEGIN_FUNCTION xnn_f32_gemm_minmax_ukernel_7x16c2__asm_amd64_avx512f_broadcast.dfsan + .intel_syntax noprefix + # We could implement this by calling a function that implements the dfsan instrumentation. + # For now, just break, so if someone tries to use this, they'll know where the problem is. + int 3 + ret +END_FUNCTION xnn_f32_gemm_minmax_ukernel_7x16c2__asm_amd64_avx512f_broadcast.dfsan + #endif + #endif \ No newline at end of file diff --git a/src/f32-gemm/gen/f32-gemm-7x32-minmax-asm-amd64-avx512f-broadcast.S b/src/f32-gemm/gen/f32-gemm-7x32-minmax-asm-amd64-avx512f-broadcast.S index f44bda77593..07190c33160 100644 --- a/src/f32-gemm/gen/f32-gemm-7x32-minmax-asm-amd64-avx512f-broadcast.S +++ b/src/f32-gemm/gen/f32-gemm-7x32-minmax-asm-amd64-avx512f-broadcast.S @@ -128,14 +128,14 @@ outer_loop: mov r13, [rsp + 112] # Initialize accumulators with the biases. - vmovaps zmm7, [r9 + 0] + vmovaps zmm11, [r9 + 0] vmovaps zmm18, [r9 + 64] - vmovaps zmm8, zmm7 - vmovaps zmm9, zmm7 - vmovaps zmm14, zmm7 - vmovaps zmm15, zmm7 - vmovaps zmm16, zmm7 - vmovaps zmm17, zmm7 + vmovaps zmm12, zmm11 + vmovaps zmm13, zmm11 + vmovaps zmm14, zmm11 + vmovaps zmm15, zmm11 + vmovaps zmm16, zmm11 + vmovaps zmm17, zmm11 vmovaps zmm19, zmm18 vmovaps zmm20, zmm18 vmovaps zmm21, zmm18 @@ -145,39 +145,39 @@ outer_loop: add r9, 128 inner_loop: - vmovaps zmm10, [r9 + 0] - vmovaps zmm11, [r9 + 64] + vmovaps zmm7, [r9 + 0] + vmovaps zmm8, [r9 + 64] add r9, 128 vbroadcastss zmm2, DWORD PTR [rcx + r11] - vfmadd231ps zmm7, zmm2, zmm10 - vfmadd231ps zmm18, zmm2, zmm11 + vfmadd231ps zmm11, zmm2, zmm7 + vfmadd231ps zmm18, zmm2, zmm8 vbroadcastss zmm2, DWORD PTR [rax + r11] - vfmadd231ps zmm8, zmm2, zmm10 - vfmadd231ps zmm19, zmm2, zmm11 + vfmadd231ps zmm12, zmm2, zmm7 + vfmadd231ps zmm19, zmm2, zmm8 vbroadcastss zmm2, DWORD PTR [r15 + r11] - vfmadd231ps zmm9, zmm2, zmm10 - vfmadd231ps zmm20, zmm2, zmm11 + vfmadd231ps zmm13, zmm2, zmm7 + vfmadd231ps zmm20, zmm2, zmm8 vbroadcastss zmm2, DWORD PTR [r14 + r11] - vfmadd231ps zmm14, zmm2, zmm10 - vfmadd231ps zmm21, zmm2, zmm11 + vfmadd231ps zmm14, zmm2, zmm7 + vfmadd231ps zmm21, zmm2, zmm8 vbroadcastss zmm2, DWORD PTR [r12 + r11] - vfmadd231ps zmm15, zmm2, zmm10 - vfmadd231ps zmm22, zmm2, zmm11 + vfmadd231ps zmm15, zmm2, zmm7 + vfmadd231ps zmm22, zmm2, zmm8 vbroadcastss zmm2, DWORD PTR [r10 + r11] - vfmadd231ps zmm16, zmm2, zmm10 - vfmadd231ps zmm23, zmm2, zmm11 + vfmadd231ps zmm16, zmm2, zmm7 + vfmadd231ps zmm23, zmm2, zmm8 vbroadcastss zmm2, DWORD PTR [r13 + r11] - vfmadd231ps zmm17, zmm2, zmm10 - vfmadd231ps zmm24, zmm2, zmm11 + vfmadd231ps zmm17, zmm2, zmm7 + vfmadd231ps zmm24, zmm2, zmm8 add r11, 4 cmp rdx, r11 jne inner_loop inner_loop_end: # Min/max clamping. - vminps zmm7, zmm1, zmm7 - vminps zmm8, zmm1, zmm8 - vminps zmm9, zmm1, zmm9 + vminps zmm11, zmm1, zmm11 + vminps zmm12, zmm1, zmm12 + vminps zmm13, zmm1, zmm13 vminps zmm14, zmm1, zmm14 vminps zmm15, zmm1, zmm15 vminps zmm16, zmm1, zmm16 @@ -189,9 +189,9 @@ inner_loop_end: vminps zmm22, zmm1, zmm22 vminps zmm23, zmm1, zmm23 vminps zmm24, zmm1, zmm24 - vmaxps zmm7, zmm0, zmm7 - vmaxps zmm8, zmm0, zmm8 - vmaxps zmm9, zmm0, zmm9 + vmaxps zmm11, zmm0, zmm11 + vmaxps zmm12, zmm0, zmm12 + vmaxps zmm13, zmm0, zmm13 vmaxps zmm14, zmm0, zmm14 vmaxps zmm15, zmm0, zmm15 vmaxps zmm16, zmm0, zmm16 @@ -217,11 +217,11 @@ inner_loop_end: cmp rsi, 32 jl tail - vmovups [rcx], zmm7 + vmovups [rcx], zmm11 vmovups [rcx + 64], zmm18 - vmovups [rax], zmm8 + vmovups [rax], zmm12 vmovups [rax + 64], zmm19 - vmovups [r15], zmm9 + vmovups [r15], zmm13 vmovups [r15 + 64], zmm20 vmovups [r14], zmm14 vmovups [r14 + 64], zmm21 @@ -259,11 +259,11 @@ tail: kmovw k1, r11d shr r11d, 16 kmovw k2, r11d - vmovups ZMMWORD PTR [rcx]{k1}, zmm7 + vmovups ZMMWORD PTR [rcx]{k1}, zmm11 vmovups ZMMWORD PTR [rcx + 64]{k2}, zmm18 - vmovups ZMMWORD PTR [rax]{k1}, zmm8 + vmovups ZMMWORD PTR [rax]{k1}, zmm12 vmovups ZMMWORD PTR [rax + 64]{k2}, zmm19 - vmovups ZMMWORD PTR [r15]{k1}, zmm9 + vmovups ZMMWORD PTR [r15]{k1}, zmm13 vmovups ZMMWORD PTR [r15 + 64]{k2}, zmm20 vmovups ZMMWORD PTR [r14]{k1}, zmm14 vmovups ZMMWORD PTR [r14 + 64]{k2}, zmm21 diff --git a/src/f32-gemm/gen/f32-gemm-8x16-minmax-asm-amd64-avx512f-broadcast.S b/src/f32-gemm/gen/f32-gemm-8x16-minmax-asm-amd64-avx512f-broadcast.S index a30416d23b1..37b62a8baae 100644 --- a/src/f32-gemm/gen/f32-gemm-8x16-minmax-asm-amd64-avx512f-broadcast.S +++ b/src/f32-gemm/gen/f32-gemm-8x16-minmax-asm-amd64-avx512f-broadcast.S @@ -141,52 +141,52 @@ outer_loop: mov rbx, [rsp + 128] # Initialize accumulators with the biases. - vmovaps zmm7, [r9 + 0] - vmovaps zmm8, zmm7 - vmovaps zmm9, zmm7 - vmovaps zmm14, zmm7 - vmovaps zmm15, zmm7 - vmovaps zmm16, zmm7 - vmovaps zmm17, zmm7 - vmovaps zmm18, zmm7 + vmovaps zmm11, [r9 + 0] + vmovaps zmm12, zmm11 + vmovaps zmm13, zmm11 + vmovaps zmm14, zmm11 + vmovaps zmm15, zmm11 + vmovaps zmm16, zmm11 + vmovaps zmm17, zmm11 + vmovaps zmm18, zmm11 add r9, 64 inner_loop: - vmovaps zmm10, [r9 + 0] + vmovaps zmm7, [r9 + 0] add r9, 64 vbroadcastss zmm2, DWORD PTR [rcx + r11] - vfmadd231ps zmm7, zmm2, zmm10 + vfmadd231ps zmm11, zmm2, zmm7 vbroadcastss zmm2, DWORD PTR [rax + r11] - vfmadd231ps zmm8, zmm2, zmm10 + vfmadd231ps zmm12, zmm2, zmm7 vbroadcastss zmm2, DWORD PTR [r15 + r11] - vfmadd231ps zmm9, zmm2, zmm10 + vfmadd231ps zmm13, zmm2, zmm7 vbroadcastss zmm2, DWORD PTR [r14 + r11] - vfmadd231ps zmm14, zmm2, zmm10 + vfmadd231ps zmm14, zmm2, zmm7 vbroadcastss zmm2, DWORD PTR [r12 + r11] - vfmadd231ps zmm15, zmm2, zmm10 + vfmadd231ps zmm15, zmm2, zmm7 vbroadcastss zmm2, DWORD PTR [r10 + r11] - vfmadd231ps zmm16, zmm2, zmm10 + vfmadd231ps zmm16, zmm2, zmm7 vbroadcastss zmm2, DWORD PTR [r13 + r11] - vfmadd231ps zmm17, zmm2, zmm10 + vfmadd231ps zmm17, zmm2, zmm7 vbroadcastss zmm2, DWORD PTR [rbx + r11] - vfmadd231ps zmm18, zmm2, zmm10 + vfmadd231ps zmm18, zmm2, zmm7 add r11, 4 cmp rdx, r11 jne inner_loop inner_loop_end: # Min/max clamping. - vminps zmm7, zmm1, zmm7 - vminps zmm8, zmm1, zmm8 - vminps zmm9, zmm1, zmm9 + vminps zmm11, zmm1, zmm11 + vminps zmm12, zmm1, zmm12 + vminps zmm13, zmm1, zmm13 vminps zmm14, zmm1, zmm14 vminps zmm15, zmm1, zmm15 vminps zmm16, zmm1, zmm16 vminps zmm17, zmm1, zmm17 vminps zmm18, zmm1, zmm18 - vmaxps zmm7, zmm0, zmm7 - vmaxps zmm8, zmm0, zmm8 - vmaxps zmm9, zmm0, zmm9 + vmaxps zmm11, zmm0, zmm11 + vmaxps zmm12, zmm0, zmm12 + vmaxps zmm13, zmm0, zmm13 vmaxps zmm14, zmm0, zmm14 vmaxps zmm15, zmm0, zmm15 vmaxps zmm16, zmm0, zmm16 @@ -207,9 +207,9 @@ inner_loop_end: cmp rsi, 16 jl tail - vmovups [rcx], zmm7 - vmovups [rax], zmm8 - vmovups [r15], zmm9 + vmovups [rcx], zmm11 + vmovups [rax], zmm12 + vmovups [r15], zmm13 vmovups [r14], zmm14 vmovups [r12], zmm15 vmovups [r10], zmm16 @@ -243,9 +243,9 @@ tail: shlx r11, r11, rsi not r11 kmovw k1, r11d - vmovups ZMMWORD PTR [rcx]{k1}, zmm7 - vmovups ZMMWORD PTR [rax]{k1}, zmm8 - vmovups ZMMWORD PTR [r15]{k1}, zmm9 + vmovups ZMMWORD PTR [rcx]{k1}, zmm11 + vmovups ZMMWORD PTR [rax]{k1}, zmm12 + vmovups ZMMWORD PTR [r15]{k1}, zmm13 vmovups ZMMWORD PTR [r14]{k1}, zmm14 vmovups ZMMWORD PTR [r12]{k1}, zmm15 vmovups ZMMWORD PTR [r10]{k1}, zmm16 diff --git a/src/f32-gemm/gen/f32-gemm-8x16c2-minmax-asm-amd64-avx512f-broadcast.S b/src/f32-gemm/gen/f32-gemm-8x16c2-minmax-asm-amd64-avx512f-broadcast.S new file mode 100644 index 00000000000..61df00ab0c1 --- /dev/null +++ b/src/f32-gemm/gen/f32-gemm-8x16c2-minmax-asm-amd64-avx512f-broadcast.S @@ -0,0 +1,404 @@ +// Copyright 2025 Google LLC +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + +#include "xnnpack/assembly.h" +.PERMUTATION: + .long 0 + .long 2 + .long 4 + .long 6 + .long 8 + .long 10 + .long 12 + .long 14 + .long 16 + .long 18 + .long 20 + .long 22 + .long 24 + .long 26 + .long 28 + .long 30 + +BEGIN_FUNCTION xnn_f32_gemm_minmax_ukernel_8x16c2__asm_amd64_avx512f_broadcast + + .intel_syntax noprefix + + # Free up GP registers. + push rbx + push rbp + push r15 + push r14 + push r13 + push r12 + + # load params to free up a GP registers + mov r13, [rsp + 80] # params + vbroadcastss zmm0, DWORD PTR [r13] + vbroadcastss zmm1, DWORD PTR [r13 + 4] + + # Load c pointer. + mov r10, [rsp + 56] + # Load cm_stride. + mov r11, [rsp + 64] + + # Align the stack pointer. + mov r13, rsp + sub rsp, 64 + and rsp, 0xFFFFFFFFFFFFFFC0 + # Store the old stack pointer containing the return address + mov [rsp], r13 + + # Allocate some space on the stack. + sub rsp, 192 + # Write rsi (a pointer) to the stack as we need the register. + mov [rsp + 16], rcx + # Write r10 (c pointer) to the stack as we need the register. + mov [rsp + 24], r10 + + # Clamp a & c pointers if mr <= 1 + mov rax, rcx + add rax, r8 + mov r13, r10 + add r13, r11 + cmp rdi, 1 + cmovle rax, rcx + cmovle r13, r10 + + mov [rsp + 32], rax + mov [rsp + 40], r13 + + # Clamp a & c pointers if mr <= 2 + mov rcx, rax + add rcx, r8 + mov r10, r13 + add r10, r11 + cmp rdi, 2 + cmovle rcx, rax + cmovle r10, r13 + + mov [rsp + 48], rcx + mov [rsp + 56], r10 + + # Clamp a & c pointers if mr <= 3 + mov rax, rcx + add rax, r8 + mov r13, r10 + add r13, r11 + cmp rdi, 3 + cmovle rax, rcx + cmovle r13, r10 + + mov [rsp + 64], rax + mov [rsp + 72], r13 + + # Clamp a & c pointers if mr <= 4 + mov rcx, rax + add rcx, r8 + mov r10, r13 + add r10, r11 + cmp rdi, 4 + cmovle rcx, rax + cmovle r10, r13 + + mov [rsp + 80], rcx + mov [rsp + 88], r10 + + # Clamp a & c pointers if mr <= 5 + mov rax, rcx + add rax, r8 + mov r13, r10 + add r13, r11 + cmp rdi, 5 + cmovle rax, rcx + cmovle r13, r10 + + mov [rsp + 96], rax + mov [rsp + 104], r13 + + # Clamp a & c pointers if mr <= 6 + mov rcx, rax + add rcx, r8 + mov r10, r13 + add r10, r11 + cmp rdi, 6 + cmovle rcx, rax + cmovle r10, r13 + + mov [rsp + 112], rcx + mov [rsp + 120], r10 + + # Clamp a & c pointers if mr <= 7 + mov rax, rcx + add rax, r8 + mov r13, r10 + add r13, r11 + cmp rdi, 7 + cmovle rax, rcx + cmovle r13, r10 + + mov [rsp + 128], rax + mov [rsp + 136], r13 + + # Copy k and flip bit. + mov r11, rdx + and r11, 0x4 + and rdx, 0xFFFFFFFFFFFFFFFB + mov [rsp + 152], r11 + mov r11, 0x5555 + kmovw k3, r11d + +outer_loop: + # Initialize k counter. + mov r11, 0 + # Read a pointers from stack into GP registers. + mov rcx, [rsp + 16] + mov rax, [rsp + 32] + mov r15, [rsp + 48] + mov r14, [rsp + 64] + mov r12, [rsp + 80] + mov r10, [rsp + 96] + mov r13, [rsp + 112] + mov rbx, [rsp + 128] + + vmovaps zmm7, [r9 + 0] + # Interleave with zeros. + vpmovzxdq zmm11, ymm7 + vextracti64x4 ymm7, zmm7, 1 + vpmovzxdq zmm19, ymm7 + vmovaps zmm12, zmm11 + vmovaps zmm13, zmm11 + vmovaps zmm14, zmm11 + vmovaps zmm15, zmm11 + vmovaps zmm16, zmm11 + vmovaps zmm17, zmm11 + vmovaps zmm18, zmm11 + vmovaps zmm20, zmm19 + vmovaps zmm21, zmm19 + vmovaps zmm22, zmm19 + vmovaps zmm23, zmm19 + vmovaps zmm24, zmm19 + vmovaps zmm25, zmm19 + vmovaps zmm26, zmm19 + add r9, 64 + + # Are there at least 8 bytes? + cmp rdx, 8 + js inner_loop_tail + +inner_loop: + vmovaps zmm7, [r9 + 0] + vmovaps zmm8, [r9 + 64] + add r9, 128 + vbroadcastsd zmm2, QWORD PTR [rcx + r11] + vfmadd231ps zmm11, zmm2, zmm7 + vfmadd231ps zmm19, zmm2, zmm8 + vbroadcastsd zmm2, QWORD PTR [rax + r11] + vfmadd231ps zmm12, zmm2, zmm7 + vfmadd231ps zmm20, zmm2, zmm8 + vbroadcastsd zmm2, QWORD PTR [r15 + r11] + vfmadd231ps zmm13, zmm2, zmm7 + vfmadd231ps zmm21, zmm2, zmm8 + vbroadcastsd zmm2, QWORD PTR [r14 + r11] + vfmadd231ps zmm14, zmm2, zmm7 + vfmadd231ps zmm22, zmm2, zmm8 + vbroadcastsd zmm2, QWORD PTR [r12 + r11] + vfmadd231ps zmm15, zmm2, zmm7 + vfmadd231ps zmm23, zmm2, zmm8 + vbroadcastsd zmm2, QWORD PTR [r10 + r11] + vfmadd231ps zmm16, zmm2, zmm7 + vfmadd231ps zmm24, zmm2, zmm8 + vbroadcastsd zmm2, QWORD PTR [r13 + r11] + vfmadd231ps zmm17, zmm2, zmm7 + vfmadd231ps zmm25, zmm2, zmm8 + vbroadcastsd zmm2, QWORD PTR [rbx + r11] + vfmadd231ps zmm18, zmm2, zmm7 + vfmadd231ps zmm26, zmm2, zmm8 + + add r11, 8 + cmp rdx, r11 + jne inner_loop + + # Store nc_register. + mov [rsp + 160], rsi + # Load odd k bit. + mov rsi, [rsp + 152] + # Check if channels are odd. + test rsi, rsi + mov rsi, [rsp + 160] + jz inner_loop_end + +inner_loop_tail: + vmovaps zmm7, [r9 + 0] + vmovaps zmm8, [r9 + 64] + add r9, 128 + vbroadcastsd zmm2, QWORD PTR [rcx + r11] + vfmadd231ps zmm11{k3}, zmm2, zmm7 + vfmadd231ps zmm19{k3}, zmm2, zmm8 + vbroadcastsd zmm2, QWORD PTR [rax + r11] + vfmadd231ps zmm12{k3}, zmm2, zmm7 + vfmadd231ps zmm20{k3}, zmm2, zmm8 + vbroadcastsd zmm2, QWORD PTR [r15 + r11] + vfmadd231ps zmm13{k3}, zmm2, zmm7 + vfmadd231ps zmm21{k3}, zmm2, zmm8 + vbroadcastsd zmm2, QWORD PTR [r14 + r11] + vfmadd231ps zmm14{k3}, zmm2, zmm7 + vfmadd231ps zmm22{k3}, zmm2, zmm8 + vbroadcastsd zmm2, QWORD PTR [r12 + r11] + vfmadd231ps zmm15{k3}, zmm2, zmm7 + vfmadd231ps zmm23{k3}, zmm2, zmm8 + vbroadcastsd zmm2, QWORD PTR [r10 + r11] + vfmadd231ps zmm16{k3}, zmm2, zmm7 + vfmadd231ps zmm24{k3}, zmm2, zmm8 + vbroadcastsd zmm2, QWORD PTR [r13 + r11] + vfmadd231ps zmm17{k3}, zmm2, zmm7 + vfmadd231ps zmm25{k3}, zmm2, zmm8 + vbroadcastsd zmm2, QWORD PTR [rbx + r11] + vfmadd231ps zmm18{k3}, zmm2, zmm7 + vfmadd231ps zmm26{k3}, zmm2, zmm8 +inner_loop_end: + vpsrlq zmm7, zmm11, 32 + vaddps zmm11, zmm11, zmm7 + vpsrlq zmm7, zmm12, 32 + vaddps zmm12, zmm12, zmm7 + vpsrlq zmm7, zmm13, 32 + vaddps zmm13, zmm13, zmm7 + vpsrlq zmm7, zmm14, 32 + vaddps zmm14, zmm14, zmm7 + vpsrlq zmm7, zmm15, 32 + vaddps zmm15, zmm15, zmm7 + vpsrlq zmm7, zmm16, 32 + vaddps zmm16, zmm16, zmm7 + vpsrlq zmm7, zmm17, 32 + vaddps zmm17, zmm17, zmm7 + vpsrlq zmm7, zmm18, 32 + vaddps zmm18, zmm18, zmm7 + vpsrlq zmm7, zmm19, 32 + vaddps zmm19, zmm19, zmm7 + vpsrlq zmm7, zmm20, 32 + vaddps zmm20, zmm20, zmm7 + vpsrlq zmm7, zmm21, 32 + vaddps zmm21, zmm21, zmm7 + vpsrlq zmm7, zmm22, 32 + vaddps zmm22, zmm22, zmm7 + vpsrlq zmm7, zmm23, 32 + vaddps zmm23, zmm23, zmm7 + vpsrlq zmm7, zmm24, 32 + vaddps zmm24, zmm24, zmm7 + vpsrlq zmm7, zmm25, 32 + vaddps zmm25, zmm25, zmm7 + vpsrlq zmm7, zmm26, 32 + vaddps zmm26, zmm26, zmm7 + vmovups zmm7, zmmword ptr [rip + .PERMUTATION] + vpermt2ps zmm11, zmm7, zmm19 + vpermt2ps zmm12, zmm7, zmm20 + vpermt2ps zmm13, zmm7, zmm21 + vpermt2ps zmm14, zmm7, zmm22 + vpermt2ps zmm15, zmm7, zmm23 + vpermt2ps zmm16, zmm7, zmm24 + vpermt2ps zmm17, zmm7, zmm25 + vpermt2ps zmm18, zmm7, zmm26 + # Min/max clamping. + vminps zmm11, zmm1, zmm11 + vminps zmm12, zmm1, zmm12 + vminps zmm13, zmm1, zmm13 + vminps zmm14, zmm1, zmm14 + vminps zmm15, zmm1, zmm15 + vminps zmm16, zmm1, zmm16 + vminps zmm17, zmm1, zmm17 + vminps zmm18, zmm1, zmm18 + vmaxps zmm11, zmm0, zmm11 + vmaxps zmm12, zmm0, zmm12 + vmaxps zmm13, zmm0, zmm13 + vmaxps zmm14, zmm0, zmm14 + vmaxps zmm15, zmm0, zmm15 + vmaxps zmm16, zmm0, zmm16 + vmaxps zmm17, zmm0, zmm17 + vmaxps zmm18, zmm0, zmm18 + + # Pop output pointers from the stack. + mov rcx, [rsp + 24] + mov rax, [rsp + 40] + mov r15, [rsp + 56] + mov r14, [rsp + 72] + mov r12, [rsp + 88] + mov r10, [rsp + 104] + mov r13, [rsp + 120] + mov rbx, [rsp + 136] + + # Check whether full or partial store. + cmp rsi, 16 + jl tail + + vmovups [rcx], zmm11 + vmovups [rax], zmm12 + vmovups [r15], zmm13 + vmovups [r14], zmm14 + vmovups [r12], zmm15 + vmovups [r10], zmm16 + vmovups [r13], zmm17 + vmovups [rbx], zmm18 + add rcx, 64 + add rax, 64 + add r15, 64 + add r14, 64 + add r12, 64 + add r10, 64 + add r13, 64 + add rbx, 64 + + # Write output pointers to the stack. + mov [rsp + 24], rcx + mov [rsp + 40], rax + mov [rsp + 56], r15 + mov [rsp + 72], r14 + mov [rsp + 88], r12 + mov [rsp + 104], r10 + mov [rsp + 120], r13 + mov [rsp + 136], rbx + + sub rsi, 16 + jne outer_loop + jmp return + +tail: + mov r11, -1 + shlx r11, r11, rsi + not r11 + kmovw k1, r11d + vmovups ZMMWORD PTR [rcx]{k1}, zmm11 + vmovups ZMMWORD PTR [rax]{k1}, zmm12 + vmovups ZMMWORD PTR [r15]{k1}, zmm13 + vmovups ZMMWORD PTR [r14]{k1}, zmm14 + vmovups ZMMWORD PTR [r12]{k1}, zmm15 + vmovups ZMMWORD PTR [r10]{k1}, zmm16 + vmovups ZMMWORD PTR [r13]{k1}, zmm17 + vmovups ZMMWORD PTR [rbx]{k1}, zmm18 + +return: + add rsp, 192 + mov r13, [rsp] + mov rsp, r13 + # Restore the callee saved registers. + pop r12 + pop r13 + pop r14 + pop r15 + pop rbp + pop rbx + ret +END_FUNCTION xnn_f32_gemm_minmax_ukernel_8x16c2__asm_amd64_avx512f_broadcast + + #ifdef __has_feature + #if __has_feature(dataflow_sanitizer) +BEGIN_FUNCTION xnn_f32_gemm_minmax_ukernel_8x16c2__asm_amd64_avx512f_broadcast.dfsan + .intel_syntax noprefix + # We could implement this by calling a function that implements the dfsan instrumentation. + # For now, just break, so if someone tries to use this, they'll know where the problem is. + int 3 + ret +END_FUNCTION xnn_f32_gemm_minmax_ukernel_8x16c2__asm_amd64_avx512f_broadcast.dfsan + #endif + #endif \ No newline at end of file diff --git a/src/f32-gemm/gen/f32-gemm-8x32-minmax-asm-amd64-avx512f-broadcast.S b/src/f32-gemm/gen/f32-gemm-8x32-minmax-asm-amd64-avx512f-broadcast.S index b8f108d8116..644883ba7a2 100644 --- a/src/f32-gemm/gen/f32-gemm-8x32-minmax-asm-amd64-avx512f-broadcast.S +++ b/src/f32-gemm/gen/f32-gemm-8x32-minmax-asm-amd64-avx512f-broadcast.S @@ -141,15 +141,15 @@ outer_loop: mov rbx, [rsp + 128] # Initialize accumulators with the biases. - vmovaps zmm7, [r9 + 0] + vmovaps zmm11, [r9 + 0] vmovaps zmm19, [r9 + 64] - vmovaps zmm8, zmm7 - vmovaps zmm9, zmm7 - vmovaps zmm14, zmm7 - vmovaps zmm15, zmm7 - vmovaps zmm16, zmm7 - vmovaps zmm17, zmm7 - vmovaps zmm18, zmm7 + vmovaps zmm12, zmm11 + vmovaps zmm13, zmm11 + vmovaps zmm14, zmm11 + vmovaps zmm15, zmm11 + vmovaps zmm16, zmm11 + vmovaps zmm17, zmm11 + vmovaps zmm18, zmm11 vmovaps zmm20, zmm19 vmovaps zmm21, zmm19 vmovaps zmm22, zmm19 @@ -160,42 +160,42 @@ outer_loop: add r9, 128 inner_loop: - vmovaps zmm10, [r9 + 0] - vmovaps zmm11, [r9 + 64] + vmovaps zmm7, [r9 + 0] + vmovaps zmm8, [r9 + 64] add r9, 128 vbroadcastss zmm2, DWORD PTR [rcx + r11] - vfmadd231ps zmm7, zmm2, zmm10 - vfmadd231ps zmm19, zmm2, zmm11 + vfmadd231ps zmm11, zmm2, zmm7 + vfmadd231ps zmm19, zmm2, zmm8 vbroadcastss zmm2, DWORD PTR [rax + r11] - vfmadd231ps zmm8, zmm2, zmm10 - vfmadd231ps zmm20, zmm2, zmm11 + vfmadd231ps zmm12, zmm2, zmm7 + vfmadd231ps zmm20, zmm2, zmm8 vbroadcastss zmm2, DWORD PTR [r15 + r11] - vfmadd231ps zmm9, zmm2, zmm10 - vfmadd231ps zmm21, zmm2, zmm11 + vfmadd231ps zmm13, zmm2, zmm7 + vfmadd231ps zmm21, zmm2, zmm8 vbroadcastss zmm2, DWORD PTR [r14 + r11] - vfmadd231ps zmm14, zmm2, zmm10 - vfmadd231ps zmm22, zmm2, zmm11 + vfmadd231ps zmm14, zmm2, zmm7 + vfmadd231ps zmm22, zmm2, zmm8 vbroadcastss zmm2, DWORD PTR [r12 + r11] - vfmadd231ps zmm15, zmm2, zmm10 - vfmadd231ps zmm23, zmm2, zmm11 + vfmadd231ps zmm15, zmm2, zmm7 + vfmadd231ps zmm23, zmm2, zmm8 vbroadcastss zmm2, DWORD PTR [r10 + r11] - vfmadd231ps zmm16, zmm2, zmm10 - vfmadd231ps zmm24, zmm2, zmm11 + vfmadd231ps zmm16, zmm2, zmm7 + vfmadd231ps zmm24, zmm2, zmm8 vbroadcastss zmm2, DWORD PTR [r13 + r11] - vfmadd231ps zmm17, zmm2, zmm10 - vfmadd231ps zmm25, zmm2, zmm11 + vfmadd231ps zmm17, zmm2, zmm7 + vfmadd231ps zmm25, zmm2, zmm8 vbroadcastss zmm2, DWORD PTR [rbx + r11] - vfmadd231ps zmm18, zmm2, zmm10 - vfmadd231ps zmm26, zmm2, zmm11 + vfmadd231ps zmm18, zmm2, zmm7 + vfmadd231ps zmm26, zmm2, zmm8 add r11, 4 cmp rdx, r11 jne inner_loop inner_loop_end: # Min/max clamping. - vminps zmm7, zmm1, zmm7 - vminps zmm8, zmm1, zmm8 - vminps zmm9, zmm1, zmm9 + vminps zmm11, zmm1, zmm11 + vminps zmm12, zmm1, zmm12 + vminps zmm13, zmm1, zmm13 vminps zmm14, zmm1, zmm14 vminps zmm15, zmm1, zmm15 vminps zmm16, zmm1, zmm16 @@ -209,9 +209,9 @@ inner_loop_end: vminps zmm24, zmm1, zmm24 vminps zmm25, zmm1, zmm25 vminps zmm26, zmm1, zmm26 - vmaxps zmm7, zmm0, zmm7 - vmaxps zmm8, zmm0, zmm8 - vmaxps zmm9, zmm0, zmm9 + vmaxps zmm11, zmm0, zmm11 + vmaxps zmm12, zmm0, zmm12 + vmaxps zmm13, zmm0, zmm13 vmaxps zmm14, zmm0, zmm14 vmaxps zmm15, zmm0, zmm15 vmaxps zmm16, zmm0, zmm16 @@ -240,11 +240,11 @@ inner_loop_end: cmp rsi, 32 jl tail - vmovups [rcx], zmm7 + vmovups [rcx], zmm11 vmovups [rcx + 64], zmm19 - vmovups [rax], zmm8 + vmovups [rax], zmm12 vmovups [rax + 64], zmm20 - vmovups [r15], zmm9 + vmovups [r15], zmm13 vmovups [r15 + 64], zmm21 vmovups [r14], zmm14 vmovups [r14 + 64], zmm22 @@ -286,11 +286,11 @@ tail: kmovw k1, r11d shr r11d, 16 kmovw k2, r11d - vmovups ZMMWORD PTR [rcx]{k1}, zmm7 + vmovups ZMMWORD PTR [rcx]{k1}, zmm11 vmovups ZMMWORD PTR [rcx + 64]{k2}, zmm19 - vmovups ZMMWORD PTR [rax]{k1}, zmm8 + vmovups ZMMWORD PTR [rax]{k1}, zmm12 vmovups ZMMWORD PTR [rax + 64]{k2}, zmm20 - vmovups ZMMWORD PTR [r15]{k1}, zmm9 + vmovups ZMMWORD PTR [r15]{k1}, zmm13 vmovups ZMMWORD PTR [r15 + 64]{k2}, zmm21 vmovups ZMMWORD PTR [r14]{k1}, zmm14 vmovups ZMMWORD PTR [r14 + 64]{k2}, zmm22 diff --git a/src/f32-gemm/gen/f32-gemm-9x16-minmax-asm-amd64-avx512f-broadcast.S b/src/f32-gemm/gen/f32-gemm-9x16-minmax-asm-amd64-avx512f-broadcast.S index e3ca086e250..08cc600217c 100644 --- a/src/f32-gemm/gen/f32-gemm-9x16-minmax-asm-amd64-avx512f-broadcast.S +++ b/src/f32-gemm/gen/f32-gemm-9x16-minmax-asm-amd64-avx512f-broadcast.S @@ -154,56 +154,56 @@ outer_loop: mov rbp, [rsp + 144] # Initialize accumulators with the biases. - vmovaps zmm7, [r9 + 0] - vmovaps zmm8, zmm7 - vmovaps zmm9, zmm7 - vmovaps zmm14, zmm7 - vmovaps zmm15, zmm7 - vmovaps zmm16, zmm7 - vmovaps zmm17, zmm7 - vmovaps zmm18, zmm7 - vmovaps zmm19, zmm7 + vmovaps zmm11, [r9 + 0] + vmovaps zmm12, zmm11 + vmovaps zmm13, zmm11 + vmovaps zmm14, zmm11 + vmovaps zmm15, zmm11 + vmovaps zmm16, zmm11 + vmovaps zmm17, zmm11 + vmovaps zmm18, zmm11 + vmovaps zmm19, zmm11 add r9, 64 inner_loop: - vmovaps zmm10, [r9 + 0] + vmovaps zmm7, [r9 + 0] add r9, 64 vbroadcastss zmm2, DWORD PTR [rcx + r11] - vfmadd231ps zmm7, zmm2, zmm10 + vfmadd231ps zmm11, zmm2, zmm7 vbroadcastss zmm2, DWORD PTR [rax + r11] - vfmadd231ps zmm8, zmm2, zmm10 + vfmadd231ps zmm12, zmm2, zmm7 vbroadcastss zmm2, DWORD PTR [r15 + r11] - vfmadd231ps zmm9, zmm2, zmm10 + vfmadd231ps zmm13, zmm2, zmm7 vbroadcastss zmm2, DWORD PTR [r14 + r11] - vfmadd231ps zmm14, zmm2, zmm10 + vfmadd231ps zmm14, zmm2, zmm7 vbroadcastss zmm2, DWORD PTR [r12 + r11] - vfmadd231ps zmm15, zmm2, zmm10 + vfmadd231ps zmm15, zmm2, zmm7 vbroadcastss zmm2, DWORD PTR [r10 + r11] - vfmadd231ps zmm16, zmm2, zmm10 + vfmadd231ps zmm16, zmm2, zmm7 vbroadcastss zmm2, DWORD PTR [r13 + r11] - vfmadd231ps zmm17, zmm2, zmm10 + vfmadd231ps zmm17, zmm2, zmm7 vbroadcastss zmm2, DWORD PTR [rbx + r11] - vfmadd231ps zmm18, zmm2, zmm10 + vfmadd231ps zmm18, zmm2, zmm7 vbroadcastss zmm2, DWORD PTR [rbp + r11] - vfmadd231ps zmm19, zmm2, zmm10 + vfmadd231ps zmm19, zmm2, zmm7 add r11, 4 cmp rdx, r11 jne inner_loop inner_loop_end: # Min/max clamping. - vminps zmm7, zmm1, zmm7 - vminps zmm8, zmm1, zmm8 - vminps zmm9, zmm1, zmm9 + vminps zmm11, zmm1, zmm11 + vminps zmm12, zmm1, zmm12 + vminps zmm13, zmm1, zmm13 vminps zmm14, zmm1, zmm14 vminps zmm15, zmm1, zmm15 vminps zmm16, zmm1, zmm16 vminps zmm17, zmm1, zmm17 vminps zmm18, zmm1, zmm18 vminps zmm19, zmm1, zmm19 - vmaxps zmm7, zmm0, zmm7 - vmaxps zmm8, zmm0, zmm8 - vmaxps zmm9, zmm0, zmm9 + vmaxps zmm11, zmm0, zmm11 + vmaxps zmm12, zmm0, zmm12 + vmaxps zmm13, zmm0, zmm13 vmaxps zmm14, zmm0, zmm14 vmaxps zmm15, zmm0, zmm15 vmaxps zmm16, zmm0, zmm16 @@ -226,9 +226,9 @@ inner_loop_end: cmp rsi, 16 jl tail - vmovups [rcx], zmm7 - vmovups [rax], zmm8 - vmovups [r15], zmm9 + vmovups [rcx], zmm11 + vmovups [rax], zmm12 + vmovups [r15], zmm13 vmovups [r14], zmm14 vmovups [r12], zmm15 vmovups [r10], zmm16 @@ -265,9 +265,9 @@ tail: shlx r11, r11, rsi not r11 kmovw k1, r11d - vmovups ZMMWORD PTR [rcx]{k1}, zmm7 - vmovups ZMMWORD PTR [rax]{k1}, zmm8 - vmovups ZMMWORD PTR [r15]{k1}, zmm9 + vmovups ZMMWORD PTR [rcx]{k1}, zmm11 + vmovups ZMMWORD PTR [rax]{k1}, zmm12 + vmovups ZMMWORD PTR [r15]{k1}, zmm13 vmovups ZMMWORD PTR [r14]{k1}, zmm14 vmovups ZMMWORD PTR [r12]{k1}, zmm15 vmovups ZMMWORD PTR [r10]{k1}, zmm16 diff --git a/src/f32-gemm/gen/f32-gemm-9x16c2-minmax-asm-amd64-avx512f-broadcast.S b/src/f32-gemm/gen/f32-gemm-9x16c2-minmax-asm-amd64-avx512f-broadcast.S new file mode 100644 index 00000000000..8ea685aa04e --- /dev/null +++ b/src/f32-gemm/gen/f32-gemm-9x16c2-minmax-asm-amd64-avx512f-broadcast.S @@ -0,0 +1,437 @@ +// Copyright 2025 Google LLC +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + +#include "xnnpack/assembly.h" +.PERMUTATION: + .long 0 + .long 2 + .long 4 + .long 6 + .long 8 + .long 10 + .long 12 + .long 14 + .long 16 + .long 18 + .long 20 + .long 22 + .long 24 + .long 26 + .long 28 + .long 30 + +BEGIN_FUNCTION xnn_f32_gemm_minmax_ukernel_9x16c2__asm_amd64_avx512f_broadcast + + .intel_syntax noprefix + + # Free up GP registers. + push rbx + push rbp + push r15 + push r14 + push r13 + push r12 + + # load params to free up a GP registers + mov r13, [rsp + 80] # params + vbroadcastss zmm0, DWORD PTR [r13] + vbroadcastss zmm1, DWORD PTR [r13 + 4] + + # Load c pointer. + mov r10, [rsp + 56] + # Load cm_stride. + mov r11, [rsp + 64] + + # Align the stack pointer. + mov r13, rsp + sub rsp, 64 + and rsp, 0xFFFFFFFFFFFFFFC0 + # Store the old stack pointer containing the return address + mov [rsp], r13 + + # Allocate some space on the stack. + sub rsp, 192 + # Write rsi (a pointer) to the stack as we need the register. + mov [rsp + 16], rcx + # Write r10 (c pointer) to the stack as we need the register. + mov [rsp + 24], r10 + + # Clamp a & c pointers if mr <= 1 + mov rax, rcx + add rax, r8 + mov r13, r10 + add r13, r11 + cmp rdi, 1 + cmovle rax, rcx + cmovle r13, r10 + + mov [rsp + 32], rax + mov [rsp + 40], r13 + + # Clamp a & c pointers if mr <= 2 + mov rcx, rax + add rcx, r8 + mov r10, r13 + add r10, r11 + cmp rdi, 2 + cmovle rcx, rax + cmovle r10, r13 + + mov [rsp + 48], rcx + mov [rsp + 56], r10 + + # Clamp a & c pointers if mr <= 3 + mov rax, rcx + add rax, r8 + mov r13, r10 + add r13, r11 + cmp rdi, 3 + cmovle rax, rcx + cmovle r13, r10 + + mov [rsp + 64], rax + mov [rsp + 72], r13 + + # Clamp a & c pointers if mr <= 4 + mov rcx, rax + add rcx, r8 + mov r10, r13 + add r10, r11 + cmp rdi, 4 + cmovle rcx, rax + cmovle r10, r13 + + mov [rsp + 80], rcx + mov [rsp + 88], r10 + + # Clamp a & c pointers if mr <= 5 + mov rax, rcx + add rax, r8 + mov r13, r10 + add r13, r11 + cmp rdi, 5 + cmovle rax, rcx + cmovle r13, r10 + + mov [rsp + 96], rax + mov [rsp + 104], r13 + + # Clamp a & c pointers if mr <= 6 + mov rcx, rax + add rcx, r8 + mov r10, r13 + add r10, r11 + cmp rdi, 6 + cmovle rcx, rax + cmovle r10, r13 + + mov [rsp + 112], rcx + mov [rsp + 120], r10 + + # Clamp a & c pointers if mr <= 7 + mov rax, rcx + add rax, r8 + mov r13, r10 + add r13, r11 + cmp rdi, 7 + cmovle rax, rcx + cmovle r13, r10 + + mov [rsp + 128], rax + mov [rsp + 136], r13 + + # Clamp a & c pointers if mr <= 8 + mov rcx, rax + add rcx, r8 + mov r10, r13 + add r10, r11 + cmp rdi, 8 + cmovle rcx, rax + cmovle r10, r13 + + mov [rsp + 144], rcx + mov [rsp + 152], r10 + + # Copy k and flip bit. + mov r11, rdx + and r11, 0x4 + and rdx, 0xFFFFFFFFFFFFFFFB + mov [rsp + 168], r11 + mov r11, 0x5555 + kmovw k3, r11d + +outer_loop: + # Initialize k counter. + mov r11, 0 + # Read a pointers from stack into GP registers. + mov rcx, [rsp + 16] + mov rax, [rsp + 32] + mov r15, [rsp + 48] + mov r14, [rsp + 64] + mov r12, [rsp + 80] + mov r10, [rsp + 96] + mov r13, [rsp + 112] + mov rbx, [rsp + 128] + mov rbp, [rsp + 144] + + vmovaps zmm7, [r9 + 0] + # Interleave with zeros. + vpmovzxdq zmm11, ymm7 + vextracti64x4 ymm7, zmm7, 1 + vpmovzxdq zmm20, ymm7 + vmovaps zmm12, zmm11 + vmovaps zmm13, zmm11 + vmovaps zmm14, zmm11 + vmovaps zmm15, zmm11 + vmovaps zmm16, zmm11 + vmovaps zmm17, zmm11 + vmovaps zmm18, zmm11 + vmovaps zmm19, zmm11 + vmovaps zmm21, zmm20 + vmovaps zmm22, zmm20 + vmovaps zmm23, zmm20 + vmovaps zmm24, zmm20 + vmovaps zmm25, zmm20 + vmovaps zmm26, zmm20 + vmovaps zmm27, zmm20 + vmovaps zmm28, zmm20 + add r9, 64 + + # Are there at least 8 bytes? + cmp rdx, 8 + js inner_loop_tail + +inner_loop: + vmovaps zmm7, [r9 + 0] + vmovaps zmm8, [r9 + 64] + add r9, 128 + vbroadcastsd zmm2, QWORD PTR [rcx + r11] + vfmadd231ps zmm11, zmm2, zmm7 + vfmadd231ps zmm20, zmm2, zmm8 + vbroadcastsd zmm2, QWORD PTR [rax + r11] + vfmadd231ps zmm12, zmm2, zmm7 + vfmadd231ps zmm21, zmm2, zmm8 + vbroadcastsd zmm2, QWORD PTR [r15 + r11] + vfmadd231ps zmm13, zmm2, zmm7 + vfmadd231ps zmm22, zmm2, zmm8 + vbroadcastsd zmm2, QWORD PTR [r14 + r11] + vfmadd231ps zmm14, zmm2, zmm7 + vfmadd231ps zmm23, zmm2, zmm8 + vbroadcastsd zmm2, QWORD PTR [r12 + r11] + vfmadd231ps zmm15, zmm2, zmm7 + vfmadd231ps zmm24, zmm2, zmm8 + vbroadcastsd zmm2, QWORD PTR [r10 + r11] + vfmadd231ps zmm16, zmm2, zmm7 + vfmadd231ps zmm25, zmm2, zmm8 + vbroadcastsd zmm2, QWORD PTR [r13 + r11] + vfmadd231ps zmm17, zmm2, zmm7 + vfmadd231ps zmm26, zmm2, zmm8 + vbroadcastsd zmm2, QWORD PTR [rbx + r11] + vfmadd231ps zmm18, zmm2, zmm7 + vfmadd231ps zmm27, zmm2, zmm8 + vbroadcastsd zmm2, QWORD PTR [rbp + r11] + vfmadd231ps zmm19, zmm2, zmm7 + vfmadd231ps zmm28, zmm2, zmm8 + + add r11, 8 + cmp rdx, r11 + jne inner_loop + + # Store nc_register. + mov [rsp + 176], rsi + # Load odd k bit. + mov rsi, [rsp + 168] + # Check if channels are odd. + test rsi, rsi + mov rsi, [rsp + 176] + jz inner_loop_end + +inner_loop_tail: + vmovaps zmm7, [r9 + 0] + vmovaps zmm8, [r9 + 64] + add r9, 128 + vbroadcastsd zmm2, QWORD PTR [rcx + r11] + vfmadd231ps zmm11{k3}, zmm2, zmm7 + vfmadd231ps zmm20{k3}, zmm2, zmm8 + vbroadcastsd zmm2, QWORD PTR [rax + r11] + vfmadd231ps zmm12{k3}, zmm2, zmm7 + vfmadd231ps zmm21{k3}, zmm2, zmm8 + vbroadcastsd zmm2, QWORD PTR [r15 + r11] + vfmadd231ps zmm13{k3}, zmm2, zmm7 + vfmadd231ps zmm22{k3}, zmm2, zmm8 + vbroadcastsd zmm2, QWORD PTR [r14 + r11] + vfmadd231ps zmm14{k3}, zmm2, zmm7 + vfmadd231ps zmm23{k3}, zmm2, zmm8 + vbroadcastsd zmm2, QWORD PTR [r12 + r11] + vfmadd231ps zmm15{k3}, zmm2, zmm7 + vfmadd231ps zmm24{k3}, zmm2, zmm8 + vbroadcastsd zmm2, QWORD PTR [r10 + r11] + vfmadd231ps zmm16{k3}, zmm2, zmm7 + vfmadd231ps zmm25{k3}, zmm2, zmm8 + vbroadcastsd zmm2, QWORD PTR [r13 + r11] + vfmadd231ps zmm17{k3}, zmm2, zmm7 + vfmadd231ps zmm26{k3}, zmm2, zmm8 + vbroadcastsd zmm2, QWORD PTR [rbx + r11] + vfmadd231ps zmm18{k3}, zmm2, zmm7 + vfmadd231ps zmm27{k3}, zmm2, zmm8 + vbroadcastsd zmm2, QWORD PTR [rbp + r11] + vfmadd231ps zmm19{k3}, zmm2, zmm7 + vfmadd231ps zmm28{k3}, zmm2, zmm8 +inner_loop_end: + vpsrlq zmm7, zmm11, 32 + vaddps zmm11, zmm11, zmm7 + vpsrlq zmm7, zmm12, 32 + vaddps zmm12, zmm12, zmm7 + vpsrlq zmm7, zmm13, 32 + vaddps zmm13, zmm13, zmm7 + vpsrlq zmm7, zmm14, 32 + vaddps zmm14, zmm14, zmm7 + vpsrlq zmm7, zmm15, 32 + vaddps zmm15, zmm15, zmm7 + vpsrlq zmm7, zmm16, 32 + vaddps zmm16, zmm16, zmm7 + vpsrlq zmm7, zmm17, 32 + vaddps zmm17, zmm17, zmm7 + vpsrlq zmm7, zmm18, 32 + vaddps zmm18, zmm18, zmm7 + vpsrlq zmm7, zmm19, 32 + vaddps zmm19, zmm19, zmm7 + vpsrlq zmm7, zmm20, 32 + vaddps zmm20, zmm20, zmm7 + vpsrlq zmm7, zmm21, 32 + vaddps zmm21, zmm21, zmm7 + vpsrlq zmm7, zmm22, 32 + vaddps zmm22, zmm22, zmm7 + vpsrlq zmm7, zmm23, 32 + vaddps zmm23, zmm23, zmm7 + vpsrlq zmm7, zmm24, 32 + vaddps zmm24, zmm24, zmm7 + vpsrlq zmm7, zmm25, 32 + vaddps zmm25, zmm25, zmm7 + vpsrlq zmm7, zmm26, 32 + vaddps zmm26, zmm26, zmm7 + vpsrlq zmm7, zmm27, 32 + vaddps zmm27, zmm27, zmm7 + vpsrlq zmm7, zmm28, 32 + vaddps zmm28, zmm28, zmm7 + vmovups zmm7, zmmword ptr [rip + .PERMUTATION] + vpermt2ps zmm11, zmm7, zmm20 + vpermt2ps zmm12, zmm7, zmm21 + vpermt2ps zmm13, zmm7, zmm22 + vpermt2ps zmm14, zmm7, zmm23 + vpermt2ps zmm15, zmm7, zmm24 + vpermt2ps zmm16, zmm7, zmm25 + vpermt2ps zmm17, zmm7, zmm26 + vpermt2ps zmm18, zmm7, zmm27 + vpermt2ps zmm19, zmm7, zmm28 + # Min/max clamping. + vminps zmm11, zmm1, zmm11 + vminps zmm12, zmm1, zmm12 + vminps zmm13, zmm1, zmm13 + vminps zmm14, zmm1, zmm14 + vminps zmm15, zmm1, zmm15 + vminps zmm16, zmm1, zmm16 + vminps zmm17, zmm1, zmm17 + vminps zmm18, zmm1, zmm18 + vminps zmm19, zmm1, zmm19 + vmaxps zmm11, zmm0, zmm11 + vmaxps zmm12, zmm0, zmm12 + vmaxps zmm13, zmm0, zmm13 + vmaxps zmm14, zmm0, zmm14 + vmaxps zmm15, zmm0, zmm15 + vmaxps zmm16, zmm0, zmm16 + vmaxps zmm17, zmm0, zmm17 + vmaxps zmm18, zmm0, zmm18 + vmaxps zmm19, zmm0, zmm19 + + # Pop output pointers from the stack. + mov rcx, [rsp + 24] + mov rax, [rsp + 40] + mov r15, [rsp + 56] + mov r14, [rsp + 72] + mov r12, [rsp + 88] + mov r10, [rsp + 104] + mov r13, [rsp + 120] + mov rbx, [rsp + 136] + mov rbp, [rsp + 152] + + # Check whether full or partial store. + cmp rsi, 16 + jl tail + + vmovups [rcx], zmm11 + vmovups [rax], zmm12 + vmovups [r15], zmm13 + vmovups [r14], zmm14 + vmovups [r12], zmm15 + vmovups [r10], zmm16 + vmovups [r13], zmm17 + vmovups [rbx], zmm18 + vmovups [rbp], zmm19 + add rcx, 64 + add rax, 64 + add r15, 64 + add r14, 64 + add r12, 64 + add r10, 64 + add r13, 64 + add rbx, 64 + add rbp, 64 + + # Write output pointers to the stack. + mov [rsp + 24], rcx + mov [rsp + 40], rax + mov [rsp + 56], r15 + mov [rsp + 72], r14 + mov [rsp + 88], r12 + mov [rsp + 104], r10 + mov [rsp + 120], r13 + mov [rsp + 136], rbx + mov [rsp + 152], rbp + + sub rsi, 16 + jne outer_loop + jmp return + +tail: + mov r11, -1 + shlx r11, r11, rsi + not r11 + kmovw k1, r11d + vmovups ZMMWORD PTR [rcx]{k1}, zmm11 + vmovups ZMMWORD PTR [rax]{k1}, zmm12 + vmovups ZMMWORD PTR [r15]{k1}, zmm13 + vmovups ZMMWORD PTR [r14]{k1}, zmm14 + vmovups ZMMWORD PTR [r12]{k1}, zmm15 + vmovups ZMMWORD PTR [r10]{k1}, zmm16 + vmovups ZMMWORD PTR [r13]{k1}, zmm17 + vmovups ZMMWORD PTR [rbx]{k1}, zmm18 + vmovups ZMMWORD PTR [rbp]{k1}, zmm19 + +return: + add rsp, 192 + mov r13, [rsp] + mov rsp, r13 + # Restore the callee saved registers. + pop r12 + pop r13 + pop r14 + pop r15 + pop rbp + pop rbx + ret +END_FUNCTION xnn_f32_gemm_minmax_ukernel_9x16c2__asm_amd64_avx512f_broadcast + + #ifdef __has_feature + #if __has_feature(dataflow_sanitizer) +BEGIN_FUNCTION xnn_f32_gemm_minmax_ukernel_9x16c2__asm_amd64_avx512f_broadcast.dfsan + .intel_syntax noprefix + # We could implement this by calling a function that implements the dfsan instrumentation. + # For now, just break, so if someone tries to use this, they'll know where the problem is. + int 3 + ret +END_FUNCTION xnn_f32_gemm_minmax_ukernel_9x16c2__asm_amd64_avx512f_broadcast.dfsan + #endif + #endif \ No newline at end of file diff --git a/src/f32-gemm/gen/f32-gemm-9x32-minmax-asm-amd64-avx512f-broadcast.S b/src/f32-gemm/gen/f32-gemm-9x32-minmax-asm-amd64-avx512f-broadcast.S index 05aeb9a282e..85acfa7e95e 100644 --- a/src/f32-gemm/gen/f32-gemm-9x32-minmax-asm-amd64-avx512f-broadcast.S +++ b/src/f32-gemm/gen/f32-gemm-9x32-minmax-asm-amd64-avx512f-broadcast.S @@ -154,16 +154,16 @@ outer_loop: mov rbp, [rsp + 144] # Initialize accumulators with the biases. - vmovaps zmm7, [r9 + 0] + vmovaps zmm11, [r9 + 0] vmovaps zmm20, [r9 + 64] - vmovaps zmm8, zmm7 - vmovaps zmm9, zmm7 - vmovaps zmm14, zmm7 - vmovaps zmm15, zmm7 - vmovaps zmm16, zmm7 - vmovaps zmm17, zmm7 - vmovaps zmm18, zmm7 - vmovaps zmm19, zmm7 + vmovaps zmm12, zmm11 + vmovaps zmm13, zmm11 + vmovaps zmm14, zmm11 + vmovaps zmm15, zmm11 + vmovaps zmm16, zmm11 + vmovaps zmm17, zmm11 + vmovaps zmm18, zmm11 + vmovaps zmm19, zmm11 vmovaps zmm21, zmm20 vmovaps zmm22, zmm20 vmovaps zmm23, zmm20 @@ -175,45 +175,45 @@ outer_loop: add r9, 128 inner_loop: - vmovaps zmm10, [r9 + 0] - vmovaps zmm11, [r9 + 64] + vmovaps zmm7, [r9 + 0] + vmovaps zmm8, [r9 + 64] add r9, 128 vbroadcastss zmm2, DWORD PTR [rcx + r11] - vfmadd231ps zmm7, zmm2, zmm10 - vfmadd231ps zmm20, zmm2, zmm11 + vfmadd231ps zmm11, zmm2, zmm7 + vfmadd231ps zmm20, zmm2, zmm8 vbroadcastss zmm2, DWORD PTR [rax + r11] - vfmadd231ps zmm8, zmm2, zmm10 - vfmadd231ps zmm21, zmm2, zmm11 + vfmadd231ps zmm12, zmm2, zmm7 + vfmadd231ps zmm21, zmm2, zmm8 vbroadcastss zmm2, DWORD PTR [r15 + r11] - vfmadd231ps zmm9, zmm2, zmm10 - vfmadd231ps zmm22, zmm2, zmm11 + vfmadd231ps zmm13, zmm2, zmm7 + vfmadd231ps zmm22, zmm2, zmm8 vbroadcastss zmm2, DWORD PTR [r14 + r11] - vfmadd231ps zmm14, zmm2, zmm10 - vfmadd231ps zmm23, zmm2, zmm11 + vfmadd231ps zmm14, zmm2, zmm7 + vfmadd231ps zmm23, zmm2, zmm8 vbroadcastss zmm2, DWORD PTR [r12 + r11] - vfmadd231ps zmm15, zmm2, zmm10 - vfmadd231ps zmm24, zmm2, zmm11 + vfmadd231ps zmm15, zmm2, zmm7 + vfmadd231ps zmm24, zmm2, zmm8 vbroadcastss zmm2, DWORD PTR [r10 + r11] - vfmadd231ps zmm16, zmm2, zmm10 - vfmadd231ps zmm25, zmm2, zmm11 + vfmadd231ps zmm16, zmm2, zmm7 + vfmadd231ps zmm25, zmm2, zmm8 vbroadcastss zmm2, DWORD PTR [r13 + r11] - vfmadd231ps zmm17, zmm2, zmm10 - vfmadd231ps zmm26, zmm2, zmm11 + vfmadd231ps zmm17, zmm2, zmm7 + vfmadd231ps zmm26, zmm2, zmm8 vbroadcastss zmm2, DWORD PTR [rbx + r11] - vfmadd231ps zmm18, zmm2, zmm10 - vfmadd231ps zmm27, zmm2, zmm11 + vfmadd231ps zmm18, zmm2, zmm7 + vfmadd231ps zmm27, zmm2, zmm8 vbroadcastss zmm2, DWORD PTR [rbp + r11] - vfmadd231ps zmm19, zmm2, zmm10 - vfmadd231ps zmm28, zmm2, zmm11 + vfmadd231ps zmm19, zmm2, zmm7 + vfmadd231ps zmm28, zmm2, zmm8 add r11, 4 cmp rdx, r11 jne inner_loop inner_loop_end: # Min/max clamping. - vminps zmm7, zmm1, zmm7 - vminps zmm8, zmm1, zmm8 - vminps zmm9, zmm1, zmm9 + vminps zmm11, zmm1, zmm11 + vminps zmm12, zmm1, zmm12 + vminps zmm13, zmm1, zmm13 vminps zmm14, zmm1, zmm14 vminps zmm15, zmm1, zmm15 vminps zmm16, zmm1, zmm16 @@ -229,9 +229,9 @@ inner_loop_end: vminps zmm26, zmm1, zmm26 vminps zmm27, zmm1, zmm27 vminps zmm28, zmm1, zmm28 - vmaxps zmm7, zmm0, zmm7 - vmaxps zmm8, zmm0, zmm8 - vmaxps zmm9, zmm0, zmm9 + vmaxps zmm11, zmm0, zmm11 + vmaxps zmm12, zmm0, zmm12 + vmaxps zmm13, zmm0, zmm13 vmaxps zmm14, zmm0, zmm14 vmaxps zmm15, zmm0, zmm15 vmaxps zmm16, zmm0, zmm16 @@ -263,11 +263,11 @@ inner_loop_end: cmp rsi, 32 jl tail - vmovups [rcx], zmm7 + vmovups [rcx], zmm11 vmovups [rcx + 64], zmm20 - vmovups [rax], zmm8 + vmovups [rax], zmm12 vmovups [rax + 64], zmm21 - vmovups [r15], zmm9 + vmovups [r15], zmm13 vmovups [r15 + 64], zmm22 vmovups [r14], zmm14 vmovups [r14 + 64], zmm23 @@ -313,11 +313,11 @@ tail: kmovw k1, r11d shr r11d, 16 kmovw k2, r11d - vmovups ZMMWORD PTR [rcx]{k1}, zmm7 + vmovups ZMMWORD PTR [rcx]{k1}, zmm11 vmovups ZMMWORD PTR [rcx + 64]{k2}, zmm20 - vmovups ZMMWORD PTR [rax]{k1}, zmm8 + vmovups ZMMWORD PTR [rax]{k1}, zmm12 vmovups ZMMWORD PTR [rax + 64]{k2}, zmm21 - vmovups ZMMWORD PTR [r15]{k1}, zmm9 + vmovups ZMMWORD PTR [r15]{k1}, zmm13 vmovups ZMMWORD PTR [r15 + 64]{k2}, zmm22 vmovups ZMMWORD PTR [r14]{k1}, zmm14 vmovups ZMMWORD PTR [r14 + 64]{k2}, zmm23 diff --git a/src/operators/convolution-nhwc.c b/src/operators/convolution-nhwc.c index 3fb791c48a3..e35edcae00a 100644 --- a/src/operators/convolution-nhwc.c +++ b/src/operators/convolution-nhwc.c @@ -1824,7 +1824,7 @@ enum xnn_status xnn_create_convolution2d_nhwc_f32( xnn_code_cache_t code_cache, xnn_weights_cache_t weights_cache, xnn_operator_t* convolution_op_out) { - const struct xnn_gemm_config* gemm_config = xnn_init_f32_gemm_config(); + const struct xnn_gemm_config* gemm_config = xnn_init_f32_igemm_config(); if (gemm_config == NULL) { xnn_log_error("failed to create %s operator: unsupported hardware configuration", xnn_operator_type_to_string(xnn_operator_type_convolution_nhwc_f32)); diff --git a/src/operators/deconvolution-nhwc.c b/src/operators/deconvolution-nhwc.c index 21c3066ea21..10bbb7f9c1e 100644 --- a/src/operators/deconvolution-nhwc.c +++ b/src/operators/deconvolution-nhwc.c @@ -1107,7 +1107,7 @@ enum xnn_status xnn_create_deconvolution2d_nhwc_f32( return xnn_status_invalid_parameter; } - const struct xnn_gemm_config* gemm_config = xnn_init_f32_gemm_config(); + const struct xnn_gemm_config* gemm_config = xnn_init_f32_igemm_config(); if (gemm_config == NULL) { xnn_log_error("failed to create %s operator: unsupported hardware configuration", xnn_operator_type_to_string(xnn_operator_type_deconvolution_nhwc_f32)); diff --git a/src/xnnpack/config.h b/src/xnnpack/config.h index 68251b7441c..f822f7df5cf 100644 --- a/src/xnnpack/config.h +++ b/src/xnnpack/config.h @@ -244,6 +244,7 @@ XNN_INTERNAL const struct xnn_gemm_config* xnn_init_bf16_f32_gemm_config(); XNN_INTERNAL const struct xnn_gemm_config* xnn_init_f16_gemm_config(); XNN_INTERNAL const struct xnn_gemm_config* xnn_init_f32_gemm_config(); XNN_INTERNAL const struct xnn_gemm_config* xnn_init_f32_gemm_nr2_config(); +XNN_INTERNAL const struct xnn_gemm_config* xnn_init_f32_igemm_config(); XNN_INTERNAL const struct xnn_gemm_config* xnn_init_f32_qc8w_gemm_config(); XNN_INTERNAL const struct xnn_gemm_config* xnn_init_f32_qc4w_gemm_config(); XNN_INTERNAL const struct xnn_gemm_config* xnn_init_pf16_gemm_config(); diff --git a/src/xnnpack/gemm.h b/src/xnnpack/gemm.h index 21877732c43..75dc40bd686 100644 --- a/src/xnnpack/gemm.h +++ b/src/xnnpack/gemm.h @@ -488,6 +488,24 @@ DECLARE_F32_GEMM_MINMAX_UKERNEL_FUNCTION( DECLARE_F32_GEMM_MINMAX_UKERNEL_FUNCTION( xnn_f32_gemm_minmax_ukernel_5x16__asm_aarch64_neonfma_ld128_2) +DECLARE_F32_GEMM_MINMAX_UKERNEL_FUNCTION(xnn_f32_gemm_minmax_ukernel_1x16c2__asm_amd64_avx512f_broadcast) +DECLARE_F32_GEMM_MINMAX_UKERNEL_FUNCTION(xnn_f32_gemm_minmax_ukernel_2x16c2__asm_amd64_avx512f_broadcast) +DECLARE_F32_GEMM_MINMAX_UKERNEL_FUNCTION(xnn_f32_gemm_minmax_ukernel_3x16c2__asm_amd64_avx512f_broadcast) +DECLARE_F32_GEMM_MINMAX_UKERNEL_FUNCTION(xnn_f32_gemm_minmax_ukernel_4x16c2__asm_amd64_avx512f_broadcast) +DECLARE_F32_GEMM_MINMAX_UKERNEL_FUNCTION(xnn_f32_gemm_minmax_ukernel_5x16c2__asm_amd64_avx512f_broadcast) +DECLARE_F32_GEMM_MINMAX_UKERNEL_FUNCTION(xnn_f32_gemm_minmax_ukernel_5x16c2__asm_amd64_avx512f_broadcast) +DECLARE_F32_GEMM_MINMAX_UKERNEL_FUNCTION(xnn_f32_gemm_minmax_ukernel_6x16c2__asm_amd64_avx512f_broadcast) +DECLARE_F32_GEMM_MINMAX_UKERNEL_FUNCTION(xnn_f32_gemm_minmax_ukernel_7x16c2__asm_amd64_avx512f_broadcast) +DECLARE_F32_GEMM_MINMAX_UKERNEL_FUNCTION(xnn_f32_gemm_minmax_ukernel_8x16c2__asm_amd64_avx512f_broadcast) +DECLARE_F32_GEMM_MINMAX_UKERNEL_FUNCTION(xnn_f32_gemm_minmax_ukernel_9x16c2__asm_amd64_avx512f_broadcast) +DECLARE_F32_GEMM_MINMAX_UKERNEL_FUNCTION(xnn_f32_gemm_minmax_ukernel_10x16c2__asm_amd64_avx512f_broadcast) +DECLARE_F32_GEMM_MINMAX_UKERNEL_FUNCTION(xnn_f32_gemm_minmax_ukernel_11x16c2__asm_amd64_avx512f_broadcast) +DECLARE_F32_GEMM_MINMAX_UKERNEL_FUNCTION(xnn_f32_gemm_minmax_ukernel_1x32c2__asm_amd64_avx512f_broadcast) +DECLARE_F32_GEMM_MINMAX_UKERNEL_FUNCTION(xnn_f32_gemm_minmax_ukernel_2x32c2__asm_amd64_avx512f_broadcast) +DECLARE_F32_GEMM_MINMAX_UKERNEL_FUNCTION(xnn_f32_gemm_minmax_ukernel_3x32c2__asm_amd64_avx512f_broadcast) +DECLARE_F32_GEMM_MINMAX_UKERNEL_FUNCTION(xnn_f32_gemm_minmax_ukernel_4x32c2__asm_amd64_avx512f_broadcast) +DECLARE_F32_GEMM_MINMAX_UKERNEL_FUNCTION(xnn_f32_gemm_minmax_ukernel_5x32c2__asm_amd64_avx512f_broadcast) + DECLARE_F32_GEMM_MINMAX_UKERNEL_FUNCTION(xnn_f32_gemm_minmax_ukernel_1x64__asm_amd64_avx512f_broadcast) DECLARE_F32_GEMM_MINMAX_UKERNEL_FUNCTION(xnn_f32_gemm_minmax_ukernel_2x64__asm_amd64_avx512f_broadcast) DECLARE_F32_GEMM_MINMAX_UKERNEL_FUNCTION(xnn_f32_gemm_minmax_ukernel_3x64__asm_amd64_avx512f_broadcast) diff --git a/test/deconvolution-nhwc.cc b/test/deconvolution-nhwc.cc index 1a27358f2c0..c98d2c29e7d 100644 --- a/test/deconvolution-nhwc.cc +++ b/test/deconvolution-nhwc.cc @@ -13005,7 +13005,7 @@ TEST(DECONVOLUTION_NHWC_F16, kernel_2x2s2_setup_changing_width) { /**************************** Future GEMM path ****************************/ TEST(DECONVOLUTION_NHWC_F32, kernel_1x1) { - const struct xnn_gemm_config* gemm_config = xnn_init_f32_gemm_config(); + const struct xnn_gemm_config* gemm_config = xnn_init_f32_igemm_config(); ASSERT_NE(gemm_config, nullptr); DeconvolutionOperatorTester() .input_size(kUnstridedInputHeight, kUnstridedInputWidth) @@ -13016,7 +13016,7 @@ TEST(DECONVOLUTION_NHWC_F32, kernel_1x1) { } TEST(DECONVOLUTION_NHWC_F32, kernel_1x1_varying_input_width) { - const struct xnn_gemm_config* gemm_config = xnn_init_f32_gemm_config(); + const struct xnn_gemm_config* gemm_config = xnn_init_f32_igemm_config(); ASSERT_NE(gemm_config, nullptr); for (size_t input_height = kUnstridedInputHeight - 2; input_height <= kUnstridedInputHeight + 2; input_height++) { DeconvolutionOperatorTester() @@ -13029,7 +13029,7 @@ TEST(DECONVOLUTION_NHWC_F32, kernel_1x1_varying_input_width) { } TEST(DECONVOLUTION_NHWC_F32, kernel_1x1_varying_input_height) { - const struct xnn_gemm_config* gemm_config = xnn_init_f32_gemm_config(); + const struct xnn_gemm_config* gemm_config = xnn_init_f32_igemm_config(); ASSERT_NE(gemm_config, nullptr); for (size_t input_width = kUnstridedInputWidth - 2; input_width <= kUnstridedInputWidth + 2; input_width++) { DeconvolutionOperatorTester() @@ -13042,7 +13042,7 @@ TEST(DECONVOLUTION_NHWC_F32, kernel_1x1_varying_input_height) { } TEST(DECONVOLUTION_NHWC_F32, kernel_1x1_varying_input_channels) { - const struct xnn_gemm_config* gemm_config = xnn_init_f32_gemm_config(); + const struct xnn_gemm_config* gemm_config = xnn_init_f32_igemm_config(); ASSERT_NE(gemm_config, nullptr); for (size_t input_channels = 1; input_channels <= 16; input_channels *= 4) { DeconvolutionOperatorTester() @@ -13055,7 +13055,7 @@ TEST(DECONVOLUTION_NHWC_F32, kernel_1x1_varying_input_channels) { } TEST(DECONVOLUTION_NHWC_F32, kernel_1x1_varying_output_channels) { - const struct xnn_gemm_config* gemm_config = xnn_init_f32_gemm_config(); + const struct xnn_gemm_config* gemm_config = xnn_init_f32_igemm_config(); ASSERT_NE(gemm_config, nullptr); for (size_t output_channels = 1; output_channels <= gemm_config->nr * 2; output_channels *= 2) { DeconvolutionOperatorTester() @@ -13068,7 +13068,7 @@ TEST(DECONVOLUTION_NHWC_F32, kernel_1x1_varying_output_channels) { } TEST(DECONVOLUTION_NHWC_F32, kernel_1x1_with_input_stride) { - const struct xnn_gemm_config* gemm_config = xnn_init_f32_gemm_config(); + const struct xnn_gemm_config* gemm_config = xnn_init_f32_igemm_config(); ASSERT_NE(gemm_config, nullptr); DeconvolutionOperatorTester() .input_size(kUnstridedInputHeight, kUnstridedInputWidth) @@ -13080,7 +13080,7 @@ TEST(DECONVOLUTION_NHWC_F32, kernel_1x1_with_input_stride) { } TEST(DECONVOLUTION_NHWC_F32, kernel_1x1_with_output_stride) { - const struct xnn_gemm_config* gemm_config = xnn_init_f32_gemm_config(); + const struct xnn_gemm_config* gemm_config = xnn_init_f32_igemm_config(); ASSERT_NE(gemm_config, nullptr); DeconvolutionOperatorTester() .input_size(kUnstridedInputHeight, kUnstridedInputWidth) @@ -13092,7 +13092,7 @@ TEST(DECONVOLUTION_NHWC_F32, kernel_1x1_with_output_stride) { } TEST(DECONVOLUTION_NHWC_F32, kernel_1x1_with_qmin) { - const struct xnn_gemm_config* gemm_config = xnn_init_f32_gemm_config(); + const struct xnn_gemm_config* gemm_config = xnn_init_f32_igemm_config(); ASSERT_NE(gemm_config, nullptr); DeconvolutionOperatorTester() .input_size(kUnstridedInputHeight, kUnstridedInputWidth) @@ -13104,7 +13104,7 @@ TEST(DECONVOLUTION_NHWC_F32, kernel_1x1_with_qmin) { } TEST(DECONVOLUTION_NHWC_F32, kernel_1x1_with_qmax) { - const struct xnn_gemm_config* gemm_config = xnn_init_f32_gemm_config(); + const struct xnn_gemm_config* gemm_config = xnn_init_f32_igemm_config(); ASSERT_NE(gemm_config, nullptr); DeconvolutionOperatorTester() .input_size(kUnstridedInputHeight, kUnstridedInputWidth) @@ -13116,7 +13116,7 @@ TEST(DECONVOLUTION_NHWC_F32, kernel_1x1_with_qmax) { } TEST(DECONVOLUTION_NHWC_F32, kernel_1x1_without_bias) { - const struct xnn_gemm_config* gemm_config = xnn_init_f32_gemm_config(); + const struct xnn_gemm_config* gemm_config = xnn_init_f32_igemm_config(); ASSERT_NE(gemm_config, nullptr); DeconvolutionOperatorTester() .has_bias(false) @@ -13130,7 +13130,7 @@ TEST(DECONVOLUTION_NHWC_F32, kernel_1x1_without_bias) { /**************************** Future GEMM path, grouped ****************************/ TEST(DECONVOLUTION_NHWC_F32, grouped_1x1) { - const struct xnn_gemm_config* gemm_config = xnn_init_f32_gemm_config(); + const struct xnn_gemm_config* gemm_config = xnn_init_f32_igemm_config(); ASSERT_NE(gemm_config, nullptr); DeconvolutionOperatorTester() .input_size(kUnstridedInputHeight, kUnstridedInputWidth) @@ -13142,7 +13142,7 @@ TEST(DECONVOLUTION_NHWC_F32, grouped_1x1) { } TEST(DECONVOLUTION_NHWC_F32, grouped_1x1_varying_input_width) { - const struct xnn_gemm_config* gemm_config = xnn_init_f32_gemm_config(); + const struct xnn_gemm_config* gemm_config = xnn_init_f32_igemm_config(); ASSERT_NE(gemm_config, nullptr); for (size_t input_height = kUnstridedInputHeight - 2; input_height <= kUnstridedInputHeight + 2; input_height++) { DeconvolutionOperatorTester() @@ -13156,7 +13156,7 @@ TEST(DECONVOLUTION_NHWC_F32, grouped_1x1_varying_input_width) { } TEST(DECONVOLUTION_NHWC_F32, grouped_1x1_varying_input_height) { - const struct xnn_gemm_config* gemm_config = xnn_init_f32_gemm_config(); + const struct xnn_gemm_config* gemm_config = xnn_init_f32_igemm_config(); ASSERT_NE(gemm_config, nullptr); for (size_t input_width = kUnstridedInputWidth - 2; input_width <= kUnstridedInputWidth + 2; input_width++) { DeconvolutionOperatorTester() @@ -13170,7 +13170,7 @@ TEST(DECONVOLUTION_NHWC_F32, grouped_1x1_varying_input_height) { } TEST(DECONVOLUTION_NHWC_F32, grouped_1x1_varying_input_channels) { - const struct xnn_gemm_config* gemm_config = xnn_init_f32_gemm_config(); + const struct xnn_gemm_config* gemm_config = xnn_init_f32_igemm_config(); ASSERT_NE(gemm_config, nullptr); for (size_t input_channels = 1; input_channels <= 16; input_channels *= 4) { DeconvolutionOperatorTester() @@ -13184,7 +13184,7 @@ TEST(DECONVOLUTION_NHWC_F32, grouped_1x1_varying_input_channels) { } TEST(DECONVOLUTION_NHWC_F32, grouped_1x1_varying_output_channels) { - const struct xnn_gemm_config* gemm_config = xnn_init_f32_gemm_config(); + const struct xnn_gemm_config* gemm_config = xnn_init_f32_igemm_config(); ASSERT_NE(gemm_config, nullptr); for (size_t output_channels = 1; output_channels <= gemm_config->nr * 2; output_channels *= 2) { DeconvolutionOperatorTester() @@ -13198,7 +13198,7 @@ TEST(DECONVOLUTION_NHWC_F32, grouped_1x1_varying_output_channels) { } TEST(DECONVOLUTION_NHWC_F32, grouped_1x1_with_input_stride) { - const struct xnn_gemm_config* gemm_config = xnn_init_f32_gemm_config(); + const struct xnn_gemm_config* gemm_config = xnn_init_f32_igemm_config(); ASSERT_NE(gemm_config, nullptr); DeconvolutionOperatorTester() .input_size(kUnstridedInputHeight, kUnstridedInputWidth) @@ -13211,7 +13211,7 @@ TEST(DECONVOLUTION_NHWC_F32, grouped_1x1_with_input_stride) { } TEST(DECONVOLUTION_NHWC_F32, grouped_1x1_with_output_stride) { - const struct xnn_gemm_config* gemm_config = xnn_init_f32_gemm_config(); + const struct xnn_gemm_config* gemm_config = xnn_init_f32_igemm_config(); ASSERT_NE(gemm_config, nullptr); DeconvolutionOperatorTester() .input_size(kUnstridedInputHeight, kUnstridedInputWidth) @@ -13224,7 +13224,7 @@ TEST(DECONVOLUTION_NHWC_F32, grouped_1x1_with_output_stride) { } TEST(DECONVOLUTION_NHWC_F32, grouped_1x1_with_qmin) { - const struct xnn_gemm_config* gemm_config = xnn_init_f32_gemm_config(); + const struct xnn_gemm_config* gemm_config = xnn_init_f32_igemm_config(); ASSERT_NE(gemm_config, nullptr); DeconvolutionOperatorTester() .input_size(kUnstridedInputHeight, kUnstridedInputWidth) @@ -13237,7 +13237,7 @@ TEST(DECONVOLUTION_NHWC_F32, grouped_1x1_with_qmin) { } TEST(DECONVOLUTION_NHWC_F32, grouped_1x1_with_qmax) { - const struct xnn_gemm_config* gemm_config = xnn_init_f32_gemm_config(); + const struct xnn_gemm_config* gemm_config = xnn_init_f32_igemm_config(); ASSERT_NE(gemm_config, nullptr); DeconvolutionOperatorTester() .input_size(kUnstridedInputHeight, kUnstridedInputWidth) @@ -13250,7 +13250,7 @@ TEST(DECONVOLUTION_NHWC_F32, grouped_1x1_with_qmax) { } TEST(DECONVOLUTION_NHWC_F32, grouped_1x1_without_bias) { - const struct xnn_gemm_config* gemm_config = xnn_init_f32_gemm_config(); + const struct xnn_gemm_config* gemm_config = xnn_init_f32_igemm_config(); ASSERT_NE(gemm_config, nullptr); DeconvolutionOperatorTester() .has_bias(false) @@ -13265,7 +13265,7 @@ TEST(DECONVOLUTION_NHWC_F32, grouped_1x1_without_bias) { /**************************** Future GEMM path, batched ****************************/ TEST(DECONVOLUTION_NHWC_F32, batched_1x1) { - const struct xnn_gemm_config* gemm_config = xnn_init_f32_gemm_config(); + const struct xnn_gemm_config* gemm_config = xnn_init_f32_igemm_config(); ASSERT_NE(gemm_config, nullptr); DeconvolutionOperatorTester() .batch_size(2) @@ -13277,7 +13277,7 @@ TEST(DECONVOLUTION_NHWC_F32, batched_1x1) { } TEST(DECONVOLUTION_NHWC_F32, batched_1x1_varying_input_width) { - const struct xnn_gemm_config* gemm_config = xnn_init_f32_gemm_config(); + const struct xnn_gemm_config* gemm_config = xnn_init_f32_igemm_config(); ASSERT_NE(gemm_config, nullptr); for (size_t input_height = kUnstridedInputHeight - 2; input_height <= kUnstridedInputHeight + 2; input_height++) { DeconvolutionOperatorTester() @@ -13291,7 +13291,7 @@ TEST(DECONVOLUTION_NHWC_F32, batched_1x1_varying_input_width) { } TEST(DECONVOLUTION_NHWC_F32, batched_1x1_varying_input_height) { - const struct xnn_gemm_config* gemm_config = xnn_init_f32_gemm_config(); + const struct xnn_gemm_config* gemm_config = xnn_init_f32_igemm_config(); ASSERT_NE(gemm_config, nullptr); for (size_t input_width = kUnstridedInputWidth - 2; input_width <= kUnstridedInputWidth + 2; input_width++) { DeconvolutionOperatorTester() @@ -13305,7 +13305,7 @@ TEST(DECONVOLUTION_NHWC_F32, batched_1x1_varying_input_height) { } TEST(DECONVOLUTION_NHWC_F32, batched_1x1_varying_input_channels) { - const struct xnn_gemm_config* gemm_config = xnn_init_f32_gemm_config(); + const struct xnn_gemm_config* gemm_config = xnn_init_f32_igemm_config(); ASSERT_NE(gemm_config, nullptr); for (size_t input_channels = 1; input_channels <= 16; input_channels *= 4) { DeconvolutionOperatorTester() @@ -13319,7 +13319,7 @@ TEST(DECONVOLUTION_NHWC_F32, batched_1x1_varying_input_channels) { } TEST(DECONVOLUTION_NHWC_F32, batched_1x1_varying_output_channels) { - const struct xnn_gemm_config* gemm_config = xnn_init_f32_gemm_config(); + const struct xnn_gemm_config* gemm_config = xnn_init_f32_igemm_config(); ASSERT_NE(gemm_config, nullptr); for (size_t output_channels = 1; output_channels <= gemm_config->nr * 2; output_channels *= 2) { DeconvolutionOperatorTester() @@ -13333,7 +13333,7 @@ TEST(DECONVOLUTION_NHWC_F32, batched_1x1_varying_output_channels) { } TEST(DECONVOLUTION_NHWC_F32, batched_1x1_with_input_stride) { - const struct xnn_gemm_config* gemm_config = xnn_init_f32_gemm_config(); + const struct xnn_gemm_config* gemm_config = xnn_init_f32_igemm_config(); ASSERT_NE(gemm_config, nullptr); DeconvolutionOperatorTester() .batch_size(2) @@ -13346,7 +13346,7 @@ TEST(DECONVOLUTION_NHWC_F32, batched_1x1_with_input_stride) { } TEST(DECONVOLUTION_NHWC_F32, batched_1x1_with_output_stride) { - const struct xnn_gemm_config* gemm_config = xnn_init_f32_gemm_config(); + const struct xnn_gemm_config* gemm_config = xnn_init_f32_igemm_config(); ASSERT_NE(gemm_config, nullptr); DeconvolutionOperatorTester() .batch_size(2) @@ -13359,7 +13359,7 @@ TEST(DECONVOLUTION_NHWC_F32, batched_1x1_with_output_stride) { } TEST(DECONVOLUTION_NHWC_F32, batched_1x1_with_qmin) { - const struct xnn_gemm_config* gemm_config = xnn_init_f32_gemm_config(); + const struct xnn_gemm_config* gemm_config = xnn_init_f32_igemm_config(); ASSERT_NE(gemm_config, nullptr); DeconvolutionOperatorTester() .batch_size(2) @@ -13372,7 +13372,7 @@ TEST(DECONVOLUTION_NHWC_F32, batched_1x1_with_qmin) { } TEST(DECONVOLUTION_NHWC_F32, batched_1x1_with_qmax) { - const struct xnn_gemm_config* gemm_config = xnn_init_f32_gemm_config(); + const struct xnn_gemm_config* gemm_config = xnn_init_f32_igemm_config(); ASSERT_NE(gemm_config, nullptr); DeconvolutionOperatorTester() .batch_size(2) @@ -13385,7 +13385,7 @@ TEST(DECONVOLUTION_NHWC_F32, batched_1x1_with_qmax) { } TEST(DECONVOLUTION_NHWC_F32, batched_1x1_without_bias) { - const struct xnn_gemm_config* gemm_config = xnn_init_f32_gemm_config(); + const struct xnn_gemm_config* gemm_config = xnn_init_f32_igemm_config(); ASSERT_NE(gemm_config, nullptr); DeconvolutionOperatorTester() .has_bias(false) @@ -13400,7 +13400,7 @@ TEST(DECONVOLUTION_NHWC_F32, batched_1x1_without_bias) { /**************************** Future GEMM path, batched, grouped ****************************/ TEST(DECONVOLUTION_NHWC_F32, batched_grouped_1x1) { - const struct xnn_gemm_config* gemm_config = xnn_init_f32_gemm_config(); + const struct xnn_gemm_config* gemm_config = xnn_init_f32_igemm_config(); ASSERT_NE(gemm_config, nullptr); DeconvolutionOperatorTester() .batch_size(2) @@ -13413,7 +13413,7 @@ TEST(DECONVOLUTION_NHWC_F32, batched_grouped_1x1) { } TEST(DECONVOLUTION_NHWC_F32, batched_grouped_1x1_varying_input_width) { - const struct xnn_gemm_config* gemm_config = xnn_init_f32_gemm_config(); + const struct xnn_gemm_config* gemm_config = xnn_init_f32_igemm_config(); ASSERT_NE(gemm_config, nullptr); for (size_t input_height = kUnstridedInputHeight - 2; input_height <= kUnstridedInputHeight + 2; input_height++) { DeconvolutionOperatorTester() @@ -13428,7 +13428,7 @@ TEST(DECONVOLUTION_NHWC_F32, batched_grouped_1x1_varying_input_width) { } TEST(DECONVOLUTION_NHWC_F32, batched_grouped_1x1_varying_input_height) { - const struct xnn_gemm_config* gemm_config = xnn_init_f32_gemm_config(); + const struct xnn_gemm_config* gemm_config = xnn_init_f32_igemm_config(); ASSERT_NE(gemm_config, nullptr); for (size_t input_width = kUnstridedInputWidth - 2; input_width <= kUnstridedInputWidth + 2; input_width++) { DeconvolutionOperatorTester() @@ -13443,7 +13443,7 @@ TEST(DECONVOLUTION_NHWC_F32, batched_grouped_1x1_varying_input_height) { } TEST(DECONVOLUTION_NHWC_F32, batched_grouped_1x1_varying_input_channels) { - const struct xnn_gemm_config* gemm_config = xnn_init_f32_gemm_config(); + const struct xnn_gemm_config* gemm_config = xnn_init_f32_igemm_config(); ASSERT_NE(gemm_config, nullptr); for (size_t input_channels = 1; input_channels <= 16; input_channels *= 4) { DeconvolutionOperatorTester() @@ -13458,7 +13458,7 @@ TEST(DECONVOLUTION_NHWC_F32, batched_grouped_1x1_varying_input_channels) { } TEST(DECONVOLUTION_NHWC_F32, batched_grouped_1x1_varying_output_channels) { - const struct xnn_gemm_config* gemm_config = xnn_init_f32_gemm_config(); + const struct xnn_gemm_config* gemm_config = xnn_init_f32_igemm_config(); ASSERT_NE(gemm_config, nullptr); for (size_t output_channels = 1; output_channels <= gemm_config->nr * 2; output_channels *= 2) { DeconvolutionOperatorTester() @@ -13473,7 +13473,7 @@ TEST(DECONVOLUTION_NHWC_F32, batched_grouped_1x1_varying_output_channels) { } TEST(DECONVOLUTION_NHWC_F32, batched_grouped_1x1_with_input_stride) { - const struct xnn_gemm_config* gemm_config = xnn_init_f32_gemm_config(); + const struct xnn_gemm_config* gemm_config = xnn_init_f32_igemm_config(); ASSERT_NE(gemm_config, nullptr); DeconvolutionOperatorTester() .batch_size(2) @@ -13487,7 +13487,7 @@ TEST(DECONVOLUTION_NHWC_F32, batched_grouped_1x1_with_input_stride) { } TEST(DECONVOLUTION_NHWC_F32, batched_grouped_1x1_with_output_stride) { - const struct xnn_gemm_config* gemm_config = xnn_init_f32_gemm_config(); + const struct xnn_gemm_config* gemm_config = xnn_init_f32_igemm_config(); ASSERT_NE(gemm_config, nullptr); DeconvolutionOperatorTester() .batch_size(2) @@ -13501,7 +13501,7 @@ TEST(DECONVOLUTION_NHWC_F32, batched_grouped_1x1_with_output_stride) { } TEST(DECONVOLUTION_NHWC_F32, batched_grouped_1x1_with_qmin) { - const struct xnn_gemm_config* gemm_config = xnn_init_f32_gemm_config(); + const struct xnn_gemm_config* gemm_config = xnn_init_f32_igemm_config(); ASSERT_NE(gemm_config, nullptr); DeconvolutionOperatorTester() .batch_size(2) @@ -13515,7 +13515,7 @@ TEST(DECONVOLUTION_NHWC_F32, batched_grouped_1x1_with_qmin) { } TEST(DECONVOLUTION_NHWC_F32, batched_grouped_1x1_with_qmax) { - const struct xnn_gemm_config* gemm_config = xnn_init_f32_gemm_config(); + const struct xnn_gemm_config* gemm_config = xnn_init_f32_igemm_config(); ASSERT_NE(gemm_config, nullptr); DeconvolutionOperatorTester() .batch_size(2) @@ -13529,7 +13529,7 @@ TEST(DECONVOLUTION_NHWC_F32, batched_grouped_1x1_with_qmax) { } TEST(DECONVOLUTION_NHWC_F32, batched_grouped_1x1_without_bias) { - const struct xnn_gemm_config* gemm_config = xnn_init_f32_gemm_config(); + const struct xnn_gemm_config* gemm_config = xnn_init_f32_igemm_config(); ASSERT_NE(gemm_config, nullptr); DeconvolutionOperatorTester() .has_bias(false) @@ -13545,7 +13545,7 @@ TEST(DECONVOLUTION_NHWC_F32, batched_grouped_1x1_without_bias) { /**************************** CONV path ****************************/ TEST(DECONVOLUTION_NHWC_F32, kernel_3x3) { - const struct xnn_gemm_config* gemm_config = xnn_init_f32_gemm_config(); + const struct xnn_gemm_config* gemm_config = xnn_init_f32_igemm_config(); ASSERT_NE(gemm_config, nullptr); DeconvolutionOperatorTester() .input_size(kUnstridedInputHeight, kUnstridedInputWidth) @@ -13557,7 +13557,7 @@ TEST(DECONVOLUTION_NHWC_F32, kernel_3x3) { } TEST(DECONVOLUTION_NHWC_F32, Kx3) { - const struct xnn_gemm_config* gemm_config = xnn_init_f32_gemm_config(); + const struct xnn_gemm_config* gemm_config = xnn_init_f32_igemm_config(); ASSERT_NE(gemm_config, nullptr); for (size_t kernel_height = 1; kernel_height <= 4; kernel_height *= 2) { DeconvolutionOperatorTester() @@ -13571,7 +13571,7 @@ TEST(DECONVOLUTION_NHWC_F32, Kx3) { } TEST(DECONVOLUTION_NHWC_F32, kernel_3xK) { - const struct xnn_gemm_config* gemm_config = xnn_init_f32_gemm_config(); + const struct xnn_gemm_config* gemm_config = xnn_init_f32_igemm_config(); ASSERT_NE(gemm_config, nullptr); for (size_t kernel_width = 1; kernel_width <= 4; kernel_width *= 2) { DeconvolutionOperatorTester() @@ -13585,7 +13585,7 @@ TEST(DECONVOLUTION_NHWC_F32, kernel_3xK) { } TEST(DECONVOLUTION_NHWC_F32, kernel_3x3_varying_height_padding) { - const struct xnn_gemm_config* gemm_config = xnn_init_f32_gemm_config(); + const struct xnn_gemm_config* gemm_config = xnn_init_f32_igemm_config(); ASSERT_NE(gemm_config, nullptr); for (size_t padding_top = 0; padding_top <= 2; padding_top++) { for (size_t padding_bottom = 0; padding_bottom <= 2; padding_bottom++) { @@ -13603,7 +13603,7 @@ TEST(DECONVOLUTION_NHWC_F32, kernel_3x3_varying_height_padding) { } TEST(DECONVOLUTION_NHWC_F32, kernel_3x3_varying_width_padding) { - const struct xnn_gemm_config* gemm_config = xnn_init_f32_gemm_config(); + const struct xnn_gemm_config* gemm_config = xnn_init_f32_igemm_config(); ASSERT_NE(gemm_config, nullptr); for (size_t padding_left = 0; padding_left <= 2; padding_left++) { for (size_t padding_right = 0; padding_right <= 2; padding_right++) { @@ -13621,7 +13621,7 @@ TEST(DECONVOLUTION_NHWC_F32, kernel_3x3_varying_width_padding) { } TEST(DECONVOLUTION_NHWC_F32, kernel_3x3_varying_height_adjustment) { - const struct xnn_gemm_config* gemm_config = xnn_init_f32_gemm_config(); + const struct xnn_gemm_config* gemm_config = xnn_init_f32_igemm_config(); ASSERT_NE(gemm_config, nullptr); for (size_t adjustment_height = 1; adjustment_height <= 2; adjustment_height++) { DeconvolutionOperatorTester() @@ -13637,7 +13637,7 @@ TEST(DECONVOLUTION_NHWC_F32, kernel_3x3_varying_height_adjustment) { } TEST(DECONVOLUTION_NHWC_F32, kernel_3x3_varying_width_adjustment) { - const struct xnn_gemm_config* gemm_config = xnn_init_f32_gemm_config(); + const struct xnn_gemm_config* gemm_config = xnn_init_f32_igemm_config(); ASSERT_NE(gemm_config, nullptr); for (size_t adjustment_width = 1; adjustment_width <= 2; adjustment_width++) { DeconvolutionOperatorTester() @@ -13653,7 +13653,7 @@ TEST(DECONVOLUTION_NHWC_F32, kernel_3x3_varying_width_adjustment) { } TEST(DECONVOLUTION_NHWC_F32, kernel_3x3_varying_input_height) { - const struct xnn_gemm_config* gemm_config = xnn_init_f32_gemm_config(); + const struct xnn_gemm_config* gemm_config = xnn_init_f32_igemm_config(); ASSERT_NE(gemm_config, nullptr); for (size_t input_height = kUnstridedInputHeight - 2; input_height <= kUnstridedInputHeight + 2; input_height++) { DeconvolutionOperatorTester() @@ -13667,7 +13667,7 @@ TEST(DECONVOLUTION_NHWC_F32, kernel_3x3_varying_input_height) { } TEST(DECONVOLUTION_NHWC_F32, kernel_3x3_varying_input_width) { - const struct xnn_gemm_config* gemm_config = xnn_init_f32_gemm_config(); + const struct xnn_gemm_config* gemm_config = xnn_init_f32_igemm_config(); ASSERT_NE(gemm_config, nullptr); for (size_t input_width = kUnstridedInputWidth - 2; input_width <= kUnstridedInputWidth + 2; input_width++) { DeconvolutionOperatorTester() @@ -13681,7 +13681,7 @@ TEST(DECONVOLUTION_NHWC_F32, kernel_3x3_varying_input_width) { } TEST(DECONVOLUTION_NHWC_F32, kernel_3x3_varying_input_channels) { - const struct xnn_gemm_config* gemm_config = xnn_init_f32_gemm_config(); + const struct xnn_gemm_config* gemm_config = xnn_init_f32_igemm_config(); ASSERT_NE(gemm_config, nullptr); for (size_t input_channels = 1; input_channels <= 16; input_channels *= 4) { DeconvolutionOperatorTester() @@ -13695,7 +13695,7 @@ TEST(DECONVOLUTION_NHWC_F32, kernel_3x3_varying_input_channels) { } TEST(DECONVOLUTION_NHWC_F32, kernel_3x3_varying_output_channels) { - const struct xnn_gemm_config* gemm_config = xnn_init_f32_gemm_config(); + const struct xnn_gemm_config* gemm_config = xnn_init_f32_igemm_config(); ASSERT_NE(gemm_config, nullptr); for (size_t output_channels = 1; output_channels <= gemm_config->nr * 2; output_channels *= 2) { DeconvolutionOperatorTester() @@ -13709,7 +13709,7 @@ TEST(DECONVOLUTION_NHWC_F32, kernel_3x3_varying_output_channels) { } TEST(DECONVOLUTION_NHWC_F32, kernel_3x3_with_height_dilation) { - const struct xnn_gemm_config* gemm_config = xnn_init_f32_gemm_config(); + const struct xnn_gemm_config* gemm_config = xnn_init_f32_igemm_config(); ASSERT_NE(gemm_config, nullptr); for (size_t dilation_height = 2; dilation_height <= 3; dilation_height++) { DeconvolutionOperatorTester() @@ -13724,7 +13724,7 @@ TEST(DECONVOLUTION_NHWC_F32, kernel_3x3_with_height_dilation) { } TEST(DECONVOLUTION_NHWC_F32, kernel_3x3_with_width_dilation) { - const struct xnn_gemm_config* gemm_config = xnn_init_f32_gemm_config(); + const struct xnn_gemm_config* gemm_config = xnn_init_f32_igemm_config(); ASSERT_NE(gemm_config, nullptr); for (size_t dilation_width = 2; dilation_width <= 3; dilation_width++) { DeconvolutionOperatorTester() @@ -13739,7 +13739,7 @@ TEST(DECONVOLUTION_NHWC_F32, kernel_3x3_with_width_dilation) { } TEST(DECONVOLUTION_NHWC_F32, kernel_3x3_with_height_dilation_and_stride) { - const struct xnn_gemm_config* gemm_config = xnn_init_f32_gemm_config(); + const struct xnn_gemm_config* gemm_config = xnn_init_f32_igemm_config(); ASSERT_NE(gemm_config, nullptr); DeconvolutionOperatorTester() .input_size(kUnstridedInputHeight, kUnstridedInputWidth) @@ -13753,7 +13753,7 @@ TEST(DECONVOLUTION_NHWC_F32, kernel_3x3_with_height_dilation_and_stride) { } TEST(DECONVOLUTION_NHWC_F32, kernel_3x3_with_width_dilation_and_stride) { - const struct xnn_gemm_config* gemm_config = xnn_init_f32_gemm_config(); + const struct xnn_gemm_config* gemm_config = xnn_init_f32_igemm_config(); ASSERT_NE(gemm_config, nullptr); DeconvolutionOperatorTester() .input_size(kUnstridedInputHeight, kUnstridedInputWidth) @@ -13767,7 +13767,7 @@ TEST(DECONVOLUTION_NHWC_F32, kernel_3x3_with_width_dilation_and_stride) { } TEST(DECONVOLUTION_NHWC_F32, kernel_3x3_with_input_stride) { - const struct xnn_gemm_config* gemm_config = xnn_init_f32_gemm_config(); + const struct xnn_gemm_config* gemm_config = xnn_init_f32_igemm_config(); ASSERT_NE(gemm_config, nullptr); DeconvolutionOperatorTester() .input_size(kUnstridedInputHeight, kUnstridedInputWidth) @@ -13780,7 +13780,7 @@ TEST(DECONVOLUTION_NHWC_F32, kernel_3x3_with_input_stride) { } TEST(DECONVOLUTION_NHWC_F32, kernel_3x3_with_output_stride) { - const struct xnn_gemm_config* gemm_config = xnn_init_f32_gemm_config(); + const struct xnn_gemm_config* gemm_config = xnn_init_f32_igemm_config(); ASSERT_NE(gemm_config, nullptr); DeconvolutionOperatorTester() .input_size(kUnstridedInputHeight, kUnstridedInputWidth) @@ -13793,7 +13793,7 @@ TEST(DECONVOLUTION_NHWC_F32, kernel_3x3_with_output_stride) { } TEST(DECONVOLUTION_NHWC_F32, kernel_3x3_with_qmin) { - const struct xnn_gemm_config* gemm_config = xnn_init_f32_gemm_config(); + const struct xnn_gemm_config* gemm_config = xnn_init_f32_igemm_config(); ASSERT_NE(gemm_config, nullptr); DeconvolutionOperatorTester() .input_size(kUnstridedInputHeight, kUnstridedInputWidth) @@ -13806,7 +13806,7 @@ TEST(DECONVOLUTION_NHWC_F32, kernel_3x3_with_qmin) { } TEST(DECONVOLUTION_NHWC_F32, kernel_3x3_with_qmax) { - const struct xnn_gemm_config* gemm_config = xnn_init_f32_gemm_config(); + const struct xnn_gemm_config* gemm_config = xnn_init_f32_igemm_config(); ASSERT_NE(gemm_config, nullptr); DeconvolutionOperatorTester() .input_size(kUnstridedInputHeight, kUnstridedInputWidth) @@ -13819,7 +13819,7 @@ TEST(DECONVOLUTION_NHWC_F32, kernel_3x3_with_qmax) { } TEST(DECONVOLUTION_NHWC_F32, kernel_3x3_without_bias) { - const struct xnn_gemm_config* gemm_config = xnn_init_f32_gemm_config(); + const struct xnn_gemm_config* gemm_config = xnn_init_f32_igemm_config(); ASSERT_NE(gemm_config, nullptr); DeconvolutionOperatorTester() .has_bias(false) @@ -13832,7 +13832,7 @@ TEST(DECONVOLUTION_NHWC_F32, kernel_3x3_without_bias) { } TEST(DECONVOLUTION_NHWC_F32, weights_cache_3x3) { - const struct xnn_gemm_config* gemm_config = xnn_init_f32_gemm_config(); + const struct xnn_gemm_config* gemm_config = xnn_init_f32_igemm_config(); ASSERT_NE(gemm_config, nullptr); DeconvolutionOperatorTester() .input_size(kUnstridedInputHeight, kUnstridedInputWidth) @@ -13847,7 +13847,7 @@ TEST(DECONVOLUTION_NHWC_F32, weights_cache_3x3) { /**************************** CONV path, grouped ****************************/ TEST(DECONVOLUTION_NHWC_F32, grouped_3x3) { - const struct xnn_gemm_config* gemm_config = xnn_init_f32_gemm_config(); + const struct xnn_gemm_config* gemm_config = xnn_init_f32_igemm_config(); ASSERT_NE(gemm_config, nullptr); DeconvolutionOperatorTester() .input_size(kUnstridedInputHeight, kUnstridedInputWidth) @@ -13860,7 +13860,7 @@ TEST(DECONVOLUTION_NHWC_F32, grouped_3x3) { } TEST(DECONVOLUTION_NHWC_F32, grouped_Kx3) { - const struct xnn_gemm_config* gemm_config = xnn_init_f32_gemm_config(); + const struct xnn_gemm_config* gemm_config = xnn_init_f32_igemm_config(); ASSERT_NE(gemm_config, nullptr); for (size_t kernel_height = 1; kernel_height <= 4; kernel_height *= 2) { DeconvolutionOperatorTester() @@ -13875,7 +13875,7 @@ TEST(DECONVOLUTION_NHWC_F32, grouped_Kx3) { } TEST(DECONVOLUTION_NHWC_F32, grouped_3xK) { - const struct xnn_gemm_config* gemm_config = xnn_init_f32_gemm_config(); + const struct xnn_gemm_config* gemm_config = xnn_init_f32_igemm_config(); ASSERT_NE(gemm_config, nullptr); for (size_t kernel_width = 1; kernel_width <= 4; kernel_width *= 2) { DeconvolutionOperatorTester() @@ -13890,7 +13890,7 @@ TEST(DECONVOLUTION_NHWC_F32, grouped_3xK) { } TEST(DECONVOLUTION_NHWC_F32, grouped_3x3_varying_height_padding) { - const struct xnn_gemm_config* gemm_config = xnn_init_f32_gemm_config(); + const struct xnn_gemm_config* gemm_config = xnn_init_f32_igemm_config(); ASSERT_NE(gemm_config, nullptr); for (size_t padding_top = 0; padding_top <= 2; padding_top++) { for (size_t padding_bottom = 0; padding_bottom <= 2; padding_bottom++) { @@ -13909,7 +13909,7 @@ TEST(DECONVOLUTION_NHWC_F32, grouped_3x3_varying_height_padding) { } TEST(DECONVOLUTION_NHWC_F32, grouped_3x3_varying_width_padding) { - const struct xnn_gemm_config* gemm_config = xnn_init_f32_gemm_config(); + const struct xnn_gemm_config* gemm_config = xnn_init_f32_igemm_config(); ASSERT_NE(gemm_config, nullptr); for (size_t padding_left = 0; padding_left <= 2; padding_left++) { for (size_t padding_right = 0; padding_right <= 2; padding_right++) { @@ -13928,7 +13928,7 @@ TEST(DECONVOLUTION_NHWC_F32, grouped_3x3_varying_width_padding) { } TEST(DECONVOLUTION_NHWC_F32, grouped_3x3_varying_height_adjustment) { - const struct xnn_gemm_config* gemm_config = xnn_init_f32_gemm_config(); + const struct xnn_gemm_config* gemm_config = xnn_init_f32_igemm_config(); ASSERT_NE(gemm_config, nullptr); for (size_t adjustment_height = 1; adjustment_height <= 2; adjustment_height++) { DeconvolutionOperatorTester() @@ -13945,7 +13945,7 @@ TEST(DECONVOLUTION_NHWC_F32, grouped_3x3_varying_height_adjustment) { } TEST(DECONVOLUTION_NHWC_F32, grouped_3x3_varying_width_adjustment) { - const struct xnn_gemm_config* gemm_config = xnn_init_f32_gemm_config(); + const struct xnn_gemm_config* gemm_config = xnn_init_f32_igemm_config(); ASSERT_NE(gemm_config, nullptr); for (size_t adjustment_width = 1; adjustment_width <= 2; adjustment_width++) { DeconvolutionOperatorTester() @@ -13962,7 +13962,7 @@ TEST(DECONVOLUTION_NHWC_F32, grouped_3x3_varying_width_adjustment) { } TEST(DECONVOLUTION_NHWC_F32, grouped_3x3_varying_input_height) { - const struct xnn_gemm_config* gemm_config = xnn_init_f32_gemm_config(); + const struct xnn_gemm_config* gemm_config = xnn_init_f32_igemm_config(); ASSERT_NE(gemm_config, nullptr); for (size_t input_height = kUnstridedInputHeight - 2; input_height <= kUnstridedInputHeight + 2; input_height++) { DeconvolutionOperatorTester() @@ -13977,7 +13977,7 @@ TEST(DECONVOLUTION_NHWC_F32, grouped_3x3_varying_input_height) { } TEST(DECONVOLUTION_NHWC_F32, grouped_3x3_varying_input_width) { - const struct xnn_gemm_config* gemm_config = xnn_init_f32_gemm_config(); + const struct xnn_gemm_config* gemm_config = xnn_init_f32_igemm_config(); ASSERT_NE(gemm_config, nullptr); for (size_t input_width = kUnstridedInputWidth - 2; input_width <= kUnstridedInputWidth + 2; input_width++) { DeconvolutionOperatorTester() @@ -13992,7 +13992,7 @@ TEST(DECONVOLUTION_NHWC_F32, grouped_3x3_varying_input_width) { } TEST(DECONVOLUTION_NHWC_F32, grouped_3x3_varying_input_channels) { - const struct xnn_gemm_config* gemm_config = xnn_init_f32_gemm_config(); + const struct xnn_gemm_config* gemm_config = xnn_init_f32_igemm_config(); ASSERT_NE(gemm_config, nullptr); for (size_t input_channels = 1; input_channels <= 16; input_channels *= 4) { DeconvolutionOperatorTester() @@ -14007,7 +14007,7 @@ TEST(DECONVOLUTION_NHWC_F32, grouped_3x3_varying_input_channels) { } TEST(DECONVOLUTION_NHWC_F32, grouped_3x3_varying_output_channels) { - const struct xnn_gemm_config* gemm_config = xnn_init_f32_gemm_config(); + const struct xnn_gemm_config* gemm_config = xnn_init_f32_igemm_config(); ASSERT_NE(gemm_config, nullptr); for (size_t output_channels = 1; output_channels <= gemm_config->nr * 2; output_channels *= 2) { DeconvolutionOperatorTester() @@ -14022,7 +14022,7 @@ TEST(DECONVOLUTION_NHWC_F32, grouped_3x3_varying_output_channels) { } TEST(DECONVOLUTION_NHWC_F32, grouped_3x3_with_height_dilation) { - const struct xnn_gemm_config* gemm_config = xnn_init_f32_gemm_config(); + const struct xnn_gemm_config* gemm_config = xnn_init_f32_igemm_config(); ASSERT_NE(gemm_config, nullptr); for (size_t dilation_height = 2; dilation_height <= 3; dilation_height++) { DeconvolutionOperatorTester() @@ -14038,7 +14038,7 @@ TEST(DECONVOLUTION_NHWC_F32, grouped_3x3_with_height_dilation) { } TEST(DECONVOLUTION_NHWC_F32, grouped_3x3_with_width_dilation) { - const struct xnn_gemm_config* gemm_config = xnn_init_f32_gemm_config(); + const struct xnn_gemm_config* gemm_config = xnn_init_f32_igemm_config(); ASSERT_NE(gemm_config, nullptr); for (size_t dilation_width = 2; dilation_width <= 3; dilation_width++) { DeconvolutionOperatorTester() @@ -14054,7 +14054,7 @@ TEST(DECONVOLUTION_NHWC_F32, grouped_3x3_with_width_dilation) { } TEST(DECONVOLUTION_NHWC_F32, grouped_3x3_with_height_dilation_and_stride) { - const struct xnn_gemm_config* gemm_config = xnn_init_f32_gemm_config(); + const struct xnn_gemm_config* gemm_config = xnn_init_f32_igemm_config(); ASSERT_NE(gemm_config, nullptr); DeconvolutionOperatorTester() .input_size(kUnstridedInputHeight, kUnstridedInputWidth) @@ -14069,7 +14069,7 @@ TEST(DECONVOLUTION_NHWC_F32, grouped_3x3_with_height_dilation_and_stride) { } TEST(DECONVOLUTION_NHWC_F32, grouped_3x3_with_width_dilation_and_stride) { - const struct xnn_gemm_config* gemm_config = xnn_init_f32_gemm_config(); + const struct xnn_gemm_config* gemm_config = xnn_init_f32_igemm_config(); ASSERT_NE(gemm_config, nullptr); DeconvolutionOperatorTester() .input_size(kUnstridedInputHeight, kUnstridedInputWidth) @@ -14084,7 +14084,7 @@ TEST(DECONVOLUTION_NHWC_F32, grouped_3x3_with_width_dilation_and_stride) { } TEST(DECONVOLUTION_NHWC_F32, grouped_3x3_with_input_stride) { - const struct xnn_gemm_config* gemm_config = xnn_init_f32_gemm_config(); + const struct xnn_gemm_config* gemm_config = xnn_init_f32_igemm_config(); ASSERT_NE(gemm_config, nullptr); DeconvolutionOperatorTester() .input_size(kUnstridedInputHeight, kUnstridedInputWidth) @@ -14098,7 +14098,7 @@ TEST(DECONVOLUTION_NHWC_F32, grouped_3x3_with_input_stride) { } TEST(DECONVOLUTION_NHWC_F32, grouped_3x3_with_output_stride) { - const struct xnn_gemm_config* gemm_config = xnn_init_f32_gemm_config(); + const struct xnn_gemm_config* gemm_config = xnn_init_f32_igemm_config(); ASSERT_NE(gemm_config, nullptr); DeconvolutionOperatorTester() .input_size(kUnstridedInputHeight, kUnstridedInputWidth) @@ -14112,7 +14112,7 @@ TEST(DECONVOLUTION_NHWC_F32, grouped_3x3_with_output_stride) { } TEST(DECONVOLUTION_NHWC_F32, grouped_3x3_with_qmin) { - const struct xnn_gemm_config* gemm_config = xnn_init_f32_gemm_config(); + const struct xnn_gemm_config* gemm_config = xnn_init_f32_igemm_config(); ASSERT_NE(gemm_config, nullptr); DeconvolutionOperatorTester() .input_size(kUnstridedInputHeight, kUnstridedInputWidth) @@ -14126,7 +14126,7 @@ TEST(DECONVOLUTION_NHWC_F32, grouped_3x3_with_qmin) { } TEST(DECONVOLUTION_NHWC_F32, grouped_3x3_with_qmax) { - const struct xnn_gemm_config* gemm_config = xnn_init_f32_gemm_config(); + const struct xnn_gemm_config* gemm_config = xnn_init_f32_igemm_config(); ASSERT_NE(gemm_config, nullptr); DeconvolutionOperatorTester() .input_size(kUnstridedInputHeight, kUnstridedInputWidth) @@ -14140,7 +14140,7 @@ TEST(DECONVOLUTION_NHWC_F32, grouped_3x3_with_qmax) { } TEST(DECONVOLUTION_NHWC_F32, grouped_3x3_without_bias) { - const struct xnn_gemm_config* gemm_config = xnn_init_f32_gemm_config(); + const struct xnn_gemm_config* gemm_config = xnn_init_f32_igemm_config(); ASSERT_NE(gemm_config, nullptr); DeconvolutionOperatorTester() .has_bias(false) @@ -14154,7 +14154,7 @@ TEST(DECONVOLUTION_NHWC_F32, grouped_3x3_without_bias) { } TEST(DECONVOLUTION_NHWC_F32, weights_cache_grouped_3x3) { - const struct xnn_gemm_config* gemm_config = xnn_init_f32_gemm_config(); + const struct xnn_gemm_config* gemm_config = xnn_init_f32_igemm_config(); ASSERT_NE(gemm_config, nullptr); DeconvolutionOperatorTester() .input_size(kUnstridedInputHeight, kUnstridedInputWidth) @@ -14170,7 +14170,7 @@ TEST(DECONVOLUTION_NHWC_F32, weights_cache_grouped_3x3) { /**************************** CONV path, batched ****************************/ TEST(DECONVOLUTION_NHWC_F32, batched_3x3) { - const struct xnn_gemm_config* gemm_config = xnn_init_f32_gemm_config(); + const struct xnn_gemm_config* gemm_config = xnn_init_f32_igemm_config(); ASSERT_NE(gemm_config, nullptr); DeconvolutionOperatorTester() .batch_size(2) @@ -14183,7 +14183,7 @@ TEST(DECONVOLUTION_NHWC_F32, batched_3x3) { } TEST(DECONVOLUTION_NHWC_F32, batched_Kx3) { - const struct xnn_gemm_config* gemm_config = xnn_init_f32_gemm_config(); + const struct xnn_gemm_config* gemm_config = xnn_init_f32_igemm_config(); ASSERT_NE(gemm_config, nullptr); for (size_t kernel_height = 1; kernel_height <= 4; kernel_height *= 2) { DeconvolutionOperatorTester() @@ -14198,7 +14198,7 @@ TEST(DECONVOLUTION_NHWC_F32, batched_Kx3) { } TEST(DECONVOLUTION_NHWC_F32, batched_3xK) { - const struct xnn_gemm_config* gemm_config = xnn_init_f32_gemm_config(); + const struct xnn_gemm_config* gemm_config = xnn_init_f32_igemm_config(); ASSERT_NE(gemm_config, nullptr); for (size_t kernel_width = 1; kernel_width <= 4; kernel_width *= 2) { DeconvolutionOperatorTester() @@ -14213,7 +14213,7 @@ TEST(DECONVOLUTION_NHWC_F32, batched_3xK) { } TEST(DECONVOLUTION_NHWC_F32, batched_3x3_varying_height_padding) { - const struct xnn_gemm_config* gemm_config = xnn_init_f32_gemm_config(); + const struct xnn_gemm_config* gemm_config = xnn_init_f32_igemm_config(); ASSERT_NE(gemm_config, nullptr); for (size_t padding_top = 0; padding_top <= 2; padding_top++) { for (size_t padding_bottom = 0; padding_bottom <= 2; padding_bottom++) { @@ -14232,7 +14232,7 @@ TEST(DECONVOLUTION_NHWC_F32, batched_3x3_varying_height_padding) { } TEST(DECONVOLUTION_NHWC_F32, batched_3x3_varying_width_padding) { - const struct xnn_gemm_config* gemm_config = xnn_init_f32_gemm_config(); + const struct xnn_gemm_config* gemm_config = xnn_init_f32_igemm_config(); ASSERT_NE(gemm_config, nullptr); for (size_t padding_left = 0; padding_left <= 2; padding_left++) { for (size_t padding_right = 0; padding_right <= 2; padding_right++) { @@ -14251,7 +14251,7 @@ TEST(DECONVOLUTION_NHWC_F32, batched_3x3_varying_width_padding) { } TEST(DECONVOLUTION_NHWC_F32, batched_3x3_varying_height_adjustment) { - const struct xnn_gemm_config* gemm_config = xnn_init_f32_gemm_config(); + const struct xnn_gemm_config* gemm_config = xnn_init_f32_igemm_config(); ASSERT_NE(gemm_config, nullptr); for (size_t adjustment_height = 1; adjustment_height <= 2; adjustment_height++) { DeconvolutionOperatorTester() @@ -14268,7 +14268,7 @@ TEST(DECONVOLUTION_NHWC_F32, batched_3x3_varying_height_adjustment) { } TEST(DECONVOLUTION_NHWC_F32, batched_3x3_varying_width_adjustment) { - const struct xnn_gemm_config* gemm_config = xnn_init_f32_gemm_config(); + const struct xnn_gemm_config* gemm_config = xnn_init_f32_igemm_config(); ASSERT_NE(gemm_config, nullptr); for (size_t adjustment_width = 1; adjustment_width <= 2; adjustment_width++) { DeconvolutionOperatorTester() @@ -14285,7 +14285,7 @@ TEST(DECONVOLUTION_NHWC_F32, batched_3x3_varying_width_adjustment) { } TEST(DECONVOLUTION_NHWC_F32, batched_3x3_varying_input_height) { - const struct xnn_gemm_config* gemm_config = xnn_init_f32_gemm_config(); + const struct xnn_gemm_config* gemm_config = xnn_init_f32_igemm_config(); ASSERT_NE(gemm_config, nullptr); for (size_t input_height = kUnstridedInputHeight - 2; input_height <= kUnstridedInputHeight + 2; input_height++) { DeconvolutionOperatorTester() @@ -14300,7 +14300,7 @@ TEST(DECONVOLUTION_NHWC_F32, batched_3x3_varying_input_height) { } TEST(DECONVOLUTION_NHWC_F32, batched_3x3_varying_input_width) { - const struct xnn_gemm_config* gemm_config = xnn_init_f32_gemm_config(); + const struct xnn_gemm_config* gemm_config = xnn_init_f32_igemm_config(); ASSERT_NE(gemm_config, nullptr); for (size_t input_width = kUnstridedInputWidth - 2; input_width <= kUnstridedInputWidth + 2; input_width++) { DeconvolutionOperatorTester() @@ -14315,7 +14315,7 @@ TEST(DECONVOLUTION_NHWC_F32, batched_3x3_varying_input_width) { } TEST(DECONVOLUTION_NHWC_F32, batched_3x3_varying_input_channels) { - const struct xnn_gemm_config* gemm_config = xnn_init_f32_gemm_config(); + const struct xnn_gemm_config* gemm_config = xnn_init_f32_igemm_config(); ASSERT_NE(gemm_config, nullptr); for (size_t input_channels = 1; input_channels <= 16; input_channels *= 4) { DeconvolutionOperatorTester() @@ -14330,7 +14330,7 @@ TEST(DECONVOLUTION_NHWC_F32, batched_3x3_varying_input_channels) { } TEST(DECONVOLUTION_NHWC_F32, batched_3x3_varying_output_channels) { - const struct xnn_gemm_config* gemm_config = xnn_init_f32_gemm_config(); + const struct xnn_gemm_config* gemm_config = xnn_init_f32_igemm_config(); ASSERT_NE(gemm_config, nullptr); for (size_t output_channels = 1; output_channels <= gemm_config->nr * 2; output_channels *= 2) { DeconvolutionOperatorTester() @@ -14345,7 +14345,7 @@ TEST(DECONVOLUTION_NHWC_F32, batched_3x3_varying_output_channels) { } TEST(DECONVOLUTION_NHWC_F32, batched_3x3_with_height_dilation) { - const struct xnn_gemm_config* gemm_config = xnn_init_f32_gemm_config(); + const struct xnn_gemm_config* gemm_config = xnn_init_f32_igemm_config(); ASSERT_NE(gemm_config, nullptr); for (size_t dilation_height = 2; dilation_height <= 3; dilation_height++) { DeconvolutionOperatorTester() @@ -14361,7 +14361,7 @@ TEST(DECONVOLUTION_NHWC_F32, batched_3x3_with_height_dilation) { } TEST(DECONVOLUTION_NHWC_F32, batched_3x3_with_width_dilation) { - const struct xnn_gemm_config* gemm_config = xnn_init_f32_gemm_config(); + const struct xnn_gemm_config* gemm_config = xnn_init_f32_igemm_config(); ASSERT_NE(gemm_config, nullptr); for (size_t dilation_width = 2; dilation_width <= 3; dilation_width++) { DeconvolutionOperatorTester() @@ -14377,7 +14377,7 @@ TEST(DECONVOLUTION_NHWC_F32, batched_3x3_with_width_dilation) { } TEST(DECONVOLUTION_NHWC_F32, batched_3x3_with_height_dilation_and_stride) { - const struct xnn_gemm_config* gemm_config = xnn_init_f32_gemm_config(); + const struct xnn_gemm_config* gemm_config = xnn_init_f32_igemm_config(); ASSERT_NE(gemm_config, nullptr); DeconvolutionOperatorTester() .batch_size(2) @@ -14392,7 +14392,7 @@ TEST(DECONVOLUTION_NHWC_F32, batched_3x3_with_height_dilation_and_stride) { } TEST(DECONVOLUTION_NHWC_F32, batched_3x3_with_width_dilation_and_stride) { - const struct xnn_gemm_config* gemm_config = xnn_init_f32_gemm_config(); + const struct xnn_gemm_config* gemm_config = xnn_init_f32_igemm_config(); ASSERT_NE(gemm_config, nullptr); DeconvolutionOperatorTester() .batch_size(2) @@ -14407,7 +14407,7 @@ TEST(DECONVOLUTION_NHWC_F32, batched_3x3_with_width_dilation_and_stride) { } TEST(DECONVOLUTION_NHWC_F32, batched_3x3_with_input_stride) { - const struct xnn_gemm_config* gemm_config = xnn_init_f32_gemm_config(); + const struct xnn_gemm_config* gemm_config = xnn_init_f32_igemm_config(); ASSERT_NE(gemm_config, nullptr); DeconvolutionOperatorTester() .batch_size(2) @@ -14421,7 +14421,7 @@ TEST(DECONVOLUTION_NHWC_F32, batched_3x3_with_input_stride) { } TEST(DECONVOLUTION_NHWC_F32, batched_3x3_with_output_stride) { - const struct xnn_gemm_config* gemm_config = xnn_init_f32_gemm_config(); + const struct xnn_gemm_config* gemm_config = xnn_init_f32_igemm_config(); ASSERT_NE(gemm_config, nullptr); DeconvolutionOperatorTester() .batch_size(2) @@ -14435,7 +14435,7 @@ TEST(DECONVOLUTION_NHWC_F32, batched_3x3_with_output_stride) { } TEST(DECONVOLUTION_NHWC_F32, batched_3x3_with_qmin) { - const struct xnn_gemm_config* gemm_config = xnn_init_f32_gemm_config(); + const struct xnn_gemm_config* gemm_config = xnn_init_f32_igemm_config(); ASSERT_NE(gemm_config, nullptr); DeconvolutionOperatorTester() .batch_size(2) @@ -14449,7 +14449,7 @@ TEST(DECONVOLUTION_NHWC_F32, batched_3x3_with_qmin) { } TEST(DECONVOLUTION_NHWC_F32, batched_3x3_with_qmax) { - const struct xnn_gemm_config* gemm_config = xnn_init_f32_gemm_config(); + const struct xnn_gemm_config* gemm_config = xnn_init_f32_igemm_config(); ASSERT_NE(gemm_config, nullptr); DeconvolutionOperatorTester() .batch_size(2) @@ -14463,7 +14463,7 @@ TEST(DECONVOLUTION_NHWC_F32, batched_3x3_with_qmax) { } TEST(DECONVOLUTION_NHWC_F32, batched_3x3_without_bias) { - const struct xnn_gemm_config* gemm_config = xnn_init_f32_gemm_config(); + const struct xnn_gemm_config* gemm_config = xnn_init_f32_igemm_config(); ASSERT_NE(gemm_config, nullptr); DeconvolutionOperatorTester() .has_bias(false) @@ -14477,7 +14477,7 @@ TEST(DECONVOLUTION_NHWC_F32, batched_3x3_without_bias) { } TEST(DECONVOLUTION_NHWC_F32, weights_cache_batched_3x3) { - const struct xnn_gemm_config* gemm_config = xnn_init_f32_gemm_config(); + const struct xnn_gemm_config* gemm_config = xnn_init_f32_igemm_config(); ASSERT_NE(gemm_config, nullptr); DeconvolutionOperatorTester() .batch_size(2) @@ -14494,7 +14494,7 @@ TEST(DECONVOLUTION_NHWC_F32, weights_cache_batched_3x3) { /**************************** CONV path, grouped, batched ****************************/ TEST(DECONVOLUTION_NHWC_F32, batched_grouped_3x3) { - const struct xnn_gemm_config* gemm_config = xnn_init_f32_gemm_config(); + const struct xnn_gemm_config* gemm_config = xnn_init_f32_igemm_config(); ASSERT_NE(gemm_config, nullptr); DeconvolutionOperatorTester() .batch_size(2) @@ -14508,7 +14508,7 @@ TEST(DECONVOLUTION_NHWC_F32, batched_grouped_3x3) { } TEST(DECONVOLUTION_NHWC_F32, batched_grouped_Kx3) { - const struct xnn_gemm_config* gemm_config = xnn_init_f32_gemm_config(); + const struct xnn_gemm_config* gemm_config = xnn_init_f32_igemm_config(); ASSERT_NE(gemm_config, nullptr); for (size_t kernel_height = 1; kernel_height <= 4; kernel_height *= 2) { DeconvolutionOperatorTester() @@ -14524,7 +14524,7 @@ TEST(DECONVOLUTION_NHWC_F32, batched_grouped_Kx3) { } TEST(DECONVOLUTION_NHWC_F32, batched_grouped_3xK) { - const struct xnn_gemm_config* gemm_config = xnn_init_f32_gemm_config(); + const struct xnn_gemm_config* gemm_config = xnn_init_f32_igemm_config(); ASSERT_NE(gemm_config, nullptr); for (size_t kernel_width = 1; kernel_width <= 4; kernel_width *= 2) { DeconvolutionOperatorTester() @@ -14540,7 +14540,7 @@ TEST(DECONVOLUTION_NHWC_F32, batched_grouped_3xK) { } TEST(DECONVOLUTION_NHWC_F32, batched_grouped_3x3_varying_height_padding) { - const struct xnn_gemm_config* gemm_config = xnn_init_f32_gemm_config(); + const struct xnn_gemm_config* gemm_config = xnn_init_f32_igemm_config(); ASSERT_NE(gemm_config, nullptr); for (size_t padding_top = 0; padding_top <= 2; padding_top++) { for (size_t padding_bottom = 0; padding_bottom <= 2; padding_bottom++) { @@ -14560,7 +14560,7 @@ TEST(DECONVOLUTION_NHWC_F32, batched_grouped_3x3_varying_height_padding) { } TEST(DECONVOLUTION_NHWC_F32, batched_grouped_3x3_varying_width_padding) { - const struct xnn_gemm_config* gemm_config = xnn_init_f32_gemm_config(); + const struct xnn_gemm_config* gemm_config = xnn_init_f32_igemm_config(); ASSERT_NE(gemm_config, nullptr); for (size_t padding_left = 0; padding_left <= 2; padding_left++) { for (size_t padding_right = 0; padding_right <= 2; padding_right++) { @@ -14580,7 +14580,7 @@ TEST(DECONVOLUTION_NHWC_F32, batched_grouped_3x3_varying_width_padding) { } TEST(DECONVOLUTION_NHWC_F32, batched_grouped_3x3_varying_height_adjustment) { - const struct xnn_gemm_config* gemm_config = xnn_init_f32_gemm_config(); + const struct xnn_gemm_config* gemm_config = xnn_init_f32_igemm_config(); ASSERT_NE(gemm_config, nullptr); for (size_t adjustment_height = 1; adjustment_height <= 2; adjustment_height++) { DeconvolutionOperatorTester() @@ -14598,7 +14598,7 @@ TEST(DECONVOLUTION_NHWC_F32, batched_grouped_3x3_varying_height_adjustment) { } TEST(DECONVOLUTION_NHWC_F32, batched_grouped_3x3_varying_width_adjustment) { - const struct xnn_gemm_config* gemm_config = xnn_init_f32_gemm_config(); + const struct xnn_gemm_config* gemm_config = xnn_init_f32_igemm_config(); ASSERT_NE(gemm_config, nullptr); for (size_t adjustment_width = 1; adjustment_width <= 2; adjustment_width++) { DeconvolutionOperatorTester() @@ -14616,7 +14616,7 @@ TEST(DECONVOLUTION_NHWC_F32, batched_grouped_3x3_varying_width_adjustment) { } TEST(DECONVOLUTION_NHWC_F32, batched_grouped_3x3_varying_input_height) { - const struct xnn_gemm_config* gemm_config = xnn_init_f32_gemm_config(); + const struct xnn_gemm_config* gemm_config = xnn_init_f32_igemm_config(); ASSERT_NE(gemm_config, nullptr); for (size_t input_height = kUnstridedInputHeight - 2; input_height <= kUnstridedInputHeight + 2; input_height++) { DeconvolutionOperatorTester() @@ -14632,7 +14632,7 @@ TEST(DECONVOLUTION_NHWC_F32, batched_grouped_3x3_varying_input_height) { } TEST(DECONVOLUTION_NHWC_F32, batched_grouped_3x3_varying_input_width) { - const struct xnn_gemm_config* gemm_config = xnn_init_f32_gemm_config(); + const struct xnn_gemm_config* gemm_config = xnn_init_f32_igemm_config(); ASSERT_NE(gemm_config, nullptr); for (size_t input_width = kUnstridedInputWidth - 2; input_width <= kUnstridedInputWidth + 2; input_width++) { DeconvolutionOperatorTester() @@ -14648,7 +14648,7 @@ TEST(DECONVOLUTION_NHWC_F32, batched_grouped_3x3_varying_input_width) { } TEST(DECONVOLUTION_NHWC_F32, batched_grouped_3x3_varying_input_channels) { - const struct xnn_gemm_config* gemm_config = xnn_init_f32_gemm_config(); + const struct xnn_gemm_config* gemm_config = xnn_init_f32_igemm_config(); ASSERT_NE(gemm_config, nullptr); for (size_t input_channels = 1; input_channels <= 16; input_channels *= 4) { DeconvolutionOperatorTester() @@ -14664,7 +14664,7 @@ TEST(DECONVOLUTION_NHWC_F32, batched_grouped_3x3_varying_input_channels) { } TEST(DECONVOLUTION_NHWC_F32, batched_grouped_3x3_varying_output_channels) { - const struct xnn_gemm_config* gemm_config = xnn_init_f32_gemm_config(); + const struct xnn_gemm_config* gemm_config = xnn_init_f32_igemm_config(); ASSERT_NE(gemm_config, nullptr); for (size_t output_channels = 1; output_channels <= gemm_config->nr * 2; output_channels *= 2) { DeconvolutionOperatorTester() @@ -14680,7 +14680,7 @@ TEST(DECONVOLUTION_NHWC_F32, batched_grouped_3x3_varying_output_channels) { } TEST(DECONVOLUTION_NHWC_F32, batched_grouped_3x3_with_height_dilation) { - const struct xnn_gemm_config* gemm_config = xnn_init_f32_gemm_config(); + const struct xnn_gemm_config* gemm_config = xnn_init_f32_igemm_config(); ASSERT_NE(gemm_config, nullptr); for (size_t dilation_height = 2; dilation_height <= 3; dilation_height++) { DeconvolutionOperatorTester() @@ -14697,7 +14697,7 @@ TEST(DECONVOLUTION_NHWC_F32, batched_grouped_3x3_with_height_dilation) { } TEST(DECONVOLUTION_NHWC_F32, batched_grouped_3x3_with_width_dilation) { - const struct xnn_gemm_config* gemm_config = xnn_init_f32_gemm_config(); + const struct xnn_gemm_config* gemm_config = xnn_init_f32_igemm_config(); ASSERT_NE(gemm_config, nullptr); for (size_t dilation_width = 2; dilation_width <= 3; dilation_width++) { DeconvolutionOperatorTester() @@ -14714,7 +14714,7 @@ TEST(DECONVOLUTION_NHWC_F32, batched_grouped_3x3_with_width_dilation) { } TEST(DECONVOLUTION_NHWC_F32, batched_grouped_3x3_with_height_dilation_and_stride) { - const struct xnn_gemm_config* gemm_config = xnn_init_f32_gemm_config(); + const struct xnn_gemm_config* gemm_config = xnn_init_f32_igemm_config(); ASSERT_NE(gemm_config, nullptr); DeconvolutionOperatorTester() .batch_size(2) @@ -14730,7 +14730,7 @@ TEST(DECONVOLUTION_NHWC_F32, batched_grouped_3x3_with_height_dilation_and_stride } TEST(DECONVOLUTION_NHWC_F32, batched_grouped_3x3_with_width_dilation_and_stride) { - const struct xnn_gemm_config* gemm_config = xnn_init_f32_gemm_config(); + const struct xnn_gemm_config* gemm_config = xnn_init_f32_igemm_config(); ASSERT_NE(gemm_config, nullptr); DeconvolutionOperatorTester() .batch_size(2) @@ -14746,7 +14746,7 @@ TEST(DECONVOLUTION_NHWC_F32, batched_grouped_3x3_with_width_dilation_and_stride) } TEST(DECONVOLUTION_NHWC_F32, batched_grouped_3x3_with_input_stride) { - const struct xnn_gemm_config* gemm_config = xnn_init_f32_gemm_config(); + const struct xnn_gemm_config* gemm_config = xnn_init_f32_igemm_config(); ASSERT_NE(gemm_config, nullptr); DeconvolutionOperatorTester() .batch_size(2) @@ -14761,7 +14761,7 @@ TEST(DECONVOLUTION_NHWC_F32, batched_grouped_3x3_with_input_stride) { } TEST(DECONVOLUTION_NHWC_F32, batched_grouped_3x3_with_output_stride) { - const struct xnn_gemm_config* gemm_config = xnn_init_f32_gemm_config(); + const struct xnn_gemm_config* gemm_config = xnn_init_f32_igemm_config(); ASSERT_NE(gemm_config, nullptr); DeconvolutionOperatorTester() .batch_size(2) @@ -14776,7 +14776,7 @@ TEST(DECONVOLUTION_NHWC_F32, batched_grouped_3x3_with_output_stride) { } TEST(DECONVOLUTION_NHWC_F32, batched_grouped_3x3_with_qmin) { - const struct xnn_gemm_config* gemm_config = xnn_init_f32_gemm_config(); + const struct xnn_gemm_config* gemm_config = xnn_init_f32_igemm_config(); ASSERT_NE(gemm_config, nullptr); DeconvolutionOperatorTester() .batch_size(2) @@ -14791,7 +14791,7 @@ TEST(DECONVOLUTION_NHWC_F32, batched_grouped_3x3_with_qmin) { } TEST(DECONVOLUTION_NHWC_F32, batched_grouped_3x3_with_qmax) { - const struct xnn_gemm_config* gemm_config = xnn_init_f32_gemm_config(); + const struct xnn_gemm_config* gemm_config = xnn_init_f32_igemm_config(); ASSERT_NE(gemm_config, nullptr); DeconvolutionOperatorTester() .batch_size(2) @@ -14806,7 +14806,7 @@ TEST(DECONVOLUTION_NHWC_F32, batched_grouped_3x3_with_qmax) { } TEST(DECONVOLUTION_NHWC_F32, batched_grouped_3x3_without_bias) { - const struct xnn_gemm_config* gemm_config = xnn_init_f32_gemm_config(); + const struct xnn_gemm_config* gemm_config = xnn_init_f32_igemm_config(); ASSERT_NE(gemm_config, nullptr); DeconvolutionOperatorTester() .has_bias(false) @@ -14821,7 +14821,7 @@ TEST(DECONVOLUTION_NHWC_F32, batched_grouped_3x3_without_bias) { } TEST(DECONVOLUTION_NHWC_F32, weights_cache_batched_grouped_3x3) { - const struct xnn_gemm_config* gemm_config = xnn_init_f32_gemm_config(); + const struct xnn_gemm_config* gemm_config = xnn_init_f32_igemm_config(); ASSERT_NE(gemm_config, nullptr); DeconvolutionOperatorTester() .batch_size(2) @@ -14882,7 +14882,7 @@ TEST(DECONVOLUTION_NHWC_F32, kernel_3x3_setup_changing_width) { /**************************** SUBCONV2D/IGEMM path ****************************/ TEST(DECONVOLUTION_NHWC_F32, kernel_3x3s2) { - const struct xnn_gemm_config* gemm_config = xnn_init_f32_gemm_config(); + const struct xnn_gemm_config* gemm_config = xnn_init_f32_igemm_config(); ASSERT_NE(gemm_config, nullptr); DeconvolutionOperatorTester() .input_size(kStridedInputHeight, kStridedInputWidth) @@ -14895,7 +14895,7 @@ TEST(DECONVOLUTION_NHWC_F32, kernel_3x3s2) { } TEST(DECONVOLUTION_NHWC_F32, Kx3s2) { - const struct xnn_gemm_config* gemm_config = xnn_init_f32_gemm_config(); + const struct xnn_gemm_config* gemm_config = xnn_init_f32_igemm_config(); ASSERT_NE(gemm_config, nullptr); for (size_t kernel_height = 2; kernel_height <= 5; kernel_height++) { DeconvolutionOperatorTester() @@ -14910,7 +14910,7 @@ TEST(DECONVOLUTION_NHWC_F32, Kx3s2) { } TEST(DECONVOLUTION_NHWC_F32, kernel_3xKs2) { - const struct xnn_gemm_config* gemm_config = xnn_init_f32_gemm_config(); + const struct xnn_gemm_config* gemm_config = xnn_init_f32_igemm_config(); ASSERT_NE(gemm_config, nullptr); for (size_t kernel_width = 2; kernel_width <= 5; kernel_width++) { DeconvolutionOperatorTester() @@ -14925,7 +14925,7 @@ TEST(DECONVOLUTION_NHWC_F32, kernel_3xKs2) { } TEST(DECONVOLUTION_NHWC_F32, kernel_3x3sSx1) { - const struct xnn_gemm_config* gemm_config = xnn_init_f32_gemm_config(); + const struct xnn_gemm_config* gemm_config = xnn_init_f32_igemm_config(); ASSERT_NE(gemm_config, nullptr); for (size_t stride_height = 2; stride_height <= 3; stride_height++) { DeconvolutionOperatorTester() @@ -14941,7 +14941,7 @@ TEST(DECONVOLUTION_NHWC_F32, kernel_3x3sSx1) { } TEST(DECONVOLUTION_NHWC_F32, kernel_3x3s1xS) { - const struct xnn_gemm_config* gemm_config = xnn_init_f32_gemm_config(); + const struct xnn_gemm_config* gemm_config = xnn_init_f32_igemm_config(); ASSERT_NE(gemm_config, nullptr); for (size_t stride_width = 2; stride_width <= 3; stride_width++) { DeconvolutionOperatorTester() @@ -14957,7 +14957,7 @@ TEST(DECONVOLUTION_NHWC_F32, kernel_3x3s1xS) { } TEST(DECONVOLUTION_NHWC_F32, kernel_3x3s2_varying_height_padding) { - const struct xnn_gemm_config* gemm_config = xnn_init_f32_gemm_config(); + const struct xnn_gemm_config* gemm_config = xnn_init_f32_igemm_config(); ASSERT_NE(gemm_config, nullptr); for (size_t padding_top = 0; padding_top <= 2; padding_top++) { for (size_t padding_bottom = 0; padding_bottom <= 2; padding_bottom++) { @@ -14976,7 +14976,7 @@ TEST(DECONVOLUTION_NHWC_F32, kernel_3x3s2_varying_height_padding) { } TEST(DECONVOLUTION_NHWC_F32, kernel_3x3s2_varying_width_padding) { - const struct xnn_gemm_config* gemm_config = xnn_init_f32_gemm_config(); + const struct xnn_gemm_config* gemm_config = xnn_init_f32_igemm_config(); ASSERT_NE(gemm_config, nullptr); for (size_t padding_left = 0; padding_left <= 2; padding_left++) { for (size_t padding_right = 0; padding_right <= 2; padding_right++) { @@ -14995,7 +14995,7 @@ TEST(DECONVOLUTION_NHWC_F32, kernel_3x3s2_varying_width_padding) { } TEST(DECONVOLUTION_NHWC_F32, kernel_3x3s2_varying_height_adjustment) { - const struct xnn_gemm_config* gemm_config = xnn_init_f32_gemm_config(); + const struct xnn_gemm_config* gemm_config = xnn_init_f32_igemm_config(); ASSERT_NE(gemm_config, nullptr); for (size_t adjustment_height = 0; adjustment_height <= 1; adjustment_height++) { DeconvolutionOperatorTester() @@ -15011,7 +15011,7 @@ TEST(DECONVOLUTION_NHWC_F32, kernel_3x3s2_varying_height_adjustment) { } TEST(DECONVOLUTION_NHWC_F32, kernel_3x3s2_varying_width_adjustment) { - const struct xnn_gemm_config* gemm_config = xnn_init_f32_gemm_config(); + const struct xnn_gemm_config* gemm_config = xnn_init_f32_igemm_config(); ASSERT_NE(gemm_config, nullptr); for (size_t adjustment_width = 0; adjustment_width <= 1; adjustment_width++) { DeconvolutionOperatorTester() @@ -15027,7 +15027,7 @@ TEST(DECONVOLUTION_NHWC_F32, kernel_3x3s2_varying_width_adjustment) { } TEST(DECONVOLUTION_NHWC_F32, kernel_3x3s2_varying_input_height) { - const struct xnn_gemm_config* gemm_config = xnn_init_f32_gemm_config(); + const struct xnn_gemm_config* gemm_config = xnn_init_f32_igemm_config(); ASSERT_NE(gemm_config, nullptr); for (size_t input_height = kStridedInputHeight - 2; input_height <= kStridedInputHeight + 2; input_height++) { DeconvolutionOperatorTester() @@ -15042,7 +15042,7 @@ TEST(DECONVOLUTION_NHWC_F32, kernel_3x3s2_varying_input_height) { } TEST(DECONVOLUTION_NHWC_F32, kernel_3x3s2_varying_input_width) { - const struct xnn_gemm_config* gemm_config = xnn_init_f32_gemm_config(); + const struct xnn_gemm_config* gemm_config = xnn_init_f32_igemm_config(); ASSERT_NE(gemm_config, nullptr); for (size_t input_width = kStridedInputWidth - 2; input_width <= kStridedInputWidth + 2; input_width++) { DeconvolutionOperatorTester() @@ -15057,7 +15057,7 @@ TEST(DECONVOLUTION_NHWC_F32, kernel_3x3s2_varying_input_width) { } TEST(DECONVOLUTION_NHWC_F32, kernel_3x3s2_varying_input_channels) { - const struct xnn_gemm_config* gemm_config = xnn_init_f32_gemm_config(); + const struct xnn_gemm_config* gemm_config = xnn_init_f32_igemm_config(); ASSERT_NE(gemm_config, nullptr); for (size_t input_channels = 1; input_channels <= 16; input_channels *= 4) { DeconvolutionOperatorTester() @@ -15072,7 +15072,7 @@ TEST(DECONVOLUTION_NHWC_F32, kernel_3x3s2_varying_input_channels) { } TEST(DECONVOLUTION_NHWC_F32, kernel_3x3s2_varying_output_channels) { - const struct xnn_gemm_config* gemm_config = xnn_init_f32_gemm_config(); + const struct xnn_gemm_config* gemm_config = xnn_init_f32_igemm_config(); ASSERT_NE(gemm_config, nullptr); for (size_t output_channels = 1; output_channels <= gemm_config->nr * 2; output_channels *= 2) { DeconvolutionOperatorTester() @@ -15087,7 +15087,7 @@ TEST(DECONVOLUTION_NHWC_F32, kernel_3x3s2_varying_output_channels) { } TEST(DECONVOLUTION_NHWC_F32, kernel_3x3s2_with_input_stride) { - const struct xnn_gemm_config* gemm_config = xnn_init_f32_gemm_config(); + const struct xnn_gemm_config* gemm_config = xnn_init_f32_igemm_config(); ASSERT_NE(gemm_config, nullptr); DeconvolutionOperatorTester() .input_size(kStridedInputHeight, kStridedInputWidth) @@ -15101,7 +15101,7 @@ TEST(DECONVOLUTION_NHWC_F32, kernel_3x3s2_with_input_stride) { } TEST(DECONVOLUTION_NHWC_F32, kernel_3x3s2_with_output_stride) { - const struct xnn_gemm_config* gemm_config = xnn_init_f32_gemm_config(); + const struct xnn_gemm_config* gemm_config = xnn_init_f32_igemm_config(); ASSERT_NE(gemm_config, nullptr); DeconvolutionOperatorTester() .input_size(kStridedInputHeight, kStridedInputWidth) @@ -15115,7 +15115,7 @@ TEST(DECONVOLUTION_NHWC_F32, kernel_3x3s2_with_output_stride) { } TEST(DECONVOLUTION_NHWC_F32, kernel_3x3s2_with_qmin) { - const struct xnn_gemm_config* gemm_config = xnn_init_f32_gemm_config(); + const struct xnn_gemm_config* gemm_config = xnn_init_f32_igemm_config(); ASSERT_NE(gemm_config, nullptr); DeconvolutionOperatorTester() .input_size(kStridedInputHeight, kStridedInputWidth) @@ -15129,7 +15129,7 @@ TEST(DECONVOLUTION_NHWC_F32, kernel_3x3s2_with_qmin) { } TEST(DECONVOLUTION_NHWC_F32, kernel_3x3s2_with_qmax) { - const struct xnn_gemm_config* gemm_config = xnn_init_f32_gemm_config(); + const struct xnn_gemm_config* gemm_config = xnn_init_f32_igemm_config(); ASSERT_NE(gemm_config, nullptr); DeconvolutionOperatorTester() .input_size(kStridedInputHeight, kStridedInputWidth) @@ -15143,7 +15143,7 @@ TEST(DECONVOLUTION_NHWC_F32, kernel_3x3s2_with_qmax) { } TEST(DECONVOLUTION_NHWC_F32, kernel_3x3s2_without_bias) { - const struct xnn_gemm_config* gemm_config = xnn_init_f32_gemm_config(); + const struct xnn_gemm_config* gemm_config = xnn_init_f32_igemm_config(); ASSERT_NE(gemm_config, nullptr); DeconvolutionOperatorTester() .has_bias(false) @@ -15157,7 +15157,7 @@ TEST(DECONVOLUTION_NHWC_F32, kernel_3x3s2_without_bias) { } TEST(DECONVOLUTION_NHWC_F32, weights_cache_3x3s2) { - const struct xnn_gemm_config* gemm_config = xnn_init_f32_gemm_config(); + const struct xnn_gemm_config* gemm_config = xnn_init_f32_igemm_config(); ASSERT_NE(gemm_config, nullptr); DeconvolutionOperatorTester() .input_size(kStridedInputHeight, kStridedInputWidth) @@ -15171,7 +15171,7 @@ TEST(DECONVOLUTION_NHWC_F32, weights_cache_3x3s2) { } TEST(DECONVOLUTION_NHWC_F32, stress_weights_cache_5x5s4) { - const struct xnn_gemm_config* gemm_config = xnn_init_f32_gemm_config(); + const struct xnn_gemm_config* gemm_config = xnn_init_f32_igemm_config(); ASSERT_NE(gemm_config, nullptr); DeconvolutionOperatorTester() .input_size(kStridedInputHeight, kStridedInputWidth) @@ -15186,7 +15186,7 @@ TEST(DECONVOLUTION_NHWC_F32, stress_weights_cache_5x5s4) { /**************************** SUBCONV2D/IGEMM path, grouped ****************************/ TEST(DECONVOLUTION_NHWC_F32, grouped_3x3s2) { - const struct xnn_gemm_config* gemm_config = xnn_init_f32_gemm_config(); + const struct xnn_gemm_config* gemm_config = xnn_init_f32_igemm_config(); ASSERT_NE(gemm_config, nullptr); DeconvolutionOperatorTester() .input_size(kStridedInputHeight, kStridedInputWidth) @@ -15200,7 +15200,7 @@ TEST(DECONVOLUTION_NHWC_F32, grouped_3x3s2) { } TEST(DECONVOLUTION_NHWC_F32, grouped_Kx3s2) { - const struct xnn_gemm_config* gemm_config = xnn_init_f32_gemm_config(); + const struct xnn_gemm_config* gemm_config = xnn_init_f32_igemm_config(); ASSERT_NE(gemm_config, nullptr); for (size_t kernel_height = 2; kernel_height <= 5; kernel_height++) { DeconvolutionOperatorTester() @@ -15216,7 +15216,7 @@ TEST(DECONVOLUTION_NHWC_F32, grouped_Kx3s2) { } TEST(DECONVOLUTION_NHWC_F32, grouped_3xKs2) { - const struct xnn_gemm_config* gemm_config = xnn_init_f32_gemm_config(); + const struct xnn_gemm_config* gemm_config = xnn_init_f32_igemm_config(); ASSERT_NE(gemm_config, nullptr); for (size_t kernel_width = 2; kernel_width <= 5; kernel_width++) { DeconvolutionOperatorTester() @@ -15232,7 +15232,7 @@ TEST(DECONVOLUTION_NHWC_F32, grouped_3xKs2) { } TEST(DECONVOLUTION_NHWC_F32, grouped_3x3sSx1) { - const struct xnn_gemm_config* gemm_config = xnn_init_f32_gemm_config(); + const struct xnn_gemm_config* gemm_config = xnn_init_f32_igemm_config(); ASSERT_NE(gemm_config, nullptr); for (size_t stride_height = 2; stride_height <= 3; stride_height++) { DeconvolutionOperatorTester() @@ -15249,7 +15249,7 @@ TEST(DECONVOLUTION_NHWC_F32, grouped_3x3sSx1) { } TEST(DECONVOLUTION_NHWC_F32, grouped_3x3s1xS) { - const struct xnn_gemm_config* gemm_config = xnn_init_f32_gemm_config(); + const struct xnn_gemm_config* gemm_config = xnn_init_f32_igemm_config(); ASSERT_NE(gemm_config, nullptr); for (size_t stride_width = 2; stride_width <= 3; stride_width++) { DeconvolutionOperatorTester() @@ -15266,7 +15266,7 @@ TEST(DECONVOLUTION_NHWC_F32, grouped_3x3s1xS) { } TEST(DECONVOLUTION_NHWC_F32, grouped_3x3s2_varying_height_padding) { - const struct xnn_gemm_config* gemm_config = xnn_init_f32_gemm_config(); + const struct xnn_gemm_config* gemm_config = xnn_init_f32_igemm_config(); ASSERT_NE(gemm_config, nullptr); for (size_t padding_top = 0; padding_top <= 2; padding_top++) { for (size_t padding_bottom = 0; padding_bottom <= 2; padding_bottom++) { @@ -15286,7 +15286,7 @@ TEST(DECONVOLUTION_NHWC_F32, grouped_3x3s2_varying_height_padding) { } TEST(DECONVOLUTION_NHWC_F32, grouped_3x3s2_varying_width_padding) { - const struct xnn_gemm_config* gemm_config = xnn_init_f32_gemm_config(); + const struct xnn_gemm_config* gemm_config = xnn_init_f32_igemm_config(); ASSERT_NE(gemm_config, nullptr); for (size_t padding_left = 0; padding_left <= 2; padding_left++) { for (size_t padding_right = 0; padding_right <= 2; padding_right++) { @@ -15306,7 +15306,7 @@ TEST(DECONVOLUTION_NHWC_F32, grouped_3x3s2_varying_width_padding) { } TEST(DECONVOLUTION_NHWC_F32, grouped_3x3s2_varying_height_adjustment) { - const struct xnn_gemm_config* gemm_config = xnn_init_f32_gemm_config(); + const struct xnn_gemm_config* gemm_config = xnn_init_f32_igemm_config(); ASSERT_NE(gemm_config, nullptr); for (size_t adjustment_height = 0; adjustment_height <= 1; adjustment_height++) { DeconvolutionOperatorTester() @@ -15323,7 +15323,7 @@ TEST(DECONVOLUTION_NHWC_F32, grouped_3x3s2_varying_height_adjustment) { } TEST(DECONVOLUTION_NHWC_F32, grouped_3x3s2_varying_width_adjustment) { - const struct xnn_gemm_config* gemm_config = xnn_init_f32_gemm_config(); + const struct xnn_gemm_config* gemm_config = xnn_init_f32_igemm_config(); ASSERT_NE(gemm_config, nullptr); for (size_t adjustment_width = 0; adjustment_width <= 1; adjustment_width++) { DeconvolutionOperatorTester() @@ -15340,7 +15340,7 @@ TEST(DECONVOLUTION_NHWC_F32, grouped_3x3s2_varying_width_adjustment) { } TEST(DECONVOLUTION_NHWC_F32, grouped_3x3s2_varying_input_height) { - const struct xnn_gemm_config* gemm_config = xnn_init_f32_gemm_config(); + const struct xnn_gemm_config* gemm_config = xnn_init_f32_igemm_config(); ASSERT_NE(gemm_config, nullptr); for (size_t input_height = kStridedInputHeight - 2; input_height <= kStridedInputHeight + 2; input_height++) { DeconvolutionOperatorTester() @@ -15356,7 +15356,7 @@ TEST(DECONVOLUTION_NHWC_F32, grouped_3x3s2_varying_input_height) { } TEST(DECONVOLUTION_NHWC_F32, grouped_3x3s2_varying_input_width) { - const struct xnn_gemm_config* gemm_config = xnn_init_f32_gemm_config(); + const struct xnn_gemm_config* gemm_config = xnn_init_f32_igemm_config(); ASSERT_NE(gemm_config, nullptr); for (size_t input_width = kStridedInputWidth - 2; input_width <= kStridedInputWidth + 2; input_width++) { DeconvolutionOperatorTester() @@ -15372,7 +15372,7 @@ TEST(DECONVOLUTION_NHWC_F32, grouped_3x3s2_varying_input_width) { } TEST(DECONVOLUTION_NHWC_F32, grouped_3x3s2_varying_input_channels) { - const struct xnn_gemm_config* gemm_config = xnn_init_f32_gemm_config(); + const struct xnn_gemm_config* gemm_config = xnn_init_f32_igemm_config(); ASSERT_NE(gemm_config, nullptr); for (size_t input_channels = 14; input_channels <= 20; input_channels++) { DeconvolutionOperatorTester() @@ -15388,7 +15388,7 @@ TEST(DECONVOLUTION_NHWC_F32, grouped_3x3s2_varying_input_channels) { } TEST(DECONVOLUTION_NHWC_F32, grouped_3x3s2_varying_output_channels) { - const struct xnn_gemm_config* gemm_config = xnn_init_f32_gemm_config(); + const struct xnn_gemm_config* gemm_config = xnn_init_f32_igemm_config(); ASSERT_NE(gemm_config, nullptr); for (size_t output_channels = 1; output_channels <= gemm_config->nr * 2; output_channels *= 2) { DeconvolutionOperatorTester() @@ -15404,7 +15404,7 @@ TEST(DECONVOLUTION_NHWC_F32, grouped_3x3s2_varying_output_channels) { } TEST(DECONVOLUTION_NHWC_F32, grouped_3x3s2_with_input_stride) { - const struct xnn_gemm_config* gemm_config = xnn_init_f32_gemm_config(); + const struct xnn_gemm_config* gemm_config = xnn_init_f32_igemm_config(); ASSERT_NE(gemm_config, nullptr); DeconvolutionOperatorTester() .input_size(kStridedInputHeight, kStridedInputWidth) @@ -15419,7 +15419,7 @@ TEST(DECONVOLUTION_NHWC_F32, grouped_3x3s2_with_input_stride) { } TEST(DECONVOLUTION_NHWC_F32, grouped_3x3s2_with_output_stride) { - const struct xnn_gemm_config* gemm_config = xnn_init_f32_gemm_config(); + const struct xnn_gemm_config* gemm_config = xnn_init_f32_igemm_config(); ASSERT_NE(gemm_config, nullptr); DeconvolutionOperatorTester() .input_size(kStridedInputHeight, kStridedInputWidth) @@ -15434,7 +15434,7 @@ TEST(DECONVOLUTION_NHWC_F32, grouped_3x3s2_with_output_stride) { } TEST(DECONVOLUTION_NHWC_F32, grouped_3x3s2_with_qmin) { - const struct xnn_gemm_config* gemm_config = xnn_init_f32_gemm_config(); + const struct xnn_gemm_config* gemm_config = xnn_init_f32_igemm_config(); ASSERT_NE(gemm_config, nullptr); DeconvolutionOperatorTester() .input_size(kStridedInputHeight, kStridedInputWidth) @@ -15449,7 +15449,7 @@ TEST(DECONVOLUTION_NHWC_F32, grouped_3x3s2_with_qmin) { } TEST(DECONVOLUTION_NHWC_F32, grouped_3x3s2_with_qmax) { - const struct xnn_gemm_config* gemm_config = xnn_init_f32_gemm_config(); + const struct xnn_gemm_config* gemm_config = xnn_init_f32_igemm_config(); ASSERT_NE(gemm_config, nullptr); DeconvolutionOperatorTester() .input_size(kStridedInputHeight, kStridedInputWidth) @@ -15464,7 +15464,7 @@ TEST(DECONVOLUTION_NHWC_F32, grouped_3x3s2_with_qmax) { } TEST(DECONVOLUTION_NHWC_F32, grouped_3x3s2_without_bias) { - const struct xnn_gemm_config* gemm_config = xnn_init_f32_gemm_config(); + const struct xnn_gemm_config* gemm_config = xnn_init_f32_igemm_config(); ASSERT_NE(gemm_config, nullptr); DeconvolutionOperatorTester() .has_bias(false) @@ -15479,7 +15479,7 @@ TEST(DECONVOLUTION_NHWC_F32, grouped_3x3s2_without_bias) { } TEST(DECONVOLUTION_NHWC_F32, weights_cache_grouped_3x3s2) { - const struct xnn_gemm_config* gemm_config = xnn_init_f32_gemm_config(); + const struct xnn_gemm_config* gemm_config = xnn_init_f32_igemm_config(); ASSERT_NE(gemm_config, nullptr); DeconvolutionOperatorTester() .input_size(kStridedInputHeight, kStridedInputWidth) @@ -15496,7 +15496,7 @@ TEST(DECONVOLUTION_NHWC_F32, weights_cache_grouped_3x3s2) { /**************************** SUBCONV2D/IGEMM path, batched ****************************/ TEST(DECONVOLUTION_NHWC_F32, batched_3x3s2) { - const struct xnn_gemm_config* gemm_config = xnn_init_f32_gemm_config(); + const struct xnn_gemm_config* gemm_config = xnn_init_f32_igemm_config(); ASSERT_NE(gemm_config, nullptr); DeconvolutionOperatorTester() .batch_size(2) @@ -15510,7 +15510,7 @@ TEST(DECONVOLUTION_NHWC_F32, batched_3x3s2) { } TEST(DECONVOLUTION_NHWC_F32, batched_Kx3s2) { - const struct xnn_gemm_config* gemm_config = xnn_init_f32_gemm_config(); + const struct xnn_gemm_config* gemm_config = xnn_init_f32_igemm_config(); ASSERT_NE(gemm_config, nullptr); for (size_t kernel_height = 2; kernel_height <= 5; kernel_height++) { DeconvolutionOperatorTester() @@ -15526,7 +15526,7 @@ TEST(DECONVOLUTION_NHWC_F32, batched_Kx3s2) { } TEST(DECONVOLUTION_NHWC_F32, batched_3xKs2) { - const struct xnn_gemm_config* gemm_config = xnn_init_f32_gemm_config(); + const struct xnn_gemm_config* gemm_config = xnn_init_f32_igemm_config(); ASSERT_NE(gemm_config, nullptr); for (size_t kernel_width = 2; kernel_width <= 5; kernel_width++) { DeconvolutionOperatorTester() @@ -15542,7 +15542,7 @@ TEST(DECONVOLUTION_NHWC_F32, batched_3xKs2) { } TEST(DECONVOLUTION_NHWC_F32, batched_3x3sSx1) { - const struct xnn_gemm_config* gemm_config = xnn_init_f32_gemm_config(); + const struct xnn_gemm_config* gemm_config = xnn_init_f32_igemm_config(); ASSERT_NE(gemm_config, nullptr); for (size_t stride_height = 2; stride_height <= 3; stride_height++) { DeconvolutionOperatorTester() @@ -15559,7 +15559,7 @@ TEST(DECONVOLUTION_NHWC_F32, batched_3x3sSx1) { } TEST(DECONVOLUTION_NHWC_F32, batched_3x3s1xS) { - const struct xnn_gemm_config* gemm_config = xnn_init_f32_gemm_config(); + const struct xnn_gemm_config* gemm_config = xnn_init_f32_igemm_config(); ASSERT_NE(gemm_config, nullptr); for (size_t stride_width = 2; stride_width <= 3; stride_width++) { DeconvolutionOperatorTester() @@ -15576,7 +15576,7 @@ TEST(DECONVOLUTION_NHWC_F32, batched_3x3s1xS) { } TEST(DECONVOLUTION_NHWC_F32, batched_3x3s2_varying_height_padding) { - const struct xnn_gemm_config* gemm_config = xnn_init_f32_gemm_config(); + const struct xnn_gemm_config* gemm_config = xnn_init_f32_igemm_config(); ASSERT_NE(gemm_config, nullptr); for (size_t padding_top = 0; padding_top <= 2; padding_top++) { for (size_t padding_bottom = 0; padding_bottom <= 2; padding_bottom++) { @@ -15596,7 +15596,7 @@ TEST(DECONVOLUTION_NHWC_F32, batched_3x3s2_varying_height_padding) { } TEST(DECONVOLUTION_NHWC_F32, batched_3x3s2_varying_width_padding) { - const struct xnn_gemm_config* gemm_config = xnn_init_f32_gemm_config(); + const struct xnn_gemm_config* gemm_config = xnn_init_f32_igemm_config(); ASSERT_NE(gemm_config, nullptr); for (size_t padding_left = 0; padding_left <= 2; padding_left++) { for (size_t padding_right = 0; padding_right <= 2; padding_right++) { @@ -15616,7 +15616,7 @@ TEST(DECONVOLUTION_NHWC_F32, batched_3x3s2_varying_width_padding) { } TEST(DECONVOLUTION_NHWC_F32, batched_3x3s2_varying_height_adjustment) { - const struct xnn_gemm_config* gemm_config = xnn_init_f32_gemm_config(); + const struct xnn_gemm_config* gemm_config = xnn_init_f32_igemm_config(); ASSERT_NE(gemm_config, nullptr); for (size_t adjustment_height = 0; adjustment_height <= 1; adjustment_height++) { DeconvolutionOperatorTester() @@ -15633,7 +15633,7 @@ TEST(DECONVOLUTION_NHWC_F32, batched_3x3s2_varying_height_adjustment) { } TEST(DECONVOLUTION_NHWC_F32, batched_3x3s2_varying_width_adjustment) { - const struct xnn_gemm_config* gemm_config = xnn_init_f32_gemm_config(); + const struct xnn_gemm_config* gemm_config = xnn_init_f32_igemm_config(); ASSERT_NE(gemm_config, nullptr); for (size_t adjustment_width = 0; adjustment_width <= 1; adjustment_width++) { DeconvolutionOperatorTester() @@ -15650,7 +15650,7 @@ TEST(DECONVOLUTION_NHWC_F32, batched_3x3s2_varying_width_adjustment) { } TEST(DECONVOLUTION_NHWC_F32, batched_3x3s2_varying_input_height) { - const struct xnn_gemm_config* gemm_config = xnn_init_f32_gemm_config(); + const struct xnn_gemm_config* gemm_config = xnn_init_f32_igemm_config(); ASSERT_NE(gemm_config, nullptr); for (size_t input_height = kStridedInputHeight - 2; input_height <= kStridedInputHeight + 2; input_height++) { DeconvolutionOperatorTester() @@ -15666,7 +15666,7 @@ TEST(DECONVOLUTION_NHWC_F32, batched_3x3s2_varying_input_height) { } TEST(DECONVOLUTION_NHWC_F32, batched_3x3s2_varying_input_width) { - const struct xnn_gemm_config* gemm_config = xnn_init_f32_gemm_config(); + const struct xnn_gemm_config* gemm_config = xnn_init_f32_igemm_config(); ASSERT_NE(gemm_config, nullptr); for (size_t input_width = kStridedInputWidth - 2; input_width <= kStridedInputWidth + 2; input_width++) { DeconvolutionOperatorTester() @@ -15682,7 +15682,7 @@ TEST(DECONVOLUTION_NHWC_F32, batched_3x3s2_varying_input_width) { } TEST(DECONVOLUTION_NHWC_F32, batched_3x3s2_varying_input_channels) { - const struct xnn_gemm_config* gemm_config = xnn_init_f32_gemm_config(); + const struct xnn_gemm_config* gemm_config = xnn_init_f32_igemm_config(); ASSERT_NE(gemm_config, nullptr); for (size_t input_channels = 1; input_channels <= 16; input_channels *= 4) { DeconvolutionOperatorTester() @@ -15698,7 +15698,7 @@ TEST(DECONVOLUTION_NHWC_F32, batched_3x3s2_varying_input_channels) { } TEST(DECONVOLUTION_NHWC_F32, batched_3x3s2_varying_output_channels) { - const struct xnn_gemm_config* gemm_config = xnn_init_f32_gemm_config(); + const struct xnn_gemm_config* gemm_config = xnn_init_f32_igemm_config(); ASSERT_NE(gemm_config, nullptr); for (size_t output_channels = 1; output_channels <= gemm_config->nr * 2; output_channels *= 2) { DeconvolutionOperatorTester() @@ -15714,7 +15714,7 @@ TEST(DECONVOLUTION_NHWC_F32, batched_3x3s2_varying_output_channels) { } TEST(DECONVOLUTION_NHWC_F32, batched_3x3s2_with_input_stride) { - const struct xnn_gemm_config* gemm_config = xnn_init_f32_gemm_config(); + const struct xnn_gemm_config* gemm_config = xnn_init_f32_igemm_config(); ASSERT_NE(gemm_config, nullptr); DeconvolutionOperatorTester() .batch_size(2) @@ -15729,7 +15729,7 @@ TEST(DECONVOLUTION_NHWC_F32, batched_3x3s2_with_input_stride) { } TEST(DECONVOLUTION_NHWC_F32, batched_3x3s2_with_output_stride) { - const struct xnn_gemm_config* gemm_config = xnn_init_f32_gemm_config(); + const struct xnn_gemm_config* gemm_config = xnn_init_f32_igemm_config(); ASSERT_NE(gemm_config, nullptr); DeconvolutionOperatorTester() .batch_size(2) @@ -15744,7 +15744,7 @@ TEST(DECONVOLUTION_NHWC_F32, batched_3x3s2_with_output_stride) { } TEST(DECONVOLUTION_NHWC_F32, batched_3x3s2_with_qmin) { - const struct xnn_gemm_config* gemm_config = xnn_init_f32_gemm_config(); + const struct xnn_gemm_config* gemm_config = xnn_init_f32_igemm_config(); ASSERT_NE(gemm_config, nullptr); DeconvolutionOperatorTester() .batch_size(2) @@ -15759,7 +15759,7 @@ TEST(DECONVOLUTION_NHWC_F32, batched_3x3s2_with_qmin) { } TEST(DECONVOLUTION_NHWC_F32, batched_3x3s2_with_qmax) { - const struct xnn_gemm_config* gemm_config = xnn_init_f32_gemm_config(); + const struct xnn_gemm_config* gemm_config = xnn_init_f32_igemm_config(); ASSERT_NE(gemm_config, nullptr); DeconvolutionOperatorTester() .batch_size(2) @@ -15774,7 +15774,7 @@ TEST(DECONVOLUTION_NHWC_F32, batched_3x3s2_with_qmax) { } TEST(DECONVOLUTION_NHWC_F32, batched_3x3s2_without_bias) { - const struct xnn_gemm_config* gemm_config = xnn_init_f32_gemm_config(); + const struct xnn_gemm_config* gemm_config = xnn_init_f32_igemm_config(); ASSERT_NE(gemm_config, nullptr); DeconvolutionOperatorTester() .has_bias(false) @@ -15789,7 +15789,7 @@ TEST(DECONVOLUTION_NHWC_F32, batched_3x3s2_without_bias) { } TEST(DECONVOLUTION_NHWC_F32, weights_cache_batched_3x3s2) { - const struct xnn_gemm_config* gemm_config = xnn_init_f32_gemm_config(); + const struct xnn_gemm_config* gemm_config = xnn_init_f32_igemm_config(); ASSERT_NE(gemm_config, nullptr); DeconvolutionOperatorTester() .batch_size(2) @@ -15806,7 +15806,7 @@ TEST(DECONVOLUTION_NHWC_F32, weights_cache_batched_3x3s2) { /**************************** SUBCONV2D/IGEMM path, grouped, batched ****************************/ TEST(DECONVOLUTION_NHWC_F32, batched_grouped_3x3s2) { - const struct xnn_gemm_config* gemm_config = xnn_init_f32_gemm_config(); + const struct xnn_gemm_config* gemm_config = xnn_init_f32_igemm_config(); ASSERT_NE(gemm_config, nullptr); DeconvolutionOperatorTester() .batch_size(2) @@ -15821,7 +15821,7 @@ TEST(DECONVOLUTION_NHWC_F32, batched_grouped_3x3s2) { } TEST(DECONVOLUTION_NHWC_F32, batched_grouped_Kx3s2) { - const struct xnn_gemm_config* gemm_config = xnn_init_f32_gemm_config(); + const struct xnn_gemm_config* gemm_config = xnn_init_f32_igemm_config(); ASSERT_NE(gemm_config, nullptr); for (size_t kernel_height = 2; kernel_height <= 5; kernel_height++) { DeconvolutionOperatorTester() @@ -15838,7 +15838,7 @@ TEST(DECONVOLUTION_NHWC_F32, batched_grouped_Kx3s2) { } TEST(DECONVOLUTION_NHWC_F32, batched_grouped_3xKs2) { - const struct xnn_gemm_config* gemm_config = xnn_init_f32_gemm_config(); + const struct xnn_gemm_config* gemm_config = xnn_init_f32_igemm_config(); ASSERT_NE(gemm_config, nullptr); for (size_t kernel_width = 2; kernel_width <= 5; kernel_width++) { DeconvolutionOperatorTester() @@ -15855,7 +15855,7 @@ TEST(DECONVOLUTION_NHWC_F32, batched_grouped_3xKs2) { } TEST(DECONVOLUTION_NHWC_F32, batched_grouped_3x3sSx1) { - const struct xnn_gemm_config* gemm_config = xnn_init_f32_gemm_config(); + const struct xnn_gemm_config* gemm_config = xnn_init_f32_igemm_config(); ASSERT_NE(gemm_config, nullptr); for (size_t stride_height = 2; stride_height <= 3; stride_height++) { DeconvolutionOperatorTester() @@ -15873,7 +15873,7 @@ TEST(DECONVOLUTION_NHWC_F32, batched_grouped_3x3sSx1) { } TEST(DECONVOLUTION_NHWC_F32, batched_grouped_3x3s1xS) { - const struct xnn_gemm_config* gemm_config = xnn_init_f32_gemm_config(); + const struct xnn_gemm_config* gemm_config = xnn_init_f32_igemm_config(); ASSERT_NE(gemm_config, nullptr); for (size_t stride_width = 2; stride_width <= 3; stride_width++) { DeconvolutionOperatorTester() @@ -15891,7 +15891,7 @@ TEST(DECONVOLUTION_NHWC_F32, batched_grouped_3x3s1xS) { } TEST(DECONVOLUTION_NHWC_F32, batched_grouped_3x3s2_varying_height_padding) { - const struct xnn_gemm_config* gemm_config = xnn_init_f32_gemm_config(); + const struct xnn_gemm_config* gemm_config = xnn_init_f32_igemm_config(); ASSERT_NE(gemm_config, nullptr); for (size_t padding_top = 0; padding_top <= 2; padding_top++) { for (size_t padding_bottom = 0; padding_bottom <= 2; padding_bottom++) { @@ -15912,7 +15912,7 @@ TEST(DECONVOLUTION_NHWC_F32, batched_grouped_3x3s2_varying_height_padding) { } TEST(DECONVOLUTION_NHWC_F32, batched_grouped_3x3s2_varying_width_padding) { - const struct xnn_gemm_config* gemm_config = xnn_init_f32_gemm_config(); + const struct xnn_gemm_config* gemm_config = xnn_init_f32_igemm_config(); ASSERT_NE(gemm_config, nullptr); for (size_t padding_left = 0; padding_left <= 2; padding_left++) { for (size_t padding_right = 0; padding_right <= 2; padding_right++) { @@ -15933,7 +15933,7 @@ TEST(DECONVOLUTION_NHWC_F32, batched_grouped_3x3s2_varying_width_padding) { } TEST(DECONVOLUTION_NHWC_F32, batched_grouped_3x3s2_varying_height_adjustment) { - const struct xnn_gemm_config* gemm_config = xnn_init_f32_gemm_config(); + const struct xnn_gemm_config* gemm_config = xnn_init_f32_igemm_config(); ASSERT_NE(gemm_config, nullptr); for (size_t adjustment_height = 0; adjustment_height <= 1; adjustment_height++) { DeconvolutionOperatorTester() @@ -15951,7 +15951,7 @@ TEST(DECONVOLUTION_NHWC_F32, batched_grouped_3x3s2_varying_height_adjustment) { } TEST(DECONVOLUTION_NHWC_F32, batched_grouped_3x3s2_varying_width_adjustment) { - const struct xnn_gemm_config* gemm_config = xnn_init_f32_gemm_config(); + const struct xnn_gemm_config* gemm_config = xnn_init_f32_igemm_config(); ASSERT_NE(gemm_config, nullptr); for (size_t adjustment_width = 0; adjustment_width <= 1; adjustment_width++) { DeconvolutionOperatorTester() @@ -15969,7 +15969,7 @@ TEST(DECONVOLUTION_NHWC_F32, batched_grouped_3x3s2_varying_width_adjustment) { } TEST(DECONVOLUTION_NHWC_F32, batched_grouped_3x3s2_varying_input_height) { - const struct xnn_gemm_config* gemm_config = xnn_init_f32_gemm_config(); + const struct xnn_gemm_config* gemm_config = xnn_init_f32_igemm_config(); ASSERT_NE(gemm_config, nullptr); for (size_t input_height = kStridedInputHeight - 2; input_height <= kStridedInputHeight + 2; input_height++) { DeconvolutionOperatorTester() @@ -15986,7 +15986,7 @@ TEST(DECONVOLUTION_NHWC_F32, batched_grouped_3x3s2_varying_input_height) { } TEST(DECONVOLUTION_NHWC_F32, batched_grouped_3x3s2_varying_input_width) { - const struct xnn_gemm_config* gemm_config = xnn_init_f32_gemm_config(); + const struct xnn_gemm_config* gemm_config = xnn_init_f32_igemm_config(); ASSERT_NE(gemm_config, nullptr); for (size_t input_width = kStridedInputWidth - 2; input_width <= kStridedInputWidth + 2; input_width++) { DeconvolutionOperatorTester() @@ -16003,7 +16003,7 @@ TEST(DECONVOLUTION_NHWC_F32, batched_grouped_3x3s2_varying_input_width) { } TEST(DECONVOLUTION_NHWC_F32, batched_grouped_3x3s2_varying_input_channels) { - const struct xnn_gemm_config* gemm_config = xnn_init_f32_gemm_config(); + const struct xnn_gemm_config* gemm_config = xnn_init_f32_igemm_config(); ASSERT_NE(gemm_config, nullptr); for (size_t input_channels = 14; input_channels <= 20; input_channels++) { DeconvolutionOperatorTester() @@ -16020,7 +16020,7 @@ TEST(DECONVOLUTION_NHWC_F32, batched_grouped_3x3s2_varying_input_channels) { } TEST(DECONVOLUTION_NHWC_F32, batched_grouped_3x3s2_varying_output_channels) { - const struct xnn_gemm_config* gemm_config = xnn_init_f32_gemm_config(); + const struct xnn_gemm_config* gemm_config = xnn_init_f32_igemm_config(); ASSERT_NE(gemm_config, nullptr); for (size_t output_channels = 1; output_channels <= gemm_config->nr * 2; output_channels *= 2) { DeconvolutionOperatorTester() @@ -16037,7 +16037,7 @@ TEST(DECONVOLUTION_NHWC_F32, batched_grouped_3x3s2_varying_output_channels) { } TEST(DECONVOLUTION_NHWC_F32, batched_grouped_3x3s2_with_input_stride) { - const struct xnn_gemm_config* gemm_config = xnn_init_f32_gemm_config(); + const struct xnn_gemm_config* gemm_config = xnn_init_f32_igemm_config(); ASSERT_NE(gemm_config, nullptr); DeconvolutionOperatorTester() .batch_size(2) @@ -16053,7 +16053,7 @@ TEST(DECONVOLUTION_NHWC_F32, batched_grouped_3x3s2_with_input_stride) { } TEST(DECONVOLUTION_NHWC_F32, batched_grouped_3x3s2_with_output_stride) { - const struct xnn_gemm_config* gemm_config = xnn_init_f32_gemm_config(); + const struct xnn_gemm_config* gemm_config = xnn_init_f32_igemm_config(); ASSERT_NE(gemm_config, nullptr); DeconvolutionOperatorTester() .batch_size(2) @@ -16069,7 +16069,7 @@ TEST(DECONVOLUTION_NHWC_F32, batched_grouped_3x3s2_with_output_stride) { } TEST(DECONVOLUTION_NHWC_F32, batched_grouped_3x3s2_with_qmin) { - const struct xnn_gemm_config* gemm_config = xnn_init_f32_gemm_config(); + const struct xnn_gemm_config* gemm_config = xnn_init_f32_igemm_config(); ASSERT_NE(gemm_config, nullptr); DeconvolutionOperatorTester() .batch_size(2) @@ -16085,7 +16085,7 @@ TEST(DECONVOLUTION_NHWC_F32, batched_grouped_3x3s2_with_qmin) { } TEST(DECONVOLUTION_NHWC_F32, batched_grouped_3x3s2_with_qmax) { - const struct xnn_gemm_config* gemm_config = xnn_init_f32_gemm_config(); + const struct xnn_gemm_config* gemm_config = xnn_init_f32_igemm_config(); ASSERT_NE(gemm_config, nullptr); DeconvolutionOperatorTester() .batch_size(2) @@ -16101,7 +16101,7 @@ TEST(DECONVOLUTION_NHWC_F32, batched_grouped_3x3s2_with_qmax) { } TEST(DECONVOLUTION_NHWC_F32, batched_grouped_3x3s2_without_bias) { - const struct xnn_gemm_config* gemm_config = xnn_init_f32_gemm_config(); + const struct xnn_gemm_config* gemm_config = xnn_init_f32_igemm_config(); ASSERT_NE(gemm_config, nullptr); DeconvolutionOperatorTester() .has_bias(false) @@ -16117,7 +16117,7 @@ TEST(DECONVOLUTION_NHWC_F32, batched_grouped_3x3s2_without_bias) { } TEST(DECONVOLUTION_NHWC_F32, weights_cache_batched_grouped_3x3s2) { - const struct xnn_gemm_config* gemm_config = xnn_init_f32_gemm_config(); + const struct xnn_gemm_config* gemm_config = xnn_init_f32_igemm_config(); ASSERT_NE(gemm_config, nullptr); DeconvolutionOperatorTester() .batch_size(2) @@ -16179,7 +16179,7 @@ TEST(DECONVOLUTION_NHWC_F32, kernel_3x3s2_setup_changing_width) { /**************************** SUBCONV2D/GEMM path ****************************/ TEST(DECONVOLUTION_NHWC_F32, kernel_2x2s2) { - const struct xnn_gemm_config* gemm_config = xnn_init_f32_gemm_config(); + const struct xnn_gemm_config* gemm_config = xnn_init_f32_igemm_config(); ASSERT_NE(gemm_config, nullptr); DeconvolutionOperatorTester() .input_size(kStridedInputHeight, kStridedInputWidth) @@ -16191,7 +16191,7 @@ TEST(DECONVOLUTION_NHWC_F32, kernel_2x2s2) { } TEST(DECONVOLUTION_NHWC_F32, Kx2sKx2) { - const struct xnn_gemm_config* gemm_config = xnn_init_f32_gemm_config(); + const struct xnn_gemm_config* gemm_config = xnn_init_f32_igemm_config(); ASSERT_NE(gemm_config, nullptr); for (size_t kernel_height = 3; kernel_height <= 5; kernel_height++) { DeconvolutionOperatorTester() @@ -16205,7 +16205,7 @@ TEST(DECONVOLUTION_NHWC_F32, Kx2sKx2) { } TEST(DECONVOLUTION_NHWC_F32, kernel_2xKs2xK) { - const struct xnn_gemm_config* gemm_config = xnn_init_f32_gemm_config(); + const struct xnn_gemm_config* gemm_config = xnn_init_f32_igemm_config(); ASSERT_NE(gemm_config, nullptr); for (size_t kernel_width = 3; kernel_width <= 5; kernel_width++) { DeconvolutionOperatorTester() @@ -16219,7 +16219,7 @@ TEST(DECONVOLUTION_NHWC_F32, kernel_2xKs2xK) { } TEST(DECONVOLUTION_NHWC_F32, kernel_2x2s2_height_adjustment) { - const struct xnn_gemm_config* gemm_config = xnn_init_f32_gemm_config(); + const struct xnn_gemm_config* gemm_config = xnn_init_f32_igemm_config(); ASSERT_NE(gemm_config, nullptr); DeconvolutionOperatorTester() .input_size(kStridedInputHeight, kStridedInputWidth) @@ -16232,7 +16232,7 @@ TEST(DECONVOLUTION_NHWC_F32, kernel_2x2s2_height_adjustment) { } TEST(DECONVOLUTION_NHWC_F32, kernel_2x2s2_width_adjustment) { - const struct xnn_gemm_config* gemm_config = xnn_init_f32_gemm_config(); + const struct xnn_gemm_config* gemm_config = xnn_init_f32_igemm_config(); ASSERT_NE(gemm_config, nullptr); DeconvolutionOperatorTester() .input_size(kStridedInputHeight, kStridedInputWidth) @@ -16245,7 +16245,7 @@ TEST(DECONVOLUTION_NHWC_F32, kernel_2x2s2_width_adjustment) { } TEST(DECONVOLUTION_NHWC_F32, kernel_2x2s2_varying_input_height) { - const struct xnn_gemm_config* gemm_config = xnn_init_f32_gemm_config(); + const struct xnn_gemm_config* gemm_config = xnn_init_f32_igemm_config(); ASSERT_NE(gemm_config, nullptr); for (size_t input_height = kStridedInputHeight - 2; input_height <= kStridedInputHeight + 2; input_height++) { DeconvolutionOperatorTester() @@ -16259,7 +16259,7 @@ TEST(DECONVOLUTION_NHWC_F32, kernel_2x2s2_varying_input_height) { } TEST(DECONVOLUTION_NHWC_F32, kernel_2x2s2_varying_input_width) { - const struct xnn_gemm_config* gemm_config = xnn_init_f32_gemm_config(); + const struct xnn_gemm_config* gemm_config = xnn_init_f32_igemm_config(); ASSERT_NE(gemm_config, nullptr); for (size_t input_width = kStridedInputWidth - 2; input_width <= kStridedInputWidth + 2; input_width++) { DeconvolutionOperatorTester() @@ -16273,7 +16273,7 @@ TEST(DECONVOLUTION_NHWC_F32, kernel_2x2s2_varying_input_width) { } TEST(DECONVOLUTION_NHWC_F32, kernel_2x2s2_varying_input_channels) { - const struct xnn_gemm_config* gemm_config = xnn_init_f32_gemm_config(); + const struct xnn_gemm_config* gemm_config = xnn_init_f32_igemm_config(); ASSERT_NE(gemm_config, nullptr); for (size_t input_channels = 1; input_channels <= 16; input_channels *= 4) { DeconvolutionOperatorTester() @@ -16287,7 +16287,7 @@ TEST(DECONVOLUTION_NHWC_F32, kernel_2x2s2_varying_input_channels) { } TEST(DECONVOLUTION_NHWC_F32, kernel_2x2s2_varying_output_channels) { - const struct xnn_gemm_config* gemm_config = xnn_init_f32_gemm_config(); + const struct xnn_gemm_config* gemm_config = xnn_init_f32_igemm_config(); ASSERT_NE(gemm_config, nullptr); for (size_t output_channels = 1; output_channels <= gemm_config->nr * 2; output_channels *= 2) { DeconvolutionOperatorTester() @@ -16301,7 +16301,7 @@ TEST(DECONVOLUTION_NHWC_F32, kernel_2x2s2_varying_output_channels) { } TEST(DECONVOLUTION_NHWC_F32, kernel_2x2s2_with_input_stride) { - const struct xnn_gemm_config* gemm_config = xnn_init_f32_gemm_config(); + const struct xnn_gemm_config* gemm_config = xnn_init_f32_igemm_config(); ASSERT_NE(gemm_config, nullptr); DeconvolutionOperatorTester() .input_size(kStridedInputHeight, kStridedInputWidth) @@ -16314,7 +16314,7 @@ TEST(DECONVOLUTION_NHWC_F32, kernel_2x2s2_with_input_stride) { } TEST(DECONVOLUTION_NHWC_F32, kernel_2x2s2_with_output_stride) { - const struct xnn_gemm_config* gemm_config = xnn_init_f32_gemm_config(); + const struct xnn_gemm_config* gemm_config = xnn_init_f32_igemm_config(); ASSERT_NE(gemm_config, nullptr); DeconvolutionOperatorTester() .input_size(kStridedInputHeight, kStridedInputWidth) @@ -16327,7 +16327,7 @@ TEST(DECONVOLUTION_NHWC_F32, kernel_2x2s2_with_output_stride) { } TEST(DECONVOLUTION_NHWC_F32, kernel_2x2s2_with_qmin) { - const struct xnn_gemm_config* gemm_config = xnn_init_f32_gemm_config(); + const struct xnn_gemm_config* gemm_config = xnn_init_f32_igemm_config(); ASSERT_NE(gemm_config, nullptr); DeconvolutionOperatorTester() .input_size(kStridedInputHeight, kStridedInputWidth) @@ -16340,7 +16340,7 @@ TEST(DECONVOLUTION_NHWC_F32, kernel_2x2s2_with_qmin) { } TEST(DECONVOLUTION_NHWC_F32, kernel_2x2s2_with_qmax) { - const struct xnn_gemm_config* gemm_config = xnn_init_f32_gemm_config(); + const struct xnn_gemm_config* gemm_config = xnn_init_f32_igemm_config(); ASSERT_NE(gemm_config, nullptr); DeconvolutionOperatorTester() .input_size(kStridedInputHeight, kStridedInputWidth) @@ -16353,7 +16353,7 @@ TEST(DECONVOLUTION_NHWC_F32, kernel_2x2s2_with_qmax) { } TEST(DECONVOLUTION_NHWC_F32, kernel_2x2s2_without_bias) { - const struct xnn_gemm_config* gemm_config = xnn_init_f32_gemm_config(); + const struct xnn_gemm_config* gemm_config = xnn_init_f32_igemm_config(); ASSERT_NE(gemm_config, nullptr); DeconvolutionOperatorTester() .has_bias(false) @@ -16366,7 +16366,7 @@ TEST(DECONVOLUTION_NHWC_F32, kernel_2x2s2_without_bias) { } TEST(DECONVOLUTION_NHWC_F32, weights_cache_2x2s2) { - const struct xnn_gemm_config* gemm_config = xnn_init_f32_gemm_config(); + const struct xnn_gemm_config* gemm_config = xnn_init_f32_igemm_config(); ASSERT_NE(gemm_config, nullptr); DeconvolutionOperatorTester() .input_size(kStridedInputHeight, kStridedInputWidth) @@ -16381,7 +16381,7 @@ TEST(DECONVOLUTION_NHWC_F32, weights_cache_2x2s2) { /**************************** SUBCONV2D/GEMM path, grouped ****************************/ TEST(DECONVOLUTION_NHWC_F32, grouped_2x2s2) { - const struct xnn_gemm_config* gemm_config = xnn_init_f32_gemm_config(); + const struct xnn_gemm_config* gemm_config = xnn_init_f32_igemm_config(); ASSERT_NE(gemm_config, nullptr); DeconvolutionOperatorTester() .input_size(kStridedInputHeight, kStridedInputWidth) @@ -16394,7 +16394,7 @@ TEST(DECONVOLUTION_NHWC_F32, grouped_2x2s2) { } TEST(DECONVOLUTION_NHWC_F32, grouped_Kx2sKx2) { - const struct xnn_gemm_config* gemm_config = xnn_init_f32_gemm_config(); + const struct xnn_gemm_config* gemm_config = xnn_init_f32_igemm_config(); ASSERT_NE(gemm_config, nullptr); for (size_t kernel_height = 3; kernel_height <= 5; kernel_height++) { DeconvolutionOperatorTester() @@ -16409,7 +16409,7 @@ TEST(DECONVOLUTION_NHWC_F32, grouped_Kx2sKx2) { } TEST(DECONVOLUTION_NHWC_F32, grouped_2xKs2xK) { - const struct xnn_gemm_config* gemm_config = xnn_init_f32_gemm_config(); + const struct xnn_gemm_config* gemm_config = xnn_init_f32_igemm_config(); ASSERT_NE(gemm_config, nullptr); for (size_t kernel_width = 3; kernel_width <= 5; kernel_width++) { DeconvolutionOperatorTester() @@ -16424,7 +16424,7 @@ TEST(DECONVOLUTION_NHWC_F32, grouped_2xKs2xK) { } TEST(DECONVOLUTION_NHWC_F32, grouped_2x2s2_height_adjustment) { - const struct xnn_gemm_config* gemm_config = xnn_init_f32_gemm_config(); + const struct xnn_gemm_config* gemm_config = xnn_init_f32_igemm_config(); ASSERT_NE(gemm_config, nullptr); DeconvolutionOperatorTester() .input_size(kStridedInputHeight, kStridedInputWidth) @@ -16438,7 +16438,7 @@ TEST(DECONVOLUTION_NHWC_F32, grouped_2x2s2_height_adjustment) { } TEST(DECONVOLUTION_NHWC_F32, grouped_2x2s2_width_adjustment) { - const struct xnn_gemm_config* gemm_config = xnn_init_f32_gemm_config(); + const struct xnn_gemm_config* gemm_config = xnn_init_f32_igemm_config(); ASSERT_NE(gemm_config, nullptr); DeconvolutionOperatorTester() .input_size(kStridedInputHeight, kStridedInputWidth) @@ -16452,7 +16452,7 @@ TEST(DECONVOLUTION_NHWC_F32, grouped_2x2s2_width_adjustment) { } TEST(DECONVOLUTION_NHWC_F32, grouped_2x2s2_varying_input_height) { - const struct xnn_gemm_config* gemm_config = xnn_init_f32_gemm_config(); + const struct xnn_gemm_config* gemm_config = xnn_init_f32_igemm_config(); ASSERT_NE(gemm_config, nullptr); for (size_t input_height = kStridedInputHeight - 2; input_height <= kStridedInputHeight + 2; input_height++) { DeconvolutionOperatorTester() @@ -16467,7 +16467,7 @@ TEST(DECONVOLUTION_NHWC_F32, grouped_2x2s2_varying_input_height) { } TEST(DECONVOLUTION_NHWC_F32, grouped_2x2s2_varying_input_width) { - const struct xnn_gemm_config* gemm_config = xnn_init_f32_gemm_config(); + const struct xnn_gemm_config* gemm_config = xnn_init_f32_igemm_config(); ASSERT_NE(gemm_config, nullptr); for (size_t input_width = kStridedInputWidth - 2; input_width <= kStridedInputWidth + 2; input_width++) { DeconvolutionOperatorTester() @@ -16482,7 +16482,7 @@ TEST(DECONVOLUTION_NHWC_F32, grouped_2x2s2_varying_input_width) { } TEST(DECONVOLUTION_NHWC_F32, grouped_2x2s2_varying_input_channels) { - const struct xnn_gemm_config* gemm_config = xnn_init_f32_gemm_config(); + const struct xnn_gemm_config* gemm_config = xnn_init_f32_igemm_config(); ASSERT_NE(gemm_config, nullptr); for (size_t input_channels = 14; input_channels <= 20; input_channels++) { DeconvolutionOperatorTester() @@ -16497,7 +16497,7 @@ TEST(DECONVOLUTION_NHWC_F32, grouped_2x2s2_varying_input_channels) { } TEST(DECONVOLUTION_NHWC_F32, grouped_2x2s2_varying_output_channels) { - const struct xnn_gemm_config* gemm_config = xnn_init_f32_gemm_config(); + const struct xnn_gemm_config* gemm_config = xnn_init_f32_igemm_config(); ASSERT_NE(gemm_config, nullptr); for (size_t output_channels = 1; output_channels <= gemm_config->nr * 2; output_channels *= 2) { DeconvolutionOperatorTester() @@ -16512,7 +16512,7 @@ TEST(DECONVOLUTION_NHWC_F32, grouped_2x2s2_varying_output_channels) { } TEST(DECONVOLUTION_NHWC_F32, grouped_2x2s2_with_input_stride) { - const struct xnn_gemm_config* gemm_config = xnn_init_f32_gemm_config(); + const struct xnn_gemm_config* gemm_config = xnn_init_f32_igemm_config(); ASSERT_NE(gemm_config, nullptr); DeconvolutionOperatorTester() .input_size(kStridedInputHeight, kStridedInputWidth) @@ -16526,7 +16526,7 @@ TEST(DECONVOLUTION_NHWC_F32, grouped_2x2s2_with_input_stride) { } TEST(DECONVOLUTION_NHWC_F32, grouped_2x2s2_with_output_stride) { - const struct xnn_gemm_config* gemm_config = xnn_init_f32_gemm_config(); + const struct xnn_gemm_config* gemm_config = xnn_init_f32_igemm_config(); ASSERT_NE(gemm_config, nullptr); DeconvolutionOperatorTester() .input_size(kStridedInputHeight, kStridedInputWidth) @@ -16540,7 +16540,7 @@ TEST(DECONVOLUTION_NHWC_F32, grouped_2x2s2_with_output_stride) { } TEST(DECONVOLUTION_NHWC_F32, grouped_2x2s2_with_qmin) { - const struct xnn_gemm_config* gemm_config = xnn_init_f32_gemm_config(); + const struct xnn_gemm_config* gemm_config = xnn_init_f32_igemm_config(); ASSERT_NE(gemm_config, nullptr); DeconvolutionOperatorTester() .input_size(kStridedInputHeight, kStridedInputWidth) @@ -16554,7 +16554,7 @@ TEST(DECONVOLUTION_NHWC_F32, grouped_2x2s2_with_qmin) { } TEST(DECONVOLUTION_NHWC_F32, grouped_2x2s2_with_qmax) { - const struct xnn_gemm_config* gemm_config = xnn_init_f32_gemm_config(); + const struct xnn_gemm_config* gemm_config = xnn_init_f32_igemm_config(); ASSERT_NE(gemm_config, nullptr); DeconvolutionOperatorTester() .input_size(kStridedInputHeight, kStridedInputWidth) @@ -16568,7 +16568,7 @@ TEST(DECONVOLUTION_NHWC_F32, grouped_2x2s2_with_qmax) { } TEST(DECONVOLUTION_NHWC_F32, grouped_2x2s2_without_bias) { - const struct xnn_gemm_config* gemm_config = xnn_init_f32_gemm_config(); + const struct xnn_gemm_config* gemm_config = xnn_init_f32_igemm_config(); ASSERT_NE(gemm_config, nullptr); DeconvolutionOperatorTester() .has_bias(false) @@ -16582,7 +16582,7 @@ TEST(DECONVOLUTION_NHWC_F32, grouped_2x2s2_without_bias) { } TEST(DECONVOLUTION_NHWC_F32, weights_cache_grouped_2x2s2) { - const struct xnn_gemm_config* gemm_config = xnn_init_f32_gemm_config(); + const struct xnn_gemm_config* gemm_config = xnn_init_f32_igemm_config(); ASSERT_NE(gemm_config, nullptr); DeconvolutionOperatorTester() .input_size(kStridedInputHeight, kStridedInputWidth) @@ -16598,7 +16598,7 @@ TEST(DECONVOLUTION_NHWC_F32, weights_cache_grouped_2x2s2) { /**************************** SUBCONV2D/GEMM path, batched ****************************/ TEST(DECONVOLUTION_NHWC_F32, batched_2x2s2) { - const struct xnn_gemm_config* gemm_config = xnn_init_f32_gemm_config(); + const struct xnn_gemm_config* gemm_config = xnn_init_f32_igemm_config(); ASSERT_NE(gemm_config, nullptr); DeconvolutionOperatorTester() .batch_size(2) @@ -16611,7 +16611,7 @@ TEST(DECONVOLUTION_NHWC_F32, batched_2x2s2) { } TEST(DECONVOLUTION_NHWC_F32, batched_Kx2sKx2) { - const struct xnn_gemm_config* gemm_config = xnn_init_f32_gemm_config(); + const struct xnn_gemm_config* gemm_config = xnn_init_f32_igemm_config(); ASSERT_NE(gemm_config, nullptr); for (size_t kernel_height = 3; kernel_height <= 5; kernel_height++) { DeconvolutionOperatorTester() @@ -16626,7 +16626,7 @@ TEST(DECONVOLUTION_NHWC_F32, batched_Kx2sKx2) { } TEST(DECONVOLUTION_NHWC_F32, batched_2xKs2xK) { - const struct xnn_gemm_config* gemm_config = xnn_init_f32_gemm_config(); + const struct xnn_gemm_config* gemm_config = xnn_init_f32_igemm_config(); ASSERT_NE(gemm_config, nullptr); for (size_t kernel_width = 3; kernel_width <= 5; kernel_width++) { DeconvolutionOperatorTester() @@ -16641,7 +16641,7 @@ TEST(DECONVOLUTION_NHWC_F32, batched_2xKs2xK) { } TEST(DECONVOLUTION_NHWC_F32, batched_2x2s2_height_adjustment) { - const struct xnn_gemm_config* gemm_config = xnn_init_f32_gemm_config(); + const struct xnn_gemm_config* gemm_config = xnn_init_f32_igemm_config(); ASSERT_NE(gemm_config, nullptr); DeconvolutionOperatorTester() .batch_size(2) @@ -16655,7 +16655,7 @@ TEST(DECONVOLUTION_NHWC_F32, batched_2x2s2_height_adjustment) { } TEST(DECONVOLUTION_NHWC_F32, batched_2x2s2_width_adjustment) { - const struct xnn_gemm_config* gemm_config = xnn_init_f32_gemm_config(); + const struct xnn_gemm_config* gemm_config = xnn_init_f32_igemm_config(); ASSERT_NE(gemm_config, nullptr); DeconvolutionOperatorTester() .batch_size(2) @@ -16669,7 +16669,7 @@ TEST(DECONVOLUTION_NHWC_F32, batched_2x2s2_width_adjustment) { } TEST(DECONVOLUTION_NHWC_F32, batched_2x2s2_varying_input_height) { - const struct xnn_gemm_config* gemm_config = xnn_init_f32_gemm_config(); + const struct xnn_gemm_config* gemm_config = xnn_init_f32_igemm_config(); ASSERT_NE(gemm_config, nullptr); for (size_t input_height = kStridedInputHeight - 2; input_height <= kStridedInputHeight + 2; input_height++) { DeconvolutionOperatorTester() @@ -16684,7 +16684,7 @@ TEST(DECONVOLUTION_NHWC_F32, batched_2x2s2_varying_input_height) { } TEST(DECONVOLUTION_NHWC_F32, batched_2x2s2_varying_input_width) { - const struct xnn_gemm_config* gemm_config = xnn_init_f32_gemm_config(); + const struct xnn_gemm_config* gemm_config = xnn_init_f32_igemm_config(); ASSERT_NE(gemm_config, nullptr); for (size_t input_width = kStridedInputWidth - 2; input_width <= kStridedInputWidth + 2; input_width++) { DeconvolutionOperatorTester() @@ -16699,7 +16699,7 @@ TEST(DECONVOLUTION_NHWC_F32, batched_2x2s2_varying_input_width) { } TEST(DECONVOLUTION_NHWC_F32, batched_2x2s2_varying_input_channels) { - const struct xnn_gemm_config* gemm_config = xnn_init_f32_gemm_config(); + const struct xnn_gemm_config* gemm_config = xnn_init_f32_igemm_config(); ASSERT_NE(gemm_config, nullptr); for (size_t input_channels = 1; input_channels <= 16; input_channels *= 4) { DeconvolutionOperatorTester() @@ -16714,7 +16714,7 @@ TEST(DECONVOLUTION_NHWC_F32, batched_2x2s2_varying_input_channels) { } TEST(DECONVOLUTION_NHWC_F32, batched_2x2s2_varying_output_channels) { - const struct xnn_gemm_config* gemm_config = xnn_init_f32_gemm_config(); + const struct xnn_gemm_config* gemm_config = xnn_init_f32_igemm_config(); ASSERT_NE(gemm_config, nullptr); for (size_t output_channels = 1; output_channels <= gemm_config->nr * 2; output_channels *= 2) { DeconvolutionOperatorTester() @@ -16729,7 +16729,7 @@ TEST(DECONVOLUTION_NHWC_F32, batched_2x2s2_varying_output_channels) { } TEST(DECONVOLUTION_NHWC_F32, batched_2x2s2_with_input_stride) { - const struct xnn_gemm_config* gemm_config = xnn_init_f32_gemm_config(); + const struct xnn_gemm_config* gemm_config = xnn_init_f32_igemm_config(); ASSERT_NE(gemm_config, nullptr); DeconvolutionOperatorTester() .batch_size(2) @@ -16743,7 +16743,7 @@ TEST(DECONVOLUTION_NHWC_F32, batched_2x2s2_with_input_stride) { } TEST(DECONVOLUTION_NHWC_F32, batched_2x2s2_with_output_stride) { - const struct xnn_gemm_config* gemm_config = xnn_init_f32_gemm_config(); + const struct xnn_gemm_config* gemm_config = xnn_init_f32_igemm_config(); ASSERT_NE(gemm_config, nullptr); DeconvolutionOperatorTester() .batch_size(2) @@ -16757,7 +16757,7 @@ TEST(DECONVOLUTION_NHWC_F32, batched_2x2s2_with_output_stride) { } TEST(DECONVOLUTION_NHWC_F32, batched_2x2s2_with_qmin) { - const struct xnn_gemm_config* gemm_config = xnn_init_f32_gemm_config(); + const struct xnn_gemm_config* gemm_config = xnn_init_f32_igemm_config(); ASSERT_NE(gemm_config, nullptr); DeconvolutionOperatorTester() .batch_size(2) @@ -16771,7 +16771,7 @@ TEST(DECONVOLUTION_NHWC_F32, batched_2x2s2_with_qmin) { } TEST(DECONVOLUTION_NHWC_F32, batched_2x2s2_with_qmax) { - const struct xnn_gemm_config* gemm_config = xnn_init_f32_gemm_config(); + const struct xnn_gemm_config* gemm_config = xnn_init_f32_igemm_config(); ASSERT_NE(gemm_config, nullptr); DeconvolutionOperatorTester() .batch_size(2) @@ -16785,7 +16785,7 @@ TEST(DECONVOLUTION_NHWC_F32, batched_2x2s2_with_qmax) { } TEST(DECONVOLUTION_NHWC_F32, batched_2x2s2_without_bias) { - const struct xnn_gemm_config* gemm_config = xnn_init_f32_gemm_config(); + const struct xnn_gemm_config* gemm_config = xnn_init_f32_igemm_config(); ASSERT_NE(gemm_config, nullptr); DeconvolutionOperatorTester() .has_bias(false) @@ -16799,7 +16799,7 @@ TEST(DECONVOLUTION_NHWC_F32, batched_2x2s2_without_bias) { } TEST(DECONVOLUTION_NHWC_F32, weights_cache_batched_2x2s2) { - const struct xnn_gemm_config* gemm_config = xnn_init_f32_gemm_config(); + const struct xnn_gemm_config* gemm_config = xnn_init_f32_igemm_config(); ASSERT_NE(gemm_config, nullptr); DeconvolutionOperatorTester() .batch_size(2) @@ -16815,7 +16815,7 @@ TEST(DECONVOLUTION_NHWC_F32, weights_cache_batched_2x2s2) { /**************************** SUBCONV2D/GEMM path, grouped, batched ****************************/ TEST(DECONVOLUTION_NHWC_F32, batched_grouped_2x2s2) { - const struct xnn_gemm_config* gemm_config = xnn_init_f32_gemm_config(); + const struct xnn_gemm_config* gemm_config = xnn_init_f32_igemm_config(); ASSERT_NE(gemm_config, nullptr); DeconvolutionOperatorTester() .batch_size(2) @@ -16829,7 +16829,7 @@ TEST(DECONVOLUTION_NHWC_F32, batched_grouped_2x2s2) { } TEST(DECONVOLUTION_NHWC_F32, batched_grouped_Kx2sKx2) { - const struct xnn_gemm_config* gemm_config = xnn_init_f32_gemm_config(); + const struct xnn_gemm_config* gemm_config = xnn_init_f32_igemm_config(); ASSERT_NE(gemm_config, nullptr); for (size_t kernel_height = 3; kernel_height <= 5; kernel_height++) { DeconvolutionOperatorTester() @@ -16845,7 +16845,7 @@ TEST(DECONVOLUTION_NHWC_F32, batched_grouped_Kx2sKx2) { } TEST(DECONVOLUTION_NHWC_F32, batched_grouped_2xKs2xK) { - const struct xnn_gemm_config* gemm_config = xnn_init_f32_gemm_config(); + const struct xnn_gemm_config* gemm_config = xnn_init_f32_igemm_config(); ASSERT_NE(gemm_config, nullptr); for (size_t kernel_width = 3; kernel_width <= 5; kernel_width++) { DeconvolutionOperatorTester() @@ -16861,7 +16861,7 @@ TEST(DECONVOLUTION_NHWC_F32, batched_grouped_2xKs2xK) { } TEST(DECONVOLUTION_NHWC_F32, batched_grouped_2x2s2_height_adjustment) { - const struct xnn_gemm_config* gemm_config = xnn_init_f32_gemm_config(); + const struct xnn_gemm_config* gemm_config = xnn_init_f32_igemm_config(); ASSERT_NE(gemm_config, nullptr); DeconvolutionOperatorTester() .batch_size(2) @@ -16876,7 +16876,7 @@ TEST(DECONVOLUTION_NHWC_F32, batched_grouped_2x2s2_height_adjustment) { } TEST(DECONVOLUTION_NHWC_F32, batched_grouped_2x2s2_width_adjustment) { - const struct xnn_gemm_config* gemm_config = xnn_init_f32_gemm_config(); + const struct xnn_gemm_config* gemm_config = xnn_init_f32_igemm_config(); ASSERT_NE(gemm_config, nullptr); DeconvolutionOperatorTester() .batch_size(2) @@ -16891,7 +16891,7 @@ TEST(DECONVOLUTION_NHWC_F32, batched_grouped_2x2s2_width_adjustment) { } TEST(DECONVOLUTION_NHWC_F32, batched_grouped_2x2s2_varying_input_height) { - const struct xnn_gemm_config* gemm_config = xnn_init_f32_gemm_config(); + const struct xnn_gemm_config* gemm_config = xnn_init_f32_igemm_config(); ASSERT_NE(gemm_config, nullptr); for (size_t input_height = kStridedInputHeight - 2; input_height <= kStridedInputHeight + 2; input_height++) { DeconvolutionOperatorTester() @@ -16907,7 +16907,7 @@ TEST(DECONVOLUTION_NHWC_F32, batched_grouped_2x2s2_varying_input_height) { } TEST(DECONVOLUTION_NHWC_F32, batched_grouped_2x2s2_varying_input_width) { - const struct xnn_gemm_config* gemm_config = xnn_init_f32_gemm_config(); + const struct xnn_gemm_config* gemm_config = xnn_init_f32_igemm_config(); ASSERT_NE(gemm_config, nullptr); for (size_t input_width = kStridedInputWidth - 2; input_width <= kStridedInputWidth + 2; input_width++) { DeconvolutionOperatorTester() @@ -16923,7 +16923,7 @@ TEST(DECONVOLUTION_NHWC_F32, batched_grouped_2x2s2_varying_input_width) { } TEST(DECONVOLUTION_NHWC_F32, batched_grouped_2x2s2_varying_input_channels) { - const struct xnn_gemm_config* gemm_config = xnn_init_f32_gemm_config(); + const struct xnn_gemm_config* gemm_config = xnn_init_f32_igemm_config(); ASSERT_NE(gemm_config, nullptr); for (size_t input_channels = 14; input_channels <= 20; input_channels++) { DeconvolutionOperatorTester() @@ -16939,7 +16939,7 @@ TEST(DECONVOLUTION_NHWC_F32, batched_grouped_2x2s2_varying_input_channels) { } TEST(DECONVOLUTION_NHWC_F32, batched_grouped_2x2s2_varying_output_channels) { - const struct xnn_gemm_config* gemm_config = xnn_init_f32_gemm_config(); + const struct xnn_gemm_config* gemm_config = xnn_init_f32_igemm_config(); ASSERT_NE(gemm_config, nullptr); for (size_t output_channels = 1; output_channels <= gemm_config->nr * 2; output_channels *= 2) { DeconvolutionOperatorTester() @@ -16955,7 +16955,7 @@ TEST(DECONVOLUTION_NHWC_F32, batched_grouped_2x2s2_varying_output_channels) { } TEST(DECONVOLUTION_NHWC_F32, batched_grouped_2x2s2_with_input_stride) { - const struct xnn_gemm_config* gemm_config = xnn_init_f32_gemm_config(); + const struct xnn_gemm_config* gemm_config = xnn_init_f32_igemm_config(); ASSERT_NE(gemm_config, nullptr); DeconvolutionOperatorTester() .batch_size(2) @@ -16970,7 +16970,7 @@ TEST(DECONVOLUTION_NHWC_F32, batched_grouped_2x2s2_with_input_stride) { } TEST(DECONVOLUTION_NHWC_F32, batched_grouped_2x2s2_with_output_stride) { - const struct xnn_gemm_config* gemm_config = xnn_init_f32_gemm_config(); + const struct xnn_gemm_config* gemm_config = xnn_init_f32_igemm_config(); ASSERT_NE(gemm_config, nullptr); DeconvolutionOperatorTester() .batch_size(2) @@ -16985,7 +16985,7 @@ TEST(DECONVOLUTION_NHWC_F32, batched_grouped_2x2s2_with_output_stride) { } TEST(DECONVOLUTION_NHWC_F32, batched_grouped_2x2s2_with_qmin) { - const struct xnn_gemm_config* gemm_config = xnn_init_f32_gemm_config(); + const struct xnn_gemm_config* gemm_config = xnn_init_f32_igemm_config(); ASSERT_NE(gemm_config, nullptr); DeconvolutionOperatorTester() .batch_size(2) @@ -17000,7 +17000,7 @@ TEST(DECONVOLUTION_NHWC_F32, batched_grouped_2x2s2_with_qmin) { } TEST(DECONVOLUTION_NHWC_F32, batched_grouped_2x2s2_with_qmax) { - const struct xnn_gemm_config* gemm_config = xnn_init_f32_gemm_config(); + const struct xnn_gemm_config* gemm_config = xnn_init_f32_igemm_config(); ASSERT_NE(gemm_config, nullptr); DeconvolutionOperatorTester() .batch_size(2) @@ -17015,7 +17015,7 @@ TEST(DECONVOLUTION_NHWC_F32, batched_grouped_2x2s2_with_qmax) { } TEST(DECONVOLUTION_NHWC_F32, batched_grouped_2x2s2_without_bias) { - const struct xnn_gemm_config* gemm_config = xnn_init_f32_gemm_config(); + const struct xnn_gemm_config* gemm_config = xnn_init_f32_igemm_config(); ASSERT_NE(gemm_config, nullptr); DeconvolutionOperatorTester() .has_bias(false) @@ -17030,7 +17030,7 @@ TEST(DECONVOLUTION_NHWC_F32, batched_grouped_2x2s2_without_bias) { } TEST(DECONVOLUTION_NHWC_F32, weights_cache_batched_grouped_2x2s2) { - const struct xnn_gemm_config* gemm_config = xnn_init_f32_gemm_config(); + const struct xnn_gemm_config* gemm_config = xnn_init_f32_igemm_config(); ASSERT_NE(gemm_config, nullptr); DeconvolutionOperatorTester() .batch_size(2) diff --git a/test/f32-gemm-minmax-2.cc b/test/f32-gemm-minmax-2.cc index 72f94aba39e..7d5b17a60d3 100644 --- a/test/f32-gemm-minmax-2.cc +++ b/test/f32-gemm-minmax-2.cc @@ -2724,6 +2724,174 @@ std::vector CreateTests2( [](const testing::TestParamInfo& info) { return info.param.test_name; }); + + INSTANTIATE_TEST_SUITE_P( + F32_GEMM_MINMAX_1X16C2__ASM_AMD64_AVX512F_BROADCAST, GemmTest, + testing::ValuesIn(CreateTests1( + /*k_block=*/2, + /*adj_k_block=*/2, + /*mr=*/1, /*nr=*/16, /*kr=*/2, /*sr=*/1, + /*is_igemm=*/false, + /*unsigned_inputs=*/false, + /*planes=*/1, + [](GemmMicrokernelTester& tester) { + tester.Test(xnn_f32_gemm_minmax_ukernel_1x16c2__asm_amd64_avx512f_broadcast, + xnn_init_f32_minmax_scalar_params, + xnn_pack_f32_gemm_goi_w); + }, + []() { + TEST_REQUIRES_X86_AVX512F; + })), + [](const testing::TestParamInfo& info) { + return info.param.test_name; + }); + + INSTANTIATE_TEST_SUITE_P( + F32_GEMM_MINMAX_4X16C2__ASM_AMD64_AVX512F_BROADCAST, GemmTest, + testing::ValuesIn(CreateTests1( + /*k_block=*/2, + /*adj_k_block=*/2, + /*mr=*/4, /*nr=*/16, /*kr=*/2, /*sr=*/1, + /*is_igemm=*/false, + /*unsigned_inputs=*/false, + /*planes=*/1, + [](GemmMicrokernelTester& tester) { + tester.Test(xnn_f32_gemm_minmax_ukernel_4x16c2__asm_amd64_avx512f_broadcast, + xnn_init_f32_minmax_scalar_params, + xnn_pack_f32_gemm_goi_w); + }, + []() { + TEST_REQUIRES_X86_AVX512F; + })), + [](const testing::TestParamInfo& info) { + return info.param.test_name; + }); + + INSTANTIATE_TEST_SUITE_P( + F32_GEMM_MINMAX_5X16C2__ASM_AMD64_AVX512F_BROADCAST, GemmTest, + testing::ValuesIn(CreateTests1( + /*k_block=*/2, + /*adj_k_block=*/2, + /*mr=*/5, /*nr=*/16, /*kr=*/2, /*sr=*/1, + /*is_igemm=*/false, + /*unsigned_inputs=*/false, + /*planes=*/1, + [](GemmMicrokernelTester& tester) { + tester.Test(xnn_f32_gemm_minmax_ukernel_5x16c2__asm_amd64_avx512f_broadcast, + xnn_init_f32_minmax_scalar_params, + xnn_pack_f32_gemm_goi_w); + }, + []() { + TEST_REQUIRES_X86_AVX512F; + })), + [](const testing::TestParamInfo& info) { + return info.param.test_name; + }); + + INSTANTIATE_TEST_SUITE_P( + F32_GEMM_MINMAX_10X16C2__ASM_AMD64_AVX512F_BROADCAST, GemmTest, + testing::ValuesIn(CreateTests1( + /*k_block=*/2, + /*adj_k_block=*/2, + /*mr=*/10, /*nr=*/16, /*kr=*/2, /*sr=*/1, + /*is_igemm=*/false, + /*unsigned_inputs=*/false, + /*planes=*/1, + [](GemmMicrokernelTester& tester) { + tester.Test(xnn_f32_gemm_minmax_ukernel_10x16c2__asm_amd64_avx512f_broadcast, + xnn_init_f32_minmax_scalar_params, + xnn_pack_f32_gemm_goi_w); + }, + []() { + TEST_REQUIRES_X86_AVX512F; + })), + [](const testing::TestParamInfo& info) { + return info.param.test_name; + }); + + INSTANTIATE_TEST_SUITE_P( + F32_GEMM_MINMAX_11X16C2__ASM_AMD64_AVX512F_BROADCAST, GemmTest, + testing::ValuesIn(CreateTests1( + /*k_block=*/2, + /*adj_k_block=*/2, + /*mr=*/11, /*nr=*/16, /*kr=*/2, /*sr=*/1, + /*is_igemm=*/false, + /*unsigned_inputs=*/false, + /*planes=*/1, + [](GemmMicrokernelTester& tester) { + tester.Test(xnn_f32_gemm_minmax_ukernel_11x16c2__asm_amd64_avx512f_broadcast, + xnn_init_f32_minmax_scalar_params, + xnn_pack_f32_gemm_goi_w); + }, + []() { + TEST_REQUIRES_X86_AVX512F; + })), + [](const testing::TestParamInfo& info) { + return info.param.test_name; + }); + + INSTANTIATE_TEST_SUITE_P( + F32_GEMM_MINMAX_1X32C2__ASM_AMD64_AVX512F_BROADCAST, GemmTest, + testing::ValuesIn(CreateTests1( + /*k_block=*/2, + /*adj_k_block=*/2, + /*mr=*/1, /*nr=*/32, /*kr=*/2, /*sr=*/1, + /*is_igemm=*/false, + /*unsigned_inputs=*/false, + /*planes=*/1, + [](GemmMicrokernelTester& tester) { + tester.Test(xnn_f32_gemm_minmax_ukernel_1x32c2__asm_amd64_avx512f_broadcast, + xnn_init_f32_minmax_scalar_params, + xnn_pack_f32_gemm_goi_w); + }, + []() { + TEST_REQUIRES_X86_AVX512F; + })), + [](const testing::TestParamInfo& info) { + return info.param.test_name; + }); + + INSTANTIATE_TEST_SUITE_P( + F32_GEMM_MINMAX_4X32C2__ASM_AMD64_AVX512F_BROADCAST, GemmTest, + testing::ValuesIn(CreateTests1( + /*k_block=*/2, + /*adj_k_block=*/2, + /*mr=*/4, /*nr=*/32, /*kr=*/2, /*sr=*/1, + /*is_igemm=*/false, + /*unsigned_inputs=*/false, + /*planes=*/1, + [](GemmMicrokernelTester& tester) { + tester.Test(xnn_f32_gemm_minmax_ukernel_4x32c2__asm_amd64_avx512f_broadcast, + xnn_init_f32_minmax_scalar_params, + xnn_pack_f32_gemm_goi_w); + }, + []() { + TEST_REQUIRES_X86_AVX512F; + })), + [](const testing::TestParamInfo& info) { + return info.param.test_name; + }); + + INSTANTIATE_TEST_SUITE_P( + F32_GEMM_MINMAX_5X32C2__ASM_AMD64_AVX512F_BROADCAST, GemmTest, + testing::ValuesIn(CreateTests1( + /*k_block=*/2, + /*adj_k_block=*/2, + /*mr=*/5, /*nr=*/32, /*kr=*/2, /*sr=*/1, + /*is_igemm=*/false, + /*unsigned_inputs=*/false, + /*planes=*/1, + [](GemmMicrokernelTester& tester) { + tester.Test(xnn_f32_gemm_minmax_ukernel_5x32c2__asm_amd64_avx512f_broadcast, + xnn_init_f32_minmax_scalar_params, + xnn_pack_f32_gemm_goi_w); + }, + []() { + TEST_REQUIRES_X86_AVX512F; + })), + [](const testing::TestParamInfo& info) { + return info.param.test_name; + }); #endif // XNN_ENABLE_AVX512F && XNN_ARCH_X86_64 && XNN_ENABLE_ASSEMBLY diff --git a/test/f32-gemm-minmax.cc b/test/f32-gemm-minmax.cc index 3ae8e4c0cff..7f495204336 100644 --- a/test/f32-gemm-minmax.cc +++ b/test/f32-gemm-minmax.cc @@ -3108,6 +3108,174 @@ std::vector CreateTests2( [](const testing::TestParamInfo& info) { return info.param.test_name; }); + + INSTANTIATE_TEST_SUITE_P( + F32_GEMM_MINMAX_2X16C2__ASM_AMD64_AVX512F_BROADCAST, GemmTest, + testing::ValuesIn(CreateTests1( + /*k_block=*/2, + /*adj_k_block=*/2, + /*mr=*/2, /*nr=*/16, /*kr=*/2, /*sr=*/1, + /*is_igemm=*/false, + /*unsigned_inputs=*/false, + /*planes=*/1, + [](GemmMicrokernelTester& tester) { + tester.Test(xnn_f32_gemm_minmax_ukernel_2x16c2__asm_amd64_avx512f_broadcast, + xnn_init_f32_minmax_scalar_params, + xnn_pack_f32_gemm_goi_w); + }, + []() { + TEST_REQUIRES_X86_AVX512F; + })), + [](const testing::TestParamInfo& info) { + return info.param.test_name; + }); + + INSTANTIATE_TEST_SUITE_P( + F32_GEMM_MINMAX_3X16C2__ASM_AMD64_AVX512F_BROADCAST, GemmTest, + testing::ValuesIn(CreateTests1( + /*k_block=*/2, + /*adj_k_block=*/2, + /*mr=*/3, /*nr=*/16, /*kr=*/2, /*sr=*/1, + /*is_igemm=*/false, + /*unsigned_inputs=*/false, + /*planes=*/1, + [](GemmMicrokernelTester& tester) { + tester.Test(xnn_f32_gemm_minmax_ukernel_3x16c2__asm_amd64_avx512f_broadcast, + xnn_init_f32_minmax_scalar_params, + xnn_pack_f32_gemm_goi_w); + }, + []() { + TEST_REQUIRES_X86_AVX512F; + })), + [](const testing::TestParamInfo& info) { + return info.param.test_name; + }); + + INSTANTIATE_TEST_SUITE_P( + F32_GEMM_MINMAX_6X16C2__ASM_AMD64_AVX512F_BROADCAST, GemmTest, + testing::ValuesIn(CreateTests1( + /*k_block=*/2, + /*adj_k_block=*/2, + /*mr=*/6, /*nr=*/16, /*kr=*/2, /*sr=*/1, + /*is_igemm=*/false, + /*unsigned_inputs=*/false, + /*planes=*/1, + [](GemmMicrokernelTester& tester) { + tester.Test(xnn_f32_gemm_minmax_ukernel_6x16c2__asm_amd64_avx512f_broadcast, + xnn_init_f32_minmax_scalar_params, + xnn_pack_f32_gemm_goi_w); + }, + []() { + TEST_REQUIRES_X86_AVX512F; + })), + [](const testing::TestParamInfo& info) { + return info.param.test_name; + }); + + INSTANTIATE_TEST_SUITE_P( + F32_GEMM_MINMAX_7X16C2__ASM_AMD64_AVX512F_BROADCAST, GemmTest, + testing::ValuesIn(CreateTests1( + /*k_block=*/2, + /*adj_k_block=*/2, + /*mr=*/7, /*nr=*/16, /*kr=*/2, /*sr=*/1, + /*is_igemm=*/false, + /*unsigned_inputs=*/false, + /*planes=*/1, + [](GemmMicrokernelTester& tester) { + tester.Test(xnn_f32_gemm_minmax_ukernel_7x16c2__asm_amd64_avx512f_broadcast, + xnn_init_f32_minmax_scalar_params, + xnn_pack_f32_gemm_goi_w); + }, + []() { + TEST_REQUIRES_X86_AVX512F; + })), + [](const testing::TestParamInfo& info) { + return info.param.test_name; + }); + + INSTANTIATE_TEST_SUITE_P( + F32_GEMM_MINMAX_8X16C2__ASM_AMD64_AVX512F_BROADCAST, GemmTest, + testing::ValuesIn(CreateTests1( + /*k_block=*/2, + /*adj_k_block=*/2, + /*mr=*/8, /*nr=*/16, /*kr=*/2, /*sr=*/1, + /*is_igemm=*/false, + /*unsigned_inputs=*/false, + /*planes=*/1, + [](GemmMicrokernelTester& tester) { + tester.Test(xnn_f32_gemm_minmax_ukernel_8x16c2__asm_amd64_avx512f_broadcast, + xnn_init_f32_minmax_scalar_params, + xnn_pack_f32_gemm_goi_w); + }, + []() { + TEST_REQUIRES_X86_AVX512F; + })), + [](const testing::TestParamInfo& info) { + return info.param.test_name; + }); + + INSTANTIATE_TEST_SUITE_P( + F32_GEMM_MINMAX_9X16C2__ASM_AMD64_AVX512F_BROADCAST, GemmTest, + testing::ValuesIn(CreateTests1( + /*k_block=*/2, + /*adj_k_block=*/2, + /*mr=*/9, /*nr=*/16, /*kr=*/2, /*sr=*/1, + /*is_igemm=*/false, + /*unsigned_inputs=*/false, + /*planes=*/1, + [](GemmMicrokernelTester& tester) { + tester.Test(xnn_f32_gemm_minmax_ukernel_9x16c2__asm_amd64_avx512f_broadcast, + xnn_init_f32_minmax_scalar_params, + xnn_pack_f32_gemm_goi_w); + }, + []() { + TEST_REQUIRES_X86_AVX512F; + })), + [](const testing::TestParamInfo& info) { + return info.param.test_name; + }); + + INSTANTIATE_TEST_SUITE_P( + F32_GEMM_MINMAX_2X32C2__ASM_AMD64_AVX512F_BROADCAST, GemmTest, + testing::ValuesIn(CreateTests1( + /*k_block=*/2, + /*adj_k_block=*/2, + /*mr=*/2, /*nr=*/32, /*kr=*/2, /*sr=*/1, + /*is_igemm=*/false, + /*unsigned_inputs=*/false, + /*planes=*/1, + [](GemmMicrokernelTester& tester) { + tester.Test(xnn_f32_gemm_minmax_ukernel_2x32c2__asm_amd64_avx512f_broadcast, + xnn_init_f32_minmax_scalar_params, + xnn_pack_f32_gemm_goi_w); + }, + []() { + TEST_REQUIRES_X86_AVX512F; + })), + [](const testing::TestParamInfo& info) { + return info.param.test_name; + }); + + INSTANTIATE_TEST_SUITE_P( + F32_GEMM_MINMAX_3X32C2__ASM_AMD64_AVX512F_BROADCAST, GemmTest, + testing::ValuesIn(CreateTests1( + /*k_block=*/2, + /*adj_k_block=*/2, + /*mr=*/3, /*nr=*/32, /*kr=*/2, /*sr=*/1, + /*is_igemm=*/false, + /*unsigned_inputs=*/false, + /*planes=*/1, + [](GemmMicrokernelTester& tester) { + tester.Test(xnn_f32_gemm_minmax_ukernel_3x32c2__asm_amd64_avx512f_broadcast, + xnn_init_f32_minmax_scalar_params, + xnn_pack_f32_gemm_goi_w); + }, + []() { + TEST_REQUIRES_X86_AVX512F; + })), + [](const testing::TestParamInfo& info) { + return info.param.test_name; + }); #endif // XNN_ENABLE_AVX512F && XNN_ARCH_X86_64 && XNN_ENABLE_ASSEMBLY diff --git a/test/f32-gemm-minmax.yaml b/test/f32-gemm-minmax.yaml index 9b41a7f4e50..72680f9e275 100644 --- a/test/f32-gemm-minmax.yaml +++ b/test/f32-gemm-minmax.yaml @@ -836,6 +836,72 @@ init: xnn_init_f32_minmax_scalar_params pack: xnn_pack_f32_gemm_goi_w k-block: 1 + +- name: xnn_f32_gemm_minmax_ukernel_1x16c2__asm_amd64_avx512f_broadcast + init: xnn_init_f32_minmax_scalar_params + pack: xnn_pack_f32_gemm_goi_w + k-block: 2 +- name: xnn_f32_gemm_minmax_ukernel_2x16c2__asm_amd64_avx512f_broadcast + init: xnn_init_f32_minmax_scalar_params + pack: xnn_pack_f32_gemm_goi_w + k-block: 2 +- name: xnn_f32_gemm_minmax_ukernel_3x16c2__asm_amd64_avx512f_broadcast + init: xnn_init_f32_minmax_scalar_params + pack: xnn_pack_f32_gemm_goi_w + k-block: 2 +- name: xnn_f32_gemm_minmax_ukernel_4x16c2__asm_amd64_avx512f_broadcast + init: xnn_init_f32_minmax_scalar_params + pack: xnn_pack_f32_gemm_goi_w + k-block: 2 +- name: xnn_f32_gemm_minmax_ukernel_5x16c2__asm_amd64_avx512f_broadcast + init: xnn_init_f32_minmax_scalar_params + pack: xnn_pack_f32_gemm_goi_w + k-block: 2 +- name: xnn_f32_gemm_minmax_ukernel_6x16c2__asm_amd64_avx512f_broadcast + init: xnn_init_f32_minmax_scalar_params + pack: xnn_pack_f32_gemm_goi_w + k-block: 2 +- name: xnn_f32_gemm_minmax_ukernel_7x16c2__asm_amd64_avx512f_broadcast + init: xnn_init_f32_minmax_scalar_params + pack: xnn_pack_f32_gemm_goi_w + k-block: 2 +- name: xnn_f32_gemm_minmax_ukernel_8x16c2__asm_amd64_avx512f_broadcast + init: xnn_init_f32_minmax_scalar_params + pack: xnn_pack_f32_gemm_goi_w + k-block: 2 +- name: xnn_f32_gemm_minmax_ukernel_9x16c2__asm_amd64_avx512f_broadcast + init: xnn_init_f32_minmax_scalar_params + pack: xnn_pack_f32_gemm_goi_w + k-block: 2 +- name: xnn_f32_gemm_minmax_ukernel_10x16c2__asm_amd64_avx512f_broadcast + init: xnn_init_f32_minmax_scalar_params + pack: xnn_pack_f32_gemm_goi_w + k-block: 2 +- name: xnn_f32_gemm_minmax_ukernel_11x16c2__asm_amd64_avx512f_broadcast + init: xnn_init_f32_minmax_scalar_params + pack: xnn_pack_f32_gemm_goi_w + k-block: 2 +- name: xnn_f32_gemm_minmax_ukernel_1x32c2__asm_amd64_avx512f_broadcast + init: xnn_init_f32_minmax_scalar_params + pack: xnn_pack_f32_gemm_goi_w + k-block: 2 +- name: xnn_f32_gemm_minmax_ukernel_2x32c2__asm_amd64_avx512f_broadcast + init: xnn_init_f32_minmax_scalar_params + pack: xnn_pack_f32_gemm_goi_w + k-block: 2 +- name: xnn_f32_gemm_minmax_ukernel_3x32c2__asm_amd64_avx512f_broadcast + init: xnn_init_f32_minmax_scalar_params + pack: xnn_pack_f32_gemm_goi_w + k-block: 2 +- name: xnn_f32_gemm_minmax_ukernel_4x32c2__asm_amd64_avx512f_broadcast + init: xnn_init_f32_minmax_scalar_params + pack: xnn_pack_f32_gemm_goi_w + k-block: 2 +- name: xnn_f32_gemm_minmax_ukernel_5x32c2__asm_amd64_avx512f_broadcast + init: xnn_init_f32_minmax_scalar_params + pack: xnn_pack_f32_gemm_goi_w + k-block: 2 + - name: xnn_f32_gemm_minmax_ukernel_4x8__fma3_broadcast init: xnn_init_f32_minmax_scalar_params pack: xnn_pack_f32_gemm_goi_w diff --git a/test/fully-connected.cc b/test/fully-connected.cc index 37b36f59c67..657fba50bb8 100644 --- a/test/fully-connected.cc +++ b/test/fully-connected.cc @@ -4183,3 +4183,92 @@ TEST_F(FullyConnectedTestF32, reshape) { ASSERT_EQ(runtime->values[fc_node->outputs[0]].size, num_output_elements * sizeof(float)); } + + +TEST_F(FullyConnectedTestF32, matches_operator_api) { + ASSERT_EQ(xnn_status_success, xnn_initialize(/*allocator=*/nullptr)); + + xnn_operator_t op = nullptr; + + std::generate(input.begin(), input.end(), [&]() { return f32dist(rng); }); + std::generate(kernel.begin(), kernel.end(), [&]() { return f32dist(rng); }); + std::generate(bias.begin(), bias.end(), [&]() { return f32dist(rng); }); + + // Call operator API. + const xnn_status status = xnn_create_fully_connected_nc_f32( + input_channels, output_channels, input_channels, output_channels, + kernel.data(), bias.data(), output_min, output_max, + /*flags=*/0, nullptr, nullptr, &op); + std::unique_ptr auto_op( + op, xnn_delete_operator); + + if (status == xnn_status_unsupported_hardware) { + GTEST_SKIP(); + } + + ASSERT_EQ(xnn_status_success, status); + ASSERT_NE(nullptr, op); + ASSERT_EQ(xnn_status_success, xnn_reshape_fully_connected_nc_f32( + op, batch_size, /*threadpool=*/nullptr)); + ASSERT_EQ(xnn_status_success, xnn_setup_fully_connected_nc_f32( + op, input.data(), operator_output.data())); + + ASSERT_EQ(xnn_status_success, xnn_run_operator(op, /*threadpool=*/nullptr)); + + // Call subgraph API. + xnn_subgraph_t subgraph = nullptr; + ASSERT_EQ(xnn_status_success, xnn_create_subgraph(4, /*flags=*/0, &subgraph)); + std::unique_ptr auto_subgraph( + subgraph, xnn_delete_subgraph); + + uint32_t input_id = XNN_INVALID_VALUE_ID; + ASSERT_EQ(xnn_status_success, + xnn_define_tensor_value( + subgraph, xnn_datatype_fp32, input_dims.size(), + input_dims.data(), nullptr, + /*external_id=*/0, XNN_VALUE_FLAG_EXTERNAL_INPUT, &input_id)); + ASSERT_NE(input_id, XNN_INVALID_VALUE_ID); + + uint32_t kernel_id = XNN_INVALID_VALUE_ID; + ASSERT_EQ( + xnn_status_success, + xnn_define_tensor_value(subgraph, xnn_datatype_fp32, kernel_dims.size(), + kernel_dims.data(), kernel.data(), + /*external_id=*/1, /*flags=*/0, &kernel_id)); + + uint32_t bias_id = XNN_INVALID_VALUE_ID; + ASSERT_EQ( + xnn_status_success, + xnn_define_tensor_value(subgraph, xnn_datatype_fp32, bias_dims.size(), + bias_dims.data(), bias.data(), + /*external_id=*/2, /*flags=*/0, &bias_id)); + + uint32_t output_id = XNN_INVALID_VALUE_ID; + ASSERT_EQ(xnn_status_success, + xnn_define_tensor_value( + subgraph, xnn_datatype_fp32, output_dims.size(), + output_dims.data(), nullptr, + /*external_id=*/3, XNN_VALUE_FLAG_EXTERNAL_OUTPUT, &output_id)); + ASSERT_NE(output_id, XNN_INVALID_VALUE_ID); + ASSERT_EQ( + xnn_status_success, + xnn_define_fully_connected(subgraph, output_min, output_max, input_id, + kernel_id, bias_id, output_id, /*flags=*/0)); + + xnn_runtime_t runtime = nullptr; + ASSERT_EQ(xnn_status_success, + xnn_create_runtime_v3(subgraph, nullptr, nullptr, + xnn_test_runtime_flags(), &runtime)); + ASSERT_NE(nullptr, runtime); + std::unique_ptr auto_runtime( + runtime, xnn_delete_runtime); + std::array external = { + xnn_external_value{input_id, input.data()}, + xnn_external_value{output_id, subgraph_output.data()}}; + ASSERT_EQ(xnn_status_success, + xnn_setup_runtime(runtime, external.size(), external.data())); + ASSERT_EQ(xnn_status_success, xnn_invoke_runtime(runtime)); + + // Check outputs match. + EXPECT_THAT(subgraph_output, ElementsAreArray(operator_output)); +}