Skip to content
Open
Show file tree
Hide file tree
Changes from 17 commits
Commits
Show all changes
35 commits
Select commit Hold shift + click to select a range
e4df35f
Add fused QKNorm+RoPE kernel header for video-gen DIT self-attention
kahyunnam Apr 22, 2026
e0cb63c
Add TVM-FFI launcher and binding for fused QKNorm+RoPE
kahyunnam Apr 22, 2026
cc62bc3
Merge fused QKNorm+RoPE into the norm JIT module
kahyunnam Apr 22, 2026
bf2eb0b
Add Python API for fused QKNorm+RoPE in norm package
kahyunnam Apr 22, 2026
cc1d649
Export fused_qk_norm_rope from norm/ and create video_gen_ops/ facade
kahyunnam Apr 22, 2026
193cc53
Add tests for fused QKNorm+RoPE and fix packed_as namespace collision
kahyunnam Apr 22, 2026
03a4eb3
Expand tests: 8 interleaved shapes, 3 NeoX xfail, multi-GPU validated
kahyunnam Apr 22, 2026
1baac3d
Fix NeoX RoPE pos_id bug and add passing NeoX tests
kahyunnam Apr 22, 2026
2cd2088
Add multi-config tests for WAN 1.3B/5B/14B model sizes
kahyunnam Apr 22, 2026
73ec639
Add benchmark for fused QKNorm+RoPE kernel
kahyunnam Apr 22, 2026
feb1966
Support 2D [num_tokens, hidden] input in fused_qk_norm_rope
kahyunnam Apr 22, 2026
b4f7d38
Clean up internal references and fix arch descriptions
kahyunnam Apr 22, 2026
2cda43d
Use single top-level import from video_gen_ops in tests
kahyunnam Apr 22, 2026
aeffd60
Fix benchmark docstring to reference video_gen_ops API path
kahyunnam Apr 22, 2026
2f1841a
Fix lint: unused imports, mypy types, clang-format, ruff format
kahyunnam Apr 22, 2026
af4247f
Rename video_gen_ops to diffusion_ops
kahyunnam Apr 22, 2026
7580c79
Tighten FP8 test tolerance from 2x to 1.5x headroom
kahyunnam Apr 22, 2026
17103dd
Rename fused_qk_norm_rope to fused_qk_rmsnorm_rope
kahyunnam Apr 22, 2026
ddfdbe9
Use int64_t for num_tokens to prevent offset overflow
kahyunnam Apr 22, 2026
693a56e
Make frequency table cache per-device for multi-GPU safety
kahyunnam Apr 22, 2026
a18bc00
Fix clang-format indentation in csrc files
kahyunnam Apr 22, 2026
2de4c2b
Merge branch 'main' into knam/fused_norm_rope_for_video_gen
kahyunnam Apr 22, 2026
5a6c91b
Fix clang-format in fused_qk_rmsnorm_rope.cuh
kahyunnam Apr 22, 2026
792eb4e
Use lazy import for get_norm_module to avoid import-time dependency
kahyunnam Apr 22, 2026
e40ecf7
Remove stray non-ASCII character from docstring
kahyunnam Apr 22, 2026
20f4ea5
Stop freeing cached freq table pointer for cudagraph safety
kahyunnam Apr 22, 2026
2cb55e6
Check CUDA return codes and use cudaMemcpyAsync in freq cache
kahyunnam Apr 22, 2026
2c92645
Guard grid dim truncation and check kernel launch error
kahyunnam Apr 22, 2026
31c8a97
Validate weights and pre-allocated output buffers
kahyunnam Apr 23, 2026
9ac8944
Add error tests for bad weights and output buffers
kahyunnam Apr 23, 2026
e630f22
Always define get_norm_module regardless of CuTe DSL availability
kahyunnam Apr 23, 2026
7de01bb
Apply ruff format to Python files
kahyunnam Apr 23, 2026
758f8f0
Add fused_qk_rmsnorm_rope to unified benchmark CLI
kahyunnam Apr 23, 2026
9b5fe87
Move fused_qk_rmsnorm_rope into norm/__init__.py
kahyunnam Apr 23, 2026
b447fc6
Move fused_qk_rmsnorm_rope.cuh into include/flashinfer/norm/
kahyunnam Apr 23, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
211 changes: 211 additions & 0 deletions benchmarks/bench_fused_qk_norm_rope.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,211 @@
"""
Benchmark for fused QKNorm + 3D RoPE kernel vs eager PyTorch baseline.

Measures performance across WAN model shapes and compares:
- Eager: separate nn.RMSNorm + manual interleaved RoPE in PyTorch
- Fused: flashinfer.diffusion_ops.fused_qk_norm_rope (single kernel)

Usage:
python benchmarks/bench_fused_qk_norm_rope.py
python benchmarks/bench_fused_qk_norm_rope.py --gpu 2 # run on specific GPU
"""

