Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/MaxText/gradient_accumulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@ def convert_to_bf16(param):
def accumulate_gradient(acc_grad_and_loss, data):
ga_params = acc_grad_and_loss["ga_params"]
(_, aux), cur_batch_gradient = grad_func(model, config, data, dropout_rng, ga_params, *extra_dpo_args, is_train=True)
cur_batch_gradient = jax.tree.map(_maybe_shard_with_name, cur_batch_gradient, grad_shardings)
acc_grad_and_loss["loss"] += aux["total_loss"]
acc_grad_and_loss["moe_lb_loss"] += aux["moe_lb_loss"]
acc_grad_and_loss["mtp_loss"] += aux["mtp_loss"]
Expand Down
150 changes: 113 additions & 37 deletions src/MaxText/layers/attention_mla.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,15 +14,17 @@

"""MLA Attention Layer."""

from functools import partial
import dataclasses
import math
from typing import Any, Optional, Tuple

from jax.ad_checkpoint import checkpoint_name
from jax.sharding import Mesh, NamedSharding
import jax.numpy as jnp

from flax import linen as nn
from flax import nnx
from flax import linen as nn

from MaxText.common_types import (
Array,
Expand All @@ -46,9 +48,11 @@
LENGTH_NO_EXP,
MODEL_MODE_PREFILL,
MODEL_MODE_TRAIN,
MODEL_MODE_AUTOREGRESSIVE,
PREFILL_KV_BATCH,
PREFILL_LENGTH,
AttentionType,
ShardMode,
)
from MaxText.inference import kvcache
from MaxText.inference import page_manager
Expand All @@ -60,6 +64,19 @@
from MaxText.layers.linears import DenseGeneral
from MaxText.layers.normalizations import RMSNorm
from MaxText.layers.quantizations import AqtQuantization as Quant
from MaxText.sharding import maybe_shard_with_logical


@dataclasses.dataclass(frozen=True)
class AttentionMLALogicalAxisNames:
"""Holds the set of axis names for Attention MLA tensors."""
query: Tuple[str, ...]
key: Tuple[str, ...]
value: Tuple[str, ...]
inputs: Tuple[str, ...]
out: Tuple[str, ...]
wqa: Tuple[str, ...]
wkva: Tuple[str, ...]


