Skip to content

Commit 2520d97

Browse files
committed
add L1 routing
1 parent f4b8fa1 commit 2520d97

8 files changed

Lines changed: 516 additions & 26 deletions

File tree

backend/config.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,13 @@ class Settings(BaseSettings):
3333
default="",
3434
validation_alias=AliasChoices("otela_fixture_path", "ocf_fixture_path"),
3535
)
36+
# CSCS L1 passthrough — when set, chat/completion requests for the
37+
# hardcoded L1 model list in backend/services/cscs_l1_service.py are
38+
# forwarded here instead of the OpenTela network. Lets us expose
39+
# Apertus 8B/70B from the upstream L1 service without launching our
40+
# own k8s pods. Both must be provided via env in k8s secrets.
41+
cscs_l1_base_url: str = ""
42+
cscs_l1_api_key: str = ""
3643
langfuse_host: str = ""
3744
langfuse_public_key: str = ""
3845
langfuse_secret_key: str = ""

backend/routers/completions.py

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,22 @@
66
llm_proxy_completions,
77
response_generator,
88
)
9+
from backend.services.cscs_l1_service import is_l1_model, l1_endpoint, l1_api_key
910
from backend.models.protocols import LLMRequest, LLMCompletionsRequest
1011
from backend.config import get_settings
1112

1213
router = APIRouter()
1314
settings = get_settings()
1415

