Skip to content

Commit fa81b33

Browse files
committed
Merge remote-tracking branch 'origin/main' into fused-moe-non-gated-fp8
2 parents d5887ab + 273c09c commit fa81b33

21 files changed

Lines changed: 1092 additions & 20 deletions

.claude/skills/add-cuda-kernel/SKILL.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -234,6 +234,7 @@ def gen_scale_module(dtype_in, dtype_out):
234234
- No Jinja template needed for simple operations
235235
- Just copy source files to generation directory
236236
- URI uniquely identifies the module configuration
237+
- **NEVER write to package directories** - see "JIT Directory Rules" in `CLAUDE.md`
237238

238239
### (Optional) Specifying Supported CUDA Architectures
239240

CLAUDE.md

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -189,6 +189,17 @@ FlashInfer's JIT system has three layers:
189189
- `sources`: List of .cu/.cpp files to compile
190190
- `extra_cuda_cflags`, `extra_cflags`, `extra_ldflags`: Compiler flags
191191

192+
### JIT Directory Rules
193+
194+
**NEVER write to package directories** - they may be read-only after installation.
195+
196+
| Directory | Writable | Use for |
197+
|-----------|----------|---------|
198+
| `FLASHINFER_GEN_SRC_DIR` | ✓ Yes | Generated source files (Jinja output, copied .cu files) |
199+
| `FLASHINFER_JIT_DIR` | ✓ Yes | Compiled `.so` outputs |
200+
| `FLASHINFER_CSRC_DIR` | ✗ No | Read-only source templates |
201+
| `FLASHINFER_AOT_DIR` | ✗ No | Read-only pre-compiled binaries |
202+
192203
### Compilation Context: Architecture-Specific Compilation
193204

194205
FlashInfer uses `CompilationContext` to manage CUDA architecture targets. Some kernels only work on specific GPU architectures (e.g., Hopper SM90, Blackwell SM100/SM12x).

benchmarks/bench_topk.py

