11from typing import Generic , Optional , TypeVar
22import logging
3- import os
43
54from dataclasses import dataclass
65
76import torch
87
8+ from aiter import dtypes , get_mla_metadata_info_v1 , get_mla_metadata_v1
9+ from aiter .dist .parallel_state import get_tp_group
910from atom .plugin .prepare import is_vllm , is_sglang
10- from atom .utils import CpuGpuBuffer
11+ from atom .utils import CpuGpuBuffer , envs
12+ from atom .config import get_current_atom_config
13+
1114from atom .utils .forward_context import Context , AttentionMetaData
1215from atom .model_ops .attention_mha import PagedAttentionImpl
13- from atom .model_ops .attention_mla import MLAAttention
16+ from atom .model_ops .attention_mla import MLAAttention , _MLA_MIN_HEADS
1417
1518logger = logging .getLogger ("atom" )
1619
1720_PARTITION_SIZE_ROCM = 256
1821_CP_TOKENS_PER_ITER_ROCM = 32 * 1024
22+ disable_vllm_plugin_attention = envs .ATOM_DISABLE_VLLM_PLUGIN_ATTENTION
1923
2024
2125@dataclass
@@ -643,6 +647,7 @@ class AiterMLAChunkedContextMetadataForPluginMode:
643647 query_seq_lens : torch .Tensor | None = None
644648 workspace_buffer : torch .Tensor | None = None
645649 q_data_type : torch .dtype | None = None
650+ output_dtype : torch .dtype | None = None
646651
647652
648653D = TypeVar ("D" , bound = AiterMLACommonDecodeMetadataForPluginMode )
@@ -699,6 +704,49 @@ def __init__(self):
699704 "Its methods are meant to be added to other classes via decorators."
700705 )
701706
707+ # TODO: support mtp and sparse
708+ def _set_mla_persistent_worker_buffers (
709+ self , bs : int , cu_seqlens_q : torch .Tensor , max_q_len : int = 1
710+ ):
711+ split_params = {
712+ "kv_granularity" : max (self .block_size , 16 ),
713+ "max_seqlen_qo" : max_q_len ,
714+ "uni_seqlen_qo" : max_q_len ,
715+ "fast_mode" : 1 ,
716+ "max_split_per_batch" : 16 ,
717+ }
718+ var = self .mla_persistent_metadata
719+ work_meta_data = var ["work_meta_data" ]
720+ work_info_set = var ["work_info_set" ]
721+ work_indptr = var ["work_indptr" ]
722+ reduce_indptr = var ["reduce_indptr" ]
723+ reduce_final_map = var ["reduce_final_map" ]
724+ reduce_partial_map = var ["reduce_partial_map" ]
725+ get_mla_metadata_v1 (
726+ cu_seqlens_q ,
727+ self .paged_kv_indptr [: bs + 1 ], # TODO: support sparse
728+ self .paged_kv_last_page_len [:bs ],
729+ self .padded_num_attention_heads ,
730+ 1 , # nhead_kv,
731+ True ,
732+ work_meta_data ,
733+ work_info_set ,
734+ work_indptr ,
735+ reduce_indptr ,
736+ reduce_final_map ,
737+ reduce_partial_map ,
738+ page_size = self .block_size ,
739+ ** split_params ,
740+ )
741+ return {
742+ "work_meta_data" : work_meta_data ,
743+ "work_info_set" : work_info_set ,
744+ "work_indptr" : work_indptr ,
745+ "reduce_indptr" : reduce_indptr ,
746+ "reduce_final_map" : reduce_final_map ,
747+ "reduce_partial_map" : reduce_partial_map ,
748+ }
749+
702750 def _build_decode (
703751 self ,
704752 block_table_tensor : torch .Tensor ,
@@ -733,34 +781,29 @@ def _build_decode(
733781 qo_len = query_start_loc_cpu [1 :] - query_start_loc_cpu [:- 1 ]
734782 max_qo_len = qo_len .max ().item ()
735783
736- if self .compilation_config .cudagraph_mode .has_full_cudagraphs ():
737- num_actual_pages = paged_kv_indices .size (0 )
784+ num_actual_pages = paged_kv_indices .size (0 )
738785
739- self .paged_kv_indices [:num_actual_pages ].copy_ (
740- paged_kv_indices , non_blocking = True
741- )
742- self .paged_kv_indices [num_actual_pages :].fill_ (- 1 )
743- paged_kv_indices = self .paged_kv_indices [:num_actual_pages ]
786+ self .paged_kv_indices [:num_actual_pages ].copy_ (
787+ paged_kv_indices , non_blocking = True
788+ )
789+ self .paged_kv_indices [num_actual_pages :].fill_ (- 1 )
790+ paged_kv_indices = self .paged_kv_indices [:num_actual_pages ]
744791
745- self .paged_kv_indptr [: 1 + num_reqs ].copy_ (
746- paged_kv_indptr , non_blocking = True
747- )
748- self .paged_kv_indptr [1 + num_reqs :].fill_ (paged_kv_indptr [- 1 ])
749- paged_kv_indptr = self .paged_kv_indptr [: 1 + num_reqs ]
792+ self .paged_kv_indptr [: 1 + num_reqs ].copy_ (paged_kv_indptr , non_blocking = True )
793+ self .paged_kv_indptr [1 + num_reqs :].fill_ (paged_kv_indptr [- 1 ])
794+ paged_kv_indptr = self .paged_kv_indptr [: 1 + num_reqs ]
750795
751- # paged_kv_last_page_len already uses the pre-initialized buffer slice
752- # (set above), so no copy needed - buffer is always 1s.
796+ # paged_kv_last_page_len already uses the pre-initialized buffer slice
797+ # (set above), so no copy needed - buffer is always 1s.
753798
754- self .qo_indptr [: 1 + num_reqs ].copy_ (
755- query_start_loc_device , non_blocking = True
756- )
757- self .qo_indptr [1 + num_reqs :] = query_start_loc_device [- 1 ]
758- qo_indptr = self .qo_indptr [: 1 + num_reqs ]
799+ self .qo_indptr [: 1 + num_reqs ].copy_ (query_start_loc_device , non_blocking = True )
800+ self .qo_indptr [1 + num_reqs :] = query_start_loc_device [- 1 ]
801+ qo_indptr = self .qo_indptr [: 1 + num_reqs ]
759802
760- else :
761- qo_indptr = torch . arange (
762- 0 , num_reqs + 1 , step = 1 , dtype = torch . int32 , device = device
763- )
803+ ctx_mla_ps = self . _set_mla_persistent_worker_buffers (
804+ num_reqs , query_start_loc_device , 1
805+ )
806+ self . mla_persistent_metadata . update ( ctx_mla_ps )
764807
765808 attn_metadata = AiterMLADecodeMetadataForPluginMode (
766809 block_table = block_table_tensor ,
@@ -1056,11 +1099,15 @@ def build(
10561099 decode = decode_metadata ,
10571100 )
10581101
1102+ # TODO: support mtp
1103+ ctx_mla_ps = self .mla_persistent_metadata
1104+
10591105 attn_metadata = AttentionMetaData (
10601106 max_seqlen_q = common_attn_metadata .max_query_len ,
10611107 block_tables = common_attn_metadata .block_table_tensor ,
10621108 slot_mapping = common_attn_metadata .slot_mapping ,
10631109 plugin_metadata = attn_metadata_for_plugin_mode ,
1110+ ** ctx_mla_ps ,
10641111 )
10651112 return attn_metadata
10661113
@@ -1095,6 +1142,13 @@ def init_method_under_plugin_mode(
10951142 max_num_pages_per_req = self .vllm_config .model_config .max_model_len
10961143 max_num_reqs = self .vllm_config .scheduler_config .max_num_seqs
10971144 max_num_pages = max_num_reqs * max_num_pages_per_req
1145+ self .num_attention_heads = (
1146+ config .model_config .hf_config .num_attention_heads
1147+ // get_tp_group ().world_size
1148+ )
1149+ self .padded_num_attention_heads = max (self .num_attention_heads , _MLA_MIN_HEADS )
1150+ self .block_size = kv_cache_spec .block_size
1151+ self .max_bs = max_num_reqs
10981152
10991153 # Preparing persistent buffers
11001154 # TODO: we can disambiguate between decode and mixed-prefill decode here
@@ -1107,17 +1161,54 @@ def init_method_under_plugin_mode(
11071161 max_num_reqs , dtype = torch .int32 , device = device
11081162 )
11091163
1110- if self .compilation_config .cudagraph_mode .has_full_cudagraphs ():
1111- self .paged_kv_indptr = torch .zeros (
1112- max_num_reqs + 1 , dtype = torch .int32 , device = device
1113- )
1114- self .paged_kv_indices = torch .zeros (
1115- max_num_pages , dtype = torch .int32 , device = device
1116- )
1164+ self .paged_kv_indptr = torch .zeros (
1165+ max_num_reqs + 1 , dtype = torch .int32 , device = device
1166+ )
1167+ self .paged_kv_indices = torch .zeros (
1168+ max_num_pages , dtype = torch .int32 , device = device
1169+ )
11171170
1118- self .qo_indptr = torch .zeros (
1119- max_num_reqs + 1 , dtype = torch .int32 , device = device
1120- )
1171+ self .qo_indptr = torch .zeros (max_num_reqs + 1 , dtype = torch .int32 , device = device )
1172+
1173+ (
1174+ (work_meta_data_size , work_meta_data_type ),
1175+ (work_indptr_size , work_indptr_type ),
1176+ (work_info_set_size , work_info_set_type ),
1177+ (reduce_indptr_size , reduce_indptr_type ),
1178+ (reduce_final_map_size , reduce_final_map_type ),
1179+ (reduce_partial_map_size , reduce_partial_map_type ),
1180+ ) = get_mla_metadata_info_v1 (
1181+ max_num_reqs ,
1182+ 1 ,
1183+ self .padded_num_attention_heads ,
1184+ torch .bfloat16 ,
1185+ dtypes .d_dtypes [config .cache_config .cache_dtype ],
1186+ is_sparse = False , # TODO: support sparse
1187+ fast_mode = True ,
1188+ )
1189+
1190+ self .mla_persistent_metadata = {
1191+ "work_meta_data" : torch .empty (
1192+ work_meta_data_size , dtype = work_meta_data_type , device = self .device
1193+ ),
1194+ "work_indptr" : torch .empty (
1195+ work_indptr_size , dtype = work_indptr_type , device = self .device
1196+ ),
1197+ "work_info_set" : torch .empty (
1198+ work_info_set_size , dtype = work_info_set_type , device = self .device
1199+ ),
1200+ "reduce_indptr" : torch .empty (
1201+ reduce_indptr_size , dtype = reduce_indptr_type , device = self .device
1202+ ),
1203+ "reduce_final_map" : torch .empty (
1204+ reduce_final_map_size , dtype = reduce_final_map_type , device = self .device
1205+ ),
1206+ "reduce_partial_map" : torch .empty (
1207+ reduce_partial_map_size ,
1208+ dtype = reduce_partial_map_type ,
1209+ device = self .device ,
1210+ ),
1211+ }
11211212
11221213 return init_method_under_plugin_mode
11231214
@@ -1206,6 +1297,7 @@ def decorator(cls):
12061297class vllmAiterMLABackendMethods :
12071298 accept_output_buffer : bool = True
12081299 supported_dtypes : list = [torch .float16 , torch .bfloat16 ]
1300+ forward_includes_kv_cache_update : bool = True
12091301
12101302 def __init__ (self ):
12111303 raise TypeError (
@@ -1288,9 +1380,6 @@ def unified_attention_with_output_base_for_plugin_mode(
12881380 use_mla : bool ,
12891381 qkv : torch .Tensor ,
12901382) -> torch .Tensor :
1291- from atom .config import get_current_atom_config
1292- from atom .utils import envs
1293-
12941383 atom_config = get_current_atom_config ()
12951384 if use_mla :
12961385 # raise NotImplementedError("MLA is not supported for plugin mode for now")
@@ -1300,7 +1389,7 @@ def unified_attention_with_output_base_for_plugin_mode(
13001389 q = self .q_proj (q , q_scale )
13011390 q = q .view (- 1 , self .num_heads , self .qk_head_dim )
13021391 # Add head dim of 1 to k_pe
1303- if os . getenv ( "ATOM_DISABLE_VLLM_PLUGIN_ATTENTION" , "0" ). lower () == "1" :
1392+ if disable_vllm_plugin_attention :
13041393 k_pe = k_pe .unsqueeze (1 )
13051394 if self .rotary_emb is not None :
13061395 q [..., self .qk_nope_head_dim :], k_pe = self .rotary_emb (
0 commit comments