Skip to content

Commit 1a4d4b6

Browse files
yzh119cyx-6bkryu
authored andcommitted
refactor: refactoring cuda code to cute-dsl (part 1) (flashinfer-ai#2428)
<!-- .github/pull_request_template.md --> ## πŸ“Œ Description We prioritize using dsl for kernel development over cuda for faster JIT compilation speed. This PR is the first series that refactors the simple normalization kernels to cute-dsl. CUDA code should be ready to remove after we finish end-to-end testing. ## πŸ” Related Issues <!-- Link any related issues here --> ## πŸš€ Pull Request Checklist Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete. ### βœ… Pre-commit Checks - [x] I have installed `pre-commit` by running `pip install pre-commit` (or used your preferred method). - [x] I have installed the hooks with `pre-commit install`. - [x] I have run the hooks manually with `pre-commit run --all-files` and fixed any reported issues. > If you are unsure about how to set up `pre-commit`, see [the pre-commit documentation](https://pre-commit.com/). ## πŸ§ͺ Tests - [x] Tests have been added or updated as needed. - [x] All tests are passing (`unittest`, etc.). ## Reviewer Notes <!-- Optional: anything you'd like reviewers to focus on, concerns, etc. --> <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **New Features** * CuTe-DSL–accelerated normalization: RMSNorm (2D/3D), LayerNorm, fused add+RMSNorm, and FP8-quantized variants exposed for runtime use. * Shared norm utilities and JIT warmup to improve kernel readiness. * **Chores** * Runtime selection and fallback for CuTe-DSL/CUDA normalization with a visibility check. * **Bug Fixes** * Safer optional-dependency handling to avoid hard failures when CUDA/CuTe-DSL is unavailable. <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Co-authored-by: Yaxing Cai <caiyaxing666@gmail.com> Co-authored-by: Brian Ryu <bryu@nvidia.com> Signed-off-by: Amey Naik <212485788+ameynaik-hub@users.noreply.github.com>
1 parent 20b3e76 commit 1a4d4b6

15 files changed

Lines changed: 2581 additions & 70 deletions

File tree

β€Žcsrc/flashinfer_norm_binding.cuβ€Ž

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,14 +17,14 @@
1717

1818
void rmsnorm(TensorView out, TensorView input, TensorView weight, double eps, bool enable_pdl);
1919

20-
void rmsnorm_quant(TensorView out, TensorView input, TensorView weight, double scale, double eps,
21-
bool enable_pdl);
20+
void rmsnorm_quant(TensorView out, TensorView input, TensorView weight, TensorView scale,
21+
double eps, bool enable_pdl);
2222

2323
void fused_add_rmsnorm(TensorView input, TensorView residual, TensorView weight, double eps,
2424
bool enable_pdl);
2525

2626
void fused_add_rmsnorm_quant(TensorView output, TensorView input, TensorView residual,
27-
TensorView weight, double scale, double eps, bool enable_pdl);
27+
TensorView weight, TensorView scale, double eps, bool enable_pdl);
2828

2929
void gemma_rmsnorm(TensorView out, TensorView input, TensorView weight, double eps,
3030
bool enable_pdl);

β€Žcsrc/norm.cuβ€Ž

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -77,13 +77,15 @@ void rmsnorm(TensorView output, TensorView input, TensorView weight, double eps,
7777
}
7878
}
7979

