Skip to content

[WIP][AMD][TDM] Add partial TDM copy support via warp_bases#10056

Draft
jungpark-mlir wants to merge 7 commits intotriton-lang:mainfrom
jungpark-mlir:tdm-special
Draft

[WIP][AMD][TDM] Add partial TDM copy support via warp_bases#10056
jungpark-mlir wants to merge 7 commits intotriton-lang:mainfrom
jungpark-mlir:tdm-special

Conversation

@jungpark-mlir
Copy link
Copy Markdown
Contributor

This PR adds an optional warp_bases attribute to AsyncTDMCopyGlobalToLocalOp that enables partial TDM copy — only a subset of warps perform useful TDM loads while the rest get pred=0 in their descriptor (hardware no-op, instruction still issued but no data moved).

The attribute is a flattened (log2(num_warps), ndim) matrix mapping each warpId bit to a tile offset. Non-zero rows form a contiguous prefix identifying active warps; trailing all-zero rows mark inactive ones. For example, with num_warps=8 and a 256x64 block where 4 warps suffice, warp_bases = [(64,0), (128,0), (0,0)] means bits 0-1 of warpId contribute offsets along dim0 (4 active warps), while bit 2 is (0,0) (warps 4-7 are inactive duplicates). The verifier enforces that the non-zero bases match the greedy distribution from tdmGetWarpDistribution.

During lowering, fillTDMDescriptor re-encodes per-warp tile dimensions for the reduced warp count and ANDs the user predicate with warpId < activeWarps. The feature works with all shared memory layouts including partitioned encodings. This is useful whenever num_warps exceeds what is needed to tile a copy — warp-pipelined loops, producer/consumer patterns, or kernels oversubscribing warps for compute density. The PR includes an example, verifier tests, and lowering lit tests.

Performance impact is unclear yet, we're investigating.

Add an optional warp_bases attribute to AsyncTDMCopyGlobalToLocalOp that
enables TDM warp specialization: only a subset of warps (activeWarps)
perform TDM copies while the remaining warps get pred=0 (hardware no-op).

- MLIR op definition: add warp_bases as OptionalAttr<DenseI64ArrayAttr>
- Verifier: validate power-of-two, contiguous prefix, greedy distribution
- Pybind: pass warp_bases from Python to MLIR
- Python API: add warp_bases param to async_load with validation
- Lowering: re-encode per-warp tile dims in TDM descriptor for
  activeWarps; emit layout_pred (warpId < activeWarps) ANDed with
  user pred
- Example: add gemm_tdm_specialized_pipelined_warp_pipelined_kernel
  with --4warp-tdm CLI flag
…or warp specialization

Replace the layout_pred approach (ANDing pred with warpId < activeWarps in
fillTDMDescriptor) with a conditional branch in emitTDMIntrinsic that skips
the entire TDM emission for inactive warps. This avoids computing dead
descriptor values and ensures tensorcnt is not incremented for inactive warps.

Also swap warp pipeline stage priorities in the specialized GEMM example
(compute stage gets higher priority).
…cation for warp specialization"

This reverts commit 17fffd9.
Add verifier negative tests (wrong size, non-contiguous prefix, greedy
mismatch) and lowering tests (predication logic, partitioned layout
instruction count) for the warp_bases attribute.

Rename "warp specialization" to "partial TDM copy" in all TDM
warp_bases-related comments and docs to better describe the mechanism.
activeWarps=0 exclusively means "warp_bases absent, all warps active."
When warp_bases is present, activeWarps is at least 1 (2^0 for all-zero
rows). This distinction matters for understanding the conditional logic
in fillTDMDescriptor and emitTDMLoadStore.
Comment on lines +103 to +104
# Partial TDM copy variant: only a subset of warps issue TDM copies.
# Duplicate warps get pred=0 (hardware no-op), freeing TDM bandwidth.
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

jus pass-by as I haven't looked at the PR in details:
I don't think we want to expose the concept of warps at Gluon level. I think there is a layering problem

Copy link
Copy Markdown
Contributor Author

@jungpark-mlir jungpark-mlir Apr 16, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, that was the main concept we tried hard not to break.
This change does not expose anything about warp, except for warp_bases, which is an element of the LinearLayout.
It only adds the ability to declare that the regions pointed to by warp0–3 and warp4–7 overlap, for example.
When regions are duplicated, tdm_copy for those warps are automatically disabled.

Does this still sound unacceptable?

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In my mental model this should not break the block programming--all warps are still collectively programmed and they go through uniform control flow paths.

As Jungwook pointed out this is just exposing controls of which warp is responsible for what elements in the tensor--we have such controls for threads, warps, etc in blocked layout and linear layouts etc. It gives the ability to declear that two warps are covering the same elements so for one warp can effectively mask if its corrresponding tdm load given duplicated load. (The masking off is achieved by using predicate; so that wave still sees and executes the tdm instruction per se so not like warp specialization).

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok, maybe I need to spend more time understand this. The TDM copy from global to shared memory so there shouldn't be any warp concept in the linear layout?

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For AMD TDM is warp level instruction. A single Gluon gfx1250.tdm.async_load op is under the hood done by all the warps collectively; each warp taking a slice of the tensor. Right now we are just having some heuristics to deduce the warp distribution. So warps are involved there; it's just implicit right now. This is making it explicit and controllable. No threads though. Even for NVIDIA I think we have warp involved and distributing to different warps?

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

but this is leaking abstraction, the way those ops are distributed on warps is not meant to be exposed at the language level, the ops should be tile level ops

Comment on lines +103 to +104
# Partial TDM copy variant: only a subset of warps issue TDM copies.
# Duplicate warps get pred=0 (hardware no-op), freeing TDM bandwidth.
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In my mental model this should not break the block programming--all warps are still collectively programmed and they go through uniform control flow paths.

As Jungwook pointed out this is just exposing controls of which warp is responsible for what elements in the tensor--we have such controls for threads, warps, etc in blocked layout and linear layouts etc. It gives the ability to declear that two warps are covering the same elements so for one warp can effectively mask if its corrresponding tdm load given duplicated load. (The masking off is achieved by using predicate; so that wave still sees and executes the tdm instruction per se so not like warp specialization).

return tensor_descriptor(handle, shape, strides, type)


def _validate_warp_bases(warp_bases, block_shape, num_warps):
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should define this validation as a static method in the C++ op definition and expose it to Python via binding so that we can share the same logic in C++ op verifier and Python.

: std::nullopt,
numDims);

// When partial TDM copy is active, the per-warp block shape differs from
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm. I'm not sure this is the natural way to implement it. If this is given from developer, we should be able to replace the logic in distributeTDMWarps with provided, and then rely on free variable to handle masking etc like gather/scatter?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants