Skip to content

[feat] Add LoRA#592

Merged
hiyouga merged 37 commits intohiyouga:mainfrom
rank-Yu:main
Jan 29, 2026
Merged

[feat] Add LoRA#592
hiyouga merged 37 commits intohiyouga:mainfrom
rank-Yu:main

Conversation

@rank-Yu
Copy link
Copy Markdown
Contributor

@rank-Yu rank-Yu commented Dec 31, 2025

This PR introduces LoRA support.

Here's the test script used:

#!/bin/bash

set -x

MODEL_PATH=Qwen/Qwen3-0.6B  # replace it with your local file path

python3 -m verl.trainer.main \
    config=examples/config.yaml \
    worker.actor.model.model_path=${MODEL_PATH} \
    trainer.experiment_name=qwen3_0.6b_math_grpo_lora \
    worker.rollout.n=4 \
    worker.rollout.tensor_parallel_size=1 \
    trainer.total_epochs=4 \
    trainer.n_gpus_per_node=1 \
    worker.actor.model.lora.rank=16 \
    worker.actor.model.lora.alpha=16

This run is over a 4090, which observes acc raise:

da2822ea-5d01-4cb2-acec-b4f8bb4e2a61

@gemini-code-assist
Copy link
Copy Markdown
Contributor

Summary of Changes

Hello @rank-Yu, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request significantly enhances the system's capabilities by introducing comprehensive LoRA support for Large Language Models. It enables efficient fine-tuning of LLMs by integrating the PEFT library and adapting FSDP for optimal distributed training. A key innovation is the custom handling of LoRA weights within vLLM, allowing direct tensor loading for seamless model synchronization. This feature aims to reduce computational overhead and memory footprint during LLM fine-tuning and inference, ultimately leading to faster experimentation and deployment.

Highlights

  • LoRA Configuration: Introduced new configuration options for LoRA (Low-Rank Adaptation) within the ModelConfig to specify rank, alpha, and target_modules for fine-tuning Large Language Models (LLMs).
  • PEFT Integration: Integrated the PEFT (Parameter-Efficient Fine-Tuning) library to apply LoRA to the actor model, enabling memory-efficient fine-tuning by only training a small number of additional parameters.
  • FSDP Compatibility for LoRA: Enhanced FSDP (Fully Sharded Data Parallel) wrapping policies to correctly identify and shard LoRA modules, ensuring efficient distributed training with LoRA-enabled models. This includes version-dependent FSDP imports for broader compatibility.
  • vLLM LoRA Tensor Loading: Implemented a custom mechanism to load LoRA weights directly from tensors into vLLM, bypassing its default file-based loading. This allows for dynamic synchronization of LoRA tensors between the actor model and the vLLM inference engine.
  • Reference Model Log Probability Calculation: Added logic to compute reference model log probabilities by temporarily disabling the LoRA adapter on the actor model, ensuring that the baseline (non-LoRA) model's output is used for comparison.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This PR introduces LoRA support for LLM, a valuable enhancement. The implementation correctly leverages peft for adapter management and adapts the FSDP strategy accordingly. A clever optimization is using the base model as a reference model during LoRA training to conserve memory. However, the integration with vLLM for rollouts relies on monkey-patching and includes some fragile logic for weight synchronization, particularly with hardcoded strings and parameter name manipulation, which could impact future maintainability. I've provided specific suggestions to improve these areas.

Comment thread verl/workers/sharding_manager/fsdp_vllm.py Outdated
Comment thread verl/utils/vllm_utils.py Outdated
Comment thread verl/workers/rollout/vllm_rollout_spmd.py
Comment thread verl/workers/sharding_manager/fsdp_vllm.py Outdated
Comment thread verl/workers/sharding_manager/fsdp_vllm.py Outdated
@hiyouga
Copy link
Copy Markdown
Owner

hiyouga commented Dec 31, 2025

@Kuangdd01 Could you please verify the performance on a 7B model?

@Kuangdd01
Copy link
Copy Markdown
Contributor

🫡 Okay. will report result with rank=128 lr=1e-5 on Qwen2.5-7B-Inst here.

@Kuangdd01
Copy link
Copy Markdown
Contributor

Kuangdd01 commented Dec 31, 2025

THX AND HAPPY NEW YEAR!
I encountered an issue when do lora tuning with multiple gpus.
E,g. when do load weights in FSDPVLLMShardingManager, some weight tensors seemed not be DTensor class with full_tensor() method.

def _make_weight_iterator(
    self, actor_weights: dict[str, Union[torch.Tensor, DTensor]]
) -> Iterable[tuple[str, torch.Tensor]]:
    for name, tensor in actor_weights.items():
        yield name, tensor.full_tensor() if self.world_size != 1 else tensor

Although i don't know what happend here, I add a if-condition to avoid it. Can @rank-Yu explain what happened here?

def _make_weight_iterator(
    self, actor_weights: dict[str, Union[torch.Tensor, DTensor]]
) -> Iterable[tuple[str, torch.Tensor]]:
    items = actor_weights.items() if isinstance(actor_weights, dict) else actor_weights
    for name, tensor in items:
        if not isinstance(name, str):
            name = str(name)

        if hasattr(tensor, "full_tensor"):
            weight = tensor.full_tensor()
        else:
            weight = tensor  # torch.Tensor

        yield name, weight

