Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions miles/backends/megatron_utils/bridge_lora_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,9 +93,14 @@ def _setup_lora_model_via_bridge(args: Namespace) -> list:
provider.sequence_parallel = args.sequence_parallel
provider.virtual_pipeline_model_parallel_size = args.virtual_pipeline_model_parallel_size
provider.context_parallel_size = args.context_parallel_size
provider.gradient_accumulation_fusion = args.gradient_accumulation_fusion
provider.variable_seq_lengths = True
provider.moe_token_dispatcher_type = "alltoall"
provider.moe_router_load_balancing_type = "none"
if getattr(args, "decoder_first_pipeline_num_layers", None) is not None:
provider.num_layers_in_first_pipeline_stage = args.decoder_first_pipeline_num_layers
if getattr(args, "decoder_last_pipeline_num_layers", None) is not None:
provider.num_layers_in_last_pipeline_stage = args.decoder_last_pipeline_num_layers
provider.finalize()

lora = create_lora_instance(args)
Expand Down
89 changes: 80 additions & 9 deletions miles/backends/megatron_utils/lora_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,19 @@

_HF_MODULE_NAMES = {"q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"}

# DeepSeek / Kimi MLA (HF names on checkpoint; Megatron uses linear_* from Megatron-Bridge mappings).
_MLA_HF_TO_MEGATRON = {
"q_a_proj": "linear_q_down_proj",
"kv_a_proj_with_mqa": "linear_kv_down_proj",
"q_b_proj": "linear_q_up_proj",
"kv_b_proj": "linear_kv_up_proj",
}
_MEGATRON_MLA_TO_HF = {v: k for k, v in _MLA_HF_TO_MEGATRON.items()}

# SGLang default get_hidden_dim (lora/utils.py) handles fused_qkv_a_proj_with_mqa via q_a / kv_a mapping,
# but not separate q_b_proj / kv_b_proj yet — omit from rollout adapter config to avoid init crashes.
_SGLANG_UNSUPPORTED_HF_TARGETS = frozenset({"q_b_proj", "kv_b_proj"})


# ---------------------------------------------------------------------------
# Core helpers
Expand Down Expand Up @@ -182,14 +195,20 @@ def convert_target_modules_to_megatron(
if hf_modules[0] in ("all", "all-linear", "all_linear"):
return list(all_modules)

# Check if already in Megatron format
if all(m not in _HF_MODULE_NAMES for m in hf_modules if "*" not in m):
return hf_modules
if isinstance(hf_modules, tuple):
hf_modules = list(hf_modules)

# Check if already in Megatron format (standard / canonical / Kimi MLA linear_*).
if all(m not in _HF_MODULE_NAMES and m not in _MLA_HF_TO_MEGATRON for m in hf_modules if "*" not in m):
return list(hf_modules)

# Convert HF names to Megatron names (dedup while preserving order)
megatron_modules: list[str] = []
for module in hf_modules:
megatron_name = hf_to_megatron.get(module, module)
if module in _MLA_HF_TO_MEGATRON:
megatron_name = _MLA_HF_TO_MEGATRON[module]
else:
megatron_name = hf_to_megatron.get(module, module)
if megatron_name not in megatron_modules:
megatron_modules.append(megatron_name)

Expand All @@ -205,14 +224,60 @@ def convert_target_modules_to_hf(megatron_modules: list[str]) -> list[str]:
Megatron canonical: linear_q, linear_k, linear_v, linear_proj,
linear_fc1_up, linear_fc1_gate, linear_fc2
HF: q_proj, k_proj, v_proj, o_proj, gate_proj, up_proj, down_proj
Kimi MLA Megatron: linear_q_down_proj -> q_a_proj, linear_kv_down_proj -> kv_a_proj_with_mqa, ...

Wildcard patterns (e.g. ``*.layers.2.mlp.experts.linear_fc1``) have their
last dotted segment extracted and mapped to HF leaf names; the layer-
scoping information is dropped. Rationale: SGLang consumes this list to
decide which adapter-buffer TYPES to allocate (``gate_proj``/``up_proj``
/``down_proj``/...). It does not scope by layer — per-layer LoRA
enablement is enforced on the training side by Megatron-Bridge's
:class:`ModuleMatcher`. If we passed wildcards through verbatim, SGLang
would fail to allocate MoE LoRA buffers when the config only has
wildcards (no plain HF names) for MoE modules.
"""
if isinstance(megatron_modules, tuple):
megatron_modules = list(megatron_modules)
hf_modules: list[str] = []
for module in megatron_modules:
if module in _MEGATRON_TO_HF_MODULES:
hf_modules.extend(_MEGATRON_TO_HF_MODULES[module])
# Wildcards: extract the last path segment so lookups hit the map
# tables. Non-wildcards go through as-is.
lookup_key = module.rsplit(".", 1)[-1] if "*" in module else module
if lookup_key in _MEGATRON_MLA_TO_HF:
hf_modules.append(_MEGATRON_MLA_TO_HF[lookup_key])
elif lookup_key in _MEGATRON_TO_HF_MODULES:
hf_modules.extend(_MEGATRON_TO_HF_MODULES[lookup_key])
else:
# Unknown wildcard tail or already-HF name — pass through as-is.
hf_modules.append(module)
return hf_modules
# Dedup preserving order.
seen: set[str] = set()
unique: list[str] = []
for m in hf_modules:
if m not in seen:
seen.add(m)
unique.append(m)
return unique


def target_modules_hf_for_sglang_rollout(args: Namespace) -> list[str]:
"""HF-style target_modules for SGLang LoRA init / sync (after Megatron→HF name map).

Drops MLA tensors that SGLang's default ``get_hidden_dim`` does not implement yet, so colocated
engines can build LoRAMemoryPool without NotImplementedError. Keep Megatron ``target_modules``
aligned with this list if you rely on online adapter sync.
"""
raw = list(args.target_modules) if args.target_modules else []
hf = convert_target_modules_to_hf(raw)
out = [m for m in hf if m not in _SGLANG_UNSUPPORTED_HF_TARGETS]
dropped = set(hf) - set(out)
if dropped:
logger.warning(
"target_modules_hf_for_sglang_rollout: omitting %s for SGLang (unsupported by default "
"get_hidden_dim); Megatron should not train LoRA on these if rollout sync is required.",
sorted(dropped),
)
return out


# ---------------------------------------------------------------------------
Expand Down Expand Up @@ -252,7 +317,7 @@ def create_lora_instance(args: Namespace):
target_modules = convert_target_modules_to_megatron(args.target_modules, lora_type=lora_cls)
exclude_modules = parse_exclude_modules(args, lora_type=lora_cls)

lora = lora_cls(
lora_kwargs = dict(
target_modules=target_modules,
exclude_modules=exclude_modules,
dim=args.lora_rank,
Expand All @@ -261,6 +326,12 @@ def create_lora_instance(args: Namespace):
lora_A_init_method=getattr(args, "lora_A_init_method", "xavier"),
lora_B_init_method=getattr(args, "lora_B_init_method", "zero"),
)
# Opt-in to SGLang PR #21466's shared-outer grouped-expert LoRA. Only the
# standard ``LoRA`` class supports the flag today.
if lora_cls is LoRA and getattr(args, "experts_shared_outer_loras", False):
lora_kwargs["experts_shared_outer_loras"] = True

lora = lora_cls(**lora_kwargs)

logger.info(
f"Created {lora_cls.__name__}: rank={args.lora_rank}, alpha={args.lora_alpha}, "
Expand Down Expand Up @@ -491,7 +562,7 @@ def _load_training_state(
def build_lora_sync_config(args: Namespace) -> dict[str, Any]:
"""Build LoRA config dict for syncing weights to SGLang engines."""
target_modules_hf = (
convert_target_modules_to_hf(list(args.target_modules))
target_modules_hf_for_sglang_rollout(args)
if args.target_modules
else ["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"]
)
Expand Down
5 changes: 5 additions & 0 deletions miles/backends/megatron_utils/megatron_to_hf/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from .deepseekv3 import convert_deepseekv3_to_hf
from .glm4 import convert_glm4_to_hf
from .glm4moe import convert_glm4moe_to_hf
from .kimi_vl import convert_kimi_k25_to_hf, convert_kimivl_to_hf
from .llama import convert_llama_to_hf
from .mimo import convert_mimo_to_hf
from .processors import quantize_params, remove_padding
Expand Down Expand Up @@ -50,6 +51,10 @@ def _convert_to_hf_core(args, model_name, name, param):
converted_named_tensors = convert_llama_to_hf(args, name, param)
elif "mimo" in model_name:
converted_named_tensors = convert_mimo_to_hf(args, name, param)
elif "kimivl" in model_name:
converted_named_tensors = convert_kimivl_to_hf(args, name, param)
elif "kimi_k25" in model_name:
converted_named_tensors = convert_kimi_k25_to_hf(args, name, param)
else:
raise ValueError(f"Unsupported model: {model_name}")

Expand Down
139 changes: 139 additions & 0 deletions miles/backends/megatron_utils/megatron_to_hf/kimi_vl.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,139 @@
import re

import torch


def convert_kimivl_to_hf(args, name, param):
if name.startswith("module.module.vision_model."):
hf_name = "vision_tower." + name[len("module.module.vision_model.") :]
return [(hf_name, param)]

if name.startswith("module.module.multi_modal_projector."):
hf_name = "multi_modal_projector." + name[len("module.module.multi_modal_projector.") :]
return [(hf_name, param)]

return convert_language_model_to_hf(args, name, param)


def convert_kimi_k25_to_hf(args, name, param):
if name.startswith("module.module.vision_tower."):
hf_name = "vision_tower." + name[len("module.module.vision_tower.") :]
return [(hf_name, param)]

if name.startswith("module.module.mm_projector."):
hf_name = "mm_projector." + name[len("module.module.mm_projector.") :]
return [(hf_name, param)]

return convert_language_model_to_hf(args, name, param)


def convert_language_model_to_hf(args, name, param):
if name == "module.module.language_model.embedding.word_embeddings.weight":
return [("language_model.model.embed_tokens.weight", param)]
if name == "module.module.language_model.output_layer.weight":
return [("language_model.lm_head.weight", param)]
if name == "module.module.language_model.decoder.final_layernorm.weight":
return [("language_model.model.norm.weight", param)]

try:
head_dim = args.kv_channels if args.kv_channels is not None else args.hidden_size // args.num_attention_heads
except AttributeError:
head_dim = args.hidden_size // args.num_attention_heads
value_num_per_group = args.num_attention_heads // args.num_query_groups

decoder_layers_pattern = r"module\.module\.(?:language_model\.)?decoder\.layers\.(\d+)\.(.+)"
match = re.match(decoder_layers_pattern, name)
if match:
layer_idx, rest = match.groups()

# experts
expert_pattern = r"mlp.experts\.(.+)\.weight(\d+)"
match = re.match(expert_pattern, rest)
if match:
rest, expert_idx = match.groups()
if rest == "linear_fc1":
gate_weight, up_weight = param.chunk(2, dim=0)
outputs = [
(
f"language_model.model.layers.{layer_idx}.mlp.experts.{expert_idx}.gate_proj.weight",
gate_weight,
),
(f"language_model.model.layers.{layer_idx}.mlp.experts.{expert_idx}.up_proj.weight", up_weight),
]
return outputs
elif rest == "linear_fc2":
outputs = [
(f"language_model.model.layers.{layer_idx}.mlp.experts.{expert_idx}.down_proj.weight", param),
]
return outputs
else:
raise ValueError(f"Unknown expert parameter name: {name}")

# shared expert
shared_expert_pattern = r"mlp.shared_experts\.(.+)"
match = re.match(shared_expert_pattern, rest)
if match:
# rest = match.groups()[0]
rest = match.group(1)
if rest == "linear_fc1.weight":
gate_weight, up_weight = param.chunk(2, dim=0)
return [
(f"language_model.model.layers.{layer_idx}.mlp.shared_experts.gate_proj.weight", gate_weight),
(f"language_model.model.layers.{layer_idx}.mlp.shared_experts.up_proj.weight", up_weight),
]
elif rest == "linear_fc2.weight":
return [(f"language_model.model.layers.{layer_idx}.mlp.shared_experts.down_proj.weight", param)]
else:
raise ValueError(f"Unknown shared expert parameter name: {name}")

if rest == "self_attention.linear_proj.weight":
return [(f"language_model.model.layers.{layer_idx}.self_attn.o_proj.weight", param)]
elif rest == "self_attention.linear_q_proj.weight":
return [(f"language_model.model.layers.{layer_idx}.self_attn.q_proj.weight", param)]
elif rest == "self_attention.linear_q_down_proj.weight":
return [(f"language_model.model.layers.{layer_idx}.self_attn.q_a_proj.weight", param)]
elif rest == "self_attention.linear_q_up_proj.layer_norm_weight":
return [(f"language_model.model.layers.{layer_idx}.self_attn.q_a_layernorm.weight", param)]
elif rest == "self_attention.linear_q_up_proj.weight":
return [(f"language_model.model.layers.{layer_idx}.self_attn.q_b_proj.weight", param)]
elif rest == "self_attention.linear_qkv.bias":
param = param.view(args.num_query_groups, -1)
q_bias, k_bias, v_bias = torch.split(
param,
split_size_or_sections=[value_num_per_group * head_dim, head_dim, head_dim],
dim=1,
)
q_bias = q_bias.contiguous().flatten()
k_bias = k_bias.contiguous().flatten()
v_bias = v_bias.contiguous().flatten()
return [
(f"language_model.model.layers.{layer_idx}.self_attn.q_proj.bias", q_bias),
(f"language_model.model.layers.{layer_idx}.self_attn.k_proj.bias", k_bias),
(f"language_model.model.layers.{layer_idx}.self_attn.v_proj.bias", v_bias),
]
elif rest == "mlp.linear_fc1.weight":
gate_weight, up_weight = param.chunk(2, dim=0)
return [
(f"language_model.model.layers.{layer_idx}.mlp.gate_proj.weight", gate_weight),
(f"language_model.model.layers.{layer_idx}.mlp.up_proj.weight", up_weight),
]
elif rest == "mlp.linear_fc2.weight":
return [(f"language_model.model.layers.{layer_idx}.mlp.down_proj.weight", param)]
elif rest == "self_attention.linear_qkv.layer_norm_weight" or rest == "input_layernorm.weight":
return [(f"language_model.model.layers.{layer_idx}.input_layernorm.weight", param)]
elif rest == "mlp.linear_fc1.layer_norm_weight":
return [(f"language_model.model.layers.{layer_idx}.post_attention_layernorm.weight", param)]
elif rest == "self_attention.linear_kv_down_proj.weight":
return [(f"language_model.model.layers.{layer_idx}.self_attn.kv_a_proj_with_mqa.weight", param)]
elif rest == "self_attention.linear_kv_up_proj.layer_norm_weight":
return [(f"language_model.model.layers.{layer_idx}.self_attn.kv_a_layernorm.weight", param)]
elif rest == "self_attention.linear_kv_up_proj.weight":
return [(f"language_model.model.layers.{layer_idx}.self_attn.kv_b_proj.weight", param)]
elif rest == "pre_mlp_layernorm.weight":
return [(f"language_model.model.layers.{layer_idx}.post_attention_layernorm.weight", param)]
elif rest == "mlp.router.weight":
return [(f"language_model.model.layers.{layer_idx}.mlp.gate.weight", param)]
elif rest == "mlp.router.expert_bias":
return [(f"language_model.model.layers.{layer_idx}.mlp.gate.e_score_correction_bias", param)]

raise ValueError(f"Unknown parameter name: {name}")
11 changes: 9 additions & 2 deletions miles/backends/megatron_utils/sglang.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,15 @@
from sglang.srt.utils import MultiprocessingSerializer

try:
from sglang.srt.weight_sync.tensor_bucket import FlattenedTensorBucket # type: ignore[import]
from sglang.srt.weight_sync.tensor_bucket import ( # type: ignore[import]
FlattenedTensorBucket,
FlattenedTensorMetadata,
)
except ImportError:
from sglang.srt.model_executor.model_runner import FlattenedTensorBucket # type: ignore[import]
from sglang.srt.model_executor.model_runner import ( # type: ignore[import]
FlattenedTensorBucket,
FlattenedTensorMetadata,
)

__all__ = [
"mxfp8_group_quantize",
Expand All @@ -33,4 +39,5 @@
"monkey_patch_torch_reductions",
"MultiprocessingSerializer",
"FlattenedTensorBucket",
"FlattenedTensorMetadata",
]
4 changes: 2 additions & 2 deletions miles/backends/megatron_utils/update_weight/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,7 @@ def _named_params_and_buffers_global(
if not name.startswith("module.module."):
name = "module." + name

decoder_layers_pattern = r"module\.module\.decoder\.layers\.(\d+)\.(.+)"
decoder_layers_pattern = r"module\.module\.(?:language_model\.)?decoder\.layers\.(\d+)\.(.+)"
match = re.match(decoder_layers_pattern, name)
if not match:
# MTP (Multi-Token Prediction) layers for speculative decoding
Expand Down Expand Up @@ -246,7 +246,7 @@ def _named_params_and_buffers_global(
if not name.startswith("module.module."):
name = "module." + name

decoder_layers_pattern = r"module\.module\.decoder\.layers\.(\d+)\.(.+)"
decoder_layers_pattern = r"module\.module\.(?:language_model\.)?decoder\.layers\.(\d+)\.(.+)"
match = re.match(decoder_layers_pattern, name)
if not match:
yield name, buffer
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ def get_hf_weight_chunks(self, megatron_local_weights, weight_type: str = "base"
self.model,
cpu=False,
conversion_tasks=conversion_tasks,
merge_adapter_weights=False,
)

# TODO: verify if postprocess_hf_param is needed for LoRA weights
Expand All @@ -53,7 +54,11 @@ def get_hf_weight_chunks(self, megatron_local_weights, weight_type: str = "base"
)

if weight_type == "base":
named_weights = ((n, t) for n, t in named_weights if not is_lora_weight_name(n))
named_weights = (
(n.replace(".base_layer.", "."), t)
for n, t in named_weights
if not is_lora_weight_name(n)
)
elif weight_type == "lora":
named_weights = ((n, t) for n, t in named_weights if is_lora_weight_name(n))

Expand Down
Loading
Loading