Skip to content
Open
Show file tree
Hide file tree
Changes from 16 commits
Commits
Show all changes
35 commits
Select commit Hold shift + click to select a range
cb9fb89
add missing flashinfer_api
averyhNV Mar 31, 2026
0dca5ff
init
averyhNV Apr 3, 2026
b933da3
move example to tests
averyhNV Apr 3, 2026
014ef85
add skills and checker
averyhNV Apr 3, 2026
fec4310
add two meta tests
averyhNV Apr 3, 2026
13c99ba
fmt and add more moe
averyhNV Apr 3, 2026
073aaee
fmt and more test
averyhNV Apr 3, 2026
8c2ac40
fmt
averyhNV Apr 3, 2026
afe2efa
upd doc webpage
averyhNV Apr 3, 2026
7311070
add trace templates for activation, cascade, and norm variants
averyhNV Apr 21, 2026
81a9456
fmt
averyhNV Apr 21, 2026
b04b3d7
add trace examples for new ops and PR checklist for trace templates
averyhNV Apr 21, 2026
6c56441
track fi_trace_out JSON files and harden example.py for non-SM100 GPUs
averyhNV Apr 21, 2026
79e3277
fmt
averyhNV Apr 21, 2026
87c1c4b
fix PR #2931 review: drop double-logging and fix gdn_mtp state update
averyhNV Apr 21, 2026
811e404
Merge branch 'main' into fi_trace
yyihuang Apr 21, 2026
0737555
add CUDA-graph example for fi_trace
averyhNV Apr 21, 2026
2999978
track CUDA-graph example trace JSON for review
averyhNV Apr 21, 2026
75f7e67
fix PR #2931 review B1-B3: correct GEMM / paged-attention reference m…
averyhNV Apr 21, 2026
6ceeedc
fix PR #2931 review B4-B11: MoE/GDN refs, schema polish, auto-dump diag
averyhNV Apr 21, 2026
d935aee
drop @flashinfer_api from internal execute_cudnn_gemm_*_override_shap…
averyhNV Apr 21, 2026
d101769
add fi_trace templates for RoPE and quantization APIs
averyhNV Apr 21, 2026
d2ddf27
add fi_trace for cuDNN/TRT-LLM attention, CUTLASS/TRT-LLM MoE, and re…
averyhNV Apr 22, 2026
aa461c1
Merge branch 'main' into fi_trace
yyihuang Apr 22, 2026
b87aea9
fmt
averyhNV Apr 22, 2026
e02d5a6
add reference implementations for FP4 MoE trace templates
averyhNV Apr 22, 2026
9917cd7
add reference implementations for all remaining trace templates + cor…
averyhNV Apr 22, 2026
d6a67d9
explicitly del unused params in MoE reference signatures
averyhNV Apr 22, 2026
b9ad044
replace del-based unused-param suppression with **_unused + trimmed s…
averyhNV Apr 22, 2026
66528b3
Merge branch 'main' into fi_trace
yyihuang Apr 22, 2026
7d9e3fe
verify fi_trace dumps JSONs during a real sglang inference pass
averyhNV Apr 22, 2026
adef8b6
fmt: ruff fixes for example_sglang.py
averyhNV Apr 22, 2026
69ebf27
Merge branch 'main' into fi_trace
yyihuang Apr 22, 2026
13b0937
fmt: trailing newlines on sglang trace fixtures
averyhNV Apr 22, 2026
2ea8ffe
trace: wire fi_trace to all remaining public APIs in flashinfer/__ini…
averyhNV Apr 22, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
170 changes: 160 additions & 10 deletions .claude/skills/add-cuda-kernel/SKILL.md
Original file line number Diff line number Diff line change
Expand Up @@ -625,7 +625,155 @@ Check functions must:
3. Raise `ValueError` with descriptive message if validation fails
4. Be decorated with `@supported_compute_capability` to specify supported architectures

## Step 6: Write Tests in `tests/`
## Step 6: Add a Trace Template

Every new kernel **must** have a `TraceTemplate` so that flashinfer-bench can auto-generate
benchmark definition files via `@flashinfer_api(trace=...)`.

### 6a. Create the template in `flashinfer/trace/templates/`

Add a file (or extend an existing one) in `flashinfer/trace/templates/`. The
real `flashinfer/trace/templates/norm.py` is a good reference β€” it shows two
variants that share an `op_type` but have distinct `name_prefix` values:

```python
# flashinfer/trace/templates/norm.py (real file, simplified for illustration)
from ..template import Const, Tensor, TraceTemplate, Var

# op_type – high-level operation category written to the JSON "op_type" field.
# Two templates can share the same op_type when they are variants of
# the same operation family.
# name_prefix – base string for the auto-generated filename and JSON "name" field.
# Const axis values are appended, e.g. rmsnorm_h4096.json.
# Must be unique across templates that share an op_type.

rmsnorm_trace = TraceTemplate(
op_type="rmsnorm", # category: all RMSNorm variants share this
name_prefix="rmsnorm", # specific variant β†’ file: rmsnorm_h<hidden>.json
description="Root Mean Square Normalization. Epsilon is fixed at 1e-6.",
axes={
"batch_size": Var(), # runtime-variable: omitted from filename
"hidden_size": Const(abbrev="h"), # baked into filename as "h<value>"
},
inputs={
# json_key "hidden_states" differs from the Python param name "input",
# so param= is set explicitly.
"hidden_states": Tensor(["batch_size", "hidden_size"], param="input"),
"weight": Tensor(["hidden_size"]), # key == param, no param= needed
},
outputs={
"output": Tensor(["batch_size", "hidden_size"], dtype_from="input"),
},
tags=["status:verified"],
reference=_rmsnorm_reference,
)

fused_add_rmsnorm_trace = TraceTemplate(
op_type="rmsnorm", # same category as rmsnorm_trace above
name_prefix="fused_add_rmsnorm", # different variant β†’ fused_add_rmsnorm_h<hidden>.json
description="Fused Add + RMSNorm. Epsilon is fixed at 1e-6.",
axes={
"batch_size": Var(),
"hidden_size": Const(abbrev="h"),
},
inputs={
"hidden_states": Tensor(["batch_size", "hidden_size"], param="input"),
"residual": Tensor(["batch_size", "hidden_size"]),
"weight": Tensor(["hidden_size"]),
},
outputs={
"output": Tensor(["batch_size", "hidden_size"], dtype_from="input"),
"residual": Tensor(
["batch_size", "hidden_size"],
dtype_from="input",
description="Updated residual (in-place: residual += hidden_states).",
),
},
tags=["status:verified", "fused"],
reference=_fused_add_rmsnorm_reference,
)
```

Key rules:
- `Var()` β†’ value is NOT baked into the generated name or JSON `value`.
- `Const(abbrev=...)` β†’ value IS extracted and written to JSON. `abbrev="h"` β†’ `h4096`; `abbrev=""` β†’ omit from filename.
- Each `Tensor` key defaults to `param=key`; use `param="other_name"` when they differ.
- `dtype_from="<input_key>"` copies the dtype from that input tensor (use the JSON key, not the param name).
- For dispatch (one function, multiple templates depending on a kwarg), pass a
plain callable as `trace=`:
```python
def _my_trace_dispatch(**kwargs):
if kwargs.get("mode") == "fast":
return fast_trace
return slow_trace

@flashinfer_api(trace=_my_trace_dispatch)
def my_op(..., mode="fast"):
...
```
See `flashinfer/fused_moe/core.py` + `flashinfer/trace/templates/moe.py` for a
real dispatch example keyed on `routing_method_type`.

### 6b. Attach the template to the API

```python
# flashinfer/norm.py (real file)
from .trace.templates.norm import rmsnorm_trace

@flashinfer_api(trace=rmsnorm_trace)
def rmsnorm(input: torch.Tensor, weight: torch.Tensor, ...) -> torch.Tensor:
...
```

The `fi_api` tag is derived automatically from `func.__module__ + "." + func.__qualname__`.

### 6c. Register your module for auto-discovery

Open `tests/trace/test_fi_trace_template_consistency.py` and add your module to
the import list inside `_collect_template_func_pairs()`:

```python
import flashinfer.norm # ← add your module here
```

That's it. `@flashinfer_api(trace=...)` automatically registers every
`(func, template)` pair in `flashinfer.api_logging._TRACE_REGISTRY` at
decoration time. Importing the module triggers the decorator, and the
parameterized tests then check:

1. **Signature consistency**: every non-optional `param=` reference exists in the function's signature.
2. **Axis coverage**: every `Const` axis appears in at least one tensor's `dim_names` or the function's parameter list.
3. **End-to-end**: `fi_trace` with auto-generated CPU tensors returns a complete dict
(no `"unknown"` dtypes for non-optional inputs, all `Const` axes have values).

If your template uses tuple inputs or exotic dtypes (fp8 scale tensors, etc.),
add a targeted end-to-end test at the bottom of the file and add your label to
`_E2E_SKIP` (see the MoE example there).

For **dispatch templates** (callable `trace=`), also set a `.templates`
attribute on the dispatch function listing all possible return values:

```python
def _my_trace_dispatch(**kwargs): ...
_my_trace_dispatch.templates = [fast_trace, slow_trace]
```

This lets the registry auto-discover and check all variants.

### 6d. Run the consistency tests

```bash
pytest tests/trace/test_fi_trace_template_consistency.py -v
```

A failing structural test looks like:
```
AssertionError: [rmsnorm] Template 'rmsnorm' has param mismatches:
Input 'hidden_states' β†’ param='x' not found in rmsnorm(['input', 'weight', 'eps'])
```
which tells you exactly which key is wrong and what names are available.

## Step 7: Write Tests in `tests/`

Create tests in an appropriate subdirectory (e.g., `tests/elementwise/test_scale.py` or create a new subdir if needed):

Expand Down Expand Up @@ -794,13 +942,15 @@ if __name__ == "__main__":
## Summary of Files Created/Modified

```
include/flashinfer/scale.cuh # NEW: CUDA kernel definition
csrc/scale.cu # NEW: PyTorch launcher
csrc/scale_jit_binding.cu # NEW: TVM-FFI binding
flashinfer/jit/scale.py # NEW: JIT generator
flashinfer/scale.py # NEW: Python API
flashinfer/__init__.py # MODIFIED: Export API
flashinfer/aot.py # MODIFIED: Register AOT
tests/test_scale.py # NEW: Unit tests
benchmarks/bench_scale.py # NEW: Benchmark script
include/flashinfer/scale.cuh # NEW: CUDA kernel definition
csrc/scale.cu # NEW: PyTorch launcher
csrc/scale_jit_binding.cu # NEW: TVM-FFI binding
flashinfer/jit/scale.py # NEW: JIT generator
flashinfer/scale.py # NEW: Python API (with @flashinfer_api(trace=...))
flashinfer/trace/templates/scale.py # NEW: TraceTemplate definition
flashinfer/__init__.py # MODIFIED: Export API
flashinfer/aot.py # MODIFIED: Register AOT
tests/test_scale.py # NEW: Kernel unit tests
tests/trace/test_fi_trace_template_consistency.py # MODIFIED: Add (func, template) pair
benchmarks/bench_scale.py # NEW: Benchmark script
```
14 changes: 14 additions & 0 deletions CLAUDE.md
Original file line number Diff line number Diff line change
Expand Up @@ -344,6 +344,20 @@ flashinfer/
7. Write tests in `tests/`
8. Register in `flashinfer/aot.py` for AOT compilation
9. Export in `flashinfer/__init__.py`
10. Add a `TraceTemplate` in `flashinfer/trace/templates/` and wire it via `@flashinfer_api(trace=...)` (see below)
11. Add an example call in `tests/trace/example.py`, re-run to regenerate `fi_trace_out/`, and commit the new JSON files

### Trace Template Checklist (for new or updated APIs)

Every public API decorated with `@flashinfer_api` should also carry a `trace=` argument so that `fi_trace()` works and auto-dump produces a benchmark definition JSON.

1. **Create or update a `TraceTemplate`** in `flashinfer/trace/templates/<category>.py` (e.g., `norm.py`, `activation.py`, `cascade.py`, `gdn.py`). Define `axes`, `inputs`, `outputs`, and optionally a `reference` function.
2. **Wire the template** to the API: `@flashinfer_api(trace=my_trace)` on the Python function (or class method's `run()`).
3. **Add an example call** in `tests/trace/example.py` that exercises the new trace with realistic shapes.
4. **Regenerate examples**: `rm -rf tests/trace/fi_trace_out && python tests/trace/example.py` β€” verify the expected JSON appears.
5. **Update the docstring** in `tests/trace/example.py` to list the new file(s).
6. **Run tests**: `pytest tests/trace/ -v` β€” all template-consistency and end-to-end tests must pass.
7. **Commit the new JSON files** under `tests/trace/fi_trace_out/` alongside the code changes.

**Example implementations:**
- **Simple**: `flashinfer/norm.py` (RMSNorm) - no Jinja, good starting point
Expand Down
Loading
Loading