Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 30 additions & 20 deletions megatron/core/models/gpt/gpt_layer_specs.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,29 @@
HAVE_APEX = False


_DENSE_MLP_SHARDED_STATE_DICT_KEYS_MAP = {
"mlp.0.weight": "mlp.linear_fc1.layer_norm_weight",
"mlp.0.bias": "mlp.linear_fc1.layer_norm_bias",
"mlp.1.basic_ops.0.weight": "mlp.linear_fc1.weight",
"mlp.1.basic_ops.1.bias": "mlp.linear_fc1.bias",
"mlp.3.basic_ops.0.weight": "mlp.linear_fc2.weight",
"mlp.3.basic_ops.1.bias": "mlp.linear_fc2.bias",
}


def _dense_mlp_sharded_state_dict_keys_map(num_experts: Optional[int]) -> dict[str, str]:
if num_experts is not None:
return {}
return dict(_DENSE_MLP_SHARDED_STATE_DICT_KEYS_MAP)


def _local_layer_norm_sharded_state_dict_keys_map(num_experts: Optional[int]) -> dict[str, str]:
keys_map = {"input_layernorm.": "self_attention.linear_qkv.layer_norm_"}
if num_experts is None:
keys_map["pre_mlp_layernorm."] = "mlp.linear_fc1.layer_norm_"
return keys_map


def get_gpt_layer_with_inference_spec(
qk_layernorm: Optional[bool] = False,
multi_latent_attention: Optional[bool] = False,
Expand Down Expand Up @@ -166,14 +189,7 @@ def get_gpt_layer_with_inference_spec(
pre_mlp_layernorm=IdentityOp,
mlp=mlp,
mlp_bda=get_bias_dropout_add,
sharded_state_dict_keys_map={
"mlp.0.weight": "mlp.linear_fc1.layer_norm_weight",
"mlp.0.bias": "mlp.linear_fc1.layer_norm_bias",
"mlp.1.basic_ops.0.weight": "mlp.linear_fc1.weight",
"mlp.1.basic_ops.1.bias": "mlp.linear_fc1.bias",
"mlp.3.basic_ops.0.weight": "mlp.linear_fc2.weight",
"mlp.3.basic_ops.1.bias": "mlp.linear_fc2.bias",
},
sharded_state_dict_keys_map=dict(_DENSE_MLP_SHARDED_STATE_DICT_KEYS_MAP),
),
)

Expand Down Expand Up @@ -308,14 +324,9 @@ def get_gpt_layer_with_transformer_engine_spec(
mlp=mlp,
mlp_bda=get_bias_dropout_add,
post_mlp_layernorm=TENorm if post_mlp_layernorm else IdentityOp,
sharded_state_dict_keys_map={
"mlp.0.weight": "mlp.linear_fc1.layer_norm_weight",
"mlp.0.bias": "mlp.linear_fc1.layer_norm_bias",
"mlp.1.basic_ops.0.weight": "mlp.linear_fc1.weight",
"mlp.1.basic_ops.1.bias": "mlp.linear_fc1.bias",
"mlp.3.basic_ops.0.weight": "mlp.linear_fc2.weight",
"mlp.3.basic_ops.1.bias": "mlp.linear_fc2.bias",
},
sharded_state_dict_keys_map=_dense_mlp_sharded_state_dict_keys_map(
num_experts
),
),
)

Expand Down Expand Up @@ -455,10 +466,9 @@ def get_gpt_layer_local_spec(
pre_mlp_layernorm=layer_norm,
mlp=mlp,
mlp_bda=bias_dropout_add,
sharded_state_dict_keys_map={
"input_layernorm.": "self_attention.linear_qkv.layer_norm_",
"pre_mlp_layernorm.": "mlp.linear_fc1.layer_norm_",
},
sharded_state_dict_keys_map=_local_layer_norm_sharded_state_dict_keys_map(
num_experts
),
),
)

Expand Down
13 changes: 12 additions & 1 deletion megatron/core/optimizer/cpu_offloading/hybrid_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -375,10 +375,21 @@ def _update_fp32_params_by_new_state(self):

