55import shlex
66import subprocess
77import sys
8+ from datetime import UTC , datetime
89from pathlib import Path
910
11+ from cluv import history
1012from cluv .cli .sync import sync
11- from cluv .config import ClusterConfig , RetryConfig , find_pyproject , get_config
13+ from cluv .config import ClusterConfig , EstimateConfig , RetryConfig , find_pyproject , get_config
1214from cluv .remote import Remote
1315from cluv .utils import console
1416
@@ -126,33 +128,56 @@ async def submit(
126128
127129 # Sync.
128130 remotes = await sync (clusters = [cluster ])
129-
130- # Run the sbatch command over SSH.
131131 remote = remotes [0 ]
132- result = await sbatch (remote , job_script , sbatch_args , program_args , git_commit )
133132
133+ # Identify "the same job" across submissions; stamp the key into the sacct
134+ # Comment field and into the job's env so the history cache and the script
135+ # can both see it.
136+ from salvo .history import spec_key
137+
138+ key = spec_key (str (job_script ), git_commit , tuple (program_args ))
139+ sbatch_args = [* sbatch_args , f"--comment={ history .build_comment (key )} " ]
140+ env_overrides : dict [str , str ] = {"CLUV_SPEC_KEY" : key }
141+
142+ # Memory estimator (opt-in). Runs after sync so the cold-cache backfill can
143+ # use the remote we just connected to.
144+ estimate_cfg = cluv_config .estimate if cluv_config .estimate and cluv_config .estimate .enabled else None
145+ initial_mem = _initial_mem (cluster )
146+ if estimate_cfg is not None :
147+ estimate_mem_mb = await _resolve_estimate (remote , cluster , key , estimate_cfg )
148+ if estimate_mem_mb is not None :
149+ initial_mem = f"{ estimate_mem_mb } M"
150+ env_overrides ["SBATCH_MEM" ] = initial_mem
151+ env_overrides ["CLUV_ESTIMATED_MEM" ] = initial_mem
152+
153+ result = await sbatch (remote , job_script , sbatch_args , program_args , git_commit , env_overrides )
134154 if result .returncode != 0 :
135155 console .print (f"[red] Error during sbatch : { result .stderr } [/red]" )
136156 return None
137157
138158 job_id = int (result .stdout .strip ())
139-
140159 console .log (
141160 f"Successfully submitted job { job_id } on the { cluster } cluster.\n "
142161 f"Use `ssh { cluster } sacct -j { job_id } ` to view its status."
143162 )
144163
145- if cluv_config .retry is None :
164+ watch = cluv_config .retry is not None or estimate_cfg is not None
165+ if not watch :
146166 return job_id
147167
148- return await _retry_on_oom (
168+ return await _watch_job_chain (
149169 remote = remote ,
170+ cluster = cluster ,
171+ key = key ,
150172 job_id = job_id ,
151173 job_script = job_script ,
152174 sbatch_args = sbatch_args ,
153175 program_args = program_args ,
154176 git_commit = git_commit ,
177+ env_overrides = env_overrides ,
178+ initial_mem = initial_mem ,
155179 retry = cluv_config .retry ,
180+ write_back = estimate_cfg is not None ,
156181 )
157182
158183
@@ -270,35 +295,109 @@ async def _wait_terminal(remote: Remote, job_id: int) -> str:
270295 await asyncio .sleep (RETRY_POLL_INTERVAL_S )
271296
272297
273- async def _retry_on_oom (
298+ async def _resolve_estimate (
299+ remote : Remote , cluster : str , key : str , cfg : EstimateConfig
300+ ) -> int | None :
301+ """Return a memory override (MiB) from local history, or None to skip.
302+
303+ Loads the cache for `(cluster, key)`, optionally backfills from sacct on
304+ cold cache, then asks `salvo.history.estimate_mem` for a number. When the
305+ estimator returns `None` (insufficient history), the configured
306+ `SBATCH_MEM` is left untouched.
307+ """
308+ from salvo .history import estimate_mem
309+
310+ records = history .load (cluster , key )
311+ if not records and cfg .backfill :
312+ try :
313+ n = await history .backfill_from_sacct (remote , cluster )
314+ if n :
315+ console .log (f"estimator: backfilled { n } record(s) from sacct on { cluster } " )
316+ except Exception as err : # network/sacct hiccup should not block submit
317+ console .log (f"[yellow]estimator: backfill failed ({ err } ); continuing[/yellow]" )
318+ records = history .load (cluster , key )
319+
320+ estimate = estimate_mem (
321+ records ,
322+ safety = cfg .safety ,
323+ window = cfg .window ,
324+ min_samples = cfg .min_samples ,
325+ )
326+ if estimate .mem_mb is None :
327+ console .log (f"estimator: { estimate .rationale } ; using configured SBATCH_MEM" )
328+ return None
329+ console .log (
330+ f"estimator: { estimate .rationale } (confidence={ estimate .confidence } ); "
331+ f"overriding SBATCH_MEM"
332+ )
333+ return estimate .mem_mb
334+
335+
336+ async def _persist_terminal (
337+ remote : Remote , cluster : str , key : str , job_id : int , mem_for_job : str
338+ ) -> None :
339+ """Read the job's terminal sacct row and append a JobRecord to the cache."""
340+ from salvo .history import JobRecord
341+ from salvo .job .spec import parse_mem_mb
342+
343+ state = await get_job_status (remote , job_id )
344+ state = (state or "" ).split ()[0 ]
345+ max_rss = await get_max_rss_mb (remote , job_id )
346+ try :
347+ mem_mb = parse_mem_mb (mem_for_job )
348+ except ValueError :
349+ mem_mb = 0
350+ history .save_record (
351+ JobRecord (
352+ job_id = str (job_id ),
353+ key = key ,
354+ cluster = cluster ,
355+ state = state or "UNKNOWN" ,
356+ mem_mb = mem_mb ,
357+ max_rss_mb = max_rss ,
358+ submitted_at = datetime .now (UTC ),
359+ )
360+ )
361+
362+
363+ async def _watch_job_chain (
274364 remote : Remote ,
365+ cluster : str ,
366+ key : str ,
275367 job_id : int ,
276368 job_script : Path ,
277369 sbatch_args : list [str ],
278370 program_args : list [str ],
279371 git_commit : str ,
280- retry : RetryConfig ,
372+ env_overrides : dict [str , str ],
373+ initial_mem : str ,
374+ retry : RetryConfig | None ,
375+ write_back : bool ,
281376) -> int | None :
282- """OOM-aware resubmit loop layered on top of the single-cluster `submit()` path.
377+ """Watch a (possibly retrying) job chain to terminal state.
378+
379+ Combines two concerns so each terminal state hits one wait loop:
283380
284- Polls sacct for `job_id` until terminal. On `OUT_OF_MEMORY `, asks
285- `salvo.policy.apply_oom` for the next memory ask, mutates the env-var dict
286- passed to `sbatch`, and resubmits. On any other terminal state, returns the
287- current `job_id`. Bounded by `retry.max_hops` and by `FailStep` in the policy .
381+ * If `retry` is set, OUT_OF_MEMORY triggers `salvo.policy.apply_oom `, the
382+ memory ask gets bumped, and the job is resubmitted (up to `max_hops`).
383+ * If `write_back` is true, each terminal job persists a `JobRecord` so the
384+ estimator learns from it on the next run .
288385 """
289- # Import lazily so users who don't opt in don't pay for pysalvo at import time.
290386 from salvo .job .spec import JobSpec
291387 from salvo .policy import OomContext , apply_oom
292388
293- env_overrides : dict [str , str ] = {}
389+ current_mem = initial_mem
390+ max_hops = retry .max_hops if retry else 0
294391 hop = 0
295- # Track current memory ask through hops. None means "rely on cluster default";
296- # in that case bump_mem still works because JobSpec defaults to 4G.
297- current_mem = env_overrides .get ("SBATCH_MEM" ) or _initial_mem (remote .hostname )
298392
299- while hop < retry . max_hops :
393+ while True :
300394 state = await _wait_terminal (remote , job_id )
301- if state != "OUT_OF_MEMORY" :
395+ if write_back :
396+ await _persist_terminal (remote , cluster , key , job_id , current_mem )
397+ if retry is None or state != "OUT_OF_MEMORY" :
398+ return job_id
399+ if hop >= max_hops :
400+ console .log (f"max_hops={ max_hops } reached; last job id is { job_id } " )
302401 return job_id
303402
304403 max_rss_mb = await get_max_rss_mb (remote , job_id )
@@ -316,9 +415,9 @@ async def _retry_on_oom(
316415 hop += 1
317416 current_mem = new_spec .mem
318417 env_overrides ["SBATCH_MEM" ] = current_mem
319- env_overrides ["CLUV_HOP" ] = f"{ hop } /{ retry . max_hops } "
418+ env_overrides ["CLUV_HOP" ] = f"{ hop } /{ max_hops } "
320419 console .log (
321- f"hop { hop } /{ retry . max_hops } : resubmitting on { remote .hostname } with mem={ current_mem } "
420+ f"hop { hop } /{ max_hops } : resubmitting on { remote .hostname } with mem={ current_mem } "
322421 )
323422 result = await sbatch (
324423 remote , job_script , sbatch_args , program_args , git_commit , env_overrides
@@ -327,10 +426,7 @@ async def _retry_on_oom(
327426 console .print (f"[red]resubmit hop { hop } failed: { result .stderr } [/red]" )
328427 return None
329428 job_id = int (result .stdout .strip ())
330- console .log (f"hop { hop } /{ retry .max_hops } : submitted as job { job_id } " )
331-
332- console .log (f"max_hops={ retry .max_hops } reached; last job id is { job_id } " )
333- return job_id
429+ console .log (f"hop { hop } /{ max_hops } : submitted as job { job_id } " )
334430
335431
336432def _initial_mem (cluster : str ) -> str :
0 commit comments