Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
91 changes: 73 additions & 18 deletions cluv/cli/submit.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,13 @@
import shlex
import subprocess
import sys
from collections.abc import Iterable, Sequence
from pathlib import Path

from rich import box
from rich.table import Table
from rich.text import Text

from cluv.cli.sync import sync
from cluv.config import ClusterConfig, find_pyproject, get_config
from cluv.remote import Remote
Expand Down Expand Up @@ -83,6 +88,53 @@ async def submit(
return job_id


def _build_commands_table(cluster_to_command: dict[str, str]) -> Table:
"""Build a rich Table showing the sbatch command that will be run on each cluster."""
table = Table(
title="[bold cyan]sbatch Commands[/bold cyan]",
box=box.ROUNDED,
show_lines=True,
header_style="bold white",
)
table.add_column("Cluster", style="bold magenta", min_width=12)
table.add_column("Command", style="green", overflow="fold")
for cluster, command in cluster_to_command.items():
table.add_row(cluster, command)
return table


def _build_submission_table(
cluster_names: Iterable[str],
sbatch_results: Sequence[subprocess.CompletedProcess[str] | BaseException],
cluster_to_jobid: dict[str, int],
) -> Table:
"""Build a rich Table summarising sbatch submission results.

Populates *cluster_to_jobid* as a side-effect for callers that need it.
"""
table = Table(
title="[bold cyan]Job Submission Results[/bold cyan]",
box=box.ROUNDED,
show_lines=True,
header_style="bold white",
)
table.add_column("Cluster", style="bold magenta", min_width=12)
table.add_column("Job ID / Status", min_width=30)

for cluster, result in zip(cluster_names, sbatch_results, strict=True):
if isinstance(result, BaseException):
status_cell = Text(f"error: {result}", style="red")
elif result.returncode == 0:
job_id = int(result.stdout.strip())
cluster_to_jobid[cluster] = job_id
status_cell = Text(str(job_id), style="green")
else:
status_cell = Text(result.stderr.strip(), style="red")
table.add_row(cluster, status_cell)

return table


async def submit_first(
job_script: Path,
sbatch_args: list[str],
Expand All @@ -96,6 +148,15 @@ async def submit_first(
remotes = await sync()
clusters_to_remote = {remote.hostname: remote for remote in remotes}

# Pre-compute and display the sbatch command for each cluster.
cluster_to_command = {
remote.hostname: get_sbatch_command(
remote.hostname, job_script, sbatch_args, program_args, git_commit
)
for remote in remotes
}
console.print(_build_commands_table(cluster_to_command))

# Submit the job on all the clusters
sbatch_results = await asyncio.gather(
*[
Expand All @@ -105,6 +166,7 @@ async def submit_first(
sbatch_args,
program_args,
git_commit,
display=False,
)
for remote in remotes
],
Expand All @@ -113,22 +175,9 @@ async def submit_first(

# Get the results of the sbatch command. We expect an int (the job id) or the exception
# if the command failed on the remote cluster.
console.print("Jobs submitted on the clusters:")
cluster_to_jobid: dict[str, int] = {}
for cluster, result in zip(clusters_to_remote.keys(), sbatch_results):
if isinstance(result, BaseException):
console.print(
f" - [bold]{cluster}[/bold]: error when trying to use remote, [red]{result}[/red]"
)
else:
if result.returncode == 0:
job_id = int(result.stdout.strip())
cluster_to_jobid[cluster] = job_id
console.print(f" - [bold]{cluster}[/bold]: job {job_id}")
else:
console.print(
f" - [bold]{cluster}[/bold]: no job, [red]{result.stderr.strip()}[/red]"
)
submission_table = _build_submission_table(clusters_to_remote.keys(), sbatch_results, cluster_to_jobid)
console.print(submission_table)

if len(cluster_to_jobid) == 0:
console.print("No job submitted on clusters. See errors above.")
Expand Down Expand Up @@ -175,14 +224,18 @@ async def submit_first(
await asyncio.sleep(wait_time)
wait_time = min(wait_time*2, 20)
console.log(
f"Job {start_job_id} on cluster {start_cluster} is running. Cancelling the other jobs...\n",
f"Use `ssh {start_cluster} sacct -j {start_job_id}` to view its status.",
f"Job {start_job_id} on cluster {start_cluster} is running. Cancelling the other jobs..."
)
except (KeyboardInterrupt, asyncio.CancelledError):
console.log("Interrupted by user. Cancelling all jobs...")
finally:
await cancel_all_jobs(clusters_to_remote, cluster_to_jobid, start_cluster)

if start_cluster is not None and start_job_id is not None:
console.print(
f"\nTo watch the job: [bold]ssh {start_cluster} sacct -j {start_job_id}[/bold]"
)

return start_job_id


Expand Down Expand Up @@ -262,14 +315,16 @@ async def sbatch(
sbatch_args: list[str],
program_args: list[str],
git_commit: str,
*,
display: bool = True,
) -> subprocess.CompletedProcess[str]:
"""Submit the job via sbatch on the remote cluster, and return the job id."""
cluster = remote.hostname

remote_cmd = get_sbatch_command(
cluster, job_script, sbatch_args, program_args, git_commit
)
return await remote.run(remote_cmd, display=True, warn=True, hide=True)
return await remote.run(remote_cmd, display=display, warn=True, hide=True)


async def get_job_status(remote: Remote, job_id: int) -> str:
Expand Down
77 changes: 76 additions & 1 deletion tests/test_submit.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
import textwrap
import subprocess
from io import StringIO
from pathlib import Path

from cluv.cli.submit import ensure_clean_git_state, get_sbatch_command, get_config
from rich.console import Console

from cluv.cli.submit import _build_commands_table, _build_submission_table, ensure_clean_git_state, get_sbatch_command, get_config

import pytest

Expand Down Expand Up @@ -139,3 +142,75 @@ def mock_subprocess_check_output(command: list[str], **kwargs) -> str:
monkeypatch.setattr(subprocess, "check_output", mock_subprocess_check_output)

assert ensure_clean_git_state() == "cccccccccccccccccccccccccccccccccccccccc"


class TestBuildSubmissionTable:
def _make_ok(self, job_id: int) -> subprocess.CompletedProcess[str]:
return subprocess.CompletedProcess([], 0, stdout=f"{job_id}\n", stderr="")

def _make_err(self, msg: str) -> subprocess.CompletedProcess[str]:
return subprocess.CompletedProcess([], 1, stdout="", stderr=msg)

def test_successful_submissions_populate_cluster_to_jobid(self) -> None:
cluster_to_jobid: dict[str, int] = {}
table = _build_submission_table(
["mila", "narval"],
[self._make_ok(12345), self._make_ok(67890)],
cluster_to_jobid,
)
assert cluster_to_jobid == {"mila": 12345, "narval": 67890}
# Two data rows expected
assert table.row_count == 2

def test_failed_submission_not_added_to_cluster_to_jobid(self) -> None:
cluster_to_jobid: dict[str, int] = {}
_build_submission_table(
["mila", "narval"],
[self._make_ok(42), self._make_err("out of memory")],
cluster_to_jobid,
)
assert "narval" not in cluster_to_jobid
assert cluster_to_jobid == {"mila": 42}

def test_exception_result_not_added_to_cluster_to_jobid(self) -> None:
cluster_to_jobid: dict[str, int] = {}
_build_submission_table(
["mila"],
[RuntimeError("connection refused")],
cluster_to_jobid,
)
assert cluster_to_jobid == {}

def test_table_cells_contain_expected_text(self) -> None:
cluster_to_jobid: dict[str, int] = {}
table = _build_submission_table(
["mila", "narval", "rorqual"],
[
self._make_ok(99),
self._make_err("sbatch: error: ..."),
RuntimeError("timeout"),
],
cluster_to_jobid,
)
buf = StringIO()
Console(file=buf, no_color=True, highlight=False).print(table)
rendered = buf.getvalue()
assert "99" in rendered
assert "sbatch: error:" in rendered
assert "timeout" in rendered


class TestBuildCommandsTable:
def test_renders_all_clusters_and_commands(self) -> None:
commands = {
"mila": "bash --login -c 'GIT_COMMIT=abc sbatch --parsable --chdir=proj job.sh'",
"narval": "bash --login -c 'GIT_COMMIT=abc sbatch --parsable --chdir=proj job.sh'",
}
table = _build_commands_table(commands)
buf = StringIO()
Console(file=buf, no_color=True, highlight=False).print(table)
rendered = buf.getvalue()
assert "mila" in rendered
assert "narval" in rendered
assert "GIT_COMMIT=abc" in rendered
assert table.row_count == 2