Skip to content

Commit 17103dd

Browse files
committed
Rename fused_qk_norm_rope to fused_qk_rmsnorm_rope
The kernel performs RMSNorm specifically (not LayerNorm or generic norm). Rename to fused_qk_rmsnorm_rope for consistency with FlashInfer's existing naming convention (rmsnorm, fused_add_rmsnorm, gemma_rmsnorm, fused_rmsnorm_silu). All files, imports, symbols, and docstrings updated. Internal kernel function names (fusedQKNormRopeKernel, launchFusedQKNormRope) kept as-is since they are not part of the public API. 25 passed, 1 xfail after rename. AI-assisted. Made-with: Cursor
1 parent 7580c79 commit 17103dd

8 files changed

Lines changed: 42 additions & 42 deletions

File tree

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,13 @@
11
"""
2-
Benchmark for fused QKNorm + 3D RoPE kernel vs eager PyTorch baseline.
2+
Benchmark for fused QK RMSNorm + 3D RoPE kernel vs eager PyTorch baseline.
33
44
Measures performance across WAN model shapes and compares:
55
- Eager: separate nn.RMSNorm + manual interleaved RoPE in PyTorch
6-
- Fused: flashinfer.diffusion_ops.fused_qk_norm_rope (single kernel)
6+
- Fused: flashinfer.diffusion_ops.fused_qk_rmsnorm_rope (single kernel)
77
88
Usage:
9-
python benchmarks/bench_fused_qk_norm_rope.py
10-
python benchmarks/bench_fused_qk_norm_rope.py --gpu 2 # run on specific GPU
9+
python benchmarks/bench_fused_qk_rmsnorm_rope.py
10+
python benchmarks/bench_fused_qk_rmsnorm_rope.py --gpu 2 # run on specific GPU
1111
"""
1212

1313
import argparse
@@ -17,7 +17,7 @@
1717
import torch.nn as nn
1818

1919
from flashinfer.testing.utils import bench_gpu_time
20-
from flashinfer.diffusion_ops import fused_qk_norm_rope
20+
from flashinfer.diffusion_ops import fused_qk_rmsnorm_rope
2121

2222

2323
def compute_rope_dims(head_dim):
@@ -133,7 +133,7 @@ def eager_fn():
133133
return q_out, k_out, v_heads
134134

