@@ -515,3 +515,218 @@ def test_sparse_attn_decode_split_k_kernel(
515515 )
516516
517517 torch .testing .assert_close (actual , expected , atol = 2e-2 , rtol = 2e-2 )
518+
519+
520+ # ---------------------------------------------------------------------------
521+ # o-projection: fused inverse-RoPE + cached bf16 wo_a (rocm_inv_rope_einsum)
522+ # ---------------------------------------------------------------------------
523+
524+
525+ # Cache rows = max_position_embeddings * scaling_factor.
526+ _ROTARY_MAX_POS = 1024
527+ _ROTARY_SCALING_FACTOR = 4.0
528+ _ROTARY_CACHE_LEN = int (_ROTARY_MAX_POS * _ROTARY_SCALING_FACTOR )
529+
530+
531+ def _make_dsv4_rotary (device : torch .device ):
532+ """The official DSv4 rotary embedding, sized down for unit tests."""
533+ from vllm .model_executor .layers .rotary_embedding .deepseek_scaling_rope import (
534+ DeepseekV4ScalingRotaryEmbedding ,
535+ )
536+
537+ # The model loader constructs layers under a default-device context;
538+ # mirror that so the fp32 cos_sin_cache lands on the GPU.
539+ with torch .device (device ):
540+ rotary_emb = DeepseekV4ScalingRotaryEmbedding (
541+ head_size = ROPE_HEAD_DIM ,
542+ rotary_dim = ROPE_HEAD_DIM ,
543+ max_position_embeddings = _ROTARY_MAX_POS ,
544+ base = 10000 ,
545+ is_neox_style = False ,
546+ scaling_factor = _ROTARY_SCALING_FACTOR ,
547+ dtype = torch .bfloat16 ,
548+ mscale = 1.0 ,
549+ mscale_all_dim = 1.0 ,
550+ )
551+ rotary_emb = rotary_emb .to (device )
552+ assert rotary_emb .cos_sin_cache .shape == (_ROTARY_CACHE_LEN , ROPE_HEAD_DIM )
553+ return rotary_emb
554+
555+
556+ def _inv_rope_via_rotary_native (
557+ rotary_emb : torch .nn .Module ,
558+ o : torch .Tensor ,
559+ positions : torch .Tensor ,
560+ ) -> torch .Tensor :
561+ """Reference: the official ``forward_native(inverse=True)`` path."""
562+ expected , _ = rotary_emb .forward_native (positions , o .clone (), None , inverse = True )
563+ return expected .to (torch .bfloat16 )
564+
565+
566+ class _FakeWoA (torch .nn .Module ):
567+ """Stand-in for the wo_a linear layer holding the (optionally fp8) weight."""
568+
569+ def __init__ (
570+ self , weight : torch .Tensor , weight_scale_inv : torch .Tensor | None = None
571+ ) -> None :
572+ super ().__init__ ()
573+ self .weight = weight
574+ if weight_scale_inv is not None :
575+ self .weight_scale_inv = weight_scale_inv
576+
577+
578+ @pytest .mark .parametrize ("num_tokens" , [1 , 7 , 64 ])
579+ @pytest .mark .parametrize ("num_heads" , [1 , 8 ])
580+ @pytest .mark .parametrize ("pos_dtype" , [torch .int32 , torch .int64 ])
581+ @torch .inference_mode ()
582+ def test_fused_inverse_rope_gptj_matches_rotary_native (
583+ num_tokens : int , num_heads : int , pos_dtype : torch .dtype , default_vllm_config
584+ ) -> None :
585+ from vllm .v1 .attention .ops .rocm_aiter_mla_sparse import _fused_inverse_rope_gptj
586+
587+ device = torch .device ("cuda" )
588+ torch .manual_seed (0 )
589+ rotary_emb = _make_dsv4_rotary (device )
590+ o = torch .randn (
591+ num_tokens , num_heads , HEAD_DIM , dtype = torch .bfloat16 , device = device
592+ )
593+ positions = torch .randint (
594+ 0 , _ROTARY_CACHE_LEN , (num_tokens ,), dtype = pos_dtype , device = device
595+ )
596+
597+ actual = _fused_inverse_rope_gptj (
598+ o , positions , rotary_emb .cos_sin_cache , ROPE_HEAD_DIM
599+ )
600+ expected = _inv_rope_via_rotary_native (rotary_emb , o , positions )
601+
602+ assert actual .dtype == torch .bfloat16
603+ assert actual .shape == o .shape
604+ # NoPE lanes are a pure bf16 passthrough -> must be bit-exact.
605+ assert torch .equal (actual [..., :NOPE_HEAD_DIM ], expected [..., :NOPE_HEAD_DIM ])
606+ # RoPE lanes: tolerate at most ~1 bf16 ulp from fp32 fma ordering.
607+ torch .testing .assert_close (actual , expected , atol = 2e-2 , rtol = 2e-2 )
608+
609+
610+ @torch .inference_mode ()
611+ def test_fused_inverse_rope_gptj_empty (default_vllm_config ) -> None :
612+ from vllm .v1 .attention .ops .rocm_aiter_mla_sparse import _fused_inverse_rope_gptj
613+
614+ device = torch .device ("cuda" )
615+ rotary_emb = _make_dsv4_rotary (device )
616+ o = torch .empty (0 , 8 , HEAD_DIM , dtype = torch .bfloat16 , device = device )
617+ positions = torch .empty (0 , dtype = torch .int32 , device = device )
618+
619+ out = _fused_inverse_rope_gptj (
620+ o , positions , rotary_emb .cos_sin_cache , ROPE_HEAD_DIM
621+ )
622+ assert out .shape == (0 , 8 , HEAD_DIM )
623+ assert out .dtype == torch .bfloat16
624+
625+
626+ @torch .inference_mode ()
627+ def test_rocm_inv_rope_einsum_matches_rotary_native (default_vllm_config ) -> None :
628+ from vllm .v1 .attention .ops .rocm_aiter_mla_sparse import rocm_inv_rope_einsum
629+
630+ device = torch .device ("cuda" )
631+ torch .manual_seed (2 )
632+ num_tokens , num_heads = 5 , 8
633+ n_local_groups = num_heads
634+ o_lora_rank = 16
635+ hidden_dim = num_heads * HEAD_DIM // n_local_groups # 512
636+
637+ rotary_emb = _make_dsv4_rotary (device )
638+ o = (
639+ torch .randn (
640+ num_tokens , num_heads , HEAD_DIM , dtype = torch .bfloat16 , device = device
641+ )
642+ * 0.125
643+ )
644+ positions = torch .randint (
645+ 0 , _ROTARY_CACHE_LEN , (num_tokens ,), dtype = torch .int32 , device = device
646+ )
647+ weight = (
648+ torch .randn (n_local_groups * o_lora_rank , hidden_dim , device = device ) * 0.125
649+ ).to (torch .bfloat16 )
650+ wo_a = _FakeWoA (weight )
651+
652+ actual = rocm_inv_rope_einsum (
653+ rotary_emb , o , positions , ROPE_HEAD_DIM , n_local_groups , o_lora_rank , wo_a
654+ )
655+
656+ o_ref = _inv_rope_via_rotary_native (rotary_emb , o , positions )
657+ o_ref = o_ref .view (num_tokens , n_local_groups , - 1 )
658+ wo_a_ref = weight .view (n_local_groups , o_lora_rank , hidden_dim ).to (torch .bfloat16 )
659+ expected = torch .einsum ("tgd,grd->tgr" , o_ref , wo_a_ref )
660+
661+ assert actual .shape == (num_tokens , n_local_groups , o_lora_rank )
662+ torch .testing .assert_close (actual , expected , atol = 2e-2 , rtol = 2e-2 )
663+
664+
665+ @torch .inference_mode ()
666+ def test_get_cached_wo_a_bf16_plain_caches () -> None :
667+ from vllm .v1 .attention .ops .rocm_aiter_mla_sparse import _get_cached_wo_a_bf16
668+
669+ device = torch .device ("cuda" )
670+ torch .manual_seed (4 )
671+ n_local_groups , o_lora_rank , hidden_dim = 2 , 4 , 8
672+ weight = torch .randn (
673+ n_local_groups * o_lora_rank , hidden_dim , dtype = torch .bfloat16 , device = device
674+ )
675+ wo_a = _FakeWoA (weight )
676+
677+ out1 = _get_cached_wo_a_bf16 (wo_a , n_local_groups , o_lora_rank , hidden_dim )
678+ expected = weight .view (n_local_groups , o_lora_rank , hidden_dim ).to (torch .bfloat16 )
679+ assert out1 .shape == (n_local_groups , o_lora_rank , hidden_dim )
680+ torch .testing .assert_close (out1 , expected , atol = 0 , rtol = 0 )
681+ assert hasattr (wo_a , "_dsv4_wo_a_bf16" )
682+
683+ # Mutate the source weight: the cached tensor must be returned unchanged
684+ # (proving the dequant is not recomputed per call).
685+ wo_a .weight .zero_ ()
686+ out2 = _get_cached_wo_a_bf16 (wo_a , n_local_groups , o_lora_rank , hidden_dim )
687+ assert out2 is out1
688+ torch .testing .assert_close (out2 , expected , atol = 0 , rtol = 0 )
689+
690+
691+ @torch .inference_mode ()
692+ def test_get_cached_wo_a_bf16_fp8_blockscale_caches () -> None :
693+ from vllm .v1 .attention .ops .rocm_aiter_mla_sparse import _get_cached_wo_a_bf16
694+
695+ device = torch .device ("cuda" )
696+ torch .manual_seed (5 )
697+ n_local_groups , o_lora_rank , hidden_dim = 2 , 4 , 8
698+ row_block , col_block = 2 , 2
699+ row_blocks = o_lora_rank // row_block
700+ col_blocks = hidden_dim // col_block
701+
702+ fp8_dtype = current_platform .fp8_dtype ()
703+ weight_f32 = (
704+ torch .randn (
705+ n_local_groups , o_lora_rank , hidden_dim , dtype = torch .float32 , device = device
706+ )
707+ * 0.1
708+ )
709+ weight_fp8 = weight_f32 .to (fp8_dtype )
710+ scale = (
711+ torch .rand (
712+ n_local_groups , row_blocks , col_blocks , dtype = torch .float32 , device = device
713+ )
714+ * 0.5
715+ + 0.5
716+ )
717+ wo_a = _FakeWoA (
718+ weight_fp8 .reshape (n_local_groups * o_lora_rank , hidden_dim ),
719+ weight_scale_inv = scale .reshape (n_local_groups * row_blocks , col_blocks ),
720+ )
721+
722+ out = _get_cached_wo_a_bf16 (wo_a , n_local_groups , o_lora_rank , hidden_dim )
723+
724+ scale_full = scale .repeat_interleave (row_block , dim = - 2 ).repeat_interleave (
725+ col_block , dim = - 1
726+ )
727+ expected = (weight_fp8 .to (torch .float32 ) * scale_full ).to (torch .bfloat16 )
728+ assert out .shape == (n_local_groups , o_lora_rank , hidden_dim )
729+ torch .testing .assert_close (out , expected , atol = 0 , rtol = 0 )
730+
731+ # Second call returns the same cached object.
732+ assert _get_cached_wo_a_bf16 (wo_a , n_local_groups , o_lora_rank , hidden_dim ) is out
0 commit comments