@hiyouga
Copy link
Copy Markdown
Owner

hiyouga commented Dec 31, 2025

@Kuangdd01 @rank-Yu we need to first apply lora model then do fsdp sharding to have all params in dtensor

@rank-Yu
Copy link
Copy Markdown
Contributor Author

rank-Yu commented Jan 6, 2026

THX AND HAPPY NEW YEAR! I encountered an issue when do lora tuning with multiple gpus. E,g. when do load weights in FSDPVLLMShardingManager, some weight tensors seemed not be DTensor class with full_tensor() method.

def _make_weight_iterator(
    self, actor_weights: dict[str, Union[torch.Tensor, DTensor]]
) -> Iterable[tuple[str, torch.Tensor]]:
    for name, tensor in actor_weights.items():
        yield name, tensor.full_tensor() if self.world_size != 1 else tensor

Although i don't know what happend here, I add a if-condition to avoid it. Can @rank-Yu explain what happened here?

def _make_weight_iterator(
    self, actor_weights: dict[str, Union[torch.Tensor, DTensor]]
) -> Iterable[tuple[str, torch.Tensor]]:
    items = actor_weights.items() if isinstance(actor_weights, dict) else actor_weights
    for name, tensor in items:
        if not isinstance(name, str):
            name = str(name)

        if hasattr(tensor, "full_tensor"):
            weight = tensor.full_tensor()
        else:
            weight = tensor  # torch.Tensor

        yield name, weight

Thanks! This can happen because actor_weights is not guaranteed to be all DTensor even when world_size != 1.

In our LoRA path (e.g. _collect_lora_params()), we often detach().cpu() the tensors or materialize them during summon_full_params, so the resulting weights can be plain torch.Tensor (typically on CPU) rather than DTensor. In that case, gating on world_size and calling .full_tensor() will fail.

So instead of using world_size as the condition, I updated the iterator to only materialize when the tensor is actually a DTensor, and also move it to the current CUDA device before full_tensor() to better support FSDP + CPU offload:

def _make_weight_iterator(
    self, actor_weights: dict[str, Union[torch.Tensor, DTensor]]
) -> Iterable[tuple[str, torch.Tensor]]:
    device = torch.device("cuda", torch.cuda.current_device())
    for name, tensor in actor_weights.items():
        yield name, tensor.to(device, non_blocking=True).full_tensor() if isinstance(tensor, DTensor) else tensor

Your hasattr(full_tensor) workaround is also safe; I just kept the check stricter (isinstance(DTensor)) and aligned device placement for the DTensor case.

Comment thread verl/utils/fsdp_utils.py Outdated
Comment thread verl/utils/fsdp_utils.py Outdated
Comment thread verl/utils/vllm_utils.py
Comment thread verl/workers/actor/config.py Outdated
Comment thread verl/workers/actor/config.py Outdated
Comment thread verl/workers/fsdp_workers.py Outdated
Comment thread verl/workers/fsdp_workers.py
Comment thread verl/workers/fsdp_workers.py Outdated
Comment thread verl/workers/fsdp_workers.py Outdated
Comment thread verl/workers/sharding_manager/fsdp_vllm.py Outdated
@Kuangdd01
Copy link
Copy Markdown
Contributor

Qwen2.5-7B-Instruct Lora experiment log.

Comment thread verl/utils/fsdp_utils.py Outdated
Comment thread verl/utils/py_functional.py Outdated
@rank-Yu rank-Yu changed the title [feat] Add LoRA support for LLM [feat] Add LoRA support Jan 14, 2026
@rank-Yu rank-Yu changed the title [feat] Add LoRA support [feat] Add LoRA Jan 15, 2026
Comment thread verl/utils/vllm_utils.py
Copy link
Copy Markdown
Owner

@hiyouga hiyouga left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM!

@hiyouga hiyouga merged commit 938d4ea into hiyouga:main Jan 29, 2026
1 check passed
@hiyouga hiyouga mentioned this pull request Jan 29, 2026
@xlg-go
Copy link
Copy Markdown

xlg-go commented Jan 29, 2026

Amazing!!!!!!! Excellent work~~~
@rank-Yu
@hiyouga

Starting from vllm 0.14.0, support for LoRa VIT has been added.

exclude_modules: .*visual.* # Exclude modules from applying LoRA; example: '.*visual.*' for excluding the ViT in Qwen2.5-VL, as vLLM does not support ViT LoRA

@hiyouga
Copy link
Copy Markdown
Owner

hiyouga commented Jan 29, 2026

@xlg-go Thanks for the information! We'll upgrade to vllm 0.14.0 recently

@xlg-go
Copy link
Copy Markdown

xlg-go commented Jan 29, 2026

@xlg-go Thanks for the information! We'll upgrade to vllm 0.14.0 recently

Nice work!!!!!!!
Leading verl~~~~~

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants