[TRTLLM-12580][perf] ltx2: fused RMSNorm+RoPE across all attention paths + PE pre-shard#13985
Conversation
|
/bot run --disable-fail-fast |
📝 WalkthroughWalkthroughThis PR adds fused CUDA kernels and PyTorch operators for efficient full-dimension RMSNorm and RoPE computation in LTX-2 attention, integrating them into the transformer model with updated positional embedding layouts, caching, and cross-attention flow to enable split-kernel execution on Q/K tensors. ChangesFused Kernels for LTX-2 Attention Optimization
Estimated code review effort🎯 4 (Complex) | ⏱️ ~75 minutes Possibly related PRs
Suggested labels
Suggested reviewers
🚥 Pre-merge checks | ✅ 4 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (4 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Tip 💬 Introducing Slack Agent: The best way for teams to turn conversations into code.Slack Agent is built on CodeRabbit's deep understanding of your code, so your team can collaborate across the entire SDLC without losing context.
Built for teams:
One agent for your entire SDLC. Right inside Slack. Comment |
There was a problem hiding this comment.
Actionable comments posted: 6
🧹 Nitpick comments (2)
cpp/tensorrt_llm/kernels/fusedDiTSplitQKNormRopeKernel.h (1)
17-18: ⚡ Quick winUse
#pragma oncein this new header.The repo standard for
*.h/*.hpp/*.cuhis#pragma once, so this new file should follow that instead of introducing another guard pattern.As per coding guidelines, "
**/*.{h,hpp,cuh}: Use#pragma oncefor C++ header include guards."🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@cpp/tensorrt_llm/kernels/fusedDiTSplitQKNormRopeKernel.h` around lines 17 - 18, Replace the traditional include guard in fusedDiTSplitQKNormRopeKernel.h with the repo-standard `#pragma once`: remove the `#ifndef/`#define/#endif block that defines TRTLLM_FUSEDDITSPLITQKNORMROPEKERNEL_H and add a single `#pragma once` at the top of the file so the header follows the project's include-guard convention.cpp/tensorrt_llm/kernels/fusedDiTSplitNormKernel.h (1)
17-18: ⚡ Quick winUse
#pragma oncein this new header.The repo’s header rule is
#pragma once, so the guard here will be a style violation in a brand-new.hfile.As per coding guidelines, `**/*.{h,hpp,cuh}`: Use `#pragma once` for C++ header include guards.Suggested change
-#ifndef TRTLLM_FUSEDDITSPLITNORMKERNEL_H -#define TRTLLM_FUSEDDITSPLITNORMKERNEL_H +#pragma once ... -#endif // TRTLLM_FUSEDDITSPLITNORMKERNEL_HAlso applies to: 56-56
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@cpp/tensorrt_llm/kernels/fusedDiTSplitNormKernel.h` around lines 17 - 18, Replace the traditional include-guard macro in this new header with the repo-standard pragma: remove the `#ifndef/`#define (TRTLLM_FUSEDDITSPLITNORMKERNEL_H) and the matching `#endif` and add a single line "pragma once" at the top of fusedDiTSplitNormKernel.h; ensure any other identical guard occurrence in this file (the second occurrence noted) is removed so the file uses only `#pragma` once as the include guard.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Inline comments:
In `@cpp/tensorrt_llm/kernels/fusedDiTQKNormRopeKernel.cu`:
- Around line 265-278: The shared buffer warp_sums is written by laneId==0
inside reduce_partial (after calling tensorrt_llm::common::warpReduceSum) but
then immediately read by all threads; add a block barrier to ensure all warps
finish writing before any warp begins summing the array. Specifically, in the
lambda reduce_partial (used for Q and K reductions) insert a __syncthreads()
after the laneId==0 assignment to warp_sums[warpId] (or equivalently before the
loop that reads warp_sums) so the second reduction cannot start overwriting
warp_sums while another warp is still summing; keep references to laneId,
warpId, warp_sums and warpReduceSum to locate the change.
In `@cpp/tensorrt_llm/kernels/fusedDiTSplitQKNormRopeKernel.cu`:
- Around line 189-202: The shuffle uses a full-32-bit mask (0xffffffff) which is
undefined when some lanes in the warp have early-exited; capture the actual
participating lanes with unsigned int active = __activemask() and pass that
active mask into __shfl_xor_sync instead of 0xffffffff. Compute xor_mask and
negate as you already do (xor_mask, negate) but ensure active is read once
before the loop and use __shfl_xor_sync(active, elements[i], xor_mask) for all
CHUNK_ELEMS shuffles so only active lanes participate consistently (refer to
xor_mask, negate, elements, CHUNK_ELEMS, and __shfl_xor_sync).
In `@cpp/tensorrt_llm/thop/fusedDiTQKNormRopeOp.cpp`:
- Around line 52-71: The launch currently uses qkv.get_device() but does not
validate that q_weight, k_weight, any optional add-weights, cos_emb and sin_emb
live on the same CUDA device; add explicit device checks after you compute auto
dev = qkv.get_device() (or wherever the kernel target is chosen) and assert with
TORCH_CHECK(each_tensor.get_device() == dev, "tensor X must be on same device as
qkv") for q_weight, k_weight, cos_emb, sin_emb and any optional weight/add
tensors used by the kernel, and replicate the same device-validation in the
alternate branch (the block around lines 103-148) so mixed-device inputs fail
fast with a clear error message.
In `@cpp/tensorrt_llm/thop/fusedDiTSplitNormOp.cpp`:
- Around line 39-51: The code currently validates tensor and weight types but
not that weight lives on the same CUDA device as tensor before calling
tensorrt_llm::kernels::launchFusedDiTSplitNormFullDim; add a device check (e.g.
using TORCH_CHECK or extending CHECK_INPUT) to assert weight.device().is_cuda()
and weight.get_device() == tensor.get_device() (or equivalent) prior to
obtaining stream and launching launchFusedDiTSplitNormFullDim so you never pass
a pointer from a different GPU.
In `@cpp/tensorrt_llm/thop/fusedDiTSplitQKNormRopeOp.cpp`:
- Around line 40-85: The code currently only uses tensor.get_device() for the
kernel launch but doesn't assert that weight, cos_emb, and sin_emb live on the
same device; add device checks and use the tensor device for the CUDA stream.
Specifically, in the function that contains these checks (the block around
CHECK_INPUT(...) and the launchFusedDiTSplitNormFullDimRope call), add
TORCH_CHECK statements that weight.device().equals(tensor.device()),
cos_emb.device().equals(tensor.device()), and
sin_emb.device().equals(tensor.device()) with clear error messages, and
replace/get the CUDA stream via
at::cuda::getCurrentCUDAStream(tensor.get_device()) (already used) ensuring the
device matches; this will prevent dereferencing foreign-device pointers at
kernel launch.
In `@tensorrt_llm/_torch/visual_gen/models/ltx2/ltx2_core/rope.py`:
- Around line 55-67: Revert the TensorRT-specific token-major reshape/slicing
and block-duplicated cos/sin handling from the ltx2_core RoPE implementation:
remove the code paths that reshape input_tensor to (b, t, h, -1), set
needs_reshape, and the branch that slices cos_freqs/sin_freqs based on equality
with input_tensor.shape[-1] (including the use of half and the comment about
_split_freqs_cis); restore the upstream, generic RoPE behavior so ltx2_core uses
the original cos_freqs/sin_freqs layout. Implement the token-major reshape and
block-duplicated-layout translation instead in the TensorRT-owned caller/helper
layer that invokes this module (handling cos_freqs, sin_freqs, input_tensor
transformations there). This also applies to the analogous changes around the
other block (lines referenced 175-186) that perform the same TensorRT-specific
translations.
---
Nitpick comments:
In `@cpp/tensorrt_llm/kernels/fusedDiTSplitNormKernel.h`:
- Around line 17-18: Replace the traditional include-guard macro in this new
header with the repo-standard pragma: remove the `#ifndef/`#define
(TRTLLM_FUSEDDITSPLITNORMKERNEL_H) and the matching `#endif` and add a single line
"pragma once" at the top of fusedDiTSplitNormKernel.h; ensure any other
identical guard occurrence in this file (the second occurrence noted) is removed
so the file uses only `#pragma` once as the include guard.
In `@cpp/tensorrt_llm/kernels/fusedDiTSplitQKNormRopeKernel.h`:
- Around line 17-18: Replace the traditional include guard in
fusedDiTSplitQKNormRopeKernel.h with the repo-standard `#pragma once`: remove
the `#ifndef/`#define/#endif block that defines
TRTLLM_FUSEDDITSPLITQKNORMROPEKERNEL_H and add a single `#pragma once` at the
top of the file so the header follows the project's include-guard convention.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Enterprise
Run ID: 5e5949bc-935b-4abe-a396-11ea23696258
📒 Files selected for processing (19)
cpp/tensorrt_llm/kernels/fusedDiTQKNormRopeKernel.cucpp/tensorrt_llm/kernels/fusedDiTQKNormRopeKernel.hcpp/tensorrt_llm/kernels/fusedDiTSplitNormKernel.cucpp/tensorrt_llm/kernels/fusedDiTSplitNormKernel.hcpp/tensorrt_llm/kernels/fusedDiTSplitQKNormRopeKernel.cucpp/tensorrt_llm/kernels/fusedDiTSplitQKNormRopeKernel.hcpp/tensorrt_llm/thop/CMakeLists.txtcpp/tensorrt_llm/thop/fusedDiTQKNormRopeOp.cppcpp/tensorrt_llm/thop/fusedDiTSplitNormOp.cppcpp/tensorrt_llm/thop/fusedDiTSplitQKNormRopeOp.cpptensorrt_llm/_torch/compilation/utils.pytensorrt_llm/_torch/visual_gen/models/ltx2/ltx2_core/rope.pytensorrt_llm/_torch/visual_gen/models/ltx2/ltx2_core/transformer_args.pytensorrt_llm/_torch/visual_gen/models/ltx2/pipeline_ltx2.pytensorrt_llm/_torch/visual_gen/models/ltx2/text_cache.pytensorrt_llm/_torch/visual_gen/models/ltx2/transformer_ltx2.pytensorrt_llm/_torch/visual_gen/modules/attention.pytests/unittest/_torch/thop/parallel/test_fused_dit_split_norm.pytests/unittest/_torch/thop/parallel/test_fused_dit_split_qk_norm_rope.py
a4c12f5 to
a595314
Compare
|
/bot run --disable-fail-fast |
|
PR_Github #47717 [ run ] triggered by Bot. Commit: |
|
PR_Github #47717 [ run ] completed with state
|
…x kernel UB CR-flagged fixes: - fusedDiTQKNormRopeKernel.cu: add trailing __syncthreads() in reduce_partial, preventing race when warp_sums[] is reused for Q->K reductions (CR PR NVIDIA#13985 thread 1). - fusedDiTQKNormRopeKernel.cu + fusedDiTSplitQKNormRopeKernel.cu: use __activemask() instead of 0xffffffff for the rotate-half __shfl_xor_sync, which avoided UB for small num_heads*HEAD_DIM where the surrounding chunk loop has partial-warp early-exit (CR thread 2). PE cache plumbing simplification (data flow): - Drop the 4 *_pe_2d duplicate fields in TextCache; the single *_pe field now holds the form the consumer expects (2D [T_local, H*D] contiguous when fuse_qk_norm_rope=True, 4D [B, T_local, H, D] otherwise). - Revert ltx2_core/transformer_args.py to upstream (drop the two _2d fields + two _2d kwargs that C8 had added to the upstream-mirrored file). - LTX2Attention now explicitly sets fuse_qk_norm_rope=True (the base class default for qk_norm_mode="full" was False, but the LTX-2 forward path ignored the flag); forward() now actually gates on it. - _shard_transformer_args drops the per-step _shard_pe — PE is sharded one-time in prepare_text_cache via _make_pe_local (renamed from _make_pe_2d_local; now produces 2D or 4D based on the fuse flag). - BasicAVTransformerBlock's 6 'pe=*._2d or *._4d' fallback expressions collapse to a single 'pe=*._pe' reference. - _forward_unfused gains a pe.ndim assert so the naive eager path fails loud if anyone passes the fused 2D form. - pipeline_ltx2 cuda-graph clone/copy halved (10 -> 6 calls per TextCache). Test reorg: - Move test_fused_dit_split_qk_norm_rope.py + test_fused_dit_split_norm.py from parallel/ to parallel_hw_agnostic/. Extend the packed test file with full-dim cells covering LTX-2 self-attn shapes (T=12288 H=32 D=128 + T=504 H=32 D=64, including the broadcast-over-B path). Verification: - 159 unit tests pass (packed + split + norm-only across fp32/bf16 cos). - 1-GPU 40-step LTX-2 e2e (gs=3.0): raw video sha256 bit-identical to the pre-cleanup HEAD (99cc34517b19e3e12fb66ccc439b4c5f7b2575cf862e627fb504e1fdcc120755). Signed-off-by: Yiyun Lu <55233584+luyiyun1021@users.noreply.github.com>
8cc054e to
fe30827
Compare
…x kernel UB CR-flagged fixes: - fusedDiTQKNormRopeKernel.cu: add trailing __syncthreads() in reduce_partial, preventing race when warp_sums[] is reused for Q->K reductions (CR PR NVIDIA#13985 thread 1). - fusedDiTQKNormRopeKernel.cu + fusedDiTSplitQKNormRopeKernel.cu: use __activemask() instead of 0xffffffff for the rotate-half __shfl_xor_sync, which avoided UB for small num_heads*HEAD_DIM where the surrounding chunk loop has partial-warp early-exit (CR thread 2). PE cache plumbing simplification (data flow): - Drop the 4 *_pe_2d duplicate fields in TextCache; the single *_pe field now holds the form the consumer expects (2D [T_local, H*D] contiguous when fuse_qk_norm_rope=True, 4D [B, T_local, H, D] otherwise). - Revert ltx2_core/transformer_args.py to upstream (drop the two _2d fields + two _2d kwargs that C8 had added to the upstream-mirrored file). - LTX2Attention now explicitly sets fuse_qk_norm_rope=True (the base class default for qk_norm_mode="full" was False, but the LTX-2 forward path ignored the flag); forward() now actually gates on it. - _shard_transformer_args drops the per-step _shard_pe — PE is sharded one-time in prepare_text_cache via _make_pe_local (renamed from _make_pe_2d_local; now produces 2D or 4D based on the fuse flag). - BasicAVTransformerBlock's 6 'pe=*._2d or *._4d' fallback expressions collapse to a single 'pe=*._pe' reference. - _forward_unfused gains a pe.ndim assert so the naive eager path fails loud if anyone passes the fused 2D form. - pipeline_ltx2 cuda-graph clone/copy halved (10 -> 6 calls per TextCache). Test reorg: - Move test_fused_dit_split_qk_norm_rope.py + test_fused_dit_split_norm.py from parallel/ to parallel_hw_agnostic/. Extend the packed test file with full-dim cells covering LTX-2 self-attn shapes (T=12288 H=32 D=128 + T=504 H=32 D=64, including the broadcast-over-B path). Verification: - 159 unit tests pass (packed + split + norm-only across fp32/bf16 cos). - 1-GPU 40-step LTX-2 e2e (gs=3.0): raw video sha256 bit-identical to the pre-cleanup HEAD (99cc34517b19e3e12fb66ccc439b4c5f7b2575cf862e627fb504e1fdcc120755). Signed-off-by: Yiyun Lu <55233584+luyiyun1021@users.noreply.github.com>
|
/bot run --disable-fail-fast |
|
PR_Github #49082 [ run ] triggered by Bot. Commit: |
|
PR_Github #49082 [ run ] completed with state
|
|
/bot run |
|
/bot run --disable-fail-fast |
|
PR_Github #49367 [ run ] triggered by Bot. Commit: |
|
PR_Github #49368 [ run ] triggered by Bot. Commit: |
|
PR_Github #49367 [ run ] completed with state |
|
PR_Github #49368 [ run ] completed with state
|
…x kernel UB CR-flagged fixes: - fusedDiTQKNormRopeKernel.cu: add trailing __syncthreads() in reduce_partial, preventing race when warp_sums[] is reused for Q->K reductions (CR PR NVIDIA#13985 thread 1). - fusedDiTQKNormRopeKernel.cu + fusedDiTSplitQKNormRopeKernel.cu: use __activemask() instead of 0xffffffff for the rotate-half __shfl_xor_sync, which avoided UB for small num_heads*HEAD_DIM where the surrounding chunk loop has partial-warp early-exit (CR thread 2). PE cache plumbing simplification (data flow): - Drop the 4 *_pe_2d duplicate fields in TextCache; the single *_pe field now holds the form the consumer expects (2D [T_local, H*D] contiguous when fuse_qk_norm_rope=True, 4D [B, T_local, H, D] otherwise). - Revert ltx2_core/transformer_args.py to upstream (drop the two _2d fields + two _2d kwargs that C8 had added to the upstream-mirrored file). - LTX2Attention now explicitly sets fuse_qk_norm_rope=True (the base class default for qk_norm_mode="full" was False, but the LTX-2 forward path ignored the flag); forward() now actually gates on it. - _shard_transformer_args drops the per-step _shard_pe — PE is sharded one-time in prepare_text_cache via _make_pe_local (renamed from _make_pe_2d_local; now produces 2D or 4D based on the fuse flag). - BasicAVTransformerBlock's 6 'pe=*._2d or *._4d' fallback expressions collapse to a single 'pe=*._pe' reference. - _forward_unfused gains a pe.ndim assert so the naive eager path fails loud if anyone passes the fused 2D form. - pipeline_ltx2 cuda-graph clone/copy halved (10 -> 6 calls per TextCache). Test reorg: - Move test_fused_dit_split_qk_norm_rope.py + test_fused_dit_split_norm.py from parallel/ to parallel_hw_agnostic/. Extend the packed test file with full-dim cells covering LTX-2 self-attn shapes (T=12288 H=32 D=128 + T=504 H=32 D=64, including the broadcast-over-B path). Verification: - 159 unit tests pass (packed + split + norm-only across fp32/bf16 cos). - 1-GPU 40-step LTX-2 e2e (gs=3.0): raw video sha256 bit-identical to the pre-cleanup HEAD (99cc34517b19e3e12fb66ccc439b4c5f7b2575cf862e627fb504e1fdcc120755). Signed-off-by: Yiyun Lu <55233584+luyiyun1021@users.noreply.github.com>
|
PR_Github #49850 [ run ] completed with state |
|
/bot run --disable-fail-fast |
|
PR_Github #49919 [ run ] triggered by Bot. Commit: |
|
PR_Github #49889 [ run ] completed with state |
|
PR_Github #49919 [ run ] completed with state
|
Wires the LTX-2 transformer to fused RMSNorm+RoPE kernels at every attention call site (self / text-cross / AV-cross). Unblocks LTX-2 production (rope_type='split') previously gated out. SPLIT cos/sin block-duplicated to head_dim in _split_freqs_cis so kernel INTERLEAVE=false branch can consume it directly. AV cross-attn rope moved into project_kv (RoPE commutes with seq-dim concat) — saves cos/sin all-gather and U× K-rope compute under Ulysses. Signed-off-by: Yiyun Lu <55233584+luyiyun1021@users.noreply.github.com>
…atch
Part A of plan_v3_2: fix the text-cross-attn Q-norm fallback that ran
flashinfer norm 96x/step (~50 ms / 10 step on LTX-2 video) and consolidate
LTX2Attention.forward dispatch around qkv_mode.
Changes:
- New kernel `fused_dit_split_norm` (mirror of fused_dit_split_norm_rope
without the RoPE step). Block=256 chunked reduce; head_dim in {64, 128};
num_heads <= 32. Code is ~50 percent of split norm+rope kernel (no cos/sin
LDG, no rotation). New thop binding + inplace_map entry.
- LTX2Attention helpers: new _apply_split_norm and dispatcher
_apply_split_norm_or_norm_rope (norm-only when pe is None, else norm+rope).
- LTX2Attention.forward refactored: dispatch by self.qkv_mode (FUSE_QKV ->
packed kernel; SEPARATE_QKV -> split fuse on Q with optional Q+K when
uncached). The three can_fuse_* booleans + cross-attn-uncached dead branch
are removed; fallback _forward_unfused only handles head_dim outside
{64, 128}.
- LTX2Attention.project_kv: K-norm path always goes through the split-fuse
dispatcher; old _pe_compatible() probe gone.
- Tests: new test_fused_dit_split_norm.py covering parametrized small shapes
+ LTX-2 video/audio shapes + non-contiguous reject.
Verified: 64 unit tests pass (21 new norm-only + 43 split norm+rope
regression). End-to-end + nsys validation in follow-up commits.
Signed-off-by: Yiyun Lu <55233584+luyiyun1021@users.noreply.github.com>
Remove unused HEAD_DIM cases, dead INTERLEAVE / PER_HEAD_COS template branches, and obsolete test parametrizations from the fused DiT split QK-norm-rope kernels left over after the qkv_mode dispatch refactor. No functional change. Signed-off-by: Yiyun Lu <55233584+luyiyun1021@users.noreply.github.com>
Drop the trailing swapaxes(1, 2) inside _split_freqs_cis so the per-head
SPLIT cos/sin are produced directly in token-major [B, T, H, D] layout.
This is exactly what the fused norm+RoPE kernel reads after
reshape(-1, num_heads * head_dim), so the helpers no longer need to do a
permute(0, 2, 1, 3).contiguous() on every call before the kernel launch.
Updated consumers:
- rope._apply_split_rotary_emb: reshape input to match cos directly,
no swapaxes on input or output
- LTX2Attention._apply_split_norm_rope: drop the cos/sin permute branch
- Attention.apply_qk_norm_rope: drop the cos/sin permute branch
- LTXModel._shard_pe (Ulysses): slice the seq dim at index 1 instead
of index 2 for 4D SPLIT-rope cos/sin
The kernel itself is unchanged. The only behavioral difference is the
elimination of the per-call permute().contiguous() in the helpers, which
removes a recurring allocation + memory-shuffle prologue on the
self-attn FUSE_QKV path.
Signed-off-by: Yiyun Lu <55233584+luyiyun1021@users.noreply.github.com>
…in helper
Add a bf16 branch to the fused full-dim RMSNorm+RoPE kernels so the
helper no longer has to upcast cos/sin from bf16 to fp32 before the
launch. Two kernels gain a `typename CosT` template parameter:
- fusedDiTSplitQKNormRopeKernel (LTX-2 cross-attn / SEPARATE_QKV)
- fusedDiTQKNormFullDimRopeKernel (LTX-2 self-attn / FUSE_QKV)
Both keep the existing fp32 cos path; bf16 cos is loaded as a single
LDG.128 (uint4) per chunk vs 2x LDG.128 (float4) for fp32, and upcast
in registers via __bfloat1622float2 before the math (which still runs
in fp32). Math is unchanged: bf16 -> fp32 is lossless, so kernel output
is bit-equivalent to the fp32-cos path.
Op bindings (`fusedDiTSplitQKNormRopeOp.cpp`, `fusedDiTQKNormRopeOp.cpp`)
relax the cos/sin dtype check from strict-fp32 to fp32-or-bf16 and
forward `cos_is_bf16` to the launcher. The per-head FLUX/Cosmos path
(`fusedDiTQKNormRopeKernel`, non-FullDim) keeps fp32-only since FLUX
cos is already fp32 upstream.
Helpers `LTX2Attention._apply_split_norm_rope` and
`Attention.apply_qk_norm_rope` drop `.float()` from the cos chain;
cos.reshape(..).contiguous() now passes bf16 through directly. For
LTX-2 this skips the per-call bf16->fp32 alloc + cast (~200 -> 400 MB
on video shape). For FLUX this is a no-op since cos was already fp32.
Unit tests parametrize on cos_dtype in {fp32, bf16}; 104/104 pass on
B200 in 1.28s, including LTX-2 production shapes
(T=12288 H=32 D=128, T=504 H=32 D=64) for both interleave modes.
E2E: 40-step LTX-2 single-stage produces a bit-identical video
(sha256) vs the prior B-1 commit. Single-op kernel time on B200
(B=2, 20-iter median) drops on LTX-2 cells:
ltx2_video_self_attn (T=12288 H=32 D=128) 0.3029 -> 0.2464 ms (-19%)
ltx2_audio_self_attn (T=504 H=32 D=64) 0.0095 -> 0.0086 ms (-10%)
ltx2_text_cross_video (T=12288 H=32 D=128) 0.1789 -> 0.1387 ms (-22%)
ltx2_av_a2v_q (T=12288 H=32 D=128) 0.1789 -> 0.1386 ms (-22%)
FLUX per-head cells unchanged within noise.
Signed-off-by: Yiyun Lu <55233584+luyiyun1021@users.noreply.github.com>
…elpers Replace the numel-based 4-way if/elif chain that derived cos_last with a direct 2-way dispatch keyed off qk_norm_mode (Attention.apply_qk_norm_rope) or hard-coded num_heads * head_dim (LTX2Attention._apply_split_norm_rope, LTX-2-only path). The 4-way derivation was overgeneralization I introduced in the original fuse-all-paths commit (f62cb5c): it conflated "is cos per-head vs shared-across-heads" (which determines cos_last) with "does cos carry the batch dim" (which is handled separately by the .reshape + optional .repeat below). Only the first is needed to pick cos_last; the second is figured out post-reshape via cos_2d.shape[0] == S vs B*S. Attention.apply_qk_norm_rope (shared FLUX/Cosmos/LTX-2): cos_last = self.q_dim if self.qk_norm_mode == "full" else self.head_dim LTX2Attention._apply_split_norm_rope (LTX-2 only): cos_last = num_heads * self.head_dim No functional change. 104/104 kernel unit tests pass; 40-step LTX-2 single-stage e2e sha256 unchanged (99cc34517b19e3e12fb66ccc439b4c5f7b2575cf862e627fb504e1fdcc120755). Signed-off-by: Yiyun Lu <55233584+luyiyun1021@users.noreply.github.com>
…t(B, 1) Add cos_seq_per_batch template param to the LTX-2 full-dim norm+rope kernels (split + packed). When non-zero, the kernel computes cos_tokenIdx = tokenIdx % cos_seq_per_batch to broadcast cos/sin across B without a host-side tile + L2 footprint. The op binding auto-detects: cos.size(0) != num_tokens with even divisor => broadcast mode, cos_seq_per_batch := cos.size(0). Helpers in LTX2Attention._apply_split_norm_rope and Attention.apply_qk_norm_rope pass cos directly as [T, H*D] instead of repeating to [B*T, H*D]. The per-head FLUX kernel does not support broadcast (its indexing assumes flat token-major); the op binding rejects broadcast for that path with TORCH_CHECK, so the FLUX caller still tiles when B > 1. Lossless: in-kernel modulo is exact. Verified by 104 unit tests across fp32/bf16 cos and all paths. Kernel bench (B=2, B200, packed video self-attn T=12288 H=32 D=128): -58 us/call (~9% kernel-time); neutral/noise on smaller shapes. Signed-off-by: Yiyun Lu <55233584+luyiyun1021@users.noreply.github.com>
…step reshape+contiguous Add 4 new optional fields to TextCache and 2 to TransformerArgs carrying the sharded-local 2D contiguous [T_local, H*D] form of self-attn and cross-attn PE -- the exact shape consumed by the fused norm+rope kernels. Computed once in LTXModel.prepare_text_cache() via _make_pe_2d_local() which slices the 4D PE along seq dim by Ulysses rank, then reshape + .contiguous(). The 4D PE flows through args.positional_embeddings as before for the un-fused fallback (apply_rotary_emb). _shard_transformer_args passes the 2D fields through unchanged (already sharded). Hot helpers (apply_qk_norm_rope / _apply_split_norm_rope) auto-handle the 2D form: cos.reshape(-1, cos_last).contiguous() degenerates to view+no-op when input is already 2D contiguous. Call sites in BasicAVTransformerBlock prefer the 2D variant with 4D fallback (or operator), so paths without a 2D form (e.g. INTERLEAVED rope models) keep working. pipeline_ltx2 cuda-graph clone/copy paths cover the new fields. Lossless: 2D form is reshape(B*T,-1) of the sharded 4D PE; reshape on a contiguous tensor is a free view, .contiguous() only materializes when the prior slice-view is strided (one alloc per generate() vs per step). Verified by 127 unit tests on the kernel ops (all paths, fp32/bf16 cos). Signed-off-by: Yiyun Lu <55233584+luyiyun1021@users.noreply.github.com>
…x kernel UB CR-flagged fixes: - fusedDiTQKNormRopeKernel.cu: add trailing __syncthreads() in reduce_partial, preventing race when warp_sums[] is reused for Q->K reductions (CR PR NVIDIA#13985 thread 1). - fusedDiTQKNormRopeKernel.cu + fusedDiTSplitQKNormRopeKernel.cu: use __activemask() instead of 0xffffffff for the rotate-half __shfl_xor_sync, which avoided UB for small num_heads*HEAD_DIM where the surrounding chunk loop has partial-warp early-exit (CR thread 2). PE cache plumbing simplification (data flow): - Drop the 4 *_pe_2d duplicate fields in TextCache; the single *_pe field now holds the form the consumer expects (2D [T_local, H*D] contiguous when fuse_qk_norm_rope=True, 4D [B, T_local, H, D] otherwise). - Revert ltx2_core/transformer_args.py to upstream (drop the two _2d fields + two _2d kwargs that C8 had added to the upstream-mirrored file). - LTX2Attention now explicitly sets fuse_qk_norm_rope=True (the base class default for qk_norm_mode="full" was False, but the LTX-2 forward path ignored the flag); forward() now actually gates on it. - _shard_transformer_args drops the per-step _shard_pe — PE is sharded one-time in prepare_text_cache via _make_pe_local (renamed from _make_pe_2d_local; now produces 2D or 4D based on the fuse flag). - BasicAVTransformerBlock's 6 'pe=*._2d or *._4d' fallback expressions collapse to a single 'pe=*._pe' reference. - _forward_unfused gains a pe.ndim assert so the naive eager path fails loud if anyone passes the fused 2D form. - pipeline_ltx2 cuda-graph clone/copy halved (10 -> 6 calls per TextCache). Test reorg: - Move test_fused_dit_split_qk_norm_rope.py + test_fused_dit_split_norm.py from parallel/ to parallel_hw_agnostic/. Extend the packed test file with full-dim cells covering LTX-2 self-attn shapes (T=12288 H=32 D=128 + T=504 H=32 D=64, including the broadcast-over-B path). Verification: - 159 unit tests pass (packed + split + norm-only across fp32/bf16 cos). - 1-GPU 40-step LTX-2 e2e (gs=3.0): raw video sha256 bit-identical to the pre-cleanup HEAD (99cc34517b19e3e12fb66ccc439b4c5f7b2575cf862e627fb504e1fdcc120755). Signed-off-by: Yiyun Lu <55233584+luyiyun1021@users.noreply.github.com>
…+ SMEM staging Restructure all three fused kernels (fusedDiTSplitNormKernel, fusedDiTSplitQKNormRopeKernel, fusedDiTQKNormFullDimRopeKernel) with a unified 2-row-per-CTA layout (256 threads = 2 rows x 128 threads x 4 warps) and a cp.async HBM->SMEM staging pipeline for X / cos / sin, overlapped with synchronous register-cache loads for the per-head norm weight. Phase 1 (sum^2) reads X from SMEM with no HBM re-read; phase 2 reads X+cos+sin from SMEM and multiplies by the register-cached weight to write the output in-place. Deletes the V1 single-row variants and renames the V2 implementations back to the original names so the LTX-2 caller and the op binding are unchanged. Adds cudaFuncSetAttribute(MaxDynamicSharedMemorySize) before launch for kernels that exceed the 48 KB default. Per-head FLUX kernel is untouched. Raises DRAM throughput on the LTX-2 video self-attention shape from ~33% to ~56% of HBM peak, reaching parity with flashinfer's CuTeDSL RMSNorm on the norm-only path while preserving the in-place semantics LTX-2's downstream attention requires. Signed-off-by: Yiyun Lu <55233584+luyiyun1021@users.noreply.github.com>
Comments
- Replace internal V1/V2/V3 dev nomenclature in two .cu files with
descriptive Strategy blocks (no code change).
- Translate a stray Chinese comment to English.
Over-defense removal in transformer_ltx2.py
- Drop the dead `else: # unknown QKVMode` fallback in LTX2Attention.forward
(qkv_mode is constructed in __init__ from a two-value Enum).
- Drop the over-defensive `assert pe.ndim >= 3` in _forward_unfused that
only restated the documented caller contract.
- Drop misleading "Override via config.attention.fuse_qk_norm_rope=False"
comment — there is no such config knob.
- Simplify layered conditional defaults for fuse_video / fuse_audio in
prepare_text_cache; the `is not None` short-circuit was unused.
- Simplify _make_pe_local: drop the silent `ndim != 4 -> None` branch, the
redundant `% seq_parallel_size == 0` check (already covered by
_audio_is_sharded), and the redundant `else: cos_out = cos_local`.
Symmetric naming + base-class hoist
- Rename Attention.apply_qk_norm_rope -> apply_packed_qk_norm_rope to
pair with the new split variants.
- Move LTX2Attention._apply_split_norm_rope / _apply_split_norm /
_apply_split_norm_or_norm_rope to the Attention base class (no logic
change; bodies use only base-class fields). Drop the leading underscore
so all four fused-kernel entry points sit at the same visibility tier.
- Update FLUX (2 call sites) and LTX-2 (5 call sites) accordingly.
- LLM-side Gemma / QkNormAttention are unrelated and untouched.
No behavioral change; only comments, dead branches, and method placement.
Signed-off-by: Yiyun Lu <55233584+luyiyun1021@users.noreply.github.com>
… to PR NVIDIA#13052 cross-head op Attention now takes a qk_norm_rope_kernel kwarg whose value is the torch.ops.trtllm op name backing the fused path: "fused_dit_qk_norm_rope" (default) — Per-head autodispatch for FLUX/Cosmos plus full-dim path for LTX-2. Bounded by the kernel template to num_heads <= 32 and head_dim in {64, 128}. "fused_dit_cross_head_qk_norm_rope" — PR NVIDIA#13052 cross-head kernel for WAN sizes outside the default op's envelope (head_dim=256, or num_heads > 32 e.g. WAN-14B at 40 heads). Requires fp32 head-broadcast cos. WAN-14B is 40 heads x 128, which trips the default op's num_heads_q <= 32 check, so transformer_wan.py opts WAN into the cross-head op explicitly. LTX-2 and FLUX keep the default selector (zero behavior change -- the default op is the same op they were already calling). The cross-head dispatch in apply_packed_qk_norm_rope is an early return that uses head-broadcast fp32 cos (the form PR NVIDIA#13052's op expects); the existing full/per-head path below is untouched. Signed-off-by: Yiyun Lu <55233584+luyiyun1021@users.noreply.github.com>
…contract LTX2Attention now hardcodes fuse_qk_norm_rope=True in __init__, which makes forward() route every self-attn call through the FUSE_QKV branch that unpacks `cos, sin = pe` unconditionally. The four self-attn sanity / backend- equivalence tests previously passed pe=None, which silently fell into the _forward_unfused path back when fuse_qk_norm_rope defaulted to False. That implicit reliance is gone, so the tests now hit a TypeError at line 265. Fix: build an identity-rotation RoPE tuple (cos=1, sin=0, shape [B,T,H,D]) in a `_make_pe` helper and pass it through pe= on the four self-attn testcases. cos/sin layout mirrors what `_split_freqs_cis` produces in production, so the tests exercise the same fused norm+RoPE kernel path without needing real RoPE angles. Identity rotation keeps the resulting shape and norm checks meaningful (q*1 + rotate_half(q)*0 = q). Cross-attention tests are unaffected — they go through SEPARATE_QKV and apply_split_norm_or_norm_rope, which already accepts pe=None as norm-only. E2E LTX-2 nvfp4 single-stage smoke test still passes (32.7s, 12.13 MB mp4). All six tests in test_ltx2_attention.py PASS after the change. Signed-off-by: Yiyun Lu <yiyunl@nvidia.com> Signed-off-by: Yiyun Lu <55233584+luyiyun1021@users.noreply.github.com>
fbe14ed to
306f21b
Compare
|
/bot run --disable-fail-fast |
|
PR_Github #50052 [ run ] triggered by Bot. Commit: |
|
PR_Github #50052 [ run ] completed with state
|
|
/bot run --disable-fail-fast |
|
PR_Github #50092 [ run ] triggered by Bot. Commit: |
|
PR_Github #50092 [ run ] completed with state |
…ths + PE pre-shard (NVIDIA#13985) Signed-off-by: Yiyun Lu <55233584+luyiyun1021@users.noreply.github.com> Signed-off-by: Yiyun Lu <yiyunl@nvidia.com>
…ths + PE pre-shard (NVIDIA#13985) Signed-off-by: Yiyun Lu <55233584+luyiyun1021@users.noreply.github.com> Signed-off-by: Yiyun Lu <yiyunl@nvidia.com>
Summary by CodeRabbit
Release Notes
Description
This PR fuses RMSNorm + RoPE into a single CUDA kernel across all six attention paths in the LTX-2 video/audio DiT, eliminating eager-PyTorch
apply_qk_norm+apply_rotary_embwork on the per-step hot path. Two ops cover both packed FUSE_QKV (self-attn) and SEPARATE_QKV (cross-attn / AV cross-attn) layouts, plus a norm-only variant for text Q-norm. cos/sin are broadcast acrossBinside the kernel, and per-step PE shape prep is lifted out of the denoise loop intoprepare_text_cache.Why
LTX-2 forward had a long chain of small host ops between every GEMM and SDPA — eager RMSNorm, eager rotate-half RoPE,
cos.repeat(B,1),.reshape(-1).contiguous()— generating dozens of triton inductor kernels per block per step. Under cuda_graph the per-launch cost is mostly hidden, but under multi-GPU Ulysses (and at no-cuda-graph e.g. when profiling), the host work shows up as both kernel time and an exposed prep gap before each fused op. This PR collapses all of that into a single kernel call per QK norm+rope site and pre-computes the sharded-local PE once pergenerate().Kernel design
Three ops, two call shapes
fused_dit_qk_norm_rope— packed FUSE_QKV self-attn.qkv = Linear(x)is one tensor[N, (Hq+Hk+Hv)*D]; the kernel walks Q and K slices, applies RMSNorm with the per-stream weight, then RoPE; V is untouched.fused_dit_split_norm_rope— SEPARATE_QKV cross-attn. Single Q (or K) tensor[N, H*D]; norm + rope in place.fused_dit_split_norm— same as above but no RoPE (text Q-norm dispatches here whenpe=None).Optimizations
cp.asyncHBM → SMEM staging for X / cos / sin (single commit group) overlapped with sync register caching of the norm weight (per-head full-dim weight cached asuint4[MAX_CHUNKS]regs).cudaFuncSetAttribute(MaxDynamicSharedMemorySize)at launch time, per kernel specialization, to escape the 48 KB default SMEM cap when bf16 X + bf16 cos + bf16 sin all stage in SMEM.griddepcontrolwait / launch_dependents (sm_90+) for stream-ordered concurrent kernel start.Bviacos_tokenIdx = tokenIdx % cos_seq_per_batch; lets the host pass cos at[T_local, H*D]without.repeat(B, 1)tiling.These get the LTX-2 video self-attention shape to ~56% of HBM peak DRAM throughput, matching
flashinfer.norm.rmsnorm(RMSNormKernelinnorm.cuh) on the norm-only path while preserving the in-place semantics LTX-2's downstream attention requires.Supported parameters
HEAD_DIMINTERLEAVEtrue(pair (2i, 2i+1)) orfalse(rotate-half via__shfl_xor_sync)PER_HEAD_COStrue(cos shape[T, H*D], per-head 3D RoPE) orfalse(cos shape[T, D], shared across heads)CosTfloator__nv_bfloat16(lossless upcast to fp32 in regs at use)cos_seq_per_batch0(flat[B*T, ...]) or>0(broadcast overBinside kernel)num_txt_tokens-1(no dual-stream) or>0(split norm weight on text/image boundary)Extensions over the original FLUX kernel
The pre-PR repo had a per-head FLUX/Cosmos packed kernel (
fusedDiTQKNormRopeKernel) — Q/K share one cos/sin pair[T, D]across heads, fp32 cos required, dual-stream norm weight selection by text-image boundary. This PR keeps that kernel intact for the existing FLUX call sites and adds, on top:[H*D](vs[D]) — LTX-2 / WAN style.[T, H*D]— different freqs per head (LTX-2 3D RoPE).Bto absorb host-sidecos.repeat(B, 1).fused_dit_split_norm_rope/fused_dit_split_norm— new kernels for SEPARATE_QKV (cross-attn), not extensions of the original.Python-side changes
tensorrt_llm/_torch/visual_gen/modules/attention.py(base class)apply_qk_norm_roperenamed toapply_packed_qk_norm_ropefor symmetric naming with the new split variants; body extended to handle full-dim cos shape, bf16 cos, and kernel-side B-broadcast.apply_split_norm_rope,apply_split_norm,apply_split_norm_or_norm_ropecalling the SEPARATE_QKV ops.interleave: boolandfuse_qk_norm_rope: boolconstructor kwargs.tensorrt_llm/_torch/visual_gen/models/ltx2/transformer_ltx2.pyLTX2Attention.forwardrewritten as a fast-path dispatcher: FUSE_QKV → packed kernel; SEPARATE_QKV (cached / uncached) → split kernel._forward_unfusedretained as fallback for unsupported configs.project_kvextended with optionalpe: when set, RoPE is applied on the local K shard before all-gather. RoPE commutes with seq-dim concat, so this is bit-identical to post-gather rope while saving the cos/sin all-gather collective and reducing K-rope FLOPs by U× under Ulysses.BasicAVTransformerBlockAV cross-attn call sites switched to project-before-rotate-before-gather (k_pe=Nonesignals "K already rotated"); the old_sp_gather_pehelper is removed.LTXModel._make_pe_local: one-time PE shard + reshape at cache-build time, replacing the per-step_shard_peslicing. Stores either 2D[T_local, H*D](fused kernel) or 4D[B, T_local, H, D](eager fallback) inTextCache._shard_transformer_argsno longer touches PE (now sharded-local at cache build).tensorrt_llm/_torch/visual_gen/models/ltx2/ltx2_core/rope.py[B, T, H, D]and block-duplicatedD/2 → Dso the fused kernel reads a flat strided pattern; the eager helper slices back toD/2for the unfused path (bit-identical since both halves are equal).Other
text_cache.py— docstring updated to describe the dual-form PE convention (2D fused vs 4D eager).compilation/utils.py— two new ops marked in-place so torch.compile / inductor skips materializing copies.Performance
Unit tests
159 passed, 24 skipped in
tests/unittest/_torch/thop/parallel_hw_agnostic/test_fused_dit_*.py. Coverage spansfp32 / bf16 cos × {per-head, full-dim} × interleave / rotate-half × broadcast / flat. Skipped cases are hardware-specific dual-stream variants not exercised on this rig.Kernel-level (B200, bf16 cos, baseline = production-faithful unfused path matching LTX-2's
_forward_unfused)End-to-end (40 step, 768×1280×121 frames, gs=3.0, B200 NVFP4, n=3 timed median)
Pre-PR is the merge-base with main; "This PR" is HEAD of the branch.
Multi-GPU configs gain more because the project-before-gather pattern eliminates the cos/sin all-gather collective for AV cross-attn under Ulysses. Per-run jitter is < 0.5% on every cell; output SHAs are identical across
n=3for every config.Tests
tests/unittest/_torch/thop/parallel_hw_agnostic/test_fused_dit_qk_norm_rope.py— packed kernel (existing per-head FLUX cases + new full-dim LTX-2 cases).tests/unittest/_torch/thop/parallel_hw_agnostic/test_fused_dit_split_qk_norm_rope.py— split norm + RoPE across cos dtype, rotate-half / interleave, per-head / shared cos.tests/unittest/_torch/thop/parallel_hw_agnostic/test_fused_dit_split_norm.py— split norm-only for text cross-attn Q-norm.PR Checklist
PR description clearly explains what and why.
PR Follows TRT-LLM CODING GUIDELINES.
Test cases are provided for new code paths.
Any new dependencies have been scanned for license and vulnerabilities.
CODEOWNERS updated if ownership changes.
Documentation updated as needed.
Update tava architecture diagram if significant design change.
Please check this after reviewing the above items as appropriate for this PR.
GitHub Bot Help
To see a list of available CI bot commands, please comment
/bot help.