Skip to content

Commit badddd2

Browse files
Fangzhou-Aiclaude
andauthored
[ROCm][DSV4][Perf] Fuse inverse-RoPE and cache bf16 wo_a in o-projection (vllm-project#45103)
Signed-off-by: Fangzhou Ai <fangzhouai@gmail.com> Co-authored-by: Claude Fable 5 <noreply@anthropic.com>
1 parent c906500 commit badddd2

2 files changed

Lines changed: 341 additions & 59 deletions

File tree

tests/kernels/attention/test_rocm_triton_attn_dsv4.py

Lines changed: 215 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -515,3 +515,218 @@ def test_sparse_attn_decode_split_k_kernel(
515515
)
516516

517517
torch.testing.assert_close(actual, expected, atol=2e-2, rtol=2e-2)
518+
519+
520+
# ---------------------------------------------------------------------------
521+
# o-projection: fused inverse-RoPE + cached bf16 wo_a (rocm_inv_rope_einsum)
522+
# ---------------------------------------------------------------------------
523+
524+
525+
# Cache rows = max_position_embeddings * scaling_factor.
526+
_ROTARY_MAX_POS = 1024
527+
_ROTARY_SCALING_FACTOR = 4.0
528+
_ROTARY_CACHE_LEN = int(_ROTARY_MAX_POS * _ROTARY_SCALING_FACTOR)
529+
530+
531+
def _make_dsv4_rotary(device: torch.device):
532+
"""The official DSv4 rotary embedding, sized down for unit tests."""
533+
from vllm.model_executor.layers.rotary_embedding.deepseek_scaling_rope import (
534+
DeepseekV4ScalingRotaryEmbedding,
535+
)
536+
537+
# The model loader constructs layers under a default-device context;
538+
# mirror that so the fp32 cos_sin_cache lands on the GPU.
539+
with torch.device(device):
540+
rotary_emb = DeepseekV4ScalingRotaryEmbedding(
541+
head_size=ROPE_HEAD_DIM,
542+
rotary_dim=ROPE_HEAD_DIM,
543+
max_position_embeddings=_ROTARY_MAX_POS,
544+
base=10000,
545+
is_neox_style=False,
546+
scaling_factor=_ROTARY_SCALING_FACTOR,
547+
dtype=torch.bfloat16,
548+
mscale=1.0,
549+
mscale_all_dim=1.0,
550+
)
551+
rotary_emb = rotary_emb.to(device)
552+
assert rotary_emb.cos_sin_cache.shape == (_ROTARY_CACHE_LEN, ROPE_HEAD_DIM)
553+
return rotary_emb
554+
555+
556+
def _inv_rope_via_rotary_native(
557+
rotary_emb: torch.nn.Module,
558+
o: torch.Tensor,
559+
positions: torch.Tensor,
560+
) -> torch.Tensor:
561+
"""Reference: the official ``forward_native(inverse=True)`` path."""
562+
expected, _ = rotary_emb.forward_native(positions, o.clone(), None, inverse=True)
563+
return expected.to(torch.bfloat16)
564+
565+
566+
class _FakeWoA(torch.nn.Module):
567+
"""Stand-in for the wo_a linear layer holding the (optionally fp8) weight."""
568+
569+
def __init__(
570+
self, weight: torch.Tensor, weight_scale_inv: torch.Tensor | None = None
571+
) -> None:
572+
super().__init__()
573+
self.weight = weight
574+
if weight_scale_inv is not None:
575+
self.weight_scale_inv = weight_scale_inv
576+
577+
578+
@pytest.mark.parametrize("num_tokens", [1, 7, 64])
579+
@pytest.mark.parametrize("num_heads", [1, 8])
580+
@pytest.mark.parametrize("pos_dtype", [torch.int32, torch.int64])
581+
@torch.inference_mode()
582+
def test_fused_inverse_rope_gptj_matches_rotary_native(
583+
num_tokens: int, num_heads: int, pos_dtype: torch.dtype, default_vllm_config
584+
) -> None:
585+
from vllm.v1.attention.ops.rocm_aiter_mla_sparse import _fused_inverse_rope_gptj
586+
587+
device = torch.device("cuda")
588+
torch.manual_seed(0)
589+
rotary_emb = _make_dsv4_rotary(device)
590+
o = torch.randn(
591+
num_tokens, num_heads, HEAD_DIM, dtype=torch.bfloat16, device=device
592+
)
593+
positions = torch.randint(
594+
0, _ROTARY_CACHE_LEN, (num_tokens,), dtype=pos_dtype, device=device
595+
)
596+
597+
actual = _fused_inverse_rope_gptj(
598+
o, positions, rotary_emb.cos_sin_cache, ROPE_HEAD_DIM
599+
)
600+
expected = _inv_rope_via_rotary_native(rotary_emb, o, positions)
601+
602+
assert actual.dtype == torch.bfloat16
603+
assert actual.shape == o.shape
604+
# NoPE lanes are a pure bf16 passthrough -> must be bit-exact.
605+
assert torch.equal(actual[..., :NOPE_HEAD_DIM], expected[..., :NOPE_HEAD_DIM])
606+
# RoPE lanes: tolerate at most ~1 bf16 ulp from fp32 fma ordering.
607+
torch.testing.assert_close(actual, expected, atol=2e-2, rtol=2e-2)
608+
609+
610+
@torch.inference_mode()
611+
def test_fused_inverse_rope_gptj_empty(default_vllm_config) -> None:
612+
from vllm.v1.attention.ops.rocm_aiter_mla_sparse import _fused_inverse_rope_gptj
613+
614+
device = torch.device("cuda")
615+
rotary_emb = _make_dsv4_rotary(device)
616+
o = torch.empty(0, 8, HEAD_DIM, dtype=torch.bfloat16, device=device)
617+
positions = torch.empty(0, dtype=torch.int32, device=device)
618+
619+
out = _fused_inverse_rope_gptj(
620+
o, positions, rotary_emb.cos_sin_cache, ROPE_HEAD_DIM
621+
)
622+
assert out.shape == (0, 8, HEAD_DIM)
623+
assert out.dtype == torch.bfloat16
624+
625+
626+
@torch.inference_mode()
627+
def test_rocm_inv_rope_einsum_matches_rotary_native(default_vllm_config) -> None:
628+
from vllm.v1.attention.ops.rocm_aiter_mla_sparse import rocm_inv_rope_einsum
629+
630+
device = torch.device("cuda")
631+
torch.manual_seed(2)
632+
num_tokens, num_heads = 5, 8
633+
n_local_groups = num_heads
634+
o_lora_rank = 16
635+
hidden_dim = num_heads * HEAD_DIM // n_local_groups # 512
636+
637+
rotary_emb = _make_dsv4_rotary(device)
638+
o = (
639+
torch.randn(
640+
num_tokens, num_heads, HEAD_DIM, dtype=torch.bfloat16, device=device
641+
)
642+
* 0.125
643+
)
644+
positions = torch.randint(
645+
0, _ROTARY_CACHE_LEN, (num_tokens,), dtype=torch.int32, device=device
646+
)
647+
weight = (
648+
torch.randn(n_local_groups * o_lora_rank, hidden_dim, device=device) * 0.125
649+
).to(torch.bfloat16)
650+
wo_a = _FakeWoA(weight)
651+
652+
actual = rocm_inv_rope_einsum(
653+
rotary_emb, o, positions, ROPE_HEAD_DIM, n_local_groups, o_lora_rank, wo_a
654+
)
655+
656+
o_ref = _inv_rope_via_rotary_native(rotary_emb, o, positions)
657+
o_ref = o_ref.view(num_tokens, n_local_groups, -1)
658+
wo_a_ref = weight.view(n_local_groups, o_lora_rank, hidden_dim).to(torch.bfloat16)
659+
expected = torch.einsum("tgd,grd->tgr", o_ref, wo_a_ref)
660+
661+
assert actual.shape == (num_tokens, n_local_groups, o_lora_rank)
662+
torch.testing.assert_close(actual, expected, atol=2e-2, rtol=2e-2)
663+
664+
665+
@torch.inference_mode()
666+
def test_get_cached_wo_a_bf16_plain_caches() -> None:
667+
from vllm.v1.attention.ops.rocm_aiter_mla_sparse import _get_cached_wo_a_bf16
668+
669+
device = torch.device("cuda")
670+
torch.manual_seed(4)
671+
n_local_groups, o_lora_rank, hidden_dim = 2, 4, 8
672+
weight = torch.randn(
673+
n_local_groups * o_lora_rank, hidden_dim, dtype=torch.bfloat16, device=device
674+
)
675+
wo_a = _FakeWoA(weight)
676+
677+
out1 = _get_cached_wo_a_bf16(wo_a, n_local_groups, o_lora_rank, hidden_dim)
678+
expected = weight.view(n_local_groups, o_lora_rank, hidden_dim).to(torch.bfloat16)
679+
assert out1.shape == (n_local_groups, o_lora_rank, hidden_dim)
680+
torch.testing.assert_close(out1, expected, atol=0, rtol=0)
681+
assert hasattr(wo_a, "_dsv4_wo_a_bf16")
682+
683+
# Mutate the source weight: the cached tensor must be returned unchanged
684+
# (proving the dequant is not recomputed per call).
685+
wo_a.weight.zero_()
686+
out2 = _get_cached_wo_a_bf16(wo_a, n_local_groups, o_lora_rank, hidden_dim)
687+
assert out2 is out1
688+
torch.testing.assert_close(out2, expected, atol=0, rtol=0)
689+
690+
691+
@torch.inference_mode()
692+
def test_get_cached_wo_a_bf16_fp8_blockscale_caches() -> None:
693+
from vllm.v1.attention.ops.rocm_aiter_mla_sparse import _get_cached_wo_a_bf16
694+
695+
device = torch.device("cuda")
696+
torch.manual_seed(5)
697+
n_local_groups, o_lora_rank, hidden_dim = 2, 4, 8
698+
row_block, col_block = 2, 2
699+
row_blocks = o_lora_rank // row_block
700+
col_blocks = hidden_dim // col_block
701+
702+
fp8_dtype = current_platform.fp8_dtype()
703+
weight_f32 = (
704+
torch.randn(
705+
n_local_groups, o_lora_rank, hidden_dim, dtype=torch.float32, device=device
706+
)
707+
* 0.1
708+
)
709+
weight_fp8 = weight_f32.to(fp8_dtype)
710+
scale = (
711+
torch.rand(
712+
n_local_groups, row_blocks, col_blocks, dtype=torch.float32, device=device
713+
)
714+
* 0.5
715+
+ 0.5
716+
)
717+
wo_a = _FakeWoA(
718+
weight_fp8.reshape(n_local_groups * o_lora_rank, hidden_dim),
719+
weight_scale_inv=scale.reshape(n_local_groups * row_blocks, col_blocks),
720+
)
721+
722+
out = _get_cached_wo_a_bf16(wo_a, n_local_groups, o_lora_rank, hidden_dim)
723+
724+
scale_full = scale.repeat_interleave(row_block, dim=-2).repeat_interleave(
725+
col_block, dim=-1
726+
)
727+
expected = (weight_fp8.to(torch.float32) * scale_full).to(torch.bfloat16)
728+
assert out.shape == (n_local_groups, o_lora_rank, hidden_dim)
729+
torch.testing.assert_close(out, expected, atol=0, rtol=0)
730+
731+
# Second call returns the same cached object.
732+
assert _get_cached_wo_a_bf16(wo_a, n_local_groups, o_lora_rank, hidden_dim) is out

0 commit comments

Comments
 (0)