Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 8 additions & 6 deletions docs/advance/ppo_lora.rst
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ Make sure you use Megatron-Bridge later than 0.2.0, and we recommended using `th

1. **LoRA Implementation**: Verl Megatron backend uses Megatron-Bridge's native LoRA implementation, which differs from HuggingFace PEFT.

2. **Weight Sync Mechanism**: Currently, Megatron-Bridge syncs weights by merging LoRA adapters into the base model weights before transferring to vLLM rather than loading separate adapters. This is necessary because Megatron-Bridge's LoRA format is not directly integratable with vLLM's LoRA loading mechanism (HF PEFT format), and LoRA bridge is not yet supported.
2. **Weight Sync / Refit Mechanism**: Currently, Megatron-Bridge can support syncing weights by either merging LoRA adapters into the base model weights before transferring to vLLM (for better inference speed but more refit time and potential precision loss), as well as loading separate adapters.

**Configuration for Megatron LoRA:**

Expand All @@ -83,6 +83,9 @@ Make sure you use Megatron-Bridge later than 0.2.0, and we recommended using `th
# LoRA type: "lora", "vlm_lora", "canonical_lora", or "dora"
type: lora

# whether to sync weights / refit by either merging LoRA adapters into the base model weights before transferring to vLLM (for better inference speed but more refit time and potential precision loss). If this is False, it will load separate adapters.
merge: False

# LoRA rank (Dimension of the low-rank projection space.). Set to 0 to disable LoRA
rank: 0

Expand Down Expand Up @@ -136,12 +139,11 @@ Make sure you use Megatron-Bridge later than 0.2.0, and we recommended using `th
freeze_vision_projection: True
freeze_language_model: True

LoRA training experiment with Qwen3-8B on 8 * H200 single node comparing FSDP and Megatron backend (script adapted from examples/grpo_trainer/run_qwen2-7b_math_megatron_lora.sh):

**Current Limitations:**

1. **No HuggingFace PEFT Export**: Currently there is no built-in way to export Megatron LoRA adapters to HuggingFace PEFT format for inference with standard HF/vLLM pipelines, such support is coming soon with Megatron-Bridge `LoRA bridge <https://github.com/NVIDIA-NeMo/Megatron-Bridge/issues/1536>`_.

2. **LoRA Merge Overhead**: As we don't have LoRA bridge for now, each weight sync (refit) requires merging LoRA weights, which adds some overhead compared to direct dynamic adapter loading.
.. image:: https://github.com/user-attachments/assets/0482f423-01a3-4e52-a7ee-8b9cd79b7b1a
.. image:: https://github.com/user-attachments/assets/6ce10400-8164-47d8-90a6-c1bf002fb9e8
.. image:: https://github.com/user-attachments/assets/092d3a43-4eba-425e-a584-8d83c1f02de4


Best Practices and Notes
Expand Down
12 changes: 7 additions & 5 deletions recipe/r1_ascend/vllm_rollout_spmd.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ def __init__(
)

from recipe.r1_ascend import engine_core # noqa: F401

# NPU-ADAPTATION END

