Skip to content

Commit 89c6db2

Browse files
committed
fix: improve vllm patch gating
1 parent c047cca commit 89c6db2

2 files changed

Lines changed: 77 additions & 84 deletions

File tree

verl/utils/vllm_utils.py

Lines changed: 76 additions & 82 deletions
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,11 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
from importlib.metadata import version
1516
from typing import List
1617

1718
from msgspec import field
19+
from packaging import version as vs
1820
from vllm.lora.models import LoRAModel
1921
from vllm.lora.request import LoRARequest
2022
from vllm.lora.utils import get_adapter_absolute_path
@@ -29,79 +31,73 @@ class TensorLoRARequest(LoRARequest):
2931
class VLLMHijack:
3032
@staticmethod
3133
def hijack():
32-
def do_hijack(target_cls, target_method_name, hooking_method):
33-
setattr(target_cls, target_method_name, hooking_method)
34-
3534
def hijack__load_adapter(self, lora_request: TensorLoRARequest) -> LoRAModel:
3635
"""
3736
based on vllm.lora.worker_manager.WorkerLoRAManager._load_adapter, support load adapter with lora tensors
3837
Reason:
3938
VLLM does not support adding LoRA from tensors directly. It only supports adding LoRA via file paths.
4039
To synchronize the LoRA tensors of the actor model, we need to find a workaround to enable VLLM to load memory-based LoRA tensors.
4140
"""
42-
try:
43-
supported_lora_modules = self._adapter_manager.supported_lora_modules
44-
packed_modules_mapping = self._adapter_manager.packed_modules_mapping
45-
expected_lora_modules: List[str] = []
46-
for module in supported_lora_modules:
47-
if module in packed_modules_mapping:
48-
expected_lora_modules.extend(packed_modules_mapping[module])
49-
else:
50-
expected_lora_modules.append(module)
51-
52-
expected_lora_modules = list(set(expected_lora_modules))
53-
54-
lora_tensors = None
55-
from vllm.lora.peft_helper import PEFTHelper
56-
57-
if isinstance(lora_request, TensorLoRARequest):
58-
peft_config = lora_request.peft_config
59-
lora_tensors = lora_request.lora_tensors
60-
peft_helper = PEFTHelper.from_dict(peft_config)
61-
else:
62-
lora_path = get_adapter_absolute_path(lora_request.lora_path)
63-
64-
peft_helper = PEFTHelper.from_local_dir(lora_path, self.max_position_embeddings)
65-
66-
# Validates the LoRA configuration against requirements before
67-
# loading weights, throwing an exception if validation fails.
68-
peft_helper.validate_legal(self.lora_config)
69-
70-
# For some models like Qwen2VL, we need to use hf_to_vllm_mapper
71-
# to ensure correct loading of lora weights.
72-
model = self._adapter_manager.model
73-
hf_to_vllm_mapper = None
74-
if hasattr(model, "hf_to_vllm_mapper") and model.hf_to_vllm_mapper is not None:
75-
hf_to_vllm_mapper = model.hf_to_vllm_mapper
76-
77-
if isinstance(lora_request, TensorLoRARequest):
78-
lora = self._lora_model_cls.from_lora_tensors(
79-
lora_model_id=lora_request.lora_int_id,
80-
tensors=lora_tensors,
81-
peft_helper=peft_helper,
82-
device="cpu",
83-
dtype=self.lora_config.lora_dtype,
84-
embeddings=None,
85-
target_embedding_padding=self.vocab_size + self.lora_config.lora_extra_vocab_size,
86-
embedding_modules=self.embedding_modules,
87-
embedding_padding_modules=self.embedding_padding_modules,
88-
weights_mapper=hf_to_vllm_mapper,
89-
)
41+
supported_lora_modules = self._adapter_manager.supported_lora_modules
42+
packed_modules_mapping = self._adapter_manager.packed_modules_mapping
43+
expected_lora_modules: List[str] = []
44+
for module in supported_lora_modules:
45+
if module in packed_modules_mapping:
46+
expected_lora_modules.extend(packed_modules_mapping[module])
9047
else:
91-
lora = self._lora_model_cls.from_local_checkpoint(
92-
lora_path,
93-
expected_lora_modules,
94-
peft_helper=peft_helper,
95-
lora_model_id=lora_request.lora_int_id,
96-
device="cpu",
97-
dtype=self.lora_config.lora_dtype,
98-
target_embedding_padding=self.vocab_size + self.lora_config.lora_extra_vocab_size,
99-
embedding_modules=self.embedding_modules,
100-
embedding_padding_modules=self.embedding_padding_modules,
101-
weights_mapper=hf_to_vllm_mapper,
102-
)
103-
except Exception as e:
104-
raise e
48+
expected_lora_modules.append(module)
49+
50+
expected_lora_modules = list(set(expected_lora_modules))
51+
52+
lora_tensors = None
53+
from vllm.lora.peft_helper import PEFTHelper
54+
55+
if isinstance(lora_request, TensorLoRARequest):
56+
peft_config = lora_request.peft_config
57+
lora_tensors = lora_request.lora_tensors
58+
peft_helper = PEFTHelper.from_dict(peft_config)
59+
else:
60+
lora_path = get_adapter_absolute_path(lora_request.lora_path)
61+
62+
peft_helper = PEFTHelper.from_local_dir(lora_path, self.max_position_embeddings)
63+
64+
# Validates the LoRA configuration against requirements before
65+
# loading weights, throwing an exception if validation fails.
66+
peft_helper.validate_legal(self.lora_config)
67+
68+
# For some models like Qwen2VL, we need to use hf_to_vllm_mapper
69+
# to ensure correct loading of lora weights.
70+
model = self._adapter_manager.model
71+
hf_to_vllm_mapper = None
72+
if hasattr(model, "hf_to_vllm_mapper") and model.hf_to_vllm_mapper is not None:
73+
hf_to_vllm_mapper = model.hf_to_vllm_mapper
74+
75+
if isinstance(lora_request, TensorLoRARequest):
76+
lora = self._lora_model_cls.from_lora_tensors(
77+
lora_model_id=lora_request.lora_int_id,
78+
tensors=lora_tensors,
79+
peft_helper=peft_helper,
80+
device="cpu",
81+
dtype=self.lora_config.lora_dtype,
82+
embeddings=None,
83+
target_embedding_padding=self.vocab_size + self.lora_config.lora_extra_vocab_size,
84+
embedding_modules=self.embedding_modules,
85+
embedding_padding_modules=self.embedding_padding_modules,
86+
weights_mapper=hf_to_vllm_mapper,
87+
)
88+
else:
89+
lora = self._lora_model_cls.from_local_checkpoint(
90+
lora_path,
91+
expected_lora_modules,
92+
peft_helper=peft_helper,
93+
lora_model_id=lora_request.lora_int_id,
94+
device="cpu",
95+
dtype=self.lora_config.lora_dtype,
96+
target_embedding_padding=self.vocab_size + self.lora_config.lora_extra_vocab_size,
97+
embedding_modules=self.embedding_modules,
98+
embedding_padding_modules=self.embedding_padding_modules,
99+
weights_mapper=hf_to_vllm_mapper,
100+
)
105101

