|
15 | 15 |
|
16 | 16 | import os |
17 | 17 | from pathlib import Path |
| 18 | +from typing import Iterator |
18 | 19 |
|
19 | 20 | import torch |
20 | 21 |
|
| 22 | +# Map megatron lora target modules to HF-style module names for vLLM |
| 23 | +MEGATRON_TO_HF_MODULES = { |
| 24 | + "linear_qkv": ["q_proj", "k_proj", "v_proj"], |
| 25 | + "linear_proj": ["o_proj"], |
| 26 | + "linear_fc1": ["gate_proj", "up_proj"], |
| 27 | + "linear_fc2": ["down_proj"], |
| 28 | + # Canonical LoRA mappings |
| 29 | + "linear_q": ["q_proj"], |
| 30 | + "linear_k": ["k_proj"], |
| 31 | + "linear_v": ["v_proj"], |
| 32 | + "linear_fc1_up": ["up_proj"], |
| 33 | + "linear_fc1_gate": ["gate_proj"], |
| 34 | + # MLA mappings |
| 35 | + "linear_kv_down_proj": ["kv_a_proj_with_mqa"], |
| 36 | + "linear_kv_up_proj": ["kv_b_proj"], |
| 37 | + "linear_q_down_proj": ["q_a_proj"], |
| 38 | + "linear_q_up_proj": ["q_b_proj"], |
| 39 | + "linear_q_proj": ["q_proj"], |
| 40 | +} |
| 41 | + |
| 42 | +# Modules with stacked parameters that need .base_layer suffix in vLLM |
| 43 | +stacked_params = [ |
| 44 | + ".q_proj.weight", |
| 45 | + ".k_proj.weight", |
| 46 | + ".v_proj.weight", |
| 47 | + ".o_proj.weight", |
| 48 | + ".gate_proj.weight", |
| 49 | + ".up_proj.weight", |
| 50 | + ".down_proj.weight", |
| 51 | + ".mlp.gate.weight", |
| 52 | + ".mlp.gate.e_score_correction_bias", |
| 53 | + ".kv_a_proj_with_mqa.weight", |
| 54 | + ".kv_b_proj.weight", |
| 55 | + ".q_a_proj.weight", |
| 56 | + ".q_b_proj.weight", |
| 57 | +] |
| 58 | + |
21 | 59 |
|
22 | 60 | def _get_rank_checkpoint_path(base_path: str) -> str: |
23 | 61 | """Get rank-specific checkpoint path following Megatron's convention. |
@@ -224,10 +262,74 @@ def print_adapter_info(model): |
224 | 262 | print(f"{'=' * 60}\n") |
225 | 263 |
|
226 | 264 |
|
| 265 | +def convert_megatron_to_hf_target_modules(megatron_modules: list[str]) -> list[str]: |
| 266 | + """Convert megatron lora target modules to HF-style module names. |
| 267 | +
|
| 268 | + Args: |
| 269 | + megatron_modules: List of megatron-style module names. |
| 270 | +
|
| 271 | + Returns: |
| 272 | + List of HF-style module names with duplicates removed. |
| 273 | + """ |
| 274 | + hf_target_modules = [] |
| 275 | + for module in megatron_modules: |
| 276 | + if module in MEGATRON_TO_HF_MODULES: |
| 277 | + hf_target_modules.extend(MEGATRON_TO_HF_MODULES[module]) |
| 278 | + else: |
| 279 | + hf_target_modules.append(module) |
| 280 | + # Remove duplicates while preserving order |
| 281 | + return list(dict.fromkeys(hf_target_modules)) |
| 282 | + |
| 283 | + |
| 284 | +def build_peft_config_for_vllm(lora_config: dict) -> dict: |
| 285 | + """Build a peft_config dict compatible with vLLM's PEFTHelper from megatron lora config. |
| 286 | +
|
| 287 | + Args: |
| 288 | + lora_config: Megatron lora configuration dictionary. |
| 289 | +
|
| 290 | + Returns: |
| 291 | + A dictionary compatible with vLLM's PEFTHelper.from_dict(). |
| 292 | + """ |
| 293 | + from peft import TaskType |
| 294 | + |
| 295 | + target_modules = lora_config.get("target_modules", ["linear_qkv", "linear_proj", "linear_fc1", "linear_fc2"]) |
| 296 | + exclude_modules = lora_config.get("exclude_modules", []) |
| 297 | + hf_target_modules = convert_megatron_to_hf_target_modules(target_modules) |
| 298 | + hf_exclude_modules = convert_megatron_to_hf_target_modules(exclude_modules) |
| 299 | + |
| 300 | + return { |
| 301 | + "task_type": TaskType.CAUSAL_LM, |
| 302 | + "r": lora_config.get("rank", 0), |
| 303 | + "lora_alpha": lora_config.get("alpha", 32), |
| 304 | + "target_modules": hf_target_modules, |
| 305 | + "exclude_modules": hf_exclude_modules, |
| 306 | + "bias": "none", |
| 307 | + "lora_dropout": lora_config.get("dropout", 0.0), |
| 308 | + } |
| 309 | + |
| 310 | + |
| 311 | +# vLLM needs to target all-linear no matter about specific LoRA config |
| 312 | +def add_base_layer_suffix(params: Iterator[tuple[str, torch.Tensor]]) -> Iterator[tuple[str, torch.Tensor]]: |
| 313 | + """Yield param pairs with a base-layer suffix added to the param name.""" |
| 314 | + for name, param in params: |
| 315 | + ending_suffix = "" |
| 316 | + for suffix in stacked_params: |
| 317 | + if name.endswith(suffix): |
| 318 | + ending_suffix = suffix |
| 319 | + break |
| 320 | + if ending_suffix: |
| 321 | + suffix = ending_suffix.rsplit(".", 1)[-1] |
| 322 | + name = f"{name[: -len(suffix)]}base_layer.{suffix}" |
| 323 | + yield name, param |
| 324 | + |
| 325 | + |
227 | 326 | __all__ = [ |
228 | 327 | "get_adapter_state_dict", |
229 | 328 | "save_adapter_checkpoint", |
230 | 329 | "load_adapter_checkpoint", |
231 | 330 | "count_adapter_parameters", |
232 | 331 | "print_adapter_info", |
| 332 | + "convert_megatron_to_hf_target_modules", |
| 333 | + "build_peft_config_for_vllm", |
| 334 | + "add_base_layer_suffix", |
233 | 335 | ] |
0 commit comments