Skip to content

Commit fec7bd7

Browse files
committed
Model Health
1 parent 0731ac1 commit fec7bd7

4 files changed

Lines changed: 47 additions & 66 deletions

File tree

src/swiss_ai_model_launch/cli/display/live.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
ModelHealth.WAITING: "[yellow]WAITING[/yellow]",
2323
ModelHealth.HEALTHY: "[green]HEALTHY[/green]",
2424
ModelHealth.ERROR: "[orange]ERROR[/orange]",
25+
ModelHealth.NOT_DEPLOYED: "[dim]NOT DEPLOYED[/dim]",
2526
ModelHealth.NOT_RESPONDING: "[red]NOT RESPONDING[/red]",
2627
}
2728

src/swiss_ai_model_launch/cli/healthcheck/checker.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,12 @@
44

55
_HEALTH_CHECK_URL = "https://api.swissai.svc.cscs.ch/v1/chat/completions"
66
_MESSAGE = {"role": "user", "content": "Say hello."}
7-
_TIMEOUT_SECONDS = 60
7+
_TIMEOUT_SECONDS = 10
88

99

10-
async def check_model_health(served_model_name: str, api_key: str) -> ModelHealth:
10+
async def check_model_health(
11+
served_model_name: str, api_key: str, ever_healthy: bool = False
12+
) -> ModelHealth:
1113
try:
1214
async with httpx.AsyncClient() as client:
1315
response = await client.post(
@@ -23,8 +25,8 @@ async def check_model_health(served_model_name: str, api_key: str) -> ModelHealt
2325
},
2426
timeout=_TIMEOUT_SECONDS,
2527
)
26-
return (
27-
ModelHealth.HEALTHY if response.is_success else ModelHealth.NOT_RESPONDING
28-
)
28+
if response.is_success:
29+
return ModelHealth.HEALTHY
30+
return ModelHealth.NOT_RESPONDING if ever_healthy else ModelHealth.NOT_DEPLOYED
2931
except (httpx.TransportError, httpx.TimeoutException):
3032
return ModelHealth.ERROR

src/swiss_ai_model_launch/cli/healthcheck/model_health.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,4 +5,5 @@ class ModelHealth(Enum):
55
WAITING = "WAITING"
66
HEALTHY = "HEALTHY"
77
ERROR = "ERROR"
8+
NOT_DEPLOYED = "NOT_DEPLOYED"
89
NOT_RESPONDING = "NOT_RESPONDING"

src/swiss_ai_model_launch/cli/main.py

Lines changed: 38 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,8 @@
44
import grp
55
import os
66
import re
7-
from collections.abc import Awaitable, Callable
8-
from typing import cast
7+
from collections.abc import Awaitable, Callable, Coroutine
8+
from typing import Any, cast
99

1010
import firecrest as f7t
1111

@@ -19,6 +19,7 @@
1919
)
2020
from swiss_ai_model_launch.cli.display import DisplayState, LiveDisplay
2121
from swiss_ai_model_launch.cli.healthcheck import check_model_health
22+
from swiss_ai_model_launch.cli.healthcheck.model_health import ModelHealth
2223
from swiss_ai_model_launch.launchers import FirecRESTLauncher, Launcher, SlurmLauncher
2324
from swiss_ai_model_launch.launchers.launch_args import LaunchArgs
2425
from swiss_ai_model_launch.launchers.launch_request import LaunchRequest
@@ -436,17 +437,12 @@ async def _get_frameworks(
436437
)
437438

438439

