Skip to content

feat: add get flashinfer-trace interface .fi_trace#2931

Open
yyihuang wants to merge 29 commits intoflashinfer-ai:mainfrom
yyihuang:fi_trace
Open

feat: add get flashinfer-trace interface .fi_trace#2931
yyihuang wants to merge 29 commits intoflashinfer-ai:mainfrom
yyihuang:fi_trace

Conversation

@yyihuang
Copy link
Copy Markdown
Collaborator

@yyihuang yyihuang commented Mar 31, 2026

📌 Description

Adds a trace layer to FlashInfer so every public kernel can be described as a portable benchmark / replay definition without tying the description to any particular launcher.

  • flashinfer/trace/template.py — new TraceTemplate schema with named axes (Var/Const), typed inputs/outputs (Tensor/Scalar), optional reference implementation, and tag/constraint metadata.
  • flashinfer/trace/templates/*.py — one module per operator family (attention, cascade, GDN, GEMM, MoE, norm, activation, sampling). Each file declares the schema and, where feasible, an executable reference.
  • @flashinfer_api(trace=...) (extends the existing decorator in flashinfer/api_logging.py) — attaches .fi_trace() to the decorated function/method and, when FLASHINFER_TRACE_DUMP=1, writes a per-shape JSON definition to FLASHINFER_TRACE_DUMP_DIR before the kernel runs (crash-safe).
  • fi_trace() helpers — public entry points for programmatic trace generation from any @flashinfer_api-decorated API or a bound method.
  • tests/trace/ — template-consistency tests (signature ↔ axes/inputs), end-to-end reference checks, and an example.py that drives a realistic workload and dumps 45 tests/trace/fi_trace_out/*.json definitions across LLaMA-3.1, DeepSeek-V3, Gemma, Qwen3-Next, etc.

Why

  • Lets external tooling (e.g. flashinfer-bench) consume a single self-describing JSON per op instead of reverse-engineering Python call sites.
  • Gives each kernel a single authoritative description of its shape constraints, keeping plan→run wrappers, benchmarks, and regression tests aligned.
  • Zero overhead at FLASHINFER_LOGLEVEL=0 / FLASHINFER_TRACE_DUMP unset (decorator is a no-op in that path).

Covered APIs

Attention (paged/ragged prefill, paged decode, MLA), sampling (top-k/top-p/top-k-top-p), GEMM (bf16, fp8, mxfp8, fp4), fused MoE (fp8/fp4 block-scale × 6 routing methods), norm (rmsnorm, fused-add, quant variants, Gemma variants, layernorm), activation (silu/gelu/gelu-tanh + mul), cascade merge (state/state-in-place/states), and GDN (decode, MTP, chunk prefill).

🔍 Related Issues

🚀 Pull Request Checklist

✅ Pre-commit Checks

  • I have installed pre-commit by running pip install pre-commit.
  • I have installed the hooks with pre-commit install.
  • I have run the hooks manually with pre-commit run --all-files and fixed any reported issues.

🧪 Tests

  • Tests have been added or updated as needed (tests/trace/test_fi_trace.py, tests/trace/test_fi_trace_template_consistency.py).
  • All tests are passing locally (pytest tests/trace/ -v → 139 passed).

Reviewer Notes

  • The decorator ordering is load-bearing: @flashinfer_api(trace=...) must be innermost so trace dump runs even when surrounding @backend_requirement raises for unsupported capability. For mm_fp4 / mm_mxfp8 on SM<100 the outer @backend_requirement raises before the dump, which is why their JSONs are only regenerated on Blackwell. See tests/trace/example.py for the realistic workload.
  • Subsequent fix-up commit drops three redundant @flashinfer_api decorators that caused double-logging at FLASHINFER_LOGLEVEL=3+ (subclass __init__ overrides and trtllm_low_latency_gemm internal helper).
  • Known remaining CodeRabbit items (follow-up work, not blocking this PR): some MoE reference helpers still hardcode H/I/top_k/n_group; gdn_prefill_trace lacks the head-ratio constraints that gdn_decode / gdn_mtp already have; the E2E test synthesizer uses 0 for int32 inputs, which makes some synthesized definitions nonsensical.

🤖 Generated with Claude Code

@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai Bot commented Mar 31, 2026

Note

Reviews paused

It looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the reviews.auto_review.auto_pause_after_reviewed_commits setting.

Use the following commands to manage reviews:

  • @coderabbitai resume to resume automatic reviews.
  • @coderabbitai review to trigger a single review.

Use the checkboxes below for quick actions:

  • ▶️ Resume reviews
  • 🔍 Trigger review
📝 Walkthrough

Walkthrough

This change adds a TraceTemplate-based tracing system, fi_trace generation and registration, attaches trace templates via an extended @flashinfer_api(trace=...) decorator across many kernels (attention, GEMM, sampling, norm, GDN, MoE), and includes tests and JSON trace fixtures to validate generation and consistency.

Changes

Cohort / File(s) Summary
Core trace infra
flashinfer/trace/template.py, flashinfer/trace/__init__.py, flashinfer/fi_trace.py
New TraceTemplate schema, axis/input/output descriptors, builders for fi_trace functions, legacy registry support, and JSON dump behavior.
Decorator & registry
flashinfer/api_logging.py, flashinfer/__init__.py
Extended flashinfer_api to accept trace=; _attach_fi_trace and _TRACE_REGISTRY added; fi_trace re-exported at package level.
Trace templates package
flashinfer/trace/templates/__init__.py, flashinfer/trace/templates/*.py
Many new per-op TraceTemplate modules: attention.py, gdn.py, gemm.py, moe.py, norm.py, sampling.py and dispatch helpers for MoE/GEMM/attention templates.
Instrumentation of APIs
flashinfer/attention.py, flashinfer/decode.py, flashinfer/prefill.py, flashinfer/gdn_decode.py, flashinfer/gdn_prefill.py, flashinfer/gemm/gemm_base.py, flashinfer/fused_moe/core.py, flashinfer/mla/_core.py, flashinfer/mla/cute_dsl/mla_decode.py, flashinfer/norm/__init__.py, flashinfer/sampling.py, flashinfer/trtllm_low_latency_gemm.py
Added or parameterized @flashinfer_api usages with trace= to attach appropriate templates; some constructors/entrypoints newly decorated to expose .fi_trace.
Tests & examples
tests/trace/test_fi_trace.py, tests/trace/test_fi_trace_template_consistency.py, tests/trace/example.py
New unit/integration tests to validate fi_trace generation, template signature/axis coverage, E2E JSON dumping, and an example trace-generation script.
JSON fixtures
tests/trace/fi_trace_out/*.json
~30+ new JSON trace specification fixtures covering sampling, GEMM variants, attention (GQA/MLA/DSA), GDN, RMSNorm, MoE (FP8/FP4) for use in tests and examples.
Docs / guide update
.claude/skills/add-cuda-kernel/SKILL.md
Updated workflow to require TraceTemplate creation and registration; guidance for attaching templates to APIs and running consistency tests.

Sequence Diagram(s)

sequenceDiagram
  participant Client
  participant API as flashinfer_api wrapper
  participant Template as TraceTemplate / dispatcher
  participant FiTrace as fi_trace builder
  participant FS as Filesystem

  Client->>API: call decorated function (possibly trace=callable)
  API->>Template: resolve trace_template (dispatch or static)
  Template-->>FiTrace: build fi_trace_fn (bind reference, axes, inputs/outputs)
  API->>FiTrace: if tracing enabled -> invoke fi_trace_fn(**bound_args)
  FiTrace->>FS: write <name>.json (if save_dir or env enabled)
  FiTrace-->>API: return trace dict
  API-->>Client: execute original function and return result
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~60 minutes

Possibly related PRs

Suggested reviewers

  • yzh119
  • bkryu
  • aleozlx
  • cyx-6
  • jimmyzho
  • yongwww
  • nv-yunzheq
  • IwakuraRein
  • saltyminty
  • sricketts
  • samuellees
  • kahyunnam
  • jiahanc
  • dhiraj113

Poem

🐰 I nibble code and stitch a trace,

Templates bloom in every place,
APIs now wear a tiny tag,
JSON crumbs inside the bag,
Hops and hops — the traces race! 🎉

🚥 Pre-merge checks | ✅ 2 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 62.26% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (2 passed)
Check name Status Explanation
Title check ✅ Passed The title clearly and concisely describes the main change: adding a .fi_trace interface to get flashinfer-trace functionality. It directly relates to the significant additions of fi_trace module, trace templates, and API decorators throughout the changeset.
Description check ✅ Passed The PR description is comprehensive and detailed with clear structure, objectives, implementation details, and rationale for the trace layer feature.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request adds the @flashinfer_api decorator to multiple classes and functions across the library, including attention and decode wrappers as well as GEMM execution utilities, to enable API logging. The review feedback points out that applying this decorator to subclasses whose base classes are already decorated results in redundant log entries. Additionally, nested calls between decorated functions may lead to duplicate logging, suggesting that the logging logic should handle re-entrancy or that certain decorators should be removed to reduce overhead.

Comment thread flashinfer/attention.py Outdated
a convenient interface for using attention sinks during prefill or decode attention.
"""

@flashinfer_api
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Adding @flashinfer_api to BatchAttentionWithAttentionSinkWrapper.__init__ will result in double logging during initialization. This class inherits from BatchPrefillWithPagedKVCacheWrapper, whose __init__ method is already decorated with @flashinfer_api. Since the decorator uses the class name of the instance (args[0]), both the subclass and base class decorators will log an entry for BatchAttentionWithAttentionSinkWrapper.__init__. This redundancy clutters the logs and adds unnecessary overhead. Consider removing the decorator from the subclass if the base class logging is sufficient for your tracing needs.

Comment thread flashinfer/decode.py Outdated
:class:`BatchDecodeWithPagedKVCacheWrapper`
"""

@flashinfer_api
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Similar to the issue in BatchAttentionWithAttentionSinkWrapper, decorating CUDAGraphBatchDecodeWithPagedKVCacheWrapper.__init__ leads to redundant log entries because its base class BatchDecodeWithPagedKVCacheWrapper.__init__ is already decorated. Both will log as CUDAGraphBatchDecodeWithPagedKVCacheWrapper.__init__ due to how the decorator resolves the class name from the instance.

Comment thread flashinfer/trtllm_low_latency_gemm.py Outdated
)


@flashinfer_api
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Decorating trtllm_low_latency_gemm will cause double logging when it is called internally by other decorated APIs, such as mm_fp8 in flashinfer/gemm/gemm_base.py. While it is important to trace this function when called directly, the current logging implementation will produce redundant entries for nested calls. This should ideally be addressed in the logging decorator's logic to handle re-entrancy, but for now, be aware of the log duplication.

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

Note

Due to the large number of review comments, Critical severity comments were prioritized as inline comments.

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (2)
flashinfer/gdn_prefill.py (1)

86-100: 🛠️ Refactor suggestion | 🟠 Major

Add backend capability gating on this SM90-only API.

chunk_gated_delta_rule documents an SM90 requirement but is not decorated with @backend_requirement. Please add the backend/capability gate alongside @flashinfer_api(...) so unsupported devices fail fast with a clear message.

As per coding guidelines: Use @backend_requirement decorator on APIs that have compute capability requirements and provide is_compute_capability_supported(cc) and is_backend_supported() methods.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/gdn_prefill.py` around lines 86 - 100, The function
chunk_gated_delta_rule is SM90-only but lacks the backend capability guard; add
the `@backend_requirement`(...) decorator alongside
`@flashinfer_api`(trace=gdn_prefill_trace) to check is_backend_supported() and
is_compute_capability_supported(cc) for SM90 and return a clear fail-fast
message for unsupported devices. Use the decorator to declare the required
compute capability (SM90) and backend, referencing chunk_gated_delta_rule so the
check runs before execution and produces a helpful error if the device is not
supported.
flashinfer/trtllm_low_latency_gemm.py (1)

119-125: 🛠️ Refactor suggestion | 🟠 Major

Add @backend_requirement for this Blackwell-only entrypoint.

trtllm_low_latency_gemm is documented as Blackwell-only, but the API is not gated with @backend_requirement. Please add the explicit capability/backend guard so callers get deterministic early validation.

As per coding guidelines: Use @backend_requirement decorator on APIs that have compute capability requirements and provide is_compute_capability_supported(cc) and is_backend_supported() methods.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/trtllm_low_latency_gemm.py` around lines 119 - 125, Add the
`@backend_requirement` decorator to the trtllm_low_latency_gemm entrypoint to gate
it to Blackwell-only execution: place `@backend_requirement`(...) immediately
above the trtllm_low_latency_gemm definition and provide checks that call the
module's support helpers (e.g., is_compute_capability_supported and
is_backend_supported) or small wrapper functions that return True only for
Blackwell compute capability/backend; ensure the decorator references the
correct check functions so callers receive deterministic early validation for
Blackwell-only usage of trtllm_low_latency_gemm.
🟠 Major comments (23)
flashinfer/trace/example/fi_trace_out/gemm_mxfp8_N4096_K4096.json-32-38 (1)

32-38: ⚠️ Potential issue | 🟠 Major

GEMM reference uses an incompatible transpose with declared shapes.

With A: [M, K] and B: [K, N] (Line 32–38), Line 66 should compute A @ B, not A @ B.T. Current reference is dimensionally inconsistent.

Suggested fix
-    return torch.matmul(A_scaled, B_scaled.T).to(torch.bfloat16)\n"
+    return torch.matmul(A_scaled, B_scaled).to(torch.bfloat16)\n"

Also applies to: 66-66

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/trace/example/fi_trace_out/gemm_mxfp8_N4096_K4096.json` around
lines 32 - 38, The GEMM reference is using an incompatible transpose for tensor
B given the declared shapes "A": [M,K] and "B": [K,N]; update the computation
that currently multiplies A by B.T to multiply A by B instead so the operation
becomes A @ B (ensure the result shape is [M,N]), and verify any accompanying
description/metadata (e.g., keys "A", "B" and dtype "float8_e4m3fn") and
comments reflect no transpose on B.
flashinfer/trace/example/fi_trace_out/top_k_sampling_v128256.json-38-43 (1)

38-43: ⚠️ Potential issue | 🟠 Major

Output dtype is inconsistent with the traced API contract.

Line 42 declares samples as int64, but this API path returns int32 by default (when indices is not provided). The reference in Line 46 also allocates int64, so both schema and reference are misaligned.

Suggested fix
-      "dtype": "int64",
+      "dtype": "int32",
-    samples = torch.empty(batch_size, dtype=torch.int64, device=device)\n
+    samples = torch.empty(batch_size, dtype=torch.int32, device=device)\n

Also applies to: 46-46

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/trace/example/fi_trace_out/top_k_sampling_v128256.json` around
lines 38 - 43, The schema and reference for the "samples" field in
top_k_sampling_v128256.json incorrectly use dtype "int64" while the API returns
int32 by default; change the "samples" dtype from "int64" to "int32" in the JSON
schema and update the corresponding reference allocation that currently creates
int64 to allocate int32 instead (look for the "samples" field and any reference
example/allocation near the "shape": ["batch_size"] and the later allocation on
Line 46 to ensure both schema and example match int32).
flashinfer/trace/example/fi_trace_out/fused_add_rmsnorm_h5120.json-49-58 (1)

49-58: ⚠️ Potential issue | 🟠 Major

Reference return signature does not match declared outputs.

Line 49–56 declares two outputs (output, residual), but Line 58’s reference returns only one tensor. This makes the trace definition internally inconsistent for validators/consumers.

Suggested fix
-    return y.to(hidden_states.dtype)\n"
+    return y.to(hidden_states.dtype), x.to(hidden_states.dtype)\n"
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/trace/example/fi_trace_out/fused_add_rmsnorm_h5120.json` around
lines 49 - 58, The reference function _fused_add_rmsnorm_reference currently
returns only the normalized output tensor but the trace declares two outputs
("output" and "residual"), so update the reference to return both values to
match the schema: compute y as now and also produce the updated residual
(residual + hidden_states in float32, cast back to residual.dtype) and return
(y, updated_residual) (or alternatively change the trace outputs to a single
"output" if the residual should not be returned); ensure names/ordering match
the declared outputs.
flashinfer/trace/example/fi_trace_out/gemm_bf16_N256_K7168.json-30-36 (1)

30-36: ⚠️ Potential issue | 🟠 Major

BF16 GEMM reference is inconsistent with input shape declaration.

Given B shape [K, N] (Line 30–36), Line 48 should not transpose B for matmul. The current expression conflicts with the stated tensor contract.

Suggested fix
-  "reference": "def _mm_reference(A, B):\n    return torch.matmul(A, B.T)\n"
+  "reference": "def _mm_reference(A, B):\n    return torch.matmul(A, B)\n"

Also applies to: 48-48

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/trace/example/fi_trace_out/gemm_bf16_N256_K7168.json` around lines
30 - 36, The JSON metadata declares tensor "B" with shape ["K","N"] (physical
column-major [K, N]) but the matmul expression erroneously transposes B; update
the matmul expression that currently uses B.T (or otherwise transposes "B") so
it uses "B" directly to match the declared [K,N] contract, and ensure any
accompanying description/comment is adjusted to reflect no transpose is applied.
flashinfer/trace/example/fi_trace_out/gemm_fp4_N2048_K7168_block_size16.json-29-43 (1)

29-43: ⚠️ Potential issue | 🟠 Major

FP4 GEMM schema and reference are shape-inconsistent.

Line 29–43 declares unpacked shapes ([M, K], [K, N]) while Line 71 treats A/B as packed bytes and reconstructs logical dims by multiplying by 2. On top of that, the final B_scaled.T introduces another dimension mismatch.

Please make schema and reference consistent in one direction:

  1. keep packed semantics and declare packed shapes, or
  2. keep unpacked shapes and remove nibble-unpack logic.
    Also, the final GEMM should not transpose B_scaled under the current shape declarations.

Also applies to: 71-71

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/trace/example/fi_trace_out/gemm_fp4_N2048_K7168_block_size16.json`
around lines 29 - 43, The schema declares A and B as packed uint8 tensors but
later code treats them as unpacked nibbles and reconstructs logical dims (the
nibble-unpack logic) and then does B_scaled.T which creates a mismatch; pick one
consistent approach and fix both schema and code: either (A) declare A/B shapes
as packed (bytes) and keep the nibble-unpack/reconstruction code that expands to
logical shapes but remove the final transpose of B_scaled (or transpose before
unpacking) so GEMM uses matching [M,K] and [K,N], or (B) declare A/B as unpacked
shapes ([M,K], [K,N]) and remove the nibble-unpack/reconstruction entirely;
update the "description" fields (fp4 e2m1fn_x2 packed as uint8) and references
to B_scaled and its transpose to match the chosen convention (adjust usage of
B_scaled.T accordingly).
flashinfer/gdn_decode.py-349-350 (1)

349-350: 🛠️ Refactor suggestion | 🟠 Major

Add @backend_requirement on these SM-constrained public APIs.

At Line 349 and Line 490, these APIs are decorated for tracing but still lack explicit backend capability guards at the API boundary.

As per coding guidelines: Use @backend_requirement decorator on APIs that have compute capability requirements and provide is_compute_capability_supported(cc) and is_backend_supported() methods.

Also applies to: 490-491

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/gdn_decode.py` around lines 349 - 350, The public API function
gated_delta_rule_decode (decorated with
`@flashinfer_api`(trace=gated_delta_rule_decode_trace)) is SM-constrained and must
be guarded by the backend capability decorator; add `@backend_requirement` above
its definition and implement the decorator to call the module's
is_compute_capability_supported(cc) and is_backend_supported() helpers. Do the
same for the other SM-constrained public API in this file that is currently
decorated with `@flashinfer_api` around the later section (the second gated rule
decode API at the other occurrence) so both API entrypoints check compute
capability and backend support before proceeding.
flashinfer/gdn_decode.py-36-53 (1)

36-53: ⚠️ Potential issue | 🟠 Major

Decouple trace-template import from flashinfer_api fallback.

If trace template import fails but flashinfer_api is available, the current combined try block still falls back to a no-op decorator, silently disabling API logging/tracing behavior.

♻️ Suggested fix
-try:
-    from .api_logging import flashinfer_api
-    from .trace.templates.gdn import (
-        gated_delta_rule_decode_trace,
-        gdn_mtp_trace,
-    )
-    _FLASHINFER_AVAILABLE = True
-except ImportError:
-    _FLASHINFER_AVAILABLE = False
-    gated_delta_rule_decode_trace = None  # type: ignore[assignment]
-    gdn_mtp_trace = None  # type: ignore[assignment]
-
-    # Fallback decorator for standalone usage (accepts trace= kwarg)
-    def flashinfer_api(func=None, *, trace=None):  # type: ignore[misc]
-        if func is None:
-            return lambda f: f
-        return func
+try:
+    from .api_logging import flashinfer_api
+    _FLASHINFER_AVAILABLE = True
+except ImportError:
+    _FLASHINFER_AVAILABLE = False
+    def flashinfer_api(func=None, *, trace=None):  # type: ignore[misc]
+        if func is None:
+            return lambda f: f
+        return func
+
+try:
+    from .trace.templates.gdn import (
+        gated_delta_rule_decode_trace,
+        gdn_mtp_trace,
+    )
+except ImportError:
+    gated_delta_rule_decode_trace = None  # type: ignore[assignment]
+    gdn_mtp_trace = None  # type: ignore[assignment]
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/gdn_decode.py` around lines 36 - 53, The combined try/except hides
a missing trace-template import by replacing flashinfer_api with a no-op; split
imports so flashinfer_api is imported in its own try/except and sets
_FLASHINFER_AVAILABLE, then separately attempt to import
gated_delta_rule_decode_trace and gdn_mtp_trace and only set them to None on
failure—define the fallback flashinfer_api decorator only when the
flashinfer_api import itself fails so trace import failures do not disable API
logging/tracing.
flashinfer/trace/templates/sampling.py-24-41 (1)

24-41: ⚠️ Potential issue | 🟠 Major

The sampling references are not reproducible as written.

These references call torch.multinomial, but the template schema does not carry any RNG input, seed, or pre-generated random variate. The same trace payload can therefore emit different samples across runs, which makes the generated definitions unstable as reference artifacts. Please encode the randomness in the trace inputs or make the reference deterministic.

Also applies to: 79-103, 141-173

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/trace/templates/sampling.py` around lines 24 - 41, The
_top_k_sampling_reference function (and the other sampling reference blocks at
lines 79-103 and 141-173) currently calls torch.multinomial which uses
nondeterministic RNG not captured by the trace; change the reference signatures
to accept explicit randomness (e.g., a per-sample uniform variates tensor or an
RNG seed/tensor) and use those inputs to deterministically draw samples: after
filtering/renormalizing the probabilities in _top_k_sampling_reference, compute
the cumulative distribution and select the token whose cdf first exceeds the
provided uniform variate for that batch (instead of torch.multinomial), and
apply the same pattern to the other sampling reference functions so the
randomness is fully encoded in trace inputs.
flashinfer/api_logging.py-1497-1503 (1)

1497-1503: ⚠️ Potential issue | 🟠 Major

Don’t silently disable .fi_trace on attachment errors.

These except Exception: pass blocks turn template/build failures into invisible feature loss: a broken trace template can quietly remove .fi_trace, and FLASHINFER_TRACE_DUMP=1 can fail to write anything without surfacing why. Please preserve the failure in a stub fi_trace or emit a warning instead of dropping it on the floor.

Also applies to: 1516-1517

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/api_logging.py` around lines 1497 - 1503, The current try/except
around calling fi_trace_fn (guarded by _is_trace_dump_enabled and using
_sig.bind(...)) swallows all exceptions and silently disables .fi_trace; change
this to catch Exception as e, log or warn about the attachment/templating error
(include the exception), and install a stub fi_trace function that preserves the
attribute but emits the warning (or raises) when invoked so the feature failure
is visible; apply the same change to the analogous block that appears for the
second attachment (the other try/except using _sig.bind and fi_trace_fn).
flashinfer/trace/example/fi_trace_out/gqa_paged_decode_h32_kv8_d128_ps64.json-112-112 (1)

112-112: ⚠️ Potential issue | 🟠 Major

The decode reference mixes page IDs with token indices.

Line 112 declares kv_indices as page IDs, but the reference flattens the cache to [num_pages * page_size, ...] and indexes that flattened tensor directly with those IDs. With page_size=64, each selected page contributes only its first token to k_b/v_b, so the logits and outputs are wrong. Please fix the upstream template to index pages first, then flatten to tokens, and regenerate this artifact.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In
`@flashinfer/trace/example/fi_trace_out/gqa_paged_decode_h32_kv8_d128_ps64.json`
at line 112, The reference implementation _gqa_paged_decode_reference
incorrectly treats kv_indices as token indices against k_flat/v_flat; instead
treat kv_indices as page IDs: use kv_indptr to select pages, index
k_cache/v_cache by page IDs to get per-page tensors, then reshape/flatten each
selected page into tokens (or index within page using page_size) before
computing k_b and v_b; update the logic that computes k_flat/v_flat (or remove
flattening) so you first select pages via kv_indices[page_start:page_end] ->
page_ids, then gather k_cache[page_ids] and v_cache[page_ids] and reshape to
token dimension prior to matmuls, then regenerate the artifact.
flashinfer/trace/example/fi_trace_out/gqa_paged_prefill_h32_kv8_d128_ps16.json-119-119 (1)

119-119: ⚠️ Potential issue | 🟠 Major

Schema/reference mismatch for kv_indices.

Line 119 says kv_indices are page IDs, but the reference indexes k_cache.reshape(-1, ...) / v_cache.reshape(-1, ...) with those IDs and sets num_kv_tokens from the number of pages. With page_size=16, that drops page_size - 1 tokens from every selected page, so the causal window and outputs are wrong for paged inputs. Please fix the source template to gather pages first, then flatten their token dimension, and regenerate this artifact.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In
`@flashinfer/trace/example/fi_trace_out/gqa_paged_prefill_h32_kv8_d128_ps16.json`
at line 119, The reference treats kv_indices as page IDs but indexes
k_cache/v_cache after reshaping into pages (k_flat/v_flat), which incorrectly
drops the per-page token dimension; in _gqa_paged_prefill_reference gather the
full pages first (use page_ids = kv_indices[kv_start:kv_end] to index k_cache
and v_cache by page dimension), then flatten the page-token axis so k_b and v_b
include all page_size tokens (adjust k_flat/v_flat usage or index
k_cache/v_cache directly), set num_kv_tokens = page_ids.shape[0] * page_size,
and update loops that compute max_kv, logits, attn, and output to iterate over
the flattened token sequence accordingly (refer to symbols kv_indices, k_cache,
v_cache, k_flat, v_flat, page_size, num_kv_tokens).
flashinfer/trace/example/fi_trace_out/mla_paged_decode_h16_ckv512_kpe64_ps64.json-123-123 (1)

123-123: ⚠️ Potential issue | 🟠 Major

This ps64 reference still assumes one token per page.

Line 123 indexes kv_indices into ckv_cache/kpe_cache without flattening the selected page_size dimension first. In this file page_size is 64, so Kc/Kp remain page tensors instead of [L, D] token matrices, and the subsequent decode matmuls no longer implement the declared operator. Please fix the upstream template for multi-token pages and regenerate this example.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In
`@flashinfer/trace/example/fi_trace_out/mla_paged_decode_h16_ckv512_kpe64_ps64.json`
at line 123, The reference _mla_paged_decode_reference assumes one-token pages
but for page_size=64 you must flatten the per-page token dimension after
selecting pages: when building Kc/Kp from Kc_all/Kp_all (currently from
ckv_cache.squeeze(1)/kpe_cache.squeeze(1)) do Kc_all[tok_idx] and
Kp_all[tok_idx] then reshape/flatten the result to [L, head_dim_ckv] and [L,
head_dim_kpe] respectively (e.g., .reshape(-1, head_dim_ckv) / .reshape(-1,
head_dim_kpe]) before computing logits and softmax) so the decode matmuls use
token-level matrices; update _mla_paged_decode_reference to flatten selected
pages accordingly and regenerate the example.
flashinfer/fi_trace.py-273-280 (1)

273-280: ⚠️ Potential issue | 🟠 Major

The public helper never actually falls back to the legacy registry.

This module keeps _REGISTRY, register_fi_trace(), and build_fi_trace_fn() for backwards compatibility, but fi_trace() only checks actual_func.fi_trace. Any legacy caller that registered a spec by qualname will still hit the No fi_trace spec is registered path.

Possible fix
     actual_func = getattr(func_or_method, "__func__", func_or_method)
     trace_fn = getattr(actual_func, "fi_trace", None)
     if trace_fn is None:
-        qualname = getattr(actual_func, "__qualname__", repr(actual_func))
-        raise ValueError(
-            f"No fi_trace spec is registered for '{qualname}'. "
-            "Only `@flashinfer_api`(trace=...)-decorated functions support fi_trace."
-        )
+        qualname = getattr(actual_func, "__qualname__", None)
+        spec = _REGISTRY.get(qualname) if qualname is not None else None
+        if spec is not None:
+            trace_fn = build_fi_trace_fn(spec)
+        else:
+            qualname = qualname or repr(actual_func)
+            raise ValueError(
+                f"No fi_trace spec is registered for '{qualname}'. "
+                "Only `@flashinfer_api`(trace=...)-decorated functions support fi_trace."
+            )
     return trace_fn(save_dir=save_dir, **kwargs)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/fi_trace.py` around lines 273 - 280, The public helper fi_trace
currently only checks the bound attribute actual_func.fi_trace and never looks
up the legacy registry, so entries registered via register_fi_trace/_REGISTRY or
built via build_fi_trace_fn are ignored; update the code after obtaining
qualname to fall back to the legacy registry by looking up _REGISTRY[qualname]
or calling build_fi_trace_fn(qualname) (using the same qualname computed from
actual_func.__qualname__ or repr(actual_func)) and use that trace_fn when
present before raising the ValueError so legacy-registered specs are honored.
tests/test_fi_trace.py-357-362 (1)

357-362: ⚠️ Potential issue | 🟠 Major

These use-case tests allocate model-sized tensors even though fi_trace only inspects metadata.

The num_pages=8192 decode case materializes about 512 MiB of KV cache, and the MLA example adds another ~288 MiB, just to read .shape and .dtype. That is likely to slow or OOM CI without adding coverage. Please shrink these fixtures or move the model-scale examples out of the unit suite.

Also applies to: 418-424

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/test_fi_trace.py` around lines 357 - 362, The test allocates
model-sized tensors (num_pages, page_size, q, k_cache, v_cache) even though
fi_trace only reads metadata; reduce memory by shrinking num_pages and page_size
to small values (e.g., single- or double-digit sizes) or replace large concrete
tensors with lightweight stand-ins (small shaped tensors or meta-device tensors)
in the test vectors q, k_cache, v_cache used by test_fi_trace functions; apply
the same change to the other occurrence around lines 418-424 to avoid CI OOMs
while preserving the shape/dtype intent for fi_trace.
flashinfer/trace/example/fi_trace_out/gdn_mtp_qk4_v8_d128.json-170-170 (1)

170-170: ⚠️ Potential issue | 🟠 Major

final_state never reflects the updates computed in the loop.

state_HVK is mutated for each token, but the function returns initial_state.clone() without writing any updated state back into it. The example therefore emits the original state pool while the schema says final_state is the updated recurrent state.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/trace/example/fi_trace_out/gdn_mtp_qk4_v8_d128.json` at line 170,
The function returns initial_state.clone() even though state_HVK is updated per
token; fix by writing the updated per-pool state_HVK back into the state pool
using the same indexing flow you used to read it: after finishing the token loop
for a batch item (or whenever you update state_HVK), assign
final_state[state_idx] = state_HVK.transpose(-1, -2) (or update initial_state
in-place) so that final_state (returned) contains the mutated states; ensure you
use initial_state_indices/state_idx to map back and preserve dtype/device the
same way intermediate_states_buffer is handled.
flashinfer/trace/templates/gemm.py-111-215 (1)

111-215: ⚠️ Potential issue | 🟠 Major

The non-BF16 templates emit shapes with undefined or mismatched axes.

mm_fp8_trace uses K_div_block_size / block_size, mm_mxfp8_trace uses K_div_32, and mm_fp4_trace uses K_div_block_size / N_div_block_size, but none of those derived dimensions are declared in axes or tied back to K/N with constraints. mm_fp4_trace also labels packed uint8 operands as logical [M, K] and [K, N], so the discovered axis values will be off on the packed dimension. The resulting JSON is not self-contained for these ops.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/trace/templates/gemm.py` around lines 111 - 215, The templates
mm_fp8_trace, mm_mxfp8_trace, and mm_fp4_trace declare derived dimensions
(K_div_block_size, block_size, K_div_32, N_div_block_size) in their Tensor
shapes but never define them in axes or relate them back to K/N; also mm_fp4
inputs are described as logical [M,K]/[K,N] while the stored packed uint8 layout
changes the packed dimension. Fix by adding explicit axes entries for each
derived dimension in the axes dict (e.g., "block_size", "K_div_block_size",
"K_div_32", "N_div_block_size" or a packed axis like "K_packed") and document
the arithmetic relationship (K_div_block_size = K // block_size, K_div_32 = K //
32, N_div_block_size = N // block_size or K_packed = packed_length_of(K) for
fp4); then update the corresponding Tensor shapes in mm_fp8_trace,
mm_mxfp8_trace, and mm_fp4_trace to reference those axes (and adjust A/B shapes
for fp4 to use the packed axis instead of logical K) so the JSON is
self-contained and axis relationships are explicit.
flashinfer/trace/example/fi_trace_out/gqa_paged_decode_h32_kv8_d128_ps16.json-112-112 (1)

112-112: ⚠️ Potential issue | 🟠 Major

The paged decode reference is indexing page IDs as if they were token IDs.

kv_indices is documented here as a page-ID array, but k_cache.reshape(-1, ...) / v_cache.reshape(-1, ...) followed by ...[token_ids] only selects one row per page and drops the remaining page_size - 1 tokens. If this reference is used for verification, any multi-token page will compare against the wrong attention result. Based on learnings, when native paged KV layout is used, page indices are not supposed to be flattened into token indices.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In
`@flashinfer/trace/example/fi_trace_out/gqa_paged_decode_h32_kv8_d128_ps16.json`
at line 112, The reference implementation in _gqa_paged_decode_reference
incorrectly treats kv_indices as token IDs by indexing into k_flat/v_flat
(created by k_cache.reshape(-1,...)), which drops tokens within multi-token
pages; instead, treat kv_indices as page IDs: extract pages via k_cache[pages]
and v_cache[pages] (where pages =
kv_indices[page_start:page_end].to(torch.long)), then combine the page_size
dimension (e.g., .reshape(-1, num_kv_heads, head_dim)) so all tokens in each
page are included before computing logits/attention; update k_b, v_b, and any
downstream uses to reflect this page->token expansion while keeping q_b and
gqa_ratio logic unchanged.
flashinfer/trace/templates/gemm.py-22-85 (1)

22-85: ⚠️ Potential issue | 🟠 Major

Fix B tensor handling in quantized GEMM references and resolve undefined symbolic dimensions.

The quantized GEMM references have multiple critical issues:

  1. Matrix multiply semantics: All references multiply with B.T despite describing B with physical shape [K, N]. This is mathematically incorrect: [M, K] @ [K, N].T = [M, K] @ [N, K] has mismatched inner dimensions. The references should either remove the transpose or update schemas to describe B as [N, K].

  2. FP8 block layout: _mm_fp8_reference() reshapes [K//block_size, N, block_size] directly to [K, N] without permuting first. TRT-LLM block layout requires permutation before reshape to reconstruct the original matrix correctly (i.e., .reshape(K_div_bs, block_size, N).permute(1, 0, 2).reshape(K, N)).

  3. FP4 decoding: _unpack_fp4() extracts raw nibble values (0–15) via bitwise masking and casts to float32 without decoding the e2m1fn format. The reference cannot serve as a correctness oracle without proper FP4 value lookup or conversion.

  4. Undefined symbolic axes: The FP8, MXFP8, and FP4 templates reference symbolic dimensions (K_div_block_size, K_div_32, N_div_block_size, block_size) not declared in their axes dictionaries, preventing proper schema validation.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/trace/templates/gemm.py` around lines 22 - 85, The GEMM refs
incorrectly transpose B and misuse block layouts: update _mm_reference,
_mm_fp8_reference, _mm_mxfp8_reference, and _mm_fp4_reference so matmul uses A @
B (not A @ B.T) if B is intended as [K, N], or alternatively document/reshape B
to [N, K] consistently; in _mm_fp8_reference before reshaping apply the TRT-LLM
permutation (reshape to [K_div_bs, block_size, N] then permute(1,0,2) then
reshape to [K, N]) instead of direct reshape; replace _unpack_fp4 with proper
e2m1fn decoding (use a lookup/decode table to map 4-bit nibble values to
float32) rather than raw nibble casts so FP4 semantics are correct; and add the
missing symbolic axis declarations for K_div_bs/K_div_32, N_div_block_size and
block_size in the template axes metadata so schema validation can resolve those
symbols (referencing the functions _mm_fp8_reference, _mm_mxfp8_reference,
_mm_fp4_reference and helper _unpack_fp4 to locate the changes).
flashinfer/trace/templates/attention.py-113-116 (1)

113-116: ⚠️ Potential issue | 🟠 Major

Add the grouped-query head constraints to the GQA templates.

The GQA references rely on num_qo_heads // num_kv_heads being a valid grouping factor, but the schema currently accepts shapes where num_qo_heads < num_kv_heads or the ratio is non-integral. In those cases the reference either divides by zero or walks kv_h past the last KV head. Please add num_qo_heads >= num_kv_heads and num_qo_heads % num_kv_heads == 0 here, and mirror the same invariant in gqa_paged_prefill_trace and gqa_ragged_prefill_trace.

Possible fix
     constraints=[
         "len_indptr == batch_size + 1",
         "num_kv_indices == kv_indptr[-1].item()",
+        "num_qo_heads >= num_kv_heads",
+        "num_qo_heads % num_kv_heads == 0",
     ],
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/trace/templates/attention.py` around lines 113 - 116, The schema
for the GQA templates currently allows invalid head groupings; update the
constraints list in the attention template (the constraints array where
"len_indptr == batch_size + 1" and "num_kv_indices == kv_indptr[-1].item()" are
defined) to also require "num_qo_heads >= num_kv_heads" and "num_qo_heads %
num_kv_heads == 0", and apply the same two invariants to the corresponding
constraint lists in gqa_paged_prefill_trace and gqa_ragged_prefill_trace so the
grouped-query computation (which uses num_qo_heads // num_kv_heads and kv head
indexing) never divides by zero or indexes past the last KV head.
flashinfer/trace/templates/gdn.py-164-168 (1)

164-168: ⚠️ Potential issue | 🟠 Major

Enforce seq_len == 1 in the decode template.

_gdn_decode_reference depends on squeeze(1) removing the time axis. If seq_len is anything else, it starts repeating along the sequence dimension instead of the head dimension and the reference becomes invalid. The description already says decode is single-token, so make that a hard constraint.

Possible fix
     constraints=[
+        "seq_len == 1",
         "num_v_heads >= num_q_heads",
         "num_v_heads % num_q_heads == 0",
         "num_k_heads == num_q_heads",
     ],
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/trace/templates/gdn.py` around lines 164 - 168, Add a hard
constraint enforcing single-token decoding by adding "seq_len == 1" to the
decode template's constraints list so the template and its consumer
_gdn_decode_reference (which relies on squeeze(1) removing the time axis) never
run with seq_len > 1; update the constraints array (the one containing
"num_v_heads >= num_q_heads", "num_v_heads % num_q_heads == 0", "num_k_heads ==
num_q_heads") to include "seq_len == 1".
flashinfer/trace/templates/gdn.py-327-330 (1)

327-330: ⚠️ Potential issue | 🟠 Major

Prefill is missing the GVA head-shape invariants used by the reference.

The prefill reference expands Q/K with num_v_heads // num_q_heads and num_v_heads // num_k_heads, so it needs the same head relationship guarantees as decode/MTP. Right now the schema accepts shapes that can truncate the repeat factor or produce an output whose head axis no longer matches the declared num_v_heads.

Possible fix
     constraints=[
         "len_cu_seqlens == num_seqs + 1",
         "total_seq_len == cu_seqlens[-1].item()",
+        "num_v_heads >= num_q_heads",
+        "num_v_heads % num_q_heads == 0",
+        "num_k_heads == num_q_heads",
     ],
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/trace/templates/gdn.py` around lines 327 - 330, The schema is
missing invariants that guarantee the GVA head expansion used in prefill; add
constraints to the same constraints list to require divisibility so expansions
don't truncate heads: include "num_v_heads % num_q_heads == 0" and "num_v_heads
% num_k_heads == 0" (referring to the symbols num_v_heads, num_q_heads,
num_k_heads) so the prefill expansion of Q/K by num_v_heads // num_q_heads and
num_v_heads // num_k_heads preserves the declared num_v_heads head axis.
flashinfer/trace/templates/attention.py-42-43 (1)

42-43: ⚠️ Potential issue | 🟠 Major

Don't treat page ids as flattened token ids.

After reshape(-1, ...), indexing with raw kv_indices only fetches one flattened slot per page and ignores the other page_size - 1 entries. That makes the paged GQA reference wrong for any page_size > 1, and the same pattern repeats in _gqa_paged_prefill_reference, _mla_paged_decode_reference, and _mla_paged_prefill_reference. Either materialize full pages and trim the last page with explicit length metadata, or constrain these paged templates to page_size == 1.

Possible direction
-    k_flat = k_cache.reshape(-1, num_kv_heads, head_dim).to(torch.float32)
-    v_flat = v_cache.reshape(-1, num_kv_heads, head_dim).to(torch.float32)
+    # Materialize selected pages first; then flatten tokens within those pages.
+    # The last page still needs an explicit length input to trim padding correctly.

...
-        token_ids = kv_indices[page_start:page_end].to(torch.long)
-        k_b = k_flat[token_ids]  # [T, num_kv_heads, head_dim]
-        v_b = v_flat[token_ids]
+        page_ids = kv_indices[page_start:page_end].to(torch.long)
+        k_b = k_cache[page_ids].reshape(-1, num_kv_heads, head_dim).to(torch.float32)
+        v_b = v_cache[page_ids].reshape(-1, num_kv_heads, head_dim).to(torch.float32)

Also applies to: 51-53

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/trace/templates/attention.py` around lines 42 - 43, The code
flattens page slots with k_cache.reshape(-1, num_kv_heads, head_dim)
(k_flat/v_flat) and then indexes with kv_indices, which treats page ids as
single flattened token ids and therefore drops the other page_size-1 entries;
update the paged templates (_gqa_paged_prefill_reference,
_mla_paged_decode_reference, _mla_paged_prefill_reference) to either (A)
materialize full page slices before flattening (i.e., expand/reshape to include
page_size dimension, gather full pages using kv_indices, then trim the final
partial page using explicit length metadata) or (B) enforce/validate page_size
== 1 at the start of these functions and raise an error if otherwise; ensure all
uses of k_flat, v_flat and kv_indices are adjusted accordingly so each page
returns all its key/value slots rather than a single flattened slot.
flashinfer/trace/templates/gdn.py-377-410 (1)

377-410: ⚠️ Potential issue | 🟠 Major

Write the updated slot back before returning final_state.

state_HVK is updated for every token, but nothing persists it into the returned pool. As written, final_state = initial_state.clone() returns the original state unchanged, which contradicts the template contract and breaks stateful verification.

Possible fix
     output = torch.zeros(
         (B, T, num_v_heads, head_size), dtype=torch.bfloat16, device=device
     )
+    final_state = initial_state.clone().float()
     cache_intermediate = intermediate_states_buffer is not None

     for b_idx in range(B):
         state_idx = int(initial_state_indices[b_idx].item())
-        state_HVK = initial_state[state_idx].clone().float().transpose(-1, -2)  # [H,V,K] -> [H,K,V]
+        state_HVK = final_state[state_idx].transpose(-1, -2).clone()  # [H,V,K] -> [H,K,V]

         for t in range(T):
             ...
             if cache_intermediate:
                 intermediate_states_buffer[state_idx, t] = state_HVK.transpose(-1, -2)  # [H,K,V] -> [H,V,K]
+
+        final_state[state_idx] = state_HVK.transpose(-1, -2)

-    final_state = initial_state.clone()
     return output, final_state
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/trace/templates/gdn.py` around lines 377 - 410, Summary:
final_state is returned unchanged because updated state_HVK is never written
back to the pool. Fix: clone initial_state into final_state before the outer
loop (or otherwise create a mutable final_state), and after processing each
batch element (after the inner t loop where state_HVK holds the final per-head
slot), write the updated slot back via final_state[state_idx] =
state_HVK.transpose(-1, -2) so the returned final_state reflects the updates;
reference symbols: initial_state, final_state, state_HVK, state_idx,
initial_state_indices.
🟡 Minor comments (4)
flashinfer/trace/template.py-371-376 (1)

371-376: ⚠️ Potential issue | 🟡 Minor

Silent exception swallowing may hide bugs during axis extraction.

Catching bare Exception and passing silently can mask unexpected errors (e.g., TypeError, AttributeError) that indicate template misconfiguration or API misuse. Consider logging at debug level or being more specific about expected exceptions.

🔧 Proposed fix to add debug logging
+import logging
+
+_logger = logging.getLogger(__name__)
+
 # In fi_trace function:
             for axis_name, extractor in axis_extractors.items():
                 try:
                     val = extractor(kwargs)
                     if val is not None:
                         axis_values[axis_name] = val
-                except Exception:
-                    pass
+                except Exception as e:
+                    _logger.debug("Failed to extract axis %r: %s", axis_name, e)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/trace/template.py` around lines 371 - 376, The try/except around
calling extractor(kwargs) silently swallows all exceptions (using bare except),
which can hide bugs; update the except to capture the exception as e and emit a
debug-level log including axis_name, extractor, and the exception/traceback (or
narrow the except to expected errors like KeyError/IndexError/ValueError if
applicable) before continuing, ensuring axis_values and axis_name remain
unchanged on failure; reuse the module's logger instance (or create one if none
exists) so the failure context is recorded for debugging.
flashinfer/trace/templates/norm.py-56-88 (1)

56-88: ⚠️ Potential issue | 🟡 Minor

Reference implementation return value mismatch with template outputs.

The _fused_add_rmsnorm_reference function returns only y (single tensor), but fused_add_rmsnorm_trace defines two outputs: output and residual. The reference should return both to match the template schema.

🐛 Proposed fix to return both outputs
 `@torch.no_grad`()
 def _fused_add_rmsnorm_reference(hidden_states, residual, weight):
     """Fused Add + RMSNorm. Epsilon is fixed at 1e-6."""
     EPS = 1e-6
-    x = hidden_states.to(torch.float32) + residual.to(torch.float32)
+    residual_updated = hidden_states + residual
+    x = residual_updated.to(torch.float32)
     inv_rms = torch.rsqrt(x.pow(2).mean(dim=-1, keepdim=True) + EPS)
     y = (x * inv_rms) * weight.to(torch.float32)
-    return y.to(hidden_states.dtype)
+    return y.to(hidden_states.dtype), residual_updated
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/trace/templates/norm.py` around lines 56 - 88, The reference
function _fused_add_rmsnorm_reference currently returns only y but the
TraceTemplate fused_add_rmsnorm_trace declares two outputs ("output" and updated
"residual"); update _fused_add_rmsnorm_reference to return a tuple (output,
residual_out) where residual_out reflects the in-place semantics described
(residual += hidden_states) — e.g., compute residual_out =
residual.to(torch.float32) + hidden_states.to(torch.float32) (or perform an
in-place add if appropriate), cast both y and residual_out back to the original
hidden_states dtype, and return them in the same order as the template outputs.
flashinfer/trace/example/__main__.py-1-1 (1)

1-1: ⚠️ Potential issue | 🟡 Minor

Replace wildcard import with explicit module reference.

Line 1 uses from .example import *, which triggers Ruff F403 and obscures what is actually imported. Since example.py defines no __all__ and is structured as a side-effect module (not an export container), use from . import example instead to make the intent explicit.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/trace/example/__main__.py` at line 1, Replace the wildcard import
in __main__.py: remove "from .example import *" and import the module explicitly
(use "from . import example") so the code references names via the example
module; update any direct references that relied on the star-import to be
prefixed with "example." to keep intent explicit and satisfy Ruff F403.
flashinfer/trace/example/example.py-129-130 (1)

129-130: ⚠️ Potential issue | 🟡 Minor

Avoid silent except Exception: pass in the example runner.

These blocks hide unexpected failures and can make the generated trace set look complete when it is not.

♻️ Suggested fix
-except Exception:
-    pass  # Requires Blackwell (SM100+)
+except Exception as e:
+    print(f"[skip] mm_mxfp8 example not run: {e}")  # Requires Blackwell (SM100+)

-except Exception:
-    pass  # Requires Blackwell (SM100+)
+except Exception as e:
+    print(f"[skip] mm_fp4 example not run: {e}")  # Requires Blackwell (SM100+)

-except Exception:
-    pass  # May require specific GPU/TRT-LLM support
+except Exception as e:
+    print(f"[skip] trtllm_fp8_block_scale_moe example not run: {e}")

Also applies to: 140-141, 276-277

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/trace/example/example.py` around lines 129 - 130, Replace each
silent "except Exception: pass" in the example runner (the three occurrences
that suppress errors for the Blackwell/SM100+ path) with targeted handling:
catch only the expected exception (e.g., ImportError or ModuleNotFoundError when
Blackwell is absent) or, if you must continue on error, log the full exception
with logging.exception or traceback.print_exc including contextual information
about which trace/step failed; do not swallow unexpected exceptions—re-raise
them after logging so real failures are visible.
🧹 Nitpick comments (2)
flashinfer/trace/template.py (1)

474-474: Consider using spread operator for list construction.

Per static analysis, using spread syntax is more idiomatic.

♻️ Suggested change
-            all_tags = [f"fi_api:{fi_api}"] + template.tags
+            all_tags = [f"fi_api:{fi_api}", *template.tags]
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/trace/template.py` at line 474, Replace the list concatenation
used to build all_tags with Python list unpacking for readability: instead of
creating all_tags via [f"fi_api:{fi_api}"] + template.tags, construct it using
the spread/unpacking form to include f"fi_api:{fi_api}" and all elements from
template.tags (referencing variables all_tags, fi_api, and template.tags in
template.py).
flashinfer/trace/__init__.py (1)

23-25: Consider sorting __all__ and reconsidering private export.

Per static analysis, __all__ should be sorted. Additionally, _TRACE_DUMP_DIR has a private naming convention (underscore prefix) but is exported publicly—consider renaming to TRACE_DUMP_DIR if it's meant for external use, or documenting why it's exposed.

♻️ Suggested sorted `__all__`
-__all__ = ["TraceTemplate", "Var", "Const", "Tensor", "Scalar", "_TRACE_DUMP_DIR"]
+__all__ = ["Const", "Scalar", "Tensor", "TraceTemplate", "Var", "_TRACE_DUMP_DIR"]
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/trace/__init__.py` around lines 23 - 25, The __all__ list is
unsorted and exposes a name with a leading underscore (_TRACE_DUMP_DIR) which
conflicts with its private naming; update the __all__ declaration so entries are
alphabetically sorted (e.g., Const, Scalar, Tensor, TraceTemplate, Var) and
decide whether _TRACE_DUMP_DIR is meant to be public—if so rename it to
TRACE_DUMP_DIR in template.py and here and export that, otherwise remove it from
__all__ (or add a comment/docstring explaining why the underscored name is
intentionally exported) so exports and naming are consistent; adjust
imports/usage accordingly (TraceTemplate, Var, Const, Tensor, Scalar, and the
chosen dump-dir symbol).

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: 834f4a1a-3013-4d83-80ed-7022baffd452

📥 Commits

Reviewing files that changed from the base of the PR and between 2ca0d38 and 8453636.

📒 Files selected for processing (49)
  • flashinfer/__init__.py
  • flashinfer/api_logging.py
  • flashinfer/attention.py
  • flashinfer/decode.py
  • flashinfer/fi_trace.py
  • flashinfer/fused_moe/core.py
  • flashinfer/gdn_decode.py
  • flashinfer/gdn_prefill.py
  • flashinfer/gemm/gemm_base.py
  • flashinfer/mla/_core.py
  • flashinfer/mla/cute_dsl/mla_decode.py
  • flashinfer/norm/__init__.py
  • flashinfer/prefill.py
  • flashinfer/sampling.py
  • flashinfer/trace/__init__.py
  • flashinfer/trace/example/__main__.py
  • flashinfer/trace/example/example.py
  • flashinfer/trace/example/fi_trace_out/fused_add_rmsnorm_h5120.json
  • flashinfer/trace/example/fi_trace_out/gdn_decode_qk4_v8_d128.json
  • flashinfer/trace/example/fi_trace_out/gdn_mtp_qk4_v8_d128.json
  • flashinfer/trace/example/fi_trace_out/gemm_bf16_N256_K7168.json
  • flashinfer/trace/example/fi_trace_out/gemm_bf16_N4096_K4096.json
  • flashinfer/trace/example/fi_trace_out/gemm_fp4_N2048_K7168_block_size16.json
  • flashinfer/trace/example/fi_trace_out/gemm_fp8_N1536_K7168.json
  • flashinfer/trace/example/fi_trace_out/gemm_mxfp8_N4096_K4096.json
  • flashinfer/trace/example/fi_trace_out/gqa_paged_decode_h32_kv8_d128_ps16.json
  • flashinfer/trace/example/fi_trace_out/gqa_paged_decode_h32_kv8_d128_ps64.json
  • flashinfer/trace/example/fi_trace_out/gqa_paged_prefill_h32_kv8_d128_ps16.json
  • flashinfer/trace/example/fi_trace_out/gqa_ragged_h32_kv8_d128.json
  • flashinfer/trace/example/fi_trace_out/mla_paged_decode_h16_ckv512_kpe64_ps1.json
  • flashinfer/trace/example/fi_trace_out/mla_paged_decode_h16_ckv512_kpe64_ps64.json
  • flashinfer/trace/example/fi_trace_out/moe_fp8_block_scale_default_routing_topk8_e32_h7168_i2048.json
  • flashinfer/trace/example/fi_trace_out/rmsnorm_h4096.json
  • flashinfer/trace/example/fi_trace_out/rmsnorm_h7168.json
  • flashinfer/trace/example/fi_trace_out/top_k_sampling_v128256.json
  • flashinfer/trace/example/fi_trace_out/top_k_top_p_sampling_v128256.json
  • flashinfer/trace/example/fi_trace_out/top_k_top_p_sampling_v151936.json
  • flashinfer/trace/example/fi_trace_out/top_p_sampling_v128256.json
  • flashinfer/trace/example/fi_trace_out/top_p_sampling_v151936.json
  • flashinfer/trace/template.py
  • flashinfer/trace/templates/__init__.py
  • flashinfer/trace/templates/attention.py
  • flashinfer/trace/templates/gdn.py
  • flashinfer/trace/templates/gemm.py
  • flashinfer/trace/templates/moe.py
  • flashinfer/trace/templates/norm.py
  • flashinfer/trace/templates/sampling.py
  • flashinfer/trtllm_low_latency_gemm.py
  • tests/test_fi_trace.py

"dtype": "bfloat16"
}
},
"reference": "def _mm_fp8_reference(A, B):\n \"\"\"Dequantize FP8 block-scale inputs and compute C = A @ B.T.\n\n B is in TRT-LLM block layout [K//block_size, N, block_size] and is\n reshaped to [K, N] before the matmul.\n \"\"\"\n K_div_bs, N, block_size = B.shape\n B_fp32 = B.reshape(K_div_bs * block_size, N).to(torch.float32)\n A_fp32 = A.to(torch.float32)\n return torch.matmul(A_fp32, B_fp32.T).to(torch.bfloat16)\n"
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Fix the embedded reference matmul transpose.

At Line 50, B_fp32 is reshaped to [K, N], so A_fp32 @ B_fp32.T is invalid when K != N (here 7168 != 1536). The reference should multiply with B_fp32 (or reshape differently if transposed semantics are intended).

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/trace/example/fi_trace_out/gemm_fp8_N1536_K7168.json` at line 50,
The helper _mm_fp8_reference currently reshapes B into B_fp32 =
B.reshape(K_div_bs * block_size, N) (i.e., [K, N]) but then computes
torch.matmul(A_fp32, B_fp32.T), which is wrong when K != N; change the matmul to
use B_fp32 (torch.matmul(A_fp32, B_fp32)) so the multiplication matches the
reshaped [K, N] layout, or alternatively reshape B to [N, K] if you truly need
B_fp32.T semantics—fix the call in _mm_fp8_reference referencing B_fp32 and
A_fp32 accordingly.

Comment thread flashinfer/trace/templates/moe.py Outdated
Comment on lines +25 to +27
H = 7168
I = 2048
BLOCK = 128
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Hardcoded H/I makes reference execution shape-fragile.

_fp8_moe_run_experts is wired to H=7168 and I=2048, but template axes are shape-driven. This will fail or produce invalid behavior for other valid MoE shapes.

💡 Proposed fix
-H = 7168
-I = 2048
 BLOCK = 128
@@
 def _fp8_moe_run_experts(
@@
-    T = hidden_states.shape[0]
+    T, H = hidden_states.shape
+    I = gemm2_weights.shape[2]
+    gemm1_out = gemm1_weights.shape[1]
+    if gemm1_out != 2 * I:
+        raise ValueError(
+            f"Invalid gemm1_out_size={gemm1_out}, expected 2 * intermediate_size={2 * I}"
+        )
@@
-    A_scale_expanded = (
-        A_scale_TH.unsqueeze(-1).repeat(1, 1, BLOCK).reshape(T, H).contiguous()
-    )
+    A_scale_expanded = A_scale_TH.unsqueeze(-1).repeat(1, 1, BLOCK).reshape(T, H).contiguous()
@@
-        X1, X2 = G1[:, :I], G1[:, I:]
+        X1, X2 = G1[:, :I], G1[:, I:]

Also applies to: 48-58, 72-73, 86-87

🧰 Tools
🪛 Ruff (0.15.7)

[error] 26-26: Ambiguous variable name: I

(E741)

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/trace/templates/moe.py` around lines 25 - 27, Hardcoded constants
H=7168 and I=2048 (and uses of BLOCK) make _fp8_moe_run_experts and related
templates shape-fragile; change these to compute H and I from the template/axis
sizes at runtime and use a derived BLOCK (e.g., based on H/I or
template.block_size) instead of literal numbers. Locate the constants H, I,
BLOCK and replace them with expressions that read the relevant template axes or
tensor shapes (reference the template used by _fp8_moe_run_experts and other
occurrences) so all occurrences use dynamically computed sizes rather than
hardcoded values.

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 6

🧹 Nitpick comments (2)
tests/trace/fi_trace_out/gdn_mtp_qk4_v8_d128.json (1)

135-146: Consider specifying dtype for intermediate_states_buffer.

The dtype is set to "unknown" while all other tensors have explicit dtypes. Since this buffer stores intermediate states similar to initial_state and final_state (both float32), consider using "float32" for consistency—or add documentation explaining why the dtype is indeterminate.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/trace/fi_trace_out/gdn_mtp_qk4_v8_d128.json` around lines 135 - 146,
The schema entry "intermediate_states_buffer" currently has dtype "unknown";
change it to a concrete dtype (e.g., "float32") to match the similar tensors
"initial_state" and "final_state", or if dtype truly varies, add a clear
description explaining why it's indeterminate and what types are allowed; update
the "intermediate_states_buffer" dtype field and its description accordingly to
ensure consistency and clarity.
tests/trace/fi_trace_out/mla_paged_decode_h16_ckv512_kpe64_ps64.json (1)

85-97: Use concrete integer dtypes for index tensors.

kv_indptr and kv_indices are currently "dtype": "unknown". This weakens schema validation and downstream codegen/consumers. Prefer explicit integer types (typically int32 or int64) for both.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/trace/fi_trace_out/mla_paged_decode_h16_ckv512_kpe64_ps64.json` around
lines 85 - 97, The schema uses "dtype": "unknown" for the index tensors
kv_indptr (shape "len_indptr") and kv_indices (shape "num_kv_indices"); change
both to a concrete integer dtype (prefer int32, or int64 if you need 64-bit
indices) so downstream validation and codegen can rely on a fixed integer
type—update the "dtype" entries for kv_indptr and kv_indices accordingly.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@tests/trace/example.py`:
- Around line 1-294: The file is a standalone script so pytest won't collect it;
convert it into a proper pytest test by moving the top-level side-effect code
into a single test function (e.g., def test_generate_fi_trace_jsons(tmp_path):)
while preserving the early environment setup (os.environ.setdefault(...) and
SAVE_DIR) before importing flashinfer, and use the tmp_path fixture to override
FLASHINFER_TRACE_DUMP_DIR/SAVE_DIR so outputs go to a test-isolated directory;
keep all calls to flashinfer functions and wrappers (e.g., flashinfer.rmsnorm,
flashinfer.fused_add_rmsnorm, flashinfer.top_k_sampling_from_probs,
flashinfer.mm_bf16, flashinfer.gdn_decode.gated_delta_rule_decode,
BatchDecodeWithPagedKVCacheWrapper, BatchPrefillWithPagedKVCacheWrapper,
BatchPrefillWithRaggedKVCacheWrapper, BatchMLAPagedAttentionWrapper,
flashinfer.fused_moe.trtllm_fp8_block_scale_moe, etc.) inside that test, and
remove or adapt prints/assert the expected JSON files exist via
SAVE_DIR.glob("*.json") to make the test assertions deterministic for CI.

In `@tests/trace/fi_trace_out/gdn_decode_qk4_v8_d128.json`:
- Around line 120-124: The "scale" field is documented as having a default
(1/sqrt(head_size)) but isn't marked optional; update the JSON schema entry for
"scale" so consumers know it may be omitted—e.g., add an optional/nullable flag
or remove it from any "required" list and set "optional": true (or equivalent)
next to the "scale" property to reflect the default behavior.
- Line 148: The reference function _gdn_decode_reference uses math.sqrt and
F.softplus but the serialized source string has no imports, causing NameError
when exec/eval runs; fix by either injecting math and torch.nn.functional as F
into the exec/eval globals where _gdn_decode_reference is executed (ensure names
"math" and "F" are present) or prepend/import lines ("import math" and "import
torch.nn.functional as F") to the serialized reference string so
_gdn_decode_reference has the required symbols at runtime.

In `@tests/trace/fi_trace_out/gdn_mtp_qk4_v8_d128.json`:
- Around line 159-168: The doc string for "final_state" references an undefined
parameter disable_state_update; either add a boolean input named
disable_state_update to the inputs section (e.g., description: "If true,
recurrent state updates are disabled and final_state remains unchanged") or
remove the mention "Unchanged if disable_state_update=True" from the
"final_state" description; update the "final_state" description or inputs
accordingly so the documentation no longer refers to an undefined symbol.
- Line 170: The reference function _gdn_mtp_reference updates per-batch states
in state_HVK but then returns final_state = initial_state.clone(), discarding
updates; fix by creating final_state = initial_state.clone() before the batch
loop and after processing each batch element (using state_idx =
int(initial_state_indices[b_idx].item())) write the updated state back with
final_state[state_idx] = state_HVK.transpose(-1, -2) (matching the stored
[H,V,K] layout); ensure types remain consistent (match .float()/.to dtype as
needed) and then return output, final_state.

In `@tests/trace/fi_trace_out/mla_paged_decode_h16_ckv512_kpe64_ps64.json`:
- Line 123: In _mla_paged_decode_reference the use of ckv_cache.squeeze(1) and
kpe_cache.squeeze(1) is wrong for paged tensors (shape [num_pages, page_size,
head_dim_*]) and leaves a 3D tensor so Kc_all[tok_idx] yields [L, page_size,
head_dim]; replace squeeze(1) with a flattening reshape (e.g. reshape(num_pages
* page_size, head_dim_ckv) / reshape(..., head_dim_kpe) or view(-1, head_dim_*))
so Kc_all and Kp_all become 2D token-major tensors before indexing, and ensure
kv_indptr and kv_indices are cast to explicit integer dtype (torch.long/int64)
before use to remove schema ambiguity.

---

Nitpick comments:
In `@tests/trace/fi_trace_out/gdn_mtp_qk4_v8_d128.json`:
- Around line 135-146: The schema entry "intermediate_states_buffer" currently
has dtype "unknown"; change it to a concrete dtype (e.g., "float32") to match
the similar tensors "initial_state" and "final_state", or if dtype truly varies,
add a clear description explaining why it's indeterminate and what types are
allowed; update the "intermediate_states_buffer" dtype field and its description
accordingly to ensure consistency and clarity.

In `@tests/trace/fi_trace_out/mla_paged_decode_h16_ckv512_kpe64_ps64.json`:
- Around line 85-97: The schema uses "dtype": "unknown" for the index tensors
kv_indptr (shape "len_indptr") and kv_indices (shape "num_kv_indices"); change
both to a concrete integer dtype (prefer int32, or int64 if you need 64-bit
indices) so downstream validation and codegen can rely on a fixed integer
type—update the "dtype" entries for kv_indptr and kv_indices accordingly.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: 9d5f1191-ca90-4d41-be7d-6533b17213b1

📥 Commits

Reviewing files that changed from the base of the PR and between 8453636 and c5296b7.

📒 Files selected for processing (29)
  • flashinfer/decode.py
  • flashinfer/fused_moe/core.py
  • flashinfer/gdn_decode.py
  • flashinfer/gemm/gemm_base.py
  • flashinfer/norm/__init__.py
  • flashinfer/prefill.py
  • tests/trace/example.py
  • tests/trace/fi_trace_out/fused_add_rmsnorm_h5120.json
  • tests/trace/fi_trace_out/gdn_decode_qk4_v8_d128.json
  • tests/trace/fi_trace_out/gdn_mtp_qk4_v8_d128.json
  • tests/trace/fi_trace_out/gemm_bf16_N256_K7168.json
  • tests/trace/fi_trace_out/gemm_bf16_N4096_K4096.json
  • tests/trace/fi_trace_out/gemm_fp4_N2048_K7168_block_size16.json
  • tests/trace/fi_trace_out/gemm_fp8_N1536_K7168.json
  • tests/trace/fi_trace_out/gemm_mxfp8_N4096_K4096.json
  • tests/trace/fi_trace_out/gqa_paged_decode_h32_kv8_d128_ps16.json
  • tests/trace/fi_trace_out/gqa_paged_decode_h32_kv8_d128_ps64.json
  • tests/trace/fi_trace_out/gqa_paged_prefill_h32_kv8_d128_ps16.json
  • tests/trace/fi_trace_out/gqa_ragged_h32_kv8_d128.json
  • tests/trace/fi_trace_out/mla_paged_decode_h16_ckv512_kpe64_ps1.json
  • tests/trace/fi_trace_out/mla_paged_decode_h16_ckv512_kpe64_ps64.json
  • tests/trace/fi_trace_out/moe_fp8_block_scale_default_routing_topk8_e32_h7168_i2048.json
  • tests/trace/fi_trace_out/rmsnorm_h4096.json
  • tests/trace/fi_trace_out/rmsnorm_h7168.json
  • tests/trace/fi_trace_out/top_k_sampling_v128256.json
  • tests/trace/fi_trace_out/top_k_top_p_sampling_v128256.json
  • tests/trace/fi_trace_out/top_k_top_p_sampling_v151936.json
  • tests/trace/fi_trace_out/top_p_sampling_v128256.json
  • tests/trace/fi_trace_out/top_p_sampling_v151936.json
✅ Files skipped from review due to trivial changes (20)
  • flashinfer/norm/init.py
  • tests/trace/fi_trace_out/gemm_bf16_N4096_K4096.json
  • tests/trace/fi_trace_out/rmsnorm_h7168.json
  • tests/trace/fi_trace_out/gemm_bf16_N256_K7168.json
  • tests/trace/fi_trace_out/fused_add_rmsnorm_h5120.json
  • tests/trace/fi_trace_out/top_k_top_p_sampling_v151936.json
  • tests/trace/fi_trace_out/rmsnorm_h4096.json
  • tests/trace/fi_trace_out/gqa_paged_decode_h32_kv8_d128_ps64.json
  • tests/trace/fi_trace_out/top_k_sampling_v128256.json
  • tests/trace/fi_trace_out/gemm_fp4_N2048_K7168_block_size16.json
  • tests/trace/fi_trace_out/gemm_fp8_N1536_K7168.json
  • tests/trace/fi_trace_out/gqa_paged_prefill_h32_kv8_d128_ps16.json
  • tests/trace/fi_trace_out/top_p_sampling_v151936.json
  • tests/trace/fi_trace_out/gqa_ragged_h32_kv8_d128.json
  • tests/trace/fi_trace_out/mla_paged_decode_h16_ckv512_kpe64_ps1.json
  • tests/trace/fi_trace_out/top_p_sampling_v128256.json
  • tests/trace/fi_trace_out/moe_fp8_block_scale_default_routing_topk8_e32_h7168_i2048.json
  • tests/trace/fi_trace_out/gqa_paged_decode_h32_kv8_d128_ps16.json
  • flashinfer/prefill.py
  • tests/trace/fi_trace_out/gemm_mxfp8_N4096_K4096.json
🚧 Files skipped from review as they are similar to previous changes (3)
  • flashinfer/decode.py
  • flashinfer/gdn_decode.py
  • flashinfer/gemm/gemm_base.py

Comment thread tests/trace/example.py
Comment on lines +1 to +294
"""
fi_trace example: generate flashinfer-bench definition JSON files via auto-dump.

Run:
python tests/trace/example.py

When FLASHINFER_TRACE_DUMP=1 (set below), every @flashinfer_api(trace=...) decorated
function automatically writes a trace JSON on its first call for each unique input
shape. Subsequent calls with the same shape are deduplicated (no re-write).

The output directory is controlled by FLASHINFER_TRACE_DUMP_DIR.

Requires a CUDA-capable GPU.

Results:
- We would get these example json files under fi_trace_out directory:
fused_add_rmsnorm_h5120.json
gdn_decode_qk4_v8_d128_k_last.json
gdn_mtp_qk4_v8_d128_k_last.json
gdn_prefill_qk4_v8_d128_k_last.json
gemm_bf16_n256_k7168.json
gemm_bf16_n4096_k4096.json
gemm_fp4_n2048_k7168.json
gemm_fp8_n1536_k7168.json
gemm_mxfp8_n4096_k4096.json
gqa_paged_decode_h32_kv8_d128_ps16.json
gqa_paged_decode_h32_kv8_d128_ps64.json
gqa_paged_prefill_h32_kv8_d128_ps16.json
gqa_ragged_prefill_h32_kv8_d128.json
mla_paged_decode_h16_ckv512_kpe64_ps1.json
mla_paged_decode_h16_ckv512_kpe64_ps64.json
moe_fp8_block_scale_ds_routing_topk8_ng8_kg4_e32_h7168_i2048.json
rmsnorm_h4096.json
rmsnorm_h7168.json
top_k_sampling_from_probs_v128256.json
top_k_top_p_sampling_from_probs_v128256.json
top_k_top_p_sampling_from_probs_v151936.json
top_p_sampling_from_probs_v128256.json
top_p_sampling_from_probs_v151936.json

Note: top_p_sampling files appear for vocab_size=151936 because
top_k_top_p_sampling (top_k_first order) calls top_p_sampling internally.
"""

import json
import os
from pathlib import Path

# Must be set before any flashinfer import: template.py reads these at module load time.
os.environ.setdefault(
"FLASHINFER_TRACE_DUMP_DIR",
str(Path(__file__).parent / "fi_trace_out"),
)
os.environ.setdefault("FLASHINFER_TRACE_DUMP", "1")

SAVE_DIR = Path(os.environ["FLASHINFER_TRACE_DUMP_DIR"])

import torch

import flashinfer
import flashinfer.norm
import flashinfer.sampling
import flashinfer.gemm
import flashinfer.gdn_decode
import flashinfer.fused_moe
from flashinfer.decode import BatchDecodeWithPagedKVCacheWrapper
from flashinfer.prefill import (
BatchPrefillWithPagedKVCacheWrapper,
BatchPrefillWithRaggedKVCacheWrapper,
)
from flashinfer.mla import BatchMLAPagedAttentionWrapper

device = "cuda"
WORKSPACE = 128 * 1024 * 1024 # 128 MB

print(f"\nAuto-dumping fi_trace JSON files to {SAVE_DIR}/\n")

# ── rmsnorm ───────────────────────────────────────────────────────────────────
# Llama-3.1-8B (hidden=4096) and DeepSeek-V3 (hidden=7168)
for hidden_size in (4096, 7168):
hidden = torch.randn(32, hidden_size, dtype=torch.bfloat16, device=device)
weight = torch.ones(hidden_size, dtype=torch.bfloat16, device=device)
flashinfer.rmsnorm(hidden, weight)

# ── fused_add_rmsnorm (Qwen3-14B, hidden=5120) ───────────────────────────────
x = torch.randn(32, 5120, dtype=torch.bfloat16, device=device)
res = torch.randn(32, 5120, dtype=torch.bfloat16, device=device)
w = torch.ones(5120, dtype=torch.bfloat16, device=device)
flashinfer.fused_add_rmsnorm(x, res, w)

# ── sampling (Llama vocab=128256) ─────────────────────────────────────────────
probs = torch.rand(64, 128256, dtype=torch.float32, device=device)
top_k = torch.full((64,), 50, dtype=torch.int32, device=device)
top_p = torch.full((64,), 0.9, dtype=torch.float32, device=device)
flashinfer.top_k_sampling_from_probs(probs, top_k)
flashinfer.top_p_sampling_from_probs(probs, top_p)
flashinfer.top_k_top_p_sampling_from_probs(probs, top_k, top_p)

# ── sampling (Qwen3 vocab=151936) ─────────────────────────────────────────────
probs = torch.rand(64, 151936, dtype=torch.float32, device=device)
flashinfer.top_k_top_p_sampling_from_probs(probs, top_k, top_p)

# ── GEMM bf16 ─────────────────────────────────────────────────────────────────
# Llama-3.1-8B o_proj (4096×4096) and DeepSeek-V3 moe.gate (256×7168)
# Use cutlass backend to avoid cuDNN dependency.
# mm_bf16 expects b in column-major layout with shape [K, N].
# randn(N, K).T gives shape [K, N] with strides (1, N); the kernel transposes
# b back to [N, K] (contiguous) before calling the C++ matmul.
for N, K in ((4096, 4096), (256, 7168)):
a = torch.randn(128, K, dtype=torch.bfloat16, device=device)
b = torch.randn(N, K, dtype=torch.bfloat16, device=device).T # [K, N] column-major; b.T is contiguous
flashinfer.mm_bf16(a, b, backend="cutlass")

# ── GEMM fp8 block-scale (DeepSeek-V3 q_proj: M×7168→1536, block=128) ────────
M, K, N, BS = 128, 7168, 1536, 128
a_fp8 = torch.zeros(M, K, dtype=torch.float8_e4m3fn, device=device)
b_fp8 = torch.zeros(K // BS, N, BS, dtype=torch.float8_e4m3fn, device=device)
alpha_fp8 = torch.tensor(1.0, dtype=torch.float32, device=device)
flashinfer.mm_fp8(a_fp8, b_fp8, alpha_fp8)

# ── GEMM mxfp8 (Blackwell SM100+: M×4096@4096×4096, block=32) ────────────────
try:
M, K, N = 128, 4096, 4096
a_mxfp8 = torch.zeros(M, K, dtype=torch.float8_e4m3fn, device=device)
b_mxfp8 = torch.zeros(K, N, dtype=torch.float8_e4m3fn, device=device)
a_ds = torch.ones(M, K // 32, dtype=torch.uint8, device=device)
b_ds = torch.ones(K // 32, N, dtype=torch.uint8, device=device)
flashinfer.gemm.mm_mxfp8(a_mxfp8, b_mxfp8, a_ds, b_ds)
except Exception:
pass # Requires Blackwell (SM100+)

# ── GEMM fp4 (Blackwell SM100+: M×7168@2048×7168, block=16) ─────────────────
try:
M, K, N, BS4 = 128, 7168, 2048, 16
a_fp4 = torch.zeros(M, K, dtype=torch.uint8, device=device)
b_fp4 = torch.zeros(K, N, dtype=torch.uint8, device=device)
a_d4 = torch.ones(M, K // BS4, dtype=torch.float8_e4m3fn, device=device)
b_d4 = torch.ones(K, N // BS4, dtype=torch.float8_e4m3fn, device=device)
flashinfer.gemm.mm_fp4(a_fp4, b_fp4, a_d4, b_d4, block_size=BS4)
except Exception:
pass # Requires Blackwell (SM100+)

# ── GQA paged decode (Llama-3.1-8B, h=32/kv=8/d=128) ────────────────────────
num_qo, num_kv, head_dim, batch_size = 32, 8, 128, 32

for page_size, num_pages in ((16, 128), (64, 32)):
total = batch_size * num_pages
kv_indptr = torch.arange(batch_size + 1, dtype=torch.int32, device=device) * num_pages
kv_indices = torch.arange(total, dtype=torch.int32, device=device)
kv_last = torch.full((batch_size,), page_size, dtype=torch.int32, device=device)

ws = torch.empty(WORKSPACE, dtype=torch.uint8, device=device)
dec = BatchDecodeWithPagedKVCacheWrapper(ws, "NHD")
dec.plan(
kv_indptr, kv_indices, kv_last,
num_qo, num_kv, head_dim, page_size,
q_data_type=torch.bfloat16, kv_data_type=torch.bfloat16,
)
q_d = torch.randn(batch_size, num_qo, head_dim, dtype=torch.bfloat16, device=device)
kc = torch.randn(total, page_size, num_kv, head_dim, dtype=torch.bfloat16, device=device)
vc = torch.randn(total, page_size, num_kv, head_dim, dtype=torch.bfloat16, device=device)
dec.run(q_d, (kc, vc))

# ── GQA paged prefill (Llama-3.1-8B, h=32/kv=8/d=128, page_size=16) ─────────
n_req, total_q, np_pf, page_size = 4, 512, 32, 16
total_pf = n_req * np_pf
qo_indptr = torch.tensor([0, 128, 256, 384, 512], dtype=torch.int32, device=device)
kv_indptr_p = torch.arange(n_req + 1, dtype=torch.int32, device=device) * np_pf
kv_idx_p = torch.arange(total_pf, dtype=torch.int32, device=device)
kv_last_p = torch.full((n_req,), page_size, dtype=torch.int32, device=device)

ws_pf = torch.empty(WORKSPACE, dtype=torch.uint8, device=device)
pf = BatchPrefillWithPagedKVCacheWrapper(ws_pf, "NHD")
pf.plan(
qo_indptr, kv_indptr_p, kv_idx_p, kv_last_p,
num_qo, num_kv, head_dim, page_size,
causal=True, q_data_type=torch.bfloat16, kv_data_type=torch.bfloat16,
)
q_pf = torch.randn(total_q, num_qo, head_dim, dtype=torch.bfloat16, device=device)
kc_pf = torch.randn(total_pf, page_size, num_kv, head_dim, dtype=torch.bfloat16, device=device)
vc_pf = torch.randn(total_pf, page_size, num_kv, head_dim, dtype=torch.bfloat16, device=device)
pf.run(q_pf, (kc_pf, vc_pf))

# ── GQA ragged prefill (Llama-3.1-8B) ────────────────────────────────────────
qo_indptr_r = torch.tensor([0, 64, 128, 192, 256], dtype=torch.int32, device=device)
kv_indptr_r = torch.tensor([0, 128, 256, 384, 512], dtype=torch.int32, device=device)

ws_r = torch.empty(WORKSPACE, dtype=torch.uint8, device=device)
rag = BatchPrefillWithRaggedKVCacheWrapper(ws_r, "NHD")
rag.plan(
qo_indptr_r, kv_indptr_r,
num_qo, num_kv, head_dim,
causal=True, q_data_type=torch.bfloat16, kv_data_type=torch.bfloat16,
)
q_r = torch.randn(256, num_qo, head_dim, dtype=torch.bfloat16, device=device)
k_r = torch.randn(512, num_kv, head_dim, dtype=torch.bfloat16, device=device)
v_r = torch.randn(512, num_kv, head_dim, dtype=torch.bfloat16, device=device)
rag.run(q_r, k_r, v_r)

# ── MLA paged decode (DeepSeek-V3 TP=8, h=16/ckv=512/kpe=64) ─────────────────
mla_b, mla_h, ckv, kpe = 128, 16, 512, 64

for mla_ps, mla_np in ((64, 32), (1, 2048)):
total_mla = mla_b * mla_np
mla_qo_indptr = torch.arange(mla_b + 1, dtype=torch.int32, device=device)
mla_kv_indptr = torch.arange(mla_b + 1, dtype=torch.int32, device=device) * mla_np
mla_kv_indices = torch.arange(total_mla, dtype=torch.int32, device=device)
mla_kv_len = torch.full((mla_b,), mla_np * mla_ps, dtype=torch.int32, device=device)

ws_mla = torch.empty(WORKSPACE, dtype=torch.uint8, device=device)
mla = BatchMLAPagedAttentionWrapper(ws_mla)
mla.plan(
mla_qo_indptr, mla_kv_indptr, mla_kv_indices, mla_kv_len,
mla_h, ckv, kpe, mla_ps,
causal=False, sm_scale=1.0 / (ckv ** 0.5),
q_data_type=torch.bfloat16, kv_data_type=torch.bfloat16,
)
q_nope = torch.randn(mla_b, mla_h, ckv, dtype=torch.bfloat16, device=device)
q_pe = torch.randn(mla_b, mla_h, kpe, dtype=torch.bfloat16, device=device)
ckv_cache = torch.randn(total_mla, mla_ps, ckv, dtype=torch.bfloat16, device=device)
kpe_cache = torch.randn(total_mla, mla_ps, kpe, dtype=torch.bfloat16, device=device)
mla.run(q_nope, q_pe, ckv_cache, kpe_cache)

# ── GDN decode (Qwen3-Next TP=4, qk=4/v=8/d=128) ────────────────────────────
B, H, HV, K = 4, 4, 8, 128
q = torch.randn(B, 1, H, K, dtype=torch.bfloat16, device=device)
k = torch.randn(B, 1, H, K, dtype=torch.bfloat16, device=device)
v = torch.randn(B, 1, HV, K, dtype=torch.bfloat16, device=device)
state = torch.zeros(B, HV, K, K, dtype=torch.float32, device=device)
A_log = torch.zeros(HV, dtype=torch.float32, device=device)
a = torch.zeros(B, 1, HV, dtype=torch.bfloat16, device=device)
dt_bias = torch.zeros(HV, dtype=torch.float32, device=device)
b_ = torch.zeros(B, 1, HV, dtype=torch.bfloat16, device=device)
flashinfer.gdn_decode.gated_delta_rule_decode(q, k, v, state, A_log, a, dt_bias, b_)

# ── GDN MTP (Qwen3-Next TP=4, spec_len=4) ────────────────────────────────────
T_mtp, pool_size = 4, 8
q_m = torch.randn(B, T_mtp, H, K, dtype=torch.bfloat16, device=device)
k_m = torch.randn(B, T_mtp, H, K, dtype=torch.bfloat16, device=device)
v_m = torch.randn(B, T_mtp, HV, K, dtype=torch.bfloat16, device=device)
init_state = torch.zeros(pool_size, HV, K, K, dtype=torch.float32, device=device)
init_idx = torch.arange(B, dtype=torch.int32, device=device)
A_log_m = torch.zeros(HV, dtype=torch.float32, device=device)
a_m = torch.zeros(B, T_mtp, HV, dtype=torch.bfloat16, device=device)
dt_bias_m = torch.zeros(HV, dtype=torch.float32, device=device)
b_m = torch.zeros(B, T_mtp, HV, dtype=torch.bfloat16, device=device)
flashinfer.gdn_decode.gated_delta_rule_mtp(
q_m, k_m, v_m, init_state, init_idx, A_log_m, a_m, dt_bias_m, b_m
)

# ── MoE FP8 (DeepSeek-V3 EP=8: 256 experts, 32 local, h=7168, i=2048, top_k=8)
try:
T_moe, H_moe, I_moe, E_tot, E_loc, BS = 128, 7168, 2048, 256, 32, 128
routing_logits = torch.randn(T_moe, E_tot, dtype=torch.float32, device=device)
routing_bias = torch.zeros(E_tot, dtype=torch.bfloat16, device=device)
hs = torch.zeros(T_moe, H_moe, dtype=torch.float8_e4m3fn, device=device)
hs_scale = torch.ones(H_moe // BS, T_moe, dtype=torch.float32, device=device)
w1 = torch.zeros(E_loc, 2 * I_moe, H_moe, dtype=torch.float8_e4m3fn, device=device)
w1s = torch.ones(E_loc, (2 * I_moe) // BS, H_moe // BS, dtype=torch.float32, device=device)
w2 = torch.zeros(E_loc, H_moe, I_moe, dtype=torch.float8_e4m3fn, device=device)
w2s = torch.ones(E_loc, H_moe // BS, I_moe // BS, dtype=torch.float32, device=device)
flashinfer.fused_moe.trtllm_fp8_block_scale_moe(
routing_logits, routing_bias,
hs, hs_scale,
w1, w1s,
w2, w2s,
num_experts=E_tot,
top_k=8,
n_group=8,
topk_group=3,
intermediate_size=I_moe,
local_expert_offset=0,
local_num_experts=E_loc,
routed_scaling_factor=2.5,
)
except Exception:
pass # May require specific GPU/TRT-LLM support

# ── Summary ───────────────────────────────────────────────────────────────────
files = sorted(SAVE_DIR.glob("*.json"))
print(f"\nWrote {len(files)} definition files:\n")
for f in files:
defn = json.loads(f.read_text())
print(f" {f.name}")
print(f" op_type : {defn['op_type']}")
print(f" fi_api : {next(t for t in defn['tags'] if t.startswith('fi_api:'))}")
const_axes = {
k: v["value"]
for k, v in defn["axes"].items()
if v["type"] == "const" and "value" in v
}
if const_axes:
print(f" axes : {const_axes}")
print()
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

🧩 Analysis chain

🏁 Script executed:

# Check if the file exists and get its basic info
find tests -name "example.py" -type f | head -5

Repository: flashinfer-ai/flashinfer

Length of output: 90


🏁 Script executed:

# Check the structure of tests directory
find tests -type f -name "*.py" | head -20

Repository: flashinfer-ai/flashinfer

Length of output: 859


🏁 Script executed:

# Look for test_ prefixed functions in the current example.py
if [ -f "tests/trace/example.py" ]; then
  grep -n "^def test_" tests/trace/example.py || echo "No test_ functions found"
fi

Repository: flashinfer-ai/flashinfer

Length of output: 92


🏁 Script executed:

# Check if there's a main block and understand the current structure
if [ -f "tests/trace/example.py" ]; then
  grep -n "if __name__" tests/trace/example.py
fi

Repository: flashinfer-ai/flashinfer

Length of output: 50


🏁 Script executed:

# Look at other test files to understand the pytest pattern used
find tests -name "test_*.py" -type f | head -5

Repository: flashinfer-ai/flashinfer

Length of output: 270


🏁 Script executed:

# Check the actual content and structure of example.py
tail -20 tests/trace/example.py

Repository: flashinfer-ai/flashinfer

Length of output: 741


🏁 Script executed:

# Look at a properly structured test file to understand the pattern
head -50 tests/attention/test_alibi.py

Repository: flashinfer-ai/flashinfer

Length of output: 1692


🏁 Script executed:

# Check if tests/trace directory has any other test files
ls -la tests/trace/

Repository: flashinfer-ai/flashinfer

Length of output: 309


🏁 Script executed:

# Look for pytest configuration to understand test discovery
find . -name "pytest.ini" -o -name "pyproject.toml" -o -name "setup.cfg" | head -3 | xargs grep -l "testpaths\|python_files" 2>/dev/null || echo "No specific pytest config found for test patterns"

Repository: flashinfer-ai/flashinfer

Length of output: 117


Refactor as a proper pytest test with test_ prefix to ensure CI coverage.

This file is a runnable script with side effects but lacks a test_ function, so pytest won't collect it during CI. Per coding guidelines for tests/**/*.py, test functions must be prefixed with test_ and structured in subdirectories matching kernel categories.

