Skip to content

Commit 91d4732

Browse files
committed
Update docs and make lora merge optional
Signed-off-by: Hollow Man <[email protected]>
1 parent a090cd8 commit 91d4732

File tree

10 files changed

+218
-64
lines changed

10 files changed

+218
-64
lines changed

docs/advance/ppo_lora.rst

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ Make sure you use Megatron-Bridge later than 0.2.0, and we recommended using `th
7171

7272
1. **LoRA Implementation**: Verl Megatron backend uses Megatron-Bridge's native LoRA implementation, which differs from HuggingFace PEFT.
7373

74-
2. **Weight Sync Mechanism**: Currently, Megatron-Bridge syncs weights by merging LoRA adapters into the base model weights before transferring to vLLM rather than loading separate adapters. This is necessary because Megatron-Bridge's LoRA format is not directly integratable with vLLM's LoRA loading mechanism (HF PEFT format), and LoRA bridge is not yet supported.
74+
2. **Weight Sync / Refit Mechanism**: Currently, Megatron-Bridge can support syncing weights by either merging LoRA adapters into the base model weights before transferring to vLLM (for better inference speed but more refit time and potential precision loss), as well as loading separate adapters.
7575

7676
**Configuration for Megatron LoRA:**
7777

@@ -83,6 +83,9 @@ Make sure you use Megatron-Bridge later than 0.2.0, and we recommended using `th
8383
# LoRA type: "lora", "vlm_lora", "canonical_lora", or "dora"
8484
type: lora
8585
86+
# whether to sync weights / refit by either merging LoRA adapters into the base model weights before transferring to vLLM (for better inference speed but more refit time and potential precision loss). If this is False, it will load separate adapters.
87+
merge: False
88+
8689
# LoRA rank (Dimension of the low-rank projection space.). Set to 0 to disable LoRA
8790
rank: 0
8891
@@ -136,12 +139,11 @@ Make sure you use Megatron-Bridge later than 0.2.0, and we recommended using `th
136139
freeze_vision_projection: True
137140
freeze_language_model: True
138141
142+
LoRA training experiment with Qwen3-8B on 8 * H200 single node comparing FSDP and Megatron backend (script adapted from examples/grpo_trainer/run_qwen2-7b_math_megatron_lora.sh):
139143

140-
**Current Limitations:**
141-
142-
1. **No HuggingFace PEFT Export**: Currently there is no built-in way to export Megatron LoRA adapters to HuggingFace PEFT format for inference with standard HF/vLLM pipelines, such support is coming soon with Megatron-Bridge `LoRA bridge <https://github.com/NVIDIA-NeMo/Megatron-Bridge/issues/1536>`_.
143-
144-
2. **LoRA Merge Overhead**: As we don't have LoRA bridge for now, each weight sync (refit) requires merging LoRA weights, which adds some overhead compared to direct dynamic adapter loading.
144+
.. image:: https://github.com/user-attachments/assets/0482f423-01a3-4e52-a7ee-8b9cd79b7b1a
145+
.. image:: https://github.com/user-attachments/assets/6ce10400-8164-47d8-90a6-c1bf002fb9e8
146+
.. image:: https://github.com/user-attachments/assets/092d3a43-4eba-425e-a584-8d83c1f02de4
145147

146148

147149
Best Practices and Notes

recipe/r1_ascend/vllm_rollout_spmd.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@ def __init__(
5454
)
5555

5656
from recipe.r1_ascend import engine_core # noqa: F401
57+
5758
# NPU-ADAPTATION END
5859

5960
if config.layered_summon:
@@ -65,11 +66,12 @@ def __init__(
6566
tokenizer = model_config.tokenizer
6667
model_hf_config = model_config.hf_config
6768
trust_remote_code = model_config.trust_remote_code
68-
self.lora_kwargs = (
69-
{"enable_lora": True, "max_loras": 1, "max_lora_rank": model_config.lora_rank}
70-
if model_config.lora_rank > 0
71-
else {}
72-
)
69+
lora_rank = model_config.lora.get("rank", 0)
70+
if model_config.lora.get("merge", False):
71+
lora_rank = 0
72+
if lora_rank <= 0:
73+
lora_rank = model_config.lora_rank
74+
self.lora_kwargs = {"enable_lora": True, "max_loras": 1, "max_lora_rank": lora_rank} if lora_rank > 0 else {}
7375

7476
tensor_parallel_size = self.config.get("tensor_model_parallel_size", 1)
7577
assert tensor_parallel_size <= torch.distributed.get_world_size(), (

verl/trainer/config/_generated_ppo_megatron_trainer.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -320,6 +320,7 @@ actor_rollout_ref:
320320
impl_backend: torch
321321
lora:
322322
type: lora
323+
merge: false
323324
rank: 0
324325
alpha: 32
325326
dropout: 0.0

verl/trainer/config/ppo_megatron_trainer.yaml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,9 @@ actor_rollout_ref:
4646
# LoRA type: "lora", "vlm_lora", "canonical_lora", or "dora"
4747
type: lora
4848

49+
# whether to sync weights / refit by either merging LoRA adapters into the base model weights before transferring to vLLM (for better inference speed but more refit time and potential precision loss). If this is False, it will load separate adapters.
50+
merge: False
51+
4952
# LoRA rank (Dimension of the low-rank projection space.). Set to 0 to disable LoRA
5053
rank: 0 # typical values: 8, 16, 32, 64
5154

verl/utils/config.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -199,7 +199,13 @@ def check_mutually_exclusive(mbs, mbs_per_gpu, name: str):
199199
)
200200

201201
# check LoRA rank in vLLM
202-
if config.actor_rollout_ref.model.get("lora_rank", 0) > 0 and config.actor_rollout_ref.rollout.name == "vllm":
203-
assert config.actor_rollout_ref.model.lora_rank <= 512, "LoRA rank in vLLM must be less than or equal to 512"
202+
lora_config = config.actor_rollout_ref.model.get("lora", {})
203+
lora_rank = lora_config.get("rank", 0)
204+
if lora_config.get("merge", False):
205+
lora_rank = 0
206+
if lora_rank <= 0:
207+
lora_rank = config.actor_rollout_ref.model.get("lora_rank", 0)
208+
if lora_rank > 0 and config.actor_rollout_ref.rollout.name == "vllm":
209+
assert lora_rank <= 512, "LoRA rank in vLLM must be less than or equal to 512"
204210

205211
print("[validate_config] All configuration checks passed successfully!")

verl/utils/megatron_peft_utils.py

Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,47 @@
1515

1616
import os
1717
from pathlib import Path
18+
from typing import Iterator
1819

1920
import torch
2021

22+
# Map megatron lora target modules to HF-style module names for vLLM
23+
MEGATRON_TO_HF_MODULES = {
24+
"linear_qkv": ["q_proj", "k_proj", "v_proj"],
25+
"linear_proj": ["o_proj"],
26+
"linear_fc1": ["gate_proj", "up_proj"],
27+
"linear_fc2": ["down_proj"],
28+
# Canonical LoRA mappings
29+
"linear_q": ["q_proj"],
30+
"linear_k": ["k_proj"],
31+
"linear_v": ["v_proj"],
32+
"linear_fc1_up": ["up_proj"],
33+
"linear_fc1_gate": ["gate_proj"],
34+
# MLA mappings
35+
"linear_kv_down_proj": ["kv_a_proj_with_mqa"],
36+
"linear_kv_up_proj": ["kv_b_proj"],
37+
"linear_q_down_proj": ["q_a_proj"],
38+
"linear_q_up_proj": ["q_b_proj"],
39+
"linear_q_proj": ["q_proj"],
40+
}
41+
42+
# Modules with stacked parameters that need .base_layer suffix in vLLM
43+
stacked_params = [
44+
".q_proj.weight",
45+
".k_proj.weight",
46+
".v_proj.weight",
47+
".o_proj.weight",
48+
".gate_proj.weight",
49+
".up_proj.weight",
50+
".down_proj.weight",
51+
".mlp.gate.weight",
52+
".mlp.gate.e_score_correction_bias",
53+
".kv_a_proj_with_mqa.weight",
54+
".kv_b_proj.weight",
55+
".q_a_proj.weight",
56+
".q_b_proj.weight",
57+
]
58+
2159

2260
def _get_rank_checkpoint_path(base_path: str) -> str:
2361
"""Get rank-specific checkpoint path following Megatron's convention.
@@ -224,10 +262,74 @@ def print_adapter_info(model):
224262
print(f"{'=' * 60}\n")
225263

226264

265+
def convert_megatron_to_hf_target_modules(megatron_modules: list[str]) -> list[str]:
266+
"""Convert megatron lora target modules to HF-style module names.
267+
268+
Args:
269+
megatron_modules: List of megatron-style module names.
270+
271+
Returns:
272+
List of HF-style module names with duplicates removed.
273+
"""
274+
hf_target_modules = []
275+
for module in megatron_modules:
276+
if module in MEGATRON_TO_HF_MODULES:
277+
hf_target_modules.extend(MEGATRON_TO_HF_MODULES[module])
278+
else:
279+
hf_target_modules.append(module)
280+
# Remove duplicates while preserving order
281+
return list(dict.fromkeys(hf_target_modules))
282+
283+
284+
def build_peft_config_for_vllm(lora_config: dict) -> dict:
285+
"""Build a peft_config dict compatible with vLLM's PEFTHelper from megatron lora config.
286+
287+
Args:
288+
lora_config: Megatron lora configuration dictionary.
289+
290+
Returns:
291+
A dictionary compatible with vLLM's PEFTHelper.from_dict().
292+
"""
293+
from peft import TaskType
294+
295+
target_modules = lora_config.get("target_modules", ["linear_qkv", "linear_proj", "linear_fc1", "linear_fc2"])
296+
exclude_modules = lora_config.get("exclude_modules", [])
297+
hf_target_modules = convert_megatron_to_hf_target_modules(target_modules)
298+
hf_exclude_modules = convert_megatron_to_hf_target_modules(exclude_modules)
299+
300+
return {
301+
"task_type": TaskType.CAUSAL_LM,
302+
"r": lora_config.get("rank", 0),
303+
"lora_alpha": lora_config.get("alpha", 32),
304+
"target_modules": hf_target_modules,
305+
"exclude_modules": hf_exclude_modules,
306+
"bias": "none",
307+
"lora_dropout": lora_config.get("dropout", 0.0),
308+
}
309+
310+
311+
# vLLM needs to target all-linear no matter about specific LoRA config
312+
def add_base_layer_suffix(params: Iterator[tuple[str, torch.Tensor]]) -> Iterator[tuple[str, torch.Tensor]]:
313+
"""Yield param pairs with a base-layer suffix added to the param name."""
314+
for name, param in params:
315+
ending_suffix = ""
316+
for suffix in stacked_params:
317+
if name.endswith(suffix):
318+
ending_suffix = suffix
319+
break
320+
if ending_suffix:
321+
suffix = ending_suffix.rsplit(".", 1)[-1]
322+
name = f"{name[: -len(suffix)]}base_layer.{suffix}"
323+
yield name, param
324+
325+
227326
__all__ = [
228327
"get_adapter_state_dict",
229328
"save_adapter_checkpoint",
230329
"load_adapter_checkpoint",
231330
"count_adapter_parameters",
232331
"print_adapter_info",
332+
"convert_megatron_to_hf_target_modules",
333+
"build_peft_config_for_vllm",
334+
"add_base_layer_suffix",
233335
]

verl/workers/engine/megatron/transformer_impl.py

Lines changed: 16 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -32,28 +32,20 @@
3232
from verl.utils.debug import log_gpu_memory_usage
3333
from verl.utils.device import get_device_id, get_device_name
3434
from verl.utils.megatron.pipeline_parallel import make_batch_generator
35-
from verl.utils.megatron.tensor_parallel import (
36-
vocab_parallel_entropy,
37-
vocab_parallel_log_probs_from_logits,
38-
)
35+
from verl.utils.megatron.tensor_parallel import vocab_parallel_entropy, vocab_parallel_log_probs_from_logits
36+
from verl.utils.megatron_peft_utils import build_peft_config_for_vllm
3937
from verl.utils.megatron_utils import (
4038
load_megatron_model_to_gpu,
4139
load_megatron_optimizer,
4240
offload_megatron_model_to_cpu,
4341
offload_megatron_optimizer,
4442
register_megatron_training_hooks,
4543
)
46-
from verl.utils.model import (
47-
extract_multi_modal_inputs,
48-
load_mcore_dist_weights,
49-
)
44+
from verl.utils.model import extract_multi_modal_inputs, load_mcore_dist_weights
5045
from verl.workers.config import HFModelConfig, McoreEngineConfig, McoreOptimizerConfig
5146

5247
from ..base import BaseEngine, BaseEngineCtx, EngineRegistry
53-
from ..utils import (
54-
postprocess_batch_func,
55-
prepare_micro_batches,
56-
)
48+
from ..utils import postprocess_batch_func, prepare_micro_batches
5749
from .utils import set_random_seed
5850

5951
logger = logging.getLogger(__file__)
@@ -180,10 +172,7 @@ def _build_tf_config(self):
180172
)
181173

182174
def _build_megatron_module(self):
183-
from verl.utils.megatron_utils import (
184-
McoreModuleWrapperConfig,
185-
make_megatron_module,
186-
)
175+
from verl.utils.megatron_utils import McoreModuleWrapperConfig, make_megatron_module
187176
from verl.utils.model import print_model_size
188177

189178
# TODO: add more cases
@@ -238,10 +227,7 @@ def _build_megatron_module(self):
238227
return module
239228

240229
def _build_optimizer(self):
241-
from verl.utils.megatron.optimizer import (
242-
get_megatron_optimizer,
243-
init_megatron_optim_config,
244-
)
230+
from verl.utils.megatron.optimizer import get_megatron_optimizer, init_megatron_optim_config
245231

246232
optim_config_megatron = init_megatron_optim_config(
247233
self.optimizer_config,
@@ -538,9 +524,16 @@ def forward_backward_batch(self, data: TensorDict, loss_function: Callable, forw
538524
def get_per_tensor_param(self):
539525
if self._is_offload_param:
540526
load_megatron_model_to_gpu(self.module, load_grad=False)
541-
per_tensor_param = self.bridge.export_weights(self.module)
542-
# TODO: support megatron LoRA
543-
return per_tensor_param, None
527+
peft_config = None
528+
if self.vanilla_bridge:
529+
per_tensor_param = self.bridge.export_weights(self.module)
530+
elif not self.model_config.lora.get("merge", False) and self.peft_cls is not None:
531+
# Only export adapter weights
532+
peft_config = build_peft_config_for_vllm(self.model_config.lora)
533+
per_tensor_param = self.bridge.export_adapter_weights(self.module)
534+
else:
535+
per_tensor_param = self.bridge.export_hf_weights(self.module)
536+
return per_tensor_param, peft_config
544537

545538
def forward_step(self, batch_iter, model, postprocess_micro_batch_func):
546539
raise NotImplementedError("forward_step must be implemented in subclass")

verl/workers/megatron_workers.py

Lines changed: 38 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@
5252
from verl.utils.flops_counter import FlopsCounter
5353
from verl.utils.fs import copy_to_local
5454
from verl.utils.megatron.router_replay_patch import RouterReplay, RouterReplayAction, apply_router_replay_patch
55+
from verl.utils.megatron_peft_utils import add_base_layer_suffix, build_peft_config_for_vllm
5556
from verl.utils.megatron_utils import (
5657
load_megatron_model_to_gpu,
5758
load_megatron_optimizer,
@@ -329,6 +330,10 @@ def __init__(self, config: DictConfig, role: str, **kwargs):
329330
self._is_offload_grad = False
330331
self._is_offload_optimizer = False
331332

333+
# Initialize LoRA-related attributes (will be updated in _build_rollout if needed)
334+
self.base_sync_done = False
335+
self.peft_merge = False
336+
332337
# normalize config
333338
if self._is_actor:
334339
self.config.actor.ppo_mini_batch_size *= self.config.rollout.n
@@ -488,14 +493,7 @@ def _build_rollout(self, trust_remote_code=False):
488493

489494
# 1. parse rollout and huggingface model config
490495
rollout_config: RolloutConfig = omega_conf_to_dataclass(self.config.rollout)
491-
492-
# Convert megatron lora config to HFModelConfig
493-
model_config_dict = OmegaConf.to_container(self.config.model)
494-
model_config_dict.pop("lora", None)
495-
496-
model_config: HFModelConfig = omega_conf_to_dataclass(
497-
OmegaConf.create(model_config_dict), dataclass_type=HFModelConfig
498-
)
496+
model_config: HFModelConfig = omega_conf_to_dataclass(self.config.model)
499497

500498
# 2. build rollout device mesh
501499
infer_tp = self.config.rollout.tensor_model_parallel_size * self.config.rollout.data_parallel_size
@@ -531,6 +529,10 @@ def _build_rollout(self, trust_remote_code=False):
531529
)
532530
log_gpu_memory_usage(f"After building {self.config.rollout.name} rollout", logger=logger)
533531

532+
# Initialize base_sync_done for LoRA
533+
self.base_sync_done: bool = "dummy" not in self.config.rollout.load_format
534+
self.peft_merge: bool = model_config.lora.get("merge", False)
535+
534536
# 5. switch to trainer mode
535537
# NOTE: It's critical that hybrid engine in trainer mode initially to load checkpoint.
536538
# For async mode, we can't call run_until_complete here, so we will switch to trainer mode in AgentLoopManager.
@@ -670,9 +672,24 @@ async def rollout_mode(self):
670672
load_megatron_model_to_gpu(self.actor.actor_module, load_grad=False)
671673
log_gpu_memory_usage("After load actor params during rollout_mode", logger=logger)
672674

675+
# Build peft_config for vLLM LoRA support
676+
peft_config = None
677+
do_lora_base_sync = False
678+
if not self.peft_merge and self.peft_cls is not None:
679+
peft_config = build_peft_config_for_vllm(self.config.model.get("lora", {}))
680+
# set sleep level for LoRA adapter weights only sync
681+
# TODO: make this configurable so that users with small
682+
# main memory can trade sync time to avoid OOM
683+
self.rollout.sleep_level = 1
684+
685+
do_lora_base_sync = not self.base_sync_done or self.rollout.sleep_level != 1
686+
673687
if self.bridge is not None:
674688
if self.vanilla_bridge:
675689
per_tensor_param = self.bridge.export_weights(self.actor.actor_module)
690+
elif not self.peft_merge and self.peft_cls is not None:
691+
# Only export adapter weights
692+
per_tensor_param = self.bridge.export_adapter_weights(self.actor.actor_module)
676693
else:
677694
per_tensor_param = self.bridge.export_hf_weights(self.actor.actor_module)
678695
else:
@@ -686,7 +703,19 @@ async def rollout_mode(self):
686703

687704
if self.config.rollout.free_cache_engine:
688705
await self.rollout.resume(tags=["weights"])
689-
await self.rollout.update_weights(per_tensor_param)
706+
if do_lora_base_sync:
707+
# Base layer sync
708+
per_tensor_param_lora_base = self.bridge.export_hf_weights(
709+
self.actor.actor_module, merge_adapter_weights=False
710+
)
711+
await self.rollout.update_weights(
712+
add_base_layer_suffix(per_tensor_param_lora_base), peft_config=peft_config, base_sync_done=False
713+
)
714+
715+
# Mark base sync as done after first successful sync
716+
self.base_sync_done = True
717+
718+
await self.rollout.update_weights(per_tensor_param, peft_config=peft_config, base_sync_done=True)
690719
if self._is_offload_param:
691720
offload_megatron_model_to_cpu(self.actor.actor_module)
692721
aggressive_empty_cache(force_sync=True)

0 commit comments

Comments
 (0)