Skip to content

[Nvidia] Enable TMA im2col mode -- Tensor descriptor#2

Closed
bingyizh233 wants to merge 35 commits intotma-im2col-asynTmaCopyfrom
tma-im2col-descriptor
Closed

[Nvidia] Enable TMA im2col mode -- Tensor descriptor#2
bingyizh233 wants to merge 35 commits intotma-im2col-asynTmaCopyfrom
tma-im2col-descriptor

Conversation

@bingyizh233
Copy link
Copy Markdown
Owner

@bingyizh233 bingyizh233 commented Jan 14, 2026

Summary

This is the second PR in a series that enables TMA im2col mode (in addition to the existing tiled mode) for NVIDIA GPUs. The goal of the series is to support TMA im2col mode in the Triton compiler and the Gluon DSL.

First PR: [Nvidia] Enable TMA im2col mode -- Tensor descriptor

PTX ISA documentation for TMA im2col mode: https://docs.nvidia.com/cuda/parallel-thread-execution/#tensor-im2col-mode
TMA tensor descriptor documentation: https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__TENSOR__MEMORY.html

Summary of changes

Extend TensorDescType to support both tiled and im2col TMA modes for tensor memory access. This enables convolution-friendly access patterns via NVIDIA’s TMA im2col functionality.
The additional information in TensorDescType is required because host-side TMA need these parameters to build the tensor descriptor at runtime.

Changes:

  • Add new optional parameters to TT_TensorDescType:
    • mode: StringAttr ("tiled" or "im2col", default: "tiled")
    • elementStrides: DenseI64ArrayAttr (optional)
    • pixelBoxLowerCorner: DenseI64ArrayAttr (optional)
    • pixelBoxUpperCorner: DenseI64ArrayAttr (optional)
    • channelsPerPixel: optional<int64_t> (optional)
    • pixelsPerColumn: optional<int64_t> (optional)

elementStrides is the traversal stride information, which can be used for dilated convolution.
pixelBoxLowerCorner and pixelBoxUpperCorner specify the padding information for the input.
channelsPerPixel and pixelsPerColumn specify the block size in shared memory, the block size will be pixelsPerColumn x channelsPerPixel = BLOCK_M/BLOCK_N x BLOCK_K

  • Add a type verifier with the following constraints:

    • mode must be either "tiled" or "im2col"
    • In tiled mode, im2col-specific parameters must not be set
    • In im2col mode, blockType must be rank-2, and its dimensions must match channelsPerPixel (N) and pixelsPerColumn (M), if those are specified
  • Update Types.h to include BuiltinAttributes.h for DenseI64ArrayAttr

  • Update test/Triton/ops.mlir with im2col mode tests

Backward compatibility: existing code using TensorDescType(blockType) or TensorDescType(blockType, isSigned) continues to work unchanged, defaulting to tiled mode with no im2col parameters.

New contributor declaration

  • I am not making a trivial change, such as fixing a typo in a comment.

  • I have written a PR description following these
    rules.

  • I have run pre-commit run --from-ref origin/main --to-ref HEAD.

  • Select one of the following.

    • I have added tests.
      • /test for lit tests
      • /unittest for C++ tests
      • /python/test for end-to-end tests
    • This PR does not need a test because FILL THIS IN.
  • Select one of the following.

    • I have not added any lit tests.
    • The lit tests I have added follow these best practices,
      including the "tests should be minimal" section. (Usually running Python code
      and using the instructions it generates is not minimal.)

lezcano and others added 15 commits January 12, 2026 22:38
We implement a LinearLayout-based `ReduceOp` lowering. This has a
number of benefits:

