Skip to content

Commit f8fe1f2

Browse files
committed
Refactor Heath Checkers
1 parent ee6a215 commit f8fe1f2

7 files changed

Lines changed: 51 additions & 52 deletions

File tree

src/swiss_ai_model_launch/cli/display/live.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@
1919
}
2020

2121
_MODEL_HEALTH_STYLE: dict[ModelHealth, str] = {
22-
ModelHealth.WAITING: "[yellow]WAITING[/yellow]",
2322
ModelHealth.HEALTHY: "[green]HEALTHY[/green]",
2423
ModelHealth.ERROR: "[orange]ERROR[/orange]",
2524
ModelHealth.NOT_DEPLOYED: "[dim]NOT DEPLOYED[/dim]",

src/swiss_ai_model_launch/cli/display/state.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ def __init__(self) -> None:
1111
self.partition: str | None = None
1212
self.job_id: int | None = None
1313
self.job_status: JobStatus | None = None
14-
self.model_health: ModelHealth = ModelHealth.WAITING
14+
self.model_health: ModelHealth = ModelHealth.NOT_DEPLOYED
1515
self.served_model_name: str | None = None
1616
self.out_logs: deque[str] = deque()
1717
self.err_logs: deque[str] = deque()

src/swiss_ai_model_launch/cli/healthcheck/checker.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,7 @@
77
_TIMEOUT_SECONDS = 10
88

99

10-
async def check_model_health(
11-
served_model_name: str, api_key: str, ever_healthy: bool = False
12-
) -> ModelHealth:
10+
async def check_model_health(model_name: str, api_key: str) -> ModelHealth:
1311
try:
1412
async with httpx.AsyncClient() as client:
1513
response = await client.post(
@@ -19,14 +17,14 @@ async def check_model_health(
1917
"Authorization": f"Bearer {api_key}",
2018
},
2119
json={
22-
"model": served_model_name,
20+
"model": model_name,
2321
"messages": [_MESSAGE],
2422
"stream": False,
2523
},
2624
timeout=_TIMEOUT_SECONDS,
2725
)
2826
if response.is_success:
2927
return ModelHealth.HEALTHY
30-
return ModelHealth.NOT_RESPONDING if ever_healthy else ModelHealth.NOT_DEPLOYED
28+
return ModelHealth.NOT_RESPONDING
3129
except (httpx.TransportError, httpx.TimeoutException):
3230
return ModelHealth.ERROR

src/swiss_ai_model_launch/cli/healthcheck/model_health.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22

33

44
class ModelHealth(Enum):
5-
WAITING = "WAITING"
65
HEALTHY = "HEALTHY"
76
ERROR = "ERROR"
87
NOT_DEPLOYED = "NOT_DEPLOYED"

src/swiss_ai_model_launch/cli/main.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -509,9 +509,9 @@ async def _monitor() -> None:
509509
job_status = await launcher.get_job_status(job_id)
510510
state.update(job_status=job_status)
511511

512-
model_health = await check_model_health(
513-
served, cscs_api_key, ever_healthy=ever_healthy
514-
)
512+
model_health = await check_model_health(served, cscs_api_key)
513+
if model_health == ModelHealth.NOT_RESPONDING and not ever_healthy:
514+
model_health = ModelHealth.NOT_DEPLOYED
515515
ever_healthy = ever_healthy or model_health == ModelHealth.HEALTHY
516516
state.update(model_health=model_health)
517517

tests/integration/test_firecrest_launcher.py

Lines changed: 3 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,13 @@
1-
import asyncio
21
import importlib.resources
32
import json
43
import os
54

65
import firecrest as f7t
76
import pytest
87

9-
from swiss_ai_model_launch.cli.healthcheck import ModelHealth, check_model_health
108
from swiss_ai_model_launch.launchers.firecrest_launcher import FirecRESTLauncher
119
from swiss_ai_model_launch.launchers.launch_request import LaunchRequest
12-
from swiss_ai_model_launch.launchers.launcher import JobStatus
10+
from tests.integration.utils import wait_for_job_running, wait_for_model_healthy
1311

1412
_LAUNCH_TIMEOUT = 60
1513
_HEALTH_TIMEOUT = 120
@@ -74,42 +72,6 @@ def cscs_api_key(env: dict[str, str]) -> str:
7472
return env["SML_CSCS_API_KEY"]
7573

7674

77-
async def _wait_for_job_running(
78-
launcher: FirecRESTLauncher,
79-
job_id: int,
80-
timeout_min: int,
81-
poll_interval_seconds: int = 15,
82-
) -> None:
83-
deadline = asyncio.get_event_loop().time() + timeout_min * 60
84-
while asyncio.get_event_loop().time() < deadline:
85-
await asyncio.sleep(poll_interval_seconds)
86-
status = await launcher.get_job_status(job_id)
87-
print(f"[job {job_id}] status: {status.value}")
88-
if status == JobStatus.RUNNING:
89-
return
90-
if status == JobStatus.TIMEOUT:
91-
pytest.fail(f"Job {job_id} timed out before becoming RUNNING.")
92-
pytest.fail(f"Job {job_id} didn't reach RUNNING within {timeout_min} mins.")
93-
94-
95-
async def _wait_for_model_healthy(
96-
served_model_name: str,
97-
api_key: str,
98-
timeout_min: int,
99-
poll_interval_seconds: int = 30,
100-
) -> None:
101-
deadline = asyncio.get_event_loop().time() + timeout_min * 60
102-
while asyncio.get_event_loop().time() < deadline:
103-
await asyncio.sleep(poll_interval_seconds)
104-
health = await check_model_health(served_model_name, api_key)
105-
print(f"[{served_model_name}] health: {health.value}")
106-
if health == ModelHealth.HEALTHY:
107-
return
108-
pytest.fail(
109-
f"Model '{served_model_name}' didn't become HEALTHY within {timeout_min} mins."
110-
)
111-
112-
11375
@pytest.mark.parametrize("launch_request", _LAUNCH_REQUESTS) # type: ignore[misc]
11476
async def test_launch_apertus_and_health(
11577
launcher: FirecRESTLauncher,
@@ -123,7 +85,7 @@ async def test_launch_apertus_and_health(
12385
assert served_model_name
12486

12587
try:
126-
await _wait_for_job_running(launcher, job_id, _LAUNCH_TIMEOUT)
127-
await _wait_for_model_healthy(served_model_name, cscs_api_key, _HEALTH_TIMEOUT)
88+
await wait_for_job_running(launcher, job_id, _LAUNCH_TIMEOUT)
89+
await wait_for_model_healthy(served_model_name, cscs_api_key, _HEALTH_TIMEOUT)
12890
finally:
12991
await launcher.cancel_job(job_id)

tests/integration/utils.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
import asyncio
2+
3+
import pytest
4+
5+
from swiss_ai_model_launch.cli.healthcheck import ModelHealth, check_model_health
6+
from swiss_ai_model_launch.launchers.firecrest_launcher import FirecRESTLauncher
7+
from swiss_ai_model_launch.launchers.launcher import JobStatus
8+
9+
10+
async def wait_for_job_running(
11+
launcher: FirecRESTLauncher,
12+
job_id: int,
13+
timeout_min: int,
14+
poll_interval_seconds: int = 15,
15+
) -> None:
16+
deadline = asyncio.get_event_loop().time() + timeout_min * 60
17+
while asyncio.get_event_loop().time() < deadline:
18+
await asyncio.sleep(poll_interval_seconds)
19+
status = await launcher.get_job_status(job_id)
20+
print(f"[job {job_id}] status: {status.value}")
21+
if status == JobStatus.RUNNING:
22+
return
23+
if status == JobStatus.TIMEOUT:
24+
pytest.fail(f"Job {job_id} timed out before becoming RUNNING.")
25+
pytest.fail(f"Job {job_id} didn't reach RUNNING within {timeout_min} mins.")
26+
27+
28+
async def wait_for_model_healthy(
29+
model_name: str,
30+
api_key: str,
31+
timeout_min: int,
32+
poll_interval_seconds: int = 30,
33+
) -> None:
34+
deadline = asyncio.get_event_loop().time() + timeout_min * 60
35+
while asyncio.get_event_loop().time() < deadline:
36+
await asyncio.sleep(poll_interval_seconds)
37+
health = await check_model_health(model_name, api_key)
38+
print(f"[{model_name}] health: {health.value}")
39+
if health == ModelHealth.HEALTHY:
40+
return
41+
pytest.fail(f"'{model_name}' didn't become HEALTHY within {timeout_min} mins.")

0 commit comments

Comments
 (0)