-
Notifications
You must be signed in to change notification settings - Fork 941
feat: RMSNorm + RoPE fusion for WAN: flashinfer.diffusion_ops.fused_qk_rmsnorm_rope #3148
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. Weβll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
kahyunnam
wants to merge
36
commits into
flashinfer-ai:main
Choose a base branch
from
kahyunnam:knam/fused_norm_rope_for_video_gen
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from 32 commits
Commits
Show all changes
36 commits
Select commit
Hold shift + click to select a range
e4df35f
Add fused QKNorm+RoPE kernel header for video-gen DIT self-attention
kahyunnam e0cb63c
Add TVM-FFI launcher and binding for fused QKNorm+RoPE
kahyunnam cc62bc3
Merge fused QKNorm+RoPE into the norm JIT module
kahyunnam bf2eb0b
Add Python API for fused QKNorm+RoPE in norm package
kahyunnam cc1d649
Export fused_qk_norm_rope from norm/ and create video_gen_ops/ facade
kahyunnam 193cc53
Add tests for fused QKNorm+RoPE and fix packed_as namespace collision
kahyunnam 03a4eb3
Expand tests: 8 interleaved shapes, 3 NeoX xfail, multi-GPU validated
kahyunnam 1baac3d
Fix NeoX RoPE pos_id bug and add passing NeoX tests
kahyunnam 2cd2088
Add multi-config tests for WAN 1.3B/5B/14B model sizes
kahyunnam 73ec639
Add benchmark for fused QKNorm+RoPE kernel
kahyunnam feb1966
Support 2D [num_tokens, hidden] input in fused_qk_norm_rope
kahyunnam b4f7d38
Clean up internal references and fix arch descriptions
kahyunnam 2cda43d
Use single top-level import from video_gen_ops in tests
kahyunnam aeffd60
Fix benchmark docstring to reference video_gen_ops API path
kahyunnam 2f1841a
Fix lint: unused imports, mypy types, clang-format, ruff format
kahyunnam af4247f
Rename video_gen_ops to diffusion_ops
kahyunnam 7580c79
Tighten FP8 test tolerance from 2x to 1.5x headroom
kahyunnam 17103dd
Rename fused_qk_norm_rope to fused_qk_rmsnorm_rope
kahyunnam ddfdbe9
Use int64_t for num_tokens to prevent offset overflow
kahyunnam 693a56e
Make frequency table cache per-device for multi-GPU safety
kahyunnam a18bc00
Fix clang-format indentation in csrc files
kahyunnam 2de4c2b
Merge branch 'main' into knam/fused_norm_rope_for_video_gen
kahyunnam 5a6c91b
Fix clang-format in fused_qk_rmsnorm_rope.cuh
kahyunnam 792eb4e
Use lazy import for get_norm_module to avoid import-time dependency
kahyunnam e40ecf7
Remove stray non-ASCII character from docstring
kahyunnam 20f4ea5
Stop freeing cached freq table pointer for cudagraph safety
kahyunnam 2cb55e6
Check CUDA return codes and use cudaMemcpyAsync in freq cache
kahyunnam 2c92645
Guard grid dim truncation and check kernel launch error
kahyunnam 31c8a97
Validate weights and pre-allocated output buffers
kahyunnam 9ac8944
Add error tests for bad weights and output buffers
kahyunnam e630f22
Always define get_norm_module regardless of CuTe DSL availability
kahyunnam 7de01bb
Apply ruff format to Python files
kahyunnam 758f8f0
Add fused_qk_rmsnorm_rope to unified benchmark CLI
kahyunnam 9b5fe87
Move fused_qk_rmsnorm_rope into norm/__init__.py
kahyunnam b447fc6
Move fused_qk_rmsnorm_rope.cuh into include/flashinfer/norm/
kahyunnam c7cca8b
Merge branch 'main' into knam/fused_norm_rope_for_video_gen
kahyunnam File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Some comments aren't visible on the classic Files Changed page.
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,211 @@ | ||
| """ | ||
| Benchmark for fused QK RMSNorm + 3D RoPE kernel vs eager PyTorch baseline. | ||
|
|
||
| Measures performance across WAN model shapes and compares: | ||
| - Eager: separate nn.RMSNorm + manual interleaved RoPE in PyTorch | ||
| - Fused: flashinfer.diffusion_ops.fused_qk_rmsnorm_rope (single kernel) | ||
|
|
||
| Usage: | ||
| python benchmarks/bench_fused_qk_rmsnorm_rope.py | ||
| python benchmarks/bench_fused_qk_rmsnorm_rope.py --gpu 2 # run on specific GPU | ||
| """ | ||
|
|
||
| import argparse | ||
|
|
||
| import numpy as np | ||
| import torch | ||
| import torch.nn as nn | ||
|
|
||
| from flashinfer.testing.utils import bench_gpu_time | ||
| from flashinfer.diffusion_ops import fused_qk_rmsnorm_rope | ||
|
|
||
|
|
||
| def compute_rope_dims(head_dim): | ||
| h_dim = w_dim = 2 * (head_dim // 6) | ||
| t_dim = head_dim - h_dim - w_dim | ||
| return t_dim, h_dim, w_dim | ||
|
|
||
|
|
||
| def apply_rotary_emb_interleaved(hidden_states, freqs_cos, freqs_sin): | ||
| x1, x2 = hidden_states.unflatten(-1, (-1, 2)).unbind(-1) | ||
| cos = freqs_cos[..., 0::2] | ||
| sin = freqs_sin[..., 1::2] | ||
| out = torch.empty_like(hidden_states) | ||
| out[..., 0::2] = x1 * cos - x2 * sin | ||
| out[..., 1::2] = x1 * sin + x2 * cos | ||
| return out.type_as(hidden_states) | ||
|
|
||
|
|
||
| def get_1d_rotary_pos_embed(dim, length, theta, device): | ||
| inv_freq = 1.0 / ( | ||
| theta ** (torch.arange(0, dim, 2, device=device, dtype=torch.float64) / dim) | ||
| ) | ||
| pos = torch.arange(length, device=device, dtype=torch.float64) | ||
| freqs = torch.einsum("i,j->ij", pos, inv_freq) | ||
| cos_out = torch.zeros(length, dim, device=device, dtype=torch.float64) | ||
| sin_out = torch.zeros(length, dim, device=device, dtype=torch.float64) | ||
| cos_out[:, 0::2] = torch.cos(freqs) | ||
| cos_out[:, 1::2] = torch.cos(freqs) | ||
| sin_out[:, 0::2] = torch.sin(freqs) | ||
| sin_out[:, 1::2] = torch.sin(freqs) | ||
| return cos_out, sin_out | ||
|
|
||
|
|
||
| def create_3d_rotary_embeddings( | ||
| batch_size, ppf, pph, ppw, head_dim, device, base=10000.0, dtype=torch.bfloat16 | ||
| ): | ||
| h_dim = w_dim = 2 * (head_dim // 6) | ||
| t_dim = head_dim - h_dim - w_dim | ||
| max_len = max(ppf, pph, ppw) | ||
| t_cos, t_sin = get_1d_rotary_pos_embed(t_dim, max_len, base, device) | ||
| h_cos, h_sin = get_1d_rotary_pos_embed(h_dim, max_len, base, device) | ||
| w_cos, w_sin = get_1d_rotary_pos_embed(w_dim, max_len, base, device) | ||
| t_cos_3d = ( | ||
| t_cos[:ppf].view(1, ppf, 1, 1, t_dim).expand(batch_size, ppf, pph, ppw, t_dim) | ||
| ) | ||
| t_sin_3d = ( | ||
| t_sin[:ppf].view(1, ppf, 1, 1, t_dim).expand(batch_size, ppf, pph, ppw, t_dim) | ||
| ) | ||
| h_cos_3d = ( | ||
| h_cos[:pph].view(1, 1, pph, 1, h_dim).expand(batch_size, ppf, pph, ppw, h_dim) | ||
| ) | ||
| h_sin_3d = ( | ||
| h_sin[:pph].view(1, 1, pph, 1, h_dim).expand(batch_size, ppf, pph, ppw, h_dim) | ||
| ) | ||
| w_cos_3d = ( | ||
| w_cos[:ppw].view(1, 1, 1, ppw, w_dim).expand(batch_size, ppf, pph, ppw, w_dim) | ||
| ) | ||
| w_sin_3d = ( | ||
| w_sin[:ppw].view(1, 1, 1, ppw, w_dim).expand(batch_size, ppf, pph, ppw, w_dim) | ||
| ) | ||
| freqs_cos = torch.cat([t_cos_3d, h_cos_3d, w_cos_3d], dim=-1) | ||
| freqs_sin = torch.cat([t_sin_3d, h_sin_3d, w_sin_3d], dim=-1) | ||
| seq_len = ppf * pph * ppw | ||
| return ( | ||
| freqs_cos.reshape(batch_size, seq_len, 1, head_dim).to(dtype), | ||
| freqs_sin.reshape(batch_size, seq_len, 1, head_dim).to(dtype), | ||
| ) | ||
|
|
||
|
|
||
| BENCH_SHAPES = [ | ||
| # (batch, ppf, pph, ppw, description) | ||
| (1, 5, 12, 32, "480p production (1920 tokens)"), | ||
| (1, 5, 12, 8, "480p small (480 tokens)"), | ||
| (1, 5, 48, 32, "720p large (7680 tokens)"), | ||
| (2, 5, 12, 32, "batch=2 (3840 tokens)"), | ||
| (1, 5, 6, 4, "tiny (120 tokens)"), | ||
| (4, 5, 12, 32, "batch=4 (7680 tokens)"), | ||
| (1, 5, 12, 16, "half seq (960 tokens)"), | ||
| (1, 10, 12, 32, "double frames (3840 tokens)"), | ||
| ] | ||
|
|
||
|
|
||
| def bench_one_shape(batch_size, ppf, pph, ppw, num_heads, head_dim, eps, base, device): | ||
| seq_len = ppf * pph * ppw | ||
| hidden_dim = num_heads * head_dim | ||
| t_dim, h_dim, w_dim = compute_rope_dims(head_dim) | ||
| dtype = torch.bfloat16 | ||
|
|
||
| torch.manual_seed(42) | ||
| query = torch.randn(batch_size, seq_len, hidden_dim, device=device, dtype=dtype) | ||
| key = torch.randn(batch_size, seq_len, hidden_dim, device=device, dtype=dtype) | ||
| value = torch.randn(batch_size, seq_len, hidden_dim, device=device, dtype=dtype) | ||
|
|
||
| norm_q = nn.RMSNorm(hidden_dim, eps=eps).to(device).to(dtype) | ||
| norm_k = nn.RMSNorm(hidden_dim, eps=eps).to(device).to(dtype) | ||
|
|
||
| freqs_cos, freqs_sin = create_3d_rotary_embeddings( | ||
| batch_size, ppf, pph, ppw, head_dim, device, base, dtype | ||
| ) | ||
|
|
||
| qkv_combined = torch.cat([query, key, value], dim=-1).contiguous() | ||
| q_weight = norm_q.weight.contiguous() | ||
| k_weight = norm_k.weight.contiguous() | ||
|
|
||
| def eager_fn(): | ||
| q_normed = norm_q(query) | ||
| k_normed = norm_k(key) | ||
| q_heads = q_normed.unflatten(2, (num_heads, -1)) | ||
| k_heads = k_normed.unflatten(2, (num_heads, -1)) | ||
| v_heads = value.unflatten(2, (num_heads, -1)) | ||
| q_out = apply_rotary_emb_interleaved(q_heads, freqs_cos, freqs_sin) | ||
| k_out = apply_rotary_emb_interleaved(k_heads, freqs_cos, freqs_sin) | ||
| return q_out, k_out, v_heads | ||
|
|
||
| def fused_fn(): | ||
| return fused_qk_rmsnorm_rope( | ||
| qkv_combined, | ||
| q_weight, | ||
| k_weight, | ||
| ppf=ppf, | ||
| pph=pph, | ||
| ppw=ppw, | ||
| num_frame_channels=t_dim, | ||
| num_height_channels=h_dim, | ||
| num_width_channels=w_dim, | ||
| num_heads_q=num_heads, | ||
| num_heads_k=num_heads, | ||
| num_heads_v=num_heads, | ||
| head_dim=head_dim, | ||
| eps=eps, | ||
| base=base, | ||
| interleave=True, | ||
| is_qk_norm=True, | ||
| ) | ||
|
|
||
| eager_times = bench_gpu_time( | ||
| eager_fn, enable_cupti=True, dry_run_iters=10, repeat_iters=100 | ||
| ) | ||
| fused_times = bench_gpu_time( | ||
| fused_fn, enable_cupti=True, dry_run_iters=10, repeat_iters=100 | ||
| ) | ||
|
|
||
| eager_ms = float(np.median(eager_times)) | ||
| fused_ms = float(np.median(fused_times)) | ||
| return eager_ms, fused_ms | ||
|
|
||
|
|
||
| def main(): | ||
| parser = argparse.ArgumentParser(description="Benchmark fused QK RMSNorm + 3D RoPE") | ||
| parser.add_argument("--gpu", type=int, default=0, help="GPU device index") | ||
| args = parser.parse_args() | ||
|
|
||
| device = torch.device(f"cuda:{args.gpu}") | ||
| torch.cuda.set_device(device) | ||
| gpu_name = torch.cuda.get_device_name(device) | ||
|
|
||
| num_heads = 24 | ||
| head_dim = 128 | ||
| eps = 1e-6 | ||
| base = 10000.0 | ||
|
|
||
| print(f"GPU: {gpu_name}") | ||
| print(f"Config: WAN 2.2 5B (num_heads={num_heads}, head_dim={head_dim})") | ||
| print() | ||
| print(f"{'Shape':<50} {'Eager (ms)':>12} {'Fused (ms)':>12} {'Speedup':>10}") | ||
| print("-" * 90) | ||
|
|
||
| for batch_size, ppf, pph, ppw, desc in BENCH_SHAPES: | ||
| seq_len = ppf * pph * ppw | ||
| shape_str = f"B={batch_size} {ppf}x{pph}x{ppw}={seq_len:>5} ({desc})" | ||
|
|
||
| eager_ms, fused_ms = bench_one_shape( | ||
| batch_size, | ||
| ppf, | ||
| pph, | ||
| ppw, | ||
| num_heads, | ||
| head_dim, | ||
| eps, | ||
| base, | ||
| device, | ||
| ) | ||
|
|
||
| speedup = eager_ms / fused_ms if fused_ms > 0 else 0 | ||
| print(f"{shape_str:<50} {eager_ms:>12.4f} {fused_ms:>12.4f} {speedup:>9.2f}x") | ||
|
|
||
| print("-" * 90) | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| main() |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,5 @@ | ||
| from flashinfer.norm import fused_qk_rmsnorm_rope | ||
|
|
||
| __all__ = [ | ||
| "fused_qk_rmsnorm_rope", | ||
| ] |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Are we trying to compile all CUDA norm kernels even when we do not set
FLASHINFER_USE_CUDA_NORM=1?If so, I wonder whether there is a way to separate out the newly added kernels