|
| 1 | +"""Trend/health check for an SFT CI run, sourced from wandb. |
| 2 | +
|
| 3 | +Replaces the bash stdout-parsing block in ``sft_tulu3_megatron.sh``: |
| 4 | +pulls the run's logged ``train/loss`` history from wandb and asserts on it. |
| 5 | +
|
| 6 | +Checks performed (any failure exits non-zero): |
| 7 | + * Run exists in the given project (matched by display name; most recent wins). |
| 8 | + * At least ``--min_steps`` ``train/loss`` rows are logged |
| 9 | + (defaults to ``2 * window``, i.e. enough for non-overlapping windows). |
| 10 | + * No NaN/inf in the logged loss history. |
| 11 | + * ``mean(last N losses) < mean(first N losses)`` where N is ``--window``. |
| 12 | + * Optionally: the run's final ``_step`` >= ``--expected_steps`` (skipped if |
| 13 | + ``--expected_steps`` is not provided). |
| 14 | +
|
| 15 | +The first 4 checks are CI-critical; the last is opt-in because some callers |
| 16 | +don't know the exact step count up front. |
| 17 | +""" |
| 18 | + |
| 19 | +import argparse |
| 20 | +import math |
| 21 | +import sys |
| 22 | + |
| 23 | +import wandb |
| 24 | + |
| 25 | + |
| 26 | +def parse_args() -> argparse.Namespace: |
| 27 | + parser = argparse.ArgumentParser(description=__doc__, formatter_class=argparse.RawDescriptionHelpFormatter) |
| 28 | + parser.add_argument("--run_name", type=str, required=True, help="wandb run display name") |
| 29 | + parser.add_argument("--project_name", type=str, required=True, help="wandb project name") |
| 30 | + parser.add_argument( |
| 31 | + "--entity", |
| 32 | + type=str, |
| 33 | + default=None, |
| 34 | + help="wandb entity. If omitted, project_name is passed as-is " |
| 35 | + "(matching the convention used by get_summary.py).", |
| 36 | + ) |
| 37 | + parser.add_argument( |
| 38 | + "--metric", |
| 39 | + type=str, |
| 40 | + default="train/loss", |
| 41 | + help="History metric to pull (default: train/loss).", |
| 42 | + ) |
| 43 | + parser.add_argument( |
| 44 | + "--window", |
| 45 | + type=int, |
| 46 | + default=5, |
| 47 | + help="Window size N for the first-vs-last mean comparison (default: 5).", |
| 48 | + ) |
| 49 | + parser.add_argument( |
| 50 | + "--min_steps", |
| 51 | + type=int, |
| 52 | + default=None, |
| 53 | + help="Minimum number of logged loss rows required. Defaults to 2 * window.", |
| 54 | + ) |
| 55 | + parser.add_argument( |
| 56 | + "--expected_steps", |
| 57 | + type=int, |
| 58 | + default=None, |
| 59 | + help="If set, assert the run's final _step is >= this value (completion check).", |
| 60 | + ) |
| 61 | + return parser.parse_args() |
| 62 | + |
| 63 | + |
| 64 | +def main() -> int: |
| 65 | + args = parse_args() |
| 66 | + min_steps = args.min_steps if args.min_steps is not None else 2 * args.window |
| 67 | + project_path = f"{args.entity}/{args.project_name}" if args.entity else args.project_name |
| 68 | + |
| 69 | + api = wandb.Api() |
| 70 | + runs = api.runs(project_path, filters={"display_name": args.run_name}, order="-created_at") |
| 71 | + matched_run = next(iter(runs), None) |
| 72 | + if matched_run is None: |
| 73 | + print(f"FAIL: run '{args.run_name}' not found in project '{project_path}'", file=sys.stderr) |
| 74 | + return 1 |
| 75 | + print(f"Matched run: id={matched_run.id} state={matched_run.state} url={matched_run.url}") |
| 76 | + |
| 77 | + # Pull the full loss history. scan_history streams every logged row (vs. |
| 78 | + # the sampled 500-point default from .history()). |
| 79 | + rows = list(matched_run.scan_history(keys=[args.metric])) |
| 80 | + losses = [row[args.metric] for row in rows if args.metric in row] |
| 81 | + print(f"Pulled {len(losses)} '{args.metric}' rows from wandb history.") |
| 82 | + |
| 83 | + # ---- Completion check (optional) ---- |
| 84 | + if args.expected_steps is not None: |
| 85 | + final_step = matched_run.summary_metrics.get("_step") |
| 86 | + if final_step is None or final_step < args.expected_steps: |
| 87 | + print( |
| 88 | + f"FAIL: run final _step={final_step} < expected_steps={args.expected_steps}", |
| 89 | + file=sys.stderr, |
| 90 | + ) |
| 91 | + return 1 |
| 92 | + print(f"PASS: run completed (final _step={final_step} >= {args.expected_steps}).") |
| 93 | + |
| 94 | + # ---- Minimum-rows check ---- |
| 95 | + if len(losses) < min_steps: |
| 96 | + print( |
| 97 | + f"FAIL: only {len(losses)} '{args.metric}' rows, need at least {min_steps} " |
| 98 | + f"(2 * window={args.window}) for windowed trend check", |
| 99 | + file=sys.stderr, |
| 100 | + ) |
| 101 | + return 1 |
| 102 | + |
| 103 | + # ---- NaN/inf check ---- |
| 104 | + bad = [(i, v) for i, v in enumerate(losses) if not math.isfinite(v)] |
| 105 | + if bad: |
| 106 | + print(f"FAIL: non-finite '{args.metric}' values detected: {bad[:5]}", file=sys.stderr) |
| 107 | + return 1 |
| 108 | + print(f"PASS: no NaN/inf in '{args.metric}' history.") |
| 109 | + |
| 110 | + # ---- Windowed trend check ---- |
| 111 | + n = args.window |
| 112 | + first_mean = sum(losses[:n]) / n |
| 113 | + last_mean = sum(losses[-n:]) / n |
| 114 | + print(f"Mean of first {n} losses: {first_mean:.6f}; Mean of last {n} losses: {last_mean:.6f}") |
| 115 | + if not (last_mean < first_mean): |
| 116 | + print( |
| 117 | + f"FAIL: mean of last {n} losses ({last_mean:.6f}) is not < " f"mean of first {n} ({first_mean:.6f})", |
| 118 | + file=sys.stderr, |
| 119 | + ) |
| 120 | + return 1 |
| 121 | + print(f"PASS: loss trend check (mean last {n} < mean first {n}).") |
| 122 | + |
| 123 | + print("All SFT CI assertions passed.") |
| 124 | + return 0 |
| 125 | + |
| 126 | + |
| 127 | +if __name__ == "__main__": |
| 128 | + sys.exit(main()) |
0 commit comments