16+
17+
async def _resolve_endpoint_and_key(model: str, user_token: str) -> tuple[str, str]:
18+
"""L1-hosted models go to the upstream L1 endpoint with our shared L1
19+
key; everything else stays on the OpenTela proxy with the user's
20+
bearer token forwarded as-is."""
21+
if await is_l1_model(model):
22+
return l1_endpoint(), l1_api_key()
23+
return settings.otela_head_addr + "/v1/service/llm/v1/", user_token
24+
1525
CHAT_RESERVED_KEYS = [
1626
"model",
1727
"messages",
@@ -74,9 +84,10 @@ async def chat_completion(
7484
user_id=token, opt_out=opt_out, app_title=app_title, **reorg_data
7585
)
7686

87+
endpoint, api_key = await _resolve_endpoint_and_key(llm_request.model, token)
7788
response = await llm_proxy(
78-
endpoint=settings.otela_head_addr + "/v1/service/llm/v1/",
79-
api_key=token,
89+
endpoint=endpoint,
90+
api_key=api_key,
8091
request=llm_request,
8192
)
8293
if "stream" in data and data["stream"]:
@@ -124,9 +135,10 @@ async def completion(
124135
user_id=token, opt_out=opt_out, app_title=app_title, **reorg_data
125136
)
126137

138+
endpoint, api_key = await _resolve_endpoint_and_key(llm_request.model, token)
127139
response = await llm_proxy_completions(
128-
endpoint=settings.otela_head_addr + "/v1/service/llm/v1/",
129-
api_key=token,
140+
endpoint=endpoint,
141+
api_key=api_key,
130142
request=llm_request,
131143
)
132144
if "stream" in data and data["stream"]:

backend/routers/models.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from fastapi import APIRouter
22
from backend.services.model_service import get_all_models
3+
from backend.services.cscs_l1_service import get_l1_synthetic_entries
34
from backend.config import get_settings
45

56
router = APIRouter()
@@ -14,9 +15,21 @@ def _dnt_endpoint() -> str:
1415
return settings.otela_head_addr + "/v1/dnt/table"
1516

1617

18+
async def _with_l1(models: list[dict], with_details: bool) -> list[dict]:
19+
"""Append synthetic L1 entries, skipping ids already present in the
20+
OpenTela result so we don't double-list a model that's still launched
21+
locally during a migration."""
22+
existing = {m["id"] for m in models if m.get("id")}
23+
for entry in await get_l1_synthetic_entries(with_details=with_details):
24+
if entry["id"] not in existing:
25+
models.append(entry)
26+
return models
27+
28+
1729
@router.get("/v1/models_detailed")
1830
async def list_models_detailed():
1931
models = get_all_models(_dnt_endpoint(), with_details=True)
32+
models = await _with_l1(models, with_details=True)
2033
return dict(
2134
object="list",
2235
data=models,
@@ -26,6 +39,7 @@ async def list_models_detailed():
2639
@router.get("/v1/models")
2740
async def list_models():
2841
models = get_all_models(_dnt_endpoint(), with_details=False)
42+
models = await _with_l1(models, with_details=False)
2943
return dict(
3044
object="list",
3145
data=models,

backend/routers/responses.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from fastapi.responses import StreamingResponse
33
from backend.middleware.auth import require_auth
44
from backend.services.llm_service import llm_proxy_responses, response_generator_raw
5+
from backend.services.cscs_l1_service import is_l1_model, l1_endpoint, l1_api_key
56
from backend.config import get_settings
67

78
router = APIRouter()
@@ -15,13 +16,19 @@ async def create_response(
1516
):
1617
data = await request.json()
1718
stream = data.get("stream", False)
19+
model = data.get("model", "unknown")
20+
21+
if await is_l1_model(model):
22+
endpoint, api_key = l1_endpoint(), l1_api_key()
23+
else:
24+
endpoint, api_key = settings.otela_head_addr + "/v1/service/llm/v1/", token
1825

1926
response = await llm_proxy_responses(
20-
endpoint=settings.otela_head_addr + "/v1/service/llm/v1/",
21-
api_key=token,
27+
endpoint=endpoint,
28+
api_key=api_key,
2229
payload=data,
2330
stream=stream,
24-
model=data.get("model", "unknown"),
31+
model=model,
2532
)
2633

2734
if stream:
Lines changed: 175 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,175 @@
1+
"""CSCS L1 passthrough.
2+
3+
CSCS already serves a small set of OpenAI-compatible models on their L1
4+
endpoint. Instead of launching duplicate pods for them ourselves, we
5+
forward those model ids to L1 and surface them in /v1/models alongside
6+
our locally-served models.
7+
8+
Discovery: we hit L1's own /models endpoint on first use (and every
9+
30 s thereafter) so the set of L1-routable models tracks whatever L1
10+
exposes, without code changes. A small `FALLBACK_MODEL_IDS` list
11+
backstops the cold-start case when L1 is unreachable on the very first
12+
fetch, so the model list isn't completely missing the Apertus rows
13+
during a brief L1 outage.
14+
15+
Secrets (base URL, API key) come from env via Settings.
16+
"""
17+
18+
import asyncio
19+
import time
20+
21+
import aiohttp
22+
23+
from backend.config import get_settings
24+
25+
26+
# Cold-start fallback. Used only if we haven't successfully fetched
27+
# /models from L1 yet AND the current fetch fails. Once we've fetched
28+
# once successfully, we keep serving the stale cache rather than fall
29+
# back, so a transient outage never drops a model that *was* there.
30+
FALLBACK_MODEL_IDS: list[str] = [
31+
"Apertus-70B-Instruct-2509",
32+
"Apertus-8B-Instruct-2509",
33+
]
34+
35+
# 30 s strikes a balance: short enough that an L1 deployment of a new
36+
# model is visible within half a minute, long enough that page reloads
37+
# + completion dispatches don't hammer L1.
38+
_CACHE_TTL_SECONDS = 30.0
39+
# Timeout for the GET /models probe — keep tight so a wedged L1 can't
40+
# stall /v1/models page loads on our side.
41+
_FETCH_TIMEOUT_SECONDS = 5.0
42+
43+
_cache_lock = asyncio.Lock()
44+
_cache: dict = {"fetched_at": 0.0, "ids": None}
45+
46+
47+
def _l1_configured() -> bool:
48+
s = get_settings()
49+
return bool(s.cscs_l1_base_url and s.cscs_l1_api_key)
50+
51+
52+
def l1_endpoint() -> str:
53+
"""Base URL for L1 OpenAI-compatible API (e.g. https://.../v1).
54+
Caller appends /chat/completions etc."""
55+
return get_settings().cscs_l1_base_url.rstrip("/")
56+
57+
58+
def l1_api_key() -> str:
59+
return get_settings().cscs_l1_api_key
60+
61+
62+
def _reset_cache_for_tests() -> None:
63+
"""Test helper — clears the cache so tests can simulate cold start
64+
without leaking state across cases."""
65+
_cache["fetched_at"] = 0.0
66+
_cache["ids"] = None
67+
68+
69+
async def _fetch_l1_model_ids() -> set[str] | None:
70+
"""GET {base}/models from L1. Returns None on any failure (network,
71+
non-200, malformed JSON) so the caller can decide whether to keep
72+
stale cache or fall back."""
73+
url = l1_endpoint() + "/models"
74+
headers = {"Authorization": f"Bearer {l1_api_key()}"}
75+
try:
76+
timeout = aiohttp.ClientTimeout(total=_FETCH_TIMEOUT_SECONDS)
77+
async with aiohttp.ClientSession(timeout=timeout) as session:
78+
async with session.get(url, headers=headers) as resp:
79+
if resp.status != 200:
80+
return None
81+
data = await resp.json()
82+
return {m["id"] for m in data.get("data", []) if m.get("id")}
83+
except Exception:
84+
return None
85+
86+
87+
async def _get_cached_ids() -> set[str]:
88+
"""Return the L1 model id set. Refreshes if TTL has expired; on
89+
fetch failure keeps stale cache, falling back to FALLBACK_MODEL_IDS
90+
only at true cold start. Never returns an empty set when L1 is
91+
configured — a transient L1 outage shouldn't make the Apertus rows
92+
disappear from the model list."""
93+
if not _l1_configured():
94+
return set()
95+
96+
now = time.time()
97+
if _cache["ids"] is not None and (now - _cache["fetched_at"]) < _CACHE_TTL_SECONDS:
98+
return _cache["ids"]
99+
100+
async with _cache_lock:
101+
# Another coroutine may have refreshed while we waited on the lock.
102+
if (
103+
_cache["ids"] is not None
104+
and (time.time() - _cache["fetched_at"]) < _CACHE_TTL_SECONDS
105+
):
106+
return _cache["ids"]
107+
108+
fetched = await _fetch_l1_model_ids()
109+
if fetched is not None:
110+
_cache["ids"] = fetched
111+
_cache["fetched_at"] = time.time()
112+
return fetched
113+
114+
if _cache["ids"] is not None:
115+
# Keep serving stale cache; don't update fetched_at so we
116+
# try again on the next call instead of waiting a full TTL.
117+
return _cache["ids"]
118+
119+
return set(FALLBACK_MODEL_IDS)
120+
121+
122+
async def is_l1_model(model_id: str) -> bool:
123+
"""True only when the model is exposed by L1 AND L1 is configured —
124+
so an unconfigured deploy doesn't try to proxy to an empty URL. With
125+
L1 unconfigured, L1 model ids fall through to OpenTela (which 404s
126+
cleanly) instead of producing an opaque connection error."""
127+
if not model_id or not _l1_configured():
128+
return False
129+
ids = await _get_cached_ids()
130+
return model_id in ids
131+
132+
133+
async def get_l1_synthetic_entries(with_details: bool = False) -> list[dict]:
134+
"""Synthesize one peer-style entry per L1 model so they appear in
135+
/v1/models* alongside OpenTela-served models. Mirrors the shape
136+
produced by services.model_service.get_all_models — the frontend
137+
can't tell the difference.
138+
139+
Returns an empty list when L1 isn't configured: we only advertise
140+
these models if we can actually serve them.
141+
"""
142+
if not _l1_configured():
143+
return []
144+
145+
ids = await _get_cached_ids()
146+
entries: list[dict] = []
147+
for model_id in sorted(ids):
148+
wg = f"cscs-l1:{model_id}"
149+
entry = {
150+
"id": model_id,
151+
"object": "model",
152+
"created": "0x",
153+
"owner": "0x",
154+
# Empty peer_id/hostname → ModelCard's L1 branch hides the
155+
# head row anyway; keep them blank rather than synthesise
156+
# fake values.
157+
"peer_id": "",
158+
"hostname": "",
159+
"otela_version": "",
160+
"status": "ready",
161+
"labels": {
162+
"launched_by": "cscs_L1",
163+
"framework": "vllm",
164+
},
165+
"worker_group_id": wg,
166+
"launched_by": "cscs_L1",
167+
"slurm_job_id": "",
168+
"framework": "vllm",
169+
"started_at": "",
170+
"expires_at": "",
171+
}
172+
if with_details:
173+
entry["device"] = "CSCS L1"
174+
entries.append(entry)
175+
return entries

0 commit comments

Comments
 (0)