You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
# Summary
This is the fifth PR in a series that enables TMA im2col mode (in
addition to the existing tiled mode) for NVIDIA GPUs. The goal of the
series is to support TMA im2col mode in Gluon DSL.
- First PR: #9202
- Second PR: #9225
- Third PR: #9303
- Fourth PR: #9305
- -> Fifth PR: #9322
PTX ISA documentation for TMA im2col mode:
https://docs.nvidia.com/cuda/parallel-thread-execution/#tensor-im2col-mode
TMA tensor descriptor documentation:
https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__TENSOR__MEMORY.html
# Summary of Changes
Added LLVM lowering logic for `AsyncTMACopyGlobalToLocalOpConversion` to
support im2col mode.
## Im2col Mode Constraints
### pixelsPerColumn (non-contiguous dimension)
- **Maximum size**: 1024 elements
- **Corresponds to**: Spatial dimensions (N, D, H, W)
- **Block shape**: Restricted to match `shapePerCTA` (no splitting)
- **Rationale**: Avoids generating multiple TMA messages along spatial
dimensions, eliminating complex offset calculations that would depend on
input tensor shape and padding
- **Note**: 1024 is sufficient for most practical use cases
### channelsPerPixel (contiguous dimension)
- **Maximum size**: 256 elements, or swizzle byte size if swizzle is
enabled
- **Multiple messages**: Supported when channel dimension exceeds block
size
- **Offset application**: Only coord[0] (channel coordinate in PTX
order) receives non-zero offsets
## Key Implementation Details
1. **Offset application**: For im2col mode, only the channel dimension
receives non-zero offsets; spatial dimension offsets are always 0
(verified by assertion)
2. **Im2col offsets reversal**: Spatial offsets (e.g., `off_w`, `off_h`)
are reversed to match PTX/CUDA innermost-to-outermost ordering,
consistent with coordinate handling
3. **Alignment with tiled mode**: These constraints align with tiled
mode behavior used for GEMM operations
<!---
The core Triton is a small number of people, and we receive many PRs
(thank
you!). To help us review your code more quickly, **if you are a new
contributor (less than 3 PRs merged) we ask that you complete the
following
tasks and include the filled-out checklist in your PR description.**
Complete the following tasks before sending your PR, and replace `[ ]`
with
`[x]` to indicate you have done them.
-->
# New contributor declaration
- [x] I am not making a trivial change, such as fixing a typo in a
comment.
- [x] I have written a PR description following these
[rules](https://cbea.ms/git-commit/#why-not-how).
- [x] I have run `pre-commit run --from-ref origin/main --to-ref HEAD`.
- Select one of the following.
- [x] I have added tests.
- `/test` for `lit` tests
- `/unittest` for C++ tests
- `/python/test` for end-to-end tests
- [] This PR does not need a test because `FILL THIS IN`.
- Select one of the following.
- [] I have not added any `lit` tests.
- [x] The `lit` tests I have added follow these [best
practices](https://mlir.llvm.org/getting_started/TestingGuide/#filecheck-best-practices),
including the "tests should be minimal" section. (Usually running Python
code
and using the instructions it generates is not minimal.)
0 commit comments