- The logic is noticeably simpler as we barely have to implement
anything. ConvertLayout and some LL helpers do all the heavy lifting
- We get shmem swizzling for free
- We sometimes save a shmem round-trip (before we did it
unconditionally)
- It is now clear that we have a `tmpLl` variable we can carefully
choose (we'll do so in a future PR)
- It opens the door to returning an arbitrary layout (fusing a
`convert_layout` into this op)
- It is now really simple to generalise this op to perform cross-cluster
reductions, provided that `convert_layout` supports them.
- We fix some latent issues the previous implementation had when run on
arbitrary linear layouts. We add a funky regression test that used to
fail and now passes.
- All this while being LOC-neutral!

In future PRs we will improve the choice fo `tmpLl` to avoid in many
cases the last `convert_layout`, and we will pack the inputs in shmem to
be able to vectorize the load/stores for full reductions with multiple
inputs.

This PR was the result of quite a long (but rather successful)
vibe-coding session together with `gpt-5.2-codex`. I found particularly
useful being able to emit a ConvertLayout within this lowering rather
than having to call the lowering of the function manually. This
simplifies the code quite a bit and I would have struggled to convince
MLIR to do so myself.

TODO: Benchmark
Refactor the default behavior of fresh_knobs in test fixtures.

`fresh_knobs` (default): Now preserves library paths (build, nvidia, amd
knobs)
Most tests need CUDA toolkit paths to compile successfully
Respects environment variables like `TRITON_PTXAS_BLACKWELL_PATH`

`fresh_knobs_including_libraries` (new): Resets ALL knobs including
library paths
…ng#9074)

leveraging wrap around due to padding, we can still get bank conflict
free padded share layout when block size is smaller than 16KB.
take Mx64xbf16, k contiguous, kWidth=8, mfma16x16 for example: (rX
stands for row X), the minimal block size can be 32x64.
padding here is set to 16 elements (32 bytes) to avoid bank conflicts
we can pack r0,r4,r8,r12,r16,r20,r24,r28to compose a contiguous tile
```
r0[0+], r0[8+],
                r1[0+], r1[8+],
                                r2[0+], r2[8+],
                                                r3[0+], r3[8+],
r4[0+], r4[8+],
                r5[0+], r5[8+],
                                r6[0+], r6[8+],
                                                r7[0+], r7[8+],
r8[0+], r8[8+],
```
 in LDS, the rows are arranged as below
```
r0,  r4,  r8,  r12, r16, r20, r24, r28
pad, r1,  r5,  r9,  r13, r17, r21, r25
r29, pad, r2,  r6,  r10, r14, r18, r22,
r26, r30, pad, r3,  r7,  r11, r15, r19,
r23, r27, r31
```
…ton-lang#9208)

Clone turned out not to be needed for now, reverting to limit
complexity.
…n-lang#9204)

This optimization moves TransOps closer to their defs, but that doesn't
actually have any impact on the generated code because the actual
transpose code generated by ConvertLayout ops.
…lang#9210)

The test_gather[src_shape2-indices_shape2-0] test fails on RDNA3 and
RDNA4 GPUs with:
triton.runtime.errors.OutOfResources: out of resource: shared memory,
Required: 131072, Hardware limit: 65536.

Extend the existing skip condition (which already covers CDNA2 and
CDNA3) to also include RDNA3 and RDNA4 GPUs.
Before, we mistakenly allowed repeated non-zero bases
Most autoWS passes automatically skip code that does not use autoWS, but
InsertTmemAref does not end up doing that because it needs to examine
`TMEMAlloc`s that occur outside WS regions as well. This can cause
assertion failures when the assumptions baked into the autoWS
implementation are violated (e.g. that an alloc must only have a single
use if there is no token).
…9222)

PR replaces amdgpu intrinsic calls with ROCDL ops which were recently
exposed to the ROCDL dialect in MLIR
@bingyizh233 bingyizh233 changed the title Add im2col information to the tensor descriptor [Nvidia] Enable TMA im2col mode -- Tensor descriptor Jan 14, 2026
qnie-oai and others added 14 commits January 15, 2026 00:40
…n-lang#9202)

### Summary 
This is the first PR in a series that enables TMA im2col mode (in
addition to the existing tiled mode) for NVIDIA GPUs. The overall goal
of the series is to support TMA im2col mode in the Triton compiler and
the Gluon DSL.

The PTX ISA documentation for TMA im2col mode is available here:
https://docs.nvidia.com/cuda/parallel-thread-execution/#tensor-im2col-mode

TMA im2col mode is supported on Hopper, Blackwell. Ampere does not
support TMA im2col mode.

## Summary of changes

The main change in this PR is adding support for TMA im2col mode to
`TTNG_AsyncTMACopyGlobalToLocalOp`, along with a corresponding `lit`
test.

Ultimately, `TTNG_AsyncTMACopyGlobalToLocalOp` will lower to
`cp.async.bulk.tensor` in PTX (see
https://docs.nvidia.com/cuda/parallel-thread-execution/#data-movement-and-conversion-instructions-cp-async-bulk-tensor
for some examples). That lowering will be implemented in follow-up PRs.

## Tensor shapes and arguments

For TMA im2col mode, the input tensor in global memory is 3D, 4D, or 5D:

| Rank | Layout |
|---:|---|
| 3D | NWC |
| 4D | NHWC |
| 5D | NDHWC |