Suggested structure
+def run_trace_example(save_dir: Path) -> list[Path]:
+    # existing body here
+    return sorted(save_dir.glob("*.json"))
+
+def test_fi_trace_example_generates_defs(tmp_path, monkeypatch):
+    monkeypatch.setenv("FLASHINFER_TRACE_DUMP", "1")
+    monkeypatch.setenv("FLASHINFER_TRACE_DUMP_DIR", str(tmp_path))
+    files = run_trace_example(tmp_path)
+    assert files, "Expected fi_trace JSON files to be generated"
+
+if __name__ == "__main__":
+    run_trace_example(SAVE_DIR)
🧰 Tools
🪛 Ruff (0.15.7)

[warning] 104-104: Comment contains ambiguous × (MULTIPLICATION SIGN). Did you mean x (LATIN SMALL LETTER X)?

(RUF003)


[warning] 104-104: Comment contains ambiguous × (MULTIPLICATION SIGN). Did you mean x (LATIN SMALL LETTER X)?

(RUF003)


[warning] 114-114: Comment contains ambiguous × (MULTIPLICATION SIGN). Did you mean x (LATIN SMALL LETTER X)?

(RUF003)


[warning] 121-121: Comment contains ambiguous × (MULTIPLICATION SIGN). Did you mean x (LATIN SMALL LETTER X)?

