|
| 1 | +"""Pytest configuration and fixtures for catalog tests. |
| 2 | +
|
| 3 | +This module follows the model-registry Python client pattern: |
| 4 | +- Assumes catalog service is already running (K8s, local, etc.) |
| 5 | +- Uses environment variables for configuration |
| 6 | +""" |
| 7 | + |
| 8 | +import logging |
| 9 | +import os |
| 10 | +import time |
| 11 | +from collections.abc import Generator |
| 12 | +from pathlib import Path |
| 13 | + |
| 14 | +import pytest |
| 15 | +import requests |
| 16 | + |
| 17 | +from model_catalog import CatalogAPIClient |
| 18 | + |
| 19 | +from .constants import ( |
| 20 | + API_BASE_PATH, |
| 21 | + CATALOG_URL, |
| 22 | + CLIENT_TIMEOUT, |
| 23 | + MAX_BACKOFF, |
| 24 | + MAX_POLL_TIME, |
| 25 | + POLL_INTERVAL, |
| 26 | + get_verify_ssl, |
| 27 | +) |
| 28 | + |
| 29 | +# Configure logging |
| 30 | +logging.basicConfig( |
| 31 | + format="%(asctime)s.%(msecs)03d - %(name)s:%(levelname)s: %(message)s", |
| 32 | + datefmt="%H:%M:%S", |
| 33 | + level=logging.WARNING, |
| 34 | +) |
| 35 | + |
| 36 | +logger = logging.getLogger("model-catalog") |
| 37 | + |
| 38 | + |
| 39 | +def pytest_addoption(parser): |
| 40 | + """Add custom command line options.""" |
| 41 | + parser.addoption("--e2e", action="store_true", help="run end-to-end tests") |
| 42 | + parser.addoption("--fuzz", action="store_true", help="run fuzzing tests") |
| 43 | + |
| 44 | + |
| 45 | +def pytest_configure(config): |
| 46 | + """Register custom markers.""" |
| 47 | + config.addinivalue_line("markers", "e2e: mark test as end-to-end test") |
| 48 | + config.addinivalue_line("markers", "fuzz: mark test as fuzzing test") |
| 49 | + config.addinivalue_line("markers", "huggingface: mark test as requiring HuggingFace API") |
| 50 | + |
| 51 | + |
| 52 | +def _auto_mark_test(item) -> None: |
| 53 | + """Auto-mark tests based on their location.""" |
| 54 | + path = str(item.fspath) |
| 55 | + if "fuzz_api" in path: |
| 56 | + item.add_marker(pytest.mark.fuzz) |
| 57 | + elif "tests" in path: |
| 58 | + item.add_marker(pytest.mark.e2e) |
| 59 | + |
| 60 | + |
| 61 | +def _apply_skip_markers(item, *, e2e: bool, fuzz: bool) -> None: |
| 62 | + """Apply skip markers based on CLI flags.""" |
| 63 | + skip_e2e = pytest.mark.skip(reason="need --e2e option to run E2E tests") |
| 64 | + skip_fuzz = pytest.mark.skip(reason="need --fuzz option to run fuzzing tests") |
| 65 | + skip_other = pytest.mark.skip(reason="skipping non-selected tests") |
| 66 | + |
| 67 | + if e2e: |
| 68 | + if "e2e" not in item.keywords: |
| 69 | + item.add_marker(skip_other) |
| 70 | + elif fuzz: |
| 71 | + if "fuzz" not in item.keywords: |
| 72 | + item.add_marker(skip_other) |
| 73 | + else: |
| 74 | + # No flag specified - skip both e2e and fuzz tests |
| 75 | + if "e2e" in item.keywords: |
| 76 | + item.add_marker(skip_e2e) |
| 77 | + if "fuzz" in item.keywords: |
| 78 | + item.add_marker(skip_fuzz) |
| 79 | + |
| 80 | + |
| 81 | +def pytest_collection_modifyitems(config, items): |
| 82 | + """Modify test collection based on markers and options.""" |
| 83 | + e2e = config.getoption("--e2e") |
| 84 | + fuzz = config.getoption("--fuzz") |
| 85 | + |
| 86 | + for item in items: |
| 87 | + _auto_mark_test(item) |
| 88 | + _apply_skip_markers(item, e2e=e2e, fuzz=fuzz) |
| 89 | + |
| 90 | + |
| 91 | +def pytest_report_teststatus(report, config): |
| 92 | + """Custom test status reporting.""" |
| 93 | + if config.getoption("--quiet", default=False): |
| 94 | + return |
| 95 | + |
| 96 | + test_name = report.head_line |
| 97 | + if report.passed: |
| 98 | + if report.when == "call": |
| 99 | + print(f"\nTEST: {test_name} STATUS: \033[0;32mPASSED\033[0m") |
| 100 | + elif report.skipped: |
| 101 | + print(f"\nTEST: {test_name} STATUS: \033[1;33mSKIPPED\033[0m") |
| 102 | + elif report.failed: |
| 103 | + if report.when != "call": |
| 104 | + print(f"\nTEST: {test_name} [{report.when}] STATUS: \033[0;31mERROR\033[0m") |
| 105 | + else: |
| 106 | + print(f"\nTEST: {test_name} STATUS: \033[0;31mFAILED\033[0m") |
| 107 | + |
| 108 | + |
| 109 | +# Maximum directory levels to traverse when searching for repo root |
| 110 | +_MAX_PARENT_LEVELS = 10 |
| 111 | + |
| 112 | + |
| 113 | +@pytest.fixture(scope="session") |
| 114 | +def root(request) -> Path: |
| 115 | + """Get repository root directory. |
| 116 | +
|
| 117 | + Navigates up from catalog/clients/python to find the repo root. |
| 118 | + The repo root is identified by the presence of a .git directory. |
| 119 | +
|
| 120 | + Raises: |
| 121 | + RuntimeError: If the repository root cannot be found. |
| 122 | + """ |
| 123 | + current = request.config.rootpath |
| 124 | + # Walk up looking for .git directory (repo root marker) |
| 125 | + for _ in range(_MAX_PARENT_LEVELS): |
| 126 | + if (current / ".git").exists(): |
| 127 | + return current |
| 128 | + current = current.parent |
| 129 | + # Fail explicitly if repo root not found |
| 130 | + msg = ( |
| 131 | + f"Could not find repository root (.git directory) starting from " |
| 132 | + f"{request.config.rootpath}. Searched {_MAX_PARENT_LEVELS} levels up." |
| 133 | + ) |
| 134 | + raise RuntimeError(msg) |
| 135 | + |
| 136 | + |
| 137 | +@pytest.fixture(scope="session") |
| 138 | +def user_token() -> str | None: |
| 139 | + """Get user token from environment.""" |
| 140 | + return os.getenv("AUTH_TOKEN") |
| 141 | + |
| 142 | + |
| 143 | +@pytest.fixture(scope="session") |
| 144 | +def request_headers(user_token: str | None) -> dict[str, str]: |
| 145 | + """Get request headers including authorization if token is set.""" |
| 146 | + headers = {"Content-Type": "application/json"} |
| 147 | + if user_token: |
| 148 | + headers["Authorization"] = f"Bearer {user_token}" |
| 149 | + return headers |
| 150 | + |
| 151 | + |
| 152 | +@pytest.fixture(scope="session") |
| 153 | +def verify_ssl() -> bool: |
| 154 | + """Get SSL verification setting from environment.""" |
| 155 | + return get_verify_ssl(logger) |
| 156 | + |
| 157 | + |
| 158 | +def poll_for_ready(user_token: str | None, verify_ssl: bool) -> None: |
| 159 | + """Wait for catalog service to be ready using exponential backoff. |
| 160 | +
|
| 161 | + Args: |
| 162 | + user_token: Optional auth token. |
| 163 | + verify_ssl: Whether to verify SSL certificates. |
| 164 | + """ |
| 165 | + url = f"{CATALOG_URL}{API_BASE_PATH}/sources" |
| 166 | + headers = {"Authorization": f"Bearer {user_token}"} if user_token else None |
| 167 | + |
| 168 | + # Exponential backoff: start at POLL_INTERVAL, double each time, cap at MAX_BACKOFF |
| 169 | + backoff = POLL_INTERVAL |
| 170 | + poll_start = time.time() |
| 171 | + |
| 172 | + while True: |
| 173 | + elapsed_time = time.time() - poll_start |
| 174 | + if elapsed_time >= MAX_POLL_TIME: |
| 175 | + msg = f"Catalog service not ready after {int(elapsed_time)}s at {url}" |
| 176 | + logger.error(msg) |
| 177 | + raise TimeoutError(msg) |
| 178 | + logger.info("Attempting to connect to server %s", url) |
| 179 | + try: |
| 180 | + response = requests.get(url, headers=headers, verify=verify_ssl, timeout=MAX_BACKOFF) |
| 181 | + if response.status_code < 500: # Accept any non-5xx response |
| 182 | + logger.info("Server is up!") |
| 183 | + return |
| 184 | + except requests.exceptions.ConnectionError: |
| 185 | + pass |
| 186 | + |
| 187 | + time.sleep(backoff) |
| 188 | + backoff = min(backoff * 2, MAX_BACKOFF) # Exponential backoff with cap |
| 189 | + |
| 190 | + |
| 191 | +@pytest.fixture(scope="session") |
| 192 | +def api_client(user_token: str | None, verify_ssl: bool) -> Generator[CatalogAPIClient, None, None]: |
| 193 | + """Create API client for the catalog service. |
| 194 | +
|
| 195 | + This is a session-scoped fixture that connects to the already-running |
| 196 | + catalog service specified by CATALOG_URL environment variable. |
| 197 | +
|
| 198 | + Timeout is configurable via CATALOG_CLIENT_TIMEOUT env var (default 30s). |
| 199 | + """ |
| 200 | + poll_for_ready(user_token=user_token, verify_ssl=verify_ssl) |
| 201 | + with CatalogAPIClient(CATALOG_URL, timeout=CLIENT_TIMEOUT, verify_ssl=verify_ssl) as client: |
| 202 | + yield client |
| 203 | + |
| 204 | + |
| 205 | +@pytest.fixture(scope="session") |
| 206 | +def model_with_artifacts(api_client: CatalogAPIClient) -> tuple[str, str]: |
| 207 | + """Get a model that has artifacts for testing. |
| 208 | +
|
| 209 | + Searches available models to find one with artifacts. |
| 210 | + Fails if no models or no models with artifacts are found. |
| 211 | +
|
| 212 | + Returns: |
| 213 | + Tuple of (source_id, model_name) for a model with artifacts. |
| 214 | +
|
| 215 | + Raises: |
| 216 | + pytest.fail: If no models are available or no model has artifacts. |
| 217 | + """ |
| 218 | + models = api_client.get_models() |
| 219 | + if not models.get("items"): |
| 220 | + pytest.fail("No models available - test data may not be loaded") |
| 221 | + |
| 222 | + # Find a model that has artifacts |
| 223 | + for model in models["items"]: |
| 224 | + source_id = model.get("source_id") |
| 225 | + model_name = model.get("name") |
| 226 | + if not source_id or not model_name: |
| 227 | + continue |
| 228 | + |
| 229 | + # Check if this model has artifacts |
| 230 | + artifacts = api_client.get_artifacts(source_id=source_id, model_name=model_name) |
| 231 | + if artifacts.get("items"): |
| 232 | + return source_id, model_name |
| 233 | + |
| 234 | + # Fallback to first model with required fields |
| 235 | + model = models["items"][0] |
| 236 | + source_id = model.get("source_id") |
| 237 | + model_name = model.get("name") |
| 238 | + |
| 239 | + if not source_id or not model_name: |
| 240 | + pytest.fail("Model missing source_id or name - test data may be malformed") |
| 241 | + |
| 242 | + return source_id, model_name |
| 243 | + |
| 244 | + |
| 245 | +@pytest.fixture(scope="session") |
| 246 | +def testdata_dir(root) -> Path: |
| 247 | + """Get path to testdata directory.""" |
| 248 | + return root / "test" / "testdata" |
| 249 | + |
| 250 | + |
| 251 | +@pytest.fixture(scope="session") |
| 252 | +def local_testdata_dir() -> Path: |
| 253 | + """Get path to local testdata directory (in tests/).""" |
| 254 | + return Path(__file__).parent / "testdata" |
0 commit comments