def update_fp32_param_by_new_param(self):
"""
Update the fp32 parameters by the new parameters.
Update optimizer-owned parameter copies from the live model parameters.

``DistributedOptimizer.reload_model_params`` calls this after checkpoint
loading. For offloaded fp32 model parameters there is no separate
``param_to_fp32_param`` entry: the CPU copy tracked by
``gpu_params_map_cpu_copy`` is the optimizer-owned parameter. Refresh it
as well, otherwise the first optimizer step can copy constructor-time
values back over checkpoint-loaded weights.
"""
for param, fp32_param in self.param_to_fp32_param.items():
fp32_param.data.copy_(param)
for param, cpu_copy in self.gpu_params_map_cpu_copy.items():
if param in self.param_to_fp32_param:
continue
cpu_copy.data.copy_(param)

def _register_load_state_dict_hooks(self):
def pre_load_state_dict_hook(self, state_dict):
Expand Down
9 changes: 9 additions & 0 deletions megatron/core/transformer/moe/experts.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
set_tensor_model_parallel_attributes,
)
from megatron.core.tensor_parallel.utils import divide
from miles_megatron_plugins.true_on_policy.contracts import resolve_true_on_policy_runtime_policy
from megatron.core.transformer.mlp import MLP, MLPSubmodules, apply_swiglu_sharded_factory
from megatron.core.transformer.module import MegatronModule
from megatron.core.transformer.moe import grouped_gemm_util as gg
Expand Down Expand Up @@ -88,6 +89,8 @@ def __init__(
), "MoE latent projection not supported in GroupedMLP yet."

