Skip to content

Commit fbe14ed

Browse files
committed
[None][test] ltx2: self-attn unit tests pass valid pe to match fused contract
LTX2Attention now hardcodes fuse_qk_norm_rope=True in __init__, which makes forward() route every self-attn call through the FUSE_QKV branch that unpacks `cos, sin = pe` unconditionally. The four self-attn sanity / backend- equivalence tests previously passed pe=None, which silently fell into the _forward_unfused path back when fuse_qk_norm_rope defaulted to False. That implicit reliance is gone, so the tests now hit a TypeError at line 265. Fix: build an identity-rotation RoPE tuple (cos=1, sin=0, shape [B,T,H,D]) in a `_make_pe` helper and pass it through pe= on the four self-attn testcases. cos/sin layout mirrors what `_split_freqs_cis` produces in production, so the tests exercise the same fused norm+RoPE kernel path without needing real RoPE angles. Identity rotation keeps the resulting shape and norm checks meaningful (q*1 + rotate_half(q)*0 = q). Cross-attention tests are unaffected — they go through SEPARATE_QKV and apply_split_norm_or_norm_rope, which already accepts pe=None as norm-only. E2E LTX-2 nvfp4 single-stage smoke test still passes (32.7s, 12.13 MB mp4). All six tests in test_ltx2_attention.py PASS after the change. Signed-off-by: Yiyun Lu <yiyunl@nvidia.com> Signed-off-by: Yiyun Lu <55233584+luyiyun1021@users.noreply.github.com>
1 parent db27b43 commit fbe14ed

1 file changed

Lines changed: 30 additions & 5 deletions

File tree

tests/unittest/_torch/visual_gen/test_ltx2_attention.py

Lines changed: 30 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,27 @@ def _init_weights(module: torch.nn.Module, std: float = 0.02):
5252
torch.nn.init.normal_(p, mean=0.0, std=std)
5353

5454

55+
def _make_pe(
56+
batch_size: int,
57+
seq_len: int,
58+
heads: int,
59+
head_dim: int,
60+
dtype: torch.dtype,
61+
device: str,
62+
) -> tuple[torch.Tensor, torch.Tensor]:
63+
"""Build an identity-rotation (cos=1, sin=0) RoPE tuple for self-attn tests.
64+
65+
LTX-2 self-attn forward (fuse_qk_norm_rope=True, head_dim ∈ {64, 128}) requires
66+
``pe`` to be a non-None ``(cos, sin)`` tuple in token-major [B, T, H, D] layout —
67+
the same shape ``_split_freqs_cis`` produces in production. cos=1, sin=0 makes
68+
the RoPE step an identity, so shape-only sanity checks remain meaningful while
69+
still exercising the fused norm+RoPE kernel.
70+
"""
71+
cos = torch.ones(batch_size, seq_len, heads, head_dim, device=device, dtype=dtype)
72+
sin = torch.zeros(batch_size, seq_len, heads, head_dim, device=device, dtype=dtype)
73+
return cos, sin
74+
75+
5576
class TestLTX2SelfAttention(unittest.TestCase):
5677
"""Test LTX2Attention self-attention with different backends."""
5778

@@ -86,9 +107,10 @@ def test_vanilla_self_attention_sanity(self):
86107
)
87108

88109
x = torch.randn(batch_size, seq_len, query_dim, device=self.DEVICE, dtype=dtype) * 0.02
110+
pe = _make_pe(batch_size, seq_len, heads, head_dim, dtype, self.DEVICE)
89111

90112
with torch.no_grad():
91-
output = attn(x, context=None, pe=None)
113+
output = attn(x, context=None, pe=pe)
92114

93115
self.assertEqual(output.shape, (batch_size, seq_len, query_dim))
94116

@@ -122,9 +144,10 @@ def test_trtllm_self_attention_sanity(self):
122144
)
123145

124146
x = torch.randn(batch_size, seq_len, query_dim, device=self.DEVICE, dtype=dtype) * 0.02
147+
pe = _make_pe(batch_size, seq_len, heads, head_dim, dtype, self.DEVICE)
125148

126149
with torch.no_grad():
127-
output = attn(x, context=None, pe=None)
150+
output = attn(x, context=None, pe=pe)
128151

129152
self.assertEqual(output.shape, (batch_size, seq_len, query_dim))
130153

@@ -248,9 +271,10 @@ def test_gated_self_attention_sanity(self):
248271
self.assertIsNotNone(attn.to_gate_logits, "Gated attention should create to_gate_logits")
249272

250273
x = torch.randn(batch_size, seq_len, query_dim, device=self.DEVICE, dtype=dtype) * 0.02
274+
pe = _make_pe(batch_size, seq_len, heads, head_dim, dtype, self.DEVICE)
251275

252276
with torch.no_grad():
253-
output = attn(x, context=None, pe=None)
277+
output = attn(x, context=None, pe=pe)
254278

255279
self.assertEqual(output.shape, (batch_size, seq_len, query_dim))
256280

@@ -308,10 +332,11 @@ def test_backend_equivalence(self):
308332
trtllm_attn.load_state_dict(vanilla_attn.state_dict())
309333

310334
x = torch.randn(batch_size, seq_len, query_dim, device=self.DEVICE, dtype=dtype) * 0.02
335+
pe = _make_pe(batch_size, seq_len, heads, head_dim, dtype, self.DEVICE)
311336

312337
with torch.no_grad():
313-
out_vanilla = vanilla_attn(x.clone(), context=None, pe=None)
314-
out_trtllm = trtllm_attn(x.clone(), context=None, pe=None)
338+
out_vanilla = vanilla_attn(x.clone(), context=None, pe=pe)
339+
out_trtllm = trtllm_attn(x.clone(), context=None, pe=pe)
315340

316341
# Skip comparison if either has NaN/Inf (can happen with random weights)
317342
has_nan = torch.isnan(out_vanilla).any() or torch.isnan(out_trtllm).any()

0 commit comments

Comments
 (0)