Skip to content
Open
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 22 additions & 7 deletions cluv/cli/submit.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
import shlex
import subprocess
import sys
from dataclasses import dataclass
from datetime import datetime, timezone
from pathlib import Path

from cluv.cli.sync import sync
Expand All @@ -17,15 +19,24 @@
FAILED_JOB_STATES = ["FAILED", "CANCELLED", "TIMEOUT", "NODE_FAIL", "OUT_OF_MEMORY", "PREEMPTED"]


__all__ = ["submit"]
__all__ = ["JobInfo", "submit"]


@dataclass(frozen=True)
class JobInfo:
"""Information about a submitted Slurm job."""

cluster: str
job_id: int
submit_time: datetime


async def submit(
cluster: str,
job_script: Path,
sbatch_args: list[str],
program_args: list[str],
) -> int | None:
) -> JobInfo | None:
"""Submit a SLURM job on a remote cluster.

Enforces a clean git state, syncs the project, sets `GIT_COMMIT` and any
Expand All @@ -43,7 +54,7 @@ async def submit(
program_args: List of arguments to pass to the job script, for example `["python", "main.py"]`.

Returns:
The job ID of the submitted job or None if the sbatch command fails.
A `JobInfo` with the cluster hostname, job ID, and submission time, or None if the sbatch command fails.

Examples:

Expand Down Expand Up @@ -73,22 +84,23 @@ async def submit(
console.print(f"[red] Error during sbatch : {result.stderr}[/red]")
return None

submit_time = datetime.now(timezone.utc)
job_id = int(result.stdout.strip())

console.log(
f"Successfully submitted job {job_id} on the {cluster} cluster.\n"
f"Use `ssh {cluster} sacct -j {job_id}` to view its status."
)

return job_id
return JobInfo(cluster=cluster, job_id=job_id, submit_time=submit_time)


async def submit_first(
job_script: Path,
sbatch_args: list[str],
program_args: list[str],
git_commit: str,
) -> int | None:
) -> JobInfo | None:
"""Submit the job on all clusters, and wait until one of them starts.
Once one starts, cancel the others.
"""
Expand All @@ -110,6 +122,7 @@ async def submit_first(
],
return_exceptions=True,
)
submit_time = datetime.now(timezone.utc)

# Get the results of the sbatch command. We expect an int (the job id) or the exception
# if the command failed on the remote cluster.
Expand Down Expand Up @@ -183,7 +196,9 @@ async def submit_first(
finally:
await cancel_all_jobs(clusters_to_remote, cluster_to_jobid, start_cluster)

return start_job_id
if start_cluster is None or start_job_id is None:
return None
return JobInfo(cluster=start_cluster, job_id=start_job_id, submit_time=submit_time)
Comment thread
lebrice marked this conversation as resolved.


def ensure_clean_git_state() -> str:
Expand Down Expand Up @@ -281,7 +296,7 @@ async def get_job_status(remote: Remote, job_id: int) -> str:
async def cancel_job(remote: Remote, job_id: int) -> str:
"""Cancel the job with the given id on the remote cluster."""
scancel_command = f"scancel {job_id}"
output = await remote.get_output(scancel_command)
output = await remote.get_output(scancel_command, warn=True)
console.print(f"Cancelled job {job_id} on cluster {remote.hostname}.")
return output

Expand Down
42 changes: 39 additions & 3 deletions tests/test_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from cluv.cli.init import DEFAULT_RESULTS_PATH, init
from cluv.cli.login import get_remote_without_2fa_prompt, login
from cluv.cli.status import ClusterStatus, get_real_cluster_status
from cluv.cli.submit import submit
from cluv.cli.submit import JobInfo, submit
from cluv.cli.sync import sync
from cluv.remote import Remote, control_socket_is_running

Expand Down Expand Up @@ -192,13 +192,15 @@ async def test_submit(remote: Remote):
if remote.hostname not in SUBMIT_SUPPORTED_CLUSTERS:
pytest.xfail(f"Submit integration test not supported on cluster {remote.hostname}.")
should_cancel_job = True
job_id = await submit(
job_info = await submit(
cluster=remote.hostname,
job_script=Path("scripts/safe_job.sh"),
sbatch_args=["--time=00:00:30"],
program_args=["python", "--version"],
)
assert isinstance(job_id, int)
assert isinstance(job_info, JobInfo)
assert job_info.cluster == remote.hostname
job_id = job_info.job_id
try:
job_name = await remote.get_output(
f"sacct -j {job_id} --format=JobName --noheader --parsable2 | head -1"
Expand Down Expand Up @@ -254,6 +256,40 @@ async def test_submit(remote: Remote):
await remote.run(f"scancel {job_id}", warn=True, hide=True, display=True)


@pytest.mark.slow
@pytest.mark.timeout(300)
async def test_submit_first():
"""End-to-end: test the 'submit first' command.

Calls `submit(cluster="first", ...)` which internally submits the job on all
clusters with active SSH connections, waits for the first job to start, and
cancels the rest.
Requires at least one cluster in SUBMIT_SUPPORTED_CLUSTERS to have an active
SSH connection.
"""
available_clusters = [
c for c in SUBMIT_SUPPORTED_CLUSTERS if await control_socket_is_running(c)
]
if not available_clusters:
pytest.skip("None of the designated clusters for this test have an active SSH connections available.")
job_info: JobInfo | None = None
remotes = [await get_remote_without_2fa_prompt(c) for c in available_clusters]
try:
job_info = await submit(
cluster="first",
job_script=Path("scripts/safe_job.sh"),
sbatch_args=["--time=00:00:30"],
program_args=["python", "--version"],
)
assert isinstance(job_info, JobInfo)
finally:
if job_info is not None:
# Cancel the winning job only on the cluster where it was submitted.
remote = next((r for r in remotes if r.hostname == job_info.cluster), None)
if remote is not None:
await remote.run(f"scancel {job_info.job_id}", warn=True, hide=True, display=True)
Comment thread
lebrice marked this conversation as resolved.
Outdated


@pytest.fixture
def fake_home(tmp_path: Path, monkeypatch: pytest.MonkeyPatch) -> Path:
monkeypatch.setattr(Path, "home", lambda: tmp_path) # Set the home directory to tmp_path
Expand Down
Loading