File tree Expand file tree Collapse file tree
torchrl/modules/llm/backends/vllm Expand file tree Collapse file tree Original file line number Diff line number Diff line change 1+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2+ #
3+ # This source code is licensed under the MIT license found in the
4+ # LICENSE file in the root directory of this source tree.
5+ from __future__ import annotations
6+
7+ import importlib
8+
9+ import pytest
10+
11+ from torchrl .modules .llm .backends .vllm .vllm_plugin import FP32_MODEL_OVERRIDES
12+
13+
14+ @pytest .mark .parametrize ("arch" , sorted (FP32_MODEL_OVERRIDES ))
15+ def test_fp32_override_paths_importable (arch ):
16+ """Every registered override must point at an importable class.
17+
18+ vLLM resolves these "module.path:ClassName" strings lazily, so a stale
19+ path is only discovered at server startup when vLLM inspects the
20+ architecture. This test does not require vLLM: ``_models`` falls back to
21+ placeholder classes when vLLM is absent, keeping the import path valid.
22+ """
23+ module_path , _ , class_name = FP32_MODEL_OVERRIDES [arch ].partition (":" )
24+ module = importlib .import_module (module_path )
25+ assert hasattr (module , class_name ), FP32_MODEL_OVERRIDES [arch ]
Original file line number Diff line number Diff line change @@ -1928,7 +1928,7 @@ def make_async_vllm_engine(
19281928 compile (bool, optional): Whether to enable model compilation for better performance. Defaults to True.
19291929 enable_fp32_output (bool, optional): Whether to enable FP32 output for the final layer. Defaults to False.
19301930 This can help with numerical stability for certain models. Requires model-specific support in
1931- torchrl.modules.llm.backends._models.
1931+ torchrl.modules.llm.backends.vllm. _models.
19321932 tensor_parallel_size (int, optional): Number of devices to use, per replica. Defaults to None.
19331933 data_parallel_size (int, optional): Number of data parallel groups to use. Defaults to None.
19341934 pipeline_parallel_size (int, optional): Number of pipeline parallel groups to use. Defaults to None.
Original file line number Diff line number Diff line change 77
88from torchrl ._utils import logger
99
10+ # Architecture name -> "module.path:ClassName" overrides registered with vLLM.
11+ # Each path must stay importable; test_vllm_plugin.py guards against drift.
12+ FP32_MODEL_OVERRIDES : dict [str , str ] = {
13+ "Qwen3ForCausalLM" : "torchrl.modules.llm.backends.vllm._models:Qwen3ForCausalLMFP32" ,
14+ }
15+
1016
1117def register_fp32_overrides () -> None :
1218 """Register FP32 overrides for vLLM models."""
1319 from vllm .model_executor .models .registry import ModelRegistry
1420
15- # ======= Register models here =======
16- # Register Qwen3 models with FP32 override
17- ModelRegistry .register_model (
18- "Qwen3ForCausalLM" ,
19- "torchrl.modules.llm.backends._models:Qwen3ForCausalLMFP32" ,
20- )
21+ for arch , model_cls_path in FP32_MODEL_OVERRIDES .items ():
22+ ModelRegistry .register_model (arch , model_cls_path )
2123
2224 logger .info ("Registered Qwen3 FP32 model overrides" )
Original file line number Diff line number Diff line change @@ -402,7 +402,7 @@ def make_vllm_worker(
402402 enforce_eager (bool, optional): Whether to enforce eager execution. Defaults to `False`.
403403 enable_fp32_output (bool, optional): Whether to enable FP32 output for the final layer. Defaults to False.
404404 This can help with numerical stability for certain models. Requires model-specific support in
405- torchrl.modules.llm.backends._models.
405+ torchrl.modules.llm.backends.vllm. _models.
406406 **kwargs: Additional arguments passed to vLLM.LLM.__init__.
407407
408408 Returns:
You can’t perform that action at this time.
0 commit comments