Here, `N` is the batch size, `(D, H, W)` are the (depth, height, width)
spatial dimensions, and `C` is the channel dimension.

The `coord` argument (I32) in `TTNG_AsyncTMACopyGlobalToLocalOp` has the
same rank as the input tensor (3, 4, or 5).

This PR also adds an `offset` argument (I16), which represents the
im2col offset in the spatial dimensions:

- For a 3D input (NWC), `offset` is 1D (W).
- For a 4D input (NHWC), `offset` is 2D (H, W).
- For a 5D input (NDHWC), `offset` is 3D (D, H, W).

In general, $$\text{rank(offset)} = \text{rank(input)} - 2$$.

Note that `offset` is only required for im2col mode; it is not used for
the regular tiled mode.

## Semantics

In TMA im2col mode, `TTNG_AsyncTMACopyGlobalToLocalOp` loads data from
global memory to shared memory in a convolution-friendly layout suitable
for explicit-GEMM algorithms. There is no corresponding im2col mode for
local-to-global copies.


<!---
The core Triton is a small number of people, and we receive many PRs
(thank
you!).  To help us review your code more quickly, **if you are a new
contributor (less than 3 PRs merged) we ask that you complete the
following
tasks and include the filled-out checklist in your PR description.**

Complete the following tasks before sending your PR, and replace `[ ]`
with
`[x]` to indicate you have done them.
-->

# New contributor declaration
- [x] I am not making a trivial change, such as fixing a typo in a
comment.

- [x] I have written a PR description following these
  [rules](https://cbea.ms/git-commit/#why-not-how).

- [x] I have run `pre-commit run --from-ref origin/main --to-ref HEAD`.

- Select one of the following.
  - [x] I have added tests.
    - `/test` for `lit` tests
    - `/unittest` for C++ tests
    - `/python/test` for end-to-end tests
  - [ ] This PR does not need a test because `FILL THIS IN`.

- Select one of the following.
  - [ ] I have not added any `lit` tests.
- [x] The `lit` tests I have added follow these [best
practices](https://mlir.llvm.org/getting_started/TestingGuide/#filecheck-best-practices),
including the "tests should be minimal" section. (Usually running Python
code
    and using the instructions it generates is not minimal.)
…riton-lang#9228)

For gfx1250, descriptor loads are no longer rewritten as regular loads.
Without this change, `CanonicalizePointers` will fail because it cannot
transform `MakeTensorDescOp` to use the rematerialized `FatPtr`.
…-lang#9206)

Cluster barriers are pure execution barriers and wait that all CTAs in a
cluster have signaled the barrier. Note that this only means that one
warp per CTA has signaled the barrier, so we need additional CTA wide
sync to ensure all warps are synced as well.
For real world kernels we get the CTA wide sync from an `async_wait` or
`Membar` if we pipeline through registers. For Gluon kernels this could
look like:
```
    tdm.async_wait(0) # or ttgl.barrier()
    cluster.signal()
    cluster.wait()
```
The cluster barrier does not implicitly include the CTA sync to avoid
having unnecessary CTA scope barriers, as shown in the example above.

The main purpose of cluster execution synchronization is to keep
warps/CTAs within a cluster temporally aligned. This ensures that
multicast loads have a high probability (or even a guarantee) of
broadcasting their data to multiple CTAs in the cluster.
New partition analysis pass based on data flow graph and incremental,
heuristic driven partition merging.

The aim of this is to provide a more general approach for partition
scheduling.

This PR is a drop in replacement of the existing pass and provide
largely the same behavior. Tests have been updated where required.

I have verified on B200 that there are no perf regressions for
`09-persistent-matmul.py` and `06-fused-attention.py`

---------

Co-authored-by: Jeff Niu <jeffniu22@gmail.com>
…criptor creation (triton-lang#9235)

This PR adds a workaround for a driver TMA descriptor related bug which
will cause occasional errors on Blackwell when the tensor's backing
memory allocation is less than 128KB and it is not a dense
non-overlapping tensor.

We follow the CUTLASS change for the same issue:
NVIDIA/cutlass@b7ecaa6#diff-1dfcaf77b33258ff3175540718d9caff1cd471215f741ba42943ef00770e6d04

Unfortunately, there is no test for this, since it is difficult to come
up with a Triton program that reliably reproduces the error.
…g#9234)

Currently the serialized data contain target specific options so
preloading on a different target can cause dramatic failures.
nzaghen and others added 6 commits January 16, 2026 09:25
This barrier allows to synchronise progress across the CTA. It also
allows to specify memory visibility guarantees after the barrier has
been completed.

A ttg.barrier can be on a specific set of address spaces.

This also replace all uses of ttg.local_barrier
and mlir::gpu::barrier with this new ttg.barrier.
add tcgen05.ld.red to Gluon that uses B300 feature that loads tmem +
compute stats.

https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen05-instructions-tcgen05-ld

This show perf improvement on Gluon flash attention on B300, except
causal fp16 attention:

```
$ python python/examples/gluon/01-attention-forward.py
Attention Z=4 H=32 D=64 causal=False has_tmem_red=False:
     N_CTX  triton-fp16 (TFLOPS)  triton-fp8 (TFLOPS)
0   1024.0            252.977768           277.967979
1   2048.0            816.426391           838.455303
2   4096.0            896.514441           911.042858
3   8192.0            918.765777           936.400368
4  16384.0            922.699215           944.950928
5  32768.0            918.744499           957.229512
6  65536.0            904.918074           958.457140
Attention Z=4 H=32 D=64 causal=False has_tmem_red=True:
     N_CTX  triton-fp16 (TFLOPS)  triton-fp8 (TFLOPS)
0   1024.0            263.684175           273.544647
1   2048.0            832.115359           885.267418
2   4096.0            916.381550           965.582454
3   8192.0            941.261428           994.436253
4  16384.0            948.707711          1004.449777
5  32768.0            948.091834          1018.365675
6  65536.0            936.832723          1023.509932
Attention Z=4 H=32 D=64 causal=True has_tmem_red=False:
     N_CTX  triton-fp16 (TFLOPS)  triton-fp8 (TFLOPS)
0   1024.0            136.571821           139.948484
1   2048.0            535.434991           558.382838
2   4096.0            724.982807           755.171625
3   8192.0            796.134863           837.101509
4  16384.0            872.450133           915.462029
5  32768.0            889.374786           943.823419
6  65536.0            869.217661           959.193145
Attention Z=4 H=32 D=64 causal=True has_tmem_red=True:
     N_CTX  triton-fp16 (TFLOPS)  triton-fp8 (TFLOPS)
0   1024.0            132.868438           134.038787
1   2048.0            530.566081           549.091688
2   4096.0            701.461330           928.786050
3   8192.0            778.478193          1040.645422
4  16384.0            854.631747          1148.441103
5  32768.0            862.854052          1186.579258
6  65536.0            857.717297          1199.146270
Attention Z=4 H=32 D=128 causal=False has_tmem_red=False:
     N_CTX  triton-fp16 (TFLOPS)  triton-fp8 (TFLOPS)
0   1024.0            521.084272           552.213352
1   2048.0           1293.432732          1569.988492
2   4096.0           1338.064712          1613.715592
3   8192.0           1312.318013          1720.438548
4  16384.0           1309.450228          1786.343871
5  32768.0           1298.798675          1808.070473
6  65536.0           1345.804669          1798.150675
Attention Z=4 H=32 D=128 causal=False has_tmem_red=True:
     N_CTX  triton-fp16 (TFLOPS)  triton-fp8 (TFLOPS)
0   1024.0            531.395175           562.816992
1   2048.0           1322.340532          1678.849868
2   4096.0           1367.440832          1753.998270
3   8192.0           1323.674061          1841.206150
4  16384.0           1341.168380          1934.546652
5  32768.0           1330.483606          1998.499839
6  65536.0           1354.725476          1964.254907
Attention Z=4 H=32 D=128 causal=True has_tmem_red=False:
     N_CTX  triton-fp16 (TFLOPS)  triton-fp8 (TFLOPS)
0   1024.0            263.377373           268.205711
1   2048.0            958.217023          1113.171399
2   4096.0           1139.760977          1369.767809
3   8192.0           1199.838614          1560.139154
4  16384.0           1212.925922          1750.154897
5  32768.0           1249.760329          1870.659789
6  65536.0           1156.212302          1820.937396
Attention Z=4 H=32 D=128 causal=True has_tmem_red=True:
     N_CTX  triton-fp16 (TFLOPS)  triton-fp8 (TFLOPS)
0   1024.0            259.034867           268.180008
1   2048.0            931.347834          1089.042686
2   4096.0           1085.310044          1625.788256
3   8192.0           1123.443093          1741.503819
4  16384.0           1174.031312          1939.977100
5  32768.0           1176.217897          2012.142667
6  65536.0           1121.913976          1956.926863
```

---------

Co-authored-by: evghenii <egaburov@nvidia>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.