def mla_as_linen(
Expand Down Expand Up @@ -361,9 +378,82 @@ def __init__(
rngs=rngs,
)

self.setup_sharding(model_mode)

self._maybe_shard_with_logical = partial(
maybe_shard_with_logical,
mesh=mesh,
shard_mode=config.shard_mode,
)
# Module attribute names must match names previously passed to Linen for checkpointing
self.MlaKVCache_0 = self.init_mla_kv_caches(inputs_kv_shape) if model_mode != MODEL_MODE_TRAIN else None

def _create_sharding(self, axis_names):
"""Creates NamedSharding if shard_mode is EXPLICIT, otherwise None."""
if self.config.shard_mode == ShardMode.EXPLICIT:
return NamedSharding(self.mesh, nn.logical_to_mesh_axes(axis_names))
return None

def _get_mla_axis_names(self, model_mode: int):
"""Determines the correct set of axis names based on the model mode."""
if model_mode == MODEL_MODE_PREFILL:
return AttentionMLALogicalAxisNames(
query = self.prefill_query_axis_names,
key = self.prefill_key_axis_names,
value = self.prefill_value_axis_names,
inputs = self.prefill_input_axis_names,
out = self.prefill_out_axis_names,
wqa = (PREFILL_KV_BATCH, PREFILL_LENGTH, 'q_lora_up_proj'),
wkva = (PREFILL_KV_BATCH, PREFILL_LENGTH, 'kv_lora_up_proj')
)
elif model_mode == MODEL_MODE_AUTOREGRESSIVE:
return AttentionMLALogicalAxisNames(
query = (DECODE_BATCH, DECODE_LENGTH, HEAD, D_KV),
key = (DECODE_BATCH, DECODE_LENGTH, KV_HEAD, D_KV),
value = (DECODE_BATCH, DECODE_LENGTH, KV_HEAD, D_KV),
inputs = self.decode_input_axis_names,
out = self.decode_out_axis_names,
wqa = (DECODE_BATCH, DECODE_LENGTH, 'q_lora_up_proj'),
wkva = (DECODE_BATCH, DECODE_LENGTH, 'kv_lora_up_proj')
)
elif model_mode == MODEL_MODE_TRAIN and self.config.expert_shard_attention_option == EP_AS_CONTEXT:
return AttentionMLALogicalAxisNames(
query=self.ep_query_axis_names,
key=self.ep_key_axis_names,
value=self.ep_value_axis_names,
inputs=self.ep_input_axis_names,
out=self.ep_out_axis_names,
wqa = (KV_BATCH_NO_EXP, LENGTH, 'q_lora_up_proj'),
wkva = (KV_BATCH_NO_EXP, LENGTH, 'kv_lora_up_proj')
)
else:
return AttentionMLALogicalAxisNames(
query=self.query_axis_names,
key=self.key_axis_names,
value=self.value_axis_names,
inputs=self.input_axis_names,
out=self.out_axis_names,
wqa = (KV_BATCH, LENGTH_NO_EXP, 'q_lora_up_proj'),
wkva = (KV_BATCH, LENGTH_NO_EXP, 'kv_lora_up_proj')
)

def setup_sharding(self, model_mode: int):
"""Sets up the sharding attributes based on the model mode."""
axis_names = self._get_mla_axis_names(model_mode)

self.query_axis_names = axis_names.query
self.key_axis_names = axis_names.key
self.value_axis_names = axis_names.value
self.input_axis_names = axis_names.inputs
self.out_axis_names = axis_names.out

# Create sharding objects using the helper method
self.inputs_sharding = self._create_sharding(axis_names.inputs)
self.query_sharding = self._create_sharding(axis_names.query)
self.wqa_out_sharding = self._create_sharding(axis_names.wqa)
self.wkvb_out_sharding = self._create_sharding(axis_names.key) # Uses key axes
self.wkva_out_sharding = self._create_sharding(axis_names.wkva)

def _init_projections(self, inputs_q_shape: Tuple, inputs_kv_shape: Tuple) -> None:
"""Initializes the MLA-specific projections."""
# Assert required configuration parameters for MLA attention.
Expand All @@ -389,6 +479,7 @@ def _init_projections(self, inputs_q_shape: Tuple, inputs_kv_shape: Tuple) -> No
weight_dtype=self.weight_dtype,
quant=self.quant,
matmul_precision=self.config.matmul_precision,
shard_mode=self.config.shard_mode,
rngs=self.rngs,
)
else:
Expand All @@ -403,6 +494,7 @@ def _init_projections(self, inputs_q_shape: Tuple, inputs_kv_shape: Tuple) -> No
weight_dtype=self.weight_dtype,
quant=self.quant,
matmul_precision=self.config.matmul_precision,
shard_mode=self.config.shard_mode,
rngs=self.rngs,
)
self.q_norm = RMSNorm(
Expand All @@ -423,6 +515,7 @@ def _init_projections(self, inputs_q_shape: Tuple, inputs_kv_shape: Tuple) -> No
weight_dtype=self.weight_dtype,
quant=self.quant,
matmul_precision=self.config.matmul_precision,
shard_mode=self.config.shard_mode,
rngs=self.rngs,
)

Expand All @@ -437,6 +530,7 @@ def _init_projections(self, inputs_q_shape: Tuple, inputs_kv_shape: Tuple) -> No
weight_dtype=self.weight_dtype,
quant=self.quant,
matmul_precision=self.config.matmul_precision,
shard_mode=self.config.shard_mode,
rngs=self.rngs,
)
self.kv_norm = RMSNorm(
Expand All @@ -460,6 +554,7 @@ def _init_projections(self, inputs_q_shape: Tuple, inputs_kv_shape: Tuple) -> No
weight_dtype=self.weight_dtype,
quant=self.quant,
matmul_precision=self.config.matmul_precision,
shard_mode=self.config.shard_mode,
rngs=self.rngs,
)

