Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 47 additions & 21 deletions docs/api-specs/cooperative_matrix.md
Original file line number Diff line number Diff line change
Expand Up @@ -221,35 +221,57 @@ These roles are part of the type; they are not interchangeable.
WGSL provides built-in functions for operating on cooperative matrices. The exact spelling may
change; the semantics are:

#### `coopLoad`
#### `coopLoad` / `coopLoadT`

Collectively load a tile from memory into a cooperative matrix.
Collectively load a tile from memory into a cooperative matrix. Two variants
select the memory layout:

```/dev/null/example.wgsl#L1-6
- `coopLoad` — matrix is stored **column-major** in memory; `stride` is the
number of elements between adjacent columns.
- `coopLoadT` — matrix is stored **row-major** in memory (i.e. transposed
relative to the canonical column-major layout used by `coopLoad`);
`stride` is the number of elements between adjacent rows. This is the
natural fit for C-style `ptr[i * num_cols + j]` storage.

```/dev/null/example.wgsl#L1-10
fn coopLoad<T, R>(
ptr: ptr<STORAGE_CLASS, T>, // base pointer to scalar or vector elements
stride: u32 // stride (in elements) between rows/columns
stride: u32 // elements between adjacent columns
) -> coop_matMxN<T, R>;

fn coopLoadT<T, R>(
ptr: ptr<STORAGE_CLASS, T>, // base pointer to scalar or vector elements
stride: u32 // elements between adjacent rows
) -> coop_matMxN<T, R>;
```

- Loads an M×N tile (or M×K / K×N, depending on role and operation) from memory pointed to by `ptr`.
- `stride` describes the layout in memory; it is usually the number of elements between adjacent rows.
- All invocations in the cooperative group must call `coopLoad` in a converged fashion.
- All invocations in the cooperative group must call the chosen variant in a converged fashion.
- Memory address range must be valid and properly aligned for the scalar type.

> Implementation note: Each lane contributes to filling the tile based on an implementation-defined mapping from
> invocation/lane ID to sub-fragment of the matrix.

#### `coopStore`
#### `coopStore` / `coopStoreT`

Collectively store a cooperative matrix tile back to memory. Variant
selection mirrors the load builtins:

Collectively store a cooperative matrix tile back to memory.
- `coopStore` — writes **column-major**; `stride` between columns.
- `coopStoreT` — writes **row-major**; `stride` between rows.

```/dev/null/example.wgsl#L8-13
```/dev/null/example.wgsl#L12-23
fn coopStore<T, R>(
value: coop_matMxN<T, R>,
ptr: ptr<STORAGE_CLASS, T>,
stride: u32
);

fn coopStoreT<T, R>(
value: coop_matMxN<T, R>,
ptr: ptr<STORAGE_CLASS, T>,
stride: u32
);
```

- Stores `value` into the memory region addressed by `ptr` with given `stride`.
Expand Down Expand Up @@ -287,6 +309,8 @@ alias MatA = coop_mat8x8<f32, A>;
alias MatB = coop_mat8x8<f32, B>;
alias MatC = coop_mat8x8<f32, C>;

