From cb9fb8995319ade78435666b196f22314c1c8e69 Mon Sep 17 00:00:00 2001 From: Avery Huang Date: Tue, 31 Mar 2026 23:16:24 +0000 Subject: [PATCH 01/38] add missing flashinfer_api --- flashinfer/attention.py | 1 + flashinfer/decode.py | 1 + flashinfer/gemm/gemm_base.py | 5 +++++ flashinfer/trtllm_low_latency_gemm.py | 1 + 4 files changed, 8 insertions(+) diff --git a/flashinfer/attention.py b/flashinfer/attention.py index c4bc4f27dc..f5d4bd84ff 100644 --- a/flashinfer/attention.py +++ b/flashinfer/attention.py @@ -209,6 +209,7 @@ class BatchAttentionWithAttentionSinkWrapper(BatchPrefillWithPagedKVCacheWrapper a convenient interface for using attention sinks during prefill or decode attention. """ + @flashinfer_api def __init__( self, float_workspace_buffer: torch.Tensor, diff --git a/flashinfer/decode.py b/flashinfer/decode.py index 822aca407c..3cad0aa954 100644 --- a/flashinfer/decode.py +++ b/flashinfer/decode.py @@ -1577,6 +1577,7 @@ class CUDAGraphBatchDecodeWithPagedKVCacheWrapper(BatchDecodeWithPagedKVCacheWra :class:`BatchDecodeWithPagedKVCacheWrapper` """ + @flashinfer_api def __init__( self, workspace_buffer: torch.Tensor, diff --git a/flashinfer/gemm/gemm_base.py b/flashinfer/gemm/gemm_base.py index 57548c780c..842caabb86 100644 --- a/flashinfer/gemm/gemm_base.py +++ b/flashinfer/gemm/gemm_base.py @@ -1435,6 +1435,7 @@ class SegmentGEMMWrapper: True """ + @flashinfer_api def __init__( self, float_workspace_buffer: torch.Tensor, backend: str = "auto" ) -> None: @@ -2082,6 +2083,7 @@ def build_cudnn_gemm_fp4_graph_override_shape( return graph +@flashinfer_api def execute_cudnn_gemm_fp4_graph_override_shape( graph, a, @@ -2317,6 +2319,7 @@ def build_cudnn_gemm_mxfp8_graph_override_shape( return graph +@flashinfer_api def execute_cudnn_gemm_mxfp8_graph_override_shape( graph, a, @@ -2563,6 +2566,7 @@ def build_cudnn_gemm_with_per_tensor_q_graph_override_shape( return graph +@flashinfer_api def execute_cudnn_gemm_with_per_tensor_q_graph_override_shape( graph, a, b, a_scale, b_scale, c_final, workspace, tactic: int = 0 ): @@ -2891,6 +2895,7 @@ def build_cudnn_gemm_bf16_graph_override_shape( return graph +@flashinfer_api def execute_cudnn_gemm_bf16_graph_override_shape( graph, a, b, bias, c_final, workspace, tactic: int = 0 ): diff --git a/flashinfer/trtllm_low_latency_gemm.py b/flashinfer/trtllm_low_latency_gemm.py index 3aea77affb..faf1dd1103 100644 --- a/flashinfer/trtllm_low_latency_gemm.py +++ b/flashinfer/trtllm_low_latency_gemm.py @@ -116,6 +116,7 @@ def gemm_runner(): ) +@flashinfer_api def trtllm_low_latency_gemm( A: torch.Tensor, B: torch.Tensor, From 0dca5ffdb867aa0c040f613a40a7f509c246fe08 Mon Sep 17 00:00:00 2001 From: Avery Huang Date: Fri, 3 Apr 2026 17:52:42 +0000 Subject: [PATCH 02/38] init --- flashinfer/__init__.py | 1 + flashinfer/api_logging.py | 112 ++- flashinfer/decode.py | 3 +- flashinfer/fi_trace.py | 281 +++++++ flashinfer/fused_moe/core.py | 3 +- flashinfer/gdn_decode.py | 16 +- flashinfer/gdn_prefill.py | 3 +- flashinfer/gemm/gemm_base.py | 14 +- flashinfer/mla/_core.py | 3 +- flashinfer/norm/__init__.py | 5 +- flashinfer/prefill.py | 5 +- flashinfer/sampling.py | 11 +- flashinfer/trace/__init__.py | 25 + flashinfer/trace/example/__main__.py | 1 + flashinfer/trace/example/example.py | 294 ++++++++ .../fi_trace_out/fused_add_rmsnorm_h5120.json | 59 ++ .../fi_trace_out/gdn_decode_qk4_v8_d128.json | 149 ++++ .../fi_trace_out/gdn_mtp_qk4_v8_d128.json | 171 +++++ .../fi_trace_out/gemm_bf16_N256_K7168.json | 49 ++ .../fi_trace_out/gemm_bf16_N4096_K4096.json | 49 ++ .../gemm_fp4_N2048_K7168_block_size16.json | 72 ++ .../fi_trace_out/gemm_fp8_N1536_K7168.json | 51 ++ .../fi_trace_out/gemm_mxfp8_N4096_K4096.json | 67 ++ .../gqa_paged_decode_h32_kv8_d128_ps16.json | 113 +++ .../gqa_paged_decode_h32_kv8_d128_ps64.json | 113 +++ .../gqa_paged_prefill_h32_kv8_d128_ps16.json | 120 +++ .../fi_trace_out/gqa_ragged_h32_kv8_d128.json | 105 +++ ...mla_paged_decode_h16_ckv512_kpe64_ps1.json | 124 ++++ ...la_paged_decode_h16_ckv512_kpe64_ps64.json | 124 ++++ ...default_routing_topk8_e32_h7168_i2048.json | 152 ++++ .../example/fi_trace_out/rmsnorm_h4096.json | 43 ++ .../example/fi_trace_out/rmsnorm_h7168.json | 43 ++ .../fi_trace_out/top_k_sampling_v128256.json | 47 ++ .../top_k_top_p_sampling_v128256.json | 54 ++ .../top_k_top_p_sampling_v151936.json | 54 ++ .../fi_trace_out/top_p_sampling_v128256.json | 47 ++ .../fi_trace_out/top_p_sampling_v151936.json | 47 ++ flashinfer/trace/template.py | 515 +++++++++++++ flashinfer/trace/templates/__init__.py | 80 ++ flashinfer/trace/templates/attention.py | 701 ++++++++++++++++++ flashinfer/trace/templates/gdn.py | 500 +++++++++++++ flashinfer/trace/templates/gemm.py | 216 ++++++ flashinfer/trace/templates/moe.py | 591 +++++++++++++++ flashinfer/trace/templates/norm.py | 89 +++ flashinfer/trace/templates/sampling.py | 210 ++++++ tests/test_fi_trace.py | 581 +++++++++++++++ 46 files changed, 6089 insertions(+), 24 deletions(-) create mode 100644 flashinfer/fi_trace.py create mode 100644 flashinfer/trace/__init__.py create mode 100644 flashinfer/trace/example/__main__.py create mode 100644 flashinfer/trace/example/example.py create mode 100644 flashinfer/trace/example/fi_trace_out/fused_add_rmsnorm_h5120.json create mode 100644 flashinfer/trace/example/fi_trace_out/gdn_decode_qk4_v8_d128.json create mode 100644 flashinfer/trace/example/fi_trace_out/gdn_mtp_qk4_v8_d128.json create mode 100644 flashinfer/trace/example/fi_trace_out/gemm_bf16_N256_K7168.json create mode 100644 flashinfer/trace/example/fi_trace_out/gemm_bf16_N4096_K4096.json create mode 100644 flashinfer/trace/example/fi_trace_out/gemm_fp4_N2048_K7168_block_size16.json create mode 100644 flashinfer/trace/example/fi_trace_out/gemm_fp8_N1536_K7168.json create mode 100644 flashinfer/trace/example/fi_trace_out/gemm_mxfp8_N4096_K4096.json create mode 100644 flashinfer/trace/example/fi_trace_out/gqa_paged_decode_h32_kv8_d128_ps16.json create mode 100644 flashinfer/trace/example/fi_trace_out/gqa_paged_decode_h32_kv8_d128_ps64.json create mode 100644 flashinfer/trace/example/fi_trace_out/gqa_paged_prefill_h32_kv8_d128_ps16.json create mode 100644 flashinfer/trace/example/fi_trace_out/gqa_ragged_h32_kv8_d128.json create mode 100644 flashinfer/trace/example/fi_trace_out/mla_paged_decode_h16_ckv512_kpe64_ps1.json create mode 100644 flashinfer/trace/example/fi_trace_out/mla_paged_decode_h16_ckv512_kpe64_ps64.json create mode 100644 flashinfer/trace/example/fi_trace_out/moe_fp8_block_scale_default_routing_topk8_e32_h7168_i2048.json create mode 100644 flashinfer/trace/example/fi_trace_out/rmsnorm_h4096.json create mode 100644 flashinfer/trace/example/fi_trace_out/rmsnorm_h7168.json create mode 100644 flashinfer/trace/example/fi_trace_out/top_k_sampling_v128256.json create mode 100644 flashinfer/trace/example/fi_trace_out/top_k_top_p_sampling_v128256.json create mode 100644 flashinfer/trace/example/fi_trace_out/top_k_top_p_sampling_v151936.json create mode 100644 flashinfer/trace/example/fi_trace_out/top_p_sampling_v128256.json create mode 100644 flashinfer/trace/example/fi_trace_out/top_p_sampling_v151936.json create mode 100644 flashinfer/trace/template.py create mode 100644 flashinfer/trace/templates/__init__.py create mode 100644 flashinfer/trace/templates/attention.py create mode 100644 flashinfer/trace/templates/gdn.py create mode 100644 flashinfer/trace/templates/gemm.py create mode 100644 flashinfer/trace/templates/moe.py create mode 100644 flashinfer/trace/templates/norm.py create mode 100644 flashinfer/trace/templates/sampling.py create mode 100644 tests/test_fi_trace.py diff --git a/flashinfer/__init__.py b/flashinfer/__init__.py index 8ced5c509a..d07b480bc6 100644 --- a/flashinfer/__init__.py +++ b/flashinfer/__init__.py @@ -184,3 +184,4 @@ from .xqa import xqa as xqa from .xqa import xqa_mla as xqa_mla from . import mamba as mamba +from .fi_trace import fi_trace as fi_trace diff --git a/flashinfer/api_logging.py b/flashinfer/api_logging.py index e88bd7d3cf..a32b9d8e22 100644 --- a/flashinfer/api_logging.py +++ b/flashinfer/api_logging.py @@ -1417,7 +1417,108 @@ def _log_function_outputs(func_name: str, result: Any, level: int) -> None: _logger.debug("\n".join(lines)) -def flashinfer_api(func: Callable = None) -> Callable: +def _attach_fi_trace( + wrapped: Callable, + original: Callable, + trace_template=None, +) -> Callable: + """Attach a ``fi_trace`` callable to *wrapped*. + + Three resolution strategies, tried in order: + + 1. **Dispatch callable** (new interface): if *trace_template* is a + plain callable (not a ``TraceTemplate``), it is called at trace time + with the bound kwargs and must return the appropriate + :class:`~flashinfer.trace.TraceTemplate` for that invocation. Use + this when a single API function needs different templates depending on + a runtime parameter (e.g. ``routing_method_type``). + 2. **Explicit template** (new interface): if *trace_template* is a + :class:`~flashinfer.trace.TraceTemplate`, use it directly. + 3. **Registry lookup** (legacy interface): look up the qualname of + *original* in the old ``_REGISTRY`` dict in ``flashinfer.fi_trace``. + + When ``FLASHINFER_TRACE_DUMP=1`` is set and a template is provided, the + returned callable also auto-dumps a trace JSON on every invocation + (deduplication: same-named files are written only once per process). + + The attachment is a no-op when neither strategy finds a spec. + """ + try: + if trace_template is not None: + from flashinfer.trace.template import ( # noqa: PLC0415 + TraceTemplate, + _is_trace_dump_enabled, + ) + + # New interface: derive fi_api from the function's module + qualname. + module = getattr(original, "__module__", "") or "" + qualname = getattr(original, "__qualname__", "") or "" + fi_api = f"{module}.{qualname}" if module else qualname + + if isinstance(trace_template, TraceTemplate): + # Static template: pre-build the fi_trace callable once. + fi_trace_fn = trace_template.build_fi_trace_fn(fi_api) + else: + # Dispatch callable: *trace_template* is a function + # ``(save_dir=None, name=None, **kwargs) -> TraceTemplate``. + # Resolve the template at call time and cache per template + # instance to avoid rebuilding extractors on every call. + _dispatch_fn = trace_template + _fi_trace_cache: Dict[int, Callable] = {} + + def fi_trace_fn( + save_dir=None, + name=None, + **kwargs: Any, + ) -> Dict[str, Any]: + tpl = _dispatch_fn(**kwargs) + if tpl is None: + return {} + tpl_id = id(tpl) + if tpl_id not in _fi_trace_cache: + _fi_trace_cache[tpl_id] = tpl.build_fi_trace_fn(fi_api) + return _fi_trace_cache[tpl_id]( + save_dir=save_dir, name=name, **kwargs + ) + + wrapped.fi_trace = fi_trace_fn + + # Auto-dump wrapper: checked lazily at call time so that callers + # can set FLASHINFER_TRACE_DUMP after importing flashinfer (e.g. + # when running via ``python -m``). + _inner = wrapped + _sig = inspect.signature(original) + + @functools.wraps(_inner) + def _auto_dump_wrapper(*args, **kwargs): + # Generate trace BEFORE the actual call (crash-safe: schema + # depends only on input shapes/dtypes, not on whether the + # computation succeeds). + if _is_trace_dump_enabled(): + try: + bound = _sig.bind(*args, **kwargs) + bound.apply_defaults() + fi_trace_fn(**dict(bound.arguments)) + except Exception: + pass + return _inner(*args, **kwargs) + + _auto_dump_wrapper.fi_trace = fi_trace_fn + return _auto_dump_wrapper + else: + # Legacy registry lookup (kept for backwards compatibility). + from flashinfer.fi_trace import _REGISTRY, build_fi_trace_fn # noqa: PLC0415 + + qualname = getattr(original, "__qualname__", "") + spec = _REGISTRY.get(qualname) + if spec is not None: + wrapped.fi_trace = build_fi_trace_fn(spec) + except Exception: + pass + return wrapped + + +def flashinfer_api(func: Callable = None, *, trace=None) -> Callable: """ Decorator to FlashInfer's APIs. @@ -1489,11 +1590,12 @@ def flashinfer_api(func: Callable = None) -> Callable: - The %i pattern is automatically replaced with the process ID for multi-process environments. - The logger does not propagate to the root logger to avoid duplicate logs. """ - # If logging is disabled, return original function with zero overhead + # If logging is disabled, return original function with zero overhead. + # We still attach fi_trace so it is always available regardless of log level. if _API_LOG_LEVEL == 0: if func is None: - return lambda f: f - return func + return lambda f: _attach_fi_trace(f, f, trace_template=trace) + return _attach_fi_trace(func, func, trace_template=trace) def decorator(f: Callable) -> Callable: @functools.wraps(f) @@ -1561,7 +1663,7 @@ def wrapper(*args, **kwargs): return result - return wrapper + return _attach_fi_trace(wrapper, f, trace_template=trace) if func is None: return decorator diff --git a/flashinfer/decode.py b/flashinfer/decode.py index 3cad0aa954..036e49d753 100644 --- a/flashinfer/decode.py +++ b/flashinfer/decode.py @@ -22,6 +22,7 @@ import torch from .api_logging import flashinfer_api +from .trace.templates.attention import gqa_paged_decode_trace ## NOTE: MLA functions have been moved to mla.py, but we keep the aliases here for backward compatibility. from .mla import ( @@ -1215,7 +1216,7 @@ def run( kv_cache_sf: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, ) -> Tuple[torch.Tensor, torch.Tensor]: ... - @flashinfer_api + @flashinfer_api(trace=gqa_paged_decode_trace) def run( self, q: torch.Tensor, diff --git a/flashinfer/fi_trace.py b/flashinfer/fi_trace.py new file mode 100644 index 0000000000..727f218df9 --- /dev/null +++ b/flashinfer/fi_trace.py @@ -0,0 +1,281 @@ +# Copyright (c) 2025 by FlashInfer team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +fi_trace: Generate `flashinfer-bench `_ +compatible definition JSON for FlashInfer APIs. + +Every ``@flashinfer_api(trace=