import argparse

import numpy as np
import torch
import torch.nn as nn

from flashinfer.testing.utils import bench_gpu_time
from flashinfer.diffusion_ops import fused_qk_norm_rope


def compute_rope_dims(head_dim):
h_dim = w_dim = 2 * (head_dim // 6)
t_dim = head_dim - h_dim - w_dim
return t_dim, h_dim, w_dim


def apply_rotary_emb_interleaved(hidden_states, freqs_cos, freqs_sin):
x1, x2 = hidden_states.unflatten(-1, (-1, 2)).unbind(-1)
cos = freqs_cos[..., 0::2]
sin = freqs_sin[..., 1::2]
out = torch.empty_like(hidden_states)
out[..., 0::2] = x1 * cos - x2 * sin
out[..., 1::2] = x1 * sin + x2 * cos
return out.type_as(hidden_states)


def get_1d_rotary_pos_embed(dim, length, theta, device):
inv_freq = 1.0 / (
theta ** (torch.arange(0, dim, 2, device=device, dtype=torch.float64) / dim)
)
pos = torch.arange(length, device=device, dtype=torch.float64)
freqs = torch.einsum("i,j->ij", pos, inv_freq)
cos_out = torch.zeros(length, dim, device=device, dtype=torch.float64)
sin_out = torch.zeros(length, dim, device=device, dtype=torch.float64)
cos_out[:, 0::2] = torch.cos(freqs)
cos_out[:, 1::2] = torch.cos(freqs)
sin_out[:, 0::2] = torch.sin(freqs)
sin_out[:, 1::2] = torch.sin(freqs)
return cos_out, sin_out


def create_3d_rotary_embeddings(
batch_size, ppf, pph, ppw, head_dim, device, base=10000.0, dtype=torch.bfloat16
):
h_dim = w_dim = 2 * (head_dim // 6)
t_dim = head_dim - h_dim - w_dim
max_len = max(ppf, pph, ppw)
t_cos, t_sin = get_1d_rotary_pos_embed(t_dim, max_len, base, device)
h_cos, h_sin = get_1d_rotary_pos_embed(h_dim, max_len, base, device)
w_cos, w_sin = get_1d_rotary_pos_embed(w_dim, max_len, base, device)
t_cos_3d = (
t_cos[:ppf].view(1, ppf, 1, 1, t_dim).expand(batch_size, ppf, pph, ppw, t_dim)
)
t_sin_3d = (
t_sin[:ppf].view(1, ppf, 1, 1, t_dim).expand(batch_size, ppf, pph, ppw, t_dim)
)
h_cos_3d = (
h_cos[:pph].view(1, 1, pph, 1, h_dim).expand(batch_size, ppf, pph, ppw, h_dim)
)
h_sin_3d = (
h_sin[:pph].view(1, 1, pph, 1, h_dim).expand(batch_size, ppf, pph, ppw, h_dim)
)
w_cos_3d = (
w_cos[:ppw].view(1, 1, 1, ppw, w_dim).expand(batch_size, ppf, pph, ppw, w_dim)
)
w_sin_3d = (
w_sin[:ppw].view(1, 1, 1, ppw, w_dim).expand(batch_size, ppf, pph, ppw, w_dim)
)
freqs_cos = torch.cat([t_cos_3d, h_cos_3d, w_cos_3d], dim=-1)
freqs_sin = torch.cat([t_sin_3d, h_sin_3d, w_sin_3d], dim=-1)
seq_len = ppf * pph * ppw
return (
freqs_cos.reshape(batch_size, seq_len, 1, head_dim).to(dtype),
freqs_sin.reshape(batch_size, seq_len, 1, head_dim).to(dtype),
)


BENCH_SHAPES = [
# (batch, ppf, pph, ppw, description)
(1, 5, 12, 32, "480p production (1920 tokens)"),
(1, 5, 12, 8, "480p small (480 tokens)"),
(1, 5, 48, 32, "720p large (7680 tokens)"),
(2, 5, 12, 32, "batch=2 (3840 tokens)"),
(1, 5, 6, 4, "tiny (120 tokens)"),
(4, 5, 12, 32, "batch=4 (7680 tokens)"),
(1, 5, 12, 16, "half seq (960 tokens)"),
(1, 10, 12, 32, "double frames (3840 tokens)"),
]


def bench_one_shape(batch_size, ppf, pph, ppw, num_heads, head_dim, eps, base, device):
seq_len = ppf * pph * ppw
hidden_dim = num_heads * head_dim
t_dim, h_dim, w_dim = compute_rope_dims(head_dim)
dtype = torch.bfloat16

torch.manual_seed(42)
query = torch.randn(batch_size, seq_len, hidden_dim, device=device, dtype=dtype)
key = torch.randn(batch_size, seq_len, hidden_dim, device=device, dtype=dtype)
value = torch.randn(batch_size, seq_len, hidden_dim, device=device, dtype=dtype)

norm_q = nn.RMSNorm(hidden_dim, eps=eps).to(device).to(dtype)
norm_k = nn.RMSNorm(hidden_dim, eps=eps).to(device).to(dtype)

freqs_cos, freqs_sin = create_3d_rotary_embeddings(
batch_size, ppf, pph, ppw, head_dim, device, base, dtype
)

qkv_combined = torch.cat([query, key, value], dim=-1).contiguous()
q_weight = norm_q.weight.contiguous()
k_weight = norm_k.weight.contiguous()

def eager_fn():
q_normed = norm_q(query)
k_normed = norm_k(key)
q_heads = q_normed.unflatten(2, (num_heads, -1))
k_heads = k_normed.unflatten(2, (num_heads, -1))
v_heads = value.unflatten(2, (num_heads, -1))
q_out = apply_rotary_emb_interleaved(q_heads, freqs_cos, freqs_sin)
k_out = apply_rotary_emb_interleaved(k_heads, freqs_cos, freqs_sin)
return q_out, k_out, v_heads

def fused_fn():
return fused_qk_norm_rope(
qkv_combined,
q_weight,
k_weight,
ppf=ppf,
pph=pph,
ppw=ppw,
num_frame_channels=t_dim,
num_height_channels=h_dim,
num_width_channels=w_dim,
num_heads_q=num_heads,
num_heads_k=num_heads,
num_heads_v=num_heads,
head_dim=head_dim,
eps=eps,
base=base,
interleave=True,
is_qk_norm=True,
)

eager_times = bench_gpu_time(
eager_fn, enable_cupti=True, dry_run_iters=10, repeat_iters=100
)
fused_times = bench_gpu_time(
fused_fn, enable_cupti=True, dry_run_iters=10, repeat_iters=100
)

eager_ms = float(np.median(eager_times))
fused_ms = float(np.median(fused_times))
return eager_ms, fused_ms


def main():
parser = argparse.ArgumentParser(description="Benchmark fused QKNorm + 3D RoPE")
parser.add_argument("--gpu", type=int, default=0, help="GPU device index")
args = parser.parse_args()

device = torch.device(f"cuda:{args.gpu}")
torch.cuda.set_device(device)
gpu_name = torch.cuda.get_device_name(device)

num_heads = 24
head_dim = 128
eps = 1e-6
base = 10000.0

print(f"GPU: {gpu_name}")
print(f"Config: WAN 2.2 5B (num_heads={num_heads}, head_dim={head_dim})")
print()
print(f"{'Shape':<50} {'Eager (ms)':>12} {'Fused (ms)':>12} {'Speedup':>10}")
print("-" * 90)

for batch_size, ppf, pph, ppw, desc in BENCH_SHAPES:
seq_len = ppf * pph * ppw
shape_str = f"B={batch_size} {ppf}x{pph}x{ppw}={seq_len:>5} ({desc})"

eager_ms, fused_ms = bench_one_shape(
batch_size,
ppf,
pph,
ppw,
num_heads,
head_dim,
eps,
base,
device,
)

speedup = eager_ms / fused_ms if fused_ms > 0 else 0
print(f"{shape_str:<50} {eager_ms:>12.4f} {fused_ms:>12.4f} {speedup:>9.2f}x")

print("-" * 90)


if __name__ == "__main__":
main()
11 changes: 11 additions & 0 deletions csrc/flashinfer_norm_binding.cu
Original file line number Diff line number Diff line change
Expand Up @@ -34,10 +34,21 @@ void gemma_fused_add_rmsnorm(TensorView input, TensorView residual, TensorView w

void layernorm(Tensor out, Tensor input, Tensor gamma, Tensor beta, double eps);

void fused_qk_norm_rope_run(TensorView qkv_in, TensorView q_weight, TensorView k_weight,
TensorView q_out, TensorView k_out, TensorView v_out,
int64_t num_tokens, int64_t seq_len, int64_t ppf, int64_t pph,
int64_t ppw, int64_t num_frame_channels, int64_t num_height_channels,
int64_t num_width_channels, int64_t num_heads_q, int64_t num_heads_k,
int64_t num_heads_v, int64_t head_dim, double eps, double base,
bool interleave, double factor, double low, double high,
double attention_factor, bool is_qk_norm, bool output_fp8,
double output_quant_scale, double v_quant_scale);

TVM_FFI_DLL_EXPORT_TYPED_FUNC(rmsnorm, rmsnorm);
TVM_FFI_DLL_EXPORT_TYPED_FUNC(rmsnorm_quant, rmsnorm_quant);
TVM_FFI_DLL_EXPORT_TYPED_FUNC(fused_add_rmsnorm, fused_add_rmsnorm);
TVM_FFI_DLL_EXPORT_TYPED_FUNC(fused_add_rmsnorm_quant, fused_add_rmsnorm_quant);
TVM_FFI_DLL_EXPORT_TYPED_FUNC(gemma_rmsnorm, gemma_rmsnorm);
TVM_FFI_DLL_EXPORT_TYPED_FUNC(gemma_fused_add_rmsnorm, gemma_fused_add_rmsnorm);
TVM_FFI_DLL_EXPORT_TYPED_FUNC(layernorm, layernorm);
TVM_FFI_DLL_EXPORT_TYPED_FUNC(fused_qk_norm_rope, fused_qk_norm_rope_run);
42 changes: 42 additions & 0 deletions csrc/norm.cu
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <flashinfer/fused_qk_norm_rope.cuh>
#include <flashinfer/norm.cuh>

#include "tvm_ffi_utils.h"
Expand Down Expand Up @@ -271,3 +272,44 @@ void layernorm(Tensor output, Tensor input, Tensor gamma, Tensor beta, double ep
return true;
});
}

void fused_qk_norm_rope_run(TensorView qkv_in, TensorView q_weight, TensorView k_weight,
TensorView q_out, TensorView k_out, TensorView v_out,
int64_t num_tokens, int64_t seq_len, int64_t ppf, int64_t pph,
int64_t ppw, int64_t num_frame_channels, int64_t num_height_channels,
int64_t num_width_channels, int64_t num_heads_q, int64_t num_heads_k,
int64_t num_heads_v, int64_t head_dim, double eps, double base,
bool interleave, double factor, double low, double high,
double attention_factor, bool is_qk_norm, bool output_fp8,
double output_quant_scale, double v_quant_scale) {
CHECK_INPUT(qkv_in);
CHECK_INPUT(q_weight);
CHECK_INPUT(k_weight);
CHECK_CUDA(q_out);
CHECK_CONTIGUOUS(q_out);
CHECK_CUDA(k_out);
CHECK_CONTIGUOUS(k_out);
CHECK_CUDA(v_out);
CHECK_CONTIGUOUS(v_out);

CHECK_INPUT_TYPE(qkv_in, dl_bfloat16);
CHECK_INPUT_TYPE(q_weight, dl_bfloat16);
CHECK_INPUT_TYPE(k_weight, dl_bfloat16);

ffi::CUDADeviceGuard device_guard(qkv_in.device().device_id);
const cudaStream_t stream = get_stream(qkv_in.device());

int num_sms;
cudaDeviceGetAttribute(&num_sms, cudaDevAttrMultiProcessorCount, qkv_in.device().device_id);

launchFusedQKNormRope(
qkv_in.data_ptr(), q_out.data_ptr(), k_out.data_ptr(), v_out.data_ptr(),
static_cast<int>(num_tokens), static_cast<int>(seq_len), static_cast<int>(ppf),
static_cast<int>(pph), static_cast<int>(ppw), static_cast<int>(num_frame_channels),
static_cast<int>(num_height_channels), static_cast<int>(num_width_channels),
static_cast<int>(num_heads_q), static_cast<int>(num_heads_k), static_cast<int>(num_heads_v),
static_cast<int>(head_dim), static_cast<float>(eps), q_weight.data_ptr(), k_weight.data_ptr(),
static_cast<float>(base), interleave, static_cast<float>(factor), static_cast<float>(low),
static_cast<float>(high), static_cast<float>(attention_factor), stream, is_qk_norm, num_sms,
output_fp8, static_cast<float>(output_quant_scale), static_cast<float>(v_quant_scale));
}
5 changes: 5 additions & 0 deletions flashinfer/diffusion_ops/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
from flashinfer.norm import fused_qk_norm_rope

__all__ = [
"fused_qk_norm_rope",
]
3 changes: 3 additions & 0 deletions flashinfer/norm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -761,6 +761,8 @@ def fused_rmsnorm_silu(
return out


from .fused_qk_norm_rope import fused_qk_norm_rope as fused_qk_norm_rope

# Public API exports
__all__ = [
# JIT module generator (always available)
Expand All @@ -774,4 +776,5 @@ def fused_rmsnorm_silu(
"gemma_fused_add_rmsnorm",
"layernorm",
"fused_rmsnorm_silu",
"fused_qk_norm_rope",
]
Loading
Loading