Skip to content

Commit 4e6ac98

Browse files
committed
Optimize gemv_n_sve kernel
1 parent ef9e3f7 commit 4e6ac98

File tree

2 files changed

+74
-14
lines changed

2 files changed

+74
-14
lines changed

kernel/arm64/KERNEL.ARMV8SVE

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@ DSCALKERNEL = scal.S
7474
CSCALKERNEL = zscal.S
7575
ZSCALKERNEL = zscal.S
7676

77-
SGEMVNKERNEL = gemv_n.S
77+
SGEMVNKERNEL = gemv_n_sve.c
7878
DGEMVNKERNEL = gemv_n.S
7979
CGEMVNKERNEL = zgemv_n.S
8080
ZGEMVNKERNEL = zgemv_n.S

kernel/arm64/gemv_n_sve.c

Lines changed: 73 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/***************************************************************************
2-
Copyright (c) 2024, The OpenBLAS Project
2+
Copyright (c) 2024-2025, The OpenBLAS Project
33
All rights reserved.
44
55
Redistribution and use in source and binary forms, with or without
@@ -57,25 +57,85 @@ int CNAME(BLASLONG m, BLASLONG n, BLASLONG dummy1, FLOAT alpha, FLOAT *a, BLASLO
5757

5858
ix = 0;
5959
a_ptr = a;
60-
6160
if (inc_y == 1) {
61+
62+
BLASLONG width = n / 3;
6263
uint64_t sve_size = SV_COUNT();
63-
for (j = 0; j < n; j++) {
64-
SV_TYPE temp_vec = SV_DUP(alpha * x[ix]);
65-
i = 0;
66-
svbool_t pg = SV_WHILE(i, m);
67-
while (svptest_any(SV_TRUE(), pg)) {
68-
SV_TYPE a_vec = svld1(pg, a_ptr + i);
64+
svbool_t pg_true = SV_TRUE();
65+
svbool_t pg = SV_WHILE(0, m % sve_size);
66+
67+
FLOAT *a0_ptr = a + lda * width * 0;
68+
FLOAT *a1_ptr = a + lda * width * 1;
69+
FLOAT *a2_ptr = a + lda * width * 2;
70+
71+
for (j = 0; j < width; j++) {
72+
for (i = 0; (i + sve_size - 1) < m; i += sve_size) {
73+
ix = j * inc_x;
74+
75+
SV_TYPE x0_vec = SV_DUP(alpha * x[ix + (inc_x * width * 0)]);
76+
SV_TYPE x1_vec = SV_DUP(alpha * x[ix + (inc_x * width * 1)]);
77+
SV_TYPE x2_vec = SV_DUP(alpha * x[ix + (inc_x * width * 2)]);
78+
79+
SV_TYPE a00_vec = svld1(pg_true, a0_ptr + i);
80+
SV_TYPE a01_vec = svld1(pg_true, a1_ptr + i);
81+
SV_TYPE a02_vec = svld1(pg_true, a2_ptr + i);
82+
83+
SV_TYPE y_vec = svld1(pg_true, y + i);
84+
y_vec = svmla_lane(y_vec, a00_vec, x0_vec, 0);
85+
y_vec = svmla_lane(y_vec, a01_vec, x1_vec, 0);
86+
y_vec = svmla_lane(y_vec, a02_vec, x2_vec, 0);
87+
88+
svst1(pg_true, y + i, y_vec);
89+
}
90+
91+
if (i < m) {
92+
SV_TYPE x0_vec = SV_DUP(alpha * x[ix + (inc_x * width * 0)]);
93+
SV_TYPE x1_vec = SV_DUP(alpha * x[ix + (inc_x * width * 1)]);
94+
SV_TYPE x2_vec = SV_DUP(alpha * x[ix + (inc_x * width * 2)]);
95+
96+
SV_TYPE a00_vec = svld1(pg, a0_ptr + i);
97+
SV_TYPE a01_vec = svld1(pg, a1_ptr + i);
98+
SV_TYPE a02_vec = svld1(pg, a2_ptr + i);
99+
69100
SV_TYPE y_vec = svld1(pg, y + i);
70-
y_vec = svmla_x(pg, y_vec, temp_vec, a_vec);
101+
y_vec = svmla_m(pg, y_vec, a00_vec, x0_vec);
102+
y_vec = svmla_m(pg, y_vec, a01_vec, x1_vec);
103+
y_vec = svmla_m(pg, y_vec, a02_vec, x2_vec);
104+
105+
ix += inc_x;
106+
71107
svst1(pg, y + i, y_vec);
72-
i += sve_size;
73-
pg = SV_WHILE(i, m);
74108
}
109+
110+
a0_ptr += lda;
111+
a1_ptr += lda;
112+
a2_ptr += lda;
113+
}
114+
115+
a_ptr = a2_ptr;
116+
for (j = width * 3; j < n; j++) {
117+
ix = j * inc_x;
118+
for (i = 0; (i + sve_size - 1) < m; i += sve_size) {
119+
SV_TYPE y_vec = svld1(pg_true, y + i);
120+
SV_TYPE x_vec = SV_DUP(alpha * x[(ix)]);
121+
SV_TYPE a_vec = svld1(pg_true, a_ptr + i);
122+
y_vec = svmla_x(pg_true, y_vec, a_vec, x_vec);
123+
svst1(pg_true, y + i, y_vec);
124+
}
125+
126+
if (i < m) {
127+
SV_TYPE y_vec = svld1(pg, y + i);
128+
SV_TYPE x_vec = SV_DUP(alpha * x[(ix)]);
129+
SV_TYPE a_vec = svld1(pg, a_ptr + i);
130+
y_vec = svmla_m(pg, y_vec, a_vec, x_vec);
131+
svst1(pg, y + i, y_vec);
132+
}
133+
75134
a_ptr += lda;
76135
ix += inc_x;
77136
}
78-
return(0);
137+
138+
return (0);
79139
}
80140

81141
for (j = 0; j < n; j++) {
@@ -89,4 +149,4 @@ int CNAME(BLASLONG m, BLASLONG n, BLASLONG dummy1, FLOAT alpha, FLOAT *a, BLASLO
89149
ix += inc_x;
90150
}
91151
return (0);
92-
}
152+
}

0 commit comments

Comments
 (0)