-
Notifications
You must be signed in to change notification settings - Fork 2
Expand file tree
/
Copy pathtest_submit.py
More file actions
221 lines (188 loc) · 8.88 KB
/
test_submit.py
File metadata and controls
221 lines (188 loc) · 8.88 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
import textwrap
import subprocess
from pathlib import Path
from cluv.cli.submit import (
TERMINAL_JOB_STATES,
ensure_clean_git_state,
get_config,
get_sbatch_command,
)
import pytest
@pytest.fixture
def project_dir(tmp_path: Path, monkeypatch: pytest.MonkeyPatch) -> Path:
# To avoid that a test reads the cached config of an other, we need to clear the cache between each test.
get_config.cache_clear()
monkeypatch.setattr(Path, "home", lambda: tmp_path) # Set the home directory to tmp_path
project_dir = tmp_path / "my_project"
project_dir.mkdir()
monkeypatch.chdir(project_dir) # Set current working dir
return project_dir
class TestGetSbatchCommand:
def test_generate_command_for_selected_cluster_with_correct_args_and_vars(self, project_dir: Path) -> None:
p = project_dir / "pyproject.toml"
p.write_text(
textwrap.dedent(
"""\
[tool.cluv]
results_path = "results"
[tool.cluv.env]
MY_VAR="1"
[tool.cluv.clusters.mila.env]
SPECIAL_MILA_VAR="xyz"
[tool.cluv.clusters.vulcan.env]
SPECIAL_VULCAN_VAR="kij"
"""
)
)
sbatch_command = get_sbatch_command(
cluster="mila",
job_script=Path("scripts/my_script.sh"),
sbatch_args=["--account=my_account", "--mem=8G"],
program_args=["program_arg_1", "program_arg_2"],
git_commit="abecdef",
)
assert (
sbatch_command
== "bash --login -c 'MY_VAR=1 SPECIAL_MILA_VAR=xyz GIT_COMMIT=abecdef sbatch --parsable --chdir=my_project --job-name=cluv-my_script --account=my_account --mem=8G ~/my_project/scripts/my_script.sh program_arg_1 program_arg_2'"
)
def test_only_override_slurm_vars_with_selected_cluster_vars(self, project_dir: Path) -> None:
p = project_dir / "pyproject.toml"
p.write_text(
textwrap.dedent(
"""\
[tool.cluv]
results_path = "results"
[tool.cluv.env]
MY_VAR="1"
[tool.cluv.clusters.mila.env]
MY_VAR="2"
[tool.cluv.clusters.vulcan.env]
MY_VAR="3"
"""
)
)
sbatch_command = get_sbatch_command(
cluster="mila",
job_script=Path("scripts/my_script.sh"),
sbatch_args=[],
program_args=[],
git_commit="abecdef",
)
assert (
sbatch_command
== "bash --login -c 'MY_VAR=2 GIT_COMMIT=abecdef sbatch --parsable --chdir=my_project --job-name=cluv-my_script ~/my_project/scripts/my_script.sh '"
)
def test_sbatch_env_vars_are_translated_to_cli_flags(self, project_dir: Path) -> None:
"""SBATCH_* env vars must reach sbatch as CLI flags, not as env vars.
DRAC clusters re-source their site profile inside `bash --login -c`,
which clobbers SBATCH_* defaults before sbatch reads them. Flags are
parsed by sbatch directly and survive the login shell.
"""
p = project_dir / "pyproject.toml"
p.write_text(
textwrap.dedent(
"""\
[tool.cluv]
results_path = "results"
[tool.cluv.env]
SBATCH_MEM = "2G"
SBATCH_TIME = "00:05:00"
SBATCH_CPUS_PER_TASK = "1"
[tool.cluv.clusters.mila.env]
SBATCH_ACCOUNT = "rrg-foo"
"""
)
)
sbatch_command = get_sbatch_command(
cluster="mila",
job_script=Path("scripts/my_script.sh"),
sbatch_args=[],
program_args=[],
git_commit="abecdef",
)
# Resource requests must appear as `--flag=value` between `sbatch --parsable
# --chdir=...` and the job script path, not as a `SBATCH_*=...` env prefix.
assert "SBATCH_MEM=" not in sbatch_command
assert "SBATCH_TIME=" not in sbatch_command
assert "SBATCH_CPUS_PER_TASK=" not in sbatch_command
assert "SBATCH_ACCOUNT=" not in sbatch_command
assert "--mem=2G" in sbatch_command
assert "--time=00:05:00" in sbatch_command
assert "--cpus-per-task=1" in sbatch_command
assert "--account=rrg-foo" in sbatch_command
# SBATCH_JOB_NAME injected by get_sbatch_command is also a flag.
assert "--job-name=cluv-my_script" in sbatch_command
# Non-SBATCH vars stay as env vars.
assert "GIT_COMMIT=abecdef" in sbatch_command
class TestTerminalJobStates:
"""The wait-loop in submit_first uses `state not in TERMINAL_JOB_STATES`.
Sanity-check the predicate: known terminal states stop the loop, transient
states (including ones absent from old `RUNNING_JOB_STATES`) keep it going,
and an empty/unknown state defaults to keep-polling.
"""
@pytest.mark.parametrize(
"state",
["COMPLETED", "FAILED", "CANCELLED", "TIMEOUT", "OUT_OF_MEMORY", "NODE_FAIL"],
)
def test_terminal_states_stop_polling(self, state: str) -> None:
assert state in TERMINAL_JOB_STATES
@pytest.mark.parametrize(
"state",
["PENDING", "RUNNING", "COMPLETING", "CONFIGURING", "SUSPENDED", "REQUEUED", "RESIZING"],
)
def test_transient_states_keep_polling(self, state: str) -> None:
assert state not in TERMINAL_JOB_STATES
@pytest.mark.parametrize("state", ["", "FUTURE_SLURM_STATE"])
def test_empty_or_unknown_state_keeps_polling(self, state: str) -> None:
# The submit_first call site is `if job_status and job_status not in TERMINAL_JOB_STATES`,
# so an empty string short-circuits to "still running" via the truthiness guard.
assert state not in TERMINAL_JOB_STATES
class TestEnsureCleanGitState:
def test_prefers_branch_tip_in_github_actions_detached_head(
self, monkeypatch: pytest.MonkeyPatch
) -> None:
monkeypatch.setenv("GITHUB_ACTIONS", "true")
monkeypatch.setenv("GITHUB_HEAD_REF", "proper_integration_tests")
def mock_subprocess_run(command: list[str], **kwargs) -> subprocess.CompletedProcess[str]:
assert kwargs.get("capture_output") is True
assert kwargs.get("text") is True
if command == ["git", "status", "--porcelain"]:
return subprocess.CompletedProcess(command, 0, stdout="", stderr="")
if command == ["git", "rev-parse", "--verify", "origin/proper_integration_tests"]:
return subprocess.CompletedProcess(
command, 0, stdout="bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb\n", stderr=""
)
raise AssertionError(f"Unexpected subprocess.run call: {command}")
def mock_subprocess_check_output(command: list[str], **kwargs) -> str:
assert kwargs.get("text") is True
if command == ["git", "rev-parse", "--abbrev-ref", "HEAD"]:
return "HEAD\n"
if command == ["git", "rev-parse", "HEAD"]:
return "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa\n"
raise AssertionError(f"Unexpected subprocess.check_output call: {command}")
monkeypatch.setattr(subprocess, "run", mock_subprocess_run)
monkeypatch.setattr(subprocess, "check_output", mock_subprocess_check_output)
assert ensure_clean_git_state() == "bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb"
def test_falls_back_to_head_if_remote_branch_ref_missing(
self, monkeypatch: pytest.MonkeyPatch
) -> None:
monkeypatch.setenv("GITHUB_ACTIONS", "true")
monkeypatch.setenv("GITHUB_HEAD_REF", "missing_branch")
def mock_subprocess_run(command: list[str], **kwargs) -> subprocess.CompletedProcess[str]:
assert kwargs.get("capture_output") is True
assert kwargs.get("text") is True
if command == ["git", "status", "--porcelain"]:
return subprocess.CompletedProcess(command, 0, stdout="", stderr="")
if command == ["git", "rev-parse", "--verify", "origin/missing_branch"]:
return subprocess.CompletedProcess(command, 1, stdout="", stderr="unknown revision")
raise AssertionError(f"Unexpected subprocess.run call: {command}")
def mock_subprocess_check_output(command: list[str], **kwargs) -> str:
assert kwargs.get("text") is True
if command == ["git", "rev-parse", "--abbrev-ref", "HEAD"]:
return "HEAD\n"
if command == ["git", "rev-parse", "HEAD"]:
return "cccccccccccccccccccccccccccccccccccccccc\n"
raise AssertionError(f"Unexpected subprocess.check_output call: {command}")
monkeypatch.setattr(subprocess, "run", mock_subprocess_run)
monkeypatch.setattr(subprocess, "check_output", mock_subprocess_check_output)
assert ensure_clean_git_state() == "cccccccccccccccccccccccccccccccccccccccc"