Skip to content

Commit e797c19

Browse files
authored
[BugFix] Make the vLLM FP32 plugin opt-in so importing torchrl can't hijack a host vLLM (#3868)
1 parent ad8ea7f commit e797c19

6 files changed

Lines changed: 92 additions & 8 deletions

File tree

pyproject.toml

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -240,6 +240,10 @@ ignore-words-list = "multicat,nd,splitted,te,uncompressible,dout"
240240
first_party_detection = false
241241

242242
[project.entry-points."vllm.general_plugins"]
243-
# Ensure FP32 overrides are registered in all vLLM processes (main, workers, and
244-
# the registry subprocess) before resolving model classes.
245-
fp32_overrides = "torchrl.modules.llm.backends.vllm.vllm_plugin:register_fp32_overrides"
243+
# Auto-loaded by vLLM in every process (main, workers, and the registry
244+
# subprocess), but a NO-OP unless that process opted in via
245+
# TORCHRL_VLLM_FP32_OVERRIDES (set by torchrl's vLLM backend when
246+
# enable_fp32_output=True). The unique name avoids colliding with other projects'
247+
# vllm.general_plugins entries: vLLM keys discovered plugins by name, so a shared
248+
# name would silently drop one of them.
249+
torchrl_fp32_overrides = "torchrl.modules.llm.backends.vllm.vllm_plugin:register_fp32_overrides"

test/llm/test_vllm_plugin.py

Lines changed: 38 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,12 @@
88

99
import pytest
1010

11-
from torchrl.modules.llm.backends.vllm.vllm_plugin import FP32_MODEL_OVERRIDES
11+
from torchrl.modules.llm.backends.vllm.vllm_plugin import (
12+
FP32_MODEL_OVERRIDES,
13+
fp32_overrides_enabled,
14+
FP32_OVERRIDES_ENV_VAR,
15+
register_fp32_overrides,
16+
)
1217

1318

1419
@pytest.mark.parametrize("arch", sorted(FP32_MODEL_OVERRIDES))
@@ -23,3 +28,35 @@ def test_fp32_override_paths_importable(arch):
2328
module_path, _, class_name = FP32_MODEL_OVERRIDES[arch].partition(":")
2429
module = importlib.import_module(module_path)
2530
assert hasattr(module, class_name), FP32_MODEL_OVERRIDES[arch]
31+
32+
33+
@pytest.mark.parametrize(
34+
"value,expected",
35+
[
36+
(None, False),
37+
("0", False),
38+
("", False),
39+
("no", False),
40+
("1", True),
41+
("true", True),
42+
("True", True),
43+
("yes", True),
44+
],
45+
)
46+
def test_fp32_overrides_enabled_reads_env(monkeypatch, value, expected):
47+
if value is None:
48+
monkeypatch.delenv(FP32_OVERRIDES_ENV_VAR, raising=False)
49+
else:
50+
monkeypatch.setenv(FP32_OVERRIDES_ENV_VAR, value)
51+
assert fp32_overrides_enabled() is expected
52+
53+
54+
def test_register_fp32_overrides_is_noop_without_optin(monkeypatch):
55+
"""Without the opt-in, registration must do nothing -- and must not even
56+
import vLLM. This is what lets another project install torchrl without its
57+
vLLM ``ModelRegistry`` being mutated. Returning before the vLLM import keeps
58+
the no-op path safe on machines with no vLLM at all.
59+
"""
60+
monkeypatch.delenv(FP32_OVERRIDES_ENV_VAR, raising=False)
61+
# Must not raise even where vLLM is absent (early return precedes the import).
62+
register_fp32_overrides()

test/transforms/test_reward_transforms.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -552,9 +552,9 @@ def test_trans_parallel_env_check(self, maybe_fork_ParallelEnv):
552552
except RuntimeError:
553553
pass
554554

555-
@pytest.mark.parametrize("has_in_keys,", [True, False])
555+
@pytest.mark.parametrize("has_in_keys", [True, False])
556556
@pytest.mark.parametrize(
557-
"reset_keys,", [[("some", "nested", "reset")], ["_reset"] * 3, None]
557+
"reset_keys", [[("some", "nested", "reset")], ["_reset"] * 3, None]
558558
)
559559
def test_trans_multi_key(
560560
self, has_in_keys, reset_keys, n_workers=2, batch_size=(3, 2), max_steps=5

torchrl/modules/llm/backends/vllm/vllm_async.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828

2929
# Import RLvLLMEngine and shared utilities
3030
from .base import RLvLLMEngine
31+
from .vllm_plugin import FP32_OVERRIDES_ENV_VAR
3132

3233

3334
_has_vllm = True
@@ -1966,6 +1967,9 @@ def make_async_vllm_engine(
19661967
# Set FP32 output environment variable if requested
19671968
if enable_fp32_output:
19681969
os.environ["VLLM_ENABLE_FP32_OUTPUT"] = "1"
1970+
# Opt the engine + its child vLLM processes into torchrl's FP32 model
1971+
# overrides (the general-plugin no-ops without this).
1972+
os.environ[FP32_OVERRIDES_ENV_VAR] = "1"
19691973
torchrl_logger.info(
19701974
"Enabled FP32 output for vLLM (VLLM_ENABLE_FP32_OUTPUT=1). "
19711975
"This will use FP32 for the final output layer if the model supports it."

torchrl/modules/llm/backends/vllm/vllm_plugin.py

Lines changed: 37 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,20 +5,55 @@
55

66
from __future__ import annotations
77

8+
import os
9+
810
from torchrl._utils import logger
911

12+
# Env var that opts a vLLM process into torchrl's FP32 model overrides. torchrl's
13+
# vLLM backend sets it (enable_fp32_output=True) before the engine and its
14+
# subprocesses start, so the overrides register in every vLLM process torchrl
15+
# owns -- and in none that it does not.
16+
FP32_OVERRIDES_ENV_VAR = "TORCHRL_VLLM_FP32_OVERRIDES"
17+
1018
# Architecture name -> "module.path:ClassName" overrides registered with vLLM.
1119
# Each path must stay importable; test_vllm_plugin.py guards against drift.
1220
FP32_MODEL_OVERRIDES: dict[str, str] = {
1321
"Qwen3ForCausalLM": "torchrl.modules.llm.backends.vllm._models:Qwen3ForCausalLMFP32",
1422
}
1523

1624

25+
def fp32_overrides_enabled() -> bool:
26+
"""Whether this process opted into torchrl's vLLM FP32 model overrides."""
27+
return os.environ.get(FP32_OVERRIDES_ENV_VAR, "0").lower() in ("1", "true", "yes")
28+
29+
1730
def register_fp32_overrides() -> None:
18-
"""Register FP32 overrides for vLLM models."""
31+
"""Register torchrl's FP32 vLLM model overrides -- only when opted in.
32+
33+
vLLM auto-loads this through the ``vllm.general_plugins`` entry point in
34+
*every* vLLM process, so it must do nothing unless this process explicitly
35+
asked for torchrl's overrides via ``TORCHRL_VLLM_FP32_OVERRIDES``. Otherwise
36+
merely *installing* torchrl would mutate an unrelated project's vLLM
37+
``ModelRegistry`` -- replacing its model classes with torchrl's, which track
38+
a newer vLLM API and would break an older host vLLM at logits time.
39+
"""
40+
if not fp32_overrides_enabled():
41+
return
42+
1943
from vllm.model_executor.models.registry import ModelRegistry
2044

2145
for arch, model_cls_path in FP32_MODEL_OVERRIDES.items():
2246
ModelRegistry.register_model(arch, model_cls_path)
2347

24-
logger.info("Registered Qwen3 FP32 model overrides")
48+
logger.info("Registered torchrl FP32 vLLM model overrides")
49+
50+
51+
def enable_fp32_overrides() -> None:
52+
"""Opt this process and its child vLLM processes into torchrl's overrides.
53+
54+
Sets ``TORCHRL_VLLM_FP32_OVERRIDES`` so spawned vLLM workers and the registry
55+
subprocess inherit the opt-in, then registers in-process. Call before
56+
constructing a vLLM engine when you want torchrl's FP32 model overrides.
57+
"""
58+
os.environ[FP32_OVERRIDES_ENV_VAR] = "1"
59+
register_fp32_overrides()

torchrl/modules/llm/backends/vllm/vllm_sync.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from torchrl.modules.llm.utils import _cuda_visible_devices
2020

2121
from .base import RLvLLMEngine
22+
from .vllm_plugin import FP32_OVERRIDES_ENV_VAR
2223

2324
try:
2425
from vllm import LLM
@@ -424,6 +425,9 @@ def make_vllm_worker(
424425
# Set FP32 output environment variable if requested
425426
if enable_fp32_output:
426427
os.environ["VLLM_ENABLE_FP32_OUTPUT"] = "1"
428+
# Opt the engine + its child vLLM processes into torchrl's FP32 model
429+
# overrides (the general-plugin no-ops without this).
430+
os.environ[FP32_OVERRIDES_ENV_VAR] = "1"
427431
torchrl_logger.info(
428432
"Enabled FP32 output for vLLM (VLLM_ENABLE_FP32_OUTPUT=1). "
429433
"This will use FP32 for the final output layer if the model supports it."

0 commit comments

Comments
 (0)