Skip to content

Commit f66c94c

Browse files
alec-flowersclaude
andcommitted
Store per-worker fingerprints instead of aggregating
Each worker (prefill_w0, decode_w0, etc.) keeps its own fingerprint in the lockfile rather than being unioned into one blob. Prefill and decode nodes can have different GPU types, drivers, and packages — collapsing them hides real differences. srtctl diff now compares each worker against its counterpart between runs. srtctl check verifies each worker independently. Backward compatible: old lockfiles with a single 'fingerprint' key are loaded as {"worker": fingerprint}. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent 72d3c5c commit f66c94c

3 files changed

Lines changed: 183 additions & 142 deletions

File tree

src/srtctl/cli/submit.py

Lines changed: 43 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@
4242
resolve_config_with_defaults,
4343
)
4444
from srtctl.core.fingerprint import check_against_fingerprint, diff_fingerprints, format_check_results, format_diff
45-
from srtctl.core.lockfile import load_lockfile_fingerprint
45+
from srtctl.core.lockfile import load_lockfile_fingerprints
4646
from srtctl.core.schema import SrtConfig
4747
from srtctl.core.status import create_job_record
4848
from srtctl.core.validation import run_validations_background
@@ -933,38 +933,58 @@ def add_common_args(p):
933933

934934
# Handle diff and check commands first (they don't use -f/config)
935935
if args.command == "diff":
936-
fp_a = load_lockfile_fingerprint(args.path_a)
937-
fp_b = load_lockfile_fingerprint(args.path_b)
938-
if fp_a is None or fp_b is None:
936+
fps_a = load_lockfile_fingerprints(args.path_a)
937+
fps_b = load_lockfile_fingerprints(args.path_b)
938+
if fps_a is None or fps_b is None:
939939
missing = []
940-
if fp_a is None:
940+
if fps_a is None:
941941
missing.append(str(args.path_a))
942-
if fp_b is None:
942+
if fps_b is None:
943943
missing.append(str(args.path_b))
944-
console.print(f"[bold red]Could not load fingerprint from:[/] {', '.join(missing)}")
944+
console.print(f"[bold red]Could not load fingerprints from:[/] {', '.join(missing)}")
945945
sys.exit(1)
946-
diff = diff_fingerprints(fp_a, fp_b)
947-
console.print(format_diff(diff, verbose=args.verbose))
946+
947+
# Diff each worker against its counterpart
948+
all_workers = sorted(set(fps_a.keys()) | set(fps_b.keys()))
949+
for worker in all_workers:
950+
if worker not in fps_a:
951+
console.print(f"\n[bold]{worker}:[/] only in {args.path_b}")
952+
continue
953+
if worker not in fps_b:
954+
console.print(f"\n[bold]{worker}:[/] only in {args.path_a}")
955+
continue
956+
diff = diff_fingerprints(fps_a[worker], fps_b[worker])
957+
console.print(f"\n[bold]{worker}:[/]")
958+
console.print(format_diff(diff, verbose=args.verbose))
948959
return
949960

950961
if args.command == "check":
951962
import json as json_mod
952963

953-
ref = load_lockfile_fingerprint(args.path)
954-
if ref is None:
955-
console.print(f"[bold red]Could not load fingerprint from:[/] {args.path}")
964+
fps = load_lockfile_fingerprints(args.path)
965+
if fps is None:
966+
console.print(f"[bold red]Could not load fingerprints from:[/] {args.path}")
956967
sys.exit(1)
957-
results = check_against_fingerprint(ref)
958-
if args.json_output:
959-
console.print(
960-
json_mod.dumps(
961-
[{"field": r.field, "status": r.status.value, "message": r.message} for r in results],
962-
indent=2,
963-
)
964-
)
965-
else:
966-
console.print(format_check_results(results))
967-
sys.exit(1 if results else 0)
968+
969+
# Check each worker's fingerprint against current environment
970+
all_results = []
971+
for worker in sorted(fps.keys()):
972+
results = check_against_fingerprint(fps[worker])
973+
if results:
974+
all_results.extend(results)
975+
console.print(f"\n[bold]{worker}:[/]")
976+
if args.json_output:
977+
console.print(
978+
json_mod.dumps(
979+
[{"field": r.field, "status": r.status.value, "message": r.message} for r in results],
980+
indent=2,
981+
)
982+
)
983+
else:
984+
console.print(format_check_results(results))
985+
if not all_results:
986+
console.print(format_check_results([]))
987+
sys.exit(1 if all_results else 0)
968988

