Skip to content

Commit be1d208

Browse files
SumanthRHclaude
andcommitted
[ci] add SFT Megatron Tulu3 E2E test
Adds a nightly E2E CI job that exercises the full SFT pipeline against the Megatron backend on the Tulu3 dataset, mirroring the structure of the existing GSM8K GPU E2E jobs. * Backend: Megatron with TP=1, PP=1 on L4_ci (4 GPUs). * Workload: Qwen/Qwen2.5-0.5B-Instruct, 100 steps, train[:2000] from allenai/tulu-3-sft-mixture, batch_size=8, lr=1e-4 (bumped from the source script's 1e-6 so the run produces a downward trend in 100 steps), train_on_what=all_assistant_messages. * Assertions: exit code 0, "SFT training complete!" appears in stdout, no nan/inf loss values, and mean of the last 5 logged losses is less than the mean of the first 5 (lenient windowed trend check, no magnitude thresholds). * Logger: wandb (project=skyrl_sft_ci), reusing the existing WANDB_API_KEY secret injection pattern. Files: - tests/train/gpu_e2e_test/sft_tulu3_megatron.sh: driver + assertions - ci/gpu_e2e_test_run_sft.sh: anyscale-job entrypoint - ci/anyscale_gpu_e2e_test_sft.yaml: anyscale job spec (megatron image) - .github/workflows/gpu_e2e_ci_sft.yaml: nightly GitHub Actions workflow Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
1 parent fb87f35 commit be1d208

5 files changed

Lines changed: 272 additions & 0 deletions

File tree

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
name: SkyRL-GPU-E2E-CI-SFT
2+
3+
on:
4+
schedule:
5+
- cron: '5 8 * * *' # Every day at 08:05 UTC (~00:05 PST / ~01:05 PDT)
6+
workflow_dispatch:
7+
8+
permissions:
9+
checks: write # for status checks to appear
10+
contents: read
11+
12+
jobs:
13+
14+
skyrl_gpu_e2e_test_sft:
15+
runs-on: ubuntu-latest
16+
defaults:
17+
run:
18+
shell: bash
19+
working-directory: .
20+
21+
steps:
22+
- uses: actions/checkout@v4
23+
- name: Set up Python
24+
# This is the version of the action for setting up Python, not the Python version.
25+
uses: actions/setup-python@v5
26+
with:
27+
# Semantic version range syntax or exact version of a Python version
28+
python-version: '3.12'
29+
cache: 'pip'
30+
- name: Install the latest version of uv
31+
uses: astral-sh/setup-uv@v6
32+
with:
33+
activate-environment: true
34+
- name: Install basic dependencies
35+
run: uv pip install anyscale==0.24.79 typer==0.9.0
36+
- name: Install envsubst
37+
run: sudo apt-get update && sudo apt-get install -y gettext-base
38+
- name: Basic convergence test
39+
env:
40+
ANYSCALE_CLI_TOKEN: ${{ secrets.ANYSCALE_CLI_TOKEN }}
41+
ANYSCALE_HOST: https://console.anyscale.com
42+
WANDB_API_KEY: ${{ secrets.WANDB_API_KEY }}
43+
run: |
44+
envsubst < ci/anyscale_gpu_e2e_test_sft.yaml > ci/anyscale_gpu_e2e_test_sft_envsubst.yaml
45+
anyscale job submit -f ci/anyscale_gpu_e2e_test_sft_envsubst.yaml --timeout 4500
46+
anyscale job wait --cloud sky-anyscale-aws-us-east-1 --name skyrl-train-gpu-e2e-test-sft --timeout 4500
47+
rm -f ci/anyscale_gpu_e2e_test_sft_envsubst.yaml

ci/anyscale_gpu_e2e_test_sft.yaml

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
name: skyrl-train-gpu-e2e-test-sft
2+
entrypoint: bash ci/gpu_e2e_test_run_sft.sh
3+
image_uri: novaskyai/skyrl-train-ray-2.51.1-py3.12-cu12.8-megatron # (Optional) Exclusive with `containerfile`.
4+
cloud: sky-anyscale-aws-us-east-1
5+
ray_version: "2.51.1"
6+
compute_config: l4_ci
7+
working_dir: . # (Optional) Use current working directory "." as the working_dir. Can be any local path or remote .zip file in cloud storage.
8+
env_vars:
9+
RAY_OVERRIDE_JOB_RUNTIME_ENV: "1"
10+
WANDB_API_KEY: $WANDB_API_KEY
11+
max_retries: 1 # (Optional) Maximum number of times the job will be retried before being marked failed. Defaults to `1`.

ci/gpu_e2e_test_run_sft.sh

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
#!/usr/bin/env bash
2+
set -euo pipefail
3+
4+
bash tests/train/gpu_e2e_test/sft_tulu3_megatron.sh
Lines changed: 128 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,128 @@
1+
"""Trend/health check for an SFT CI run, sourced from wandb.
2+
3+
Replaces the bash stdout-parsing block in ``sft_tulu3_megatron.sh``:
4+
pulls the run's logged ``train/loss`` history from wandb and asserts on it.
5+
6+
Checks performed (any failure exits non-zero):
7+
* Run exists in the given project (matched by display name; most recent wins).
8+
* At least ``--min_steps`` ``train/loss`` rows are logged
9+
(defaults to ``2 * window``, i.e. enough for non-overlapping windows).
10+
* No NaN/inf in the logged loss history.
11+
* ``mean(last N losses) < mean(first N losses)`` where N is ``--window``.
12+
* Optionally: the run's final ``_step`` >= ``--expected_steps`` (skipped if
13+
``--expected_steps`` is not provided).
14+
15+
The first 4 checks are CI-critical; the last is opt-in because some callers
16+
don't know the exact step count up front.
17+
"""
18+
19+
import argparse
20+
import math
21+
import sys
22+
23+
import wandb
24+
25+
26+
def parse_args() -> argparse.Namespace:
27+
parser = argparse.ArgumentParser(description=__doc__, formatter_class=argparse.RawDescriptionHelpFormatter)
28+
parser.add_argument("--run_name", type=str, required=True, help="wandb run display name")
29+
parser.add_argument("--project_name", type=str, required=True, help="wandb project name")
30+
parser.add_argument(
31+
"--entity",
32+
type=str,
33+
default=None,
34+
help="wandb entity. If omitted, project_name is passed as-is "
35+
"(matching the convention used by get_summary.py).",
36+
)
37+
parser.add_argument(
38+
"--metric",
39+
type=str,
40+
default="train/loss",
41+
help="History metric to pull (default: train/loss).",
42+
)
43+
parser.add_argument(
44+
"--window",
45+
type=int,
46+
default=5,
47+
help="Window size N for the first-vs-last mean comparison (default: 5).",
48+
)
49+
parser.add_argument(
50+
"--min_steps",
51+
type=int,
52+
default=None,
53+
help="Minimum number of logged loss rows required. Defaults to 2 * window.",
54+
)
55+
parser.add_argument(
56+
"--expected_steps",
57+
type=int,
58+
default=None,
59+
help="If set, assert the run's final _step is >= this value (completion check).",
60+
)
61+
return parser.parse_args()
62+
63+
64+
def main() -> int:
65+
args = parse_args()
66+
min_steps = args.min_steps if args.min_steps is not None else 2 * args.window
67+
project_path = f"{args.entity}/{args.project_name}" if args.entity else args.project_name
68+
69+
api = wandb.Api()
70+
runs = api.runs(project_path, filters={"display_name": args.run_name}, order="-created_at")
71+
matched_run = next(iter(runs), None)
72+
if matched_run is None:
73+
print(f"FAIL: run '{args.run_name}' not found in project '{project_path}'", file=sys.stderr)
74+
return 1
75+
print(f"Matched run: id={matched_run.id} state={matched_run.state} url={matched_run.url}")
76+
77+
# Pull the full loss history. scan_history streams every logged row (vs.
78+
# the sampled 500-point default from .history()).
79+
rows = list(matched_run.scan_history(keys=[args.metric]))
80+
losses = [row[args.metric] for row in rows if args.metric in row]
81+
print(f"Pulled {len(losses)} '{args.metric}' rows from wandb history.")
82+
83+
# ---- Completion check (optional) ----
84+
if args.expected_steps is not None:
85+
final_step = matched_run.summary_metrics.get("_step")
86+
if final_step is None or final_step < args.expected_steps:
87+
print(
88+
f"FAIL: run final _step={final_step} < expected_steps={args.expected_steps}",
89+
file=sys.stderr,
90+
)
91+
return 1
92+
print(f"PASS: run completed (final _step={final_step} >= {args.expected_steps}).")
93+
94+
# ---- Minimum-rows check ----
95+
if len(losses) < min_steps:
96+
print(
97+
f"FAIL: only {len(losses)} '{args.metric}' rows, need at least {min_steps} "
98+
f"(2 * window={args.window}) for windowed trend check",
99+
file=sys.stderr,
100+
)
101+
return 1
102+
103+
# ---- NaN/inf check ----
104+
bad = [(i, v) for i, v in enumerate(losses) if not math.isfinite(v)]
105+
if bad:
106+
print(f"FAIL: non-finite '{args.metric}' values detected: {bad[:5]}", file=sys.stderr)
107+
return 1
108+
print(f"PASS: no NaN/inf in '{args.metric}' history.")
109+
110+
# ---- Windowed trend check ----
111+
n = args.window
112+
first_mean = sum(losses[:n]) / n
113+
last_mean = sum(losses[-n:]) / n
114+
print(f"Mean of first {n} losses: {first_mean:.6f}; Mean of last {n} losses: {last_mean:.6f}")
115+
if not (last_mean < first_mean):
116+
print(
117+
f"FAIL: mean of last {n} losses ({last_mean:.6f}) is not < " f"mean of first {n} ({first_mean:.6f})",
118+
file=sys.stderr,
119+
)
120+
return 1
121+
print(f"PASS: loss trend check (mean last {n} < mean first {n}).")
122+
123+
print("All SFT CI assertions passed.")
124+
return 0
125+
126+
127+
if __name__ == "__main__":
128+
sys.exit(main())
Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
#!/usr/bin/env bash
2+
# E2E CI test for SFT with the Megatron backend on Tulu3.
3+
#
4+
# Runs ``examples/train/sft/run_sft_megatron_tulu3_50k.sh`` with shorter
5+
# overrides (100 steps, train[:2000]) and asserts:
6+
# * Process exits 0.
7+
# * "SFT training complete!" appears in stdout.
8+
# * Via ``check_sft_trend.py`` (sourcing the run's history from wandb):
9+
# - The run completed all expected steps.
10+
# - No NaN/inf in the ``train/loss`` history.
11+
# - Mean of the last 5 logged losses is strictly less than the mean of the
12+
# first 5 (lenient trend check averaged over windows to absorb step
13+
# noise; no magnitude thresholds).
14+
#
15+
# Logger is wandb so that the run is visible alongside other CI runs and
16+
# downstream assertions can introspect the run's logged history directly.
17+
set -euo pipefail
18+
19+
SCRIPT_DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" && pwd )"
20+
RUN_NAME="sft_megatron_run_$(date +%Y%m%d%H)"
21+
PROJECT_NAME="skyrl_sft_ci"
22+
ENTITY="sky-posttraining-uc-berkeley"
23+
NUM_STEPS=100
24+
LOG_FILE="${LOG_FILE:-/tmp/${RUN_NAME}.log}"
25+
26+
# The anyscale job's working_dir is the repo root, so we can use relative paths.
27+
# We pipe through `tee` so the full stdout is mirrored to ``$LOG_FILE`` for
28+
# downstream parsing of the loss trend / completion signal.
29+
#
30+
# Notes on overrides vs the source script:
31+
# * lr is bumped from 1e-6 to 1e-4 so the model produces a clear downward
32+
# trend in 100 steps; the source script's 1e-6 is calibrated for 4166 steps.
33+
# * batch_size=8, micro_train_batch_size_per_gpu=2 are sized for L4_ci (4 GPUs).
34+
bash examples/train/sft/run_sft_megatron_tulu3_50k.sh \
35+
num_steps=$NUM_STEPS \
36+
dataset_split="train[:2000]" \
37+
batch_size=8 \
38+
micro_train_batch_size_per_gpu=2 \
39+
max_length=1024 \
40+
model.path=Qwen/Qwen2.5-0.5B-Instruct \
41+
optimizer_config.lr=1e-4 \
42+
placement.num_nodes=1 \
43+
placement.num_gpus_per_node=4 \
44+
megatron_config.tensor_model_parallel_size=1 \
45+
megatron_config.pipeline_model_parallel_size=1 \
46+
megatron_config.context_parallel_size=1 \
47+
train_on_what="all_assistant_messages" \
48+
logger=wandb \
49+
project_name="$PROJECT_NAME" \
50+
run_name="$RUN_NAME" \
51+
ckpt_path="" \
52+
ckpt_interval=0 \
53+
hf_save_interval=0 \
54+
resume_from="" \
55+
2>&1 | tee "$LOG_FILE"
56+
57+
# `set -o pipefail` ensures the failure of the training command propagates
58+
# through the `tee` pipeline, so by the time we get here the training run
59+
# itself succeeded (exit code 0).
60+
61+
# ---- Completion marker (stdout-side, cheap sanity check) ----
62+
# Confirms the trainer reached its final print before exiting. The wandb-side
63+
# completion/trend/nan-inf assertions follow.
64+
if ! grep -q "SFT training complete!" "$LOG_FILE"; then
65+
echo "FAIL: 'SFT training complete!' not found in $LOG_FILE"
66+
exit 1
67+
fi
68+
echo "PASS: 'SFT training complete!' marker found."
69+
70+
# ---- Wandb-side assertions ----
71+
# Pulls the run's logged ``train/loss`` history and asserts:
72+
# * final _step >= NUM_STEPS (completion),
73+
# * no NaN/inf in the history,
74+
# * mean(last 5) < mean(first 5) (lenient windowed trend).
75+
uv run --isolated --extra fsdp $SCRIPT_DIR/check_sft_trend.py \
76+
--run_name "$RUN_NAME" \
77+
--project_name "$PROJECT_NAME" \
78+
--entity "$ENTITY" \
79+
--window 5 \
80+
--expected_steps "$NUM_STEPS"
81+
82+
echo "All SFT CI assertions passed."

0 commit comments

Comments
 (0)