Skip to content

Commit 153405d

Browse files
committed
Add Reservation
1 parent 5b4c337 commit 153405d

10 files changed

Lines changed: 51 additions & 1 deletion

File tree

.github/workflows/ci-comprehensive.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ jobs:
2323
FIRECREST_TOKEN_URI: ${{ secrets.FIRECREST_TOKEN_URI }}
2424
FIRECREST_URL: ${{ secrets.FIRECREST_URL }}
2525
FIRECREST_USERNAME: ${{ secrets.FIRECREST_USERNAME }}
26+
RESERVATION: ${{ secrets.RESERVATION }}
2627
steps:
2728
- uses: actions/checkout@v5
2829
- uses: ./.github/actions/setup

.github/workflows/ci-lightweight.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ jobs:
2323
FIRECREST_TOKEN_URI: ${{ secrets.FIRECREST_TOKEN_URI }}
2424
FIRECREST_URL: ${{ secrets.FIRECREST_URL }}
2525
FIRECREST_USERNAME: ${{ secrets.FIRECREST_USERNAME }}
26+
RESERVATION: ${{ secrets.RESERVATION }}
2627
steps:
2728
- uses: actions/checkout@v5
2829
- uses: ./.github/actions/setup

README.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,7 @@ If you want to skip the interactive prompts and launch a pre-configured model di
7979
| ----------------------- | -------------------------------- | ---------------------------------------------------------------------- |
8080
| `--firecrest-system` | `SML_FIRECREST_SYSTEM` | Target system to launch on (required only if using FirecREST launcher) |
8181
| `--partition` | `SML_PARTITION` | SLURM partition to use |
82+
| `--reservation` | `SML_RESERVATION` | SLURM reservation name (optional) |
8283
| `--model` | | Model to launch (`<vendor>/<model>`) |
8384
| `--framework` | | Inference framework to use |
8485
| `--workers` | | Number of workers |
@@ -117,6 +118,7 @@ For full control over the SLURM job, use `sml advanced`. This bypasses the model
117118
| -------------------------- | ------------------------- | ----------------------------------------------------------------- |
118119
| `--firecrest-system` | `SML_FIRECREST_SYSTEM` | Target HPC system to launch on |
119120
| `--partition` | `SML_PARTITION` | SLURM partition to use |
121+
| `--slurm-reservation` | `SML_RESERVATION` | SLURM reservation name (optional) |
120122
| `--serving-framework` | | Inference framework (`sglang`, `vllm`) — **required** |
121123
| `--slurm-environment` | | Local path to the environment `.toml` file — **required** |
122124
| `--framework-args` | | Arguments forwarded to the inference framework |
@@ -174,6 +176,7 @@ export FIRECREST_SYSTEM=clariden
174176
export FIRECREST_ACCOUNT=<your-account>
175177
export FIRECREST_PARTITION=normal
176178
export CSCS_API_KEY=<your-api-key>
179+
export RESERVATION=<your-reservation>
177180
```
178181

179182
This file will be sourced when running the tests with `make test`, and the environment variables will be available for the tests.

src/swiss_ai_model_launch/assets/template.jinja

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
#SBATCH --exclusive
66
#SBATCH --nodes={{ nodes }}
77
#SBATCH --partition={{ partition }}
8+
{% if reservation %}#SBATCH --reservation={{ reservation }}{% endif %}
89
#SBATCH --output=logs/%j/log.out
910
#SBATCH --error=logs/%j/log.out
1011

src/swiss_ai_model_launch/cli/main.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,19 @@ def _make_partition_config(
7373
)
7474

7575

76+
def _make_reservation_config() -> ChainConfiguration:
77+
return ChainConfiguration(
78+
name="reservation_configuration",
79+
chain=[
80+
TextConfiguration(
81+
name="reservation",
82+
prompt="SLURM reservation name (optional, leave blank to skip).",
83+
env_var="SML_RESERVATION",
84+
),
85+
],
86+
)
87+
88+
7689
def _make_launch_request_config(
7790
vendor_models_factory: _OptionsFactory = None,
7891
frameworks_factory: _OptionsFactory = None,
@@ -148,6 +161,7 @@ def _build_parser() -> argparse.ArgumentParser:
148161
)
149162
_make_firecrest_launcher_config().add_to_parser(preconfigured_parser)
150163
_make_partition_config().add_to_parser(preconfigured_parser)
164+
_make_reservation_config().add_to_parser(preconfigured_parser)
151165
_make_launch_request_config().add_to_parser(preconfigured_parser)
152166

153167
advanced_parser = subparsers.add_parser(
@@ -203,6 +217,13 @@ def _build_parser() -> argparse.ArgumentParser:
203217
metavar="HH:MM:SS",
204218
help="Job time limit (default: 00:05:00).",
205219
)
220+
advanced_parser.add_argument(
221+
"--slurm-reservation",
222+
dest="reservation",
223+
default=None,
224+
metavar="RESERVATION",
225+
help="SLURM reservation name (optional).",
226+
)
206227
advanced_parser.add_argument(
207228
"--served-model-name",
208229
dest="served_model_name",
@@ -290,12 +311,16 @@ async def _get_partitions() -> dict[str, tuple[str, str]]:
290311
partition_config = _make_partition_config(partitions_factory=_get_partitions)
291312
await partition_config.aconfigure(args=args)
292313

314+
reservation_config = _make_reservation_config()
315+
await reservation_config.aconfigure(args=args)
316+
293317
return FirecRESTLauncher(
294318
client,
295319
system_name=system_name,
296320
username=user_info["user"]["name"],
297321
account=user_info["group"]["name"],
298322
partition=partition_config.get_non_none_value("partition"),
323+
reservation=reservation_config.get_value("reservation") or None,
299324
telemetry_endpoint=telemetry_endpoint,
300325
)
301326

@@ -320,11 +345,15 @@ async def _get_partitions() -> dict[str, tuple[str, str]]:
320345
partition_config = _make_partition_config(partitions_factory=_get_partitions)
321346
await partition_config.aconfigure(args=args)
322347

348+
reservation_config = _make_reservation_config()
349+
await reservation_config.aconfigure(args=args)
350+
323351
return SlurmLauncher(
324352
system_name="local",
325353
username=getpass.getuser(),
326354
account=grp.getgrgid(os.getgid()).gr_name,
327355
partition=partition_config.get_non_none_value("partition"),
356+
reservation=reservation_config.get_value("reservation") or None,
328357
telemetry_endpoint=telemetry_endpoint,
329358
)
330359

@@ -541,6 +570,7 @@ async def _run_advanced(args: argparse.Namespace) -> None:
541570
nodes_per_worker=args.nodes_per_worker,
542571
nodes=args.nodes,
543572
time=args.time,
573+
reservation=args.reservation or None,
544574
environment=args.slurm_environment,
545575
framework=args.framework,
546576
framework_args=args.framework_args,

src/swiss_ai_model_launch/launchers/firecrest_launcher.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,13 +29,15 @@ def __init__(
2929
username: str,
3030
account: str,
3131
partition: str,
32+
reservation: str | None = None,
3233
telemetry_endpoint: str | None = None,
3334
):
3435
super().__init__(
3536
system_name=system_name,
3637
username=username,
3738
account=account,
3839
partition=partition,
40+
reservation=reservation,
3941
telemetry_endpoint=telemetry_endpoint,
4042
)
4143
self.client = client
@@ -64,6 +66,7 @@ def _get_launch_args_from_request(
6466
workers=launch_request.workers,
6567
nodes_per_worker=launch_request.nodes_per_worker,
6668
time=launch_request.time,
69+
reservation=self.reservation,
6770
environment=launch_request.environment,
6871
framework=launch_request.framework,
6972
served_model_name=served_model_name,
@@ -128,7 +131,9 @@ async def launch_with_args(self, launch_args: LaunchArgs) -> tuple[int, str]:
128131
remote_env_path = await self._upload_env_file(
129132
launch_args.environment, launch_args.framework
130133
)
131-
launch_args = launch_args.model_copy(update={"environment": remote_env_path})
134+
launch_args = launch_args.model_copy(
135+
update={"environment": remote_env_path, "reservation": self.reservation}
136+
)
132137
script_str = render_job_script(launch_args)
133138
job_submission_report = await self.client.submit(
134139
system_name=self.system_name,

src/swiss_ai_model_launch/launchers/launch_args.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ class LaunchArgs(BaseModel):
1212
nodes: int | None = None
1313

1414
time: str = "00:05:00"
15+
reservation: str | None = None
1516
environment: str
1617

1718
framework: str

src/swiss_ai_model_launch/launchers/launcher.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,12 +26,14 @@ def __init__(
2626
username: str,
2727
account: str,
2828
partition: str,
29+
reservation: str | None = None,
2930
telemetry_endpoint: str | None = None,
3031
):
3132
self.system_name = system_name
3233
self.username = username
3334
self.account = account
3435
self.partition = partition
36+
self.reservation = reservation
3537
self.telemetry_endpoint = telemetry_endpoint
3638

3739
@abstractmethod

src/swiss_ai_model_launch/launchers/slurm_launcher.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ def __init__(
2727
username: str,
2828
account: str,
2929
partition: str,
30+
reservation: str | None = None,
3031
model_registry: Path = _REMOTE_MODEL_REGISTRY,
3132
telemetry_endpoint: str | None = None,
3233
):
@@ -35,6 +36,7 @@ def __init__(
3536
username=username,
3637
account=account,
3738
partition=partition,
39+
reservation=reservation,
3840
telemetry_endpoint=telemetry_endpoint,
3941
)
4042
self.model_registry = model_registry
@@ -59,6 +61,7 @@ def _get_launch_args_from_request(
5961
workers=launch_request.workers,
6062
nodes_per_worker=launch_request.nodes_per_worker,
6163
time=launch_request.time,
64+
reservation=self.reservation,
6265
environment=launch_request.environment,
6366
framework=launch_request.framework,
6467
served_model_name=served_model_name,
@@ -138,6 +141,7 @@ async def get_preconfigured_models(self) -> list[LaunchRequest]:
138141
]
139142

140143
async def launch_with_args(self, launch_args: LaunchArgs) -> tuple[int, str]:
144+
launch_args = launch_args.model_copy(update={"reservation": self.reservation})
141145
job_id = await self._sbatch(launch_args)
142146
return job_id, launch_args.served_model_name
143147

tests/integration/test_firecrest_launcher.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@
3838
"FIRECREST_ACCOUNT",
3939
"FIRECREST_PARTITION",
4040
"CSCS_API_KEY",
41+
"RESERVATION",
4142
]
4243

4344

@@ -68,6 +69,7 @@ def launcher(env: dict[str, str]) -> FirecRESTLauncher:
6869
username=env["FIRECREST_USERNAME"],
6970
account=env["FIRECREST_ACCOUNT"],
7071
partition=env["FIRECREST_PARTITION"],
72+
reservation=env["RESERVATION"] or None,
7173
)
7274

7375

0 commit comments

Comments
 (0)