Skip to content

Commit 0245103

Browse files
committed
feat(submit): integrate estimator and write-back into watch loop
Stamps each submission with --comment=cluv:v1:<spec_key> so sacct backfill can identify cluv jobs later. When [tool.cluv.estimate] is configured, _resolve_estimate loads the local cache (backfilling from sacct on cold cache) and overrides SBATCH_MEM with the estimator's prediction. Renames _retry_on_oom to _watch_job_chain and threads write_back through it so each terminal job persists a JobRecord for future estimates. Both retry-only and estimate-only paths reuse the same loop.
1 parent 916d091 commit 0245103

2 files changed

Lines changed: 164 additions & 65 deletions

File tree

cluv/cli/submit.py

Lines changed: 123 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,12 @@
55
import shlex
66
import subprocess
77
import sys
8+
from datetime import UTC, datetime
89
from pathlib import Path
910

11+
from cluv import history
1012
from cluv.cli.sync import sync
11-
from cluv.config import ClusterConfig, RetryConfig, find_pyproject, get_config
13+
from cluv.config import ClusterConfig, EstimateConfig, RetryConfig, find_pyproject, get_config
1214
from cluv.remote import Remote
1315
from cluv.utils import console
1416

@@ -126,33 +128,56 @@ async def submit(
126128

127129
# Sync.
128130
remotes = await sync(clusters=[cluster])
129-
130-
# Run the sbatch command over SSH.
131131
remote = remotes[0]
132-
result = await sbatch(remote, job_script, sbatch_args, program_args, git_commit)
133132

133+
# Identify "the same job" across submissions; stamp the key into the sacct
134+
# Comment field and into the job's env so the history cache and the script
135+
# can both see it.
136+
from salvo.history import spec_key
137+
138+
key = spec_key(str(job_script), git_commit, tuple(program_args))
139+
sbatch_args = [*sbatch_args, f"--comment={history.build_comment(key)}"]
140+
env_overrides: dict[str, str] = {"CLUV_SPEC_KEY": key}
141+
142+
# Memory estimator (opt-in). Runs after sync so the cold-cache backfill can
143+
# use the remote we just connected to.
144+
estimate_cfg = cluv_config.estimate if cluv_config.estimate and cluv_config.estimate.enabled else None
145+
initial_mem = _initial_mem(cluster)
146+
if estimate_cfg is not None:
147+
estimate_mem_mb = await _resolve_estimate(remote, cluster, key, estimate_cfg)
148+
if estimate_mem_mb is not None:
149+
initial_mem = f"{estimate_mem_mb}M"
150+
env_overrides["SBATCH_MEM"] = initial_mem
151+
env_overrides["CLUV_ESTIMATED_MEM"] = initial_mem
152+
153+
result = await sbatch(remote, job_script, sbatch_args, program_args, git_commit, env_overrides)
134154
if result.returncode != 0:
135155
console.print(f"[red] Error during sbatch : {result.stderr}[/red]")
136156
return None
137157

138158
job_id = int(result.stdout.strip())
139-
140159
console.log(
141160
f"Successfully submitted job {job_id} on the {cluster} cluster.\n"
142161
f"Use `ssh {cluster} sacct -j {job_id}` to view its status."
143162
)
144163

145-
if cluv_config.retry is None:
164+
watch = cluv_config.retry is not None or estimate_cfg is not None
165+
if not watch:
146166
return job_id
147167

148-
return await _retry_on_oom(
168+
return await _watch_job_chain(
149169
remote=remote,
170+
cluster=cluster,
171+
key=key,
150172
job_id=job_id,
151173
job_script=job_script,
152174
sbatch_args=sbatch_args,
153175
program_args=program_args,
154176
git_commit=git_commit,
177+
env_overrides=env_overrides,
178+
initial_mem=initial_mem,
155179
retry=cluv_config.retry,
180+
write_back=estimate_cfg is not None,
156181
)
157182

158183

@@ -270,35 +295,109 @@ async def _wait_terminal(remote: Remote, job_id: int) -> str:
270295
await asyncio.sleep(RETRY_POLL_INTERVAL_S)
271296

272297

273-
async def _retry_on_oom(
298+
async def _resolve_estimate(
299+
remote: Remote, cluster: str, key: str, cfg: EstimateConfig
300+
) -> int | None:
301+
"""Return a memory override (MiB) from local history, or None to skip.
302+
303+
Loads the cache for `(cluster, key)`, optionally backfills from sacct on
304+
cold cache, then asks `salvo.history.estimate_mem` for a number. When the
305+
estimator returns `None` (insufficient history), the configured
306+
`SBATCH_MEM` is left untouched.
307+
"""
308+
from salvo.history import estimate_mem
309+
310+
records = history.load(cluster, key)
311+
if not records and cfg.backfill:
312+
try:
313+
n = await history.backfill_from_sacct(remote, cluster)
314+
if n:
315+
console.log(f"estimator: backfilled {n} record(s) from sacct on {cluster}")
316+
except Exception as err: # network/sacct hiccup should not block submit
317+
console.log(f"[yellow]estimator: backfill failed ({err}); continuing[/yellow]")
318+
records = history.load(cluster, key)
319+
320+
estimate = estimate_mem(
321+
records,
322+
safety=cfg.safety,
323+
window=cfg.window,
324+
min_samples=cfg.min_samples,
325+
)
326+
if estimate.mem_mb is None:
327+
console.log(f"estimator: {estimate.rationale}; using configured SBATCH_MEM")
328+
return None
329+
console.log(
330+
f"estimator: {estimate.rationale} (confidence={estimate.confidence}); "
331+
f"overriding SBATCH_MEM"
332+
)
333+
return estimate.mem_mb
334+
335+
336+
async def _persist_terminal(
337+
remote: Remote, cluster: str, key: str, job_id: int, mem_for_job: str
338+
) -> None:
339+
"""Read the job's terminal sacct row and append a JobRecord to the cache."""
340+
from salvo.history import JobRecord
341+
from salvo.job.spec import parse_mem_mb
342+
343+
state = await get_job_status(remote, job_id)
344+
state = (state or "").split()[0]
345+
max_rss = await get_max_rss_mb(remote, job_id)
346+
try:
347+
mem_mb = parse_mem_mb(mem_for_job)
348+
except ValueError:
349+
mem_mb = 0
350+
history.save_record(
351+
JobRecord(
352+
job_id=str(job_id),
353+
key=key,
354+
cluster=cluster,
355+
state=state or "UNKNOWN",
356+
mem_mb=mem_mb,
357+
max_rss_mb=max_rss,
358+
submitted_at=datetime.now(UTC),
359+
)
360+
)
361+
362+
363+
async def _watch_job_chain(
274364
remote: Remote,
365+
cluster: str,
366+
key: str,
275367
job_id: int,
276368
job_script: Path,
277369
sbatch_args: list[str],
278370
program_args: list[str],
279371
git_commit: str,
280-
retry: RetryConfig,
372+
env_overrides: dict[str, str],
373+
initial_mem: str,
374+
retry: RetryConfig | None,
375+
write_back: bool,
281376
) -> int | None:
282-
"""OOM-aware resubmit loop layered on top of the single-cluster `submit()` path.
377+
"""Watch a (possibly retrying) job chain to terminal state.
378+
379+
Combines two concerns so each terminal state hits one wait loop:
283380
284-
Polls sacct for `job_id` until terminal. On `OUT_OF_MEMORY`, asks
285-
`salvo.policy.apply_oom` for the next memory ask, mutates the env-var dict
286-
passed to `sbatch`, and resubmits. On any other terminal state, returns the
287-
current `job_id`. Bounded by `retry.max_hops` and by `FailStep` in the policy.
381+
* If `retry` is set, OUT_OF_MEMORY triggers `salvo.policy.apply_oom`, the
382+
memory ask gets bumped, and the job is resubmitted (up to `max_hops`).
383+
* If `write_back` is true, each terminal job persists a `JobRecord` so the
384+
estimator learns from it on the next run.
288385
"""
289-
# Import lazily so users who don't opt in don't pay for pysalvo at import time.
290386
from salvo.job.spec import JobSpec
291387
from salvo.policy import OomContext, apply_oom
292388

293-
env_overrides: dict[str, str] = {}
389+
current_mem = initial_mem
390+
max_hops = retry.max_hops if retry else 0
294391
hop = 0
295-
# Track current memory ask through hops. None means "rely on cluster default";
296-
# in that case bump_mem still works because JobSpec defaults to 4G.
297-
current_mem = env_overrides.get("SBATCH_MEM") or _initial_mem(remote.hostname)
298392

299-
while hop < retry.max_hops:
393+
while True:
300394
state = await _wait_terminal(remote, job_id)
301-
if state != "OUT_OF_MEMORY":
395+
if write_back:
396+
await _persist_terminal(remote, cluster, key, job_id, current_mem)
397+
if retry is None or state != "OUT_OF_MEMORY":
398+
return job_id
399+
if hop >= max_hops:
400+
console.log(f"max_hops={max_hops} reached; last job id is {job_id}")
302401
return job_id
303402

304403
max_rss_mb = await get_max_rss_mb(remote, job_id)
@@ -316,9 +415,9 @@ async def _retry_on_oom(
316415
hop += 1
317416
current_mem = new_spec.mem
318417
env_overrides["SBATCH_MEM"] = current_mem
319-
env_overrides["CLUV_HOP"] = f"{hop}/{retry.max_hops}"
418+
env_overrides["CLUV_HOP"] = f"{hop}/{max_hops}"
320419
console.log(
321-
f"hop {hop}/{retry.max_hops}: resubmitting on {remote.hostname} with mem={current_mem}"
420+
f"hop {hop}/{max_hops}: resubmitting on {remote.hostname} with mem={current_mem}"
322421
)
323422
result = await sbatch(
324423
remote, job_script, sbatch_args, program_args, git_commit, env_overrides
@@ -327,10 +426,7 @@ async def _retry_on_oom(
327426
console.print(f"[red]resubmit hop {hop} failed: {result.stderr}[/red]")
328427
return None
329428
job_id = int(result.stdout.strip())
330-
console.log(f"hop {hop}/{retry.max_hops}: submitted as job {job_id}")
331-
332-
console.log(f"max_hops={retry.max_hops} reached; last job id is {job_id}")
333-
return job_id
429+
console.log(f"hop {hop}/{max_hops}: submitted as job {job_id}")
334430

335431

336432
def _initial_mem(cluster: str) -> str:

tests/test_submit_retry.py

Lines changed: 41 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -14,11 +14,38 @@
1414
import pytest
1515

1616
from cluv.cli import submit as submit_module
17-
from cluv.cli.submit import _retry_on_oom
17+
from cluv.cli.submit import _watch_job_chain
1818
from cluv.config import RetryConfig, get_config
1919
from cluv.remote import Remote
2020

2121

22+
def _call_watch(
23+
*,
24+
job_id: int,
25+
retry: RetryConfig,
26+
initial_mem: str = "16G",
27+
):
28+
"""Test helper: invoke _watch_job_chain with retry-only defaults.
29+
30+
`write_back=False` keeps these tests focused on the retry math; history
31+
cache writes are covered separately in test_history.py.
32+
"""
33+
return _watch_job_chain(
34+
remote=Remote(hostname="mila"),
35+
cluster="mila",
36+
key="testkey",
37+
job_id=job_id,
38+
job_script=Path("scripts/job.sh"),
39+
sbatch_args=[],
40+
program_args=[],
41+
git_commit="abcdef",
42+
env_overrides={},
43+
initial_mem=initial_mem,
44+
retry=retry,
45+
write_back=False,
46+
)
47+
48+
2249
@pytest.fixture
2350
def project_dir(tmp_path: Path, monkeypatch: pytest.MonkeyPatch) -> Path:
2451
"""Mirror the fixture in `tests/test_submit.py` so retry tests pick up a config."""
@@ -90,13 +117,8 @@ async def test_retry_bumps_mem_then_completes(
90117
runner = _ScriptedRunner(states=["OUT_OF_MEMORY", "COMPLETED"], sbatch_jobs=[1002])
91118
runner.install(monkeypatch)
92119

93-
job_id = await _retry_on_oom(
94-
remote=Remote(hostname="mila"),
120+
job_id = await _call_watch(
95121
job_id=1001,
96-
job_script=Path("scripts/job.sh"),
97-
sbatch_args=[],
98-
program_args=[],
99-
git_commit="abcdef",
100122
retry=RetryConfig(on_oom=["bump_mem(2x, max=128G)", "fail"], max_hops=5),
101123
)
102124

@@ -117,13 +139,8 @@ async def test_retry_grows_mem_across_hops(
117139
)
118140
runner.install(monkeypatch)
119141

120-
job_id = await _retry_on_oom(
121-
remote=Remote(hostname="mila"),
142+
job_id = await _call_watch(
122143
job_id=2001,
123-
job_script=Path("scripts/job.sh"),
124-
sbatch_args=[],
125-
program_args=[],
126-
git_commit="abcdef",
127144
retry=RetryConfig(on_oom=["bump_mem(2x, max=128G)", "fail"], max_hops=5),
128145
)
129146

@@ -138,19 +155,17 @@ async def test_retry_grows_mem_across_hops(
138155
async def test_retry_caps_at_max_hops(
139156
project_dir: Path, monkeypatch: pytest.MonkeyPatch
140157
) -> None:
158+
# max_hops=2 caps resubmits at two even if the last one still OOMs. The
159+
# watch loop polls the final job to terminal state so we get a state to
160+
# report back, then exits without a third resubmit.
141161
runner = _ScriptedRunner(
142-
states=["OUT_OF_MEMORY", "OUT_OF_MEMORY"],
162+
states=["OUT_OF_MEMORY", "OUT_OF_MEMORY", "OUT_OF_MEMORY"],
143163
sbatch_jobs=[3002, 3003],
144164
)
145165
runner.install(monkeypatch)
146166

147-
job_id = await _retry_on_oom(
148-
remote=Remote(hostname="mila"),
167+
job_id = await _call_watch(
149168
job_id=3001,
150-
job_script=Path("scripts/job.sh"),
151-
sbatch_args=[],
152-
program_args=[],
153-
git_commit="abcdef",
154169
retry=RetryConfig(on_oom=["bump_mem(1.5x, max=128G)", "fail"], max_hops=2),
155170
)
156171

@@ -167,13 +182,8 @@ async def test_retry_terminates_on_fail_step(
167182

168183
# 100G * 5 capped at 128G is reachable; force fall-through to fail by
169184
# asking for a bump that the policy declines (already at the cap).
170-
job_id = await _retry_on_oom(
171-
remote=Remote(hostname="mila"),
185+
job_id = await _call_watch(
172186
job_id=4001,
173-
job_script=Path("scripts/job.sh"),
174-
sbatch_args=[],
175-
program_args=[],
176-
git_commit="abcdef",
177187
retry=RetryConfig(on_oom=["fail"], max_hops=5),
178188
)
179189

@@ -187,26 +197,19 @@ async def test_retry_returns_immediately_on_non_oom_terminal(
187197
runner = _ScriptedRunner(states=["COMPLETED"], sbatch_jobs=[])
188198
runner.install(monkeypatch)
189199

190-
job_id = await _retry_on_oom(
191-
remote=Remote(hostname="mila"),
200+
job_id = await _call_watch(
192201
job_id=5001,
193-
job_script=Path("scripts/job.sh"),
194-
sbatch_args=[],
195-
program_args=[],
196-
git_commit="abcdef",
197202
retry=RetryConfig(on_oom=["bump_mem(2x, max=128G)", "fail"], max_hops=5),
198203
)
199204

200205
assert job_id == 5001
201206
assert runner.recorded == []
202207

203208

204-
def test_submit_is_noop_path_when_retry_is_none(project_dir: Path) -> None:
205-
"""When `[tool.cluv.retry]` is absent, `cluv_config.retry` is None.
206-
207-
This is the cheap structural guarantee: a config without the retry section
208-
deserializes to `retry=None`, so `submit()` skips `_retry_on_oom` on the
209-
very first branch and the new code path is dormant.
209+
def test_submit_is_noop_path_when_retry_and_estimate_are_none(project_dir: Path) -> None:
210+
"""When neither `[tool.cluv.retry]` nor `[tool.cluv.estimate]` is configured,
211+
`submit()` exits the moment sbatch returns and the watch loop never runs.
210212
"""
211213
config = get_config()
212214
assert config.retry is None
215+
assert config.estimate is None

0 commit comments

Comments
 (0)