diff --git a/benchmarks/bench_fused_qk_rmsnorm_rope.py b/benchmarks/bench_fused_qk_rmsnorm_rope.py new file mode 100644 index 0000000000..02aeef489a --- /dev/null +++ b/benchmarks/bench_fused_qk_rmsnorm_rope.py @@ -0,0 +1,211 @@ +""" +Benchmark for fused QK RMSNorm + 3D RoPE kernel vs eager PyTorch baseline. + +Measures performance across WAN model shapes and compares: +- Eager: separate nn.RMSNorm + manual interleaved RoPE in PyTorch +- Fused: flashinfer.diffusion_ops.fused_qk_rmsnorm_rope (single kernel) + +Usage: + python benchmarks/bench_fused_qk_rmsnorm_rope.py + python benchmarks/bench_fused_qk_rmsnorm_rope.py --gpu 2 # run on specific GPU +""" + +import argparse + +import numpy as np +import torch +import torch.nn as nn + +from flashinfer.testing.utils import bench_gpu_time +from flashinfer.diffusion_ops import fused_qk_rmsnorm_rope + + +def compute_rope_dims(head_dim): + h_dim = w_dim = 2 * (head_dim // 6) + t_dim = head_dim - h_dim - w_dim + return t_dim, h_dim, w_dim + + +def apply_rotary_emb_interleaved(hidden_states, freqs_cos, freqs_sin): + x1, x2 = hidden_states.unflatten(-1, (-1, 2)).unbind(-1) + cos = freqs_cos[..., 0::2] + sin = freqs_sin[..., 1::2] + out = torch.empty_like(hidden_states) + out[..., 0::2] = x1 * cos - x2 * sin + out[..., 1::2] = x1 * sin + x2 * cos + return out.type_as(hidden_states) + + +def get_1d_rotary_pos_embed(dim, length, theta, device): + inv_freq = 1.0 / ( + theta ** (torch.arange(0, dim, 2, device=device, dtype=torch.float64) / dim) + ) + pos = torch.arange(length, device=device, dtype=torch.float64) + freqs = torch.einsum("i,j->ij", pos, inv_freq) + cos_out = torch.zeros(length, dim, device=device, dtype=torch.float64) + sin_out = torch.zeros(length, dim, device=device, dtype=torch.float64) + cos_out[:, 0::2] = torch.cos(freqs) + cos_out[:, 1::2] = torch.cos(freqs) + sin_out[:, 0::2] = torch.sin(freqs) + sin_out[:, 1::2] = torch.sin(freqs) + return cos_out, sin_out + + +def create_3d_rotary_embeddings( + batch_size, ppf, pph, ppw, head_dim, device, base=10000.0, dtype=torch.bfloat16 +): + h_dim = w_dim = 2 * (head_dim // 6) + t_dim = head_dim - h_dim - w_dim + max_len = max(ppf, pph, ppw) + t_cos, t_sin = get_1d_rotary_pos_embed(t_dim, max_len, base, device) + h_cos, h_sin = get_1d_rotary_pos_embed(h_dim, max_len, base, device) + w_cos, w_sin = get_1d_rotary_pos_embed(w_dim, max_len, base, device) + t_cos_3d = ( + t_cos[:ppf].view(1, ppf, 1, 1, t_dim).expand(batch_size, ppf, pph, ppw, t_dim) + ) + t_sin_3d = ( + t_sin[:ppf].view(1, ppf, 1, 1, t_dim).expand(batch_size, ppf, pph, ppw, t_dim) + ) + h_cos_3d = ( + h_cos[:pph].view(1, 1, pph, 1, h_dim).expand(batch_size, ppf, pph, ppw, h_dim) + ) + h_sin_3d = ( + h_sin[:pph].view(1, 1, pph, 1, h_dim).expand(batch_size, ppf, pph, ppw, h_dim) + ) + w_cos_3d = ( + w_cos[:ppw].view(1, 1, 1, ppw, w_dim).expand(batch_size, ppf, pph, ppw, w_dim) + ) + w_sin_3d = ( + w_sin[:ppw].view(1, 1, 1, ppw, w_dim).expand(batch_size, ppf, pph, ppw, w_dim) + ) + freqs_cos = torch.cat([t_cos_3d, h_cos_3d, w_cos_3d], dim=-1) + freqs_sin = torch.cat([t_sin_3d, h_sin_3d, w_sin_3d], dim=-1) + seq_len = ppf * pph * ppw + return ( + freqs_cos.reshape(batch_size, seq_len, 1, head_dim).to(dtype), + freqs_sin.reshape(batch_size, seq_len, 1, head_dim).to(dtype), + ) + + +BENCH_SHAPES = [ + # (batch, ppf, pph, ppw, description) + (1, 5, 12, 32, "480p production (1920 tokens)"), + (1, 5, 12, 8, "480p small (480 tokens)"), + (1, 5, 48, 32, "720p large (7680 tokens)"), + (2, 5, 12, 32, "batch=2 (3840 tokens)"), + (1, 5, 6, 4, "tiny (120 tokens)"), + (4, 5, 12, 32, "batch=4 (7680 tokens)"), + (1, 5, 12, 16, "half seq (960 tokens)"), + (1, 10, 12, 32, "double frames (3840 tokens)"), +] + + +def bench_one_shape(batch_size, ppf, pph, ppw, num_heads, head_dim, eps, base, device): + seq_len = ppf * pph * ppw + hidden_dim = num_heads * head_dim + t_dim, h_dim, w_dim = compute_rope_dims(head_dim) + dtype = torch.bfloat16 + + torch.manual_seed(42) + query = torch.randn(batch_size, seq_len, hidden_dim, device=device, dtype=dtype) + key = torch.randn(batch_size, seq_len, hidden_dim, device=device, dtype=dtype) + value = torch.randn(batch_size, seq_len, hidden_dim, device=device, dtype=dtype) + + norm_q = nn.RMSNorm(hidden_dim, eps=eps).to(device).to(dtype) + norm_k = nn.RMSNorm(hidden_dim, eps=eps).to(device).to(dtype) + + freqs_cos, freqs_sin = create_3d_rotary_embeddings( + batch_size, ppf, pph, ppw, head_dim, device, base, dtype + ) + + qkv_combined = torch.cat([query, key, value], dim=-1).contiguous() + q_weight = norm_q.weight.contiguous() + k_weight = norm_k.weight.contiguous() + + def eager_fn(): + q_normed = norm_q(query) + k_normed = norm_k(key) + q_heads = q_normed.unflatten(2, (num_heads, -1)) + k_heads = k_normed.unflatten(2, (num_heads, -1)) + v_heads = value.unflatten(2, (num_heads, -1)) + q_out = apply_rotary_emb_interleaved(q_heads, freqs_cos, freqs_sin) + k_out = apply_rotary_emb_interleaved(k_heads, freqs_cos, freqs_sin) + return q_out, k_out, v_heads + + def fused_fn(): + return fused_qk_rmsnorm_rope( + qkv_combined, + q_weight, + k_weight, + ppf=ppf, + pph=pph, + ppw=ppw, + num_frame_channels=t_dim, + num_height_channels=h_dim, + num_width_channels=w_dim, + num_heads_q=num_heads, + num_heads_k=num_heads, + num_heads_v=num_heads, + head_dim=head_dim, + eps=eps, + base=base, + interleave=True, + is_qk_norm=True, + ) + + eager_times = bench_gpu_time( + eager_fn, enable_cupti=True, dry_run_iters=10, repeat_iters=100 + ) + fused_times = bench_gpu_time( + fused_fn, enable_cupti=True, dry_run_iters=10, repeat_iters=100 + ) + + eager_ms = float(np.median(eager_times)) + fused_ms = float(np.median(fused_times)) + return eager_ms, fused_ms + + +def main(): + parser = argparse.ArgumentParser(description="Benchmark fused QK RMSNorm + 3D RoPE") + parser.add_argument("--gpu", type=int, default=0, help="GPU device index") + args = parser.parse_args() + + device = torch.device(f"cuda:{args.gpu}") + torch.cuda.set_device(device) + gpu_name = torch.cuda.get_device_name(device) + + num_heads = 24 + head_dim = 128 + eps = 1e-6 + base = 10000.0 + + print(f"GPU: {gpu_name}") + print(f"Config: WAN 2.2 5B (num_heads={num_heads}, head_dim={head_dim})") + print() + print(f"{'Shape':<50} {'Eager (ms)':>12} {'Fused (ms)':>12} {'Speedup':>10}") + print("-" * 90) + + for batch_size, ppf, pph, ppw, desc in BENCH_SHAPES: + seq_len = ppf * pph * ppw + shape_str = f"B={batch_size} {ppf}x{pph}x{ppw}={seq_len:>5} ({desc})" + + eager_ms, fused_ms = bench_one_shape( + batch_size, + ppf, + pph, + ppw, + num_heads, + head_dim, + eps, + base, + device, + ) + + speedup = eager_ms / fused_ms if fused_ms > 0 else 0 + print(f"{shape_str:<50} {eager_ms:>12.4f} {fused_ms:>12.4f} {speedup:>9.2f}x") + + print("-" * 90) + + +if __name__ == "__main__": + main() diff --git a/benchmarks/routines/flashinfer_benchmark_utils.py b/benchmarks/routines/flashinfer_benchmark_utils.py index 2bf9916e61..19f56694bf 100644 --- a/benchmarks/routines/flashinfer_benchmark_utils.py +++ b/benchmarks/routines/flashinfer_benchmark_utils.py @@ -91,6 +91,9 @@ "scale", "eps", "use_global_scale", + "ppf", + "pph", + "ppw", ], "quantization": [ "alignment", @@ -213,6 +216,7 @@ "rmsnorm_fp4quant", "add_rmsnorm_fp4quant", "fused_rmsnorm_silu", + "fused_qk_rmsnorm_rope", ], "quantization": [ "mxfp8_quantize", @@ -563,6 +567,16 @@ def dtype_str_to_torch_dtype(dtype_str): "12.0": ["cute-dsl"], "12.1": ["cute-dsl"], }, + "fused_qk_rmsnorm_rope": { + "8.0": ["cuda"], + "8.6": ["cuda"], + "8.9": ["cuda"], + "9.0": ["cuda"], + "10.0": ["cuda"], + "10.3": ["cuda"], + "12.0": ["cuda"], + "12.1": ["cuda"], + }, # QUANTIZATION "mxfp8_quantize": { "7.5": [], diff --git a/benchmarks/routines/norm.py b/benchmarks/routines/norm.py index ffbafdd46a..f881a01ea4 100644 --- a/benchmarks/routines/norm.py +++ b/benchmarks/routines/norm.py @@ -53,6 +53,8 @@ def run_norm_test(args): return testAddRmsnormFp4quant(args) elif args.routine == "fused_rmsnorm_silu": return testFusedRmsnormSilu(args) + elif args.routine == "fused_qk_rmsnorm_rope": + return testFusedQkRmsnormRope(args) else: raise ValueError(f"Unsupported routine: {args.routine}") @@ -152,6 +154,29 @@ def parse_norm_args(line, parser): "overrides --is_sf_swizzled_layout and returns both layouts. Default: False", ) + # fused_qk_rmsnorm_rope specific arguments + parser.add_argument( + "--ppf", + type=int, + required=False, + default=5, + help="Number of patches in frame dimension (for fused_qk_rmsnorm_rope).", + ) + parser.add_argument( + "--pph", + type=int, + required=False, + default=12, + help="Number of patches in height dimension (for fused_qk_rmsnorm_rope).", + ) + parser.add_argument( + "--ppw", + type=int, + required=False, + default=32, + help="Number of patches in width dimension (for fused_qk_rmsnorm_rope).", + ) + args = parser.parse_args(line) if args.verbose >= 1: print(f"[INFO] {args = }") @@ -1199,3 +1224,157 @@ def run_fn(input_tensor, weight, out): cur_res["case_tag"] = args.case_tag res.append(cur_res) return res + + +def testFusedQkRmsnormRope(args): + """ + Test fused QK RMSNorm + 3D RoPE + V copy API. + + Benchmarks the fused kernel for video generation DIT self-attention + (e.g. WAN 2.1/2.2). Compares against eager PyTorch (separate RMSNorm + RoPE). + + Args: + args: Parsed command line arguments + + Returns: + list: List of dicts containing performance results + """ + from flashinfer.diffusion_ops import fused_qk_rmsnorm_rope + + if args.verbose >= 1: + print("[INFO] Running testFusedQkRmsnormRope") + print(f"[INFO] FlashInfer version: {flashinfer.__version__}") + + device = get_device(args) + + batch_size = args.batch_size + hidden_size = args.hidden_size + num_heads = args.num_heads + eps = args.eps + ppf = args.ppf + pph = args.pph + ppw = args.ppw + is_cuda_graph_compatible = not args.no_cuda_graph + run_refcheck = args.refcheck + res = [] + + backends = filter_backends_by_compute_capability( + args.backends[:], args.routine, device + ) + if len(backends) == 0: + print("[ERROR] No backends to test. Exiting.") + return res + + if num_heads is None: + raise ValueError("--num_heads is required for fused_qk_rmsnorm_rope") + + head_dim = hidden_size // num_heads + seq_len = ppf * pph * ppw + + h_dim = w_dim = 2 * (head_dim // 6) + t_dim = head_dim - h_dim - w_dim + + input_dtype = torch.bfloat16 + + torch.manual_seed(42) + qkv = torch.randn( + batch_size, seq_len, 3 * hidden_size, dtype=input_dtype, device=device + ) + q_weight = torch.randn(hidden_size, dtype=input_dtype, device=device) + k_weight = torch.randn(hidden_size, dtype=input_dtype, device=device) + + if args.verbose >= 2: + print(f"[VVERBOSE] qkv.shape = {qkv.shape}") + print(f"[VVERBOSE] seq_len = {seq_len} (ppf={ppf}, pph={pph}, ppw={ppw})") + print(f"[VVERBOSE] head_dim = {head_dim}, t/h/w = {t_dim}/{h_dim}/{w_dim}") + + kwargs = dict( + ppf=ppf, + pph=pph, + ppw=ppw, + num_frame_channels=t_dim, + num_height_channels=h_dim, + num_width_channels=w_dim, + num_heads_q=num_heads, + num_heads_k=num_heads, + num_heads_v=num_heads, + head_dim=head_dim, + eps=eps, + base=10000.0, + interleave=True, + is_qk_norm=True, + ) + + def run_fused(): + return fused_qk_rmsnorm_rope(qkv, q_weight, k_weight, **kwargs) + + # Reference check + if run_refcheck: + import torch.nn as nn + + query = qkv[..., :hidden_size] + key = qkv[..., hidden_size : 2 * hidden_size] + + norm_q = nn.RMSNorm(hidden_size, eps=eps).to(device).to(input_dtype) + norm_k = nn.RMSNorm(hidden_size, eps=eps).to(device).to(input_dtype) + with torch.no_grad(): + norm_q.weight.copy_(q_weight) + norm_k.weight.copy_(k_weight) + + q_ref = norm_q(query).unflatten(2, (num_heads, head_dim)) + k_ref = norm_k(key).unflatten(2, (num_heads, head_dim)) + + q_fused, k_fused, _ = run_fused() + + # Compare after norm, before RoPE (RoPE adds position-dependent rotation) + # For a rough check, compare magnitudes + q_diff = (q_fused.flatten(2).float() - q_ref.flatten(2).float()).abs().max() + k_diff = (k_fused.flatten(2).float() - k_ref.flatten(2).float()).abs().max() + if args.verbose >= 1: + print( + f"[INFO] Refcheck: Q max diff = {q_diff:.4f}, K max diff = {k_diff:.4f}" + ) + print("[INFO] (Note: diff includes RoPE rotation, so nonzero is expected)") + + backend_times = bench_gpu_time( + fn=run_fused, + dry_run_iters=args.dry_run_iters, + repeat_iters=args.num_iters, + enable_cupti=args.use_cupti, + use_cuda_graph=is_cuda_graph_compatible, + ) + + if len(backend_times) > 0: + median_time = np.median(backend_times) + std_time = np.std(backend_times) + + # Memory bandwidth: read QKV + Q/K weights, write Q + K + V + num_tokens = batch_size * seq_len + problem_bytes = ( + num_tokens * 3 * hidden_size * input_dtype.itemsize # QKV read + + 2 * hidden_size * input_dtype.itemsize # Q/K weight read + + num_tokens * 3 * hidden_size * input_dtype.itemsize # Q+K+V write + ) + problem_flops = num_tokens * hidden_size * 10 # rough estimate + tflops = problem_flops / (10**9 * median_time) + tb_per_sec = problem_bytes / (10**9 * median_time) + + print_perf_metrics("cuda", median_time, std_time, tflops, tb_per_sec) + + if args.output_path is not None: + cur_res = defaultdict(str) + cur_res["routine"] = args.routine + cur_res["median_time"] = median_time + cur_res["std_time"] = std_time + cur_res["tflops"] = tflops + cur_res["tb_per_sec"] = tb_per_sec + cur_res["num_heads"] = num_heads + cur_res["input_dtype"] = str(input_dtype) + cur_res["eps"] = eps + cur_res["backend"] = "cuda" + cur_res["ppf"] = ppf + cur_res["pph"] = pph + cur_res["ppw"] = ppw + cur_res["case_tag"] = args.case_tag + res.append(cur_res) + return res diff --git a/csrc/flashinfer_norm_binding.cu b/csrc/flashinfer_norm_binding.cu index 816eb2754b..fb64f38046 100644 --- a/csrc/flashinfer_norm_binding.cu +++ b/csrc/flashinfer_norm_binding.cu @@ -34,6 +34,16 @@ void gemma_fused_add_rmsnorm(TensorView input, TensorView residual, TensorView w void layernorm(Tensor out, Tensor input, Tensor gamma, Tensor beta, double eps); +void fused_qk_rmsnorm_rope_run(TensorView qkv_in, TensorView q_weight, TensorView k_weight, + TensorView q_out, TensorView k_out, TensorView v_out, + int64_t num_tokens, int64_t seq_len, int64_t ppf, int64_t pph, + int64_t ppw, int64_t num_frame_channels, int64_t num_height_channels, + int64_t num_width_channels, int64_t num_heads_q, int64_t num_heads_k, + int64_t num_heads_v, int64_t head_dim, double eps, double base, + bool interleave, double factor, double low, double high, + double attention_factor, bool is_qk_norm, bool output_fp8, + double output_quant_scale, double v_quant_scale); + TVM_FFI_DLL_EXPORT_TYPED_FUNC(rmsnorm, rmsnorm); TVM_FFI_DLL_EXPORT_TYPED_FUNC(rmsnorm_quant, rmsnorm_quant); TVM_FFI_DLL_EXPORT_TYPED_FUNC(fused_add_rmsnorm, fused_add_rmsnorm); @@ -41,3 +51,4 @@ TVM_FFI_DLL_EXPORT_TYPED_FUNC(fused_add_rmsnorm_quant, fused_add_rmsnorm_quant); TVM_FFI_DLL_EXPORT_TYPED_FUNC(gemma_rmsnorm, gemma_rmsnorm); TVM_FFI_DLL_EXPORT_TYPED_FUNC(gemma_fused_add_rmsnorm, gemma_fused_add_rmsnorm); TVM_FFI_DLL_EXPORT_TYPED_FUNC(layernorm, layernorm); +TVM_FFI_DLL_EXPORT_TYPED_FUNC(fused_qk_rmsnorm_rope, fused_qk_rmsnorm_rope_run); diff --git a/csrc/norm.cu b/csrc/norm.cu index b3460a87d9..a82314e2f4 100644 --- a/csrc/norm.cu +++ b/csrc/norm.cu @@ -14,6 +14,7 @@ * limitations under the License. */ #include +#include #include "tvm_ffi_utils.h" @@ -271,3 +272,44 @@ void layernorm(Tensor output, Tensor input, Tensor gamma, Tensor beta, double ep return true; }); } + +void fused_qk_rmsnorm_rope_run(TensorView qkv_in, TensorView q_weight, TensorView k_weight, + TensorView q_out, TensorView k_out, TensorView v_out, + int64_t num_tokens, int64_t seq_len, int64_t ppf, int64_t pph, + int64_t ppw, int64_t num_frame_channels, int64_t num_height_channels, + int64_t num_width_channels, int64_t num_heads_q, int64_t num_heads_k, + int64_t num_heads_v, int64_t head_dim, double eps, double base, + bool interleave, double factor, double low, double high, + double attention_factor, bool is_qk_norm, bool output_fp8, + double output_quant_scale, double v_quant_scale) { + CHECK_INPUT(qkv_in); + CHECK_INPUT(q_weight); + CHECK_INPUT(k_weight); + CHECK_CUDA(q_out); + CHECK_CONTIGUOUS(q_out); + CHECK_CUDA(k_out); + CHECK_CONTIGUOUS(k_out); + CHECK_CUDA(v_out); + CHECK_CONTIGUOUS(v_out); + + CHECK_INPUT_TYPE(qkv_in, dl_bfloat16); + CHECK_INPUT_TYPE(q_weight, dl_bfloat16); + CHECK_INPUT_TYPE(k_weight, dl_bfloat16); + + ffi::CUDADeviceGuard device_guard(qkv_in.device().device_id); + const cudaStream_t stream = get_stream(qkv_in.device()); + + int num_sms; + cudaDeviceGetAttribute(&num_sms, cudaDevAttrMultiProcessorCount, qkv_in.device().device_id); + + launchFusedQKNormRope( + qkv_in.data_ptr(), q_out.data_ptr(), k_out.data_ptr(), v_out.data_ptr(), num_tokens, + static_cast(seq_len), static_cast(ppf), static_cast(pph), + static_cast(ppw), static_cast(num_frame_channels), + static_cast(num_height_channels), static_cast(num_width_channels), + static_cast(num_heads_q), static_cast(num_heads_k), static_cast(num_heads_v), + static_cast(head_dim), static_cast(eps), q_weight.data_ptr(), k_weight.data_ptr(), + static_cast(base), interleave, static_cast(factor), static_cast(low), + static_cast(high), static_cast(attention_factor), stream, is_qk_norm, num_sms, + output_fp8, static_cast(output_quant_scale), static_cast(v_quant_scale)); +} diff --git a/flashinfer/diffusion_ops/__init__.py b/flashinfer/diffusion_ops/__init__.py new file mode 100644 index 0000000000..2a87e57d66 --- /dev/null +++ b/flashinfer/diffusion_ops/__init__.py @@ -0,0 +1,5 @@ +from flashinfer.norm import fused_qk_rmsnorm_rope + +__all__ = [ + "fused_qk_rmsnorm_rope", +] diff --git a/flashinfer/norm/__init__.py b/flashinfer/norm/__init__.py index ba612b2853..affb4b75d0 100644 --- a/flashinfer/norm/__init__.py +++ b/flashinfer/norm/__init__.py @@ -27,7 +27,7 @@ import functools import os import warnings -from typing import Optional, Union +from typing import Optional, Tuple, Union import torch @@ -43,10 +43,12 @@ rmsnorm_trace, ) from ..utils import ( + backend_requirement, device_support_pdl, get_compute_capability, register_custom_op, register_fake_op, + supported_compute_capability, ) # Always import gen_norm_module for JIT warmup and CUDA fallback @@ -70,11 +72,16 @@ # nvidia-cutlass-dsl not installed or incompatible version _USE_CUDA_NORM = True -if _USE_CUDA_NORM: - @functools.cache - def get_norm_module(): - return gen_norm_module().build_and_load() +@functools.cache +def get_norm_module(): + """Get or compile the CUDA JIT norm module. + + Always available regardless of _USE_CUDA_NORM setting, since some + fused kernels (e.g. fused_qk_rmsnorm_rope) only have a CUDA JIT + implementation and no CuTe DSL alternative. + """ + return gen_norm_module().build_and_load() def _normalize_scale_tensor( @@ -771,6 +778,300 @@ def fused_rmsnorm_silu( return out +#################################################################################################### +# Fused QK RMSNorm + 3D RoPE for Video Generation DIT Self-Attention +#################################################################################################### + + +@supported_compute_capability([80, 86, 89, 90, 100, 103, 110, 120, 121]) +def _check_fused_qk_rmsnorm_rope( + qkv, + q_weight, + k_weight, + **kwargs, +): + """Validate inputs for fused QK RMSNorm + 3D RoPE. + + Architecture notes: + - SM80+ (Ampere): Full support for BF16 path; FP8 output uses software emulation + - SM89+ (Ada): Native FP8 E4M3 conversion instructions (faster FP8 output) + - SM90 (Hopper): Primary target architecture + - SM100/103 (Blackwell B200, B300): Native float2 packed math (FFMA2); primary target + All SM100+/SM89+ features have scalar fallbacks, so SM80 is the true minimum. + """ + if not qkv.is_cuda: + raise ValueError("qkv must be a CUDA tensor") + if qkv.dtype != torch.bfloat16: + raise ValueError("qkv must be bfloat16") + if not qkv.is_contiguous(): + raise ValueError("qkv must be contiguous") + if qkv.ndim not in (2, 3): + raise ValueError( + f"qkv must be 2D [num_tokens, hidden] or 3D [batch, seq_len, hidden], " + f"got {qkv.ndim}D" + ) + + head_dim = kwargs.get("head_dim") + if head_dim not in (64, 128, 256): + raise ValueError(f"head_dim must be 64, 128, or 256, got {head_dim}") + + num_heads_q = kwargs.get("num_heads_q") + num_heads_k = kwargs.get("num_heads_k") + num_heads_v = kwargs.get("num_heads_v") + max_heads = max(num_heads_q, num_heads_k, num_heads_v) + if max_heads > 32: + raise ValueError( + f"max(num_heads_q, num_heads_k, num_heads_v) must be <= 32, got {max_heads}" + ) + + num_frame_channels = kwargs.get("num_frame_channels") + num_height_channels = kwargs.get("num_height_channels") + num_width_channels = kwargs.get("num_width_channels") + if num_frame_channels + num_height_channels + num_width_channels != head_dim: + raise ValueError( + f"num_frame_channels ({num_frame_channels}) + num_height_channels " + f"({num_height_channels}) + num_width_channels ({num_width_channels}) " + f"must equal head_dim ({head_dim})" + ) + if ( + num_frame_channels % 2 != 0 + or num_height_channels % 2 != 0 + or num_width_channels % 2 != 0 + ): + raise ValueError( + f"Channel counts must all be even (freq table uses count/2), got " + f"frame={num_frame_channels}, height={num_height_channels}, " + f"width={num_width_channels}" + ) + + ppf = kwargs.get("ppf") + pph = kwargs.get("pph") + ppw = kwargs.get("ppw") + if ppf <= 0 or pph <= 0 or ppw <= 0: + raise ValueError(f"ppf, pph, ppw must be positive, got ({ppf}, {pph}, {ppw})") + expected_seq_len = ppf * pph * ppw + if qkv.ndim == 3: + actual_seq_len = qkv.shape[1] + if actual_seq_len != expected_seq_len: + raise ValueError( + f"qkv seq_len ({actual_seq_len}) != ppf*pph*ppw ({expected_seq_len})" + ) + else: + num_tokens = qkv.shape[0] + if num_tokens % expected_seq_len != 0: + raise ValueError( + f"qkv num_tokens ({num_tokens}) must be divisible by " + f"ppf*pph*ppw ({expected_seq_len})" + ) + + return True + + +@flashinfer_api +@backend_requirement(backend_checks={}, common_check=_check_fused_qk_rmsnorm_rope) +def fused_qk_rmsnorm_rope( + qkv: torch.Tensor, + q_weight: torch.Tensor, + k_weight: torch.Tensor, + *, + ppf: int, + pph: int, + ppw: int, + num_frame_channels: int, + num_height_channels: int, + num_width_channels: int, + num_heads_q: int, + num_heads_k: int, + num_heads_v: int, + head_dim: int, + eps: float = 1e-6, + base: float = 10000.0, + interleave: bool = True, + factor: float = 1.0, + low: float = 0.0, + high: float = 0.0, + attention_factor: float = 1.0, + is_qk_norm: bool = True, + output_fp8: bool = False, + output_quant_scale: float = 1.0, + v_quant_scale: float = 1.0, + q_out: Optional[torch.Tensor] = None, + k_out: Optional[torch.Tensor] = None, + v_out: Optional[torch.Tensor] = None, +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + r"""Fused QK RMSNorm + 3D RoPE + V copy for video generation DIT self-attention. + + Applies across-heads RMSNorm to Q and K, then rotary position embeddings + with 3D spatial decomposition (frame/height/width), and copies V to a + contiguous output buffer. Optionally quantizes all outputs to FP8 E4M3. + + Parameters + ---------- + qkv : torch.Tensor + Combined QKV input, BF16, contiguous. Accepted shapes: + - 3D: ``[batch, seq_len, (num_heads_q+num_heads_k+num_heads_v)*head_dim]`` + - 2D: ``[num_tokens, (num_heads_q+num_heads_k+num_heads_v)*head_dim]`` + where ``num_tokens`` must be divisible by ``ppf*pph*ppw``. + q_weight : torch.Tensor + RMSNorm weight for Q ``[num_heads_q * head_dim]``, BF16. + k_weight : torch.Tensor + RMSNorm weight for K ``[num_heads_k * head_dim]``, BF16. + ppf : int + Number of patches in frame dimension. + pph : int + Number of patches in height dimension. + ppw : int + Number of patches in width dimension. + ``seq_len = ppf * pph * ppw``. + num_frame_channels : int + RoPE frequency channels for the frame dimension (must be even). + num_height_channels : int + RoPE frequency channels for the height dimension (must be even). + num_width_channels : int + RoPE frequency channels for the width dimension (must be even). + ``num_frame_channels + num_height_channels + num_width_channels == head_dim``. + num_heads_q : int + Number of query heads. + num_heads_k : int + Number of key heads. + num_heads_v : int + Number of value heads. + head_dim : int + Dimension per head (must be 64, 128, or 256). + eps : float + RMSNorm epsilon. + base : float + RoPE base frequency. + interleave : bool + True for interleaved RoPE (non-NeoX style), False for NeoX-style. + factor : float + YARN RoPE scaling factor. 1.0 disables YARN. + low : float + YARN low frequency threshold. + high : float + YARN high frequency threshold. + attention_factor : float + YARN attention factor applied to cos/sin. Must be 1.0 when factor is 1.0. + is_qk_norm : bool + Whether to apply RMSNorm (False = RoPE only, skip normalization). + output_fp8 : bool + Quantize Q, K, V outputs to FP8 E4M3. + output_quant_scale : float + FP8 quantization scale for Q and K outputs. + v_quant_scale : float + FP8 quantization scale for V output. + q_out : Optional[torch.Tensor] + Pre-allocated Q output tensor (destination-passing style). + k_out : Optional[torch.Tensor] + Pre-allocated K output tensor. + v_out : Optional[torch.Tensor] + Pre-allocated V output tensor. + + Returns + ------- + Tuple[torch.Tensor, torch.Tensor, torch.Tensor] + ``(q_out, k_out, v_out)``. If input is 3D, each has shape + ``[batch, seq_len, num_heads_x, head_dim]``. If input is 2D, + each has shape ``[num_tokens, num_heads_x, head_dim]``. + """ + out_dtype = torch.float8_e4m3fn if output_fp8 else torch.bfloat16 + seq_len = ppf * pph * ppw + + out_shape_q: tuple[int, ...] + out_shape_k: tuple[int, ...] + out_shape_v: tuple[int, ...] + if qkv.ndim == 3: + batch_size = qkv.shape[0] + num_tokens = batch_size * seq_len + out_shape_q = (batch_size, seq_len, num_heads_q, head_dim) + out_shape_k = (batch_size, seq_len, num_heads_k, head_dim) + out_shape_v = (batch_size, seq_len, num_heads_v, head_dim) + else: + num_tokens = qkv.shape[0] + out_shape_q = (num_tokens, num_heads_q, head_dim) + out_shape_k = (num_tokens, num_heads_k, head_dim) + out_shape_v = (num_tokens, num_heads_v, head_dim) + + # Validate weights + expected_q_weight_numel = num_heads_q * head_dim + expected_k_weight_numel = num_heads_k * head_dim + if q_weight.numel() != expected_q_weight_numel: + raise ValueError( + f"q_weight size {q_weight.numel()} != num_heads_q*head_dim ({expected_q_weight_numel})" + ) + if k_weight.numel() != expected_k_weight_numel: + raise ValueError( + f"k_weight size {k_weight.numel()} != num_heads_k*head_dim ({expected_k_weight_numel})" + ) + if q_weight.dtype != torch.bfloat16 or k_weight.dtype != torch.bfloat16: + raise ValueError("q_weight and k_weight must be bfloat16") + if not q_weight.is_contiguous() or not k_weight.is_contiguous(): + raise ValueError("q_weight and k_weight must be contiguous") + + if q_out is None: + q_out = torch.empty(*out_shape_q, dtype=out_dtype, device=qkv.device) + if k_out is None: + k_out = torch.empty(*out_shape_k, dtype=out_dtype, device=qkv.device) + if v_out is None: + v_out = torch.empty(*out_shape_v, dtype=out_dtype, device=qkv.device) + + # Validate user-supplied output buffers + for name, buf, expected_shape in [ + ("q_out", q_out, out_shape_q), + ("k_out", k_out, out_shape_k), + ("v_out", v_out, out_shape_v), + ]: + if tuple(buf.shape) != expected_shape: + raise ValueError( + f"{name} shape {tuple(buf.shape)} != expected {expected_shape}" + ) + if buf.dtype != out_dtype: + raise ValueError(f"{name} dtype {buf.dtype} != expected {out_dtype}") + if not buf.is_contiguous(): + raise ValueError(f"{name} must be contiguous") + if buf.device != qkv.device: + raise ValueError(f"{name} device {buf.device} != qkv device {qkv.device}") + + qkv_flat = qkv.view(num_tokens, -1) + q_out_flat = q_out.view(num_tokens, -1) + k_out_flat = k_out.view(num_tokens, -1) + v_out_flat = v_out.view(num_tokens, -1) + + get_norm_module().fused_qk_rmsnorm_rope( + qkv_flat, + q_weight, + k_weight, + q_out_flat, + k_out_flat, + v_out_flat, + num_tokens, + seq_len, + ppf, + pph, + ppw, + num_frame_channels, + num_height_channels, + num_width_channels, + num_heads_q, + num_heads_k, + num_heads_v, + head_dim, + float(eps), + float(base), + interleave, + float(factor), + float(low), + float(high), + float(attention_factor), + is_qk_norm, + output_fp8, + float(output_quant_scale), + float(v_quant_scale), + ) + + return q_out, k_out, v_out + + # Public API exports __all__ = [ # JIT module generator (always available) @@ -784,4 +1085,5 @@ def fused_rmsnorm_silu( "gemma_fused_add_rmsnorm", "layernorm", "fused_rmsnorm_silu", + "fused_qk_rmsnorm_rope", ] diff --git a/include/flashinfer/norm/fused_qk_rmsnorm_rope.cuh b/include/flashinfer/norm/fused_qk_rmsnorm_rope.cuh new file mode 100644 index 0000000000..f86ad687bb --- /dev/null +++ b/include/flashinfer/norm/fused_qk_rmsnorm_rope.cuh @@ -0,0 +1,770 @@ +/* + * Copyright (c) 2025 by FlashInfer team. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef FLASHINFER_FUSED_QK_RMSNORM_ROPE_CUH_ +#define FLASHINFER_FUSED_QK_RMSNORM_ROPE_CUH_ + +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include + +namespace flashinfer { + +#define FLASHINFER_FUSED_CHECK(condition) \ + do { \ + if (!(condition)) { \ + fprintf(stderr, "FLASHINFER_FUSED_CHECK failed at %s:%d: %s\n", __FILE__, __LINE__, \ + #condition); \ + abort(); \ + } \ + } while (0) + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// Section 1: IntFastDiv — fast signed integer division on GPU +// Based on Hacker's Delight, Second Edition, Chapter 10. +//////////////////////////////////////////////////////////////////////////////////////////////////// + +class IntFastDiv { + public: + __host__ IntFastDiv() : mDivisor(1), mMagicM(0), mMagicS(-1), mAddSign(1) {} + + __host__ IntFastDiv(int divisor) : mDivisor(divisor) { + if (mDivisor == 0) throw std::runtime_error("IntFastDiv: cannot divide by 0"); + updateMagicNumbers(); + } + + __host__ IntFastDiv& operator=(int divisor) { + this->mDivisor = divisor; + if (this->mDivisor == 0) throw std::runtime_error("IntFastDiv: cannot divide by 0"); + updateMagicNumbers(); + return *this; + } + + __host__ __device__ operator int() const { return mDivisor; } + + private: + int mDivisor; + int mMagicM; + int mMagicS; + int mAddSign; + + __host__ void updateMagicNumbers() { + if (mDivisor == 1) { + mMagicM = 0; + mMagicS = -1; + mAddSign = 1; + return; + } else if (mDivisor == -1) { + mMagicM = 0; + mMagicS = -1; + mAddSign = -1; + return; + } + + int p; + unsigned int tmpAd, tmpAnc, delta, q1, r1, q2, r2, t; + unsigned const two31 = 0x80000000; + tmpAd = abs(mDivisor); + t = two31 + ((unsigned int)mDivisor >> 31); + tmpAnc = t - 1 - t % tmpAd; + p = 31; + q1 = two31 / tmpAnc; + r1 = two31 - q1 * tmpAnc; + q2 = two31 / tmpAd; + r2 = two31 - q2 * tmpAd; + do { + ++p; + q1 = 2 * q1; + r1 = 2 * r1; + if (r1 >= tmpAnc) { + ++q1; + r1 -= tmpAnc; + } + q2 = 2 * q2; + r2 = 2 * r2; + if (r2 >= tmpAd) { + ++q2; + r2 -= tmpAd; + } + delta = tmpAd - r2; + } while (q1 < delta || (q1 == delta && r1 == 0)); + this->mMagicM = q2 + 1; + if (mDivisor < 0) this->mMagicM = -this->mMagicM; + this->mMagicS = p - 32; + + if ((mDivisor > 0) && (mMagicM < 0)) + mAddSign = 1; + else if ((mDivisor < 0) && (mMagicM > 0)) + mAddSign = -1; + else + mAddSign = 0; + } + + __host__ __device__ friend int operator/(int const dividend, IntFastDiv const& divisor); +}; + +__host__ __device__ inline int operator/(int const dividend, IntFastDiv const& divisor) { + int q; +#ifdef __CUDA_ARCH__ + asm("mul.hi.s32 %0, %1, %2;" : "=r"(q) : "r"(divisor.mMagicM), "r"(dividend)); +#else + q = (((unsigned long long)((long long)divisor.mMagicM * (long long)dividend)) >> 32); +#endif + q += dividend * divisor.mAddSign; + if (divisor.mMagicS >= 0) { + q >>= divisor.mMagicS; + q += (((unsigned int)q) >> 31); + } + return q; +} + +__host__ __device__ inline int operator%(int const dividend, IntFastDiv const& divisor) { + int quotient = dividend / divisor; + int remainder = dividend - quotient * divisor; + return remainder; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// Section 2: packed_as — maps a base type + vector width to the appropriate CUDA vector type. +// In a sub-namespace to avoid collision with identically-named templates in norm.cuh. +//////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace fused_rope_detail { + +template +struct packed_as { + static_assert(N == 1, "packed_as only supports N=1, 2, 4"); + using type = T; +}; + +template <> +struct packed_as { + using type = uint; +}; + +template <> +struct packed_as { + using type = uint2; +}; + +template <> +struct packed_as { + using type = uint4; +}; + +} // namespace fused_rope_detail + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// Section 3: FP8 E4M3 quantization helpers +// SM89+ (Ada/Hopper/Blackwell): native vectorized PTX conversion +// SM < 89: scalar __nv_fp8_e4m3 constructor fallback +//////////////////////////////////////////////////////////////////////////////////////////////////// + +__device__ __forceinline__ __nv_fp8_e4m3 quantize_fp8_e4m3(float val, float scale = 1.0f) { + return __nv_fp8_e4m3(val * scale); +} + +__device__ __forceinline__ uint16_t float2_to_fp8_e4m3_packed(float2 val, float scale = 1.0f) { + __nv_fp8_e4m3 fp8_0 = quantize_fp8_e4m3(val.x, scale); + __nv_fp8_e4m3 fp8_1 = quantize_fp8_e4m3(val.y, scale); + return (*reinterpret_cast(&fp8_0)) | ((*reinterpret_cast(&fp8_1)) << 8); +} + +__device__ __forceinline__ uint32_t float4_to_fp8_e4m3_packed(float f0, float f1, float f2, + float f3, float scale = 1.0f) { + float scaled0 = f0 * scale; + float scaled1 = f1 * scale; + float scaled2 = f2 * scale; + float scaled3 = f3 * scale; + + uint32_t result; + +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 890 + asm volatile( + "{\n\t" + ".reg .b16 r_lo, r_hi;\n\t" + "cvt.rn.satfinite.e4m3x2.f32 r_lo, %4, %3;\n\t" + "cvt.rn.satfinite.e4m3x2.f32 r_hi, %2, %1;\n\t" + "mov.b32 %0, {r_hi, r_lo};\n\t" + "}" + : "=r"(result) + : "f"(scaled0), "f"(scaled1), "f"(scaled2), "f"(scaled3)); +#else + __nv_fp8_e4m3 fp8_0 = __nv_fp8_e4m3(scaled0); + __nv_fp8_e4m3 fp8_1 = __nv_fp8_e4m3(scaled1); + __nv_fp8_e4m3 fp8_2 = __nv_fp8_e4m3(scaled2); + __nv_fp8_e4m3 fp8_3 = __nv_fp8_e4m3(scaled3); + result = (*reinterpret_cast(&fp8_0)) | ((*reinterpret_cast(&fp8_1)) << 8) | + ((*reinterpret_cast(&fp8_2)) << 16) | + ((*reinterpret_cast(&fp8_3)) << 24); +#endif + + return result; +} + +__device__ __forceinline__ uint2 float8_to_fp8_e4m3_packed(float2 val0, float2 val1, float2 val2, + float2 val3, float scale = 1.0f) { + uint32_t packed_lo = float4_to_fp8_e4m3_packed(val0.x, val0.y, val1.x, val1.y, scale); + uint32_t packed_hi = float4_to_fp8_e4m3_packed(val2.x, val2.y, val3.x, val3.y, scale); + return make_uint2(packed_lo, packed_hi); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// Section 4: Blackwell FFMA2 intrinsics (SM100+) with scalar fallbacks +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 1000 +__device__ __forceinline__ float2 fmul2(const float2& a, const float2& b) { + uint64_t c; + asm volatile("mul.f32x2 %0, %1, %2;\n" + : "=l"(c) + : "l"(reinterpret_cast(a)), + "l"(reinterpret_cast(b))); + return reinterpret_cast(c); +} + +__device__ __forceinline__ float2 ffma2(const float2& a, const float2& b, const float2& c) { + uint64_t d; + asm volatile("fma.rn.f32x2 %0, %1, %2, %3;\n" + : "=l"(d) + : "l"(reinterpret_cast(a)), + "l"(reinterpret_cast(b)), + "l"(reinterpret_cast(c))); + return reinterpret_cast(d); +} + +__device__ __forceinline__ float2 fadd2(const float2& a, const float2& b) { + uint64_t c; + asm volatile("add.f32x2 %0, %1, %2;\n" + : "=l"(c) + : "l"(reinterpret_cast(a)), + "l"(reinterpret_cast(b))); + return reinterpret_cast(c); +} +#else +__device__ __forceinline__ float2 fmul2(const float2& a, const float2& b) { + return make_float2(a.x * b.x, a.y * b.y); +} + +__device__ __forceinline__ float2 ffma2(const float2& a, const float2& b, const float2& c) { + return make_float2(a.x * b.x + c.x, a.y * b.y + c.y); +} + +__device__ __forceinline__ float2 fadd2(const float2& a, const float2& b) { + return make_float2(a.x + b.x, a.y + b.y); +} +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// Section 5: Vectorized FP8 store helper +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +__device__ __forceinline__ void quantize_store_fp8(float2 const* elements, __nv_fp8_e4m3* out, + int64_t offset, float scale) { + constexpr int numFloat2PerThread = numElemsPerThread / 2; + if constexpr (numElemsPerThread == 2) { + uint16_t packed = float2_to_fp8_e4m3_packed(elements[0], scale); + *reinterpret_cast(&out[offset]) = packed; + } else if constexpr (numElemsPerThread == 4) { + uint32_t packed = float4_to_fp8_e4m3_packed(elements[0].x, elements[0].y, elements[1].x, + elements[1].y, scale); + *reinterpret_cast(&out[offset]) = packed; + } else if constexpr (numElemsPerThread == 8) { + uint2 packed = + float8_to_fp8_e4m3_packed(elements[0], elements[1], elements[2], elements[3], scale); + *reinterpret_cast(&out[offset]) = packed; + } else { +#pragma unroll + for (int ii = 0; ii < numFloat2PerThread; ii++) { + out[offset + ii * 2] = quantize_fp8_e4m3(elements[ii].x, scale); + out[offset + ii * 2 + 1] = quantize_fp8_e4m3(elements[ii].y, scale); + } + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// Section 6: Fused QK RMSNorm + RoPE kernel +// +// Performs across-heads RMSNorm and 3D RoPE in a single kernel (for self-attention). +// Also copies V to a separate contiguous output buffer with optional FP8 quantization. +// +// Architecture: +// - 2D grid: blockIdx.x = tokenIdx, blockIdx.y = block type (0=Q, 1=K, 2=V) +// - Q/K blocks: warps cooperate for RMSNorm reduction, then each warp applies norm + RoPE +// - V blocks: just copy + optional FP8 quantize (no RMSNorm, no RoPE) +//////////////////////////////////////////////////////////////////////////////////////////////////// + +constexpr int THREADS_PER_WARP = 32; + +template +__global__ void fusedQKNormRopeKernel( + __nv_bfloat16 const* qkv_in, void* q_out, void* k_out, void* v_out, int const num_heads_q, + int const num_heads_k, int const num_heads_v, float const eps, __nv_bfloat16 const* q_weight, + __nv_bfloat16 const* k_weight, float const* freq_table, int64_t const num_tokens, + IntFastDiv const seq_len, IntFastDiv const ppw, IntFastDiv const pphppw, + int const num_frame_channels, int const num_height_channels, int const num_width_channels, + float attention_factor, bool is_qk_norm, float output_quant_scale, float v_quant_scale) { + static_assert((head_dim & (head_dim - 1)) == 0, + "head_dim must be a power of 2 (required for bitwise modulo in NeoX RoPE path)"); + static_assert( + head_dim % (THREADS_PER_WARP * 2) == 0, + "head_dim must be divisible by 64 (each warp processes one head with even element count)"); + constexpr int log_head_dim = __builtin_ctz(head_dim); + constexpr int numElemsPerThread = head_dim / THREADS_PER_WARP; + static_assert(numElemsPerThread % 2 == 0, "numElemsPerThread must be divisible by 2"); + constexpr int numFloat2PerThread = numElemsPerThread / 2; + constexpr int elemSizeBytes = numElemsPerThread * sizeof(__nv_bfloat16); + static_assert(elemSizeBytes % 4 == 0, "elemSizeBytes must be a multiple of 4"); + constexpr int vecSize = elemSizeBytes / 4; + using vec_T = typename fused_rope_detail::packed_as::type; + + int const warpId = threadIdx.x / THREADS_PER_WARP; + int const laneId = threadIdx.x % THREADS_PER_WARP; + + int64_t const tokenIdx = blockIdx.x; + int const blockType = blockIdx.y; + + if (tokenIdx >= num_tokens) return; + + int const num_heads = num_heads_q + num_heads_k + num_heads_v; + int64_t const baseOffset = tokenIdx * (int64_t)(num_heads * head_dim); + int const headIdx = warpId; + + int const threadHeadOffset = headIdx * head_dim + laneId * numElemsPerThread; + + // ========== V blocks: simple copy + optional FP8 quantize ========== + if (blockType == 2) { + if (headIdx >= num_heads_v) return; + + int64_t const v_input_offset = + baseOffset + (int64_t)((num_heads_q + num_heads_k) * head_dim) + threadHeadOffset; + vec_T vec = *reinterpret_cast(&qkv_in[v_input_offset]); + + int64_t const v_output_offset = tokenIdx * (int64_t)(num_heads_v * head_dim) + threadHeadOffset; + + if constexpr (OUTPUT_FP8) { + float2 elements[numFloat2PerThread]; +#pragma unroll + for (int i = 0; i < vecSize; i++) { + elements[i] = __bfloat1622float2( + *reinterpret_cast<__nv_bfloat162*>(reinterpret_cast(&vec) + i)); + } + quantize_store_fp8(elements, reinterpret_cast<__nv_fp8_e4m3*>(v_out), + v_output_offset, v_quant_scale); + } else { + __nv_bfloat16* bf16_out = reinterpret_cast<__nv_bfloat16*>(v_out); + *reinterpret_cast(&bf16_out[v_output_offset]) = vec; + } + return; + } + + // ========== Q/K blocks: RMSNorm + RoPE + optional FP8 quantize ========== + __shared__ float sharedSumOfSquares[MAX_HEADS]; + + bool const isQ = (blockType == 0); + + int const num_heads_this = isQ ? num_heads_q : num_heads_k; + + bool const validHead = (headIdx < num_heads_this); + int const hidden_dim_this = num_heads_this * head_dim; + + float2 elements[numFloat2PerThread]; + float2 r_weight[numFloat2PerThread]; + + int const qkSegmentStart = isQ ? 0 : num_heads_q * head_dim; + int64_t const inputOffset = baseOffset + qkSegmentStart + threadHeadOffset; + + float2 sumOfSquares = make_float2(0.0f, 0.0f); + + if (validHead) { + __nv_bfloat16 const* weight_ptr = isQ ? q_weight : k_weight; + + vec_T weight_vec = *reinterpret_cast(&weight_ptr[threadHeadOffset]); + vec_T vec = *reinterpret_cast(&qkv_in[inputOffset]); + +#pragma unroll + for (int i = 0; i < vecSize; i++) { + r_weight[i] = __bfloat1622float2( + *reinterpret_cast<__nv_bfloat162*>(reinterpret_cast(&weight_vec) + i)); + + float2 vals = + __bfloat1622float2(*reinterpret_cast<__nv_bfloat162*>(reinterpret_cast(&vec) + i)); + + sumOfSquares = ffma2(vals, vals, sumOfSquares); + + elements[i] = vals; + } + } + + if (is_qk_norm) { + float sumOfSquaresScalar = sumOfSquares.x + sumOfSquares.y; + +#pragma unroll + for (int step = THREADS_PER_WARP / 2; step > 0; step /= 2) { + sumOfSquaresScalar += __shfl_xor_sync(0xffffffff, sumOfSquaresScalar, step); + } + + if (laneId == 0) { + sharedSumOfSquares[warpId] = validHead ? sumOfSquaresScalar : 0.0f; + } + __syncthreads(); + + int const warpsPerBlock = blockDim.x / THREADS_PER_WARP; + float totalSumOfSquares = (laneId < warpsPerBlock) ? sharedSumOfSquares[laneId] : 0.0f; + +#pragma unroll + for (int step = THREADS_PER_WARP / 2; step > 0; step /= 2) { + totalSumOfSquares += __shfl_xor_sync(0xffffffff, totalSumOfSquares, step); + } + + float rms_rcp = rsqrtf(totalSumOfSquares / static_cast(hidden_dim_this) + eps); + + if (validHead) { + float2 rms_rcp_vec = make_float2(rms_rcp, rms_rcp); +#pragma unroll + for (int i = 0; i < numFloat2PerThread; i++) { + elements[i] = fmul2(fmul2(elements[i], rms_rcp_vec), r_weight[i]); + + // Round to BF16 and back to match the precision of the unfused reference. + // Without this, the fused kernel carries extra float32 mantissa bits + // through RoPE, producing results that differ from the reference. + __nv_bfloat162 tmp = __float22bfloat162_rn(elements[i]); + elements[i] = __bfloat1622float2(tmp); + } + } + } + + // Apply RoPE to normalized elements + if (validHead) { + float2 elements2[numFloat2PerThread]; + float2 cos_vals[numFloat2PerThread]; + float2 sin_vals[numFloat2PerThread]; + + int const token_idx_in_seq = tokenIdx % seq_len; + int const pos_id_t = token_idx_in_seq / pphppw; + int const pos_id_x = token_idx_in_seq % pphppw; + int const pos_id_h = pos_id_x / ppw; + int const pos_id_w = pos_id_x % ppw; + + int32_t height_slice_start = num_frame_channels; + int32_t width_slice_start = num_frame_channels + num_height_channels; + + if constexpr (interleave) { +#pragma unroll + for (int ii = 0; ii < numFloat2PerThread; ii++) { + int dim_idx_x = laneId * numElemsPerThread + ii * 2; + int pos_id = dim_idx_x >= width_slice_start ? pos_id_w + : dim_idx_x >= height_slice_start ? pos_id_h + : pos_id_t; + + float freq_xy = freq_table[dim_idx_x >> 1]; + float theta_xy = pos_id * freq_xy; + float sin_xy, cos_xy; + __sincosf(theta_xy, &sin_xy, &cos_xy); + + sin_vals[ii] = make_float2(sin_xy, sin_xy); + cos_vals[ii] = make_float2(cos_xy, cos_xy); + + elements2[ii] = make_float2(-elements[ii].y, elements[ii].x); + } + } else { + __syncwarp(); +#pragma unroll + for (int ii = 0; ii < numFloat2PerThread; ii++) { + float elem_x = __shfl_xor_sync(0xffffffff, elements[ii].x, 16); + float elem_y = __shfl_xor_sync(0xffffffff, elements[ii].y, 16); + if (laneId < 16) { + elem_x = -elem_x; + elem_y = -elem_y; + } + elements2[ii] = make_float2(elem_x, elem_y); + } + +#pragma unroll + for (int ii = 0; ii < numFloat2PerThread; ii++) { + int dim_idx_x = laneId * numElemsPerThread + ii * 2; + dim_idx_x = (dim_idx_x * 2) & ((1 << log_head_dim) - 1); + int pos_id = dim_idx_x >= width_slice_start ? pos_id_w + : dim_idx_x >= height_slice_start ? pos_id_h + : pos_id_t; + + float freq_x = freq_table[dim_idx_x >> 1]; + float theta_x = pos_id * freq_x; + float sin_x, cos_x; + __sincosf(theta_x, &sin_x, &cos_x); + + int dim_idx_y = laneId * numElemsPerThread + ii * 2 + 1; + dim_idx_y = (dim_idx_y * 2) & ((1 << log_head_dim) - 1); + pos_id = dim_idx_y >= width_slice_start ? pos_id_w + : dim_idx_y >= height_slice_start ? pos_id_h + : pos_id_t; + + float freq_y = freq_table[dim_idx_y >> 1]; + float theta_y = pos_id * freq_y; + float sin_y, cos_y; + __sincosf(theta_y, &sin_y, &cos_y); + + sin_vals[ii] = make_float2(sin_x, sin_y); + cos_vals[ii] = make_float2(cos_x, cos_y); + } + __syncwarp(); + } + + if constexpr (HAS_YARN) { + float2 attention_factor_vec = make_float2(attention_factor, attention_factor); +#pragma unroll + for (int ii = 0; ii < numFloat2PerThread; ii++) { + elements[ii] = fmul2(ffma2(elements[ii], cos_vals[ii], fmul2(elements2[ii], sin_vals[ii])), + attention_factor_vec); + } + } else { +#pragma unroll + for (int ii = 0; ii < numFloat2PerThread; ii++) { + elements[ii] = ffma2(elements[ii], cos_vals[ii], fmul2(elements2[ii], sin_vals[ii])); + } + } + + int64_t const outputBase = tokenIdx * (int64_t)(num_heads_this * head_dim); + int64_t const outputOffset = outputBase + threadHeadOffset; + + if constexpr (OUTPUT_FP8) { + __nv_fp8_e4m3* fp8_out = + isQ ? reinterpret_cast<__nv_fp8_e4m3*>(q_out) : reinterpret_cast<__nv_fp8_e4m3*>(k_out); + quantize_store_fp8(elements, fp8_out, outputOffset, output_quant_scale); + } else { + __nv_bfloat16* bf16_out = + isQ ? reinterpret_cast<__nv_bfloat16*>(q_out) : reinterpret_cast<__nv_bfloat16*>(k_out); + vec_T vec; + for (int ii = 0; ii < vecSize; ii++) { + __nv_bfloat162 vals = __float22bfloat162_rn(elements[ii]); + reinterpret_cast<__nv_bfloat162&>(*(reinterpret_cast(&vec) + ii)) = vals; + } + vec_T* outputPtr = reinterpret_cast(&bf16_out[outputOffset]); + *outputPtr = vec; + } + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// Section 7: Dispatch macros +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#define DISPATCH_INTERLEAVE(interleave, INTERLEAVE, ...) \ + if (interleave) { \ + const bool INTERLEAVE = true; \ + __VA_ARGS__ \ + } else { \ + const bool INTERLEAVE = false; \ + __VA_ARGS__ \ + } + +#define DISPATCH_OUTPUT_FP8(output_fp8, OUTPUT_FP8, ...) \ + if (output_fp8) { \ + const bool OUTPUT_FP8 = true; \ + __VA_ARGS__ \ + } else { \ + const bool OUTPUT_FP8 = false; \ + __VA_ARGS__ \ + } + +#define DISPATCH_HAS_YARN(has_yarn, HAS_YARN, ...) \ + if (has_yarn) { \ + const bool HAS_YARN = true; \ + __VA_ARGS__ \ + } else { \ + const bool HAS_YARN = false; \ + __VA_ARGS__ \ + } + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// Section 8: Host-side frequency computation +//////////////////////////////////////////////////////////////////////////////////////////////////// + +static inline float compute_adjusted_freq_host(int half_dim_val, float base, int dim_size, + float factor, float low, float high) { + float freq = powf(base, -2.0f * half_dim_val / static_cast(dim_size)); + if (factor != 1.0f) { + float inv_freq_extrapolation = freq; + float inv_freq_interpolation = freq / factor; + + float high_adj = high; + if (fabsf(low - high_adj) <= 1e-6f) { + high_adj += 0.001f; + } + float linear_func = (static_cast(half_dim_val) - low) / (high_adj - low); + float ramp_func = fminf(fmaxf(linear_func, 0.0f), 1.0f); + float inv_freq_extrapolation_factor = 1.0f - ramp_func; + freq = inv_freq_interpolation * (1.0f - inv_freq_extrapolation_factor) + + inv_freq_extrapolation * inv_freq_extrapolation_factor; + } + return freq; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// Section 9: Frequency table cache + host launcher +// +// The cache is allocated once and never freed to ensure cudagraph capture compatibility. +//////////////////////////////////////////////////////////////////////////////////////////////////// + +struct FreqCacheEntry { + float* d_ptr = nullptr; + int alloc_floats = 0; + int head_dim = 0; + float base = 0.0f; + float factor = 0.0f; + float low = 0.0f; + float high = 0.0f; + int num_frame_channels = 0; + int num_height_channels = 0; + int num_width_channels = 0; +}; + +// Per-device frequency table cache. Keyed by CUDA device ID so that +// multi-GPU usage within a single process is safe. +static std::unordered_map s_freq_cache_map; + +inline void launchFusedQKNormRope(void const* qkv_in, void* q_out, void* k_out, void* v_out, + int64_t const num_tokens, int const seq_len, int const ppf, + int const pph, int const ppw, int const num_frame_channels, + int const num_height_channels, int const num_width_channels, + int const num_heads_q, int const num_heads_k, + int const num_heads_v, int const head_dim, float const eps, + void const* q_weight, void const* k_weight, float const base, + bool const interleave, float factor, float low, float high, + float attention_factor, cudaStream_t stream, bool is_qk_norm, + int const num_sms, bool output_fp8, float output_quant_scale, + float v_quant_scale) { + FLASHINFER_FUSED_CHECK((head_dim & (head_dim - 1)) == 0); + + if (factor == 1.0f) { + FLASHINFER_FUSED_CHECK(attention_factor == 1.0f); + } + + int device_id; + FLASHINFER_FUSED_CHECK(cudaGetDevice(&device_id) == cudaSuccess); + FreqCacheEntry& cache = s_freq_cache_map[device_id]; + + int const table_size = head_dim / 2; + + if (cache.alloc_floats < table_size) { + // Allocate without freeing: the old pointer (if any) is intentionally + // leaked to keep it valid for any cudagraph that captured it. + float* new_ptr = nullptr; + FLASHINFER_FUSED_CHECK(cudaMalloc(&new_ptr, table_size * sizeof(float)) == cudaSuccess); + cache.d_ptr = new_ptr; + cache.alloc_floats = table_size; + } + + bool cache_miss = + (cache.head_dim != head_dim || cache.base != base || cache.factor != factor || + cache.low != low || cache.high != high || cache.num_frame_channels != num_frame_channels || + cache.num_height_channels != num_height_channels || + cache.num_width_channels != num_width_channels); + + if (cache_miss) { + FLASHINFER_FUSED_CHECK(table_size <= 128); + float h_freq_table[128]; + int offset = 0; + + for (int i = 0; i < num_frame_channels / 2; i++) + h_freq_table[offset++] = + compute_adjusted_freq_host(i, base, num_frame_channels, factor, low, high); + + for (int i = 0; i < num_height_channels / 2; i++) + h_freq_table[offset++] = + compute_adjusted_freq_host(i, base, num_height_channels, factor, low, high); + + for (int i = 0; i < num_width_channels / 2; i++) + h_freq_table[offset++] = + compute_adjusted_freq_host(i, base, num_width_channels, factor, low, high); + + FLASHINFER_FUSED_CHECK(offset == table_size); + + // cudaMemcpyAsync from unpinned host memory is synchronous by CUDA spec, + // so h_freq_table (stack-allocated) is safe. Using Async form so the call + // is associated with the caller's stream for cudagraph capture. + FLASHINFER_FUSED_CHECK(cudaMemcpyAsync(cache.d_ptr, h_freq_table, table_size * sizeof(float), + cudaMemcpyHostToDevice, stream) == cudaSuccess); + + cache.head_dim = head_dim; + cache.base = base; + cache.factor = factor; + cache.low = low; + cache.high = high; + cache.num_frame_channels = num_frame_channels; + cache.num_height_channels = num_height_channels; + cache.num_width_channels = num_width_channels; + } + + int const maxHeads = max(max(num_heads_q, num_heads_k), num_heads_v); + int const warpsPerBlock = maxHeads; + int const blockSize = warpsPerBlock * THREADS_PER_WARP; + + FLASHINFER_FUSED_CHECK(num_tokens > 0); + FLASHINFER_FUSED_CHECK(num_tokens <= static_cast(INT32_MAX)); + dim3 gridDim(static_cast(num_tokens), 3); + dim3 blockDim(blockSize); + + bool const has_yarn = (factor != 1.0f); + +#define LAUNCH_KERNEL(HD) \ + DISPATCH_INTERLEAVE(interleave, INTERLEAVE, { \ + DISPATCH_OUTPUT_FP8(output_fp8, OUTPUT_FP8, { \ + DISPATCH_HAS_YARN(has_yarn, HAS_YARN, { \ + fusedQKNormRopeKernel \ + <<>>( \ + reinterpret_cast<__nv_bfloat16 const*>(qkv_in), q_out, k_out, v_out, num_heads_q, \ + num_heads_k, num_heads_v, eps, reinterpret_cast<__nv_bfloat16 const*>(q_weight), \ + reinterpret_cast<__nv_bfloat16 const*>(k_weight), cache.d_ptr, num_tokens, \ + IntFastDiv(seq_len), IntFastDiv(ppw), IntFastDiv(pph * ppw), num_frame_channels, \ + num_height_channels, num_width_channels, attention_factor, is_qk_norm, \ + output_quant_scale, v_quant_scale); \ + }); \ + }); \ + }) + + switch (head_dim) { + case 64: + LAUNCH_KERNEL(64); + break; + case 128: + LAUNCH_KERNEL(128); + break; + case 256: + LAUNCH_KERNEL(256); + break; + default: + FLASHINFER_FUSED_CHECK(false); // Unsupported head_dim + break; + } + +#undef LAUNCH_KERNEL + + FLASHINFER_FUSED_CHECK(cudaGetLastError() == cudaSuccess); +} + +} // namespace flashinfer + +#endif // FLASHINFER_FUSED_QK_RMSNORM_ROPE_CUH_ diff --git a/tests/norm/test_fused_qk_rmsnorm_rope.py b/tests/norm/test_fused_qk_rmsnorm_rope.py new file mode 100644 index 0000000000..09c208ee0a --- /dev/null +++ b/tests/norm/test_fused_qk_rmsnorm_rope.py @@ -0,0 +1,1028 @@ +""" +Tests for fused QK RMSNorm + 3D RoPE kernel. + +Tests correctness against a PyTorch reference implementation that matches +the WAN 2.2 model.py: + - RMSNorm across all heads (not per-head) + - 3D RoPE with frame/height/width spatial decomposition + - V passthrough copy + - Optional FP8 E4M3 quantized output + +Both interleaved and non-interleaved (NeoX) RoPE modes are tested. +""" + +import pytest +import torch +import torch.nn as nn + +from flashinfer.diffusion_ops import fused_qk_rmsnorm_rope + + +# --------------------------------------------------------------------------- +# Reference helpers +# --------------------------------------------------------------------------- + + +def apply_rotary_emb_interleaved( + hidden_states: torch.Tensor, + freqs_cos: torch.Tensor, + freqs_sin: torch.Tensor, +) -> torch.Tensor: + """Interleaved RoPE (pairs adjacent elements: 0,1 2,3 ...). + + hidden_states: [batch, seq, num_heads, head_dim] + freqs_cos/sin: [batch, seq, 1, head_dim] + """ + x1, x2 = hidden_states.unflatten(-1, (-1, 2)).unbind(-1) + cos = freqs_cos[..., 0::2] + sin = freqs_sin[..., 1::2] + out = torch.empty_like(hidden_states) + out[..., 0::2] = x1 * cos - x2 * sin + out[..., 1::2] = x1 * sin + x2 * cos + return out.type_as(hidden_states) + + +def apply_rotary_emb_neox( + hidden_states: torch.Tensor, + freqs_cos: torch.Tensor, + freqs_sin: torch.Tensor, +) -> torch.Tensor: + """NeoX-style RoPE matching the kernel's per-element frequency mapping. + + The kernel's NeoX path uses (dim_idx * 2) & ((1 << log_head_dim) - 1) to + map each element to a frequency index, then swaps first/second halves via + warp shuffle. Each element gets its own cos/sin value. + + hidden_states: [batch, seq, num_heads, head_dim] + freqs_cos/sin: [batch, seq, 1, head_dim] — per-element cos/sin values + computed using the kernel's mapped frequency index convention. + """ + half = hidden_states.shape[-1] // 2 + x1 = hidden_states[..., :half] + x2 = hidden_states[..., half:] + cos1 = freqs_cos[..., :half] + sin1 = freqs_sin[..., :half] + cos2 = freqs_cos[..., half:] + sin2 = freqs_sin[..., half:] + out = torch.empty_like(hidden_states) + out[..., :half] = x1 * cos1 - x2 * sin1 + out[..., half:] = x2 * cos2 + x1 * sin2 + return out.type_as(hidden_states) + + +def get_1d_rotary_pos_embed(dim, length, theta, device): + inv_freq = 1.0 / ( + theta ** (torch.arange(0, dim, 2, device=device, dtype=torch.float64) / dim) + ) + pos = torch.arange(length, device=device, dtype=torch.float64) + freqs = torch.einsum("i,j->ij", pos, inv_freq) + + cos_out = torch.zeros(length, dim, device=device, dtype=torch.float64) + sin_out = torch.zeros(length, dim, device=device, dtype=torch.float64) + cos_out[:, 0::2] = torch.cos(freqs) + cos_out[:, 1::2] = torch.cos(freqs) + sin_out[:, 0::2] = torch.sin(freqs) + sin_out[:, 1::2] = torch.sin(freqs) + return cos_out, sin_out + + +def create_3d_rotary_embeddings( + batch_size, ppf, pph, ppw, head_dim, device, base=10000.0, dtype=torch.bfloat16 +): + h_dim = w_dim = 2 * (head_dim // 6) + t_dim = head_dim - h_dim - w_dim + + max_len = max(ppf, pph, ppw) + t_cos, t_sin = get_1d_rotary_pos_embed(t_dim, max_len, base, device) + h_cos, h_sin = get_1d_rotary_pos_embed(h_dim, max_len, base, device) + w_cos, w_sin = get_1d_rotary_pos_embed(w_dim, max_len, base, device) + + t_cos_3d = ( + t_cos[:ppf].view(1, ppf, 1, 1, t_dim).expand(batch_size, ppf, pph, ppw, t_dim) + ) + t_sin_3d = ( + t_sin[:ppf].view(1, ppf, 1, 1, t_dim).expand(batch_size, ppf, pph, ppw, t_dim) + ) + h_cos_3d = ( + h_cos[:pph].view(1, 1, pph, 1, h_dim).expand(batch_size, ppf, pph, ppw, h_dim) + ) + h_sin_3d = ( + h_sin[:pph].view(1, 1, pph, 1, h_dim).expand(batch_size, ppf, pph, ppw, h_dim) + ) + w_cos_3d = ( + w_cos[:ppw].view(1, 1, 1, ppw, w_dim).expand(batch_size, ppf, pph, ppw, w_dim) + ) + w_sin_3d = ( + w_sin[:ppw].view(1, 1, 1, ppw, w_dim).expand(batch_size, ppf, pph, ppw, w_dim) + ) + + freqs_cos = torch.cat([t_cos_3d, h_cos_3d, w_cos_3d], dim=-1) + freqs_sin = torch.cat([t_sin_3d, h_sin_3d, w_sin_3d], dim=-1) + + seq_len = ppf * pph * ppw + freqs_cos = freqs_cos.reshape(batch_size, seq_len, 1, head_dim).to(dtype) + freqs_sin = freqs_sin.reshape(batch_size, seq_len, 1, head_dim).to(dtype) + return freqs_cos, freqs_sin + + +def create_3d_rotary_embeddings_neox( + batch_size, ppf, pph, ppw, head_dim, device, base=10000.0, dtype=torch.bfloat16 +): + """Create NeoX-style 3D rotary embeddings matching the kernel's per-element mapping. + + The kernel's NeoX path applies (dim_idx * 2) & ((1 << log_head_dim) - 1) to + compute a mapped dimension index, then uses that to look up both the frequency + AND the spatial dimension (frame/height/width) for position ID selection. + + This means each element gets its own cos/sin, and adjacent elements within + a float2 pair can map to different spatial dimensions. + + Returns freqs_cos, freqs_sin with shape [batch, seq_len, 1, head_dim]. + """ + h_dim = w_dim = 2 * (head_dim // 6) + t_dim = head_dim - h_dim - w_dim + log_head_dim = head_dim.bit_length() - 1 + numElemsPerThread = head_dim // 32 + + freq_table = [] + for i in range(t_dim // 2): + freq_table.append(base ** (-2.0 * i / t_dim)) + for i in range(h_dim // 2): + freq_table.append(base ** (-2.0 * i / h_dim)) + for i in range(w_dim // 2): + freq_table.append(base ** (-2.0 * i / w_dim)) + freq_table = torch.tensor(freq_table, dtype=torch.float64, device=device) + + height_slice_start = t_dim + width_slice_start = t_dim + h_dim + + # Build per-element freq and spatial-dim assignment following the kernel's mapping + freq_per_elem = torch.zeros(head_dim, dtype=torch.float64, device=device) + spatial_dim_per_elem = torch.zeros(head_dim, dtype=torch.long, device=device) + + for elem_idx in range(head_dim): + laneId = elem_idx // numElemsPerThread + within_lane = elem_idx % numElemsPerThread + ii = within_lane // 2 + comp = within_lane % 2 + raw = laneId * numElemsPerThread + ii * 2 + comp + mapped = (raw * 2) & ((1 << log_head_dim) - 1) + freq_idx = mapped >> 1 + freq_per_elem[elem_idx] = freq_table[freq_idx] + if mapped >= width_slice_start: + spatial_dim_per_elem[elem_idx] = 2 # width + elif mapped >= height_slice_start: + spatial_dim_per_elem[elem_idx] = 1 # height + else: + spatial_dim_per_elem[elem_idx] = 0 # frame + + seq_len = ppf * pph * ppw + cos_out = torch.zeros( + batch_size, seq_len, 1, head_dim, dtype=torch.float64, device=device + ) + sin_out = torch.zeros( + batch_size, seq_len, 1, head_dim, dtype=torch.float64, device=device + ) + + for b in range(batch_size): + for s in range(seq_len): + tok = s + pos_t = tok // (pph * ppw) + pos_x = tok % (pph * ppw) + pos_h = pos_x // ppw + pos_w = pos_x % ppw + pos_ids = torch.tensor( + [pos_t, pos_h, pos_w], dtype=torch.float64, device=device + ) + + for d in range(head_dim): + pos_id = pos_ids[spatial_dim_per_elem[d]] + theta = pos_id * freq_per_elem[d] + cos_out[b, s, 0, d] = torch.cos(theta) + sin_out[b, s, 0, d] = torch.sin(theta) + + return cos_out.to(dtype), sin_out.to(dtype) + + +def compute_rope_dims(head_dim): + h_dim = w_dim = 2 * (head_dim // 6) + t_dim = head_dim - h_dim - w_dim + return t_dim, h_dim, w_dim + + +def reference_qk_norm_rope( + query, key, value, norm_q, norm_k, num_heads, freqs_cos, freqs_sin, interleave=True +): + query = norm_q(query) + key = norm_k(key) + + query = query.unflatten(2, (num_heads, -1)) + key = key.unflatten(2, (num_heads, -1)) + value = value.unflatten(2, (num_heads, -1)) + + apply_fn = apply_rotary_emb_interleaved if interleave else apply_rotary_emb_neox + query = apply_fn(query, freqs_cos, freqs_sin) + key = apply_fn(key, freqs_cos, freqs_sin) + + return query, key, value + + +# --------------------------------------------------------------------------- +# Fixtures / constants +# --------------------------------------------------------------------------- + +# Configs from official WAN model releases (github.com/Wan-Video/Wan2.1, Wan2.2) +WAN_CONFIGS = { + "wan2.1-1.3B": { # wan/configs/wan_t2v_1_3B.py: dim=1536, num_heads=12 + "num_heads": 12, + "head_dim": 128, + "hidden_dim": 12 * 128, # 1536 + "eps": 1e-6, + "base": 10000.0, + }, + "wan2.2-5B": { # wan/configs/wan_ti2v_5B.py: dim=3072, num_heads=24 + "num_heads": 24, + "head_dim": 128, + "hidden_dim": 24 * 128, # 3072 + "eps": 1e-6, + "base": 10000.0, + }, + "wan2.1-14B": { # wan/configs/wan_t2v_14B.py: dim=5120, num_heads=40 + "num_heads": 40, + "head_dim": 128, + "hidden_dim": 40 * 128, # 5120 + "eps": 1e-6, + "base": 10000.0, + }, +} + +# Default config used for most tests (WAN 2.2 5B, the production target) +WAN_CONFIG = WAN_CONFIGS["wan2.2-5B"] + +INTERLEAVED_SHAPES = [ + (1, 5, 12, 32), # Production: 5x12x32=1920 + (1, 5, 12, 8), # Smaller: 5x12x8=480 + (1, 5, 48, 32), # Larger: 5x48x32=7680 + (2, 5, 12, 32), # batch=2 + (1, 5, 6, 4), # Tiny: 5x6x4=120 + (4, 5, 12, 32), # batch=4 + (1, 5, 12, 16), # Half seq: 5x12x16=960 + (1, 10, 12, 32), # Double frames: 10x12x32=3840 +] + +NEOX_SHAPES = [ + (1, 5, 12, 8), + (1, 5, 6, 4), + (2, 5, 12, 32), +] + + +# --------------------------------------------------------------------------- +# Correctness: interleaved RoPE (primary path, used in production) +# --------------------------------------------------------------------------- + + +@pytest.mark.parametrize("batch_size,ppf,pph,ppw", INTERLEAVED_SHAPES) +def test_interleaved_correctness(batch_size, ppf, pph, ppw): + device = torch.device("cuda") + dtype = torch.bfloat16 + num_heads = WAN_CONFIG["num_heads"] + head_dim = WAN_CONFIG["head_dim"] + hidden_dim = WAN_CONFIG["hidden_dim"] + eps = WAN_CONFIG["eps"] + base = WAN_CONFIG["base"] + t_dim, h_dim, w_dim = compute_rope_dims(head_dim) + seq_len = ppf * pph * ppw + + torch.manual_seed(42) + query = torch.randn(batch_size, seq_len, hidden_dim, device=device, dtype=dtype) + key = torch.randn(batch_size, seq_len, hidden_dim, device=device, dtype=dtype) + value = torch.randn(batch_size, seq_len, hidden_dim, device=device, dtype=dtype) + + norm_q = nn.RMSNorm(hidden_dim, eps=eps).to(device).to(dtype) + norm_k = nn.RMSNorm(hidden_dim, eps=eps).to(device).to(dtype) + with torch.no_grad(): + norm_q.weight.copy_(1.0 + 0.1 * torch.randn(hidden_dim, device=device)) + norm_k.weight.copy_(1.0 + 0.1 * torch.randn(hidden_dim, device=device)) + + freqs_cos, freqs_sin = create_3d_rotary_embeddings( + batch_size, ppf, pph, ppw, head_dim, device, base, dtype + ) + + q_ref, k_ref, v_ref = reference_qk_norm_rope( + query.clone(), + key.clone(), + value.clone(), + norm_q, + norm_k, + num_heads, + freqs_cos, + freqs_sin, + interleave=True, + ) + + qkv_combined = torch.cat([query, key, value], dim=-1).contiguous() + q_fused, k_fused, v_fused = fused_qk_rmsnorm_rope( + qkv_combined, + norm_q.weight.contiguous(), + norm_k.weight.contiguous(), + ppf=ppf, + pph=pph, + ppw=ppw, + num_frame_channels=t_dim, + num_height_channels=h_dim, + num_width_channels=w_dim, + num_heads_q=num_heads, + num_heads_k=num_heads, + num_heads_v=num_heads, + head_dim=head_dim, + eps=eps, + base=base, + interleave=True, + is_qk_norm=True, + ) + + q_ref_flat = q_ref.flatten(2) + k_ref_flat = k_ref.flatten(2) + q_fused_flat = q_fused.flatten(2) + k_fused_flat = k_fused.flatten(2) + + q_max_diff = (q_fused_flat.float() - q_ref_flat.float()).abs().max().item() + k_max_diff = (k_fused_flat.float() - k_ref_flat.float()).abs().max().item() + + assert q_max_diff < 0.1, f"Q max diff {q_max_diff} >= 0.1" + assert k_max_diff < 0.1, f"K max diff {k_max_diff} >= 0.1" + + +# --------------------------------------------------------------------------- +# Correctness: non-interleaved (NeoX) RoPE — first validation of this path +# --------------------------------------------------------------------------- + + +@pytest.mark.parametrize("batch_size,ppf,pph,ppw", NEOX_SHAPES) +def test_neox_correctness(batch_size, ppf, pph, ppw): + """NeoX (non-interleaved) RoPE path validation.""" + + device = torch.device("cuda") + dtype = torch.bfloat16 + num_heads = WAN_CONFIG["num_heads"] + head_dim = WAN_CONFIG["head_dim"] + hidden_dim = WAN_CONFIG["hidden_dim"] + eps = WAN_CONFIG["eps"] + base = WAN_CONFIG["base"] + t_dim, h_dim, w_dim = compute_rope_dims(head_dim) + seq_len = ppf * pph * ppw + + torch.manual_seed(123) + query = torch.randn(batch_size, seq_len, hidden_dim, device=device, dtype=dtype) + key = torch.randn(batch_size, seq_len, hidden_dim, device=device, dtype=dtype) + value = torch.randn(batch_size, seq_len, hidden_dim, device=device, dtype=dtype) + + norm_q = nn.RMSNorm(hidden_dim, eps=eps).to(device).to(dtype) + norm_k = nn.RMSNorm(hidden_dim, eps=eps).to(device).to(dtype) + with torch.no_grad(): + norm_q.weight.copy_(1.0 + 0.1 * torch.randn(hidden_dim, device=device)) + norm_k.weight.copy_(1.0 + 0.1 * torch.randn(hidden_dim, device=device)) + + freqs_cos, freqs_sin = create_3d_rotary_embeddings_neox( + batch_size, ppf, pph, ppw, head_dim, device, base, dtype + ) + + q_ref, k_ref, v_ref = reference_qk_norm_rope( + query.clone(), + key.clone(), + value.clone(), + norm_q, + norm_k, + num_heads, + freqs_cos, + freqs_sin, + interleave=False, + ) + + qkv_combined = torch.cat([query, key, value], dim=-1).contiguous() + q_fused, k_fused, v_fused = fused_qk_rmsnorm_rope( + qkv_combined, + norm_q.weight.contiguous(), + norm_k.weight.contiguous(), + ppf=ppf, + pph=pph, + ppw=ppw, + num_frame_channels=t_dim, + num_height_channels=h_dim, + num_width_channels=w_dim, + num_heads_q=num_heads, + num_heads_k=num_heads, + num_heads_v=num_heads, + head_dim=head_dim, + eps=eps, + base=base, + interleave=False, + is_qk_norm=True, + ) + + q_ref_flat = q_ref.flatten(2) + k_ref_flat = k_ref.flatten(2) + q_fused_flat = q_fused.flatten(2) + k_fused_flat = k_fused.flatten(2) + + q_max_diff = (q_fused_flat.float() - q_ref_flat.float()).abs().max().item() + k_max_diff = (k_fused_flat.float() - k_ref_flat.float()).abs().max().item() + + assert q_max_diff < 0.1, f"NeoX Q max diff {q_max_diff} >= 0.1" + assert k_max_diff < 0.1, f"NeoX K max diff {k_max_diff} >= 0.1" + + +# --------------------------------------------------------------------------- +# Correctness: V passthrough (should be an exact BF16 copy) +# --------------------------------------------------------------------------- + + +def test_v_passthrough(): + device = torch.device("cuda") + dtype = torch.bfloat16 + num_heads = WAN_CONFIG["num_heads"] + head_dim = WAN_CONFIG["head_dim"] + hidden_dim = WAN_CONFIG["hidden_dim"] + t_dim, h_dim, w_dim = compute_rope_dims(head_dim) + + batch_size, ppf, pph, ppw = 1, 5, 6, 4 + seq_len = ppf * pph * ppw + + torch.manual_seed(42) + query = torch.randn(batch_size, seq_len, hidden_dim, device=device, dtype=dtype) + key = torch.randn(batch_size, seq_len, hidden_dim, device=device, dtype=dtype) + value = torch.randn(batch_size, seq_len, hidden_dim, device=device, dtype=dtype) + + qkv_combined = torch.cat([query, key, value], dim=-1).contiguous() + q_weight = torch.ones(hidden_dim, device=device, dtype=dtype) + k_weight = torch.ones(hidden_dim, device=device, dtype=dtype) + + _, _, v_fused = fused_qk_rmsnorm_rope( + qkv_combined, + q_weight, + k_weight, + ppf=ppf, + pph=pph, + ppw=ppw, + num_frame_channels=t_dim, + num_height_channels=h_dim, + num_width_channels=w_dim, + num_heads_q=num_heads, + num_heads_k=num_heads, + num_heads_v=num_heads, + head_dim=head_dim, + interleave=True, + is_qk_norm=True, + ) + + v_expected = value.unflatten(2, (num_heads, head_dim)) + assert torch.equal(v_fused, v_expected), "V output should be an exact copy" + + +# --------------------------------------------------------------------------- +# Correctness: destination-passing style (pre-allocated output) +# --------------------------------------------------------------------------- + + +def test_destination_passing(): + device = torch.device("cuda") + dtype = torch.bfloat16 + num_heads = WAN_CONFIG["num_heads"] + head_dim = WAN_CONFIG["head_dim"] + hidden_dim = WAN_CONFIG["hidden_dim"] + t_dim, h_dim, w_dim = compute_rope_dims(head_dim) + + batch_size, ppf, pph, ppw = 1, 5, 6, 4 + seq_len = ppf * pph * ppw + + torch.manual_seed(42) + qkv = torch.randn(batch_size, seq_len, 3 * hidden_dim, device=device, dtype=dtype) + + q_out = torch.empty( + batch_size, seq_len, num_heads, head_dim, device=device, dtype=dtype + ) + k_out = torch.empty( + batch_size, seq_len, num_heads, head_dim, device=device, dtype=dtype + ) + v_out = torch.empty( + batch_size, seq_len, num_heads, head_dim, device=device, dtype=dtype + ) + + q_ret, k_ret, v_ret = fused_qk_rmsnorm_rope( + qkv, + torch.ones(hidden_dim, device=device, dtype=dtype), + torch.ones(hidden_dim, device=device, dtype=dtype), + ppf=ppf, + pph=pph, + ppw=ppw, + num_frame_channels=t_dim, + num_height_channels=h_dim, + num_width_channels=w_dim, + num_heads_q=num_heads, + num_heads_k=num_heads, + num_heads_v=num_heads, + head_dim=head_dim, + interleave=True, + is_qk_norm=True, + q_out=q_out, + k_out=k_out, + v_out=v_out, + ) + + assert q_ret is q_out + assert k_ret is k_out + assert v_ret is v_out + + +# --------------------------------------------------------------------------- +# Correctness: 2D (pre-flattened) input +# --------------------------------------------------------------------------- + + +def test_2d_input(): + """2D [num_tokens, hidden] input should produce same results as 3D.""" + + device = torch.device("cuda") + dtype = torch.bfloat16 + num_heads = WAN_CONFIG["num_heads"] + head_dim = WAN_CONFIG["head_dim"] + hidden_dim = WAN_CONFIG["hidden_dim"] + eps = WAN_CONFIG["eps"] + base = WAN_CONFIG["base"] + t_dim, h_dim, w_dim = compute_rope_dims(head_dim) + + batch_size, ppf, pph, ppw = 2, 5, 6, 4 + seq_len = ppf * pph * ppw + num_tokens = batch_size * seq_len + + torch.manual_seed(42) + qkv_3d = torch.randn( + batch_size, seq_len, 3 * hidden_dim, device=device, dtype=dtype + ) + qkv_2d = qkv_3d.view(num_tokens, 3 * hidden_dim).contiguous() + + kwargs = dict( + ppf=ppf, + pph=pph, + ppw=ppw, + num_frame_channels=t_dim, + num_height_channels=h_dim, + num_width_channels=w_dim, + num_heads_q=num_heads, + num_heads_k=num_heads, + num_heads_v=num_heads, + head_dim=head_dim, + eps=eps, + base=base, + interleave=True, + is_qk_norm=True, + ) + q_weight = torch.ones(hidden_dim, device=device, dtype=dtype) + k_weight = torch.ones(hidden_dim, device=device, dtype=dtype) + + q_3d, k_3d, v_3d = fused_qk_rmsnorm_rope(qkv_3d, q_weight, k_weight, **kwargs) + q_2d, k_2d, v_2d = fused_qk_rmsnorm_rope(qkv_2d, q_weight, k_weight, **kwargs) + + assert q_3d.ndim == 4, f"3D input should give 4D output, got {q_3d.ndim}D" + assert q_2d.ndim == 3, f"2D input should give 3D output, got {q_2d.ndim}D" + assert q_3d.shape == (batch_size, seq_len, num_heads, head_dim) + assert q_2d.shape == (num_tokens, num_heads, head_dim) + + assert torch.equal(q_3d.view(num_tokens, num_heads, head_dim), q_2d) + assert torch.equal(k_3d.view(num_tokens, num_heads, head_dim), k_2d) + assert torch.equal(v_3d.view(num_tokens, num_heads, head_dim), v_2d) + + +# --------------------------------------------------------------------------- +# Correctness: FP8 output +# --------------------------------------------------------------------------- + + +@pytest.mark.parametrize("output_scale", [1.0, 0.5, 2.0]) +def test_fp8_output(output_scale): + device = torch.device("cuda") + dtype = torch.bfloat16 + num_heads = WAN_CONFIG["num_heads"] + head_dim = WAN_CONFIG["head_dim"] + hidden_dim = WAN_CONFIG["hidden_dim"] + eps = WAN_CONFIG["eps"] + base = WAN_CONFIG["base"] + t_dim, h_dim, w_dim = compute_rope_dims(head_dim) + + batch_size, ppf, pph, ppw = 1, 5, 6, 4 + seq_len = ppf * pph * ppw + + torch.manual_seed(42) + query = torch.randn(batch_size, seq_len, hidden_dim, device=device, dtype=dtype) + key = torch.randn(batch_size, seq_len, hidden_dim, device=device, dtype=dtype) + value = torch.randn(batch_size, seq_len, hidden_dim, device=device, dtype=dtype) + + norm_q = nn.RMSNorm(hidden_dim, eps=eps).to(device).to(dtype) + norm_k = nn.RMSNorm(hidden_dim, eps=eps).to(device).to(dtype) + with torch.no_grad(): + norm_q.weight.copy_(1.0 + 0.1 * torch.randn(hidden_dim, device=device)) + norm_k.weight.copy_(1.0 + 0.1 * torch.randn(hidden_dim, device=device)) + + qkv_combined = torch.cat([query, key, value], dim=-1).contiguous() + + q_fp8, k_fp8, v_fp8 = fused_qk_rmsnorm_rope( + qkv_combined, + norm_q.weight.contiguous(), + norm_k.weight.contiguous(), + ppf=ppf, + pph=pph, + ppw=ppw, + num_frame_channels=t_dim, + num_height_channels=h_dim, + num_width_channels=w_dim, + num_heads_q=num_heads, + num_heads_k=num_heads, + num_heads_v=num_heads, + head_dim=head_dim, + eps=eps, + base=base, + interleave=True, + is_qk_norm=True, + output_fp8=True, + output_quant_scale=output_scale, + v_quant_scale=output_scale, + ) + + assert q_fp8.dtype == torch.float8_e4m3fn + assert k_fp8.dtype == torch.float8_e4m3fn + assert v_fp8.dtype == torch.float8_e4m3fn + assert q_fp8.shape == (batch_size, seq_len, num_heads, head_dim) + assert q_fp8.is_contiguous() + + freqs_cos, freqs_sin = create_3d_rotary_embeddings( + batch_size, ppf, pph, ppw, head_dim, device, base, dtype + ) + q_ref, k_ref, _ = reference_qk_norm_rope( + query.clone(), + key.clone(), + value.clone(), + norm_q, + norm_k, + num_heads, + freqs_cos, + freqs_sin, + interleave=True, + ) + q_ref_fp8 = (q_ref.float() * output_scale).to(torch.float8_e4m3fn) + k_ref_fp8 = (k_ref.float() * output_scale).to(torch.float8_e4m3fn) + + q_diff = ( + (q_fp8.flatten(2).float() - q_ref_fp8.flatten(2).float()).abs().max().item() + ) + k_diff = ( + (k_fp8.flatten(2).float() - k_ref_fp8.flatten(2).float()).abs().max().item() + ) + + # Observed max diffs scale as ~0.5 * output_scale (FP8 quantization boundary + # rounding between kernel's float32 intermediate and reference's BF16 path). + # Allow 50% headroom over observed worst case. + max_allowed = max(0.75 * output_scale, 0.375) + assert q_diff < max_allowed, f"FP8 Q diff {q_diff} >= {max_allowed}" + assert k_diff < max_allowed, f"FP8 K diff {k_diff} >= {max_allowed}" + + +# --------------------------------------------------------------------------- +# Correctness: is_qk_norm=False (RoPE only, no normalization) +# --------------------------------------------------------------------------- + + +def test_rope_only_no_norm(): + device = torch.device("cuda") + dtype = torch.bfloat16 + num_heads = WAN_CONFIG["num_heads"] + head_dim = WAN_CONFIG["head_dim"] + hidden_dim = WAN_CONFIG["hidden_dim"] + base = WAN_CONFIG["base"] + t_dim, h_dim, w_dim = compute_rope_dims(head_dim) + + batch_size, ppf, pph, ppw = 1, 5, 6, 4 + seq_len = ppf * pph * ppw + + torch.manual_seed(42) + query = torch.randn(batch_size, seq_len, hidden_dim, device=device, dtype=dtype) + key = torch.randn(batch_size, seq_len, hidden_dim, device=device, dtype=dtype) + value = torch.randn(batch_size, seq_len, hidden_dim, device=device, dtype=dtype) + + qkv_combined = torch.cat([query, key, value], dim=-1).contiguous() + q_weight = torch.ones(hidden_dim, device=device, dtype=dtype) + k_weight = torch.ones(hidden_dim, device=device, dtype=dtype) + + q_fused, k_fused, _ = fused_qk_rmsnorm_rope( + qkv_combined, + q_weight, + k_weight, + ppf=ppf, + pph=pph, + ppw=ppw, + num_frame_channels=t_dim, + num_height_channels=h_dim, + num_width_channels=w_dim, + num_heads_q=num_heads, + num_heads_k=num_heads, + num_heads_v=num_heads, + head_dim=head_dim, + base=base, + interleave=True, + is_qk_norm=False, + ) + + freqs_cos, freqs_sin = create_3d_rotary_embeddings( + batch_size, ppf, pph, ppw, head_dim, device, base, dtype + ) + q_heads = query.unflatten(2, (num_heads, head_dim)) + k_heads = key.unflatten(2, (num_heads, head_dim)) + q_ref = apply_rotary_emb_interleaved(q_heads, freqs_cos, freqs_sin) + k_ref = apply_rotary_emb_interleaved(k_heads, freqs_cos, freqs_sin) + + q_diff = (q_fused.flatten(2).float() - q_ref.flatten(2).float()).abs().max().item() + k_diff = (k_fused.flatten(2).float() - k_ref.flatten(2).float()).abs().max().item() + + assert q_diff < 0.05, f"RoPE-only Q diff {q_diff} >= 0.05" + assert k_diff < 0.05, f"RoPE-only K diff {k_diff} >= 0.05" + + +# --------------------------------------------------------------------------- +# Correctness: multi-config (WAN 1.3B, 5B, 14B model sizes) +# --------------------------------------------------------------------------- + + +@pytest.mark.parametrize( + "config_name", + [ + "wan2.1-1.3B", + "wan2.2-5B", + pytest.param( + "wan2.1-14B", + marks=pytest.mark.xfail( + reason="14B has num_heads=40 which exceeds kernel MAX_HEADS=32", + raises=ValueError, + strict=True, + ), + ), + ], +) +def test_multi_config(config_name): + """Test across WAN model sizes: 1.3B (12 heads), 5B (24 heads), 14B (40 heads).""" + + cfg = WAN_CONFIGS[config_name] + device = torch.device("cuda") + dtype = torch.bfloat16 + num_heads = cfg["num_heads"] + head_dim = cfg["head_dim"] + hidden_dim = cfg["hidden_dim"] + eps = cfg["eps"] + base = cfg["base"] + t_dim, h_dim, w_dim = compute_rope_dims(head_dim) + + batch_size, ppf, pph, ppw = 1, 5, 6, 4 + seq_len = ppf * pph * ppw + + torch.manual_seed(42) + query = torch.randn(batch_size, seq_len, hidden_dim, device=device, dtype=dtype) + key = torch.randn(batch_size, seq_len, hidden_dim, device=device, dtype=dtype) + value = torch.randn(batch_size, seq_len, hidden_dim, device=device, dtype=dtype) + + norm_q = nn.RMSNorm(hidden_dim, eps=eps).to(device).to(dtype) + norm_k = nn.RMSNorm(hidden_dim, eps=eps).to(device).to(dtype) + with torch.no_grad(): + norm_q.weight.copy_(1.0 + 0.1 * torch.randn(hidden_dim, device=device)) + norm_k.weight.copy_(1.0 + 0.1 * torch.randn(hidden_dim, device=device)) + + freqs_cos, freqs_sin = create_3d_rotary_embeddings( + batch_size, ppf, pph, ppw, head_dim, device, base, dtype + ) + + q_ref, k_ref, _ = reference_qk_norm_rope( + query.clone(), + key.clone(), + value.clone(), + norm_q, + norm_k, + num_heads, + freqs_cos, + freqs_sin, + interleave=True, + ) + + qkv_combined = torch.cat([query, key, value], dim=-1).contiguous() + q_fused, k_fused, _ = fused_qk_rmsnorm_rope( + qkv_combined, + norm_q.weight.contiguous(), + norm_k.weight.contiguous(), + ppf=ppf, + pph=pph, + ppw=ppw, + num_frame_channels=t_dim, + num_height_channels=h_dim, + num_width_channels=w_dim, + num_heads_q=num_heads, + num_heads_k=num_heads, + num_heads_v=num_heads, + head_dim=head_dim, + eps=eps, + base=base, + interleave=True, + is_qk_norm=True, + ) + + q_diff = (q_fused.flatten(2).float() - q_ref.flatten(2).float()).abs().max().item() + k_diff = (k_fused.flatten(2).float() - k_ref.flatten(2).float()).abs().max().item() + + assert q_diff < 0.1, f"{config_name} Q max diff {q_diff} >= 0.1" + assert k_diff < 0.1, f"{config_name} K max diff {k_diff} >= 0.1" + + +# --------------------------------------------------------------------------- +# Validation: error cases +# --------------------------------------------------------------------------- + + +def test_error_non_cuda(): + qkv = torch.randn(1, 120, 3 * 3072, dtype=torch.bfloat16) + w = torch.ones(3072, dtype=torch.bfloat16) + with pytest.raises((ValueError, RuntimeError)): + fused_qk_rmsnorm_rope( + qkv, + w, + w, + ppf=5, + pph=6, + ppw=4, + num_frame_channels=44, + num_height_channels=42, + num_width_channels=42, + num_heads_q=24, + num_heads_k=24, + num_heads_v=24, + head_dim=128, + ) + + +def test_error_wrong_dtype(): + device = torch.device("cuda") + qkv = torch.randn(1, 120, 3 * 3072, dtype=torch.float16, device=device) + w = torch.ones(3072, dtype=torch.bfloat16, device=device) + with pytest.raises((ValueError, RuntimeError)): + fused_qk_rmsnorm_rope( + qkv, + w, + w, + ppf=5, + pph=6, + ppw=4, + num_frame_channels=44, + num_height_channels=42, + num_width_channels=42, + num_heads_q=24, + num_heads_k=24, + num_heads_v=24, + head_dim=128, + ) + + +def test_error_bad_head_dim(): + device = torch.device("cuda") + head_dim = 96 + hidden = 24 * head_dim + qkv = torch.randn(1, 120, 3 * hidden, dtype=torch.bfloat16, device=device) + w = torch.ones(hidden, dtype=torch.bfloat16, device=device) + with pytest.raises((ValueError, RuntimeError)): + fused_qk_rmsnorm_rope( + qkv, + w, + w, + ppf=5, + pph=6, + ppw=4, + num_frame_channels=32, + num_height_channels=32, + num_width_channels=32, + num_heads_q=24, + num_heads_k=24, + num_heads_v=24, + head_dim=head_dim, + ) + + +def test_error_channel_sum_mismatch(): + device = torch.device("cuda") + qkv = torch.randn(1, 120, 3 * 3072, dtype=torch.bfloat16, device=device) + w = torch.ones(3072, dtype=torch.bfloat16, device=device) + with pytest.raises((ValueError, RuntimeError)): + fused_qk_rmsnorm_rope( + qkv, + w, + w, + ppf=5, + pph=6, + ppw=4, + num_frame_channels=40, + num_height_channels=40, + num_width_channels=40, + num_heads_q=24, + num_heads_k=24, + num_heads_v=24, + head_dim=128, + ) + + +def test_error_seq_len_mismatch(): + device = torch.device("cuda") + qkv = torch.randn(1, 100, 3 * 3072, dtype=torch.bfloat16, device=device) + w = torch.ones(3072, dtype=torch.bfloat16, device=device) + with pytest.raises((ValueError, RuntimeError)): + fused_qk_rmsnorm_rope( + qkv, + w, + w, + ppf=5, + pph=6, + ppw=4, + num_frame_channels=44, + num_height_channels=42, + num_width_channels=42, + num_heads_q=24, + num_heads_k=24, + num_heads_v=24, + head_dim=128, + ) + + +def test_error_wrong_weight_size(): + device = torch.device("cuda") + qkv = torch.randn(1, 120, 3 * 3072, dtype=torch.bfloat16, device=device) + w_good = torch.ones(3072, dtype=torch.bfloat16, device=device) + w_bad = torch.ones(1024, dtype=torch.bfloat16, device=device) + with pytest.raises((ValueError, RuntimeError)): + fused_qk_rmsnorm_rope( + qkv, + w_bad, + w_good, + ppf=5, + pph=6, + ppw=4, + num_frame_channels=44, + num_height_channels=42, + num_width_channels=42, + num_heads_q=24, + num_heads_k=24, + num_heads_v=24, + head_dim=128, + ) + + +def test_error_wrong_output_shape(): + device = torch.device("cuda") + dtype = torch.bfloat16 + qkv = torch.randn(1, 120, 3 * 3072, dtype=dtype, device=device) + w = torch.ones(3072, dtype=dtype, device=device) + bad_q_out = torch.empty( + 1, 120, 12, 128, dtype=dtype, device=device + ) # 12 heads, not 24 + with pytest.raises((ValueError, RuntimeError)): + fused_qk_rmsnorm_rope( + qkv, + w, + w, + ppf=5, + pph=6, + ppw=4, + num_frame_channels=44, + num_height_channels=42, + num_width_channels=42, + num_heads_q=24, + num_heads_k=24, + num_heads_v=24, + head_dim=128, + q_out=bad_q_out, + ) + + +def test_error_wrong_output_dtype(): + device = torch.device("cuda") + qkv = torch.randn(1, 120, 3 * 3072, dtype=torch.bfloat16, device=device) + w = torch.ones(3072, dtype=torch.bfloat16, device=device) + bad_q_out = torch.empty(1, 120, 24, 128, dtype=torch.float16, device=device) + with pytest.raises((ValueError, RuntimeError)): + fused_qk_rmsnorm_rope( + qkv, + w, + w, + ppf=5, + pph=6, + ppw=4, + num_frame_channels=44, + num_height_channels=42, + num_width_channels=42, + num_heads_q=24, + num_heads_k=24, + num_heads_v=24, + head_dim=128, + q_out=bad_q_out, + )