Skip to content

Add perf models for aiter::(wrapper_)fmha_v3_varlen_{fwd,bwd} + aten::_flash_attention_forward + fix training-context categorization (Wan 2.2 coverage) #650

@gphuang

Description

@gphuang

Summary

Profiling a Wan 2.2 T2V A14B training run (Primus, MI300-class GPU, mbs=1) shows that 10.9 % of step time is attention with no FLOPs / no roofline in the TraceLens report: 6.83 % sits in the `other` bucket (aiter::fmha_v3_varlen_bwd, aten::_flash_attention_forward) and 3.92 % sits in `InferenceAttention` (aiter::fmha_v3_varlen_fwd, whose existing perf-model class requires an sglang/vLLM annotation that training traces don't have — so it returns GFLOPS = None while the row is silently labelled InferenceAttention).

This issue covers three fixes in one PR, plus the equivalent non-varlen path formerly tracked in #590 and the wrapper-layer name from #290.

Trace evidence (Wan 2.2 T2V A14B, Primus, BF16, mbs=1)

Op Count ms % step xlsx category GFLOPS in xlsx
aiter::fmha_v3_varlen_fwd 160 1 553 3.92 InferenceAttention None
aiter::fmha_v3_varlen_bwd 80 2 705 6.83 other None
aten::_flash_attention_forward 21 68 0.17 other None
aten::_scaled_dot_product_flash_attention 1 0.06 <0.01 SDPA_fwd 59.8 ✓

Category mix of the full step (from ops_summary_by_category):

Category % ms
elementwise 58.97 23 344
reduce 20.84 8 250
other 7.19 2 845
GEMM 6.49 2 568
InferenceAttention 3.92 1 553

1. New perf model — aiter::(wrapper_)fmha_v3_varlen_{fwd,bwd}

Currently uncategorized (varlen bwd) or annotation-only (varlen fwd, see item 3). Dominant: 6.83 % of step for aiter::fmha_v3_varlen_bwd alone, 80 events, mean 65.4 ms each.

Input Dims layout from ops_unique_args for aiter::fmha_v3_varlen_bwd:

((T, H, D), (T, H, D), (T, H, D), (T, H, D), (T, H, D), (H, T), (B+1,), (B+1,), (), ...)
 Q          K          V          O          dO         softmax_lse  cu_q   cu_kv

with (T, H, D) = (32 760, 40, 128) in this trace.

Concrete Inputs slots for the varlen bwd:

  • [8] = max_seqlen_q
  • [9] = max_seqlen_kv
  • [10] = dropout
  • [11] = scale (= 1/√d_h)
  • [12]/[13] = is_causal, window_size flags

Fix: add class aiter_fmha_v3_varlen_forward(SDPA) and class aiter_fmha_v3_varlen_backward(SDPA) in TraceLens/PerfModel/perf_model.py. Mirror the existing flash_attention_varlen_forward / flash_attention_varlen_backward (lines 2370 / 2473) — same flops_func / flops_bwd_func(..., flash_impl=True) math (bwd = 5/2 × fwd for square N_Q=N_KV), only the Input Dims / Concrete Inputs indices differ.

Add four mappings in torch_op_mapping.py:

  • aiter::fmha_v3_varlen_fwdaiter_fmha_v3_varlen_forward
  • aiter::wrapper_fmha_v3_varlen_fwdaiter_fmha_v3_varlen_forward
  • aiter::fmha_v3_varlen_bwdaiter_fmha_v3_varlen_backward
  • aiter::wrapper_fmha_v3_varlen_bwdaiter_fmha_v3_varlen_backward

Add aiter::fmha_v3_varlen_bwd and aiter::wrapper_fmha_v3_varlen_bwd to the sdpa_bwd_names list (line 270) so the category split routes them to SDPA_bwd.

Also added (related to #590): add aiter_fmha_v3_bwd(SDPA) for the non-varlen path (aiter::fmha_v3_bwd). Not exercised by Wan 2.2 but trivially symmetric to the existing aiter__fmha_v3_fwd (line 3020). Add to sdpa_bwd_names.

2. New perf model — aten::_flash_attention_forward

Currently other. Small here (0.17 %, 21 events) but the same dispatcher path appears in any PyTorch training run using torch.nn.functional.scaled_dot_product_attention with the flash backend selected.

Fix: add class aten___flash_attention_forward(SDPA) mirroring the existing aten__scaled_dot_product_flash_attention (line 2690). Input layout is (B, S, H, D) from Input Dims; scale / is_causal from Concrete Inputs. Map in torch_op_mapping.py.

3. Training-context fix for aiter::fmha_v3_varlen_fwd

The class aiter_fmha_v3_varlen_fwd exists in TraceLens/PerfModel/extensions/attention_perf_model_extensions.py but inherits from InferenceAttention and requires an inference annotation injected by sglang/vLLM to compute FLOPs. Training traces (Primus / Megatron / vanilla PyTorch) have no such annotation, so:

  • get_param_details returns no_perf_param_details()
  • GFLOPS, Data Moved, roofline all become None
  • the row remains labelled InferenceAttention, masking the gap

In this trace that hides 3.92 % / 1.5 s of step time.

Fix: with the new aiter_fmha_v3_varlen_forward(SDPA) class from item 1 mapped in torch_op_mapping.py, the training-context path goes through SDPA (shape-based FLOPs from Input Dims + Concrete Inputs) and the row lands in SDPA_fwd. The InferenceAttention-extension class (annotation path) remains for sglang/vLLM inference reports — preserve it under the existing pseudo-op decomposition pipeline so inference reports don't regress.

If both paths fire for the same event (unlikely in practice — they're triggered by different processing stages), the inference-extension annotation should win when present. Concretely, keep the inference extension as-is but ensure the core torch_op_mapping.py mapping is checked first and provides a non-None result; the extension path only contributes when the core path returned no_perf_param_details.

Expected impact on Wan 2.2

Pre-fix Post-fix
other % of step 7.19 ≈ 0.35
InferenceAttention % of step 3.92 0 (moved to SDPA_fwd)
Attention rows in ops_summary with non-None GFLOPS 1 / 4 4 / 4

Test plan

  • Unit tests in tests/PerfModel/test_perf_model.py: one per new class (aiter_fmha_v3_varlen_forward, aiter_fmha_v3_varlen_backward, aiter_fmha_v3_bwd, aten___flash_attention_forward), using fixed args payloads.
  • Regenerate the Wan 2.2 T2V A14B perf report and assert other ≤ 0.5 % and no rows with name LIKE 'fmha%' or _flash_attention% have null GFLOPS.
  • Regression check on the existing sglang/vLLM inference trace used by the original aiter_fmha_v3_varlen_fwd test: annotation path still wins, identical FLOPs to current main.
  • black formatting; existing test suite passes.

Related

Metadata

Metadata

Assignees

Labels

effort: M2-3 daysenhancementNew feature or requestperf_modelAdd performance model for calculating TFLOPS/s and TB/s

Type

No type
No fields configured for issues without a type.

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions