Skip to content

Commit 6d2b569

Browse files
fbarchardxnnpack-bot
authored andcommitted
MRx2 GEMM/IGEMM fma3 variant
PiperOrigin-RevId: 719127729
1 parent a108468 commit 6d2b569

25 files changed

+1484
-70
lines changed

bench/f32-gemm-minmax.cc

Lines changed: 33 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -3421,6 +3421,17 @@
34213421

34223422
BENCHMARK_GEMM(f32_gemm_minmax_ukernel_4x2c4__sse)
34233423

3424+
static void f32_gemm_minmax_ukernel_6x2c4__sse(benchmark::State& state, const char* net) {
3425+
GEMMBenchmark(state,
3426+
xnn_f32_gemm_minmax_ukernel_6x2c4__sse,
3427+
xnn_init_f32_minmax_scalar_params,
3428+
xnn_pack_f32_gemm_goi_w,
3429+
/*mr=*/6, /*nr=*/2, /*kr=*/4, /*sr=*/1,
3430+
/*isa_check=*/nullptr);
3431+
}
3432+
3433+
BENCHMARK_GEMM(f32_gemm_minmax_ukernel_6x2c4__sse)
3434+
34243435
static void f32_gemm_minmax_ukernel_4x8__sse_dup(benchmark::State& state, const char* net) {
34253436
GEMMBenchmark(state,
34263437
xnn_f32_gemm_minmax_ukernel_4x8__sse_dup,
@@ -3487,17 +3498,6 @@
34873498

34883499
BENCHMARK_GEMM(f32_gemm_minmax_ukernel_5x8s4__sse)
34893500

3490-
static void f32_gemm_minmax_ukernel_6x2c4__sse(benchmark::State& state, const char* net) {
3491-
GEMMBenchmark(state,
3492-
xnn_f32_gemm_minmax_ukernel_6x2c4__sse,
3493-
xnn_init_f32_minmax_scalar_params,
3494-
xnn_pack_f32_gemm_goi_w,
3495-
/*mr=*/6, /*nr=*/2, /*kr=*/4, /*sr=*/1,
3496-
/*isa_check=*/nullptr);
3497-
}
3498-
3499-
BENCHMARK_GEMM(f32_gemm_minmax_ukernel_6x2c4__sse)
3500-
35013501
static void f32_gemm_minmax_ukernel_6x8__sse_dup(benchmark::State& state, const char* net) {
35023502
GEMMBenchmark(state,
35033503
xnn_f32_gemm_minmax_ukernel_6x8__sse_dup,
@@ -3531,6 +3531,28 @@
35313531

35323532
BENCHMARK_GEMM(f32_gemm_minmax_ukernel_6x8s4__sse)
35333533

3534+
static void f32_gemm_minmax_ukernel_4x2c4__fma3(benchmark::State& state, const char* net) {
3535+
GEMMBenchmark(state,
3536+
xnn_f32_gemm_minmax_ukernel_4x2c4__fma3,
3537+
xnn_init_f32_minmax_scalar_params,
3538+
xnn_pack_f32_gemm_goi_w,
3539+
/*mr=*/4, /*nr=*/2, /*kr=*/4, /*sr=*/1,
3540+
benchmark::utils::CheckFMA3);
3541+
}
3542+
3543+
BENCHMARK_GEMM(f32_gemm_minmax_ukernel_4x2c4__fma3)
3544+
3545+
static void f32_gemm_minmax_ukernel_6x2c4__fma3(benchmark::State& state, const char* net) {
3546+
GEMMBenchmark(state,
3547+
xnn_f32_gemm_minmax_ukernel_6x2c4__fma3,
3548+
xnn_init_f32_minmax_scalar_params,
3549+
xnn_pack_f32_gemm_goi_w,
3550+
/*mr=*/6, /*nr=*/2, /*kr=*/4, /*sr=*/1,
3551+
benchmark::utils::CheckFMA3);
3552+
}
3553+
3554+
BENCHMARK_GEMM(f32_gemm_minmax_ukernel_6x2c4__fma3)
3555+
35343556
static void f32_gemm_minmax_ukernel_1x8__fma3_broadcast(benchmark::State& state, const char* net) {
35353557
GEMMBenchmark(state,
35363558
xnn_f32_gemm_minmax_ukernel_1x8__fma3_broadcast,

cmake/gen/fma3_microkernels.cmake

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -125,8 +125,10 @@ SET(NON_PROD_FMA3_MICROKERNEL_SRCS
125125
src/f32-dwconv/gen/f32-dwconv-25p16c-minmax-fma3.c
126126
src/f32-gemm/gen/f32-gemm-3x16-minmax-fma3-broadcast.c
127127
src/f32-gemm/gen/f32-gemm-3x16s4-minmax-fma3-broadcast.c
128+
src/f32-gemm/gen/f32-gemm-4x2c4-minmax-fma3.c
128129
src/f32-gemm/gen/f32-gemm-4x16-minmax-fma3-broadcast.c
129130
src/f32-gemm/gen/f32-gemm-5x16s4-minmax-fma3-broadcast.c
131+
src/f32-gemm/gen/f32-gemm-6x2c4-minmax-fma3.c
130132
src/f32-gemm/gen/f32-gemm-6x8-minmax-fma3-broadcast.c
131133
src/f32-gemm/gen/f32-gemm-6x16-minmax-fma3-broadcast.c
132134
src/f32-gemm/gen/f32-gemm-6x16s4-minmax-fma3-broadcast.c
@@ -150,9 +152,11 @@ SET(NON_PROD_FMA3_MICROKERNEL_SRCS
150152
src/f32-gemminc/gen/f32-gemminc-8x8-minmax-fma3-broadcast.c
151153
src/f32-igemm/gen/f32-igemm-3x16-minmax-fma3-broadcast.c
152154
src/f32-igemm/gen/f32-igemm-3x16s4-minmax-fma3-broadcast.c
155+
src/f32-igemm/gen/f32-igemm-4x2c4-minmax-fma3.c
153156
src/f32-igemm/gen/f32-igemm-4x16-minmax-fma3-broadcast.c
154157
src/f32-igemm/gen/f32-igemm-5x16-minmax-fma3-broadcast.c
155158
src/f32-igemm/gen/f32-igemm-5x16s4-minmax-fma3-broadcast.c
159+
src/f32-igemm/gen/f32-igemm-6x2c4-minmax-fma3.c
156160
src/f32-igemm/gen/f32-igemm-6x8-minmax-fma3-broadcast.c
157161
src/f32-igemm/gen/f32-igemm-6x16-minmax-fma3-broadcast-prfm.c
158162
src/f32-igemm/gen/f32-igemm-6x16-minmax-fma3-broadcast.c
@@ -167,7 +171,9 @@ SET(NON_PROD_FMA3_MICROKERNEL_SRCS
167171
src/f32-qc4w-gemm/gen/f32-qc4w-gemm-8x16-minmax-fma3-broadcast.c
168172
src/f32-qc8w-gemm/gen/f32-qc8w-gemm-2x16-minmax-fma3-broadcast.c
169173
src/f32-qc8w-gemm/gen/f32-qc8w-gemm-3x16-minmax-fma3-broadcast.c
174+
src/f32-qc8w-gemm/gen/f32-qc8w-gemm-4x2c4-minmax-fma3.c
170175
src/f32-qc8w-gemm/gen/f32-qc8w-gemm-4x16-minmax-fma3-broadcast.c
176+
src/f32-qc8w-gemm/gen/f32-qc8w-gemm-6x2c4-minmax-fma3.c
171177
src/f32-qc8w-gemm/gen/f32-qc8w-gemm-6x16-minmax-fma3-broadcast.c
172178
src/f32-qc8w-gemm/gen/f32-qc8w-gemm-7x16-minmax-fma3-broadcast.c
173179
src/f32-qc8w-gemm/gen/f32-qc8w-gemm-8x16-minmax-fma3-broadcast.c

gen/fma3_microkernels.bzl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -122,8 +122,10 @@ NON_PROD_FMA3_MICROKERNEL_SRCS = [
122122
"src/f32-dwconv/gen/f32-dwconv-25p16c-minmax-fma3.c",
123123
"src/f32-gemm/gen/f32-gemm-3x16-minmax-fma3-broadcast.c",
124124
"src/f32-gemm/gen/f32-gemm-3x16s4-minmax-fma3-broadcast.c",
125+
"src/f32-gemm/gen/f32-gemm-4x2c4-minmax-fma3.c",
125126
"src/f32-gemm/gen/f32-gemm-4x16-minmax-fma3-broadcast.c",
126127
"src/f32-gemm/gen/f32-gemm-5x16s4-minmax-fma3-broadcast.c",
128+
"src/f32-gemm/gen/f32-gemm-6x2c4-minmax-fma3.c",
127129
"src/f32-gemm/gen/f32-gemm-6x8-minmax-fma3-broadcast.c",
128130
"src/f32-gemm/gen/f32-gemm-6x16-minmax-fma3-broadcast.c",
129131
"src/f32-gemm/gen/f32-gemm-6x16s4-minmax-fma3-broadcast.c",
@@ -147,9 +149,11 @@ NON_PROD_FMA3_MICROKERNEL_SRCS = [
147149
"src/f32-gemminc/gen/f32-gemminc-8x8-minmax-fma3-broadcast.c",
148150
"src/f32-igemm/gen/f32-igemm-3x16-minmax-fma3-broadcast.c",
149151
"src/f32-igemm/gen/f32-igemm-3x16s4-minmax-fma3-broadcast.c",
152+
"src/f32-igemm/gen/f32-igemm-4x2c4-minmax-fma3.c",
150153
"src/f32-igemm/gen/f32-igemm-4x16-minmax-fma3-broadcast.c",
151154
"src/f32-igemm/gen/f32-igemm-5x16-minmax-fma3-broadcast.c",
152155
"src/f32-igemm/gen/f32-igemm-5x16s4-minmax-fma3-broadcast.c",
156+
"src/f32-igemm/gen/f32-igemm-6x2c4-minmax-fma3.c",
153157
"src/f32-igemm/gen/f32-igemm-6x8-minmax-fma3-broadcast.c",
154158
"src/f32-igemm/gen/f32-igemm-6x16-minmax-fma3-broadcast-prfm.c",
155159
"src/f32-igemm/gen/f32-igemm-6x16-minmax-fma3-broadcast.c",
@@ -164,7 +168,9 @@ NON_PROD_FMA3_MICROKERNEL_SRCS = [
164168
"src/f32-qc4w-gemm/gen/f32-qc4w-gemm-8x16-minmax-fma3-broadcast.c",
165169
"src/f32-qc8w-gemm/gen/f32-qc8w-gemm-2x16-minmax-fma3-broadcast.c",
166170
"src/f32-qc8w-gemm/gen/f32-qc8w-gemm-3x16-minmax-fma3-broadcast.c",
171+
"src/f32-qc8w-gemm/gen/f32-qc8w-gemm-4x2c4-minmax-fma3.c",
167172
"src/f32-qc8w-gemm/gen/f32-qc8w-gemm-4x16-minmax-fma3-broadcast.c",
173+
"src/f32-qc8w-gemm/gen/f32-qc8w-gemm-6x2c4-minmax-fma3.c",
168174
"src/f32-qc8w-gemm/gen/f32-qc8w-gemm-6x16-minmax-fma3-broadcast.c",
169175
"src/f32-qc8w-gemm/gen/f32-qc8w-gemm-7x16-minmax-fma3-broadcast.c",
170176
"src/f32-qc8w-gemm/gen/f32-qc8w-gemm-8x16-minmax-fma3-broadcast.c",

scripts/generate-f32-gemm.sh

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -534,8 +534,10 @@ tools/xngen src/f32-gemm/sse-shuffle.c.in -D MR=6 -D NR=8 -D INC=0 -D SSE=1 -D D
534534
tools/xngen src/f32-gemm/sse-shuffle.c.in -D MR=6 -D NR=8 -D INC=1 -D SSE=1 -D DATATYPE=F32 -o src/f32-gemminc/gen/f32-gemminc-6x8s4-minmax-sse.c &
535535

536536
### MRx2 micro-kernels
537-
tools/xngen src/f32-gemm/MRx2c4-sse.c.in -D MR=4 -D NR=2 -D SSE=1 -D DATATYPE=F32 -o src/f32-gemm/gen/f32-gemm-4x2c4-minmax-sse.c &
538-
tools/xngen src/f32-gemm/MRx2c4-sse.c.in -D MR=6 -D NR=2 -D SSE=1 -D DATATYPE=F32 -o src/f32-gemm/gen/f32-gemm-6x2c4-minmax-sse.c &
537+
tools/xngen src/f32-gemm/MRx2c4-sse.c.in -D MR=4 -D NR=2 -D SSE=1 -D FMA=0 -D DATATYPE=F32 -o src/f32-gemm/gen/f32-gemm-4x2c4-minmax-sse.c &
538+
tools/xngen src/f32-gemm/MRx2c4-sse.c.in -D MR=6 -D NR=2 -D SSE=1 -D FMA=0 -D DATATYPE=F32 -o src/f32-gemm/gen/f32-gemm-6x2c4-minmax-sse.c &
539+
tools/xngen src/f32-gemm/MRx2c4-sse.c.in -D MR=4 -D NR=2 -D SSE=1 -D FMA=3 -D DATATYPE=F32 -o src/f32-gemm/gen/f32-gemm-4x2c4-minmax-fma3.c &
540+
tools/xngen src/f32-gemm/MRx2c4-sse.c.in -D MR=6 -D NR=2 -D SSE=1 -D FMA=3 -D DATATYPE=F32 -o src/f32-gemm/gen/f32-gemm-6x2c4-minmax-fma3.c &
539541

540542
################################### x86 AVX ###################################
541543
### AVX+BROADCAST micro-kernels

scripts/generate-f32-igemm.sh

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -315,8 +315,10 @@ tools/xngen src/f32-igemm/sse-shuffle.c.in -D MR=5 -D NR=8 -o src/f32-igemm/gen/
315315
tools/xngen src/f32-igemm/sse-shuffle.c.in -D MR=6 -D NR=8 -o src/f32-igemm/gen/f32-igemm-6x8s4-minmax-sse.c &
316316

317317
### MRx2 micro-kernels
318-
tools/xngen src/f32-igemm/MRx2c4-sse.c.in -D MR=4 -D NR=2 -o src/f32-igemm/gen/f32-igemm-4x2c4-minmax-sse.c &
319-
tools/xngen src/f32-igemm/MRx2c4-sse.c.in -D MR=6 -D NR=2 -o src/f32-igemm/gen/f32-igemm-6x2c4-minmax-sse.c &
318+
tools/xngen src/f32-igemm/MRx2c4-sse.c.in -D MR=4 -D NR=2 -D FMA=0 -o src/f32-igemm/gen/f32-igemm-4x2c4-minmax-sse.c &
319+
tools/xngen src/f32-igemm/MRx2c4-sse.c.in -D MR=6 -D NR=2 -D FMA=0 -o src/f32-igemm/gen/f32-igemm-6x2c4-minmax-sse.c &
320+
tools/xngen src/f32-igemm/MRx2c4-sse.c.in -D MR=4 -D NR=2 -D FMA=3 -o src/f32-igemm/gen/f32-igemm-4x2c4-minmax-fma3.c &
321+
tools/xngen src/f32-igemm/MRx2c4-sse.c.in -D MR=6 -D NR=2 -D FMA=3 -o src/f32-igemm/gen/f32-igemm-6x2c4-minmax-fma3.c &
320322

321323
################################### x86 AVX ###################################
322324
### AVX+BROADCAST micro-kernels

scripts/generate-f32-qc8w-gemm.sh

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -272,8 +272,10 @@ tools/xngen src/f32-gemm/sse-shuffle.c.in -D MR=5 -D NR=8 -D INC=0 -D SSE=4 -D D
272272
tools/xngen src/f32-gemm/sse-shuffle.c.in -D MR=6 -D NR=8 -D INC=0 -D SSE=4 -D DATATYPE=QC8 -o src/f32-qc8w-gemm/gen/f32-qc8w-gemm-6x8s4-minmax-sse41.c &
273273

274274
### MRx2 micro-kernels
275-
tools/xngen src/f32-gemm/MRx2c4-sse.c.in -D MR=4 -D NR=2 -D SSE=4 -D DATATYPE=QC8 -o src/f32-qc8w-gemm/gen/f32-qc8w-gemm-4x2c4-minmax-sse41.c &
276-
tools/xngen src/f32-gemm/MRx2c4-sse.c.in -D MR=6 -D NR=2 -D SSE=4 -D DATATYPE=QC8 -o src/f32-qc8w-gemm/gen/f32-qc8w-gemm-6x2c4-minmax-sse41.c &
275+
tools/xngen src/f32-gemm/MRx2c4-sse.c.in -D MR=4 -D NR=2 -D SSE=4 -D FMA=0 -D DATATYPE=QC8 -o src/f32-qc8w-gemm/gen/f32-qc8w-gemm-4x2c4-minmax-sse41.c &
276+
tools/xngen src/f32-gemm/MRx2c4-sse.c.in -D MR=6 -D NR=2 -D SSE=4 -D FMA=0 -D DATATYPE=QC8 -o src/f32-qc8w-gemm/gen/f32-qc8w-gemm-6x2c4-minmax-sse41.c &
277+
tools/xngen src/f32-gemm/MRx2c4-sse.c.in -D MR=4 -D NR=2 -D SSE=4 -D FMA=3 -D DATATYPE=QC8 -o src/f32-qc8w-gemm/gen/f32-qc8w-gemm-4x2c4-minmax-fma3.c &
278+
tools/xngen src/f32-gemm/MRx2c4-sse.c.in -D MR=6 -D NR=2 -D SSE=4 -D FMA=3 -D DATATYPE=QC8 -o src/f32-qc8w-gemm/gen/f32-qc8w-gemm-6x2c4-minmax-fma3.c &
277279

278280
################################### x86 AVX ###################################
279281
### AVX BROADCAST micro-kernels

src/f32-gemm/MRx2c4-sse.c.in

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,9 +20,9 @@ $if DATATYPE == "QC8":
2020

2121
$RANGE_MRX2 = list(range(0, MR, 2))
2222
$if DATATYPE in ["QC8", "QC4"]:
23-
$ISA = {2: "sse2", 4: "sse41"}[SSE]
23+
$ISA = "fma3" if FMA else {2: "sse2", 4: "sse41"}[SSE]
2424
$else:
25-
$ISA = "sse"
25+
$ISA = "fma3" if FMA else "sse"
2626
$DATATYPE_SPEC = {"F32": "f32", "QC8": "f32_qc8w", "QC4": "f32_qc4w"}[DATATYPE]
2727
void xnn_${DATATYPE_SPEC}_gemm_minmax_ukernel_${MR}x${NR}c4__${ISA}(
2828
size_t mr,
@@ -121,7 +121,10 @@ void xnn_${DATATYPE_SPEC}_gemm_minmax_ukernel_${MR}x${NR}c4__${ISA}(
121121

122122
$for M in range(MR):
123123
$for N in range(NR):
124-
vacc${M}x${N}c4 = _mm_add_ps(vacc${M}x${N}c4, _mm_mul_ps(va${M}, vb${N}));
124+
$if FMA == 3:
125+
vacc${M}x${N}c4 = _mm_fmadd_ps(va${M}, vb${N}, vacc${M}x${N}c4);
126+
$else:
127+
vacc${M}x${N}c4 = _mm_add_ps(vacc${M}x${N}c4, _mm_mul_ps(va${M}, vb${N}));
125128
}
126129
if XNN_UNLIKELY(k != 0) {
127130
$for M in range(MR):
@@ -154,7 +157,10 @@ void xnn_${DATATYPE_SPEC}_gemm_minmax_ukernel_${MR}x${NR}c4__${ISA}(
154157

155158
$for M in range(MR):
156159
$for N in range(NR):
157-
vacc${M}x${N}c4 = _mm_add_ps(vacc${M}x${N}c4, _mm_mul_ps(_mm_andnot_ps(vmask${N}, va${M}), vb${N}));
160+
$if FMA == 3:
161+
vacc${M}x${N}c4 = _mm_fmadd_ps(_mm_andnot_ps(vmask${N}, va${M}), vb${N}, vacc${M}x${N}c4);
162+
$else:
163+
vacc${M}x${N}c4 = _mm_add_ps(vacc${M}x${N}c4, _mm_mul_ps(_mm_andnot_ps(vmask${N}, va${M}), vb${N}));
158164
}
159165

160166
$for M in range(MR):
Lines changed: 165 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,165 @@
1+
// Auto-generated file. Do not edit!
2+
// Template: src/f32-gemm/MRx2c4-sse.c.in
3+
// Generator: tools/xngen
4+
//
5+
// Copyright 2019 Google LLC
6+
//
7+
// This source code is licensed under the BSD-style license found in the
8+
// LICENSE file in the root directory of this source tree.
9+
10+
#include <assert.h>
11+
12+
#include <xmmintrin.h>
13+
14+
#include "xnnpack/gemm.h"
15+
16+
17+
void xnn_f32_gemm_minmax_ukernel_4x2c4__fma3(
18+
size_t mr,
19+
size_t nc,
20+
size_t kc,
21+
const float* restrict a,
22+
size_t a_stride,
23+
const float* restrict w,
24+
float* restrict c,
25+
size_t cm_stride,
26+
size_t cn_stride,
27+
const union xnn_f32_minmax_params params[restrict XNN_MIN_ELEMENTS(1)]) XNN_OOB_READS
28+
{
29+
assert(mr != 0);
30+
assert(mr <= 4);
31+
assert(nc != 0);
32+
assert(kc != 0);
33+
assert(kc % sizeof(float) == 0);
34+
assert(a != NULL);
35+
assert(w != NULL);
36+
assert(c != NULL);
37+
38+
const float* a0 = a;
39+
float* c0 = c;
40+
const float* a1 = (const float*) ((uintptr_t) a0 + a_stride);
41+
float* c1 = (float*) ((uintptr_t) c0 + cm_stride);
42+
if XNN_UNPREDICTABLE(mr < 2) {
43+
a1 = a0;
44+
c1 = c0;
45+
}
46+
const float* a2 = (const float*) ((uintptr_t) a1 + a_stride);
47+
float* c2 = (float*) ((uintptr_t) c1 + cm_stride);
48+
if XNN_UNPREDICTABLE(mr <= 2) {
49+
a2 = a1;
50+
c2 = c1;
51+
}
52+
const float* a3 = (const float*) ((uintptr_t) a2 + a_stride);
53+
float* c3 = (float*) ((uintptr_t) c2 + cm_stride);
54+
if XNN_UNPREDICTABLE(mr != 4) {
55+
a3 = a2;
56+
c3 = c2;
57+
}
58+
59+
const __m128 vmin = _mm_set1_ps(params->scalar.min);
60+
const __m128 vmax = _mm_set1_ps(params->scalar.max);
61+
XNN_FORCE_REALIZATION(vmin);
62+
XNN_FORCE_REALIZATION(vmax);
63+
64+
do {
65+
__m128 vacc0x0c4 = _mm_load_ss(w);
66+
__m128 vacc0x1c4 = _mm_load_ss(w + 1);
67+
__m128 vacc1x0c4 = vacc0x0c4;
68+
__m128 vacc1x1c4 = vacc0x1c4;
69+
__m128 vacc2x0c4 = vacc0x0c4;
70+
__m128 vacc2x1c4 = vacc0x1c4;
71+
__m128 vacc3x0c4 = vacc0x0c4;
72+
__m128 vacc3x1c4 = vacc0x1c4;
73+
w += 2;
74+
75+
size_t k = kc;
76+
for (; k >= 4 * sizeof(float); k -= 4 * sizeof(float)) {
77+
const __m128 va0 = _mm_loadu_ps(a0);
78+
a0 += 4;
79+
const __m128 va1 = _mm_loadu_ps(a1);
80+
a1 += 4;
81+
const __m128 va2 = _mm_loadu_ps(a2);
82+
a2 += 4;
83+
const __m128 va3 = _mm_loadu_ps(a3);
84+
a3 += 4;
85+
86+
const __m128 vb0 = _mm_loadu_ps(w);
87+
const __m128 vb1 = _mm_loadu_ps(w + 4);
88+
w += 8;
89+
90+
vacc0x0c4 = _mm_fmadd_ps(va0, vb0, vacc0x0c4);
91+
vacc0x1c4 = _mm_fmadd_ps(va0, vb1, vacc0x1c4);
92+
vacc1x0c4 = _mm_fmadd_ps(va1, vb0, vacc1x0c4);
93+
vacc1x1c4 = _mm_fmadd_ps(va1, vb1, vacc1x1c4);
94+
vacc2x0c4 = _mm_fmadd_ps(va2, vb0, vacc2x0c4);
95+
vacc2x1c4 = _mm_fmadd_ps(va2, vb1, vacc2x1c4);
96+
vacc3x0c4 = _mm_fmadd_ps(va3, vb0, vacc3x0c4);
97+
vacc3x1c4 = _mm_fmadd_ps(va3, vb1, vacc3x1c4);
98+
}
99+
if XNN_UNLIKELY(k != 0) {
100+
const __m128 va0 = _mm_loadu_ps(a0);
101+
a0 = (const float*) ((uintptr_t) a0 + k);
102+
const __m128 va1 = _mm_loadu_ps(a1);
103+
a1 = (const float*) ((uintptr_t) a1 + k);
104+
const __m128 va2 = _mm_loadu_ps(a2);
105+
a2 = (const float*) ((uintptr_t) a2 + k);
106+
const __m128 va3 = _mm_loadu_ps(a3);
107+
a3 = (const float*) ((uintptr_t) a3 + k);
108+
109+
const __m128 vb0 = _mm_loadu_ps(w);
110+
const __m128 vb1 = _mm_loadu_ps(w + 4);
111+
w += 8;
112+
113+
const __m128 vmask0 = _mm_cmpeq_ps(_mm_setzero_ps(), vb0);
114+
const __m128 vmask1 = _mm_cmpeq_ps(_mm_setzero_ps(), vb1);
115+
116+
vacc0x0c4 = _mm_fmadd_ps(_mm_andnot_ps(vmask0, va0), vb0, vacc0x0c4);
117+
vacc0x1c4 = _mm_fmadd_ps(_mm_andnot_ps(vmask1, va0), vb1, vacc0x1c4);
118+
vacc1x0c4 = _mm_fmadd_ps(_mm_andnot_ps(vmask0, va1), vb0, vacc1x0c4);
119+
vacc1x1c4 = _mm_fmadd_ps(_mm_andnot_ps(vmask1, va1), vb1, vacc1x1c4);
120+
vacc2x0c4 = _mm_fmadd_ps(_mm_andnot_ps(vmask0, va2), vb0, vacc2x0c4);
121+
vacc2x1c4 = _mm_fmadd_ps(_mm_andnot_ps(vmask1, va2), vb1, vacc2x1c4);
122+
vacc3x0c4 = _mm_fmadd_ps(_mm_andnot_ps(vmask0, va3), vb0, vacc3x0c4);
123+
vacc3x1c4 = _mm_fmadd_ps(_mm_andnot_ps(vmask1, va3), vb1, vacc3x1c4);
124+
}
125+
126+
const __m128 vacc0x01c2 = _mm_add_ps(_mm_unpacklo_ps(vacc0x0c4, vacc0x1c4), _mm_unpackhi_ps(vacc0x0c4, vacc0x1c4));
127+
const __m128 vacc1x01c2 = _mm_add_ps(_mm_unpacklo_ps(vacc1x0c4, vacc1x1c4), _mm_unpackhi_ps(vacc1x0c4, vacc1x1c4));
128+
const __m128 vacc2x01c2 = _mm_add_ps(_mm_unpacklo_ps(vacc2x0c4, vacc2x1c4), _mm_unpackhi_ps(vacc2x0c4, vacc2x1c4));
129+
const __m128 vacc3x01c2 = _mm_add_ps(_mm_unpacklo_ps(vacc3x0c4, vacc3x1c4), _mm_unpackhi_ps(vacc3x0c4, vacc3x1c4));
130+
131+
__m128 vacc01x01 = _mm_add_ps(_mm_movelh_ps(vacc0x01c2, vacc1x01c2), _mm_movehl_ps(vacc1x01c2, vacc0x01c2));
132+
__m128 vacc23x01 = _mm_add_ps(_mm_movelh_ps(vacc2x01c2, vacc3x01c2), _mm_movehl_ps(vacc3x01c2, vacc2x01c2));
133+
134+
vacc01x01 = _mm_min_ps(vacc01x01, vmax);
135+
vacc23x01 = _mm_min_ps(vacc23x01, vmax);
136+
137+
vacc01x01 = _mm_max_ps(vacc01x01, vmin);
138+
vacc23x01 = _mm_max_ps(vacc23x01, vmin);
139+
140+
if XNN_LIKELY(nc >= 2) {
141+
_mm_storel_pi((__m64*) c0, vacc01x01);
142+
c0 = (float*) ((uintptr_t) c0 + cn_stride);
143+
a0 = (const float*) ((uintptr_t) a0 - kc);
144+
_mm_storeh_pi((__m64*) c1, vacc01x01);
145+
c1 = (float*) ((uintptr_t) c1 + cn_stride);
146+
a1 = (const float*) ((uintptr_t) a1 - kc);
147+
_mm_storel_pi((__m64*) c2, vacc23x01);
148+
c2 = (float*) ((uintptr_t) c2 + cn_stride);
149+
a2 = (const float*) ((uintptr_t) a2 - kc);
150+
_mm_storeh_pi((__m64*) c3, vacc23x01);
151+
c3 = (float*) ((uintptr_t) c3 + cn_stride);
152+
a3 = (const float*) ((uintptr_t) a3 - kc);
153+
154+
nc -= 2;
155+
} else {
156+
assert(nc == 1);
157+
_mm_store_ss(c0, vacc01x01);
158+
_mm_store_ss(c1, _mm_movehl_ps(vacc01x01, vacc01x01));
159+
_mm_store_ss(c2, vacc23x01);
160+
_mm_store_ss(c3, _mm_movehl_ps(vacc23x01, vacc23x01));
161+
162+
nc = 0;
163+
}
164+
} while (nc != 0);
165+
}

0 commit comments

Comments
 (0)