diff --git a/megatron/core/inference/contexts/dynamic_context.py b/megatron/core/inference/contexts/dynamic_context.py index dd7af272546..25eb92fa6e1 100644 --- a/megatron/core/inference/contexts/dynamic_context.py +++ b/megatron/core/inference/contexts/dynamic_context.py @@ -258,10 +258,6 @@ def __init__(self, model_config: TransformerConfig, inference_config: InferenceC self.cache_mla_latent = ( isinstance(model_config, MLATransformerConfig) and model_config.cache_mla_latents ) - if self.cache_mla_latent: - assert ( - inference_config.block_size_tokens == 64 - ), "Flash MLA requires a block size of 64. Set --inference-dynamic-batching-block-size 64 to fix this assert" # Per partition num heads and hidden size. num_attention_heads = model_config.num_query_groups or model_config.num_attention_heads diff --git a/megatron/core/transformer/attention.py b/megatron/core/transformer/attention.py index 28e3dde01c4..2c0e9b46dec 100644 --- a/megatron/core/transformer/attention.py +++ b/megatron/core/transformer/attention.py @@ -117,6 +117,8 @@ except ImportError: HAVE_FUSED_QKV_ROPE = False +FMLA_REQUIRED_BLOCK_SIZE = 64 + class LinearQkv(Protocol): """Protocol for linear_qkv modules.""" @@ -620,18 +622,19 @@ def _adjust_key_value_for_inference( ) _, max_seqlen_q = inference_context.cu_query_lengths() - if getattr(self.config, "cache_mla_latents", None) and max_seqlen_q > 1: + # Read key/value *pointer* tensors from cache. + key, value, block_table = inference_context.key_value_cache( + self.layer_number - pp_layer_offset + ) + block_size_tokens = key.size(1) + # Do not use absorption when block size doesn't match what's expected by FlashMLA. + if getattr(self.config, "cache_mla_latents", None) and ( + max_seqlen_q > 1 or block_size_tokens != FMLA_REQUIRED_BLOCK_SIZE + ): # Doing unabsorbed MLA Attention with cached mla latents (prefill/mixed mode) - kv_cache, _, block_table = inference_context.key_value_cache( - self.layer_number - pp_layer_offset - ) + kv_cache = key # Uncompress the KV cache for prefill/mixed mode key, value = self.uncompress_kv_from_cache(kv_cache) - else: - # Read key/value *pointer* tensors from cache. - key, value, block_table = inference_context.key_value_cache( - self.layer_number - pp_layer_offset - ) return query, key, value, rotary_pos_emb, attn_mask_type, block_table @abstractmethod @@ -837,8 +840,13 @@ def flash_decode_and_prefill( ) output_total = output_total.unsqueeze(1) else: # decode only - # If using MLA we use the FlashMLA kernel - if isinstance(self.config, MLATransformerConfig): + # If using MLA we use the FlashMLA kernel when possible. + block_size_tokens = k.size(1) + if ( + isinstance(self.config, MLATransformerConfig) + # Only use FlashMLA when the block size matches + and block_size_tokens == FMLA_REQUIRED_BLOCK_SIZE + ): softmax_scale = self.softmax_scale num_heads_k = 1 # Only a single head for MLA Flash @@ -874,6 +882,7 @@ def flash_decode_and_prefill( "causal": True, "page_table" if HAVE_FA3 else "block_table": block_table, "num_splits": 0 if not self.batch_invariant_mode else 1, + "softmax_scale": getattr(self, "softmax_scale", self.config.softmax_scale), } if HAVE_FA3: output_total = flash_attn3_with_kvcache(**flash_attn_args) diff --git a/megatron/core/transformer/multi_latent_attention.py b/megatron/core/transformer/multi_latent_attention.py index a9cdc697cc8..76fafb0bfa3 100644 --- a/megatron/core/transformer/multi_latent_attention.py +++ b/megatron/core/transformer/multi_latent_attention.py @@ -16,7 +16,10 @@ from megatron.core import tensor_parallel -from megatron.core.extensions.transformer_engine import split_te_layernorm_column_parallel_linear +from megatron.core.extensions.transformer_engine import ( + TELayerNormColumnParallelLinear, + split_te_layernorm_column_parallel_linear, +) from megatron.core.models.common.embeddings import ( RotaryEmbedding, YarnRotaryEmbedding, @@ -33,7 +36,7 @@ gather_from_tensor_model_parallel_region, scatter_to_sequence_parallel_region, ) -from megatron.core.transformer.attention import Attention +from megatron.core.transformer.attention import FMLA_REQUIRED_BLOCK_SIZE, Attention from megatron.core.transformer.enums import AttnMaskType from megatron.core.transformer.spec_utils import ModuleSpec, build_module from megatron.core.transformer.torch_norm import LayerNormBuilder @@ -318,9 +321,13 @@ def forward( cu_kv_lengths, kv_lengths, block_table, + inference_context.is_decode_only(), ) # Only rearrange if not in absorption mode (Flash MLA handles format correctly) - if not inference_context.is_decode_only(): + if ( + not inference_context.is_decode_only() + or inference_context.block_size_tokens != FMLA_REQUIRED_BLOCK_SIZE + ): core_attn_out = rearrange(core_attn_out, 's b h d -> s b (h d)') if self.offload_core_attention and self.training: core_attn_out = off_interface.group_commit( @@ -328,7 +335,11 @@ def forward( ) # We are doing absorption with cache mla latents and decode mode. - if self.cache_mla_latents and inference_context.is_decode_only(): + if ( + self.cache_mla_latents + and inference_context.is_decode_only() + and inference_context.block_size_tokens == FMLA_REQUIRED_BLOCK_SIZE + ): # core_attn_out = self.self.up_v_layer(core_attn_out) core_attn_out = torch.einsum("sbhc,hdc->sbhd", core_attn_out, self.up_v_weight) core_attn_out = core_attn_out.contiguous() @@ -618,7 +629,10 @@ def get_query_key_value_tensors( kv_combined, [self.config.kv_lora_rank, self.config.qk_pos_emb_head_dim], dim=-1 ) if get_pg_size(self.tp_group) > 1 and self.config.sequence_parallel: - # k_pos_emb: [s, b, qk_pos_emb_head_dim] + # kv_compressed: [s, b, kv_lora_rank], k_pos_emb: [s, b, qk_pos_emb_head_dim] + kv_compressed = gather_from_sequence_parallel_region( + kv_compressed, group=self.tp_group + ) k_pos_emb = gather_from_sequence_parallel_region(k_pos_emb, group=self.tp_group) if packed_seq_params is not None: @@ -693,6 +707,7 @@ def qkv_up_proj_and_rope_apply_for_cached_latent_kv( self.config.cache_mla_latents and inference_context and inference_context.is_decode_only() + and inference_context.block_size_tokens == FMLA_REQUIRED_BLOCK_SIZE ) # Compute query components. Multiply by up k if absorbing q_content = ( @@ -877,6 +892,8 @@ def uncompress_kv_from_cache(self, kv_cached): # Add head dimension k_pos_emb = k_pos_emb.unsqueeze(-2) + if get_pg_size(self.tp_group) > 1 and self.config.sequence_parallel: + k_pos_emb = gather_from_sequence_parallel_region(k_pos_emb, group=self.tp_group) k_pos_emb = k_pos_emb.expand(-1, -1, self.num_attention_heads_per_partition, -1) key = torch.cat([k_no_pe, k_pos_emb], dim=-1) @@ -908,16 +925,22 @@ def prepare_for_absorption(self): # We should only have to call to set once at start if not hasattr(self, "up_k_weight"): with torch.no_grad(): - linear_kv_up_proj_norm, linear_kv_up_proj_linear = ( - split_te_layernorm_column_parallel_linear( - self.linear_kv_up_proj, self.config, None, self.linear_kv_up_proj.tp_group + if isinstance(self.linear_kv_up_proj, TELayerNormColumnParallelLinear): + linear_kv_up_proj_norm, linear_kv_up_proj_linear = ( + split_te_layernorm_column_parallel_linear( + self.linear_kv_up_proj, + self.config, + None, + self.linear_kv_up_proj.tp_group, + ) ) - ) - # Note: When caching latents we overide the kv_layernorm - # which was an identity before because in the is path - # we unfused the linear_kv_up_proj - self.kv_layernorm = linear_kv_up_proj_norm + # Note: When caching latents we overide the kv_layernorm + # which was an identity before because in the is path + # we unfused the linear_kv_up_proj + self.kv_layernorm = linear_kv_up_proj_norm + else: + linear_kv_up_proj_linear = self.linear_kv_up_proj # This is used in absorption when we are # uncompressing the KV cache in prefill/mixed stages diff --git a/tests/unit_tests/inference/engines/test_dynamic_engine.py b/tests/unit_tests/inference/engines/test_dynamic_engine.py index 6a07d7a35ae..ff9a070c188 100644 --- a/tests/unit_tests/inference/engines/test_dynamic_engine.py +++ b/tests/unit_tests/inference/engines/test_dynamic_engine.py @@ -49,7 +49,7 @@ from megatron.core.tensor_parallel.random import model_parallel_cuda_manual_seed from megatron.core.transformer.cuda_graphs import CudaGraphManager, _CudagraphGlobalRecord from megatron.core.transformer.enums import CudaGraphScope -from megatron.core.transformer.transformer_config import TransformerConfig +from megatron.core.transformer.transformer_config import MLATransformerConfig, TransformerConfig from megatron.core.utils import is_fa_min_version, is_te_min_version from tests.unit_tests.test_utilities import Utils @@ -140,17 +140,24 @@ class DynamicEngineTestConfig: static_kv_memory_pointers: bool = True track_generated_token_events: bool = False - fp8: bool = False + use_mla: bool = False + cache_mla_latent: bool = False def __post_init__(self): + assert self.max_sequence_length is None + assert ( + self.num_tokens_to_generate is None or self.num_tokens_total is None + ) and self.num_tokens_to_generate != self.num_tokens_total + + if self.use_mla and self.cache_mla_latent: + # Fix paged KV cache block size requirement (needs to be divisible by 256). + # Note, this doesn't work with FlashMLA, which requires a block size of exactly 64. + self.context_block_size_tokens = 256 # Compute max_sequence_length. - assert self.max_sequence_length is None - assert self.num_tokens_to_generate is None or self.num_tokens_total is None if self.num_tokens_to_generate is not None: self.max_sequence_length = self.max_prompt_length + self.num_tokens_to_generate else: - assert self.num_tokens_total is not None self.max_sequence_length = self.num_tokens_total # Default paused buffer size. @@ -291,9 +298,28 @@ def _build_test_env(cls, test_config): # Requests. requests = cls._build_requests(test_config) + # Values required for proper cache_mla_latent functioning + qk_head_dim = 128 + qk_pos_emb_head_dim = 64 + transformer_config_cls = ( + partial( + MLATransformerConfig, + cache_mla_latents=test_config.cache_mla_latent, + qk_head_dim=qk_head_dim, + qk_pos_emb_head_dim=qk_pos_emb_head_dim, + # For cache_mla_latent, the following needs to hold: + # v_head_dim == qk_head_dim + qk_pos_emb_head_dim + v_head_dim=( + (qk_head_dim + qk_pos_emb_head_dim) if test_config.cache_mla_latent else 128 + ), + ) + if test_config.use_mla + else TransformerConfig + ) + if test_config.model_provider == "gpt": # Transformer config. - transformer_config = TransformerConfig( + transformer_config = transformer_config_cls( params_dtype=torch.bfloat16, num_layers=4, hidden_size=128 if test_config.fp8 else 32, @@ -331,11 +357,15 @@ def _build_test_env(cls, test_config): # inference optimized currently only supports RMS Norm ) if test_config.fp8 or test_config.transformer_impl == "transformer_engine": - layer_spec = get_gpt_layer_with_transformer_engine_spec() + layer_spec = get_gpt_layer_with_transformer_engine_spec( + multi_latent_attention=test_config.use_mla + ) elif test_config.transformer_impl == "local": - layer_spec = get_gpt_layer_local_spec() + layer_spec = get_gpt_layer_local_spec(multi_latent_attention=test_config.use_mla) elif test_config.transformer_impl == "inference_optimized": - layer_spec = get_gpt_layer_with_inference_spec() + layer_spec = get_gpt_layer_with_inference_spec( + multi_latent_attention=test_config.use_mla + ) # GPT model. model = GPTModel( @@ -350,7 +380,7 @@ def _build_test_env(cls, test_config): elif test_config.model_provider == "mamba": pp_size = test_config.pipeline_model_parallel_size # Transformer config. - transformer_config = TransformerConfig( + transformer_config = transformer_config_cls( params_dtype=torch.bfloat16, num_layers=( 3 if pp_size == 1 else 6 @@ -1037,6 +1067,8 @@ def test_log_probs_token_correspondence(self): @pytest.mark.parametrize("tp_size", [1, 2]) @pytest.mark.parametrize("model_provider", ["gpt", "mamba"]) @pytest.mark.parametrize("transformer_impl", ["local", "inference_optimized"]) + @pytest.mark.parametrize("use_mla", [False, True]) + @pytest.mark.parametrize("cache_mla_latent", [False, True]) @torch.inference_mode() def test_parallel_inference( self, @@ -1047,6 +1079,8 @@ def test_parallel_inference( sequence_parallel, materialize_only_last_token_logits, transformer_impl, + use_mla, + cache_mla_latent, ): skip_if_mamba_sequence_packing_not_available(model_provider) @@ -1074,10 +1108,14 @@ def test_parallel_inference( "when tp_size > 1." ) ) - if model_provider == "mamba": - pytest.skip( - reason="Mamba model is not supported with the inference optimized transformer." - ) + if use_mla and transformer_impl == "local": + pytest.skip(reason="MLA does not work with the local implementation.") + if cache_mla_latent and not use_mla: + pytest.skip(reason="MLA latent caching requires MLA use.") + if use_mla and not cache_mla_latent: + pytest.skip( + reason="MLA use for dynamic inference currently requires `cache_mla_latents=True`." + ) env = self._run_test( model_provider=model_provider, @@ -1087,6 +1125,8 @@ def test_parallel_inference( sequence_parallel=sequence_parallel, materialize_only_last_token_logits=materialize_only_last_token_logits, transformer_impl=transformer_impl, + use_mla=use_mla, + cache_mla_latent=cache_mla_latent, ) @pytest.mark.internal