Summary
Profiling a Wan 2.2 T2V A14B training run (Primus, MI300-class GPU, mbs=1) shows that 10.9 % of step time is attention with no FLOPs / no roofline in the TraceLens report: 6.83 % sits in the `other` bucket (aiter::fmha_v3_varlen_bwd, aten::_flash_attention_forward) and 3.92 % sits in `InferenceAttention` (aiter::fmha_v3_varlen_fwd, whose existing perf-model class requires an sglang/vLLM annotation that training traces don't have — so it returns GFLOPS = None while the row is silently labelled InferenceAttention).
This issue covers three fixes in one PR, plus the equivalent non-varlen path formerly tracked in #590 and the wrapper-layer name from #290.
Trace evidence (Wan 2.2 T2V A14B, Primus, BF16, mbs=1)
| Op |
Count |
ms |
% step |
xlsx category |
GFLOPS in xlsx |
aiter::fmha_v3_varlen_fwd |
160 |
1 553 |
3.92 |
InferenceAttention |
None |
aiter::fmha_v3_varlen_bwd |
80 |
2 705 |
6.83 |
other |
None |
aten::_flash_attention_forward |
21 |
68 |
0.17 |
other |
None |
aten::_scaled_dot_product_flash_attention |
1 |
0.06 |
<0.01 |
SDPA_fwd |
59.8 ✓ |
Category mix of the full step (from ops_summary_by_category):
| Category |
% |
ms |
| elementwise |
58.97 |
23 344 |
| reduce |
20.84 |
8 250 |
| other |
7.19 |
2 845 |
| GEMM |
6.49 |
2 568 |
| InferenceAttention |
3.92 |
1 553 |
1. New perf model — aiter::(wrapper_)fmha_v3_varlen_{fwd,bwd}
Currently uncategorized (varlen bwd) or annotation-only (varlen fwd, see item 3). Dominant: 6.83 % of step for aiter::fmha_v3_varlen_bwd alone, 80 events, mean 65.4 ms each.
Input Dims layout from ops_unique_args for aiter::fmha_v3_varlen_bwd:
((T, H, D), (T, H, D), (T, H, D), (T, H, D), (T, H, D), (H, T), (B+1,), (B+1,), (), ...)
Q K V O dO softmax_lse cu_q cu_kv
with (T, H, D) = (32 760, 40, 128) in this trace.
Concrete Inputs slots for the varlen bwd:
[8] = max_seqlen_q
[9] = max_seqlen_kv
[10] = dropout
[11] = scale (= 1/√d_h)
[12]/[13] = is_causal, window_size flags
Fix: add class aiter_fmha_v3_varlen_forward(SDPA) and class aiter_fmha_v3_varlen_backward(SDPA) in TraceLens/PerfModel/perf_model.py. Mirror the existing flash_attention_varlen_forward / flash_attention_varlen_backward (lines 2370 / 2473) — same flops_func / flops_bwd_func(..., flash_impl=True) math (bwd = 5/2 × fwd for square N_Q=N_KV), only the Input Dims / Concrete Inputs indices differ.
Add four mappings in torch_op_mapping.py:
aiter::fmha_v3_varlen_fwd → aiter_fmha_v3_varlen_forward
aiter::wrapper_fmha_v3_varlen_fwd → aiter_fmha_v3_varlen_forward
aiter::fmha_v3_varlen_bwd → aiter_fmha_v3_varlen_backward
aiter::wrapper_fmha_v3_varlen_bwd → aiter_fmha_v3_varlen_backward
Add aiter::fmha_v3_varlen_bwd and aiter::wrapper_fmha_v3_varlen_bwd to the sdpa_bwd_names list (line 270) so the category split routes them to SDPA_bwd.
Also added (related to #590): add aiter_fmha_v3_bwd(SDPA) for the non-varlen path (aiter::fmha_v3_bwd). Not exercised by Wan 2.2 but trivially symmetric to the existing aiter__fmha_v3_fwd (line 3020). Add to sdpa_bwd_names.
2. New perf model — aten::_flash_attention_forward
Currently other. Small here (0.17 %, 21 events) but the same dispatcher path appears in any PyTorch training run using torch.nn.functional.scaled_dot_product_attention with the flash backend selected.
Fix: add class aten___flash_attention_forward(SDPA) mirroring the existing aten__scaled_dot_product_flash_attention (line 2690). Input layout is (B, S, H, D) from Input Dims; scale / is_causal from Concrete Inputs. Map in torch_op_mapping.py.
3. Training-context fix for aiter::fmha_v3_varlen_fwd
The class aiter_fmha_v3_varlen_fwd exists in TraceLens/PerfModel/extensions/attention_perf_model_extensions.py but inherits from InferenceAttention and requires an inference annotation injected by sglang/vLLM to compute FLOPs. Training traces (Primus / Megatron / vanilla PyTorch) have no such annotation, so:
get_param_details returns no_perf_param_details()
GFLOPS, Data Moved, roofline all become None
- the row remains labelled
InferenceAttention, masking the gap
In this trace that hides 3.92 % / 1.5 s of step time.
Fix: with the new aiter_fmha_v3_varlen_forward(SDPA) class from item 1 mapped in torch_op_mapping.py, the training-context path goes through SDPA (shape-based FLOPs from Input Dims + Concrete Inputs) and the row lands in SDPA_fwd. The InferenceAttention-extension class (annotation path) remains for sglang/vLLM inference reports — preserve it under the existing pseudo-op decomposition pipeline so inference reports don't regress.
If both paths fire for the same event (unlikely in practice — they're triggered by different processing stages), the inference-extension annotation should win when present. Concretely, keep the inference extension as-is but ensure the core torch_op_mapping.py mapping is checked first and provides a non-None result; the extension path only contributes when the core path returned no_perf_param_details.
Expected impact on Wan 2.2
|
Pre-fix |
Post-fix |
other % of step |
7.19 |
≈ 0.35 |
InferenceAttention % of step |
3.92 |
0 (moved to SDPA_fwd) |
Attention rows in ops_summary with non-None GFLOPS |
1 / 4 |
4 / 4 |
Test plan
Related
Summary
Profiling a Wan 2.2 T2V A14B training run (Primus, MI300-class GPU, mbs=1) shows that 10.9 % of step time is attention with no FLOPs / no roofline in the TraceLens report: 6.83 % sits in the `other` bucket (
aiter::fmha_v3_varlen_bwd,aten::_flash_attention_forward) and 3.92 % sits in `InferenceAttention` (aiter::fmha_v3_varlen_fwd, whose existing perf-model class requires an sglang/vLLM annotation that training traces don't have — so it returnsGFLOPS = Nonewhile the row is silently labelledInferenceAttention).This issue covers three fixes in one PR, plus the equivalent non-varlen path formerly tracked in #590 and the wrapper-layer name from #290.
Trace evidence (Wan 2.2 T2V A14B, Primus, BF16, mbs=1)
aiter::fmha_v3_varlen_fwdInferenceAttentionaiter::fmha_v3_varlen_bwdotheraten::_flash_attention_forwardotheraten::_scaled_dot_product_flash_attentionSDPA_fwdCategory mix of the full step (from
ops_summary_by_category):1. New perf model —
aiter::(wrapper_)fmha_v3_varlen_{fwd,bwd}Currently uncategorized (varlen bwd) or annotation-only (varlen fwd, see item 3). Dominant: 6.83 % of step for
aiter::fmha_v3_varlen_bwdalone, 80 events, mean 65.4 ms each.Input Dimslayout fromops_unique_argsforaiter::fmha_v3_varlen_bwd:with
(T, H, D) = (32 760, 40, 128)in this trace.Concrete Inputsslots for the varlen bwd:[8]=max_seqlen_q[9]=max_seqlen_kv[10]= dropout[11]= scale (= 1/√d_h)[12]/[13]=is_causal,window_sizeflagsFix: add
class aiter_fmha_v3_varlen_forward(SDPA)andclass aiter_fmha_v3_varlen_backward(SDPA)inTraceLens/PerfModel/perf_model.py. Mirror the existingflash_attention_varlen_forward/flash_attention_varlen_backward(lines 2370 / 2473) — sameflops_func/flops_bwd_func(..., flash_impl=True)math (bwd = 5/2 × fwdfor squareN_Q=N_KV), only theInput Dims/Concrete Inputsindices differ.Add four mappings in
torch_op_mapping.py:aiter::fmha_v3_varlen_fwd→aiter_fmha_v3_varlen_forwardaiter::wrapper_fmha_v3_varlen_fwd→aiter_fmha_v3_varlen_forwardaiter::fmha_v3_varlen_bwd→aiter_fmha_v3_varlen_backwardaiter::wrapper_fmha_v3_varlen_bwd→aiter_fmha_v3_varlen_backwardAdd
aiter::fmha_v3_varlen_bwdandaiter::wrapper_fmha_v3_varlen_bwdto thesdpa_bwd_nameslist (line 270) so the category split routes them toSDPA_bwd.Also added (related to #590): add
aiter_fmha_v3_bwd(SDPA)for the non-varlen path (aiter::fmha_v3_bwd). Not exercised by Wan 2.2 but trivially symmetric to the existingaiter__fmha_v3_fwd(line 3020). Add tosdpa_bwd_names.2. New perf model —
aten::_flash_attention_forwardCurrently
other. Small here (0.17 %, 21 events) but the same dispatcher path appears in any PyTorch training run usingtorch.nn.functional.scaled_dot_product_attentionwith the flash backend selected.Fix: add
class aten___flash_attention_forward(SDPA)mirroring the existingaten__scaled_dot_product_flash_attention(line 2690). Input layout is(B, S, H, D)fromInput Dims;scale/is_causalfromConcrete Inputs. Map intorch_op_mapping.py.3. Training-context fix for
aiter::fmha_v3_varlen_fwdThe class
aiter_fmha_v3_varlen_fwdexists inTraceLens/PerfModel/extensions/attention_perf_model_extensions.pybut inherits fromInferenceAttentionand requires an inference annotation injected by sglang/vLLM to compute FLOPs. Training traces (Primus / Megatron / vanilla PyTorch) have no such annotation, so:get_param_detailsreturnsno_perf_param_details()GFLOPS,Data Moved, roofline all becomeNoneInferenceAttention, masking the gapIn this trace that hides 3.92 % / 1.5 s of step time.
Fix: with the new
aiter_fmha_v3_varlen_forward(SDPA)class from item 1 mapped intorch_op_mapping.py, the training-context path goes throughSDPA(shape-based FLOPs fromInput Dims+Concrete Inputs) and the row lands inSDPA_fwd. TheInferenceAttention-extension class (annotation path) remains for sglang/vLLM inference reports — preserve it under the existing pseudo-op decomposition pipeline so inference reports don't regress.If both paths fire for the same event (unlikely in practice — they're triggered by different processing stages), the inference-extension annotation should win when present. Concretely, keep the inference extension as-is but ensure the core
torch_op_mapping.pymapping is checked first and provides a non-None result; the extension path only contributes when the core path returnedno_perf_param_details.Expected impact on Wan 2.2
other% of stepInferenceAttention% of stepSDPA_fwd)ops_summarywith non-NoneGFLOPSTest plan
tests/PerfModel/test_perf_model.py: one per new class (aiter_fmha_v3_varlen_forward,aiter_fmha_v3_varlen_backward,aiter_fmha_v3_bwd,aten___flash_attention_forward), using fixedargspayloads.other≤ 0.5 % and no rows withname LIKE 'fmha%'or_flash_attention%have nullGFLOPS.aiter_fmha_v3_varlen_fwdtest: annotation path still wins, identical FLOPs to currentmain.blackformatting; existing test suite passes.Related
aiter::wrapper_fmha_v3_varlen_fwdperformance model #290 (aiter::wrapper_fmha_v3_varlen_fwd) and Perf model for aiter::fmha_v3_bwd #590 (aiter::fmha_v3_bwdnon-varlen) — all three are addressed by the same PR (feat(perfmodel): add varlen aiter FA + aten::_flash_attention_forward perf models (Wan 2.2 coverage) #651), which is the one that will actually close them.torch_op_mapping.py(lines ~36–47) and the cousin classesflash_attention_varlen_forward/flash_attention_varlen_backward(lines 2370 / 2473)