feat: add get flashinfer-trace interface .fi_trace#2931
feat: add get flashinfer-trace interface .fi_trace#2931yyihuang wants to merge 29 commits intoflashinfer-ai:mainfrom
Conversation
|
Note Reviews pausedIt looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the Use the following commands to manage reviews:
Use the checkboxes below for quick actions:
📝 WalkthroughWalkthroughThis change adds a TraceTemplate-based tracing system, fi_trace generation and registration, attaches trace templates via an extended Changes
Sequence Diagram(s)sequenceDiagram
participant Client
participant API as flashinfer_api wrapper
participant Template as TraceTemplate / dispatcher
participant FiTrace as fi_trace builder
participant FS as Filesystem
Client->>API: call decorated function (possibly trace=callable)
API->>Template: resolve trace_template (dispatch or static)
Template-->>FiTrace: build fi_trace_fn (bind reference, axes, inputs/outputs)
API->>FiTrace: if tracing enabled -> invoke fi_trace_fn(**bound_args)
FiTrace->>FS: write <name>.json (if save_dir or env enabled)
FiTrace-->>API: return trace dict
API-->>Client: execute original function and return result
Estimated code review effort🎯 4 (Complex) | ⏱️ ~60 minutes Possibly related PRs
Suggested reviewers
Poem
🚥 Pre-merge checks | ✅ 2 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Code Review
This pull request adds the @flashinfer_api decorator to multiple classes and functions across the library, including attention and decode wrappers as well as GEMM execution utilities, to enable API logging. The review feedback points out that applying this decorator to subclasses whose base classes are already decorated results in redundant log entries. Additionally, nested calls between decorated functions may lead to duplicate logging, suggesting that the logging logic should handle re-entrancy or that certain decorators should be removed to reduce overhead.
| a convenient interface for using attention sinks during prefill or decode attention. | ||
| """ | ||
|
|
||
| @flashinfer_api |
There was a problem hiding this comment.
Adding @flashinfer_api to BatchAttentionWithAttentionSinkWrapper.__init__ will result in double logging during initialization. This class inherits from BatchPrefillWithPagedKVCacheWrapper, whose __init__ method is already decorated with @flashinfer_api. Since the decorator uses the class name of the instance (args[0]), both the subclass and base class decorators will log an entry for BatchAttentionWithAttentionSinkWrapper.__init__. This redundancy clutters the logs and adds unnecessary overhead. Consider removing the decorator from the subclass if the base class logging is sufficient for your tracing needs.
| :class:`BatchDecodeWithPagedKVCacheWrapper` | ||
| """ | ||
|
|
||
| @flashinfer_api |
There was a problem hiding this comment.
Similar to the issue in BatchAttentionWithAttentionSinkWrapper, decorating CUDAGraphBatchDecodeWithPagedKVCacheWrapper.__init__ leads to redundant log entries because its base class BatchDecodeWithPagedKVCacheWrapper.__init__ is already decorated. Both will log as CUDAGraphBatchDecodeWithPagedKVCacheWrapper.__init__ due to how the decorator resolves the class name from the instance.
| ) | ||
|
|
||
|
|
||
| @flashinfer_api |
There was a problem hiding this comment.
Decorating trtllm_low_latency_gemm will cause double logging when it is called internally by other decorated APIs, such as mm_fp8 in flashinfer/gemm/gemm_base.py. While it is important to trace this function when called directly, the current logging implementation will produce redundant entries for nested calls. This should ideally be addressed in the logging decorator's logic to handle re-entrancy, but for now, be aware of the log duplication.
There was a problem hiding this comment.
Actionable comments posted: 2
Note
Due to the large number of review comments, Critical severity comments were prioritized as inline comments.
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (2)
flashinfer/gdn_prefill.py (1)
86-100: 🛠️ Refactor suggestion | 🟠 MajorAdd backend capability gating on this SM90-only API.
chunk_gated_delta_ruledocuments an SM90 requirement but is not decorated with@backend_requirement. Please add the backend/capability gate alongside@flashinfer_api(...)so unsupported devices fail fast with a clear message.As per coding guidelines:
Use@backend_requirementdecorator on APIs that have compute capability requirements and provide is_compute_capability_supported(cc) and is_backend_supported() methods.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@flashinfer/gdn_prefill.py` around lines 86 - 100, The function chunk_gated_delta_rule is SM90-only but lacks the backend capability guard; add the `@backend_requirement`(...) decorator alongside `@flashinfer_api`(trace=gdn_prefill_trace) to check is_backend_supported() and is_compute_capability_supported(cc) for SM90 and return a clear fail-fast message for unsupported devices. Use the decorator to declare the required compute capability (SM90) and backend, referencing chunk_gated_delta_rule so the check runs before execution and produces a helpful error if the device is not supported.flashinfer/trtllm_low_latency_gemm.py (1)
119-125: 🛠️ Refactor suggestion | 🟠 MajorAdd
@backend_requirementfor this Blackwell-only entrypoint.
trtllm_low_latency_gemmis documented as Blackwell-only, but the API is not gated with@backend_requirement. Please add the explicit capability/backend guard so callers get deterministic early validation.As per coding guidelines:
Use@backend_requirementdecorator on APIs that have compute capability requirements and provide is_compute_capability_supported(cc) and is_backend_supported() methods.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@flashinfer/trtllm_low_latency_gemm.py` around lines 119 - 125, Add the `@backend_requirement` decorator to the trtllm_low_latency_gemm entrypoint to gate it to Blackwell-only execution: place `@backend_requirement`(...) immediately above the trtllm_low_latency_gemm definition and provide checks that call the module's support helpers (e.g., is_compute_capability_supported and is_backend_supported) or small wrapper functions that return True only for Blackwell compute capability/backend; ensure the decorator references the correct check functions so callers receive deterministic early validation for Blackwell-only usage of trtllm_low_latency_gemm.
🟠 Major comments (23)
flashinfer/trace/example/fi_trace_out/gemm_mxfp8_N4096_K4096.json-32-38 (1)
32-38:⚠️ Potential issue | 🟠 MajorGEMM reference uses an incompatible transpose with declared shapes.
With
A: [M, K]andB: [K, N](Line 32–38), Line 66 should computeA @ B, notA @ B.T. Current reference is dimensionally inconsistent.Suggested fix
- return torch.matmul(A_scaled, B_scaled.T).to(torch.bfloat16)\n" + return torch.matmul(A_scaled, B_scaled).to(torch.bfloat16)\n"Also applies to: 66-66
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@flashinfer/trace/example/fi_trace_out/gemm_mxfp8_N4096_K4096.json` around lines 32 - 38, The GEMM reference is using an incompatible transpose for tensor B given the declared shapes "A": [M,K] and "B": [K,N]; update the computation that currently multiplies A by B.T to multiply A by B instead so the operation becomes A @ B (ensure the result shape is [M,N]), and verify any accompanying description/metadata (e.g., keys "A", "B" and dtype "float8_e4m3fn") and comments reflect no transpose on B.flashinfer/trace/example/fi_trace_out/top_k_sampling_v128256.json-38-43 (1)
38-43:⚠️ Potential issue | 🟠 MajorOutput dtype is inconsistent with the traced API contract.
Line 42 declares
samplesasint64, but this API path returnsint32by default (whenindicesis not provided). The reference in Line 46 also allocatesint64, so both schema and reference are misaligned.Suggested fix
- "dtype": "int64", + "dtype": "int32",- samples = torch.empty(batch_size, dtype=torch.int64, device=device)\n + samples = torch.empty(batch_size, dtype=torch.int32, device=device)\nAlso applies to: 46-46
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@flashinfer/trace/example/fi_trace_out/top_k_sampling_v128256.json` around lines 38 - 43, The schema and reference for the "samples" field in top_k_sampling_v128256.json incorrectly use dtype "int64" while the API returns int32 by default; change the "samples" dtype from "int64" to "int32" in the JSON schema and update the corresponding reference allocation that currently creates int64 to allocate int32 instead (look for the "samples" field and any reference example/allocation near the "shape": ["batch_size"] and the later allocation on Line 46 to ensure both schema and example match int32).flashinfer/trace/example/fi_trace_out/fused_add_rmsnorm_h5120.json-49-58 (1)
49-58:⚠️ Potential issue | 🟠 MajorReference return signature does not match declared outputs.
Line 49–56 declares two outputs (
output,residual), but Line 58’s reference returns only one tensor. This makes the trace definition internally inconsistent for validators/consumers.Suggested fix
- return y.to(hidden_states.dtype)\n" + return y.to(hidden_states.dtype), x.to(hidden_states.dtype)\n"🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@flashinfer/trace/example/fi_trace_out/fused_add_rmsnorm_h5120.json` around lines 49 - 58, The reference function _fused_add_rmsnorm_reference currently returns only the normalized output tensor but the trace declares two outputs ("output" and "residual"), so update the reference to return both values to match the schema: compute y as now and also produce the updated residual (residual + hidden_states in float32, cast back to residual.dtype) and return (y, updated_residual) (or alternatively change the trace outputs to a single "output" if the residual should not be returned); ensure names/ordering match the declared outputs.flashinfer/trace/example/fi_trace_out/gemm_bf16_N256_K7168.json-30-36 (1)
30-36:⚠️ Potential issue | 🟠 MajorBF16 GEMM reference is inconsistent with input shape declaration.
Given
Bshape[K, N](Line 30–36), Line 48 should not transposeBfor matmul. The current expression conflicts with the stated tensor contract.Suggested fix
- "reference": "def _mm_reference(A, B):\n return torch.matmul(A, B.T)\n" + "reference": "def _mm_reference(A, B):\n return torch.matmul(A, B)\n"Also applies to: 48-48
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@flashinfer/trace/example/fi_trace_out/gemm_bf16_N256_K7168.json` around lines 30 - 36, The JSON metadata declares tensor "B" with shape ["K","N"] (physical column-major [K, N]) but the matmul expression erroneously transposes B; update the matmul expression that currently uses B.T (or otherwise transposes "B") so it uses "B" directly to match the declared [K,N] contract, and ensure any accompanying description/comment is adjusted to reflect no transpose is applied.flashinfer/trace/example/fi_trace_out/gemm_fp4_N2048_K7168_block_size16.json-29-43 (1)
29-43:⚠️ Potential issue | 🟠 MajorFP4 GEMM schema and reference are shape-inconsistent.
Line 29–43 declares unpacked shapes (
[M, K],[K, N]) while Line 71 treatsA/Bas packed bytes and reconstructs logical dims by multiplying by 2. On top of that, the finalB_scaled.Tintroduces another dimension mismatch.Please make schema and reference consistent in one direction:
- keep packed semantics and declare packed shapes, or
- keep unpacked shapes and remove nibble-unpack logic.
Also, the final GEMM should not transposeB_scaledunder the current shape declarations.Also applies to: 71-71
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@flashinfer/trace/example/fi_trace_out/gemm_fp4_N2048_K7168_block_size16.json` around lines 29 - 43, The schema declares A and B as packed uint8 tensors but later code treats them as unpacked nibbles and reconstructs logical dims (the nibble-unpack logic) and then does B_scaled.T which creates a mismatch; pick one consistent approach and fix both schema and code: either (A) declare A/B shapes as packed (bytes) and keep the nibble-unpack/reconstruction code that expands to logical shapes but remove the final transpose of B_scaled (or transpose before unpacking) so GEMM uses matching [M,K] and [K,N], or (B) declare A/B as unpacked shapes ([M,K], [K,N]) and remove the nibble-unpack/reconstruction entirely; update the "description" fields (fp4 e2m1fn_x2 packed as uint8) and references to B_scaled and its transpose to match the chosen convention (adjust usage of B_scaled.T accordingly).flashinfer/gdn_decode.py-349-350 (1)
349-350: 🛠️ Refactor suggestion | 🟠 MajorAdd
@backend_requirementon these SM-constrained public APIs.At Line 349 and Line 490, these APIs are decorated for tracing but still lack explicit backend capability guards at the API boundary.
As per coding guidelines: Use
@backend_requirementdecorator on APIs that have compute capability requirements and provideis_compute_capability_supported(cc)andis_backend_supported()methods.Also applies to: 490-491
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@flashinfer/gdn_decode.py` around lines 349 - 350, The public API function gated_delta_rule_decode (decorated with `@flashinfer_api`(trace=gated_delta_rule_decode_trace)) is SM-constrained and must be guarded by the backend capability decorator; add `@backend_requirement` above its definition and implement the decorator to call the module's is_compute_capability_supported(cc) and is_backend_supported() helpers. Do the same for the other SM-constrained public API in this file that is currently decorated with `@flashinfer_api` around the later section (the second gated rule decode API at the other occurrence) so both API entrypoints check compute capability and backend support before proceeding.flashinfer/gdn_decode.py-36-53 (1)
36-53:⚠️ Potential issue | 🟠 MajorDecouple trace-template import from
flashinfer_apifallback.If trace template import fails but
flashinfer_apiis available, the current combinedtryblock still falls back to a no-op decorator, silently disabling API logging/tracing behavior.♻️ Suggested fix
-try: - from .api_logging import flashinfer_api - from .trace.templates.gdn import ( - gated_delta_rule_decode_trace, - gdn_mtp_trace, - ) - _FLASHINFER_AVAILABLE = True -except ImportError: - _FLASHINFER_AVAILABLE = False - gated_delta_rule_decode_trace = None # type: ignore[assignment] - gdn_mtp_trace = None # type: ignore[assignment] - - # Fallback decorator for standalone usage (accepts trace= kwarg) - def flashinfer_api(func=None, *, trace=None): # type: ignore[misc] - if func is None: - return lambda f: f - return func +try: + from .api_logging import flashinfer_api + _FLASHINFER_AVAILABLE = True +except ImportError: + _FLASHINFER_AVAILABLE = False + def flashinfer_api(func=None, *, trace=None): # type: ignore[misc] + if func is None: + return lambda f: f + return func + +try: + from .trace.templates.gdn import ( + gated_delta_rule_decode_trace, + gdn_mtp_trace, + ) +except ImportError: + gated_delta_rule_decode_trace = None # type: ignore[assignment] + gdn_mtp_trace = None # type: ignore[assignment]🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@flashinfer/gdn_decode.py` around lines 36 - 53, The combined try/except hides a missing trace-template import by replacing flashinfer_api with a no-op; split imports so flashinfer_api is imported in its own try/except and sets _FLASHINFER_AVAILABLE, then separately attempt to import gated_delta_rule_decode_trace and gdn_mtp_trace and only set them to None on failure—define the fallback flashinfer_api decorator only when the flashinfer_api import itself fails so trace import failures do not disable API logging/tracing.flashinfer/trace/templates/sampling.py-24-41 (1)
24-41:⚠️ Potential issue | 🟠 MajorThe sampling references are not reproducible as written.
These references call
torch.multinomial, but the template schema does not carry any RNG input, seed, or pre-generated random variate. The same trace payload can therefore emit differentsamplesacross runs, which makes the generated definitions unstable as reference artifacts. Please encode the randomness in the trace inputs or make the reference deterministic.Also applies to: 79-103, 141-173
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@flashinfer/trace/templates/sampling.py` around lines 24 - 41, The _top_k_sampling_reference function (and the other sampling reference blocks at lines 79-103 and 141-173) currently calls torch.multinomial which uses nondeterministic RNG not captured by the trace; change the reference signatures to accept explicit randomness (e.g., a per-sample uniform variates tensor or an RNG seed/tensor) and use those inputs to deterministically draw samples: after filtering/renormalizing the probabilities in _top_k_sampling_reference, compute the cumulative distribution and select the token whose cdf first exceeds the provided uniform variate for that batch (instead of torch.multinomial), and apply the same pattern to the other sampling reference functions so the randomness is fully encoded in trace inputs.flashinfer/api_logging.py-1497-1503 (1)
1497-1503:⚠️ Potential issue | 🟠 MajorDon’t silently disable
.fi_traceon attachment errors.These
except Exception: passblocks turn template/build failures into invisible feature loss: a broken trace template can quietly remove.fi_trace, andFLASHINFER_TRACE_DUMP=1can fail to write anything without surfacing why. Please preserve the failure in a stubfi_traceor emit a warning instead of dropping it on the floor.Also applies to: 1516-1517
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@flashinfer/api_logging.py` around lines 1497 - 1503, The current try/except around calling fi_trace_fn (guarded by _is_trace_dump_enabled and using _sig.bind(...)) swallows all exceptions and silently disables .fi_trace; change this to catch Exception as e, log or warn about the attachment/templating error (include the exception), and install a stub fi_trace function that preserves the attribute but emits the warning (or raises) when invoked so the feature failure is visible; apply the same change to the analogous block that appears for the second attachment (the other try/except using _sig.bind and fi_trace_fn).flashinfer/trace/example/fi_trace_out/gqa_paged_decode_h32_kv8_d128_ps64.json-112-112 (1)
112-112:⚠️ Potential issue | 🟠 MajorThe decode reference mixes page IDs with token indices.
Line 112 declares
kv_indicesas page IDs, but the reference flattens the cache to[num_pages * page_size, ...]and indexes that flattened tensor directly with those IDs. Withpage_size=64, each selected page contributes only its first token tok_b/v_b, so the logits and outputs are wrong. Please fix the upstream template to index pages first, then flatten to tokens, and regenerate this artifact.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@flashinfer/trace/example/fi_trace_out/gqa_paged_decode_h32_kv8_d128_ps64.json` at line 112, The reference implementation _gqa_paged_decode_reference incorrectly treats kv_indices as token indices against k_flat/v_flat; instead treat kv_indices as page IDs: use kv_indptr to select pages, index k_cache/v_cache by page IDs to get per-page tensors, then reshape/flatten each selected page into tokens (or index within page using page_size) before computing k_b and v_b; update the logic that computes k_flat/v_flat (or remove flattening) so you first select pages via kv_indices[page_start:page_end] -> page_ids, then gather k_cache[page_ids] and v_cache[page_ids] and reshape to token dimension prior to matmuls, then regenerate the artifact.flashinfer/trace/example/fi_trace_out/gqa_paged_prefill_h32_kv8_d128_ps16.json-119-119 (1)
119-119:⚠️ Potential issue | 🟠 MajorSchema/reference mismatch for
kv_indices.Line 119 says
kv_indicesare page IDs, but the reference indexesk_cache.reshape(-1, ...)/v_cache.reshape(-1, ...)with those IDs and setsnum_kv_tokensfrom the number of pages. Withpage_size=16, that dropspage_size - 1tokens from every selected page, so the causal window and outputs are wrong for paged inputs. Please fix the source template to gather pages first, then flatten their token dimension, and regenerate this artifact.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@flashinfer/trace/example/fi_trace_out/gqa_paged_prefill_h32_kv8_d128_ps16.json` at line 119, The reference treats kv_indices as page IDs but indexes k_cache/v_cache after reshaping into pages (k_flat/v_flat), which incorrectly drops the per-page token dimension; in _gqa_paged_prefill_reference gather the full pages first (use page_ids = kv_indices[kv_start:kv_end] to index k_cache and v_cache by page dimension), then flatten the page-token axis so k_b and v_b include all page_size tokens (adjust k_flat/v_flat usage or index k_cache/v_cache directly), set num_kv_tokens = page_ids.shape[0] * page_size, and update loops that compute max_kv, logits, attn, and output to iterate over the flattened token sequence accordingly (refer to symbols kv_indices, k_cache, v_cache, k_flat, v_flat, page_size, num_kv_tokens).flashinfer/trace/example/fi_trace_out/mla_paged_decode_h16_ckv512_kpe64_ps64.json-123-123 (1)
123-123:⚠️ Potential issue | 🟠 MajorThis
ps64reference still assumes one token per page.Line 123 indexes
kv_indicesintockv_cache/kpe_cachewithout flattening the selectedpage_sizedimension first. In this filepage_sizeis 64, soKc/Kpremain page tensors instead of[L, D]token matrices, and the subsequent decode matmuls no longer implement the declared operator. Please fix the upstream template for multi-token pages and regenerate this example.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@flashinfer/trace/example/fi_trace_out/mla_paged_decode_h16_ckv512_kpe64_ps64.json` at line 123, The reference _mla_paged_decode_reference assumes one-token pages but for page_size=64 you must flatten the per-page token dimension after selecting pages: when building Kc/Kp from Kc_all/Kp_all (currently from ckv_cache.squeeze(1)/kpe_cache.squeeze(1)) do Kc_all[tok_idx] and Kp_all[tok_idx] then reshape/flatten the result to [L, head_dim_ckv] and [L, head_dim_kpe] respectively (e.g., .reshape(-1, head_dim_ckv) / .reshape(-1, head_dim_kpe]) before computing logits and softmax) so the decode matmuls use token-level matrices; update _mla_paged_decode_reference to flatten selected pages accordingly and regenerate the example.flashinfer/fi_trace.py-273-280 (1)
273-280:⚠️ Potential issue | 🟠 MajorThe public helper never actually falls back to the legacy registry.
This module keeps
_REGISTRY,register_fi_trace(), andbuild_fi_trace_fn()for backwards compatibility, butfi_trace()only checksactual_func.fi_trace. Any legacy caller that registered a spec by qualname will still hit theNo fi_trace spec is registeredpath.Possible fix
actual_func = getattr(func_or_method, "__func__", func_or_method) trace_fn = getattr(actual_func, "fi_trace", None) if trace_fn is None: - qualname = getattr(actual_func, "__qualname__", repr(actual_func)) - raise ValueError( - f"No fi_trace spec is registered for '{qualname}'. " - "Only `@flashinfer_api`(trace=...)-decorated functions support fi_trace." - ) + qualname = getattr(actual_func, "__qualname__", None) + spec = _REGISTRY.get(qualname) if qualname is not None else None + if spec is not None: + trace_fn = build_fi_trace_fn(spec) + else: + qualname = qualname or repr(actual_func) + raise ValueError( + f"No fi_trace spec is registered for '{qualname}'. " + "Only `@flashinfer_api`(trace=...)-decorated functions support fi_trace." + ) return trace_fn(save_dir=save_dir, **kwargs)🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@flashinfer/fi_trace.py` around lines 273 - 280, The public helper fi_trace currently only checks the bound attribute actual_func.fi_trace and never looks up the legacy registry, so entries registered via register_fi_trace/_REGISTRY or built via build_fi_trace_fn are ignored; update the code after obtaining qualname to fall back to the legacy registry by looking up _REGISTRY[qualname] or calling build_fi_trace_fn(qualname) (using the same qualname computed from actual_func.__qualname__ or repr(actual_func)) and use that trace_fn when present before raising the ValueError so legacy-registered specs are honored.tests/test_fi_trace.py-357-362 (1)
357-362:⚠️ Potential issue | 🟠 MajorThese use-case tests allocate model-sized tensors even though
fi_traceonly inspects metadata.The
num_pages=8192decode case materializes about 512 MiB of KV cache, and the MLA example adds another ~288 MiB, just to read.shapeand.dtype. That is likely to slow or OOM CI without adding coverage. Please shrink these fixtures or move the model-scale examples out of the unit suite.Also applies to: 418-424
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/test_fi_trace.py` around lines 357 - 362, The test allocates model-sized tensors (num_pages, page_size, q, k_cache, v_cache) even though fi_trace only reads metadata; reduce memory by shrinking num_pages and page_size to small values (e.g., single- or double-digit sizes) or replace large concrete tensors with lightweight stand-ins (small shaped tensors or meta-device tensors) in the test vectors q, k_cache, v_cache used by test_fi_trace functions; apply the same change to the other occurrence around lines 418-424 to avoid CI OOMs while preserving the shape/dtype intent for fi_trace.flashinfer/trace/example/fi_trace_out/gdn_mtp_qk4_v8_d128.json-170-170 (1)
170-170:⚠️ Potential issue | 🟠 Major
final_statenever reflects the updates computed in the loop.
state_HVKis mutated for each token, but the function returnsinitial_state.clone()without writing any updated state back into it. The example therefore emits the original state pool while the schema saysfinal_stateis the updated recurrent state.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@flashinfer/trace/example/fi_trace_out/gdn_mtp_qk4_v8_d128.json` at line 170, The function returns initial_state.clone() even though state_HVK is updated per token; fix by writing the updated per-pool state_HVK back into the state pool using the same indexing flow you used to read it: after finishing the token loop for a batch item (or whenever you update state_HVK), assign final_state[state_idx] = state_HVK.transpose(-1, -2) (or update initial_state in-place) so that final_state (returned) contains the mutated states; ensure you use initial_state_indices/state_idx to map back and preserve dtype/device the same way intermediate_states_buffer is handled.flashinfer/trace/templates/gemm.py-111-215 (1)
111-215:⚠️ Potential issue | 🟠 MajorThe non-BF16 templates emit shapes with undefined or mismatched axes.
mm_fp8_traceusesK_div_block_size/block_size,mm_mxfp8_traceusesK_div_32, andmm_fp4_traceusesK_div_block_size/N_div_block_size, but none of those derived dimensions are declared inaxesor tied back toK/Nwith constraints.mm_fp4_tracealso labels packed uint8 operands as logical[M, K]and[K, N], so the discovered axis values will be off on the packed dimension. The resulting JSON is not self-contained for these ops.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@flashinfer/trace/templates/gemm.py` around lines 111 - 215, The templates mm_fp8_trace, mm_mxfp8_trace, and mm_fp4_trace declare derived dimensions (K_div_block_size, block_size, K_div_32, N_div_block_size) in their Tensor shapes but never define them in axes or relate them back to K/N; also mm_fp4 inputs are described as logical [M,K]/[K,N] while the stored packed uint8 layout changes the packed dimension. Fix by adding explicit axes entries for each derived dimension in the axes dict (e.g., "block_size", "K_div_block_size", "K_div_32", "N_div_block_size" or a packed axis like "K_packed") and document the arithmetic relationship (K_div_block_size = K // block_size, K_div_32 = K // 32, N_div_block_size = N // block_size or K_packed = packed_length_of(K) for fp4); then update the corresponding Tensor shapes in mm_fp8_trace, mm_mxfp8_trace, and mm_fp4_trace to reference those axes (and adjust A/B shapes for fp4 to use the packed axis instead of logical K) so the JSON is self-contained and axis relationships are explicit.flashinfer/trace/example/fi_trace_out/gqa_paged_decode_h32_kv8_d128_ps16.json-112-112 (1)
112-112:⚠️ Potential issue | 🟠 MajorThe paged decode reference is indexing page IDs as if they were token IDs.
kv_indicesis documented here as a page-ID array, butk_cache.reshape(-1, ...)/v_cache.reshape(-1, ...)followed by...[token_ids]only selects one row per page and drops the remainingpage_size - 1tokens. If this reference is used for verification, any multi-token page will compare against the wrong attention result. Based on learnings, when native paged KV layout is used, page indices are not supposed to be flattened into token indices.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@flashinfer/trace/example/fi_trace_out/gqa_paged_decode_h32_kv8_d128_ps16.json` at line 112, The reference implementation in _gqa_paged_decode_reference incorrectly treats kv_indices as token IDs by indexing into k_flat/v_flat (created by k_cache.reshape(-1,...)), which drops tokens within multi-token pages; instead, treat kv_indices as page IDs: extract pages via k_cache[pages] and v_cache[pages] (where pages = kv_indices[page_start:page_end].to(torch.long)), then combine the page_size dimension (e.g., .reshape(-1, num_kv_heads, head_dim)) so all tokens in each page are included before computing logits/attention; update k_b, v_b, and any downstream uses to reflect this page->token expansion while keeping q_b and gqa_ratio logic unchanged.flashinfer/trace/templates/gemm.py-22-85 (1)
22-85:⚠️ Potential issue | 🟠 MajorFix B tensor handling in quantized GEMM references and resolve undefined symbolic dimensions.
The quantized GEMM references have multiple critical issues:
Matrix multiply semantics: All references multiply with
B.Tdespite describing B with physical shape[K, N]. This is mathematically incorrect:[M, K] @ [K, N].T = [M, K] @ [N, K]has mismatched inner dimensions. The references should either remove the transpose or update schemas to describe B as[N, K].FP8 block layout:
_mm_fp8_reference()reshapes[K//block_size, N, block_size]directly to[K, N]without permuting first. TRT-LLM block layout requires permutation before reshape to reconstruct the original matrix correctly (i.e.,.reshape(K_div_bs, block_size, N).permute(1, 0, 2).reshape(K, N)).FP4 decoding:
_unpack_fp4()extracts raw nibble values (0–15) via bitwise masking and casts to float32 without decoding the e2m1fn format. The reference cannot serve as a correctness oracle without proper FP4 value lookup or conversion.Undefined symbolic axes: The FP8, MXFP8, and FP4 templates reference symbolic dimensions (
K_div_block_size,K_div_32,N_div_block_size,block_size) not declared in theiraxesdictionaries, preventing proper schema validation.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@flashinfer/trace/templates/gemm.py` around lines 22 - 85, The GEMM refs incorrectly transpose B and misuse block layouts: update _mm_reference, _mm_fp8_reference, _mm_mxfp8_reference, and _mm_fp4_reference so matmul uses A @ B (not A @ B.T) if B is intended as [K, N], or alternatively document/reshape B to [N, K] consistently; in _mm_fp8_reference before reshaping apply the TRT-LLM permutation (reshape to [K_div_bs, block_size, N] then permute(1,0,2) then reshape to [K, N]) instead of direct reshape; replace _unpack_fp4 with proper e2m1fn decoding (use a lookup/decode table to map 4-bit nibble values to float32) rather than raw nibble casts so FP4 semantics are correct; and add the missing symbolic axis declarations for K_div_bs/K_div_32, N_div_block_size and block_size in the template axes metadata so schema validation can resolve those symbols (referencing the functions _mm_fp8_reference, _mm_mxfp8_reference, _mm_fp4_reference and helper _unpack_fp4 to locate the changes).flashinfer/trace/templates/attention.py-113-116 (1)
113-116:⚠️ Potential issue | 🟠 MajorAdd the grouped-query head constraints to the GQA templates.
The GQA references rely on
num_qo_heads // num_kv_headsbeing a valid grouping factor, but the schema currently accepts shapes wherenum_qo_heads < num_kv_headsor the ratio is non-integral. In those cases the reference either divides by zero or walkskv_hpast the last KV head. Please addnum_qo_heads >= num_kv_headsandnum_qo_heads % num_kv_heads == 0here, and mirror the same invariant ingqa_paged_prefill_traceandgqa_ragged_prefill_trace.Possible fix
constraints=[ "len_indptr == batch_size + 1", "num_kv_indices == kv_indptr[-1].item()", + "num_qo_heads >= num_kv_heads", + "num_qo_heads % num_kv_heads == 0", ],🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@flashinfer/trace/templates/attention.py` around lines 113 - 116, The schema for the GQA templates currently allows invalid head groupings; update the constraints list in the attention template (the constraints array where "len_indptr == batch_size + 1" and "num_kv_indices == kv_indptr[-1].item()" are defined) to also require "num_qo_heads >= num_kv_heads" and "num_qo_heads % num_kv_heads == 0", and apply the same two invariants to the corresponding constraint lists in gqa_paged_prefill_trace and gqa_ragged_prefill_trace so the grouped-query computation (which uses num_qo_heads // num_kv_heads and kv head indexing) never divides by zero or indexes past the last KV head.flashinfer/trace/templates/gdn.py-164-168 (1)
164-168:⚠️ Potential issue | 🟠 MajorEnforce
seq_len == 1in the decode template.
_gdn_decode_referencedepends onsqueeze(1)removing the time axis. Ifseq_lenis anything else, it starts repeating along the sequence dimension instead of the head dimension and the reference becomes invalid. The description already says decode is single-token, so make that a hard constraint.Possible fix
constraints=[ + "seq_len == 1", "num_v_heads >= num_q_heads", "num_v_heads % num_q_heads == 0", "num_k_heads == num_q_heads", ],🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@flashinfer/trace/templates/gdn.py` around lines 164 - 168, Add a hard constraint enforcing single-token decoding by adding "seq_len == 1" to the decode template's constraints list so the template and its consumer _gdn_decode_reference (which relies on squeeze(1) removing the time axis) never run with seq_len > 1; update the constraints array (the one containing "num_v_heads >= num_q_heads", "num_v_heads % num_q_heads == 0", "num_k_heads == num_q_heads") to include "seq_len == 1".flashinfer/trace/templates/gdn.py-327-330 (1)
327-330:⚠️ Potential issue | 🟠 MajorPrefill is missing the GVA head-shape invariants used by the reference.
The prefill reference expands Q/K with
num_v_heads // num_q_headsandnum_v_heads // num_k_heads, so it needs the same head relationship guarantees as decode/MTP. Right now the schema accepts shapes that can truncate the repeat factor or produce an output whose head axis no longer matches the declarednum_v_heads.Possible fix
constraints=[ "len_cu_seqlens == num_seqs + 1", "total_seq_len == cu_seqlens[-1].item()", + "num_v_heads >= num_q_heads", + "num_v_heads % num_q_heads == 0", + "num_k_heads == num_q_heads", ],🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@flashinfer/trace/templates/gdn.py` around lines 327 - 330, The schema is missing invariants that guarantee the GVA head expansion used in prefill; add constraints to the same constraints list to require divisibility so expansions don't truncate heads: include "num_v_heads % num_q_heads == 0" and "num_v_heads % num_k_heads == 0" (referring to the symbols num_v_heads, num_q_heads, num_k_heads) so the prefill expansion of Q/K by num_v_heads // num_q_heads and num_v_heads // num_k_heads preserves the declared num_v_heads head axis.flashinfer/trace/templates/attention.py-42-43 (1)
42-43:⚠️ Potential issue | 🟠 MajorDon't treat page ids as flattened token ids.
After
reshape(-1, ...), indexing with rawkv_indicesonly fetches one flattened slot per page and ignores the otherpage_size - 1entries. That makes the paged GQA reference wrong for anypage_size > 1, and the same pattern repeats in_gqa_paged_prefill_reference,_mla_paged_decode_reference, and_mla_paged_prefill_reference. Either materialize full pages and trim the last page with explicit length metadata, or constrain these paged templates topage_size == 1.Possible direction
- k_flat = k_cache.reshape(-1, num_kv_heads, head_dim).to(torch.float32) - v_flat = v_cache.reshape(-1, num_kv_heads, head_dim).to(torch.float32) + # Materialize selected pages first; then flatten tokens within those pages. + # The last page still needs an explicit length input to trim padding correctly. ... - token_ids = kv_indices[page_start:page_end].to(torch.long) - k_b = k_flat[token_ids] # [T, num_kv_heads, head_dim] - v_b = v_flat[token_ids] + page_ids = kv_indices[page_start:page_end].to(torch.long) + k_b = k_cache[page_ids].reshape(-1, num_kv_heads, head_dim).to(torch.float32) + v_b = v_cache[page_ids].reshape(-1, num_kv_heads, head_dim).to(torch.float32)Also applies to: 51-53
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@flashinfer/trace/templates/attention.py` around lines 42 - 43, The code flattens page slots with k_cache.reshape(-1, num_kv_heads, head_dim) (k_flat/v_flat) and then indexes with kv_indices, which treats page ids as single flattened token ids and therefore drops the other page_size-1 entries; update the paged templates (_gqa_paged_prefill_reference, _mla_paged_decode_reference, _mla_paged_prefill_reference) to either (A) materialize full page slices before flattening (i.e., expand/reshape to include page_size dimension, gather full pages using kv_indices, then trim the final partial page using explicit length metadata) or (B) enforce/validate page_size == 1 at the start of these functions and raise an error if otherwise; ensure all uses of k_flat, v_flat and kv_indices are adjusted accordingly so each page returns all its key/value slots rather than a single flattened slot.flashinfer/trace/templates/gdn.py-377-410 (1)
377-410:⚠️ Potential issue | 🟠 MajorWrite the updated slot back before returning
final_state.
state_HVKis updated for every token, but nothing persists it into the returned pool. As written,final_state = initial_state.clone()returns the original state unchanged, which contradicts the template contract and breaks stateful verification.Possible fix
output = torch.zeros( (B, T, num_v_heads, head_size), dtype=torch.bfloat16, device=device ) + final_state = initial_state.clone().float() cache_intermediate = intermediate_states_buffer is not None for b_idx in range(B): state_idx = int(initial_state_indices[b_idx].item()) - state_HVK = initial_state[state_idx].clone().float().transpose(-1, -2) # [H,V,K] -> [H,K,V] + state_HVK = final_state[state_idx].transpose(-1, -2).clone() # [H,V,K] -> [H,K,V] for t in range(T): ... if cache_intermediate: intermediate_states_buffer[state_idx, t] = state_HVK.transpose(-1, -2) # [H,K,V] -> [H,V,K] + + final_state[state_idx] = state_HVK.transpose(-1, -2) - final_state = initial_state.clone() return output, final_state🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@flashinfer/trace/templates/gdn.py` around lines 377 - 410, Summary: final_state is returned unchanged because updated state_HVK is never written back to the pool. Fix: clone initial_state into final_state before the outer loop (or otherwise create a mutable final_state), and after processing each batch element (after the inner t loop where state_HVK holds the final per-head slot), write the updated slot back via final_state[state_idx] = state_HVK.transpose(-1, -2) so the returned final_state reflects the updates; reference symbols: initial_state, final_state, state_HVK, state_idx, initial_state_indices.
🟡 Minor comments (4)
flashinfer/trace/template.py-371-376 (1)
371-376:⚠️ Potential issue | 🟡 MinorSilent exception swallowing may hide bugs during axis extraction.
Catching bare
Exceptionand passing silently can mask unexpected errors (e.g.,TypeError,AttributeError) that indicate template misconfiguration or API misuse. Consider logging at debug level or being more specific about expected exceptions.🔧 Proposed fix to add debug logging
+import logging + +_logger = logging.getLogger(__name__) + # In fi_trace function: for axis_name, extractor in axis_extractors.items(): try: val = extractor(kwargs) if val is not None: axis_values[axis_name] = val - except Exception: - pass + except Exception as e: + _logger.debug("Failed to extract axis %r: %s", axis_name, e)🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@flashinfer/trace/template.py` around lines 371 - 376, The try/except around calling extractor(kwargs) silently swallows all exceptions (using bare except), which can hide bugs; update the except to capture the exception as e and emit a debug-level log including axis_name, extractor, and the exception/traceback (or narrow the except to expected errors like KeyError/IndexError/ValueError if applicable) before continuing, ensuring axis_values and axis_name remain unchanged on failure; reuse the module's logger instance (or create one if none exists) so the failure context is recorded for debugging.flashinfer/trace/templates/norm.py-56-88 (1)
56-88:⚠️ Potential issue | 🟡 MinorReference implementation return value mismatch with template outputs.
The
_fused_add_rmsnorm_referencefunction returns onlyy(single tensor), butfused_add_rmsnorm_tracedefines two outputs:outputandresidual. The reference should return both to match the template schema.🐛 Proposed fix to return both outputs
`@torch.no_grad`() def _fused_add_rmsnorm_reference(hidden_states, residual, weight): """Fused Add + RMSNorm. Epsilon is fixed at 1e-6.""" EPS = 1e-6 - x = hidden_states.to(torch.float32) + residual.to(torch.float32) + residual_updated = hidden_states + residual + x = residual_updated.to(torch.float32) inv_rms = torch.rsqrt(x.pow(2).mean(dim=-1, keepdim=True) + EPS) y = (x * inv_rms) * weight.to(torch.float32) - return y.to(hidden_states.dtype) + return y.to(hidden_states.dtype), residual_updated🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@flashinfer/trace/templates/norm.py` around lines 56 - 88, The reference function _fused_add_rmsnorm_reference currently returns only y but the TraceTemplate fused_add_rmsnorm_trace declares two outputs ("output" and updated "residual"); update _fused_add_rmsnorm_reference to return a tuple (output, residual_out) where residual_out reflects the in-place semantics described (residual += hidden_states) — e.g., compute residual_out = residual.to(torch.float32) + hidden_states.to(torch.float32) (or perform an in-place add if appropriate), cast both y and residual_out back to the original hidden_states dtype, and return them in the same order as the template outputs.flashinfer/trace/example/__main__.py-1-1 (1)
1-1:⚠️ Potential issue | 🟡 MinorReplace wildcard import with explicit module reference.
Line 1 uses
from .example import *, which triggers Ruff F403 and obscures what is actually imported. Sinceexample.pydefines no__all__and is structured as a side-effect module (not an export container), usefrom . import exampleinstead to make the intent explicit.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@flashinfer/trace/example/__main__.py` at line 1, Replace the wildcard import in __main__.py: remove "from .example import *" and import the module explicitly (use "from . import example") so the code references names via the example module; update any direct references that relied on the star-import to be prefixed with "example." to keep intent explicit and satisfy Ruff F403.flashinfer/trace/example/example.py-129-130 (1)
129-130:⚠️ Potential issue | 🟡 MinorAvoid silent
except Exception: passin the example runner.These blocks hide unexpected failures and can make the generated trace set look complete when it is not.
♻️ Suggested fix
-except Exception: - pass # Requires Blackwell (SM100+) +except Exception as e: + print(f"[skip] mm_mxfp8 example not run: {e}") # Requires Blackwell (SM100+) -except Exception: - pass # Requires Blackwell (SM100+) +except Exception as e: + print(f"[skip] mm_fp4 example not run: {e}") # Requires Blackwell (SM100+) -except Exception: - pass # May require specific GPU/TRT-LLM support +except Exception as e: + print(f"[skip] trtllm_fp8_block_scale_moe example not run: {e}")Also applies to: 140-141, 276-277
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@flashinfer/trace/example/example.py` around lines 129 - 130, Replace each silent "except Exception: pass" in the example runner (the three occurrences that suppress errors for the Blackwell/SM100+ path) with targeted handling: catch only the expected exception (e.g., ImportError or ModuleNotFoundError when Blackwell is absent) or, if you must continue on error, log the full exception with logging.exception or traceback.print_exc including contextual information about which trace/step failed; do not swallow unexpected exceptions—re-raise them after logging so real failures are visible.
🧹 Nitpick comments (2)
flashinfer/trace/template.py (1)
474-474: Consider using spread operator for list construction.Per static analysis, using spread syntax is more idiomatic.
♻️ Suggested change
- all_tags = [f"fi_api:{fi_api}"] + template.tags + all_tags = [f"fi_api:{fi_api}", *template.tags]🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@flashinfer/trace/template.py` at line 474, Replace the list concatenation used to build all_tags with Python list unpacking for readability: instead of creating all_tags via [f"fi_api:{fi_api}"] + template.tags, construct it using the spread/unpacking form to include f"fi_api:{fi_api}" and all elements from template.tags (referencing variables all_tags, fi_api, and template.tags in template.py).flashinfer/trace/__init__.py (1)
23-25: Consider sorting__all__and reconsidering private export.Per static analysis,
__all__should be sorted. Additionally,_TRACE_DUMP_DIRhas a private naming convention (underscore prefix) but is exported publicly—consider renaming toTRACE_DUMP_DIRif it's meant for external use, or documenting why it's exposed.♻️ Suggested sorted `__all__`
-__all__ = ["TraceTemplate", "Var", "Const", "Tensor", "Scalar", "_TRACE_DUMP_DIR"] +__all__ = ["Const", "Scalar", "Tensor", "TraceTemplate", "Var", "_TRACE_DUMP_DIR"]🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@flashinfer/trace/__init__.py` around lines 23 - 25, The __all__ list is unsorted and exposes a name with a leading underscore (_TRACE_DUMP_DIR) which conflicts with its private naming; update the __all__ declaration so entries are alphabetically sorted (e.g., Const, Scalar, Tensor, TraceTemplate, Var) and decide whether _TRACE_DUMP_DIR is meant to be public—if so rename it to TRACE_DUMP_DIR in template.py and here and export that, otherwise remove it from __all__ (or add a comment/docstring explaining why the underscored name is intentionally exported) so exports and naming are consistent; adjust imports/usage accordingly (TraceTemplate, Var, Const, Tensor, Scalar, and the chosen dump-dir symbol).
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: 834f4a1a-3013-4d83-80ed-7022baffd452
📒 Files selected for processing (49)
flashinfer/__init__.pyflashinfer/api_logging.pyflashinfer/attention.pyflashinfer/decode.pyflashinfer/fi_trace.pyflashinfer/fused_moe/core.pyflashinfer/gdn_decode.pyflashinfer/gdn_prefill.pyflashinfer/gemm/gemm_base.pyflashinfer/mla/_core.pyflashinfer/mla/cute_dsl/mla_decode.pyflashinfer/norm/__init__.pyflashinfer/prefill.pyflashinfer/sampling.pyflashinfer/trace/__init__.pyflashinfer/trace/example/__main__.pyflashinfer/trace/example/example.pyflashinfer/trace/example/fi_trace_out/fused_add_rmsnorm_h5120.jsonflashinfer/trace/example/fi_trace_out/gdn_decode_qk4_v8_d128.jsonflashinfer/trace/example/fi_trace_out/gdn_mtp_qk4_v8_d128.jsonflashinfer/trace/example/fi_trace_out/gemm_bf16_N256_K7168.jsonflashinfer/trace/example/fi_trace_out/gemm_bf16_N4096_K4096.jsonflashinfer/trace/example/fi_trace_out/gemm_fp4_N2048_K7168_block_size16.jsonflashinfer/trace/example/fi_trace_out/gemm_fp8_N1536_K7168.jsonflashinfer/trace/example/fi_trace_out/gemm_mxfp8_N4096_K4096.jsonflashinfer/trace/example/fi_trace_out/gqa_paged_decode_h32_kv8_d128_ps16.jsonflashinfer/trace/example/fi_trace_out/gqa_paged_decode_h32_kv8_d128_ps64.jsonflashinfer/trace/example/fi_trace_out/gqa_paged_prefill_h32_kv8_d128_ps16.jsonflashinfer/trace/example/fi_trace_out/gqa_ragged_h32_kv8_d128.jsonflashinfer/trace/example/fi_trace_out/mla_paged_decode_h16_ckv512_kpe64_ps1.jsonflashinfer/trace/example/fi_trace_out/mla_paged_decode_h16_ckv512_kpe64_ps64.jsonflashinfer/trace/example/fi_trace_out/moe_fp8_block_scale_default_routing_topk8_e32_h7168_i2048.jsonflashinfer/trace/example/fi_trace_out/rmsnorm_h4096.jsonflashinfer/trace/example/fi_trace_out/rmsnorm_h7168.jsonflashinfer/trace/example/fi_trace_out/top_k_sampling_v128256.jsonflashinfer/trace/example/fi_trace_out/top_k_top_p_sampling_v128256.jsonflashinfer/trace/example/fi_trace_out/top_k_top_p_sampling_v151936.jsonflashinfer/trace/example/fi_trace_out/top_p_sampling_v128256.jsonflashinfer/trace/example/fi_trace_out/top_p_sampling_v151936.jsonflashinfer/trace/template.pyflashinfer/trace/templates/__init__.pyflashinfer/trace/templates/attention.pyflashinfer/trace/templates/gdn.pyflashinfer/trace/templates/gemm.pyflashinfer/trace/templates/moe.pyflashinfer/trace/templates/norm.pyflashinfer/trace/templates/sampling.pyflashinfer/trtllm_low_latency_gemm.pytests/test_fi_trace.py
| "dtype": "bfloat16" | ||
| } | ||
| }, | ||
| "reference": "def _mm_fp8_reference(A, B):\n \"\"\"Dequantize FP8 block-scale inputs and compute C = A @ B.T.\n\n B is in TRT-LLM block layout [K//block_size, N, block_size] and is\n reshaped to [K, N] before the matmul.\n \"\"\"\n K_div_bs, N, block_size = B.shape\n B_fp32 = B.reshape(K_div_bs * block_size, N).to(torch.float32)\n A_fp32 = A.to(torch.float32)\n return torch.matmul(A_fp32, B_fp32.T).to(torch.bfloat16)\n" |
There was a problem hiding this comment.
Fix the embedded reference matmul transpose.
At Line 50, B_fp32 is reshaped to [K, N], so A_fp32 @ B_fp32.T is invalid when K != N (here 7168 != 1536). The reference should multiply with B_fp32 (or reshape differently if transposed semantics are intended).
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@flashinfer/trace/example/fi_trace_out/gemm_fp8_N1536_K7168.json` at line 50,
The helper _mm_fp8_reference currently reshapes B into B_fp32 =
B.reshape(K_div_bs * block_size, N) (i.e., [K, N]) but then computes
torch.matmul(A_fp32, B_fp32.T), which is wrong when K != N; change the matmul to
use B_fp32 (torch.matmul(A_fp32, B_fp32)) so the multiplication matches the
reshaped [K, N] layout, or alternatively reshape B to [N, K] if you truly need
B_fp32.T semantics—fix the call in _mm_fp8_reference referencing B_fp32 and
A_fp32 accordingly.
| H = 7168 | ||
| I = 2048 | ||
| BLOCK = 128 |
There was a problem hiding this comment.
Hardcoded H/I makes reference execution shape-fragile.
_fp8_moe_run_experts is wired to H=7168 and I=2048, but template axes are shape-driven. This will fail or produce invalid behavior for other valid MoE shapes.
💡 Proposed fix
-H = 7168
-I = 2048
BLOCK = 128
@@
def _fp8_moe_run_experts(
@@
- T = hidden_states.shape[0]
+ T, H = hidden_states.shape
+ I = gemm2_weights.shape[2]
+ gemm1_out = gemm1_weights.shape[1]
+ if gemm1_out != 2 * I:
+ raise ValueError(
+ f"Invalid gemm1_out_size={gemm1_out}, expected 2 * intermediate_size={2 * I}"
+ )
@@
- A_scale_expanded = (
- A_scale_TH.unsqueeze(-1).repeat(1, 1, BLOCK).reshape(T, H).contiguous()
- )
+ A_scale_expanded = A_scale_TH.unsqueeze(-1).repeat(1, 1, BLOCK).reshape(T, H).contiguous()
@@
- X1, X2 = G1[:, :I], G1[:, I:]
+ X1, X2 = G1[:, :I], G1[:, I:]Also applies to: 48-58, 72-73, 86-87
🧰 Tools
🪛 Ruff (0.15.7)
[error] 26-26: Ambiguous variable name: I
(E741)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@flashinfer/trace/templates/moe.py` around lines 25 - 27, Hardcoded constants
H=7168 and I=2048 (and uses of BLOCK) make _fp8_moe_run_experts and related
templates shape-fragile; change these to compute H and I from the template/axis
sizes at runtime and use a derived BLOCK (e.g., based on H/I or
template.block_size) instead of literal numbers. Locate the constants H, I,
BLOCK and replace them with expressions that read the relevant template axes or
tensor shapes (reference the template used by _fp8_moe_run_experts and other
occurrences) so all occurrences use dynamically computed sizes rather than
hardcoded values.
There was a problem hiding this comment.
Actionable comments posted: 6
🧹 Nitpick comments (2)
tests/trace/fi_trace_out/gdn_mtp_qk4_v8_d128.json (1)
135-146: Consider specifying dtype forintermediate_states_buffer.The
dtypeis set to"unknown"while all other tensors have explicit dtypes. Since this buffer stores intermediate states similar toinitial_stateandfinal_state(bothfloat32), consider using"float32"for consistency—or add documentation explaining why the dtype is indeterminate.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/trace/fi_trace_out/gdn_mtp_qk4_v8_d128.json` around lines 135 - 146, The schema entry "intermediate_states_buffer" currently has dtype "unknown"; change it to a concrete dtype (e.g., "float32") to match the similar tensors "initial_state" and "final_state", or if dtype truly varies, add a clear description explaining why it's indeterminate and what types are allowed; update the "intermediate_states_buffer" dtype field and its description accordingly to ensure consistency and clarity.tests/trace/fi_trace_out/mla_paged_decode_h16_ckv512_kpe64_ps64.json (1)
85-97: Use concrete integer dtypes for index tensors.
kv_indptrandkv_indicesare currently"dtype": "unknown". This weakens schema validation and downstream codegen/consumers. Prefer explicit integer types (typicallyint32orint64) for both.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/trace/fi_trace_out/mla_paged_decode_h16_ckv512_kpe64_ps64.json` around lines 85 - 97, The schema uses "dtype": "unknown" for the index tensors kv_indptr (shape "len_indptr") and kv_indices (shape "num_kv_indices"); change both to a concrete integer dtype (prefer int32, or int64 if you need 64-bit indices) so downstream validation and codegen can rely on a fixed integer type—update the "dtype" entries for kv_indptr and kv_indices accordingly.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@tests/trace/example.py`:
- Around line 1-294: The file is a standalone script so pytest won't collect it;
convert it into a proper pytest test by moving the top-level side-effect code
into a single test function (e.g., def test_generate_fi_trace_jsons(tmp_path):)
while preserving the early environment setup (os.environ.setdefault(...) and
SAVE_DIR) before importing flashinfer, and use the tmp_path fixture to override
FLASHINFER_TRACE_DUMP_DIR/SAVE_DIR so outputs go to a test-isolated directory;
keep all calls to flashinfer functions and wrappers (e.g., flashinfer.rmsnorm,
flashinfer.fused_add_rmsnorm, flashinfer.top_k_sampling_from_probs,
flashinfer.mm_bf16, flashinfer.gdn_decode.gated_delta_rule_decode,
BatchDecodeWithPagedKVCacheWrapper, BatchPrefillWithPagedKVCacheWrapper,
BatchPrefillWithRaggedKVCacheWrapper, BatchMLAPagedAttentionWrapper,
flashinfer.fused_moe.trtllm_fp8_block_scale_moe, etc.) inside that test, and
remove or adapt prints/assert the expected JSON files exist via
SAVE_DIR.glob("*.json") to make the test assertions deterministic for CI.
In `@tests/trace/fi_trace_out/gdn_decode_qk4_v8_d128.json`:
- Around line 120-124: The "scale" field is documented as having a default
(1/sqrt(head_size)) but isn't marked optional; update the JSON schema entry for
"scale" so consumers know it may be omitted—e.g., add an optional/nullable flag
or remove it from any "required" list and set "optional": true (or equivalent)
next to the "scale" property to reflect the default behavior.
- Line 148: The reference function _gdn_decode_reference uses math.sqrt and
F.softplus but the serialized source string has no imports, causing NameError
when exec/eval runs; fix by either injecting math and torch.nn.functional as F
into the exec/eval globals where _gdn_decode_reference is executed (ensure names
"math" and "F" are present) or prepend/import lines ("import math" and "import
torch.nn.functional as F") to the serialized reference string so
_gdn_decode_reference has the required symbols at runtime.
In `@tests/trace/fi_trace_out/gdn_mtp_qk4_v8_d128.json`:
- Around line 159-168: The doc string for "final_state" references an undefined
parameter disable_state_update; either add a boolean input named
disable_state_update to the inputs section (e.g., description: "If true,
recurrent state updates are disabled and final_state remains unchanged") or
remove the mention "Unchanged if disable_state_update=True" from the
"final_state" description; update the "final_state" description or inputs
accordingly so the documentation no longer refers to an undefined symbol.
- Line 170: The reference function _gdn_mtp_reference updates per-batch states
in state_HVK but then returns final_state = initial_state.clone(), discarding
updates; fix by creating final_state = initial_state.clone() before the batch
loop and after processing each batch element (using state_idx =
int(initial_state_indices[b_idx].item())) write the updated state back with
final_state[state_idx] = state_HVK.transpose(-1, -2) (matching the stored
[H,V,K] layout); ensure types remain consistent (match .float()/.to dtype as
needed) and then return output, final_state.
In `@tests/trace/fi_trace_out/mla_paged_decode_h16_ckv512_kpe64_ps64.json`:
- Line 123: In _mla_paged_decode_reference the use of ckv_cache.squeeze(1) and
kpe_cache.squeeze(1) is wrong for paged tensors (shape [num_pages, page_size,
head_dim_*]) and leaves a 3D tensor so Kc_all[tok_idx] yields [L, page_size,
head_dim]; replace squeeze(1) with a flattening reshape (e.g. reshape(num_pages
* page_size, head_dim_ckv) / reshape(..., head_dim_kpe) or view(-1, head_dim_*))
so Kc_all and Kp_all become 2D token-major tensors before indexing, and ensure
kv_indptr and kv_indices are cast to explicit integer dtype (torch.long/int64)
before use to remove schema ambiguity.
---
Nitpick comments:
In `@tests/trace/fi_trace_out/gdn_mtp_qk4_v8_d128.json`:
- Around line 135-146: The schema entry "intermediate_states_buffer" currently
has dtype "unknown"; change it to a concrete dtype (e.g., "float32") to match
the similar tensors "initial_state" and "final_state", or if dtype truly varies,
add a clear description explaining why it's indeterminate and what types are
allowed; update the "intermediate_states_buffer" dtype field and its description
accordingly to ensure consistency and clarity.
In `@tests/trace/fi_trace_out/mla_paged_decode_h16_ckv512_kpe64_ps64.json`:
- Around line 85-97: The schema uses "dtype": "unknown" for the index tensors
kv_indptr (shape "len_indptr") and kv_indices (shape "num_kv_indices"); change
both to a concrete integer dtype (prefer int32, or int64 if you need 64-bit
indices) so downstream validation and codegen can rely on a fixed integer
type—update the "dtype" entries for kv_indptr and kv_indices accordingly.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: 9d5f1191-ca90-4d41-be7d-6533b17213b1
📒 Files selected for processing (29)
flashinfer/decode.pyflashinfer/fused_moe/core.pyflashinfer/gdn_decode.pyflashinfer/gemm/gemm_base.pyflashinfer/norm/__init__.pyflashinfer/prefill.pytests/trace/example.pytests/trace/fi_trace_out/fused_add_rmsnorm_h5120.jsontests/trace/fi_trace_out/gdn_decode_qk4_v8_d128.jsontests/trace/fi_trace_out/gdn_mtp_qk4_v8_d128.jsontests/trace/fi_trace_out/gemm_bf16_N256_K7168.jsontests/trace/fi_trace_out/gemm_bf16_N4096_K4096.jsontests/trace/fi_trace_out/gemm_fp4_N2048_K7168_block_size16.jsontests/trace/fi_trace_out/gemm_fp8_N1536_K7168.jsontests/trace/fi_trace_out/gemm_mxfp8_N4096_K4096.jsontests/trace/fi_trace_out/gqa_paged_decode_h32_kv8_d128_ps16.jsontests/trace/fi_trace_out/gqa_paged_decode_h32_kv8_d128_ps64.jsontests/trace/fi_trace_out/gqa_paged_prefill_h32_kv8_d128_ps16.jsontests/trace/fi_trace_out/gqa_ragged_h32_kv8_d128.jsontests/trace/fi_trace_out/mla_paged_decode_h16_ckv512_kpe64_ps1.jsontests/trace/fi_trace_out/mla_paged_decode_h16_ckv512_kpe64_ps64.jsontests/trace/fi_trace_out/moe_fp8_block_scale_default_routing_topk8_e32_h7168_i2048.jsontests/trace/fi_trace_out/rmsnorm_h4096.jsontests/trace/fi_trace_out/rmsnorm_h7168.jsontests/trace/fi_trace_out/top_k_sampling_v128256.jsontests/trace/fi_trace_out/top_k_top_p_sampling_v128256.jsontests/trace/fi_trace_out/top_k_top_p_sampling_v151936.jsontests/trace/fi_trace_out/top_p_sampling_v128256.jsontests/trace/fi_trace_out/top_p_sampling_v151936.json
✅ Files skipped from review due to trivial changes (20)
- flashinfer/norm/init.py
- tests/trace/fi_trace_out/gemm_bf16_N4096_K4096.json
- tests/trace/fi_trace_out/rmsnorm_h7168.json
- tests/trace/fi_trace_out/gemm_bf16_N256_K7168.json
- tests/trace/fi_trace_out/fused_add_rmsnorm_h5120.json
- tests/trace/fi_trace_out/top_k_top_p_sampling_v151936.json
- tests/trace/fi_trace_out/rmsnorm_h4096.json
- tests/trace/fi_trace_out/gqa_paged_decode_h32_kv8_d128_ps64.json
- tests/trace/fi_trace_out/top_k_sampling_v128256.json
- tests/trace/fi_trace_out/gemm_fp4_N2048_K7168_block_size16.json
- tests/trace/fi_trace_out/gemm_fp8_N1536_K7168.json
- tests/trace/fi_trace_out/gqa_paged_prefill_h32_kv8_d128_ps16.json
- tests/trace/fi_trace_out/top_p_sampling_v151936.json
- tests/trace/fi_trace_out/gqa_ragged_h32_kv8_d128.json
- tests/trace/fi_trace_out/mla_paged_decode_h16_ckv512_kpe64_ps1.json
- tests/trace/fi_trace_out/top_p_sampling_v128256.json
- tests/trace/fi_trace_out/moe_fp8_block_scale_default_routing_topk8_e32_h7168_i2048.json
- tests/trace/fi_trace_out/gqa_paged_decode_h32_kv8_d128_ps16.json
- flashinfer/prefill.py
- tests/trace/fi_trace_out/gemm_mxfp8_N4096_K4096.json
🚧 Files skipped from review as they are similar to previous changes (3)
- flashinfer/decode.py
- flashinfer/gdn_decode.py
- flashinfer/gemm/gemm_base.py
| """ | ||
| fi_trace example: generate flashinfer-bench definition JSON files via auto-dump. | ||
|
|
||
| Run: | ||
| python tests/trace/example.py | ||
|
|
||
| When FLASHINFER_TRACE_DUMP=1 (set below), every @flashinfer_api(trace=...) decorated | ||
| function automatically writes a trace JSON on its first call for each unique input | ||
| shape. Subsequent calls with the same shape are deduplicated (no re-write). | ||
|
|
||
| The output directory is controlled by FLASHINFER_TRACE_DUMP_DIR. | ||
|
|
||
| Requires a CUDA-capable GPU. | ||
|
|
||
| Results: | ||
| - We would get these example json files under fi_trace_out directory: | ||
| fused_add_rmsnorm_h5120.json | ||
| gdn_decode_qk4_v8_d128_k_last.json | ||
| gdn_mtp_qk4_v8_d128_k_last.json | ||
| gdn_prefill_qk4_v8_d128_k_last.json | ||
| gemm_bf16_n256_k7168.json | ||
| gemm_bf16_n4096_k4096.json | ||
| gemm_fp4_n2048_k7168.json | ||
| gemm_fp8_n1536_k7168.json | ||
| gemm_mxfp8_n4096_k4096.json | ||
| gqa_paged_decode_h32_kv8_d128_ps16.json | ||
| gqa_paged_decode_h32_kv8_d128_ps64.json | ||
| gqa_paged_prefill_h32_kv8_d128_ps16.json | ||
| gqa_ragged_prefill_h32_kv8_d128.json | ||
| mla_paged_decode_h16_ckv512_kpe64_ps1.json | ||
| mla_paged_decode_h16_ckv512_kpe64_ps64.json | ||
| moe_fp8_block_scale_ds_routing_topk8_ng8_kg4_e32_h7168_i2048.json | ||
| rmsnorm_h4096.json | ||
| rmsnorm_h7168.json | ||
| top_k_sampling_from_probs_v128256.json | ||
| top_k_top_p_sampling_from_probs_v128256.json | ||
| top_k_top_p_sampling_from_probs_v151936.json | ||
| top_p_sampling_from_probs_v128256.json | ||
| top_p_sampling_from_probs_v151936.json | ||
|
|
||
| Note: top_p_sampling files appear for vocab_size=151936 because | ||
| top_k_top_p_sampling (top_k_first order) calls top_p_sampling internally. | ||
| """ | ||
|
|
||
| import json | ||
| import os | ||
| from pathlib import Path | ||
|
|
||
| # Must be set before any flashinfer import: template.py reads these at module load time. | ||
| os.environ.setdefault( | ||
| "FLASHINFER_TRACE_DUMP_DIR", | ||
| str(Path(__file__).parent / "fi_trace_out"), | ||
| ) | ||
| os.environ.setdefault("FLASHINFER_TRACE_DUMP", "1") | ||
|
|
||
| SAVE_DIR = Path(os.environ["FLASHINFER_TRACE_DUMP_DIR"]) | ||
|
|
||
| import torch | ||
|
|
||
| import flashinfer | ||
| import flashinfer.norm | ||
| import flashinfer.sampling | ||
| import flashinfer.gemm | ||
| import flashinfer.gdn_decode | ||
| import flashinfer.fused_moe | ||
| from flashinfer.decode import BatchDecodeWithPagedKVCacheWrapper | ||
| from flashinfer.prefill import ( | ||
| BatchPrefillWithPagedKVCacheWrapper, | ||
| BatchPrefillWithRaggedKVCacheWrapper, | ||
| ) | ||
| from flashinfer.mla import BatchMLAPagedAttentionWrapper | ||
|
|
||
| device = "cuda" | ||
| WORKSPACE = 128 * 1024 * 1024 # 128 MB | ||
|
|
||
| print(f"\nAuto-dumping fi_trace JSON files to {SAVE_DIR}/\n") | ||
|
|
||
| # ── rmsnorm ─────────────────────────────────────────────────────────────────── | ||
| # Llama-3.1-8B (hidden=4096) and DeepSeek-V3 (hidden=7168) | ||
| for hidden_size in (4096, 7168): | ||
| hidden = torch.randn(32, hidden_size, dtype=torch.bfloat16, device=device) | ||
| weight = torch.ones(hidden_size, dtype=torch.bfloat16, device=device) | ||
| flashinfer.rmsnorm(hidden, weight) | ||
|
|
||
| # ── fused_add_rmsnorm (Qwen3-14B, hidden=5120) ─────────────────────────────── | ||
| x = torch.randn(32, 5120, dtype=torch.bfloat16, device=device) | ||
| res = torch.randn(32, 5120, dtype=torch.bfloat16, device=device) | ||
| w = torch.ones(5120, dtype=torch.bfloat16, device=device) | ||
| flashinfer.fused_add_rmsnorm(x, res, w) | ||
|
|
||
| # ── sampling (Llama vocab=128256) ───────────────────────────────────────────── | ||
| probs = torch.rand(64, 128256, dtype=torch.float32, device=device) | ||
| top_k = torch.full((64,), 50, dtype=torch.int32, device=device) | ||
| top_p = torch.full((64,), 0.9, dtype=torch.float32, device=device) | ||
| flashinfer.top_k_sampling_from_probs(probs, top_k) | ||
| flashinfer.top_p_sampling_from_probs(probs, top_p) | ||
| flashinfer.top_k_top_p_sampling_from_probs(probs, top_k, top_p) | ||
|
|
||
| # ── sampling (Qwen3 vocab=151936) ───────────────────────────────────────────── | ||
| probs = torch.rand(64, 151936, dtype=torch.float32, device=device) | ||
| flashinfer.top_k_top_p_sampling_from_probs(probs, top_k, top_p) | ||
|
|
||
| # ── GEMM bf16 ───────────────────────────────────────────────────────────────── | ||
| # Llama-3.1-8B o_proj (4096×4096) and DeepSeek-V3 moe.gate (256×7168) | ||
| # Use cutlass backend to avoid cuDNN dependency. | ||
| # mm_bf16 expects b in column-major layout with shape [K, N]. | ||
| # randn(N, K).T gives shape [K, N] with strides (1, N); the kernel transposes | ||
| # b back to [N, K] (contiguous) before calling the C++ matmul. | ||
| for N, K in ((4096, 4096), (256, 7168)): | ||
| a = torch.randn(128, K, dtype=torch.bfloat16, device=device) | ||
| b = torch.randn(N, K, dtype=torch.bfloat16, device=device).T # [K, N] column-major; b.T is contiguous | ||
| flashinfer.mm_bf16(a, b, backend="cutlass") | ||
|
|
||
| # ── GEMM fp8 block-scale (DeepSeek-V3 q_proj: M×7168→1536, block=128) ──────── | ||
| M, K, N, BS = 128, 7168, 1536, 128 | ||
| a_fp8 = torch.zeros(M, K, dtype=torch.float8_e4m3fn, device=device) | ||
| b_fp8 = torch.zeros(K // BS, N, BS, dtype=torch.float8_e4m3fn, device=device) | ||
| alpha_fp8 = torch.tensor(1.0, dtype=torch.float32, device=device) | ||
| flashinfer.mm_fp8(a_fp8, b_fp8, alpha_fp8) | ||
|
|
||
| # ── GEMM mxfp8 (Blackwell SM100+: M×4096@4096×4096, block=32) ──────────────── | ||
| try: | ||
| M, K, N = 128, 4096, 4096 | ||
| a_mxfp8 = torch.zeros(M, K, dtype=torch.float8_e4m3fn, device=device) | ||
| b_mxfp8 = torch.zeros(K, N, dtype=torch.float8_e4m3fn, device=device) | ||
| a_ds = torch.ones(M, K // 32, dtype=torch.uint8, device=device) | ||
| b_ds = torch.ones(K // 32, N, dtype=torch.uint8, device=device) | ||
| flashinfer.gemm.mm_mxfp8(a_mxfp8, b_mxfp8, a_ds, b_ds) | ||
| except Exception: | ||
| pass # Requires Blackwell (SM100+) | ||
|
|
||
| # ── GEMM fp4 (Blackwell SM100+: M×7168@2048×7168, block=16) ───────────────── | ||
| try: | ||
| M, K, N, BS4 = 128, 7168, 2048, 16 | ||
| a_fp4 = torch.zeros(M, K, dtype=torch.uint8, device=device) | ||
| b_fp4 = torch.zeros(K, N, dtype=torch.uint8, device=device) | ||
| a_d4 = torch.ones(M, K // BS4, dtype=torch.float8_e4m3fn, device=device) | ||
| b_d4 = torch.ones(K, N // BS4, dtype=torch.float8_e4m3fn, device=device) | ||
| flashinfer.gemm.mm_fp4(a_fp4, b_fp4, a_d4, b_d4, block_size=BS4) | ||
| except Exception: | ||
| pass # Requires Blackwell (SM100+) | ||
|
|
||
| # ── GQA paged decode (Llama-3.1-8B, h=32/kv=8/d=128) ──────────────────────── | ||
| num_qo, num_kv, head_dim, batch_size = 32, 8, 128, 32 | ||
|
|
||
| for page_size, num_pages in ((16, 128), (64, 32)): | ||
| total = batch_size * num_pages | ||
| kv_indptr = torch.arange(batch_size + 1, dtype=torch.int32, device=device) * num_pages | ||
| kv_indices = torch.arange(total, dtype=torch.int32, device=device) | ||
| kv_last = torch.full((batch_size,), page_size, dtype=torch.int32, device=device) | ||
|
|
||
| ws = torch.empty(WORKSPACE, dtype=torch.uint8, device=device) | ||
| dec = BatchDecodeWithPagedKVCacheWrapper(ws, "NHD") | ||
| dec.plan( | ||
| kv_indptr, kv_indices, kv_last, | ||
| num_qo, num_kv, head_dim, page_size, | ||
| q_data_type=torch.bfloat16, kv_data_type=torch.bfloat16, | ||
| ) | ||
| q_d = torch.randn(batch_size, num_qo, head_dim, dtype=torch.bfloat16, device=device) | ||
| kc = torch.randn(total, page_size, num_kv, head_dim, dtype=torch.bfloat16, device=device) | ||
| vc = torch.randn(total, page_size, num_kv, head_dim, dtype=torch.bfloat16, device=device) | ||
| dec.run(q_d, (kc, vc)) | ||
|
|
||
| # ── GQA paged prefill (Llama-3.1-8B, h=32/kv=8/d=128, page_size=16) ───────── | ||
| n_req, total_q, np_pf, page_size = 4, 512, 32, 16 | ||
| total_pf = n_req * np_pf | ||
| qo_indptr = torch.tensor([0, 128, 256, 384, 512], dtype=torch.int32, device=device) | ||
| kv_indptr_p = torch.arange(n_req + 1, dtype=torch.int32, device=device) * np_pf | ||
| kv_idx_p = torch.arange(total_pf, dtype=torch.int32, device=device) | ||
| kv_last_p = torch.full((n_req,), page_size, dtype=torch.int32, device=device) | ||
|
|
||
| ws_pf = torch.empty(WORKSPACE, dtype=torch.uint8, device=device) | ||
| pf = BatchPrefillWithPagedKVCacheWrapper(ws_pf, "NHD") | ||
| pf.plan( | ||
| qo_indptr, kv_indptr_p, kv_idx_p, kv_last_p, | ||
| num_qo, num_kv, head_dim, page_size, | ||
| causal=True, q_data_type=torch.bfloat16, kv_data_type=torch.bfloat16, | ||
| ) | ||
| q_pf = torch.randn(total_q, num_qo, head_dim, dtype=torch.bfloat16, device=device) | ||
| kc_pf = torch.randn(total_pf, page_size, num_kv, head_dim, dtype=torch.bfloat16, device=device) | ||
| vc_pf = torch.randn(total_pf, page_size, num_kv, head_dim, dtype=torch.bfloat16, device=device) | ||
| pf.run(q_pf, (kc_pf, vc_pf)) | ||
|
|
||
| # ── GQA ragged prefill (Llama-3.1-8B) ──────────────────────────────────────── | ||
| qo_indptr_r = torch.tensor([0, 64, 128, 192, 256], dtype=torch.int32, device=device) | ||
| kv_indptr_r = torch.tensor([0, 128, 256, 384, 512], dtype=torch.int32, device=device) | ||
|
|
||
| ws_r = torch.empty(WORKSPACE, dtype=torch.uint8, device=device) | ||
| rag = BatchPrefillWithRaggedKVCacheWrapper(ws_r, "NHD") | ||
| rag.plan( | ||
| qo_indptr_r, kv_indptr_r, | ||
| num_qo, num_kv, head_dim, | ||
| causal=True, q_data_type=torch.bfloat16, kv_data_type=torch.bfloat16, | ||
| ) | ||
| q_r = torch.randn(256, num_qo, head_dim, dtype=torch.bfloat16, device=device) | ||
| k_r = torch.randn(512, num_kv, head_dim, dtype=torch.bfloat16, device=device) | ||
| v_r = torch.randn(512, num_kv, head_dim, dtype=torch.bfloat16, device=device) | ||
| rag.run(q_r, k_r, v_r) | ||
|
|
||
| # ── MLA paged decode (DeepSeek-V3 TP=8, h=16/ckv=512/kpe=64) ───────────────── | ||
| mla_b, mla_h, ckv, kpe = 128, 16, 512, 64 | ||
|
|
||
| for mla_ps, mla_np in ((64, 32), (1, 2048)): | ||
| total_mla = mla_b * mla_np | ||
| mla_qo_indptr = torch.arange(mla_b + 1, dtype=torch.int32, device=device) | ||
| mla_kv_indptr = torch.arange(mla_b + 1, dtype=torch.int32, device=device) * mla_np | ||
| mla_kv_indices = torch.arange(total_mla, dtype=torch.int32, device=device) | ||
| mla_kv_len = torch.full((mla_b,), mla_np * mla_ps, dtype=torch.int32, device=device) | ||
|
|
||
| ws_mla = torch.empty(WORKSPACE, dtype=torch.uint8, device=device) | ||
| mla = BatchMLAPagedAttentionWrapper(ws_mla) | ||
| mla.plan( | ||
| mla_qo_indptr, mla_kv_indptr, mla_kv_indices, mla_kv_len, | ||
| mla_h, ckv, kpe, mla_ps, | ||
| causal=False, sm_scale=1.0 / (ckv ** 0.5), | ||
| q_data_type=torch.bfloat16, kv_data_type=torch.bfloat16, | ||
| ) | ||
| q_nope = torch.randn(mla_b, mla_h, ckv, dtype=torch.bfloat16, device=device) | ||
| q_pe = torch.randn(mla_b, mla_h, kpe, dtype=torch.bfloat16, device=device) | ||
| ckv_cache = torch.randn(total_mla, mla_ps, ckv, dtype=torch.bfloat16, device=device) | ||
| kpe_cache = torch.randn(total_mla, mla_ps, kpe, dtype=torch.bfloat16, device=device) | ||
| mla.run(q_nope, q_pe, ckv_cache, kpe_cache) | ||
|
|
||
| # ── GDN decode (Qwen3-Next TP=4, qk=4/v=8/d=128) ──────────────────────────── | ||
| B, H, HV, K = 4, 4, 8, 128 | ||
| q = torch.randn(B, 1, H, K, dtype=torch.bfloat16, device=device) | ||
| k = torch.randn(B, 1, H, K, dtype=torch.bfloat16, device=device) | ||
| v = torch.randn(B, 1, HV, K, dtype=torch.bfloat16, device=device) | ||
| state = torch.zeros(B, HV, K, K, dtype=torch.float32, device=device) | ||
| A_log = torch.zeros(HV, dtype=torch.float32, device=device) | ||
| a = torch.zeros(B, 1, HV, dtype=torch.bfloat16, device=device) | ||
| dt_bias = torch.zeros(HV, dtype=torch.float32, device=device) | ||
| b_ = torch.zeros(B, 1, HV, dtype=torch.bfloat16, device=device) | ||
| flashinfer.gdn_decode.gated_delta_rule_decode(q, k, v, state, A_log, a, dt_bias, b_) | ||
|
|
||
| # ── GDN MTP (Qwen3-Next TP=4, spec_len=4) ──────────────────────────────────── | ||
| T_mtp, pool_size = 4, 8 | ||
| q_m = torch.randn(B, T_mtp, H, K, dtype=torch.bfloat16, device=device) | ||
| k_m = torch.randn(B, T_mtp, H, K, dtype=torch.bfloat16, device=device) | ||
| v_m = torch.randn(B, T_mtp, HV, K, dtype=torch.bfloat16, device=device) | ||
| init_state = torch.zeros(pool_size, HV, K, K, dtype=torch.float32, device=device) | ||
| init_idx = torch.arange(B, dtype=torch.int32, device=device) | ||
| A_log_m = torch.zeros(HV, dtype=torch.float32, device=device) | ||
| a_m = torch.zeros(B, T_mtp, HV, dtype=torch.bfloat16, device=device) | ||
| dt_bias_m = torch.zeros(HV, dtype=torch.float32, device=device) | ||
| b_m = torch.zeros(B, T_mtp, HV, dtype=torch.bfloat16, device=device) | ||
| flashinfer.gdn_decode.gated_delta_rule_mtp( | ||
| q_m, k_m, v_m, init_state, init_idx, A_log_m, a_m, dt_bias_m, b_m | ||
| ) | ||
|
|
||
| # ── MoE FP8 (DeepSeek-V3 EP=8: 256 experts, 32 local, h=7168, i=2048, top_k=8) | ||
| try: | ||
| T_moe, H_moe, I_moe, E_tot, E_loc, BS = 128, 7168, 2048, 256, 32, 128 | ||
| routing_logits = torch.randn(T_moe, E_tot, dtype=torch.float32, device=device) | ||
| routing_bias = torch.zeros(E_tot, dtype=torch.bfloat16, device=device) | ||
| hs = torch.zeros(T_moe, H_moe, dtype=torch.float8_e4m3fn, device=device) | ||
| hs_scale = torch.ones(H_moe // BS, T_moe, dtype=torch.float32, device=device) | ||
| w1 = torch.zeros(E_loc, 2 * I_moe, H_moe, dtype=torch.float8_e4m3fn, device=device) | ||
| w1s = torch.ones(E_loc, (2 * I_moe) // BS, H_moe // BS, dtype=torch.float32, device=device) | ||
| w2 = torch.zeros(E_loc, H_moe, I_moe, dtype=torch.float8_e4m3fn, device=device) | ||
| w2s = torch.ones(E_loc, H_moe // BS, I_moe // BS, dtype=torch.float32, device=device) | ||
| flashinfer.fused_moe.trtllm_fp8_block_scale_moe( | ||
| routing_logits, routing_bias, | ||
| hs, hs_scale, | ||
| w1, w1s, | ||
| w2, w2s, | ||
| num_experts=E_tot, | ||
| top_k=8, | ||
| n_group=8, | ||
| topk_group=3, | ||
| intermediate_size=I_moe, | ||
| local_expert_offset=0, | ||
| local_num_experts=E_loc, | ||
| routed_scaling_factor=2.5, | ||
| ) | ||
| except Exception: | ||
| pass # May require specific GPU/TRT-LLM support | ||
|
|
||
| # ── Summary ─────────────────────────────────────────────────────────────────── | ||
| files = sorted(SAVE_DIR.glob("*.json")) | ||
| print(f"\nWrote {len(files)} definition files:\n") | ||
| for f in files: | ||
| defn = json.loads(f.read_text()) | ||
| print(f" {f.name}") | ||
| print(f" op_type : {defn['op_type']}") | ||
| print(f" fi_api : {next(t for t in defn['tags'] if t.startswith('fi_api:'))}") | ||
| const_axes = { | ||
| k: v["value"] | ||
| for k, v in defn["axes"].items() | ||
| if v["type"] == "const" and "value" in v | ||
| } | ||
| if const_axes: | ||
| print(f" axes : {const_axes}") | ||
| print() |
There was a problem hiding this comment.
🧩 Analysis chain
🏁 Script executed:
# Check if the file exists and get its basic info
find tests -name "example.py" -type f | head -5Repository: flashinfer-ai/flashinfer
Length of output: 90
🏁 Script executed:
# Check the structure of tests directory
find tests -type f -name "*.py" | head -20Repository: flashinfer-ai/flashinfer
Length of output: 859
🏁 Script executed:
# Look for test_ prefixed functions in the current example.py
if [ -f "tests/trace/example.py" ]; then
grep -n "^def test_" tests/trace/example.py || echo "No test_ functions found"
fiRepository: flashinfer-ai/flashinfer
Length of output: 92
🏁 Script executed:
# Check if there's a main block and understand the current structure
if [ -f "tests/trace/example.py" ]; then
grep -n "if __name__" tests/trace/example.py
fiRepository: flashinfer-ai/flashinfer
Length of output: 50
🏁 Script executed:
# Look at other test files to understand the pytest pattern used
find tests -name "test_*.py" -type f | head -5Repository: flashinfer-ai/flashinfer
Length of output: 270
🏁 Script executed:
# Check the actual content and structure of example.py
tail -20 tests/trace/example.pyRepository: flashinfer-ai/flashinfer
Length of output: 741
🏁 Script executed:
# Look at a properly structured test file to understand the pattern
head -50 tests/attention/test_alibi.pyRepository: flashinfer-ai/flashinfer
Length of output: 1692
🏁 Script executed:
# Check if tests/trace directory has any other test files
ls -la tests/trace/Repository: flashinfer-ai/flashinfer
Length of output: 309
🏁 Script executed:
# Look for pytest configuration to understand test discovery
find . -name "pytest.ini" -o -name "pyproject.toml" -o -name "setup.cfg" | head -3 | xargs grep -l "testpaths\|python_files" 2>/dev/null || echo "No specific pytest config found for test patterns"Repository: flashinfer-ai/flashinfer
Length of output: 117
Refactor as a proper pytest test with test_ prefix to ensure CI coverage.
This file is a runnable script with side effects but lacks a test_ function, so pytest won't collect it during CI. Per coding guidelines for tests/**/*.py, test functions must be prefixed with test_ and structured in subdirectories matching kernel categories.
Suggested structure
+def run_trace_example(save_dir: Path) -> list[Path]:
+ # existing body here
+ return sorted(save_dir.glob("*.json"))
+
+def test_fi_trace_example_generates_defs(tmp_path, monkeypatch):
+ monkeypatch.setenv("FLASHINFER_TRACE_DUMP", "1")
+ monkeypatch.setenv("FLASHINFER_TRACE_DUMP_DIR", str(tmp_path))
+ files = run_trace_example(tmp_path)
+ assert files, "Expected fi_trace JSON files to be generated"
+
+if __name__ == "__main__":
+ run_trace_example(SAVE_DIR)🧰 Tools
🪛 Ruff (0.15.7)
[warning] 104-104: Comment contains ambiguous × (MULTIPLICATION SIGN). Did you mean x (LATIN SMALL LETTER X)?
(RUF003)
[warning] 104-104: Comment contains ambiguous × (MULTIPLICATION SIGN). Did you mean x (LATIN SMALL LETTER X)?
(RUF003)
[warning] 114-114: Comment contains ambiguous × (MULTIPLICATION SIGN). Did you mean x (LATIN SMALL LETTER X)?
(RUF003)
[warning] 121-121: Comment contains ambiguous × (MULTIPLICATION SIGN). Did you mean x (LATIN SMALL LETTER X)?
(RUF003)
[warning] 121-121: Comment contains ambiguous × (MULTIPLICATION SIGN). Did you mean x (LATIN SMALL LETTER X)?
(RUF003)
[error] 129-130: try-except-pass detected, consider logging the exception
(S110)
[warning] 129-129: Do not catch blind exception: Exception
(BLE001)
[warning] 132-132: Comment contains ambiguous × (MULTIPLICATION SIGN). Did you mean x (LATIN SMALL LETTER X)?
(RUF003)
[warning] 132-132: Comment contains ambiguous × (MULTIPLICATION SIGN). Did you mean x (LATIN SMALL LETTER X)?
(RUF003)
[error] 140-141: try-except-pass detected, consider logging the exception
(S110)
[warning] 140-140: Do not catch blind exception: Exception
(BLE001)
[error] 276-277: try-except-pass detected, consider logging the exception
(S110)
[warning] 276-276: Do not catch blind exception: Exception
(BLE001)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@tests/trace/example.py` around lines 1 - 294, The file is a standalone script
so pytest won't collect it; convert it into a proper pytest test by moving the
top-level side-effect code into a single test function (e.g., def
test_generate_fi_trace_jsons(tmp_path):) while preserving the early environment
setup (os.environ.setdefault(...) and SAVE_DIR) before importing flashinfer, and
use the tmp_path fixture to override FLASHINFER_TRACE_DUMP_DIR/SAVE_DIR so
outputs go to a test-isolated directory; keep all calls to flashinfer functions
and wrappers (e.g., flashinfer.rmsnorm, flashinfer.fused_add_rmsnorm,
flashinfer.top_k_sampling_from_probs, flashinfer.mm_bf16,
flashinfer.gdn_decode.gated_delta_rule_decode,
BatchDecodeWithPagedKVCacheWrapper, BatchPrefillWithPagedKVCacheWrapper,
BatchPrefillWithRaggedKVCacheWrapper, BatchMLAPagedAttentionWrapper,
flashinfer.fused_moe.trtllm_fp8_block_scale_moe, etc.) inside that test, and
remove or adapt prints/assert the expected JSON files exist via
SAVE_DIR.glob("*.json") to make the test assertions deterministic for CI.
| "scale": { | ||
| "shape": null, | ||
| "dtype": "float32", | ||
| "description": "Scale factor. Default is 1/sqrt(head_size)." | ||
| } |
There was a problem hiding this comment.
Mark scale as optional to match the declared default behavior.
Line 123 says a default is applied (1/sqrt(head_size)), but scale is not marked optional. This can make schema consumers treat it as required.
🛠️ Proposed fix
"scale": {
"shape": null,
"dtype": "float32",
+ "optional": true,
"description": "Scale factor. Default is 1/sqrt(head_size)."
}📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| "scale": { | |
| "shape": null, | |
| "dtype": "float32", | |
| "description": "Scale factor. Default is 1/sqrt(head_size)." | |
| } | |
| "scale": { | |
| "shape": null, | |
| "dtype": "float32", | |
| "optional": true, | |
| "description": "Scale factor. Default is 1/sqrt(head_size)." | |
| } |
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@tests/trace/fi_trace_out/gdn_decode_qk4_v8_d128.json` around lines 120 - 124,
The "scale" field is documented as having a default (1/sqrt(head_size)) but
isn't marked optional; update the JSON schema entry for "scale" so consumers
know it may be omitted—e.g., add an optional/nullable flag or remove it from any
"required" list and set "optional": true (or equivalent) next to the "scale"
property to reflect the default behavior.
| "description": "Updated recurrent state in k-last layout [B, H, V, K]." | ||
| } | ||
| }, | ||
| "reference": "@torch.no_grad()\ndef _gdn_decode_reference(q, k, v, state, A_log, a, dt_bias, b, scale):\n \"\"\"\n Gated Delta Net decode reference implementation (k-last layout).\n\n State layout: [B, H, V, K] (k-last, K dimension at the end)\n\n Gate computation:\n g = exp(-exp(A_log) * softplus(a + dt_bias))\n beta = sigmoid(b)\n\n Delta rule update:\n state_new = g * state_old + k^T @ (beta * v + (1-beta) * k @ state_old) - k^T @ (k @ state_old)\n output = scale * q @ state_new\n \"\"\"\n B, T, num_q_heads, K = q.shape\n _, _, num_k_heads, _ = k.shape\n _, _, num_v_heads, V = v.shape\n num_heads = num_v_heads\n device = q.device\n\n if scale is None or scale == 0.0:\n scale = 1.0 / math.sqrt(K)\n\n x = a.float() + dt_bias.float() # [B, 1, HV]\n g = torch.exp(-torch.exp(A_log.float()) * F.softplus(x)) # [B, 1, HV]\n beta = torch.sigmoid(b.float()) # [B, 1, HV]\n\n q_f32 = q.squeeze(1).float()\n k_f32 = k.squeeze(1).float()\n v_f32 = v.squeeze(1).float()\n g_f32 = g.squeeze(1).float()\n beta_f32 = beta.squeeze(1).float()\n\n if state is not None:\n state_f32 = state.float()\n else:\n state_f32 = torch.zeros(B, num_heads, V, K, dtype=torch.float32, device=device)\n\n q_exp = q_f32.repeat_interleave(num_v_heads // num_q_heads, dim=1)\n k_exp = k_f32.repeat_interleave(num_v_heads // num_k_heads, dim=1)\n\n new_state = torch.zeros_like(state_f32)\n output = torch.zeros(B, num_heads, V, dtype=torch.float32, device=device)\n\n for b_idx in range(B):\n for h_idx in range(num_heads):\n q_h = q_exp[b_idx, h_idx]\n k_h = k_exp[b_idx, h_idx]\n v_h = v_f32[b_idx, h_idx]\n h_state = state_f32[b_idx, h_idx].clone().transpose(-1, -2) # [V,K] -> [K,V]\n g_val = g_f32[b_idx, h_idx]\n beta_val = beta_f32[b_idx, h_idx]\n\n old_state = g_val * h_state\n old_v = k_h @ old_state\n new_v = beta_val * v_h + (1 - beta_val) * old_v\n state_remove = k_h.unsqueeze(1) @ old_v.unsqueeze(0)\n state_update = k_h.unsqueeze(1) @ new_v.unsqueeze(0)\n h_state = old_state - state_remove + state_update\n\n output[b_idx, h_idx] = scale * (q_h @ h_state)\n new_state[b_idx, h_idx] = h_state.transpose(-1, -2) # [K,V] -> [V,K]\n\n output = output.unsqueeze(1).to(torch.bfloat16)\n return output, new_state\n" |
There was a problem hiding this comment.
🧩 Analysis chain
🏁 Script executed:
#!/bin/bash
set -euo pipefail
# 1) Find where trace "reference" snippets are parsed/executed.
rg -n -C3 --type=py 'reference|exec\(|compile\('
# 2) Verify whether execution globals/locals include math and torch.nn.functional as F.
rg -n -C3 --type=py 'import math|torch\.nn\.functional as F|globals\(|locals\('
# 3) Find other trace fixtures that use math./F. in embedded references
# to confirm the expected convention is consistent.
rg -n --type=json 'math\.|F\.softplus|F\.' testsRepository: flashinfer-ai/flashinfer
Length of output: 50381
🏁 Script executed:
#!/bin/bash
# Find files that load or process trace JSON files
find tests -type f -name "*.py" | head -20
rg -l "trace.*json|gdn_decode" tests --type py | head -10Repository: flashinfer-ai/flashinfer
Length of output: 998
🏁 Script executed:
#!/bin/bash
# Search for code that parses/loads trace JSON files
rg -n "load.*trace|json\.load|reference" tests --type py -A 3 | grep -E "trace|reference|exec|compile" | head -40Repository: flashinfer-ai/flashinfer
Length of output: 3912
🏁 Script executed:
#!/bin/bash
# Look for test files in the trace directory
ls -la tests/trace*.py tests/*trace*.py 2>/dev/null || find tests -name "*trace*" -type fRepository: flashinfer-ai/flashinfer
Length of output: 163
🏁 Script executed:
#!/bin/bash
# Read the main trace test file to understand how reference is executed
wc -l tests/test_fi_trace.py
head -100 tests/test_fi_trace.pyRepository: flashinfer-ai/flashinfer
Length of output: 3745
🏁 Script executed:
#!/bin/bash
# Search for exec/compile patterns in trace test file
rg -n "exec|compile|reference|math|import F|import math" tests/test_fi_trace.py -A 5 -B 2Repository: flashinfer-ai/flashinfer
Length of output: 50
🏁 Script executed:
#!/bin/bash
# Check the actual trace JSON file
head -30 tests/trace/fi_trace_out/gdn_decode_qk4_v8_d128.jsonRepository: flashinfer-ai/flashinfer
Length of output: 940
🏁 Script executed:
#!/bin/bash
# Search for code that loads/executes these trace JSON files
rg -rn "gdn_decode|fi_trace_out" tests --type pyRepository: flashinfer-ai/flashinfer
Length of output: 2731
🏁 Script executed:
#!/bin/bash
# Search more broadly for "reference" field execution in Python test files
rg -rn "reference.*exec|exec.*reference|\[.reference.\]" tests --type py -A 5Repository: flashinfer-ai/flashinfer
Length of output: 1018
🏁 Script executed:
#!/bin/bash
# Search for any code that executes or uses the "reference" field from JSON traces
rg -rn "\.reference|reference.*exec|eval.*reference" tests --type py -A 3 -B 3Repository: flashinfer-ai/flashinfer
Length of output: 2062
🏁 Script executed:
#!/bin/bash
# Look at the gdn test file that's mentioned
wc -l tests/gdn/test_decode_delta_rule.py
head -50 tests/gdn/test_decode_delta_rule.pyRepository: flashinfer-ai/flashinfer
Length of output: 1537
🏁 Script executed:
#!/bin/bash
# Search for how reference implementations are called in the gdn tests
rg -n "_gdn_decode_reference|_reference" tests/gdn --type py -B 2 -A 2Repository: flashinfer-ai/flashinfer
Length of output: 50
🏁 Script executed:
#!/bin/bash
# Look for any code that loads JSON traces and executes reference field
find tests -name "*.py" -type f -exec rg -l "reference" {} \; | head -5
rg -rn "loads\(|compile\(|exec\(" tests/trace --type pyRepository: flashinfer-ai/flashinfer
Length of output: 324
🏁 Script executed:
#!/bin/bash
# Check if there's a trace execution utility in the flashinfer source
find . -path ./tests -prune -o -name "*.py" -type f -print | xargs rg -l "fi_trace|trace.*reference" 2>/dev/null | head -10Repository: flashinfer-ai/flashinfer
Length of output: 284
🏁 Script executed:
#!/bin/bash
# Search more broadly for any eval/exec of code strings in Python files
rg -rn "eval\(|exec\(|compile\(" . --type py --iglob "*/flashinfer/*" | head -20Repository: flashinfer-ai/flashinfer
Length of output: 50
🏁 Script executed:
#!/bin/bash
# Examine the trace template and related files
wc -l flashinfer/trace/template.py flashinfer/fi_trace.py
head -80 flashinfer/trace/template.pyRepository: flashinfer-ai/flashinfer
Length of output: 3025
🏁 Script executed:
#!/bin/bash
# Search for reference field usage in trace modules
rg -n "reference" flashinfer/trace --type py -B 2 -A 5Repository: flashinfer-ai/flashinfer
Length of output: 30715
🏁 Script executed:
#!/bin/bash
# Check fi_trace.py for reference handling
rg -n "reference" flashinfer/fi_trace.py -B 3 -A 5Repository: flashinfer-ai/flashinfer
Length of output: 50
🏁 Script executed:
#!/bin/bash
# Look at the template.py file more carefully to see where reference is serialized
sed -n '480,510p' flashinfer/trace/template.pyRepository: flashinfer-ai/flashinfer
Length of output: 1460
🏁 Script executed:
#!/antml:thinking
The key finding is that `reference` is serialized as source code using `inspect.getsource()`. But I need to check if/where it's executed. Let me check if the trace system actually EXECUTES the reference field or just stores it.
</thinking>
# Now check the GDN reference function definition
cat flashinfer/trace/templates/gdn.py | head -100Repository: flashinfer-ai/flashinfer
Length of output: 136
🏁 Script executed:
#!/bin/bash
# Search for where the "reference" field from JSON is actually executed
rg -rn "json.*reference|loads.*reference|exec.*reference" . --type py --iglob "*/flashinfer/*" -A 5 -B 2Repository: flashinfer-ai/flashinfer
Length of output: 50
🏁 Script executed:
#!/bin/bash
# Check the actual gdn.py template to see the reference function definition
head -80 flashinfer/trace/templates/gdn.pyRepository: flashinfer-ai/flashinfer
Length of output: 2773
🏁 Script executed:
#!/bin/bash
# Look for any code that loads trace JSON and executes the reference
find . -name "*.py" -path "*/flashinfer/*" -o -path "*/tests/*" | xargs grep -l "reference.*exec\|exec.*reference" 2>/dev/nullRepository: flashinfer-ai/flashinfer
Length of output: 164
🏁 Script executed:
#!/bin/bash
# Check if there are any test files that load and run the reference field from JSON traces
find tests -name "*.py" -type f | xargs grep -l "\.json" | head -5Repository: flashinfer-ai/flashinfer
Length of output: 204
🏁 Script executed:
#!/bin/bash
# Search for where reference strings might be executed from JSON
rg -rn "reference.*eval|reference.*exec|exec.*inspect|compile.*reference" tests flashinfer --type py -B 3 -A 3Repository: flashinfer-ai/flashinfer
Length of output: 1200
🏁 Script executed:
#!/bin/bash
# Check test_artifacts.py to see if it loads/executes traces
wc -l tests/test_artifacts.py
head -100 tests/test_artifacts.pyRepository: flashinfer-ai/flashinfer
Length of output: 4681
🏁 Script executed:
#!/bin/bash
# Look for any test that loads the JSON traces and uses them
grep -r "fi_trace_out\|gdn_decode_qk4" tests --type py -B 2 -A 5Repository: flashinfer-ai/flashinfer
Length of output: 185
🏁 Script executed:
#!/bin/bash
# Search for references to trace JSON files
rg "fi_trace_out|gdn_decode_qk4" tests -B 2 -A 5Repository: flashinfer-ai/flashinfer
Length of output: 1661
🏁 Script executed:
#!/bin/bash
# Look for any test files that actually load and execute the reference from JSON
rg "reference.*code|reference.*string|compile\(|exec\(" flashinfer/trace --type py -B 3 -A 5Repository: flashinfer-ai/flashinfer
Length of output: 1775
🏁 Script executed:
#!/bin/bash
# Directly check the JSON file to confirm what the reference field contains
python3 << 'EOF'
import json
with open('tests/trace/fi_trace_out/gdn_decode_qk4_v8_d128.json') as f:
data = json.load(f)
if 'reference' in data:
ref = data['reference']
print("First 500 chars of reference field:")
print(ref[:500])
print("\n...\n")
print("Last 500 chars:")
print(ref[-500:])
# Check for imports
if 'import math' in ref:
print("\n✓ Contains: import math")
if 'import torch' in ref:
print("✓ Contains: import torch")
if 'torch.nn.functional as F' in ref or 'torch.nn.functional' in ref:
print("✓ Contains: torch.nn.functional reference")
if 'F.softplus' in ref:
print("✓ Uses: F.softplus")
if 'math.sqrt' in ref:
print("✓ Uses: math.sqrt")
else:
print("No 'reference' field in JSON")
EOFRepository: flashinfer-ai/flashinfer
Length of output: 1164
🏁 Script executed:
#!/bin/bash
# Search for any code that actually executes the reference field from JSON
rg "reference.*exec|exec.*reference|compile.*reference|eval.*reference" flashinfer tests --type py -B 5 -A 5Repository: flashinfer-ai/flashinfer
Length of output: 1875
🏁 Script executed:
#!/bin/bash
# Check if the reference field is ONLY stored for documentation or if it's ever executed
rg "loads\(|load\(" flashinfer/trace --type py -B 3 -A 10 | head -50Repository: flashinfer-ai/flashinfer
Length of output: 50
🏁 Script executed:
#!/bin/bash
# Look for any test that loads and executes trace JSON references
find tests -name "*.py" -type f -exec grep -l "loads\|exec\|compile" {} \; | xargs grep -l "reference\|trace.*json" 2>/dev/nullRepository: flashinfer-ai/flashinfer
Length of output: 416
🏁 Script executed:
#!/bin/bash
# Check if reference is ever loaded and executed from the JSON
rg -rn "reference" flashinfer/trace/template.py -B 3 -A 10 | grep -A 10 "reference"Repository: flashinfer-ai/flashinfer
Length of output: 50
🏁 Script executed:
#!/bin/bash
# Search more directly for any eval/exec in the entire flashinfer package
find flashinfer -name "*.py" -type f | xargs grep -n "eval\|exec\|compile" | head -30Repository: flashinfer-ai/flashinfer
Length of output: 2853
🏁 Script executed:
#!/bin/bash
# Verify the actual imports in the gdn.py template file
head -30 flashinfer/trace/templates/gdn.py | grep -E "^import|^from"Repository: flashinfer-ai/flashinfer
Length of output: 189
🏁 Script executed:
#!/bin/bash
# Confirm the imports are NOT in the extracted reference string from JSON
python3 << 'EOF'
import json
with open('tests/trace/fi_trace_out/gdn_decode_qk4_v8_d128.json') as f:
data = json.load(f)
if 'reference' in data:
ref = data['reference']
print("Checking for imports in the reference field:")
if 'import math' in ref:
print("✓ Contains: import math")
else:
print("✗ MISSING: import math")
if 'import torch' in ref:
print("✓ Contains: import torch")
else:
print("✗ MISSING: import torch")
if 'torch.nn.functional' in ref or 'import F' in ref:
print("✓ Contains: torch.nn.functional or F import")
else:
print("✗ MISSING: torch.nn.functional as F import")
print("\nFirst 300 chars of reference:")
print(ref[:300])
EOFRepository: flashinfer-ai/flashinfer
Length of output: 534
🏁 Script executed:
#!/bin/bash
# Check if there are any plans or code to execute the reference field in the future
rg -rn "reference.*exec|load.*reference|reference.*run" flashinfer --type py -B 3 -A 5Repository: flashinfer-ai/flashinfer
Length of output: 50
The embedded reference code lacks required imports and will fail if executed.
Line 148's reference field contains math.sqrt and F.softplus but is serialized as a function source string without the required imports (import math and import torch.nn.functional as F). If this reference string is ever executed via exec() or eval(), it will raise NameError for unbound symbols.
Ensure that any trace execution context either:
- Injects
mathandtorch.nn.functional(aliased asF) into the execution globals, or - Includes the necessary import statements in the serialized reference string.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@tests/trace/fi_trace_out/gdn_decode_qk4_v8_d128.json` at line 148, The
reference function _gdn_decode_reference uses math.sqrt and F.softplus but the
serialized source string has no imports, causing NameError when exec/eval runs;
fix by either injecting math and torch.nn.functional as F into the exec/eval
globals where _gdn_decode_reference is executed (ensure names "math" and "F" are
present) or prepend/import lines ("import math" and "import torch.nn.functional
as F") to the serialized reference string so _gdn_decode_reference has the
required symbols at runtime.
| "final_state": { | ||
| "shape": [ | ||
| "pool_size", | ||
| "num_v_heads", | ||
| "head_size", | ||
| "head_size" | ||
| ], | ||
| "dtype": "float32", | ||
| "description": "Updated recurrent state pool in k-last layout [pool_size, H, V, K]. Unchanged if disable_state_update=True." | ||
| } |
There was a problem hiding this comment.
Documentation references undefined parameter disable_state_update.
Line 167 states "Unchanged if disable_state_update=True" but disable_state_update is not defined in the inputs section. Either add this parameter to inputs if it's required, or remove the reference from the description.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@tests/trace/fi_trace_out/gdn_mtp_qk4_v8_d128.json` around lines 159 - 168,
The doc string for "final_state" references an undefined parameter
disable_state_update; either add a boolean input named disable_state_update to
the inputs section (e.g., description: "If true, recurrent state updates are
disabled and final_state remains unchanged") or remove the mention "Unchanged if
disable_state_update=True" from the "final_state" description; update the
"final_state" description or inputs accordingly so the documentation no longer
refers to an undefined symbol.
| "description": "Updated recurrent state pool in k-last layout [pool_size, H, V, K]. Unchanged if disable_state_update=True." | ||
| } | ||
| }, | ||
| "reference": "@torch.no_grad()\ndef _gdn_mtp_reference(\n q, k, v, initial_state, initial_state_indices, A_log, a, dt_bias, b, scale,\n intermediate_states_buffer=None,\n):\n \"\"\"\n Gated Delta Net MTP (Multi-Token Prediction) reference implementation.\n\n State layout: [pool_size, H, V, K] (k-last, K dimension at the end)\n\n Gate computation:\n g = exp(-exp(A_log) * softplus(a + dt_bias))\n beta = sigmoid(b)\n\n For each token t in sequence:\n state_new = g_t * state_old + k_t^T @ (beta_t * v_t + (1-beta_t) * k_t @ state_old) - k_t^T @ (k_t @ state_old)\n output_t = scale * q_t @ state_new\n state_old = state_new # Update for next token\n \"\"\"\n B, T, num_q_heads, head_size = q.shape\n _, _, num_k_heads, _ = k.shape\n _, _, num_v_heads, _ = v.shape\n device = q.device\n\n if scale is None or scale == 0.0:\n scale = 1.0 / math.sqrt(head_size)\n\n x = a.float() + dt_bias.float() # [B, T, HV]\n g = torch.exp(-torch.exp(A_log.float()) * F.softplus(x)) # [B, T, HV]\n beta = torch.sigmoid(b.float()) # [B, T, HV]\n\n q_exp = q.repeat_interleave(num_v_heads // num_q_heads, dim=2) # [B, T, HV, K]\n k_exp = k.repeat_interleave(num_v_heads // num_k_heads, dim=2) # [B, T, HV, K]\n\n output = torch.zeros(\n (B, T, num_v_heads, head_size), dtype=torch.bfloat16, device=device\n )\n cache_intermediate = intermediate_states_buffer is not None\n\n for b_idx in range(B):\n state_idx = int(initial_state_indices[b_idx].item())\n state_HVK = initial_state[state_idx].clone().float().transpose(-1, -2) # [H,V,K] -> [H,K,V]\n\n for t in range(T):\n q_HK = q_exp[b_idx, t].float() # [HV, K]\n k_HK = k_exp[b_idx, t].float() # [HV, K]\n v_HV = v[b_idx, t].float() # [HV, V]\n g_H = g[b_idx, t] # [HV]\n beta_H = beta[b_idx, t] # [HV]\n\n for h_idx in range(num_v_heads):\n q_h = q_HK[h_idx]\n k_h = k_HK[h_idx]\n v_h = v_HV[h_idx]\n h_state = state_HVK[h_idx]\n g_val = g_H[h_idx]\n beta_val = beta_H[h_idx]\n\n old_state = g_val * h_state\n old_v = k_h @ old_state\n new_v = beta_val * v_h + (1 - beta_val) * old_v\n state_remove = k_h.unsqueeze(1) @ old_v.unsqueeze(0)\n state_update = k_h.unsqueeze(1) @ new_v.unsqueeze(0)\n h_state = old_state - state_remove + state_update\n\n output[b_idx, t, h_idx] = (scale * (q_h @ h_state)).to(torch.bfloat16)\n state_HVK[h_idx] = h_state\n\n if cache_intermediate:\n intermediate_states_buffer[state_idx, t] = state_HVK.transpose(-1, -2) # [H,K,V] -> [H,V,K]\n\n final_state = initial_state.clone()\n return output, final_state\n" |
There was a problem hiding this comment.
Reference implementation does not return the updated state.
The reference function computes state updates in state_HVK for each batch element, but at the end returns initial_state.clone() instead of the accumulated updated state:
final_state = initial_state.clone()
return output, final_stateThis means final_state will always equal the input initial_state, discarding all computed state updates. The correct behavior should write state_HVK.transpose(-1, -2) back to final_state[state_idx] after processing each batch.
🐛 Proposed fix
+ final_state = initial_state.clone()
for b_idx in range(B):
state_idx = int(initial_state_indices[b_idx].item())
state_HVK = initial_state[state_idx].clone().float().transpose(-1, -2) # [H,V,K] -> [H,K,V]
for t in range(T):
# ... state update logic ...
if cache_intermediate:
intermediate_states_buffer[state_idx, t] = state_HVK.transpose(-1, -2) # [H,K,V] -> [H,V,K]
- final_state = initial_state.clone()
+ # Write back updated state for this batch element
+ final_state[state_idx] = state_HVK.transpose(-1, -2) # [H,K,V] -> [H,V,K]
+
return output, final_state🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@tests/trace/fi_trace_out/gdn_mtp_qk4_v8_d128.json` at line 170, The reference
function _gdn_mtp_reference updates per-batch states in state_HVK but then
returns final_state = initial_state.clone(), discarding updates; fix by creating
final_state = initial_state.clone() before the batch loop and after processing
each batch element (using state_idx = int(initial_state_indices[b_idx].item()))
write the updated state back with final_state[state_idx] =
state_HVK.transpose(-1, -2) (matching the stored [H,V,K] layout); ensure types
remain consistent (match .float()/.to dtype as needed) and then return output,
final_state.
| "description": "The 2-based log-sum-exp of attention logits." | ||
| } | ||
| }, | ||
| "reference": "@torch.no_grad()\ndef _mla_paged_decode_reference(\n q_nope, q_pe, ckv_cache, kpe_cache, kv_indptr, kv_indices, sm_scale\n):\n batch_size, num_qo_heads, head_dim_ckv = q_nope.shape\n len_indptr = kv_indptr.shape[0]\n\n Kc_all = ckv_cache.squeeze(1).to(torch.float32) # [num_pages, head_dim_ckv]\n Kp_all = kpe_cache.squeeze(1).to(torch.float32) # [num_pages, head_dim_kpe]\n\n output = torch.zeros(\n (batch_size, num_qo_heads, head_dim_ckv), dtype=torch.bfloat16, device=q_nope.device\n )\n lse = torch.full(\n (batch_size, num_qo_heads), -float(\"inf\"), dtype=torch.float32, device=q_nope.device\n )\n\n for b in range(batch_size):\n page_beg = int(kv_indptr[b].item())\n page_end = int(kv_indptr[b + 1].item())\n if page_beg >= page_end:\n output[b].zero_()\n continue\n tok_idx = kv_indices[page_beg:page_end].to(torch.long)\n Kc = Kc_all[tok_idx] # [L, head_dim_ckv]\n Kp = Kp_all[tok_idx] # [L, head_dim_kpe]\n qn = q_nope[b].to(torch.float32) # [num_qo_heads, head_dim_ckv]\n qp = q_pe[b].to(torch.float32) # [num_qo_heads, head_dim_kpe]\n logits = ((qn @ Kc.T) + (qp @ Kp.T)) * sm_scale # [num_qo_heads, L]\n lse[b] = torch.logsumexp(logits, dim=-1) / math.log(2.0)\n output[b] = (torch.softmax(logits, dim=-1) @ Kc).to(torch.bfloat16)\n\n return output, lse\n" |
There was a problem hiding this comment.
🧩 Analysis chain
🌐 Web query:
For PyTorch 2.x: if qnis shape[H, D]andKcis shape[L, P, D], what are the rules for qn @ Kc.T? Also, what does .T do for tensors with more than 2 dimensions?
💡 Result:
For PyTorch 2.x, the @ operator uses torch.matmul rules for matrix multiplication. For qn (shape [H, D]) @ Kc.T where Kc is [L, P, D]: - tensor.T (or .t) on a tensor with more than 2 dimensions reverses all dimensions: for Kc [L, P, D] becomes [D, P, L]. This behavior is deprecated for >2D and will error in future releases; use .mT to transpose only the last two dims or .transpose(-2, -1). - However, matmul treats the 2D qn [H, D] and 3D Kc.T [D, P, L] as batched: prepends 1 to qn making it [...=1, H, D], broadcasts batch dims [1] with [D] (invalid, sizes differ), so raises RuntimeError: "matmul: Input operand 0 does not have enough dimensions" or shape mismatch. The inner matrix dims D==D match, but batch dims don't broadcast. To compute batched attention scores (intended [H, L*P] or reshaped), use qn @ Kc.transpose(-2, -1) which is [H, D] @ [L, P, D] -> [L, P, D]^T=[L, D, P], batched matmul gives [H, L, P] after broadcasting the missing batch dim on qn. .mT on Kc [L,P,D] gives [L, D, P] (transposes last two), then qn @ Kc.mT gives [H, L, P].
Citations:
- 1: https://docs.pytorch.org/docs/stable/generated/torch.matmul.html
- 2: https://dev.to/hyperkai/matrix-and-element-wise-multiplication-in-pytorch-25le
- 3: https://runebook.dev/en/docs/pytorch/generated/torch.mm
- 4: https://docs.pytorch.org/docs/1.10/generated/torch.matmul.html
- 5: https://pytorch.org/docs/stable/tensors.html
- 6: https://docs.pytorch.org/docs/1.10/generated/torch.t.html
- 7: https://www.codegenes.net/blog/pytorch-t/
- 8: https://docs.pytorch.org/docs/2.0/tensors.html
- 9: [ux] Proposal to have t() === transpose(-1, -2), since batches are very frequent pytorch/pytorch#51280
🏁 Script executed:
# First, let's look at the JSON file structure and find the input shapes
fd -type f -name "mla_paged_decode_h16_ckv512_kpe64_ps64.json" | head -1 | xargs -I {} sh -c '
# Get file size to decide how to read it
wc -l {} | awk "{print \$1}"
'Repository: flashinfer-ai/flashinfer
Length of output: 238
🏁 Script executed:
# Find the JSON file using git ls-files
git ls-files | grep "mla_paged_decode_h16_ckv512_kpe64_ps64.json"Repository: flashinfer-ai/flashinfer
Length of output: 136
🏁 Script executed:
# Get the line count and examine the JSON structure
wc -l tests/trace/fi_trace_out/mla_paged_decode_h16_ckv512_kpe64_ps64.jsonRepository: flashinfer-ai/flashinfer
Length of output: 140
🏁 Script executed:
# Read the entire JSON file to understand input shapes
cat -n tests/trace/fi_trace_out/mla_paged_decode_h16_ckv512_kpe64_ps64.jsonRepository: flashinfer-ai/flashinfer
Length of output: 5318
Reference function squeeze(1) is incorrect for paged layout.
On Line 123, squeeze(1) cannot remove the page_size dimension when it equals 64. The inputs ckv_cache and kpe_cache have shape [num_pages, page_size, head_dim_*], so squeeze(1) leaves them 3D. The subsequent indexing Kc_all[tok_idx] produces [L, 64, head_dim_ckv] instead of the intended 2D [L, head_dim_ckv], causing the matmul qn @ Kc.T to fail.
Replace squeeze(1) with reshape(num_pages * page_size, head_dim_ckv) to flatten pages into a token-major 2D tensor, or adjust the indexing logic if indices are page-level rather than token-level.
Additionally, specify explicit dtypes for kv_indptr and kv_indices (currently "unknown") to reduce schema ambiguity.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@tests/trace/fi_trace_out/mla_paged_decode_h16_ckv512_kpe64_ps64.json` at line
123, In _mla_paged_decode_reference the use of ckv_cache.squeeze(1) and
kpe_cache.squeeze(1) is wrong for paged tensors (shape [num_pages, page_size,
head_dim_*]) and leaves a 3D tensor so Kc_all[tok_idx] yields [L, page_size,
head_dim]; replace squeeze(1) with a flattening reshape (e.g. reshape(num_pages
* page_size, head_dim_ckv) / reshape(..., head_dim_kpe) or view(-1, head_dim_*))
so Kc_all and Kp_all become 2D token-major tensors before indexing, and ensure
kv_indptr and kv_indices are cast to explicit integer dtype (torch.long/int64)
before use to remove schema ambiguity.
There was a problem hiding this comment.
Actionable comments posted: 12
♻️ Duplicate comments (1)
flashinfer/trace/templates/moe.py (1)
25-27:⚠️ Potential issue | 🟠 MajorThe MoE reference is still hard-wired to one hidden/intermediate shape.
H=7168andI=2048leak into the scale expansion, output allocation, andG1split, so any traced MoE with different shapes will either reshape incorrectly or slice the expert output wrong.Suggested fix
-H = 7168 -I = 2048 BLOCK = 128 @@ - T = hidden_states.shape[0] + T, H = hidden_states.shape + I = gemm2_weights.shape[2] + gemm1_out = gemm1_weights.shape[1] + if gemm1_out != 2 * I: + raise ValueError( + f"Invalid gemm1_out_size={gemm1_out}, expected 2 * intermediate_size={2 * I}" + )Also applies to: 53-57, 72-88
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@flashinfer/trace/templates/moe.py` around lines 25 - 27, The code hard-codes H, I, and BLOCK which leak into scale expansion, output allocation, and the G1 split causing incorrect reshapes/slices for other MoE shapes; replace these constants with dynamic values derived from the model/tensor shapes (e.g., infer hidden_size and intermediate_size from the input/weight tensors or pass them as parameters), update any uses in scale expansion, output allocation, and the G1 split logic (references: H, I, BLOCK, and the G1 split/expert output slicing code in moe.py) to compute sizes at runtime and use those computed sizes for reshape, split and slice operations so traced models with different H/I/BLOCK work correctly.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@flashinfer/trace/templates/attention.py`:
- Line 357: The variables len_indptr and page_size are assigned but never used;
to fix the ruff F841 error, either remove the unused assignments or rename them
to _len_indptr and _page_size (or prefix with a single underscore) where they
are set (e.g., the len_indptr = kv_indptr.shape[0] assignment and the other
page_size assignment in the same module), and apply the same change to the
second occurrence that also triggers the warning so the linter no longer reports
unused locals.
- Around line 144-157: The prefill path incorrectly treats kv_indices as token
indices by indexing k_flat/v_flat directly with page_ids (kv_indices) and
computing num_kv_tokens from page_ids.shape[0]; instead expand the selected
pages to token-level rows first: use kv_indices[kv_start:kv_end] to select page
rows from the full per-page KV buffer (not the flattened token axis), then
concatenate or expand those page rows into token-level k_b and v_b and compute
num_kv_tokens from the resulting expanded KV token rows; update usages around
k_flat, v_flat, page_ids, k_b, v_b and num_kv_tokens so page->token expansion
happens before indexing the flattened token axis.
- Around line 42-53: kv_indices currently represents page IDs, so indexing
k_flat/v_flat directly with kv_indices selects wrong rows when page_size > 1;
instead, first gather the pages from the original k_cache and v_cache using
kv_indices (use kv_indices to index the page dimension of k_cache/v_cache to
produce per-token page slices), then flatten or reshape the gathered per-page
tensors into token-level rows and proceed (so create k_b/v_b by gathering pages
via kv_indices from k_cache/v_cache, then reshape to [T, num_kv_heads, head_dim]
before using them as k_b and v_b); update all uses of k_flat/v_flat and
token_ids accordingly and ensure kv_indptr logic still slices the kv_indices by
token count, not flattened token offsets.
- Around line 359-383: The reference implementations _mla_paged_decode_reference
and _mla_paged_prefill_reference assume page_size==1 by calling
ckv_cache.squeeze(1) and kpe_cache.squeeze(1); instead update these functions to
flatten the page and token dimensions so arbitrary page_size works (e.g.,
replace squeeze(1) with a reshape/flatten to (-1, head_dim_ckv) for Kc_all and
(-1, head_dim_kpe) for Kp_all or use flatten(0,1)), ensuring subsequent indexing
via kv_indices still selects the correct token rows; alternatively, if you
prefer to keep the current code, enforce page_size==1 in the TraceTemplate
schema, but do not leave squeeze(1) as-is.
In `@flashinfer/trace/templates/gdn.py`:
- Around line 153-157: The Tensor schema for the "output" entries in
flashinfer/trace/templates/gdn.py currently uses dtype_from="q" but the
implementation always casts outputs to torch.bfloat16; update the schema to
reflect the real emitted dtype by replacing dtype_from="q" with dtype="bfloat16"
for the "output" Tensor declarations (the entries named "output" in the
template), or alternatively model the runtime control explicitly if outputs can
vary; make the same change for the other "output" Tensor occurrences mentioned
so the trace metadata matches the torch.bfloat16 casts in the code.
- Around line 382-415: The function mutates per-example head states in state_HVK
but returns final_state built from the unchanged initial_state; fix by writing
the updated state_HVK back into final_state before returning. After the outer
loops (or just before return), clone initial_state into final_state as done now
and then for each b_idx set final_state[state_idx] = state_HVK.transpose(-1, -2)
(or assign the corresponding typed/ device-matched tensor) so the updated
[H,V,K] state for the sample index (state_idx derived from
initial_state_indices[b_idx]) is committed; ensure dtype/device matches
initial_state when assigning.
- Around line 205-206: gdn_prefill_trace currently expands q and k with
repeat_interleave using num_v_heads // num_q_heads and num_v_heads //
num_k_heads but does not validate the required head-ratio constraints; add
explicit checks in gdn_prefill_trace to assert num_v_heads >= num_q_heads and
num_v_heads % num_q_heads == 0 and also assert num_k_heads == num_q_heads (or
otherwise enforce the same constraints used by decode/MTP), and apply the same
fixes to the other expansion site (the block around the q/k/v repeat_interleave
at the later occurrence). Ensure the assertions raise clear errors mentioning
num_v_heads, num_q_heads, and num_k_heads so invalid head layouts are rejected
before repeat_interleave is called.
In `@flashinfer/trace/templates/gemm.py`:
- Around line 57-78: The template misdeclares packed uint8 inputs as logical FP4
shapes causing fi_trace to infer wrong K/N; update the public trace signatures
(or add a pre-trace extractor) so the runtime sees packed dimensions: treat A
and B as [M, K_packed] and [K_packed, N_packed] (or expose an extractor that
maps packed -> logical by doubling the last axis) and propagate corrected
logical axes into mm_fp4_trace before calling _mm_fp4_reference/_unpack_fp4;
apply the same change to the other occurrence around the second block (the
191-200 region) and ensure a_descale/b_descale shape metadata matches the
packed-block layout.
- Around line 22-35: The reference GEMM helpers currently transpose B (using .T)
even though B is modeled as the physical [K, N] tensor; update _mm_reference and
_mm_fp8_reference (and the other similar reference helpers in the file) to
multiply A by B directly (remove the .T and B_fp32.T), keeping the same dtype
conversions and return types (e.g., _mm_fp8_reference should still dequantize to
float32, matmul, then cast to bfloat16), and update any docstrings/comments that
incorrectly describe B as needing transpose.
In `@flashinfer/trace/templates/moe.py`:
- Around line 577-598: The direct attribute assignment
trtllm_fp8_block_scale_moe_trace_dispatch.templates causes mypy attr-defined
errors; replace that assignment with a setattr call to attach the templates list
at runtime (e.g., use setattr(trtllm_fp8_block_scale_moe_trace_dispatch,
"templates", list(_MOE_TRACE_BY_ROUTING_TYPE.values()))), keeping the same value
(list of _MOE_TRACE_BY_ROUTING_TYPE.values()) and preserving behavior for
_attach_fi_trace registration.
In `@tests/trace/test_fi_trace_template_consistency.py`:
- Around line 193-199: The test E2E generator currently assigns 0 for int32
scalars in the loop over template.inputs which can create impossible values
(e.g., block_size=0); update the assignment in the loop that inspects
isinstance(descriptor, Scalar) and uses _resolved_param(json_key, descriptor) so
that int32 defaults are positive (e.g., 1 or another small positive) and
preferably support per-parameter overrides for constrained scalars before
populating kwargs; ensure any change keeps optional descriptors skipped and
preserves the dtype branch for non-int32 floats, so assert_fi_trace_complete()
validates realistic traces.
---
Duplicate comments:
In `@flashinfer/trace/templates/moe.py`:
- Around line 25-27: The code hard-codes H, I, and BLOCK which leak into scale
expansion, output allocation, and the G1 split causing incorrect reshapes/slices
for other MoE shapes; replace these constants with dynamic values derived from
the model/tensor shapes (e.g., infer hidden_size and intermediate_size from the
input/weight tensors or pass them as parameters), update any uses in scale
expansion, output allocation, and the G1 split logic (references: H, I, BLOCK,
and the G1 split/expert output slicing code in moe.py) to compute sizes at
runtime and use those computed sizes for reshape, split and slice operations so
traced models with different H/I/BLOCK work correctly.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: 8f771e55-82cf-425e-9083-fba1ef3390e8
📒 Files selected for processing (7)
.claude/skills/add-cuda-kernel/SKILL.mdflashinfer/api_logging.pyflashinfer/trace/templates/attention.pyflashinfer/trace/templates/gdn.pyflashinfer/trace/templates/gemm.pyflashinfer/trace/templates/moe.pytests/trace/test_fi_trace_template_consistency.py
| k_flat = k_cache.reshape(-1, num_kv_heads, head_dim).to(torch.float32) | ||
| v_flat = v_cache.reshape(-1, num_kv_heads, head_dim).to(torch.float32) | ||
|
|
||
| for b in range(batch_size): | ||
| page_start = int(kv_indptr[b].item()) | ||
| page_end = int(kv_indptr[b + 1].item()) | ||
| if page_start >= page_end: | ||
| output[b].zero_() | ||
| continue | ||
| token_ids = kv_indices[page_start:page_end].to(torch.long) | ||
| k_b = k_flat[token_ids] # [T, num_kv_heads, head_dim] | ||
| v_b = v_flat[token_ids] |
There was a problem hiding this comment.
kv_indices are page IDs, but decode reference indexes flattened tokens.
This reference is incorrect when page_size > 1: indexing k_flat/v_flat with page IDs selects wrong rows. Use page gather first, then flatten within selected pages.
Proposed fix
- k_flat = k_cache.reshape(-1, num_kv_heads, head_dim).to(torch.float32)
- v_flat = v_cache.reshape(-1, num_kv_heads, head_dim).to(torch.float32)
@@
- token_ids = kv_indices[page_start:page_end].to(torch.long)
- k_b = k_flat[token_ids] # [T, num_kv_heads, head_dim]
- v_b = v_flat[token_ids]
+ page_ids = kv_indices[page_start:page_end].to(torch.long)
+ k_b = k_cache[page_ids].reshape(-1, num_kv_heads, head_dim).to(torch.float32)
+ v_b = v_cache[page_ids].reshape(-1, num_kv_heads, head_dim).to(torch.float32)🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@flashinfer/trace/templates/attention.py` around lines 42 - 53, kv_indices
currently represents page IDs, so indexing k_flat/v_flat directly with
kv_indices selects wrong rows when page_size > 1; instead, first gather the
pages from the original k_cache and v_cache using kv_indices (use kv_indices to
index the page dimension of k_cache/v_cache to produce per-token page slices),
then flatten or reshape the gathered per-page tensors into token-level rows and
proceed (so create k_b/v_b by gathering pages via kv_indices from
k_cache/v_cache, then reshape to [T, num_kv_heads, head_dim] before using them
as k_b and v_b); update all uses of k_flat/v_flat and token_ids accordingly and
ensure kv_indptr logic still slices the kv_indices by token count, not flattened
token offsets.
| k_flat = k_cache.reshape(-1, num_kv_heads, head_dim).to(torch.float32) | ||
| v_flat = v_cache.reshape(-1, num_kv_heads, head_dim).to(torch.float32) | ||
|
|
||
| for b in range(len_indptr - 1): | ||
| q_start = int(qo_indptr[b].item()) | ||
| q_end = int(qo_indptr[b + 1].item()) | ||
| kv_start = int(kv_indptr[b].item()) | ||
| kv_end = int(kv_indptr[b + 1].item()) | ||
| if q_start >= q_end or kv_start >= kv_end: | ||
| continue | ||
| page_ids = kv_indices[kv_start:kv_end].to(torch.long) | ||
| k_b = k_flat[page_ids] | ||
| v_b = v_flat[page_ids] | ||
| num_kv_tokens = page_ids.shape[0] |
There was a problem hiding this comment.
Prefill reference has the same page-id/token-id mismatch.
kv_indices are documented as page IDs, but this path indexes a flattened token axis directly. Expand selected pages first, then derive num_kv_tokens from expanded KV rows.
Proposed fix
- k_flat = k_cache.reshape(-1, num_kv_heads, head_dim).to(torch.float32)
- v_flat = v_cache.reshape(-1, num_kv_heads, head_dim).to(torch.float32)
@@
- page_ids = kv_indices[kv_start:kv_end].to(torch.long)
- k_b = k_flat[page_ids]
- v_b = v_flat[page_ids]
- num_kv_tokens = page_ids.shape[0]
+ page_ids = kv_indices[kv_start:kv_end].to(torch.long)
+ k_b = k_cache[page_ids].reshape(-1, num_kv_heads, head_dim).to(torch.float32)
+ v_b = v_cache[page_ids].reshape(-1, num_kv_heads, head_dim).to(torch.float32)
+ num_kv_tokens = k_b.shape[0]🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@flashinfer/trace/templates/attention.py` around lines 144 - 157, The prefill
path incorrectly treats kv_indices as token indices by indexing k_flat/v_flat
directly with page_ids (kv_indices) and computing num_kv_tokens from
page_ids.shape[0]; instead expand the selected pages to token-level rows first:
use kv_indices[kv_start:kv_end] to select page rows from the full per-page KV
buffer (not the flattened token axis), then concatenate or expand those page
rows into token-level k_b and v_b and compute num_kv_tokens from the resulting
expanded KV token rows; update usages around k_flat, v_flat, page_ids, k_b, v_b
and num_kv_tokens so page->token expansion happens before indexing the flattened
token axis.
| Kc_all = ckv_cache.squeeze(1).to(torch.float32) # [num_pages, head_dim_ckv] | ||
| Kp_all = kpe_cache.squeeze(1).to(torch.float32) # [num_pages, head_dim_kpe] | ||
|
|
||
| output = torch.zeros( | ||
| (batch_size, num_qo_heads, head_dim_ckv), dtype=torch.bfloat16, device=q_nope.device | ||
| ) | ||
| lse = torch.full( | ||
| (batch_size, num_qo_heads), -float("inf"), dtype=torch.float32, device=q_nope.device | ||
| ) | ||
|
|
||
| for b in range(batch_size): | ||
| page_beg = int(kv_indptr[b].item()) | ||
| page_end = int(kv_indptr[b + 1].item()) | ||
| if page_beg >= page_end: | ||
| output[b].zero_() | ||
| continue | ||
| tok_idx = kv_indices[page_beg:page_end].to(torch.long) | ||
| Kc = Kc_all[tok_idx] # [L, head_dim_ckv] | ||
| Kp = Kp_all[tok_idx] # [L, head_dim_kpe] | ||
| qn = q_nope[b].to(torch.float32) # [num_qo_heads, head_dim_ckv] | ||
| qp = q_pe[b].to(torch.float32) # [num_qo_heads, head_dim_kpe] | ||
| logits = ((qn @ Kc.T) + (qp @ Kp.T)) * sm_scale # [num_qo_heads, L] | ||
| lse[b] = torch.logsumexp(logits, dim=-1) / math.log(2.0) | ||
| output[b] = (torch.softmax(logits, dim=-1) @ Kc).to(torch.bfloat16) | ||
|
|
There was a problem hiding this comment.
🧩 Analysis chain
🏁 Script executed:
#!/bin/bash
set -euo pipefail
# 1) Locate all MLA paged wrappers/templates and check whether page_size is constrained to 1.
rg -nP --type=py '\bmla_paged\b|\bpage_size\b|\bkv_indices\b|\bckv_cache\b|\bkpe_cache\b' -C2
# 2) Find explicit invariants/assertions for page_size==1 (if any).
rg -nP --type=py 'page_size\s*==\s*1|assert\s+.*page_size|squeeze\(\s*1\s*\)' -C2
# 3) Inspect fi_trace tests/fixtures for MLA coverage and page_size values used.
rg -nP --type=py --type=json 'mla_paged|page_size' tests -C2Repository: flashinfer-ai/flashinfer
Length of output: 50381
🏁 Script executed:
# Examine the specific file and lines mentioned in the review
head -n 520 flashinfer/trace/templates/attention.py | tail -n 180Repository: flashinfer-ai/flashinfer
Length of output: 6896
🏁 Script executed:
# Check for schema constraints on page_size for MLA
rg -nP 'page_size.*==.*1|page_size.*constraint|mla.*page_size|page_size.*schema' flashinfer -A3 -B3Repository: flashinfer-ai/flashinfer
Length of output: 6949
🏁 Script executed:
# Look for comments or documentation about MLA page_size limitations
rg -nP 'MLA|mla_paged|mla' flashinfer/trace/templates/ -A5 -B5 | head -n 200Repository: flashinfer-ai/flashinfer
Length of output: 6947
🏁 Script executed:
# Check how the reference functions are used
rg -n "reference.*mla_paged|_mla_paged.*reference" flashinfer/trace/templates/attention.py -A2 -B2Repository: flashinfer-ai/flashinfer
Length of output: 620
🏁 Script executed:
# Look at test_fi_trace.py to see if it uses the reference implementations
grep -n "reference\|fi_trace.*mla" tests/test_fi_trace.py | head -30Repository: flashinfer-ai/flashinfer
Length of output: 50
🏁 Script executed:
# Check if TraceTemplate reference functions are ever called
rg -n "\.reference\(\|run_reference" flashinfer/trace -A3 -B3Repository: flashinfer-ai/flashinfer
Length of output: 50
🏁 Script executed:
# Check how TraceTemplate is used and if reference is called
rg -n "class TraceTemplate|def.*reference" flashinfer/trace -A5 | head -100Repository: flashinfer-ai/flashinfer
Length of output: 6857
🏁 Script executed:
# Look at how fi_trace is implemented
find flashinfer -name "*.py" -exec grep -l "fi_trace" {} \; | head -10Repository: flashinfer-ai/flashinfer
Length of output: 270
🏁 Script executed:
# Check if there's validation code that calls the reference implementations
rg -n "reference" flashinfer/trace/core.py -A3 -B3Repository: flashinfer-ai/flashinfer
Length of output: 132
🏁 Script executed:
# Check the fi_trace.py file
cat -n flashinfer/fi_trace.py | head -100Repository: flashinfer-ai/flashinfer
Length of output: 4308
🏁 Script executed:
# Check the template.py file for TraceTemplate
grep -n "class TraceTemplate\|reference" flashinfer/trace/template.py | head -50Repository: flashinfer-ai/flashinfer
Length of output: 476
🏁 Script executed:
# Look at how reference is used in tests
grep -n "reference\|\.run\|\.plan" tests/test_fi_trace.py -A2 -B2 | head -80Repository: flashinfer-ai/flashinfer
Length of output: 2127
🏁 Script executed:
# Check if reference functions are ever actually invoked/called
rg -n "template\.reference\|_mla_paged_decode_reference\|_mla_paged_prefill_reference" flashinfer -A3Repository: flashinfer-ai/flashinfer
Length of output: 50
🏁 Script executed:
# Look at the actual CUDA/kernel implementations to see if they have page_size constraints
find flashinfer/mla -name "*.py" -o -name "*.cu" | head -10Repository: flashinfer-ai/flashinfer
Length of output: 316
🏁 Script executed:
# Check the cute_dsl MLA implementation for page_size constraints
grep -n "page_size.*1\|squeeze" flashinfer/mla/cute_dsl/mla_decode_fp16.py | head -30Repository: flashinfer-ai/flashinfer
Length of output: 318
🏁 Script executed:
# Check what the actual MLA wrapper implementation looks like
grep -n "class.*MLA\|def.*run\|page_size" flashinfer/mla/_core.py | head -50Repository: flashinfer-ai/flashinfer
Length of output: 1405
🏁 Script executed:
# Look at tests to see if they verify correctness with page_size != 1
grep -n "page_size.*64\|page_size.*16\|test_mla" tests/test_fi_trace.py -B5 -A10Repository: flashinfer-ai/flashinfer
Length of output: 4860
🏁 Script executed:
# Check if there are actual runtime tests for MLA with different page sizes
rg -n "BatchMLAPagedAttentionWrapper\|mla_paged" tests/ -A10 | head -100Repository: flashinfer-ai/flashinfer
Length of output: 50
🏁 Script executed:
# Check if reference implementations are actually called during testing or validation
rg -n "reference\(" flashinfer -A5 | head -40Repository: flashinfer-ai/flashinfer
Length of output: 3161
🏁 Script executed:
# Look at the actual wrapper implementation to see how it handles different page_size
sed -n '442,550p' flashinfer/mla/_core.pyRepository: flashinfer-ai/flashinfer
Length of output: 4684
🏁 Script executed:
# Check if squeeze is used anywhere in actual MLA implementations
rg -n "squeeze" flashinfer/mla/_core.pyRepository: flashinfer-ai/flashinfer
Length of output: 213
🏁 Script executed:
# Check if the reference implementations are actually ever invoked during normal operation
rg -n "\.reference\|call.*reference" flashinfer -A5 -B5Repository: flashinfer-ai/flashinfer
Length of output: 50
🏁 Script executed:
# Verify that actual kernels handle arbitrary page_size
grep -n "page_size\|ckv_cache\|kpe_cache" flashinfer/mla/_core.py | grep -A3 -B3 "cutlass\|cuda"Repository: flashinfer-ai/flashinfer
Length of output: 708
🏁 Script executed:
# Check what the cute_dsl actually does with page_size
head -n 100 flashinfer/mla/cute_dsl/mla_decode_fp16.py | grep -n "page_size"Repository: flashinfer-ai/flashinfer
Length of output: 50
🏁 Script executed:
# Check if reference implementations are embedded in JSON or used elsewhere
grep -rn "reference" flashinfer/trace/template.py | head -20Repository: flashinfer-ai/flashinfer
Length of output: 451
🏁 Script executed:
# Look for how flashinfer-bench uses the reference
find . -name "*bench*" -type f | grep -E "\.py$|\.json$" | head -10Repository: flashinfer-ai/flashinfer
Length of output: 515
🏁 Script executed:
# Verify that tests with page_size=64 actually pass/work
grep -A 20 "def test_mla_paged_fi_trace" tests/test_fi_trace.py | head -30Repository: flashinfer-ai/flashinfer
Length of output: 917
Reference implementations for MLA paged templates assume single-token pages but schema permits arbitrary page_size.
The _mla_paged_decode_reference() and _mla_paged_prefill_reference() functions use squeeze(1) to remove the page dimension, which is only valid when page_size == 1. However, the TraceTemplate schema defines page_size as a generic Const() without constraining it to 1. These reference implementations are embedded in the JSON trace output and used by flashinfer-bench for validation.
Tests already use page_size=64 (test_fi_trace.py:253, test_fi_trace.py:419), which would cause the reference implementations to fail or produce incorrect results.
Either enforce page_size == 1 as a schema constraint, or rewrite the reference implementations to properly flatten pages and tokens (e.g., Kc_all = ckv_cache.reshape(-1, head_dim_ckv) instead of squeeze) before the indexing operations.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@flashinfer/trace/templates/attention.py` around lines 359 - 383, The
reference implementations _mla_paged_decode_reference and
_mla_paged_prefill_reference assume page_size==1 by calling ckv_cache.squeeze(1)
and kpe_cache.squeeze(1); instead update these functions to flatten the page and
token dimensions so arbitrary page_size works (e.g., replace squeeze(1) with a
reshape/flatten to (-1, head_dim_ckv) for Kc_all and (-1, head_dim_kpe) for
Kp_all or use flatten(0,1)), ensuring subsequent indexing via kv_indices still
selects the correct token rows; alternatively, if you prefer to keep the current
code, enforce page_size==1 in the TraceTemplate schema, but do not leave
squeeze(1) as-is.
| "output": Tensor( | ||
| ["batch_size", "seq_len", "num_v_heads", "head_size"], | ||
| dtype_from="q", | ||
| description="Attention output. Shape follows num_v_heads in GVA mode.", | ||
| ), |
There was a problem hiding this comment.
The templates report output as dtype_from="q", but the references always emit bfloat16.
Lines 91, 208-210, and 377-379 cast the output tensors to torch.bfloat16, so these schemas become wrong as soon as q is not already bfloat16. The trace metadata should either fix the dtype to bfloat16 or model the real output-dtype control explicitly.
Suggested fix
- dtype_from="q",
+ dtype="bfloat16",
@@
- dtype_from="q",
+ dtype="bfloat16",
@@
- dtype_from="q",
+ dtype="bfloat16",Also applies to: 321-325, 486-490
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@flashinfer/trace/templates/gdn.py` around lines 153 - 157, The Tensor schema
for the "output" entries in flashinfer/trace/templates/gdn.py currently uses
dtype_from="q" but the implementation always casts outputs to torch.bfloat16;
update the schema to reflect the real emitted dtype by replacing dtype_from="q"
with dtype="bfloat16" for the "output" Tensor declarations (the entries named
"output" in the template), or alternatively model the runtime control explicitly
if outputs can vary; make the same change for the other "output" Tensor
occurrences mentioned so the trace metadata matches the torch.bfloat16 casts in
the code.
| def _mm_reference(A, B): | ||
| return torch.matmul(A, B.T) | ||
|
|
||
|
|
||
| def _mm_fp8_reference(A, B): | ||
| """Dequantize FP8 block-scale inputs and compute C = A @ B.T. | ||
|
|
||
| B is in TRT-LLM block layout [K//block_size, N, block_size] and is | ||
| reshaped to [K, N] before the matmul. | ||
| """ | ||
| K_div_bs, N, block_size = B.shape | ||
| B_fp32 = B.reshape(K_div_bs * block_size, N).to(torch.float32) | ||
| A_fp32 = A.to(torch.float32) | ||
| return torch.matmul(A_fp32, B_fp32.T).to(torch.bfloat16) |
There was a problem hiding this comment.
The GEMM references are transposing B after already modeling it as physical [K, N].
For mm_bf16 this is shape-invalid as soon as N != K, and the quantized helpers have the same problem after dequantization. Given these templates describe b as the physical [K, N] tensor, the reference path should multiply by B directly.
Suggested fix
def _mm_reference(A, B):
- return torch.matmul(A, B.T)
+ return torch.matmul(A, B)
@@
- return torch.matmul(A_fp32, B_fp32.T).to(torch.bfloat16)
+ return torch.matmul(A_fp32, B_fp32).to(torch.bfloat16)
@@
- return torch.matmul(A_scaled, B_scaled.T).to(torch.bfloat16)
+ return torch.matmul(A_scaled, B_scaled).to(torch.bfloat16)
@@
- return torch.matmul(A_scaled, B_scaled.T).to(torch.bfloat16)
+ return torch.matmul(A_scaled, B_scaled).to(torch.bfloat16)Also applies to: 38-55, 57-85
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@flashinfer/trace/templates/gemm.py` around lines 22 - 35, The reference GEMM
helpers currently transpose B (using .T) even though B is modeled as the
physical [K, N] tensor; update _mm_reference and _mm_fp8_reference (and the
other similar reference helpers in the file) to multiply A by B directly (remove
the .T and B_fp32.T), keeping the same dtype conversions and return types (e.g.,
_mm_fp8_reference should still dequantize to float32, matmul, then cast to
bfloat16), and update any docstrings/comments that incorrectly describe B as
needing transpose.
| def _mm_fp4_reference(A, B, a_descale, b_descale, block_size=16): | ||
| """Dequantize FP4 inputs and compute C = A @ B.T. | ||
|
|
||
| A and B are fp4 e2m1fn values packed two-per-byte as uint8. | ||
| a_descale: [M, K//block_size], b_descale: [K, N//block_size]. | ||
| The reference unpacks the nibbles and applies the block scales. | ||
| """ | ||
| def _unpack_fp4(packed, rows, cols): | ||
| # Each byte holds two fp4 nibbles (low nibble = first element). | ||
| lo = (packed & 0x0F).to(torch.float32) | ||
| hi = ((packed >> 4) & 0x0F).to(torch.float32) | ||
| # Interleave low/high nibbles along the last dimension. | ||
| out = torch.stack([lo, hi], dim=-1).reshape(rows, cols) | ||
| return out | ||
|
|
||
| M, K_packed = A.shape | ||
| K = K_packed * 2 | ||
| _, N_packed = B.shape | ||
| N = N_packed * 2 | ||
|
|
||
| A_fp32 = _unpack_fp4(A, M, K) | ||
| B_fp32 = _unpack_fp4(B, K, N) |
There was a problem hiding this comment.
mm_fp4_trace cannot infer the right logical axes from packed inputs.
Lines 72-78 make it clear the runtime tensors are packed uint8 shapes, but the template still declares a and b as [M, K] and [K, N]. fi_trace will therefore report halved or conflicting K/N values for real FP4 calls. This needs packed-dimension axes or a custom extractor before the public API can emit correct runtime traces.
Also applies to: 191-200
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@flashinfer/trace/templates/gemm.py` around lines 57 - 78, The template
misdeclares packed uint8 inputs as logical FP4 shapes causing fi_trace to infer
wrong K/N; update the public trace signatures (or add a pre-trace extractor) so
the runtime sees packed dimensions: treat A and B as [M, K_packed] and
[K_packed, N_packed] (or expose an extractor that maps packed -> logical by
doubling the last axis) and propagate corrected logical axes into mm_fp4_trace
before calling _mm_fp4_reference/_unpack_fp4; apply the same change to the other
occurrence around the second block (the 191-200 region) and ensure
a_descale/b_descale shape metadata matches the packed-block layout.
| def trtllm_fp8_block_scale_moe_trace_dispatch(**kwargs): | ||
| """Return the appropriate TraceTemplate for the given ``routing_method_type``. | ||
|
|
||
| Pass this as ``trace=trtllm_fp8_block_scale_moe_trace_dispatch`` to | ||
| ``@flashinfer_api`` so the correct template is selected at call time:: | ||
|
|
||
| @flashinfer_api(trace=trtllm_fp8_block_scale_moe_trace_dispatch) | ||
| def trtllm_fp8_block_scale_moe(..., routing_method_type: int = 0, ...): | ||
| ... | ||
|
|
||
| Returns ``None`` for ``RoutingMethodType.Unspecified`` (6), which | ||
| suppresses trace generation. | ||
| """ | ||
| routing_method_type = int(kwargs.get("routing_method_type", 0)) | ||
| return _MOE_TRACE_BY_ROUTING_TYPE.get(routing_method_type) | ||
|
|
||
|
|
||
| # Expose all possible templates so _attach_fi_trace can auto-register them | ||
| # in _TRACE_REGISTRY for consistency testing. | ||
| trtllm_fp8_block_scale_moe_trace_dispatch.templates = list( | ||
| _MOE_TRACE_BY_ROUTING_TYPE.values() | ||
| ) |
There was a problem hiding this comment.
🧩 Analysis chain
🏁 Script executed:
# First, let's examine the file to see the current state
cat -n flashinfer/trace/templates/moe.py | sed -n '570,605p'Repository: flashinfer-ai/flashinfer
Length of output: 1461
🏁 Script executed:
# Check if there's a mypy configuration or pre-commit hooks
fd -type f -name ".pre-commit*" -o -name "mypy.ini" -o -name "pyproject.toml" | head -20Repository: flashinfer-ai/flashinfer
Length of output: 238
🏁 Script executed:
# Look for mypy configuration in pyproject.toml or setup.cfg
fd -type f \( -name "pyproject.toml" -o -name "setup.cfg" \) -exec head -50 {} +Repository: flashinfer-ai/flashinfer
Length of output: 238
🏁 Script executed:
find . -maxdepth 2 -type f \( -name "pyproject.toml" -o -name ".pre-commit-config.yaml" -o -name "setup.cfg" \) 2>/dev/nullRepository: flashinfer-ai/flashinfer
Length of output: 182
🏁 Script executed:
# Search for mypy configuration
rg -l "mypy|attr-defined" --type toml --type yaml --type cfg 2>/dev/null | head -10Repository: flashinfer-ai/flashinfer
Length of output: 50
🏁 Script executed:
# Check if there are any comments or errors mentioning attr-defined in the moe.py file
rg "attr-defined|setattr" flashinfer/trace/templates/moe.pyRepository: flashinfer-ai/flashinfer
Length of output: 50
🏁 Script executed:
# Check if setattr is used elsewhere in the codebase for similar purposes
rg "setattr\(" flashinfer/ -A 2 | head -40Repository: flashinfer-ai/flashinfer
Length of output: 50
🏁 Script executed:
cat .pre-commit-config.yamlRepository: flashinfer-ai/flashinfer
Length of output: 1722
🏁 Script executed:
cat pyproject.tomlRepository: flashinfer-ai/flashinfer
Length of output: 3158
🏁 Script executed:
# Check if there are any .pyi stub files for this module
find flashinfer/trace -name "*.pyi" 2>/dev/nullRepository: flashinfer-ai/flashinfer
Length of output: 50
🏁 Script executed:
# Look for Protocol or TypedDict definitions related to function templates
rg "Protocol|TypedDict" flashinfer/trace/ --type python | head -20Repository: flashinfer-ai/flashinfer
Length of output: 98
🏁 Script executed:
# Check if _attach_fi_trace function expects templates attribute
rg "_attach_fi_trace" flashinfer/ -B 5 -A 10 | head -50Repository: flashinfer-ai/flashinfer
Length of output: 3287
🏁 Script executed:
# Check for other similar patterns where attributes are assigned to functions
rg "^\w+\.\w+\s*=\s*" flashinfer/trace/templates/ -A 0 | head -20Repository: flashinfer-ai/flashinfer
Length of output: 161
🏁 Script executed:
# Verify the exact mypy behavior - look at the api_logging.py more carefully
cat -n flashinfer/api_logging.py | sed -n '1,100p'Repository: flashinfer-ai/flashinfer
Length of output: 4356
🏁 Script executed:
# Look for any existing patterns where setattr might be used or recommended
rg "templates\s*=" flashinfer/ --type py -B 2 -A 2Repository: flashinfer-ai/flashinfer
Length of output: 460
🏁 Script executed:
# Check if there are any mypy: ignore comments or type: ignore in moe.py
rg "mypy|type.*ignore" flashinfer/trace/templates/moe.pyRepository: flashinfer-ai/flashinfer
Length of output: 50
🏁 Script executed:
# Look for the function signature and what it returns
cat -n flashinfer/trace/templates/moe.py | sed -n '577,592p'Repository: flashinfer-ai/flashinfer
Length of output: 881
🏁 Script executed:
# Check what TraceTemplate is to understand the type signature better
rg "class TraceTemplate|def TraceTemplate" flashinfer/ --type py -A 5 | head -30Repository: flashinfer-ai/flashinfer
Length of output: 380
🏁 Script executed:
# Verify the fix would work by checking how setattr is commonly used in the codebase
rg "setattr" flashinfer/ --type py -B 1 -A 1 | head -20Repository: flashinfer-ai/flashinfer
Length of output: 356
🏁 Script executed:
# Find where the .templates attribute is accessed
rg "\.templates" flashinfer/ --type py -B 2 -A 2Repository: flashinfer-ai/flashinfer
Length of output: 4375
🏁 Script executed:
# Check if there's any hasattr check
rg "hasattr.*templates" flashinfer/ --type pyRepository: flashinfer-ai/flashinfer
Length of output: 50
🏁 Script executed:
# Look at the complete _attach_fi_trace function to understand how templates is used
rg "_attach_fi_trace" flashinfer/api_logging.py -A 50 | head -80Repository: flashinfer-ai/flashinfer
Length of output: 3908
Use setattr() to avoid mypy's attr-defined error.
The direct assignment at line 596 triggers mypy's strict mode because the function object has no declared templates attribute. Using setattr() preserves identical runtime behavior while satisfying type checking requirements, which is necessary for pre-commit validation.
Suggested fix
-trtllm_fp8_block_scale_moe_trace_dispatch.templates = list(
- _MOE_TRACE_BY_ROUTING_TYPE.values()
-)
+setattr(
+ trtllm_fp8_block_scale_moe_trace_dispatch,
+ "templates",
+ list(_MOE_TRACE_BY_ROUTING_TYPE.values()),
+)📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| def trtllm_fp8_block_scale_moe_trace_dispatch(**kwargs): | |
| """Return the appropriate TraceTemplate for the given ``routing_method_type``. | |
| Pass this as ``trace=trtllm_fp8_block_scale_moe_trace_dispatch`` to | |
| ``@flashinfer_api`` so the correct template is selected at call time:: | |
| @flashinfer_api(trace=trtllm_fp8_block_scale_moe_trace_dispatch) | |
| def trtllm_fp8_block_scale_moe(..., routing_method_type: int = 0, ...): | |
| ... | |
| Returns ``None`` for ``RoutingMethodType.Unspecified`` (6), which | |
| suppresses trace generation. | |
| """ | |
| routing_method_type = int(kwargs.get("routing_method_type", 0)) | |
| return _MOE_TRACE_BY_ROUTING_TYPE.get(routing_method_type) | |
| # Expose all possible templates so _attach_fi_trace can auto-register them | |
| # in _TRACE_REGISTRY for consistency testing. | |
| trtllm_fp8_block_scale_moe_trace_dispatch.templates = list( | |
| _MOE_TRACE_BY_ROUTING_TYPE.values() | |
| ) | |
| def trtllm_fp8_block_scale_moe_trace_dispatch(**kwargs): | |
| """Return the appropriate TraceTemplate for the given ``routing_method_type``. | |
| Pass this as ``trace=trtllm_fp8_block_scale_moe_trace_dispatch`` to | |
| ``@flashinfer_api`` so the correct template is selected at call time:: | |
| `@flashinfer_api`(trace=trtllm_fp8_block_scale_moe_trace_dispatch) | |
| def trtllm_fp8_block_scale_moe(..., routing_method_type: int = 0, ...): | |
| ... | |
| Returns ``None`` for ``RoutingMethodType.Unspecified`` (6), which | |
| suppresses trace generation. | |
| """ | |
| routing_method_type = int(kwargs.get("routing_method_type", 0)) | |
| return _MOE_TRACE_BY_ROUTING_TYPE.get(routing_method_type) | |
| # Expose all possible templates so _attach_fi_trace can auto-register them | |
| # in _TRACE_REGISTRY for consistency testing. | |
| setattr( | |
| trtllm_fp8_block_scale_moe_trace_dispatch, | |
| "templates", | |
| list(_MOE_TRACE_BY_ROUTING_TYPE.values()), | |
| ) |
🧰 Tools
🪛 GitHub Actions: pre-commit
[error] 596-596: mypy error: "Callable[[KwArg(Any)], Any]" has no attribute "templates" [attr-defined]
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@flashinfer/trace/templates/moe.py` around lines 577 - 598, The direct
attribute assignment trtllm_fp8_block_scale_moe_trace_dispatch.templates causes
mypy attr-defined errors; replace that assignment with a setattr call to attach
the templates list at runtime (e.g., use
setattr(trtllm_fp8_block_scale_moe_trace_dispatch, "templates",
list(_MOE_TRACE_BY_ROUTING_TYPE.values()))), keeping the same value (list of
_MOE_TRACE_BY_ROUTING_TYPE.values()) and preserving behavior for
_attach_fi_trace registration.
| for json_key, descriptor in template.inputs.items(): | ||
| if isinstance(descriptor, Scalar): | ||
| if descriptor.optional: | ||
| continue | ||
| p = _resolved_param(json_key, descriptor) | ||
| kwargs[p] = 0 if descriptor.dtype == "int32" else 1.0 | ||
|
|
There was a problem hiding this comment.
The E2E generator should not synthesize int32 inputs as zero.
This currently creates impossible definitions such as block_size=0, so assert_fi_trace_complete() can pass on traces that would never be valid at runtime. Use a positive default, ideally with per-parameter overrides for constrained scalars.
Suggested fix
+_INT_SAMPLE_DEFAULTS = {
+ "block_size": 16,
+ "top_k": 1,
+ "n_group": 1,
+ "topk_group": 1,
+}
+
@@
if isinstance(descriptor, Scalar):
if descriptor.optional:
continue
p = _resolved_param(json_key, descriptor)
- kwargs[p] = 0 if descriptor.dtype == "int32" else 1.0
+ if descriptor.dtype == "int32":
+ kwargs[p] = _INT_SAMPLE_DEFAULTS.get(p, axis_size)
+ else:
+ kwargs[p] = 1.0🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@tests/trace/test_fi_trace_template_consistency.py` around lines 193 - 199,
The test E2E generator currently assigns 0 for int32 scalars in the loop over
template.inputs which can create impossible values (e.g., block_size=0); update
the assignment in the loop that inspects isinstance(descriptor, Scalar) and uses
_resolved_param(json_key, descriptor) so that int32 defaults are positive (e.g.,
1 or another small positive) and preferably support per-parameter overrides for
constrained scalars before populating kwargs; ensure any change keeps optional
descriptors skipped and preserves the dtype branch for non-int32 floats, so
assert_fi_trace_complete() validates realistic traces.
There was a problem hiding this comment.
Actionable comments posted: 8
♻️ Duplicate comments (12)
flashinfer/trace/templates/gdn.py (3)
165-169:⚠️ Potential issue | 🟠 MajorReport
outputasbfloat16in the schema.All three references cast their outputs to
torch.bfloat16, so inheriting dtype fromqmakes the trace metadata wrong whenever callers use another input dtype.Also applies to: 351-355, 537-541
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@flashinfer/trace/templates/gdn.py` around lines 165 - 169, The schema currently sets the attention "output" Tensor using dtype_from="q", which misreports dtype because outputs are cast to torch.bfloat16; update the Tensor definition for the "output" field in the GDN templates to use an explicit dtype of "bfloat16" (replace dtype_from="q" with dtype="bfloat16") for the occurrences around the shown block and the other two occurrences (near lines 351-355 and 537-541) so the trace metadata correctly reflects torch.bfloat16 outputs.
421-458:⚠️ Potential issue | 🟠 MajorPersist the updated pooled state before returning.
state_HVKis updated for every token, butfinal_stateis cloned frominitial_stateafter the loop and never receives those updates. The returnedfinal_stateis therefore stale, and the generated JSON fixture will be stale too.Suggested fix
- for b_idx in range(B): + final_state = initial_state.clone() + for b_idx in range(B): state_idx = int(initial_state_indices[b_idx].item()) state_HVK = ( initial_state[state_idx].clone().float().transpose(-1, -2) ) # [H,V,K] -> [H,K,V] @@ if cache_intermediate: intermediate_states_buffer[state_idx, t] = state_HVK.transpose( -1, -2 ) # [H,K,V] -> [H,V,K] - - final_state = initial_state.clone() + final_state[state_idx] = state_HVK.transpose(-1, -2) return output, final_state🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@flashinfer/trace/templates/gdn.py` around lines 421 - 458, The loop updates state_HVK per batch/state but final_state is created from initial_state and never updated, so return value is stale; update final_state with the pooled (transposed) state_HVK for each corresponding state index (initial_state_indices) after finishing updates for that state (or after the outer loops) so final_state[state_idx] = state_HVK.transpose(-1, -2) (match the same [H,V,K] ↔ [H,K,V] orientation used for initial_state/state_HVK) before returning output and final_state.
362-365:⚠️ Potential issue | 🟠 Major
gdn_prefill_traceneeds the same head-ratio constraints as the other GDN templates.The reference divides by
num_v_heads // num_q_headsandnum_v_heads // num_k_heads, but this template currently accepts layouts that make those expansions invalid.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@flashinfer/trace/templates/gdn.py` around lines 362 - 365, The gdn_prefill_trace template is missing head-ratio validity checks: add constraints ensuring num_v_heads is divisible by num_q_heads and by num_k_heads (e.g., num_v_heads % num_q_heads == 0 and num_v_heads % num_k_heads == 0) so the downstream divisions (num_v_heads // num_q_heads and num_v_heads // num_k_heads) used elsewhere are valid; update the constraints list in gdn_prefill_trace to include these checks referencing the variables num_v_heads, num_q_heads, and num_k_heads.flashinfer/trace/templates/gemm.py (2)
180-217:⚠️ Potential issue | 🟠 MajorThe FP4 trace schema still advertises unpacked shapes.
AandBare packeduint8tensors at runtime, so exposing them as[M, K]and[K, N]makesfi_traceinfer the wrong dimensions for real FP4 calls. Model the packed axes explicitly or add an extractor that maps packed sizes back to logicalK/N.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@flashinfer/trace/templates/gemm.py` around lines 180 - 217, The mm_fp4_trace TraceTemplate currently lists inputs "A" and "B" with unpacked shapes ["M","K"] and ["K","N"], but at runtime these are packed uint8 FP4 buffers; update mm_fp4_trace so the Tensor entries for "A" and "B" describe the packed axes (e.g., K_packed/K_block or bytes per packed row) or add an extractor that converts the packed dimensions back to logical K and N (use the existing "block_size" Var/Scalar to compute K//block_size and N//block_size); specifically modify the Tensor definitions for "A" and "B" in mm_fp4_trace (and any related axis defs such as "K" or "N") so fi_trace will infer correct runtime shapes for FP4-packed inputs.
22-35:⚠️ Potential issue | 🟠 MajorMultiply by the physical
[K, N]weight matrix in these references.Each template models
Bas a physical[K, N]tensor, but the references all call... @ B.T. That breaksmm_bf16as soon asN != Kand skews the quantized references the same way.Suggested fix
def _mm_reference(A, B): - return torch.matmul(A, B.T) + return torch.matmul(A, B) @@ - return torch.matmul(A_fp32, B_fp32.T).to(torch.bfloat16) + return torch.matmul(A_fp32, B_fp32).to(torch.bfloat16) @@ - return torch.matmul(A_scaled, B_scaled.T).to(torch.bfloat16) + return torch.matmul(A_scaled, B_scaled).to(torch.bfloat16) @@ - return torch.matmul(A_scaled, B_scaled.T).to(torch.bfloat16) + return torch.matmul(A_scaled, B_scaled).to(torch.bfloat16)Also applies to: 38-55, 57-86
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@flashinfer/trace/templates/gemm.py` around lines 22 - 35, The reference implementations currently multiply by B.T, but B represents the physical [K, N] weight matrix so using B.T swaps dims and breaks cases where N != K; update _mm_reference to compute torch.matmul(A, B) (not A @ B.T), and in _mm_fp8_reference reshape B into [K, N] (B_fp32 = B.reshape(K_div_bs * block_size, N)) and use torch.matmul(A_fp32, B_fp32) (remove the trailing .T), applying the same fix to the other reference helpers mentioned (the FP8 and bf16 variants in the file).flashinfer/trace/templates/attention.py (3)
140-169:⚠️ Potential issue | 🟠 MajorExpand selected pages before applying the prefill causal window.
The reference currently indexes
k_flat/v_flatwith page ids and setsnum_kv_tokens = page_ids.shape[0], so both the causal window and the gathered KV tensors are off bypage_sizefor real paged caches.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@flashinfer/trace/templates/attention.py` around lines 140 - 169, The code is indexing k_flat/v_flat by page_ids and treating num_kv_tokens as page_ids.shape[0], which is incorrect for paged caches; you must expand page indices to per-token indices before building k_b/v_b and computing num_kv_tokens so the causal window and gathers operate at token granularity. Change the gather so that page_ids are multiplied/expanded by page_size into token_indices (e.g. token_indices = page_ids.unsqueeze(1)*page_size + torch.arange(page_size, device=...)) and then use those token_indices to index the original k_flat/v_flat (or reshape k_cache/v_cache into per-token and gather by token_indices) so k_b/v_b contain all tokens from the selected pages, set num_kv_tokens = token_indices.numel() (or actual token count if last page partial), and adjust uses of max_kv, delta, and slicing (k_b[:max_kv], v_b[:max_kv]) accordingly.
357-385:⚠️ Potential issue | 🟠 MajorThese MLA references still assume
page_size == 1.Both paths call
squeeze(1)on paged caches, but the schema accepts arbitrarypage_sizeand the tests already use larger values like 64. Flatten the page/token dimensions or constrain the template to single-token pages.Also applies to: 476-518
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@flashinfer/trace/templates/attention.py` around lines 357 - 385, The code currently assumes page_size == 1 by calling ckv_cache.squeeze(1) and kpe_cache.squeeze(1) (variables Kc_all/Kp_all) which collapses the page dimension; instead merge the page and token dimensions so arbitrary page_size works: replace the squeeze(1) usage with a reshape/view that flattens the first two dims (e.g. ckv_cache.reshape(-1, head_dim_ckv) and kpe_cache.reshape(-1, head_dim_kpe)) and keep using kv_indices[kv_indptr[b]:kv_indptr[b+1]] (tok_idx) to index into the flattened Kc_all/Kp_all so Kc = Kc_all[tok_idx] and Kp = Kp_all[tok_idx] work for multi-token pages; apply the same change to the other block (lines ~476-518) where ckv_cache/kpe_cache are squeezed.
39-58:⚠️ Potential issue | 🟠 MajorTreat
kv_indicesas page ids in the decode reference.
kv_indicesare documented as page ids, but this code indexes the flattened token buffer with them. That only stays correct whenpage_size == 1; otherwise the reference gathers the wrong KV rows.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@flashinfer/trace/templates/attention.py` around lines 39 - 58, kv_indices are page IDs but the code treats them as token indices when building token_ids; fix by expanding each page id into its page_size token-row indices before indexing k_flat/v_flat. In the loop over b replace the current token_ids = kv_indices[page_start:page_end].to(torch.long) with logic that maps each page id p to the contiguous token index range p*page_size .. (p+1)*page_size-1 (preserving dtype/device), then flatten that to a 1D tensor and use it to build k_b and v_b so k_b/v_b remain shaped [T, num_kv_heads, head_dim]; keep using k_flat/v_flat, kv_indptr, kv_indices, page_size, k_b, v_b, token_ids, and ensure device/torch.long handling remains correct.tests/trace/example.py (1)
54-373:⚠️ Potential issue | 🟠 MajorPytest won’t collect this trace example.
Everything after the env setup runs as import-time side effects, but the file defines no
test_...entrypoint. CI will never exercise trace generation unless this body is moved into a real pytest test and the script entrypoint is kept separate.As per coding guidelines,
tests/**/*.py: Prefix test functions withtest_and structure tests by feature intests/subdirectories matching kernel categories.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/trace/example.py` around lines 54 - 373, The file runs the entire trace-generation body at import time so pytest won't collect it; keep the environment setup (the FLASHINFER_TRACE_* os.environ lines and SAVE_DIR) and the imports as-is but move everything that performs work (starting from device/WORKSPACE and all calls that exercise flashinfer APIs, e.g., the loops that call flashinfer.rmsnorm, flashinfer.mm_bf16, BatchDecodeWithPagedKVCacheWrapper.plan/run, BatchPrefillWithPagedKVCacheWrapper.plan/run, BatchPrefillWithRaggedKVCacheWrapper.plan/run, BatchMLAPagedAttentionWrapper.plan/run, flashinfer.gdn_decode.*, flashinfer.fused_moe.* and the final JSON summary) into a pytest test function named with a test_ prefix (e.g., test_generate_fi_traces) so CI will execute it; also add a minimal if __name__ == "__main__" guard to call that function when run as a script so the example remains runnable standalone.flashinfer/trace/templates/moe.py (3)
648-652:⚠️ Potential issue | 🟡 MinorUse
setattr()to avoid mypyattr-definederror.The direct attribute assignment triggers mypy's
attr-definederror because function objects don't have a declaredtemplatesattribute. Usesetattr()to preserve runtime behavior while satisfying type checking.Suggested fix
# Expose all possible templates so _attach_fi_trace can auto-register them # in _TRACE_REGISTRY for consistency testing. -trtllm_fp8_block_scale_moe_trace_dispatch.templates = list( - _MOE_TRACE_BY_ROUTING_TYPE.values() -) +setattr( + trtllm_fp8_block_scale_moe_trace_dispatch, + "templates", + list(_MOE_TRACE_BY_ROUTING_TYPE.values()), +)🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@flashinfer/trace/templates/moe.py` around lines 648 - 652, The assignment to add a dynamic attribute on the function trtllm_fp8_block_scale_moe_trace_dispatch causes mypy attr-defined errors; replace the direct assignment with a setattr call so the templates attribute is attached at runtime (e.g., setattr(trtllm_fp8_block_scale_moe_trace_dispatch, "templates", list(_MOE_TRACE_BY_ROUTING_TYPE.values()))) to preserve behavior while satisfying type checking.
25-27:⚠️ Potential issue | 🟠 MajorHardcoded
HandIconstants make reference execution shape-fragile.The module-level constants
H=7168andI=2048are used in_fp8_moe_run_expertsbut the actual hidden_size and intermediate_size can vary. This will produce incorrect results or errors for other valid MoE configurations.Suggested fix — derive H and I from tensor shapes
-H = 7168 -I = 2048 BLOCK = 128 `@torch.no_grad`() def _fp8_moe_run_experts( hidden_states, hidden_states_scale, gemm1_weights, gemm1_weights_scale, gemm2_weights, gemm2_weights_scale, weights, topk_idx, local_expert_offset, E_global, ): - T = hidden_states.shape[0] + T, H = hidden_states.shape + I = gemm2_weights.shape[2] E_local = gemm1_weights.shape[0]🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@flashinfer/trace/templates/moe.py` around lines 25 - 27, The module currently hardcodes H=7168 and I=2048 which breaks _fp8_moe_run_experts for models with different hidden/intermediate sizes; change the code to derive hidden_size and intermediate_size at runtime from tensor shapes (e.g., infer hidden_size from the input/hidden tensor shape[-1] or the expert weight shapes, and infer intermediate_size from the feedforward weight/output shapes) and replace uses of H and I with those derived values (also ensure BLOCK is computed/validated against hidden_size if needed); update all references in _fp8_moe_run_experts to use the derived variables so the function works for arbitrary MoE shapes.
126-131:⚠️ Potential issue | 🟠 MajorReference implementations hardcode routing parameters that should be configurable.
TOP_K=8,N_GROUP=8,TOPK_GROUP=4are hardcoded, but the public API accepts these as arguments. If these references are used for numerical validation, they will only be correct for one configuration.If these references are only for schema validation (not numerical correctness), consider adding a comment to clarify their limited scope.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@flashinfer/trace/templates/moe.py` around lines 126 - 131, The template hardcodes routing parameters TOP_K, N_GROUP, TOPK_GROUP which conflict with the public API; update the code that defines TOP_K, N_GROUP, TOPK_GROUP to read the corresponding function arguments (e.g., top_k, n_group, topk_group) or the routing parameters object instead of fixed literals so the template matches whatever configuration is passed via routing_logits' caller (or if these values are truly only for shape/schema checks, replace the literals with a clarifying comment near TOP_K/N_GROUP/TOPK_GROUP stating they are placeholder defaults used only for schema validation and not numeric correctness). Ensure you change the occurrences of TOP_K, N_GROUP, TOPK_GROUP in this module (referenced with routing_logits, E_global, T) accordingly.
🧹 Nitpick comments (8)
flashinfer/trace/template.py (3)
473-473: Consider using list unpacking for slightly cleaner syntax.Ruff suggests
[f"fi_api:{fi_api}", *template.tags]instead of list concatenation.Suggested fix
- all_tags = [f"fi_api:{fi_api}"] + template.tags + all_tags = [f"fi_api:{fi_api}", *template.tags]🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@flashinfer/trace/template.py` at line 473, Replace the list concatenation that builds all_tags with list unpacking for clearer syntax: in the function/block where variable all_tags is assigned (currently using all_tags = [f"fi_api:{fi_api}"] + template.tags), change it to construct the list using [f"fi_api:{fi_api}", *template.tags] so it directly prepends the formatted fi_api tag to template.tags; keep the same variable name and semantics.
426-443: Auto-infer dtype uses first matching input — document this behavior.The auto-inference logic selects the dtype from the first input tensor with overlapping dimension names (line 443
break). This is a reasonable heuristic, but if multiple inputs have overlapping dims with different dtypes, the choice is arbitrary. Consider adding a brief inline comment noting this precedence.Suggested documentation
else: - # Auto-infer: find first input tensor with overlapping dims + # Auto-infer: use dtype from first input tensor with overlapping + # dims. If multiple inputs overlap, precedence follows dict order. dtype = "unknown"🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@flashinfer/trace/template.py` around lines 426 - 443, The auto-infer branch in template.py sets dtype to the first matching input's type (looping over template.inputs, checking Tensor instances and overlapping descriptor.dim_names, using _get_tensor and _dtype_str, then break), which is arbitrary when multiple inputs overlap; add a concise inline comment near this logic (around the loop and the break) stating that this chooses the first matching input's dtype as the precedence rule and that other overlapping inputs may be ignored, so callers should avoid ambiguous multiple-dtype overlaps or explicitly provide dtype to override; keep the comment short and reference template.inputs, Tensor, descriptor, _get_tensor, and _dtype_str.
370-378: Silent exception swallowing may hide bugs during axis extraction.The bare
except Exception: passat lines 376-377 silently ignores all errors during axis value extraction. While this provides robustness, it can hide bugs in extractor logic or unexpected input types. Consider at minimum logging at debug level.Suggested fix
+import logging + +_logger = logging.getLogger(__name__) + # In fi_trace function: for axis_name, extractor in axis_extractors.items(): try: val = extractor(kwargs) if val is not None: axis_values[axis_name] = val - except Exception: - pass + except Exception as exc: + _logger.debug( + "Axis extraction failed for %s: %s", axis_name, exc + )🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@flashinfer/trace/template.py` around lines 370 - 378, The code silently swallows all exceptions when running axis extractor functions (axis_extractors -> extractor(kwargs) populating axis_values), which hides bugs; change the bare "except Exception: pass" to catch the exception as e and log it at debug level (e.g., logger.debug("axis extractor %s failed for kwargs=%s: %s", axis_name, kwargs, e, exc_info=True)) so failures are recorded but extraction remains robust, and if there is no existing logger in this module create one via logging.getLogger(__name__) and import logging.tests/trace/test_fi_trace_template_consistency.py (4)
399-408: Variablekshadows the tensorkdefined earlier.The loop variable
kat line 400 shadows the tensorkdefined at line 391. While this doesn't affect correctness (the tensor is no longer needed at this point), it reduces readability.Suggested fix
non_optional_unknown = [ - k - for k, v in defn["inputs"].items() - if isinstance(v, dict) - and v.get("dtype") == "unknown" - and not v.get("optional", False) + key + for key, val in defn["inputs"].items() + if isinstance(val, dict) + and val.get("dtype") == "unknown" + and not val.get("optional", False) ]🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/trace/test_fi_trace_template_consistency.py` around lines 399 - 408, The loop variable k in the comprehension that builds non_optional_unknown shadows the tensor named k earlier; rename the loop variable (e.g., to input_name or inp_key) used in the comprehension and in the f-string so it no longer collides with the tensor k, updating the comprehension over defn["inputs"].items() and the f"Non-optional inputs with unknown dtype: {...}" reference accordingly.
495-496: Use a raw string for the regex pattern.The pattern contains backslashes and should be a raw string to avoid unintended escapes and satisfy Ruff RUF043.
Suggested fix
- with pytest.raises(AssertionError, match="param=.*hidden_state.*not found"): + with pytest.raises(AssertionError, match=r"param=.*hidden_state.*not found"):🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/trace/test_fi_trace_template_consistency.py` around lines 495 - 496, The regex passed to pytest.raises should be a raw string to avoid accidental escape sequences; update the call to pytest.raises(AssertionError, match=...) used around assert_template_signature_consistency(func, broken, label="meta-test") so the match argument is a raw string literal (prefix it with r, e.g. r"param=.*hidden_state.*not found") to satisfy Ruff RUF043 and ensure the pattern is interpreted correctly.
430-460: Rename ambiguous variableIin the MoE routing test.Ruff flags
Ias ambiguous (E741). Consider renaming tointermediateorinter_sizefor clarity.Suggested fix
- T, E, EL, H, I, BS = 4, 16, 2, 256, 64, 128 + T, E, EL, H, INTER, BS = 4, 16, 2, 256, 64, 128 defn = trtllm_fp8_block_scale_moe.fi_trace( routing_logits=torch.zeros(T, E, dtype=torch.float32), routing_bias=torch.zeros(E, dtype=torch.bfloat16), hidden_states=torch.zeros(T, H, dtype=torch.float8_e4m3fn), hidden_states_scale=torch.ones(H // BS, T, dtype=torch.float32), - gemm1_weights=torch.zeros(EL, 2 * I, H, dtype=torch.float8_e4m3fn), - gemm1_weights_scale=torch.ones(EL, (2 * I) // BS, H // BS, dtype=torch.float32), - gemm2_weights=torch.zeros(EL, H, I, dtype=torch.float8_e4m3fn), - gemm2_weights_scale=torch.ones(EL, H // BS, I // BS, dtype=torch.float32), + gemm1_weights=torch.zeros(EL, 2 * INTER, H, dtype=torch.float8_e4m3fn), + gemm1_weights_scale=torch.ones(EL, (2 * INTER) // BS, H // BS, dtype=torch.float32), + gemm2_weights=torch.zeros(EL, H, INTER, dtype=torch.float8_e4m3fn), + gemm2_weights_scale=torch.ones(EL, H // BS, INTER // BS, dtype=torch.float32), num_experts=E, top_k=top_k, - intermediate_size=I, + intermediate_size=INTER,🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/trace/test_fi_trace_template_consistency.py` around lines 430 - 460, The test function test_fi_trace_complete_moe_routing uses a single-letter variable I (intermediate size) which triggers an ambiguity lint (E741); rename I to a descriptive identifier (e.g., inter_size or intermediate) and update all references inside the function and the fi_trace(...) call (intermediate_size=I, shapes using I, 2 * I etc.) so the values and assertions remain identical but the variable name is clear and matches usage in trtllm_fp8_block_scale_moe.fi_trace.
369-370: Rename ambiguous loop variablelto improve readability.Ruff flags
las ambiguous (E741) because it can be confused with1. Consider renaming tolblorlabel.Suggested fix
-_E2E_PAIRS = [(f, t, l) for f, t, l in _ALL_PAIRS if l not in _E2E_SKIP] +_E2E_PAIRS = [(f, t, lbl) for f, t, lbl in _ALL_PAIRS if lbl not in _E2E_SKIP]🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/trace/test_fi_trace_template_consistency.py` around lines 369 - 370, The list comprehension _E2E_PAIRS uses an ambiguous loop variable named "l"; rename it to a clearer identifier (e.g., "label" or "lbl") in the comprehension and update the subsequent _E2E_IDS comprehension to unpack/use that new name so both _E2E_PAIRS = [(f, t, label) for f, t, label in _ALL_PAIRS if label not in _E2E_SKIP] and _E2E_IDS = [label for _, _, label in _E2E_PAIRS] remain consistent.flashinfer/trace/templates/moe.py (1)
85-88: Rename ambiguous variableOto improve readability.Ruff flags
Oas ambiguous (E741) because it can be confused with0. Consider renaming tooutoroutput_e.Suggested fix
- O = (silu_X2 * X1).matmul(W2[le].t()) + expert_out = (silu_X2 * X1).matmul(W2[le].t()) # per-expert contribution weight for each token w_tok = weights.index_select(0, token_idx) # find which slot in topk_idx[token_idx] corresponds to ge match = (topk_idx.index_select(0, token_idx) == ge).float() w_e = (w_tok * match).sum(dim=1) - output.index_add_(0, token_idx, O * w_e.unsqueeze(1)) + output.index_add_(0, token_idx, expert_out * w_e.unsqueeze(1))🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@flashinfer/trace/templates/moe.py` around lines 85 - 88, Rename the ambiguous variable O used in the moe attention feedforward block to a clearer name (e.g., out or output_e) to avoid confusion with the digit zero; update the assignment and any subsequent uses where O appears (the expression "(silu_X2 * X1).matmul(W2[le].t())") and ensure references to G1, X1, X2, silu_X2, W13, W2, A_e, and le remain correct with the new variable name.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@flashinfer/api_logging.py`:
- Around line 1510-1531: The wrapper created by _attach_fi_trace (and returned
by flashinfer_api) adds runtime cost even when tracing/LOGLEVEL=0; change
_attach_fi_trace so that if tracing is disabled (i.e., _is_trace_dump_enabled()
is False and caller requested zero-overhead) it does not create
&_auto_dump_wrapper but instead attaches fi_trace to the original callable via
setattr(original, "fi_trace", fi_trace_fn) and returns original; otherwise keep
the current wrapper behavior. Also avoid direct attribute assignment on
Callable-typed objects that triggers mypy attr-defined errors by using
setattr(original, "fi_trace", fi_trace_fn) or by casting to Any/creating a small
Protocol for fi_trace to satisfy type-checkers (e.g., cast(original, Any) or
define Protocol with fi_trace) so the pipeline no longer errors.
- Around line 1508-1531: Replace direct attribute assignments to .fi_trace with
setattr to avoid mypy attr-defined errors: where the diff sets wrapped.fi_trace
= fi_trace_fn and _auto_dump_wrapper.fi_trace = fi_trace_fn, change those direct
assignments to use setattr(wrapped, "fi_trace", fi_trace_fn) and
setattr(_auto_dump_wrapper, "fi_trace", fi_trace_fn). Keep the same semantics
(assign the fi_trace_fn callable) and leave other code in _auto_dump_wrapper,
_sig, and fi_trace_fn unchanged.
In `@flashinfer/fi_trace.py`:
- Around line 238-285: The function fi_trace currently types func_or_method as
Callable but relies on the object (actual_func) exposing a .fi_trace attribute;
update the typing to make that contract explicit by introducing a Protocol
(e.g., TracedCallable with a fi_trace(self, save_dir: Optional[Union[str, Path]]
= None, **kwargs) -> Dict[str, Any]) and use that Protocol as the type for
func_or_method (or cast actual_func to TracedCallable before accessing
.fi_trace); ensure the Protocol signature matches how trace_fn is called in
fi_trace and import typing.Protocol and any necessary types so mypy recognizes
the requirement.
- Around line 103-110: The import line bringing in Const, Scalar, Tensor,
TraceTemplate, and Var from .trace.template is unused in build_fi_trace_fn and
causing Ruff F401 warnings; remove those five names (or the whole legacy import
if nothing else from that module is used) so only needed symbols remain imported
in flashinfer/fi_trace.py and eliminate the unused imports Const, Scalar,
Tensor, TraceTemplate, Var from the import statement that currently appears
alongside build_fi_trace_fn.
In `@flashinfer/trace/templates/gdn.py`:
- Around line 502-546: The template schema is missing the disable_state_update
input required by gated_delta_rule_mtp, so add a boolean Tensor/Scalar entry
named "disable_state_update" to the inputs dict (matching how other flags are
represented) and mark it optional or required consistent with
gated_delta_rule_mtp's signature; ensure you reference the symbol name
"disable_state_update" and update the inputs block near the existing
"initial_state"/"final_state" entries so the trace can distinguish
state-updating vs non-updating behavior described in "final_state".
In `@tests/trace/test_fi_trace_template_consistency.py`:
- Around line 384-394: Remove the unused import gqa_paged_decode_trace from the
test; locate the import statement that reads "from
flashinfer.trace.templates.attention import gqa_paged_decode_trace" and delete
it so the test only imports and uses
BatchDecodeWithPagedKVCacheWrapper.run.fi_trace (ensure no other references to
gqa_paged_decode_trace remain in the file).
- Around line 309-321: The import of flashinfer.sampling is unused and flagged
by pre-commit; either remove the import statement for flashinfer.sampling or
ensure it registers decorators used by _TRACE_REGISTRY (so the import has side
effects). Locate the import block containing flashinfer.sampling (near imports
for flashinfer.decode, flashinfer.gdn_decode, etc.) and delete the
flashinfer.sampling line if no decorated functions from that module are expected
to be registered, otherwise import the specific symbols that cause registration
or add a comment explaining the necessary side-effect to avoid removal by
linters.
In `@tests/trace/test_fi_trace.py`:
- Line 20: Remove the unused top-level import of pytest in
tests/trace/test_fi_trace.py: delete the line "import pytest" since the file
relies on pytest fixtures (tmp_path, monkeypatch) provided by pytest's runtime
and does not reference the pytest symbol directly; ensure no other code in the
module uses the pytest name before committing.
---
Duplicate comments:
In `@flashinfer/trace/templates/attention.py`:
- Around line 140-169: The code is indexing k_flat/v_flat by page_ids and
treating num_kv_tokens as page_ids.shape[0], which is incorrect for paged
caches; you must expand page indices to per-token indices before building
k_b/v_b and computing num_kv_tokens so the causal window and gathers operate at
token granularity. Change the gather so that page_ids are multiplied/expanded by
page_size into token_indices (e.g. token_indices =
page_ids.unsqueeze(1)*page_size + torch.arange(page_size, device=...)) and then
use those token_indices to index the original k_flat/v_flat (or reshape
k_cache/v_cache into per-token and gather by token_indices) so k_b/v_b contain
all tokens from the selected pages, set num_kv_tokens = token_indices.numel()
(or actual token count if last page partial), and adjust uses of max_kv, delta,
and slicing (k_b[:max_kv], v_b[:max_kv]) accordingly.
- Around line 357-385: The code currently assumes page_size == 1 by calling
ckv_cache.squeeze(1) and kpe_cache.squeeze(1) (variables Kc_all/Kp_all) which
collapses the page dimension; instead merge the page and token dimensions so
arbitrary page_size works: replace the squeeze(1) usage with a reshape/view that
flattens the first two dims (e.g. ckv_cache.reshape(-1, head_dim_ckv) and
kpe_cache.reshape(-1, head_dim_kpe)) and keep using
kv_indices[kv_indptr[b]:kv_indptr[b+1]] (tok_idx) to index into the flattened
Kc_all/Kp_all so Kc = Kc_all[tok_idx] and Kp = Kp_all[tok_idx] work for
multi-token pages; apply the same change to the other block (lines ~476-518)
where ckv_cache/kpe_cache are squeezed.
- Around line 39-58: kv_indices are page IDs but the code treats them as token
indices when building token_ids; fix by expanding each page id into its
page_size token-row indices before indexing k_flat/v_flat. In the loop over b
replace the current token_ids = kv_indices[page_start:page_end].to(torch.long)
with logic that maps each page id p to the contiguous token index range
p*page_size .. (p+1)*page_size-1 (preserving dtype/device), then flatten that to
a 1D tensor and use it to build k_b and v_b so k_b/v_b remain shaped [T,
num_kv_heads, head_dim]; keep using k_flat/v_flat, kv_indptr, kv_indices,
page_size, k_b, v_b, token_ids, and ensure device/torch.long handling remains
correct.
In `@flashinfer/trace/templates/gdn.py`:
- Around line 165-169: The schema currently sets the attention "output" Tensor
using dtype_from="q", which misreports dtype because outputs are cast to
torch.bfloat16; update the Tensor definition for the "output" field in the GDN
templates to use an explicit dtype of "bfloat16" (replace dtype_from="q" with
dtype="bfloat16") for the occurrences around the shown block and the other two
occurrences (near lines 351-355 and 537-541) so the trace metadata correctly
reflects torch.bfloat16 outputs.
- Around line 421-458: The loop updates state_HVK per batch/state but
final_state is created from initial_state and never updated, so return value is
stale; update final_state with the pooled (transposed) state_HVK for each
corresponding state index (initial_state_indices) after finishing updates for
that state (or after the outer loops) so final_state[state_idx] =
state_HVK.transpose(-1, -2) (match the same [H,V,K] ↔ [H,K,V] orientation used
for initial_state/state_HVK) before returning output and final_state.
- Around line 362-365: The gdn_prefill_trace template is missing head-ratio
validity checks: add constraints ensuring num_v_heads is divisible by
num_q_heads and by num_k_heads (e.g., num_v_heads % num_q_heads == 0 and
num_v_heads % num_k_heads == 0) so the downstream divisions (num_v_heads //
num_q_heads and num_v_heads // num_k_heads) used elsewhere are valid; update the
constraints list in gdn_prefill_trace to include these checks referencing the
variables num_v_heads, num_q_heads, and num_k_heads.
In `@flashinfer/trace/templates/gemm.py`:
- Around line 180-217: The mm_fp4_trace TraceTemplate currently lists inputs "A"
and "B" with unpacked shapes ["M","K"] and ["K","N"], but at runtime these are
packed uint8 FP4 buffers; update mm_fp4_trace so the Tensor entries for "A" and
"B" describe the packed axes (e.g., K_packed/K_block or bytes per packed row) or
add an extractor that converts the packed dimensions back to logical K and N
(use the existing "block_size" Var/Scalar to compute K//block_size and
N//block_size); specifically modify the Tensor definitions for "A" and "B" in
mm_fp4_trace (and any related axis defs such as "K" or "N") so fi_trace will
infer correct runtime shapes for FP4-packed inputs.
- Around line 22-35: The reference implementations currently multiply by B.T,
but B represents the physical [K, N] weight matrix so using B.T swaps dims and
breaks cases where N != K; update _mm_reference to compute torch.matmul(A, B)
(not A @ B.T), and in _mm_fp8_reference reshape B into [K, N] (B_fp32 =
B.reshape(K_div_bs * block_size, N)) and use torch.matmul(A_fp32, B_fp32)
(remove the trailing .T), applying the same fix to the other reference helpers
mentioned (the FP8 and bf16 variants in the file).
In `@flashinfer/trace/templates/moe.py`:
- Around line 648-652: The assignment to add a dynamic attribute on the function
trtllm_fp8_block_scale_moe_trace_dispatch causes mypy attr-defined errors;
replace the direct assignment with a setattr call so the templates attribute is
attached at runtime (e.g., setattr(trtllm_fp8_block_scale_moe_trace_dispatch,
"templates", list(_MOE_TRACE_BY_ROUTING_TYPE.values()))) to preserve behavior
while satisfying type checking.
- Around line 25-27: The module currently hardcodes H=7168 and I=2048 which
breaks _fp8_moe_run_experts for models with different hidden/intermediate sizes;
change the code to derive hidden_size and intermediate_size at runtime from
tensor shapes (e.g., infer hidden_size from the input/hidden tensor shape[-1] or
the expert weight shapes, and infer intermediate_size from the feedforward
weight/output shapes) and replace uses of H and I with those derived values
(also ensure BLOCK is computed/validated against hidden_size if needed); update
all references in _fp8_moe_run_experts to use the derived variables so the
function works for arbitrary MoE shapes.
- Around line 126-131: The template hardcodes routing parameters TOP_K, N_GROUP,
TOPK_GROUP which conflict with the public API; update the code that defines
TOP_K, N_GROUP, TOPK_GROUP to read the corresponding function arguments (e.g.,
top_k, n_group, topk_group) or the routing parameters object instead of fixed
literals so the template matches whatever configuration is passed via
routing_logits' caller (or if these values are truly only for shape/schema
checks, replace the literals with a clarifying comment near
TOP_K/N_GROUP/TOPK_GROUP stating they are placeholder defaults used only for
schema validation and not numeric correctness). Ensure you change the
occurrences of TOP_K, N_GROUP, TOPK_GROUP in this module (referenced with
routing_logits, E_global, T) accordingly.
In `@tests/trace/example.py`:
- Around line 54-373: The file runs the entire trace-generation body at import
time so pytest won't collect it; keep the environment setup (the
FLASHINFER_TRACE_* os.environ lines and SAVE_DIR) and the imports as-is but move
everything that performs work (starting from device/WORKSPACE and all calls that
exercise flashinfer APIs, e.g., the loops that call flashinfer.rmsnorm,
flashinfer.mm_bf16, BatchDecodeWithPagedKVCacheWrapper.plan/run,
BatchPrefillWithPagedKVCacheWrapper.plan/run,
BatchPrefillWithRaggedKVCacheWrapper.plan/run,
BatchMLAPagedAttentionWrapper.plan/run, flashinfer.gdn_decode.*,
flashinfer.fused_moe.* and the final JSON summary) into a pytest test function
named with a test_ prefix (e.g., test_generate_fi_traces) so CI will execute it;
also add a minimal if __name__ == "__main__" guard to call that function when
run as a script so the example remains runnable standalone.
---
Nitpick comments:
In `@flashinfer/trace/template.py`:
- Line 473: Replace the list concatenation that builds all_tags with list
unpacking for clearer syntax: in the function/block where variable all_tags is
assigned (currently using all_tags = [f"fi_api:{fi_api}"] + template.tags),
change it to construct the list using [f"fi_api:{fi_api}", *template.tags] so it
directly prepends the formatted fi_api tag to template.tags; keep the same
variable name and semantics.
- Around line 426-443: The auto-infer branch in template.py sets dtype to the
first matching input's type (looping over template.inputs, checking Tensor
instances and overlapping descriptor.dim_names, using _get_tensor and
_dtype_str, then break), which is arbitrary when multiple inputs overlap; add a
concise inline comment near this logic (around the loop and the break) stating
that this chooses the first matching input's dtype as the precedence rule and
that other overlapping inputs may be ignored, so callers should avoid ambiguous
multiple-dtype overlaps or explicitly provide dtype to override; keep the
comment short and reference template.inputs, Tensor, descriptor, _get_tensor,
and _dtype_str.
- Around line 370-378: The code silently swallows all exceptions when running
axis extractor functions (axis_extractors -> extractor(kwargs) populating
axis_values), which hides bugs; change the bare "except Exception: pass" to
catch the exception as e and log it at debug level (e.g., logger.debug("axis
extractor %s failed for kwargs=%s: %s", axis_name, kwargs, e, exc_info=True)) so
failures are recorded but extraction remains robust, and if there is no existing
logger in this module create one via logging.getLogger(__name__) and import
logging.
In `@flashinfer/trace/templates/moe.py`:
- Around line 85-88: Rename the ambiguous variable O used in the moe attention
feedforward block to a clearer name (e.g., out or output_e) to avoid confusion
with the digit zero; update the assignment and any subsequent uses where O
appears (the expression "(silu_X2 * X1).matmul(W2[le].t())") and ensure
references to G1, X1, X2, silu_X2, W13, W2, A_e, and le remain correct with the
new variable name.
In `@tests/trace/test_fi_trace_template_consistency.py`:
- Around line 399-408: The loop variable k in the comprehension that builds
non_optional_unknown shadows the tensor named k earlier; rename the loop
variable (e.g., to input_name or inp_key) used in the comprehension and in the
f-string so it no longer collides with the tensor k, updating the comprehension
over defn["inputs"].items() and the f"Non-optional inputs with unknown dtype:
{...}" reference accordingly.
- Around line 495-496: The regex passed to pytest.raises should be a raw string
to avoid accidental escape sequences; update the call to
pytest.raises(AssertionError, match=...) used around
assert_template_signature_consistency(func, broken, label="meta-test") so the
match argument is a raw string literal (prefix it with r, e.g.
r"param=.*hidden_state.*not found") to satisfy Ruff RUF043 and ensure the
pattern is interpreted correctly.
- Around line 430-460: The test function test_fi_trace_complete_moe_routing uses
a single-letter variable I (intermediate size) which triggers an ambiguity lint
(E741); rename I to a descriptive identifier (e.g., inter_size or intermediate)
and update all references inside the function and the fi_trace(...) call
(intermediate_size=I, shapes using I, 2 * I etc.) so the values and assertions
remain identical but the variable name is clear and matches usage in
trtllm_fp8_block_scale_moe.fi_trace.
- Around line 369-370: The list comprehension _E2E_PAIRS uses an ambiguous loop
variable named "l"; rename it to a clearer identifier (e.g., "label" or "lbl")
in the comprehension and update the subsequent _E2E_IDS comprehension to
unpack/use that new name so both _E2E_PAIRS = [(f, t, label) for f, t, label in
_ALL_PAIRS if label not in _E2E_SKIP] and _E2E_IDS = [label for _, _, label in
_E2E_PAIRS] remain consistent.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: a970c3e2-8b74-4f9f-b236-d08657910713
📒 Files selected for processing (32)
flashinfer/api_logging.pyflashinfer/fi_trace.pyflashinfer/trace/template.pyflashinfer/trace/templates/attention.pyflashinfer/trace/templates/gdn.pyflashinfer/trace/templates/gemm.pyflashinfer/trace/templates/moe.pytests/trace/example.pytests/trace/fi_trace_out/fused_add_rmsnorm_h5120.jsontests/trace/fi_trace_out/gdn_decode_qk4_v8_d128.jsontests/trace/fi_trace_out/gdn_mtp_qk4_v8_d128.jsontests/trace/fi_trace_out/gemm_bf16_N256_K7168.jsontests/trace/fi_trace_out/gemm_bf16_N4096_K4096.jsontests/trace/fi_trace_out/gemm_fp4_N2048_K7168_block_size16.jsontests/trace/fi_trace_out/gemm_fp8_N1536_K7168.jsontests/trace/fi_trace_out/gemm_mxfp8_N4096_K4096.jsontests/trace/fi_trace_out/gqa_paged_decode_h32_kv8_d128_ps16.jsontests/trace/fi_trace_out/gqa_paged_decode_h32_kv8_d128_ps64.jsontests/trace/fi_trace_out/gqa_paged_prefill_h32_kv8_d128_ps16.jsontests/trace/fi_trace_out/gqa_ragged_h32_kv8_d128.jsontests/trace/fi_trace_out/mla_paged_decode_h16_ckv512_kpe64_ps1.jsontests/trace/fi_trace_out/mla_paged_decode_h16_ckv512_kpe64_ps64.jsontests/trace/fi_trace_out/moe_fp8_block_scale_default_routing_topk8_e32_h7168_i2048.jsontests/trace/fi_trace_out/rmsnorm_h4096.jsontests/trace/fi_trace_out/rmsnorm_h7168.jsontests/trace/fi_trace_out/top_k_sampling_v128256.jsontests/trace/fi_trace_out/top_k_top_p_sampling_v128256.jsontests/trace/fi_trace_out/top_k_top_p_sampling_v151936.jsontests/trace/fi_trace_out/top_p_sampling_v128256.jsontests/trace/fi_trace_out/top_p_sampling_v151936.jsontests/trace/test_fi_trace.pytests/trace/test_fi_trace_template_consistency.py
✅ Files skipped from review due to trivial changes (15)
- tests/trace/fi_trace_out/gemm_bf16_N256_K7168.json
- tests/trace/fi_trace_out/rmsnorm_h4096.json
- tests/trace/fi_trace_out/rmsnorm_h7168.json
- tests/trace/fi_trace_out/gemm_bf16_N4096_K4096.json
- tests/trace/fi_trace_out/gemm_fp8_N1536_K7168.json
- tests/trace/fi_trace_out/fused_add_rmsnorm_h5120.json
- tests/trace/fi_trace_out/top_k_top_p_sampling_v128256.json
- tests/trace/fi_trace_out/gemm_mxfp8_N4096_K4096.json
- tests/trace/fi_trace_out/top_p_sampling_v151936.json
- tests/trace/fi_trace_out/top_k_sampling_v128256.json
- tests/trace/fi_trace_out/gqa_paged_decode_h32_kv8_d128_ps64.json
- tests/trace/fi_trace_out/top_k_top_p_sampling_v151936.json
- tests/trace/fi_trace_out/mla_paged_decode_h16_ckv512_kpe64_ps64.json
- tests/trace/fi_trace_out/mla_paged_decode_h16_ckv512_kpe64_ps1.json
- tests/trace/fi_trace_out/gqa_paged_prefill_h32_kv8_d128_ps16.json
🚧 Files skipped from review as they are similar to previous changes (6)
- tests/trace/fi_trace_out/top_p_sampling_v128256.json
- tests/trace/fi_trace_out/gemm_fp4_N2048_K7168_block_size16.json
- tests/trace/fi_trace_out/gqa_ragged_h32_kv8_d128.json
- tests/trace/fi_trace_out/gdn_decode_qk4_v8_d128.json
- tests/trace/fi_trace_out/gqa_paged_decode_h32_kv8_d128_ps16.json
- tests/trace/fi_trace_out/moe_fp8_block_scale_default_routing_topk8_e32_h7168_i2048.json
| wrapped.fi_trace = fi_trace_fn | ||
|
|
||
| # Auto-dump wrapper: checked lazily at call time so that callers | ||
| # can set FLASHINFER_TRACE_DUMP after importing flashinfer (e.g. | ||
| # when running via ``python -m``). | ||
| _inner = wrapped | ||
| _sig = inspect.signature(original) | ||
|
|
||
| @functools.wraps(_inner) | ||
| def _auto_dump_wrapper(*args, **kwargs): | ||
| # Generate trace BEFORE the actual call (crash-safe: schema | ||
| # depends only on input shapes/dtypes, not on whether the | ||
| # computation succeeds). | ||
| if _is_trace_dump_enabled(): | ||
| try: | ||
| bound = _sig.bind(*args, **kwargs) | ||
| bound.apply_defaults() | ||
| fi_trace_fn(**dict(bound.arguments)) | ||
| except Exception: | ||
| pass | ||
| return _inner(*args, **kwargs) | ||
|
|
||
| _auto_dump_wrapper.fi_trace = fi_trace_fn | ||
| return _auto_dump_wrapper |
There was a problem hiding this comment.
🧩 Analysis chain
🏁 Script executed:
find . -maxdepth 5 -name "*.py" -path "*flashinfer*" | grep -i api | head -10Repository: flashinfer-ai/flashinfer
Length of output: 95
🏁 Script executed:
ls -la flashinfer/ 2>/dev/null | head -20 || echo "Directory not found"Repository: flashinfer-ai/flashinfer
Length of output: 1275
🏁 Script executed:
git ls-files "*api_logging*"Repository: flashinfer-ai/flashinfer
Length of output: 93
🏁 Script executed:
sed -n '1508,1540p' flashinfer/api_logging.pyRepository: flashinfer-ai/flashinfer
Length of output: 1518
🏁 Script executed:
grep -n "\.fi_trace\s*=" flashinfer/api_logging.py | head -20Repository: flashinfer-ai/flashinfer
Length of output: 238
🏁 Script executed:
sed -n '1490,1545p' flashinfer/api_logging.pyRepository: flashinfer-ai/flashinfer
Length of output: 2559
🏁 Script executed:
head -100 flashinfer/api_logging.py | grep -E "^(import|from|class|def|__all__)"Repository: flashinfer-ai/flashinfer
Length of output: 406
🏁 Script executed:
grep -B 30 "wrapped.fi_trace = fi_trace_fn" flashinfer/api_logging.py | head -50Repository: flashinfer-ai/flashinfer
Length of output: 1594
🏁 Script executed:
sed -n '1450,1510p' flashinfer/api_logging.pyRepository: flashinfer-ai/flashinfer
Length of output: 3083
🏁 Script executed:
find . -name "pyproject.toml" -o -name "setup.cfg" -o -name "mypy.ini" | head -5Repository: flashinfer-ai/flashinfer
Length of output: 156
🏁 Script executed:
grep -n "def.*wrapped" flashinfer/api_logging.py | head -20Repository: flashinfer-ai/flashinfer
Length of output: 50
🏁 Script executed:
cat pyproject.toml | grep -A 20 "\[tool.mypy\]"Repository: flashinfer-ai/flashinfer
Length of output: 429
🏁 Script executed:
cat pyproject.toml | grep -A 20 "lint\|mypy"Repository: flashinfer-ai/flashinfer
Length of output: 712
🏁 Script executed:
sed -n '1450,1560p' flashinfer/api_logging.py | cat -nRepository: flashinfer-ai/flashinfer
Length of output: 6002
🏁 Script executed:
sed -n '1450,1545p' flashinfer/api_logging.py | head -100Repository: flashinfer-ai/flashinfer
Length of output: 4750
🏁 Script executed:
grep -n "def flashinfer_api\|def _attach_fi_trace" flashinfer/api_logging.pyRepository: flashinfer-ai/flashinfer
Length of output: 169
🏁 Script executed:
python3 -m mypy flashinfer/api_logging.py --no-error-summary 2>&1 | grep -i "fi_trace\|attr-defined" | head -20Repository: flashinfer-ai/flashinfer
Length of output: 50
🏁 Script executed:
sed -n '1435,1460p' flashinfer/api_logging.pyRepository: flashinfer-ai/flashinfer
Length of output: 1278
🏁 Script executed:
sed -n '1557,1620p' flashinfer/api_logging.pyRepository: flashinfer-ai/flashinfer
Length of output: 2910
🏁 Script executed:
python3 << 'EOF'
# Check what mypy would say about direct assignment to Callable
import sys
import tempfile
import os
code = '''
from typing import Callable
def my_decorator(original: Callable) -> Callable:
wrapped: Callable = original
wrapped.fi_trace = lambda: None # This is the issue
return wrapped
'''
with tempfile.NamedTemporaryFile(mode='w', suffix='.py', delete=False) as f:
f.write(code)
temp_file = f.name
try:
os.system(f'python3 -m mypy {temp_file} 2>&1 | head -20')
finally:
os.unlink(temp_file)
EOFRepository: flashinfer-ai/flashinfer
Length of output: 106
🏁 Script executed:
grep -n "fi_trace" flashinfer/api_logging.py | head -30Repository: flashinfer-ai/flashinfer
Length of output: 1728
🏁 Script executed:
grep -n "class.*Protocol\|@.*Protocol" flashinfer/api_logging.pyRepository: flashinfer-ai/flashinfer
Length of output: 50
🏁 Script executed:
grep -B 5 "def _attach_fi_trace" flashinfer/api_logging.py | head -10Repository: flashinfer-ai/flashinfer
Length of output: 287
🏁 Script executed:
sed -n '1400,1445p' flashinfer/api_logging.pyRepository: flashinfer-ai/flashinfer
Length of output: 1579
🏁 Script executed:
python3 << 'EOF'
# Check if setattr vs direct assignment matters for mypy
code1 = '''
from typing import Callable
def decorator(func: Callable) -> Callable:
func.fi_trace = lambda: None # Direct assignment - mypy error
return func
'''
code2 = '''
from typing import Callable
def decorator(func: Callable) -> Callable:
setattr(func, "fi_trace", lambda: None) # setattr - mypy OK
return func
'''
print("Direct assignment would trigger attr-defined error on Callable types")
print("setattr() is the mypy-friendly approach for dynamic attributes")
EOFRepository: flashinfer-ai/flashinfer
Length of output: 199
🏁 Script executed:
grep -C 3 "wrapped: Callable\|original: Callable" flashinfer/api_logging.py | head -30Repository: flashinfer-ai/flashinfer
Length of output: 230
Use setattr() for dynamic attribute assignment to avoid mypy attr-defined errors.
Direct assignment of .fi_trace to Callable-typed variables triggers mypy's attr-defined check. Replace direct assignment with setattr() to explicitly model the dynamic attribute:
Required changes
- wrapped.fi_trace = fi_trace_fn
+ setattr(wrapped, "fi_trace", fi_trace_fn)
@@
- _auto_dump_wrapper.fi_trace = fi_trace_fn
+ setattr(_auto_dump_wrapper, "fi_trace", fi_trace_fn)
@@
- wrapped.fi_trace = build_fi_trace_fn(spec)
+ setattr(wrapped, "fi_trace", build_fi_trace_fn(spec))🧰 Tools
🪛 GitHub Actions: pre-commit
[error] 1508-1539: mypy error(s): "Callable[..., Any]" has no attribute "fi_trace" / "_Wrapped[...]" has no attribute "fi_trace" [attr-defined]
🪛 Ruff (0.15.9)
[error] 1526-1527: try-except-pass detected, consider logging the exception
(S110)
[warning] 1526-1526: Do not catch blind exception: Exception
(BLE001)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@flashinfer/api_logging.py` around lines 1508 - 1531, Replace direct attribute
assignments to .fi_trace with setattr to avoid mypy attr-defined errors: where
the diff sets wrapped.fi_trace = fi_trace_fn and _auto_dump_wrapper.fi_trace =
fi_trace_fn, change those direct assignments to use setattr(wrapped, "fi_trace",
fi_trace_fn) and setattr(_auto_dump_wrapper, "fi_trace", fi_trace_fn). Keep the
same semantics (assign the fi_trace_fn callable) and leave other code in
_auto_dump_wrapper, _sig, and fi_trace_fn unchanged.
| # Import the old implementation from the trace package for backwards compat. | ||
| from .trace.template import ( # noqa: PLC0415 | ||
| Const, | ||
| Scalar, | ||
| Tensor, | ||
| TraceTemplate, | ||
| Var, | ||
| ) |
There was a problem hiding this comment.
🧩 Analysis chain
🏁 Script executed:
# First, check if the file exists and examine the relevant lines
if [ -f "flashinfer/fi_trace.py" ]; then
echo "=== File exists, examining lines 100-115 ==="
sed -n '100,115p' flashinfer/fi_trace.py | cat -n
else
echo "File not found"
fiRepository: flashinfer-ai/flashinfer
Length of output: 675
🏁 Script executed:
# Search for usages of the imported names in the file
echo "=== Searching for usages of Const, Scalar, Tensor, TraceTemplate, Var ==="
for name in Const Scalar Tensor TraceTemplate Var; do
echo "--- Searching for '$name' ---"
rg "\b${name}\b" flashinfer/fi_trace.py -n | head -20
doneRepository: flashinfer-ai/flashinfer
Length of output: 732
🏁 Script executed:
# Check the full file to understand the import context
echo "=== Counting lines and examining structure ==="
wc -l flashinfer/fi_trace.py
echo ""
echo "=== Checking for function definitions ==="
rg "^def " flashinfer/fi_trace.py -nRepository: flashinfer-ai/flashinfer
Length of output: 332
🏁 Script executed:
# Examine the build_fi_trace_fn function in detail to see if these imports are used
echo "=== Examining build_fi_trace_fn function (lines 97-237) ==="
sed -n '97,237p' flashinfer/fi_trace.py | head -80Repository: flashinfer-ai/flashinfer
Length of output: 2919
🏁 Script executed:
# Check if there are any references to these imports after line 110
echo "=== Checking if imports are used after definition ==="
sed -n '110,237p' flashinfer/fi_trace.py | grep -E "Const|Scalar|Tensor[^_]|TraceTemplate|Var" | grep -v "torch.Tensor" | grep -v "isinstance"Repository: flashinfer-ai/flashinfer
Length of output: 121
🏁 Script executed:
# Verify the file can be linted - check if Ruff would report F401
echo "=== Checking for Ruff/linting setup ==="
if command -v ruff &> /dev/null; then
echo "Ruff is available, checking for F401 violations..."
ruff check flashinfer/fi_trace.py --select F401 2>&1 | head -20
else
echo "Ruff not available in sandbox"
fiRepository: flashinfer-ai/flashinfer
Length of output: 1078
Remove unused template imports from the legacy builder.
The names Const, Scalar, Tensor, TraceTemplate, and Var imported at lines 105–109 are not referenced anywhere in build_fi_trace_fn. Ruff F401 violations are already flagged on these lines.
🧰 Tools
🪛 GitHub Actions: pre-commit
[error] 105-105: ruff (F401): .trace.template.Const imported but unused
[error] 106-106: ruff (F401): .trace.template.Scalar imported but unused
[error] 107-107: ruff (F401): .trace.template.Tensor imported but unused
[error] 108-108: ruff (F401): .trace.template.TraceTemplate imported but unused
[error] 109-109: ruff (F401): .trace.template.Var imported but unused
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@flashinfer/fi_trace.py` around lines 103 - 110, The import line bringing in
Const, Scalar, Tensor, TraceTemplate, and Var from .trace.template is unused in
build_fi_trace_fn and causing Ruff F401 warnings; remove those five names (or
the whole legacy import if nothing else from that module is used) so only needed
symbols remain imported in flashinfer/fi_trace.py and eliminate the unused
imports Const, Scalar, Tensor, TraceTemplate, Var from the import statement that
currently appears alongside build_fi_trace_fn.
| def fi_trace( | ||
| func_or_method: Callable, | ||
| save_dir: Optional[Union[str, Path]] = None, | ||
| **kwargs: Any, | ||
| ) -> Dict[str, Any]: | ||
| """Generate a flashinfer-bench definition JSON for any FlashInfer API call. | ||
|
|
||
| Parameters | ||
| ---------- | ||
| func_or_method: | ||
| A ``@flashinfer_api``-decorated function or (bound) method. | ||
| save_dir: | ||
| Directory where the JSON definition file should be written. | ||
| Falls back to ``FLASHINFER_TRACE_DUMP_DIR`` env-var when *None*. | ||
| **kwargs: | ||
| The same tensor arguments you would pass to the real API. | ||
|
|
||
| Returns | ||
| ------- | ||
| dict | ||
| A flashinfer-bench compatible definition dictionary. | ||
|
|
||
| Examples | ||
| -------- | ||
| Standalone function:: | ||
|
|
||
| defn = fi_trace(flashinfer.norm.rmsnorm, input=hidden, weight=weight) | ||
|
|
||
| Bound method (instance.run):: | ||
|
|
||
| defn = fi_trace(wrapper.run, q=q_tensor, paged_kv_cache=(k, v)) | ||
|
|
||
| Class-level (unbound):: | ||
|
|
||
| defn = fi_trace( | ||
| flashinfer.decode.BatchDecodeWithPagedKVCacheWrapper.run, | ||
| q=q_tensor, paged_kv_cache=(k, v), | ||
| ) | ||
| """ | ||
| actual_func = getattr(func_or_method, "__func__", func_or_method) | ||
| trace_fn = getattr(actual_func, "fi_trace", None) | ||
| if trace_fn is None: | ||
| qualname = getattr(actual_func, "__qualname__", repr(actual_func)) | ||
| raise ValueError( | ||
| f"No fi_trace spec is registered for '{qualname}'. " | ||
| "Only @flashinfer_api(trace=...)-decorated functions support fi_trace." | ||
| ) | ||
| return trace_fn(save_dir=save_dir, **kwargs) |
There was a problem hiding this comment.
🧩 Analysis chain
🏁 Script executed:
# First, let's find and examine the fi_trace.py file
find . -name "fi_trace.py" -type fRepository: flashinfer-ai/flashinfer
Length of output: 92
🏁 Script executed:
# Count lines in the file to understand its size
wc -l flashinfer/fi_trace.pyRepository: flashinfer-ai/flashinfer
Length of output: 94
🏁 Script executed:
# Read the function in question with context
sed -n '230,295p' flashinfer/fi_trace.py | cat -nRepository: flashinfer-ai/flashinfer
Length of output: 2241
🏁 Script executed:
# Check the imports at the top of the file to see what type annotations are available
head -40 flashinfer/fi_trace.py | cat -nRepository: flashinfer-ai/flashinfer
Length of output: 1954
🏁 Script executed:
# Search for any Protocol definitions or type hints related to fi_trace in the codebase
rg "Protocol|fi_trace" flashinfer/*.py --type py -B 2 -A 2 | head -100Repository: flashinfer-ai/flashinfer
Length of output: 5252
🏁 Script executed:
# Check if there's a pyproject.toml or setup.cfg with mypy configuration
find . -name "pyproject.toml" -o -name "setup.cfg" -o -name "mypy.ini" | head -5Repository: flashinfer-ai/flashinfer
Length of output: 156
🏁 Script executed:
# Read full imports section of fi_trace.py
head -80 flashinfer/fi_trace.py | cat -nRepository: flashinfer-ai/flashinfer
Length of output: 3516
🏁 Script executed:
# Check pyproject.toml for mypy configuration
cat pyproject.toml | grep -A 20 "tool.mypy" | head -30Repository: flashinfer-ai/flashinfer
Length of output: 429
🏁 Script executed:
# Search for existing Protocol definitions in the codebase
rg "from typing import|Protocol" flashinfer/*.py --type py -B 1 -A 1 | head -80Repository: flashinfer-ai/flashinfer
Length of output: 2994
🏁 Script executed:
# Check if there are any type-related issues or comments in fi_trace.py
rg "type:|Protocol|HasAttr" flashinfer/fi_trace.py --type py -B 2 -A 2Repository: flashinfer-ai/flashinfer
Length of output: 50
🏁 Script executed:
# Verify the exact issue by looking at the complete function implementation once more
sed -n '238,256p' flashinfer/fi_trace.py | cat -nRepository: flashinfer-ai/flashinfer
Length of output: 792
🏁 Script executed:
# Check if there's a typing.Protocol or typing_extensions.Protocol import anywhere
rg "^from typing|^import typing" flashinfer/fi_trace.pyRepository: flashinfer-ai/flashinfer
Length of output: 123
Make the .fi_trace attribute requirement explicit in the type signature.
The function expects func_or_method (or its __func__) to have a .fi_trace attribute, but the signature declares plain Callable. Define a Protocol for this contract (e.g., class TracedCallable(Protocol): def fi_trace(...) -> Dict[str, Any]: ...) or use a cast when resolving the actual function. This will satisfy mypy and make the requirement clear to callers.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@flashinfer/fi_trace.py` around lines 238 - 285, The function fi_trace
currently types func_or_method as Callable but relies on the object
(actual_func) exposing a .fi_trace attribute; update the typing to make that
contract explicit by introducing a Protocol (e.g., TracedCallable with a
fi_trace(self, save_dir: Optional[Union[str, Path]] = None, **kwargs) ->
Dict[str, Any]) and use that Protocol as the type for func_or_method (or cast
actual_func to TracedCallable before accessing .fi_trace); ensure the Protocol
signature matches how trace_fn is called in fi_trace and import typing.Protocol
and any necessary types so mypy recognizes the requirement.
| "initial_state": Tensor( | ||
| ["pool_size", "num_v_heads", "head_size", "head_size"], | ||
| description="Initial recurrent state pool in k-last layout [pool_size, H, V, K].", | ||
| ), | ||
| "initial_state_indices": Tensor( | ||
| ["batch_size"], | ||
| description="Indices mapping each batch to its initial state in the pool.", | ||
| ), | ||
| "A_log": Tensor( | ||
| ["num_v_heads"], | ||
| description="Log decay parameter (learnable). Used to compute g = exp(-exp(A_log) * softplus(a + dt_bias)).", | ||
| ), | ||
| "a": Tensor( | ||
| ["batch_size", "seq_len", "num_v_heads"], | ||
| description="Input-dependent decay from projection.", | ||
| ), | ||
| "dt_bias": Tensor( | ||
| ["num_v_heads"], | ||
| description="Decay bias (learnable). Added to 'a' before softplus.", | ||
| ), | ||
| "b": Tensor( | ||
| ["batch_size", "seq_len", "num_v_heads"], | ||
| description="Update gate input from projection. beta = sigmoid(b).", | ||
| ), | ||
| "scale": Scalar( | ||
| "float32", | ||
| description="Scale factor. Default is 1/sqrt(head_size).", | ||
| ), | ||
| "intermediate_states_buffer": Tensor( | ||
| ["pool_size", "seq_len", "num_v_heads", "head_size", "head_size"], | ||
| optional=True, | ||
| description="Optional buffer for caching intermediate states for potential rollback.", | ||
| ), | ||
| }, | ||
| outputs={ | ||
| "output": Tensor( | ||
| ["batch_size", "seq_len", "num_v_heads", "head_size"], | ||
| dtype_from="q", | ||
| description="Attention output for all T tokens. Shape follows num_v_heads in GVA mode.", | ||
| ), | ||
| "final_state": Tensor( | ||
| ["pool_size", "num_v_heads", "head_size", "head_size"], | ||
| dtype="float32", | ||
| description="Updated recurrent state pool in k-last layout [pool_size, H, V, K]. Unchanged if disable_state_update=True.", | ||
| ), |
There was a problem hiding this comment.
🧩 Analysis chain
🏁 Script executed:
# First, find the gated_delta_rule_mtp function definition
rg "def gated_delta_rule_mtp" -A 20Repository: flashinfer-ai/flashinfer
Length of output: 3167
🏁 Script executed:
# Also look at the template file to understand its full context
fd "gdn.py" --type fRepository: flashinfer-ai/flashinfer
Length of output: 123
🏁 Script executed:
# Check for references to disable_state_update in the codebase
rg "disable_state_update" -A 2 -B 2Repository: flashinfer-ai/flashinfer
Length of output: 49287
🏁 Script executed:
# Get the full template definition including inputs
sed -n '450,550p' flashinfer/trace/templates/gdn.pyRepository: flashinfer-ai/flashinfer
Length of output: 4136
🏁 Script executed:
# Look for where disable_state_update might be defined elsewhere in the template
rg "disable_state_update" flashinfer/trace/templates/gdn.py -B 5 -A 5Repository: flashinfer-ai/flashinfer
Length of output: 558
🏁 Script executed:
# Check if there are other GDN-related templates that might model disable_state_update
fd "gdn" flashinfer/trace/templates/ -type fRepository: flashinfer-ai/flashinfer
Length of output: 238
🏁 Script executed:
# Look for all trace templates in the gdn area
find flashinfer/trace/templates -name "*gdn*" -o -name "*delta*"Repository: flashinfer-ai/flashinfer
Length of output: 101
🏁 Script executed:
# Check what trace templates exist
ls -la flashinfer/trace/templates/Repository: flashinfer-ai/flashinfer
Length of output: 598
🏁 Script executed:
# Check the gdn_decode.py function signature to see all parameters
sed -n '1,150p' flashinfer/gdn_decode.py | grep -A 50 "def gated_delta_rule_mtp"Repository: flashinfer-ai/flashinfer
Length of output: 50
🏁 Script executed:
# Get the exact parameter list from gdn_decode.py
grep -A 15 "def gated_delta_rule_mtp" flashinfer/gdn_decode.py | head -20Repository: flashinfer-ai/flashinfer
Length of output: 586
Add disable_state_update as a template input.
The gated_delta_rule_mtp() function supports a disable_state_update parameter that controls whether final_state is updated, yet the template does not expose it as an input. While the output description correctly mentions this behavior, the schema omission causes fi_trace to emit identical specifications for both state-updating and non-updating modes.
Add disable_state_update as a boolean input (optional or required, as per the function's design) to accurately model the two distinct operational modes.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@flashinfer/trace/templates/gdn.py` around lines 502 - 546, The template
schema is missing the disable_state_update input required by
gated_delta_rule_mtp, so add a boolean Tensor/Scalar entry named
"disable_state_update" to the inputs dict (matching how other flags are
represented) and mark it optional or required consistent with
gated_delta_rule_mtp's signature; ensure you reference the symbol name
"disable_state_update" and update the inputs block near the existing
"initial_state"/"final_state" entries so the trace can distinguish
state-updating vs non-updating behavior described in "final_state".
| def test_fi_trace_complete_gqa_paged_decode(): | ||
| """GQA paged decode: tuple paged_kv_cache input handled correctly.""" | ||
| from flashinfer.decode import BatchDecodeWithPagedKVCacheWrapper | ||
| from flashinfer.trace.templates.attention import gqa_paged_decode_trace | ||
|
|
||
| B, H, KV, D, P, NP = 4, 8, 4, 64, 16, 8 | ||
| q = torch.zeros(B, H, D, dtype=torch.bfloat16) | ||
| k = torch.zeros(NP, P, KV, D, dtype=torch.bfloat16) | ||
| v = torch.zeros(NP, P, KV, D, dtype=torch.bfloat16) | ||
|
|
||
| defn = BatchDecodeWithPagedKVCacheWrapper.run.fi_trace(q=q, paged_kv_cache=(k, v)) |
There was a problem hiding this comment.
Remove unused import gqa_paged_decode_trace.
The import gqa_paged_decode_trace at line 387 is flagged as unused by pre-commit. The test only uses BatchDecodeWithPagedKVCacheWrapper.run.fi_trace.
Suggested fix
def test_fi_trace_complete_gqa_paged_decode():
"""GQA paged decode: tuple paged_kv_cache input handled correctly."""
from flashinfer.decode import BatchDecodeWithPagedKVCacheWrapper
- from flashinfer.trace.templates.attention import gqa_paged_decode_trace
B, H, KV, D, P, NP = 4, 8, 4, 64, 16, 8📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| def test_fi_trace_complete_gqa_paged_decode(): | |
| """GQA paged decode: tuple paged_kv_cache input handled correctly.""" | |
| from flashinfer.decode import BatchDecodeWithPagedKVCacheWrapper | |
| from flashinfer.trace.templates.attention import gqa_paged_decode_trace | |
| B, H, KV, D, P, NP = 4, 8, 4, 64, 16, 8 | |
| q = torch.zeros(B, H, D, dtype=torch.bfloat16) | |
| k = torch.zeros(NP, P, KV, D, dtype=torch.bfloat16) | |
| v = torch.zeros(NP, P, KV, D, dtype=torch.bfloat16) | |
| defn = BatchDecodeWithPagedKVCacheWrapper.run.fi_trace(q=q, paged_kv_cache=(k, v)) | |
| def test_fi_trace_complete_gqa_paged_decode(): | |
| """GQA paged decode: tuple paged_kv_cache input handled correctly.""" | |
| from flashinfer.decode import BatchDecodeWithPagedKVCacheWrapper | |
| B, H, KV, D, P, NP = 4, 8, 4, 64, 16, 8 | |
| q = torch.zeros(B, H, D, dtype=torch.bfloat16) | |
| k = torch.zeros(NP, P, KV, D, dtype=torch.bfloat16) | |
| v = torch.zeros(NP, P, KV, D, dtype=torch.bfloat16) | |
| defn = BatchDecodeWithPagedKVCacheWrapper.run.fi_trace(q=q, paged_kv_cache=(k, v)) |
🧰 Tools
🪛 GitHub Actions: pre-commit
[error] 387-387: ruff (F401): flashinfer.trace.templates.attention.gqa_paged_decode_trace imported but unused
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@tests/trace/test_fi_trace_template_consistency.py` around lines 384 - 394,
Remove the unused import gqa_paged_decode_trace from the test; locate the import
statement that reads "from flashinfer.trace.templates.attention import
gqa_paged_decode_trace" and delete it so the test only imports and uses
BatchDecodeWithPagedKVCacheWrapper.run.fi_trace (ensure no other references to
gqa_paged_decode_trace remain in the file).
There was a problem hiding this comment.
Actionable comments posted: 1
♻️ Duplicate comments (3)
flashinfer/api_logging.py (1)
1510-1531:⚠️ Potential issue | 🟠 Major
FLASHINFER_LOGLEVEL=0no longer preserves the zero-overhead path for traced APIs.For any
@flashinfer_api(trace=...)function, this still returns_attach_fi_trace(...), and_attach_fi_trace()always builds_auto_dump_wrapper. That means every call pays an extra Python frame and_is_trace_dump_enabled()check even when logging is disabled, which contradicts the decorator's documented “returns original function” contract.Also applies to: 1629-1634
tests/trace/fi_trace_out/mla_paged_decode_h16_ckv512_kpe64_ps64.json (1)
126-126:⚠️ Potential issue | 🟠 MajorFlatten the paged cache before the reference matmul.
The schema above says
ckv_cache/kpe_cacheare[num_pages, page_size, head_dim_*], so withpage_size=64thesqueeze(1)calls do nothing.Kc_all[tok_idx]/Kp_all[tok_idx]therefore stay 3D, and the laterqn @ Kc.Tpath no longer matches the intended[num_qo_heads, L]attention score computation. Reshape the caches to token-major 2D tensors before indexing, or rewrite the reference to handle paged tensors directly.For PyTorch 2.x, if ckv_cache has shape [num_pages, page_size, head_dim], what does ckv_cache.squeeze(1) return when page_size=64, and what shape does qn @ Kc.T use when qn is [num_qo_heads, head_dim] and Kc is [L, 64, head_dim]?🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/trace/fi_trace_out/mla_paged_decode_h16_ckv512_kpe64_ps64.json` at line 126, The reference incorrectly uses ckv_cache.squeeze(1) / kpe_cache.squeeze(1) which leaves a 3D paged tensor when page_size>1; in _mla_paged_decode_reference you must flatten the page-major caches to token-major 2D tensors (num_pages*page_size, head_dim_ckv/kpe) before indexing (i.e., compute Kc_all and Kp_all as [num_tokens, head_dim_*] rather than [num_pages, page_size, head_dim_*]), then select Kc/Kp with tok_idx so that qn @ Kc.T and qp @ Kp.T produce [num_qo_heads, L] logits; update the Kc_all/Kp_all creation near their assignments and ensure subsequent uses (Kc, Kp, logits, output) operate on the flattened shapes.tests/trace/example.py (1)
54-551:⚠️ Potential issue | 🟠 MajorPytest still won't execute this trace generator.
tests/trace/example.pyis still a standalone script with top-level side effects and notest_*entrypoint, so CI won't collect it or validate the generated fixtures. Please move the body into a test/helper and keep a__main__guard only for manual runs.As per coding guidelines,
tests/**/*.py: Prefix test functions withtest_and structure tests by feature intests/subdirectories matching kernel categories.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/trace/example.py` around lines 54 - 551, The file currently executes trace-generation at import (top-level side effects: SAVE_DIR, the whole sequence of flashinfer.* calls, and the final files/print summary), so move the entire body into a callable function (e.g. generate_fi_traces or build_example_traces) that encapsulates planning/running wrappers and the final JSON-summary logic, then add a pytest entrypoint test_example_traces() in the same module (or a new tests/ submodule) that calls that function and asserts expected output (e.g. presence/count of files from SAVE_DIR or that no exceptions occur), and retain an if __name__ == "__main__": guard to call generate_fi_traces() for manual runs; reference SAVE_DIR, the trace-generation sequence (all flashinfer.* calls and wrapper usages like BatchDecodeWithPagedKVCacheWrapper, BatchPrefillWithPagedKVCacheWrapper, BatchPrefillWithRaggedKVCacheWrapper, BatchMLAPagedAttentionWrapper, and the files variable/summary) when making the changes.
🧹 Nitpick comments (8)
flashinfer/trace/templates/attention.py (3)
28-41: Prefix unused unpacked variables with underscore.
page_sizeis unpacked fromk_cache.shapebut never used in the function. Prefix with_to satisfy linter.Proposed fix
def _gqa_paged_decode_reference(q, k_cache, v_cache, kv_indptr, kv_indices, sm_scale): batch_size, num_qo_heads, head_dim = q.shape - _, page_size, num_kv_heads, _ = k_cache.shape + _, _page_size, num_kv_heads, _ = k_cache.shape🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@flashinfer/trace/templates/attention.py` around lines 28 - 41, The variable page_size unpacked in _gqa_paged_decode_reference is unused and should be prefixed with an underscore to satisfy the linter; change the unpacking from "_, page_size, num_kv_heads, _ = k_cache.shape" style to use "_page_size" (or simply "_" if preferred) so the function signature still captures batch dimensions but removes the unused symbol while keeping references to q, k_cache, v_cache, kv_indptr, kv_indices, and sm_scale intact.
244-248: Prefix unused unpacked variable with underscore.
total_kvis unpacked but never used. Prefix with_to satisfy linter.Proposed fix
def _gqa_ragged_prefill_reference(q, k, v, qo_indptr, kv_indptr, sm_scale): total_q, num_qo_heads, head_dim = q.shape - total_kv, num_kv_heads, _ = k.shape + _total_kv, num_kv_heads, _ = k.shape len_indptr = qo_indptr.shape[0]🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@flashinfer/trace/templates/attention.py` around lines 244 - 248, In _gqa_ragged_prefill_reference, the unpacking assigns total_kv from k.shape but it's unused; change the unpacked name to _total_kv (or prefix with underscore) in the q, k, v shape assignment so the linter recognizes it as intentionally unused (i.e., update the tuple unpack on the line with "total_q, num_qo_heads, head_dim = q.shape" / "total_kv, num_kv_heads, _ = k.shape" to use _total_kv).
125-144: Prefix unused unpacked variables with underscore.Both
num_pagesandpage_sizeare unpacked but never used. This triggers ruff RUF059.Proposed fix
def _gqa_paged_prefill_reference( q, k_cache, v_cache, qo_indptr, kv_indptr, kv_indices, sm_scale ): total_q, num_qo_heads, head_dim = q.shape - num_pages, page_size, num_kv_heads, _ = k_cache.shape + _num_pages, _page_size, num_kv_heads, _ = k_cache.shape len_indptr = qo_indptr.shape[0]🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@flashinfer/trace/templates/attention.py` around lines 125 - 144, The variables num_pages and page_size are unpacked in _gqa_paged_prefill_reference but never used, triggering ruff RUF059; update the tuple unpacking to prefix unused names with an underscore (e.g., _num_pages, _page_size or simply _, _) in the line that unpacks k_cache.shape so the intent is clear and the linter warning is resolved while leaving k_cache usage unchanged.tests/trace/test_fi_trace_template_consistency.py (3)
440-466: Consider renamingItoINTERMEDIATEorINTER_SIZEfor clarity.The variable name
Iis flagged as ambiguous (E741) because it can be confused with1orl. The same applies to line 488.Proposed fix
- T, E, EL, H, I, BS = 4, 16, 2, 256, 64, 128 + T, E, EL, H, INTER, BS = 4, 16, 2, 256, 64, 128 defn = trtllm_fp8_block_scale_moe.fi_trace( routing_logits=torch.zeros(T, E, dtype=torch.float32), routing_bias=torch.zeros(E, dtype=torch.bfloat16), hidden_states=torch.zeros(T, H, dtype=torch.float8_e4m3fn), hidden_states_scale=torch.ones(H // BS, T, dtype=torch.float32), - gemm1_weights=torch.zeros(EL, 2 * I, H, dtype=torch.float8_e4m3fn), - gemm1_weights_scale=torch.ones(EL, (2 * I) // BS, H // BS, dtype=torch.float32), - gemm2_weights=torch.zeros(EL, H, I, dtype=torch.float8_e4m3fn), - gemm2_weights_scale=torch.ones(EL, H // BS, I // BS, dtype=torch.float32), + gemm1_weights=torch.zeros(EL, 2 * INTER, H, dtype=torch.float8_e4m3fn), + gemm1_weights_scale=torch.ones(EL, (2 * INTER) // BS, H // BS, dtype=torch.float32), + gemm2_weights=torch.zeros(EL, H, INTER, dtype=torch.float8_e4m3fn), + gemm2_weights_scale=torch.ones(EL, H // BS, INTER // BS, dtype=torch.float32), num_experts=E, top_k=top_k, - intermediate_size=I, + intermediate_size=INTER,🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/trace/test_fi_trace_template_consistency.py` around lines 440 - 466, Rename the ambiguous single-letter variable I to a clearer name like INTERMEDIATE or INTER_SIZE throughout the test (e.g., the variable declaration T, E, EL, H, INTERMEDIATE, BS = ... and all uses: gemm1_weights shape (EL, 2 * INTERMEDIATE, H), gemm2_weights shape (EL, H, INTERMEDIATE), and intermediate_size=INTERMEDIATE in the trtllm_fp8_block_scale_moe.fi_trace(...) call) and similarly update any other occurrence on the nearby line 488 so all references remain consistent.
374-376: Rename ambiguous variableltolabelfor clarity.The single-letter
lcan be confused with1orI. Use a more descriptive name.Proposed fix
-_E2E_PAIRS = [(f, t, l) for f, t, l in _ALL_PAIRS if l not in _E2E_SKIP] -_E2E_IDS = [label for _, _, label in _E2E_PAIRS] +_E2E_PAIRS = [(f, t, label) for f, t, label in _ALL_PAIRS if label not in _E2E_SKIP] +_E2E_IDS = [lbl for _, _, lbl in _E2E_PAIRS]🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/trace/test_fi_trace_template_consistency.py` around lines 374 - 376, Rename the ambiguous single-letter variable in the list comprehension: change the unpacking in _E2E_PAIRS from (f, t, l) to (f, t, label) and update the filter to use label instead of l; also update _E2E_IDS to unpack/use the same name (e.g., [label for _, _, label in _E2E_PAIRS]) so all references use the descriptive symbol label while keeping the existing logic with _ALL_PAIRS and _E2E_SKIP.
560-563: Use raw string for regex pattern with metacharacters.The pattern contains regex metacharacters (
=,*,.) but is not a raw string. While it works due to no escape conflicts, usingr"..."is safer and clearer.Proposed fix
- with pytest.raises(AssertionError, match="param=.*hidden_state.*not found"): + with pytest.raises(AssertionError, match=r"param=.*hidden_state.*not found"): assert_template_signature_consistency(func, broken, label="meta-test")🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/trace/test_fi_trace_template_consistency.py` around lines 560 - 563, The regex string passed to pytest.raises in the test (match="param=.*hidden_state.*not found") uses metacharacters but isn't a raw string; change it to a raw string (e.g., r"param=.*hidden_state.*not found") in the pytest.raises call that wraps assert_template_signature_consistency(func, broken, label="meta-test") to ensure backslashes and metacharacters are interpreted correctly; update the test invocation around _make_gdn_decode_func(), func, and broken accordingly.flashinfer/trace/templates/moe.py (2)
674-678: Replace ambiguous×(multiplication sign) with ASCIIx.The Unicode multiplication sign
×(U+00D7) can cause confusion. Use ASCIIxor*instead.Proposed fix
"gemm1_out_size": Const( - description="Output size of FC1 (2 × intermediate_size for SwiGLU).", + description="Output size of FC1 (2 * intermediate_size for SwiGLU).", abbrev="", ),🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@flashinfer/trace/templates/moe.py` around lines 674 - 678, The description string for the Const named "gemm1_out_size" contains a Unicode multiplication sign (`×`); update the description in the gemm1_out_size Const to use an ASCII "x" (or "*" if you prefer) instead (e.g., change "2 × intermediate_size for SwiGLU" to "2 x intermediate_size for SwiGLU") so the comment uses plain ASCII characters.
795-806: FP4 MoE templates cannot be validated against reference implementations.All FP4 templates pass
reference=Nonebecause the_make_standard_fp4_moe_tracefactory does not accept a reference parameter (unlike the FP8 factory). No FP4 MoE reference implementations are defined. Given that FP4 templates are marked asstatus:experimental, either implement reference functions or document why validation is deferred.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@flashinfer/trace/templates/moe.py` around lines 795 - 806, The FP4 MoE factory _make_standard_fp4_moe_trace currently hardcodes reference=None so FP4 templates cannot be validated; update the factory signature to accept an optional reference parameter (e.g., reference=None) and pass that through into TraceTemplate(reference=reference), then update all call sites that construct FP4 MoE traces to provide a proper reference function or explicitly pass None with a comment; additionally, either implement the missing FP4 MoE reference functions (and register them where other references live) or add clear documentation in the template module explaining that FP4 validation is deferred and why.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@flashinfer/api_logging.py`:
- Around line 1521-1527: The trace auto-dump currently swallows all exceptions
in the block guarded by _is_trace_dump_enabled(), making failures invisible;
change the except Exception: pass to catch Exception as e and emit a non-fatal
log (e.g., processLogger.warning or module logger) that includes the failing
trace function name (use fi_trace_fn.__name__ or _sig.signature info) and the
exception information (e) or traceback, then continue; update the try/except
around _sig.bind(*args, **kwargs), bound.apply_defaults(), and
fi_trace_fn(**dict(bound.arguments)) to log the diagnostic while keeping the
call non-fatal.
---
Duplicate comments:
In `@tests/trace/example.py`:
- Around line 54-551: The file currently executes trace-generation at import
(top-level side effects: SAVE_DIR, the whole sequence of flashinfer.* calls, and
the final files/print summary), so move the entire body into a callable function
(e.g. generate_fi_traces or build_example_traces) that encapsulates
planning/running wrappers and the final JSON-summary logic, then add a pytest
entrypoint test_example_traces() in the same module (or a new tests/ submodule)
that calls that function and asserts expected output (e.g. presence/count of
files from SAVE_DIR or that no exceptions occur), and retain an if __name__ ==
"__main__": guard to call generate_fi_traces() for manual runs; reference
SAVE_DIR, the trace-generation sequence (all flashinfer.* calls and wrapper
usages like BatchDecodeWithPagedKVCacheWrapper,
BatchPrefillWithPagedKVCacheWrapper, BatchPrefillWithRaggedKVCacheWrapper,
BatchMLAPagedAttentionWrapper, and the files variable/summary) when making the
changes.
In `@tests/trace/fi_trace_out/mla_paged_decode_h16_ckv512_kpe64_ps64.json`:
- Line 126: The reference incorrectly uses ckv_cache.squeeze(1) /
kpe_cache.squeeze(1) which leaves a 3D paged tensor when page_size>1; in
_mla_paged_decode_reference you must flatten the page-major caches to
token-major 2D tensors (num_pages*page_size, head_dim_ckv/kpe) before indexing
(i.e., compute Kc_all and Kp_all as [num_tokens, head_dim_*] rather than
[num_pages, page_size, head_dim_*]), then select Kc/Kp with tok_idx so that qn @
Kc.T and qp @ Kp.T produce [num_qo_heads, L] logits; update the Kc_all/Kp_all
creation near their assignments and ensure subsequent uses (Kc, Kp, logits,
output) operate on the flattened shapes.
---
Nitpick comments:
In `@flashinfer/trace/templates/attention.py`:
- Around line 28-41: The variable page_size unpacked in
_gqa_paged_decode_reference is unused and should be prefixed with an underscore
to satisfy the linter; change the unpacking from "_, page_size, num_kv_heads, _
= k_cache.shape" style to use "_page_size" (or simply "_" if preferred) so the
function signature still captures batch dimensions but removes the unused symbol
while keeping references to q, k_cache, v_cache, kv_indptr, kv_indices, and
sm_scale intact.
- Around line 244-248: In _gqa_ragged_prefill_reference, the unpacking assigns
total_kv from k.shape but it's unused; change the unpacked name to _total_kv (or
prefix with underscore) in the q, k, v shape assignment so the linter recognizes
it as intentionally unused (i.e., update the tuple unpack on the line with
"total_q, num_qo_heads, head_dim = q.shape" / "total_kv, num_kv_heads, _ =
k.shape" to use _total_kv).
- Around line 125-144: The variables num_pages and page_size are unpacked in
_gqa_paged_prefill_reference but never used, triggering ruff RUF059; update the
tuple unpacking to prefix unused names with an underscore (e.g., _num_pages,
_page_size or simply _, _) in the line that unpacks k_cache.shape so the intent
is clear and the linter warning is resolved while leaving k_cache usage
unchanged.
In `@flashinfer/trace/templates/moe.py`:
- Around line 674-678: The description string for the Const named
"gemm1_out_size" contains a Unicode multiplication sign (`×`); update the
description in the gemm1_out_size Const to use an ASCII "x" (or "*" if you
prefer) instead (e.g., change "2 × intermediate_size for SwiGLU" to "2 x
intermediate_size for SwiGLU") so the comment uses plain ASCII characters.
- Around line 795-806: The FP4 MoE factory _make_standard_fp4_moe_trace
currently hardcodes reference=None so FP4 templates cannot be validated; update
the factory signature to accept an optional reference parameter (e.g.,
reference=None) and pass that through into TraceTemplate(reference=reference),
then update all call sites that construct FP4 MoE traces to provide a proper
reference function or explicitly pass None with a comment; additionally, either
implement the missing FP4 MoE reference functions (and register them where other
references live) or add clear documentation in the template module explaining
that FP4 validation is deferred and why.
In `@tests/trace/test_fi_trace_template_consistency.py`:
- Around line 440-466: Rename the ambiguous single-letter variable I to a
clearer name like INTERMEDIATE or INTER_SIZE throughout the test (e.g., the
variable declaration T, E, EL, H, INTERMEDIATE, BS = ... and all uses:
gemm1_weights shape (EL, 2 * INTERMEDIATE, H), gemm2_weights shape (EL, H,
INTERMEDIATE), and intermediate_size=INTERMEDIATE in the
trtllm_fp8_block_scale_moe.fi_trace(...) call) and similarly update any other
occurrence on the nearby line 488 so all references remain consistent.
- Around line 374-376: Rename the ambiguous single-letter variable in the list
comprehension: change the unpacking in _E2E_PAIRS from (f, t, l) to (f, t,
label) and update the filter to use label instead of l; also update _E2E_IDS to
unpack/use the same name (e.g., [label for _, _, label in _E2E_PAIRS]) so all
references use the descriptive symbol label while keeping the existing logic
with _ALL_PAIRS and _E2E_SKIP.
- Around line 560-563: The regex string passed to pytest.raises in the test
(match="param=.*hidden_state.*not found") uses metacharacters but isn't a raw
string; change it to a raw string (e.g., r"param=.*hidden_state.*not found") in
the pytest.raises call that wraps assert_template_signature_consistency(func,
broken, label="meta-test") to ensure backslashes and metacharacters are
interpreted correctly; update the test invocation around
_make_gdn_decode_func(), func, and broken accordingly.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: 4aacac64-d181-4f0b-a266-998fd025caf7
📒 Files selected for processing (22)
flashinfer/api_logging.pyflashinfer/fi_trace.pyflashinfer/fused_moe/core.pyflashinfer/trace/templates/attention.pyflashinfer/trace/templates/moe.pytests/trace/example.pytests/trace/fi_trace_out/mla_paged_decode_h16_ckv512_kpe64_ps1.jsontests/trace/fi_trace_out/mla_paged_decode_h16_ckv512_kpe64_ps64.jsontests/trace/fi_trace_out/moe_fp4_block_scale_default_routing_topk8_e32_h7168_i2048.jsontests/trace/fi_trace_out/moe_fp4_block_scale_ds_routing_topk8_e32_h7168_i2048_ng8_kg4.jsontests/trace/fi_trace_out/moe_fp4_block_scale_llama4_routing_topk1_e32_h7168_i2048.jsontests/trace/fi_trace_out/moe_fp4_block_scale_renormalize_naive_routing_topk8_e32_h7168_i2048.jsontests/trace/fi_trace_out/moe_fp4_block_scale_renormalize_routing_topk8_e32_h7168_i2048.jsontests/trace/fi_trace_out/moe_fp4_block_scale_topk_routing_topk8_e32_h7168_i2048.jsontests/trace/fi_trace_out/moe_fp8_block_scale_default_routing_topk8_e32_h7168_i2048.jsontests/trace/fi_trace_out/moe_fp8_block_scale_ds_routing_topk8_ng8_kg4_e32_h7168_i2048.jsontests/trace/fi_trace_out/moe_fp8_block_scale_llama4_routing_topk1_e32_h7168_i2048.jsontests/trace/fi_trace_out/moe_fp8_block_scale_renormalize_naive_routing_topk8_e32_h7168_i2048.jsontests/trace/fi_trace_out/moe_fp8_block_scale_renormalize_routing_topk8_e32_h7168_i2048.jsontests/trace/fi_trace_out/moe_fp8_block_scale_topk_routing_topk8_e32_h7168_i2048.jsontests/trace/test_fi_trace.pytests/trace/test_fi_trace_template_consistency.py
✅ Files skipped from review due to trivial changes (10)
- tests/trace/fi_trace_out/moe_fp4_block_scale_ds_routing_topk8_e32_h7168_i2048_ng8_kg4.json
- tests/trace/fi_trace_out/moe_fp8_block_scale_renormalize_routing_topk8_e32_h7168_i2048.json
- tests/trace/fi_trace_out/moe_fp4_block_scale_renormalize_routing_topk8_e32_h7168_i2048.json
- tests/trace/fi_trace_out/moe_fp4_block_scale_llama4_routing_topk1_e32_h7168_i2048.json
- tests/trace/fi_trace_out/moe_fp8_block_scale_default_routing_topk8_e32_h7168_i2048.json
- tests/trace/fi_trace_out/moe_fp8_block_scale_renormalize_naive_routing_topk8_e32_h7168_i2048.json
- tests/trace/fi_trace_out/moe_fp8_block_scale_topk_routing_topk8_e32_h7168_i2048.json
- tests/trace/fi_trace_out/mla_paged_decode_h16_ckv512_kpe64_ps1.json
- tests/trace/fi_trace_out/moe_fp4_block_scale_renormalize_naive_routing_topk8_e32_h7168_i2048.json
- tests/trace/fi_trace_out/moe_fp4_block_scale_default_routing_topk8_e32_h7168_i2048.json
🚧 Files skipped from review as they are similar to previous changes (1)
- flashinfer/fused_moe/core.py
- Add flashinfer/trace/templates/activation.py: silu_and_mul, gelu_and_mul, gelu_tanh_and_mul (used in FFN layers of LLaMA/Mistral/GPT-style models) - Add flashinfer/trace/templates/cascade.py: merge_state, merge_state_in_place, merge_states (cascade/speculative attention state merging) - Extend flashinfer/trace/templates/norm.py: rmsnorm_quant, fused_add_rmsnorm_quant, gemma_rmsnorm, gemma_fused_add_rmsnorm, layernorm (additional norm variants) - Wire @flashinfer_api(trace=...) for all 11 new templates in activation.py, cascade.py, and norm/__init__.py - Update example.py: add activation and cascade calls, update docstring to list all 39 expected output files (33 original + 6 new) - Add tests/trace/fi_trace_out/ to .gitignore AI-assisted Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Add example calls in tests/trace/example.py for rmsnorm_quant, fused_add_rmsnorm_quant, gemma_rmsnorm, gemma_fused_add_rmsnorm, layernorm, and gdn_prefill. Update docstring to list all 45 expected JSON files. Add "Trace Template Checklist" section to CLAUDE.md documenting the steps for wiring trace to new APIs. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
- Remove tests/trace/fi_trace_out/ from .gitignore so generated benchmark definition JSONs are committed alongside the code that produces them. - Wrap mm_bf16 and mm_fp8 calls in contextlib.suppress so example.py runs end-to-end on SM90 (H100). mm_bf16 now uses backend="auto" (cudnn on SM<100, cutlass on SM100+); mm_fp8's low-latency GEMM is SM100-only at runtime but the trace still dumps before launch. - Add newly-generated trace JSONs for the activation, cascade, norm-quant, gemma-norm, layernorm, and gdn-prefill APIs. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
… state update Three sites had @flashinfer_api on a subclass or internal helper whose parent/caller was already decorated, producing duplicate log entries at higher FLASHINFER_LOGLEVEL values. Remove the redundant decorator: - BatchAttentionWithAttentionSinkWrapper.__init__ (parent BatchPrefillWithPagedKVCacheWrapper.__init__ already decorated) - CUDAGraphBatchDecodeWithPagedKVCacheWrapper.__init__ (parent BatchDecodeWithPagedKVCacheWrapper.__init__ already decorated) - trtllm_low_latency_gemm (called internally by the already-decorated mm_fp8) Also fix _gdn_mtp_reference in flashinfer/trace/templates/gdn.py: the function was returning initial_state.clone() as final_state, silently discarding every state update accumulated across the T tokens. Now final_state is built once outside the batch loop and the [H,K,V] scratch buffer is committed back to the pool slot as [H,V,K] after each sequence. Regenerate gdn_mtp_qk4_v8_d128.json so the embedded reference matches. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
bkryu
left a comment
There was a problem hiding this comment.
Overall looks good to me. Left a comment about some (I believe) non-APIs that we may not want to decorate.
Also, not sure if you intended the PR to be exhaustive, but I think you missed:
- cuDNN and TRTLLM attention variants
- CUTLASS Fused MoE
- Quantization APIs
- RoPe APIs
| return graph | ||
|
|
||
|
|
||
| @flashinfer_api |
There was a problem hiding this comment.
Not sure if this is a FlashInfer API that we expect users to run. Do we need the label here?
| return graph | ||
|
|
||
|
|
||
| @flashinfer_api |
| return graph | ||
|
|
||
|
|
||
| @flashinfer_api |
| return graph | ||
|
|
||
|
|
||
| @flashinfer_api |
Demonstrates that @flashinfer_api(trace=...) auto-dump is compatible with torch.cuda.graph capture: - Schema extraction reads only CPU-side metadata (.shape, .dtype) and writes JSON via host-thread file I/O — no CUDA stream ops, so nothing corrupts the captured graph even if a write fires inside the capture block. - The _DUMPED_NAMES dedup in flashinfer/trace/template.py ensures at most one write per (process, trace name), so re-entering the decorated wrapper during capture is cheap. - Graph replay does not execute Python, so auto-dump cannot fire on replay under any circumstance. Example uses CUDAGraphBatchDecodeWithPagedKVCacheWrapper with Llama-3.1-8B shapes, captures wrapper.run(), replays 5×, and verifies numerical equivalence to eager. fi_trace_out_cudagraph/ is gitignored — the single JSON it produces is identical to the one committed under fi_trace_out/ for the same op. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Remove tests/trace/fi_trace_out_cudagraph/ from .gitignore and commit the single JSON produced by example_cuda_graph.py so reviewers can inspect the schema without running the example. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
…n reference matmul ### B1 — GEMM references compute A @ B instead of A @ B.T `_mm_reference` and the three quantized helpers in flashinfer/trace/templates/gemm.py modeled `B` as physical `[K, N]` in the template inputs but then computed `A @ B.T`, which is only valid when `K == N`. This would crash for every non-square weight shape we trace (e.g. 7168→256 in example.py). Drop the `.T` in all four refs and update the three "C = A @ B.T" template descriptions. ### B2 — paged GQA refs treat kv_indices as token IDs instead of page IDs `_gqa_paged_decode_reference` and `_gqa_paged_prefill_reference` flattened `k_cache` to `[num_tokens, ...]` and indexed with `kv_indices`, which are page IDs. The lookup only gave correct tokens when `page_size == 1`. Gather pages first, then reshape the gathered `[num_selected_pages, page_size, ...]` into a single token axis. ### B3 — MLA refs silently assumed page_size=1 via squeeze(1) `_mla_paged_decode_reference` and `_mla_paged_prefill_reference` used `ckv_cache.squeeze(1)` which is a no-op for page_size != 1, leaving a 3-D tensor that would break later matmuls. Apply the same page-gather fix as B2 so both page_size=1 and page_size>1 MLA work. Regenerate the 7 affected JSON fixtures and the cuda-graph example JSON so their embedded reference strings reflect the fixes. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
… auto-dump diag B4: _fp8_moe_run_experts in flashinfer/trace/templates/moe.py no longer reads module-level H=7168/I=2048/BLOCK=128; those are derived from hidden_states.shape and gemm1_weights.shape so the reference is valid for any MoE shape, not just DeepSeek-V3. B5: The five fp8 MoE routing references now accept top_k (and n_group / topk_group for DeepSeek-V3) as explicit parameters instead of hardcoding TOP_K=8/N_GROUP=8/TOPK_GROUP=4. Corresponding Scalar inputs are added to each template so external consumers of the trace JSON pass the correct routing configuration. B6: gdn_prefill_trace gains the head-ratio constraints (num_v_heads >= num_q_heads, divisibility, num_k_heads == num_q_heads) that its reference already assumes via repeat_interleave. B7: GDN decode/prefill/MTP outputs now declare dtype="bfloat16" to match the reference (the references always emit bfloat16, so the previous dtype_from="q" was a lie when q was fp16 or fp32). B9: scale Scalar is marked optional=True in all three GDN templates (decode/prefill/MTP). The reference already handles scale=None. B10: Drop the "Unchanged if disable_state_update=True" phrase from gdn_mtp_trace.final_state — disable_state_update is a real kwarg on gated_delta_rule_mtp but not modelled as an input on the template, so referencing it in the description was misleading. B8: tests/trace/test_fi_trace_template_consistency.py E2E synthesizer uses per-key positive defaults for int32 scalars (block_size=16, top_k=1, n_group=1, topk_group=1, ...) instead of 0, so synthesized definitions are semantically valid. B11: _auto_dump_wrapper in flashinfer/api_logging.py now emits a warnings.warn() when schema binding or trace file write fails, deduped per (API name, error class). Users previously saw missing JSON files with no explanation. Regenerate the 6 MoE JSON fixtures + GDN decode/prefill/MTP fixtures so the embedded reference strings and input schemas match. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
…e helpers Per bkryu's review on PR flashinfer-ai#2931: the four execute_cudnn_gemm_*_graph_override_shape functions in flashinfer/gemm/gemm_base.py are internal helpers called from the already-decorated mm_fp4 / mm_mxfp8 / mm_fp8 / mm_bf16 user APIs. Decorating them too causes double log entries at FLASHINFER_LOGLEVEL>=1 (same pattern fixed earlier for trtllm_low_latency_gemm and the CUDAGraph wrapper __init__). Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Per bkryu's review on PR flashinfer-ai#2931: several user-facing APIs were decorated with @flashinfer_api but had no trace template attached. This commit wires trace templates to RoPE and quantization. RoPE (flashinfer/trace/templates/rope.py, 10 new templates): - apply_rope / apply_rope_inplace - apply_rope_pos_ids / apply_rope_pos_ids_inplace - apply_llama31_rope / apply_llama31_rope_inplace - apply_llama31_rope_pos_ids / apply_llama31_rope_pos_ids_inplace - apply_rope_with_cos_sin_cache / apply_rope_with_cos_sin_cache_inplace Quantization (flashinfer/trace/templates/quantize.py, 4 new templates): - fp4_quantize, nvfp4_quantize, mxfp4_quantize, mxfp8_quantize Follow-ups (not addressed in this commit): cuDNN/TRTLLM attention variants (single_prefill/single_decode, cudnn_batch_*, trtllm_batch_*) and MoE variants (cutlass_fused_moe, trtllm_bf16_moe, etc.) still need templates. Add example calls for RoPE and quantization in tests/trace/example.py and commit the 14 regenerated JSON fixtures. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
…fine attention template descriptions Addresses bkryu's top-level review on PR flashinfer-ai#2931 listing missing trace templates, and responds to follow-up feedback that attention descriptions were redundant. New templates (13): attention.py: single_decode_with_kv_cache_trace, single_prefill_with_kv_cache_trace, trtllm_batch_decode_trace, trtllm_batch_context_trace, cudnn_batch_decode_trace, cudnn_batch_prefill_trace moe.py: cutlass_fused_moe_trace, trtllm_bf16_moe_trace, trtllm_bf16_routed_moe_trace, trtllm_fp8_per_tensor_scale_moe_trace, trtllm_fp8_block_scale_routed_moe_trace, trtllm_fp4_block_scale_routed_moe_trace, trtllm_mxint4_block_scale_moe_trace Wire-ups: flashinfer/decode.py: single_decode_with_kv_cache, trtllm_batch_decode_with_kv_cache flashinfer/prefill.py: single_prefill_with_kv_cache, trtllm_batch_context_with_kv_cache flashinfer/cudnn/decode.py: cudnn_batch_decode_with_kv_cache flashinfer/cudnn/prefill.py: cudnn_batch_prefill_with_kv_cache flashinfer/fused_moe/core.py: 7 MoE variants Attention description polish (flashinfer/trace/templates/attention.py): Replaced verbose cross-referencing paragraphs with one- or two- sentence identifiers that state (a) the API wrapped, (b) one or two distinctive structural features. Added a module-level comparison table as the single source of truth for how templates differ. The table lists each template's batching, KV layout, indexing mechanism, stage, and backend, so consumers can pick the right template without parsing every description. Also add per-key positive int32 defaults in the E2E synthesizer for num_experts, intermediate_size, hidden_size (in addition to the earlier block_size/top_k/n_group/topk_group defaults) and introduce _TRTLLM_MOE_ROUTED_AXES so routed-variant templates mark num_experts and intermediate_size as Var (they arrive as scalar kwargs when topk_ids is pre-computed, so the routing_logits shape can't resolve them). Tests: 220 passed (was 139 before the whole review cycle). Regenerate affected JSON fixtures so their embedded descriptions and schemas match. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
The six trtllm_fp4_block_scale_moe_*_routing_trace templates previously
had reference=None. This commit adds executable reference functions
modelled after tests/moe/test_trtllm_gen_fused_moe.py::run_moe_dequant,
so external consumers (flashinfer-bench) can verify kernel output
against the reference.
Helpers added to flashinfer/trace/templates/moe.py:
- _unpack_fp4_e2m1: 16-entry LUT-based unpack of uint8-packed
e2m1fn FP4 values into float32 (sign + exponent + mantissa), so
the returned tensor has twice the packed last dim.
- _ue8m0_to_float32: decode UE8M0 (MX-format) scales.
- _decode_block_scales: dispatches UE8M0 vs fp8_e4m3fn based on the
scale dtype.
- _dequantize_fp4_tensor: unpack + apply per-block scales to a
packed FP4 tensor. Block size is inferred from the shape ratio so
NvFP4 (block_size=16) and MXFP4 (block_size=32) both work.
- _dequantize_fp4_hidden_states: handles the three activation
formats the kernel accepts — bfloat16, float8_e4m3fn (MXFP8) with
UE8M0 per-32 scales, and uint8-packed FP4.
Shared MoE kernel (_fp4_moe_run_experts): dequantizes weights and
hidden states, gathers per-expert tokens, does GEMM1 → SwiGLU
(silu(X2) * X1 to match trtllm-gen's convention) → GEMM2, applies
optional biases, and combines per-expert contributions weighted by
the routing weights. Emits bfloat16 output to match the template
schema.
Per-routing references (6, one per RoutingMethodType.{Default,
Renormalize, DeepSeekV3, Llama4, RenormalizeNaive, TopK}) compute
their own topk_idx + weights and call _fp4_moe_run_experts. DS
routing replicates the sigmoid → group-top2 → topk_group → top_k
path used in DeepSeek-V3.
Verified all six paths produce finite bfloat16 output of the expected
shape on NvFP4 hidden states (uint8 packed + fp8_e4m3fn scales),
MXFP8 hidden states (float8_e4m3fn + UE8M0 scales), and bf16
hidden states. Also verified the E2M1 LUT: nibble 0x7 → 6.0,
0xF → -6.0, etc.
Regenerate all six FP4 MoE JSON fixtures so they embed the new
reference source (previously absent).
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
…rectness tests
Before this commit: 14 templates had reference=None. Now: every template
has an executable reference, and each reference is verified numerically
against its corresponding flashinfer API in
tests/trace/test_reference_correctness.py.
Templates with new references (per file):
flashinfer/trace/templates/rope.py (10):
apply_rope, apply_rope_inplace, apply_rope_pos_ids,
apply_rope_pos_ids_inplace, apply_llama31_rope,
apply_llama31_rope_inplace, apply_llama31_rope_pos_ids,
apply_llama31_rope_pos_ids_inplace,
apply_rope_with_cos_sin_cache, apply_rope_with_cos_sin_cache_inplace
Helpers: _rope_freqs, _llama31_freqs (piecewise NTK scaling),
_rotate, _positions_from_indptr, _apply_rope_core.
flashinfer/trace/templates/norm.py (2):
rmsnorm_quant, fused_add_rmsnorm_quant (RMSNorm + per-tensor
FP8 quantize; returns fp8_e4m3fn + optional updated residual).
flashinfer/trace/templates/cascade.py (1):
merge_state_in_place (LSE-weighted merge with optional mask).
flashinfer/trace/templates/quantize.py (4):
fp4_quantize, nvfp4_quantize, mxfp4_quantize, mxfp8_quantize.
E2M1 nearest-magnitude rounding, UE8M0 vs fp8_e4m3fn scale
decoding, NvFP4 block_size=16 vs MXFP4/MXFP8 block_size=32.
flashinfer/trace/templates/attention.py (6):
single_decode, single_prefill (contiguous KV SDPA with causal),
trtllm_batch_decode, trtllm_batch_context (rectangular block_tables
+ interleaved kv_cache + bmm1/bmm2 scales),
cudnn_batch_decode, cudnn_batch_prefill (separate k/v caches,
actual_seq_lens_q/kv, optional LSE return).
Helpers: _trtllm_kv_from_cache,
_trtllm_paged_attention_reference.
flashinfer/trace/templates/moe.py (7):
cutlass_fused_moe (precomputed expert ids + scales),
trtllm_bf16_moe, trtllm_bf16_routed_moe (un-quantized),
trtllm_fp8_per_tensor_scale_moe (per-expert scalar scales),
trtllm_fp8_block_scale_routed_moe,
trtllm_fp4_block_scale_routed_moe (reuses _fp8_moe_run_experts
/ _fp4_moe_run_experts),
trtllm_mxint4_block_scale_moe (int4 unpack + bf16 scales).
Correctness tests (tests/trace/test_reference_correctness.py):
18 numerical tests compare reference output to the live flashinfer
API on the same inputs, within per-dtype tolerances:
- RoPE (10): bf16 output within 5e-2 of kernel (1 bf16 ULP)
- rmsnorm_quant, fused_add_rmsnorm_quant: residual exact; fp8
output compared after multiplying by scale
- merge_state_in_place: bf16/float32 within 5e-3
- mxfp8_quantize, fp4_quantize round-trip: within 50% relative
error (FP4 has inherent quantization error)
- single_decode, single_prefill (causal): within 5e-2
5 tests are marked skipped with clear reasons (cuDNN/TRT-LLM
kernels require specific runtime/hardware; those references are
covered by the shape-and-finite smoke test
test_moe_references_produce_valid_outputs).
Also regenerate every trace JSON under tests/trace/fi_trace_out/ so
the new reference source strings are embedded in the committed
fixtures.
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
basedpyright flagged ~45 "parameter not accessed" hints in the new MoE references. The unused params are intentional — the references accept the full API signature so external consumers can call them with the same kwargs they'd pass to the corresponding flashinfer API. Add explicit ``del`` statements at the top of each reference to document that the params are accepted for API parity but unused in the reference computation, silencing the hints. Affects the 7 references added in the previous commit: _cutlass_fused_moe_reference, _trtllm_bf16_moe_reference, _trtllm_bf16_routed_moe_reference, _trtllm_fp8_per_tensor_scale_moe_reference, _trtllm_fp8_block_scale_routed_moe_reference, _trtllm_fp4_block_scale_routed_moe_reference, _trtllm_mxint4_block_scale_moe_reference Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
…ignatures
Previous commit added ``del unused_a, unused_b, ...`` at the top of each of
the 7 new MoE references to silence basedpyright's
``reportUnusedParameter`` hints. That was noisy boilerplate.
The more Pythonic fix is to (a) drop parameters that neither the template
inputs schema nor the reference body reference, and (b) rename the
catch-all ``**kwargs`` to ``**_unused`` — the ``_`` prefix is the standard
convention that tells linters "intentionally unused." External callers can
still pass any extra API kwargs by keyword; they land in ``**_unused`` and
are silently discarded.
Net effect per reference:
_cutlass_fused_moe_reference:
drop output_dtype, quant_scales (kept via **_unused)
_trtllm_bf16_moe_reference:
drop n_group, topk_group, intermediate_size, local_num_experts,
routing_method_type (kept via **_unused)
_trtllm_bf16_routed_moe_reference:
drop n_group, topk_group, intermediate_size, local_num_experts
_trtllm_fp8_per_tensor_scale_moe_reference:
drop n_group, topk_group, intermediate_size, local_num_experts,
routing_method_type
_trtllm_fp8_block_scale_routed_moe_reference:
drop routing_bias, n_group, topk_group, intermediate_size,
local_num_experts
_trtllm_fp4_block_scale_routed_moe_reference:
drop routing_bias, gemm1_alpha/beta/clamp_limit, output1_scale_scalar,
output1_scale_gate_scalar, output2_scale_scalar, n_group, topk_group,
intermediate_size, local_num_experts
_trtllm_mxint4_block_scale_moe_reference:
drop gemm1_alpha/beta/clamp_limit, n_group, topk_group,
intermediate_size, local_num_experts, routing_method_type
Net diff: +51 / -75 — shorter, self-documenting signatures with no ``del``
boilerplate, and basedpyright is quiet. Test
``test_moe_references_produce_valid_outputs`` updated to call the MxInt4
reference with all-kwargs so it doesn't rely on positional ordering.
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
📌 Description
Adds a trace layer to FlashInfer so every public kernel can be described as a portable benchmark / replay definition without tying the description to any particular launcher.
flashinfer/trace/template.py— newTraceTemplateschema with namedaxes(Var/Const), typedinputs/outputs(Tensor/Scalar), optionalreferenceimplementation, and tag/constraint metadata.flashinfer/trace/templates/*.py— one module per operator family (attention, cascade, GDN, GEMM, MoE, norm, activation, sampling). Each file declares the schema and, where feasible, an executable reference.@flashinfer_api(trace=...)(extends the existing decorator inflashinfer/api_logging.py) — attaches.fi_trace()to the decorated function/method and, whenFLASHINFER_TRACE_DUMP=1, writes a per-shape JSON definition toFLASHINFER_TRACE_DUMP_DIRbefore the kernel runs (crash-safe).fi_trace()helpers — public entry points for programmatic trace generation from any@flashinfer_api-decorated API or a bound method.tests/trace/— template-consistency tests (signature ↔ axes/inputs), end-to-end reference checks, and anexample.pythat drives a realistic workload and dumps 45tests/trace/fi_trace_out/*.jsondefinitions across LLaMA-3.1, DeepSeek-V3, Gemma, Qwen3-Next, etc.Why
flashinfer-bench) consume a single self-describing JSON per op instead of reverse-engineering Python call sites.FLASHINFER_LOGLEVEL=0/FLASHINFER_TRACE_DUMPunset (decorator is a no-op in that path).Covered APIs
Attention (paged/ragged prefill, paged decode, MLA), sampling (top-k/top-p/top-k-top-p), GEMM (bf16, fp8, mxfp8, fp4), fused MoE (fp8/fp4 block-scale × 6 routing methods), norm (rmsnorm, fused-add, quant variants, Gemma variants, layernorm), activation (silu/gelu/gelu-tanh + mul), cascade merge (state/state-in-place/states), and GDN (decode, MTP, chunk prefill).
🔍 Related Issues
🚀 Pull Request Checklist
✅ Pre-commit Checks
pre-commitby runningpip install pre-commit.pre-commit install.pre-commit run --all-filesand fixed any reported issues.🧪 Tests
tests/trace/test_fi_trace.py,tests/trace/test_fi_trace_template_consistency.py).pytest tests/trace/ -v→ 139 passed).Reviewer Notes
@flashinfer_api(trace=...)must be innermost so trace dump runs even when surrounding@backend_requirementraises for unsupported capability. Formm_fp4/mm_mxfp8on SM<100 the outer@backend_requirementraises before the dump, which is why their JSONs are only regenerated on Blackwell. Seetests/trace/example.pyfor the realistic workload.@flashinfer_apidecorators that caused double-logging atFLASHINFER_LOGLEVEL=3+(subclass__init__overrides andtrtllm_low_latency_gemminternal helper).H/I/top_k/n_group;gdn_prefill_tracelacks the head-ratio constraints thatgdn_decode/gdn_mtpalready have; the E2E test synthesizer uses0forint32inputs, which makes some synthesized definitions nonsensical.🤖 Generated with Claude Code