Skip to content

Commit fb64fd4

Browse files
zoranzhaometa-codesync[bot]
authored andcommitted
Make HSTU Triton attention TLX path safe under enable_tma=False
Reviewed By: htyu Differential Revision: D101714344 fbshipit-source-id: a6dc2c35f8ec9ef47dd28e42f2b959b111d16fa2
1 parent 5ccde78 commit fb64fd4

3 files changed

Lines changed: 66 additions & 1 deletion

File tree

generative_recommenders/ops/hstu_compute.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -206,6 +206,7 @@ def hstu_preprocess_and_attention(
206206
sort_by_length: bool,
207207
prefill: bool = False,
208208
kernel: HammerKernel = HammerKernel.PYTORCH,
209+
enable_tma: Optional[bool] = None,
209210
) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]:
210211
if not is_fx_tracing():
211212
torch._assert(max_seq_len > 0, "max_seq_len must be larger than 0")
@@ -239,6 +240,7 @@ def hstu_preprocess_and_attention(
239240
recompute_uvqk_in_backward=recompute_uvqk_in_backward,
240241
recompute_normed_x_in_backward=recompute_normed_x_in_backward,
241242
sort_by_length=sort_by_length,
243+
enable_tma=enable_tma,
242244
)
243245
attn_output = attn_output.view(-1, hidden_dim * num_heads)
244246
k = None

generative_recommenders/ops/triton/triton_hstu_attention.py

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,60 @@ def _host_descriptor_pre_hook(nargs):
8282
nargs["K"].block_shape = [BLOCK_N, BLOCK_D_Q]
8383

8484

85+
# pyre-ignore[2]
86+
def _early_config_prune(
87+
configs: List[triton.Config],
88+
named_args,
89+
**kwargs,
90+
) -> List[triton.Config]:
91+
"""Filter autotune configs that are incompatible with the current call.
92+
93+
The TLX (warp-specialized) variant of ``_hstu_attn_fwd`` calls
94+
``tlx.async_descriptor_load(Q, ...)`` which requires Q/K/V to be real TMA
95+
tensor descriptors (``tl.tensor_descriptor_base``). They are only
96+
constructed by the host wrapper when ``ENABLE_TMA=True`` AND the host
97+
``TensorDescriptor`` API is importable. If the kernel is invoked without
98+
those preconditions, raw tensors flow into the TLX path and the
99+
``isinstance(desc, tl.tensor_descriptor_base)`` assert in
100+
``triton/language/extra/tlx/mem_ops.py`` fires at compile time.
101+
102+
We make autotuning robust to that mismatch by dropping any config with
103+
``USE_TLX=True`` whenever ENABLE_TMA is not set or TMA host descriptors
104+
are unavailable. This is purely defensive: if the caller threads
105+
``enable_tma=True`` (see ``_should_enable_tma`` below) the TLX configs
106+
remain eligible.
107+
"""
108+
enable_tma = kwargs.get("ENABLE_TMA", None)
109+
if enable_tma is None:
110+
enable_tma = named_args.get("ENABLE_TMA", False)
111+
if enable_tma and tensor_descriptor_tma:
112+
return configs
113+
pruned = [c for c in configs if not c.kwargs.get("USE_TLX", False)]
114+
# Safety: never return an empty config list.
115+
return pruned if pruned else configs
116+
117+
118+
def _should_enable_tma() -> bool:
119+
"""Return True iff the TMA / TLX fast path can be safely enabled.
120+
121+
Conditions:
122+
* The host ``triton.tools.tensor_descriptor.TensorDescriptor`` API is
123+
importable (``tensor_descriptor_tma``).
124+
* CUDA is available and the device is Hopper (compute capability 9),
125+
which is the only architecture for which TLX configs are emitted in
126+
``_get_fw_configs``.
127+
"""
128+
if not tensor_descriptor_tma:
129+
return False
130+
if not torch.cuda.is_available():
131+
return False
132+
try:
133+
device_capability = torch.cuda.get_device_capability()[0]
134+
except (RuntimeError, AssertionError):
135+
return False
136+
return device_capability == 9
137+
138+
85139
def _get_fw_configs() -> List[triton.Config]: # noqa: C901
86140
configs = []
87141
if torch.version.hip:
@@ -1513,6 +1567,7 @@ def _hstu_attn_fwd_compute_tlx( # noqa C901
15131567
"DeltaSize",
15141568
"IS_DELTA_Q",
15151569
],
1570+
prune_configs_by={"early_config_prune": _early_config_prune},
15161571
)
15171572
@triton.jit
15181573
def _hstu_attn_fwd( # noqa C901
@@ -1656,6 +1711,7 @@ def _hstu_attn_fwd( # noqa C901
16561711
"DeltaSize",
16571712
"IS_DELTA_Q",
16581713
],
1714+
prune_configs_by={"early_config_prune": _early_config_prune},
16591715
)
16601716
@triton.jit
16611717
def _hstu_attn_fwd_persistent( # noqa C901

generative_recommenders/ops/triton/triton_hstu_preprocess_and_attention.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
triton_addmm_fwd,
2626
)
2727
from generative_recommenders.ops.triton.triton_hstu_attention import (
28+
_should_enable_tma,
2829
triton_hstu_attention_bwd,
2930
triton_hstu_attention_fwd,
3031
)
@@ -310,8 +311,14 @@ def triton_hstu_preprocess_and_attention(
310311
recompute_uvqk_in_backward: bool = False,
311312
recompute_normed_x_in_backward: bool = False,
312313
sort_by_length: bool = False,
313-
enable_tma: bool = False,
314+
enable_tma: Optional[bool] = None,
314315
) -> Tuple[torch.Tensor, torch.Tensor]:
316+
# When the caller does not specify enable_tma, auto-detect whether the
317+
# TMA / TLX fast path is safe on this device. Resolving here (vs inside
318+
# the autograd Function.forward) keeps a concrete bool flowing through
319+
# ctx.save_for_backward / ctx attributes.
320+
if enable_tma is None:
321+
enable_tma = _should_enable_tma()
315322
return _HSTUPreprocessAndAttentionFunction.apply(
316323
x,
317324
norm_weight,

0 commit comments

Comments
 (0)