1313
1414from __future__ import annotations
1515
16+ import contextlib
1617import getpass
1718import logging
1819import os
@@ -61,10 +62,8 @@ def collect_slurm_context() -> dict[str, Any]:
6162 ctx [key ] = val
6263
6364 # User and working directory (always available)
64- try :
65+ with contextlib . suppress ( Exception ) :
6566 ctx ["user" ] = getpass .getuser ()
66- except Exception :
67- pass
6867
6968 ctx ["cwd" ] = str (Path .cwd ())
7069
@@ -76,12 +75,13 @@ def collect_slurm_context() -> dict[str, Any]:
7675 return ctx
7776
7877
79- def aggregate_fingerprints (log_dir : Path ) -> dict [str , Any ] | None :
80- """Aggregate per-worker fingerprint files into a single fingerprint .
78+ def collect_worker_fingerprints (log_dir : Path ) -> dict [str , Any ] | None :
79+ """Load per-worker fingerprint files into a dict keyed by worker name .
8180
82- Reads all fingerprint_*.json files from the log directory. Scalar fields
83- are taken from the first file. pip_packages are merged (sorted union).
81+ Returns a dict like:
82+ {"prefill_w0": {...}, "decode_w0": {...}, "decode_w1": {...}}
8483
84+ The key is derived from the filename: fingerprint_prefill_w0.json -> "prefill_w0".
8585 Returns None if no fingerprint files are found or all fail to load.
8686 """
8787 try :
@@ -93,38 +93,27 @@ def aggregate_fingerprints(log_dir: Path) -> dict[str, Any] | None:
9393 if not fp_files :
9494 return None
9595
96- fingerprints = []
96+ result : dict [ str , Any ] = {}
9797 for fp_file in fp_files :
9898 fp = load_fingerprint (fp_file )
9999 if fp is not None :
100- fingerprints .append (fp )
101-
102- if not fingerprints :
103- return None
104-
105- # Use first fingerprint as base for scalar fields
106- result = {k : v for k , v in fingerprints [0 ].items () if k != "pip_packages" }
107-
108- # Merge pip packages: sorted union across all workers
109- all_packages : set [str ] = set ()
110- for fp in fingerprints :
111- for pkg in fp .get ("pip_packages" , []):
112- all_packages .add (pkg )
113- result ["pip_packages" ] = sorted (all_packages , key = lambda s : s .lower ())
100+ # fingerprint_prefill_w0.json -> prefill_w0
101+ worker_key = fp_file .stem .removeprefix ("fingerprint_" )
102+ result [worker_key ] = fp
114103
115- return result
104+ return result if result else None
116105
117106
118107def build_lockfile (
119108 config : SrtConfig ,
120- runtime_fingerprint : dict [str , Any ] | None = None ,
109+ worker_fingerprints : dict [str , Any ] | None = None ,
121110) -> dict [str , Any ]:
122- """Build the lockfile dict from a resolved config and optional fingerprint .
111+ """Build the lockfile dict from a resolved config and optional per-worker fingerprints .
123112
124113 Returns a dict with:
125114 - _meta: lockfile version, timestamp, SLURM context
126115 - config: the full resolved config as a dict
127- - fingerprint: the aggregated runtime fingerprint (or None)
116+ - fingerprints: per-worker fingerprints keyed by worker name (or None)
128117 """
129118 from srtctl .core .schema import SrtConfig
130119
@@ -137,7 +126,7 @@ def build_lockfile(
137126 "slurm" : collect_slurm_context (),
138127 },
139128 "config" : config_dict ,
140- "fingerprint " : runtime_fingerprint ,
129+ "fingerprints " : worker_fingerprints ,
141130 }
142131
143132
@@ -149,14 +138,14 @@ def write_lockfile(
149138 """Write recipe.lock.yaml to the output directory.
150139
151140 Called twice per job:
152- 1. At job start (log_dir=None) — writes config + SLURM context, fingerprint =null
153- 2. At job end (log_dir set) — rewrites with aggregated runtime fingerprint
141+ 1. At job start (log_dir=None) — writes config + SLURM context, fingerprints =null
142+ 2. At job end (log_dir set) — rewrites with per-worker fingerprints
154143
155144 Returns True on success, False on any failure. Never raises.
156145 """
157146 try :
158- fingerprint = aggregate_fingerprints (log_dir ) if log_dir else None
159- lockfile_data = build_lockfile (config , fingerprint )
147+ fingerprints = collect_worker_fingerprints (log_dir ) if log_dir else None
148+ lockfile_data = build_lockfile (config , fingerprints )
160149
161150 lockfile_path = output_dir / "recipe.lock.yaml"
162151 lockfile_path .write_text (yaml .dump (lockfile_data , default_flow_style = False , sort_keys = False ))
@@ -167,48 +156,59 @@ def write_lockfile(
167156 return False
168157
169158
170- def load_lockfile_fingerprint (path : Path ) -> dict [str , Any ] | None :
171- """Load a fingerprint from a lockfile, output directory, or raw JSON.
159+ def load_lockfile_fingerprints (path : Path ) -> dict [str , Any ] | None :
160+ """Load per-worker fingerprints from a lockfile, output directory, or raw JSON.
172161
173162 Accepts:
174- - Path to recipe.lock.yaml → reads the 'fingerprint ' section
175- - Path to an output directory → looks for recipe.lock.yaml inside
176- - Path to a fingerprint JSON file → loads directly
163+ - Path to recipe.lock.yaml → reads the 'fingerprints ' section (per-worker dict)
164+ - Path to an output directory → looks for recipe.lock.yaml or raw fingerprint files
165+ - Path to a single fingerprint JSON → wraps as {"worker": fingerprint}
177166
178- Returns None if the fingerprint cannot be loaded.
167+ Returns a dict keyed by worker name, e.g.:
168+ {"prefill_w0": {...}, "decode_w0": {...}}
169+ Returns None if no fingerprints can be loaded.
179170 """
180171 try :
181- # If it's a directory, look for lockfile or fingerprint files
182172 if path .is_dir ():
183173 lockfile = path / "recipe.lock.yaml"
184174 if lockfile .exists ():
185- return _load_fingerprint_from_lockfile (lockfile )
186- # Fall back to aggregating raw fingerprint files from logs/
175+ return _load_fingerprints_from_lockfile (lockfile )
176+ # Fall back to collecting raw fingerprint files
187177 logs_dir = path / "logs"
188178 if logs_dir .is_dir ():
189- return aggregate_fingerprints (logs_dir )
190- return aggregate_fingerprints (path )
179+ return collect_worker_fingerprints (logs_dir )
180+ return collect_worker_fingerprints (path )
191181
192- # If it's a YAML file, try loading as lockfile
193182 if path .suffix in (".yaml" , ".yml" ):
194- return _load_fingerprint_from_lockfile (path )
183+ return _load_fingerprints_from_lockfile (path )
195184
196- # Otherwise try loading as raw fingerprint JSON
197185 if path .suffix == ".json" :
198- return load_fingerprint (path )
186+ fp = load_fingerprint (path )
187+ if fp is not None :
188+ # Single file — derive worker key from filename
189+ worker_key = path .stem .removeprefix ("fingerprint_" ) or "worker"
190+ return {worker_key : fp }
191+ return None
199192
200193 return None
201194 except Exception as e :
202- logger .debug ("Failed to load fingerprint from %s: %s" , path , e )
195+ logger .debug ("Failed to load fingerprints from %s: %s" , path , e )
203196 return None
204197
205198
206- def _load_fingerprint_from_lockfile (path : Path ) -> dict [str , Any ] | None :
207- """Extract the fingerprint section from a lockfile YAML."""
199+ def _load_fingerprints_from_lockfile (path : Path ) -> dict [str , Any ] | None :
200+ """Extract the per-worker fingerprints from a lockfile YAML."""
208201 try :
209202 data = yaml .safe_load (path .read_text ())
210- if isinstance (data , dict ):
211- return data .get ("fingerprint" )
203+ if not isinstance (data , dict ):
204+ return None
205+ # Support both 'fingerprints' (new, per-worker) and 'fingerprint' (old, single)
206+ fps = data .get ("fingerprints" )
207+ if isinstance (fps , dict ):
208+ return fps
209+ fp = data .get ("fingerprint" )
210+ if isinstance (fp , dict ):
211+ return {"worker" : fp }
212212 return None
213213 except Exception as e :
214214 logger .debug ("Failed to parse lockfile %s: %s" , path , e )
0 commit comments