1313# limitations under the License.
1414import copy
1515import hashlib
16+ import json
1617import logging
1718import os
19+ import re
1820import shlex
1921import subprocess
2022from collections import defaultdict
@@ -180,7 +182,7 @@ def get_expected_done_files(output_dir, random_seeds, chunk_ids):
180182 return file_map
181183
182184
183- def get_remaining_jobs (cluster_config , output_dir , random_seeds , chunk_ids , rerun_done ):
185+ def get_remaining_jobs (cluster_config , output_dir , random_seeds , chunk_ids , rerun_done , rerun_ratelimit_errors = False ):
184186 """
185187 Determines which jobs still need to be run based on missing .done files.
186188 Returns a mapping from random_seed to list of chunk_ids that need processing.
@@ -189,14 +191,13 @@ def get_remaining_jobs(cluster_config, output_dir, random_seeds, chunk_ids, reru
189191 return {seed : copy .deepcopy (chunk_ids ) for seed in random_seeds }
190192
191193 status_dir = get_unmounted_path (cluster_config , output_dir )
192- expected_files = get_expected_done_files (output_dir , random_seeds , chunk_ids )
194+ expected_files = get_expected_done_files (status_dir , random_seeds , chunk_ids )
193195 check_commands = []
194196 for (seed , chunk_id ), filepath in expected_files .items ():
195- unmounted_path = filepath .replace (output_dir , status_dir )
196197 # Create identifiers that can be parsed from output
197198 seed_str = "NONE" if seed is None else str (seed )
198199 chunk_str = "NONE" if chunk_id is None else str (chunk_id )
199- check_commands .append (f'if [ ! -f "{ unmounted_path } " ]; then echo "MISSING:{ seed_str } :{ chunk_str } "; fi' )
200+ check_commands .append (f'if [ ! -f "{ filepath } " ]; then echo "MISSING:{ seed_str } :{ chunk_str } "; fi' )
200201
201202 # Process commands in batches to avoid "Argument list too long" error
202203 # Use a conservative batch size that works well even with long paths
@@ -255,9 +256,41 @@ def get_remaining_jobs(cluster_config, output_dir, random_seeds, chunk_ids, reru
255256
256257 done_jobs = defaultdict (list )
257258 for seed , chunk_id in expected_files .keys ():
258- if chunk_id not in missing_jobs [ seed ] :
259+ if chunk_id not in missing_jobs . get ( seed , []) :
259260 done_jobs [seed ].append (chunk_id )
260261
262+ if rerun_ratelimit_errors :
263+ for seed in random_seeds :
264+ for chunk_id in list (missing_jobs .get (seed , [])):
265+ # for partially done seed/chunk_id combo rewrite the current -async file by simply dropping ratelimit error rows
266+ output_file = get_chunked_rs_filename (status_dir , random_seed = seed , chunk_id = chunk_id )
267+ _rewrite_async_resume_file_if_needed (output_file )
268+
269+ for seed in random_seeds :
270+ for chunk_id in list (done_jobs [seed ]):
271+ # for fully done seed/chunk_id combo - of the chunks have been merged, there is no easy way to rerun just the ratelimit error rows so raise and ask user to --rerun-done
272+ if chunk_id is not None and len (chunk_ids ) > 1 :
273+ merged_output_file = get_chunked_rs_filename (status_dir , random_seed = seed , chunk_id = None )
274+ if Path (merged_output_file ).exists ():
275+ raise ValueError (
276+ "Cannot use --rerun-rate-limit-error for completed chunked outputs because "
277+ f"`{ merged_output_file } ` has already been merged. Use --rerun-done to fully rerun the seed."
278+ )
279+ # for fully done seed/chunk_id combo - if the chunks have not been merged, rewrite the .jsonl to .jsonl-async file by dropping ratelimit error rows and remove .done files
280+ output_file = get_chunked_rs_filename (status_dir , random_seed = seed , chunk_id = chunk_id )
281+ if _rewrite_async_resume_file_if_needed (output_file ):
282+ done_file = Path (expected_files [(seed , chunk_id )])
283+ if done_file .exists ():
284+ done_file .unlink ()
285+ missing_jobs [seed ].append (chunk_id )
286+ done_jobs [seed ].remove (chunk_id )
287+
288+ missing_job_keys = [(seed , chunk_id ) for seed in random_seeds for chunk_id in missing_jobs .get (seed , [])]
289+ if len (missing_job_keys ) != len (set (missing_job_keys )):
290+ raise RuntimeError (
291+ "Internal error: duplicate jobs were scheduled for rerun under --rerun-rate-limit-error - this should never happen"
292+ )
293+
261294 done_jobs_str = ", " .join (
262295 [
263296 (
@@ -298,6 +331,59 @@ def get_remaining_jobs(cluster_config, output_dir, random_seeds, chunk_ids, reru
298331 return missing_jobs
299332
300333
334+ def _is_rerunnable_ratelimit_error (output_row : dict ) -> bool :
335+ values_to_check = [output_row .get ("detailed_error" ), output_row .get ("error" )]
336+
337+ while values_to_check :
338+ value = values_to_check .pop ()
339+ if isinstance (value , str ) and "ratelimit" in re .sub (r"[\s_-]+" , "" , value .lower ()):
340+ return True
341+ if isinstance (value , list ):
342+ values_to_check .extend (value )
343+
344+ return False
345+
346+
347+ def _rewrite_async_resume_file_if_needed (output_file : str , async_position_key : str = "_async_position" ) -> bool :
348+ output_path = Path (output_file )
349+ async_output_path = Path (f"{ output_file } -async" )
350+ if output_path .exists ():
351+ source_path = output_path
352+ rewriting_existing_async = False
353+ elif async_output_path .exists ():
354+ source_path = async_output_path
355+ rewriting_existing_async = True
356+ else :
357+ return False
358+
359+ preserved_rows = []
360+ rerunnable_found = False
361+ with open (source_path , "rt" , encoding = "utf-8" ) as fin :
362+ for idx , line in enumerate (fin ):
363+ row = json .loads (line )
364+ if _is_rerunnable_ratelimit_error (row ):
365+ rerunnable_found = True
366+ continue
367+ if not rewriting_existing_async or async_position_key not in row :
368+ row [async_position_key ] = idx
369+ preserved_rows .append (row )
370+
371+ if not rerunnable_found :
372+ return False
373+
374+ with open (async_output_path , "wt" , encoding = "utf-8" ) as fout :
375+ for row in preserved_rows :
376+ fout .write (json .dumps (row ) + "\n " )
377+ if not rewriting_existing_async :
378+ output_path .unlink ()
379+ LOG .info (
380+ "Prepared `%s` for resume by preserving %d completed rows and removing rate-limit failures" ,
381+ source_path ,
382+ len (preserved_rows ),
383+ )
384+ return True
385+
386+
301387def separate_hydra_args (extra_arguments : str ) -> tuple [str , str ]:
302388 """
303389 Separate Hydra config args (--config-*, --cfg, --info, etc.) and
0 commit comments