Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 0 additions & 4 deletions megatron/core/inference/contexts/dynamic_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,10 +258,6 @@ def __init__(self, model_config: TransformerConfig, inference_config: InferenceC
self.cache_mla_latent = (
isinstance(model_config, MLATransformerConfig) and model_config.cache_mla_latents
)
if self.cache_mla_latent:
assert (
inference_config.block_size_tokens == 64
), "Flash MLA requires a block size of 64. Set --inference-dynamic-batching-block-size 64 to fix this assert"

# Per partition num heads and hidden size.
num_attention_heads = model_config.num_query_groups or model_config.num_attention_heads
Expand Down
31 changes: 20 additions & 11 deletions megatron/core/transformer/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,8 @@
except ImportError:
HAVE_FUSED_QKV_ROPE = False

FMLA_REQUIRED_BLOCK_SIZE = 64


class LinearQkv(Protocol):
"""Protocol for linear_qkv modules."""
Expand Down Expand Up @@ -620,18 +622,19 @@ def _adjust_key_value_for_inference(
)

_, max_seqlen_q = inference_context.cu_query_lengths()
if getattr(self.config, "cache_mla_latents", None) and max_seqlen_q > 1:
# Read key/value *pointer* tensors from cache.
key, value, block_table = inference_context.key_value_cache(
self.layer_number - pp_layer_offset
)
block_size_tokens = key.size(1)
# Do not use absorption when block size doesn't match what's expected by FlashMLA.
if getattr(self.config, "cache_mla_latents", None) and (
max_seqlen_q > 1 or block_size_tokens != FMLA_REQUIRED_BLOCK_SIZE
):
# Doing unabsorbed MLA Attention with cached mla latents (prefill/mixed mode)
kv_cache, _, block_table = inference_context.key_value_cache(
self.layer_number - pp_layer_offset
)
kv_cache = key
# Uncompress the KV cache for prefill/mixed mode
key, value = self.uncompress_kv_from_cache(kv_cache)
else:
# Read key/value *pointer* tensors from cache.
key, value, block_table = inference_context.key_value_cache(
self.layer_number - pp_layer_offset
)
return query, key, value, rotary_pos_emb, attn_mask_type, block_table

@abstractmethod
Expand Down Expand Up @@ -837,8 +840,13 @@ def flash_decode_and_prefill(
)
output_total = output_total.unsqueeze(1)
else: # decode only
# If using MLA we use the FlashMLA kernel
if isinstance(self.config, MLATransformerConfig):
# If using MLA we use the FlashMLA kernel when possible.
block_size_tokens = k.size(1)
if (
isinstance(self.config, MLATransformerConfig)
# Only use FlashMLA when the block size matches
and block_size_tokens == FMLA_REQUIRED_BLOCK_SIZE
):
softmax_scale = self.softmax_scale

num_heads_k = 1 # Only a single head for MLA Flash
Expand Down Expand Up @@ -874,6 +882,7 @@ def flash_decode_and_prefill(
"causal": True,
"page_table" if HAVE_FA3 else "block_table": block_table,
"num_splits": 0 if not self.batch_invariant_mode else 1,
"softmax_scale": getattr(self, "softmax_scale", self.config.softmax_scale),
}
if HAVE_FA3:
output_total = flash_attn3_with_kvcache(**flash_attn_args)
Expand Down
49 changes: 36 additions & 13 deletions megatron/core/transformer/multi_latent_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,10 @@


from megatron.core import tensor_parallel
from megatron.core.extensions.transformer_engine import split_te_layernorm_column_parallel_linear
from megatron.core.extensions.transformer_engine import (
TELayerNormColumnParallelLinear,
split_te_layernorm_column_parallel_linear,
)
from megatron.core.models.common.embeddings import (
RotaryEmbedding,
YarnRotaryEmbedding,
Expand All @@ -33,7 +36,7 @@
gather_from_tensor_model_parallel_region,
scatter_to_sequence_parallel_region,
)
from megatron.core.transformer.attention import Attention
from megatron.core.transformer.attention import FMLA_REQUIRED_BLOCK_SIZE, Attention
from megatron.core.transformer.enums import AttnMaskType
from megatron.core.transformer.spec_utils import ModuleSpec, build_module
from megatron.core.transformer.torch_norm import LayerNormBuilder
Expand Down Expand Up @@ -318,17 +321,25 @@ def forward(
cu_kv_lengths,
kv_lengths,
block_table,
inference_context.is_decode_only(),
)
# Only rearrange if not in absorption mode (Flash MLA handles format correctly)
if not inference_context.is_decode_only():
if (
not inference_context.is_decode_only()
or inference_context.block_size_tokens != FMLA_REQUIRED_BLOCK_SIZE
):
core_attn_out = rearrange(core_attn_out, 's b h d -> s b (h d)')
if self.offload_core_attention and self.training:
core_attn_out = off_interface.group_commit(
core_attn_out, name="core_attn", forced_released_tensors=[query, key, value]
)

# We are doing absorption with cache mla latents and decode mode.
if self.cache_mla_latents and inference_context.is_decode_only():
if (
self.cache_mla_latents
and inference_context.is_decode_only()
and inference_context.block_size_tokens == FMLA_REQUIRED_BLOCK_SIZE
):
# core_attn_out = self.self.up_v_layer(core_attn_out)
core_attn_out = torch.einsum("sbhc,hdc->sbhd", core_attn_out, self.up_v_weight)
core_attn_out = core_attn_out.contiguous()
Expand Down Expand Up @@ -618,7 +629,10 @@ def get_query_key_value_tensors(
kv_combined, [self.config.kv_lora_rank, self.config.qk_pos_emb_head_dim], dim=-1
)
if get_pg_size(self.tp_group) > 1 and self.config.sequence_parallel:
# k_pos_emb: [s, b, qk_pos_emb_head_dim]
# kv_compressed: [s, b, kv_lora_rank], k_pos_emb: [s, b, qk_pos_emb_head_dim]
kv_compressed = gather_from_sequence_parallel_region(
kv_compressed, group=self.tp_group
)
k_pos_emb = gather_from_sequence_parallel_region(k_pos_emb, group=self.tp_group)

if packed_seq_params is not None:
Expand Down Expand Up @@ -693,6 +707,7 @@ def qkv_up_proj_and_rope_apply_for_cached_latent_kv(
self.config.cache_mla_latents
and inference_context
and inference_context.is_decode_only()
and inference_context.block_size_tokens == FMLA_REQUIRED_BLOCK_SIZE
)
# Compute query components. Multiply by up k if absorbing
q_content = (
Expand Down Expand Up @@ -877,6 +892,8 @@ def uncompress_kv_from_cache(self, kv_cached):

# Add head dimension
k_pos_emb = k_pos_emb.unsqueeze(-2)
if get_pg_size(self.tp_group) > 1 and self.config.sequence_parallel:
k_pos_emb = gather_from_sequence_parallel_region(k_pos_emb, group=self.tp_group)
k_pos_emb = k_pos_emb.expand(-1, -1, self.num_attention_heads_per_partition, -1)

key = torch.cat([k_no_pe, k_pos_emb], dim=-1)
Expand Down Expand Up @@ -908,16 +925,22 @@ def prepare_for_absorption(self):
# We should only have to call to set once at start
if not hasattr(self, "up_k_weight"):
with torch.no_grad():
linear_kv_up_proj_norm, linear_kv_up_proj_linear = (
split_te_layernorm_column_parallel_linear(
self.linear_kv_up_proj, self.config, None, self.linear_kv_up_proj.tp_group
if isinstance(self.linear_kv_up_proj, TELayerNormColumnParallelLinear):
linear_kv_up_proj_norm, linear_kv_up_proj_linear = (
split_te_layernorm_column_parallel_linear(
self.linear_kv_up_proj,
self.config,
None,
self.linear_kv_up_proj.tp_group,
)
)
)

# Note: When caching latents we overide the kv_layernorm
# which was an identity before because in the is path
# we unfused the linear_kv_up_proj
self.kv_layernorm = linear_kv_up_proj_norm
# Note: When caching latents we overide the kv_layernorm
# which was an identity before because in the is path
# we unfused the linear_kv_up_proj
self.kv_layernorm = linear_kv_up_proj_norm
else:
linear_kv_up_proj_linear = self.linear_kv_up_proj

# This is used in absorption when we are
# uncompressing the KV cache in prefill/mixed stages
Expand Down
68 changes: 54 additions & 14 deletions tests/unit_tests/inference/engines/test_dynamic_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@
from megatron.core.tensor_parallel.random import model_parallel_cuda_manual_seed
from megatron.core.transformer.cuda_graphs import CudaGraphManager, _CudagraphGlobalRecord
from megatron.core.transformer.enums import CudaGraphScope
from megatron.core.transformer.transformer_config import TransformerConfig
from megatron.core.transformer.transformer_config import MLATransformerConfig, TransformerConfig
from megatron.core.utils import is_fa_min_version, is_te_min_version
from tests.unit_tests.test_utilities import Utils

Expand Down Expand Up @@ -140,17 +140,24 @@ class DynamicEngineTestConfig:
static_kv_memory_pointers: bool = True
track_generated_token_events: bool = False

fp8: bool = False
use_mla: bool = False
cache_mla_latent: bool = False

def __post_init__(self):
assert self.max_sequence_length is None
assert (
self.num_tokens_to_generate is None or self.num_tokens_total is None
) and self.num_tokens_to_generate != self.num_tokens_total

if self.use_mla and self.cache_mla_latent:
# Fix paged KV cache block size requirement (needs to be divisible by 256).
# Note, this doesn't work with FlashMLA, which requires a block size of exactly 64.
self.context_block_size_tokens = 256

# Compute max_sequence_length.
assert self.max_sequence_length is None
assert self.num_tokens_to_generate is None or self.num_tokens_total is None
if self.num_tokens_to_generate is not None:
self.max_sequence_length = self.max_prompt_length + self.num_tokens_to_generate
else:
assert self.num_tokens_total is not None
self.max_sequence_length = self.num_tokens_total

# Default paused buffer size.
Expand Down Expand Up @@ -291,9 +298,28 @@ def _build_test_env(cls, test_config):
# Requests.
requests = cls._build_requests(test_config)

# Values required for proper cache_mla_latent functioning
qk_head_dim = 128
qk_pos_emb_head_dim = 64
transformer_config_cls = (
partial(
MLATransformerConfig,
cache_mla_latents=test_config.cache_mla_latent,
qk_head_dim=qk_head_dim,
qk_pos_emb_head_dim=qk_pos_emb_head_dim,
# For cache_mla_latent, the following needs to hold:
# v_head_dim == qk_head_dim + qk_pos_emb_head_dim
v_head_dim=(
(qk_head_dim + qk_pos_emb_head_dim) if test_config.cache_mla_latent else 128
),
)
if test_config.use_mla
else TransformerConfig
)

if test_config.model_provider == "gpt":
# Transformer config.
transformer_config = TransformerConfig(
transformer_config = transformer_config_cls(
params_dtype=torch.bfloat16,
num_layers=4,
hidden_size=128 if test_config.fp8 else 32,
Expand Down Expand Up @@ -331,11 +357,15 @@ def _build_test_env(cls, test_config):
# inference optimized currently only supports RMS Norm
)
if test_config.fp8 or test_config.transformer_impl == "transformer_engine":
layer_spec = get_gpt_layer_with_transformer_engine_spec()
layer_spec = get_gpt_layer_with_transformer_engine_spec(
multi_latent_attention=test_config.use_mla
)
elif test_config.transformer_impl == "local":
layer_spec = get_gpt_layer_local_spec()
layer_spec = get_gpt_layer_local_spec(multi_latent_attention=test_config.use_mla)
elif test_config.transformer_impl == "inference_optimized":
layer_spec = get_gpt_layer_with_inference_spec()
layer_spec = get_gpt_layer_with_inference_spec(
multi_latent_attention=test_config.use_mla
)

# GPT model.
model = GPTModel(
Expand All @@ -350,7 +380,7 @@ def _build_test_env(cls, test_config):
elif test_config.model_provider == "mamba":
pp_size = test_config.pipeline_model_parallel_size
# Transformer config.
transformer_config = TransformerConfig(
transformer_config = transformer_config_cls(
params_dtype=torch.bfloat16,
num_layers=(
3 if pp_size == 1 else 6
Expand Down Expand Up @@ -1037,6 +1067,8 @@ def test_log_probs_token_correspondence(self):
@pytest.mark.parametrize("tp_size", [1, 2])
@pytest.mark.parametrize("model_provider", ["gpt", "mamba"])
@pytest.mark.parametrize("transformer_impl", ["local", "inference_optimized"])
@pytest.mark.parametrize("use_mla", [False, True])
@pytest.mark.parametrize("cache_mla_latent", [False, True])
@torch.inference_mode()
def test_parallel_inference(
self,
Expand All @@ -1047,6 +1079,8 @@ def test_parallel_inference(
sequence_parallel,
materialize_only_last_token_logits,
transformer_impl,
use_mla,
cache_mla_latent,
):
skip_if_mamba_sequence_packing_not_available(model_provider)

Expand Down Expand Up @@ -1074,10 +1108,14 @@ def test_parallel_inference(
"when tp_size > 1."
)
)
if model_provider == "mamba":
pytest.skip(
reason="Mamba model is not supported with the inference optimized transformer."
)
if use_mla and transformer_impl == "local":
pytest.skip(reason="MLA does not work with the local implementation.")
if cache_mla_latent and not use_mla:
pytest.skip(reason="MLA latent caching requires MLA use.")
if use_mla and not cache_mla_latent:
pytest.skip(
reason="MLA use for dynamic inference currently requires `cache_mla_latents=True`."
)

env = self._run_test(
model_provider=model_provider,
Expand All @@ -1087,6 +1125,8 @@ def test_parallel_inference(
sequence_parallel=sequence_parallel,
materialize_only_last_token_logits=materialize_only_last_token_logits,
transformer_impl=transformer_impl,
use_mla=use_mla,
cache_mla_latent=cache_mla_latent,
)

@pytest.mark.internal
Expand Down