Expand Down Expand Up @@ -506,47 +601,37 @@ def mla_query_projection(self, inputs_q: Array, inputs_positions: Array, model_m
self.softmax_scale = self.softmax_scale * mscale * mscale

if self.q_lora_rank == 0:
q = self.query(inputs_q)
q = self.query(inputs_q, out_sharding=self.query_sharding)
else:
# LoRA path
low_rank_q = self.wq_a(inputs_q) # [B, L, q_lora_rank]
low_rank_q = self.wq_a(inputs_q, out_sharding=self.wqa_out_sharding) # [B, L, q_lora_rank]
low_rank_q = self.q_norm(low_rank_q) # RMSNorm on low rank
q = self.wq_b(low_rank_q) # [B, L, n_heads * qk_head_dim]
q = self.wq_b(low_rank_q, out_sharding=self.query_sharding) # [B, L, n_heads * qk_head_dim]

# Split into non-positional and rotary parts.
q_nope, q_pe = jnp.split(q, [self.qk_nope_head_dim], axis=-1)
q_pe = self.apply_rotary_embedding(q_pe, inputs_positions=inputs_positions)
# Query projection is scaled by self.softmax_scale to be consistent MaxText implementation.
# DeepSeek v3 was doing it in attention score computation.
q_pe = self._maybe_shard_with_logical(q_pe, self.query_axis_names)
q_nope = self._maybe_shard_with_logical(q_nope, self.query_axis_names)
query = jnp.concatenate([q_nope, q_pe], axis=-1) * self.softmax_scale

if model_mode == MODEL_MODE_PREFILL:
query = nn.with_logical_constraint(query, self.prefill_query_axis_names)
elif model_mode == MODEL_MODE_TRAIN and self.config.expert_shard_attention_option == EP_AS_CONTEXT:
query = nn.with_logical_constraint(query, self.ep_query_axis_names)
else:
query = nn.with_logical_constraint(query, self.query_axis_names)
query = self._maybe_shard_with_logical(query, self.query_axis_names)
return query

def mla_get_key_value(self, low_rank_main, key_rope, model_mode):
"""get (key,value) pair from mla"""
kv_out = self.wkv_b(low_rank_main)

kv_out = self.wkv_b(low_rank_main, out_sharding=self.wkvb_out_sharding)
# Split kv_out into key_nope and value parts.
key_nope, value = jnp.split(kv_out, [self.qk_nope_head_dim], axis=-1)
key_rope = jnp.broadcast_to(key_rope, (key_nope.shape[0], key_nope.shape[1], self.num_query_heads, key_rope.shape[3]))

key_nope = self._maybe_shard_with_logical(key_nope, self.key_axis_names)
key_rope = self._maybe_shard_with_logical(key_rope, self.key_axis_names)
key = jnp.concatenate([key_nope, key_rope], axis=-1)

if model_mode == MODEL_MODE_PREFILL:
key = nn.with_logical_constraint(key, self.prefill_key_axis_names)
value = nn.with_logical_constraint(value, self.prefill_value_axis_names)
elif model_mode == MODEL_MODE_TRAIN and self.config.expert_shard_attention_option == EP_AS_CONTEXT:
key = nn.with_logical_constraint(key, self.ep_key_axis_names)
value = nn.with_logical_constraint(value, self.ep_value_axis_names)
else:
key = nn.with_logical_constraint(key, self.key_axis_names)
value = nn.with_logical_constraint(value, self.value_axis_names)
key = self._maybe_shard_with_logical(key, self.key_axis_names)
value = self._maybe_shard_with_logical(value, self.value_axis_names)
return key, value

def init_mla_kv_caches(self, inputs_kv_shape: Tuple):
Expand Down Expand Up @@ -637,7 +722,8 @@ def update_mla_kv_caches(self, low_rank_main, key_rope, decoder_segment_ids, mod

def mla_kv_projection(self, inputs: Array, inputs_positions: Array, decoder_segment_ids, model_mode, previous_chunk):
"""MLA key/value projection with integrated rotary embedding."""
low_rank = self.wkv_a(inputs)

low_rank = self.wkv_a(inputs, out_sharding=self.wkva_out_sharding)
low_rank_main, low_rank_rope = jnp.split(low_rank, [self.kv_lora_rank], axis=-1)
low_rank_main = self.kv_norm(low_rank_main)

Expand Down Expand Up @@ -690,15 +776,8 @@ def __call__(
A tensor of shape [batch, length, embed_dim] containing the
MLA-attended outputs.
"""
if model_mode == MODEL_MODE_PREFILL:
inputs_q = nn.with_logical_constraint(inputs_q, self.prefill_input_axis_names)
inputs_kv = nn.with_logical_constraint(inputs_kv, self.prefill_input_axis_names)
elif model_mode == MODEL_MODE_TRAIN and self.config.expert_shard_attention_option == EP_AS_CONTEXT:
inputs_q = nn.with_logical_constraint(inputs_q, self.ep_input_axis_names)
inputs_kv = nn.with_logical_constraint(inputs_kv, self.ep_input_axis_names)
else:
inputs_q = nn.with_logical_constraint(inputs_q, self.input_axis_names)
inputs_kv = nn.with_logical_constraint(inputs_kv, self.input_axis_names)
inputs_q = self._maybe_shard_with_logical(inputs_q, self.input_axis_names)
inputs_kv = self._maybe_shard_with_logical(inputs_kv, self.input_axis_names)

query = self.mla_query_projection(inputs_q, inputs_positions, model_mode)
key, value, cached_values = self.mla_kv_projection(
Expand All @@ -718,11 +797,8 @@ def __call__(
else:
out = self.attention_op(query, key, value, decoder_segment_ids, model_mode, cached_values)

if model_mode == MODEL_MODE_TRAIN and self.config.expert_shard_attention_option == EP_AS_CONTEXT:
out = nn.with_logical_constraint(out, self.ep_out_axis_names)
else:
out = nn.with_logical_constraint(out, self.out_axis_names)
out = self._maybe_shard_with_logical(out, self.out_axis_names)

out = self.out_projection(out)
out = self.out_projection(out, out_sharding=out_sharding)
out = checkpoint_name(out, "out_proj")
return out
Loading
Loading