🧪Experimental🧪
wgpu supports an experimental cooperative matrix feature when Features::EXPERIMENTAL_COOPERATIVE_MATRIX is enabled.
This exposes hardware-accelerated matrix multiply-accumulate (MMA) operations (for example, NVIDIA tensor cores,
Metal SIMD-group matrices, and Vulkan VK_KHR_cooperative_matrix).
Note: The features documented here may have bugs and are subject to breaking changes. The API and shader semantics are expected to evolve. Please refer to the GitHub issue tracker for the latest status and discussions.
Cooperative matrices allow a workgroup (or equivalent execution group) to collectively:
- load small matrix tiles from memory,
- perform matrix multiply-accumulate operations on those tiles, and
- store the results back to memory.
Conceptually, this is specialized hardware that evaluates:
C = A * B + C
for relatively small tiles, but at very high throughput compared to composing the same operation from scalar/vector instructions.
Cooperative matrix operations are most useful in workloads such as:
- machine learning and inference,
- dense linear algebra and scientific computing,
- image processing, filtering, and transforms.
The cooperative nature means that all lanes in the cooperating execution group must participate in the operations; individual invocations cannot diverge.
Typical example:
Ais an M×K matrix.Bis a K×N matrix.Cis an M×N matrix, acting as the accumulator and result.
Before using cooperative matrices in shaders, you must query what configurations your hardware and backend support.
On the Adapter, wgpu exposes:
Adapter::cooperative_matrix_properties() -> Vec<CooperativeMatrixProperties>
Each CooperativeMatrixProperties describes a single supported configuration. Fields are:
m_size: height of matrices A and C (type:naga::CooperativeSize)n_size: width of matrices B and C (type:naga::CooperativeSize)k_size: shared inner dimension of A and B (type:naga::CooperativeSize)ab_type: scalar element type for A and B (type:naga::Scalar)cr_type: scalar element type for C and the result (type:naga::Scalar)saturating_accumulation:boolindicating whether overflow clamping on accumulation is supported for this configuration
Example usage:
let coop_props = adapter.cooperative_matrix_properties();
for prop in &coop_props {
println!(
"{:?}x{:?}x{:?} - AB: {:?}, CR: {:?}, saturating: {}",
prop.m_size, prop.n_size, prop.k_size,
prop.ab_type, prop.cr_type,
prop.saturating_accumulation,
);
}
You must:
- Enable
Features::EXPERIMENTAL_COOPERATIVE_MATRIXon theDevice. - Query
adapter.cooperative_matrix_properties()and ensure that the configuration(s) you intend to use in WGSL are actually available on the running adapter/backend. - Treat the sizes and types as a contract between your shaders and the underlying hardware implementation. Using unsupported configurations is an error.
- Using cooperative matrices requires enabling:
Features::EXPERIMENTAL_COOPERATIVE_MATRIX
This feature may be restricted to certain backends and hardware.
These are general guidelines, not a complete compatibility matrix:
-
Metal:
- Requires Apple7+ (A14) or Mac2+ (M1) GPU with MSL 2.3+.
- Strong support for 8×8
f32, 8×8f16, and mixed-precision modes (e.g.f16A/B andf32accumulator C). - Implementation is based on SIMD-group matrix operations.
-
Vulkan:
- Requires the
VK_KHR_cooperative_matrixextension. - Many NVIDIA and AMD GPUs support
f16at 16×16 tile sizes and similar. - 8×8
f32support is hardware-dependent. - Exact configurations are enumerated by
Adapter::cooperative_matrix_properties().
- Requires the
-
Other backends:
- May not support cooperative matrices at all. In that case the feature will not be exposed, and
adapter.cooperative_matrix_properties()will return an empty list.
- May not support cooperative matrices at all. In that case the feature will not be exposed, and
Always treat the properties returned at runtime as the source of truth.
This section summarizes the host-side API elements related to cooperative matrices. (For exact signatures and details, refer to the Rust documentation.)
Adapter::cooperative_matrix_properties() -> Vec<CooperativeMatrixProperties>
Returns all cooperative matrix configurations supported by the adapter/backend.
CooperativeMatrixPropertiesm_size: naga::CooperativeSizen_size: naga::CooperativeSizek_size: naga::CooperativeSizeab_type: naga::Scalarcr_type: naga::Scalarsaturating_accumulation: bool
The naga types (CooperativeSize, Scalar) are part of the shader translation layer and
determine the legal WGSL/cooperative matrix combinations.
There are currently no dedicated wgpu buffer or texture types for cooperative matrices; they are
expressed in WGSL as special value types accessed via pointers into ordinary var<storage> /
var<workgroup> / var<private> / etc.
Cooperative matrices are enabled and accessed via WGSL extensions. The exact extension spelling may change; the details below describe the intended semantics.
Any WGSL program using cooperative matrices must declare an extension at the top of the shader, for example:
enable wgpu_cooperative_matrix;
The shader is invalid if any cooperative matrix types or builtins are used without enabling this extension.
A cooperative matrix is a value type parameterized by:
- tile size (M×N),
- scalar element type
T, and - role
Rindicating how the matrix participates in the multiply-accumulate:A: left operandB: right operandC: accumulator / result
Conceptually:
// A: MxK, B: KxN, C: MxN
type coop_matMxN<T, A>;
type coop_matMxN<T, B>;
type coop_matMxN<T, C>;
Concrete examples (sizes and types must match a supported configuration from
Adapter::cooperative_matrix_properties):
// 8x8 single-precision tiles
alias CoopMatA = coop_mat8x8<f32, A>;
alias CoopMatB = coop_mat8x8<f32, B>;
alias CoopMatC = coop_mat8x8<f32, C>;
// 16x16 half-precision inputs, 16x16 f32 accumulator (mixed precision)
alias CoopMat16x16A = coop_mat16x16<f16, A>;
alias CoopMat16x16B = coop_mat16x16<f16, B>;
alias CoopMat16x16C = coop_mat16x16<f32, C>;
The actual set of legal (M, N, T, R) combinations is defined by the cooperative matrix
properties returned at runtime; shaders must not use arbitrary combinations.
-
Arole:- Treated as the left operand in the multiplication. Has shape M×K.
- Participates as
AinA * B + C.
-
Brole:- Treated as the right operand in the multiplication. Has shape K×N.
- Participates as
BinA * B + C.
-
Crole:- Treated as accumulator and result. Has shape M×N.
- Participates as
CinA * B + C.
These roles are part of the type; they are not interchangeable.
WGSL provides built-in functions for operating on cooperative matrices. The exact spelling may change; the semantics are:
Collectively load a tile from memory into a cooperative matrix. Two variants select the memory layout:
coopLoad— matrix is stored column-major in memory;strideis the number of elements between adjacent columns.coopLoadT— matrix is stored row-major in memory (i.e. transposed relative to the canonical column-major layout used bycoopLoad);strideis the number of elements between adjacent rows. This is the natural fit for C-styleptr[i * num_cols + j]storage.
fn coopLoad<T, R>(
ptr: ptr<STORAGE_CLASS, T>, // base pointer to scalar or vector elements
stride: u32 // elements between adjacent columns
) -> coop_matMxN<T, R>;
fn coopLoadT<T, R>(
ptr: ptr<STORAGE_CLASS, T>, // base pointer to scalar or vector elements
stride: u32 // elements between adjacent rows
) -> coop_matMxN<T, R>;
- Loads an M×N tile (or M×K / K×N, depending on role and operation) from memory pointed to by
ptr. - All invocations in the cooperative group must call the chosen variant in a converged fashion.
- Memory address range must be valid and properly aligned for the scalar type.
Implementation note: Each lane contributes to filling the tile based on an implementation-defined mapping from invocation/lane ID to sub-fragment of the matrix.
Collectively store a cooperative matrix tile back to memory. Variant selection mirrors the load builtins:
coopStore— writes column-major;stridebetween columns.coopStoreT— writes row-major;stridebetween rows.
fn coopStore<T, R>(
value: coop_matMxN<T, R>,
ptr: ptr<STORAGE_CLASS, T>,
stride: u32
);
fn coopStoreT<T, R>(
value: coop_matMxN<T, R>,
ptr: ptr<STORAGE_CLASS, T>,
stride: u32
);
- Stores
valueinto the memory region addressed byptrwith givenstride. - All invocations in the cooperative group must participate.
- The store must not alias overlapping tiles in undefined ways.
Perform a matrix multiply-accumulate operation on cooperative matrices:
fn coopMultiplyAdd<Tab, Tcr, MA, KA, KB, NB>(
a: coop_matMAxKA<Tab, A>, // A: MAxKA tile
b: coop_matKBxNB<Tab, B>, // B: KBxNB tile (KB == KA)
c: coop_matMAxNB<Tcr, C> // C: MAxNB accumulator/result
) -> coop_matMAxNB<Tcr, C>;
Semantics:
- Computes
C' = A * B + C. - Returns the resulting accumulator tile
C'. - Implies:
KA == KB(inner dimension must match).- Types
(Tab, Tcr)must be one of the supported AB/CR combinations given byCooperativeMatrixProperties. - Sizes
(MA, NB, KA)must match a supported(m_size, n_size, k_size)triple.
For example, with a supported configuration:
enable wgpu_cooperative_matrix;
alias MatA = coop_mat8x8<f32, A>;
alias MatB = coop_mat8x8<f32, B>;
alias MatC = coop_mat8x8<f32, C>;
// Assumes each tile is stored column-major in memory (the plain `coopLoad`
// / `coopStore` form); use `coopLoadT` / `coopStoreT` for row-major storage.
fn matmul_tile(
ptr_a: ptr<storage, f32>,
ptr_b: ptr<storage, f32>,
ptr_c: ptr<storage, f32>,
stride: u32,
) {
let a: MatA = coopLoad<_, A>(ptr_a, stride);
let b: MatB = coopLoad<_, B>(ptr_b, stride);
let c: MatC = coopLoad<_, C>(ptr_c, stride);
let result: MatC = coopMultiplyAdd(a, b, c);
coopStore(result, ptr_c, stride);
}
If saturating_accumulation is true for the chosen configuration, then overflow during accumulation
is clamped (e.g. saturating arithmetic). If false, overflow behavior for the accumulator follows the
underlying scalar type semantics (e.g. IEEE-754 for floats).
Cooperative matrix operations are collective:
-
All invocations in the relevant execution group must execute each cooperative operation in uniform control flow:
- Using
coopLoad/coopLoadT,coopStore/coopStoreT, orcoopMultiplyAddin divergent control flow (e.g. some lanes taking a branch, others not) is undefined behavior. - The exact execution group may be a workgroup, a SIMD-group / subgroup, or another backend-specific granularity; shaders must treat it abstractly.
- Using
-
The workgroup (or cooperating group) size is constrained by both:
- the cooperative matrix configuration, and
- backend-specific implementation details.
For portable code:
-
Choose a workgroup size that is known to be supported efficiently on your target backends, for example:
@workgroup_size(8, 8, 1)to operate on an 8×8 tile, or- a multiple of the tile size where each subgroup handles a tile.
-
Avoid control-flow divergence around cooperative operations.
Example:
enable wgpu_cooperative_matrix;
struct Matrices {
// Row-major tiles for A, B, C — use the `…T` load/store variants.
data: array<f32>,
};
@group(0) @binding(0)
var<storage, read> buf_a: Matrices;
@group(0) @binding(1)
var<storage, read> buf_b: Matrices;
@group(0) @binding(2)
var<storage, read_write> buf_c: Matrices;
alias MatA = coop_mat8x8<f32, A>;
alias MatB = coop_mat8x8<f32, B>;
alias MatC = coop_mat8x8<f32, C>;
@compute @workgroup_size(8, 8, 1)
fn main(
@builtin(workgroup_id) wg_id: vec3<u32>,
@builtin(local_invocation_id) lid: vec3<u32>,
) {
// Compute tile offset; this is one of many possible mappings.
let tile_index = wg_id.x; // 1D tiling in this simple example
let tile_offset = tile_index * 64u; // 8x8 tile has 64 elements
// Base pointers for tiles of A, B, C.
let base_a = &buf_a.data[tile_offset];
let base_b = &buf_b.data[tile_offset];
let base_c = &buf_c.data[tile_offset];
let a: MatA = coopLoadT<f32, A>(base_a, 8u);
let b: MatB = coopLoadT<f32, B>(base_b, 8u);
let c: MatC = coopLoadT<f32, C>(base_c, 8u);
let result: MatC = coopMultiplyAdd(a, b, c);
coopStoreT(result, base_c, 8u);
}
Implementations must validate the following where possible:
- The
wgpu_cooperative_matrixWGSL extension is enabled if any cooperative matrix types or builtins are used. - Tile sizes
(M, N, K)and scalar types(ab_type, cr_type)match at least oneCooperativeMatrixPropertiesentry for the current adapter/backend. - Workgroup size, shader stage, and other pipeline configuration constraints required by the backend are satisfied.
The following are examples of undefined behavior (non-exhaustive):
- Using cooperative matrix operations without enabling the WGSL extension.
- Using a cooperative matrix type
(M, N, T, R)not supported byAdapter::cooperative_matrix_properties(). - Mismatching sizes or roles in
coopMultiplyAdd(e.g. incompatible M/N/K, or incorrect roles). - Executing
coopLoad/coopLoadT,coopStore/coopStoreT, orcoopMultiplyAddin divergent control flow within the cooperating execution group. - Providing invalid, misaligned, or out-of-bounds pointers to any of the load/store builtins.
- Using a load/store variant (
coopLoadvscoopLoadT,coopStorevscoopStoreT) whose memory layout does not match how the tile is actually stored. - Overlapping
coopStore/coopStoreTtargets in a way that creates data races or aliasing that the memory model does not allow.
The example in examples/features/src/cooperative_matrix demonstrates using cooperative matrices to
compute:
C = A * B + Cwhere:Ais 64×64,Bis 64×64,Cis 64×64.
A high-level tiling strategy:
- Partition A, B, and C into 8×8 tiles.
- Launch one workgroup per output tile of C (i.e. 8×8 tiles for a 64×64 matrix = 8×8 = 64 tiles).
- Within each workgroup:
- Loop over K-dimension tiles.
- For each
ktile:- Load an 8×8 tile of A (
MatA). - Load an 8×8 tile of B (
MatB). - Maintain an 8×8 accumulator tile (
MatC) and repeatedly applycoopMultiplyAdd.
- Load an 8×8 tile of A (
- After the K loop, store the final accumulator tile back to C.
Key points from the example:
- Workgroup size is chosen so that all cooperative operations are well-defined and efficient for 8×8 tiles.
- Host-side code:
- Enables
Features::EXPERIMENTAL_COOPERATIVE_MATRIX. - Queries
cooperative_matrix_propertiesand verifies that 8×8f32or chosen configuration is supported. - Dispatches the compute pipeline with appropriate grid dimensions.
- Enables
- Always query
adapter.cooperative_matrix_properties()and check that the configuration your shaders use exists. Do not hard-code assumptions about available tile sizes or element types. - Treat the cooperative execution group as an abstract concept; avoid making assumptions about how tiles are mapped to lanes beyond what is guaranteed by the spec.
- Avoid divergent control flow around cooperative operations.
- Consider providing a fallback non-cooperative implementation for devices that do not support the feature.
- This is an experimental extension; API and semantics may change across versions of
wgpuandnaga.