Skip to content
5 changes: 5 additions & 0 deletions cluv/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,11 @@ def add_submit_args(
formatter_class=rich_argparse.RichHelpFormatter,
usage="cluv submit <cluster> <job.sh> [sbatch-args...] [-- program-args...]",
)
submit_parser.add_argument(
"--make-commit",
action="store_true",
help="Create a local commit with tracked changes before submitting the job.",
)
submit_parser.add_argument(
"cluster",
metavar="<cluster>",
Expand Down
72 changes: 65 additions & 7 deletions cluv/cli/submit.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import subprocess
import sys
from pathlib import Path
from typing import Callable

from cluv.cli.sync import sync
from cluv.config import ClusterConfig, find_pyproject, get_config
Expand All @@ -25,6 +26,7 @@ async def submit(
job_script: Path,
sbatch_args: list[str],
program_args: list[str],
make_commit: bool = False,
) -> int | None:
"""Submit a SLURM job on a remote cluster.

Expand All @@ -41,6 +43,7 @@ async def submit(
job_script: Path to the job script to submit, relative to the project root.
sbatch_args: List of additional flags to pass to `sbatch`.
program_args: List of arguments to pass to the job script, for example `["python", "main.py"]`.
make_commit: If True, automatically create a local commit with tracked changes before submitting.

Returns:
The job ID of the submitted job or None if the sbatch command fails.
Expand All @@ -57,7 +60,14 @@ async def submit(
```
"""
# Check git is clean locally (untracked files are fine) and capture current commit hash.
git_commit = ensure_clean_git_state()
git_commit = ensure_clean_git_state(
make_commit=make_commit,
launched_job_command_builder=(
(lambda: build_submit_command(cluster, job_script, sbatch_args, program_args))
if make_commit
else None
),
)

if cluster == "first":
return await submit_first(job_script, sbatch_args, program_args, git_commit)
Expand Down Expand Up @@ -186,17 +196,65 @@ async def submit_first(
return start_job_id


def ensure_clean_git_state() -> str:
def build_submit_command(
cluster: str,
job_script: Path,
sbatch_args: list[str],
program_args: list[str],
) -> str:
"""Build the local `cluv submit` command line used to launch the job."""
command_parts = ["cluv", "submit"]
command_parts.extend([cluster, str(job_script), *sbatch_args])
if program_args:
command_parts.extend(["--", *program_args])
return shlex.join(command_parts)


def create_submit_commit(launched_job_command: str) -> None:
"""Create a commit with tracked changes and include the launched job command in the body."""
try:
subprocess.run(["git", "add", "-u"], check=True, capture_output=True, text=True)
subprocess.run(
[
"git",
"commit",
"-m",
"cluv submit: auto-commit tracked changes",
"-m",
f"Launched job command:\n\n{launched_job_command}",
],
check=True,
capture_output=True,
text=True,
)
except subprocess.CalledProcessError as err:
error_text = (err.stderr or err.stdout or str(err)).strip()
console.print(
"[red]Failed to create automatic submit commit before job submission:[/red] "
f"{error_text}"
)
raise


def ensure_clean_git_state(
make_commit: bool = False, launched_job_command_builder: Callable[[], str] | None = None
) -> str:
"""
Check git is clean locally and return the current commit hash.
"""
git_status = subprocess.run(["git", "status", "--porcelain"], capture_output=True, text=True)
dirty_lines = [line for line in git_status.stdout.splitlines() if not line.startswith("??")]
if dirty_lines and not (os.environ.get("SKIP_CLEAN_GIT_CHECK", "0") == "1"):
console.print(
"[red]Working directory is dirty. Please commit your changes before submitting.[/red]",
)
sys.exit(1)
if dirty_lines:
if make_commit:
if launched_job_command_builder is None:
raise ValueError("launched_job_command_builder is required when make_commit=True")
launched_job_command = launched_job_command_builder()
create_submit_commit(launched_job_command=launched_job_command)
elif not (os.environ.get("SKIP_CLEAN_GIT_CHECK", "0") == "1"):
console.print(
"[red]Working directory is dirty. Please commit your changes before submitting.[/red]",
)
sys.exit(1)

# In GitHub Actions PR jobs we can be on a detached merge commit that doesn't exist on
# the synced remote checkout. Prefer the branch tip commit in that case.
Expand Down
104 changes: 103 additions & 1 deletion tests/test_submit.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import subprocess
from pathlib import Path

from cluv.cli.submit import ensure_clean_git_state, get_sbatch_command, get_config
from cluv.cli.submit import build_submit_command, ensure_clean_git_state, get_sbatch_command, get_config

import pytest

Expand Down Expand Up @@ -81,7 +81,109 @@ def test_only_override_slurm_vars_with_selected_cluster_vars(self, project_dir:
)


class TestBuildSubmitCommand:
def test_build_submit_command_with_program_args(self) -> None:
assert (
build_submit_command(
cluster="mila",
job_script=Path("scripts/job.sh"),
sbatch_args=[],
program_args=["--flag"],
)
== "cluv submit mila scripts/job.sh -- --flag"
)


class TestEnsureCleanGitState:
def test_ensure_clean_git_state_exits_when_repo_dirty_without_make_commit(
self, monkeypatch: pytest.MonkeyPatch
) -> None:
def mock_subprocess_run(command: list[str], **kwargs) -> subprocess.CompletedProcess[str]:
assert kwargs.get("capture_output") is True
assert kwargs.get("text") is True
if command == ["git", "status", "--porcelain"]:
return subprocess.CompletedProcess(command, 0, stdout=" M cluv/cli/submit.py\n", stderr="")
raise AssertionError(f"Unexpected subprocess.run call: {command}")

monkeypatch.setattr(subprocess, "run", mock_subprocess_run)

with pytest.raises(SystemExit):
ensure_clean_git_state()

def test_ensure_clean_git_state_creates_commit_when_make_commit_enabled(
self, monkeypatch: pytest.MonkeyPatch
) -> None:
launched_job_command = "cluv submit mila scripts/job.sh -- --flag"
expected_commit_body = f"Launched job command:\n\n{launched_job_command}"
command_calls: list[tuple[list[str], dict]] = []

def mock_subprocess_run(command: list[str], **kwargs) -> subprocess.CompletedProcess[str]:
command_calls.append((command, kwargs))
if command == ["git", "status", "--porcelain"]:
return subprocess.CompletedProcess(
command, 0, stdout=" M cluv/cli/submit.py\n?? notes.txt\n", stderr=""
)
if command == ["git", "add", "-u"]:
assert kwargs.get("check") is True
assert kwargs.get("capture_output") is True
assert kwargs.get("text") is True
return subprocess.CompletedProcess(command, 0, stdout="", stderr="")
if command[:2] == ["git", "commit"]:
assert kwargs.get("check") is True
assert kwargs.get("capture_output") is True
assert kwargs.get("text") is True
assert command[2:4] == ["-m", "cluv submit: auto-commit tracked changes"]
assert command[4] == "-m"
assert command[5] == expected_commit_body
return subprocess.CompletedProcess(command, 0, stdout="", stderr="")
raise AssertionError(f"Unexpected subprocess.run call: {command}")

def mock_subprocess_check_output(command: list[str], **kwargs) -> str:
assert kwargs.get("text") is True
if command == ["git", "rev-parse", "--abbrev-ref", "HEAD"]:
return "main\n"
if command == ["git", "rev-parse", "HEAD"]:
return "dddddddddddddddddddddddddddddddddddddddd\n"
raise AssertionError(f"Unexpected subprocess.check_output call: {command}")

monkeypatch.setattr(subprocess, "run", mock_subprocess_run)
monkeypatch.setattr(subprocess, "check_output", mock_subprocess_check_output)

assert (
ensure_clean_git_state(
make_commit=True,
launched_job_command_builder=lambda: launched_job_command,
)
== "dddddddddddddddddddddddddddddddddddddddd"
)
assert [call[0] for call in command_calls[:3]] == [
["git", "status", "--porcelain"],
["git", "add", "-u"],
[
"git",
"commit",
"-m",
"cluv submit: auto-commit tracked changes",
"-m",
expected_commit_body,
],
]

def test_ensure_clean_git_state_raises_when_make_commit_without_builder(
self, monkeypatch: pytest.MonkeyPatch
) -> None:
def mock_subprocess_run(command: list[str], **kwargs) -> subprocess.CompletedProcess[str]:
assert kwargs.get("capture_output") is True
assert kwargs.get("text") is True
if command == ["git", "status", "--porcelain"]:
return subprocess.CompletedProcess(command, 0, stdout=" M cluv/cli/submit.py\n", stderr="")
raise AssertionError(f"Unexpected subprocess.run call: {command}")

monkeypatch.setattr(subprocess, "run", mock_subprocess_run)

with pytest.raises(ValueError, match="launched_job_command_builder is required"):
ensure_clean_git_state(make_commit=True)

def test_prefers_branch_tip_in_github_actions_detached_head(
self, monkeypatch: pytest.MonkeyPatch
) -> None:
Expand Down