diff --git a/kernel/arm64/KERNEL.ARMV8SVE b/kernel/arm64/KERNEL.ARMV8SVE index dc58e329fc..9adacce632 100644 --- a/kernel/arm64/KERNEL.ARMV8SVE +++ b/kernel/arm64/KERNEL.ARMV8SVE @@ -74,7 +74,7 @@ DSCALKERNEL = scal.S CSCALKERNEL = zscal.S ZSCALKERNEL = zscal.S -SGEMVNKERNEL = gemv_n.S +SGEMVNKERNEL = gemv_n_sve.c DGEMVNKERNEL = gemv_n.S CGEMVNKERNEL = zgemv_n.S ZGEMVNKERNEL = zgemv_n.S diff --git a/kernel/arm64/gemv_n_sve.c b/kernel/arm64/gemv_n_sve.c index 2950555615..59a5c85572 100644 --- a/kernel/arm64/gemv_n_sve.c +++ b/kernel/arm64/gemv_n_sve.c @@ -1,5 +1,5 @@ /*************************************************************************** -Copyright (c) 2024, The OpenBLAS Project +Copyright (c) 2024-2025, The OpenBLAS Project All rights reserved. Redistribution and use in source and binary forms, with or without @@ -59,23 +59,82 @@ int CNAME(BLASLONG m, BLASLONG n, BLASLONG dummy1, FLOAT alpha, FLOAT *a, BLASLO a_ptr = a; if (inc_y == 1) { + BLASLONG width = n / 3; uint64_t sve_size = SV_COUNT(); - for (j = 0; j < n; j++) { - SV_TYPE temp_vec = SV_DUP(alpha * x[ix]); - i = 0; - svbool_t pg = SV_WHILE(i, m); - while (svptest_any(SV_TRUE(), pg)) { - SV_TYPE a_vec = svld1(pg, a_ptr + i); + svbool_t pg_true = SV_TRUE(); + svbool_t pg = SV_WHILE(0, m % sve_size); + + FLOAT *a0_ptr = a + lda * width * 0; + FLOAT *a1_ptr = a + lda * width * 1; + FLOAT *a2_ptr = a + lda * width * 2; + + for (j = 0; j < width; j++) { + for (i = 0; (i + sve_size - 1) < m; i += sve_size) { + ix = j * inc_x; + + SV_TYPE x0_vec = SV_DUP(alpha * x[ix + (inc_x * width * 0)]); + SV_TYPE x1_vec = SV_DUP(alpha * x[ix + (inc_x * width * 1)]); + SV_TYPE x2_vec = SV_DUP(alpha * x[ix + (inc_x * width * 2)]); + + SV_TYPE a00_vec = svld1(pg_true, a0_ptr + i); + SV_TYPE a01_vec = svld1(pg_true, a1_ptr + i); + SV_TYPE a02_vec = svld1(pg_true, a2_ptr + i); + + SV_TYPE y_vec = svld1(pg_true, y + i); + y_vec = svmla_lane(y_vec, a00_vec, x0_vec, 0); + y_vec = svmla_lane(y_vec, a01_vec, x1_vec, 0); + y_vec = svmla_lane(y_vec, a02_vec, x2_vec, 0); + + svst1(pg_true, y + i, y_vec); + } + + if (i < m) { + SV_TYPE x0_vec = SV_DUP(alpha * x[ix + (inc_x * width * 0)]); + SV_TYPE x1_vec = SV_DUP(alpha * x[ix + (inc_x * width * 1)]); + SV_TYPE x2_vec = SV_DUP(alpha * x[ix + (inc_x * width * 2)]); + + SV_TYPE a00_vec = svld1(pg, a0_ptr + i); + SV_TYPE a01_vec = svld1(pg, a1_ptr + i); + SV_TYPE a02_vec = svld1(pg, a2_ptr + i); + SV_TYPE y_vec = svld1(pg, y + i); - y_vec = svmla_x(pg, y_vec, temp_vec, a_vec); + y_vec = svmla_m(pg, y_vec, a00_vec, x0_vec); + y_vec = svmla_m(pg, y_vec, a01_vec, x1_vec); + y_vec = svmla_m(pg, y_vec, a02_vec, x2_vec); + + ix += inc_x; + svst1(pg, y + i, y_vec); - i += sve_size; - pg = SV_WHILE(i, m); } + + a0_ptr += lda; + a1_ptr += lda; + a2_ptr += lda; + } + + a_ptr = a2_ptr; + for (j = width * 3; j < n; j++) { + ix = j * inc_x; + for (i = 0; (i + sve_size - 1) < m; i += sve_size) { + SV_TYPE y_vec = svld1(pg_true, y + i); + SV_TYPE x_vec = SV_DUP(alpha * x[(ix)]); + SV_TYPE a_vec = svld1(pg_true, a_ptr + i); + y_vec = svmla_x(pg_true, y_vec, a_vec, x_vec); + svst1(pg_true, y + i, y_vec); + } + + if (i < m) { + SV_TYPE y_vec = svld1(pg, y + i); + SV_TYPE x_vec = SV_DUP(alpha * x[(ix)]); + SV_TYPE a_vec = svld1(pg, a_ptr + i); + y_vec = svmla_m(pg, y_vec, a_vec, x_vec); + svst1(pg, y + i, y_vec); + } + a_ptr += lda; ix += inc_x; } - return(0); + return (0); } for (j = 0; j < n; j++) { @@ -89,4 +148,4 @@ int CNAME(BLASLONG m, BLASLONG n, BLASLONG dummy1, FLOAT alpha, FLOAT *a, BLASLO ix += inc_x; } return (0); -} +} \ No newline at end of file