Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 60 additions & 5 deletions cluv/cli/submit.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,60 @@
from cluv.utils import console


RUNNING_JOB_STATES = ["PENDING", "RUNNING"]
# SLURM states a job cannot transition out of. Anything else (PENDING, RUNNING,
# COMPLETING, SUSPENDED, REQUEUED, RESIZING, STAGE_OUT, unknown future states)
# is treated as transient so wait-loops default to "keep polling" on unknowns.
TERMINAL_JOB_STATES = [
"COMPLETED",
"FAILED",
"CANCELLED",
"TIMEOUT",
"NODE_FAIL",
"OUT_OF_MEMORY",
"PREEMPTED",
"BOOT_FAIL",
"DEADLINE",
"REVOKED",
"SPECIAL_EXIT",
]
FAILED_JOB_STATES = ["FAILED", "CANCELLED", "TIMEOUT", "NODE_FAIL", "OUT_OF_MEMORY", "PREEMPTED"]

# SBATCH_* env vars that have an equivalent `sbatch` CLI flag. We translate
# these to flags before invoking sbatch because some clusters (notably DRAC)
# re-source their site profile inside `bash --login -c`, which can clobber
# SBATCH_* env defaults before sbatch reads them; CLI flags are parsed by
# sbatch directly and survive the login shell. Any SBATCH_* key not in this
# table falls through as a plain env var (preserving existing behavior).
SBATCH_ENV_TO_FLAG: dict[str, str] = {
"SBATCH_ACCOUNT": "--account",
"SBATCH_CONSTRAINT": "--constraint",
"SBATCH_CPUS_PER_TASK": "--cpus-per-task",
"SBATCH_ERROR": "--error",
"SBATCH_GRES": "--gres",
"SBATCH_JOB_NAME": "--job-name",
"SBATCH_MEM": "--mem",
"SBATCH_NODES": "--nodes",
"SBATCH_NTASKS": "--ntasks",
"SBATCH_OUTPUT": "--output",
"SBATCH_PARTITION": "--partition",
"SBATCH_QOS": "--qos",
"SBATCH_RESERVATION": "--reservation",
"SBATCH_TIME": "--time",
}


def _split_env_for_sbatch(env_vars: dict[str, str]) -> tuple[list[str], dict[str, str]]:
"""Translate known SBATCH_* env vars into `sbatch` CLI flags; pass the rest through."""
sbatch_flags: list[str] = []
remaining: dict[str, str] = {}
for key, value in env_vars.items():
flag = SBATCH_ENV_TO_FLAG.get(key)
if flag is None:
remaining[key] = value
else:
sbatch_flags.append(f"{flag}={shlex.quote(str(value))}")
return sbatch_flags, remaining


__all__ = ["submit"]

