diff --git a/cluv/__main__.py b/cluv/__main__.py index 1c71424..af6aa1b 100644 --- a/cluv/__main__.py +++ b/cluv/__main__.py @@ -112,6 +112,11 @@ def add_submit_args( formatter_class=rich_argparse.RichHelpFormatter, usage="cluv submit [sbatch-args...] [-- program-args...]", ) + submit_parser.add_argument( + "--make-commit", + action="store_true", + help="Create a local commit with tracked changes before submitting the job.", + ) submit_parser.add_argument( "cluster", metavar="", diff --git a/cluv/cli/submit.py b/cluv/cli/submit.py index 53d5da1..65c191f 100644 --- a/cluv/cli/submit.py +++ b/cluv/cli/submit.py @@ -6,6 +6,7 @@ import subprocess import sys from pathlib import Path +from typing import Callable from cluv.cli.sync import sync from cluv.config import ClusterConfig, find_pyproject, get_config @@ -25,6 +26,7 @@ async def submit( job_script: Path, sbatch_args: list[str], program_args: list[str], + make_commit: bool = False, ) -> int | None: """Submit a SLURM job on a remote cluster. @@ -41,6 +43,7 @@ async def submit( job_script: Path to the job script to submit, relative to the project root. sbatch_args: List of additional flags to pass to `sbatch`. program_args: List of arguments to pass to the job script, for example `["python", "main.py"]`. + make_commit: If True, automatically create a local commit with tracked changes before submitting. Returns: The job ID of the submitted job or None if the sbatch command fails. @@ -57,7 +60,14 @@ async def submit( ``` """ # Check git is clean locally (untracked files are fine) and capture current commit hash. - git_commit = ensure_clean_git_state() + git_commit = ensure_clean_git_state( + make_commit=make_commit, + launched_job_command_builder=( + (lambda: build_submit_command(cluster, job_script, sbatch_args, program_args)) + if make_commit + else None + ), + ) if cluster == "first": return await submit_first(job_script, sbatch_args, program_args, git_commit) @@ -186,17 +196,65 @@ async def submit_first( return start_job_id -def ensure_clean_git_state() -> str: +def build_submit_command( + cluster: str, + job_script: Path, + sbatch_args: list[str], + program_args: list[str], +) -> str: + """Build the local `cluv submit` command line used to launch the job.""" + command_parts = ["cluv", "submit"] + command_parts.extend([cluster, str(job_script), *sbatch_args]) + if program_args: + command_parts.extend(["--", *program_args]) + return shlex.join(command_parts) + + +def create_submit_commit(launched_job_command: str) -> None: + """Create a commit with tracked changes and include the launched job command in the body.""" + try: + subprocess.run(["git", "add", "-u"], check=True, capture_output=True, text=True) + subprocess.run( + [ + "git", + "commit", + "-m", + "cluv submit: auto-commit tracked changes", + "-m", + f"Launched job command:\n\n{launched_job_command}", + ], + check=True, + capture_output=True, + text=True, + ) + except subprocess.CalledProcessError as err: + error_text = (err.stderr or err.stdout or str(err)).strip() + console.print( + "[red]Failed to create automatic submit commit before job submission:[/red] " + f"{error_text}" + ) + raise + + +def ensure_clean_git_state( + make_commit: bool = False, launched_job_command_builder: Callable[[], str] | None = None +) -> str: """ Check git is clean locally and return the current commit hash. """ git_status = subprocess.run(["git", "status", "--porcelain"], capture_output=True, text=True) dirty_lines = [line for line in git_status.stdout.splitlines() if not line.startswith("??")] - if dirty_lines and not (os.environ.get("SKIP_CLEAN_GIT_CHECK", "0") == "1"): - console.print( - "[red]Working directory is dirty. Please commit your changes before submitting.[/red]", - ) - sys.exit(1) + if dirty_lines: + if make_commit: + if launched_job_command_builder is None: + raise ValueError("launched_job_command_builder is required when make_commit=True") + launched_job_command = launched_job_command_builder() + create_submit_commit(launched_job_command=launched_job_command) + elif not (os.environ.get("SKIP_CLEAN_GIT_CHECK", "0") == "1"): + console.print( + "[red]Working directory is dirty. Please commit your changes before submitting.[/red]", + ) + sys.exit(1) # In GitHub Actions PR jobs we can be on a detached merge commit that doesn't exist on # the synced remote checkout. Prefer the branch tip commit in that case. diff --git a/tests/test_submit.py b/tests/test_submit.py index 8706ae4..926918d 100644 --- a/tests/test_submit.py +++ b/tests/test_submit.py @@ -2,7 +2,7 @@ import subprocess from pathlib import Path -from cluv.cli.submit import ensure_clean_git_state, get_sbatch_command, get_config +from cluv.cli.submit import build_submit_command, ensure_clean_git_state, get_sbatch_command, get_config import pytest @@ -81,7 +81,109 @@ def test_only_override_slurm_vars_with_selected_cluster_vars(self, project_dir: ) +class TestBuildSubmitCommand: + def test_build_submit_command_with_program_args(self) -> None: + assert ( + build_submit_command( + cluster="mila", + job_script=Path("scripts/job.sh"), + sbatch_args=[], + program_args=["--flag"], + ) + == "cluv submit mila scripts/job.sh -- --flag" + ) + + class TestEnsureCleanGitState: + def test_ensure_clean_git_state_exits_when_repo_dirty_without_make_commit( + self, monkeypatch: pytest.MonkeyPatch + ) -> None: + def mock_subprocess_run(command: list[str], **kwargs) -> subprocess.CompletedProcess[str]: + assert kwargs.get("capture_output") is True + assert kwargs.get("text") is True + if command == ["git", "status", "--porcelain"]: + return subprocess.CompletedProcess(command, 0, stdout=" M cluv/cli/submit.py\n", stderr="") + raise AssertionError(f"Unexpected subprocess.run call: {command}") + + monkeypatch.setattr(subprocess, "run", mock_subprocess_run) + + with pytest.raises(SystemExit): + ensure_clean_git_state() + + def test_ensure_clean_git_state_creates_commit_when_make_commit_enabled( + self, monkeypatch: pytest.MonkeyPatch + ) -> None: + launched_job_command = "cluv submit mila scripts/job.sh -- --flag" + expected_commit_body = f"Launched job command:\n\n{launched_job_command}" + command_calls: list[tuple[list[str], dict]] = [] + + def mock_subprocess_run(command: list[str], **kwargs) -> subprocess.CompletedProcess[str]: + command_calls.append((command, kwargs)) + if command == ["git", "status", "--porcelain"]: + return subprocess.CompletedProcess( + command, 0, stdout=" M cluv/cli/submit.py\n?? notes.txt\n", stderr="" + ) + if command == ["git", "add", "-u"]: + assert kwargs.get("check") is True + assert kwargs.get("capture_output") is True + assert kwargs.get("text") is True + return subprocess.CompletedProcess(command, 0, stdout="", stderr="") + if command[:2] == ["git", "commit"]: + assert kwargs.get("check") is True + assert kwargs.get("capture_output") is True + assert kwargs.get("text") is True + assert command[2:4] == ["-m", "cluv submit: auto-commit tracked changes"] + assert command[4] == "-m" + assert command[5] == expected_commit_body + return subprocess.CompletedProcess(command, 0, stdout="", stderr="") + raise AssertionError(f"Unexpected subprocess.run call: {command}") + + def mock_subprocess_check_output(command: list[str], **kwargs) -> str: + assert kwargs.get("text") is True + if command == ["git", "rev-parse", "--abbrev-ref", "HEAD"]: + return "main\n" + if command == ["git", "rev-parse", "HEAD"]: + return "dddddddddddddddddddddddddddddddddddddddd\n" + raise AssertionError(f"Unexpected subprocess.check_output call: {command}") + + monkeypatch.setattr(subprocess, "run", mock_subprocess_run) + monkeypatch.setattr(subprocess, "check_output", mock_subprocess_check_output) + + assert ( + ensure_clean_git_state( + make_commit=True, + launched_job_command_builder=lambda: launched_job_command, + ) + == "dddddddddddddddddddddddddddddddddddddddd" + ) + assert [call[0] for call in command_calls[:3]] == [ + ["git", "status", "--porcelain"], + ["git", "add", "-u"], + [ + "git", + "commit", + "-m", + "cluv submit: auto-commit tracked changes", + "-m", + expected_commit_body, + ], + ] + + def test_ensure_clean_git_state_raises_when_make_commit_without_builder( + self, monkeypatch: pytest.MonkeyPatch + ) -> None: + def mock_subprocess_run(command: list[str], **kwargs) -> subprocess.CompletedProcess[str]: + assert kwargs.get("capture_output") is True + assert kwargs.get("text") is True + if command == ["git", "status", "--porcelain"]: + return subprocess.CompletedProcess(command, 0, stdout=" M cluv/cli/submit.py\n", stderr="") + raise AssertionError(f"Unexpected subprocess.run call: {command}") + + monkeypatch.setattr(subprocess, "run", mock_subprocess_run) + + with pytest.raises(ValueError, match="launched_job_command_builder is required"): + ensure_clean_git_state(make_commit=True) + def test_prefers_branch_tip_in_github_actions_detached_head( self, monkeypatch: pytest.MonkeyPatch ) -> None: