diff --git a/cluv/cli/submit.py b/cluv/cli/submit.py index 53d5da1..779211c 100644 --- a/cluv/cli/submit.py +++ b/cluv/cli/submit.py @@ -5,8 +5,13 @@ import shlex import subprocess import sys +from collections.abc import Iterable, Sequence from pathlib import Path +from rich import box +from rich.table import Table +from rich.text import Text + from cluv.cli.sync import sync from cluv.config import ClusterConfig, find_pyproject, get_config from cluv.remote import Remote @@ -83,6 +88,53 @@ async def submit( return job_id +def _build_commands_table(cluster_to_command: dict[str, str]) -> Table: + """Build a rich Table showing the sbatch command that will be run on each cluster.""" + table = Table( + title="[bold cyan]sbatch Commands[/bold cyan]", + box=box.ROUNDED, + show_lines=True, + header_style="bold white", + ) + table.add_column("Cluster", style="bold magenta", min_width=12) + table.add_column("Command", style="green", overflow="fold") + for cluster, command in cluster_to_command.items(): + table.add_row(cluster, command) + return table + + +def _build_submission_table( + cluster_names: Iterable[str], + sbatch_results: Sequence[subprocess.CompletedProcess[str] | BaseException], + cluster_to_jobid: dict[str, int], +) -> Table: + """Build a rich Table summarising sbatch submission results. + + Populates *cluster_to_jobid* as a side-effect for callers that need it. + """ + table = Table( + title="[bold cyan]Job Submission Results[/bold cyan]", + box=box.ROUNDED, + show_lines=True, + header_style="bold white", + ) + table.add_column("Cluster", style="bold magenta", min_width=12) + table.add_column("Job ID / Status", min_width=30) + + for cluster, result in zip(cluster_names, sbatch_results, strict=True): + if isinstance(result, BaseException): + status_cell = Text(f"error: {result}", style="red") + elif result.returncode == 0: + job_id = int(result.stdout.strip()) + cluster_to_jobid[cluster] = job_id + status_cell = Text(str(job_id), style="green") + else: + status_cell = Text(result.stderr.strip(), style="red") + table.add_row(cluster, status_cell) + + return table + + async def submit_first( job_script: Path, sbatch_args: list[str], @@ -96,6 +148,15 @@ async def submit_first( remotes = await sync() clusters_to_remote = {remote.hostname: remote for remote in remotes} + # Pre-compute and display the sbatch command for each cluster. + cluster_to_command = { + remote.hostname: get_sbatch_command( + remote.hostname, job_script, sbatch_args, program_args, git_commit + ) + for remote in remotes + } + console.print(_build_commands_table(cluster_to_command)) + # Submit the job on all the clusters sbatch_results = await asyncio.gather( *[ @@ -105,6 +166,7 @@ async def submit_first( sbatch_args, program_args, git_commit, + display=False, ) for remote in remotes ], @@ -113,22 +175,9 @@ async def submit_first( # Get the results of the sbatch command. We expect an int (the job id) or the exception # if the command failed on the remote cluster. - console.print("Jobs submitted on the clusters:") cluster_to_jobid: dict[str, int] = {} - for cluster, result in zip(clusters_to_remote.keys(), sbatch_results): - if isinstance(result, BaseException): - console.print( - f" - [bold]{cluster}[/bold]: error when trying to use remote, [red]{result}[/red]" - ) - else: - if result.returncode == 0: - job_id = int(result.stdout.strip()) - cluster_to_jobid[cluster] = job_id - console.print(f" - [bold]{cluster}[/bold]: job {job_id}") - else: - console.print( - f" - [bold]{cluster}[/bold]: no job, [red]{result.stderr.strip()}[/red]" - ) + submission_table = _build_submission_table(clusters_to_remote.keys(), sbatch_results, cluster_to_jobid) + console.print(submission_table) if len(cluster_to_jobid) == 0: console.print("No job submitted on clusters. See errors above.") @@ -175,14 +224,18 @@ async def submit_first( await asyncio.sleep(wait_time) wait_time = min(wait_time*2, 20) console.log( - f"Job {start_job_id} on cluster {start_cluster} is running. Cancelling the other jobs...\n", - f"Use `ssh {start_cluster} sacct -j {start_job_id}` to view its status.", + f"Job {start_job_id} on cluster {start_cluster} is running. Cancelling the other jobs..." ) except (KeyboardInterrupt, asyncio.CancelledError): console.log("Interrupted by user. Cancelling all jobs...") finally: await cancel_all_jobs(clusters_to_remote, cluster_to_jobid, start_cluster) + if start_cluster is not None and start_job_id is not None: + console.print( + f"\nTo watch the job: [bold]ssh {start_cluster} sacct -j {start_job_id}[/bold]" + ) + return start_job_id @@ -262,6 +315,8 @@ async def sbatch( sbatch_args: list[str], program_args: list[str], git_commit: str, + *, + display: bool = True, ) -> subprocess.CompletedProcess[str]: """Submit the job via sbatch on the remote cluster, and return the job id.""" cluster = remote.hostname @@ -269,7 +324,7 @@ async def sbatch( remote_cmd = get_sbatch_command( cluster, job_script, sbatch_args, program_args, git_commit ) - return await remote.run(remote_cmd, display=True, warn=True, hide=True) + return await remote.run(remote_cmd, display=display, warn=True, hide=True) async def get_job_status(remote: Remote, job_id: int) -> str: diff --git a/tests/test_submit.py b/tests/test_submit.py index 8706ae4..a075083 100644 --- a/tests/test_submit.py +++ b/tests/test_submit.py @@ -1,8 +1,11 @@ import textwrap import subprocess +from io import StringIO from pathlib import Path -from cluv.cli.submit import ensure_clean_git_state, get_sbatch_command, get_config +from rich.console import Console + +from cluv.cli.submit import _build_commands_table, _build_submission_table, ensure_clean_git_state, get_sbatch_command, get_config import pytest @@ -139,3 +142,75 @@ def mock_subprocess_check_output(command: list[str], **kwargs) -> str: monkeypatch.setattr(subprocess, "check_output", mock_subprocess_check_output) assert ensure_clean_git_state() == "cccccccccccccccccccccccccccccccccccccccc" + + +class TestBuildSubmissionTable: + def _make_ok(self, job_id: int) -> subprocess.CompletedProcess[str]: + return subprocess.CompletedProcess([], 0, stdout=f"{job_id}\n", stderr="") + + def _make_err(self, msg: str) -> subprocess.CompletedProcess[str]: + return subprocess.CompletedProcess([], 1, stdout="", stderr=msg) + + def test_successful_submissions_populate_cluster_to_jobid(self) -> None: + cluster_to_jobid: dict[str, int] = {} + table = _build_submission_table( + ["mila", "narval"], + [self._make_ok(12345), self._make_ok(67890)], + cluster_to_jobid, + ) + assert cluster_to_jobid == {"mila": 12345, "narval": 67890} + # Two data rows expected + assert table.row_count == 2 + + def test_failed_submission_not_added_to_cluster_to_jobid(self) -> None: + cluster_to_jobid: dict[str, int] = {} + _build_submission_table( + ["mila", "narval"], + [self._make_ok(42), self._make_err("out of memory")], + cluster_to_jobid, + ) + assert "narval" not in cluster_to_jobid + assert cluster_to_jobid == {"mila": 42} + + def test_exception_result_not_added_to_cluster_to_jobid(self) -> None: + cluster_to_jobid: dict[str, int] = {} + _build_submission_table( + ["mila"], + [RuntimeError("connection refused")], + cluster_to_jobid, + ) + assert cluster_to_jobid == {} + + def test_table_cells_contain_expected_text(self) -> None: + cluster_to_jobid: dict[str, int] = {} + table = _build_submission_table( + ["mila", "narval", "rorqual"], + [ + self._make_ok(99), + self._make_err("sbatch: error: ..."), + RuntimeError("timeout"), + ], + cluster_to_jobid, + ) + buf = StringIO() + Console(file=buf, no_color=True, highlight=False).print(table) + rendered = buf.getvalue() + assert "99" in rendered + assert "sbatch: error:" in rendered + assert "timeout" in rendered + + +class TestBuildCommandsTable: + def test_renders_all_clusters_and_commands(self) -> None: + commands = { + "mila": "bash --login -c 'GIT_COMMIT=abc sbatch --parsable --chdir=proj job.sh'", + "narval": "bash --login -c 'GIT_COMMIT=abc sbatch --parsable --chdir=proj job.sh'", + } + table = _build_commands_table(commands) + buf = StringIO() + Console(file=buf, no_color=True, highlight=False).print(table) + rendered = buf.getvalue() + assert "mila" in rendered + assert "narval" in rendered + assert "GIT_COMMIT=abc" in rendered + assert table.row_count == 2