969989
# Parse config arg: supports path:selector format for overrides
970990
config_path, selector = parse_config_arg(args.config)

src/srtctl/core/lockfile.py

Lines changed: 51 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313

1414
from __future__ import annotations
1515

16+
import contextlib
1617
import getpass
1718
import logging
1819
import os
@@ -61,10 +62,8 @@ def collect_slurm_context() -> dict[str, Any]:
6162
ctx[key] = val
6263

6364
# User and working directory (always available)
64-
try:
65+
with contextlib.suppress(Exception):
6566
ctx["user"] = getpass.getuser()
66-
except Exception:
67-
pass
6867

6968
ctx["cwd"] = str(Path.cwd())
7069

@@ -76,12 +75,13 @@ def collect_slurm_context() -> dict[str, Any]:
7675
return ctx
7776

7877

79-
def aggregate_fingerprints(log_dir: Path) -> dict[str, Any] | None:
80-
"""Aggregate per-worker fingerprint files into a single fingerprint.
78+
def collect_worker_fingerprints(log_dir: Path) -> dict[str, Any] | None:
79+
"""Load per-worker fingerprint files into a dict keyed by worker name.
8180
82-
Reads all fingerprint_*.json files from the log directory. Scalar fields
83-
are taken from the first file. pip_packages are merged (sorted union).
81+
Returns a dict like:
82+
{"prefill_w0": {...}, "decode_w0": {...}, "decode_w1": {...}}
8483
84+
The key is derived from the filename: fingerprint_prefill_w0.json -> "prefill_w0".
8585
Returns None if no fingerprint files are found or all fail to load.
8686
"""
8787
try:
@@ -93,38 +93,27 @@ def aggregate_fingerprints(log_dir: Path) -> dict[str, Any] | None:
9393
if not fp_files:
9494
return None
9595

96-
fingerprints = []
96+
result: dict[str, Any] = {}
9797
for fp_file in fp_files:
9898
fp = load_fingerprint(fp_file)
9999
if fp is not None:
100-
fingerprints.append(fp)
101-
102-
if not fingerprints:
103-
return None
104-
105-
# Use first fingerprint as base for scalar fields
106-
result = {k: v for k, v in fingerprints[0].items() if k != "pip_packages"}
107-
108-
# Merge pip packages: sorted union across all workers
109-
all_packages: set[str] = set()
110-
for fp in fingerprints:
111-
for pkg in fp.get("pip_packages", []):
112-
all_packages.add(pkg)
113-
result["pip_packages"] = sorted(all_packages, key=lambda s: s.lower())
100+
# fingerprint_prefill_w0.json -> prefill_w0
101+
worker_key = fp_file.stem.removeprefix("fingerprint_")
102+
result[worker_key] = fp
114103

115-
return result
104+
return result if result else None
116105

117106

118107
def build_lockfile(
119108
config: SrtConfig,
120-
runtime_fingerprint: dict[str, Any] | None = None,
109+
worker_fingerprints: dict[str, Any] | None = None,
121110
) -> dict[str, Any]:
122-
"""Build the lockfile dict from a resolved config and optional fingerprint.
111+
"""Build the lockfile dict from a resolved config and optional per-worker fingerprints.
123112
124113
Returns a dict with:
125114
- _meta: lockfile version, timestamp, SLURM context
126115
- config: the full resolved config as a dict
127-
- fingerprint: the aggregated runtime fingerprint (or None)
116+
- fingerprints: per-worker fingerprints keyed by worker name (or None)
128117
"""
129118
from srtctl.core.schema import SrtConfig
130119

@@ -137,7 +126,7 @@ def build_lockfile(
137126
"slurm": collect_slurm_context(),
138127
},
139128
"config": config_dict,
140-
"fingerprint": runtime_fingerprint,
129+
"fingerprints": worker_fingerprints,
141130
}
142131

143132