Expand Down Expand Up @@ -149,7 +200,7 @@ async def submit_first(
continue
job_status = await get_job_status(remote, job_id)

if job_status in RUNNING_JOB_STATES:
if job_status and job_status not in TERMINAL_JOB_STATES:
start_cluster = cluster
start_job_id = job_id
break
Expand Down Expand Up @@ -246,13 +297,15 @@ def get_sbatch_command(
env_vars["SBATCH_JOB_NAME"] = f"cluv-{base_name}"
env_vars["GIT_COMMIT"] = git_commit

env_vars_prefix = " ".join(f"{k}={shlex.quote(str(v))}" for k, v in env_vars.items())
sbatch_flags, env_remaining = _split_env_for_sbatch(env_vars)
env_vars_prefix = " ".join(f"{k}={shlex.quote(str(v))}" for k, v in env_remaining.items())
sbatch_flags_str = " ".join(sbatch_flags)
sbatch_args_str = " ".join(shlex.quote(f) for f in sbatch_args)
program_args_str = shlex.join(program_args)

return (
f"bash --login -c '{env_vars_prefix} sbatch --parsable --chdir={project_path} "
f"{sbatch_args_str} {remote_job_script} {program_args_str}'"
f"{sbatch_flags_str} {sbatch_args_str} {remote_job_script} {program_args_str}'"
)


Expand All @@ -274,7 +327,9 @@ async def sbatch(

async def get_job_status(remote: Remote, job_id: int) -> str:
"""Get the status of the job with the given id on the remote cluster."""
sacct_command = f"sacct -j {job_id} --format=State --noheader --allocations"
# --parsable2 prevents sacct from truncating wider state names to 10 chars
# (e.g. "OUT_OF_ME+" for OUT_OF_MEMORY); we want the full canonical string.
sacct_command = f"sacct -j {job_id} --format=State --noheader --allocations --parsable2"
return await remote.get_output(sacct_command)


Expand Down
86 changes: 83 additions & 3 deletions tests/test_submit.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,12 @@
import subprocess
from pathlib import Path

from cluv.cli.submit import ensure_clean_git_state, get_sbatch_command, get_config
from cluv.cli.submit import (
TERMINAL_JOB_STATES,
ensure_clean_git_state,
get_config,
get_sbatch_command,
)

import pytest

Expand Down Expand Up @@ -47,7 +52,7 @@ def test_generate_command_for_selected_cluster_with_correct_args_and_vars(self,

assert (
sbatch_command
== "bash --login -c 'MY_VAR=1 SPECIAL_MILA_VAR=xyz SBATCH_JOB_NAME=cluv-my_script GIT_COMMIT=abecdef sbatch --parsable --chdir=my_project --account=my_account --mem=8G ~/my_project/scripts/my_script.sh program_arg_1 program_arg_2'"
== "bash --login -c 'MY_VAR=1 SPECIAL_MILA_VAR=xyz GIT_COMMIT=abecdef sbatch --parsable --chdir=my_project --job-name=cluv-my_script --account=my_account --mem=8G ~/my_project/scripts/my_script.sh program_arg_1 program_arg_2'"
)

def test_only_override_slurm_vars_with_selected_cluster_vars(self, project_dir: Path) -> None:
Expand Down Expand Up @@ -77,9 +82,84 @@ def test_only_override_slurm_vars_with_selected_cluster_vars(self, project_dir:

assert (
sbatch_command
== "bash --login -c 'MY_VAR=2 SBATCH_JOB_NAME=cluv-my_script GIT_COMMIT=abecdef sbatch --parsable --chdir=my_project ~/my_project/scripts/my_script.sh '"
== "bash --login -c 'MY_VAR=2 GIT_COMMIT=abecdef sbatch --parsable --chdir=my_project --job-name=cluv-my_script ~/my_project/scripts/my_script.sh '"
)

def test_sbatch_env_vars_are_translated_to_cli_flags(self, project_dir: Path) -> None:
"""SBATCH_* env vars must reach sbatch as CLI flags, not as env vars.

DRAC clusters re-source their site profile inside `bash --login -c`,
which clobbers SBATCH_* defaults before sbatch reads them. Flags are
parsed by sbatch directly and survive the login shell.
"""
p = project_dir / "pyproject.toml"
p.write_text(
textwrap.dedent(
"""\
[tool.cluv]
results_path = "results"
[tool.cluv.env]
SBATCH_MEM = "2G"
SBATCH_TIME = "00:05:00"
SBATCH_CPUS_PER_TASK = "1"
[tool.cluv.clusters.mila.env]
SBATCH_ACCOUNT = "rrg-foo"
"""
)
)

sbatch_command = get_sbatch_command(
cluster="mila",
job_script=Path("scripts/my_script.sh"),
sbatch_args=[],
program_args=[],
git_commit="abecdef",
)

# Resource requests must appear as `--flag=value` between `sbatch --parsable
# --chdir=...` and the job script path, not as a `SBATCH_*=...` env prefix.
assert "SBATCH_MEM=" not in sbatch_command
assert "SBATCH_TIME=" not in sbatch_command
assert "SBATCH_CPUS_PER_TASK=" not in sbatch_command
assert "SBATCH_ACCOUNT=" not in sbatch_command
assert "--mem=2G" in sbatch_command
assert "--time=00:05:00" in sbatch_command
assert "--cpus-per-task=1" in sbatch_command
assert "--account=rrg-foo" in sbatch_command
# SBATCH_JOB_NAME injected by get_sbatch_command is also a flag.
assert "--job-name=cluv-my_script" in sbatch_command
# Non-SBATCH vars stay as env vars.
assert "GIT_COMMIT=abecdef" in sbatch_command


class TestTerminalJobStates:
"""The wait-loop in submit_first uses `state not in TERMINAL_JOB_STATES`.

Sanity-check the predicate: known terminal states stop the loop, transient
states (including ones absent from old `RUNNING_JOB_STATES`) keep it going,
and an empty/unknown state defaults to keep-polling.
"""

@pytest.mark.parametrize(
"state",
["COMPLETED", "FAILED", "CANCELLED", "TIMEOUT", "OUT_OF_MEMORY", "NODE_FAIL"],
)
def test_terminal_states_stop_polling(self, state: str) -> None:
assert state in TERMINAL_JOB_STATES

@pytest.mark.parametrize(
"state",
["PENDING", "RUNNING", "COMPLETING", "CONFIGURING", "SUSPENDED", "REQUEUED", "RESIZING"],
)
def test_transient_states_keep_polling(self, state: str) -> None:
assert state not in TERMINAL_JOB_STATES

@pytest.mark.parametrize("state", ["", "FUTURE_SLURM_STATE"])
def test_empty_or_unknown_state_keeps_polling(self, state: str) -> None:
# The submit_first call site is `if job_status and job_status not in TERMINAL_JOB_STATES`,
# so an empty string short-circuits to "still running" via the truthiness guard.
assert state not in TERMINAL_JOB_STATES


class TestEnsureCleanGitState:
def test_prefers_branch_tip_in_github_actions_detached_head(
Expand Down
Loading