-
Notifications
You must be signed in to change notification settings - Fork 5
Batch GEMM Guide
Batch GEMM runs many independent GEMM operations in a single API call. Instead of looping over aocl_gemm_* yourself, you hand the library arrays of matrices organized into groups, and it dispatches them with shared setup and coordinated threading.
This guide assumes you are already familiar with the single-GEMM interface. If not, read the GEMM Guide first -- the data-type naming, order/transpose/mem_format semantics, leading dimensions, and post-ops all carry over unchanged.
A single aocl_batch_gemm_* call computes, for every operation i in the batch:
C[i] = alpha * op(A[i]) * op(B[i]) + beta * C[i]
Use batch GEMM when you have many independent GEMMs that can be described compactly:
- Transformer/MLP workloads with many small or medium matrices.
- Operations that share a shape and parameters, so they can be grouped.
- Cases where per-call setup (kernel selection, B packing, thread fork/join) would otherwise be paid once per matrix.
Batching amortizes that setup across the whole group and lets the library schedule the work across threads as one unit, which is more efficient than issuing each GEMM separately. The interface follows the grouped CBLAS cblas_gemm_batch model (group_count + group_size[]).
The batch is divided into groups. Every operation inside a group shares the same shape and attributes (m, n, k, transpose flags, leading dimensions, alpha/beta, memory formats, and post-ops); only the matrix data pointers differ. Different groups may have completely different configurations.
Two indexing rules apply to the argument arrays:
| Array kind | Length | Indexed by |
|---|---|---|
order, transa, transb, m, n, k, alpha, beta, lda, ldb, ldc, mem_format_a, mem_format_b, metadata
|
group_count |
group number |
a, b, c (matrix pointer arrays) |
total operations = sum(group_size[i])
|
flat operation number |
-
group_count-- number of groups (themd_tscalar passed by value). -
group_size[]-- lengthgroup_count;group_size[g]is the number of operations (matrix triples) in groupg. -
Per-group arrays carry one value per group. For example,
m[g]is the row count for every operation in groupg. -
Per-operation pointer arrays (
a,b,c) are laid out by concatenating groups in order: group 0's operations first, then group 1's, and so on.
The post-op
metadataarray is indexed per group: onedlp_metadata_t*descriptor is applied to all operations in that group (passNULLfor a group with no post-ops).
With group_count = 3 and group_size = {4, 3, 2} (9 operations total):
-
m,n,k,alpha, ... are length-3 arrays (one entry per group). -
a,b,care length-9 arrays. Operations0..3belong to group 0,4..6to group 1,7..8to group 2.
Batch functions use the same naming scheme as single GEMM, with the aocl_batch_gemm_ prefix:
aocl_batch_gemm_<A_type><B_type><accumulator_type>o<output_type>
See the GEMM Guide naming convention for the short-name/type table (f32, f16, bf16, s8, u8, s4, u4, s32). The combinations below are the ones that exist as real aocl_batch_gemm_* symbols in aocl_gemm_interface_apis.h.
| Input A | Input B | Accumulator | Supported Outputs | Min ISA |
|---|---|---|---|---|
| f32 | f32 | f32 | f32 | AVX2 |
| f16 | f16 | f16 | f16, f32 | AVX512_FP16 |
| f32 | f16 | f32 | f32 | AVX512 |
| Input A | Input B | Accumulator | Supported Outputs | Min ISA |
|---|---|---|---|---|
| bf16 | bf16 | f32 | f32, bf16 | AVX2 (*) |
| bf16 | s4 | f32 | f32, bf16 | AVX512 |
| bf16 | u4 | f32 | f32, bf16 | AVX512 |
| bf16 | s8 | s32 | s32, s8, u8, f32, bf16 | AVX512_VNNI |
| Input A | Input B | Accumulator | Supported Outputs | Min ISA |
|---|---|---|---|---|
| f32 | s8 | s32 | s32, s8, u8, f32, bf16 | AVX512_VNNI |
| Input A | Input B | Accumulator | Supported Outputs | Min ISA |
|---|---|---|---|---|
| u8 | s8 | s32 | s32, s8, u8, f32, bf16 | AVX512_VNNI |
| s8 | s8 | s32 | s32, s8, u8, f32, bf16 | AVX512_VNNI |
| Input A | Input B | Accumulator | Supported Outputs | Min ISA |
|---|---|---|---|---|
| s8 | s8 (sym_quant) | s32 | f32, bf16 | AVX512_VNNI |
(*) BFloat16 batch GEMM on hardware without native AVX512_BF16 falls back to float32 kernels with transparent conversion, matching the single-GEMM behavior. See Library Overview.
The symmetric-quant entry maps to aocl_batch_gemm_s8s8s32of32_sym_quant and aocl_batch_gemm_s8s8s32obf16_sym_quant. See the Quantization Guide for scale/zero-point setup, which is identical to the single-GEMM path.
All batch functions share one parameter pattern. Here is aocl_batch_gemm_f32f32f32of32, the reference signature:
void aocl_batch_gemm_f32f32f32of32(
const char* order, // [group_count] 'R' row-major / 'C' column-major
const char* transa, // [group_count] 'N' / 'T'
const char* transb, // [group_count] 'N' / 'T'
const md_t* m, // [group_count] rows of A (and C) per group
const md_t* n, // [group_count] columns of B (and C) per group
const md_t* k, // [group_count] columns of A / rows of B per group
const float* alpha, // [group_count] scalar per group
const float** a, // [total ops] pointers to A matrices
const md_t* lda, // [group_count] leading dimension of A per group
const float** b, // [total ops] pointers to B matrices
const md_t* ldb, // [group_count] leading dimension of B per group
const float* beta, // [group_count] scalar per group
float** c, // [total ops] pointers to C matrices (output)
const md_t* ldc, // [group_count] leading dimension of C per group
const md_t group_count, // number of groups (scalar)
const md_t* group_size, // [group_count] operations per group
const char* mem_format_a, // [group_count] 'N' / 'P' / 'R'
const char* mem_format_b, // [group_count] 'N' / 'P' / 'R'
dlp_metadata_t** metadata // [group_count] post-ops per group (NULL = none)
);Other variants differ only in the element types of a, b, c, and the scalar types of alpha/beta (e.g. integer paths take const int32_t* for alpha/beta, f16 paths take const float16*). The meaning of order, transa/transb, lda/ldb/ldc, and mem_format_a/mem_format_b is exactly as documented in the GEMM Guide.
The snippet below mirrors the official example, examples/classic/batch_gemm.c, trimmed to the essentials. It runs three groups with different shapes.
#include <aocl_dlp.h>
#include <stdlib.h>
// 3 groups; group g holds group_size[g] independent GEMMs.
const md_t group_count = 3;
md_t group_size[3] = { 4, 3, 2 }; // 9 operations total
// Per-group attributes (length = group_count).
md_t m[3] = { 128, 256, 64 };
md_t n[3] = { 128, 256, 512 };
md_t k[3] = { 128, 256, 256 };
md_t lda[3] = { 128, 256, 256 }; // row-major: lda = k
md_t ldb[3] = { 128, 256, 512 }; // row-major: ldb = n
md_t ldc[3] = { 128, 256, 512 }; // row-major: ldc = n
float alpha[3] = { 1.0f, 1.5f, 0.8f };
float beta[3] = { 0.0f, 0.0f, 0.0f };
char order[3] = { 'R', 'R', 'R' };
char transa[3] = { 'N', 'N', 'N' };
char transb[3] = { 'N', 'N', 'N' };
char mfa[3] = { 'N', 'N', 'N' }; // mem_format_a
char mfb[3] = { 'N', 'N', 'N' }; // mem_format_b
// Per-operation matrix pointers (length = sum(group_size) = 9).
md_t total_ops = group_size[0] + group_size[1] + group_size[2];
const float** a = malloc(total_ops * sizeof(float*));
const float** b = malloc(total_ops * sizeof(float*));
float** c = malloc(total_ops * sizeof(float*));
// Allocate each operation's A/B/C using its group's m,n,k (allocation omitted).
// Operations 0..3 -> group 0, 4..6 -> group 1, 7..8 -> group 2.
// One post-op descriptor per group; NULL means "no post-ops".
dlp_metadata_t* metadata[3] = { NULL, NULL, NULL };
aocl_batch_gemm_f32f32f32of32(
order, transa, transb, m, n, k, alpha,
a, lda, b, ldb, beta, c, ldc,
group_count, group_size, mfa, mfb, metadata);The same call shape applies to every variant -- swap the function name and the element types of a/b/c and alpha/beta.
When a group's B matrices are reused (e.g. weights), pre-reordering them into the library's internal layout removes per-call B packing and improves throughput. Batch GEMM honors the reordered memory tag through mem_format_b.
Workflow per group:
- Query the buffer size with the matching single-GEMM helper, e.g.
aocl_get_reorder_buf_size_f32f32f32of32('R', 'N', 'B', k, n, NULL). - Reorder each B matrix in the group with
aocl_reorder_f32f32f32of32(...)into its own buffer, and store the reordered pointer back into the per-operationbarray. - Set that group's
mem_format_b[g] = 'R'.
Notes and limits (enforced by the batch implementation):
- Reorder applies to B only. Requesting a reordered A in row-major returns
DLP_CLSC_NOT_SUPPORTED. - Reordering is not supported for column-major (
order = 'C') groups. - Reorder is available for the data types that expose
aocl_get_reorder_buf_size_*/aocl_reorder_*:f32f32f32of32,bf16bf16f32of32,u8s8s32os32,s8s8s32os32,bf16s4f32of32,u8s4s32os32,f16f16f16of16,f32f16f32of32, ands8s8s32os32_sym_quant.
See GEMM Guide -- Matrix Reordering for the full reorder API and the per-type support table.
Fused post-operations attach exactly like single GEMM: build a dlp_metadata_t descriptor and pass it through the per-group metadata array. Because metadata is indexed per group, one descriptor governs all operations in its group; pass NULL for groups that need no post-ops.
dlp_metadata_t* metadata[3] = { meta_g0, NULL, meta_g2 }; // length = group_countAll batch variants accept post-ops. The FP16 rails (f16f16f16of16 and f16f16f16of32) support the standard fused post-ops, with these exceptions: A-dequantization, pre-ops, and group post-ops are not supported on the FP16 batch paths and return DLP_CLSC_NOT_SUPPORTED.
For the catalog of post-ops (BIAS, activations, SCALE, MATRIX_ADD/MUL, etc.) and how to populate dlp_metadata_t, see the Post-Operations Guide.
Each group reports status into its own metadata[g] entry (when non-NULL) via dlp_metadata_t.error_hndl. Check the per-group code after the call:
if (metadata[g] && metadata[g]->error_hndl.error_code != DLP_CLSC_SUCCESS) {
// group g failed -- see dlp_errors.h
}The error codes are the same dlp_clsc_err_t values listed in the GEMM Guide (DLP_CLSC_NOT_SUPPORTED, DLP_CLSC_INVALID_MATRIX_DIMENSION, etc.).
Batch GEMM has dedicated, public harness assets on amd-main:
| Purpose | Asset |
|---|---|
| Benchmark driver | bench/bench_batch_gemm.cc |
| Benchmark config | bench/configs/batch_gemm_bench_config.yaml |
| Test config | tests/classic/configs/batch_gemm_test_config.yaml |
The benchmark config defines named shape groups (e.g. transformer-style m/n/k with a group_size) you can edit to model your workload. For how to run and analyze these, see DLP Benchmarking and DLP Testing.
- GEMM Guide -- Single-GEMM parameters, data types, reordering, and variant selection
- Post-Operations Guide -- Fusing BIAS, activations, and scaling
- Quantization Guide -- Symmetric quantization and mixed-precision workflows
- Examples & Tutorials -- Working code examples
- API Reference -- Generated API documentation
Getting Started
User Guides
- Library Overview
- GEMM Guide
- Batch GEMM Guide
- Post-Operations
- Eltwise Operations
- Quantization
- API Lifecycle
Performance & Config
Testing & Benchmarking
Developer Guides
Reference