// Assumes each tile is stored column-major in memory (the plain `coopLoad`
// / `coopStore` form); use `coopLoadT` / `coopStoreT` for row-major storage.
fn matmul_tile(
ptr_a: ptr<storage, f32>,
ptr_b: ptr<storage, f32>,
Expand All @@ -311,8 +335,8 @@ underlying scalar type semantics (e.g. IEEE-754 for floats).
Cooperative matrix operations are **collective**:

- All invocations in the relevant execution group must execute each cooperative operation in uniform control flow:
- Using `coopLoad`, `coopStore`, or `coopMultiplyAdd` in divergent control flow (e.g. some lanes taking
a branch, others not) is undefined behavior.
- Using `coopLoad` / `coopLoadT`, `coopStore` / `coopStoreT`, or `coopMultiplyAdd` in divergent control flow
(e.g. some lanes taking a branch, others not) is undefined behavior.
- The exact execution group may be a workgroup, a SIMD-group / subgroup, or another backend-specific
granularity; shaders must treat it abstractly.

Expand All @@ -334,7 +358,7 @@ Example:
enable wgpu_cooperative_matrix;

struct Matrices {
// Row-major tiles for A, B, C.
// Row-major tiles for A, B, C — use the `…T` load/store variants.
data: array<f32>,
};

Expand Down Expand Up @@ -363,12 +387,12 @@ fn main(
let base_b = &buf_b.data[tile_offset];
let base_c = &buf_c.data[tile_offset];

let a: MatA = coopLoad<f32, A>(base_a, 8u);
let b: MatB = coopLoad<f32, B>(base_b, 8u);
let c: MatC = coopLoad<f32, C>(base_c, 8u);
let a: MatA = coopLoadT<f32, A>(base_a, 8u);
let b: MatB = coopLoadT<f32, B>(base_b, 8u);
let c: MatC = coopLoadT<f32, C>(base_c, 8u);

let result: MatC = coopMultiplyAdd(a, b, c);
coopStore(result, base_c, 8u);
coopStoreT(result, base_c, 8u);
}
```

Expand All @@ -391,11 +415,13 @@ The following are examples of **undefined behavior** (non-exhaustive):
- Using a cooperative matrix type `(M, N, T, R)` not supported by
`Adapter::cooperative_matrix_properties()`.
- Mismatching sizes or roles in `coopMultiplyAdd` (e.g. incompatible M/N/K, or incorrect roles).
- Executing `coopLoad`, `coopStore`, or `coopMultiplyAdd` in divergent control flow within the
cooperating execution group.
- Providing invalid, misaligned, or out-of-bounds pointers to `coopLoad` / `coopStore`.
- Overlapping `coopStore` targets in a way that creates data races or aliasing that the memory
model does not allow.
- Executing `coopLoad` / `coopLoadT`, `coopStore` / `coopStoreT`, or `coopMultiplyAdd` in divergent
control flow within the cooperating execution group.
- Providing invalid, misaligned, or out-of-bounds pointers to any of the load/store builtins.
- Using a load/store variant (`coopLoad` vs `coopLoadT`, `coopStore` vs `coopStoreT`) whose memory
layout does not match how the tile is actually stored.
- Overlapping `coopStore` / `coopStoreT` targets in a way that creates data races or aliasing that
the memory model does not allow.

---

Expand Down
23 changes: 20 additions & 3 deletions examples/features/src/cooperative_matrix/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -221,9 +221,26 @@ async fn execute(

// Initialize matrices
// A is MxK, B is KxN, C is MxN (result)
// Use f32 for computation, convert to f16 if needed for GPU
let matrix_a_f32: Vec<f32> = (0..M * K).map(|i| (i % 7) as f32 * 0.1).collect();
let matrix_b_f32: Vec<f32> = (0..K * N).map(|i| (i % 11) as f32 * 0.1).collect();
// Use f32 for computation, convert to f16 if needed for GPU.
//
// The init weights `i * col_stride + j * row_stride` are chosen so
// neither A nor B is symmetric in (i, j): if the row/col index
// weighting reduced to the same residue class modulo the divisor,
// the matrix would become symmetric and the test would no longer
// distinguish row-major from column-major loads. The primes here
// (`3, 5` for A; `7, 11` for B) ensure asymmetry for any M/N/K.
let matrix_a_f32: Vec<f32> = (0..M * K)
.map(|idx| {
let (i, j) = (idx / K, idx % K);
((i * 3 + j * 5) % 11) as f32 * 0.1
})
.collect();
let matrix_b_f32: Vec<f32> = (0..K * N)
.map(|idx| {
let (i, j) = (idx / N, idx % N);
((i * 7 + j * 11) % 13) as f32 * 0.1
})
.collect();
let matrix_c_f32: Vec<f32> = vec![0.0; (M * N) as usize];

// Element size depends on precision
Expand Down
8 changes: 4 additions & 4 deletions examples/features/src/cooperative_matrix/shader.wgsl
Original file line number Diff line number Diff line change
Expand Up @@ -38,22 +38,22 @@ fn main(@builtin(workgroup_id) workgroup_id: vec3<u32>) {

// Load the C tile (accumulator)
let c_offset = tile_row * stride + tile_col;
var c_tile = coopLoad<coop_mat8x8<f32, C>>(&matrix_c[c_offset], stride);
var c_tile = coopLoadT<coop_mat8x8<f32, C>>(&matrix_c[c_offset], stride);

// Iterate over K dimension in tiles
for (var k: u32 = 0u; k < K; k += TILE_SIZE) {
// Load A tile: rows [tile_row, tile_row+8), cols [k, k+8)
let a_offset = tile_row * K + k;
let a_tile = coopLoad<coop_mat8x8<f32, A>>(&matrix_a[a_offset], K);
let a_tile = coopLoadT<coop_mat8x8<f32, A>>(&matrix_a[a_offset], K);

// Load B tile: rows [k, k+8), cols [tile_col, tile_col+8)
let b_offset = k * stride + tile_col;
let b_tile = coopLoad<coop_mat8x8<f32, B>>(&matrix_b[b_offset], stride);
let b_tile = coopLoadT<coop_mat8x8<f32, B>>(&matrix_b[b_offset], stride);

// Multiply and accumulate: C += A * B
c_tile = coopMultiplyAdd(a_tile, b_tile, c_tile);
}

// Store the result back to C
coopStore(c_tile, &matrix_c[c_offset], stride);
coopStoreT(c_tile, &matrix_c[c_offset], stride);
}
14 changes: 12 additions & 2 deletions naga/src/back/msl/writer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2917,7 +2917,13 @@ impl<W: Write> Writer<W> {
self.put_access_chain(data.pointer, context.policies.index, context)?;
write!(self.out, ", ")?;
self.put_expression(data.stride, context, true)?;
write!(self.out, ", {})", data.row_major)?;
// Metal's `simdgroup_load` treats its `transpose` flag as
// "memory is transposed from the simdgroup_matrix's canonical
// layout". On Apple GPUs that canonical layout is row-major,
// so `transpose=false` loads from row-major memory. WGSL's
// `coopLoadT` (row_major=true) = row-major memory, so it must
// map to `transpose=false`. Hence the negation.
write!(self.out, ", {})", !data.row_major)?;
}
crate::Expression::CooperativeMultiplyAdd { a, b, c } => {
if context.lang_version < (2, 3) {
Expand Down Expand Up @@ -4216,7 +4222,11 @@ impl<W: Write> Writer<W> {
)?;
write!(self.out, ", ")?;
self.put_expression(data.stride, &context.expression, true)?;
if data.row_major {
// See the comment in `CooperativeLoad` above: WGSL's
// row_major flag is negated when emitting Metal's
// `transpose` flag, so a col-major store (row_major=false)
// must use `transpose=true`.
if !data.row_major {
let matrix_origin = "0";
let transpose = true;
write!(self.out, ", {matrix_origin}, {transpose}")?;
Expand Down
4 changes: 2 additions & 2 deletions naga/tests/out/msl/wgsl-cooperative-matrix.metal
Original file line number Diff line number Diff line change
Expand Up @@ -30,13 +30,13 @@ metal::simdgroup_float8x8 NagaCooperativeMultiplyAdd(const thread metal::simdgro
metal::simdgroup_float8x8 b = {};
metal::simdgroup_float8x8 c = {};
metal::simdgroup_float8x8 d = {};
c = NagaCooperativeLoad(&ext[4], 8u, false);
c = NagaCooperativeLoad(&ext[4], 8u, true);
metal::simdgroup_float8x8 _e6 = a;
metal::simdgroup_float8x8 _e8 = b;
metal::simdgroup_float8x8 _e9 = c;
d = NagaCooperativeMultiplyAdd(_e6, _e8, _e9);
metal::simdgroup_float8x8 _e12 = d;
simdgroup_store(_e12, &ext[0], 8u);
simdgroup_store(_e12, &ext[0], 8u, 0, true);
metal::simdgroup_float8x8 _e16 = d;
c = _e16;
return;
Expand Down