[WIP][AMD][TDM] Add partial TDM copy support via warp_bases#10056
[WIP][AMD][TDM] Add partial TDM copy support via warp_bases#10056jungpark-mlir wants to merge 7 commits intotriton-lang:mainfrom
Conversation
Add an optional warp_bases attribute to AsyncTDMCopyGlobalToLocalOp that enables TDM warp specialization: only a subset of warps (activeWarps) perform TDM copies while the remaining warps get pred=0 (hardware no-op). - MLIR op definition: add warp_bases as OptionalAttr<DenseI64ArrayAttr> - Verifier: validate power-of-two, contiguous prefix, greedy distribution - Pybind: pass warp_bases from Python to MLIR - Python API: add warp_bases param to async_load with validation - Lowering: re-encode per-warp tile dims in TDM descriptor for activeWarps; emit layout_pred (warpId < activeWarps) ANDed with user pred - Example: add gemm_tdm_specialized_pipelined_warp_pipelined_kernel with --4warp-tdm CLI flag
…or warp specialization Replace the layout_pred approach (ANDing pred with warpId < activeWarps in fillTDMDescriptor) with a conditional branch in emitTDMIntrinsic that skips the entire TDM emission for inactive warps. This avoids computing dead descriptor values and ensures tensorcnt is not incremented for inactive warps. Also swap warp pipeline stage priorities in the specialized GEMM example (compute stage gets higher priority).
…cation for warp specialization" This reverts commit 17fffd9.
Add verifier negative tests (wrong size, non-contiguous prefix, greedy mismatch) and lowering tests (predication logic, partitioned layout instruction count) for the warp_bases attribute. Rename "warp specialization" to "partial TDM copy" in all TDM warp_bases-related comments and docs to better describe the mechanism.
activeWarps=0 exclusively means "warp_bases absent, all warps active." When warp_bases is present, activeWarps is at least 1 (2^0 for all-zero rows). This distinction matters for understanding the conditional logic in fillTDMDescriptor and emitTDMLoadStore.
| # Partial TDM copy variant: only a subset of warps issue TDM copies. | ||
| # Duplicate warps get pred=0 (hardware no-op), freeing TDM bandwidth. |
There was a problem hiding this comment.
jus pass-by as I haven't looked at the PR in details:
I don't think we want to expose the concept of warps at Gluon level. I think there is a layering problem
There was a problem hiding this comment.
Yes, that was the main concept we tried hard not to break.
This change does not expose anything about warp, except for warp_bases, which is an element of the LinearLayout.
It only adds the ability to declare that the regions pointed to by warp0–3 and warp4–7 overlap, for example.
When regions are duplicated, tdm_copy for those warps are automatically disabled.
Does this still sound unacceptable?
There was a problem hiding this comment.
In my mental model this should not break the block programming--all warps are still collectively programmed and they go through uniform control flow paths.
As Jungwook pointed out this is just exposing controls of which warp is responsible for what elements in the tensor--we have such controls for threads, warps, etc in blocked layout and linear layouts etc. It gives the ability to declear that two warps are covering the same elements so for one warp can effectively mask if its corrresponding tdm load given duplicated load. (The masking off is achieved by using predicate; so that wave still sees and executes the tdm instruction per se so not like warp specialization).
There was a problem hiding this comment.
ok, maybe I need to spend more time understand this. The TDM copy from global to shared memory so there shouldn't be any warp concept in the linear layout?
There was a problem hiding this comment.
For AMD TDM is warp level instruction. A single Gluon gfx1250.tdm.async_load op is under the hood done by all the warps collectively; each warp taking a slice of the tensor. Right now we are just having some heuristics to deduce the warp distribution. So warps are involved there; it's just implicit right now. This is making it explicit and controllable. No threads though. Even for NVIDIA I think we have warp involved and distributing to different warps?
There was a problem hiding this comment.
but this is leaking abstraction, the way those ops are distributed on warps is not meant to be exposed at the language level, the ops should be tile level ops
| # Partial TDM copy variant: only a subset of warps issue TDM copies. | ||
| # Duplicate warps get pred=0 (hardware no-op), freeing TDM bandwidth. |
There was a problem hiding this comment.
In my mental model this should not break the block programming--all warps are still collectively programmed and they go through uniform control flow paths.
As Jungwook pointed out this is just exposing controls of which warp is responsible for what elements in the tensor--we have such controls for threads, warps, etc in blocked layout and linear layouts etc. It gives the ability to declear that two warps are covering the same elements so for one warp can effectively mask if its corrresponding tdm load given duplicated load. (The masking off is achieved by using predicate; so that wave still sees and executes the tdm instruction per se so not like warp specialization).
| return tensor_descriptor(handle, shape, strides, type) | ||
|
|
||
|
|
||
| def _validate_warp_bases(warp_bases, block_shape, num_warps): |
There was a problem hiding this comment.
We should define this validation as a static method in the C++ op definition and expose it to Python via binding so that we can share the same logic in C++ op verifier and Python.
| : std::nullopt, | ||
| numDims); | ||
|
|
||
| // When partial TDM copy is active, the per-warp block shape differs from |
There was a problem hiding this comment.
Hmm. I'm not sure this is the natural way to implement it. If this is given from developer, we should be able to replace the logic in distributeTDMWarps with provided, and then rely on free variable to handle masking etc like gather/scatter?
This PR adds an optional
warp_basesattribute toAsyncTDMCopyGlobalToLocalOpthat enables partial TDM copy — only a subset of warps perform useful TDM loads while the rest getpred=0in their descriptor (hardware no-op, instruction still issued but no data moved).The attribute is a flattened
(log2(num_warps), ndim)matrix mapping eachwarpIdbit to a tile offset. Non-zero rows form a contiguous prefix identifying active warps; trailing all-zero rows mark inactive ones. For example, withnum_warps=8and a256x64block where 4 warps suffice,warp_bases = [(64,0), (128,0), (0,0)]means bits 0-1 ofwarpIdcontribute offsets along dim0 (4 active warps), while bit 2 is(0,0)(warps 4-7 are inactive duplicates). The verifier enforces that the non-zero bases match the greedy distribution fromtdmGetWarpDistribution.During lowering,
fillTDMDescriptorre-encodes per-warp tile dimensions for the reduced warp count and ANDs the user predicate withwarpId < activeWarps. The feature works with all shared memory layouts including partitioned encodings. This is useful whenevernum_warpsexceeds what is needed to tile a copy — warp-pipelined loops, producer/consumer patterns, or kernels oversubscribing warps for compute density. The PR includes an example, verifier tests, and lowering lit tests.Performance impact is unclear yet, we're investigating.