106102
if lora.extra_vocab_size > self.lora_config.lora_extra_vocab_size:
107103
raise ValueError(
@@ -111,25 +107,23 @@ def hijack__load_adapter(self, lora_request: TensorLoRARequest) -> LoRAModel:
111107
)
112108
return lora
113109

114-
do_hijack(LRUCacheWorkerLoRAManager, "_load_adapter", hijack__load_adapter)
110+
setattr(LRUCacheWorkerLoRAManager, "_load_adapter", hijack__load_adapter)
115111

116-
try:
112+
if vs.parse(version("vllm")).base_version == "0.11.0":
117113
from vllm.model_executor.models.module_mapping import MultiModelKeys
118114
from vllm.model_executor.models.qwen3_vl import Qwen3VLForConditionalGeneration
119-
except Exception:
120-
return
121115

122-
def hijack__get_mm_mapping(self) -> MultiModelKeys:
123-
"""
124-
Patch vllm.model_executor.models.qwen3_vl.Qwen3VLForConditionalGeneration.get_mm_mapping in vLLM 0.11.0
125-
Reason:
126-
vLLM 0.11.0 uses "model.visual.*" prefixes for Qwen3-VL, but the real module names are "visual.*".
127-
This breaks LoRA filtering for multimodal parts, so we align the prefixes to the real module names.
128-
"""
129-
return MultiModelKeys.from_string_field(
130-
language_model="language_model",
131-
connector="visual.merger.",
132-
tower_model="visual.",
133-
)
116+
def hijack__get_mm_mapping(self) -> MultiModelKeys:
117+
"""
118+
Patch vllm.model_executor.models.qwen3_vl.Qwen3VLForConditionalGeneration.get_mm_mapping in vLLM 0.11.0
119+
Reason:
120+
vLLM 0.11.0 uses "model.visual.*" prefixes for Qwen3-VL, but the real module names are "visual.*".
121+
This breaks LoRA filtering for multimodal parts, so we align the prefixes to the real module names.
122+
"""
123+
return MultiModelKeys.from_string_field(
124+
language_model="language_model",
125+
connector="visual.merger.",
126+
tower_model="visual.",
127+
)
134128

135-
do_hijack(Qwen3VLForConditionalGeneration, "get_mm_mapping", hijack__get_mm_mapping)
129+
setattr(Qwen3VLForConditionalGeneration, "get_mm_mapping", hijack__get_mm_mapping)

verl/workers/rollout/vllm_rollout_spmd.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -109,8 +109,7 @@ def __init__(
109109
if config.limit_images:
110110
engine_kwargs["limit_mm_per_prompt"] = {"image": config.limit_images}
111111

112-
if self.lora_kwargs:
113-
VLLMHijack.hijack()
112+
VLLMHijack.hijack()
114113

115114
self.inference_engine = LLM(
116115
model=model_path,

0 commit comments

Comments
 (0)