Skip to content

Commit c8e216b

Browse files
authored
Merge branch 'main' into vtombari/cute-dsl-fp4-gemm-heur
2 parents c2d4313 + 77a179f commit c8e216b

18 files changed

Lines changed: 4283 additions & 145 deletions

README.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,8 @@ High-Performance GPU Kernels for Inference
6969
| Ada Lovelace | SM 8.9 | L4, L40, RTX 40 series |
7070
| Hopper | SM 9.0 | H100, H200 |
7171
| Blackwell | SM 10.0, 10.3 | B200, B300 |
72-
| Blackwell | SM 12.0, 12.1 | RTX 50 series, DGX Spark, Jetson Thor |
72+
| Blackwell | SM 11.0 | Jetson Thor |
73+
| Blackwell | SM 12.0, 12.1 | RTX 50 series, DGX Spark |
7374

7475
> **Note:** Not all features are supported across all compute capabilities.
7576

benchmarks/routines/flashinfer_benchmark_utils.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -197,6 +197,7 @@
197197
"fused_add_rmsnorm_quant",
198198
"rmsnorm_fp4quant",
199199
"add_rmsnorm_fp4quant",
200+
"fused_rmsnorm_silu",
200201
],
201202
"quantization": [
202203
"mxfp8_quantize",

benchmarks/routines/norm.py

Lines changed: 121 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,8 @@ def run_norm_test(args):
5151
return testRmsnormFp4quant(args)
5252
elif args.routine == "add_rmsnorm_fp4quant":
5353
return testAddRmsnormFp4quant(args)
54+
elif args.routine == "fused_rmsnorm_silu":
55+
return testFusedRmsnormSilu(args)
5456
else:
5557
raise ValueError(f"Unsupported routine: {args.routine}")
5658

@@ -1078,3 +1080,122 @@ def run_backend(backend, input_tensor, residual_tensor, weight):
10781080
cur_res["case_tag"] = args.case_tag
10791081
res.append(cur_res)
10801082
return res
1083+
1084+
1085+
def testFusedRmsnormSilu(args):
1086+
"""
1087+
Test fused_rmsnorm_silu API (RMSNorm + SiLU activation).
1088+
1089+
This test:
1090+
1. Generates random input tensors
1091+
2. Runs fused_rmsnorm_silu with bf16 output
1092+
3. Optionally runs reference check
1093+
4. Measures performance metrics (memory bandwidth)
1094+
1095+
Args:
1096+
args: Parsed command line arguments containing test configuration
1097+
1098+
Returns:
1099+
dict: List of dictionaries containing performance results
1100+
"""
1101+
if args.verbose >= 1:
1102+
print("[INFO] Running testFusedRmsnormSilu")
1103+
print(f"[INFO] FlashInfer version: {flashinfer.__version__}")
1104+
1105+
device = get_device(args)
1106+
if args.generate_repro_command:
1107+
print(
1108+
f"[INFO] To reproduce this test case, run the following command: {args.repro_command}"
1109+
)
1110+
1111+
batch_size = args.batch_size
1112+
hidden_size = args.hidden_size
1113+
eps = args.eps
1114+
is_cuda_graph_compatible = not args.no_cuda_graph
1115+
run_refcheck = args.refcheck
1116+
res = []
1117+
1118+
input_dtype = dtype_str_to_torch_dtype(args.input_dtype)
1119+
if input_dtype != torch.bfloat16:
1120+
raise ValueError(
1121+
f"fused_rmsnorm_silu requires bfloat16 input, got {args.input_dtype}"
1122+
)
1123+
1124+
input_shape = (batch_size, hidden_size)
1125+
input_tensor = torch.randn(input_shape, dtype=torch.bfloat16, device=device)
1126+
weight = torch.rand(hidden_size, dtype=torch.bfloat16, device=device) * 1.5 + 0.5
1127+
out = torch.empty(input_shape, dtype=torch.bfloat16, device=device)
1128+
1129+
if args.verbose >= 2:
1130+
print(f"[VVERBOSE] {input_tensor.shape = }")
1131+
print(f"[VVERBOSE] {input_tensor.dtype = }")
1132+
print(f"[VVERBOSE] {weight.shape = }")
1133+
1134+
def run_fn(input_tensor, weight, out):
1135+
return flashinfer.fused_rmsnorm_silu(input_tensor, weight, eps=eps, out=out)
1136+
1137+
has_reference_output = False
1138+
if run_refcheck:
1139+
rms = torch.sqrt(
1140+
torch.mean(input_tensor.float() ** 2, dim=-1, keepdim=True) + eps
1141+
)
1142+
x_norm = input_tensor.float() / rms * weight.float()
1143+
reference_output = torch.nn.functional.silu(x_norm).to(torch.bfloat16)
1144+
has_reference_output = True
1145+
1146+
if run_refcheck:
1147+
test_out = run_fn(input_tensor, weight, out)
1148+
if has_reference_output:
1149+
(
1150+
num_different_elements,
1151+
num_elements,
1152+
num_different_elements_percentage,
1153+
) = is_close_stats(reference_output, test_out, rtol=2e-2, atol=2e-2)
1154+
if num_different_elements > 0:
1155+
print(
1156+
f"[ERROR] Output tensor mismatch: "
1157+
f"{num_different_elements}/{num_elements} ({num_different_elements_percentage:.2f}%) elements differ"
1158+
)
1159+
if not args.allow_output_mismatch:
1160+
raise AssertionError(
1161+
f"[ERROR] Output mismatch with {num_different_elements} elements"
1162+
)
1163+
1164+
times = bench_gpu_time(
1165+
fn=run_fn,
1166+
dry_run_iters=args.dry_run_iters,
1167+
repeat_iters=args.num_iters,
1168+
enable_cupti=args.use_cupti,
1169+
use_cuda_graph=is_cuda_graph_compatible,
1170+
input_args=(input_tensor, weight, out),
1171+
)
1172+
1173+
if len(times) > 0:
1174+
median_time = np.median(times)
1175+
std_time = np.std(times)
1176+
1177+
num_elements = np.prod(input_shape)
1178+
problem_bytes = (
1179+
num_elements * input_dtype.itemsize # input read
1180+
+ hidden_size * input_dtype.itemsize # weight read
1181+
+ num_elements * input_dtype.itemsize # output write
1182+
)
1183+
problem_flops = num_elements * 7 # rmsnorm (5) + silu (2: exp + div)
1184+
tflops = problem_flops / (10**9 * median_time)
1185+
tb_per_sec = problem_bytes / (10**9 * median_time)
1186+
1187+
print_perf_metrics("cuda", median_time, std_time, tflops, tb_per_sec)
1188+
1189+
if args.output_path is not None:
1190+
cur_res = defaultdict(str)
1191+
cur_res["routine"] = args.routine
1192+
cur_res["median_time"] = median_time
1193+
cur_res["std_time"] = std_time
1194+
cur_res["tflops"] = tflops
1195+
cur_res["tb_per_sec"] = tb_per_sec
1196+
cur_res["input_dtype"] = str(input_dtype)
1197+
cur_res["eps"] = eps
1198+
cur_res["backend"] = "cuda"
1199+
cur_res["case_tag"] = args.case_tag
1200+
res.append(cur_res)
1201+
return res

benchmarks/samples/sample_testlist.txt

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -133,6 +133,14 @@
133133
# Both SF layouts with MXFP4 format
134134
--routine add_rmsnorm_fp4quant --batch_size 32 --hidden_size 4096 --input_dtype bfloat16 --out_dtype mxfp4 --output_both_sf_layouts -vv --generate_repro_command --case_tag "add_rmsnorm_fp4quant_mxfp4_both_sf"
135135

136+
## Fused RMSNorm + SiLU (SM80+, sweep-tuned on SM100/B200)
137+
# VAE decoder shapes (LUT-optimized on B200)
138+
--routine fused_rmsnorm_silu --batch_size 1560 --hidden_size 1024 --input_dtype bfloat16 --refcheck -vv --generate_repro_command --case_tag "fused_rmsnorm_silu_vae_small"
139+
--routine fused_rmsnorm_silu --batch_size 24960 --hidden_size 512 --input_dtype bfloat16 --refcheck -vv --generate_repro_command --case_tag "fused_rmsnorm_silu_vae_mid"
140+
--routine fused_rmsnorm_silu --batch_size 99840 --hidden_size 256 --input_dtype bfloat16 --refcheck -vv --generate_repro_command --case_tag "fused_rmsnorm_silu_vae_large"
141+
# Non-VAE shapes (fallback heuristics)
142+
--routine fused_rmsnorm_silu --batch_size 2048 --hidden_size 4096 --input_dtype bfloat16 --refcheck -vv --generate_repro_command --case_tag "fused_rmsnorm_silu_llama"
143+
136144
## Quantization (Blackwell SM10.0+ only)
137145
# MxFP8 Quantization - basic
138146
--routine mxfp8_quantize --m 1024 --k 4096 --input_dtype bfloat16 -vv --generate_repro_command --case_tag "mxfp8_quantize_basic"
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
/*
2+
* Copyright (c) 2026 by FlashInfer team.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
#include "tvm_ffi_utils.h"
17+
18+
void rmsnorm_silu(TensorView output, TensorView input, TensorView weight, double eps,
19+
TensorView workspace, TensorView scale_row_out, int64_t sm_count);
20+
21+
TVM_FFI_DLL_EXPORT_TYPED_FUNC(rmsnorm_silu, rmsnorm_silu);

csrc/rmsnorm_silu.cu

Lines changed: 115 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,115 @@
1+
/*
2+
* Copyright (c) 2026 by FlashInfer team.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
// clang-format off
18+
// Include order matters: headers → config (defines Ktraits) → kernel (uses Ktraits)
19+
#include <algorithm>
20+
#include <flashinfer/norm/ln_silu_headers.cuh>
21+
#include "rmsnorm_silu_config.inc"
22+
#include <flashinfer/norm/ln_fwd_silu_kernel.cuh>
23+
// clang-format on
24+
25+
#include "tvm_ffi_utils.h"
26+
27+
void rmsnorm_silu(TensorView output, TensorView input, TensorView weight, double eps,
28+
TensorView workspace, TensorView scale_row_out, int64_t sm_count) {
29+
CHECK_LAST_DIM_CONTIGUOUS_INPUT(input);
30+
CHECK_LAST_DIM_CONTIGUOUS_INPUT(output);
31+
CHECK_LAST_DIM_CONTIGUOUS_INPUT(weight);
32+
CHECK_DEVICE(input, weight);
33+
CHECK_DIM(2, input);
34+
CHECK_DIM(2, output);
35+
CHECK_DIM(1, weight);
36+
37+
int rows = input.size(0);
38+
int cols = input.size(1);
39+
TVM_FFI_ICHECK_EQ(cols, HIDDEN_SIZE) << "Input cols must match compiled HIDDEN_SIZE";
40+
TVM_FFI_ICHECK_EQ(output.size(0), rows);
41+
42+
ffi::CUDADeviceGuard device_guard(input.device().device_id);
43+
const cudaStream_t stream = get_stream(input.device());
44+
45+
// Grid dimensions (same logic as Sm100RmsNormSiluEngine::execute)
46+
int ctas_per_col_max = (rows + WARPS_M - 1) / WARPS_M;
47+
int ctas_per_col;
48+
if (KERNEL_CFG == 2) {
49+
ctas_per_col = ctas_per_col_max;
50+
} else {
51+
ctas_per_col =
52+
std::min(static_cast<int>(sm_count) * DESIRED_OCCUPANCY / CTAS_PER_ROW, ctas_per_col_max);
53+
}
54+
ctas_per_col = std::max(ctas_per_col, 1);
55+
56+
dim3 grid(CTAS_PER_ROW * ctas_per_col);
57+
dim3 block(WARPS_M * WARPS_N * 32);
58+
59+
// Pack kernel params
60+
PersistentLnFwdParams params{};
61+
params.rows = rows;
62+
params.cols = cols;
63+
params.ctas_per_col = ctas_per_col;
64+
params.isRMSNorm = true;
65+
params.noScale = false;
66+
params.noBias = true;
67+
params.isBatchFirst = true;
68+
params.batchSize = 1;
69+
params.seqLen = rows;
70+
params.epsilon = static_cast<float>(eps);
71+
params.x = input.data_ptr();
72+
params.z = output.data_ptr();
73+
params.gamma = weight.data_ptr();
74+
75+
// Workspace layout (128-byte aligned segments)
76+
char* ws_ptr = static_cast<char*>(workspace.data_ptr());
77+
78+
// [0] rs: rows * sizeof(float)
79+
params.rs = ws_ptr;
80+
int64_t off = static_cast<int64_t>(rows) * sizeof(float);
81+
off = ((off + 127) / 128) * 128;
82+
83+
// [aligned] fp8_scale: sizeof(float)
84+
if (isFP8Out) {
85+
params.fp8_out = true;
86+
float* default_scale = reinterpret_cast<float*>(ws_ptr + off);
87+
// Set scale = 1.0f via cudaMemcpyAsync from host
88+
static const float one = 1.0f;
89+
cudaMemcpyAsync(default_scale, &one, sizeof(float), cudaMemcpyHostToDevice, stream);
90+
params.scale = default_scale;
91+
}
92+
off += sizeof(float);
93+
off = ((off + 127) / 128) * 128;
94+
95+
// scale_row: passed as separate output tensor (NVFP4 only)
96+
if (isFP4Out) {
97+
params.scale_row = scale_row_out.data_ptr();
98+
}
99+
100+
// [aligned] cooperative workspace + barriers (multi-CTA only)
101+
if (CTAS_PER_ROW > 1) {
102+
params.workspace = ws_ptr + off;
103+
int64_t coop_ws_size =
104+
static_cast<int64_t>(ctas_per_col) * WARPS_M * CTAS_PER_ROW * sizeof(float) * 2 * 2;
105+
off += coop_ws_size;
106+
off = ((off + 127) / 128) * 128;
107+
108+
params.barrier = reinterpret_cast<int*>(ws_ptr + off);
109+
cudaMemsetAsync(params.barrier, 0, 2 * ctas_per_col * sizeof(int32_t), stream);
110+
}
111+
112+
reduced_divisor divisor(rows);
113+
114+
ln_fwd_kernel<<<grid, block, 0, stream>>>(params, divisor);
115+
}

docker/install/install_python_packages.sh

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ pip3 install responses pytest scipy build cuda-python nvshmem4py-cu12
3232
if [[ "$CUDA_VERSION" == *"cu13"* ]]; then
3333
pip3 install --upgrade cuda-python==13.0
3434
pip3 install --upgrade nvidia-cudnn-cu13
35+
pip3 install --upgrade "nvidia-cutlass-dsl[cu13]>=4.4.2"
3536
else
3637
pip3 install --upgrade cuda-python==12.*
3738
pip3 install --upgrade nvidia-cudnn-cu12

docs/api/norm.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,3 +17,4 @@ Kernels for normalization layers.
1717
gemma_rmsnorm
1818
gemma_fused_add_rmsnorm
1919
layernorm
20+
fused_rmsnorm_silu

flashinfer/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,7 @@
114114
from .norm import gemma_rmsnorm as gemma_rmsnorm
115115
from .norm import rmsnorm as rmsnorm
116116
from .norm import rmsnorm_quant as rmsnorm_quant
117+
from .norm import fused_rmsnorm_silu as fused_rmsnorm_silu
117118

118119
try:
119120
from .norm import rmsnorm_fp4quant as rmsnorm_fp4quant

flashinfer/aot.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,14 @@
8888
)
8989
from .jit.mla import gen_mla_module
9090
from .jit.norm import gen_norm_module
91+
from .jit.rmsnorm_silu import (
92+
gen_rmsnorm_silu_module,
93+
select_knobs,
94+
_estimate_ctas_per_row,
95+
_compute_default_knobs,
96+
_SUPPORTED_C,
97+
_SUPPORTED_TOKENS,
98+
)
9199
from .jit.page import gen_page_module
92100
from .jit.quantization import gen_quantization_module
93101
from .jit.rope import gen_rope_module
@@ -558,6 +566,44 @@ def gen_all_modules(
558566
gen_sampling_module(),
559567
gen_topk_module(),
560568
]
569+
# Fused RMSNorm+SiLU: pre-compile all LUT configs (SM100+ only)
570+
if has_sm100:
571+
for C in _SUPPORTED_C:
572+
for tokens in _SUPPORTED_TOKENS:
573+
for dtype in ["bf16", "fp8", "nvfp4"]:
574+
knobs = select_knobs(C, tokens, dtype)
575+
if knobs is None:
576+
continue
577+
wm, sc, kcfg, occ, bpl = knobs
578+
cpr = _estimate_ctas_per_row(C, sc, kcfg, bpl)
579+
jit_specs.append(
580+
gen_rmsnorm_silu_module(C, dtype, wm, cpr, bpl, kcfg, occ)
581+
)
582+
# Fallback configs for common hidden sizes not in the LUT.
583+
# Fallback knobs depend only on (C, dtype), not num_tokens,
584+
# so one module per (C, dtype) covers all token counts.
585+
_FALLBACK_C = [
586+
768,
587+
1280,
588+
1536,
589+
2048,
590+
2560,
591+
3072,
592+
4096,
593+
5120,
594+
6144,
595+
8192,
596+
]
597+
for C in _FALLBACK_C:
598+
for dtype in ["bf16", "fp8", "nvfp4"]:
599+
knobs = _compute_default_knobs(C, dtype)
600+
if knobs is None:
601+
continue
602+
wm, sc, kcfg, occ, bpl = knobs
603+
cpr = _estimate_ctas_per_row(C, sc, kcfg, bpl)
604+
jit_specs.append(
605+
gen_rmsnorm_silu_module(C, dtype, wm, cpr, bpl, kcfg, occ)
606+
)
561607
# selective_state_update: one module per dtype combo per GPU arch
562608
_ssu_dtype_combos = [
563609
# (state, input, weight, matrixA, stateIndex, state_scale_dtype)

0 commit comments

Comments
 (0)