|
| 1 | +# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. |
| 2 | +# SPDX-License-Identifier: Apache-2.0 |
| 3 | +# |
| 4 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 5 | +# you may not use this file except in compliance with the License. |
| 6 | +# You may obtain a copy of the License at |
| 7 | +# |
| 8 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 9 | +# |
| 10 | +# Unless required by applicable law or agreed to in writing, software |
| 11 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 12 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 13 | +# See the License for the specific language governing permissions and |
| 14 | +# limitations under the License. |
| 15 | +"""Self-contained wrapper tests for the verl repo. |
| 16 | +
|
| 17 | +All setup (dependency installation, repo cloning, env vars) is handled by |
| 18 | +a session-scoped pytest fixture. Configuration is read from verl_config.yml. |
| 19 | +""" |
| 20 | + |
| 21 | +import os |
| 22 | +import subprocess |
| 23 | +import sys |
| 24 | + |
| 25 | +import pytest |
| 26 | +import yaml |
| 27 | + |
| 28 | +_HERE = os.path.dirname(os.path.abspath(__file__)) |
| 29 | +_CONFIG_PATH = os.path.join(_HERE, "verl_config.yml") |
| 30 | +VERL_ROOT = os.path.join(_HERE, "verl_repo") |
| 31 | + |
| 32 | + |
| 33 | +def _load_config(): |
| 34 | + with open(_CONFIG_PATH) as f: |
| 35 | + return yaml.safe_load(f)["verl_config"] |
| 36 | + |
| 37 | + |
| 38 | +def _export_env_vars(config): |
| 39 | + """Export env vars from config into the current process environment.""" |
| 40 | + for entry in config.get("env_vars", []): |
| 41 | + key, val = entry.split("=", 1) |
| 42 | + val = val.strip('"') |
| 43 | + val = os.path.expandvars(val) |
| 44 | + os.environ[key] = val |
| 45 | + |
| 46 | + |
| 47 | +def _run_install_commands(config): |
| 48 | + """Run install commands from config with env vars already set.""" |
| 49 | + for cmd in config.get("install_commands", []): |
| 50 | + print(f"[verl setup] Running: {cmd}") |
| 51 | + subprocess.check_call(cmd, shell=True) |
| 52 | + |
| 53 | + |
| 54 | +def _clone_verl_repo(config): |
| 55 | + """Clone the verl repo and checkout the specified tag.""" |
| 56 | + if os.path.isdir(VERL_ROOT): |
| 57 | + print(f"[verl setup] Repo already exists at {VERL_ROOT}, skipping clone") |
| 58 | + return |
| 59 | + repo_url = config["repo_url"] |
| 60 | + repo_tag = config["repo_tag"] |
| 61 | + print(f"[verl setup] Cloning {repo_url} (tag={repo_tag}) into {VERL_ROOT}") |
| 62 | + subprocess.check_call( |
| 63 | + f"git clone {repo_url} {VERL_ROOT} && cd {VERL_ROOT} && git checkout {repo_tag}", |
| 64 | + shell=True, |
| 65 | + ) |
| 66 | + assert os.path.isdir(VERL_ROOT), f"Failed to clone verl repo to {VERL_ROOT}" |
| 67 | + print(f"[verl setup] Installing verl package from {VERL_ROOT}") |
| 68 | + subprocess.check_call( |
| 69 | + [sys.executable, "-m", "pip", "install", "-e", VERL_ROOT], |
| 70 | + ) |
| 71 | + |
| 72 | + |
| 73 | +def _setup_model_symlinks(config): |
| 74 | + """Create symlinks from HF-style paths to CI cache paths. |
| 75 | +
|
| 76 | + Verl tests expect models at {model_root}/Qwen/ModelName but the CI cache |
| 77 | + stores them at {ci_cache}/ModelName (flat structure). We create symlinks |
| 78 | + in a writable staging directory that point to the read-only CI cache. |
| 79 | + """ |
| 80 | + model_root = os.environ.get("TRTLLM_TEST_MODEL_PATH_ROOT", "") |
| 81 | + ci_cache = config.get("ci_model_cache", "") |
| 82 | + if not model_root or not ci_cache: |
| 83 | + return |
| 84 | + for model_id in config.get("required_models", []): |
| 85 | + if "/" not in model_id: |
| 86 | + continue |
| 87 | + namespace, name = model_id.split("/", 1) |
| 88 | + ns_dir = os.path.join(model_root, namespace) |
| 89 | + src = os.path.join(ci_cache, name) |
| 90 | + dst = os.path.join(ns_dir, name) |
| 91 | + if os.path.exists(dst): |
| 92 | + print(f"[verl setup] Model symlink already exists: {dst}") |
| 93 | + continue |
| 94 | + if not os.path.isdir(src): |
| 95 | + print(f"[verl setup] Model not found in CI cache: {src}, skipping") |
| 96 | + continue |
| 97 | + os.makedirs(ns_dir, exist_ok=True) |
| 98 | + os.symlink(src, dst) |
| 99 | + print(f"[verl setup] Created symlink: {dst} -> {src}") |
| 100 | + |
| 101 | + |
| 102 | +@pytest.fixture(scope="session", autouse=True) |
| 103 | +def verl_setup(): |
| 104 | + """Session-scoped fixture: install deps, set env vars, clone verl repo.""" |
| 105 | + config = _load_config() |
| 106 | + _export_env_vars(config) |
| 107 | + _run_install_commands(config) |
| 108 | + _clone_verl_repo(config) |
| 109 | + _setup_model_symlinks(config) |
| 110 | + yield VERL_ROOT |
| 111 | + |
| 112 | + |
| 113 | +def _run_verl_test(test_path, extra_args=None, timeout=600): |
| 114 | + """Run a test from the verl repo via subprocess.""" |
| 115 | + full_path = os.path.join(VERL_ROOT, test_path) |
| 116 | + assert os.path.exists(full_path), f"Verl test not found: {full_path}" |
| 117 | + cmd = [sys.executable, "-m", "pytest", full_path, "-v", "--tb=short"] |
| 118 | + if extra_args: |
| 119 | + cmd.extend(extra_args) |
| 120 | + result = subprocess.run( |
| 121 | + cmd, |
| 122 | + cwd=VERL_ROOT, |
| 123 | + env=os.environ.copy(), |
| 124 | + timeout=timeout, |
| 125 | + ) |
| 126 | + assert result.returncode == 0, f"Verl test failed with return code {result.returncode}" |
| 127 | + |
| 128 | + |
| 129 | +def test_async_server(): |
| 130 | + _run_verl_test("tests/workers/rollout/rollout_trtllm/test_async_server.py") |
| 131 | + |
| 132 | + |
| 133 | +def test_adapter(): |
| 134 | + _run_verl_test("tests/workers/rollout/rollout_trtllm/test_adapter.py") |
| 135 | + |
| 136 | + |
| 137 | +def test_rollout_utils(): |
| 138 | + _run_verl_test( |
| 139 | + "tests/workers/rollout/rollout_trtllm/test_trtllm_rollout_utils.py", |
| 140 | + extra_args=[ |
| 141 | + "-k", |
| 142 | + "not (test_unimodal_generate or test_unimodal_batch_generate)", |
| 143 | + ], |
| 144 | + timeout=900, |
| 145 | + ) |
0 commit comments