Skip to content

Commit a08a110

Browse files
Loadtest job length setting
1 parent 2cfca15 commit a08a110

4 files changed

Lines changed: 92 additions & 11 deletions

File tree

src/swiss_ai_model_launch/cli/main.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@
4747
_DEFAULT_LOADTEST_READY_POLL_SECONDS = 10
4848
_LOADTEST_READY_PROGRESS_SECONDS = 300
4949
_DEFAULT_LOADTEST_METRICS_REMOTE_WRITE_URL = "https://prometheus-dev.swissai.svc.cscs.ch/api/v1/write"
50+
_DEFAULT_LOADTEST_JOB_TIME = "00:30:00"
5051

5152

5253
def _make_firecrest_launcher_config(
@@ -178,6 +179,13 @@ def _add_loadtest_arguments(
178179
default=True,
179180
help="Wait for the cluster loadtest job and download/copy the summary (default: true).",
180181
)
182+
parser.add_argument(
183+
"--loadtest-job-time",
184+
dest="loadtest_job_time",
185+
default=_DEFAULT_LOADTEST_JOB_TIME,
186+
metavar="HH:MM:SS",
187+
help=f"SLURM time limit for the cluster k6 loadtest job (default: {_DEFAULT_LOADTEST_JOB_TIME}).",
188+
)
181189
parser.add_argument(
182190
"--loadtest-server-url",
183191
dest="loadtest_server_url",
@@ -870,8 +878,11 @@ def _make_cluster_loadtest_config(
870878
raise ValueError("--cancel-after-loadtest requires --wait-for-loadtest.")
871879
if args.loadtest_ready_timeout <= 0:
872880
raise ValueError("--loadtest-ready-timeout must be greater than 0.")
881+
if not re.fullmatch(r"[0-9]{1,2}:[0-5][0-9]:[0-5][0-9]", args.loadtest_job_time):
882+
raise ValueError("--loadtest-job-time must be in HH:MM:SS format.")
873883
return ClusterLoadtestConfig(
874884
container_image=str(DEFAULT_CLUSTER_CONTAINER_IMAGE),
885+
time=args.loadtest_job_time,
875886
wait=args.wait_for_loadtest,
876887
reservation=reservation or getattr(args, "reservation", None),
877888
metrics_remote_write_url=(
@@ -899,7 +910,7 @@ async def _run_k6_on_cluster(
899910
print(f"Loadtest prompts file: {prompts_file}")
900911
if cluster_config.metrics_remote_write_url:
901912
print(f"Loadtest metrics remote write: {cluster_config.metrics_remote_write_url}")
902-
job_id = await submit_cluster_loadtest(
913+
submission = await submit_cluster_loadtest(
903914
launcher=launcher,
904915
server=server,
905916
bench=loadtest_config,
@@ -908,7 +919,8 @@ async def _run_k6_on_cluster(
908919
summary_path=summary_path,
909920
cluster=cluster_config,
910921
)
911-
print(f"Cluster loadtest job submitted: {job_id}")
922+
print(f"Cluster loadtest job submitted: {submission.job_id}")
923+
print(f"Loadtest run label: {submission.run_label}")
912924
if cluster_config.wait:
913925
print(f"Loadtest summary: {summary_path}")
914926

src/swiss_ai_model_launch/loadtest/cluster.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,12 @@ class ClusterLoadtestConfig:
2828
metrics_remote_write_url: str | None = None
2929

3030

31+
@dataclass(frozen=True)
32+
class ClusterLoadtestSubmission:
33+
job_id: int
34+
run_label: str
35+
36+
3137
def build_cluster_loadtest_script(
3238
*,
3339
bench: LoadtestConfig,
@@ -205,7 +211,7 @@ async def submit_cluster_loadtest(
205211
prompts_file: Path,
206212
summary_path: Path,
207213
cluster: ClusterLoadtestConfig,
208-
) -> int:
214+
) -> ClusterLoadtestSubmission:
209215
run_label = f"loadtest_{bench.scenario}_{datetime.now().strftime('%Y%m%d_%H%M%S')}_{create_salt(6)}"
210216
prompts_path, container_mounts = _container_mounts_for_external_prompts(run_label, prompts_file)
211217
script = build_cluster_loadtest_script(
@@ -242,7 +248,7 @@ async def submit_cluster_loadtest(
242248
if remote_summary.exists():
243249
summary_path.parent.mkdir(parents=True, exist_ok=True)
244250
shutil.copyfile(remote_summary, summary_path)
245-
return job_id
251+
return ClusterLoadtestSubmission(job_id=job_id, run_label=run_label)
246252

247253
if isinstance(launcher, FirecRESTLauncher):
248254
firecrest_working_dir = str(launcher._get_working_dir())
@@ -294,6 +300,6 @@ async def submit_cluster_loadtest(
294300
)
295301
except f7t.FirecrestException as e:
296302
raise RuntimeError(f"Could not download cluster loadtest summary for job {job_id}: {e}") from e
297-
return job_id
303+
return ClusterLoadtestSubmission(job_id=job_id, run_label=run_label)
298304

299305
raise TypeError(f"Cluster loadtests are not supported for launcher type {type(launcher).__name__}")

tests/unit/test_cluster_loadtest.py

Lines changed: 46 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,15 @@
33

44
import pytest
55

6+
import swiss_ai_model_launch.loadtest.cluster as cluster_module
7+
from swiss_ai_model_launch.launchers import SlurmLauncher
68
from swiss_ai_model_launch.loadtest.cluster import (
79
ClusterLoadtestConfig,
810
_container_mounts_for_external_prompts,
911
build_cluster_loadtest_script,
12+
submit_cluster_loadtest,
1013
)
11-
from swiss_ai_model_launch.loadtest.models import LoadtestConfig
14+
from swiss_ai_model_launch.loadtest.models import LoadtestConfig, ServerConfig
1215

1316

1417
@pytest.fixture
@@ -123,9 +126,9 @@ def test_script_k6_run_command_shell_quotes_dynamic_values(
123126
tokens = shlex.split(inner_command)
124127

125128
assert tokens[tokens.index("--tag") + 1] == "scenario=throughput"
126-
assert f"run_label=run-$USER's-label" in tokens
127-
assert f"model=model-$USER's-name" in tokens
128-
assert f"PROMPTS_FILE=/capstor/prompt files/$USER/prompts.json" in tokens
129+
assert "run_label=run-$USER's-label" in tokens
130+
assert "model=model-$USER's-name" in tokens
131+
assert "PROMPTS_FILE=/capstor/prompt files/$USER/prompts.json" in tokens
129132
assert 'sh -lc "' not in script
130133

131134

@@ -147,3 +150,42 @@ def test_container_mounts_includes_top_level_dir() -> None:
147150
def test_container_mounts_prompts_path_returned() -> None:
148151
prompts_path, _ = _container_mounts_for_external_prompts("my_run", Path("/scratch/data/prompts.jsonl"))
149152
assert prompts_path == "/scratch/data/prompts.jsonl"
153+
154+
155+
@pytest.mark.asyncio
156+
async def test_submit_cluster_loadtest_returns_job_id_and_run_label(
157+
tmp_path: Path,
158+
monkeypatch: pytest.MonkeyPatch,
159+
bench: LoadtestConfig,
160+
) -> None:
161+
async def fake_run_checked(*cmd: str) -> str:
162+
assert cmd[0] == "sbatch"
163+
return "Submitted batch job 123\n"
164+
165+
monkeypatch.setenv("HOME", str(tmp_path))
166+
monkeypatch.setattr(cluster_module, "_run_checked", fake_run_checked)
167+
monkeypatch.setattr(cluster_module, "create_salt", lambda length: "X" * length)
168+
169+
k6_script = tmp_path / "script.js"
170+
k6_script.write_text("export default function() {}\n")
171+
prompts_file = tmp_path / "prompts.jsonl"
172+
prompts_file.write_text("{}\n")
173+
174+
submission = await submit_cluster_loadtest(
175+
launcher=SlurmLauncher(
176+
system_name="test-system",
177+
username="test-user",
178+
account="test-account",
179+
partition="test-partition",
180+
),
181+
server=ServerConfig(url="https://example.test", api_key="secret", model="test-model", is_swissai=True),
182+
bench=bench,
183+
k6_script=k6_script,
184+
prompts_file=prompts_file,
185+
summary_path=tmp_path / "summary.json",
186+
cluster=ClusterLoadtestConfig(container_image="/images/k6.sqsh", wait=False),
187+
)
188+
189+
assert submission.job_id == 123
190+
assert submission.run_label.startswith("loadtest_throughput_")
191+
assert submission.run_label.endswith("_XXXXXX")

tests/unit/test_loadtest_cli.py

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,11 +36,11 @@ def test_loadtest_run_help_excludes_removed_scenario_owned_flags(capsys: pytest.
3636
assert "--loadtest-ignore-eos" in help_text
3737
assert "--wait-until-healthy" in help_text
3838
assert "--loadtest-metrics-remote-write" in help_text
39+
assert "--loadtest-job-time" in help_text
3940
assert "--job-id" not in help_text
4041
assert "--loadtest-chat-mode" not in help_text
4142
assert "--loadtest-cpus-per-task" not in help_text
4243
assert "--loadtest-k6-script" not in help_text
43-
assert "--loadtest-job-time" not in help_text
4444
assert "--loadtest-ready-poll-interval" not in help_text
4545
assert "--loadtest-think-time" not in help_text
4646
assert "--loadtest-request-timeout" not in help_text
@@ -143,7 +143,6 @@ def test_loadtest_parser_does_not_expose_api_key_override() -> None:
143143
("--no-loadtest-chat-mode", None),
144144
("--job-id", "123"),
145145
("--loadtest-cpus-per-task", "8"),
146-
("--loadtest-job-time", "01:00:00"),
147146
("--loadtest-k6-script", "script.js"),
148147
("--loadtest-ready-poll-interval", "30"),
149148
("--loadtest-think-time", "0"),
@@ -174,6 +173,28 @@ def test_loadtest_metrics_remote_write_enabled_by_default() -> None:
174173
)
175174

176175

176+
def test_loadtest_job_time_defaults_to_cluster_default() -> None:
177+
parser = _build_parser()
178+
args = parser.parse_args(["loadtest", "run"])
179+
180+
assert _make_cluster_loadtest_config(args).time == "00:30:00"
181+
182+
183+
def test_loadtest_job_time_can_be_overridden() -> None:
184+
parser = _build_parser()
185+
args = parser.parse_args(["loadtest", "run", "--loadtest-job-time", "01:00:00"])
186+
187+
assert _make_cluster_loadtest_config(args).time == "01:00:00"
188+
189+
190+
def test_loadtest_job_time_rejects_invalid_format() -> None:
191+
parser = _build_parser()
192+
args = parser.parse_args(["loadtest", "run", "--loadtest-job-time", "1h"])
193+
194+
with pytest.raises(ValueError, match="--loadtest-job-time"):
195+
_make_cluster_loadtest_config(args)
196+
197+
177198
def test_loadtest_metrics_remote_write_can_be_disabled() -> None:
178199
parser = _build_parser()
179200
args = parser.parse_args(["loadtest", "run", "--no-loadtest-metrics-remote-write"])

0 commit comments

Comments
 (0)