self.expert_parallel = config.expert_model_parallel_size > 1
true_on_policy = resolve_true_on_policy_runtime_policy(config)
self.cast_grouped_gemm_input_to_weight_dtype = true_on_policy.use_sglang_backend
if self.config.gated_linear_unit:
if self.config.activation_func not in (F.silu, F.gelu):
raise ValueError("Activation function must be silu or gelu when using GroupedMLP.")
Expand Down Expand Up @@ -255,6 +258,12 @@ def forward(
# Probs already applied, so reset to 1.
permuted_probs = torch.ones_like(permuted_probs)

if (
self.cast_grouped_gemm_input_to_weight_dtype
and permuted_local_hidden_states.dtype != self.weight1.dtype
):
permuted_local_hidden_states = permuted_local_hidden_states.to(self.weight1.dtype)

if permuted_local_hidden_states.nelement() != 0:
# Reshape the weights for the grouped GEMMs.
w1 = self.weight1.view(self.num_local_experts, self.config.hidden_size, -1)
Expand Down
52 changes: 46 additions & 6 deletions megatron/core/transformer/moe/moe_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -324,6 +324,10 @@ def routed_experts_compute(self, hidden_states: torch.Tensor, probs: torch.Tenso
dispatched_input, tokens_per_expert, permuted_probs = (
self.token_dispatcher.dispatch_postprocess(hidden_states, probs)
)
if hasattr(self.experts, "set_sglang_alltoall_source_counts"):
self.experts.set_sglang_alltoall_source_counts(
getattr(self.token_dispatcher, "num_global_tokens_per_local_expert", None)
)
expert_output, mlp_bias = self.experts(dispatched_input, tokens_per_expert, permuted_probs)
assert mlp_bias is None, f"mlp_bias is not supported for {type(self.token_dispatcher)}"
output = self.token_dispatcher.combine_preprocess(expert_output)
Expand Down Expand Up @@ -391,21 +395,38 @@ def forward(

# MoE forward: route -> dispatch -> compute -> combine
def custom_forward(hidden_states, intermediate_tensors, padding_mask=None):
sglang_exact_output = None
try:
if "route" in self.fwd_execution_map:
shared_expert_output = self.shared_experts_compute(hidden_states)

try:
from miles_megatron_plugins.true_on_policy.moe_layer_ext import (
try_sglang_ep_forward,
)

ep_result = try_sglang_ep_forward(
self,
hidden_states,
padding_mask,
shared_expert_output,
intermediate_tensors,
)
except ImportError:
ep_result = None

if ep_result is not None:
if ep_result.is_final:
return ep_result.output
sglang_exact_output = ep_result.exact_output

probs, routing_map = self.route(hidden_states, padding_mask)
hidden_states, probs = self.preprocess(hidden_states, probs, routing_map)

if intermediate_tensors is not None:
return hidden_states, probs, shared_expert_output

except MoECudaGraphPartialCaptureSignal as e:
# This signal is raised from the maybe_skip_or_early_return_by_cudagraph decorator.
# It means we should early-return from the MoE layer forward pass.
# This happens when we are partially capturing the CUDA graph of the MoE layer,
# like cuda_graph_scope=["moe_router", "moe_preprocess"].
# We need to return the intermediate tensors as CUDA graph outputs.
return e.get_early_return_outputs(hidden_states, shared_expert_output)

if "expert_compute" in self.fwd_execution_map:
Expand All @@ -431,9 +452,28 @@ def custom_forward(hidden_states, intermediate_tensors, padding_mask=None):
if intermediate_tensors is not None:
return output

if sglang_exact_output is not None:
output = sglang_exact_output + (output - output.detach())

return output, mlp_bias

if self.moe_layer_recompute:
try:
from miles_megatron_plugins.true_on_policy.moe_layer_ext import (
forward_compacted_true_on_policy_padding,
should_compact_true_on_policy_padding,
)

use_compact = should_compact_true_on_policy_padding(
self, padding_mask, intermediate_tensors
)
except ImportError:
use_compact = False

if use_compact:
outputs = forward_compacted_true_on_policy_padding(
hidden_states, padding_mask, custom_forward
)
elif self.moe_layer_recompute:
if self.config.fp8 or self.config.fp4:
outputs = te_checkpoint(
custom_forward,
Expand Down
40 changes: 33 additions & 7 deletions megatron/core/transformer/moe/moe_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -620,6 +620,7 @@ def topk_routing_with_score_function(
expert_bias: Optional[torch.Tensor] = None,
fused: bool = False,
is_mtp: bool = False,
topk_tiebreak: Optional[str] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Compute the routing probabilities and map for top-k selection with score function.

Expand Down Expand Up @@ -666,6 +667,16 @@ def topk_routing_with_score_function(
expert_bias=expert_bias,
)

def _stable_topk(scores: torch.Tensor, k: int) -> Tuple[torch.Tensor, torch.Tensor]:
expert_ids = torch.arange(scores.shape[-1], device=scores.device, dtype=torch.float32)
scores_fp32 = scores.float()
tie_step = torch.finfo(torch.float32).eps * scores_fp32.abs().clamp_min(
1.0 / scores.shape[-1]
)
scores_for_topk = scores_fp32 - expert_ids.view(1, -1) * tie_step
_, selected_indices = torch.topk(scores_for_topk, k, dim=-1)
return torch.gather(scores, dim=-1, index=selected_indices), selected_indices

def _compute_topk(
scores: torch.Tensor,
topk: int,
Expand Down Expand Up @@ -694,19 +705,29 @@ def _compute_topk(
num_groups=num_groups,
group_topk=group_topk,
)
if topk_tiebreak == "stable_sort":
return _stable_topk(scores, topk)
else:
return torch.topk(scores, k=topk, dim=1)

from miles.utils.replay_base import routing_replay_manager
try:
from miles.utils.replay_base import routing_replay_manager

# MTP layers cannot use rollout routing replay
if not is_mtp:
compute_topk = routing_replay_manager.get_topk_fn(_compute_topk, return_probs=True)
else:
if not is_mtp:
compute_topk = routing_replay_manager.get_topk_fn(_compute_topk, return_probs=True)
else:
compute_topk = _compute_topk
except ImportError:
compute_topk = _compute_topk

if score_function == "softmax":
if use_pre_softmax:
if topk_tiebreak == "stable_sort" and not use_pre_softmax and group_topk is None:
# SGLang's deterministic MoE route selects experts from fp32 softmax scores
# and only casts the selected probabilities back to the activation dtype.
all_scores = torch.softmax(logits, dim=-1, dtype=torch.float32)
probs, top_indices = compute_topk(all_scores, topk, num_groups, group_topk)
probs = (probs / probs.sum(dim=-1, keepdim=True)).type_as(logits)
elif use_pre_softmax:
scores = torch.softmax(logits, dim=-1, dtype=torch.float32).type_as(logits)
probs, top_indices = compute_topk(scores, topk, num_groups, group_topk)
else:
Expand Down Expand Up @@ -752,6 +773,7 @@ def compute_routing_scores_for_aux_loss(
score_function: str,
fused: bool = False,
padding_mask: Optional[torch.Tensor] = None,
topk_tiebreak: Optional[str] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Compute routing scores based on the score function.

Expand Down Expand Up @@ -784,7 +806,11 @@ def compute_routing_scores_for_aux_loss(
else:
raise ValueError(f"Invalid score_function: {score_function}")

_, top_indices = torch.topk(scores, k=topk, dim=1)
if topk_tiebreak == "stable_sort":
_, top_indices = torch.sort(scores, dim=-1, descending=True, stable=True)
top_indices = top_indices[:, :topk]
else:
_, top_indices = torch.topk(scores, k=topk, dim=1)
routing_map = torch.zeros_like(logits).int().scatter(1, top_indices, 1).bool()

# Apply padding mask to scores if provided
Expand Down
45 changes: 37 additions & 8 deletions megatron/core/transformer/moe/router.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,15 @@
)
from megatron.core.transformer.transformer_config import TransformerConfig

try:
from miles_megatron_plugins.true_on_policy.contracts import (
resolve_true_on_policy_runtime_policy,
)

_HAS_TRUE_ON_POLICY = True
except ImportError:
_HAS_TRUE_ON_POLICY = False


class Router(ABC, MegatronModule):
"""Base Router class"""
Expand Down Expand Up @@ -92,12 +101,20 @@ def gating(self, input: torch.Tensor):
if self.bias is not None and self.bias.device.type == 'cpu':
self.bias.data = self.bias.data.to(device=torch.cuda.current_device())

# Convert to specified datatype for routing computation if enabled
router_dtype = input.dtype
if self.config.moe_router_dtype == 'fp32':
router_dtype = torch.float32
elif self.config.moe_router_dtype == 'fp64':
router_dtype = torch.float64
# When the true-on-policy MoE contract is active, cast the post-norm
# activation back to parameter dtype for the router projection so the
# softmax/top-k numerics match SGLang's routed expert path.
if (
_HAS_TRUE_ON_POLICY
and resolve_true_on_policy_runtime_policy(self.config).deterministic_moe_routing
):
router_dtype = self.config.params_dtype
else:
router_dtype = input.dtype
if self.config.moe_router_dtype == 'fp32':
router_dtype = torch.float32
elif self.config.moe_router_dtype == 'fp64':
router_dtype = torch.float64
logits = router_gating_linear(input, self.weight, self.bias, router_dtype)
return logits

Expand Down Expand Up @@ -203,8 +220,12 @@ def __init__(
self.global_tokens_per_expert = None
self.ga_steps = None

from miles.utils.replay_base import routing_replay_manager
routing_replay_manager.register_to_module(self, "routing_replay")
try:
from miles.utils.replay_base import routing_replay_manager

routing_replay_manager.register_to_module(self, "routing_replay")
except ImportError:
pass

def _maintain_float32_expert_bias(self):
"""
Expand Down Expand Up @@ -567,6 +588,12 @@ def routing(self, logits: torch.Tensor, padding_mask: Optional[torch.Tensor] = N
# Apply Z-Loss
logits = self.apply_z_loss(logits, padding_mask=padding_mask)

if _HAS_TRUE_ON_POLICY:
_top = resolve_true_on_policy_runtime_policy(self.config)
_topk_tiebreak = _top.moe_topk_tiebreak if _top.deterministic_moe_routing else None
else:
_topk_tiebreak = None

# Calculate probs and routing_map for token dispatching
if self.routing_type == "sinkhorn":
probs, routing_map = self.sinkhorn_load_balancing(logits)
Expand All @@ -582,6 +609,7 @@ def routing(self, logits: torch.Tensor, padding_mask: Optional[torch.Tensor] = N
expert_bias=self.expert_bias,
fused=self.config.moe_router_fusion,
is_mtp=self.is_mtp,
topk_tiebreak=_topk_tiebreak,
)

# Apply token dropping to probs and routing_map.
Expand All @@ -604,6 +632,7 @@ def routing(self, logits: torch.Tensor, padding_mask: Optional[torch.Tensor] = N
self.score_function,
fused=self.config.moe_router_fusion,
padding_mask=padding_mask,
topk_tiebreak=_topk_tiebreak,
)
probs = self._apply_aux_loss(
probs,
Expand Down
Loading