Skip to content

Commit ad8ea7f

Browse files
authored
[BugFix] Fix FP32 override registration path in vLLM plugin (#3861)
1 parent 393042c commit ad8ea7f

4 files changed

Lines changed: 35 additions & 8 deletions

File tree

test/llm/test_vllm_plugin.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
#
3+
# This source code is licensed under the MIT license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
from __future__ import annotations
6+
7+
import importlib
8+
9+
import pytest
10+
11+
from torchrl.modules.llm.backends.vllm.vllm_plugin import FP32_MODEL_OVERRIDES
12+
13+
14+
@pytest.mark.parametrize("arch", sorted(FP32_MODEL_OVERRIDES))
15+
def test_fp32_override_paths_importable(arch):
16+
"""Every registered override must point at an importable class.
17+
18+
vLLM resolves these "module.path:ClassName" strings lazily, so a stale
19+
path is only discovered at server startup when vLLM inspects the
20+
architecture. This test does not require vLLM: ``_models`` falls back to
21+
placeholder classes when vLLM is absent, keeping the import path valid.
22+
"""
23+
module_path, _, class_name = FP32_MODEL_OVERRIDES[arch].partition(":")
24+
module = importlib.import_module(module_path)
25+
assert hasattr(module, class_name), FP32_MODEL_OVERRIDES[arch]

torchrl/modules/llm/backends/vllm/vllm_async.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1928,7 +1928,7 @@ def make_async_vllm_engine(
19281928
compile (bool, optional): Whether to enable model compilation for better performance. Defaults to True.
19291929
enable_fp32_output (bool, optional): Whether to enable FP32 output for the final layer. Defaults to False.
19301930
This can help with numerical stability for certain models. Requires model-specific support in
1931-
torchrl.modules.llm.backends._models.
1931+
torchrl.modules.llm.backends.vllm._models.
19321932
tensor_parallel_size (int, optional): Number of devices to use, per replica. Defaults to None.
19331933
data_parallel_size (int, optional): Number of data parallel groups to use. Defaults to None.
19341934
pipeline_parallel_size (int, optional): Number of pipeline parallel groups to use. Defaults to None.

torchrl/modules/llm/backends/vllm/vllm_plugin.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -7,16 +7,18 @@
77

88
from torchrl._utils import logger
99

10+
# Architecture name -> "module.path:ClassName" overrides registered with vLLM.
11+
# Each path must stay importable; test_vllm_plugin.py guards against drift.
12+
FP32_MODEL_OVERRIDES: dict[str, str] = {
13+
"Qwen3ForCausalLM": "torchrl.modules.llm.backends.vllm._models:Qwen3ForCausalLMFP32",
14+
}
15+
1016

1117
def register_fp32_overrides() -> None:
1218
"""Register FP32 overrides for vLLM models."""
1319
from vllm.model_executor.models.registry import ModelRegistry
1420

15-
# ======= Register models here =======
16-
# Register Qwen3 models with FP32 override
17-
ModelRegistry.register_model(
18-
"Qwen3ForCausalLM",
19-
"torchrl.modules.llm.backends._models:Qwen3ForCausalLMFP32",
20-
)
21+
for arch, model_cls_path in FP32_MODEL_OVERRIDES.items():
22+
ModelRegistry.register_model(arch, model_cls_path)
2123

2224
logger.info("Registered Qwen3 FP32 model overrides")

torchrl/modules/llm/backends/vllm/vllm_sync.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -402,7 +402,7 @@ def make_vllm_worker(
402402
enforce_eager (bool, optional): Whether to enforce eager execution. Defaults to `False`.
403403
enable_fp32_output (bool, optional): Whether to enable FP32 output for the final layer. Defaults to False.
404404
This can help with numerical stability for certain models. Requires model-specific support in
405-
torchrl.modules.llm.backends._models.
405+
torchrl.modules.llm.backends.vllm._models.
406406
**kwargs: Additional arguments passed to vLLM.LLM.__init__.
407407
408408
Returns:

0 commit comments

Comments
 (0)