Skip to content

Commit c365ec5

Browse files
test: add loadtest unit and integration coverage
1 parent ff41320 commit c365ec5

6 files changed

Lines changed: 815 additions & 43 deletions

File tree

Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,96 @@
1+
import os
2+
from collections.abc import AsyncIterator
3+
from pathlib import Path
4+
5+
import firecrest as f7t
6+
import pytest
7+
8+
from swiss_ai_model_launch.launchers.firecrest_launcher import FirecRESTLauncher
9+
from swiss_ai_model_launch.loadtest.cluster import ClusterLoadtestConfig, submit_cluster_loadtest
10+
from swiss_ai_model_launch.loadtest.models import LoadtestConfig, ServerConfig
11+
from swiss_ai_model_launch.loadtest.setup import DEFAULT_CLUSTER_CONTAINER_IMAGE, K6_SCRIPT
12+
from tests.integration.utils import wait_for_job_running
13+
14+
_LOADTEST_RUNNING_TIMEOUT_MIN = 20
15+
_LOADTEST_SERVER_URL = "https://api.swissai.svc.cscs.ch"
16+
17+
_REQUIRED_ENV_VARS = [
18+
"SML_CSCS_API_KEY",
19+
"SML_FIRECREST_CLIENT_ID",
20+
"SML_FIRECREST_CLIENT_SECRET",
21+
"SML_FIRECREST_SYSTEM",
22+
"SML_FIRECREST_TOKEN_URI",
23+
"SML_FIRECREST_URL",
24+
"SML_LOADTEST_MODEL",
25+
"SML_LOADTEST_PROMPTS_FILE",
26+
"SML_PARTITION",
27+
"SML_RESERVATION",
28+
]
29+
30+
31+
@pytest.fixture(scope="function") # type: ignore[misc]
32+
def env() -> dict[str, str]:
33+
missing = [v for v in _REQUIRED_ENV_VARS if os.environ.get(v) is None]
34+
if missing:
35+
pytest.fail(
36+
"Missing required environment variables: " + ", ".join(missing),
37+
pytrace=False,
38+
)
39+
return {v: os.environ[v] for v in _REQUIRED_ENV_VARS}
40+
41+
42+
@pytest.fixture(scope="function") # type: ignore[misc]
43+
async def launcher(env: dict[str, str]) -> AsyncIterator[FirecRESTLauncher]:
44+
client = f7t.v2.AsyncFirecrest(
45+
firecrest_url=env["SML_FIRECREST_URL"],
46+
authorization=f7t.ClientCredentialsAuth(
47+
client_id=env["SML_FIRECREST_CLIENT_ID"],
48+
client_secret=env["SML_FIRECREST_CLIENT_SECRET"],
49+
token_uri=env["SML_FIRECREST_TOKEN_URI"],
50+
min_token_validity=90,
51+
),
52+
)
53+
try:
54+
yield await FirecRESTLauncher.from_client(
55+
client=client,
56+
system_name=env["SML_FIRECREST_SYSTEM"],
57+
partition=env["SML_PARTITION"],
58+
reservation=env["SML_RESERVATION"] or None,
59+
)
60+
finally:
61+
await client.close_session()
62+
63+
64+
@pytest.mark.comprehensive
65+
async def test_submit_cluster_loadtest_starts_cluster_job(
66+
launcher: FirecRESTLauncher,
67+
env: dict[str, str],
68+
tmp_path: Path,
69+
) -> None:
70+
submission = await submit_cluster_loadtest(
71+
launcher=launcher,
72+
server=ServerConfig(
73+
url=os.environ.get("SML_LOADTEST_SERVER_URL", _LOADTEST_SERVER_URL),
74+
api_key=env["SML_CSCS_API_KEY"],
75+
model=env["SML_LOADTEST_MODEL"],
76+
is_swissai=True,
77+
),
78+
bench=LoadtestConfig(
79+
scenario=os.environ.get("SML_LOADTEST_SCENARIO", "throughput"),
80+
think_time="0",
81+
max_tokens=os.environ.get("SML_LOADTEST_MAX_TOKENS", "16"),
82+
),
83+
k6_script=K6_SCRIPT,
84+
prompts_file=Path(env["SML_LOADTEST_PROMPTS_FILE"]),
85+
summary_path=tmp_path / "summary.json",
86+
cluster=ClusterLoadtestConfig(
87+
container_image=str(DEFAULT_CLUSTER_CONTAINER_IMAGE),
88+
wait=False,
89+
reservation=env["SML_RESERVATION"] or None,
90+
),
91+
)
92+
93+
try:
94+
await wait_for_job_running(launcher, submission.job_id, _LOADTEST_RUNNING_TIMEOUT_MIN)
95+
finally:
96+
await launcher.cancel_job(submission.job_id)

tests/unit/test_cluster_loadtest.py

Lines changed: 1 addition & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -3,15 +3,12 @@
33

44
import pytest
55

6-
import swiss_ai_model_launch.loadtest.cluster as cluster_module
7-
from swiss_ai_model_launch.launchers import SlurmLauncher
86
from swiss_ai_model_launch.loadtest.cluster import (
97
ClusterLoadtestConfig,
108
_container_mounts_for_external_prompts,
119
build_cluster_loadtest_script,
12-
submit_cluster_loadtest,
1310
)
14-
from swiss_ai_model_launch.loadtest.models import LoadtestConfig, ServerConfig
11+
from swiss_ai_model_launch.loadtest.models import LoadtestConfig
1512

1613

1714
@pytest.fixture
@@ -150,42 +147,3 @@ def test_container_mounts_includes_top_level_dir() -> None:
150147
def test_container_mounts_prompts_path_returned() -> None:
151148
prompts_path, _ = _container_mounts_for_external_prompts("my_run", Path("/scratch/data/prompts.jsonl"))
152149
assert prompts_path == "/scratch/data/prompts.jsonl"
153-
154-
155-
@pytest.mark.asyncio
156-
async def test_submit_cluster_loadtest_returns_job_id_and_run_label(
157-
tmp_path: Path,
158-
monkeypatch: pytest.MonkeyPatch,
159-
bench: LoadtestConfig,
160-
) -> None:
161-
async def fake_run_checked(*cmd: str) -> str:
162-
assert cmd[0] == "sbatch"
163-
return "Submitted batch job 123\n"
164-
165-
monkeypatch.setenv("HOME", str(tmp_path))
166-
monkeypatch.setattr(cluster_module, "_run_checked", fake_run_checked)
167-
monkeypatch.setattr(cluster_module, "create_salt", lambda length: "X" * length)
168-
169-
k6_script = tmp_path / "script.js"
170-
k6_script.write_text("export default function() {}\n")
171-
prompts_file = tmp_path / "prompts.jsonl"
172-
prompts_file.write_text("{}\n")
173-
174-
submission = await submit_cluster_loadtest(
175-
launcher=SlurmLauncher(
176-
system_name="test-system",
177-
username="test-user",
178-
account="test-account",
179-
partition="test-partition",
180-
),
181-
server=ServerConfig(url="https://example.test", api_key="secret", model="test-model", is_swissai=True),
182-
bench=bench,
183-
k6_script=k6_script,
184-
prompts_file=prompts_file,
185-
summary_path=tmp_path / "summary.json",
186-
cluster=ClusterLoadtestConfig(container_image="/images/k6.sqsh", wait=False),
187-
)
188-
189-
assert submission.job_id == 123
190-
assert submission.run_label.startswith("loadtest_throughput_")
191-
assert submission.run_label.endswith("_XXXXXX")

0 commit comments

Comments
 (0)