Lines changed: 34 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,19 @@ def bench_top_k(
7272
result["torch_us"] = torch_ms * 1e3
7373
result["speedup_vs_torch"] = torch_ms / fi_ms
7474

75+
# SGLang comparison (only supports k=2048 and float32)
76+
if compare_sglang and HAS_SGL_KERNEL and k == 2048 and dtype == torch.float32:
77+
lengths = torch.full((batch_size,), seq_len, dtype=torch.int32, device="cuda")
78+
measurements = bench_gpu_time(
79+
lambda: sgl_kernel.fast_topk_v2(scores, lengths, k, row_starts=None),
80+
enable_cupti=True,
81+
dry_run_iters=10,
82+
repeat_iters=100,
83+
)
84+
sg_ms = np.median(measurements)
85+
result["sglang_us"] = sg_ms * 1e3
86+
result["speedup_vs_sglang"] = sg_ms / fi_ms
87+
7588
return result
7689

7790

@@ -282,24 +295,39 @@ def main():
282295
if args.op in ["all", "top_k"]:
283296
print("=" * 100)
284297
print(f"top_k: Basic radix-based top-k selection (dtype={dtype_str})")
298+
if args.compare_sglang:
299+
print("NOTE: SGLang only supports k=2048 and float32")
285300
print("=" * 100)
286-
print(
287-
f"{'batch':>6} {'seq_len':>10} {'k':>6} | {'FlashInfer':>12} {'torch.topk':>12} {'Speedup':>10}"
288-
)
289-
print("-" * 70)
301+
302+
header = f"{'batch':>6} {'seq_len':>10} {'k':>6} | {'FlashInfer':>12} {'torch.topk':>12} {'Speedup':>10}"
303+
if args.compare_sglang:
304+
header += f" {'SGLang':>12} {'Speedup':>10}"
305+
print(header)
306+
print("-" * (70 if not args.compare_sglang else 90))
290307

291308
for batch_size in batch_sizes:
292309
for seq_len in seq_lens:
293310
for k in k_values:
294311
if k > seq_len:
295312
continue
296313
try:
297-
result = bench_top_k(batch_size, seq_len, k, dtype)
298-
print(
314+
result = bench_top_k(
315+
batch_size,
316+
seq_len,
317+
k,
318+
dtype,
319+
compare_sglang=args.compare_sglang,
320+
)
321+
line = (
299322
f"{result['batch_size']:>6} {result['seq_len']:>10} {result['k']:>6} | "
300323
f"{result['flashinfer_us']:>10.2f}us {result['torch_us']:>10.2f}us "
301324
f"{result['speedup_vs_torch']:>9.2f}x"
302325
)
326+
if "sglang_us" in result:
327+
line += f" {result['sglang_us']:>10.2f}us {result['speedup_vs_sglang']:>9.2f}x"
328+
elif args.compare_sglang and k == 2048:
329+
line += " (SGLang error)"
330+
print(line)
303331
except RuntimeError as e:
304332
if "out of memory" in str(e):
305333
print(f"{batch_size:>6} {seq_len:>10} {k:>6} | OOM")

ci/docker-tags.yml

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
flashinfer/flashinfer-ci-cu126: 20260131-a52eff1
2-
flashinfer/flashinfer-ci-cu128: 20260131-a52eff1
3-
flashinfer/flashinfer-ci-cu129: 20260131-a52eff1
4-
flashinfer/flashinfer-ci-cu130: 20260131-a52eff1
1+
flashinfer/flashinfer-ci-cu126: 20260203-9b5901e
2+
flashinfer/flashinfer-ci-cu128: 20260203-9b5901e
3+
flashinfer/flashinfer-ci-cu129: 20260203-9b5901e
4+
flashinfer/flashinfer-ci-cu130: 20260203-9b5901e

csrc/fp4_gemm_cutlass.jinja

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER({{ type }}, {{ cta_m }}, {{ cta_n }}, {{ ct
2626
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER({{ type }}, {{ cta_m }}, {{ cta_n }}, {{ cta_k }}, 2, 4, 1, _2SM)
2727
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER({{ type }}, {{ cta_m }}, {{ cta_n }}, {{ cta_k }}, 4, 2, 1, _2SM)
2828
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER({{ type }}, {{ cta_m }}, {{ cta_n }}, {{ cta_k }}, 4, 4, 1, _2SM)
29+
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER({{ type }}, {{ cta_m }}, {{ cta_n }}, {{ cta_k }}, 4, 1, 1, _2SM)
2930

3031
} // namespace gemm
3132
} // namespace flashinfer

csrc/fp4_gemm_cutlass_sm103.cu

Lines changed: 193 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,193 @@
1+
/*
2+
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
#include <cuda_fp16.h>
17+
18+
#include <cstddef>
19+
#include <cstdint>
20+
#include <functional>
21+
#include <type_traits>
22+
#include <vector>
23+
24+
#include "flashinfer/gemm/cutlass_gemm_configs.h"
25+
#include "flashinfer/gemm/fp4_gemm_cutlass.h"
26+
#include "flashinfer/gemm/fp4_gemm_cutlass_template_sm103.h"
27+
#include "tvm_ffi_utils.h"
28+
29+
using flashinfer::gemm::ClusterShape;
30+
using flashinfer::gemm::CutlassFp4GemmRunner;
31+
using flashinfer::gemm::CutlassFp4GemmRunnerInterface;
32+
using flashinfer::gemm::CutlassGemmConfig;
33+
using flashinfer::gemm::CutlassTileConfigSM100;
34+
using flashinfer::gemm::EpilogueScheduleType;
35+
using flashinfer::gemm::FP4GemmType;
36+
using flashinfer::gemm::MainloopScheduleType;
37+
38+
namespace flashinfer {
39+
namespace gemm {
40+
template class CutlassFp4GemmRunner<__nv_bfloat16, FP4GemmType::W4A4_NVFP4_NVFP4>;
41+
template class CutlassFp4GemmRunner<half, FP4GemmType::W4A4_NVFP4_NVFP4>;
42+
} // namespace gemm
43+
} // namespace flashinfer
44+
45+
namespace torch_ext {
46+
47+
namespace {
48+
49+
CutlassGemmConfig getFp4GemmConfig(int64_t m, int64_t n, int64_t k, int64_t tactic) {
50+
auto getCutlassFp4GemmConfigs = []() {
51+
CutlassFp4GemmRunner<__nv_bfloat16, FP4GemmType::W4A4_NVFP4_NVFP4> gemmRunner;
52+
return gemmRunner.getConfigs();
53+
};
54+
static std::vector<CutlassGemmConfig> globalConfigs = getCutlassFp4GemmConfigs();
55+
TVM_FFI_ICHECK(tactic >= 0 && tactic < globalConfigs.size())
56+
<< "tactic must be between 0 and " << globalConfigs.size();
57+
return globalConfigs[tactic];
58+
}
59+
60+
template <typename T>
61+
void runGemm(TensorView out, TensorView mat1, TensorView mat2, TensorView mat1Scale,
62+
TensorView mat2Scale, TensorView globalScale, int64_t m, int64_t n, int64_t k,
63+
int64_t batch_count, CutlassGemmConfig const& gemmConfig,
64+
TensorView workspace_buffer) {
65+
CutlassFp4GemmRunner<T, FP4GemmType::W4A4_NVFP4_NVFP4> gemmRunner;
66+
67+
int64_t const required_workspace_size = gemmRunner.getWorkspaceSize(m, n, k, batch_count);
68+
int64_t const provided_workspace_size =
69+
workspace_buffer.numel() * get_element_size(workspace_buffer);
70+
71+
auto runKernel = [&](void* workspace) {
72+
gemmRunner.gemm(out.data_ptr(), mat1.data_ptr(), mat2.data_ptr(), mat1Scale.data_ptr(),
73+
mat2Scale.data_ptr(), static_cast<float*>(globalScale.data_ptr()), m, n, k,
74+
batch_count, gemmConfig, reinterpret_cast<char*>(workspace),
75+
required_workspace_size, get_stream(mat1.device()));
76+
};
77+
78+
if (provided_workspace_size < required_workspace_size) {
79+
Tensor new_workspace =
80+
alloc_tensor({required_workspace_size}, DLDataType{kDLInt, 8, 1}, mat1.device());
81+
runKernel(new_workspace.data_ptr());
82+
} else {
83+
runKernel(workspace_buffer.data_ptr());
84+
}
85+
}
86+
87+
constexpr auto FLOAT4_E2M1X2 = dl_uint8; // uint8_t
88+
constexpr auto SF_DTYPE = dl_uint8; // uint8_t
89+
90+
// mat1: [B, M, K / 2], FLOAT4_E2M1X2 or [B, M, K], FLOAT8_E4M3FN
91+
// mat2: [B, N, K / 2], FLOAT4_E2M1X2
92+
// out: [B, M, N], fp16/bf16/fp32
93+
// mat1Scale: ceil(M / 128) * 128 * ceil(K / sfVecSize / 4) * 4, SF_DTYPE (UE4M3 or UE8M0)
94+
// mat2Scale: ceil(N / 128) * 128 * ceil(K / sfVecSize / 4) * 4, SF_DTYPE (UE4M3 or UE8M0)
95+
// globalScale: [1], 1 / (((448 * 6) / mat1.abs().max()) * ((448 * 6) / mat2.abs().max()))
96+
// B = 1 for GEMM op as a special case
97+
void fp4_bmm_impl(TensorView mat1, TensorView mat2, TensorView mat1Scale, TensorView mat2Scale,
98+
TensorView globalScale, TensorView out, TensorView workspace_buffer,
99+
int64_t tactic) {
100+
CHECK_INPUT_AND_TYPE(mat1, FLOAT4_E2M1X2);
101+
CHECK_INPUT_AND_TYPE(mat2, FLOAT4_E2M1X2);
102+
103+
int mat2_k_scale = 1;
104+
105+
CHECK_INPUT_AND_TYPE(mat1Scale, SF_DTYPE);
106+
CHECK_INPUT_AND_TYPE(mat2Scale, SF_DTYPE);
107+
108+
CHECK_INPUT_AND_TYPE(globalScale, dl_float32);
109+
110+
int64_t m, n, k, b;
111+
if (mat1.ndim() == 2) {
112+
TVM_FFI_ICHECK_EQ(mat2.ndim(), 2) << "mat2 must be a matrix";
113+
TVM_FFI_ICHECK_EQ(mat1.size(1), mat2.size(1) * mat2_k_scale)
114+
<< "mat1 and mat2 shapes cannot be multiplied (" << mat1.size(0) << "x" << mat1.size(1)
115+
<< " and " << mat2.size(0) << "x" << mat2.size(1) << ")";
116+
m = mat1.size(0);
117+
n = mat2.size(0);
118+
k = mat2.size(1) * 2;
119+
b = 1;
120+
} else if (mat1.ndim() == 3) {
121+
TVM_FFI_ICHECK_EQ(mat2.ndim(), 3) << "mat2 must be a batch of matrices";
122+
TVM_FFI_ICHECK_EQ(mat1.size(0), mat2.size(0)) << "mat1 and mat2 must have the same batch size ("
123+
<< mat1.size(0) << " and " << mat2.size(0) << ")";
124+
TVM_FFI_ICHECK_EQ(mat1.size(2), mat2.size(2) * mat2_k_scale)
125+
<< "mat1 and mat2 shapes cannot be multiplied (" << mat1.size(1) << "x" << mat1.size(2)
126+
<< " and " << mat2.size(1) << "x" << mat2.size(2) << ")";
127+
m = mat1.size(1);
128+
n = mat2.size(1);
129+
k = mat2.size(2) * 2;
130+
b = mat1.size(0);
131+
} else {
132+
TVM_FFI_LOG_AND_THROW(NotImplementedError) << "mat1 must be a matrix or a batch of matrices";
133+
}
134+
135+
// No heuristic for now, we rely on the autotuner to select the best tactic.
136+
if (tactic == -1) {
137+
tactic = 0;
138+
}
139+
auto config = getFp4GemmConfig(m, n, k, tactic);
140+
141+
constexpr int alignment = 32;
142+
TVM_FFI_ICHECK_EQ(k % alignment, 0)
143+
<< "Expected k to be divisible by " << alignment << ", but got mat1 shape: (" << mat1.size(0)
144+
<< "x" << mat1.size(1) << "), k: " << k << ".";
145+
TVM_FFI_ICHECK_EQ(n % alignment, 0)
146+
<< "Expected n to be divisible by " << alignment << ", but got mat2 shape: (" << mat2.size(0)
147+
<< "x" << mat2.size(1) << ").";
148+
149+
// Validate out dimensions
150+
std::vector<int64_t> out_shape =
151+
mat1.ndim() == 2 ? std::vector<int64_t>{m, n} : std::vector<int64_t>{b, m, n};
152+
TVM_FFI_ICHECK_EQ(out.ndim(), out_shape.size())
153+
<< "out must have " << out_shape.size() << " dimensions, but got " << out.ndim();
154+
for (int i = 0; i < out_shape.size(); ++i) {
155+
TVM_FFI_ICHECK_EQ(out.size(i), out_shape[i])
156+
<< "out shape mismatch at dimension " << i << ": expected " << out_shape[i] << ", got "
157+
<< out.size(i);
158+
}
159+
160+
switch (encode_dlpack_dtype(out.dtype())) {
161+
case float16_code:
162+
runGemm<half>(out, mat1, mat2, mat1Scale, mat2Scale, globalScale, m, n, k, b, config,
163+
workspace_buffer);
164+
break;
165+
case bfloat16_code:
166+
runGemm<__nv_bfloat16>(out, mat1, mat2, mat1Scale, mat2Scale, globalScale, m, n, k, b, config,
167+
workspace_buffer);
168+
break;
169+
default:
170+
TVM_FFI_ICHECK(false) << "out_dtype must be one of fp16/bf16.";
171+
}
172+
}
173+
174+
} // namespace
175+
176+
void fp4_gemm(TensorView mat1, TensorView mat2, TensorView mat1Scale, TensorView mat2Scale,
177+
TensorView globalScale, TensorView out, TensorView workspace_buffer, int64_t tactic) {
178+
fp4_bmm_impl(mat1, mat2, mat1Scale, mat2Scale, globalScale, out, workspace_buffer, tactic);
179+
}
180+
181+
int64_t fp4_gemm_tactic_num() {
182+
auto getCutlassConfigs = []() {
183+
CutlassFp4GemmRunner<__nv_bfloat16, FP4GemmType::W4A4_NVFP4_NVFP4> gemmRunner;
184+
return gemmRunner.getConfigs();
185+
};
186+
static int64_t totalTactics = getCutlassConfigs().size();
187+
return totalTactics;
188+
}
189+
190+
} // namespace torch_ext
191+
192+
TVM_FFI_DLL_EXPORT_TYPED_FUNC(fp4_gemm, torch_ext::fp4_gemm);
193+
TVM_FFI_DLL_EXPORT_TYPED_FUNC(fp4_gemm_tactic_num, torch_ext::fp4_gemm_tactic_num);

csrc/fp4_gemm_cutlass_sm103.jinja

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
/*
2+
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
#include "flashinfer/gemm/fp4_gemm_cutlass_template_sm103.h"
18+
19+
namespace flashinfer {
20+
namespace gemm {
21+
INSTANTIATE_FP4_ULTRA_GEMM_KERNEL_LAUNCHER({{ type }}, {{ cta_m }}, {{ cta_n }}, {{ cta_k }}, 1, 1, 1, _1SM_sm103)
22+
INSTANTIATE_FP4_ULTRA_GEMM_KERNEL_LAUNCHER({{ type }}, {{ cta_m }}, {{ cta_n }}, {{ cta_k }}, 1, 2, 1, _1SM_sm103)
23+
INSTANTIATE_FP4_ULTRA_GEMM_KERNEL_LAUNCHER({{ type }}, {{ cta_m }}, {{ cta_n }}, {{ cta_k }}, 1, 4, 1, _1SM_sm103)
24+
INSTANTIATE_FP4_ULTRA_GEMM_KERNEL_LAUNCHER({{ type }}, {{ cta_m }}, {{ cta_n }}, {{ cta_k }}, 2, 1, 1, _2SM_sm103)
25+
INSTANTIATE_FP4_ULTRA_GEMM_KERNEL_LAUNCHER({{ type }}, {{ cta_m }}, {{ cta_n }}, {{ cta_k }}, 2, 2, 1, _2SM_sm103)
26+
INSTANTIATE_FP4_ULTRA_GEMM_KERNEL_LAUNCHER({{ type }}, {{ cta_m }}, {{ cta_n }}, {{ cta_k }}, 2, 4, 1, _2SM_sm103)
27+
INSTANTIATE_FP4_ULTRA_GEMM_KERNEL_LAUNCHER({{ type }}, {{ cta_m }}, {{ cta_n }}, {{ cta_k }}, 4, 2, 1, _2SM_sm103)
28+
INSTANTIATE_FP4_ULTRA_GEMM_KERNEL_LAUNCHER({{ type }}, {{ cta_m }}, {{ cta_n }}, {{ cta_k }}, 4, 4, 1, _2SM_sm103)
29+
INSTANTIATE_FP4_ULTRA_GEMM_KERNEL_LAUNCHER({{ type }}, {{ cta_m }}, {{ cta_n }}, {{ cta_k }}, 4, 1, 1, _2SM_sm103)
30+
31+
} // namespace gemm
32+
} // namespace flashinfer

docker/Dockerfile.cu126

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,9 @@ RUN echo "source activate py312" >> ~/.bashrc
1919
ENV PATH="/opt/conda/bin:$PATH"
2020
ENV PATH="/opt/conda/envs/py312/bin:$PATH"
2121

22+
# Ensure pip-installed nvidia-cublas takes precedence over system libraries
23+
ENV LD_LIBRARY_PATH="/opt/conda/envs/py312/lib/python3.12/site-packages/nvidia/cublas/lib/:$LD_LIBRARY_PATH"
24+
2225
# Install torch and other python packages
2326
COPY requirements.txt /install/requirements.txt
2427
COPY docker/install/install_python_packages.sh /install/install_python_packages.sh

docker/Dockerfile.cu128

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,9 @@ RUN echo "source activate py312" >> ~/.bashrc
1919
ENV PATH="/opt/conda/bin:$PATH"
2020
ENV PATH="/opt/conda/envs/py312/bin:$PATH"
2121

22+
# Ensure pip-installed nvidia-cublas takes precedence over system libraries
23+
ENV LD_LIBRARY_PATH="/opt/conda/envs/py312/lib/python3.12/site-packages/nvidia/cublas/lib/:$LD_LIBRARY_PATH"
24+
2225
# Install torch and other python packages
2326
COPY requirements.txt /install/requirements.txt
2427
COPY docker/install/install_python_packages.sh /install/install_python_packages.sh

docker/Dockerfile.cu129

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,9 @@ RUN echo "source activate py312" >> ~/.bashrc
1919
ENV PATH="/opt/conda/bin:$PATH"
2020
ENV PATH="/opt/conda/envs/py312/bin:$PATH"
2121

22+
# Ensure pip-installed nvidia-cublas takes precedence over system libraries
23+
ENV LD_LIBRARY_PATH="/opt/conda/envs/py312/lib/python3.12/site-packages/nvidia/cublas/lib/:$LD_LIBRARY_PATH"
24+
2225
# Triton
2326
ENV TRITON_PTXAS_PATH="/usr/local/cuda/bin/ptxas"
2427

0 commit comments

Comments
 (0)