From edaf51dd99bb979f15fa4f2774ba068cfec0c09e Mon Sep 17 00:00:00 2001 From: Annop Wongwathanarat Date: Wed, 26 Feb 2025 12:47:11 +0000 Subject: [PATCH] Add sbgemv_t_bfdot kernel for ARM64 This improves performance for sbgemv_t by up to 100x on NEOVERSEV1. The geometric mean speedup is ~61x for M=N=[2,512]. --- CONTRIBUTORS.md | 1 + kernel/arm64/KERNEL.NEOVERSEN2 | 1 + kernel/arm64/KERNEL.NEOVERSEV1 | 1 + kernel/arm64/KERNEL.NEOVERSEV2 | 4 + kernel/arm64/sbgemv_t_bfdot.c | 207 +++++++++++++++++++++++++++++++++ 5 files changed, 214 insertions(+) create mode 100644 kernel/arm64/sbgemv_t_bfdot.c diff --git a/CONTRIBUTORS.md b/CONTRIBUTORS.md index 99166f5203..9ce5e37de3 100644 --- a/CONTRIBUTORS.md +++ b/CONTRIBUTORS.md @@ -236,6 +236,7 @@ In chronological order: * Annop Wongwathanarat * [2025-01-10] Add thread throttling profile for SGEMM on NEOVERSEV1 * [2025-01-21] Optimize gemv_t_sve_v1x3 kernel + * [2025-02-26] Add sbgemv_t_bfdot kernel * Marek Michalowski * [2025-01-21] Add thread throttling profile for SGEMV on `NEOVERSEV1` diff --git a/kernel/arm64/KERNEL.NEOVERSEN2 b/kernel/arm64/KERNEL.NEOVERSEN2 index 2f7400113b..e4e1cfde31 100644 --- a/kernel/arm64/KERNEL.NEOVERSEN2 +++ b/kernel/arm64/KERNEL.NEOVERSEN2 @@ -198,3 +198,4 @@ SBGEMMINCOPYOBJ = sbgemm_incopy$(TSUFFIX).$(SUFFIX) SBGEMMITCOPYOBJ = sbgemm_itcopy$(TSUFFIX).$(SUFFIX) SBGEMMONCOPYOBJ = sbgemm_oncopy$(TSUFFIX).$(SUFFIX) SBGEMMOTCOPYOBJ = sbgemm_otcopy$(TSUFFIX).$(SUFFIX) +SBGEMVTKERNEL = sbgemv_t_bfdot.c \ No newline at end of file diff --git a/kernel/arm64/KERNEL.NEOVERSEV1 b/kernel/arm64/KERNEL.NEOVERSEV1 index 8845e6860a..374acb35b8 100644 --- a/kernel/arm64/KERNEL.NEOVERSEV1 +++ b/kernel/arm64/KERNEL.NEOVERSEV1 @@ -15,4 +15,5 @@ SBGEMMONCOPY = sbgemm_ncopy_$(SBGEMM_UNROLL_N)_neoversev1.c SBGEMMOTCOPY = sbgemm_tcopy_$(SBGEMM_UNROLL_N)_neoversev1.c SBGEMMONCOPYOBJ = sbgemm_oncopy$(TSUFFIX).$(SUFFIX) SBGEMMOTCOPYOBJ = sbgemm_otcopy$(TSUFFIX).$(SUFFIX) +SBGEMVTKERNEL = sbgemv_t_bfdot.c endif \ No newline at end of file diff --git a/kernel/arm64/KERNEL.NEOVERSEV2 b/kernel/arm64/KERNEL.NEOVERSEV2 index bc59990979..4d866f8584 100644 --- a/kernel/arm64/KERNEL.NEOVERSEV2 +++ b/kernel/arm64/KERNEL.NEOVERSEV2 @@ -1 +1,5 @@ include $(KERNELDIR)/KERNEL.ARMV8SVE + +ifeq ($(BUILD_BFLOAT16), 1) +SBGEMVTKERNEL = sbgemv_t_bfdot.c +endif \ No newline at end of file diff --git a/kernel/arm64/sbgemv_t_bfdot.c b/kernel/arm64/sbgemv_t_bfdot.c new file mode 100644 index 0000000000..0751690fcd --- /dev/null +++ b/kernel/arm64/sbgemv_t_bfdot.c @@ -0,0 +1,207 @@ +/*************************************************************************** +Copyright (c) 2025, The OpenBLAS Project +All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are +met: + + 1. Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + + 2. Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in + the documentation and/or other materials provided with the + distribution. + 3. Neither the name of the OpenBLAS project nor the names of + its contributors may be used to endorse or promote products + derived from this software without specific prior written + permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE +LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE +USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +*****************************************************************************/ + +#include +#include "common.h" + +static inline float bf16_to_fp32(bfloat16 bf16) { + uint32_t fp32 = (uint32_t)bf16 << 16; + return *((float*)&fp32); +} + +int CNAME(BLASLONG m, BLASLONG n, float alpha, bfloat16 *a, BLASLONG lda, bfloat16 *x, BLASLONG incx, float beta, float *y, BLASLONG incy) +{ + if (m < 1 || n < 1) return(0); + BLASLONG i; + BLASLONG ix,iy; + BLASLONG j; + bfloat16_t *a_ptr; + bfloat16_t *x_ptr; + float *y_ptr; + float temp; + + iy = 0; + a_ptr = (bfloat16_t*)(a); + x_ptr = (bfloat16_t*)(x); + + if (incx == 1) { + BLASLONG width = n / 4; + + bfloat16_t *a0_ptr = a_ptr + lda * width * 0; + bfloat16_t *a1_ptr = a_ptr + lda * width * 1; + bfloat16_t *a2_ptr = a_ptr + lda * width * 2; + bfloat16_t *a3_ptr = a_ptr + lda * width * 3; + + float *y0_ptr = y + incy * width * 0; + float *y1_ptr = y + incy * width * 1; + float *y2_ptr = y + incy * width * 2; + float *y3_ptr = y + incy * width * 3; + + for (j = 0; j < width; j++) { + float32x4_t temp0_vec = vdupq_n_f32(0.0f); + float32x4_t temp1_vec = vdupq_n_f32(0.0f); + float32x4_t temp2_vec = vdupq_n_f32(0.0f); + float32x4_t temp3_vec = vdupq_n_f32(0.0f); + + i = 0; + while (i + 7 < m) { + bfloat16x8_t x_vec = vld1q_bf16(x_ptr + i); + + bfloat16x8_t a0_vec = vld1q_bf16(a0_ptr + i); + bfloat16x8_t a1_vec = vld1q_bf16(a1_ptr + i); + bfloat16x8_t a2_vec = vld1q_bf16(a2_ptr + i); + bfloat16x8_t a3_vec = vld1q_bf16(a3_ptr + i); + + temp0_vec = vbfdotq_f32(temp0_vec, a0_vec, x_vec); + temp1_vec = vbfdotq_f32(temp1_vec, a1_vec, x_vec); + temp2_vec = vbfdotq_f32(temp2_vec, a2_vec, x_vec); + temp3_vec = vbfdotq_f32(temp3_vec, a3_vec, x_vec); + + i += 8; + } + if (i + 3 < m) { + float32x2_t t0 = vdup_n_f32(0.0f); + float32x2_t t1 = vdup_n_f32(0.0f); + float32x2_t t2 = vdup_n_f32(0.0f); + float32x2_t t3 = vdup_n_f32(0.0f); + + bfloat16x4_t x_vec = vld1_bf16(x_ptr + i); + + bfloat16x4_t a0_vec = vld1_bf16(a0_ptr + i); + bfloat16x4_t a1_vec = vld1_bf16(a1_ptr + i); + bfloat16x4_t a2_vec = vld1_bf16(a2_ptr + i); + bfloat16x4_t a3_vec = vld1_bf16(a3_ptr + i); + + t0 = vbfdot_f32(t0, a0_vec, x_vec); + t1 = vbfdot_f32(t1, a1_vec, x_vec); + t2 = vbfdot_f32(t2, a2_vec, x_vec); + t3 = vbfdot_f32(t3, a3_vec, x_vec); + + float32x2_t temp0_vec_low = vget_low_f32(temp0_vec); + float32x2_t temp1_vec_low = vget_low_f32(temp1_vec); + float32x2_t temp2_vec_low = vget_low_f32(temp2_vec); + float32x2_t temp3_vec_low = vget_low_f32(temp3_vec); + + temp0_vec = vcombine_f32(vadd_f32(t0, temp0_vec_low), vget_high_f32(temp0_vec)); + temp1_vec = vcombine_f32(vadd_f32(t1, temp1_vec_low), vget_high_f32(temp1_vec)); + temp2_vec = vcombine_f32(vadd_f32(t2, temp2_vec_low), vget_high_f32(temp2_vec)); + temp3_vec = vcombine_f32(vadd_f32(t3, temp3_vec_low), vget_high_f32(temp3_vec)); + + i += 4; + } + if (beta == 0.0f) { + y0_ptr[iy] = alpha * vaddvq_f32(temp0_vec); + y1_ptr[iy] = alpha * vaddvq_f32(temp1_vec); + y2_ptr[iy] = alpha * vaddvq_f32(temp2_vec); + y3_ptr[iy] = alpha * vaddvq_f32(temp3_vec); + } + else { + y0_ptr[iy] = alpha * vaddvq_f32(temp0_vec) + beta * y0_ptr[iy]; + y1_ptr[iy] = alpha * vaddvq_f32(temp1_vec) + beta * y1_ptr[iy]; + y2_ptr[iy] = alpha * vaddvq_f32(temp2_vec) + beta * y2_ptr[iy]; + y3_ptr[iy] = alpha * vaddvq_f32(temp3_vec) + beta * y3_ptr[iy]; + } + + for (; i < m; ++i) { + y0_ptr[iy] += alpha * a0_ptr[i] * x_ptr[i]; + y1_ptr[iy] += alpha * a1_ptr[i] * x_ptr[i]; + y2_ptr[iy] += alpha * a2_ptr[i] * x_ptr[i]; + y3_ptr[iy] += alpha * a3_ptr[i] * x_ptr[i]; + } + + iy += incy; + + a0_ptr += lda; + a1_ptr += lda; + a2_ptr += lda; + a3_ptr += lda; + } + + a_ptr = a3_ptr; + y_ptr = y3_ptr; + for (j = width * 4; j < n; j++) { + float32x4_t temp0_vec = vdupq_n_f32(0.0f); + i = 0; + while (i + 7 < m) { + bfloat16x8_t x_vec = vld1q_bf16(x_ptr + i); + bfloat16x8_t a0_vec = vld1q_bf16(a_ptr + i); + temp0_vec = vbfdotq_f32(temp0_vec, a0_vec, x_vec); + + i += 8; + } + if (i + 3 < m) { + float32x2_t t0 = vdup_n_f32(0.0f); + bfloat16x4_t x_vec = vld1_bf16(x_ptr + i); + bfloat16x4_t a0_vec = vld1_bf16(a_ptr + i); + + t0 = vbfdot_f32(t0, a0_vec, x_vec); + float32x2_t temp0_vec_low = vget_low_f32(temp0_vec); + temp0_vec = vcombine_f32(vadd_f32(t0, temp0_vec_low), vget_high_f32(temp0_vec)); + + i += 4; + } + if (beta == 0.0f) { + y_ptr[iy] = alpha * vaddvq_f32(temp0_vec); + } + else { + y_ptr[iy] = alpha * vaddvq_f32(temp0_vec) + beta * y_ptr[iy]; + } + + for (; i < m; ++i) { + y_ptr[iy] += alpha * a_ptr[i] * x_ptr[i]; + } + + iy += incy; + + a_ptr += lda; + } + return(0); + } + + for (j = 0; j < n; j++) { + temp = 0.0; + ix = 0; + for (i = 0; i < m; i++) { + temp += bf16_to_fp32(a[i]) * bf16_to_fp32(x[ix]); + ix += incx; + } + if (beta == 0.0f) { + y[iy] = alpha * temp; + } + else { + y[iy] = alpha * temp + beta * y[iy]; + } + iy += incy; + a += lda; + } + return (0); +}