439-
async def _run_preconfigured(args: argparse.Namespace) -> None:
440-
if not InitConfig.exists():
441-
print("SML is not configured. Run `sml init` first.")
442-
return
443-
444-
config = InitConfig.load()
440+
async def _create_launcher(config: InitConfig, args: argparse.Namespace) -> Launcher:
445441
launcher_type = config.get_non_none_value("launcher")
446442
telemetry_endpoint = config.get_value("telemetry_endpoint")
447443
if launcher_type == "firecrest":
448444
firecrest_client = _get_firecrest_client_from_init_config(config)
449-
launcher = cast(
445+
return cast(
450446
Launcher,
451447
await _get_firecrest_launcher_with_client(
452448
firecrest_client,
@@ -455,7 +451,7 @@ async def _run_preconfigured(args: argparse.Namespace) -> None:
455451
),
456452
)
457453
elif launcher_type == "slurm":
458-
launcher = cast(
454+
return cast(
459455
Launcher,
460456
await _get_slurm_launcher(
461457
telemetry_endpoint=telemetry_endpoint,
@@ -465,22 +461,33 @@ async def _run_preconfigured(args: argparse.Namespace) -> None:
465461
else:
466462
raise NotImplementedError(f"Launcher {launcher_type} is not supported yet.")
467463

468-
cscs_api_key = config.get_non_none_value("cscs_api_key")
469-
launch_request = await _get_launch_request(launcher, args)
470464

465+
async def _run_monitor(
466+
launcher: Launcher,
467+
launch_coro: Coroutine[Any, Any, tuple[int, str]],
468+
cscs_api_key: str,
469+
) -> None:
471470
state = DisplayState()
472471
state.update(cluster=launcher.system_name, partition=launcher.partition)
473472

474473
async def _monitor() -> None:
475-
job_id, served_model_name = await launcher.launch_model(launch_request)
476-
state.update(job_id=job_id, served_model_name=served_model_name)
474+
job_id, served = await launch_coro
475+
state.update(
476+
job_id=job_id,
477+
served_model_name=served,
478+
model_health=ModelHealth.NOT_DEPLOYED,
479+
)
480+
ever_healthy = False
477481
while True:
478482
await asyncio.sleep(5)
479483

480484
job_status = await launcher.get_job_status(job_id)
481485
state.update(job_status=job_status)
482486

483-
model_health = await check_model_health(served_model_name, cscs_api_key)
487+
model_health = await check_model_health(
488+
served, cscs_api_key, ever_healthy=ever_healthy
489+
)
490+
ever_healthy = ever_healthy or model_health == ModelHealth.HEALTHY
484491
state.update(model_health=model_health)
485492

486493
o, e = await launcher.get_job_logs(job_id)
@@ -492,36 +499,27 @@ async def _monitor() -> None:
492499
await launcher.cancel_job(state.job_id)
493500

494501

495-
async def _run_advanced(args: argparse.Namespace) -> None:
502+
async def _run_preconfigured(args: argparse.Namespace) -> None:
496503
if not InitConfig.exists():
497504
print("SML is not configured. Run `sml init` first.")
498505
return
499506

500507
config = InitConfig.load()
501-
launcher_type = config.get_non_none_value("launcher")
502-
telemetry_endpoint = config.get_value("telemetry_endpoint")
503-
if launcher_type == "firecrest":
504-
firecrest_client = _get_firecrest_client_from_init_config(config)
505-
launcher = cast(
506-
Launcher,
507-
await _get_firecrest_launcher_with_client(
508-
firecrest_client,
509-
telemetry_endpoint=telemetry_endpoint,
510-
args=args,
511-
),
512-
)
513-
elif launcher_type == "slurm":
514-
launcher = cast(
515-
Launcher,
516-
await _get_slurm_launcher(
517-
telemetry_endpoint=telemetry_endpoint,
518-
args=args,
519-
),
520-
)
521-
else:
522-
raise NotImplementedError(f"Launcher {launcher_type} is not supported yet.")
508+
launcher = await _create_launcher(config, args)
509+
cscs_api_key = config.get_non_none_value("cscs_api_key")
510+
launch_request = await _get_launch_request(launcher, args)
511+
await _run_monitor(launcher, launcher.launch_model(launch_request), cscs_api_key)
512+
513+
514+
async def _run_advanced(args: argparse.Namespace) -> None:
515+
if not InitConfig.exists():
516+
print("SML is not configured. Run `sml init` first.")
517+
return
523518

519+
config = InitConfig.load()
520+
launcher = await _create_launcher(config, args)
524521
cscs_api_key = config.get_non_none_value("cscs_api_key")
522+
525523
if args.served_model_name:
526524
served_model_name = args.served_model_name
527525
else:
@@ -551,31 +549,10 @@ async def _run_advanced(args: argparse.Namespace) -> None:
551549
use_router=args.use_router,
552550
router_args=args.router_args,
553551
disable_ocf=args.disable_ocf,
554-
telemetry_endpoint=telemetry_endpoint,
552+
telemetry_endpoint=config.get_value("telemetry_endpoint"),
555553
)
556554

557-
state = DisplayState()
558-
state.update(cluster=launcher.system_name, partition=launcher.partition)
559-
560-
async def _monitor() -> None:
561-
job_id, served = await launcher.launch_with_args(launch_args)
562-
state.update(job_id=job_id, served_model_name=served)
563-
while True:
564-
await asyncio.sleep(5)
565-
566-
job_status = await launcher.get_job_status(job_id)
567-
state.update(job_status=job_status)
568-
569-
model_health = await check_model_health(served, cscs_api_key)
570-
state.update(model_health=model_health)
571-
572-
o, e = await launcher.get_job_logs(job_id)
573-
state.set_out_log(o)
574-
state.set_err_log(e)
575-
576-
kill_job = await LiveDisplay(state).run(_monitor())
577-
if kill_job and state.job_id is not None:
578-
await launcher.cancel_job(state.job_id)
555+
await _run_monitor(launcher, launcher.launch_with_args(launch_args), cscs_api_key)
579556

580557

581558
async def _main(args: argparse.Namespace) -> None:

0 commit comments

Comments
 (0)