@@ -149,14 +138,14 @@ def write_lockfile(
149138
"""Write recipe.lock.yaml to the output directory.
150139
151140
Called twice per job:
152-
1. At job start (log_dir=None) — writes config + SLURM context, fingerprint=null
153-
2. At job end (log_dir set) — rewrites with aggregated runtime fingerprint
141+
1. At job start (log_dir=None) — writes config + SLURM context, fingerprints=null
142+
2. At job end (log_dir set) — rewrites with per-worker fingerprints
154143
155144
Returns True on success, False on any failure. Never raises.
156145
"""
157146
try:
158-
fingerprint = aggregate_fingerprints(log_dir) if log_dir else None
159-
lockfile_data = build_lockfile(config, fingerprint)
147+
fingerprints = collect_worker_fingerprints(log_dir) if log_dir else None
148+
lockfile_data = build_lockfile(config, fingerprints)
160149

161150
lockfile_path = output_dir / "recipe.lock.yaml"
162151
lockfile_path.write_text(yaml.dump(lockfile_data, default_flow_style=False, sort_keys=False))
@@ -167,48 +156,59 @@ def write_lockfile(
167156
return False
168157

169158

170-
def load_lockfile_fingerprint(path: Path) -> dict[str, Any] | None:
171-
"""Load a fingerprint from a lockfile, output directory, or raw JSON.
159+
def load_lockfile_fingerprints(path: Path) -> dict[str, Any] | None:
160+
"""Load per-worker fingerprints from a lockfile, output directory, or raw JSON.
172161
173162
Accepts:
174-
- Path to recipe.lock.yaml → reads the 'fingerprint' section
175-
- Path to an output directory → looks for recipe.lock.yaml inside
176-
- Path to a fingerprint JSON file → loads directly
163+
- Path to recipe.lock.yaml → reads the 'fingerprints' section (per-worker dict)
164+
- Path to an output directory → looks for recipe.lock.yaml or raw fingerprint files
165+
- Path to a single fingerprint JSON → wraps as {"worker": fingerprint}
177166
178-
Returns None if the fingerprint cannot be loaded.
167+
Returns a dict keyed by worker name, e.g.:
168+
{"prefill_w0": {...}, "decode_w0": {...}}
169+
Returns None if no fingerprints can be loaded.
179170
"""
180171
try:
181-
# If it's a directory, look for lockfile or fingerprint files
182172
if path.is_dir():
183173
lockfile = path / "recipe.lock.yaml"
184174
if lockfile.exists():
185-
return _load_fingerprint_from_lockfile(lockfile)
186-
# Fall back to aggregating raw fingerprint files from logs/
175+
return _load_fingerprints_from_lockfile(lockfile)
176+
# Fall back to collecting raw fingerprint files
187177
logs_dir = path / "logs"
188178
if logs_dir.is_dir():
189-
return aggregate_fingerprints(logs_dir)
190-
return aggregate_fingerprints(path)
179+
return collect_worker_fingerprints(logs_dir)
180+
return collect_worker_fingerprints(path)
191181

192-
# If it's a YAML file, try loading as lockfile
193182
if path.suffix in (".yaml", ".yml"):
194-
return _load_fingerprint_from_lockfile(path)
183+
return _load_fingerprints_from_lockfile(path)
195184

196-
# Otherwise try loading as raw fingerprint JSON
197185
if path.suffix == ".json":
198-
return load_fingerprint(path)
186+
fp = load_fingerprint(path)
187+
if fp is not None:
188+
# Single file — derive worker key from filename
189+
worker_key = path.stem.removeprefix("fingerprint_") or "worker"
190+
return {worker_key: fp}
191+
return None
199192

200193
return None
201194
except Exception as e:
202-
logger.debug("Failed to load fingerprint from %s: %s", path, e)
195+
logger.debug("Failed to load fingerprints from %s: %s", path, e)
203196
return None
204197

205198

206-
def _load_fingerprint_from_lockfile(path: Path) -> dict[str, Any] | None:
207-
"""Extract the fingerprint section from a lockfile YAML."""
199+
def _load_fingerprints_from_lockfile(path: Path) -> dict[str, Any] | None:
200+
"""Extract the per-worker fingerprints from a lockfile YAML."""
208201
try:
209202
data = yaml.safe_load(path.read_text())
210-
if isinstance(data, dict):
211-
return data.get("fingerprint")
203+
if not isinstance(data, dict):
204+
return None
205+
# Support both 'fingerprints' (new, per-worker) and 'fingerprint' (old, single)
206+
fps = data.get("fingerprints")
207+
if isinstance(fps, dict):
208+
return fps
209+
fp = data.get("fingerprint")
210+
if isinstance(fp, dict):
211+
return {"worker": fp}
212212
return None
213213
except Exception as e:
214214
logger.debug("Failed to parse lockfile %s: %s", path, e)

0 commit comments

Comments
 (0)