Skip to content

Commit a0526d0

Browse files
fbarchardxnnpack-bot
authored andcommitted
Enable QS8-GEMM for HVX
PiperOrigin-RevId: 743053358
1 parent 95b13c3 commit a0526d0

9 files changed

+270
-81
lines changed

bench/qs8-qc8w-gemm-fp32.cc

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,53 @@
2020
#include "src/xnnpack/packw.h"
2121

2222

23+
#if XNN_ENABLE_HVX && XNN_ARCH_HEXAGON
24+
static void qs8_qc8w_gemm_minmax_fp32_ukernel_1x32c4__hvx(benchmark::State& state, const char* net) {
25+
GEMMBenchmark(state,
26+
xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_1x32c4__hvx,
27+
xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params,
28+
xnn_pack_qs8_gemm_goi_w,
29+
/*mr=*/1, /*nr=*/32, /*kr=*/4, /*sr=*/1,
30+
benchmark::utils::CheckHVX);
31+
}
32+
33+
BENCHMARK_GEMM(qs8_qc8w_gemm_minmax_fp32_ukernel_1x32c4__hvx)
34+
35+
static void qs8_qc8w_gemm_minmax_fp32_ukernel_4x32c4__hvx(benchmark::State& state, const char* net) {
36+
GEMMBenchmark(state,
37+
xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_4x32c4__hvx,
38+
xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params,
39+
xnn_pack_qs8_gemm_goi_w,
40+
/*mr=*/4, /*nr=*/32, /*kr=*/4, /*sr=*/1,
41+
benchmark::utils::CheckHVX);
42+
}
43+
44+
BENCHMARK_GEMM(qs8_qc8w_gemm_minmax_fp32_ukernel_4x32c4__hvx)
45+
46+
static void qs8_qc8w_gemm_minmax_fp32_ukernel_8x32c4__hvx(benchmark::State& state, const char* net) {
47+
GEMMBenchmark(state,
48+
xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_8x32c4__hvx,
49+
xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params,
50+
xnn_pack_qs8_gemm_goi_w,
51+
/*mr=*/8, /*nr=*/32, /*kr=*/4, /*sr=*/1,
52+
benchmark::utils::CheckHVX);
53+
}
54+
55+
BENCHMARK_GEMM(qs8_qc8w_gemm_minmax_fp32_ukernel_8x32c4__hvx)
56+
57+
static void qs8_qc8w_gemm_minmax_fp32_ukernel_16x32c4__hvx(benchmark::State& state, const char* net) {
58+
GEMMBenchmark(state,
59+
xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_16x32c4__hvx,
60+
xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params,
61+
xnn_pack_qs8_gemm_goi_w,
62+
/*mr=*/16, /*nr=*/32, /*kr=*/4, /*sr=*/1,
63+
benchmark::utils::CheckHVX);
64+
}
65+
66+
BENCHMARK_GEMM(qs8_qc8w_gemm_minmax_fp32_ukernel_16x32c4__hvx)
67+
#endif // XNN_ENABLE_HVX && XNN_ARCH_HEXAGON
68+
69+
2370
#if XNN_ENABLE_RISCV_VECTOR && XNN_ARCH_RISCV
2471
static void qs8_qc8w_gemm_minmax_fp32_ukernel_1x4v__rvv(benchmark::State& state, const char* net) {
2572
GEMMBenchmark(state,

src/qs8-gemm/MRx32c4-hvx.c.in

Lines changed: 16 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ $assert REQUANTIZATION == "FP32" or not REQUANTIZATION
77
$assert DATATYPE in ["QC8"]
88

99
#include <assert.h>
10+
#include <string.h> // for memcpy
1011

1112
#include <hexagon_types.h>
1213
#include <hexagon_protos.h>
@@ -18,6 +19,11 @@ $assert DATATYPE in ["QC8"]
1819
#include "src/xnnpack/unaligned.h"
1920

2021

22+
static XNN_INTRINSIC void xnn_Q6_V_vstu_variable(void* addr, uint32_t n,
23+
HVX_Vector vin) {
24+
memcpy(addr, &vin, n);
25+
}
26+
2127
$DATATYPE_SPEC = {"QC8": "qs8_qc8w", "QD8": "qd8_f32_qc8w", "QS8": "qs8", "QU8": "qu8", "QC4": "qd8_f32_qc4w"}[DATATYPE]
2228
$REQUANTIZATION_SPEC = "" if DATATYPE in ["QD8", "QC4"] else "_" + REQUANTIZATION.lower()
2329
$PARAMS_STRUCT = REQUANTIZATION.lower() + "_scalar" if REQUANTIZATION else "scalar"
@@ -74,23 +80,23 @@ void xnn_${DATATYPE_SPEC}_gemm_minmax${REQUANTIZATION_SPEC}_ukernel_${MR}x32c4__
7480
const HVX_Vector voutput_min = Q6_Vb_vsplat_R(params->${PARAMS_STRUCT}.output_min);
7581

7682
do {
77-
HVX_Vector vacc0x32 = *((HVX_Vector*)w);
83+
HVX_Vector vacc0x32 = *((HVX_UVector*)w);
7884
HVX_Vector vacc1x0x32 = Q6_V_vsplat_R(0);
7985
$for M in range(1, MR):
80-
HVX_Vector vacc${M}x32 = *((HVX_Vector*)w);
81-
HVX_Vector vacc1x${M}x32 = Q6_V_vsplat_R(0);
86+
HVX_Vector vacc${M}x32 = *((HVX_UVector*)w);
87+
HVX_Vector vacc1x${M}x32 = Q6_V_vsplat_R(0);
8288

8389
w = (const int32_t*) w + 32;
8490

8591
size_t k = kc;
8692
while (k >= 8 * sizeof(${XINT8_T})) {
8793
$for M in range(MR):
88-
const HVX_Vector va${M}x0123 = Q6_V_vsplat_R(unaligned_load_s32(a${M}));
89-
const HVX_Vector va${M}x4567 = Q6_V_vsplat_R(unaligned_load_s32(a${M}+4));
90-
a${M} += 8;
94+
const HVX_Vector va${M}x0123 = Q6_V_vsplat_R(unaligned_load_s32(a${M}));
95+
const HVX_Vector va${M}x4567 = Q6_V_vsplat_R(unaligned_load_s32(a${M}+4));
96+
a${M} += 8;
9197

92-
const HVX_Vector vb32x0123 = *((HVX_Vector *)((${XINT8_T} *)w));
93-
const HVX_Vector vb32x4567 = *((HVX_Vector *)((${XINT8_T} *)w + 128));
98+
const HVX_Vector vb32x0123 = *((HVX_UVector *)((${XINT8_T} *)w));
99+
const HVX_Vector vb32x4567 = *((HVX_UVector *)((${XINT8_T} *)w + 128));
94100
$for M in range(MR):
95101
vacc${M}x32 = Q6_Vw_vrmpyacc_VwVbVb(vacc${M}x32, va${M}x0123, vb32x0123);
96102
vacc1x${M}x32 = Q6_Vw_vrmpyacc_VwVbVb(vacc1x${M}x32, va${M}x4567, vb32x4567);
@@ -107,7 +113,7 @@ void xnn_${DATATYPE_SPEC}_gemm_minmax${REQUANTIZATION_SPEC}_ukernel_${MR}x32c4__
107113
const HVX_Vector va${M}x0123 = Q6_V_vsplat_R(unaligned_load_s32(a${M}));
108114
a${M} += 4;
109115

110-
const HVX_Vector vb32x0123 = *((HVX_Vector *)((${XINT8_T} *)w));
116+
const HVX_Vector vb32x0123 = *((HVX_UVector *)((${XINT8_T} *)w));
111117
$for M in range(MR):
112118
vacc${M}x32 = Q6_Vw_vrmpyacc_VwVbVb(vacc${M}x32, va${M}x0123, vb32x0123);
113119

@@ -153,7 +159,7 @@ void xnn_${DATATYPE_SPEC}_gemm_minmax${REQUANTIZATION_SPEC}_ukernel_${MR}x32c4__
153159
} else {
154160
// Prepare mask for valid 8-bit elements (depends on nc).
155161
$for M in range(MR):
156-
Q6_V_vstu_variable(c${M}, nc, vout${M}x32);
162+
xnn_Q6_V_vstu_variable(c${M}, nc, vout${M}x32);
157163
nc = 0;
158164
}
159165
} while (nc != 0);

src/qs8-qc8w-gemm/gen/qs8-qc8w-gemm-16x32c4-minmax-fp32-hvx.c

Lines changed: 41 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010

1111

1212
#include <assert.h>
13+
#include <string.h> // for memcpy
1314

1415
#include <hexagon_types.h>
1516
#include <hexagon_protos.h>
@@ -21,6 +22,11 @@
2122
#include "src/xnnpack/unaligned.h"
2223

2324

25+
static XNN_INTRINSIC void xnn_Q6_V_vstu_variable(void* addr, uint32_t n,
26+
HVX_Vector vin) {
27+
memcpy(addr, &vin, n);
28+
}
29+
2430

2531
void xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_16x32c4__hvx(
2632
size_t mr,
@@ -143,37 +149,37 @@ void xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_16x32c4__hvx(
143149
const HVX_Vector voutput_min = Q6_Vb_vsplat_R(params->fp32_scalar.output_min);
144150

145151
do {
146-
HVX_Vector vacc0x32 = *((HVX_Vector*)w);
152+
HVX_Vector vacc0x32 = *((HVX_UVector*)w);
147153
HVX_Vector vacc1x0x32 = Q6_V_vsplat_R(0);
148-
HVX_Vector vacc1x32 = *((HVX_Vector*)w);
154+
HVX_Vector vacc1x32 = *((HVX_UVector*)w);
149155
HVX_Vector vacc1x1x32 = Q6_V_vsplat_R(0);
150-
HVX_Vector vacc2x32 = *((HVX_Vector*)w);
156+
HVX_Vector vacc2x32 = *((HVX_UVector*)w);
151157
HVX_Vector vacc1x2x32 = Q6_V_vsplat_R(0);
152-
HVX_Vector vacc3x32 = *((HVX_Vector*)w);
158+
HVX_Vector vacc3x32 = *((HVX_UVector*)w);
153159
HVX_Vector vacc1x3x32 = Q6_V_vsplat_R(0);
154-
HVX_Vector vacc4x32 = *((HVX_Vector*)w);
160+
HVX_Vector vacc4x32 = *((HVX_UVector*)w);
155161
HVX_Vector vacc1x4x32 = Q6_V_vsplat_R(0);
156-
HVX_Vector vacc5x32 = *((HVX_Vector*)w);
162+
HVX_Vector vacc5x32 = *((HVX_UVector*)w);
157163
HVX_Vector vacc1x5x32 = Q6_V_vsplat_R(0);
158-
HVX_Vector vacc6x32 = *((HVX_Vector*)w);
164+
HVX_Vector vacc6x32 = *((HVX_UVector*)w);
159165
HVX_Vector vacc1x6x32 = Q6_V_vsplat_R(0);
160-
HVX_Vector vacc7x32 = *((HVX_Vector*)w);
166+
HVX_Vector vacc7x32 = *((HVX_UVector*)w);
161167
HVX_Vector vacc1x7x32 = Q6_V_vsplat_R(0);
162-
HVX_Vector vacc8x32 = *((HVX_Vector*)w);
168+
HVX_Vector vacc8x32 = *((HVX_UVector*)w);
163169
HVX_Vector vacc1x8x32 = Q6_V_vsplat_R(0);
164-
HVX_Vector vacc9x32 = *((HVX_Vector*)w);
170+
HVX_Vector vacc9x32 = *((HVX_UVector*)w);
165171
HVX_Vector vacc1x9x32 = Q6_V_vsplat_R(0);
166-
HVX_Vector vacc10x32 = *((HVX_Vector*)w);
172+
HVX_Vector vacc10x32 = *((HVX_UVector*)w);
167173
HVX_Vector vacc1x10x32 = Q6_V_vsplat_R(0);
168-
HVX_Vector vacc11x32 = *((HVX_Vector*)w);
174+
HVX_Vector vacc11x32 = *((HVX_UVector*)w);
169175
HVX_Vector vacc1x11x32 = Q6_V_vsplat_R(0);
170-
HVX_Vector vacc12x32 = *((HVX_Vector*)w);
176+
HVX_Vector vacc12x32 = *((HVX_UVector*)w);
171177
HVX_Vector vacc1x12x32 = Q6_V_vsplat_R(0);
172-
HVX_Vector vacc13x32 = *((HVX_Vector*)w);
178+
HVX_Vector vacc13x32 = *((HVX_UVector*)w);
173179
HVX_Vector vacc1x13x32 = Q6_V_vsplat_R(0);
174-
HVX_Vector vacc14x32 = *((HVX_Vector*)w);
180+
HVX_Vector vacc14x32 = *((HVX_UVector*)w);
175181
HVX_Vector vacc1x14x32 = Q6_V_vsplat_R(0);
176-
HVX_Vector vacc15x32 = *((HVX_Vector*)w);
182+
HVX_Vector vacc15x32 = *((HVX_UVector*)w);
177183
HVX_Vector vacc1x15x32 = Q6_V_vsplat_R(0);
178184

179185
w = (const int32_t*) w + 32;
@@ -229,8 +235,8 @@ void xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_16x32c4__hvx(
229235
const HVX_Vector va15x4567 = Q6_V_vsplat_R(unaligned_load_s32(a15+4));
230236
a15 += 8;
231237

232-
const HVX_Vector vb32x0123 = *((HVX_Vector *)((int8_t *)w));
233-
const HVX_Vector vb32x4567 = *((HVX_Vector *)((int8_t *)w + 128));
238+
const HVX_Vector vb32x0123 = *((HVX_UVector *)((int8_t *)w));
239+
const HVX_Vector vb32x4567 = *((HVX_UVector *)((int8_t *)w + 128));
234240
vacc0x32 = Q6_Vw_vrmpyacc_VwVbVb(vacc0x32, va0x0123, vb32x0123);
235241
vacc1x0x32 = Q6_Vw_vrmpyacc_VwVbVb(vacc1x0x32, va0x4567, vb32x4567);
236242
vacc1x32 = Q6_Vw_vrmpyacc_VwVbVb(vacc1x32, va1x0123, vb32x0123);
@@ -319,7 +325,7 @@ void xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_16x32c4__hvx(
319325
const HVX_Vector va15x0123 = Q6_V_vsplat_R(unaligned_load_s32(a15));
320326
a15 += 4;
321327

322-
const HVX_Vector vb32x0123 = *((HVX_Vector *)((int8_t *)w));
328+
const HVX_Vector vb32x0123 = *((HVX_UVector *)((int8_t *)w));
323329
vacc0x32 = Q6_Vw_vrmpyacc_VwVbVb(vacc0x32, va0x0123, vb32x0123);
324330
vacc1x32 = Q6_Vw_vrmpyacc_VwVbVb(vacc1x32, va1x0123, vb32x0123);
325331
vacc2x32 = Q6_Vw_vrmpyacc_VwVbVb(vacc2x32, va2x0123, vb32x0123);
@@ -548,22 +554,22 @@ void xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_16x32c4__hvx(
548554
nc -= 32;
549555
} else {
550556
// Prepare mask for valid 8-bit elements (depends on nc).
551-
Q6_V_vstu_variable(c0, nc, vout0x32);
552-
Q6_V_vstu_variable(c1, nc, vout1x32);
553-
Q6_V_vstu_variable(c2, nc, vout2x32);
554-
Q6_V_vstu_variable(c3, nc, vout3x32);
555-
Q6_V_vstu_variable(c4, nc, vout4x32);
556-
Q6_V_vstu_variable(c5, nc, vout5x32);
557-
Q6_V_vstu_variable(c6, nc, vout6x32);
558-
Q6_V_vstu_variable(c7, nc, vout7x32);
559-
Q6_V_vstu_variable(c8, nc, vout8x32);
560-
Q6_V_vstu_variable(c9, nc, vout9x32);
561-
Q6_V_vstu_variable(c10, nc, vout10x32);
562-
Q6_V_vstu_variable(c11, nc, vout11x32);
563-
Q6_V_vstu_variable(c12, nc, vout12x32);
564-
Q6_V_vstu_variable(c13, nc, vout13x32);
565-
Q6_V_vstu_variable(c14, nc, vout14x32);
566-
Q6_V_vstu_variable(c15, nc, vout15x32);
557+
xnn_Q6_V_vstu_variable(c0, nc, vout0x32);
558+
xnn_Q6_V_vstu_variable(c1, nc, vout1x32);
559+
xnn_Q6_V_vstu_variable(c2, nc, vout2x32);
560+
xnn_Q6_V_vstu_variable(c3, nc, vout3x32);
561+
xnn_Q6_V_vstu_variable(c4, nc, vout4x32);
562+
xnn_Q6_V_vstu_variable(c5, nc, vout5x32);
563+
xnn_Q6_V_vstu_variable(c6, nc, vout6x32);
564+
xnn_Q6_V_vstu_variable(c7, nc, vout7x32);
565+
xnn_Q6_V_vstu_variable(c8, nc, vout8x32);
566+
xnn_Q6_V_vstu_variable(c9, nc, vout9x32);
567+
xnn_Q6_V_vstu_variable(c10, nc, vout10x32);
568+
xnn_Q6_V_vstu_variable(c11, nc, vout11x32);
569+
xnn_Q6_V_vstu_variable(c12, nc, vout12x32);
570+
xnn_Q6_V_vstu_variable(c13, nc, vout13x32);
571+
xnn_Q6_V_vstu_variable(c14, nc, vout14x32);
572+
xnn_Q6_V_vstu_variable(c15, nc, vout15x32);
567573
nc = 0;
568574
}
569575
} while (nc != 0);

src/qs8-qc8w-gemm/gen/qs8-qc8w-gemm-1x32c4-minmax-fp32-hvx.c

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010

1111

1212
#include <assert.h>
13+
#include <string.h> // for memcpy
1314

1415
#include <hexagon_types.h>
1516
#include <hexagon_protos.h>
@@ -21,6 +22,11 @@
2122
#include "src/xnnpack/unaligned.h"
2223

2324

25+
static XNN_INTRINSIC void xnn_Q6_V_vstu_variable(void* addr, uint32_t n,
26+
HVX_Vector vin) {
27+
memcpy(addr, &vin, n);
28+
}
29+
2430

2531
void xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_1x32c4__hvx(
2632
size_t mr,
@@ -53,7 +59,7 @@ void xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_1x32c4__hvx(
5359
const HVX_Vector voutput_min = Q6_Vb_vsplat_R(params->fp32_scalar.output_min);
5460

5561
do {
56-
HVX_Vector vacc0x32 = *((HVX_Vector*)w);
62+
HVX_Vector vacc0x32 = *((HVX_UVector*)w);
5763
HVX_Vector vacc1x0x32 = Q6_V_vsplat_R(0);
5864

5965
w = (const int32_t*) w + 32;
@@ -64,8 +70,8 @@ void xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_1x32c4__hvx(
6470
const HVX_Vector va0x4567 = Q6_V_vsplat_R(unaligned_load_s32(a0+4));
6571
a0 += 8;
6672

67-
const HVX_Vector vb32x0123 = *((HVX_Vector *)((int8_t *)w));
68-
const HVX_Vector vb32x4567 = *((HVX_Vector *)((int8_t *)w + 128));
73+
const HVX_Vector vb32x0123 = *((HVX_UVector *)((int8_t *)w));
74+
const HVX_Vector vb32x4567 = *((HVX_UVector *)((int8_t *)w + 128));
6975
vacc0x32 = Q6_Vw_vrmpyacc_VwVbVb(vacc0x32, va0x0123, vb32x0123);
7076
vacc1x0x32 = Q6_Vw_vrmpyacc_VwVbVb(vacc1x0x32, va0x4567, vb32x4567);
7177

@@ -79,7 +85,7 @@ void xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_1x32c4__hvx(
7985
const HVX_Vector va0x0123 = Q6_V_vsplat_R(unaligned_load_s32(a0));
8086
a0 += 4;
8187

82-
const HVX_Vector vb32x0123 = *((HVX_Vector *)((int8_t *)w));
88+
const HVX_Vector vb32x0123 = *((HVX_UVector *)((int8_t *)w));
8389
vacc0x32 = Q6_Vw_vrmpyacc_VwVbVb(vacc0x32, va0x0123, vb32x0123);
8490

8591
w = (const int8_t*) w + 128;
@@ -113,7 +119,7 @@ void xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_1x32c4__hvx(
113119
nc -= 32;
114120
} else {
115121
// Prepare mask for valid 8-bit elements (depends on nc).
116-
Q6_V_vstu_variable(c0, nc, vout0x32);
122+
xnn_Q6_V_vstu_variable(c0, nc, vout0x32);
117123
nc = 0;
118124
}
119125
} while (nc != 0);

src/qs8-qc8w-gemm/gen/qs8-qc8w-gemm-4x32c4-minmax-fp32-hvx.c

Lines changed: 17 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010

1111

1212
#include <assert.h>
13+
#include <string.h> // for memcpy
1314

1415
#include <hexagon_types.h>
1516
#include <hexagon_protos.h>
@@ -21,6 +22,11 @@
2122
#include "src/xnnpack/unaligned.h"
2223

2324

25+
static XNN_INTRINSIC void xnn_Q6_V_vstu_variable(void* addr, uint32_t n,
26+
HVX_Vector vin) {
27+
memcpy(addr, &vin, n);
28+
}
29+
2430

2531
void xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_4x32c4__hvx(
2632
size_t mr,
@@ -71,13 +77,13 @@ void xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_4x32c4__hvx(
7177
const HVX_Vector voutput_min = Q6_Vb_vsplat_R(params->fp32_scalar.output_min);
7278

7379
do {
74-
HVX_Vector vacc0x32 = *((HVX_Vector*)w);
80+
HVX_Vector vacc0x32 = *((HVX_UVector*)w);
7581
HVX_Vector vacc1x0x32 = Q6_V_vsplat_R(0);
76-
HVX_Vector vacc1x32 = *((HVX_Vector*)w);
82+
HVX_Vector vacc1x32 = *((HVX_UVector*)w);
7783
HVX_Vector vacc1x1x32 = Q6_V_vsplat_R(0);
78-
HVX_Vector vacc2x32 = *((HVX_Vector*)w);
84+
HVX_Vector vacc2x32 = *((HVX_UVector*)w);
7985
HVX_Vector vacc1x2x32 = Q6_V_vsplat_R(0);
80-
HVX_Vector vacc3x32 = *((HVX_Vector*)w);
86+
HVX_Vector vacc3x32 = *((HVX_UVector*)w);
8187
HVX_Vector vacc1x3x32 = Q6_V_vsplat_R(0);
8288

8389
w = (const int32_t*) w + 32;
@@ -97,8 +103,8 @@ void xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_4x32c4__hvx(
97103
const HVX_Vector va3x4567 = Q6_V_vsplat_R(unaligned_load_s32(a3+4));
98104
a3 += 8;
99105

100-
const HVX_Vector vb32x0123 = *((HVX_Vector *)((int8_t *)w));
101-
const HVX_Vector vb32x4567 = *((HVX_Vector *)((int8_t *)w + 128));
106+
const HVX_Vector vb32x0123 = *((HVX_UVector *)((int8_t *)w));
107+
const HVX_Vector vb32x4567 = *((HVX_UVector *)((int8_t *)w + 128));
102108
vacc0x32 = Q6_Vw_vrmpyacc_VwVbVb(vacc0x32, va0x0123, vb32x0123);
103109
vacc1x0x32 = Q6_Vw_vrmpyacc_VwVbVb(vacc1x0x32, va0x4567, vb32x4567);
104110
vacc1x32 = Q6_Vw_vrmpyacc_VwVbVb(vacc1x32, va1x0123, vb32x0123);
@@ -127,7 +133,7 @@ void xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_4x32c4__hvx(
127133
const HVX_Vector va3x0123 = Q6_V_vsplat_R(unaligned_load_s32(a3));
128134
a3 += 4;
129135

130-
const HVX_Vector vb32x0123 = *((HVX_Vector *)((int8_t *)w));
136+
const HVX_Vector vb32x0123 = *((HVX_UVector *)((int8_t *)w));
131137
vacc0x32 = Q6_Vw_vrmpyacc_VwVbVb(vacc0x32, va0x0123, vb32x0123);
132138
vacc1x32 = Q6_Vw_vrmpyacc_VwVbVb(vacc1x32, va1x0123, vb32x0123);
133139
vacc2x32 = Q6_Vw_vrmpyacc_VwVbVb(vacc2x32, va2x0123, vb32x0123);
@@ -200,10 +206,10 @@ void xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_4x32c4__hvx(
200206
nc -= 32;
201207
} else {
202208
// Prepare mask for valid 8-bit elements (depends on nc).
203-
Q6_V_vstu_variable(c0, nc, vout0x32);
204-
Q6_V_vstu_variable(c1, nc, vout1x32);
205-
Q6_V_vstu_variable(c2, nc, vout2x32);
206-
Q6_V_vstu_variable(c3, nc, vout3x32);
209+
xnn_Q6_V_vstu_variable(c0, nc, vout0x32);
210+
xnn_Q6_V_vstu_variable(c1, nc, vout1x32);
211+
xnn_Q6_V_vstu_variable(c2, nc, vout2x32);
212+
xnn_Q6_V_vstu_variable(c3, nc, vout3x32);
207213
nc = 0;
208214
}
209215
} while (nc != 0);

0 commit comments

Comments
 (0)