135135
def fused_fn():
136-
return fused_qk_norm_rope(
136+
return fused_qk_rmsnorm_rope(
137137
qkv_combined,
138138
q_weight,
139139
k_weight,
@@ -166,7 +166,7 @@ def fused_fn():
166166

167167

168168
def main():
169-
parser = argparse.ArgumentParser(description="Benchmark fused QKNorm + 3D RoPE")
169+
parser = argparse.ArgumentParser(description="Benchmark fused QK RMSNorm + 3D RoPE")
170170
parser.add_argument("--gpu", type=int, default=0, help="GPU device index")
171171
args = parser.parse_args()
172172

csrc/flashinfer_norm_binding.cu

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ void gemma_fused_add_rmsnorm(TensorView input, TensorView residual, TensorView w
3434

3535
void layernorm(Tensor out, Tensor input, Tensor gamma, Tensor beta, double eps);
3636

37-
void fused_qk_norm_rope_run(TensorView qkv_in, TensorView q_weight, TensorView k_weight,
37+
void fused_qk_rmsnorm_rope_run(TensorView qkv_in, TensorView q_weight, TensorView k_weight,
3838
TensorView q_out, TensorView k_out, TensorView v_out,
3939
int64_t num_tokens, int64_t seq_len, int64_t ppf, int64_t pph,
4040
int64_t ppw, int64_t num_frame_channels, int64_t num_height_channels,
@@ -51,4 +51,4 @@ TVM_FFI_DLL_EXPORT_TYPED_FUNC(fused_add_rmsnorm_quant, fused_add_rmsnorm_quant);
5151
TVM_FFI_DLL_EXPORT_TYPED_FUNC(gemma_rmsnorm, gemma_rmsnorm);
5252
TVM_FFI_DLL_EXPORT_TYPED_FUNC(gemma_fused_add_rmsnorm, gemma_fused_add_rmsnorm);
5353
TVM_FFI_DLL_EXPORT_TYPED_FUNC(layernorm, layernorm);
54-
TVM_FFI_DLL_EXPORT_TYPED_FUNC(fused_qk_norm_rope, fused_qk_norm_rope_run);
54+
TVM_FFI_DLL_EXPORT_TYPED_FUNC(fused_qk_rmsnorm_rope, fused_qk_rmsnorm_rope_run);

csrc/norm.cu

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
* See the License for the specific language governing permissions and
1414
* limitations under the License.
1515
*/
16-
#include <flashinfer/fused_qk_norm_rope.cuh>
16+
#include <flashinfer/fused_qk_rmsnorm_rope.cuh>
1717
#include <flashinfer/norm.cuh>
1818

1919
#include "tvm_ffi_utils.h"
@@ -273,7 +273,7 @@ void layernorm(Tensor output, Tensor input, Tensor gamma, Tensor beta, double ep
273273
});
274274
}
275275

276-
void fused_qk_norm_rope_run(TensorView qkv_in, TensorView q_weight, TensorView k_weight,
276+
void fused_qk_rmsnorm_rope_run(TensorView qkv_in, TensorView q_weight, TensorView k_weight,
277277
TensorView q_out, TensorView k_out, TensorView v_out,
278278
int64_t num_tokens, int64_t seq_len, int64_t ppf, int64_t pph,
279279
int64_t ppw, int64_t num_frame_channels, int64_t num_height_channels,
Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
1-
from flashinfer.norm import fused_qk_norm_rope
1+
from flashinfer.norm import fused_qk_rmsnorm_rope
22

33
__all__ = [
4-
"fused_qk_norm_rope",
4+
"fused_qk_rmsnorm_rope",
55
]

flashinfer/norm/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -761,7 +761,7 @@ def fused_rmsnorm_silu(
761761
return out
762762

763763

764-
from .fused_qk_norm_rope import fused_qk_norm_rope as fused_qk_norm_rope
764+
from .fused_qk_rmsnorm_rope import fused_qk_rmsnorm_rope as fused_qk_rmsnorm_rope
765765

766766
# Public API exports
767767
__all__ = [
@@ -776,5 +776,5 @@ def fused_rmsnorm_silu(
776776
"gemma_fused_add_rmsnorm",
777777
"layernorm",
778778
"fused_rmsnorm_silu",
779-
"fused_qk_norm_rope",
779+
"fused_qk_rmsnorm_rope",
780780
]
Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
See the License for the specific language governing permissions and
1414
limitations under the License.
1515
16-
Fused QKNorm + 3D RoPE for Video Generation DIT Self-Attention
16+
Fused QK RMSNorm + 3D RoPE for Video Generation DIT Self-Attention
1717
===============================================================
1818
1919
Fuses across-heads RMSNorm on Q and K, 3D rotary position embeddings
@@ -34,13 +34,13 @@
3434

3535

3636
@supported_compute_capability([80, 86, 89, 90, 100, 103, 110, 120, 121])
37-
def _check_fused_qk_norm_rope(
37+
def _check_fused_qk_rmsnorm_rope(
3838
qkv,
3939
q_weight,
4040
k_weight,
4141
**kwargs,
4242
):
43-
"""Validate inputs for fused QKNorm + 3D RoPE.
43+
"""Validate inputs for fused QK RMSNorm + 3D RoPE.
4444
4545
Architecture notes:
4646
- SM80+ (Ampere): Full support for BF16 path; FP8 output uses software emulation
@@ -119,8 +119,8 @@ def _check_fused_qk_norm_rope(
119119

120120

121121
@flashinfer_api
122-
@backend_requirement(backend_checks={}, common_check=_check_fused_qk_norm_rope)
123-
def fused_qk_norm_rope(
122+
@backend_requirement(backend_checks={}, common_check=_check_fused_qk_rmsnorm_rope)
123+
def fused_qk_rmsnorm_rope(
124124
qkv: torch.Tensor,
125125
q_weight: torch.Tensor,
126126
k_weight: torch.Tensor,
@@ -150,7 +150,7 @@ def fused_qk_norm_rope(
150150
k_out: Optional[torch.Tensor] = None,
151151
v_out: Optional[torch.Tensor] = None,
152152
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
153-
r"""Fused QKNorm + 3D RoPE + V copy for video generation DIT self-attention.
153+
r"""Fused QK RMSNorm + 3D RoPE + V copy for video generation DIT self-attention.
154154
155155
Applies across-heads RMSNorm to Q and K, then rotary position embeddings
156156
with 3D spatial decomposition (frame/height/width), and copies V to a
@@ -255,7 +255,7 @@ def fused_qk_norm_rope(
255255
k_out_flat = k_out.view(num_tokens, -1)
256256
v_out_flat = v_out.view(num_tokens, -1)
257257

258-
get_norm_module().fused_qk_norm_rope(
258+
get_norm_module().fused_qk_rmsnorm_rope(
259259
qkv_flat,
260260
q_weight,
261261
k_weight,

include/flashinfer/fused_qk_norm_rope.cuh renamed to include/flashinfer/fused_qk_rmsnorm_rope.cuh

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,8 @@
1313
* See the License for the specific language governing permissions and
1414
* limitations under the License.
1515
*/
16-
#ifndef FLASHINFER_FUSED_QK_NORM_ROPE_CUH_
17-
#define FLASHINFER_FUSED_QK_NORM_ROPE_CUH_
16+
#ifndef FLASHINFER_FUSED_QK_RMSNORM_ROPE_CUH_
17+
#define FLASHINFER_FUSED_QK_RMSNORM_ROPE_CUH_
1818

1919
#include <cuda_bf16.h>
2020
#include <cuda_fp8.h>
@@ -303,7 +303,7 @@ __device__ __forceinline__ void quantize_store_fp8(float2 const* elements, __nv_
303303
}
304304

305305
////////////////////////////////////////////////////////////////////////////////////////////////////
306-
// Section 6: Fused QKNorm + RoPE kernel
306+
// Section 6: Fused QK RMSNorm + RoPE kernel
307307
//
308308
// Performs across-heads RMSNorm and 3D RoPE in a single kernel (for self-attention).
309309
// Also copies V to a separate contiguous output buffer with optional FP8 quantization.
@@ -751,4 +751,4 @@ inline void launchFusedQKNormRope(void const* qkv_in, void* q_out, void* k_out,
751751

752752
} // namespace flashinfer
753753

754-
#endif // FLASHINFER_FUSED_QK_NORM_ROPE_CUH_
754+
#endif // FLASHINFER_FUSED_QK_RMSNORM_ROPE_CUH_
Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
"""
2-
Tests for fused QKNorm + 3D RoPE kernel.
2+
Tests for fused QK RMSNorm + 3D RoPE kernel.
33
44
Tests correctness against a PyTorch reference implementation that matches
55
the WAN 2.2 model.py:
@@ -15,7 +15,7 @@
1515
import torch
1616
import torch.nn as nn
1717

18-
from flashinfer.diffusion_ops import fused_qk_norm_rope
18+
from flashinfer.diffusion_ops import fused_qk_rmsnorm_rope
1919

2020

2121
# ---------------------------------------------------------------------------
@@ -322,7 +322,7 @@ def test_interleaved_correctness(batch_size, ppf, pph, ppw):
322322
)
323323

324324
qkv_combined = torch.cat([query, key, value], dim=-1).contiguous()
325-
q_fused, k_fused, v_fused = fused_qk_norm_rope(
325+
q_fused, k_fused, v_fused = fused_qk_rmsnorm_rope(
326326
qkv_combined,
327327
norm_q.weight.contiguous(),
328328
norm_k.weight.contiguous(),
@@ -401,7 +401,7 @@ def test_neox_correctness(batch_size, ppf, pph, ppw):
401401
)
402402

403403
qkv_combined = torch.cat([query, key, value], dim=-1).contiguous()
404-
q_fused, k_fused, v_fused = fused_qk_norm_rope(
404+
q_fused, k_fused, v_fused = fused_qk_rmsnorm_rope(
405405
qkv_combined,
406406
norm_q.weight.contiguous(),
407407
norm_k.weight.contiguous(),
@@ -458,7 +458,7 @@ def test_v_passthrough():
458458
q_weight = torch.ones(hidden_dim, device=device, dtype=dtype)
459459
k_weight = torch.ones(hidden_dim, device=device, dtype=dtype)
460460

461-
_, _, v_fused = fused_qk_norm_rope(
461+
_, _, v_fused = fused_qk_rmsnorm_rope(
462462
qkv_combined,
463463
q_weight,
464464
k_weight,
@@ -509,7 +509,7 @@ def test_destination_passing():
509509
batch_size, seq_len, num_heads, head_dim, device=device, dtype=dtype
510510
)
511511

512-
q_ret, k_ret, v_ret = fused_qk_norm_rope(
512+
q_ret, k_ret, v_ret = fused_qk_rmsnorm_rope(
513513
qkv,
514514
torch.ones(hidden_dim, device=device, dtype=dtype),
515515
torch.ones(hidden_dim, device=device, dtype=dtype),
@@ -581,8 +581,8 @@ def test_2d_input():
581581
q_weight = torch.ones(hidden_dim, device=device, dtype=dtype)
582582
k_weight = torch.ones(hidden_dim, device=device, dtype=dtype)
583583

584-
q_3d, k_3d, v_3d = fused_qk_norm_rope(qkv_3d, q_weight, k_weight, **kwargs)
585-
q_2d, k_2d, v_2d = fused_qk_norm_rope(qkv_2d, q_weight, k_weight, **kwargs)
584+
q_3d, k_3d, v_3d = fused_qk_rmsnorm_rope(qkv_3d, q_weight, k_weight, **kwargs)
585+
q_2d, k_2d, v_2d = fused_qk_rmsnorm_rope(qkv_2d, q_weight, k_weight, **kwargs)
586586

587587
assert q_3d.ndim == 4, f"3D input should give 4D output, got {q_3d.ndim}D"
588588
assert q_2d.ndim == 3, f"2D input should give 3D output, got {q_2d.ndim}D"
@@ -626,7 +626,7 @@ def test_fp8_output(output_scale):
626626

627627
qkv_combined = torch.cat([query, key, value], dim=-1).contiguous()
628628

629-
q_fp8, k_fp8, v_fp8 = fused_qk_norm_rope(
629+
q_fp8, k_fp8, v_fp8 = fused_qk_rmsnorm_rope(
630630
qkv_combined,
631631
norm_q.weight.contiguous(),
632632
norm_k.weight.contiguous(),
@@ -713,7 +713,7 @@ def test_rope_only_no_norm():
713713
q_weight = torch.ones(hidden_dim, device=device, dtype=dtype)
714714
k_weight = torch.ones(hidden_dim, device=device, dtype=dtype)
715715

716-
q_fused, k_fused, _ = fused_qk_norm_rope(
716+
q_fused, k_fused, _ = fused_qk_rmsnorm_rope(
717717
qkv_combined,
718718
q_weight,
719719
k_weight,
@@ -811,7 +811,7 @@ def test_multi_config(config_name):
811811
)
812812

813813
qkv_combined = torch.cat([query, key, value], dim=-1).contiguous()
814-
q_fused, k_fused, _ = fused_qk_norm_rope(
814+
q_fused, k_fused, _ = fused_qk_rmsnorm_rope(
815815
qkv_combined,
816816
norm_q.weight.contiguous(),
817817
norm_k.weight.contiguous(),
@@ -847,7 +847,7 @@ def test_error_non_cuda():
847847
qkv = torch.randn(1, 120, 3 * 3072, dtype=torch.bfloat16)
848848
w = torch.ones(3072, dtype=torch.bfloat16)
849849
with pytest.raises((ValueError, RuntimeError)):
850-
fused_qk_norm_rope(
850+
fused_qk_rmsnorm_rope(
851851
qkv,
852852
w,
853853
w,
@@ -869,7 +869,7 @@ def test_error_wrong_dtype():
869869
qkv = torch.randn(1, 120, 3 * 3072, dtype=torch.float16, device=device)
870870
w = torch.ones(3072, dtype=torch.bfloat16, device=device)
871871
with pytest.raises((ValueError, RuntimeError)):
872-
fused_qk_norm_rope(
872+
fused_qk_rmsnorm_rope(
873873
qkv,
874874
w,
875875
w,
@@ -893,7 +893,7 @@ def test_error_bad_head_dim():
893893
qkv = torch.randn(1, 120, 3 * hidden, dtype=torch.bfloat16, device=device)
894894
w = torch.ones(hidden, dtype=torch.bfloat16, device=device)
895895
with pytest.raises((ValueError, RuntimeError)):
896-
fused_qk_norm_rope(
896+
fused_qk_rmsnorm_rope(
897897
qkv,
898898
w,
899899
w,
@@ -915,7 +915,7 @@ def test_error_channel_sum_mismatch():
915915
qkv = torch.randn(1, 120, 3 * 3072, dtype=torch.bfloat16, device=device)
916916
w = torch.ones(3072, dtype=torch.bfloat16, device=device)
917917
with pytest.raises((ValueError, RuntimeError)):
918-
fused_qk_norm_rope(
918+
fused_qk_rmsnorm_rope(
919919
qkv,
920920
w,
921921
w,
@@ -937,7 +937,7 @@ def test_error_seq_len_mismatch():
937937
qkv = torch.randn(1, 100, 3 * 3072, dtype=torch.bfloat16, device=device)
938938
w = torch.ones(3072, dtype=torch.bfloat16, device=device)
939939
with pytest.raises((ValueError, RuntimeError)):
940-
fused_qk_norm_rope(
940+
fused_qk_rmsnorm_rope(
941941
qkv,
942942
w,
943943
w,

0 commit comments

Comments
 (0)