diff --git a/.claude/skills/add-cuda-kernel/SKILL.md b/.claude/skills/add-cuda-kernel/SKILL.md index ee8c74da22..8da3c7d7f2 100644 --- a/.claude/skills/add-cuda-kernel/SKILL.md +++ b/.claude/skills/add-cuda-kernel/SKILL.md @@ -625,7 +625,155 @@ Check functions must: 3. Raise `ValueError` with descriptive message if validation fails 4. Be decorated with `@supported_compute_capability` to specify supported architectures -## Step 6: Write Tests in `tests/` +## Step 6: Add a Trace Template + +Every new kernel **must** have a `TraceTemplate` so that flashinfer-bench can auto-generate +benchmark definition files via `@flashinfer_api(trace=...)`. + +### 6a. Create the template in `flashinfer/trace/templates/` + +Add a file (or extend an existing one) in `flashinfer/trace/templates/`. The +real `flashinfer/trace/templates/norm.py` is a good reference — it shows two +variants that share an `op_type` but have distinct `name_prefix` values: + +```python +# flashinfer/trace/templates/norm.py (real file, simplified for illustration) +from ..template import Const, Tensor, TraceTemplate, Var + +# op_type – high-level operation category written to the JSON "op_type" field. +# Two templates can share the same op_type when they are variants of +# the same operation family. +# name_prefix – base string for the auto-generated filename and JSON "name" field. +# Const axis values are appended, e.g. rmsnorm_h4096.json. +# Must be unique across templates that share an op_type. + +rmsnorm_trace = TraceTemplate( + op_type="rmsnorm", # category: all RMSNorm variants share this + name_prefix="rmsnorm", # specific variant → file: rmsnorm_h.json + description="Root Mean Square Normalization. Epsilon is fixed at 1e-6.", + axes={ + "batch_size": Var(), # runtime-variable: omitted from filename + "hidden_size": Const(abbrev="h"), # baked into filename as "h" + }, + inputs={ + # json_key "hidden_states" differs from the Python param name "input", + # so param= is set explicitly. + "hidden_states": Tensor(["batch_size", "hidden_size"], param="input"), + "weight": Tensor(["hidden_size"]), # key == param, no param= needed + }, + outputs={ + "output": Tensor(["batch_size", "hidden_size"], dtype_from="input"), + }, + tags=["status:verified"], + reference=_rmsnorm_reference, +) + +fused_add_rmsnorm_trace = TraceTemplate( + op_type="rmsnorm", # same category as rmsnorm_trace above + name_prefix="fused_add_rmsnorm", # different variant → fused_add_rmsnorm_h.json + description="Fused Add + RMSNorm. Epsilon is fixed at 1e-6.", + axes={ + "batch_size": Var(), + "hidden_size": Const(abbrev="h"), + }, + inputs={ + "hidden_states": Tensor(["batch_size", "hidden_size"], param="input"), + "residual": Tensor(["batch_size", "hidden_size"]), + "weight": Tensor(["hidden_size"]), + }, + outputs={ + "output": Tensor(["batch_size", "hidden_size"], dtype_from="input"), + "residual": Tensor( + ["batch_size", "hidden_size"], + dtype_from="input", + description="Updated residual (in-place: residual += hidden_states).", + ), + }, + tags=["status:verified", "fused"], + reference=_fused_add_rmsnorm_reference, +) +``` + +Key rules: +- `Var()` → value is NOT baked into the generated name or JSON `value`. +- `Const(abbrev=...)` → value IS extracted and written to JSON. `abbrev="h"` → `h4096`; `abbrev=""` → omit from filename. +- Each `Tensor` key defaults to `param=key`; use `param="other_name"` when they differ. +- `dtype_from=""` copies the dtype from that input tensor (use the JSON key, not the param name). +- For dispatch (one function, multiple templates depending on a kwarg), pass a + plain callable as `trace=`: + ```python + def _my_trace_dispatch(**kwargs): + if kwargs.get("mode") == "fast": + return fast_trace + return slow_trace + + @flashinfer_api(trace=_my_trace_dispatch) + def my_op(..., mode="fast"): + ... + ``` + See `flashinfer/fused_moe/core.py` + `flashinfer/trace/templates/moe.py` for a + real dispatch example keyed on `routing_method_type`. + +### 6b. Attach the template to the API + +```python +# flashinfer/norm.py (real file) +from .trace.templates.norm import rmsnorm_trace + +@flashinfer_api(trace=rmsnorm_trace) +def rmsnorm(input: torch.Tensor, weight: torch.Tensor, ...) -> torch.Tensor: + ... +``` + +The `fi_api` tag is derived automatically from `func.__module__ + "." + func.__qualname__`. + +### 6c. Register your module for auto-discovery + +Open `tests/trace/test_fi_trace_template_consistency.py` and add your module to +the import list inside `_collect_template_func_pairs()`: + +```python +import flashinfer.norm # ← add your module here +``` + +That's it. `@flashinfer_api(trace=...)` automatically registers every +`(func, template)` pair in `flashinfer.api_logging._TRACE_REGISTRY` at +decoration time. Importing the module triggers the decorator, and the +parameterized tests then check: + +1. **Signature consistency**: every non-optional `param=` reference exists in the function's signature. +2. **Axis coverage**: every `Const` axis appears in at least one tensor's `dim_names` or the function's parameter list. +3. **End-to-end**: `fi_trace` with auto-generated CPU tensors returns a complete dict + (no `"unknown"` dtypes for non-optional inputs, all `Const` axes have values). + +If your template uses tuple inputs or exotic dtypes (fp8 scale tensors, etc.), +add a targeted end-to-end test at the bottom of the file and add your label to +`_E2E_SKIP` (see the MoE example there). + +For **dispatch templates** (callable `trace=`), also set a `.templates` +attribute on the dispatch function listing all possible return values: + +```python +def _my_trace_dispatch(**kwargs): ... +_my_trace_dispatch.templates = [fast_trace, slow_trace] +``` + +This lets the registry auto-discover and check all variants. + +### 6d. Run the consistency tests + +```bash +pytest tests/trace/test_fi_trace_template_consistency.py -v +``` + +A failing structural test looks like: +``` +AssertionError: [rmsnorm] Template 'rmsnorm' has param mismatches: + Input 'hidden_states' → param='x' not found in rmsnorm(['input', 'weight', 'eps']) +``` +which tells you exactly which key is wrong and what names are available. + +## Step 7: Write Tests in `tests/` Create tests in an appropriate subdirectory (e.g., `tests/elementwise/test_scale.py` or create a new subdir if needed): @@ -794,13 +942,15 @@ if __name__ == "__main__": ## Summary of Files Created/Modified ``` -include/flashinfer/scale.cuh # NEW: CUDA kernel definition -csrc/scale.cu # NEW: PyTorch launcher -csrc/scale_jit_binding.cu # NEW: TVM-FFI binding -flashinfer/jit/scale.py # NEW: JIT generator -flashinfer/scale.py # NEW: Python API -flashinfer/__init__.py # MODIFIED: Export API -flashinfer/aot.py # MODIFIED: Register AOT -tests/test_scale.py # NEW: Unit tests -benchmarks/bench_scale.py # NEW: Benchmark script +include/flashinfer/scale.cuh # NEW: CUDA kernel definition +csrc/scale.cu # NEW: PyTorch launcher +csrc/scale_jit_binding.cu # NEW: TVM-FFI binding +flashinfer/jit/scale.py # NEW: JIT generator +flashinfer/scale.py # NEW: Python API (with @flashinfer_api(trace=...)) +flashinfer/trace/templates/scale.py # NEW: TraceTemplate definition +flashinfer/__init__.py # MODIFIED: Export API +flashinfer/aot.py # MODIFIED: Register AOT +tests/test_scale.py # NEW: Kernel unit tests +tests/trace/test_fi_trace_template_consistency.py # MODIFIED: Add (func, template) pair +benchmarks/bench_scale.py # NEW: Benchmark script ``` diff --git a/CLAUDE.md b/CLAUDE.md index bbd055286a..e74821b306 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -344,6 +344,20 @@ flashinfer/ 7. Write tests in `tests/` 8. Register in `flashinfer/aot.py` for AOT compilation 9. Export in `flashinfer/__init__.py` +10. Add a `TraceTemplate` in `flashinfer/trace/templates/` and wire it via `@flashinfer_api(trace=...)` (see below) +11. Add an example call in `tests/trace/example.py`, re-run to regenerate `fi_trace_out/`, and commit the new JSON files + +### Trace Template Checklist (for new or updated APIs) + +Every public API decorated with `@flashinfer_api` should also carry a `trace=` argument so that `fi_trace()` works and auto-dump produces a benchmark definition JSON. + +1. **Create or update a `TraceTemplate`** in `flashinfer/trace/templates/.py` (e.g., `norm.py`, `activation.py`, `cascade.py`, `gdn.py`). Define `axes`, `inputs`, `outputs`, and optionally a `reference` function. +2. **Wire the template** to the API: `@flashinfer_api(trace=my_trace)` on the Python function (or class method's `run()`). +3. **Add an example call** in `tests/trace/example.py` that exercises the new trace with realistic shapes. +4. **Regenerate examples**: `rm -rf tests/trace/fi_trace_out && python tests/trace/example.py` — verify the expected JSON appears. +5. **Update the docstring** in `tests/trace/example.py` to list the new file(s). +6. **Run tests**: `pytest tests/trace/ -v` — all template-consistency and end-to-end tests must pass. +7. **Commit the new JSON files** under `tests/trace/fi_trace_out/` alongside the code changes. **Example implementations:** - **Simple**: `flashinfer/norm.py` (RMSNorm) - no Jinja, good starting point diff --git a/docs/fi_trace.rst b/docs/fi_trace.rst new file mode 100644 index 0000000000..4002283ada --- /dev/null +++ b/docs/fi_trace.rst @@ -0,0 +1,321 @@ +.. _fi_trace: + +fi_trace — Operation Schema Extraction +======================================= + +``fi_trace`` is FlashInfer's operation schema extraction system. Every +``@flashinfer_api``-decorated function automatically grows a ``.fi_trace()`` +method that captures the *shape*, *dtype*, and *axis structure* of a call as a +portable JSON file — without running the GPU kernel. + +These JSON files are the input format for `flashinfer-bench +`_, the companion benchmark +toolkit. Collecting them while running your production workload gives you a +precise benchmark suite that reflects your actual model and serving scenario. + +Quick Start +----------- + +Set two environment variables **before** importing FlashInfer: + +.. code-block:: bash + + export FLASHINFER_TRACE_DUMP=1 + export FLASHINFER_TRACE_DUMP_DIR=./fi_trace_out # default: ./fi_trace_out + + python my_inference_script.py + +FlashInfer writes one ``.json`` file per unique (op, shape) combination. +Subsequent calls with the same shapes are deduplicated — no duplicate files. + +.. code-block:: text + + fi_trace_out/ + ├── rmsnorm_h7168.json + ├── gqa_paged_decode_h32_kv8_d128_ps16.json + ├── moe_fp8_block_scale_ds_routing_topk8_ng8_kg4_e32_h7168_i2048.json + └── ... + +Environment Variables +--------------------- + +.. list-table:: + :header-rows: 1 + :widths: 35 12 20 33 + + * - Variable + - Type + - Default + - Description + * - ``FLASHINFER_TRACE_DUMP`` + - int + - ``0`` + - Set to ``1`` to enable automatic JSON dumping on every API call. + * - ``FLASHINFER_TRACE_DUMP_DIR`` + - str + - ``./fi_trace_out`` + - Directory where JSON files are written. + +Both variables are read **lazily at call time**, so they can be set after +``import flashinfer`` (e.g. when using ``python -m``). + +JSON File Format +---------------- + +Each file describes one operation instance. Here is an annotated example for +``rmsnorm`` with ``hidden_size=7168``: + +.. code-block:: json + + { + "name": "rmsnorm_h7168", + "description": "Root Mean Square Normalization. Epsilon is fixed at 1e-6.", + "op_type": "rmsnorm", + "tags": [ + "fi_api:flashinfer.norm.rmsnorm", + "status:verified" + ], + "axes": { + "batch_size": { "type": "var" }, + "hidden_size": { "type": "const", "value": 7168 } + }, + "inputs": { + "hidden_states": { "shape": ["batch_size", "hidden_size"], "dtype": "bfloat16" }, + "weight": { "shape": ["hidden_size"], "dtype": "bfloat16" } + }, + "outputs": { + "output": { "shape": ["batch_size", "hidden_size"], "dtype": "bfloat16" } + }, + "reference": "..." + } + +Key fields: + +.. list-table:: + :header-rows: 1 + :widths: 20 80 + + * - Field + - Meaning + * - ``name`` + - Auto-generated from ``op_type`` / ``name_prefix`` + const-axis values. + Becomes the benchmark name in flashinfer-bench. + * - ``op_type`` + - Identifies the kernel class (``rmsnorm``, ``gqa_paged``, ``moe``, …). + * - ``tags`` + - List of key:value tags. Always includes ``fi_api:`` + and optional metadata like ``status:verified``. + * - ``axes`` + - Symbolic dimensions. ``"var"`` axes vary at runtime (batch size, + sequence length). ``"const"`` axes are fixed by model config (head + dimension, hidden size) and carry a ``"value"``. + * - ``inputs`` / ``outputs`` + - Each entry has ``"shape"`` (list of axis names) and a resolved + ``"dtype"``. Optional inputs carry ``"optional": true``. + * - ``reference`` + - Source of a pure-PyTorch reference implementation for correctness + checking (present on ``status:verified`` ops). + +Calling ``.fi_trace()`` Directly +--------------------------------- + +Every decorated function exposes a ``.fi_trace()`` method. +You can call it without running the kernel: + +.. code-block:: python + + import torch + import flashinfer + + q = torch.zeros(32, 32, 128, dtype=torch.bfloat16, device="cuda") + k = torch.zeros(64, 16, 8, 128, dtype=torch.bfloat16, device="cuda") + v = torch.zeros(64, 16, 8, 128, dtype=torch.bfloat16, device="cuda") + + schema = flashinfer.norm.rmsnorm.fi_trace( + hidden_states=torch.zeros(32, 7168, dtype=torch.bfloat16), + weight=torch.ones(7168, dtype=torch.bfloat16), + ) + print(schema["name"]) # rmsnorm_h7168 + print(schema["axes"]) # {'batch_size': {'type': 'var'}, 'hidden_size': {'type': 'const', 'value': 7168}} + +To write to a specific directory, pass ``save_dir``: + +.. code-block:: python + + schema = flashinfer.norm.rmsnorm.fi_trace( + hidden_states=..., + weight=..., + save_dir="./my_traces", + ) + +Covered Operations +------------------ + +The following FlashInfer operations have trace templates and will emit JSON +files when ``FLASHINFER_TRACE_DUMP=1``: + +.. list-table:: + :header-rows: 1 + :widths: 25 35 40 + + * - Module + - Operation + - ``op_type`` + * - ``flashinfer.norm`` + - ``rmsnorm``, ``fused_add_rmsnorm`` + - ``rmsnorm`` + * - ``flashinfer.sampling`` + - ``top_k_sampling_from_probs``, + ``top_p_sampling_from_probs``, + ``top_k_top_p_sampling_from_probs`` + - ``sampling`` + * - ``flashinfer.gemm`` + - ``mm_bf16``, ``mm_fp8``, ``mm_mxfp8``, ``mm_fp4`` + - ``gemm_bf16`` / ``gemm_fp8`` / ``gemm_mxfp8`` / ``gemm_fp4`` + * - ``flashinfer.decode`` + - ``BatchDecodeWithPagedKVCacheWrapper.run`` + - ``gqa_paged`` + * - ``flashinfer.prefill`` + - ``BatchPrefillWithPagedKVCacheWrapper.run``, + ``BatchPrefillWithRaggedKVCacheWrapper.run`` + - ``gqa_paged`` / ``gqa_ragged`` + * - ``flashinfer.mla`` + - ``BatchMLAPagedAttentionWrapper.run`` + - ``mla_paged`` + * - ``flashinfer.gdn_decode`` + - ``gated_delta_rule_decode``, ``gated_delta_rule_mtp`` + - ``gdn`` + * - ``flashinfer.gdn_prefill`` + - ``chunk_gated_delta_rule`` + - ``gdn`` + * - ``flashinfer.fused_moe`` + - ``trtllm_fp8_block_scale_moe`` (6 routing types) + - ``moe`` + * - ``flashinfer.fused_moe`` + - ``trtllm_fp4_block_scale_moe`` (6 routing types) + - ``moe`` + +MoE Routing Types +----------------- + +MoE operations dispatch to per-routing-type templates. The output filename +encodes the routing method: + +.. list-table:: + :header-rows: 1 + :widths: 10 25 65 + + * - Value + - Name + - Filename pattern (FP8 example) + * - 0 + - Default (Softmax → TopK) + - ``moe_fp8_block_scale_default_routing_topk8_e32_h7168_i2048.json`` + * - 1 + - Renormalize (TopK → Softmax) + - ``moe_fp8_block_scale_renormalize_routing_topk8_e32_h7168_i2048.json`` + * - 2 + - DeepSeekV3 (Sigmoid + group selection) + - ``moe_fp8_block_scale_ds_routing_topk8_ng8_kg4_e32_h7168_i2048.json`` + * - 3 + - Llama4 (Top1 → Sigmoid) + - ``moe_fp8_block_scale_llama4_routing_topk1_e32_h7168_i2048.json`` + * - 4 + - RenormalizeNaive (Softmax → TopK → Renormalize) + - ``moe_fp8_block_scale_renormalize_naive_routing_topk8_e32_h7168_i2048.json`` + * - 5 + - TopK (no normalisation) + - ``moe_fp8_block_scale_topk_routing_topk8_e32_h7168_i2048.json`` + +Example: Collecting Traces from a Real Workload +------------------------------------------------ + +The script below runs a representative set of FlashInfer ops and collects all +trace JSON files in one pass. It covers the shapes used in DeepSeek-V3-style +models with expert-parallel MoE serving. + +.. code-block:: bash + + python tests/trace/example.py + +The generated files can be passed directly to ``flashinfer-bench``: + +.. code-block:: bash + + flashinfer-bench --trace-dir fi_trace_out/ --backends fa2 cudnn cutlass + +Adding Trace Support to a New Kernel +-------------------------------------- + +When adding a new kernel (see ``CLAUDE.md`` and ``.claude/skills/add-cuda-kernel/SKILL.md`` +for the full tutorial), attach a ``TraceTemplate`` to the ``@flashinfer_api`` decorator: + +.. code-block:: python + + from flashinfer.trace.template import Const, Tensor, TraceTemplate, Var + from flashinfer.api_logging import flashinfer_api + + rmsnorm_trace = TraceTemplate( + op_type="rmsnorm", + name_prefix="rmsnorm", + description="Root Mean Square Normalization.", + axes={ + "batch_size": Var(), + "hidden_size": Const(abbrev="h"), + }, + inputs={ + "hidden_states": Tensor(["batch_size", "hidden_size"]), + "weight": Tensor(["hidden_size"]), + }, + outputs={ + "output": Tensor(["batch_size", "hidden_size"], dtype_from="hidden_states"), + }, + tags=["status:verified"], + ) + + @flashinfer_api(trace=rmsnorm_trace) + def rmsnorm(hidden_states, weight, eps=1e-6): + ... + +The template is registered automatically in ``_TRACE_REGISTRY`` at decoration +time and picked up by the consistency tests without any manual registration. + +For operations whose template depends on a runtime parameter (e.g. +``routing_method_type`` for MoE), write a dispatch callable and attach a +``.templates`` attribute so the registry discovers all variants: + +.. code-block:: python + + _TEMPLATES = {0: default_trace, 1: renorm_trace, ...} + + def my_dispatch(**kwargs): + return _TEMPLATES.get(int(kwargs.get("routing_method_type", 0))) + + my_dispatch.templates = list(_TEMPLATES.values()) + + @flashinfer_api(trace=my_dispatch) + def my_moe_op(...): + ... + +Consistency Tests +----------------- + +FlashInfer ships automated **linter-style tests** that validate every trace +template without running GPU kernels: + +.. code-block:: bash + + pytest tests/trace/test_fi_trace_template_consistency.py -v + +The tests check three properties for every registered template: + +1. **Signature consistency** — every ``param=`` reference in the template + matches a real parameter of the decorated function. +2. **Axes coverage** — every ``Const`` axis can be resolved from at least one + tensor's shape or from a scalar kwarg. +3. **End-to-end completeness** — calling ``.fi_trace()`` with auto-generated + minimal tensors returns a dict where all ``Const`` axes have values and + no input/output has ``dtype == "unknown"``. + +When you add a template, these tests run automatically with no manual +registration required. diff --git a/docs/index.rst b/docs/index.rst index 028ed54a59..55f4e0a991 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -17,6 +17,7 @@ FlashInfer is a library and kernel generator for Large Language Models that prov installation cli logging + fi_trace autotuning .. toctree:: diff --git a/flashinfer/__init__.py b/flashinfer/__init__.py index 58187ca85c..bbf548aeef 100644 --- a/flashinfer/__init__.py +++ b/flashinfer/__init__.py @@ -187,3 +187,4 @@ from .xqa import xqa as xqa from .xqa import xqa_mla as xqa_mla from . import mamba as mamba +from .fi_trace import fi_trace as fi_trace diff --git a/flashinfer/activation.py b/flashinfer/activation.py index 3bdd3df769..c1f4e4dc79 100644 --- a/flashinfer/activation.py +++ b/flashinfer/activation.py @@ -22,6 +22,11 @@ from .api_logging import flashinfer_api from .jit import gen_act_and_mul_module +from .trace.templates.activation import ( + gelu_and_mul_trace, + gelu_tanh_and_mul_trace, + silu_and_mul_trace, +) from .utils import ( device_support_pdl, register_custom_op, @@ -67,7 +72,7 @@ def _check_shape(input: torch.Tensor, output: torch.Tensor) -> None: ) -@flashinfer_api +@flashinfer_api(trace=silu_and_mul_trace) def silu_and_mul( input: torch.Tensor, out: torch.Tensor = None, enable_pdl: Optional[bool] = None ) -> torch.Tensor: @@ -112,7 +117,7 @@ def silu_and_mul( return out -@flashinfer_api +@flashinfer_api(trace=gelu_tanh_and_mul_trace) def gelu_tanh_and_mul( input: torch.Tensor, out: torch.Tensor = None, enable_pdl: Optional[bool] = None ) -> torch.Tensor: @@ -153,7 +158,7 @@ def gelu_tanh_and_mul( return out -@flashinfer_api +@flashinfer_api(trace=gelu_and_mul_trace) def gelu_and_mul( input: torch.Tensor, out: torch.Tensor = None, enable_pdl: Optional[bool] = None ) -> torch.Tensor: diff --git a/flashinfer/api_logging.py b/flashinfer/api_logging.py index e88bd7d3cf..0213b3da80 100644 --- a/flashinfer/api_logging.py +++ b/flashinfer/api_logging.py @@ -24,7 +24,7 @@ import sys from datetime import datetime from pathlib import Path -from typing import Any, Callable, Dict, Tuple, Optional +from typing import Any, Callable, Dict, List, Tuple, Optional import contextlib import importlib import torch @@ -1417,7 +1417,162 @@ def _log_function_outputs(func_name: str, result: Any, level: int) -> None: _logger.debug("\n".join(lines)) -def flashinfer_api(func: Callable = None) -> Callable: +# --------------------------------------------------------------------------- +# Trace template registry +# --------------------------------------------------------------------------- +# Populated automatically by _attach_fi_trace whenever @flashinfer_api is +# given a trace= argument. Each entry is (original_func, template, label) +# where label is the template's name_prefix (or op_type as fallback). +# +# For dispatch callables (trace=some_fn), every template listed in +# some_fn.templates is registered if that attribute exists. +# +# Read by tests/trace/test_fi_trace_template_consistency.py to auto-discover +# all registered templates without requiring manual maintenance. +_TRACE_REGISTRY: List[Tuple[Callable, Any, str]] = [] + + +def _attach_fi_trace( + wrapped: Callable, + original: Callable, + trace_template=None, +) -> Callable: + """Attach a ``fi_trace`` callable to *wrapped*. + + Three resolution strategies, tried in order: + + 1. **Dispatch callable** (new interface): if *trace_template* is a + plain callable (not a ``TraceTemplate``), it is called at trace time + with the bound kwargs and must return the appropriate + :class:`~flashinfer.trace.TraceTemplate` for that invocation. Use + this when a single API function needs different templates depending on + a runtime parameter (e.g. ``routing_method_type``). + 2. **Explicit template** (new interface): if *trace_template* is a + :class:`~flashinfer.trace.TraceTemplate`, use it directly. + 3. **Registry lookup** (legacy interface): look up the qualname of + *original* in the old ``_REGISTRY`` dict in ``flashinfer.fi_trace``. + + When ``FLASHINFER_TRACE_DUMP=1`` is set and a template is provided, the + returned callable also auto-dumps a trace JSON on every invocation + (deduplication: same-named files are written only once per process). + + The attachment is a no-op when neither strategy finds a spec. + """ + try: + if trace_template is not None: + from flashinfer.trace.template import ( # noqa: PLC0415 + TraceTemplate, + _is_trace_dump_enabled, + ) + + # New interface: derive fi_api from the function's module + qualname. + module = getattr(original, "__module__", "") or "" + qualname = getattr(original, "__qualname__", "") or "" + fi_api = f"{module}.{qualname}" if module else qualname + + if isinstance(trace_template, TraceTemplate): + # Static template: pre-build the fi_trace callable once. + fi_trace_fn = trace_template.build_fi_trace_fn(fi_api) + # Register for auto-discovery by consistency tests. + label = trace_template.name_prefix or trace_template.op_type + _TRACE_REGISTRY.append((original, trace_template, label)) + else: + # Dispatch callable: *trace_template* is a function + # ``(save_dir=None, name=None, **kwargs) -> TraceTemplate``. + # Resolve the template at call time and cache per template + # instance to avoid rebuilding extractors on every call. + # If the dispatch function exposes a .templates iterable, + # register each template for auto-discovery. + for tpl in getattr(trace_template, "templates", ()): + if isinstance(tpl, TraceTemplate): + _label = tpl.name_prefix or tpl.op_type + _TRACE_REGISTRY.append((original, tpl, _label)) + _dispatch_fn = trace_template + _fi_trace_cache: Dict[int, Callable] = {} + + def fi_trace_fn( + save_dir=None, + name=None, + **kwargs: Any, + ) -> Dict[str, Any]: + tpl = _dispatch_fn(**kwargs) + if tpl is None: + return {} + tpl_id = id(tpl) + if tpl_id not in _fi_trace_cache: + _fi_trace_cache[tpl_id] = tpl.build_fi_trace_fn(fi_api) + return _fi_trace_cache[tpl_id]( + save_dir=save_dir, name=name, **kwargs + ) + + wrapped.fi_trace = fi_trace_fn # type: ignore[attr-defined] + + # Auto-dump wrapper: checked lazily at call time so that callers + # can set FLASHINFER_TRACE_DUMP after importing flashinfer (e.g. + # when running via ``python -m``). + _inner = wrapped + _sig = inspect.signature(original) + + # Track which (function, error-type) pairs have already been warned + # about so we emit at most one diagnostic per failure class per process. + _autodump_warned: set = set() + + @functools.wraps(_inner) + def _auto_dump_wrapper(*args, **kwargs): + # Generate trace BEFORE the actual call (crash-safe: schema + # depends only on input shapes/dtypes, not on whether the + # computation succeeds). + if _is_trace_dump_enabled(): + try: + bound = _sig.bind(*args, **kwargs) + bound.apply_defaults() + fi_trace_fn(**dict(bound.arguments)) + except Exception as _exc: + # Non-fatal: the API call still runs. Warn once per + # (function, error-type) so users get a diagnostic + # instead of silently missing a trace JSON. + _key = (fi_api, type(_exc).__name__) + if _key not in _autodump_warned: + _autodump_warned.add(_key) + import warnings as _warnings # noqa: PLC0415 + + _warnings.warn( + f"[flashinfer] fi_trace auto-dump failed for " + f"'{fi_api}': {type(_exc).__name__}: {_exc}. " + f"Further occurrences of this error for this API " + f"will be suppressed.", + stacklevel=2, + ) + return _inner(*args, **kwargs) + + _auto_dump_wrapper.fi_trace = fi_trace_fn # type: ignore[attr-defined] + return _auto_dump_wrapper + else: + # Legacy registry lookup (kept for backwards compatibility). + from flashinfer.fi_trace import _REGISTRY, build_fi_trace_fn # noqa: PLC0415 + + qualname = getattr(original, "__qualname__", "") + spec = _REGISTRY.get(qualname) + if spec is not None: + wrapped.fi_trace = build_fi_trace_fn(spec) # type: ignore[attr-defined] + except Exception as _exc: + # Warn instead of silently swallowing: a broken trace template should + # be visible to the developer during import, not discovered later as a + # confusing AttributeError when calling func.fi_trace(...). + _func_name = getattr(original, "__qualname__", repr(original)) + import warnings # noqa: PLC0415 + + warnings.warn( + f"[flashinfer] Failed to attach fi_trace to '{_func_name}': " + f"{type(_exc).__name__}: {_exc}\n" + f"The function will work normally but fi_trace will be unavailable. " + f"Fix the TraceTemplate passed to @flashinfer_api(trace=...).", + stacklevel=3, + ) + return wrapped + + +def flashinfer_api(func: Callable = None, *, trace=None) -> Callable: """ Decorator to FlashInfer's APIs. @@ -1489,11 +1644,12 @@ def flashinfer_api(func: Callable = None) -> Callable: - The %i pattern is automatically replaced with the process ID for multi-process environments. - The logger does not propagate to the root logger to avoid duplicate logs. """ - # If logging is disabled, return original function with zero overhead + # If logging is disabled, return original function with zero overhead. + # We still attach fi_trace so it is always available regardless of log level. if _API_LOG_LEVEL == 0: if func is None: - return lambda f: f - return func + return lambda f: _attach_fi_trace(f, f, trace_template=trace) + return _attach_fi_trace(func, func, trace_template=trace) def decorator(f: Callable) -> Callable: @functools.wraps(f) @@ -1561,7 +1717,7 @@ def wrapper(*args, **kwargs): return result - return wrapper + return _attach_fi_trace(wrapper, f, trace_template=trace) if func is None: return decorator diff --git a/flashinfer/attention.py b/flashinfer/attention.py index c4bc4f27dc..5ce30409cc 100644 --- a/flashinfer/attention.py +++ b/flashinfer/attention.py @@ -21,6 +21,7 @@ import torch from .api_logging import flashinfer_api +from .trace.templates.attention import batch_attention_run_trace from .jit import gen_batch_attention_module from .utils import ( MaskMode, @@ -135,7 +136,7 @@ def plan( causal, ) - @flashinfer_api + @flashinfer_api(trace=batch_attention_run_trace) def run( self, q: torch.Tensor, @@ -209,6 +210,8 @@ class BatchAttentionWithAttentionSinkWrapper(BatchPrefillWithPagedKVCacheWrapper a convenient interface for using attention sinks during prefill or decode attention. """ + # No @flashinfer_api here: parent class BatchPrefillWithPagedKVCacheWrapper + # already decorates __init__, so decorating again produces double log entries. def __init__( self, float_workspace_buffer: torch.Tensor, diff --git a/flashinfer/cascade.py b/flashinfer/cascade.py index 1de363bb37..bdaaa6234e 100644 --- a/flashinfer/cascade.py +++ b/flashinfer/cascade.py @@ -23,6 +23,12 @@ from .decode import BatchDecodeWithPagedKVCacheWrapper from .jit.cascade import gen_cascade_module from .prefill import BatchPrefillWithPagedKVCacheWrapper, single_prefill_with_kv_cache +from .trace.templates.attention import multi_level_cascade_run_trace +from .trace.templates.cascade import ( + merge_state_in_place_trace, + merge_state_trace, + merge_states_trace, +) from .utils import register_custom_op, register_fake_op @@ -31,7 +37,7 @@ def get_cascade_module(): return gen_cascade_module().build_and_load() -@flashinfer_api +@flashinfer_api(trace=merge_state_trace) @register_custom_op("flashinfer::merge_state", mutates_args=()) def merge_state( v_a: torch.Tensor, s_a: torch.Tensor, v_b: torch.Tensor, s_b: torch.Tensor @@ -98,7 +104,7 @@ def _fake_merge_state( return v, s -@flashinfer_api +@flashinfer_api(trace=merge_state_in_place_trace) @register_custom_op("flashinfer::merge_state_in_place", mutates_args=("v", "s")) def merge_state_in_place( v: torch.Tensor, @@ -159,7 +165,7 @@ def _fake_merge_state_in_place( pass -@flashinfer_api +@flashinfer_api(trace=merge_states_trace) @register_custom_op("flashinfer::merge_states", mutates_args=()) def merge_states(v: torch.Tensor, s: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: r"""Merge multiple attention states (v, s). @@ -512,7 +518,7 @@ def plan( begin_forward = plan - @flashinfer_api + @flashinfer_api(trace=multi_level_cascade_run_trace) def run( self, q: torch.Tensor, diff --git a/flashinfer/cudnn/decode.py b/flashinfer/cudnn/decode.py index 195ca2d49d..9b59309534 100644 --- a/flashinfer/cudnn/decode.py +++ b/flashinfer/cudnn/decode.py @@ -4,6 +4,7 @@ import torch from ..api_logging import flashinfer_api +from ..trace.templates.attention import cudnn_batch_decode_trace from .utils import get_cudnn_fmha_gen_module try: @@ -253,7 +254,7 @@ def _batch_decode_with_kv_cache( return out -@flashinfer_api +@flashinfer_api(trace=cudnn_batch_decode_trace) def cudnn_batch_decode_with_kv_cache( q: torch.Tensor, k_cache: torch.Tensor, diff --git a/flashinfer/cudnn/prefill.py b/flashinfer/cudnn/prefill.py index fc1bbb5f4c..b16d604305 100644 --- a/flashinfer/cudnn/prefill.py +++ b/flashinfer/cudnn/prefill.py @@ -4,6 +4,7 @@ import torch from ..api_logging import flashinfer_api +from ..trace.templates.attention import cudnn_batch_prefill_trace from .utils import get_cudnn_fmha_gen_module try: @@ -558,7 +559,7 @@ def _batch_prefill_with_kv_cache( return out, None -@flashinfer_api +@flashinfer_api(trace=cudnn_batch_prefill_trace) def cudnn_batch_prefill_with_kv_cache( q: torch.Tensor, k_cache: torch.Tensor, diff --git a/flashinfer/decode.py b/flashinfer/decode.py index 822aca407c..c0daa6859d 100644 --- a/flashinfer/decode.py +++ b/flashinfer/decode.py @@ -22,6 +22,11 @@ import torch from .api_logging import flashinfer_api +from .trace.templates.attention import ( + gqa_paged_decode_trace, + single_decode_with_kv_cache_trace, + trtllm_batch_decode_trace, +) ## NOTE: MLA functions have been moved to mla.py, but we keep the aliases here for backward compatibility. from .mla import ( @@ -400,7 +405,7 @@ def single_decode_with_kv_cache( ) -> Tuple[torch.Tensor, torch.Tensor]: ... -@flashinfer_api +@flashinfer_api(trace=single_decode_with_kv_cache_trace) def single_decode_with_kv_cache( q: torch.Tensor, k: torch.Tensor, @@ -1215,7 +1220,7 @@ def run( kv_cache_sf: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, ) -> Tuple[torch.Tensor, torch.Tensor]: ... - @flashinfer_api + @flashinfer_api(trace=gqa_paged_decode_trace) def run( self, q: torch.Tensor, @@ -1577,6 +1582,8 @@ class CUDAGraphBatchDecodeWithPagedKVCacheWrapper(BatchDecodeWithPagedKVCacheWra :class:`BatchDecodeWithPagedKVCacheWrapper` """ + # No @flashinfer_api here: parent class BatchDecodeWithPagedKVCacheWrapper + # already decorates __init__, so decorating again produces double log entries. def __init__( self, workspace_buffer: torch.Tensor, @@ -2232,7 +2239,7 @@ def _fake_paged_run( ) -@flashinfer_api +@flashinfer_api(trace=trtllm_batch_decode_trace) def trtllm_batch_decode_with_kv_cache( query: torch.Tensor, kv_cache: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]], diff --git a/flashinfer/fi_trace.py b/flashinfer/fi_trace.py new file mode 100644 index 0000000000..1104eb6f07 --- /dev/null +++ b/flashinfer/fi_trace.py @@ -0,0 +1,285 @@ +# Copyright (c) 2025 by FlashInfer team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +fi_trace: Generate `flashinfer-bench `_ +compatible definition JSON for FlashInfer APIs. + +Every ``@flashinfer_api(trace=