@@ -73,6 +73,19 @@ def _make_partition_config(
7373 )
7474
7575
76+ def _make_reservation_config () -> ChainConfiguration :
77+ return ChainConfiguration (
78+ name = "reservation_configuration" ,
79+ chain = [
80+ TextConfiguration (
81+ name = "reservation" ,
82+ prompt = "SLURM reservation name (optional, leave blank to skip)." ,
83+ env_var = "SML_RESERVATION" ,
84+ ),
85+ ],
86+ )
87+
88+
7689def _make_launch_request_config (
7790 vendor_models_factory : _OptionsFactory = None ,
7891 frameworks_factory : _OptionsFactory = None ,
@@ -148,6 +161,7 @@ def _build_parser() -> argparse.ArgumentParser:
148161 )
149162 _make_firecrest_launcher_config ().add_to_parser (preconfigured_parser )
150163 _make_partition_config ().add_to_parser (preconfigured_parser )
164+ _make_reservation_config ().add_to_parser (preconfigured_parser )
151165 _make_launch_request_config ().add_to_parser (preconfigured_parser )
152166
153167 advanced_parser = subparsers .add_parser (
@@ -203,6 +217,13 @@ def _build_parser() -> argparse.ArgumentParser:
203217 metavar = "HH:MM:SS" ,
204218 help = "Job time limit (default: 00:05:00)." ,
205219 )
220+ advanced_parser .add_argument (
221+ "--slurm-reservation" ,
222+ dest = "reservation" ,
223+ default = None ,
224+ metavar = "RESERVATION" ,
225+ help = "SLURM reservation name (optional)." ,
226+ )
206227 advanced_parser .add_argument (
207228 "--served-model-name" ,
208229 dest = "served_model_name" ,
@@ -290,12 +311,16 @@ async def _get_partitions() -> dict[str, tuple[str, str]]:
290311 partition_config = _make_partition_config (partitions_factory = _get_partitions )
291312 await partition_config .aconfigure (args = args )
292313
314+ reservation_config = _make_reservation_config ()
315+ await reservation_config .aconfigure (args = args )
316+
293317 return FirecRESTLauncher (
294318 client ,
295319 system_name = system_name ,
296320 username = user_info ["user" ]["name" ],
297321 account = user_info ["group" ]["name" ],
298322 partition = partition_config .get_non_none_value ("partition" ),
323+ reservation = reservation_config .get_value ("reservation" ) or None ,
299324 telemetry_endpoint = telemetry_endpoint ,
300325 )
301326
@@ -320,11 +345,15 @@ async def _get_partitions() -> dict[str, tuple[str, str]]:
320345 partition_config = _make_partition_config (partitions_factory = _get_partitions )
321346 await partition_config .aconfigure (args = args )
322347
348+ reservation_config = _make_reservation_config ()
349+ await reservation_config .aconfigure (args = args )
350+
323351 return SlurmLauncher (
324352 system_name = "local" ,
325353 username = getpass .getuser (),
326354 account = grp .getgrgid (os .getgid ()).gr_name ,
327355 partition = partition_config .get_non_none_value ("partition" ),
356+ reservation = reservation_config .get_value ("reservation" ) or None ,
328357 telemetry_endpoint = telemetry_endpoint ,
329358 )
330359
@@ -541,6 +570,7 @@ async def _run_advanced(args: argparse.Namespace) -> None:
541570 nodes_per_worker = args .nodes_per_worker ,
542571 nodes = args .nodes ,
543572 time = args .time ,
573+ reservation = args .reservation or None ,
544574 environment = args .slurm_environment ,
545575 framework = args .framework ,
546576 framework_args = args .framework_args ,
0 commit comments