|
5 | 5 |
|
6 | 6 | from __future__ import annotations |
7 | 7 |
|
| 8 | +import os |
| 9 | + |
8 | 10 | from torchrl._utils import logger |
9 | 11 |
|
| 12 | +# Env var that opts a vLLM process into torchrl's FP32 model overrides. torchrl's |
| 13 | +# vLLM backend sets it (enable_fp32_output=True) before the engine and its |
| 14 | +# subprocesses start, so the overrides register in every vLLM process torchrl |
| 15 | +# owns -- and in none that it does not. |
| 16 | +FP32_OVERRIDES_ENV_VAR = "TORCHRL_VLLM_FP32_OVERRIDES" |
| 17 | + |
10 | 18 | # Architecture name -> "module.path:ClassName" overrides registered with vLLM. |
11 | 19 | # Each path must stay importable; test_vllm_plugin.py guards against drift. |
12 | 20 | FP32_MODEL_OVERRIDES: dict[str, str] = { |
13 | 21 | "Qwen3ForCausalLM": "torchrl.modules.llm.backends.vllm._models:Qwen3ForCausalLMFP32", |
14 | 22 | } |
15 | 23 |
|
16 | 24 |
|
| 25 | +def fp32_overrides_enabled() -> bool: |
| 26 | + """Whether this process opted into torchrl's vLLM FP32 model overrides.""" |
| 27 | + return os.environ.get(FP32_OVERRIDES_ENV_VAR, "0").lower() in ("1", "true", "yes") |
| 28 | + |
| 29 | + |
17 | 30 | def register_fp32_overrides() -> None: |
18 | | - """Register FP32 overrides for vLLM models.""" |
| 31 | + """Register torchrl's FP32 vLLM model overrides -- only when opted in. |
| 32 | +
|
| 33 | + vLLM auto-loads this through the ``vllm.general_plugins`` entry point in |
| 34 | + *every* vLLM process, so it must do nothing unless this process explicitly |
| 35 | + asked for torchrl's overrides via ``TORCHRL_VLLM_FP32_OVERRIDES``. Otherwise |
| 36 | + merely *installing* torchrl would mutate an unrelated project's vLLM |
| 37 | + ``ModelRegistry`` -- replacing its model classes with torchrl's, which track |
| 38 | + a newer vLLM API and would break an older host vLLM at logits time. |
| 39 | + """ |
| 40 | + if not fp32_overrides_enabled(): |
| 41 | + return |
| 42 | + |
19 | 43 | from vllm.model_executor.models.registry import ModelRegistry |
20 | 44 |
|
21 | 45 | for arch, model_cls_path in FP32_MODEL_OVERRIDES.items(): |
22 | 46 | ModelRegistry.register_model(arch, model_cls_path) |
23 | 47 |
|
24 | | - logger.info("Registered Qwen3 FP32 model overrides") |
| 48 | + logger.info("Registered torchrl FP32 vLLM model overrides") |
| 49 | + |
| 50 | + |
| 51 | +def enable_fp32_overrides() -> None: |
| 52 | + """Opt this process and its child vLLM processes into torchrl's overrides. |
| 53 | +
|
| 54 | + Sets ``TORCHRL_VLLM_FP32_OVERRIDES`` so spawned vLLM workers and the registry |
| 55 | + subprocess inherit the opt-in, then registers in-process. Call before |
| 56 | + constructing a vLLM engine when you want torchrl's FP32 model overrides. |
| 57 | + """ |
| 58 | + os.environ[FP32_OVERRIDES_ENV_VAR] = "1" |
| 59 | + register_fp32_overrides() |
0 commit comments