Skip to content

Commit 0241d51

Browse files
authored
Merge pull request #5220 from iha-taisei/sdgemv_n_unroll
Further performance improvements to non-transposed [SD]GEMV kernels for A64FX and Neoverse V1.
2 parents afb6645 + f1e628b commit 0241d51

File tree

4 files changed

+349
-2
lines changed

4 files changed

+349
-2
lines changed

kernel/arm64/KERNEL.A64FX

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
include $(KERNELDIR)/KERNEL.ARMV8SVE
22

3-
SGEMVNKERNEL = gemv_n_sve.c
4-
DGEMVNKERNEL = gemv_n_sve.c
3+
SGEMVNKERNEL = gemv_n_sve_v4x3.c
4+
DGEMVNKERNEL = gemv_n_sve_v4x3.c
55
SGEMVTKERNEL = gemv_t_sve_v4x3.c
66
DGEMVTKERNEL = gemv_t_sve_v4x3.c

kernel/arm64/KERNEL.NEOVERSEV1

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
include $(KERNELDIR)/KERNEL.ARMV8SVE
22

3+
SGEMVNKERNEL = gemv_n_sve_v1x3.c
4+
DGEMVNKERNEL = gemv_n_sve_v1x3.c
35
SGEMVTKERNEL = gemv_t_sve_v1x3.c
46
DGEMVTKERNEL = gemv_t_sve_v1x3.c
57
ifeq ($(BUILD_BFLOAT16), 1)

kernel/arm64/gemv_n_sve_v1x3.c

Lines changed: 138 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,138 @@
1+
/***************************************************************************
2+
Copyright (c) 2025, The OpenBLAS Project
3+
All rights reserved.
4+
5+
Redistribution and use in source and binary forms, with or without
6+
modification, are permitted provided that the following conditions are
7+
met:
8+
9+
1. Redistributions of source code must retain the above copyright
10+
notice, this list of conditions and the following disclaimer.
11+
12+
2. Redistributions in binary form must reproduce the above copyright
13+
notice, this list of conditions and the following disclaimer in
14+
the documentation and/or other materials provided with the
15+
distribution.
16+
3. Neither the name of the OpenBLAS project nor the names of
17+
its contributors may be used to endorse or promote products
18+
derived from this software without specific prior written
19+
permission.
20+
21+
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
22+
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
23+
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
24+
ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE
25+
LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
26+
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
27+
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
28+
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
29+
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE
30+
USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
31+
*****************************************************************************/
32+
33+
#include <arm_sve.h>
34+
35+
#include "common.h"
36+
37+
#ifdef DOUBLE
38+
#define SV_COUNT svcntd
39+
#define SV_TYPE svfloat64_t
40+
#define SV_TRUE svptrue_b64
41+
#define SV_WHILE svwhilelt_b64_s64
42+
#define SV_DUP svdup_f64
43+
#else
44+
#define SV_COUNT svcntw
45+
#define SV_TYPE svfloat32_t
46+
#define SV_TRUE svptrue_b32
47+
#define SV_WHILE svwhilelt_b32_s64
48+
#define SV_DUP svdup_f32
49+
#endif
50+
51+
int CNAME(BLASLONG m, BLASLONG n, BLASLONG dummy1, FLOAT alpha, FLOAT *a,
52+
BLASLONG lda, FLOAT *x, BLASLONG inc_x, FLOAT *y, BLASLONG inc_y,
53+
FLOAT *buffer)
54+
{
55+
BLASLONG i;
56+
BLASLONG ix,iy;
57+
BLASLONG j;
58+
FLOAT *a_ptr;
59+
FLOAT temp;
60+
61+
ix = 0;
62+
a_ptr = a;
63+
64+
if (inc_y == 1) {
65+
BLASLONG width = (n + 3 - 1) / 3;
66+
67+
FLOAT *a0_ptr = a_ptr + lda * width * 0;
68+
FLOAT *a1_ptr = a_ptr + lda * width * 1;
69+
FLOAT *a2_ptr = a_ptr + lda * width * 2;
70+
71+
FLOAT *x0_ptr = x + inc_x * width * 0;
72+
FLOAT *x1_ptr = x + inc_x * width * 1;
73+
FLOAT *x2_ptr = x + inc_x * width * 2;
74+
75+
for (j = 0; j < width; j++) {
76+
svbool_t pg00 = ((j + width * 0) < n) ? SV_TRUE() : svpfalse();
77+
svbool_t pg01 = ((j + width * 1) < n) ? SV_TRUE() : svpfalse();
78+
svbool_t pg02 = ((j + width * 2) < n) ? SV_TRUE() : svpfalse();
79+
80+
SV_TYPE temp0_vec = SV_DUP(alpha * x0_ptr[ix]);
81+
SV_TYPE temp1_vec = SV_DUP(alpha * x1_ptr[ix]);
82+
SV_TYPE temp2_vec = SV_DUP(alpha * x2_ptr[ix]);
83+
i = 0;
84+
BLASLONG sve_size = SV_COUNT();
85+
while ((i + sve_size * 1 - 1) < m) {
86+
SV_TYPE y0_vec = svld1_vnum(SV_TRUE(), y + i, 0);
87+
88+
SV_TYPE a00_vec = svld1_vnum(pg00, a0_ptr + i, 0);
89+
SV_TYPE a01_vec = svld1_vnum(pg01, a1_ptr + i, 0);
90+
SV_TYPE a02_vec = svld1_vnum(pg02, a2_ptr + i, 0);
91+
92+
y0_vec = svmla_m(pg00, y0_vec, temp0_vec, a00_vec);
93+
y0_vec = svmla_m(pg01, y0_vec, temp1_vec, a01_vec);
94+
y0_vec = svmla_m(pg02, y0_vec, temp2_vec, a02_vec);
95+
96+
svst1_vnum(SV_TRUE(), y + i, 0, y0_vec);
97+
i += sve_size * 1;
98+
}
99+
100+
if (i < m) {
101+
svbool_t pg0 = SV_WHILE(i + sve_size * 0, m);
102+
103+
pg00 = svand_z(SV_TRUE(), pg0, pg00);
104+
pg01 = svand_z(SV_TRUE(), pg0, pg01);
105+
pg02 = svand_z(SV_TRUE(), pg0, pg02);
106+
107+
SV_TYPE y0_vec = svld1_vnum(pg0, y + i, 0);
108+
109+
SV_TYPE a00_vec = svld1_vnum(pg00, a0_ptr + i, 0);
110+
SV_TYPE a01_vec = svld1_vnum(pg01, a1_ptr + i, 0);
111+
SV_TYPE a02_vec = svld1_vnum(pg02, a2_ptr + i, 0);
112+
113+
y0_vec = svmla_m(pg00, y0_vec, temp0_vec, a00_vec);
114+
y0_vec = svmla_m(pg01, y0_vec, temp1_vec, a01_vec);
115+
y0_vec = svmla_m(pg02, y0_vec, temp2_vec, a02_vec);
116+
117+
svst1_vnum(pg0, y + i, 0, y0_vec);
118+
}
119+
a0_ptr += lda;
120+
a1_ptr += lda;
121+
a2_ptr += lda;
122+
ix += inc_x;
123+
}
124+
return(0);
125+
}
126+
127+
for (j = 0; j < n; j++) {
128+
temp = alpha * x[ix];
129+
iy = 0;
130+
for (i = 0; i < m; i++) {
131+
y[iy] += temp * a_ptr[i];
132+
iy += inc_y;
133+
}
134+
a_ptr += lda;
135+
ix += inc_x;
136+
}
137+
return (0);
138+
}

kernel/arm64/gemv_n_sve_v4x3.c

