@@ -82,6 +82,60 @@ def _host_descriptor_pre_hook(nargs):
8282 nargs ["K" ].block_shape = [BLOCK_N , BLOCK_D_Q ]
8383
8484
85+ # pyre-ignore[2]
86+ def _early_config_prune (
87+ configs : List [triton .Config ],
88+ named_args ,
89+ ** kwargs ,
90+ ) -> List [triton .Config ]:
91+ """Filter autotune configs that are incompatible with the current call.
92+
93+ The TLX (warp-specialized) variant of ``_hstu_attn_fwd`` calls
94+ ``tlx.async_descriptor_load(Q, ...)`` which requires Q/K/V to be real TMA
95+ tensor descriptors (``tl.tensor_descriptor_base``). They are only
96+ constructed by the host wrapper when ``ENABLE_TMA=True`` AND the host
97+ ``TensorDescriptor`` API is importable. If the kernel is invoked without
98+ those preconditions, raw tensors flow into the TLX path and the
99+ ``isinstance(desc, tl.tensor_descriptor_base)`` assert in
100+ ``triton/language/extra/tlx/mem_ops.py`` fires at compile time.
101+
102+ We make autotuning robust to that mismatch by dropping any config with
103+ ``USE_TLX=True`` whenever ENABLE_TMA is not set or TMA host descriptors
104+ are unavailable. This is purely defensive: if the caller threads
105+ ``enable_tma=True`` (see ``_should_enable_tma`` below) the TLX configs
106+ remain eligible.
107+ """
108+ enable_tma = kwargs .get ("ENABLE_TMA" , None )
109+ if enable_tma is None :
110+ enable_tma = named_args .get ("ENABLE_TMA" , False )
111+ if enable_tma and tensor_descriptor_tma :
112+ return configs
113+ pruned = [c for c in configs if not c .kwargs .get ("USE_TLX" , False )]
114+ # Safety: never return an empty config list.
115+ return pruned if pruned else configs
116+
117+
118+ def _should_enable_tma () -> bool :
119+ """Return True iff the TMA / TLX fast path can be safely enabled.
120+
121+ Conditions:
122+ * The host ``triton.tools.tensor_descriptor.TensorDescriptor`` API is
123+ importable (``tensor_descriptor_tma``).
124+ * CUDA is available and the device is Hopper (compute capability 9),
125+ which is the only architecture for which TLX configs are emitted in
126+ ``_get_fw_configs``.
127+ """
128+ if not tensor_descriptor_tma :
129+ return False
130+ if not torch .cuda .is_available ():
131+ return False
132+ try :
133+ device_capability = torch .cuda .get_device_capability ()[0 ]
134+ except (RuntimeError , AssertionError ):
135+ return False
136+ return device_capability == 9
137+
138+
85139def _get_fw_configs () -> List [triton .Config ]: # noqa: C901
86140 configs = []
87141 if torch .version .hip :
@@ -1513,6 +1567,7 @@ def _hstu_attn_fwd_compute_tlx( # noqa C901
15131567 "DeltaSize" ,
15141568 "IS_DELTA_Q" ,
15151569 ],
1570+ prune_configs_by = {"early_config_prune" : _early_config_prune },
15161571)
15171572@triton .jit
15181573def _hstu_attn_fwd ( # noqa C901
@@ -1656,6 +1711,7 @@ def _hstu_attn_fwd( # noqa C901
16561711 "DeltaSize" ,
16571712 "IS_DELTA_Q" ,
16581713 ],
1714+ prune_configs_by = {"early_config_prune" : _early_config_prune },
16591715)
16601716@triton .jit
16611717def _hstu_attn_fwd_persistent ( # noqa C901
0 commit comments