Skip to content

Commit 3f52982

Browse files
committed
Add rerun rate-limit error resume support
1 parent 74650dd commit 3f52982

5 files changed

Lines changed: 108 additions & 5 deletions

File tree

nemo_skills/pipeline/eval.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -256,6 +256,12 @@ def eval(
256256
rerun_done: bool = typer.Option(
257257
False, help="If True, will re-run jobs even if a corresponding '.done' file already exists"
258258
),
259+
rerun_ratelimit_errors: bool = typer.Option(
260+
False,
261+
"--rerun-rate-limit-error",
262+
"--rerun-ratelimit-errors",
263+
help="If True, rerun rows whose stored soft-failure error indicates rate limiting, even when '.done' exists.",
264+
),
259265
with_sandbox: bool = typer.Option(False, help="If True, will start a sandbox container alongside this job"),
260266
keep_mounts_for_sandbox: bool = typer.Option(
261267
False,
@@ -400,6 +406,7 @@ def eval(
400406
num_chunks,
401407
chunk_ids,
402408
rerun_done,
409+
rerun_ratelimit_errors,
403410
server_parameters,
404411
extra_arguments,
405412
data_dir,

nemo_skills/pipeline/generate.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -327,6 +327,12 @@ def generate(
327327
rerun_done: bool = typer.Option(
328328
False, help="If True, will re-run jobs even if a corresponding '.done' file already exists"
329329
),
330+
rerun_ratelimit_errors: bool = typer.Option(
331+
False,
332+
"--rerun-rate-limit-error",
333+
"--rerun-ratelimit-errors",
334+
help="If True, rerun rows whose stored soft-failure error indicates rate limiting, even when '.done' exists.",
335+
),
330336
with_sandbox: bool = typer.Option(False, help="If True, will start a sandbox container alongside this job"),
331337
sandbox_env_overrides: List[str] = typer.Option(
332338
None,
@@ -503,6 +509,7 @@ def convert_server_type_to_string(server_type):
503509
random_seeds=random_seeds,
504510
chunk_ids=chunk_ids,
505511
rerun_done=rerun_done,
512+
rerun_ratelimit_errors=rerun_ratelimit_errors,
506513
)
507514

508515
if _task_dependencies is None:

nemo_skills/pipeline/utils/eval.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -295,6 +295,7 @@ def prepare_eval_commands(
295295
num_chunks,
296296
chunk_ids,
297297
rerun_done,
298+
rerun_ratelimit_errors,
298299
server_parameters,
299300
extra_arguments,
300301
data_dir,
@@ -388,6 +389,7 @@ def prepare_eval_commands(
388389
random_seeds=random_seeds,
389390
chunk_ids=benchmark_chunk_ids,
390391
rerun_done=rerun_done,
392+
rerun_ratelimit_errors=rerun_ratelimit_errors,
391393
)
392394
for seed_idx, (seed, benchmark_chunk_ids) in enumerate(benchmark_args.remaining_jobs.items()):
393395
total_evals += len(benchmark_chunk_ids)

nemo_skills/pipeline/utils/generation.py

Lines changed: 91 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,10 @@
1313
# limitations under the License.
1414
import copy
1515
import hashlib
16+
import json
1617
import logging
1718
import os
19+
import re
1820
import shlex
1921
import subprocess
2022
from collections import defaultdict
@@ -180,7 +182,7 @@ def get_expected_done_files(output_dir, random_seeds, chunk_ids):
180182
return file_map
181183

182184

183-
def get_remaining_jobs(cluster_config, output_dir, random_seeds, chunk_ids, rerun_done):
185+
def get_remaining_jobs(cluster_config, output_dir, random_seeds, chunk_ids, rerun_done, rerun_ratelimit_errors=False):
184186
"""
185187
Determines which jobs still need to be run based on missing .done files.
186188
Returns a mapping from random_seed to list of chunk_ids that need processing.
@@ -189,14 +191,13 @@ def get_remaining_jobs(cluster_config, output_dir, random_seeds, chunk_ids, reru
189191
return {seed: copy.deepcopy(chunk_ids) for seed in random_seeds}
190192

191193
status_dir = get_unmounted_path(cluster_config, output_dir)
192-
expected_files = get_expected_done_files(output_dir, random_seeds, chunk_ids)
194+
expected_files = get_expected_done_files(status_dir, random_seeds, chunk_ids)
193195
check_commands = []
194196
for (seed, chunk_id), filepath in expected_files.items():
195-
unmounted_path = filepath.replace(output_dir, status_dir)
196197
# Create identifiers that can be parsed from output
197198
seed_str = "NONE" if seed is None else str(seed)
198199
chunk_str = "NONE" if chunk_id is None else str(chunk_id)
199-
check_commands.append(f'if [ ! -f "{unmounted_path}" ]; then echo "MISSING:{seed_str}:{chunk_str}"; fi')
200+
check_commands.append(f'if [ ! -f "{filepath}" ]; then echo "MISSING:{seed_str}:{chunk_str}"; fi')
200201

201202
# Process commands in batches to avoid "Argument list too long" error
202203
# Use a conservative batch size that works well even with long paths
@@ -255,9 +256,41 @@ def get_remaining_jobs(cluster_config, output_dir, random_seeds, chunk_ids, reru
255256

256257
done_jobs = defaultdict(list)
257258
for seed, chunk_id in expected_files.keys():
258-
if chunk_id not in missing_jobs[seed]:
259+
if chunk_id not in missing_jobs.get(seed, []):
259260
done_jobs[seed].append(chunk_id)
260261

262+
if rerun_ratelimit_errors:
263+
for seed in random_seeds:
264+
for chunk_id in list(missing_jobs.get(seed, [])):
265+
# for partially done seed/chunk_id combo rewrite the current -async file by simply dropping ratelimit error rows
266+
output_file = get_chunked_rs_filename(status_dir, random_seed=seed, chunk_id=chunk_id)
267+
_rewrite_async_resume_file_if_needed(output_file)
268+
269+
for seed in random_seeds:
270+
for chunk_id in list(done_jobs[seed]):
271+
# for fully done seed/chunk_id combo - of the chunks have been merged, there is no easy way to rerun just the ratelimit error rows so raise and ask user to --rerun-done
272+
if chunk_id is not None and len(chunk_ids) > 1:
273+
merged_output_file = get_chunked_rs_filename(status_dir, random_seed=seed, chunk_id=None)
274+
if Path(merged_output_file).exists():
275+
raise ValueError(
276+
"Cannot use --rerun-rate-limit-error for completed chunked outputs because "
277+
f"`{merged_output_file}` has already been merged. Use --rerun-done to fully rerun the seed."
278+
)
279+
# for fully done seed/chunk_id combo - if the chunks have not been merged, rewrite the .jsonl to .jsonl-async file by dropping ratelimit error rows and remove .done files
280+
output_file = get_chunked_rs_filename(status_dir, random_seed=seed, chunk_id=chunk_id)
281+
if _rewrite_async_resume_file_if_needed(output_file):
282+
done_file = Path(expected_files[(seed, chunk_id)])
283+
if done_file.exists():
284+
done_file.unlink()
285+
missing_jobs[seed].append(chunk_id)
286+
done_jobs[seed].remove(chunk_id)
287+
288+
missing_job_keys = [(seed, chunk_id) for seed in random_seeds for chunk_id in missing_jobs.get(seed, [])]
289+
if len(missing_job_keys) != len(set(missing_job_keys)):
290+
raise RuntimeError(
291+
"Internal error: duplicate jobs were scheduled for rerun under --rerun-rate-limit-error - this should never happen"
292+
)
293+
261294
done_jobs_str = ", ".join(
262295
[
263296
(
@@ -298,6 +331,59 @@ def get_remaining_jobs(cluster_config, output_dir, random_seeds, chunk_ids, reru
298331
return missing_jobs
299332

300333

334+
def _is_rerunnable_ratelimit_error(output_row: dict) -> bool:
335+
values_to_check = [output_row.get("detailed_error"), output_row.get("error")]
336+
337+
while values_to_check:
338+
value = values_to_check.pop()
339+
if isinstance(value, str) and "ratelimit" in re.sub(r"[\s_-]+", "", value.lower()):
340+
return True
341+
if isinstance(value, list):
342+
values_to_check.extend(value)
343+
344+
return False
345+
346+
347+
def _rewrite_async_resume_file_if_needed(output_file: str, async_position_key: str = "_async_position") -> bool:
348+
output_path = Path(output_file)
349+
async_output_path = Path(f"{output_file}-async")
350+
if output_path.exists():
351+
source_path = output_path
352+
rewriting_existing_async = False
353+
elif async_output_path.exists():
354+
source_path = async_output_path
355+
rewriting_existing_async = True
356+
else:
357+
return False
358+
359+
preserved_rows = []
360+
rerunnable_found = False
361+
with open(source_path, "rt", encoding="utf-8") as fin:
362+
for idx, line in enumerate(fin):
363+
row = json.loads(line)
364+
if _is_rerunnable_ratelimit_error(row):
365+
rerunnable_found = True
366+
continue
367+
if not rewriting_existing_async or async_position_key not in row:
368+
row[async_position_key] = idx
369+
preserved_rows.append(row)
370+
371+
if not rerunnable_found:
372+
return False
373+
374+
with open(async_output_path, "wt", encoding="utf-8") as fout:
375+
for row in preserved_rows:
376+
fout.write(json.dumps(row) + "\n")
377+
if not rewriting_existing_async:
378+
output_path.unlink()
379+
LOG.info(
380+
"Prepared `%s` for resume by preserving %d completed rows and removing rate-limit failures",
381+
source_path,
382+
len(preserved_rows),
383+
)
384+
return True
385+
386+
301387
def separate_hydra_args(extra_arguments: str) -> tuple[str, str]:
302388
"""
303389
Separate Hydra config args (--config-*, --cfg, --info, etc.) and

tests/test_generation.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -367,6 +367,7 @@ def fake_get_generation_cmd(*args, **kwargs):
367367
num_chunks=None,
368368
chunk_ids=None,
369369
rerun_done=False,
370+
rerun_ratelimit_errors=False,
370371
server_parameters={
371372
"model": "test-model",
372373
"server_type": "openai",

0 commit comments

Comments
 (0)