if config.layered_summon:
Expand All @@ -65,11 +66,12 @@ def __init__(
tokenizer = model_config.tokenizer
model_hf_config = model_config.hf_config
trust_remote_code = model_config.trust_remote_code
self.lora_kwargs = (
{"enable_lora": True, "max_loras": 1, "max_lora_rank": model_config.lora_rank}
if model_config.lora_rank > 0
else {}
)
lora_rank = model_config.lora.get("rank", 0)
if model_config.lora.get("merge", False):
lora_rank = 0
if lora_rank <= 0:
lora_rank = model_config.lora_rank
self.lora_kwargs = {"enable_lora": True, "max_loras": 1, "max_lora_rank": lora_rank} if lora_rank > 0 else {}

tensor_parallel_size = self.config.get("tensor_model_parallel_size", 1)
assert tensor_parallel_size <= torch.distributed.get_world_size(), (
Expand Down
1 change: 1 addition & 0 deletions verl/trainer/config/_generated_ppo_megatron_trainer.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -320,6 +320,7 @@ actor_rollout_ref:
impl_backend: torch
lora:
type: lora
merge: false
rank: 0
alpha: 32
dropout: 0.0
Expand Down
3 changes: 3 additions & 0 deletions verl/trainer/config/ppo_megatron_trainer.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,9 @@ actor_rollout_ref:
# LoRA type: "lora", "vlm_lora", "canonical_lora", or "dora"
type: lora

# whether to sync weights / refit by either merging LoRA adapters into the base model weights before transferring to vLLM (for better inference speed but more refit time and potential precision loss). If this is False, it will load separate adapters.
merge: False

# LoRA rank (Dimension of the low-rank projection space.). Set to 0 to disable LoRA
rank: 0 # typical values: 8, 16, 32, 64

Expand Down
10 changes: 8 additions & 2 deletions verl/utils/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,13 @@ def check_mutually_exclusive(mbs, mbs_per_gpu, name: str):
)

# check LoRA rank in vLLM
if config.actor_rollout_ref.model.get("lora_rank", 0) > 0 and config.actor_rollout_ref.rollout.name == "vllm":
assert config.actor_rollout_ref.model.lora_rank <= 512, "LoRA rank in vLLM must be less than or equal to 512"
lora_config = config.actor_rollout_ref.model.get("lora", {})
lora_rank = lora_config.get("rank", 0)
if lora_config.get("merge", False):
lora_rank = 0
if lora_rank <= 0:
lora_rank = config.actor_rollout_ref.model.get("lora_rank", 0)
if lora_rank > 0 and config.actor_rollout_ref.rollout.name == "vllm":
assert lora_rank <= 512, "LoRA rank in vLLM must be less than or equal to 512"

print("[validate_config] All configuration checks passed successfully!")
102 changes: 102 additions & 0 deletions verl/utils/megatron_peft_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,47 @@

import os
from pathlib import Path
from typing import Iterator

import torch

# Map megatron lora target modules to HF-style module names for vLLM
MEGATRON_TO_HF_MODULES = {
"linear_qkv": ["q_proj", "k_proj", "v_proj"],
"linear_proj": ["o_proj"],
"linear_fc1": ["gate_proj", "up_proj"],
"linear_fc2": ["down_proj"],
# Canonical LoRA mappings
"linear_q": ["q_proj"],
"linear_k": ["k_proj"],
"linear_v": ["v_proj"],
"linear_fc1_up": ["up_proj"],
"linear_fc1_gate": ["gate_proj"],
# MLA mappings
"linear_kv_down_proj": ["kv_a_proj_with_mqa"],
"linear_kv_up_proj": ["kv_b_proj"],
"linear_q_down_proj": ["q_a_proj"],
"linear_q_up_proj": ["q_b_proj"],
"linear_q_proj": ["q_proj"],
}

# Modules with stacked parameters that need .base_layer suffix in vLLM
stacked_params = [
".q_proj.weight",
".k_proj.weight",
".v_proj.weight",
".o_proj.weight",
".gate_proj.weight",
".up_proj.weight",
".down_proj.weight",
".mlp.gate.weight",
".mlp.gate.e_score_correction_bias",
".kv_a_proj_with_mqa.weight",
".kv_b_proj.weight",
".q_a_proj.weight",
".q_b_proj.weight",
]


def _get_rank_checkpoint_path(base_path: str) -> str:
"""Get rank-specific checkpoint path following Megatron's convention.
Expand Down Expand Up @@ -224,10 +262,74 @@ def print_adapter_info(model):
print(f"{'=' * 60}\n")


def convert_megatron_to_hf_target_modules(megatron_modules: list[str]) -> list[str]:
"""Convert megatron lora target modules to HF-style module names.

Args:
megatron_modules: List of megatron-style module names.

Returns:
List of HF-style module names with duplicates removed.
"""
hf_target_modules = []
for module in megatron_modules:
if module in MEGATRON_TO_HF_MODULES:
hf_target_modules.extend(MEGATRON_TO_HF_MODULES[module])
else:
hf_target_modules.append(module)
# Remove duplicates while preserving order
return list(dict.fromkeys(hf_target_modules))


def build_peft_config_for_vllm(lora_config: dict) -> dict:
"""Build a peft_config dict compatible with vLLM's PEFTHelper from megatron lora config.

Args:
lora_config: Megatron lora configuration dictionary.

Returns:
A dictionary compatible with vLLM's PEFTHelper.from_dict().
"""
from peft import TaskType

target_modules = lora_config.get("target_modules", ["linear_qkv", "linear_proj", "linear_fc1", "linear_fc2"])
exclude_modules = lora_config.get("exclude_modules", [])
hf_target_modules = convert_megatron_to_hf_target_modules(target_modules)
hf_exclude_modules = convert_megatron_to_hf_target_modules(exclude_modules)

return {
"task_type": TaskType.CAUSAL_LM,
"r": lora_config.get("rank", 0),
"lora_alpha": lora_config.get("alpha", 32),
"target_modules": hf_target_modules,
"exclude_modules": hf_exclude_modules,
"bias": "none",
"lora_dropout": lora_config.get("dropout", 0.0),
}


# vLLM needs to target all-linear no matter about specific LoRA config
def add_base_layer_suffix(params: Iterator[tuple[str, torch.Tensor]]) -> Iterator[tuple[str, torch.Tensor]]:
"""Yield param pairs with a base-layer suffix added to the param name."""
for name, param in params:
ending_suffix = ""
for suffix in stacked_params:
if name.endswith(suffix):
ending_suffix = suffix
break
if ending_suffix:
suffix = ending_suffix.rsplit(".", 1)[-1]
name = f"{name[: -len(suffix)]}base_layer.{suffix}"
yield name, param


__all__ = [
"get_adapter_state_dict",
"save_adapter_checkpoint",
"load_adapter_checkpoint",
"count_adapter_parameters",
"print_adapter_info",
"convert_megatron_to_hf_target_modules",
"build_peft_config_for_vllm",
"add_base_layer_suffix",
]
39 changes: 16 additions & 23 deletions verl/workers/engine/megatron/transformer_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,28 +32,20 @@
from verl.utils.debug import log_gpu_memory_usage
from verl.utils.device import get_device_id, get_device_name
from verl.utils.megatron.pipeline_parallel import make_batch_generator
from verl.utils.megatron.tensor_parallel import (
vocab_parallel_entropy,
vocab_parallel_log_probs_from_logits,
)
from verl.utils.megatron.tensor_parallel import vocab_parallel_entropy, vocab_parallel_log_probs_from_logits
from verl.utils.megatron_peft_utils import build_peft_config_for_vllm
from verl.utils.megatron_utils import (
load_megatron_model_to_gpu,
load_megatron_optimizer,
offload_megatron_model_to_cpu,
offload_megatron_optimizer,
register_megatron_training_hooks,
)
from verl.utils.model import (
extract_multi_modal_inputs,
load_mcore_dist_weights,
)
from verl.utils.model import extract_multi_modal_inputs, load_mcore_dist_weights
from verl.workers.config import HFModelConfig, McoreEngineConfig, McoreOptimizerConfig

from ..base import BaseEngine, BaseEngineCtx, EngineRegistry
from ..utils import (
postprocess_batch_func,
prepare_micro_batches,
)
from ..utils import postprocess_batch_func, prepare_micro_batches
from .utils import set_random_seed

logger = logging.getLogger(__file__)
Expand Down Expand Up @@ -180,10 +172,7 @@ def _build_tf_config(self):
)

def _build_megatron_module(self):
from verl.utils.megatron_utils import (
McoreModuleWrapperConfig,
make_megatron_module,
)
from verl.utils.megatron_utils import McoreModuleWrapperConfig, make_megatron_module
from verl.utils.model import print_model_size

# TODO: add more cases
Expand Down Expand Up @@ -238,10 +227,7 @@ def _build_megatron_module(self):
return module

def _build_optimizer(self):
from verl.utils.megatron.optimizer import (
get_megatron_optimizer,
init_megatron_optim_config,
)
from verl.utils.megatron.optimizer import get_megatron_optimizer, init_megatron_optim_config

optim_config_megatron = init_megatron_optim_config(
self.optimizer_config,
Expand Down Expand Up @@ -538,9 +524,16 @@ def forward_backward_batch(self, data: TensorDict, loss_function: Callable, forw
def get_per_tensor_param(self):
if self._is_offload_param:
load_megatron_model_to_gpu(self.module, load_grad=False)
per_tensor_param = self.bridge.export_weights(self.module)
# TODO: support megatron LoRA
return per_tensor_param, None
peft_config = None
if self.vanilla_bridge:
per_tensor_param = self.bridge.export_weights(self.module)
elif not self.model_config.lora.get("merge", False) and self.peft_cls is not None:
# Only export adapter weights
peft_config = build_peft_config_for_vllm(self.model_config.lora)
per_tensor_param = self.bridge.export_adapter_weights(self.module)
else:
per_tensor_param = self.bridge.export_hf_weights(self.module)
return per_tensor_param, peft_config

def forward_step(self, batch_iter, model, postprocess_micro_batch_func):
raise NotImplementedError("forward_step must be implemented in subclass")
Expand Down
47 changes: 38 additions & 9 deletions verl/workers/megatron_workers.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@
from verl.utils.flops_counter import FlopsCounter
from verl.utils.fs import copy_to_local
from verl.utils.megatron.router_replay_patch import RouterReplay, RouterReplayAction, apply_router_replay_patch
from verl.utils.megatron_peft_utils import add_base_layer_suffix, build_peft_config_for_vllm
from verl.utils.megatron_utils import (
load_megatron_model_to_gpu,
load_megatron_optimizer,
Expand Down Expand Up @@ -329,6 +330,10 @@ def __init__(self, config: DictConfig, role: str, **kwargs):
self._is_offload_grad = False
self._is_offload_optimizer = False

# Initialize LoRA-related attributes (will be updated in _build_rollout if needed)
self.base_sync_done = False
self.peft_merge = False

# normalize config
if self._is_actor:
self.config.actor.ppo_mini_batch_size *= self.config.rollout.n
Expand Down Expand Up @@ -488,14 +493,7 @@ def _build_rollout(self, trust_remote_code=False):

# 1. parse rollout and huggingface model config
rollout_config: RolloutConfig = omega_conf_to_dataclass(self.config.rollout)

# Convert megatron lora config to HFModelConfig
model_config_dict = OmegaConf.to_container(self.config.model)
model_config_dict.pop("lora", None)

model_config: HFModelConfig = omega_conf_to_dataclass(
OmegaConf.create(model_config_dict), dataclass_type=HFModelConfig
)
model_config: HFModelConfig = omega_conf_to_dataclass(self.config.model)

# 2. build rollout device mesh
infer_tp = self.config.rollout.tensor_model_parallel_size * self.config.rollout.data_parallel_size
Expand Down Expand Up @@ -531,6 +529,10 @@ def _build_rollout(self, trust_remote_code=False):
)
log_gpu_memory_usage(f"After building {self.config.rollout.name} rollout", logger=logger)

# Initialize base_sync_done for LoRA
self.base_sync_done: bool = "dummy" not in self.config.rollout.load_format
self.peft_merge: bool = model_config.lora.get("merge", False)

# 5. switch to trainer mode
# NOTE: It's critical that hybrid engine in trainer mode initially to load checkpoint.
# For async mode, we can't call run_until_complete here, so we will switch to trainer mode in AgentLoopManager.
Expand Down Expand Up @@ -670,9 +672,24 @@ async def rollout_mode(self):
load_megatron_model_to_gpu(self.actor.actor_module, load_grad=False)
log_gpu_memory_usage("After load actor params during rollout_mode", logger=logger)

# Build peft_config for vLLM LoRA support
peft_config = None
do_lora_base_sync = False
if not self.peft_merge and self.peft_cls is not None:
peft_config = build_peft_config_for_vllm(self.config.model.get("lora", {}))
# set sleep level for LoRA adapter weights only sync
# TODO: make this configurable so that users with small
# main memory can trade sync time to avoid OOM
self.rollout.sleep_level = 1

do_lora_base_sync = not self.base_sync_done or self.rollout.sleep_level != 1

if self.bridge is not None:
if self.vanilla_bridge:
per_tensor_param = self.bridge.export_weights(self.actor.actor_module)
elif not self.peft_merge and self.peft_cls is not None:
# Only export adapter weights
per_tensor_param = self.bridge.export_adapter_weights(self.actor.actor_module)
else:
per_tensor_param = self.bridge.export_hf_weights(self.actor.actor_module)
else:
Expand All @@ -686,7 +703,19 @@ async def rollout_mode(self):

if self.config.rollout.free_cache_engine:
await self.rollout.resume(tags=["weights"])
await self.rollout.update_weights(per_tensor_param)
if do_lora_base_sync:
# Base layer sync
per_tensor_param_lora_base = self.bridge.export_hf_weights(
self.actor.actor_module, merge_adapter_weights=False
)
await self.rollout.update_weights(
add_base_layer_suffix(per_tensor_param_lora_base), peft_config=peft_config, base_sync_done=False
)

# Mark base sync as done after first successful sync
self.base_sync_done = True

await self.rollout.update_weights(per_tensor_param, peft_config=peft_config, base_sync_done=True)
if self._is_offload_param:
offload_megatron_model_to_cpu(self.actor.actor_module)
aggressive_empty_cache(force_sync=True)
Expand Down
Loading
Loading