44import grp
55import os
66import re
7- from collections .abc import Awaitable , Callable
8- from typing import cast
7+ from collections .abc import Awaitable , Callable , Coroutine
8+ from typing import Any , cast
99
1010import firecrest as f7t
1111
1919)
2020from swiss_ai_model_launch .cli .display import DisplayState , LiveDisplay
2121from swiss_ai_model_launch .cli .healthcheck import check_model_health
22+ from swiss_ai_model_launch .cli .healthcheck .model_health import ModelHealth
2223from swiss_ai_model_launch .launchers import FirecRESTLauncher , Launcher , SlurmLauncher
2324from swiss_ai_model_launch .launchers .launch_args import LaunchArgs
2425from swiss_ai_model_launch .launchers .launch_request import LaunchRequest
@@ -436,17 +437,12 @@ async def _get_frameworks(
436437 )
437438
438439
439- async def _run_preconfigured (args : argparse .Namespace ) -> None :
440- if not InitConfig .exists ():
441- print ("SML is not configured. Run `sml init` first." )
442- return
443-
444- config = InitConfig .load ()
440+ async def _create_launcher (config : InitConfig , args : argparse .Namespace ) -> Launcher :
445441 launcher_type = config .get_non_none_value ("launcher" )
446442 telemetry_endpoint = config .get_value ("telemetry_endpoint" )
447443 if launcher_type == "firecrest" :
448444 firecrest_client = _get_firecrest_client_from_init_config (config )
449- launcher = cast (
445+ return cast (
450446 Launcher ,
451447 await _get_firecrest_launcher_with_client (
452448 firecrest_client ,
@@ -455,7 +451,7 @@ async def _run_preconfigured(args: argparse.Namespace) -> None:
455451 ),
456452 )
457453 elif launcher_type == "slurm" :
458- launcher = cast (
454+ return cast (
459455 Launcher ,
460456 await _get_slurm_launcher (
461457 telemetry_endpoint = telemetry_endpoint ,
@@ -465,22 +461,33 @@ async def _run_preconfigured(args: argparse.Namespace) -> None:
465461 else :
466462 raise NotImplementedError (f"Launcher { launcher_type } is not supported yet." )
467463
468- cscs_api_key = config .get_non_none_value ("cscs_api_key" )
469- launch_request = await _get_launch_request (launcher , args )
470464
465+ async def _run_monitor (
466+ launcher : Launcher ,
467+ launch_coro : Coroutine [Any , Any , tuple [int , str ]],
468+ cscs_api_key : str ,
469+ ) -> None :
471470 state = DisplayState ()
472471 state .update (cluster = launcher .system_name , partition = launcher .partition )
473472
474473 async def _monitor () -> None :
475- job_id , served_model_name = await launcher .launch_model (launch_request )
476- state .update (job_id = job_id , served_model_name = served_model_name )
474+ job_id , served = await launch_coro
475+ state .update (
476+ job_id = job_id ,
477+ served_model_name = served ,
478+ model_health = ModelHealth .NOT_DEPLOYED ,
479+ )
480+ ever_healthy = False
477481 while True :
478482 await asyncio .sleep (5 )
479483
480484 job_status = await launcher .get_job_status (job_id )
481485 state .update (job_status = job_status )
482486
483- model_health = await check_model_health (served_model_name , cscs_api_key )
487+ model_health = await check_model_health (
488+ served , cscs_api_key , ever_healthy = ever_healthy
489+ )
490+ ever_healthy = ever_healthy or model_health == ModelHealth .HEALTHY
484491 state .update (model_health = model_health )
485492
486493 o , e = await launcher .get_job_logs (job_id )
@@ -492,36 +499,27 @@ async def _monitor() -> None:
492499 await launcher .cancel_job (state .job_id )
493500
494501
495- async def _run_advanced (args : argparse .Namespace ) -> None :
502+ async def _run_preconfigured (args : argparse .Namespace ) -> None :
496503 if not InitConfig .exists ():
497504 print ("SML is not configured. Run `sml init` first." )
498505 return
499506
500507 config = InitConfig .load ()
501- launcher_type = config .get_non_none_value ("launcher" )
502- telemetry_endpoint = config .get_value ("telemetry_endpoint" )
503- if launcher_type == "firecrest" :
504- firecrest_client = _get_firecrest_client_from_init_config (config )
505- launcher = cast (
506- Launcher ,
507- await _get_firecrest_launcher_with_client (
508- firecrest_client ,
509- telemetry_endpoint = telemetry_endpoint ,
510- args = args ,
511- ),
512- )
513- elif launcher_type == "slurm" :
514- launcher = cast (
515- Launcher ,
516- await _get_slurm_launcher (
517- telemetry_endpoint = telemetry_endpoint ,
518- args = args ,
519- ),
520- )
521- else :
522- raise NotImplementedError (f"Launcher { launcher_type } is not supported yet." )
508+ launcher = await _create_launcher (config , args )
509+ cscs_api_key = config .get_non_none_value ("cscs_api_key" )
510+ launch_request = await _get_launch_request (launcher , args )
511+ await _run_monitor (launcher , launcher .launch_model (launch_request ), cscs_api_key )
512+
513+
514+ async def _run_advanced (args : argparse .Namespace ) -> None :
515+ if not InitConfig .exists ():
516+ print ("SML is not configured. Run `sml init` first." )
517+ return
523518
519+ config = InitConfig .load ()
520+ launcher = await _create_launcher (config , args )
524521 cscs_api_key = config .get_non_none_value ("cscs_api_key" )
522+
525523 if args .served_model_name :
526524 served_model_name = args .served_model_name
527525 else :
@@ -551,31 +549,10 @@ async def _run_advanced(args: argparse.Namespace) -> None:
551549 use_router = args .use_router ,
552550 router_args = args .router_args ,
553551 disable_ocf = args .disable_ocf ,
554- telemetry_endpoint = telemetry_endpoint ,
552+ telemetry_endpoint = config . get_value ( " telemetry_endpoint" ) ,
555553 )
556554
557- state = DisplayState ()
558- state .update (cluster = launcher .system_name , partition = launcher .partition )
559-
560- async def _monitor () -> None :
561- job_id , served = await launcher .launch_with_args (launch_args )
562- state .update (job_id = job_id , served_model_name = served )
563- while True :
564- await asyncio .sleep (5 )
565-
566- job_status = await launcher .get_job_status (job_id )
567- state .update (job_status = job_status )
568-
569- model_health = await check_model_health (served , cscs_api_key )
570- state .update (model_health = model_health )
571-
572- o , e = await launcher .get_job_logs (job_id )
573- state .set_out_log (o )
574- state .set_err_log (e )
575-
576- kill_job = await LiveDisplay (state ).run (_monitor ())
577- if kill_job and state .job_id is not None :
578- await launcher .cancel_job (state .job_id )
555+ await _run_monitor (launcher , launcher .launch_with_args (launch_args ), cscs_api_key )
579556
580557
581558async def _main (args : argparse .Namespace ) -> None :
0 commit comments