80-
void rmsnorm_quant(TensorView output, TensorView input, TensorView weight, double scale, double eps,
81-
bool enable_pdl) {
80+
void rmsnorm_quant(TensorView output, TensorView input, TensorView weight, TensorView scale,
81+
double eps, bool enable_pdl) {
8282
CHECK_LAST_DIM_CONTIGUOUS_INPUT(input);
8383
CHECK_LAST_DIM_CONTIGUOUS_INPUT(output);
8484
CHECK_LAST_DIM_CONTIGUOUS_INPUT(weight);
8585
CHECK_DEVICE(input, weight);
86+
CHECK_DEVICE(input, scale);
8687
CHECK_DIM(1, weight); // weight: (hidden_size)
88+
TVM_FFI_ICHECK_EQ(scale.numel(), 1);
8789

8890
auto input_ndim = input.ndim();
8991
if (input_ndim == 2) {
@@ -103,7 +105,7 @@ void rmsnorm_quant(TensorView output, TensorView input, TensorView weight, doubl
103105
cudaError_t status = norm::RMSNormQuant(
104106
static_cast<c_type*>(input.data_ptr()), static_cast<c_type*>(weight.data_ptr()),
105107
static_cast<o_type*>(output.data_ptr()), batch_size, hidden_size, input.stride(0),
106-
output.stride(0), static_cast<float>(scale), eps, enable_pdl, stream);
108+
output.stride(0), static_cast<float*>(scale.data_ptr()), eps, enable_pdl, stream);
107109
TVM_FFI_ICHECK(status == cudaSuccess)
108110
<< "RMSNormQuant failed with error code " << cudaGetErrorString(status);
109111
return true;
@@ -145,14 +147,15 @@ void fused_add_rmsnorm(TensorView input, TensorView residual, TensorView weight,
145147
}
146148

147149
void fused_add_rmsnorm_quant(TensorView output, TensorView input, TensorView residual,
148-
TensorView weight, double scale, double eps, bool enable_pdl) {
150+
TensorView weight, TensorView scale, double eps, bool enable_pdl) {
149151
CHECK_LAST_DIM_CONTIGUOUS_INPUT(input);
150152
CHECK_LAST_DIM_CONTIGUOUS_INPUT(residual);
151153
CHECK_LAST_DIM_CONTIGUOUS_INPUT(weight);
152154
CHECK_LAST_DIM_CONTIGUOUS_INPUT(output);
153155
CHECK_DEVICE(input, residual);
154156
CHECK_DEVICE(input, weight);
155157
CHECK_DEVICE(input, output);
158+
CHECK_DEVICE(input, scale);
156159
CHECK_DIM(2, input); // input: (batch_size, hidden_size)
157160
CHECK_DIM(2, residual); // residual: (batch_size, hidden_size)
158161
CHECK_DIM(1, weight); // weight: (hidden_size)
@@ -162,6 +165,7 @@ void fused_add_rmsnorm_quant(TensorView output, TensorView input, TensorView res
162165
TVM_FFI_ICHECK_EQ(residual.size(0), batch_size);
163166
TVM_FFI_ICHECK_EQ(residual.size(1), hidden_size);
164167
TVM_FFI_ICHECK_EQ(weight.size(0), hidden_size);
168+
TVM_FFI_ICHECK_EQ(scale.numel(), 1);
165169
ffi::CUDADeviceGuard device_guard(input.device().device_id);
166170
const cudaStream_t stream = get_stream(input.device());
167171

@@ -170,8 +174,8 @@ void fused_add_rmsnorm_quant(TensorView output, TensorView input, TensorView res
170174
cudaError_t status = norm::FusedAddRMSNormQuant(
171175
static_cast<c_type*>(input.data_ptr()), static_cast<c_type*>(residual.data_ptr()),
172176
static_cast<c_type*>(weight.data_ptr()), static_cast<o_type*>(output.data_ptr()),
173-
batch_size, hidden_size, input.stride(0), residual.stride(0), output.stride(0), scale,
174-
eps, enable_pdl, stream);
177+
batch_size, hidden_size, input.stride(0), residual.stride(0), output.stride(0),
178+
static_cast<float*>(scale.data_ptr()), eps, enable_pdl, stream);
175179

176180
TVM_FFI_ICHECK(status == cudaSuccess)
177181
<< "FusedAddRMSNormQuant failed with error code " << cudaGetErrorString(status);

β€Ždocs/api/norm.rstβ€Ž

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,9 @@ Kernels for normalization layers.
1111
:toctree: ../generated
1212

1313
rmsnorm
14+
rmsnorm_quant
1415
fused_add_rmsnorm
16+
fused_add_rmsnorm_quant
1517
gemma_rmsnorm
1618
gemma_fused_add_rmsnorm
1719
layernorm

β€Žflashinfer/__init__.pyβ€Ž

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -112,8 +112,11 @@
112112
from .norm import gemma_rmsnorm as gemma_rmsnorm
113113
from .norm import rmsnorm as rmsnorm
114114

115-
from .norm import rmsnorm_fp4quant as rmsnorm_fp4quant
116-
from .norm import add_rmsnorm_fp4quant as add_rmsnorm_fp4quant
115+
try:
116+
from .norm import rmsnorm_fp4quant as rmsnorm_fp4quant
117+
from .norm import add_rmsnorm_fp4quant as add_rmsnorm_fp4quant
118+
except (ImportError, AttributeError):
119+
pass # nvidia-cutlass-dsl not installed
117120
from .page import append_paged_kv_cache as append_paged_kv_cache
118121
from .page import append_paged_mla_kv_cache as append_paged_mla_kv_cache
119122
from .page import get_batch_indices_positions as get_batch_indices_positions

β€Žflashinfer/cute_dsl/__init__.pyβ€Ž

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,24 @@
5454
AddRMSNormFP4QuantKernel,
5555
)
5656

57+
# Backwards-compatible re-exports from flashinfer.norm.kernels submodule
58+
from ..norm.kernels import (
59+
# Kernel classes
60+
RMSNormKernel,
61+
QKRMSNormKernel,
62+
RMSNormQuantKernel,
63+
FusedAddRMSNormKernel,
64+
FusedAddRMSNormQuantKernel,
65+
LayerNormKernel,
66+
# Python API functions
67+
rmsnorm_cute,
68+
qk_rmsnorm_cute,
69+
rmsnorm_quant_cute,
70+
fused_add_rmsnorm_cute,
71+
fused_add_rmsnorm_quant_cute,
72+
layernorm_cute,
73+
)
74+
5775
__all__ = [
5876
# Utils (always available)
5977
"is_cute_dsl_available",
@@ -79,4 +97,17 @@
7997
# Add + RMSNorm + FP4 Quantization
8098
"add_rmsnorm_fp4quant",
8199
"AddRMSNormFP4QuantKernel",
100+
# Norm kernels (CuTe DSL) - backwards-compatible re-exports
101+
"RMSNormKernel",
102+
"QKRMSNormKernel",
103+
"RMSNormQuantKernel",
104+
"FusedAddRMSNormKernel",
105+
"FusedAddRMSNormQuantKernel",
106+
"LayerNormKernel",
107+
"rmsnorm_cute",
108+
"qk_rmsnorm_cute",
109+
"rmsnorm_quant_cute",
110+
"fused_add_rmsnorm_cute",
111+
"fused_add_rmsnorm_quant_cute",
112+
"layernorm_cute",
82113
]

β€Žflashinfer/cute_dsl/add_rmsnorm_fp4quant.pyβ€Ž

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1012,8 +1012,8 @@ def tensor_api(
10121012
s_tensor,
10131013
s_unswizzled.contiguous(),
10141014
global_scale,
1015-
Int32(M),
1016-
Float32(eps),
1015+
M,
1016+
eps,
10171017
)
10181018

10191019
return tensor_api

β€Žflashinfer/cute_dsl/rmsnorm_fp4quant.pyβ€Ž

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -750,8 +750,8 @@ def tensor_api(
750750
y_uint8,
751751
s_tensor,
752752
global_scale,
753-
Int32(M),
754-
Float32(eps),
753+
M,
754+
eps,
755755
)
756756

757757
return tensor_api

0 commit comments

Comments
Β (0)