Description
RFC: Introduction of BGEMM and BGEMV for BFloat16 Matrix Operations in OpenBLAS
Author: Nikhil Gupta
Date: 2025-02-27
Status: Proposal / Draft
1. Abstract
This RFC proposes the addition of two new BLAS routines—BGEMM and BGEMV—to the OpenBLAS project. These routines perform matrix-matrix multiplication and matrix-vector multiplication, respectively, entirely in BFloat16 (BF16) precision. Unlike the existing sbgemm
operation—which consumes BF16 inputs but produces FP32 outputs—the new operations will produce BF16 outputs. For architectures that lack native BF16 multiply–accumulate instructions with BF16 outputs, the implementation may perform accumulation in FP32 and subsequently convert the results to BF16.
2. Motivation
- Increased Precision Consistency: Many modern deep learning applications rely on BF16 for both inputs and outputs to reduce memory bandwidth and storage costs while maintaining adequate precision.
- Hardware Evolution: With the ongoing evolution of hardware architectures, future systems might offer native BF16 operations. The proposed routines will provide a cleaner pathway to leverage such advancements without altering the API.
- Performance Optimization: A dedicated BF16 routine can be optimized for architectures that either support native BF16 arithmetic or that can benefit from mixed-precision strategies (e.g., FP32 accumulation with BF16 conversion).
3. Proposed Changes
Introduce two new BLAS routines into OpenBLAS:
- BGEMM: Performs matrix-matrix multiplication on BF16 matrices.
- BGEMV: Performs matrix-vector multiplication on BF16 matrices.
Both routines will:
- Accept BF16 input matrices/vectors.
- Use BF16 scalars for scaling factors (
alpha
andbeta
). - Return the result in BF16 precision.
- Internally, if necessary, perform FP32 accumulation followed by a conversion to BF16.
4. Proposed API Signatures
Below are the example of proposed function signatures, modeled after the existing BLAS conventions:
BGEMM
void bgemm_(char *transa, char *transb, blasint *m, blasint *n, blasint *k,
bfloat16 *alpha,
const bfloat16 *a, blasint *lda,
const bfloat16 *b, blasint *ldb,
bfloat16 *beta,
bfloat16 *c, blasint *ldc);