(RUF003)


[warning] 121-121: Comment contains ambiguous × (MULTIPLICATION SIGN). Did you mean x (LATIN SMALL LETTER X)?

(RUF003)


[error] 129-130: try-except-pass detected, consider logging the exception

(S110)


[warning] 129-129: Do not catch blind exception: Exception

(BLE001)


[warning] 132-132: Comment contains ambiguous × (MULTIPLICATION SIGN). Did you mean x (LATIN SMALL LETTER X)?

(RUF003)


[warning] 132-132: Comment contains ambiguous × (MULTIPLICATION SIGN). Did you mean x (LATIN SMALL LETTER X)?

(RUF003)


[error] 140-141: try-except-pass detected, consider logging the exception

(S110)


[warning] 140-140: Do not catch blind exception: Exception

(BLE001)


[error] 276-277: try-except-pass detected, consider logging the exception

(S110)


[warning] 276-276: Do not catch blind exception: Exception

(BLE001)

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/trace/example.py` around lines 1 - 294, The file is a standalone script
so pytest won't collect it; convert it into a proper pytest test by moving the
top-level side-effect code into a single test function (e.g., def
test_generate_fi_trace_jsons(tmp_path):) while preserving the early environment
setup (os.environ.setdefault(...) and SAVE_DIR) before importing flashinfer, and
use the tmp_path fixture to override FLASHINFER_TRACE_DUMP_DIR/SAVE_DIR so
outputs go to a test-isolated directory; keep all calls to flashinfer functions
and wrappers (e.g., flashinfer.rmsnorm, flashinfer.fused_add_rmsnorm,
flashinfer.top_k_sampling_from_probs, flashinfer.mm_bf16,
flashinfer.gdn_decode.gated_delta_rule_decode,
BatchDecodeWithPagedKVCacheWrapper, BatchPrefillWithPagedKVCacheWrapper,
BatchPrefillWithRaggedKVCacheWrapper, BatchMLAPagedAttentionWrapper,
flashinfer.fused_moe.trtllm_fp8_block_scale_moe, etc.) inside that test, and
remove or adapt prints/assert the expected JSON files exist via
SAVE_DIR.glob("*.json") to make the test assertions deterministic for CI.

Comment on lines +120 to +124
"scale": {
"shape": null,
"dtype": "float32",
"description": "Scale factor. Default is 1/sqrt(head_size)."
}
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Mark scale as optional to match the declared default behavior.

Line 123 says a default is applied (1/sqrt(head_size)), but scale is not marked optional. This can make schema consumers treat it as required.

🛠️ Proposed fix
     "scale": {
       "shape": null,
       "dtype": "float32",
+      "optional": true,
       "description": "Scale factor. Default is 1/sqrt(head_size)."
     }
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
"scale": {
"shape": null,
"dtype": "float32",
"description": "Scale factor. Default is 1/sqrt(head_size)."
}
"scale": {
"shape": null,
"dtype": "float32",
"optional": true,
"description": "Scale factor. Default is 1/sqrt(head_size)."
}
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/trace/fi_trace_out/gdn_decode_qk4_v8_d128.json` around lines 120 - 124,
The "scale" field is documented as having a default (1/sqrt(head_size)) but
isn't marked optional; update the JSON schema entry for "scale" so consumers
know it may be omitted—e.g., add an optional/nullable flag or remove it from any
"required" list and set "optional": true (or equivalent) next to the "scale"
property to reflect the default behavior.

"description": "Updated recurrent state in k-last layout [B, H, V, K]."
}
},
"reference": "@torch.no_grad()\ndef _gdn_decode_reference(q, k, v, state, A_log, a, dt_bias, b, scale):\n \"\"\"\n Gated Delta Net decode reference implementation (k-last layout).\n\n State layout: [B, H, V, K] (k-last, K dimension at the end)\n\n Gate computation:\n g = exp(-exp(A_log) * softplus(a + dt_bias))\n beta = sigmoid(b)\n\n Delta rule update:\n state_new = g * state_old + k^T @ (beta * v + (1-beta) * k @ state_old) - k^T @ (k @ state_old)\n output = scale * q @ state_new\n \"\"\"\n B, T, num_q_heads, K = q.shape\n _, _, num_k_heads, _ = k.shape\n _, _, num_v_heads, V = v.shape\n num_heads = num_v_heads\n device = q.device\n\n if scale is None or scale == 0.0:\n scale = 1.0 / math.sqrt(K)\n\n x = a.float() + dt_bias.float() # [B, 1, HV]\n g = torch.exp(-torch.exp(A_log.float()) * F.softplus(x)) # [B, 1, HV]\n beta = torch.sigmoid(b.float()) # [B, 1, HV]\n\n q_f32 = q.squeeze(1).float()\n k_f32 = k.squeeze(1).float()\n v_f32 = v.squeeze(1).float()\n g_f32 = g.squeeze(1).float()\n beta_f32 = beta.squeeze(1).float()\n\n if state is not None:\n state_f32 = state.float()\n else:\n state_f32 = torch.zeros(B, num_heads, V, K, dtype=torch.float32, device=device)\n\n q_exp = q_f32.repeat_interleave(num_v_heads // num_q_heads, dim=1)\n k_exp = k_f32.repeat_interleave(num_v_heads // num_k_heads, dim=1)\n\n new_state = torch.zeros_like(state_f32)\n output = torch.zeros(B, num_heads, V, dtype=torch.float32, device=device)\n\n for b_idx in range(B):\n for h_idx in range(num_heads):\n q_h = q_exp[b_idx, h_idx]\n k_h = k_exp[b_idx, h_idx]\n v_h = v_f32[b_idx, h_idx]\n h_state = state_f32[b_idx, h_idx].clone().transpose(-1, -2) # [V,K] -> [K,V]\n g_val = g_f32[b_idx, h_idx]\n beta_val = beta_f32[b_idx, h_idx]\n\n old_state = g_val * h_state\n old_v = k_h @ old_state\n new_v = beta_val * v_h + (1 - beta_val) * old_v\n state_remove = k_h.unsqueeze(1) @ old_v.unsqueeze(0)\n state_update = k_h.unsqueeze(1) @ new_v.unsqueeze(0)\n h_state = old_state - state_remove + state_update\n\n output[b_idx, h_idx] = scale * (q_h @ h_state)\n new_state[b_idx, h_idx] = h_state.transpose(-1, -2) # [K,V] -> [V,K]\n\n output = output.unsqueeze(1).to(torch.bfloat16)\n return output, new_state\n"
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

🧩 Analysis chain

🏁 Script executed:

#!/bin/bash
set -euo pipefail

# 1) Find where trace "reference" snippets are parsed/executed.
rg -n -C3 --type=py 'reference|exec\(|compile\('

# 2) Verify whether execution globals/locals include math and torch.nn.functional as F.
rg -n -C3 --type=py 'import math|torch\.nn\.functional as F|globals\(|locals\('

# 3) Find other trace fixtures that use math./F. in embedded references
# to confirm the expected convention is consistent.
rg -n --type=json 'math\.|F\.softplus|F\.' tests

Repository: flashinfer-ai/flashinfer

Length of output: 50381


🏁 Script executed:

#!/bin/bash
# Find files that load or process trace JSON files
find tests -type f -name "*.py" | head -20
rg -l "trace.*json|gdn_decode" tests --type py | head -10

Repository: flashinfer-ai/flashinfer

Length of output: 998


🏁 Script executed:

#!/bin/bash
# Search for code that parses/loads trace JSON files
rg -n "load.*trace|json\.load|reference" tests --type py -A 3 | grep -E "trace|reference|exec|compile" | head -40

Repository: flashinfer-ai/flashinfer

Length of output: 3912


🏁 Script executed:

