Skip to content

Commit 2355ea0

Browse files
committed
[None][chore] ltx2: address CR review, simplify PE cache plumbing, fix kernel UB
CR-flagged fixes: - fusedDiTQKNormRopeKernel.cu: add trailing __syncthreads() in reduce_partial, preventing race when warp_sums[] is reused for Q->K reductions (CR PR NVIDIA#13985 thread 1). - fusedDiTQKNormRopeKernel.cu + fusedDiTSplitQKNormRopeKernel.cu: use __activemask() instead of 0xffffffff for the rotate-half __shfl_xor_sync, which avoided UB for small num_heads*HEAD_DIM where the surrounding chunk loop has partial-warp early-exit (CR thread 2). PE cache plumbing simplification (data flow): - Drop the 4 *_pe_2d duplicate fields in TextCache; the single *_pe field now holds the form the consumer expects (2D [T_local, H*D] contiguous when fuse_qk_norm_rope=True, 4D [B, T_local, H, D] otherwise). - Revert ltx2_core/transformer_args.py to upstream (drop the two _2d fields + two _2d kwargs that C8 had added to the upstream-mirrored file). - LTX2Attention now explicitly sets fuse_qk_norm_rope=True (the base class default for qk_norm_mode="full" was False, but the LTX-2 forward path ignored the flag); forward() now actually gates on it. - _shard_transformer_args drops the per-step _shard_pe — PE is sharded one-time in prepare_text_cache via _make_pe_local (renamed from _make_pe_2d_local; now produces 2D or 4D based on the fuse flag). - BasicAVTransformerBlock's 6 'pe=*._2d or *._4d' fallback expressions collapse to a single 'pe=*._pe' reference. - _forward_unfused gains a pe.ndim assert so the naive eager path fails loud if anyone passes the fused 2D form. - pipeline_ltx2 cuda-graph clone/copy halved (10 -> 6 calls per TextCache). Test reorg: - Move test_fused_dit_split_qk_norm_rope.py + test_fused_dit_split_norm.py from parallel/ to parallel_hw_agnostic/. Extend the packed test file with full-dim cells covering LTX-2 self-attn shapes (T=12288 H=32 D=128 + T=504 H=32 D=64, including the broadcast-over-B path). Verification: - 159 unit tests pass (packed + split + norm-only across fp32/bf16 cos). - 1-GPU 40-step LTX-2 e2e (gs=3.0): raw video sha256 bit-identical to the pre-cleanup HEAD (99cc34517b19e3e12fb66ccc439b4c5f7b2575cf862e627fb504e1fdcc120755). Signed-off-by: Yiyun Lu <55233584+luyiyun1021@users.noreply.github.com>
1 parent 5feedce commit 2355ea0

11 files changed

Lines changed: 362 additions & 126 deletions

File tree

cpp/tensorrt_llm/kernels/fusedDiTQKNormRopeKernel.cu

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -217,7 +217,7 @@ __global__ void fusedDiTQKNormRopeKernel(__nv_bfloat16* qkv, // [num_tokens, tot
217217
// each head h reads its own slice (LTX-2 INTERLEAVED RoPE).
218218
// Note: when PER_HEAD_COS=true, Q and K share the same cos/sin buffer (same
219219
// num_heads_q == num_heads_k for LTX-2 self-attn).
220-
// CosT: float (fp32 cos) or __nv_bfloat16 (B-2: kernel upcasts in registers).
220+
// CosT: float (fp32 cos) or __nv_bfloat16 (kernel upcasts bf16 to fp32 in registers, lossless).
221221
template <int HEAD_DIM, bool INTERLEAVE, bool PER_HEAD_COS, typename CosT>
222222
__global__ void fusedDiTQKNormFullDimRopeKernel(__nv_bfloat16* qkv, int const num_heads_q, int const num_heads_k,
223223
int const num_heads_v, float const eps, __nv_bfloat16 const* q_weight, __nv_bfloat16 const* k_weight,
@@ -274,6 +274,10 @@ __global__ void fusedDiTQKNormFullDimRopeKernel(__nv_bfloat16* qkv, int const nu
274274
#pragma unroll
275275
for (int w = 0; w < N_WARPS; w++)
276276
total += warp_sums[w];
277+
// Trailing barrier: this lambda is called twice (Q then K) and reuses
278+
// warp_sums; without this, warp X's next-iteration lane-0 write can race
279+
// warp Y's pending read of the previous iteration.
280+
__syncthreads();
277281
return total;
278282
};
279283

@@ -405,13 +409,16 @@ __global__ void fusedDiTQKNormFullDimRopeKernel(__nv_bfloat16* qkv, int const nu
405409
else
406410
{
407411
// rotate-half (LTX-2 SPLIT): partner element at +HEAD_DIM/2 within head.
408-
// Inline partner exchange to avoid 8-reg partner array (Step 1 reg-pressure opt).
412+
// Inline partner exchange to avoid 8-reg partner array (reg-pressure opt).
413+
// Use __activemask() because the surrounding chunk loop can have
414+
// `continue` early-exit on partial warps for small N.
409415
constexpr int xor_mask = HEAD_DIM / 16;
410416
bool const negate = ((laneId & xor_mask) == 0);
417+
unsigned const activeMask = __activemask();
411418
#pragma unroll
412419
for (int i = 0; i < CHUNK_ELEMS; i++)
413420
{
414-
float p = __shfl_xor_sync(0xffffffff, elements[i], xor_mask);
421+
float p = __shfl_xor_sync(activeMask, elements[i], xor_mask);
415422
if (negate)
416423
{
417424
p = -p;

cpp/tensorrt_llm/kernels/fusedDiTSplitQKNormRopeKernel.cu

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ namespace kernels
3737
//
3838
// PER_HEAD_COS=false: cos/sin shape [num_tokens, HEAD_DIM] (FLUX-style, head broadcast).
3939
// PER_HEAD_COS=true: cos/sin shape [num_tokens, num_heads * HEAD_DIM] (LTX-2 3D RoPE).
40-
// CosT: float (fp32 cos) or __nv_bfloat16 (B-2: kernel upcasts in registers).
40+
// CosT: float (fp32 cos) or __nv_bfloat16 (kernel upcasts bf16 to fp32 in registers, lossless).
4141
template <int HEAD_DIM, bool INTERLEAVE, bool PER_HEAD_COS, typename CosT>
4242
__global__ void fusedDiTSplitNormFullDimRopeKernel(__nv_bfloat16* __restrict__ tensor, int const num_tokens,
4343
int const num_heads, float const eps, __nv_bfloat16 const* __restrict__ weight, CosT const* __restrict__ cos_emb,
@@ -188,12 +188,15 @@ __global__ void fusedDiTSplitNormFullDimRopeKernel(__nv_bfloat16* __restrict__ t
188188
{
189189
// rotate-half: partner element at +HEAD_DIM/2 within the same head.
190190
// Inline partner exchange (single reg `p` per iter, no array).
191+
// Use __activemask() because the surrounding chunk loop can have
192+
// `continue` early-exit on partial warps for small num_heads*HEAD_DIM.
191193
constexpr int xor_mask = HEAD_DIM / 16;
192194
bool const negate = ((laneId & xor_mask) == 0);
195+
unsigned const activeMask = __activemask();
193196
#pragma unroll
194197
for (int i = 0; i < CHUNK_ELEMS; i++)
195198
{
196-
float p = __shfl_xor_sync(0xffffffff, elements[i], xor_mask);
199+
float p = __shfl_xor_sync(activeMask, elements[i], xor_mask);
197200
if (negate)
198201
{
199202
p = -p;

cpp/tensorrt_llm/thop/fusedDiTSplitQKNormRopeOp.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ void fused_dit_split_norm_rope(torch::Tensor& tensor, int64_t num_heads, int64_t
3939

4040
CHECK_INPUT(tensor, torch::kBFloat16);
4141
CHECK_INPUT(weight, torch::kBFloat16);
42-
// Cos/sin may be fp32 (legacy) or bf16 (B-2: kernel upcasts in registers).
42+
// Cos/sin may be fp32 or bf16 (kernel upcasts bf16 to fp32 in registers, lossless).
4343
auto const cos_dtype = cos_emb.scalar_type();
4444
TORCH_CHECK(cos_dtype == torch::kFloat32 || cos_dtype == torch::kBFloat16,
4545
"cos_emb dtype must be float32 or bfloat16, got ", cos_dtype);

tensorrt_llm/_torch/visual_gen/models/ltx2/ltx2_core/transformer_args.py

Lines changed: 1 addition & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -32,12 +32,6 @@ class TransformerArgs:
3232
cross_scale_shift_timestep: torch.Tensor | None
3333
cross_gate_timestep: torch.Tensor | None
3434
enabled: bool
35-
# Sharded-local 2D contiguous [T_local, H*D] forms of *positional_embeddings*
36-
# and *cross_positional_embeddings*, computed once in prepare_text_cache
37-
# (loop-external) and threaded through to fused norm+rope kernels. The 4D
38-
# forms above stay around for the unfused fallback (apply_rotary_emb).
39-
positional_embeddings_2d: tuple[torch.Tensor, torch.Tensor] | None = None
40-
cross_positional_embeddings_2d: tuple[torch.Tensor, torch.Tensor] | None = None
4135

4236

4337
class TransformerArgsPreprocessor:
@@ -161,14 +155,11 @@ def prepare(
161155
static_mask: torch.Tensor | None,
162156
static_pe: tuple[torch.Tensor, torch.Tensor],
163157
static_cross_pe: tuple[torch.Tensor, torch.Tensor] | None = None,
164-
static_pe_2d: tuple[torch.Tensor, torch.Tensor] | None = None,
165-
static_cross_pe_2d: tuple[torch.Tensor, torch.Tensor] | None = None,
166158
) -> TransformerArgs:
167159
"""Build TransformerArgs for one denoise step.
168160
169161
Step-invariant static args are always required. *static_cross_pe*
170-
and *static_pe_2d* / *static_cross_pe_2d* are only meaningful when
171-
provided by the caller; ignored in this base class for *_cross_pe.
162+
is only used by the MultiModal subclass; ignored here.
172163
"""
173164
x = self.patchify_proj(modality.latent.contiguous())
174165
timestep, embedded_timestep = self._prepare_timestep(
@@ -185,8 +176,6 @@ def prepare(
185176
cross_scale_shift_timestep=None,
186177
cross_gate_timestep=None,
187178
enabled=modality.enabled,
188-
positional_embeddings_2d=static_pe_2d,
189-
cross_positional_embeddings_2d=None,
190179
)
191180

192181

@@ -266,16 +255,13 @@ def prepare(
266255
static_mask: torch.Tensor | None,
267256
static_pe: tuple[torch.Tensor, torch.Tensor],
268257
static_cross_pe: tuple[torch.Tensor, torch.Tensor],
269-
static_pe_2d: tuple[torch.Tensor, torch.Tensor] | None = None,
270-
static_cross_pe_2d: tuple[torch.Tensor, torch.Tensor] | None = None,
271258
) -> TransformerArgs:
272259
"""Build TransformerArgs for one denoise step with pre-computed static outputs."""
273260
transformer_args = self.simple_preprocessor.prepare(
274261
modality,
275262
static_context=static_context,
276263
static_mask=static_mask,
277264
static_pe=static_pe,
278-
static_pe_2d=static_pe_2d,
279265
)
280266
cross_scale_shift_timestep, cross_gate_timestep = self._prepare_cross_attention_timestep(
281267
timestep=modality.timesteps,
@@ -288,7 +274,6 @@ def prepare(
288274
cross_positional_embeddings=static_cross_pe,
289275
cross_scale_shift_timestep=cross_scale_shift_timestep,
290276
cross_gate_timestep=cross_gate_timestep,
291-
cross_positional_embeddings_2d=static_cross_pe_2d,
292277
)
293278

294279
def _prepare_cross_attention_timestep(

tensorrt_llm/_torch/visual_gen/models/ltx2/pipeline_ltx2.py

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -379,16 +379,12 @@ def _clone_value(v):
379379
video_context=v.video_context.clone() if v.video_context is not None else None,
380380
video_mask=v.video_mask.clone() if v.video_mask is not None else None,
381381
video_pe=clone_pair(v.video_pe),
382-
video_pe_2d=clone_pair(v.video_pe_2d),
383382
video_cross_pe=clone_pair(v.video_cross_pe),
384-
video_cross_pe_2d=clone_pair(v.video_cross_pe_2d),
385383
video_kv=[clone_pair(kv) for kv in v.video_kv] if v.video_kv is not None else None,
386384
audio_context=v.audio_context.clone() if v.audio_context is not None else None,
387385
audio_mask=v.audio_mask.clone() if v.audio_mask is not None else None,
388386
audio_pe=clone_pair(v.audio_pe),
389-
audio_pe_2d=clone_pair(v.audio_pe_2d),
390387
audio_cross_pe=clone_pair(v.audio_cross_pe),
391-
audio_cross_pe_2d=clone_pair(v.audio_cross_pe_2d),
392388
audio_kv=[clone_pair(kv) for kv in v.audio_kv] if v.audio_kv is not None else None,
393389
)
394390
if isinstance(v, torch.Tensor):
@@ -417,9 +413,7 @@ def _copy_value(dst, src):
417413
if dst.video_mask is not None and src.video_mask is not None:
418414
dst.video_mask.copy_(src.video_mask)
419415
copy_pair(dst.video_pe, src.video_pe)
420-
copy_pair(dst.video_pe_2d, src.video_pe_2d)
421416
copy_pair(dst.video_cross_pe, src.video_cross_pe)
422-
copy_pair(dst.video_cross_pe_2d, src.video_cross_pe_2d)
423417
if dst.video_kv is not None and src.video_kv is not None:
424418
for d, s in zip(dst.video_kv, src.video_kv):
425419
copy_pair(d, s)
@@ -428,9 +422,7 @@ def _copy_value(dst, src):
428422
if dst.audio_mask is not None and src.audio_mask is not None:
429423
dst.audio_mask.copy_(src.audio_mask)
430424
copy_pair(dst.audio_pe, src.audio_pe)
431-
copy_pair(dst.audio_pe_2d, src.audio_pe_2d)
432425
copy_pair(dst.audio_cross_pe, src.audio_cross_pe)
433-
copy_pair(dst.audio_cross_pe_2d, src.audio_cross_pe_2d)
434426
if dst.audio_kv is not None and src.audio_kv is not None:
435427
for d, s in zip(dst.audio_kv, src.audio_kv):
436428
copy_pair(d, s)

tensorrt_llm/_torch/visual_gen/models/ltx2/text_cache.py

Lines changed: 15 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -20,36 +20,37 @@
2020
class TextCache:
2121
"""Pre-computed text-derived tensors that are constant across denoise steps.
2222
23+
The ``*_pe`` fields hold sharded-local positional embeddings in the form
24+
the consumer wants:
25+
26+
- ``fuse_qk_norm_rope=True`` (LTX-2 default): 2D ``[T_local, H*D]``
27+
contiguous, fed directly to the fused norm+rope kernel.
28+
- ``fuse_qk_norm_rope=False``: 4D ``[B, T_local, H, D]`` sharded but
29+
otherwise unchanged, for the naive ``apply_rotary_emb`` path.
30+
31+
Form is decided at cache-build time (``LTXModel.prepare_text_cache``); no
32+
per-step reshape, ``.contiguous()``, or shard slicing.
33+
2334
Attributes:
2435
video_context: Projected text embedding for video cross-attention.
2536
video_mask: Attention mask for video text cross-attention.
26-
video_pe: RoPE (cos, sin) for video. 4D form [1, T, H, D], un-sharded.
27-
video_pe_2d: Sharded-local 2D contiguous form [T_local, H*D] of video_pe,
28-
fed directly to fused norm+rope kernels — skips per-step reshape +
29-
``.contiguous()`` in the hot helper.
37+
video_pe: Sharded-local RoPE (cos, sin) for video self-attn.
38+
video_cross_pe: Sharded-local RoPE for video AV cross-attn (audio-video model only).
3039
audio_context: Projected text embedding for audio cross-attention.
3140
audio_mask: Attention mask for audio text cross-attention.
32-
audio_pe: RoPE (cos, sin) for audio. 4D form, un-sharded.
33-
audio_pe_2d: Sharded-local 2D contiguous form of audio_pe.
34-
video_cross_pe: Cross-modal RoPE for video (audio-video model only).
35-
video_cross_pe_2d: Sharded-local 2D contiguous form of video_cross_pe.
36-
audio_cross_pe: Cross-modal RoPE for audio (audio-video model only).
37-
audio_cross_pe_2d: Sharded-local 2D contiguous form of audio_cross_pe.
41+
audio_pe: Sharded-local RoPE (cos, sin) for audio self-attn.
42+
audio_cross_pe: Sharded-local RoPE for audio AV cross-attn (audio-video model only).
3843
video_kv: Per-layer pre-projected text K/V for video cross-attention.
3944
audio_kv: Per-layer pre-projected text K/V for audio cross-attention.
4045
"""
4146

4247
video_context: Optional[torch.Tensor] = None
4348
video_mask: Optional[torch.Tensor] = None
4449
video_pe: Optional[tuple[torch.Tensor, torch.Tensor]] = None
45-
video_pe_2d: Optional[tuple[torch.Tensor, torch.Tensor]] = None
4650
video_cross_pe: Optional[tuple[torch.Tensor, torch.Tensor]] = None
47-
video_cross_pe_2d: Optional[tuple[torch.Tensor, torch.Tensor]] = None
4851
video_kv: Optional[list[tuple[torch.Tensor, torch.Tensor]]] = None
4952
audio_context: Optional[torch.Tensor] = None
5053
audio_mask: Optional[torch.Tensor] = None
5154
audio_pe: Optional[tuple[torch.Tensor, torch.Tensor]] = None
52-
audio_pe_2d: Optional[tuple[torch.Tensor, torch.Tensor]] = None
5355
audio_cross_pe: Optional[tuple[torch.Tensor, torch.Tensor]] = None
54-
audio_cross_pe_2d: Optional[tuple[torch.Tensor, torch.Tensor]] = None
5556
audio_kv: Optional[list[tuple[torch.Tensor, torch.Tensor]]] = None

0 commit comments

Comments
 (0)