Skip to content

Commit f6706bc

Browse files
committed
[megatron] feat: LoRA adapter only refit (TensorLoRARequest)
Signed-off-by: Hollow Man <[email protected]>
1 parent 5b0cd28 commit f6706bc

File tree

8 files changed

+164
-65
lines changed

8 files changed

+164
-65
lines changed

docs/advance/ppo_lora.rst

Lines changed: 1 addition & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -67,11 +67,7 @@ Make sure you use Megatron-Bridge later than 0.2.0, and we recommended using `th
6767
- ``actor_rollout_ref.actor.megatron.use_mbridge=True``
6868
- ``actor_rollout_ref.actor.megatron.vanilla_mbridge=False``
6969

70-
**Key Differences from FSDP LoRA:**
71-
72-
1. **LoRA Implementation**: Verl Megatron backend uses Megatron-Bridge's native LoRA implementation, which differs from HuggingFace PEFT.
73-
74-
2. **Weight Sync Mechanism**: Currently, Megatron-Bridge syncs weights by merging LoRA adapters into the base model weights before transferring to vLLM rather than loading separate adapters. This is necessary because Megatron-Bridge's LoRA format is not directly integratable with vLLM's LoRA loading mechanism (HF PEFT format), and LoRA bridge is not yet supported.
70+
**Key Differences from FSDP LoRA:** Verl Megatron backend uses Megatron-Bridge's native LoRA implementation, which differs from HuggingFace PEFT.
7571

7672
**Configuration for Megatron LoRA:**
7773

@@ -136,14 +132,6 @@ Make sure you use Megatron-Bridge later than 0.2.0, and we recommended using `th
136132
freeze_vision_projection: True
137133
freeze_language_model: True
138134
139-
140-
**Current Limitations:**
141-
142-
1. **No HuggingFace PEFT Export**: Currently there is no built-in way to export Megatron LoRA adapters to HuggingFace PEFT format for inference with standard HF/vLLM pipelines, such support is coming soon with Megatron-Bridge `LoRA bridge <https://github.com/NVIDIA-NeMo/Megatron-Bridge/issues/1536>`_.
143-
144-
2. **LoRA Merge Overhead**: As we don't have LoRA bridge for now, each weight sync (refit) requires merging LoRA weights, which adds some overhead compared to direct dynamic adapter loading.
145-
146-
147135
Best Practices and Notes
148136
-------------------------
149137

recipe/r1_ascend/vllm_rollout_spmd.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@ def __init__(
5454
)
5555

5656
from recipe.r1_ascend import engine_core # noqa: F401
57+
5758
# NPU-ADAPTATION END
5859

5960
if config.layered_summon:
@@ -65,11 +66,10 @@ def __init__(
6566
tokenizer = model_config.tokenizer
6667
model_hf_config = model_config.hf_config
6768
trust_remote_code = model_config.trust_remote_code
68-
self.lora_kwargs = (
69-
{"enable_lora": True, "max_loras": 1, "max_lora_rank": model_config.lora_rank}
70-
if model_config.lora_rank > 0
71-
else {}
72-
)
69+
lora_rank = model_config.lora.get("rank", 0)
70+
if lora_rank <= 0:
71+
lora_rank = model_config.lora_rank
72+
self.lora_kwargs = {"enable_lora": True, "max_loras": 1, "max_lora_rank": lora_rank} if lora_rank > 0 else {}
7373

7474
tensor_parallel_size = self.config.get("tensor_model_parallel_size", 1)
7575
assert tensor_parallel_size <= torch.distributed.get_world_size(), (

verl/utils/config.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -199,7 +199,10 @@ def check_mutually_exclusive(mbs, mbs_per_gpu, name: str):
199199
)
200200

201201
# check LoRA rank in vLLM
202-
if config.actor_rollout_ref.model.get("lora_rank", 0) > 0 and config.actor_rollout_ref.rollout.name == "vllm":
203-
assert config.actor_rollout_ref.model.lora_rank <= 512, "LoRA rank in vLLM must be less than or equal to 512"
202+
lora_rank = config.actor_rollout_ref.model.get("lora", {}).get("rank", 0)
203+
if lora_rank <= 0:
204+
lora_rank = config.actor_rollout_ref.model.get("lora_rank", 0)
205+
if lora_rank > 0 and config.actor_rollout_ref.rollout.name == "vllm":
206+
assert lora_rank <= 512, "LoRA rank in vLLM must be less than or equal to 512"
204207

205208
print("[validate_config] All configuration checks passed successfully!")

verl/utils/megatron_peft_utils.py

Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,36 @@
1515

1616
import os
1717
from pathlib import Path
18+
from typing import Iterator
1819

1920
import torch
2021

22+
# Map megatron lora target modules to HF-style module names for vLLM
23+
MEGATRON_TO_HF_MODULES = {
24+
"linear_qkv": ["q_proj", "k_proj", "v_proj"],
25+
"linear_proj": ["o_proj"],
26+
"linear_fc1": ["gate_proj", "up_proj"],
27+
"linear_fc2": ["down_proj"],
28+
# Canonical LoRA mappings
29+
"linear_q": ["q_proj"],
30+
"linear_k": ["k_proj"],
31+
"linear_v": ["v_proj"],
32+
"linear_fc1_up": ["up_proj"],
33+
"linear_fc1_gate": ["gate_proj"],
34+
}
35+
36+
# Modules with stacked parameters that need .base_layer suffix in vLLM
37+
stacked_params = [
38+
".q_proj",
39+
".k_proj",
40+
".v_proj",
41+
".o_proj",
42+
".gate_proj",
43+
".up_proj",
44+
".down_proj",
45+
".mlp.gate",
46+
]
47+
2148

2249
def _get_rank_checkpoint_path(base_path: str) -> str:
2350
"""Get rank-specific checkpoint path following Megatron's convention.
@@ -224,10 +251,70 @@ def print_adapter_info(model):
224251
print(f"{'=' * 60}\n")
225252

226253

254+
def convert_megatron_to_hf_target_modules(megatron_modules: list[str]) -> list[str]:
255+
"""Convert megatron lora target modules to HF-style module names.
256+
257+
Args:
258+
megatron_modules: List of megatron-style module names.
259+
260+
Returns:
261+
List of HF-style module names with duplicates removed.
262+
"""
263+
hf_target_modules = []
264+
for module in megatron_modules:
265+
if module in MEGATRON_TO_HF_MODULES:
266+
hf_target_modules.extend(MEGATRON_TO_HF_MODULES[module])
267+
else:
268+
hf_target_modules.append(module)
269+
# Remove duplicates while preserving order
270+
return list(dict.fromkeys(hf_target_modules))
271+
272+
273+
def build_peft_config_for_vllm(lora_config: dict) -> dict:
274+
"""Build a peft_config dict compatible with vLLM's PEFTHelper from megatron lora config.
275+
276+
Args:
277+
lora_config: Megatron lora configuration dictionary.
278+
279+
Returns:
280+
A dictionary compatible with vLLM's PEFTHelper.from_dict().
281+
"""
282+
from peft import TaskType
283+
284+
target_modules = lora_config.get("target_modules", ["linear_qkv", "linear_proj", "linear_fc1", "linear_fc2"])
285+
exclude_modules = lora_config.get("exclude_modules", [])
286+
hf_target_modules = convert_megatron_to_hf_target_modules(target_modules)
287+
hf_exclude_modules = convert_megatron_to_hf_target_modules(exclude_modules)
288+
289+
return {
290+
"task_type": TaskType.CAUSAL_LM,
291+
"r": lora_config.get("rank", 0),
292+
"lora_alpha": lora_config.get("alpha", 32),
293+
"target_modules": hf_target_modules,
294+
"exclude_modules": hf_exclude_modules,
295+
"bias": "none",
296+
"lora_dropout": lora_config.get("dropout", 0.0),
297+
}
298+
299+
300+
# vLLM needs to target all-linear no matter about specific LoRA config
301+
def add_base_layer_suffix(params: Iterator[tuple[str, torch.Tensor]]) -> Iterator[tuple[str, torch.Tensor]]:
302+
for name, param in params:
303+
if name.endswith(".weight"):
304+
module_k = name[: -len(".weight")]
305+
if not module_k.endswith(".base_layer") and any(module_k.endswith(suffix) for suffix in stacked_params):
306+
yield f"{module_k}.base_layer.weight", param
307+
continue
308+
yield name, param
309+
310+
227311
__all__ = [
228312
"get_adapter_state_dict",
229313
"save_adapter_checkpoint",
230314
"load_adapter_checkpoint",
231315
"count_adapter_parameters",
232316
"print_adapter_info",
317+
"convert_megatron_to_hf_target_modules",
318+
"build_peft_config_for_vllm",
319+
"add_base_layer_suffix",
233320
]

verl/workers/engine/megatron/transformer_impl.py

Lines changed: 16 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -32,28 +32,20 @@
3232
from verl.utils.debug import log_gpu_memory_usage
3333
from verl.utils.device import get_device_id, get_device_name
3434
from verl.utils.megatron.pipeline_parallel import make_batch_generator
35-
from verl.utils.megatron.tensor_parallel import (
36-
vocab_parallel_entropy,
37-
vocab_parallel_log_probs_from_logits,
38-
)
35+
from verl.utils.megatron.tensor_parallel import vocab_parallel_entropy, vocab_parallel_log_probs_from_logits
36+
from verl.utils.megatron_peft_utils import build_peft_config_for_vllm
3937
from verl.utils.megatron_utils import (
4038
load_megatron_model_to_gpu,
4139
load_megatron_optimizer,
4240
offload_megatron_model_to_cpu,
4341
offload_megatron_optimizer,
4442
register_megatron_training_hooks,
4543
)
46-
from verl.utils.model import (
47-
extract_multi_modal_inputs,
48-
load_mcore_dist_weights,
49-
)
44+
from verl.utils.model import extract_multi_modal_inputs, load_mcore_dist_weights
5045
from verl.workers.config import HFModelConfig, McoreEngineConfig, McoreOptimizerConfig
5146

5247
from ..base import BaseEngine, BaseEngineCtx, EngineRegistry
53-
from ..utils import (
54-
postprocess_batch_func,
55-
prepare_micro_batches,
56-
)
48+
from ..utils import postprocess_batch_func, prepare_micro_batches
5749
from .utils import set_random_seed
5850

5951
logger = logging.getLogger(__file__)
@@ -180,10 +172,7 @@ def _build_tf_config(self):
180172
)
181173

182174
def _build_megatron_module(self):
183-
from verl.utils.megatron_utils import (
184-
McoreModuleWrapperConfig,
185-
make_megatron_module,
186-
)
175+
from verl.utils.megatron_utils import McoreModuleWrapperConfig, make_megatron_module
187176
from verl.utils.model import print_model_size
188177

189178
# TODO: add more cases
@@ -238,10 +227,7 @@ def _build_megatron_module(self):
238227
return module
239228

240229
def _build_optimizer(self):
241-
from verl.utils.megatron.optimizer import (
242-
get_megatron_optimizer,
243-
init_megatron_optim_config,
244-
)
230+
from verl.utils.megatron.optimizer import get_megatron_optimizer, init_megatron_optim_config
245231

246232
optim_config_megatron = init_megatron_optim_config(
247233
self.optimizer_config,
@@ -538,9 +524,16 @@ def forward_backward_batch(self, data: TensorDict, loss_function: Callable, forw
538524
def get_per_tensor_param(self):
539525
if self._is_offload_param:
540526
load_megatron_model_to_gpu(self.module, load_grad=False)
541-
per_tensor_param = self.bridge.export_weights(self.module)
542-
# TODO: support megatron LoRA
543-
return per_tensor_param, None
527+
peft_config = None
528+
if self.vanilla_bridge:
529+
per_tensor_param = self.bridge.export_weights(self.module)
530+
elif self.peft_cls is not None:
531+
# Only export adapter weights
532+
peft_config = build_peft_config_for_vllm(self.model_config.lora)
533+
per_tensor_param = self.bridge.export_adapter_weights(self.module)
534+
else:
535+
per_tensor_param = self.bridge.export_hf_weights(self.module)
536+
return per_tensor_param, peft_config
544537

545538
def forward_step(self, batch_iter, model, postprocess_micro_batch_func):
546539
raise NotImplementedError("forward_step must be implemented in subclass")

verl/workers/megatron_workers.py

Lines changed: 36 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@
5252
from verl.utils.flops_counter import FlopsCounter
5353
from verl.utils.fs import copy_to_local
5454
from verl.utils.megatron.router_replay_patch import RouterReplay, RouterReplayAction, apply_router_replay_patch
55+
from verl.utils.megatron_peft_utils import add_base_layer_suffix, build_peft_config_for_vllm
5556
from verl.utils.megatron_utils import (
5657
load_megatron_model_to_gpu,
5758
load_megatron_optimizer,
@@ -329,6 +330,9 @@ def __init__(self, config: DictConfig, role: str, **kwargs):
329330
self._is_offload_grad = False
330331
self._is_offload_optimizer = False
331332

333+
# Initialize LoRA-related attributes (will be updated in _build_rollout if needed)
334+
self.base_sync_done = False
335+
332336
# normalize config
333337
if self._is_actor:
334338
self.config.actor.ppo_mini_batch_size *= self.config.rollout.n
@@ -488,14 +492,7 @@ def _build_rollout(self, trust_remote_code=False):
488492

489493
# 1. parse rollout and huggingface model config
490494
rollout_config: RolloutConfig = omega_conf_to_dataclass(self.config.rollout)
491-
492-
# Convert megatron lora config to HFModelConfig
493-
model_config_dict = OmegaConf.to_container(self.config.model)
494-
model_config_dict.pop("lora", None)
495-
496-
model_config: HFModelConfig = omega_conf_to_dataclass(
497-
OmegaConf.create(model_config_dict), dataclass_type=HFModelConfig
498-
)
495+
model_config: HFModelConfig = omega_conf_to_dataclass(self.config.model)
499496

500497
# 2. build rollout device mesh
501498
infer_tp = self.config.rollout.tensor_model_parallel_size * self.config.rollout.data_parallel_size
@@ -531,6 +528,9 @@ def _build_rollout(self, trust_remote_code=False):
531528
)
532529
log_gpu_memory_usage(f"After building {self.config.rollout.name} rollout", logger=logger)
533530

531+
# Initialize base_sync_done for LoRA
532+
self.base_sync_done: bool = "dummy" not in self.config.rollout.load_format
533+
534534
# 5. switch to trainer mode
535535
# NOTE: It's critical that hybrid engine in trainer mode initially to load checkpoint.
536536
# For sync mode, we directly switch to trainer mode here.
@@ -673,9 +673,24 @@ async def rollout_mode(self):
673673
load_megatron_model_to_gpu(self.actor.actor_module, load_grad=False)
674674
log_gpu_memory_usage("After load actor params during rollout_mode", logger=logger)
675675

676+
# Build peft_config for vLLM LoRA support
677+
peft_config = None
678+
do_lora_base_sync = False
679+
if self.peft_cls is not None:
680+
peft_config = build_peft_config_for_vllm(self.config.model.get("lora", {}))
681+
# set sleep level for LoRA adapter weights only sync
682+
# TODO: make this configurable so that users with small
683+
# main memory can trade sync time to avoid OOM
684+
self.rollout.sleep_level = 1
685+
686+
do_lora_base_sync = not self.base_sync_done or self.rollout.sleep_level != 1
687+
676688
if self.bridge is not None:
677689
if self.vanilla_bridge:
678690
per_tensor_param = self.bridge.export_weights(self.actor.actor_module)
691+
elif self.peft_cls is not None:
692+
# Only export adapter weights
693+
per_tensor_param = self.bridge.export_adapter_weights(self.actor.actor_module)
679694
else:
680695
per_tensor_param = self.bridge.export_hf_weights(self.actor.actor_module)
681696
else:
@@ -689,7 +704,19 @@ async def rollout_mode(self):
689704

690705
if self.config.rollout.free_cache_engine:
691706
await self.rollout.resume(tags=["weights"])
692-
await self.rollout.update_weights(per_tensor_param)
707+
if do_lora_base_sync:
708+
# Base layer sync
709+
per_tensor_param_lora_base = self.bridge.export_hf_weights(
710+
self.actor.actor_module, merge_adapter_weights=False
711+
)
712+
await self.rollout.update_weights(
713+
add_base_layer_suffix(per_tensor_param_lora_base), peft_config=peft_config, base_sync_done=False
714+
)
715+
716+
# Mark base sync as done after first successful sync
717+
self.base_sync_done = True
718+
719+
await self.rollout.update_weights(per_tensor_param, peft_config=peft_config, base_sync_done=True)
693720
if self._is_offload_param:
694721
offload_megatron_model_to_cpu(self.actor.actor_module)
695722
aggressive_empty_cache(force_sync=True)

verl/workers/rollout/vllm_rollout/vllm_async_server.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -29,10 +29,7 @@
2929
from ray.actor import ActorHandle
3030
from vllm import SamplingParams
3131
from vllm.engine.arg_utils import AsyncEngineArgs
32-
from vllm.entrypoints.openai.api_server import (
33-
build_app,
34-
init_app_state,
35-
)
32+
from vllm.entrypoints.openai.api_server import build_app, init_app_state
3633
from vllm.inputs import TokensPrompt
3734
from vllm.lora.request import LoRARequest
3835
from vllm.outputs import RequestOutput
@@ -331,12 +328,15 @@ async def launch_server(self, master_address: str = None, master_port: int = Non
331328
)
332329

333330
# update lora-related args
334-
if self.model_config.lora_rank > 0:
331+
lora_rank = self.model_config.lora.get("rank", 0)
332+
if lora_rank <= 0:
333+
lora_rank = self.model_config.lora_rank
334+
if lora_rank > 0:
335335
args.update(
336336
{
337337
"enable_lora": True,
338338
"max_loras": 1,
339-
"max_lora_rank": get_vllm_max_lora_rank(self.model_config.lora_rank),
339+
"max_lora_rank": get_vllm_max_lora_rank(lora_rank),
340340
}
341341
)
342342

@@ -462,7 +462,7 @@ async def generate(
462462

463463
# Add lora request
464464
lora_request = None
465-
if self.model_config.lora_rank > 0:
465+
if self.model_config.lora_rank > 0 or self.model_config.lora.get("rank", 0) > 0:
466466
# Make sure we also check that the lora is already loaded in the engine
467467
lora_loaded = VLLM_LORA_INT_ID in await self.engine.list_loras()
468468
if lora_loaded:

0 commit comments

Comments
 (0)