#!/bin/bash
# Look for test files in the trace directory
ls -la tests/trace*.py tests/*trace*.py 2>/dev/null || find tests -name "*trace*" -type f

Repository: flashinfer-ai/flashinfer

Length of output: 163


🏁 Script executed:

#!/bin/bash
# Read the main trace test file to understand how reference is executed
wc -l tests/test_fi_trace.py
head -100 tests/test_fi_trace.py

Repository: flashinfer-ai/flashinfer

Length of output: 3745


🏁 Script executed:

#!/bin/bash
# Search for exec/compile patterns in trace test file
rg -n "exec|compile|reference|math|import F|import math" tests/test_fi_trace.py -A 5 -B 2

Repository: flashinfer-ai/flashinfer

Length of output: 50


🏁 Script executed:

#!/bin/bash
# Check the actual trace JSON file
head -30 tests/trace/fi_trace_out/gdn_decode_qk4_v8_d128.json

Repository: flashinfer-ai/flashinfer

Length of output: 940


🏁 Script executed:

#!/bin/bash
# Search for code that loads/executes these trace JSON files
rg -rn "gdn_decode|fi_trace_out" tests --type py

Repository: flashinfer-ai/flashinfer

Length of output: 2731


🏁 Script executed:

#!/bin/bash
# Search more broadly for "reference" field execution in Python test files
rg -rn "reference.*exec|exec.*reference|\[.reference.\]" tests --type py -A 5

Repository: flashinfer-ai/flashinfer

Length of output: 1018


🏁 Script executed:

#!/bin/bash
# Search for any code that executes or uses the "reference" field from JSON traces
rg -rn "\.reference|reference.*exec|eval.*reference" tests --type py -A 3 -B 3

Repository: flashinfer-ai/flashinfer

Length of output: 2062


🏁 Script executed:

#!/bin/bash
# Look at the gdn test file that's mentioned
wc -l tests/gdn/test_decode_delta_rule.py
head -50 tests/gdn/test_decode_delta_rule.py

Repository: flashinfer-ai/flashinfer

Length of output: 1537


🏁 Script executed:

#!/bin/bash
# Search for how reference implementations are called in the gdn tests
rg -n "_gdn_decode_reference|_reference" tests/gdn --type py -B 2 -A 2

Repository: flashinfer-ai/flashinfer

Length of output: 50


🏁 Script executed:

#!/bin/bash
# Look for any code that loads JSON traces and executes reference field
find tests -name "*.py" -type f -exec rg -l "reference" {} \; | head -5
rg -rn "loads\(|compile\(|exec\(" tests/trace --type py

Repository: flashinfer-ai/flashinfer

Length of output: 324


🏁 Script executed:

#!/bin/bash
# Check if there's a trace execution utility in the flashinfer source
find . -path ./tests -prune -o -name "*.py" -type f -print | xargs rg -l "fi_trace|trace.*reference" 2>/dev/null | head -10

Repository: flashinfer-ai/flashinfer

Length of output: 284


🏁 Script executed:

#!/bin/bash
# Search more broadly for any eval/exec of code strings in Python files
rg -rn "eval\(|exec\(|compile\(" . --type py --iglob "*/flashinfer/*" | head -20

Repository: flashinfer-ai/flashinfer

Length of output: 50


🏁 Script executed:

#!/bin/bash
# Examine the trace template and related files
wc -l flashinfer/trace/template.py flashinfer/fi_trace.py
head -80 flashinfer/trace/template.py

Repository: flashinfer-ai/flashinfer

Length of output: 3025


🏁 Script executed:

#!/bin/bash
# Search for reference field usage in trace modules
rg -n "reference" flashinfer/trace --type py -B 2 -A 5

Repository: flashinfer-ai/flashinfer

Length of output: 30715


🏁 Script executed:

#!/bin/bash
# Check fi_trace.py for reference handling
rg -n "reference" flashinfer/fi_trace.py -B 3 -A 5

Repository: flashinfer-ai/flashinfer

Length of output: 50


🏁 Script executed:

#!/bin/bash
# Look at the template.py file more carefully to see where reference is serialized
sed -n '480,510p' flashinfer/trace/template.py

Repository: flashinfer-ai/flashinfer

Length of output: 1460


🏁 Script executed:

#!/antml:thinking
The key finding is that `reference` is serialized as source code using `inspect.getsource()`. But I need to check if/where it's executed. Let me check if the trace system actually EXECUTES the reference field or just stores it.
</thinking>

# Now check the GDN reference function definition
cat flashinfer/trace/templates/gdn.py | head -100

Repository: flashinfer-ai/flashinfer

Length of output: 136


🏁 Script executed:

#!/bin/bash
# Search for where the "reference" field from JSON is actually executed
rg -rn "json.*reference|loads.*reference|exec.*reference" . --type py --iglob "*/flashinfer/*" -A 5 -B 2

Repository: flashinfer-ai/flashinfer

Length of output: 50


🏁 Script executed:

#!/bin/bash
# Check the actual gdn.py template to see the reference function definition
head -80 flashinfer/trace/templates/gdn.py

Repository: flashinfer-ai/flashinfer

Length of output: 2773


🏁 Script executed:

#!/bin/bash
# Look for any code that loads trace JSON and executes the reference
find . -name "*.py" -path "*/flashinfer/*" -o -path "*/tests/*" | xargs grep -l "reference.*exec\|exec.*reference" 2>/dev/null

Repository: flashinfer-ai/flashinfer

Length of output: 164


🏁 Script executed:

#!/bin/bash
# Check if there are any test files that load and run the reference field from JSON traces
find tests -name "*.py" -type f | xargs grep -l "\.json" | head -5

Repository: flashinfer-ai/flashinfer

Length of output: 204


🏁 Script executed:

#!/bin/bash
# Search for where reference strings might be executed from JSON
rg -rn "reference.*eval|reference.*exec|exec.*inspect|compile.*reference" tests flashinfer --type py -B 3 -A 3

Repository: flashinfer-ai/flashinfer

Length of output: 1200


🏁 Script executed:

#!/bin/bash
# Check test_artifacts.py to see if it loads/executes traces
wc -l tests/test_artifacts.py
head -100 tests/test_artifacts.py

Repository: flashinfer-ai/flashinfer

Length of output: 4681


🏁 Script executed:

#!/bin/bash
# Look for any test that loads the JSON traces and uses them
grep -r "fi_trace_out\|gdn_decode_qk4" tests --type py -B 2 -A 5

Repository: flashinfer-ai/flashinfer

Length of output: 185


🏁 Script executed:

#!/bin/bash
# Search for references to trace JSON files
rg "fi_trace_out|gdn_decode_qk4" tests -B 2 -A 5

Repository: flashinfer-ai/flashinfer

Length of output: 1661


🏁 Script executed:

#!/bin/bash
# Look for any test files that actually load and execute the reference from JSON
rg "reference.*code|reference.*string|compile\(|exec\(" flashinfer/trace --type py -B 3 -A 5

Repository: flashinfer-ai/flashinfer

Length of output: 1775


🏁 Script executed:

#!/bin/bash
# Directly check the JSON file to confirm what the reference field contains
python3 << 'EOF'
import json
with open('tests/trace/fi_trace_out/gdn_decode_qk4_v8_d128.json') as f:
    data = json.load(f)
    if 'reference' in data:
        ref = data['reference']
        print("First 500 chars of reference field:")
        print(ref[:500])
        print("\n...\n")
        print("Last 500 chars:")
        print(ref[-500:])
        # Check for imports
        if 'import math' in ref:
            print("\n✓ Contains: import math")
        if 'import torch' in ref:
            print("✓ Contains: import torch")
        if 'torch.nn.functional as F' in ref or 'torch.nn.functional' in ref:
            print("✓ Contains: torch.nn.functional reference")
        if 'F.softplus' in ref:
            print("✓ Uses: F.softplus")
        if 'math.sqrt' in ref:
            print("✓ Uses: math.sqrt")
    else:
        print("No 'reference' field in JSON")
EOF

Repository: flashinfer-ai/flashinfer

Length of output: 1164


🏁 Script executed:

#!/bin/bash
# Search for any code that actually executes the reference field from JSON
rg "reference.*exec|exec.*reference|compile.*reference|eval.*reference" flashinfer tests --type py -B 5 -A 5

Repository: flashinfer-ai/flashinfer

Length of output: 1875


🏁 Script executed:

#!/bin/bash
# Check if the reference field is ONLY stored for documentation or if it's ever executed
rg "loads\(|load\(" flashinfer/trace --type py -B 3 -A 10 | head -50

Repository: flashinfer-ai/flashinfer

Length of output: 50


🏁 Script executed:

#!/bin/bash
# Look for any test that loads and executes trace JSON references
find tests -name "*.py" -type f -exec grep -l "loads\|exec\|compile" {} \; | xargs grep -l "reference\|trace.*json" 2>/dev/null

Repository: flashinfer-ai/flashinfer

Length of output: 416


🏁 Script executed:

#!/bin/bash
# Check if reference is ever loaded and executed from the JSON
rg -rn "reference" flashinfer/trace/template.py -B 3 -A 10 | grep -A 10 "reference"

Repository: flashinfer-ai/flashinfer

Length of output: 50


🏁 Script executed:

#!/bin/bash
# Search more directly for any eval/exec in the entire flashinfer package
find flashinfer -name "*.py" -type f | xargs grep -n "eval\|exec\|compile" | head -30

Repository: flashinfer-ai/flashinfer

Length of output: 2853


🏁 Script executed:

#!/bin/bash
# Verify the actual imports in the gdn.py template file
head -30 flashinfer/trace/templates/gdn.py | grep -E "^import|^from"

Repository: flashinfer-ai/flashinfer

Length of output: 189


🏁 Script executed:

#!/bin/bash
# Confirm the imports are NOT in the extracted reference string from JSON
python3 << 'EOF'
import json
with open('tests/trace/fi_trace_out/gdn_decode_qk4_v8_d128.json') as f:
    data = json.load(f)
    if 'reference' in data:
        ref = data['reference']
        print("Checking for imports in the reference field:")
        if 'import math' in ref:
            print("✓ Contains: import math")
        else:
            print("✗ MISSING: import math")
        if 'import torch' in ref:
            print("✓ Contains: import torch")
        else:
            print("✗ MISSING: import torch")
        if 'torch.nn.functional' in ref or 'import F' in ref:
            print("✓ Contains: torch.nn.functional or F import")
        else:
            print("✗ MISSING: torch.nn.functional as F import")
        print("\nFirst 300 chars of reference:")
        print(ref[:300])
EOF

Repository: flashinfer-ai/flashinfer

Length of output: 534


🏁 Script executed:

#!/bin/bash
# Check if there are any plans or code to execute the reference field in the future
rg -rn "reference.*exec|load.*reference|reference.*run" flashinfer --type py -B 3 -A 5

Repository: flashinfer-ai/flashinfer

Length of output: 50


The embedded reference code lacks required imports and will fail if executed.

Line 148's reference field contains math.sqrt and F.softplus but is serialized as a function source string without the required imports (import math and import torch.nn.functional as F). If this reference string is ever executed via exec() or eval(), it will raise NameError for unbound symbols.

Ensure that any trace execution context either:

  1. Injects math and torch.nn.functional (aliased as F) into the execution globals, or
  2. Includes the necessary import statements in the serialized reference string.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/trace/fi_trace_out/gdn_decode_qk4_v8_d128.json` at line 148, The
reference function _gdn_decode_reference uses math.sqrt and F.softplus but the
serialized source string has no imports, causing NameError when exec/eval runs;
fix by either injecting math and torch.nn.functional as F into the exec/eval
globals where _gdn_decode_reference is executed (ensure names "math" and "F" are
present) or prepend/import lines ("import math" and "import torch.nn.functional
as F") to the serialized reference string so _gdn_decode_reference has the
required symbols at runtime.

Comment on lines +159 to +168
"final_state": {
"shape": [
"pool_size",
"num_v_heads",
"head_size",
"head_size"
],
"dtype": "float32",
"description": "Updated recurrent state pool in k-last layout [pool_size, H, V, K]. Unchanged if disable_state_update=True."
}
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Documentation references undefined parameter disable_state_update.

Line 167 states "Unchanged if disable_state_update=True" but disable_state_update is not defined in the inputs section. Either add this parameter to inputs if it's required, or remove the reference from the description.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/trace/fi_trace_out/gdn_mtp_qk4_v8_d128.json` around lines 159 - 168,
The doc string for "final_state" references an undefined parameter
disable_state_update; either add a boolean input named disable_state_update to
the inputs section (e.g., description: "If true, recurrent state updates are
disabled and final_state remains unchanged") or remove the mention "Unchanged if
disable_state_update=True" from the "final_state" description; update the
"final_state" description or inputs accordingly so the documentation no longer
refers to an undefined symbol.

"description": "Updated recurrent state pool in k-last layout [pool_size, H, V, K]. Unchanged if disable_state_update=True."
}
},
"reference": "@torch.no_grad()\ndef _gdn_mtp_reference(\n q, k, v, initial_state, initial_state_indices, A_log, a, dt_bias, b, scale,\n intermediate_states_buffer=None,\n):\n \"\"\"\n Gated Delta Net MTP (Multi-Token Prediction) reference implementation.\n\n State layout: [pool_size, H, V, K] (k-last, K dimension at the end)\n\n Gate computation:\n g = exp(-exp(A_log) * softplus(a + dt_bias))\n beta = sigmoid(b)\n\n For each token t in sequence:\n state_new = g_t * state_old + k_t^T @ (beta_t * v_t + (1-beta_t) * k_t @ state_old) - k_t^T @ (k_t @ state_old)\n output_t = scale * q_t @ state_new\n state_old = state_new # Update for next token\n \"\"\"\n B, T, num_q_heads, head_size = q.shape\n _, _, num_k_heads, _ = k.shape\n _, _, num_v_heads, _ = v.shape\n device = q.device\n\n if scale is None or scale == 0.0:\n scale = 1.0 / math.sqrt(head_size)\n\n x = a.float() + dt_bias.float() # [B, T, HV]\n g = torch.exp(-torch.exp(A_log.float()) * F.softplus(x)) # [B, T, HV]\n beta = torch.sigmoid(b.float()) # [B, T, HV]\n\n q_exp = q.repeat_interleave(num_v_heads // num_q_heads, dim=2) # [B, T, HV, K]\n k_exp = k.repeat_interleave(num_v_heads // num_k_heads, dim=2) # [B, T, HV, K]\n\n output = torch.zeros(\n (B, T, num_v_heads, head_size), dtype=torch.bfloat16, device=device\n )\n cache_intermediate = intermediate_states_buffer is not None\n\n for b_idx in range(B):\n state_idx = int(initial_state_indices[b_idx].item())\n state_HVK = initial_state[state_idx].clone().float().transpose(-1, -2) # [H,V,K] -> [H,K,V]\n\n for t in range(T):\n q_HK = q_exp[b_idx, t].float() # [HV, K]\n k_HK = k_exp[b_idx, t].float() # [HV, K]\n v_HV = v[b_idx, t].float() # [HV, V]\n g_H = g[b_idx, t] # [HV]\n beta_H = beta[b_idx, t] # [HV]\n\n for h_idx in range(num_v_heads):\n q_h = q_HK[h_idx]\n k_h = k_HK[h_idx]\n v_h = v_HV[h_idx]\n h_state = state_HVK[h_idx]\n g_val = g_H[h_idx]\n beta_val = beta_H[h_idx]\n\n old_state = g_val * h_state\n old_v = k_h @ old_state\n new_v = beta_val * v_h + (1 - beta_val) * old_v\n state_remove = k_h.unsqueeze(1) @ old_v.unsqueeze(0)\n state_update = k_h.unsqueeze(1) @ new_v.unsqueeze(0)\n h_state = old_state - state_remove + state_update\n\n output[b_idx, t, h_idx] = (scale * (q_h @ h_state)).to(torch.bfloat16)\n state_HVK[h_idx] = h_state\n\n if cache_intermediate:\n intermediate_states_buffer[state_idx, t] = state_HVK.transpose(-1, -2) # [H,K,V] -> [H,V,K]\n\n final_state = initial_state.clone()\n return output, final_state\n"
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Reference implementation does not return the updated state.

The reference function computes state updates in state_HVK for each batch element, but at the end returns initial_state.clone() instead of the accumulated updated state:

final_state = initial_state.clone()
return output, final_state

This means final_state will always equal the input initial_state, discarding all computed state updates. The correct behavior should write state_HVK.transpose(-1, -2) back to final_state[state_idx] after processing each batch.

🐛 Proposed fix
+    final_state = initial_state.clone()
     for b_idx in range(B):
         state_idx = int(initial_state_indices[b_idx].item())
         state_HVK = initial_state[state_idx].clone().float().transpose(-1, -2)  # [H,V,K] -> [H,K,V]
 
         for t in range(T):
             # ... state update logic ...
 
             if cache_intermediate:
                 intermediate_states_buffer[state_idx, t] = state_HVK.transpose(-1, -2)  # [H,K,V] -> [H,V,K]
 
-    final_state = initial_state.clone()
+        # Write back updated state for this batch element
+        final_state[state_idx] = state_HVK.transpose(-1, -2)  # [H,K,V] -> [H,V,K]
+
     return output, final_state
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/trace/fi_trace_out/gdn_mtp_qk4_v8_d128.json` at line 170, The reference
function _gdn_mtp_reference updates per-batch states in state_HVK but then
returns final_state = initial_state.clone(), discarding updates; fix by creating
final_state = initial_state.clone() before the batch loop and after processing
each batch element (using state_idx = int(initial_state_indices[b_idx].item()))
write the updated state back with final_state[state_idx] =
state_HVK.transpose(-1, -2) (matching the stored [H,V,K] layout); ensure types
remain consistent (match .float()/.to dtype as needed) and then return output,
final_state.

"description": "The 2-based log-sum-exp of attention logits."
}
},
"reference": "@torch.no_grad()\ndef _mla_paged_decode_reference(\n q_nope, q_pe, ckv_cache, kpe_cache, kv_indptr, kv_indices, sm_scale\n):\n batch_size, num_qo_heads, head_dim_ckv = q_nope.shape\n len_indptr = kv_indptr.shape[0]\n\n Kc_all = ckv_cache.squeeze(1).to(torch.float32) # [num_pages, head_dim_ckv]\n Kp_all = kpe_cache.squeeze(1).to(torch.float32) # [num_pages, head_dim_kpe]\n\n output = torch.zeros(\n (batch_size, num_qo_heads, head_dim_ckv), dtype=torch.bfloat16, device=q_nope.device\n )\n lse = torch.full(\n (batch_size, num_qo_heads), -float(\"inf\"), dtype=torch.float32, device=q_nope.device\n )\n\n for b in range(batch_size):\n page_beg = int(kv_indptr[b].item())\n page_end = int(kv_indptr[b + 1].item())\n if page_beg >= page_end:\n output[b].zero_()\n continue\n tok_idx = kv_indices[page_beg:page_end].to(torch.long)\n Kc = Kc_all[tok_idx] # [L, head_dim_ckv]\n Kp = Kp_all[tok_idx] # [L, head_dim_kpe]\n qn = q_nope[b].to(torch.float32) # [num_qo_heads, head_dim_ckv]\n qp = q_pe[b].to(torch.float32) # [num_qo_heads, head_dim_kpe]\n logits = ((qn @ Kc.T) + (qp @ Kp.T)) * sm_scale # [num_qo_heads, L]\n lse[b] = torch.logsumexp(logits, dim=-1) / math.log(2.0)\n output[b] = (torch.softmax(logits, dim=-1) @ Kc).to(torch.bfloat16)\n\n return output, lse\n"
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

🧩 Analysis chain

🌐 Web query:

For PyTorch 2.x: if qnis shape[H, D]andKcis shape[L, P, D], what are the rules for qn @ Kc.T? Also, what does .T do for tensors with more than 2 dimensions?

💡 Result:

For PyTorch 2.x, the @ operator uses torch.matmul rules for matrix multiplication. For qn (shape [H, D]) @ Kc.T where Kc is [L, P, D]: - tensor.T (or .t) on a tensor with more than 2 dimensions reverses all dimensions: for Kc [L, P, D] becomes [D, P, L]. This behavior is deprecated for >2D and will error in future releases; use .mT to transpose only the last two dims or .transpose(-2, -1). - However, matmul treats the 2D qn [H, D] and 3D Kc.T [D, P, L] as batched: prepends 1 to qn making it [...=1, H, D], broadcasts batch dims [1] with [D] (invalid, sizes differ), so raises RuntimeError: "matmul: Input operand 0 does not have enough dimensions" or shape mismatch. The inner matrix dims D==D match, but batch dims don't broadcast. To compute batched attention scores (intended [H, L*P] or reshaped), use qn @ Kc.transpose(-2, -1) which is [H, D] @ [L, P, D] -> [L, P, D]^T=[L, D, P], batched matmul gives [H, L, P] after broadcasting the missing batch dim on qn. .mT on Kc [L,P,D] gives [L, D, P] (transposes last two), then qn @ Kc.mT gives [H, L, P].

Citations:


🏁 Script executed:

# First, let's look at the JSON file structure and find the input shapes
fd -type f -name "mla_paged_decode_h16_ckv512_kpe64_ps64.json" | head -1 | xargs -I {} sh -c '
  # Get file size to decide how to read it
  wc -l {} | awk "{print \$1}" 
'

Repository: flashinfer-ai/flashinfer

Length of output: 238


🏁 Script executed:

# Find the JSON file using git ls-files
git ls-files | grep "mla_paged_decode_h16_ckv512_kpe64_ps64.json"

Repository: flashinfer-ai/flashinfer

Length of output: 136


🏁 Script executed:

# Get the line count and examine the JSON structure
wc -l tests/trace/fi_trace_out/mla_paged_decode_h16_ckv512_kpe64_ps64.json

Repository: flashinfer-ai/flashinfer

Length of output: 140


🏁 Script executed:

# Read the entire JSON file to understand input shapes
cat -n tests/trace/fi_trace_out/mla_paged_decode_h16_ckv512_kpe64_ps64.json

Repository: flashinfer-ai/flashinfer

Length of output: 5318


Reference function squeeze(1) is incorrect for paged layout.

On Line 123, squeeze(1) cannot remove the page_size dimension when it equals 64. The inputs ckv_cache and kpe_cache have shape [num_pages, page_size, head_dim_*], so squeeze(1) leaves them 3D. The subsequent indexing Kc_all[tok_idx] produces [L, 64, head_dim_ckv] instead of the intended 2D [L, head_dim_ckv], causing the matmul qn @ Kc.T to fail.

Replace squeeze(1) with reshape(num_pages * page_size, head_dim_ckv) to flatten pages into a token-major 2D tensor, or adjust the indexing logic if indices are page-level rather than token-level.

Additionally, specify explicit dtypes for kv_indptr and kv_indices (currently "unknown") to reduce schema ambiguity.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/trace/fi_trace_out/mla_paged_decode_h16_ckv512_kpe64_ps64.json` at line
123, In _mla_paged_decode_reference the use of ckv_cache.squeeze(1) and
kpe_cache.squeeze(1) is wrong for paged tensors (shape [num_pages, page_size,
head_dim_*]) and leaves a 3D tensor so Kc_all[tok_idx] yields [L, page_size,
head_dim]; replace squeeze(1) with a flattening reshape (e.g. reshape(num_pages
* page_size, head_dim_ckv) / reshape(..., head_dim_kpe) or view(-1, head_dim_*))
so Kc_all and Kp_all become 2D token-major tensors before indexing, and ensure
kv_indptr and kv_indices are cast to explicit integer dtype (torch.long/int64)
before use to remove schema ambiguity.

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 12

♻️ Duplicate comments (1)
flashinfer/trace/templates/moe.py (1)

25-27: ⚠️ Potential issue | 🟠 Major

The MoE reference is still hard-wired to one hidden/intermediate shape.

H=7168 and I=2048 leak into the scale expansion, output allocation, and G1 split, so any traced MoE with different shapes will either reshape incorrectly or slice the expert output wrong.

Suggested fix
-H = 7168
-I = 2048
 BLOCK = 128
@@
-    T = hidden_states.shape[0]
+    T, H = hidden_states.shape
+    I = gemm2_weights.shape[2]
+    gemm1_out = gemm1_weights.shape[1]
+    if gemm1_out != 2 * I:
+        raise ValueError(
+            f"Invalid gemm1_out_size={gemm1_out}, expected 2 * intermediate_size={2 * I}"
+        )

Also applies to: 53-57, 72-88

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/trace/templates/moe.py` around lines 25 - 27, The code hard-codes
H, I, and BLOCK which leak into scale expansion, output allocation, and the G1
split causing incorrect reshapes/slices for other MoE shapes; replace these
constants with dynamic values derived from the model/tensor shapes (e.g., infer
hidden_size and intermediate_size from the input/weight tensors or pass them as
parameters), update any uses in scale expansion, output allocation, and the G1
split logic (references: H, I, BLOCK, and the G1 split/expert output slicing
code in moe.py) to compute sizes at runtime and use those computed sizes for
reshape, split and slice operations so traced models with different H/I/BLOCK
work correctly.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@flashinfer/trace/templates/attention.py`:
- Line 357: The variables len_indptr and page_size are assigned but never used;
to fix the ruff F841 error, either remove the unused assignments or rename them
to _len_indptr and _page_size (or prefix with a single underscore) where they
are set (e.g., the len_indptr = kv_indptr.shape[0] assignment and the other
page_size assignment in the same module), and apply the same change to the
second occurrence that also triggers the warning so the linter no longer reports
unused locals.
- Around line 144-157: The prefill path incorrectly treats kv_indices as token
indices by indexing k_flat/v_flat directly with page_ids (kv_indices) and
computing num_kv_tokens from page_ids.shape[0]; instead expand the selected
pages to token-level rows first: use kv_indices[kv_start:kv_end] to select page
rows from the full per-page KV buffer (not the flattened token axis), then
concatenate or expand those page rows into token-level k_b and v_b and compute
num_kv_tokens from the resulting expanded KV token rows; update usages around
k_flat, v_flat, page_ids, k_b, v_b and num_kv_tokens so page->token expansion
happens before indexing the flattened token axis.
- Around line 42-53: kv_indices currently represents page IDs, so indexing
k_flat/v_flat directly with kv_indices selects wrong rows when page_size > 1;
instead, first gather the pages from the original k_cache and v_cache using
kv_indices (use kv_indices to index the page dimension of k_cache/v_cache to
produce per-token page slices), then flatten or reshape the gathered per-page
tensors into token-level rows and proceed (so create k_b/v_b by gathering pages
via kv_indices from k_cache/v_cache, then reshape to [T, num_kv_heads, head_dim]
before using them as k_b and v_b); update all uses of k_flat/v_flat and
token_ids accordingly and ensure kv_indptr logic still slices the kv_indices by
token count, not flattened token offsets.
- Around line 359-383: The reference implementations _mla_paged_decode_reference
and _mla_paged_prefill_reference assume page_size==1 by calling
ckv_cache.squeeze(1) and kpe_cache.squeeze(1); instead update these functions to
flatten the page and token dimensions so arbitrary page_size works (e.g.,
replace squeeze(1) with a reshape/flatten to (-1, head_dim_ckv) for Kc_all and
(-1, head_dim_kpe) for Kp_all or use flatten(0,1)), ensuring subsequent indexing
via kv_indices still selects the correct token rows; alternatively, if you
prefer to keep the current code, enforce page_size==1 in the TraceTemplate
schema, but do not leave squeeze(1) as-is.

In `@flashinfer/trace/templates/gdn.py`:
- Around line 153-157: The Tensor schema for the "output" entries in
flashinfer/trace/templates/gdn.py currently uses dtype_from="q" but the
implementation always casts outputs to torch.bfloat16; update the schema to
reflect the real emitted dtype by replacing dtype_from="q" with dtype="bfloat16"
for the "output" Tensor declarations (the entries named "output" in the
template), or alternatively model the runtime control explicitly if outputs can
vary; make the same change for the other "output" Tensor occurrences mentioned
so the trace metadata matches the torch.bfloat16 casts in the code.
- Around line 382-415: The function mutates per-example head states in state_HVK
but returns final_state built from the unchanged initial_state; fix by writing
the updated state_HVK back into final_state before returning. After the outer
loops (or just before return), clone initial_state into final_state as done now
and then for each b_idx set final_state[state_idx] = state_HVK.transpose(-1, -2)
(or assign the corresponding typed/ device-matched tensor) so the updated
[H,V,K] state for the sample index (state_idx derived from
initial_state_indices[b_idx]) is committed; ensure dtype/device matches
initial_state when assigning.
- Around line 205-206: gdn_prefill_trace currently expands q and k with
repeat_interleave using num_v_heads // num_q_heads and num_v_heads //
num_k_heads but does not validate the required head-ratio constraints; add
explicit checks in gdn_prefill_trace to assert num_v_heads >= num_q_heads and
num_v_heads % num_q_heads == 0 and also assert num_k_heads == num_q_heads (or
otherwise enforce the same constraints used by decode/MTP), and apply the same
fixes to the other expansion site (the block around the q/k/v repeat_interleave
at the later occurrence). Ensure the assertions raise clear errors mentioning
num_v_heads, num_q_heads, and num_k_heads so invalid head layouts are rejected
before repeat_interleave is called.

In `@flashinfer/trace/templates/gemm.py`:
- Around line 57-78: The template misdeclares packed uint8 inputs as logical FP4
shapes causing fi_trace to infer wrong K/N; update the public trace signatures
(or add a pre-trace extractor) so the runtime sees packed dimensions: treat A
and B as [M, K_packed] and [K_packed, N_packed] (or expose an extractor that
maps packed -> logical by doubling the last axis) and propagate corrected
logical axes into mm_fp4_trace before calling _mm_fp4_reference/_unpack_fp4;
apply the same change to the other occurrence around the second block (the
191-200 region) and ensure a_descale/b_descale shape metadata matches the
packed-block layout.
- Around line 22-35: The reference GEMM helpers currently transpose B (using .T)
even though B is modeled as the physical [K, N] tensor; update _mm_reference and
_mm_fp8_reference (and the other similar reference helpers in the file) to
multiply A by B directly (remove the .T and B_fp32.T), keeping the same dtype
conversions and return types (e.g., _mm_fp8_reference should still dequantize to
float32, matmul, then cast to bfloat16), and update any docstrings/comments that
incorrectly describe B as needing transpose.

In `@flashinfer/trace/templates/moe.py`:
- Around line 577-598: The direct attribute assignment
trtllm_fp8_block_scale_moe_trace_dispatch.templates causes mypy attr-defined
errors; replace that assignment with a setattr call to attach the templates list
at runtime (e.g., use setattr(trtllm_fp8_block_scale_moe_trace_dispatch,
"templates", list(_MOE_TRACE_BY_ROUTING_TYPE.values()))), keeping the same value
(list of _MOE_TRACE_BY_ROUTING_TYPE.values()) and preserving behavior for
_attach_fi_trace registration.

In `@tests/trace/test_fi_trace_template_consistency.py`:
- Around line 193-199: The test E2E generator currently assigns 0 for int32
scalars in the loop over template.inputs which can create impossible values
(e.g., block_size=0); update the assignment in the loop that inspects
isinstance(descriptor, Scalar) and uses _resolved_param(json_key, descriptor) so
that int32 defaults are positive (e.g., 1 or another small positive) and
preferably support per-parameter overrides for constrained scalars before
populating kwargs; ensure any change keeps optional descriptors skipped and
preserves the dtype branch for non-int32 floats, so assert_fi_trace_complete()
validates realistic traces.

---

Duplicate comments:
In `@flashinfer/trace/templates/moe.py`:
- Around line 25-27: The code hard-codes H, I, and BLOCK which leak into scale
expansion, output allocation, and the G1 split causing incorrect reshapes/slices
for other MoE shapes; replace these constants with dynamic values derived from
the model/tensor shapes (e.g., infer hidden_size and intermediate_size from the
input/weight tensors or pass them as parameters), update any uses in scale
expansion, output allocation, and the G1 split logic (references: H, I, BLOCK,
and the G1 split/expert output slicing code in moe.py) to compute sizes at
runtime and use those computed sizes for reshape, split and slice operations so
traced models with different H/I/BLOCK work correctly.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: 8f771e55-82cf-425e-9083-fba1ef3390e8

📥 Commits

Reviewing files that changed from the base of the PR and between c5296b7 and f7e2129.

📒 Files selected for processing (7)
  • .claude/skills/add-cuda-kernel/SKILL.md
  • flashinfer/api_logging.py
  • flashinfer/trace/templates/attention.py
  • flashinfer/trace/templates/gdn.py
  • flashinfer/trace/templates/gemm.py
  • flashinfer/trace/templates/moe.py
  • tests/trace/test_fi_trace_template_consistency.py

Comment thread flashinfer/trace/templates/attention.py Outdated
Comment on lines +42 to +53
k_flat = k_cache.reshape(-1, num_kv_heads, head_dim).to(torch.float32)
v_flat = v_cache.reshape(-1, num_kv_heads, head_dim).to(torch.float32)

for b in range(batch_size):
page_start = int(kv_indptr[b].item())
page_end = int(kv_indptr[b + 1].item())
if page_start >= page_end:
output[b].zero_()
continue
token_ids = kv_indices[page_start:page_end].to(torch.long)
k_b = k_flat[token_ids] # [T, num_kv_heads, head_dim]
v_b = v_flat[token_ids]
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

kv_indices are page IDs, but decode reference indexes flattened tokens.

This reference is incorrect when page_size > 1: indexing k_flat/v_flat with page IDs selects wrong rows. Use page gather first, then flatten within selected pages.

Proposed fix
-    k_flat = k_cache.reshape(-1, num_kv_heads, head_dim).to(torch.float32)
-    v_flat = v_cache.reshape(-1, num_kv_heads, head_dim).to(torch.float32)
@@
-        token_ids = kv_indices[page_start:page_end].to(torch.long)
-        k_b = k_flat[token_ids]  # [T, num_kv_heads, head_dim]
-        v_b = v_flat[token_ids]
+        page_ids = kv_indices[page_start:page_end].to(torch.long)
+        k_b = k_cache[page_ids].reshape(-1, num_kv_heads, head_dim).to(torch.float32)
+        v_b = v_cache[page_ids].reshape(-1, num_kv_heads, head_dim).to(torch.float32)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/trace/templates/attention.py` around lines 42 - 53, kv_indices
currently represents page IDs, so indexing k_flat/v_flat directly with
kv_indices selects wrong rows when page_size > 1; instead, first gather the
pages from the original k_cache and v_cache using kv_indices (use kv_indices to
index the page dimension of k_cache/v_cache to produce per-token page slices),
then flatten or reshape the gathered per-page tensors into token-level rows and
proceed (so create k_b/v_b by gathering pages via kv_indices from
k_cache/v_cache, then reshape to [T, num_kv_heads, head_dim] before using them
as k_b and v_b); update all uses of k_flat/v_flat and token_ids accordingly and
ensure kv_indptr logic still slices the kv_indices by token count, not flattened
token offsets.

Comment thread flashinfer/trace/templates/attention.py Outdated
Comment on lines +144 to +157
k_flat = k_cache.reshape(-1, num_kv_heads, head_dim).to(torch.float32)
v_flat = v_cache.reshape(-1, num_kv_heads, head_dim).to(torch.float32)

for b in range(len_indptr - 1):
q_start = int(qo_indptr[b].item())
q_end = int(qo_indptr[b + 1].item())
kv_start = int(kv_indptr[b].item())
kv_end = int(kv_indptr[b + 1].item())
if q_start >= q_end or kv_start >= kv_end:
continue
page_ids = kv_indices[kv_start:kv_end].to(torch.long)
k_b = k_flat[page_ids]
v_b = v_flat[page_ids]
num_kv_tokens = page_ids.shape[0]
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Prefill reference has the same page-id/token-id mismatch.

kv_indices are documented as page IDs, but this path indexes a flattened token axis directly. Expand selected pages first, then derive num_kv_tokens from expanded KV rows.

Proposed fix
-    k_flat = k_cache.reshape(-1, num_kv_heads, head_dim).to(torch.float32)
-    v_flat = v_cache.reshape(-1, num_kv_heads, head_dim).to(torch.float32)
@@
-        page_ids = kv_indices[kv_start:kv_end].to(torch.long)
-        k_b = k_flat[page_ids]
-        v_b = v_flat[page_ids]
-        num_kv_tokens = page_ids.shape[0]
+        page_ids = kv_indices[kv_start:kv_end].to(torch.long)
+        k_b = k_cache[page_ids].reshape(-1, num_kv_heads, head_dim).to(torch.float32)
+        v_b = v_cache[page_ids].reshape(-1, num_kv_heads, head_dim).to(torch.float32)
+        num_kv_tokens = k_b.shape[0]
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/trace/templates/attention.py` around lines 144 - 157, The prefill
path incorrectly treats kv_indices as token indices by indexing k_flat/v_flat
directly with page_ids (kv_indices) and computing num_kv_tokens from
page_ids.shape[0]; instead expand the selected pages to token-level rows first:
use kv_indices[kv_start:kv_end] to select page rows from the full per-page KV
buffer (not the flattened token axis), then concatenate or expand those page
rows into token-level k_b and v_b and compute num_kv_tokens from the resulting
expanded KV token rows; update usages around k_flat, v_flat, page_ids, k_b, v_b
and num_kv_tokens so page->token expansion happens before indexing the flattened
token axis.

Comment thread flashinfer/trace/templates/attention.py Outdated
Comment thread flashinfer/trace/templates/attention.py Outdated
Comment on lines +359 to +383
Kc_all = ckv_cache.squeeze(1).to(torch.float32) # [num_pages, head_dim_ckv]
Kp_all = kpe_cache.squeeze(1).to(torch.float32) # [num_pages, head_dim_kpe]

output = torch.zeros(
(batch_size, num_qo_heads, head_dim_ckv), dtype=torch.bfloat16, device=q_nope.device
)
lse = torch.full(
(batch_size, num_qo_heads), -float("inf"), dtype=torch.float32, device=q_nope.device
)

for b in range(batch_size):
page_beg = int(kv_indptr[b].item())
page_end = int(kv_indptr[b + 1].item())
if page_beg >= page_end:
output[b].zero_()
continue
tok_idx = kv_indices[page_beg:page_end].to(torch.long)
Kc = Kc_all[tok_idx] # [L, head_dim_ckv]
Kp = Kp_all[tok_idx] # [L, head_dim_kpe]
qn = q_nope[b].to(torch.float32) # [num_qo_heads, head_dim_ckv]
qp = q_pe[b].to(torch.float32) # [num_qo_heads, head_dim_kpe]
logits = ((qn @ Kc.T) + (qp @ Kp.T)) * sm_scale # [num_qo_heads, L]
lse[b] = torch.logsumexp(logits, dim=-1) / math.log(2.0)
output[b] = (torch.softmax(logits, dim=-1) @ Kc).to(torch.bfloat16)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

🧩 Analysis chain

🏁 Script executed:

#!/bin/bash
set -euo pipefail

# 1) Locate all MLA paged wrappers/templates and check whether page_size is constrained to 1.
rg -nP --type=py '\bmla_paged\b|\bpage_size\b|\bkv_indices\b|\bckv_cache\b|\bkpe_cache\b' -C2

# 2) Find explicit invariants/assertions for page_size==1 (if any).
rg -nP --type=py 'page_size\s*==\s*1|assert\s+.*page_size|squeeze\(\s*1\s*\)' -C2

# 3) Inspect fi_trace tests/fixtures for MLA coverage and page_size values used.
rg -nP --type=py --type=json 'mla_paged|page_size' tests -C2

Repository: flashinfer-ai/flashinfer

Length of output: 50381


🏁 Script executed:

# Examine the specific file and lines mentioned in the review
head -n 520 flashinfer/trace/templates/attention.py | tail -n 180

Repository: flashinfer-ai/flashinfer

Length of output: 6896


🏁 Script executed:

# Check for schema constraints on page_size for MLA
rg -nP 'page_size.*==.*1|page_size.*constraint|mla.*page_size|page_size.*schema' flashinfer -A3 -B3

Repository: flashinfer-ai/flashinfer

Length of output: 6949


🏁 Script executed:

# Look for comments or documentation about MLA page_size limitations
rg -nP 'MLA|mla_paged|mla' flashinfer/trace/templates/ -A5 -B5 | head -n 200

Repository: flashinfer-ai/flashinfer

Length of output: 6947


🏁 Script executed:

# Check how the reference functions are used
rg -n "reference.*mla_paged|_mla_paged.*reference" flashinfer/trace/templates/attention.py -A2 -B2

Repository: flashinfer-ai/flashinfer

Length of output: 620


🏁 Script executed:

# Look at test_fi_trace.py to see if it uses the reference implementations
grep -n "reference\|fi_trace.*mla" tests/test_fi_trace.py | head -30

Repository: flashinfer-ai/flashinfer

Length of output: 50


🏁 Script executed:

# Check if TraceTemplate reference functions are ever called
rg -n "\.reference\(\|run_reference" flashinfer/trace -A3 -B3

Repository: flashinfer-ai/flashinfer

Length of output: 50


🏁 Script executed:

# Check how TraceTemplate is used and if reference is called
rg -n "class TraceTemplate|def.*reference" flashinfer/trace -A5 | head -100

Repository: flashinfer-ai/flashinfer

Length of output: 6857


🏁 Script executed:

# Look at how fi_trace is implemented
find flashinfer -name "*.py" -exec grep -l "fi_trace" {} \; | head -10

Repository: flashinfer-ai/flashinfer

Length of output: 270


🏁 Script executed:

# Check if there's validation code that calls the reference implementations
rg -n "reference" flashinfer/trace/core.py -A3 -B3

Repository: flashinfer-ai/flashinfer

Length of output: 132


🏁 Script executed:

# Check the fi_trace.py file
cat -n flashinfer/fi_trace.py | head -100

Repository: flashinfer-ai/flashinfer

Length of output: 4308


🏁 Script executed:

# Check the template.py file for TraceTemplate
grep -n "class TraceTemplate\|reference" flashinfer/trace/template.py | head -50

Repository: flashinfer-ai/flashinfer

Length of output: 476


🏁 Script executed:

# Look at how reference is used in tests
grep -n "reference\|\.run\|\.plan" tests/test_fi_trace.py -A2 -B2 | head -80

Repository: flashinfer-ai/flashinfer

Length of output: 2127


🏁 Script executed:

# Check if reference functions are ever actually invoked/called
rg -n "template\.reference\|_mla_paged_decode_reference\|_mla_paged_prefill_reference" flashinfer -A3

Repository: flashinfer-ai/flashinfer

Length of output: 50


🏁 Script executed:

# Look at the actual CUDA/kernel implementations to see if they have page_size constraints
find flashinfer/mla -name "*.py" -o -name "*.cu" | head -10

Repository: flashinfer-ai/flashinfer

Length of output: 316


🏁 Script executed:

# Check the cute_dsl MLA implementation for page_size constraints
grep -n "page_size.*1\|squeeze" flashinfer/mla/cute_dsl/mla_decode_fp16.py | head -30

Repository: flashinfer-ai/flashinfer

Length of output: 318


🏁 Script executed:

# Check what the actual MLA wrapper implementation looks like
grep -n "class.*MLA\|def.*run\|page_size" flashinfer/mla/_core.py | head -50

Repository: flashinfer-ai/flashinfer

Length of output: 1405


🏁 Script executed:

# Look at tests to see if they verify correctness with page_size != 1
grep -n "page_size.*64\|page_size.*16\|test_mla" tests/test_fi_trace.py -B5 -A10

Repository: flashinfer-ai/flashinfer

Length of output: 4860


🏁 Script executed:

# Check if there are actual runtime tests for MLA with different page sizes
rg -n "BatchMLAPagedAttentionWrapper\|mla_paged" tests/ -A10 | head -100

Repository: flashinfer-ai/flashinfer

Length of output: 50


🏁 Script executed:

# Check if reference implementations are actually called during testing or validation
rg -n "reference\(" flashinfer -A5 | head -40

Repository: flashinfer-ai/flashinfer

Length of output: 3161


🏁 Script executed:

# Look at the actual wrapper implementation to see how it handles different page_size
sed -n '442,550p' flashinfer/mla/_core.py

Repository: flashinfer-ai/flashinfer

Length of output: 4684


🏁 Script executed:

# Check if squeeze is used anywhere in actual MLA implementations
rg -n "squeeze" flashinfer/mla/_core.py

Repository: flashinfer-ai/flashinfer

Length of output: 213


🏁 Script executed:

# Check if the reference implementations are actually ever invoked during normal operation
rg -n "\.reference\|call.*reference" flashinfer -A5 -B5

Repository: flashinfer-ai/flashinfer

Length of output: 50


🏁 Script executed:

# Verify that actual kernels handle arbitrary page_size
grep -n "page_size\|ckv_cache\|kpe_cache" flashinfer/mla/_core.py | grep -A3 -B3 "cutlass\|cuda"

Repository: flashinfer-ai/flashinfer

Length of output: 708


🏁 Script executed:

# Check what the cute_dsl actually does with page_size
head -n 100 flashinfer/mla/cute_dsl/mla_decode_fp16.py | grep -n "page_size"

Repository: flashinfer-ai/flashinfer

Length of output: 50


🏁 Script executed:

# Check if reference implementations are embedded in JSON or used elsewhere
grep -rn "reference" flashinfer/trace/template.py | head -20

Repository: flashinfer-ai/flashinfer

Length of output: 451


🏁 Script executed:

# Look for how flashinfer-bench uses the reference
find . -name "*bench*" -type f | grep -E "\.py$|\.json$" | head -10

Repository: flashinfer-ai/flashinfer

Length of output: 515


🏁 Script executed:

# Verify that tests with page_size=64 actually pass/work
grep -A 20 "def test_mla_paged_fi_trace" tests/test_fi_trace.py | head -30

Repository: flashinfer-ai/flashinfer

Length of output: 917


Reference implementations for MLA paged templates assume single-token pages but schema permits arbitrary page_size.

The _mla_paged_decode_reference() and _mla_paged_prefill_reference() functions use squeeze(1) to remove the page dimension, which is only valid when page_size == 1. However, the TraceTemplate schema defines page_size as a generic Const() without constraining it to 1. These reference implementations are embedded in the JSON trace output and used by flashinfer-bench for validation.

Tests already use page_size=64 (test_fi_trace.py:253, test_fi_trace.py:419), which would cause the reference implementations to fail or produce incorrect results.

Either enforce page_size == 1 as a schema constraint, or rewrite the reference implementations to properly flatten pages and tokens (e.g., Kc_all = ckv_cache.reshape(-1, head_dim_ckv) instead of squeeze) before the indexing operations.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/trace/templates/attention.py` around lines 359 - 383, The
reference implementations _mla_paged_decode_reference and
_mla_paged_prefill_reference assume page_size==1 by calling ckv_cache.squeeze(1)
and kpe_cache.squeeze(1); instead update these functions to flatten the page and
token dimensions so arbitrary page_size works (e.g., replace squeeze(1) with a
reshape/flatten to (-1, head_dim_ckv) for Kc_all and (-1, head_dim_kpe) for
Kp_all or use flatten(0,1)), ensuring subsequent indexing via kv_indices still
selects the correct token rows; alternatively, if you prefer to keep the current
code, enforce page_size==1 in the TraceTemplate schema, but do not leave
squeeze(1) as-is.

Comment on lines +153 to +157
"output": Tensor(
["batch_size", "seq_len", "num_v_heads", "head_size"],
dtype_from="q",
description="Attention output. Shape follows num_v_heads in GVA mode.",
),
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

The templates report output as dtype_from="q", but the references always emit bfloat16.

Lines 91, 208-210, and 377-379 cast the output tensors to torch.bfloat16, so these schemas become wrong as soon as q is not already bfloat16. The trace metadata should either fix the dtype to bfloat16 or model the real output-dtype control explicitly.

Suggested fix
-            dtype_from="q",
+            dtype="bfloat16",
@@
-            dtype_from="q",
+            dtype="bfloat16",
@@
-            dtype_from="q",
+            dtype="bfloat16",

Also applies to: 321-325, 486-490

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/trace/templates/gdn.py` around lines 153 - 157, The Tensor schema
for the "output" entries in flashinfer/trace/templates/gdn.py currently uses
dtype_from="q" but the implementation always casts outputs to torch.bfloat16;
update the schema to reflect the real emitted dtype by replacing dtype_from="q"
with dtype="bfloat16" for the "output" Tensor declarations (the entries named
"output" in the template), or alternatively model the runtime control explicitly
if outputs can vary; make the same change for the other "output" Tensor
occurrences mentioned so the trace metadata matches the torch.bfloat16 casts in
the code.

Comment thread flashinfer/trace/templates/gemm.py Outdated
Comment on lines +22 to +35
def _mm_reference(A, B):
return torch.matmul(A, B.T)


def _mm_fp8_reference(A, B):
"""Dequantize FP8 block-scale inputs and compute C = A @ B.T.

B is in TRT-LLM block layout [K//block_size, N, block_size] and is
reshaped to [K, N] before the matmul.
"""
K_div_bs, N, block_size = B.shape
B_fp32 = B.reshape(K_div_bs * block_size, N).to(torch.float32)
A_fp32 = A.to(torch.float32)
return torch.matmul(A_fp32, B_fp32.T).to(torch.bfloat16)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

The GEMM references are transposing B after already modeling it as physical [K, N].

For mm_bf16 this is shape-invalid as soon as N != K, and the quantized helpers have the same problem after dequantization. Given these templates describe b as the physical [K, N] tensor, the reference path should multiply by B directly.

Suggested fix
 def _mm_reference(A, B):
-    return torch.matmul(A, B.T)
+    return torch.matmul(A, B)
@@
-    return torch.matmul(A_fp32, B_fp32.T).to(torch.bfloat16)
+    return torch.matmul(A_fp32, B_fp32).to(torch.bfloat16)
@@
-    return torch.matmul(A_scaled, B_scaled.T).to(torch.bfloat16)
+    return torch.matmul(A_scaled, B_scaled).to(torch.bfloat16)
@@
-    return torch.matmul(A_scaled, B_scaled.T).to(torch.bfloat16)
+    return torch.matmul(A_scaled, B_scaled).to(torch.bfloat16)

Also applies to: 38-55, 57-85

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/trace/templates/gemm.py` around lines 22 - 35, The reference GEMM
helpers currently transpose B (using .T) even though B is modeled as the
physical [K, N] tensor; update _mm_reference and _mm_fp8_reference (and the
other similar reference helpers in the file) to multiply A by B directly (remove
the .T and B_fp32.T), keeping the same dtype conversions and return types (e.g.,
_mm_fp8_reference should still dequantize to float32, matmul, then cast to
bfloat16), and update any docstrings/comments that incorrectly describe B as
needing transpose.

Comment on lines +57 to +78
def _mm_fp4_reference(A, B, a_descale, b_descale, block_size=16):
"""Dequantize FP4 inputs and compute C = A @ B.T.

A and B are fp4 e2m1fn values packed two-per-byte as uint8.
a_descale: [M, K//block_size], b_descale: [K, N//block_size].
The reference unpacks the nibbles and applies the block scales.
"""
def _unpack_fp4(packed, rows, cols):
# Each byte holds two fp4 nibbles (low nibble = first element).
lo = (packed & 0x0F).to(torch.float32)
hi = ((packed >> 4) & 0x0F).to(torch.float32)
# Interleave low/high nibbles along the last dimension.
out = torch.stack([lo, hi], dim=-1).reshape(rows, cols)
return out

M, K_packed = A.shape
K = K_packed * 2
_, N_packed = B.shape
N = N_packed * 2

A_fp32 = _unpack_fp4(A, M, K)
B_fp32 = _unpack_fp4(B, K, N)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

mm_fp4_trace cannot infer the right logical axes from packed inputs.

Lines 72-78 make it clear the runtime tensors are packed uint8 shapes, but the template still declares a and b as [M, K] and [K, N]. fi_trace will therefore report halved or conflicting K/N values for real FP4 calls. This needs packed-dimension axes or a custom extractor before the public API can emit correct runtime traces.

Also applies to: 191-200

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/trace/templates/gemm.py` around lines 57 - 78, The template
misdeclares packed uint8 inputs as logical FP4 shapes causing fi_trace to infer
wrong K/N; update the public trace signatures (or add a pre-trace extractor) so
the runtime sees packed dimensions: treat A and B as [M, K_packed] and
[K_packed, N_packed] (or expose an extractor that maps packed -> logical by
doubling the last axis) and propagate corrected logical axes into mm_fp4_trace
before calling _mm_fp4_reference/_unpack_fp4; apply the same change to the other
occurrence around the second block (the 191-200 region) and ensure
a_descale/b_descale shape metadata matches the packed-block layout.

Comment thread flashinfer/trace/templates/moe.py
Comment on lines +577 to +598
def trtllm_fp8_block_scale_moe_trace_dispatch(**kwargs):
"""Return the appropriate TraceTemplate for the given ``routing_method_type``.

Pass this as ``trace=trtllm_fp8_block_scale_moe_trace_dispatch`` to
``@flashinfer_api`` so the correct template is selected at call time::

@flashinfer_api(trace=trtllm_fp8_block_scale_moe_trace_dispatch)
def trtllm_fp8_block_scale_moe(..., routing_method_type: int = 0, ...):
...

Returns ``None`` for ``RoutingMethodType.Unspecified`` (6), which
suppresses trace generation.
"""
routing_method_type = int(kwargs.get("routing_method_type", 0))
return _MOE_TRACE_BY_ROUTING_TYPE.get(routing_method_type)


# Expose all possible templates so _attach_fi_trace can auto-register them
# in _TRACE_REGISTRY for consistency testing.
trtllm_fp8_block_scale_moe_trace_dispatch.templates = list(
_MOE_TRACE_BY_ROUTING_TYPE.values()
)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

🧩 Analysis chain

🏁 Script executed:

# First, let's examine the file to see the current state
cat -n flashinfer/trace/templates/moe.py | sed -n '570,605p'

Repository: flashinfer-ai/flashinfer

Length of output: 1461


🏁 Script executed:

# Check if there's a mypy configuration or pre-commit hooks
fd -type f -name ".pre-commit*" -o -name "mypy.ini" -o -name "pyproject.toml" | head -20

Repository: flashinfer-ai/flashinfer

Length of output: 238


🏁 Script executed:

# Look for mypy configuration in pyproject.toml or setup.cfg
fd -type f \( -name "pyproject.toml" -o -name "setup.cfg" \) -exec head -50 {} +

Repository: flashinfer-ai/flashinfer

Length of output: 238


🏁 Script executed:

find . -maxdepth 2 -type f \( -name "pyproject.toml" -o -name ".pre-commit-config.yaml" -o -name "setup.cfg" \) 2>/dev/null

Repository: flashinfer-ai/flashinfer

Length of output: 182


🏁 Script executed:

# Search for mypy configuration
rg -l "mypy|attr-defined" --type toml --type yaml --type cfg 2>/dev/null | head -10

Repository: flashinfer-ai/flashinfer

Length of output: 50


🏁 Script executed:

# Check if there are any comments or errors mentioning attr-defined in the moe.py file
rg "attr-defined|setattr" flashinfer/trace/templates/moe.py

Repository: flashinfer-ai/flashinfer

Length of output: 50


🏁 Script executed:

# Check if setattr is used elsewhere in the codebase for similar purposes
rg "setattr\(" flashinfer/ -A 2 | head -40

Repository: flashinfer-ai/flashinfer

Length of output: 50


🏁 Script executed:

cat .pre-commit-config.yaml

Repository: flashinfer-ai/flashinfer

Length of output: 1722


🏁 Script executed:

cat pyproject.toml

Repository: flashinfer-ai/flashinfer

Length of output: 3158


🏁 Script executed:

# Check if there are any .pyi stub files for this module
find flashinfer/trace -name "*.pyi" 2>/dev/null

Repository: flashinfer-ai/flashinfer

Length of output: 50


🏁 Script executed:

# Look for Protocol or TypedDict definitions related to function templates
rg "Protocol|TypedDict" flashinfer/trace/ --type python | head -20

Repository: flashinfer-ai/flashinfer

Length of output: 98


🏁 Script executed:

# Check if _attach_fi_trace function expects templates attribute
rg "_attach_fi_trace" flashinfer/ -B 5 -A 10 | head -50

Repository: flashinfer-ai/flashinfer

Length of output: 3287


🏁 Script executed:

# Check for other similar patterns where attributes are assigned to functions
rg "^\w+\.\w+\s*=\s*" flashinfer/trace/templates/ -A 0 | head -20

Repository: flashinfer-ai/flashinfer

Length of output: 161


🏁 Script executed:

# Verify the exact mypy behavior - look at the api_logging.py more carefully
cat -n flashinfer/api_logging.py | sed -n '1,100p'

Repository: flashinfer-ai/flashinfer

Length of output: 4356


🏁 Script executed:

# Look for any existing patterns where setattr might be used or recommended
rg "templates\s*=" flashinfer/ --type py -B 2 -A 2

Repository: flashinfer-ai/flashinfer

Length of output: 460


🏁 Script executed:

# Check if there are any mypy: ignore comments or type: ignore in moe.py
rg "mypy|type.*ignore" flashinfer/trace/templates/moe.py

Repository: flashinfer-ai/flashinfer

Length of output: 50


🏁 Script executed:

# Look for the function signature and what it returns
cat -n flashinfer/trace/templates/moe.py | sed -n '577,592p'

Repository: flashinfer-ai/flashinfer

Length of output: 881


🏁 Script executed:

# Check what TraceTemplate is to understand the type signature better
rg "class TraceTemplate|def TraceTemplate" flashinfer/ --type py -A 5 | head -30

Repository: flashinfer-ai/flashinfer

Length of output: 380


🏁 Script executed:

# Verify the fix would work by checking how setattr is commonly used in the codebase
rg "setattr" flashinfer/ --type py -B 1 -A 1 | head -20

Repository: flashinfer-ai/flashinfer

Length of output: 356


🏁 Script executed:

# Find where the .templates attribute is accessed
rg "\.templates" flashinfer/ --type py -B 2 -A 2

Repository: flashinfer-ai/flashinfer

Length of output: 4375


🏁 Script executed:

# Check if there's any hasattr check
rg "hasattr.*templates" flashinfer/ --type py

Repository: flashinfer-ai/flashinfer

Length of output: 50


🏁 Script executed:

# Look at the complete _attach_fi_trace function to understand how templates is used
rg "_attach_fi_trace" flashinfer/api_logging.py -A 50 | head -80

Repository: flashinfer-ai/flashinfer

Length of output: 3908


Use setattr() to avoid mypy's attr-defined error.

The direct assignment at line 596 triggers mypy's strict mode because the function object has no declared templates attribute. Using setattr() preserves identical runtime behavior while satisfying type checking requirements, which is necessary for pre-commit validation.

Suggested fix
-trtllm_fp8_block_scale_moe_trace_dispatch.templates = list(
-    _MOE_TRACE_BY_ROUTING_TYPE.values()
-)
+setattr(
+    trtllm_fp8_block_scale_moe_trace_dispatch,
+    "templates",
+    list(_MOE_TRACE_BY_ROUTING_TYPE.values()),
+)
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
def trtllm_fp8_block_scale_moe_trace_dispatch(**kwargs):
"""Return the appropriate TraceTemplate for the given ``routing_method_type``.
Pass this as ``trace=trtllm_fp8_block_scale_moe_trace_dispatch`` to
``@flashinfer_api`` so the correct template is selected at call time::
@flashinfer_api(trace=trtllm_fp8_block_scale_moe_trace_dispatch)
def trtllm_fp8_block_scale_moe(..., routing_method_type: int = 0, ...):
...
Returns ``None`` for ``RoutingMethodType.Unspecified`` (6), which
suppresses trace generation.
"""
routing_method_type = int(kwargs.get("routing_method_type", 0))
return _MOE_TRACE_BY_ROUTING_TYPE.get(routing_method_type)
# Expose all possible templates so _attach_fi_trace can auto-register them
# in _TRACE_REGISTRY for consistency testing.
trtllm_fp8_block_scale_moe_trace_dispatch.templates = list(
_MOE_TRACE_BY_ROUTING_TYPE.values()
)
def trtllm_fp8_block_scale_moe_trace_dispatch(**kwargs):
"""Return the appropriate TraceTemplate for the given ``routing_method_type``.
Pass this as ``trace=trtllm_fp8_block_scale_moe_trace_dispatch`` to
``@flashinfer_api`` so the correct template is selected at call time::
`@flashinfer_api`(trace=trtllm_fp8_block_scale_moe_trace_dispatch)
def trtllm_fp8_block_scale_moe(..., routing_method_type: int = 0, ...):
...
Returns ``None`` for ``RoutingMethodType.Unspecified`` (6), which
suppresses trace generation.
"""
routing_method_type = int(kwargs.get("routing_method_type", 0))
return _MOE_TRACE_BY_ROUTING_TYPE.get(routing_method_type)
# Expose all possible templates so _attach_fi_trace can auto-register them
# in _TRACE_REGISTRY for consistency testing.
setattr(
trtllm_fp8_block_scale_moe_trace_dispatch,
"templates",
list(_MOE_TRACE_BY_ROUTING_TYPE.values()),
)
🧰 Tools
🪛 GitHub Actions: pre-commit

[error] 596-596: mypy error: "Callable[[KwArg(Any)], Any]" has no attribute "templates" [attr-defined]

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/trace/templates/moe.py` around lines 577 - 598, The direct
attribute assignment trtllm_fp8_block_scale_moe_trace_dispatch.templates causes
mypy attr-defined errors; replace that assignment with a setattr call to attach
the templates list at runtime (e.g., use
setattr(trtllm_fp8_block_scale_moe_trace_dispatch, "templates",
list(_MOE_TRACE_BY_ROUTING_TYPE.values()))), keeping the same value (list of
_MOE_TRACE_BY_ROUTING_TYPE.values()) and preserving behavior for
_attach_fi_trace registration.

Comment on lines +193 to +199
for json_key, descriptor in template.inputs.items():
if isinstance(descriptor, Scalar):
if descriptor.optional:
continue
p = _resolved_param(json_key, descriptor)
kwargs[p] = 0 if descriptor.dtype == "int32" else 1.0

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

The E2E generator should not synthesize int32 inputs as zero.

This currently creates impossible definitions such as block_size=0, so assert_fi_trace_complete() can pass on traces that would never be valid at runtime. Use a positive default, ideally with per-parameter overrides for constrained scalars.

Suggested fix
+_INT_SAMPLE_DEFAULTS = {
+    "block_size": 16,
+    "top_k": 1,
+    "n_group": 1,
+    "topk_group": 1,
+}
+
@@
         if isinstance(descriptor, Scalar):
             if descriptor.optional:
                 continue
             p = _resolved_param(json_key, descriptor)
-            kwargs[p] = 0 if descriptor.dtype == "int32" else 1.0
+            if descriptor.dtype == "int32":
+                kwargs[p] = _INT_SAMPLE_DEFAULTS.get(p, axis_size)
+            else:
+                kwargs[p] = 1.0
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/trace/test_fi_trace_template_consistency.py` around lines 193 - 199,
The test E2E generator currently assigns 0 for int32 scalars in the loop over
template.inputs which can create impossible values (e.g., block_size=0); update
the assignment in the loop that inspects isinstance(descriptor, Scalar) and uses
_resolved_param(json_key, descriptor) so that int32 defaults are positive (e.g.,
1 or another small positive) and preferably support per-parameter overrides for
constrained scalars before populating kwargs; ensure any change keeps optional
descriptors skipped and preserves the dtype branch for non-int32 floats, so
assert_fi_trace_complete() validates realistic traces.

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 8

♻️ Duplicate comments (12)
flashinfer/trace/templates/gdn.py (3)

165-169: ⚠️ Potential issue | 🟠 Major

Report output as bfloat16 in the schema.

All three references cast their outputs to torch.bfloat16, so inheriting dtype from q makes the trace metadata wrong whenever callers use another input dtype.

Also applies to: 351-355, 537-541

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/trace/templates/gdn.py` around lines 165 - 169, The schema
currently sets the attention "output" Tensor using dtype_from="q", which
misreports dtype because outputs are cast to torch.bfloat16; update the Tensor
definition for the "output" field in the GDN templates to use an explicit dtype
of "bfloat16" (replace dtype_from="q" with dtype="bfloat16") for the occurrences
around the shown block and the other two occurrences (near lines 351-355 and
537-541) so the trace metadata correctly reflects torch.bfloat16 outputs.

421-458: ⚠️ Potential issue | 🟠 Major

Persist the updated pooled state before returning.

state_HVK is updated for every token, but final_state is cloned from initial_state after the loop and never receives those updates. The returned final_state is therefore stale, and the generated JSON fixture will be stale too.

Suggested fix
-    for b_idx in range(B):
+    final_state = initial_state.clone()
+    for b_idx in range(B):
         state_idx = int(initial_state_indices[b_idx].item())
         state_HVK = (
             initial_state[state_idx].clone().float().transpose(-1, -2)
         )  # [H,V,K] -> [H,K,V]
@@
             if cache_intermediate:
                 intermediate_states_buffer[state_idx, t] = state_HVK.transpose(
                     -1, -2
                 )  # [H,K,V] -> [H,V,K]
-
-    final_state = initial_state.clone()
+        final_state[state_idx] = state_HVK.transpose(-1, -2)
     return output, final_state
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/trace/templates/gdn.py` around lines 421 - 458, The loop updates
state_HVK per batch/state but final_state is created from initial_state and
never updated, so return value is stale; update final_state with the pooled
(transposed) state_HVK for each corresponding state index
(initial_state_indices) after finishing updates for that state (or after the
outer loops) so final_state[state_idx] = state_HVK.transpose(-1, -2) (match the
same [H,V,K] ↔ [H,K,V] orientation used for initial_state/state_HVK) before
returning output and final_state.

362-365: ⚠️ Potential issue | 🟠 Major

gdn_prefill_trace needs the same head-ratio constraints as the other GDN templates.

The reference divides by num_v_heads // num_q_heads and num_v_heads // num_k_heads, but this template currently accepts layouts that make those expansions invalid.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/trace/templates/gdn.py` around lines 362 - 365, The
gdn_prefill_trace template is missing head-ratio validity checks: add
constraints ensuring num_v_heads is divisible by num_q_heads and by num_k_heads
(e.g., num_v_heads % num_q_heads == 0 and num_v_heads % num_k_heads == 0) so the
downstream divisions (num_v_heads // num_q_heads and num_v_heads // num_k_heads)
used elsewhere are valid; update the constraints list in gdn_prefill_trace to
include these checks referencing the variables num_v_heads, num_q_heads, and
num_k_heads.
flashinfer/trace/templates/gemm.py (2)

180-217: ⚠️ Potential issue | 🟠 Major

The FP4 trace schema still advertises unpacked shapes.

A and B are packed uint8 tensors at runtime, so exposing them as [M, K] and [K, N] makes fi_trace infer the wrong dimensions for real FP4 calls. Model the packed axes explicitly or add an extractor that maps packed sizes back to logical K/N.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/trace/templates/gemm.py` around lines 180 - 217, The mm_fp4_trace
TraceTemplate currently lists inputs "A" and "B" with unpacked shapes ["M","K"]
and ["K","N"], but at runtime these are packed uint8 FP4 buffers; update
mm_fp4_trace so the Tensor entries for "A" and "B" describe the packed axes
(e.g., K_packed/K_block or bytes per packed row) or add an extractor that
converts the packed dimensions back to logical K and N (use the existing
"block_size" Var/Scalar to compute K//block_size and N//block_size);
specifically modify the Tensor definitions for "A" and "B" in mm_fp4_trace (and
any related axis defs such as "K" or "N") so fi_trace will infer correct runtime
shapes for FP4-packed inputs.

22-35: ⚠️ Potential issue | 🟠 Major

Multiply by the physical [K, N] weight matrix in these references.

Each template models B as a physical [K, N] tensor, but the references all call ... @ B.T. That breaks mm_bf16 as soon as N != K and skews the quantized references the same way.

Suggested fix
 def _mm_reference(A, B):
-    return torch.matmul(A, B.T)
+    return torch.matmul(A, B)
@@
-    return torch.matmul(A_fp32, B_fp32.T).to(torch.bfloat16)
+    return torch.matmul(A_fp32, B_fp32).to(torch.bfloat16)
@@
-    return torch.matmul(A_scaled, B_scaled.T).to(torch.bfloat16)
+    return torch.matmul(A_scaled, B_scaled).to(torch.bfloat16)
@@
-    return torch.matmul(A_scaled, B_scaled.T).to(torch.bfloat16)
+    return torch.matmul(A_scaled, B_scaled).to(torch.bfloat16)

Also applies to: 38-55, 57-86

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/trace/templates/gemm.py` around lines 22 - 35, The reference
implementations currently multiply by B.T, but B represents the physical [K, N]
weight matrix so using B.T swaps dims and breaks cases where N != K; update
_mm_reference to compute torch.matmul(A, B) (not A @ B.T), and in
_mm_fp8_reference reshape B into [K, N] (B_fp32 = B.reshape(K_div_bs *
block_size, N)) and use torch.matmul(A_fp32, B_fp32) (remove the trailing .T),
applying the same fix to the other reference helpers mentioned (the FP8 and bf16
variants in the file).
flashinfer/trace/templates/attention.py (3)

140-169: ⚠️ Potential issue | 🟠 Major

Expand selected pages before applying the prefill causal window.

The reference currently indexes k_flat/v_flat with page ids and sets num_kv_tokens = page_ids.shape[0], so both the causal window and the gathered KV tensors are off by page_size for real paged caches.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/trace/templates/attention.py` around lines 140 - 169, The code is
indexing k_flat/v_flat by page_ids and treating num_kv_tokens as
page_ids.shape[0], which is incorrect for paged caches; you must expand page
indices to per-token indices before building k_b/v_b and computing num_kv_tokens
so the causal window and gathers operate at token granularity. Change the gather
so that page_ids are multiplied/expanded by page_size into token_indices (e.g.
token_indices = page_ids.unsqueeze(1)*page_size + torch.arange(page_size,
device=...)) and then use those token_indices to index the original
k_flat/v_flat (or reshape k_cache/v_cache into per-token and gather by
token_indices) so k_b/v_b contain all tokens from the selected pages, set
num_kv_tokens = token_indices.numel() (or actual token count if last page
partial), and adjust uses of max_kv, delta, and slicing (k_b[:max_kv],
v_b[:max_kv]) accordingly.

357-385: ⚠️ Potential issue | 🟠 Major

These MLA references still assume page_size == 1.

Both paths call squeeze(1) on paged caches, but the schema accepts arbitrary page_size and the tests already use larger values like 64. Flatten the page/token dimensions or constrain the template to single-token pages.

Also applies to: 476-518

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/trace/templates/attention.py` around lines 357 - 385, The code
currently assumes page_size == 1 by calling ckv_cache.squeeze(1) and
kpe_cache.squeeze(1) (variables Kc_all/Kp_all) which collapses the page
dimension; instead merge the page and token dimensions so arbitrary page_size
works: replace the squeeze(1) usage with a reshape/view that flattens the first
two dims (e.g. ckv_cache.reshape(-1, head_dim_ckv) and kpe_cache.reshape(-1,
head_dim_kpe)) and keep using kv_indices[kv_indptr[b]:kv_indptr[b+1]] (tok_idx)
to index into the flattened Kc_all/Kp_all so Kc = Kc_all[tok_idx] and Kp =
Kp_all[tok_idx] work for multi-token pages; apply the same change to the other
block (lines ~476-518) where ckv_cache/kpe_cache are squeezed.

39-58: ⚠️ Potential issue | 🟠 Major

Treat kv_indices as page ids in the decode reference.

kv_indices are documented as page ids, but this code indexes the flattened token buffer with them. That only stays correct when page_size == 1; otherwise the reference gathers the wrong KV rows.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/trace/templates/attention.py` around lines 39 - 58, kv_indices are
page IDs but the code treats them as token indices when building token_ids; fix
by expanding each page id into its page_size token-row indices before indexing
k_flat/v_flat. In the loop over b replace the current token_ids =
kv_indices[page_start:page_end].to(torch.long) with logic that maps each page id
p to the contiguous token index range p*page_size .. (p+1)*page_size-1
(preserving dtype/device), then flatten that to a 1D tensor and use it to build
k_b and v_b so k_b/v_b remain shaped [T, num_kv_heads, head_dim]; keep using
k_flat/v_flat, kv_indptr, kv_indices, page_size, k_b, v_b, token_ids, and ensure
device/torch.long handling remains correct.
tests/trace/example.py (1)

54-373: ⚠️ Potential issue | 🟠 Major

Pytest won’t collect this trace example.

Everything after the env setup runs as import-time side effects, but the file defines no test_... entrypoint. CI will never exercise trace generation unless this body is moved into a real pytest test and the script entrypoint is kept separate.

As per coding guidelines, tests/**/*.py: Prefix test functions with test_ and structure tests by feature in tests/ subdirectories matching kernel categories.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/trace/example.py` around lines 54 - 373, The file runs the entire
trace-generation body at import time so pytest won't collect it; keep the
environment setup (the FLASHINFER_TRACE_* os.environ lines and SAVE_DIR) and the
imports as-is but move everything that performs work (starting from
device/WORKSPACE and all calls that exercise flashinfer APIs, e.g., the loops
that call flashinfer.rmsnorm, flashinfer.mm_bf16,
BatchDecodeWithPagedKVCacheWrapper.plan/run,
BatchPrefillWithPagedKVCacheWrapper.plan/run,
BatchPrefillWithRaggedKVCacheWrapper.plan/run,
BatchMLAPagedAttentionWrapper.plan/run, flashinfer.gdn_decode.*,
flashinfer.fused_moe.* and the final JSON summary) into a pytest test function
named with a test_ prefix (e.g., test_generate_fi_traces) so CI will execute it;
also add a minimal if __name__ == "__main__" guard to call that function when
run as a script so the example remains runnable standalone.
flashinfer/trace/templates/moe.py (3)

648-652: ⚠️ Potential issue | 🟡 Minor

Use setattr() to avoid mypy attr-defined error.

The direct attribute assignment triggers mypy's attr-defined error because function objects don't have a declared templates attribute. Use setattr() to preserve runtime behavior while satisfying type checking.

Suggested fix
 # Expose all possible templates so _attach_fi_trace can auto-register them
 # in _TRACE_REGISTRY for consistency testing.
-trtllm_fp8_block_scale_moe_trace_dispatch.templates = list(
-    _MOE_TRACE_BY_ROUTING_TYPE.values()
-)
+setattr(
+    trtllm_fp8_block_scale_moe_trace_dispatch,
+    "templates",
+    list(_MOE_TRACE_BY_ROUTING_TYPE.values()),
+)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/trace/templates/moe.py` around lines 648 - 652, The assignment to
add a dynamic attribute on the function
trtllm_fp8_block_scale_moe_trace_dispatch causes mypy attr-defined errors;
replace the direct assignment with a setattr call so the templates attribute is
attached at runtime (e.g., setattr(trtllm_fp8_block_scale_moe_trace_dispatch,
"templates", list(_MOE_TRACE_BY_ROUTING_TYPE.values()))) to preserve behavior
while satisfying type checking.

25-27: ⚠️ Potential issue | 🟠 Major

Hardcoded H and I constants make reference execution shape-fragile.

The module-level constants H=7168 and I=2048 are used in _fp8_moe_run_experts but the actual hidden_size and intermediate_size can vary. This will produce incorrect results or errors for other valid MoE configurations.

Suggested fix — derive H and I from tensor shapes
-H = 7168
-I = 2048
 BLOCK = 128


 `@torch.no_grad`()
 def _fp8_moe_run_experts(
     hidden_states,
     hidden_states_scale,
     gemm1_weights,
     gemm1_weights_scale,
     gemm2_weights,
     gemm2_weights_scale,
     weights,
     topk_idx,
     local_expert_offset,
     E_global,
 ):
-    T = hidden_states.shape[0]
+    T, H = hidden_states.shape
+    I = gemm2_weights.shape[2]
     E_local = gemm1_weights.shape[0]
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/trace/templates/moe.py` around lines 25 - 27, The module currently
hardcodes H=7168 and I=2048 which breaks _fp8_moe_run_experts for models with
different hidden/intermediate sizes; change the code to derive hidden_size and
intermediate_size at runtime from tensor shapes (e.g., infer hidden_size from
the input/hidden tensor shape[-1] or the expert weight shapes, and infer
intermediate_size from the feedforward weight/output shapes) and replace uses of
H and I with those derived values (also ensure BLOCK is computed/validated
against hidden_size if needed); update all references in _fp8_moe_run_experts to
use the derived variables so the function works for arbitrary MoE shapes.

126-131: ⚠️ Potential issue | 🟠 Major

Reference implementations hardcode routing parameters that should be configurable.

TOP_K=8, N_GROUP=8, TOPK_GROUP=4 are hardcoded, but the public API accepts these as arguments. If these references are used for numerical validation, they will only be correct for one configuration.

If these references are only for schema validation (not numerical correctness), consider adding a comment to clarify their limited scope.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/trace/templates/moe.py` around lines 126 - 131, The template
hardcodes routing parameters TOP_K, N_GROUP, TOPK_GROUP which conflict with the
public API; update the code that defines TOP_K, N_GROUP, TOPK_GROUP to read the
corresponding function arguments (e.g., top_k, n_group, topk_group) or the
routing parameters object instead of fixed literals so the template matches
whatever configuration is passed via routing_logits' caller (or if these values
are truly only for shape/schema checks, replace the literals with a clarifying
comment near TOP_K/N_GROUP/TOPK_GROUP stating they are placeholder defaults used
only for schema validation and not numeric correctness). Ensure you change the
occurrences of TOP_K, N_GROUP, TOPK_GROUP in this module (referenced with
routing_logits, E_global, T) accordingly.
🧹 Nitpick comments (8)
flashinfer/trace/template.py (3)

473-473: Consider using list unpacking for slightly cleaner syntax.

Ruff suggests [f"fi_api:{fi_api}", *template.tags] instead of list concatenation.

Suggested fix
-            all_tags = [f"fi_api:{fi_api}"] + template.tags
+            all_tags = [f"fi_api:{fi_api}", *template.tags]
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/trace/template.py` at line 473, Replace the list concatenation
that builds all_tags with list unpacking for clearer syntax: in the
function/block where variable all_tags is assigned (currently using all_tags =
[f"fi_api:{fi_api}"] + template.tags), change it to construct the list using
[f"fi_api:{fi_api}", *template.tags] so it directly prepends the formatted
fi_api tag to template.tags; keep the same variable name and semantics.

426-443: Auto-infer dtype uses first matching input — document this behavior.

The auto-inference logic selects the dtype from the first input tensor with overlapping dimension names (line 443 break). This is a reasonable heuristic, but if multiple inputs have overlapping dims with different dtypes, the choice is arbitrary. Consider adding a brief inline comment noting this precedence.

Suggested documentation
                     else:
-                        # Auto-infer: find first input tensor with overlapping dims
+                        # Auto-infer: use dtype from first input tensor with overlapping
+                        # dims. If multiple inputs overlap, precedence follows dict order.
                         dtype = "unknown"
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/trace/template.py` around lines 426 - 443, The auto-infer branch
in template.py sets dtype to the first matching input's type (looping over
template.inputs, checking Tensor instances and overlapping descriptor.dim_names,
using _get_tensor and _dtype_str, then break), which is arbitrary when multiple
inputs overlap; add a concise inline comment near this logic (around the loop
and the break) stating that this chooses the first matching input's dtype as the
precedence rule and that other overlapping inputs may be ignored, so callers
should avoid ambiguous multiple-dtype overlaps or explicitly provide dtype to
override; keep the comment short and reference template.inputs, Tensor,
descriptor, _get_tensor, and _dtype_str.

370-378: Silent exception swallowing may hide bugs during axis extraction.

The bare except Exception: pass at lines 376-377 silently ignores all errors during axis value extraction. While this provides robustness, it can hide bugs in extractor logic or unexpected input types. Consider at minimum logging at debug level.

Suggested fix
+import logging
+
+_logger = logging.getLogger(__name__)
+
 # In fi_trace function:
             for axis_name, extractor in axis_extractors.items():
                 try:
                     val = extractor(kwargs)
                     if val is not None:
                         axis_values[axis_name] = val
-                except Exception:
-                    pass
+                except Exception as exc:
+                    _logger.debug(
+                        "Axis extraction failed for %s: %s", axis_name, exc
+                    )
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/trace/template.py` around lines 370 - 378, The code silently
swallows all exceptions when running axis extractor functions (axis_extractors
-> extractor(kwargs) populating axis_values), which hides bugs; change the bare
"except Exception: pass" to catch the exception as e and log it at debug level
(e.g., logger.debug("axis extractor %s failed for kwargs=%s: %s", axis_name,
kwargs, e, exc_info=True)) so failures are recorded but extraction remains
robust, and if there is no existing logger in this module create one via
logging.getLogger(__name__) and import logging.
tests/trace/test_fi_trace_template_consistency.py (4)

399-408: Variable k shadows the tensor k defined earlier.

The loop variable k at line 400 shadows the tensor k defined at line 391. While this doesn't affect correctness (the tensor is no longer needed at this point), it reduces readability.

Suggested fix
     non_optional_unknown = [
-        k
-        for k, v in defn["inputs"].items()
-        if isinstance(v, dict)
-        and v.get("dtype") == "unknown"
-        and not v.get("optional", False)
+        key
+        for key, val in defn["inputs"].items()
+        if isinstance(val, dict)
+        and val.get("dtype") == "unknown"
+        and not val.get("optional", False)
     ]
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/trace/test_fi_trace_template_consistency.py` around lines 399 - 408,
The loop variable k in the comprehension that builds non_optional_unknown
shadows the tensor named k earlier; rename the loop variable (e.g., to
input_name or inp_key) used in the comprehension and in the f-string so it no
longer collides with the tensor k, updating the comprehension over
defn["inputs"].items() and the f"Non-optional inputs with unknown dtype: {...}"
reference accordingly.

495-496: Use a raw string for the regex pattern.

The pattern contains backslashes and should be a raw string to avoid unintended escapes and satisfy Ruff RUF043.

Suggested fix
-    with pytest.raises(AssertionError, match="param=.*hidden_state.*not found"):
+    with pytest.raises(AssertionError, match=r"param=.*hidden_state.*not found"):
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/trace/test_fi_trace_template_consistency.py` around lines 495 - 496,
The regex passed to pytest.raises should be a raw string to avoid accidental
escape sequences; update the call to pytest.raises(AssertionError, match=...)
used around assert_template_signature_consistency(func, broken,
label="meta-test") so the match argument is a raw string literal (prefix it with
r, e.g. r"param=.*hidden_state.*not found") to satisfy Ruff RUF043 and ensure
the pattern is interpreted correctly.

430-460: Rename ambiguous variable I in the MoE routing test.

Ruff flags I as ambiguous (E741). Consider renaming to intermediate or inter_size for clarity.

Suggested fix
-    T, E, EL, H, I, BS = 4, 16, 2, 256, 64, 128
+    T, E, EL, H, INTER, BS = 4, 16, 2, 256, 64, 128
     defn = trtllm_fp8_block_scale_moe.fi_trace(
         routing_logits=torch.zeros(T, E, dtype=torch.float32),
         routing_bias=torch.zeros(E, dtype=torch.bfloat16),
         hidden_states=torch.zeros(T, H, dtype=torch.float8_e4m3fn),
         hidden_states_scale=torch.ones(H // BS, T, dtype=torch.float32),
-        gemm1_weights=torch.zeros(EL, 2 * I, H, dtype=torch.float8_e4m3fn),
-        gemm1_weights_scale=torch.ones(EL, (2 * I) // BS, H // BS, dtype=torch.float32),
-        gemm2_weights=torch.zeros(EL, H, I, dtype=torch.float8_e4m3fn),
-        gemm2_weights_scale=torch.ones(EL, H // BS, I // BS, dtype=torch.float32),
+        gemm1_weights=torch.zeros(EL, 2 * INTER, H, dtype=torch.float8_e4m3fn),
+        gemm1_weights_scale=torch.ones(EL, (2 * INTER) // BS, H // BS, dtype=torch.float32),
+        gemm2_weights=torch.zeros(EL, H, INTER, dtype=torch.float8_e4m3fn),
+        gemm2_weights_scale=torch.ones(EL, H // BS, INTER // BS, dtype=torch.float32),
         num_experts=E,
         top_k=top_k,
-        intermediate_size=I,
+        intermediate_size=INTER,
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/trace/test_fi_trace_template_consistency.py` around lines 430 - 460,
The test function test_fi_trace_complete_moe_routing uses a single-letter
variable I (intermediate size) which triggers an ambiguity lint (E741); rename I
to a descriptive identifier (e.g., inter_size or intermediate) and update all
references inside the function and the fi_trace(...) call (intermediate_size=I,
shapes using I, 2 * I etc.) so the values and assertions remain identical but
the variable name is clear and matches usage in
trtllm_fp8_block_scale_moe.fi_trace.

369-370: Rename ambiguous loop variable l to improve readability.

Ruff flags l as ambiguous (E741) because it can be confused with 1. Consider renaming to lbl or label.

Suggested fix
-_E2E_PAIRS = [(f, t, l) for f, t, l in _ALL_PAIRS if l not in _E2E_SKIP]
+_E2E_PAIRS = [(f, t, lbl) for f, t, lbl in _ALL_PAIRS if lbl not in _E2E_SKIP]
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/trace/test_fi_trace_template_consistency.py` around lines 369 - 370,
The list comprehension _E2E_PAIRS uses an ambiguous loop variable named "l";
rename it to a clearer identifier (e.g., "label" or "lbl") in the comprehension
and update the subsequent _E2E_IDS comprehension to unpack/use that new name so
both _E2E_PAIRS = [(f, t, label) for f, t, label in _ALL_PAIRS if label not in
_E2E_SKIP] and _E2E_IDS = [label for _, _, label in _E2E_PAIRS] remain
consistent.
flashinfer/trace/templates/moe.py (1)

85-88: Rename ambiguous variable O to improve readability.

Ruff flags O as ambiguous (E741) because it can be confused with 0. Consider renaming to out or output_e.

Suggested fix
-        O = (silu_X2 * X1).matmul(W2[le].t())
+        expert_out = (silu_X2 * X1).matmul(W2[le].t())
         # per-expert contribution weight for each token
         w_tok = weights.index_select(0, token_idx)
         # find which slot in topk_idx[token_idx] corresponds to ge
         match = (topk_idx.index_select(0, token_idx) == ge).float()
         w_e = (w_tok * match).sum(dim=1)
-        output.index_add_(0, token_idx, O * w_e.unsqueeze(1))
+        output.index_add_(0, token_idx, expert_out * w_e.unsqueeze(1))
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/trace/templates/moe.py` around lines 85 - 88, Rename the ambiguous
variable O used in the moe attention feedforward block to a clearer name (e.g.,
out or output_e) to avoid confusion with the digit zero; update the assignment
and any subsequent uses where O appears (the expression "(silu_X2 *
X1).matmul(W2[le].t())") and ensure references to G1, X1, X2, silu_X2, W13, W2,
A_e, and le remain correct with the new variable name.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@flashinfer/api_logging.py`:
- Around line 1510-1531: The wrapper created by _attach_fi_trace (and returned
by flashinfer_api) adds runtime cost even when tracing/LOGLEVEL=0; change
_attach_fi_trace so that if tracing is disabled (i.e., _is_trace_dump_enabled()
is False and caller requested zero-overhead) it does not create
&_auto_dump_wrapper but instead attaches fi_trace to the original callable via
setattr(original, "fi_trace", fi_trace_fn) and returns original; otherwise keep
the current wrapper behavior. Also avoid direct attribute assignment on
Callable-typed objects that triggers mypy attr-defined errors by using
setattr(original, "fi_trace", fi_trace_fn) or by casting to Any/creating a small
Protocol for fi_trace to satisfy type-checkers (e.g., cast(original, Any) or
define Protocol with fi_trace) so the pipeline no longer errors.
- Around line 1508-1531: Replace direct attribute assignments to .fi_trace with
setattr to avoid mypy attr-defined errors: where the diff sets wrapped.fi_trace
= fi_trace_fn and _auto_dump_wrapper.fi_trace = fi_trace_fn, change those direct
assignments to use setattr(wrapped, "fi_trace", fi_trace_fn) and
setattr(_auto_dump_wrapper, "fi_trace", fi_trace_fn). Keep the same semantics
(assign the fi_trace_fn callable) and leave other code in _auto_dump_wrapper,
_sig, and fi_trace_fn unchanged.

In `@flashinfer/fi_trace.py`:
- Around line 238-285: The function fi_trace currently types func_or_method as
Callable but relies on the object (actual_func) exposing a .fi_trace attribute;
update the typing to make that contract explicit by introducing a Protocol
(e.g., TracedCallable with a fi_trace(self, save_dir: Optional[Union[str, Path]]
= None, **kwargs) -> Dict[str, Any]) and use that Protocol as the type for
func_or_method (or cast actual_func to TracedCallable before accessing
.fi_trace); ensure the Protocol signature matches how trace_fn is called in
fi_trace and import typing.Protocol and any necessary types so mypy recognizes
the requirement.
- Around line 103-110: The import line bringing in Const, Scalar, Tensor,
TraceTemplate, and Var from .trace.template is unused in build_fi_trace_fn and
causing Ruff F401 warnings; remove those five names (or the whole legacy import
if nothing else from that module is used) so only needed symbols remain imported
in flashinfer/fi_trace.py and eliminate the unused imports Const, Scalar,
Tensor, TraceTemplate, Var from the import statement that currently appears
alongside build_fi_trace_fn.

In `@flashinfer/trace/templates/gdn.py`:
- Around line 502-546: The template schema is missing the disable_state_update
input required by gated_delta_rule_mtp, so add a boolean Tensor/Scalar entry
named "disable_state_update" to the inputs dict (matching how other flags are
represented) and mark it optional or required consistent with
gated_delta_rule_mtp's signature; ensure you reference the symbol name
"disable_state_update" and update the inputs block near the existing
"initial_state"/"final_state" entries so the trace can distinguish
state-updating vs non-updating behavior described in "final_state".

In `@tests/trace/test_fi_trace_template_consistency.py`:
- Around line 384-394: Remove the unused import gqa_paged_decode_trace from the
test; locate the import statement that reads "from
flashinfer.trace.templates.attention import gqa_paged_decode_trace" and delete
it so the test only imports and uses
BatchDecodeWithPagedKVCacheWrapper.run.fi_trace (ensure no other references to
gqa_paged_decode_trace remain in the file).
- Around line 309-321: The import of flashinfer.sampling is unused and flagged
by pre-commit; either remove the import statement for flashinfer.sampling or
ensure it registers decorators used by _TRACE_REGISTRY (so the import has side
effects). Locate the import block containing flashinfer.sampling (near imports
for flashinfer.decode, flashinfer.gdn_decode, etc.) and delete the
flashinfer.sampling line if no decorated functions from that module are expected
to be registered, otherwise import the specific symbols that cause registration
or add a comment explaining the necessary side-effect to avoid removal by
linters.

In `@tests/trace/test_fi_trace.py`:
- Line 20: Remove the unused top-level import of pytest in
tests/trace/test_fi_trace.py: delete the line "import pytest" since the file
relies on pytest fixtures (tmp_path, monkeypatch) provided by pytest's runtime
and does not reference the pytest symbol directly; ensure no other code in the
module uses the pytest name before committing.

---

Duplicate comments:
In `@flashinfer/trace/templates/attention.py`:
- Around line 140-169: The code is indexing k_flat/v_flat by page_ids and
treating num_kv_tokens as page_ids.shape[0], which is incorrect for paged
caches; you must expand page indices to per-token indices before building
k_b/v_b and computing num_kv_tokens so the causal window and gathers operate at
token granularity. Change the gather so that page_ids are multiplied/expanded by
page_size into token_indices (e.g. token_indices =
page_ids.unsqueeze(1)*page_size + torch.arange(page_size, device=...)) and then
use those token_indices to index the original k_flat/v_flat (or reshape
k_cache/v_cache into per-token and gather by token_indices) so k_b/v_b contain
all tokens from the selected pages, set num_kv_tokens = token_indices.numel()
(or actual token count if last page partial), and adjust uses of max_kv, delta,
and slicing (k_b[:max_kv], v_b[:max_kv]) accordingly.
- Around line 357-385: The code currently assumes page_size == 1 by calling
ckv_cache.squeeze(1) and kpe_cache.squeeze(1) (variables Kc_all/Kp_all) which
collapses the page dimension; instead merge the page and token dimensions so
arbitrary page_size works: replace the squeeze(1) usage with a reshape/view that
flattens the first two dims (e.g. ckv_cache.reshape(-1, head_dim_ckv) and
kpe_cache.reshape(-1, head_dim_kpe)) and keep using
kv_indices[kv_indptr[b]:kv_indptr[b+1]] (tok_idx) to index into the flattened
Kc_all/Kp_all so Kc = Kc_all[tok_idx] and Kp = Kp_all[tok_idx] work for
multi-token pages; apply the same change to the other block (lines ~476-518)
where ckv_cache/kpe_cache are squeezed.
- Around line 39-58: kv_indices are page IDs but the code treats them as token
indices when building token_ids; fix by expanding each page id into its
page_size token-row indices before indexing k_flat/v_flat. In the loop over b
replace the current token_ids = kv_indices[page_start:page_end].to(torch.long)
with logic that maps each page id p to the contiguous token index range
p*page_size .. (p+1)*page_size-1 (preserving dtype/device), then flatten that to
a 1D tensor and use it to build k_b and v_b so k_b/v_b remain shaped [T,
num_kv_heads, head_dim]; keep using k_flat/v_flat, kv_indptr, kv_indices,
page_size, k_b, v_b, token_ids, and ensure device/torch.long handling remains
correct.

In `@flashinfer/trace/templates/gdn.py`:
- Around line 165-169: The schema currently sets the attention "output" Tensor
using dtype_from="q", which misreports dtype because outputs are cast to
torch.bfloat16; update the Tensor definition for the "output" field in the GDN
templates to use an explicit dtype of "bfloat16" (replace dtype_from="q" with
dtype="bfloat16") for the occurrences around the shown block and the other two
occurrences (near lines 351-355 and 537-541) so the trace metadata correctly
reflects torch.bfloat16 outputs.
- Around line 421-458: The loop updates state_HVK per batch/state but
final_state is created from initial_state and never updated, so return value is
stale; update final_state with the pooled (transposed) state_HVK for each
corresponding state index (initial_state_indices) after finishing updates for
that state (or after the outer loops) so final_state[state_idx] =
state_HVK.transpose(-1, -2) (match the same [H,V,K] ↔ [H,K,V] orientation used
for initial_state/state_HVK) before returning output and final_state.
- Around line 362-365: The gdn_prefill_trace template is missing head-ratio
validity checks: add constraints ensuring num_v_heads is divisible by
num_q_heads and by num_k_heads (e.g., num_v_heads % num_q_heads == 0 and
num_v_heads % num_k_heads == 0) so the downstream divisions (num_v_heads //
num_q_heads and num_v_heads // num_k_heads) used elsewhere are valid; update the
constraints list in gdn_prefill_trace to include these checks referencing the
variables num_v_heads, num_q_heads, and num_k_heads.

In `@flashinfer/trace/templates/gemm.py`:
- Around line 180-217: The mm_fp4_trace TraceTemplate currently lists inputs "A"
and "B" with unpacked shapes ["M","K"] and ["K","N"], but at runtime these are
packed uint8 FP4 buffers; update mm_fp4_trace so the Tensor entries for "A" and
"B" describe the packed axes (e.g., K_packed/K_block or bytes per packed row) or
add an extractor that converts the packed dimensions back to logical K and N
(use the existing "block_size" Var/Scalar to compute K//block_size and
N//block_size); specifically modify the Tensor definitions for "A" and "B" in
mm_fp4_trace (and any related axis defs such as "K" or "N") so fi_trace will
infer correct runtime shapes for FP4-packed inputs.
- Around line 22-35: The reference implementations currently multiply by B.T,
but B represents the physical [K, N] weight matrix so using B.T swaps dims and
breaks cases where N != K; update _mm_reference to compute torch.matmul(A, B)
(not A @ B.T), and in _mm_fp8_reference reshape B into [K, N] (B_fp32 =
B.reshape(K_div_bs * block_size, N)) and use torch.matmul(A_fp32, B_fp32)
(remove the trailing .T), applying the same fix to the other reference helpers
mentioned (the FP8 and bf16 variants in the file).

In `@flashinfer/trace/templates/moe.py`:
- Around line 648-652: The assignment to add a dynamic attribute on the function
trtllm_fp8_block_scale_moe_trace_dispatch causes mypy attr-defined errors;
replace the direct assignment with a setattr call so the templates attribute is
attached at runtime (e.g., setattr(trtllm_fp8_block_scale_moe_trace_dispatch,
"templates", list(_MOE_TRACE_BY_ROUTING_TYPE.values()))) to preserve behavior
while satisfying type checking.
- Around line 25-27: The module currently hardcodes H=7168 and I=2048 which
breaks _fp8_moe_run_experts for models with different hidden/intermediate sizes;
change the code to derive hidden_size and intermediate_size at runtime from
tensor shapes (e.g., infer hidden_size from the input/hidden tensor shape[-1] or
the expert weight shapes, and infer intermediate_size from the feedforward
weight/output shapes) and replace uses of H and I with those derived values
(also ensure BLOCK is computed/validated against hidden_size if needed); update
all references in _fp8_moe_run_experts to use the derived variables so the
function works for arbitrary MoE shapes.
- Around line 126-131: The template hardcodes routing parameters TOP_K, N_GROUP,
TOPK_GROUP which conflict with the public API; update the code that defines
TOP_K, N_GROUP, TOPK_GROUP to read the corresponding function arguments (e.g.,
top_k, n_group, topk_group) or the routing parameters object instead of fixed
literals so the template matches whatever configuration is passed via
routing_logits' caller (or if these values are truly only for shape/schema
checks, replace the literals with a clarifying comment near
TOP_K/N_GROUP/TOPK_GROUP stating they are placeholder defaults used only for
schema validation and not numeric correctness). Ensure you change the
occurrences of TOP_K, N_GROUP, TOPK_GROUP in this module (referenced with
routing_logits, E_global, T) accordingly.

In `@tests/trace/example.py`:
- Around line 54-373: The file runs the entire trace-generation body at import
time so pytest won't collect it; keep the environment setup (the
FLASHINFER_TRACE_* os.environ lines and SAVE_DIR) and the imports as-is but move
everything that performs work (starting from device/WORKSPACE and all calls that
exercise flashinfer APIs, e.g., the loops that call flashinfer.rmsnorm,
flashinfer.mm_bf16, BatchDecodeWithPagedKVCacheWrapper.plan/run,
BatchPrefillWithPagedKVCacheWrapper.plan/run,
BatchPrefillWithRaggedKVCacheWrapper.plan/run,
BatchMLAPagedAttentionWrapper.plan/run, flashinfer.gdn_decode.*,
flashinfer.fused_moe.* and the final JSON summary) into a pytest test function
named with a test_ prefix (e.g., test_generate_fi_traces) so CI will execute it;
also add a minimal if __name__ == "__main__" guard to call that function when
run as a script so the example remains runnable standalone.

---

Nitpick comments:
In `@flashinfer/trace/template.py`:
- Line 473: Replace the list concatenation that builds all_tags with list
unpacking for clearer syntax: in the function/block where variable all_tags is
assigned (currently using all_tags = [f"fi_api:{fi_api}"] + template.tags),
change it to construct the list using [f"fi_api:{fi_api}", *template.tags] so it
directly prepends the formatted fi_api tag to template.tags; keep the same
variable name and semantics.
- Around line 426-443: The auto-infer branch in template.py sets dtype to the
first matching input's type (looping over template.inputs, checking Tensor
instances and overlapping descriptor.dim_names, using _get_tensor and
_dtype_str, then break), which is arbitrary when multiple inputs overlap; add a
concise inline comment near this logic (around the loop and the break) stating
that this chooses the first matching input's dtype as the precedence rule and
that other overlapping inputs may be ignored, so callers should avoid ambiguous
multiple-dtype overlaps or explicitly provide dtype to override; keep the
comment short and reference template.inputs, Tensor, descriptor, _get_tensor,
and _dtype_str.
- Around line 370-378: The code silently swallows all exceptions when running
axis extractor functions (axis_extractors -> extractor(kwargs) populating
axis_values), which hides bugs; change the bare "except Exception: pass" to
catch the exception as e and log it at debug level (e.g., logger.debug("axis
extractor %s failed for kwargs=%s: %s", axis_name, kwargs, e, exc_info=True)) so
failures are recorded but extraction remains robust, and if there is no existing
logger in this module create one via logging.getLogger(__name__) and import
logging.

In `@flashinfer/trace/templates/moe.py`:
- Around line 85-88: Rename the ambiguous variable O used in the moe attention
feedforward block to a clearer name (e.g., out or output_e) to avoid confusion
with the digit zero; update the assignment and any subsequent uses where O
appears (the expression "(silu_X2 * X1).matmul(W2[le].t())") and ensure
references to G1, X1, X2, silu_X2, W13, W2, A_e, and le remain correct with the
new variable name.

In `@tests/trace/test_fi_trace_template_consistency.py`:
- Around line 399-408: The loop variable k in the comprehension that builds
non_optional_unknown shadows the tensor named k earlier; rename the loop
variable (e.g., to input_name or inp_key) used in the comprehension and in the
f-string so it no longer collides with the tensor k, updating the comprehension
over defn["inputs"].items() and the f"Non-optional inputs with unknown dtype:
{...}" reference accordingly.
- Around line 495-496: The regex passed to pytest.raises should be a raw string
to avoid accidental escape sequences; update the call to
pytest.raises(AssertionError, match=...) used around
assert_template_signature_consistency(func, broken, label="meta-test") so the
match argument is a raw string literal (prefix it with r, e.g.
r"param=.*hidden_state.*not found") to satisfy Ruff RUF043 and ensure the
pattern is interpreted correctly.
- Around line 430-460: The test function test_fi_trace_complete_moe_routing uses
a single-letter variable I (intermediate size) which triggers an ambiguity lint
(E741); rename I to a descriptive identifier (e.g., inter_size or intermediate)
and update all references inside the function and the fi_trace(...) call
(intermediate_size=I, shapes using I, 2 * I etc.) so the values and assertions
remain identical but the variable name is clear and matches usage in
trtllm_fp8_block_scale_moe.fi_trace.
- Around line 369-370: The list comprehension _E2E_PAIRS uses an ambiguous loop
variable named "l"; rename it to a clearer identifier (e.g., "label" or "lbl")
in the comprehension and update the subsequent _E2E_IDS comprehension to
unpack/use that new name so both _E2E_PAIRS = [(f, t, label) for f, t, label in
_ALL_PAIRS if label not in _E2E_SKIP] and _E2E_IDS = [label for _, _, label in
_E2E_PAIRS] remain consistent.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: a970c3e2-8b74-4f9f-b236-d08657910713

📥 Commits

Reviewing files that changed from the base of the PR and between f7e2129 and 2f4aceb.

📒 Files selected for processing (32)
  • flashinfer/api_logging.py
  • flashinfer/fi_trace.py
  • flashinfer/trace/template.py
  • flashinfer/trace/templates/attention.py
  • flashinfer/trace/templates/gdn.py
  • flashinfer/trace/templates/gemm.py
  • flashinfer/trace/templates/moe.py
  • tests/trace/example.py
  • tests/trace/fi_trace_out/fused_add_rmsnorm_h5120.json
  • tests/trace/fi_trace_out/gdn_decode_qk4_v8_d128.json
  • tests/trace/fi_trace_out/gdn_mtp_qk4_v8_d128.json
  • tests/trace/fi_trace_out/gemm_bf16_N256_K7168.json
  • tests/trace/fi_trace_out/gemm_bf16_N4096_K4096.json
  • tests/trace/fi_trace_out/gemm_fp4_N2048_K7168_block_size16.json
  • tests/trace/fi_trace_out/gemm_fp8_N1536_K7168.json
  • tests/trace/fi_trace_out/gemm_mxfp8_N4096_K4096.json
  • tests/trace/fi_trace_out/gqa_paged_decode_h32_kv8_d128_ps16.json
  • tests/trace/fi_trace_out/gqa_paged_decode_h32_kv8_d128_ps64.json
  • tests/trace/fi_trace_out/gqa_paged_prefill_h32_kv8_d128_ps16.json
  • tests/trace/fi_trace_out/gqa_ragged_h32_kv8_d128.json
  • tests/trace/fi_trace_out/mla_paged_decode_h16_ckv512_kpe64_ps1.json
  • tests/trace/fi_trace_out/mla_paged_decode_h16_ckv512_kpe64_ps64.json
  • tests/trace/fi_trace_out/moe_fp8_block_scale_default_routing_topk8_e32_h7168_i2048.json
  • tests/trace/fi_trace_out/rmsnorm_h4096.json
  • tests/trace/fi_trace_out/rmsnorm_h7168.json
  • tests/trace/fi_trace_out/top_k_sampling_v128256.json
  • tests/trace/fi_trace_out/top_k_top_p_sampling_v128256.json
  • tests/trace/fi_trace_out/top_k_top_p_sampling_v151936.json
  • tests/trace/fi_trace_out/top_p_sampling_v128256.json
  • tests/trace/fi_trace_out/top_p_sampling_v151936.json
  • tests/trace/test_fi_trace.py
  • tests/trace/test_fi_trace_template_consistency.py
✅ Files skipped from review due to trivial changes (15)
  • tests/trace/fi_trace_out/gemm_bf16_N256_K7168.json
  • tests/trace/fi_trace_out/rmsnorm_h4096.json
  • tests/trace/fi_trace_out/rmsnorm_h7168.json
  • tests/trace/fi_trace_out/gemm_bf16_N4096_K4096.json
  • tests/trace/fi_trace_out/gemm_fp8_N1536_K7168.json
  • tests/trace/fi_trace_out/fused_add_rmsnorm_h5120.json
  • tests/trace/fi_trace_out/top_k_top_p_sampling_v128256.json
  • tests/trace/fi_trace_out/gemm_mxfp8_N4096_K4096.json
  • tests/trace/fi_trace_out/top_p_sampling_v151936.json
  • tests/trace/fi_trace_out/top_k_sampling_v128256.json
  • tests/trace/fi_trace_out/gqa_paged_decode_h32_kv8_d128_ps64.json
  • tests/trace/fi_trace_out/top_k_top_p_sampling_v151936.json
  • tests/trace/fi_trace_out/mla_paged_decode_h16_ckv512_kpe64_ps64.json
  • tests/trace/fi_trace_out/mla_paged_decode_h16_ckv512_kpe64_ps1.json
  • tests/trace/fi_trace_out/gqa_paged_prefill_h32_kv8_d128_ps16.json
🚧 Files skipped from review as they are similar to previous changes (6)
  • tests/trace/fi_trace_out/top_p_sampling_v128256.json
  • tests/trace/fi_trace_out/gemm_fp4_N2048_K7168_block_size16.json
  • tests/trace/fi_trace_out/gqa_ragged_h32_kv8_d128.json
  • tests/trace/fi_trace_out/gdn_decode_qk4_v8_d128.json
  • tests/trace/fi_trace_out/gqa_paged_decode_h32_kv8_d128_ps16.json
  • tests/trace/fi_trace_out/moe_fp8_block_scale_default_routing_topk8_e32_h7168_i2048.json

Comment thread flashinfer/api_logging.py Outdated
Comment on lines +1508 to +1531
wrapped.fi_trace = fi_trace_fn

# Auto-dump wrapper: checked lazily at call time so that callers
# can set FLASHINFER_TRACE_DUMP after importing flashinfer (e.g.
# when running via ``python -m``).
_inner = wrapped
_sig = inspect.signature(original)

@functools.wraps(_inner)
def _auto_dump_wrapper(*args, **kwargs):
# Generate trace BEFORE the actual call (crash-safe: schema
# depends only on input shapes/dtypes, not on whether the
# computation succeeds).
if _is_trace_dump_enabled():
try:
bound = _sig.bind(*args, **kwargs)
bound.apply_defaults()
fi_trace_fn(**dict(bound.arguments))
except Exception:
pass
return _inner(*args, **kwargs)

_auto_dump_wrapper.fi_trace = fi_trace_fn
return _auto_dump_wrapper
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

🧩 Analysis chain

🏁 Script executed:

find . -maxdepth 5 -name "*.py" -path "*flashinfer*" | grep -i api | head -10

Repository: flashinfer-ai/flashinfer

Length of output: 95


🏁 Script executed:

ls -la flashinfer/ 2>/dev/null | head -20 || echo "Directory not found"

Repository: flashinfer-ai/flashinfer

Length of output: 1275


🏁 Script executed:

git ls-files "*api_logging*"

Repository: flashinfer-ai/flashinfer

Length of output: 93


🏁 Script executed:

sed -n '1508,1540p' flashinfer/api_logging.py

Repository: flashinfer-ai/flashinfer

Length of output: 1518


🏁 Script executed:

grep -n "\.fi_trace\s*=" flashinfer/api_logging.py | head -20

Repository: flashinfer-ai/flashinfer

Length of output: 238


🏁 Script executed:

sed -n '1490,1545p' flashinfer/api_logging.py

Repository: flashinfer-ai/flashinfer

Length of output: 2559


🏁 Script executed:

head -100 flashinfer/api_logging.py | grep -E "^(import|from|class|def|__all__)"

Repository: flashinfer-ai/flashinfer

Length of output: 406


🏁 Script executed:

grep -B 30 "wrapped.fi_trace = fi_trace_fn" flashinfer/api_logging.py | head -50

Repository: flashinfer-ai/flashinfer

Length of output: 1594


🏁 Script executed:

sed -n '1450,1510p' flashinfer/api_logging.py

Repository: flashinfer-ai/flashinfer

Length of output: 3083


🏁 Script executed:

find . -name "pyproject.toml" -o -name "setup.cfg" -o -name "mypy.ini" | head -5

Repository: flashinfer-ai/flashinfer

Length of output: 156


🏁 Script executed:

grep -n "def.*wrapped" flashinfer/api_logging.py | head -20

Repository: flashinfer-ai/flashinfer

Length of output: 50


🏁 Script executed:

cat pyproject.toml | grep -A 20 "\[tool.mypy\]"

Repository: flashinfer-ai/flashinfer

Length of output: 429


🏁 Script executed:

cat pyproject.toml | grep -A 20 "lint\|mypy"

Repository: flashinfer-ai/flashinfer

Length of output: 712


🏁 Script executed:

sed -n '1450,1560p' flashinfer/api_logging.py | cat -n

Repository: flashinfer-ai/flashinfer

Length of output: 6002


🏁 Script executed:

sed -n '1450,1545p' flashinfer/api_logging.py | head -100

Repository: flashinfer-ai/flashinfer

Length of output: 4750


🏁 Script executed:

grep -n "def flashinfer_api\|def _attach_fi_trace" flashinfer/api_logging.py

Repository: flashinfer-ai/flashinfer

Length of output: 169


🏁 Script executed:

python3 -m mypy flashinfer/api_logging.py --no-error-summary 2>&1 | grep -i "fi_trace\|attr-defined" | head -20

Repository: flashinfer-ai/flashinfer

Length of output: 50


🏁 Script executed:

sed -n '1435,1460p' flashinfer/api_logging.py

Repository: flashinfer-ai/flashinfer

Length of output: 1278


🏁 Script executed:

sed -n '1557,1620p' flashinfer/api_logging.py

Repository: flashinfer-ai/flashinfer

Length of output: 2910


🏁 Script executed:

python3 << 'EOF'
# Check what mypy would say about direct assignment to Callable
import sys
import tempfile
import os

code = '''
from typing import Callable

def my_decorator(original: Callable) -> Callable:
    wrapped: Callable = original
    wrapped.fi_trace = lambda: None  # This is the issue
    return wrapped
'''

with tempfile.NamedTemporaryFile(mode='w', suffix='.py', delete=False) as f:
    f.write(code)
    temp_file = f.name

try:
    os.system(f'python3 -m mypy {temp_file} 2>&1 | head -20')
finally:
    os.unlink(temp_file)
EOF

Repository: flashinfer-ai/flashinfer

Length of output: 106


🏁 Script executed:

grep -n "fi_trace" flashinfer/api_logging.py | head -30

Repository: flashinfer-ai/flashinfer

Length of output: 1728


🏁 Script executed:

grep -n "class.*Protocol\|@.*Protocol" flashinfer/api_logging.py

Repository: flashinfer-ai/flashinfer

Length of output: 50


🏁 Script executed:

grep -B 5 "def _attach_fi_trace" flashinfer/api_logging.py | head -10

Repository: flashinfer-ai/flashinfer

Length of output: 287


🏁 Script executed:

sed -n '1400,1445p' flashinfer/api_logging.py

Repository: flashinfer-ai/flashinfer

Length of output: 1579


🏁 Script executed:

python3 << 'EOF'
# Check if setattr vs direct assignment matters for mypy
code1 = '''
from typing import Callable

def decorator(func: Callable) -> Callable:
    func.fi_trace = lambda: None  # Direct assignment - mypy error
    return func
'''

code2 = '''
from typing import Callable

def decorator(func: Callable) -> Callable:
    setattr(func, "fi_trace", lambda: None)  # setattr - mypy OK
    return func
'''

print("Direct assignment would trigger attr-defined error on Callable types")
print("setattr() is the mypy-friendly approach for dynamic attributes")
EOF

Repository: flashinfer-ai/flashinfer

Length of output: 199


🏁 Script executed:

grep -C 3 "wrapped: Callable\|original: Callable" flashinfer/api_logging.py | head -30

Repository: flashinfer-ai/flashinfer

Length of output: 230


Use setattr() for dynamic attribute assignment to avoid mypy attr-defined errors.

Direct assignment of .fi_trace to Callable-typed variables triggers mypy's attr-defined check. Replace direct assignment with setattr() to explicitly model the dynamic attribute:

Required changes
-            wrapped.fi_trace = fi_trace_fn
+            setattr(wrapped, "fi_trace", fi_trace_fn)
@@
-            _auto_dump_wrapper.fi_trace = fi_trace_fn
+            setattr(_auto_dump_wrapper, "fi_trace", fi_trace_fn)
@@
-                wrapped.fi_trace = build_fi_trace_fn(spec)
+                setattr(wrapped, "fi_trace", build_fi_trace_fn(spec))
🧰 Tools
🪛 GitHub Actions: pre-commit

[error] 1508-1539: mypy error(s): "Callable[..., Any]" has no attribute "fi_trace" / "_Wrapped[...]" has no attribute "fi_trace" [attr-defined]

🪛 Ruff (0.15.9)

[error] 1526-1527: try-except-pass detected, consider logging the exception

(S110)


[warning] 1526-1526: Do not catch blind exception: Exception

(BLE001)

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/api_logging.py` around lines 1508 - 1531, Replace direct attribute
assignments to .fi_trace with setattr to avoid mypy attr-defined errors: where
the diff sets wrapped.fi_trace = fi_trace_fn and _auto_dump_wrapper.fi_trace =
fi_trace_fn, change those direct assignments to use setattr(wrapped, "fi_trace",
fi_trace_fn) and setattr(_auto_dump_wrapper, "fi_trace", fi_trace_fn). Keep the
same semantics (assign the fi_trace_fn callable) and leave other code in
_auto_dump_wrapper, _sig, and fi_trace_fn unchanged.