Lines changed: 207 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,207 @@
1+
/***************************************************************************
2+
Copyright (c) 2025, The OpenBLAS Project
3+
All rights reserved.
4+
5+
Redistribution and use in source and binary forms, with or without
6+
modification, are permitted provided that the following conditions are
7+
met:
8+
9+
1. Redistributions of source code must retain the above copyright
10+
notice, this list of conditions and the following disclaimer.
11+
12+
2. Redistributions in binary form must reproduce the above copyright
13+
notice, this list of conditions and the following disclaimer in
14+
the documentation and/or other materials provided with the
15+
distribution.
16+
3. Neither the name of the OpenBLAS project nor the names of
17+
its contributors may be used to endorse or promote products
18+
derived from this software without specific prior written
19+
permission.
20+
21+
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
22+
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
23+
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
24+
ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE
25+
LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
26+
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
27+
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
28+
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
29+
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE
30+
USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
31+
*****************************************************************************/
32+
33+
#include <arm_sve.h>
34+
35+
#include "common.h"
36+
37+
#ifdef DOUBLE
38+
#define SV_COUNT svcntd
39+
#define SV_TYPE svfloat64_t
40+
#define SV_TRUE svptrue_b64
41+
#define SV_WHILE svwhilelt_b64_s64
42+
#define SV_DUP svdup_f64
43+
#else
44+
#define SV_COUNT svcntw
45+
#define SV_TYPE svfloat32_t
46+
#define SV_TRUE svptrue_b32
47+
#define SV_WHILE svwhilelt_b32_s64
48+
#define SV_DUP svdup_f32
49+
#endif
50+
51+
int CNAME(BLASLONG m, BLASLONG n, BLASLONG dummy1, FLOAT alpha, FLOAT *a,
52+
BLASLONG lda, FLOAT *x, BLASLONG inc_x, FLOAT *y, BLASLONG inc_y,
53+
FLOAT *buffer)
54+
{
55+
BLASLONG i;
56+
BLASLONG ix,iy;
57+
BLASLONG j;
58+
FLOAT *a_ptr;
59+
FLOAT temp;
60+
61+
ix = 0;
62+
a_ptr = a;
63+
64+
if (inc_y == 1) {
65+
BLASLONG width = (n + 3 - 1) / 3;
66+
67+
FLOAT *a0_ptr = a_ptr + lda * width * 0;
68+
FLOAT *a1_ptr = a_ptr + lda * width * 1;
69+
FLOAT *a2_ptr = a_ptr + lda * width * 2;
70+
71+
FLOAT *x0_ptr = x + inc_x * width * 0;
72+
FLOAT *x1_ptr = x + inc_x * width * 1;
73+
FLOAT *x2_ptr = x + inc_x * width * 2;
74+
75+
for (j = 0; j < width; j++) {
76+
svbool_t pg00 = ((j + width * 0) < n) ? SV_TRUE() : svpfalse();
77+
svbool_t pg10 = ((j + width * 0) < n) ? SV_TRUE() : svpfalse();
78+
svbool_t pg20 = ((j + width * 0) < n) ? SV_TRUE() : svpfalse();
79+
svbool_t pg30 = ((j + width * 0) < n) ? SV_TRUE() : svpfalse();
80+
svbool_t pg01 = ((j + width * 1) < n) ? SV_TRUE() : svpfalse();
81+
svbool_t pg11 = ((j + width * 1) < n) ? SV_TRUE() : svpfalse();
82+
svbool_t pg21 = ((j + width * 1) < n) ? SV_TRUE() : svpfalse();
83+
svbool_t pg31 = ((j + width * 1) < n) ? SV_TRUE() : svpfalse();
84+
svbool_t pg02 = ((j + width * 2) < n) ? SV_TRUE() : svpfalse();
85+
svbool_t pg12 = ((j + width * 2) < n) ? SV_TRUE() : svpfalse();
86+
svbool_t pg22 = ((j + width * 2) < n) ? SV_TRUE() : svpfalse();
87+
svbool_t pg32 = ((j + width * 2) < n) ? SV_TRUE() : svpfalse();
88+
89+
SV_TYPE temp0_vec = SV_DUP(alpha * x0_ptr[ix]);
90+
SV_TYPE temp1_vec = SV_DUP(alpha * x1_ptr[ix]);
91+
SV_TYPE temp2_vec = SV_DUP(alpha * x2_ptr[ix]);
92+
i = 0;
93+
BLASLONG sve_size = SV_COUNT();
94+
while ((i + sve_size * 4 - 1) < m) {
95+
SV_TYPE y0_vec = svld1_vnum(SV_TRUE(), y + i, 0);
96+
SV_TYPE y1_vec = svld1_vnum(SV_TRUE(), y + i, 1);
97+
SV_TYPE y2_vec = svld1_vnum(SV_TRUE(), y + i, 2);
98+
SV_TYPE y3_vec = svld1_vnum(SV_TRUE(), y + i, 3);
99+
100+
SV_TYPE a00_vec = svld1_vnum(pg00, a0_ptr + i, 0);
101+
SV_TYPE a10_vec = svld1_vnum(pg10, a0_ptr + i, 1);
102+
SV_TYPE a20_vec = svld1_vnum(pg20, a0_ptr + i, 2);
103+
SV_TYPE a30_vec = svld1_vnum(pg30, a0_ptr + i, 3);
104+
SV_TYPE a01_vec = svld1_vnum(pg01, a1_ptr + i, 0);
105+
SV_TYPE a11_vec = svld1_vnum(pg11, a1_ptr + i, 1);
106+
SV_TYPE a21_vec = svld1_vnum(pg21, a1_ptr + i, 2);
107+
SV_TYPE a31_vec = svld1_vnum(pg31, a1_ptr + i, 3);
108+
SV_TYPE a02_vec = svld1_vnum(pg02, a2_ptr + i, 0);
109+
SV_TYPE a12_vec = svld1_vnum(pg12, a2_ptr + i, 1);
110+
SV_TYPE a22_vec = svld1_vnum(pg22, a2_ptr + i, 2);
111+
SV_TYPE a32_vec = svld1_vnum(pg32, a2_ptr + i, 3);
112+
113+
y0_vec = svmla_m(pg00, y0_vec, temp0_vec, a00_vec);
114+
y1_vec = svmla_m(pg10, y1_vec, temp0_vec, a10_vec);
115+
y2_vec = svmla_m(pg20, y2_vec, temp0_vec, a20_vec);
116+
y3_vec = svmla_m(pg30, y3_vec, temp0_vec, a30_vec);
117+
y0_vec = svmla_m(pg01, y0_vec, temp1_vec, a01_vec);
118+
y1_vec = svmla_m(pg11, y1_vec, temp1_vec, a11_vec);
119+
y2_vec = svmla_m(pg21, y2_vec, temp1_vec, a21_vec);
120+
y3_vec = svmla_m(pg31, y3_vec, temp1_vec, a31_vec);
121+
y0_vec = svmla_m(pg02, y0_vec, temp2_vec, a02_vec);
122+
y1_vec = svmla_m(pg12, y1_vec, temp2_vec, a12_vec);
123+
y2_vec = svmla_m(pg22, y2_vec, temp2_vec, a22_vec);
124+
y3_vec = svmla_m(pg32, y3_vec, temp2_vec, a32_vec);
125+
126+
svst1_vnum(SV_TRUE(), y + i, 0, y0_vec);
127+
svst1_vnum(SV_TRUE(), y + i, 1, y1_vec);
128+
svst1_vnum(SV_TRUE(), y + i, 2, y2_vec);
129+
svst1_vnum(SV_TRUE(), y + i, 3, y3_vec);
130+
i += sve_size * 4;
131+
}
132+
133+
if (i < m) {
134+
svbool_t pg0 = SV_WHILE(i + sve_size * 0, m);
135+
svbool_t pg1 = SV_WHILE(i + sve_size * 1, m);
136+
svbool_t pg2 = SV_WHILE(i + sve_size * 2, m);
137+
svbool_t pg3 = SV_WHILE(i + sve_size * 3, m);
138+
139+
pg00 = svand_z(SV_TRUE(), pg0, pg00);
140+
pg10 = svand_z(SV_TRUE(), pg1, pg10);
141+
pg20 = svand_z(SV_TRUE(), pg2, pg20);
142+
pg30 = svand_z(SV_TRUE(), pg3, pg30);
143+
pg01 = svand_z(SV_TRUE(), pg0, pg01);
144+
pg11 = svand_z(SV_TRUE(), pg1, pg11);
145+
pg21 = svand_z(SV_TRUE(), pg2, pg21);
146+
pg31 = svand_z(SV_TRUE(), pg3, pg31);
147+
pg02 = svand_z(SV_TRUE(), pg0, pg02);
148+
pg12 = svand_z(SV_TRUE(), pg1, pg12);
149+
pg22 = svand_z(SV_TRUE(), pg2, pg22);
150+
pg32 = svand_z(SV_TRUE(), pg3, pg32);
151+
152+
SV_TYPE y0_vec = svld1_vnum(pg0, y + i, 0);
153+
SV_TYPE y1_vec = svld1_vnum(pg1, y + i, 1);
154+
SV_TYPE y2_vec = svld1_vnum(pg2, y + i, 2);
155+
SV_TYPE y3_vec = svld1_vnum(pg3, y + i, 3);
156+
157+
SV_TYPE a00_vec = svld1_vnum(pg00, a0_ptr + i, 0);
158+
SV_TYPE a10_vec = svld1_vnum(pg10, a0_ptr + i, 1);
159+
SV_TYPE a20_vec = svld1_vnum(pg20, a0_ptr + i, 2);
160+
SV_TYPE a30_vec = svld1_vnum(pg30, a0_ptr + i, 3);
161+
SV_TYPE a01_vec = svld1_vnum(pg01, a1_ptr + i, 0);
162+
SV_TYPE a11_vec = svld1_vnum(pg11, a1_ptr + i, 1);
163+
SV_TYPE a21_vec = svld1_vnum(pg21, a1_ptr + i, 2);
164+
SV_TYPE a31_vec = svld1_vnum(pg31, a1_ptr + i, 3);
165+
SV_TYPE a02_vec = svld1_vnum(pg02, a2_ptr + i, 0);
166+
SV_TYPE a12_vec = svld1_vnum(pg12, a2_ptr + i, 1);
167+
SV_TYPE a22_vec = svld1_vnum(pg22, a2_ptr + i, 2);
168+
SV_TYPE a32_vec = svld1_vnum(pg32, a2_ptr + i, 3);
169+
170+
y0_vec = svmla_m(pg00, y0_vec, temp0_vec, a00_vec);
171+
y1_vec = svmla_m(pg10, y1_vec, temp0_vec, a10_vec);
172+
y2_vec = svmla_m(pg20, y2_vec, temp0_vec, a20_vec);
173+
y3_vec = svmla_m(pg30, y3_vec, temp0_vec, a30_vec);
174+
y0_vec = svmla_m(pg01, y0_vec, temp1_vec, a01_vec);
175+
y1_vec = svmla_m(pg11, y1_vec, temp1_vec, a11_vec);
176+
y2_vec = svmla_m(pg21, y2_vec, temp1_vec, a21_vec);
177+
y3_vec = svmla_m(pg31, y3_vec, temp1_vec, a31_vec);
178+
y0_vec = svmla_m(pg02, y0_vec, temp2_vec, a02_vec);
179+
y1_vec = svmla_m(pg12, y1_vec, temp2_vec, a12_vec);
180+
y2_vec = svmla_m(pg22, y2_vec, temp2_vec, a22_vec);
181+
y3_vec = svmla_m(pg32, y3_vec, temp2_vec, a32_vec);
182+
183+
svst1_vnum(pg0, y + i, 0, y0_vec);
184+
svst1_vnum(pg1, y + i, 1, y1_vec);
185+
svst1_vnum(pg2, y + i, 2, y2_vec);
186+
svst1_vnum(pg3, y + i, 3, y3_vec);
187+
}
188+
a0_ptr += lda;
189+
a1_ptr += lda;
190+
a2_ptr += lda;
191+
ix += inc_x;
192+
}
193+
return(0);
194+
}
195+
196+
for (j = 0; j < n; j++) {
197+
temp = alpha * x[ix];
198+
iy = 0;
199+
for (i = 0; i < m; i++) {
200+
y[iy] += temp * a_ptr[i];
201+
iy += inc_y;
202+
}
203+
a_ptr += lda;
204+
ix += inc_x;
205+
}
206+
return (0);
207+
}

0 commit comments

Comments
 (0)