|
| 1 | +"""CSCS L1 passthrough. |
| 2 | +
|
| 3 | +CSCS already serves a small set of OpenAI-compatible models on their L1 |
| 4 | +endpoint. Instead of launching duplicate pods for them ourselves, we |
| 5 | +forward those model ids to L1 and surface them in /v1/models alongside |
| 6 | +our locally-served models. |
| 7 | +
|
| 8 | +Discovery: we hit L1's own /models endpoint on first use (and every |
| 9 | +30 s thereafter) so the set of L1-routable models tracks whatever L1 |
| 10 | +exposes, without code changes. A small `FALLBACK_MODEL_IDS` list |
| 11 | +backstops the cold-start case when L1 is unreachable on the very first |
| 12 | +fetch, so the model list isn't completely missing the Apertus rows |
| 13 | +during a brief L1 outage. |
| 14 | +
|
| 15 | +Secrets (base URL, API key) come from env via Settings. |
| 16 | +""" |
| 17 | + |
| 18 | +import asyncio |
| 19 | +import time |
| 20 | + |
| 21 | +import aiohttp |
| 22 | + |
| 23 | +from backend.config import get_settings |
| 24 | + |
| 25 | + |
| 26 | +# Cold-start fallback. Used only if we haven't successfully fetched |
| 27 | +# /models from L1 yet AND the current fetch fails. Once we've fetched |
| 28 | +# once successfully, we keep serving the stale cache rather than fall |
| 29 | +# back, so a transient outage never drops a model that *was* there. |
| 30 | +FALLBACK_MODEL_IDS: list[str] = [ |
| 31 | + "Apertus-70B-Instruct-2509", |
| 32 | + "Apertus-8B-Instruct-2509", |
| 33 | +] |
| 34 | + |
| 35 | +# 30 s strikes a balance: short enough that an L1 deployment of a new |
| 36 | +# model is visible within half a minute, long enough that page reloads |
| 37 | +# + completion dispatches don't hammer L1. |
| 38 | +_CACHE_TTL_SECONDS = 30.0 |
| 39 | +# Timeout for the GET /models probe — keep tight so a wedged L1 can't |
| 40 | +# stall /v1/models page loads on our side. |
| 41 | +_FETCH_TIMEOUT_SECONDS = 5.0 |
| 42 | + |
| 43 | +_cache_lock = asyncio.Lock() |
| 44 | +_cache: dict = {"fetched_at": 0.0, "ids": None} |
| 45 | + |
| 46 | + |
| 47 | +def _l1_configured() -> bool: |
| 48 | + s = get_settings() |
| 49 | + return bool(s.cscs_l1_base_url and s.cscs_l1_api_key) |
| 50 | + |
| 51 | + |
| 52 | +def l1_endpoint() -> str: |
| 53 | + """Base URL for L1 OpenAI-compatible API (e.g. https://.../v1). |
| 54 | + Caller appends /chat/completions etc.""" |
| 55 | + return get_settings().cscs_l1_base_url.rstrip("/") |
| 56 | + |
| 57 | + |
| 58 | +def l1_api_key() -> str: |
| 59 | + return get_settings().cscs_l1_api_key |
| 60 | + |
| 61 | + |
| 62 | +def _reset_cache_for_tests() -> None: |
| 63 | + """Test helper — clears the cache so tests can simulate cold start |
| 64 | + without leaking state across cases.""" |
| 65 | + _cache["fetched_at"] = 0.0 |
| 66 | + _cache["ids"] = None |
| 67 | + |
| 68 | + |
| 69 | +async def _fetch_l1_model_ids() -> set[str] | None: |
| 70 | + """GET {base}/models from L1. Returns None on any failure (network, |
| 71 | + non-200, malformed JSON) so the caller can decide whether to keep |
| 72 | + stale cache or fall back.""" |
| 73 | + url = l1_endpoint() + "/models" |
| 74 | + headers = {"Authorization": f"Bearer {l1_api_key()}"} |
| 75 | + try: |
| 76 | + timeout = aiohttp.ClientTimeout(total=_FETCH_TIMEOUT_SECONDS) |
| 77 | + async with aiohttp.ClientSession(timeout=timeout) as session: |
| 78 | + async with session.get(url, headers=headers) as resp: |
| 79 | + if resp.status != 200: |
| 80 | + return None |
| 81 | + data = await resp.json() |
| 82 | + return {m["id"] for m in data.get("data", []) if m.get("id")} |
| 83 | + except Exception: |
| 84 | + return None |
| 85 | + |
| 86 | + |
| 87 | +async def _get_cached_ids() -> set[str]: |
| 88 | + """Return the L1 model id set. Refreshes if TTL has expired; on |
| 89 | + fetch failure keeps stale cache, falling back to FALLBACK_MODEL_IDS |
| 90 | + only at true cold start. Never returns an empty set when L1 is |
| 91 | + configured — a transient L1 outage shouldn't make the Apertus rows |
| 92 | + disappear from the model list.""" |
| 93 | + if not _l1_configured(): |
| 94 | + return set() |
| 95 | + |
| 96 | + now = time.time() |
| 97 | + if _cache["ids"] is not None and (now - _cache["fetched_at"]) < _CACHE_TTL_SECONDS: |
| 98 | + return _cache["ids"] |
| 99 | + |
| 100 | + async with _cache_lock: |
| 101 | + # Another coroutine may have refreshed while we waited on the lock. |
| 102 | + if ( |
| 103 | + _cache["ids"] is not None |
| 104 | + and (time.time() - _cache["fetched_at"]) < _CACHE_TTL_SECONDS |
| 105 | + ): |
| 106 | + return _cache["ids"] |
| 107 | + |
| 108 | + fetched = await _fetch_l1_model_ids() |
| 109 | + if fetched is not None: |
| 110 | + _cache["ids"] = fetched |
| 111 | + _cache["fetched_at"] = time.time() |
| 112 | + return fetched |
| 113 | + |
| 114 | + if _cache["ids"] is not None: |
| 115 | + # Keep serving stale cache; don't update fetched_at so we |
| 116 | + # try again on the next call instead of waiting a full TTL. |
| 117 | + return _cache["ids"] |
| 118 | + |
| 119 | + return set(FALLBACK_MODEL_IDS) |
| 120 | + |
| 121 | + |
| 122 | +async def is_l1_model(model_id: str) -> bool: |
| 123 | + """True only when the model is exposed by L1 AND L1 is configured — |
| 124 | + so an unconfigured deploy doesn't try to proxy to an empty URL. With |
| 125 | + L1 unconfigured, L1 model ids fall through to OpenTela (which 404s |
| 126 | + cleanly) instead of producing an opaque connection error.""" |
| 127 | + if not model_id or not _l1_configured(): |
| 128 | + return False |
| 129 | + ids = await _get_cached_ids() |
| 130 | + return model_id in ids |
| 131 | + |
| 132 | + |
| 133 | +async def get_l1_synthetic_entries(with_details: bool = False) -> list[dict]: |
| 134 | + """Synthesize one peer-style entry per L1 model so they appear in |
| 135 | + /v1/models* alongside OpenTela-served models. Mirrors the shape |
| 136 | + produced by services.model_service.get_all_models — the frontend |
| 137 | + can't tell the difference. |
| 138 | +
|
| 139 | + Returns an empty list when L1 isn't configured: we only advertise |
| 140 | + these models if we can actually serve them. |
| 141 | + """ |
| 142 | + if not _l1_configured(): |
| 143 | + return [] |
| 144 | + |
| 145 | + ids = await _get_cached_ids() |
| 146 | + entries: list[dict] = [] |
| 147 | + for model_id in sorted(ids): |
| 148 | + wg = f"cscs-l1:{model_id}" |
| 149 | + entry = { |
| 150 | + "id": model_id, |
| 151 | + "object": "model", |
| 152 | + "created": "0x", |
| 153 | + "owner": "0x", |
| 154 | + # Empty peer_id/hostname → ModelCard's L1 branch hides the |
| 155 | + # head row anyway; keep them blank rather than synthesise |
| 156 | + # fake values. |
| 157 | + "peer_id": "", |
| 158 | + "hostname": "", |
| 159 | + "otela_version": "", |
| 160 | + "status": "ready", |
| 161 | + "labels": { |
| 162 | + "launched_by": "cscs_L1", |
| 163 | + "framework": "vllm", |
| 164 | + }, |
| 165 | + "worker_group_id": wg, |
| 166 | + "launched_by": "cscs_L1", |
| 167 | + "slurm_job_id": "", |
| 168 | + "framework": "vllm", |
| 169 | + "started_at": "", |
| 170 | + "expires_at": "", |
| 171 | + } |
| 172 | + if with_details: |
| 173 | + entry["device"] = "CSCS L1" |
| 174 | + entries.append(entry) |
| 175 | + return entries |
0 commit comments