Comment thread flashinfer/api_logging.py
Comment thread flashinfer/fi_trace.py
Comment on lines +103 to +110
# Import the old implementation from the trace package for backwards compat.
from .trace.template import ( # noqa: PLC0415
Const,
Scalar,
Tensor,
TraceTemplate,
Var,
)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

🧩 Analysis chain

🏁 Script executed:

# First, check if the file exists and examine the relevant lines
if [ -f "flashinfer/fi_trace.py" ]; then
    echo "=== File exists, examining lines 100-115 ==="
    sed -n '100,115p' flashinfer/fi_trace.py | cat -n
else
    echo "File not found"
fi

Repository: flashinfer-ai/flashinfer

Length of output: 675


🏁 Script executed:

# Search for usages of the imported names in the file
echo "=== Searching for usages of Const, Scalar, Tensor, TraceTemplate, Var ==="
for name in Const Scalar Tensor TraceTemplate Var; do
    echo "--- Searching for '$name' ---"
    rg "\b${name}\b" flashinfer/fi_trace.py -n | head -20
done

Repository: flashinfer-ai/flashinfer

Length of output: 732


🏁 Script executed:

# Check the full file to understand the import context
echo "=== Counting lines and examining structure ==="
wc -l flashinfer/fi_trace.py
echo ""
echo "=== Checking for function definitions ==="
rg "^def " flashinfer/fi_trace.py -n

Repository: flashinfer-ai/flashinfer

Length of output: 332


🏁 Script executed:

# Examine the build_fi_trace_fn function in detail to see if these imports are used
echo "=== Examining build_fi_trace_fn function (lines 97-237) ==="
sed -n '97,237p' flashinfer/fi_trace.py | head -80

Repository: flashinfer-ai/flashinfer

Length of output: 2919


🏁 Script executed:

# Check if there are any references to these imports after line 110
echo "=== Checking if imports are used after definition ==="
sed -n '110,237p' flashinfer/fi_trace.py | grep -E "Const|Scalar|Tensor[^_]|TraceTemplate|Var" | grep -v "torch.Tensor" | grep -v "isinstance"

Repository: flashinfer-ai/flashinfer

Length of output: 121


🏁 Script executed:

# Verify the file can be linted - check if Ruff would report F401
echo "=== Checking for Ruff/linting setup ==="
if command -v ruff &> /dev/null; then
    echo "Ruff is available, checking for F401 violations..."
    ruff check flashinfer/fi_trace.py --select F401 2>&1 | head -20
else
    echo "Ruff not available in sandbox"
fi

Repository: flashinfer-ai/flashinfer

Length of output: 1078


Remove unused template imports from the legacy builder.

The names Const, Scalar, Tensor, TraceTemplate, and Var imported at lines 105–109 are not referenced anywhere in build_fi_trace_fn. Ruff F401 violations are already flagged on these lines.

🧰 Tools
🪛 GitHub Actions: pre-commit

[error] 105-105: ruff (F401): .trace.template.Const imported but unused


[error] 106-106: ruff (F401): .trace.template.Scalar imported but unused


[error] 107-107: ruff (F401): .trace.template.Tensor imported but unused


[error] 108-108: ruff (F401): .trace.template.TraceTemplate imported but unused


[error] 109-109: ruff (F401): .trace.template.Var imported but unused

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/fi_trace.py` around lines 103 - 110, The import line bringing in
Const, Scalar, Tensor, TraceTemplate, and Var from .trace.template is unused in
build_fi_trace_fn and causing Ruff F401 warnings; remove those five names (or
the whole legacy import if nothing else from that module is used) so only needed
symbols remain imported in flashinfer/fi_trace.py and eliminate the unused
imports Const, Scalar, Tensor, TraceTemplate, Var from the import statement that
currently appears alongside build_fi_trace_fn.

Comment thread flashinfer/fi_trace.py
Comment on lines +238 to +285
def fi_trace(
func_or_method: Callable,
save_dir: Optional[Union[str, Path]] = None,
**kwargs: Any,
) -> Dict[str, Any]:
"""Generate a flashinfer-bench definition JSON for any FlashInfer API call.

Parameters
----------
func_or_method:
A ``@flashinfer_api``-decorated function or (bound) method.
save_dir:
Directory where the JSON definition file should be written.
Falls back to ``FLASHINFER_TRACE_DUMP_DIR`` env-var when *None*.
**kwargs:
The same tensor arguments you would pass to the real API.

Returns
-------
dict
A flashinfer-bench compatible definition dictionary.

Examples
--------
Standalone function::

defn = fi_trace(flashinfer.norm.rmsnorm, input=hidden, weight=weight)

Bound method (instance.run)::

defn = fi_trace(wrapper.run, q=q_tensor, paged_kv_cache=(k, v))

Class-level (unbound)::

defn = fi_trace(
flashinfer.decode.BatchDecodeWithPagedKVCacheWrapper.run,
q=q_tensor, paged_kv_cache=(k, v),
)
"""
actual_func = getattr(func_or_method, "__func__", func_or_method)
trace_fn = getattr(actual_func, "fi_trace", None)
if trace_fn is None:
qualname = getattr(actual_func, "__qualname__", repr(actual_func))
raise ValueError(
f"No fi_trace spec is registered for '{qualname}'. "
"Only @flashinfer_api(trace=...)-decorated functions support fi_trace."
)
return trace_fn(save_dir=save_dir, **kwargs)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

🧩 Analysis chain

🏁 Script executed:

# First, let's find and examine the fi_trace.py file
find . -name "fi_trace.py" -type f

Repository: flashinfer-ai/flashinfer

Length of output: 92


🏁 Script executed:

# Count lines in the file to understand its size
wc -l flashinfer/fi_trace.py

Repository: flashinfer-ai/flashinfer

Length of output: 94


🏁 Script executed:

# Read the function in question with context
sed -n '230,295p' flashinfer/fi_trace.py | cat -n

Repository: flashinfer-ai/flashinfer

Length of output: 2241


🏁 Script executed:

# Check the imports at the top of the file to see what type annotations are available
head -40 flashinfer/fi_trace.py | cat -n

Repository: flashinfer-ai/flashinfer

Length of output: 1954


🏁 Script executed:

# Search for any Protocol definitions or type hints related to fi_trace in the codebase
rg "Protocol|fi_trace" flashinfer/*.py --type py -B 2 -A 2 | head -100

Repository: flashinfer-ai/flashinfer

Length of output: 5252


🏁 Script executed:

# Check if there's a pyproject.toml or setup.cfg with mypy configuration
find . -name "pyproject.toml" -o -name "setup.cfg" -o -name "mypy.ini" | head -5

Repository: flashinfer-ai/flashinfer

Length of output: 156


🏁 Script executed:

# Read full imports section of fi_trace.py
head -80 flashinfer/fi_trace.py | cat -n

Repository: flashinfer-ai/flashinfer

Length of output: 3516


🏁 Script executed:

# Check pyproject.toml for mypy configuration
cat pyproject.toml | grep -A 20 "tool.mypy" | head -30

Repository: flashinfer-ai/flashinfer

Length of output: 429


🏁 Script executed:

# Search for existing Protocol definitions in the codebase
rg "from typing import|Protocol" flashinfer/*.py --type py -B 1 -A 1 | head -80

Repository: flashinfer-ai/flashinfer

Length of output: 2994


🏁 Script executed:

# Check if there are any type-related issues or comments in fi_trace.py
rg "type:|Protocol|HasAttr" flashinfer/fi_trace.py --type py -B 2 -A 2

Repository: flashinfer-ai/flashinfer

Length of output: 50


🏁 Script executed:

# Verify the exact issue by looking at the complete function implementation once more
sed -n '238,256p' flashinfer/fi_trace.py | cat -n

Repository: flashinfer-ai/flashinfer

Length of output: 792


🏁 Script executed:

# Check if there's a typing.Protocol or typing_extensions.Protocol import anywhere
rg "^from typing|^import typing" flashinfer/fi_trace.py

Repository: flashinfer-ai/flashinfer

Length of output: 123


Make the .fi_trace attribute requirement explicit in the type signature.

The function expects func_or_method (or its __func__) to have a .fi_trace attribute, but the signature declares plain Callable. Define a Protocol for this contract (e.g., class TracedCallable(Protocol): def fi_trace(...) -> Dict[str, Any]: ...) or use a cast when resolving the actual function. This will satisfy mypy and make the requirement clear to callers.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/fi_trace.py` around lines 238 - 285, The function fi_trace
currently types func_or_method as Callable but relies on the object
(actual_func) exposing a .fi_trace attribute; update the typing to make that
contract explicit by introducing a Protocol (e.g., TracedCallable with a
fi_trace(self, save_dir: Optional[Union[str, Path]] = None, **kwargs) ->
Dict[str, Any]) and use that Protocol as the type for func_or_method (or cast
actual_func to TracedCallable before accessing .fi_trace); ensure the Protocol
signature matches how trace_fn is called in fi_trace and import typing.Protocol
and any necessary types so mypy recognizes the requirement.

Comment on lines +502 to +546
"initial_state": Tensor(
["pool_size", "num_v_heads", "head_size", "head_size"],
description="Initial recurrent state pool in k-last layout [pool_size, H, V, K].",
),
"initial_state_indices": Tensor(
["batch_size"],
description="Indices mapping each batch to its initial state in the pool.",
),
"A_log": Tensor(
["num_v_heads"],
description="Log decay parameter (learnable). Used to compute g = exp(-exp(A_log) * softplus(a + dt_bias)).",
),
"a": Tensor(
["batch_size", "seq_len", "num_v_heads"],
description="Input-dependent decay from projection.",
),
"dt_bias": Tensor(
["num_v_heads"],
description="Decay bias (learnable). Added to 'a' before softplus.",
),
"b": Tensor(
["batch_size", "seq_len", "num_v_heads"],
description="Update gate input from projection. beta = sigmoid(b).",
),
"scale": Scalar(
"float32",
description="Scale factor. Default is 1/sqrt(head_size).",
),
"intermediate_states_buffer": Tensor(
["pool_size", "seq_len", "num_v_heads", "head_size", "head_size"],
optional=True,
description="Optional buffer for caching intermediate states for potential rollback.",
),
},
outputs={
"output": Tensor(
["batch_size", "seq_len", "num_v_heads", "head_size"],
dtype_from="q",
description="Attention output for all T tokens. Shape follows num_v_heads in GVA mode.",
),
"final_state": Tensor(
["pool_size", "num_v_heads", "head_size", "head_size"],
dtype="float32",
description="Updated recurrent state pool in k-last layout [pool_size, H, V, K]. Unchanged if disable_state_update=True.",
),
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

🧩 Analysis chain

🏁 Script executed:

# First, find the gated_delta_rule_mtp function definition
rg "def gated_delta_rule_mtp" -A 20

Repository: flashinfer-ai/flashinfer

Length of output: 3167


🏁 Script executed:

# Also look at the template file to understand its full context
fd "gdn.py" --type f

Repository: flashinfer-ai/flashinfer

Length of output: 123


🏁 Script executed:

# Check for references to disable_state_update in the codebase
rg "disable_state_update" -A 2 -B 2

Repository: flashinfer-ai/flashinfer

Length of output: 49287


🏁 Script executed:

# Get the full template definition including inputs
sed -n '450,550p' flashinfer/trace/templates/gdn.py

Repository: flashinfer-ai/flashinfer

Length of output: 4136


🏁 Script executed:

# Look for where disable_state_update might be defined elsewhere in the template
rg "disable_state_update" flashinfer/trace/templates/gdn.py -B 5 -A 5

Repository: flashinfer-ai/flashinfer

Length of output: 558


🏁 Script executed:

# Check if there are other GDN-related templates that might model disable_state_update
fd "gdn" flashinfer/trace/templates/ -type f

Repository: flashinfer-ai/flashinfer

Length of output: 238


🏁 Script executed:

# Look for all trace templates in the gdn area
find flashinfer/trace/templates -name "*gdn*" -o -name "*delta*"

Repository: flashinfer-ai/flashinfer

Length of output: 101


🏁 Script executed:

# Check what trace templates exist
ls -la flashinfer/trace/templates/

Repository: flashinfer-ai/flashinfer

Length of output: 598


🏁 Script executed:

# Check the gdn_decode.py function signature to see all parameters
sed -n '1,150p' flashinfer/gdn_decode.py | grep -A 50 "def gated_delta_rule_mtp"

Repository: flashinfer-ai/flashinfer

Length of output: 50


🏁 Script executed:

# Get the exact parameter list from gdn_decode.py
grep -A 15 "def gated_delta_rule_mtp" flashinfer/gdn_decode.py | head -20

Repository: flashinfer-ai/flashinfer

Length of output: 586


Add disable_state_update as a template input.

The gated_delta_rule_mtp() function supports a disable_state_update parameter that controls whether final_state is updated, yet the template does not expose it as an input. While the output description correctly mentions this behavior, the schema omission causes fi_trace to emit identical specifications for both state-updating and non-updating modes.

Add disable_state_update as a boolean input (optional or required, as per the function's design) to accurately model the two distinct operational modes.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/trace/templates/gdn.py` around lines 502 - 546, The template
schema is missing the disable_state_update input required by
gated_delta_rule_mtp, so add a boolean Tensor/Scalar entry named
"disable_state_update" to the inputs dict (matching how other flags are
represented) and mark it optional or required consistent with
gated_delta_rule_mtp's signature; ensure you reference the symbol name
"disable_state_update" and update the inputs block near the existing
"initial_state"/"final_state" entries so the trace can distinguish
state-updating vs non-updating behavior described in "final_state".

Comment thread tests/trace/test_fi_trace_template_consistency.py
Comment on lines +384 to +394
def test_fi_trace_complete_gqa_paged_decode():
"""GQA paged decode: tuple paged_kv_cache input handled correctly."""
from flashinfer.decode import BatchDecodeWithPagedKVCacheWrapper
from flashinfer.trace.templates.attention import gqa_paged_decode_trace

B, H, KV, D, P, NP = 4, 8, 4, 64, 16, 8
q = torch.zeros(B, H, D, dtype=torch.bfloat16)
k = torch.zeros(NP, P, KV, D, dtype=torch.bfloat16)
v = torch.zeros(NP, P, KV, D, dtype=torch.bfloat16)

defn = BatchDecodeWithPagedKVCacheWrapper.run.fi_trace(q=q, paged_kv_cache=(k, v))
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Remove unused import gqa_paged_decode_trace.

The import gqa_paged_decode_trace at line 387 is flagged as unused by pre-commit. The test only uses BatchDecodeWithPagedKVCacheWrapper.run.fi_trace.

Suggested fix
 def test_fi_trace_complete_gqa_paged_decode():
     """GQA paged decode: tuple paged_kv_cache input handled correctly."""
     from flashinfer.decode import BatchDecodeWithPagedKVCacheWrapper
-    from flashinfer.trace.templates.attention import gqa_paged_decode_trace
 
     B, H, KV, D, P, NP = 4, 8, 4, 64, 16, 8
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
def test_fi_trace_complete_gqa_paged_decode():
"""GQA paged decode: tuple paged_kv_cache input handled correctly."""
from flashinfer.decode import BatchDecodeWithPagedKVCacheWrapper
from flashinfer.trace.templates.attention import gqa_paged_decode_trace
B, H, KV, D, P, NP = 4, 8, 4, 64, 16, 8
q = torch.zeros(B, H, D, dtype=torch.bfloat16)
k = torch.zeros(NP, P, KV, D, dtype=torch.bfloat16)
v = torch.zeros(NP, P, KV, D, dtype=torch.bfloat16)
defn = BatchDecodeWithPagedKVCacheWrapper.run.fi_trace(q=q, paged_kv_cache=(k, v))
def test_fi_trace_complete_gqa_paged_decode():
"""GQA paged decode: tuple paged_kv_cache input handled correctly."""
from flashinfer.decode import BatchDecodeWithPagedKVCacheWrapper
B, H, KV, D, P, NP = 4, 8, 4, 64, 16, 8
q = torch.zeros(B, H, D, dtype=torch.bfloat16)
k = torch.zeros(NP, P, KV, D, dtype=torch.bfloat16)
v = torch.zeros(NP, P, KV, D, dtype=torch.bfloat16)
defn = BatchDecodeWithPagedKVCacheWrapper.run.fi_trace(q=q, paged_kv_cache=(k, v))
🧰 Tools
🪛 GitHub Actions: pre-commit

[error] 387-387: ruff (F401): flashinfer.trace.templates.attention.gqa_paged_decode_trace imported but unused

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/trace/test_fi_trace_template_consistency.py` around lines 384 - 394,
Remove the unused import gqa_paged_decode_trace from the test; locate the import
statement that reads "from flashinfer.trace.templates.attention import
gqa_paged_decode_trace" and delete it so the test only imports and uses
BatchDecodeWithPagedKVCacheWrapper.run.fi_trace (ensure no other references to
gqa_paged_decode_trace remain in the file).

Comment thread tests/trace/test_fi_trace.py Outdated
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

♻️ Duplicate comments (3)
flashinfer/api_logging.py (1)

1510-1531: ⚠️ Potential issue | 🟠 Major

FLASHINFER_LOGLEVEL=0 no longer preserves the zero-overhead path for traced APIs.

For any @flashinfer_api(trace=...) function, this still returns _attach_fi_trace(...), and _attach_fi_trace() always builds _auto_dump_wrapper. That means every call pays an extra Python frame and _is_trace_dump_enabled() check even when logging is disabled, which contradicts the decorator's documented “returns original function” contract.

Also applies to: 1629-1634

tests/trace/fi_trace_out/mla_paged_decode_h16_ckv512_kpe64_ps64.json (1)

126-126: ⚠️ Potential issue | 🟠 Major

Flatten the paged cache before the reference matmul.

The schema above says ckv_cache/kpe_cache are [num_pages, page_size, head_dim_*], so with page_size=64 the squeeze(1) calls do nothing. Kc_all[tok_idx]/Kp_all[tok_idx] therefore stay 3D, and the later qn @ Kc.T path no longer matches the intended [num_qo_heads, L] attention score computation. Reshape the caches to token-major 2D tensors before indexing, or rewrite the reference to handle paged tensors directly.

For PyTorch 2.x, if ckv_cache has shape [num_pages, page_size, head_dim], what does ckv_cache.squeeze(1) return when page_size=64, and what shape does qn @ Kc.T use when qn is [num_qo_heads, head_dim] and Kc is [L, 64, head_dim]?
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/trace/fi_trace_out/mla_paged_decode_h16_ckv512_kpe64_ps64.json` at line
126, The reference incorrectly uses ckv_cache.squeeze(1) / kpe_cache.squeeze(1)
which leaves a 3D paged tensor when page_size>1; in _mla_paged_decode_reference
you must flatten the page-major caches to token-major 2D tensors
(num_pages*page_size, head_dim_ckv/kpe) before indexing (i.e., compute Kc_all
and Kp_all as [num_tokens, head_dim_*] rather than [num_pages, page_size,
head_dim_*]), then select Kc/Kp with tok_idx so that qn @ Kc.T and qp @ Kp.T
produce [num_qo_heads, L] logits; update the Kc_all/Kp_all creation near their
assignments and ensure subsequent uses (Kc, Kp, logits, output) operate on the
flattened shapes.
tests/trace/example.py (1)

54-551: ⚠️ Potential issue | 🟠 Major

Pytest still won't execute this trace generator.

tests/trace/example.py is still a standalone script with top-level side effects and no test_* entrypoint, so CI won't collect it or validate the generated fixtures. Please move the body into a test/helper and keep a __main__ guard only for manual runs.

As per coding guidelines, tests/**/*.py: Prefix test functions with test_ and structure tests by feature in tests/ subdirectories matching kernel categories.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/trace/example.py` around lines 54 - 551, The file currently executes
trace-generation at import (top-level side effects: SAVE_DIR, the whole sequence
of flashinfer.* calls, and the final files/print summary), so move the entire
body into a callable function (e.g. generate_fi_traces or build_example_traces)
that encapsulates planning/running wrappers and the final JSON-summary logic,
then add a pytest entrypoint test_example_traces() in the same module (or a new
tests/ submodule) that calls that function and asserts expected output (e.g.
presence/count of files from SAVE_DIR or that no exceptions occur), and retain
an if __name__ == "__main__": guard to call generate_fi_traces() for manual
runs; reference SAVE_DIR, the trace-generation sequence (all flashinfer.* calls
and wrapper usages like BatchDecodeWithPagedKVCacheWrapper,
BatchPrefillWithPagedKVCacheWrapper, BatchPrefillWithRaggedKVCacheWrapper,
BatchMLAPagedAttentionWrapper, and the files variable/summary) when making the
changes.
🧹 Nitpick comments (8)
flashinfer/trace/templates/attention.py (3)

28-41: Prefix unused unpacked variables with underscore.

page_size is unpacked from k_cache.shape but never used in the function. Prefix with _ to satisfy linter.

Proposed fix
 def _gqa_paged_decode_reference(q, k_cache, v_cache, kv_indptr, kv_indices, sm_scale):
     batch_size, num_qo_heads, head_dim = q.shape
-    _, page_size, num_kv_heads, _ = k_cache.shape
+    _, _page_size, num_kv_heads, _ = k_cache.shape
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/trace/templates/attention.py` around lines 28 - 41, The variable
page_size unpacked in _gqa_paged_decode_reference is unused and should be
prefixed with an underscore to satisfy the linter; change the unpacking from "_,
page_size, num_kv_heads, _ = k_cache.shape" style to use "_page_size" (or simply
"_" if preferred) so the function signature still captures batch dimensions but
removes the unused symbol while keeping references to q, k_cache, v_cache,
kv_indptr, kv_indices, and sm_scale intact.

244-248: Prefix unused unpacked variable with underscore.

total_kv is unpacked but never used. Prefix with _ to satisfy linter.

Proposed fix
 def _gqa_ragged_prefill_reference(q, k, v, qo_indptr, kv_indptr, sm_scale):
     total_q, num_qo_heads, head_dim = q.shape
-    total_kv, num_kv_heads, _ = k.shape
+    _total_kv, num_kv_heads, _ = k.shape
     len_indptr = qo_indptr.shape[0]
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/trace/templates/attention.py` around lines 244 - 248, In
_gqa_ragged_prefill_reference, the unpacking assigns total_kv from k.shape but
it's unused; change the unpacked name to _total_kv (or prefix with underscore)
in the q, k, v shape assignment so the linter recognizes it as intentionally
unused (i.e., update the tuple unpack on the line with "total_q, num_qo_heads,
head_dim = q.shape" / "total_kv, num_kv_heads, _ = k.shape" to use _total_kv).

125-144: Prefix unused unpacked variables with underscore.

Both num_pages and page_size are unpacked but never used. This triggers ruff RUF059.

Proposed fix
 def _gqa_paged_prefill_reference(
     q, k_cache, v_cache, qo_indptr, kv_indptr, kv_indices, sm_scale
 ):
     total_q, num_qo_heads, head_dim = q.shape
-    num_pages, page_size, num_kv_heads, _ = k_cache.shape
+    _num_pages, _page_size, num_kv_heads, _ = k_cache.shape
     len_indptr = qo_indptr.shape[0]
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/trace/templates/attention.py` around lines 125 - 144, The
variables num_pages and page_size are unpacked in _gqa_paged_prefill_reference
but never used, triggering ruff RUF059; update the tuple unpacking to prefix
unused names with an underscore (e.g., _num_pages, _page_size or simply _, _) in
the line that unpacks k_cache.shape so the intent is clear and the linter
warning is resolved while leaving k_cache usage unchanged.
tests/trace/test_fi_trace_template_consistency.py (3)

440-466: Consider renaming I to INTERMEDIATE or INTER_SIZE for clarity.

The variable name I is flagged as ambiguous (E741) because it can be confused with 1 or l. The same applies to line 488.

Proposed fix
-    T, E, EL, H, I, BS = 4, 16, 2, 256, 64, 128
+    T, E, EL, H, INTER, BS = 4, 16, 2, 256, 64, 128
     defn = trtllm_fp8_block_scale_moe.fi_trace(
         routing_logits=torch.zeros(T, E, dtype=torch.float32),
         routing_bias=torch.zeros(E, dtype=torch.bfloat16),
         hidden_states=torch.zeros(T, H, dtype=torch.float8_e4m3fn),
         hidden_states_scale=torch.ones(H // BS, T, dtype=torch.float32),
-        gemm1_weights=torch.zeros(EL, 2 * I, H, dtype=torch.float8_e4m3fn),
-        gemm1_weights_scale=torch.ones(EL, (2 * I) // BS, H // BS, dtype=torch.float32),
-        gemm2_weights=torch.zeros(EL, H, I, dtype=torch.float8_e4m3fn),
-        gemm2_weights_scale=torch.ones(EL, H // BS, I // BS, dtype=torch.float32),
+        gemm1_weights=torch.zeros(EL, 2 * INTER, H, dtype=torch.float8_e4m3fn),
+        gemm1_weights_scale=torch.ones(EL, (2 * INTER) // BS, H // BS, dtype=torch.float32),
+        gemm2_weights=torch.zeros(EL, H, INTER, dtype=torch.float8_e4m3fn),
+        gemm2_weights_scale=torch.ones(EL, H // BS, INTER // BS, dtype=torch.float32),
         num_experts=E,
         top_k=top_k,
-        intermediate_size=I,
+        intermediate_size=INTER,
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/trace/test_fi_trace_template_consistency.py` around lines 440 - 466,
Rename the ambiguous single-letter variable I to a clearer name like
INTERMEDIATE or INTER_SIZE throughout the test (e.g., the variable declaration
T, E, EL, H, INTERMEDIATE, BS = ... and all uses: gemm1_weights shape (EL, 2 *
INTERMEDIATE, H), gemm2_weights shape (EL, H, INTERMEDIATE), and
intermediate_size=INTERMEDIATE in the trtllm_fp8_block_scale_moe.fi_trace(...)
call) and similarly update any other occurrence on the nearby line 488 so all
references remain consistent.

374-376: Rename ambiguous variable l to label for clarity.

The single-letter l can be confused with 1 or I. Use a more descriptive name.

Proposed fix
-_E2E_PAIRS = [(f, t, l) for f, t, l in _ALL_PAIRS if l not in _E2E_SKIP]
-_E2E_IDS = [label for _, _, label in _E2E_PAIRS]
+_E2E_PAIRS = [(f, t, label) for f, t, label in _ALL_PAIRS if label not in _E2E_SKIP]
+_E2E_IDS = [lbl for _, _, lbl in _E2E_PAIRS]
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/trace/test_fi_trace_template_consistency.py` around lines 374 - 376,
Rename the ambiguous single-letter variable in the list comprehension: change
the unpacking in _E2E_PAIRS from (f, t, l) to (f, t, label) and update the
filter to use label instead of l; also update _E2E_IDS to unpack/use the same
name (e.g., [label for _, _, label in _E2E_PAIRS]) so all references use the
descriptive symbol label while keeping the existing logic with _ALL_PAIRS and
_E2E_SKIP.

560-563: Use raw string for regex pattern with metacharacters.

The pattern contains regex metacharacters (=, *, .) but is not a raw string. While it works due to no escape conflicts, using r"..." is safer and clearer.

Proposed fix
-    with pytest.raises(AssertionError, match="param=.*hidden_state.*not found"):
+    with pytest.raises(AssertionError, match=r"param=.*hidden_state.*not found"):
         assert_template_signature_consistency(func, broken, label="meta-test")
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/trace/test_fi_trace_template_consistency.py` around lines 560 - 563,
The regex string passed to pytest.raises in the test
(match="param=.*hidden_state.*not found") uses metacharacters but isn't a raw
string; change it to a raw string (e.g., r"param=.*hidden_state.*not found") in
the pytest.raises call that wraps assert_template_signature_consistency(func,
broken, label="meta-test") to ensure backslashes and metacharacters are
interpreted correctly; update the test invocation around
_make_gdn_decode_func(), func, and broken accordingly.
flashinfer/trace/templates/moe.py (2)

674-678: Replace ambiguous × (multiplication sign) with ASCII x.

The Unicode multiplication sign × (U+00D7) can cause confusion. Use ASCII x or * instead.

Proposed fix
     "gemm1_out_size": Const(
-        description="Output size of FC1 (2 × intermediate_size for SwiGLU).",
+        description="Output size of FC1 (2 * intermediate_size for SwiGLU).",
         abbrev="",
     ),
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/trace/templates/moe.py` around lines 674 - 678, The description
string for the Const named "gemm1_out_size" contains a Unicode multiplication
sign (`×`); update the description in the gemm1_out_size Const to use an ASCII
"x" (or "*" if you prefer) instead (e.g., change "2 × intermediate_size for
SwiGLU" to "2 x intermediate_size for SwiGLU") so the comment uses plain ASCII
characters.

795-806: FP4 MoE templates cannot be validated against reference implementations.

All FP4 templates pass reference=None because the _make_standard_fp4_moe_trace factory does not accept a reference parameter (unlike the FP8 factory). No FP4 MoE reference implementations are defined. Given that FP4 templates are marked as status:experimental, either implement reference functions or document why validation is deferred.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/trace/templates/moe.py` around lines 795 - 806, The FP4 MoE
factory _make_standard_fp4_moe_trace currently hardcodes reference=None so FP4
templates cannot be validated; update the factory signature to accept an
optional reference parameter (e.g., reference=None) and pass that through into
TraceTemplate(reference=reference), then update all call sites that construct
FP4 MoE traces to provide a proper reference function or explicitly pass None
with a comment; additionally, either implement the missing FP4 MoE reference
functions (and register them where other references live) or add clear
documentation in the template module explaining that FP4 validation is deferred
and why.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@flashinfer/api_logging.py`:
- Around line 1521-1527: The trace auto-dump currently swallows all exceptions
in the block guarded by _is_trace_dump_enabled(), making failures invisible;
change the except Exception: pass to catch Exception as e and emit a non-fatal
log (e.g., processLogger.warning or module logger) that includes the failing
trace function name (use fi_trace_fn.__name__ or _sig.signature info) and the
exception information (e) or traceback, then continue; update the try/except
around _sig.bind(*args, **kwargs), bound.apply_defaults(), and
fi_trace_fn(**dict(bound.arguments)) to log the diagnostic while keeping the
call non-fatal.

---

Duplicate comments:
In `@tests/trace/example.py`:
- Around line 54-551: The file currently executes trace-generation at import
(top-level side effects: SAVE_DIR, the whole sequence of flashinfer.* calls, and
the final files/print summary), so move the entire body into a callable function
(e.g. generate_fi_traces or build_example_traces) that encapsulates
planning/running wrappers and the final JSON-summary logic, then add a pytest
entrypoint test_example_traces() in the same module (or a new tests/ submodule)
that calls that function and asserts expected output (e.g. presence/count of
files from SAVE_DIR or that no exceptions occur), and retain an if __name__ ==
"__main__": guard to call generate_fi_traces() for manual runs; reference
SAVE_DIR, the trace-generation sequence (all flashinfer.* calls and wrapper
usages like BatchDecodeWithPagedKVCacheWrapper,
BatchPrefillWithPagedKVCacheWrapper, BatchPrefillWithRaggedKVCacheWrapper,
BatchMLAPagedAttentionWrapper, and the files variable/summary) when making the
changes.

In `@tests/trace/fi_trace_out/mla_paged_decode_h16_ckv512_kpe64_ps64.json`:
- Line 126: The reference incorrectly uses ckv_cache.squeeze(1) /
kpe_cache.squeeze(1) which leaves a 3D paged tensor when page_size>1; in
_mla_paged_decode_reference you must flatten the page-major caches to
token-major 2D tensors (num_pages*page_size, head_dim_ckv/kpe) before indexing
(i.e., compute Kc_all and Kp_all as [num_tokens, head_dim_*] rather than
[num_pages, page_size, head_dim_*]), then select Kc/Kp with tok_idx so that qn @
Kc.T and qp @ Kp.T produce [num_qo_heads, L] logits; update the Kc_all/Kp_all
creation near their assignments and ensure subsequent uses (Kc, Kp, logits,
output) operate on the flattened shapes.

---

Nitpick comments:
In `@flashinfer/trace/templates/attention.py`:
- Around line 28-41: The variable page_size unpacked in
_gqa_paged_decode_reference is unused and should be prefixed with an underscore
to satisfy the linter; change the unpacking from "_, page_size, num_kv_heads, _
= k_cache.shape" style to use "_page_size" (or simply "_" if preferred) so the
function signature still captures batch dimensions but removes the unused symbol
while keeping references to q, k_cache, v_cache, kv_indptr, kv_indices, and
sm_scale intact.
- Around line 244-248: In _gqa_ragged_prefill_reference, the unpacking assigns
total_kv from k.shape but it's unused; change the unpacked name to _total_kv (or
prefix with underscore) in the q, k, v shape assignment so the linter recognizes
it as intentionally unused (i.e., update the tuple unpack on the line with
"total_q, num_qo_heads, head_dim = q.shape" / "total_kv, num_kv_heads, _ =
k.shape" to use _total_kv).
- Around line 125-144: The variables num_pages and page_size are unpacked in
_gqa_paged_prefill_reference but never used, triggering ruff RUF059; update the
tuple unpacking to prefix unused names with an underscore (e.g., _num_pages,
_page_size or simply _, _) in the line that unpacks k_cache.shape so the intent
is clear and the linter warning is resolved while leaving k_cache usage
unchanged.

In `@flashinfer/trace/templates/moe.py`:
- Around line 674-678: The description string for the Const named
"gemm1_out_size" contains a Unicode multiplication sign (`×`); update the
description in the gemm1_out_size Const to use an ASCII "x" (or "*" if you
prefer) instead (e.g., change "2 × intermediate_size for SwiGLU" to "2 x
intermediate_size for SwiGLU") so the comment uses plain ASCII characters.
- Around line 795-806: The FP4 MoE factory _make_standard_fp4_moe_trace
currently hardcodes reference=None so FP4 templates cannot be validated; update
the factory signature to accept an optional reference parameter (e.g.,
reference=None) and pass that through into TraceTemplate(reference=reference),
then update all call sites that construct FP4 MoE traces to provide a proper
reference function or explicitly pass None with a comment; additionally, either
implement the missing FP4 MoE reference functions (and register them where other
references live) or add clear documentation in the template module explaining
that FP4 validation is deferred and why.

In `@tests/trace/test_fi_trace_template_consistency.py`:
- Around line 440-466: Rename the ambiguous single-letter variable I to a
clearer name like INTERMEDIATE or INTER_SIZE throughout the test (e.g., the
variable declaration T, E, EL, H, INTERMEDIATE, BS = ... and all uses:
gemm1_weights shape (EL, 2 * INTERMEDIATE, H), gemm2_weights shape (EL, H,
INTERMEDIATE), and intermediate_size=INTERMEDIATE in the
trtllm_fp8_block_scale_moe.fi_trace(...) call) and similarly update any other
occurrence on the nearby line 488 so all references remain consistent.
- Around line 374-376: Rename the ambiguous single-letter variable in the list
comprehension: change the unpacking in _E2E_PAIRS from (f, t, l) to (f, t,
label) and update the filter to use label instead of l; also update _E2E_IDS to
unpack/use the same name (e.g., [label for _, _, label in _E2E_PAIRS]) so all
references use the descriptive symbol label while keeping the existing logic
with _ALL_PAIRS and _E2E_SKIP.
- Around line 560-563: The regex string passed to pytest.raises in the test
(match="param=.*hidden_state.*not found") uses metacharacters but isn't a raw
string; change it to a raw string (e.g., r"param=.*hidden_state.*not found") in
the pytest.raises call that wraps assert_template_signature_consistency(func,
broken, label="meta-test") to ensure backslashes and metacharacters are
interpreted correctly; update the test invocation around
_make_gdn_decode_func(), func, and broken accordingly.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: 4aacac64-d181-4f0b-a266-998fd025caf7

📥 Commits

Reviewing files that changed from the base of the PR and between 2f4aceb and c2843a5.

📒 Files selected for processing (22)
  • flashinfer/api_logging.py
  • flashinfer/fi_trace.py
  • flashinfer/fused_moe/core.py
  • flashinfer/trace/templates/attention.py
  • flashinfer/trace/templates/moe.py
  • tests/trace/example.py
  • tests/trace/fi_trace_out/mla_paged_decode_h16_ckv512_kpe64_ps1.json
  • tests/trace/fi_trace_out/mla_paged_decode_h16_ckv512_kpe64_ps64.json
  • tests/trace/fi_trace_out/moe_fp4_block_scale_default_routing_topk8_e32_h7168_i2048.json
  • tests/trace/fi_trace_out/moe_fp4_block_scale_ds_routing_topk8_e32_h7168_i2048_ng8_kg4.json
  • tests/trace/fi_trace_out/moe_fp4_block_scale_llama4_routing_topk1_e32_h7168_i2048.json
  • tests/trace/fi_trace_out/moe_fp4_block_scale_renormalize_naive_routing_topk8_e32_h7168_i2048.json
  • tests/trace/fi_trace_out/moe_fp4_block_scale_renormalize_routing_topk8_e32_h7168_i2048.json
  • tests/trace/fi_trace_out/moe_fp4_block_scale_topk_routing_topk8_e32_h7168_i2048.json
  • tests/trace/fi_trace_out/moe_fp8_block_scale_default_routing_topk8_e32_h7168_i2048.json
  • tests/trace/fi_trace_out/moe_fp8_block_scale_ds_routing_topk8_ng8_kg4_e32_h7168_i2048.json
  • tests/trace/fi_trace_out/moe_fp8_block_scale_llama4_routing_topk1_e32_h7168_i2048.json
  • tests/trace/fi_trace_out/moe_fp8_block_scale_renormalize_naive_routing_topk8_e32_h7168_i2048.json
  • tests/trace/fi_trace_out/moe_fp8_block_scale_renormalize_routing_topk8_e32_h7168_i2048.json
  • tests/trace/fi_trace_out/moe_fp8_block_scale_topk_routing_topk8_e32_h7168_i2048.json
  • tests/trace/test_fi_trace.py
  • tests/trace/test_fi_trace_template_consistency.py
✅ Files skipped from review due to trivial changes (10)
  • tests/trace/fi_trace_out/moe_fp4_block_scale_ds_routing_topk8_e32_h7168_i2048_ng8_kg4.json
  • tests/trace/fi_trace_out/moe_fp8_block_scale_renormalize_routing_topk8_e32_h7168_i2048.json
  • tests/trace/fi_trace_out/moe_fp4_block_scale_renormalize_routing_topk8_e32_h7168_i2048.json
  • tests/trace/fi_trace_out/moe_fp4_block_scale_llama4_routing_topk1_e32_h7168_i2048.json
  • tests/trace/fi_trace_out/moe_fp8_block_scale_default_routing_topk8_e32_h7168_i2048.json
  • tests/trace/fi_trace_out/moe_fp8_block_scale_renormalize_naive_routing_topk8_e32_h7168_i2048.json
  • tests/trace/fi_trace_out/moe_fp8_block_scale_topk_routing_topk8_e32_h7168_i2048.json
  • tests/trace/fi_trace_out/mla_paged_decode_h16_ckv512_kpe64_ps1.json
  • tests/trace/fi_trace_out/moe_fp4_block_scale_renormalize_naive_routing_topk8_e32_h7168_i2048.json
  • tests/trace/fi_trace_out/moe_fp4_block_scale_default_routing_topk8_e32_h7168_i2048.json
🚧 Files skipped from review as they are similar to previous changes (1)
  • flashinfer/fused_moe/core.py

Comment thread flashinfer/api_logging.py Outdated
averyhNV and others added 7 commits April 21, 2026 01:41
- Add flashinfer/trace/templates/activation.py: silu_and_mul, gelu_and_mul,
  gelu_tanh_and_mul (used in FFN layers of LLaMA/Mistral/GPT-style models)
- Add flashinfer/trace/templates/cascade.py: merge_state, merge_state_in_place,
  merge_states (cascade/speculative attention state merging)
- Extend flashinfer/trace/templates/norm.py: rmsnorm_quant, fused_add_rmsnorm_quant,
  gemma_rmsnorm, gemma_fused_add_rmsnorm, layernorm (additional norm variants)
- Wire @flashinfer_api(trace=...) for all 11 new templates in activation.py,
  cascade.py, and norm/__init__.py
- Update example.py: add activation and cascade calls, update docstring to list
  all 39 expected output files (33 original + 6 new)
- Add tests/trace/fi_trace_out/ to .gitignore

AI-assisted

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Add example calls in tests/trace/example.py for rmsnorm_quant,
fused_add_rmsnorm_quant, gemma_rmsnorm, gemma_fused_add_rmsnorm,
layernorm, and gdn_prefill. Update docstring to list all 45 expected
JSON files.

Add "Trace Template Checklist" section to CLAUDE.md documenting the
steps for wiring trace to new APIs.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
- Remove tests/trace/fi_trace_out/ from .gitignore so generated benchmark
  definition JSONs are committed alongside the code that produces them.
- Wrap mm_bf16 and mm_fp8 calls in contextlib.suppress so example.py runs
  end-to-end on SM90 (H100). mm_bf16 now uses backend="auto" (cudnn on
  SM<100, cutlass on SM100+); mm_fp8's low-latency GEMM is SM100-only at
  runtime but the trace still dumps before launch.
- Add newly-generated trace JSONs for the activation, cascade, norm-quant,
  gemma-norm, layernorm, and gdn-prefill APIs.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
… state update

Three sites had @flashinfer_api on a subclass or internal helper whose
parent/caller was already decorated, producing duplicate log entries at
higher FLASHINFER_LOGLEVEL values. Remove the redundant decorator:

- BatchAttentionWithAttentionSinkWrapper.__init__ (parent
  BatchPrefillWithPagedKVCacheWrapper.__init__ already decorated)
- CUDAGraphBatchDecodeWithPagedKVCacheWrapper.__init__ (parent
  BatchDecodeWithPagedKVCacheWrapper.__init__ already decorated)
- trtllm_low_latency_gemm (called internally by the already-decorated
  mm_fp8)

Also fix _gdn_mtp_reference in flashinfer/trace/templates/gdn.py: the
function was returning initial_state.clone() as final_state, silently
discarding every state update accumulated across the T tokens. Now
final_state is built once outside the batch loop and the [H,K,V]
scratch buffer is committed back to the pool slot as [H,V,K] after
each sequence. Regenerate gdn_mtp_qk4_v8_d128.json so the embedded
reference matches.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Copy link
Copy Markdown
Collaborator

@bkryu bkryu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall looks good to me. Left a comment about some (I believe) non-APIs that we may not want to decorate.

Also, not sure if you intended the PR to be exhaustive, but I think you missed:

  • cuDNN and TRTLLM attention variants
  • CUTLASS Fused MoE
  • Quantization APIs
  • RoPe APIs

Comment thread flashinfer/gemm/gemm_base.py Outdated
return graph


@flashinfer_api
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure if this is a FlashInfer API that we expect users to run. Do we need the label here?

Comment thread flashinfer/gemm/gemm_base.py Outdated
return graph


@flashinfer_api
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ditto

Comment thread flashinfer/gemm/gemm_base.py Outdated
return graph


@flashinfer_api
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ditto

Comment thread flashinfer/gemm/gemm_base.py Outdated
return graph


@flashinfer_api
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ditto

averyhNV and others added 6 commits April 21, 2026 18:44
Demonstrates that @flashinfer_api(trace=...) auto-dump is compatible
with torch.cuda.graph capture:

- Schema extraction reads only CPU-side metadata (.shape, .dtype) and
  writes JSON via host-thread file I/O — no CUDA stream ops, so nothing
  corrupts the captured graph even if a write fires inside the capture
  block.
- The _DUMPED_NAMES dedup in flashinfer/trace/template.py ensures at
  most one write per (process, trace name), so re-entering the decorated
  wrapper during capture is cheap.
- Graph replay does not execute Python, so auto-dump cannot fire on
  replay under any circumstance.

Example uses CUDAGraphBatchDecodeWithPagedKVCacheWrapper with
Llama-3.1-8B shapes, captures wrapper.run(), replays 5×, and verifies
numerical equivalence to eager.

fi_trace_out_cudagraph/ is gitignored — the single JSON it produces is
identical to the one committed under fi_trace_out/ for the same op.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Remove tests/trace/fi_trace_out_cudagraph/ from .gitignore and commit
the single JSON produced by example_cuda_graph.py so reviewers can
inspect the schema without running the example.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
…n reference matmul

### B1 — GEMM references compute A @ B instead of A @ B.T
`_mm_reference` and the three quantized helpers in
flashinfer/trace/templates/gemm.py modeled `B` as physical `[K, N]` in
the template inputs but then computed `A @ B.T`, which is only valid
when `K == N`. This would crash for every non-square weight shape we
trace (e.g. 7168→256 in example.py). Drop the `.T` in all four refs
and update the three "C = A @ B.T" template descriptions.

### B2 — paged GQA refs treat kv_indices as token IDs instead of page IDs
`_gqa_paged_decode_reference` and `_gqa_paged_prefill_reference`
flattened `k_cache` to `[num_tokens, ...]` and indexed with
`kv_indices`, which are page IDs. The lookup only gave correct tokens
when `page_size == 1`. Gather pages first, then reshape the gathered
`[num_selected_pages, page_size, ...]` into a single token axis.

### B3 — MLA refs silently assumed page_size=1 via squeeze(1)
`_mla_paged_decode_reference` and `_mla_paged_prefill_reference` used
`ckv_cache.squeeze(1)` which is a no-op for page_size != 1, leaving a
3-D tensor that would break later matmuls. Apply the same page-gather
fix as B2 so both page_size=1 and page_size>1 MLA work.

Regenerate the 7 affected JSON fixtures and the cuda-graph example JSON
so their embedded reference strings reflect the fixes.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
… auto-dump diag

B4: _fp8_moe_run_experts in flashinfer/trace/templates/moe.py no longer
reads module-level H=7168/I=2048/BLOCK=128; those are derived from
hidden_states.shape and gemm1_weights.shape so the reference is valid
for any MoE shape, not just DeepSeek-V3.

B5: The five fp8 MoE routing references now accept top_k (and n_group
/ topk_group for DeepSeek-V3) as explicit parameters instead of
hardcoding TOP_K=8/N_GROUP=8/TOPK_GROUP=4. Corresponding Scalar
inputs are added to each template so external consumers of the trace
JSON pass the correct routing configuration.

B6: gdn_prefill_trace gains the head-ratio constraints
(num_v_heads >= num_q_heads, divisibility, num_k_heads == num_q_heads)
that its reference already assumes via repeat_interleave.

B7: GDN decode/prefill/MTP outputs now declare dtype="bfloat16" to
match the reference (the references always emit bfloat16, so the
previous dtype_from="q" was a lie when q was fp16 or fp32).

B9: scale Scalar is marked optional=True in all three GDN templates
(decode/prefill/MTP). The reference already handles scale=None.

B10: Drop the "Unchanged if disable_state_update=True" phrase from
gdn_mtp_trace.final_state — disable_state_update is a real kwarg on
gated_delta_rule_mtp but not modelled as an input on the template, so
referencing it in the description was misleading.

B8: tests/trace/test_fi_trace_template_consistency.py E2E synthesizer
uses per-key positive defaults for int32 scalars
(block_size=16, top_k=1, n_group=1, topk_group=1, ...) instead of 0,
so synthesized definitions are semantically valid.

B11: _auto_dump_wrapper in flashinfer/api_logging.py now emits a
warnings.warn() when schema binding or trace file write fails, deduped
per (API name, error class). Users previously saw missing JSON files
with no explanation.

Regenerate the 6 MoE JSON fixtures + GDN decode/prefill/MTP fixtures so
the embedded reference strings and input schemas match.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
…e helpers

Per bkryu's review on PR flashinfer-ai#2931: the four
execute_cudnn_gemm_*_graph_override_shape functions in
flashinfer/gemm/gemm_base.py are internal helpers called from the
already-decorated mm_fp4 / mm_mxfp8 / mm_fp8 / mm_bf16 user APIs.
Decorating them too causes double log entries at
FLASHINFER_LOGLEVEL>=1 (same pattern fixed earlier for
trtllm_low_latency_gemm and the CUDAGraph wrapper __init__).

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Per bkryu's review on PR flashinfer-ai#2931: several user-facing APIs were decorated
with @flashinfer_api but had no trace template attached. This commit
wires trace templates to RoPE and quantization.

RoPE (flashinfer/trace/templates/rope.py, 10 new templates):
  - apply_rope / apply_rope_inplace
  - apply_rope_pos_ids / apply_rope_pos_ids_inplace
  - apply_llama31_rope / apply_llama31_rope_inplace
  - apply_llama31_rope_pos_ids / apply_llama31_rope_pos_ids_inplace
  - apply_rope_with_cos_sin_cache / apply_rope_with_cos_sin_cache_inplace

Quantization (flashinfer/trace/templates/quantize.py, 4 new templates):
  - fp4_quantize, nvfp4_quantize, mxfp4_quantize, mxfp8_quantize

Follow-ups (not addressed in this commit): cuDNN/TRTLLM attention
variants (single_prefill/single_decode, cudnn_batch_*, trtllm_batch_*)
and MoE variants (cutlass_fused_moe, trtllm_bf16_moe, etc.) still need
templates.

Add example calls for RoPE and quantization in tests/trace/example.py
and commit the 14 regenerated JSON fixtures.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Copy link
Copy Markdown
Contributor

@Ubospica Ubospica left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

averyhNV and others added 7 commits April 22, 2026 00:12
…fine attention template descriptions

Addresses bkryu's top-level review on PR flashinfer-ai#2931 listing missing trace
templates, and responds to follow-up feedback that attention
descriptions were redundant.

New templates (13):
  attention.py: single_decode_with_kv_cache_trace,
    single_prefill_with_kv_cache_trace,
    trtllm_batch_decode_trace, trtllm_batch_context_trace,
    cudnn_batch_decode_trace, cudnn_batch_prefill_trace
  moe.py: cutlass_fused_moe_trace, trtllm_bf16_moe_trace,
    trtllm_bf16_routed_moe_trace, trtllm_fp8_per_tensor_scale_moe_trace,
    trtllm_fp8_block_scale_routed_moe_trace,
    trtllm_fp4_block_scale_routed_moe_trace,
    trtllm_mxint4_block_scale_moe_trace

Wire-ups:
  flashinfer/decode.py: single_decode_with_kv_cache,
    trtllm_batch_decode_with_kv_cache
  flashinfer/prefill.py: single_prefill_with_kv_cache,
    trtllm_batch_context_with_kv_cache
  flashinfer/cudnn/decode.py: cudnn_batch_decode_with_kv_cache
  flashinfer/cudnn/prefill.py: cudnn_batch_prefill_with_kv_cache
  flashinfer/fused_moe/core.py: 7 MoE variants

Attention description polish (flashinfer/trace/templates/attention.py):
  Replaced verbose cross-referencing paragraphs with one- or two-
  sentence identifiers that state (a) the API wrapped, (b) one or two
  distinctive structural features. Added a module-level comparison
  table as the single source of truth for how templates differ. The
  table lists each template's batching, KV layout, indexing mechanism,
  stage, and backend, so consumers can pick the right template without
  parsing every description.

Also add per-key positive int32 defaults in the E2E synthesizer for
num_experts, intermediate_size, hidden_size (in addition to the
earlier block_size/top_k/n_group/topk_group defaults) and introduce
_TRTLLM_MOE_ROUTED_AXES so routed-variant templates mark num_experts
and intermediate_size as Var (they arrive as scalar kwargs when
topk_ids is pre-computed, so the routing_logits shape can't resolve
them).

Tests: 220 passed (was 139 before the whole review cycle).
Regenerate affected JSON fixtures so their embedded descriptions and
schemas match.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
The six trtllm_fp4_block_scale_moe_*_routing_trace templates previously
had reference=None. This commit adds executable reference functions
modelled after tests/moe/test_trtllm_gen_fused_moe.py::run_moe_dequant,
so external consumers (flashinfer-bench) can verify kernel output
against the reference.

Helpers added to flashinfer/trace/templates/moe.py:
  - _unpack_fp4_e2m1: 16-entry LUT-based unpack of uint8-packed
    e2m1fn FP4 values into float32 (sign + exponent + mantissa), so
    the returned tensor has twice the packed last dim.
  - _ue8m0_to_float32: decode UE8M0 (MX-format) scales.
  - _decode_block_scales: dispatches UE8M0 vs fp8_e4m3fn based on the
    scale dtype.
  - _dequantize_fp4_tensor: unpack + apply per-block scales to a
    packed FP4 tensor. Block size is inferred from the shape ratio so
    NvFP4 (block_size=16) and MXFP4 (block_size=32) both work.
  - _dequantize_fp4_hidden_states: handles the three activation
    formats the kernel accepts — bfloat16, float8_e4m3fn (MXFP8) with
    UE8M0 per-32 scales, and uint8-packed FP4.

Shared MoE kernel (_fp4_moe_run_experts): dequantizes weights and
hidden states, gathers per-expert tokens, does GEMM1 → SwiGLU
(silu(X2) * X1 to match trtllm-gen's convention) → GEMM2, applies
optional biases, and combines per-expert contributions weighted by
the routing weights. Emits bfloat16 output to match the template
schema.

Per-routing references (6, one per RoutingMethodType.{Default,
Renormalize, DeepSeekV3, Llama4, RenormalizeNaive, TopK}) compute
their own topk_idx + weights and call _fp4_moe_run_experts. DS
routing replicates the sigmoid → group-top2 → topk_group → top_k
path used in DeepSeek-V3.

Verified all six paths produce finite bfloat16 output of the expected
shape on NvFP4 hidden states (uint8 packed + fp8_e4m3fn scales),
MXFP8 hidden states (float8_e4m3fn + UE8M0 scales), and bf16
hidden states. Also verified the E2M1 LUT: nibble 0x7 → 6.0,
0xF → -6.0, etc.

Regenerate all six FP4 MoE JSON fixtures so they embed the new
reference source (previously absent).

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
…rectness tests

Before this commit: 14 templates had reference=None. Now: every template
has an executable reference, and each reference is verified numerically
against its corresponding flashinfer API in
tests/trace/test_reference_correctness.py.

Templates with new references (per file):

  flashinfer/trace/templates/rope.py (10):
    apply_rope, apply_rope_inplace, apply_rope_pos_ids,
    apply_rope_pos_ids_inplace, apply_llama31_rope,
    apply_llama31_rope_inplace, apply_llama31_rope_pos_ids,
    apply_llama31_rope_pos_ids_inplace,
    apply_rope_with_cos_sin_cache, apply_rope_with_cos_sin_cache_inplace
    Helpers: _rope_freqs, _llama31_freqs (piecewise NTK scaling),
    _rotate, _positions_from_indptr, _apply_rope_core.

  flashinfer/trace/templates/norm.py (2):
    rmsnorm_quant, fused_add_rmsnorm_quant (RMSNorm + per-tensor
    FP8 quantize; returns fp8_e4m3fn + optional updated residual).

  flashinfer/trace/templates/cascade.py (1):
    merge_state_in_place (LSE-weighted merge with optional mask).

  flashinfer/trace/templates/quantize.py (4):
    fp4_quantize, nvfp4_quantize, mxfp4_quantize, mxfp8_quantize.
    E2M1 nearest-magnitude rounding, UE8M0 vs fp8_e4m3fn scale
    decoding, NvFP4 block_size=16 vs MXFP4/MXFP8 block_size=32.

  flashinfer/trace/templates/attention.py (6):
    single_decode, single_prefill (contiguous KV SDPA with causal),
    trtllm_batch_decode, trtllm_batch_context (rectangular block_tables
    + interleaved kv_cache + bmm1/bmm2 scales),
    cudnn_batch_decode, cudnn_batch_prefill (separate k/v caches,
    actual_seq_lens_q/kv, optional LSE return).
    Helpers: _trtllm_kv_from_cache,
    _trtllm_paged_attention_reference.

  flashinfer/trace/templates/moe.py (7):
    cutlass_fused_moe (precomputed expert ids + scales),
    trtllm_bf16_moe, trtllm_bf16_routed_moe (un-quantized),
    trtllm_fp8_per_tensor_scale_moe (per-expert scalar scales),
    trtllm_fp8_block_scale_routed_moe,
    trtllm_fp4_block_scale_routed_moe (reuses _fp8_moe_run_experts
    / _fp4_moe_run_experts),
    trtllm_mxint4_block_scale_moe (int4 unpack + bf16 scales).

Correctness tests (tests/trace/test_reference_correctness.py):
  18 numerical tests compare reference output to the live flashinfer
  API on the same inputs, within per-dtype tolerances:
    - RoPE (10): bf16 output within 5e-2 of kernel (1 bf16 ULP)
    - rmsnorm_quant, fused_add_rmsnorm_quant: residual exact; fp8
      output compared after multiplying by scale
    - merge_state_in_place: bf16/float32 within 5e-3
    - mxfp8_quantize, fp4_quantize round-trip: within 50% relative
      error (FP4 has inherent quantization error)
    - single_decode, single_prefill (causal): within 5e-2
  5 tests are marked skipped with clear reasons (cuDNN/TRT-LLM
  kernels require specific runtime/hardware; those references are
  covered by the shape-and-finite smoke test
  test_moe_references_produce_valid_outputs).

Also regenerate every trace JSON under tests/trace/fi_trace_out/ so
the new reference source strings are embedded in the committed
fixtures.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
basedpyright flagged ~45 "parameter not accessed" hints in the new MoE
references. The unused params are intentional — the references accept
the full API signature so external consumers can call them with the
same kwargs they'd pass to the corresponding flashinfer API. Add
explicit ``del`` statements at the top of each reference to document
that the params are accepted for API parity but unused in the
reference computation, silencing the hints.

Affects the 7 references added in the previous commit:
  _cutlass_fused_moe_reference, _trtllm_bf16_moe_reference,
  _trtllm_bf16_routed_moe_reference,
  _trtllm_fp8_per_tensor_scale_moe_reference,
  _trtllm_fp8_block_scale_routed_moe_reference,
  _trtllm_fp4_block_scale_routed_moe_reference,
  _trtllm_mxint4_block_scale_moe_reference

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
…ignatures

Previous commit added ``del unused_a, unused_b, ...`` at the top of each of
the 7 new MoE references to silence basedpyright's
``reportUnusedParameter`` hints. That was noisy boilerplate.

The more Pythonic fix is to (a) drop parameters that neither the template
inputs schema nor the reference body reference, and (b) rename the
catch-all ``**kwargs`` to ``**_unused`` — the ``_`` prefix is the standard
convention that tells linters "intentionally unused." External callers can
still pass any extra API kwargs by keyword; they land in ``**_unused`` and
are silently discarded.

Net effect per reference:
  _cutlass_fused_moe_reference:
    drop output_dtype, quant_scales (kept via **_unused)
  _trtllm_bf16_moe_reference:
    drop n_group, topk_group, intermediate_size, local_num_experts,
    routing_method_type (kept via **_unused)
  _trtllm_bf16_routed_moe_reference:
    drop n_group, topk_group, intermediate_size, local_num_experts
  _trtllm_fp8_per_tensor_scale_moe_reference:
    drop n_group, topk_group, intermediate_size, local_num_experts,
    routing_method_type
  _trtllm_fp8_block_scale_routed_moe_reference:
    drop routing_bias, n_group, topk_group, intermediate_size,
    local_num_experts
  _trtllm_fp4_block_scale_routed_moe_reference:
    drop routing_bias, gemm1_alpha/beta/clamp_limit, output1_scale_scalar,
    output1_scale_gate_scalar, output2_scale_scalar, n_group, topk_group,
    intermediate_size, local_num_experts
  _trtllm_mxint4_block_scale_moe_reference:
    drop gemm1_alpha/beta/clamp_limit, n_group, topk_group,
    intermediate_size, local_num_experts, routing_method_type

Net diff: +51 / -75 — shorter, self-documenting signatures with no ``del``
boilerplate, and basedpyright is quiet. Test
``test_moe_references_produce_valid_outputs`` updated to call the MxInt4
reference with all-kwargs so it doesn't rely on positional ordering.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants