Skip to content

Commit 064fc1b

Browse files
SumanthRHclaude
andauthored
[feat] Add support for evaluation dataset in SFTTrainer (NovaSky-AI#1668)
# What does this PR do? ## Summary Adds evaluation-dataset support to `SFTTrainer`. During training, eval is run every `eval_interval` steps on a held-out dataset, and `eval_loss` is logged to wandb alongside `train_loss`. Built on top of [PR 1 (`pr1-forward-loss`)](../pr1-forward-loss/pr.md), which provides the Tinker-compatible `forward(loss_fn="cross_entropy", ...)` API + `WorkerOutput` dataclass that this PR consumes. ## Changes ### `SFTConfig` fields ```python eval_dataset_name: Optional[str] = None # HF dataset name/path; None disables eval eval_dataset_split: str = "validation" eval_interval: int = 0 # Run eval every N training steps; 0 disables eval_before_train: bool = False # Run evaluation at step 0 before training starts ``` Removes `eval_batch_size` — eval automatically uses `micro_train_batch_size_per_gpu * dp_size` per dispatch (via the new public `WorkerDispatch.dp_size(model)` accessor). Side effect: `eval_loss` is now the true per-non-pad-token NLL independent of batch size. `collate_batch.batch_size` is now a required positional arg (was optional with a default fall-through to `sft_cfg.batch_size`). ### Eval loop in `SFTTrainer` New `SFTTrainer.run_eval(eval_tokenized)`: - Iterates the eval dataset in chunks of `micro_train_batch_size_per_gpu * dp_size` - Calls `dispatch.forward("policy", batch, loss_fn="cross_entropy", loss_fn_config=None)` per chunk - Aggregates a token-weighted mean loss across batches: `sum(batch_loss * nonpad_tokens) / sum(nonpad_tokens)` - Returns `{"eval_loss": ...}` ### Example script `examples/train/sft/run_sft_megatron_tulu3_eval.sh` — end-to-end SFT run on the tulu-3-sft-mixture dataset with the Megatron backend and eval enabled. ## Test plan - [x] `tests/train/test_sft_tokenization.py` — 23 passed - [x] End-to-end tulu Megatron run (Qwen2.5-0.5B-Instruct, 200 steps, eval every 50 steps) — eval_loss decreases monotonically: 2.1463 → 2.1419 → 2.1407 → 2.1400 ([wandb run 8t8opq7z](https://wandb.ai/sky-posttraining-uc-berkeley/skyrl-sft-eval/runs/8t8opq7z)) - [x] Verified wandb step-ordering bug is fixed: zero "Tried to log to step N that is less than the current step" warnings in the run log ## Dependencies - Built on top of `pr1-forward-loss` — that PR must merge first. --------- Signed-off-by: SumanthRH <sumanthrh99@gmail.com> Co-authored-by: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
1 parent c3cdb2b commit 064fc1b

5 files changed

Lines changed: 196 additions & 6 deletions

File tree

examples/train/sft/run_sft_megatron_tulu3_50k.sh

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,8 @@ uv run --isolated --extra megatron \
2121
model.path=Qwen/Qwen2.5-0.5B-Instruct \
2222
dataset_name=allenai/tulu-3-sft-mixture \
2323
dataset_split="train[:50000]" \
24+
eval_dataset_name=allenai/tulu-3-sft-mixture \
25+
eval_dataset_split="train[-500:]" \
2426
messages_key=messages \
2527
max_length=4096 \
2628
num_steps=4166 \

skyrl/backends/skyrl_train/workers/worker_dispatch.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,10 @@ def get_lcm_dp_size(self) -> int:
114114
dp_size = math.lcm(dp_size, self._actor_groups["ref"].actor_infos[0].rank.dp_size)
115115
return dp_size
116116

117+
def dp_size(self, model: str) -> int:
118+
"""Return the data-parallel size for ``model`` (e.g. "policy")."""
119+
return self._actor_groups[model].actor_infos[0].rank.dp_size
120+
117121
def _should_manage_offload(self, model: str) -> bool:
118122
"""Check if we need to manage offload for this model."""
119123
if self.colocate_all:

skyrl/train/config/sft_config.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,18 @@ def from_cli_overrides(cls, args: Union[List[str], dict]) -> "SFTConfig":
128128
dataset_name: str = "yahma/alpaca-cleaned"
129129
dataset_split: str = "train[:100]"
130130
messages_key: str = "messages" # column name for chat-format datasets
131+
132+
# ---- Evaluation dataset ----
133+
eval_dataset_name: Optional[str] = None
134+
"""HuggingFace dataset name (or path) used to compute eval loss during training.
135+
When ``None`` (default), eval is disabled."""
136+
eval_dataset_split: str = "validation"
137+
"""Split of the eval dataset to load (e.g. ``"validation"``, ``"test[:500]"``)."""
138+
eval_interval: int = 0
139+
"""Run eval every N training steps. Eval also runs once at the end of training
140+
when an eval dataset is configured. ``0`` disables periodic eval."""
141+
eval_before_train: bool = False
142+
"""If True, run a baseline eval pass before training begins (logged at step 0)."""
131143
max_length: Optional[int] = None
132144
"""Maximum length of tokenized sequences. If specified, all sequences will be truncated to this value
133145
By default, no truncation is performed"""
@@ -207,6 +219,14 @@ def validate_sft_cfg(cfg: SFTConfig) -> None:
207219
if cfg.dummy_run_full_ctx and cfg.dummy_run_max_steps <= 0:
208220
raise ValueError(f"dummy_run_max_steps must be > 0, got {cfg.dummy_run_max_steps}")
209221

222+
# Eval config
223+
if cfg.eval_interval < 0:
224+
raise ValueError(f"eval_interval must be >= 0, got {cfg.eval_interval}")
225+
if cfg.eval_interval > 0 and not cfg.eval_dataset_name:
226+
raise ValueError("eval_interval > 0 requires eval_dataset_name to be set")
227+
if cfg.eval_before_train and cfg.eval_dataset_name is None:
228+
raise ValueError("eval_before_train=True requires eval_dataset_name to be set")
229+
210230
# checks for megatron
211231
if cfg.strategy == "megatron":
212232
tp = cfg.megatron_config.tensor_model_parallel_size

skyrl/train/sft_trainer.py

Lines changed: 168 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,10 @@
3232
from loguru import logger
3333
from ray.util.placement_group import placement_group
3434

35-
from skyrl.backends.skyrl_train.training_batch import TrainingInputBatch
35+
from skyrl.backends.skyrl_train.training_batch import (
36+
TrainingInputBatch,
37+
pad_training_input_batch,
38+
)
3639
from skyrl.backends.skyrl_train.utils.io import io
3740
from skyrl.backends.skyrl_train.workers.worker import PPORayActorGroup
3841
from skyrl.backends.skyrl_train.workers.worker_dispatch import WorkerDispatch
@@ -467,6 +470,12 @@ def load_dataset(self) -> list:
467470
"""Load and tokenize the training dataset."""
468471
return self._load_and_tokenize(self.sft_cfg.dataset_name, self.sft_cfg.dataset_split)
469472

473+
def load_eval_dataset(self) -> Optional[list]:
474+
"""Load and tokenize the eval dataset, or return ``None`` if not configured."""
475+
if not self.sft_cfg.eval_dataset_name:
476+
return None
477+
return self._load_and_tokenize(self.sft_cfg.eval_dataset_name, self.sft_cfg.eval_dataset_split)
478+
470479
def _log_dataset_stats(self, tokenized: list) -> None:
471480
"""Log tokenized sequence length statistics over the training set.
472481
@@ -498,7 +507,7 @@ def pct(p: float) -> int:
498507
f"total={sum(lengths)}, mean={mean_len:.1f}, median={q50}, q25={q25}, q75={q75}, min={min_len}, max={max_len}"
499508
)
500509

501-
def collate_batch(self, examples: list) -> TrainingInputBatch:
510+
def collate_batch(self, examples: list, batch_size: int) -> TrainingInputBatch:
502511
"""Collate examples into a TrainingInputBatch with loss normalization.
503512
504513
Normalizes the loss_mask so that the sum-reduction in cross_entropy_loss
@@ -510,14 +519,20 @@ def collate_batch(self, examples: list) -> TrainingInputBatch:
510519
(FSDP) or ``1/num_microbatches`` (Megatron) applied during gradient
511520
accumulation so that the effective gradient equals
512521
``d[sum(-log_probs_on_nonpad) / total_nonpad]``.
522+
523+
Args:
524+
examples: Tokenized examples to collate.
525+
batch_size: Global batch dimension used in the loss-mask scaling
526+
factor. Required; the train path passes ``sft_cfg.batch_size``
527+
and the eval path passes its per-dispatch chunk size.
513528
"""
514529
batch = collate_sft_batch(examples, self.tokenizer)
515530
# Loss normalization: divide by non-pad token count (not padded seq length)
516531
# NOTE (sumanthrh): This specific scaling factor is because SkyRL's workers internally normalize
517532
# by number of micro batches, but aggregate otherwise
518533
micro_batch_size = self.sft_cfg.micro_train_batch_size_per_gpu
519534
total_nonpad = max(batch["loss_mask"].sum().item(), 1)
520-
batch["loss_mask"] = batch["loss_mask"].float() * (self.sft_cfg.batch_size / (micro_batch_size * total_nonpad))
535+
batch["loss_mask"] = batch["loss_mask"].float() * (batch_size / (micro_batch_size * total_nonpad))
521536
return batch
522537

523538
# ------------------------------------------------------------------ #
@@ -603,6 +618,84 @@ def load_checkpoint(self) -> int:
603618
# Training
604619
# ------------------------------------------------------------------ #
605620

621+
def run_eval(self, eval_tokenized: list) -> tuple[dict, int]:
622+
"""Compute eval loss over the full eval dataset.
623+
624+
Iterates the eval dataset in chunks of ``micro_train_batch_size_per_gpu * dp_size``
625+
(i.e. exactly one micro-batch per DP rank per dispatch call), calls
626+
:meth:`WorkerDispatch.forward` with ``loss_fn="cross_entropy"`` (which
627+
runs the model in ``eval()`` mode under ``no_grad``), and aggregates the
628+
per-batch losses into a token-weighted mean.
629+
630+
The aggregated loss is a token-weighted mean of the per-batch losses,
631+
which are themselves per-non-pad-token means within each batch. This
632+
yields the true per-non-pad-token mean across the eval dataset.
633+
634+
Args:
635+
eval_tokenized: Pre-tokenized eval dataset (output of
636+
:meth:`load_eval_dataset`).
637+
638+
Returns:
639+
``(metrics, num_eval_batches)`` where ``metrics`` contains
640+
``eval_loss`` and ``num_eval_batches`` is bookkeeping for
641+
stdout logging (not a wandb metric).
642+
"""
643+
num_eval = len(eval_tokenized)
644+
if num_eval == 0:
645+
raise ValueError(
646+
"Eval dataset is empty. Provide a non-empty eval split or disable eval "
647+
"by setting eval_dataset_name=None."
648+
)
649+
650+
# One micro-batch per DP rank per dispatch call — keeps memory usage bounded
651+
# and removes the need for a separate `eval_batch_size` knob.
652+
dp_size = self.dispatch.dp_size("policy")
653+
eval_chunk_size = self.sft_cfg.micro_train_batch_size_per_gpu * dp_size
654+
655+
# Pad a trailing partial batch up to ``eval_chunk_size`` via
656+
# ``pad_training_input_batch`` (which zeros ``loss_mask`` on padded rows).
657+
# Padded rows contribute 0 to the cross-entropy numerator, and the
658+
# pre-padding ``total_nonpad`` scaling in ``collate_batch`` excludes
659+
# them from the denominator, so the reported ``eval_loss`` is the
660+
# per-real-token mean over the full (non-padded) eval set.
661+
num_eval_batches = ceil(num_eval / eval_chunk_size)
662+
663+
total_loss_weighted = 0.0
664+
total_tokens = 0
665+
for batch_idx in range(num_eval_batches):
666+
start = batch_idx * eval_chunk_size
667+
end = min(start + eval_chunk_size, num_eval)
668+
batch_examples = eval_tokenized[start:end]
669+
batch = self.collate_batch(batch_examples, batch_size=eval_chunk_size)
670+
# Pad the last (possibly-short) chunk so every dispatch sees exactly
671+
# ``eval_chunk_size`` rows. ``pad_training_input_batch`` zeros the
672+
# ``loss_mask`` for padding rows; with ``pad_size=0`` it is a no-op.
673+
pad_rows = eval_chunk_size - len(batch_examples)
674+
if pad_rows > 0:
675+
logger.info(
676+
f"Padding final eval batch by {pad_rows} rows "
677+
f"({len(batch_examples)} real -> {eval_chunk_size} total); "
678+
f"padded rows are masked out of the loss."
679+
)
680+
batch = pad_training_input_batch(batch, pad_rows)
681+
# Count non-pad response tokens (from the unscaled mask, recovered from the batch)
682+
# We use the attention_mask response window via collate_sft_batch's loss_mask which
683+
# was 0/1 before scaling. Recover the count from the batch by counting positive entries.
684+
# Padded rows have loss_mask=0 so they are excluded here.
685+
nonpad_tokens = int((batch["loss_mask"] > 0).sum().item())
686+
output = self.dispatch.forward(
687+
"policy",
688+
batch,
689+
loss_fn="cross_entropy",
690+
loss_fn_config=None,
691+
)
692+
batch_loss = float(output.metrics.get("loss", float("nan")))
693+
total_loss_weighted += batch_loss * nonpad_tokens
694+
total_tokens += nonpad_tokens
695+
696+
eval_loss = total_loss_weighted / max(total_tokens, 1)
697+
return {"eval_loss": eval_loss}, num_eval_batches
698+
606699
def train_step(self, batch: TrainingInputBatch, step: int) -> dict:
607700
"""Execute a single training step: forward_backward + optim_step.
608701
@@ -732,6 +825,24 @@ def train(self):
732825
# Log tokenized sequence length statistics (once, before training loop)
733826
self._log_dataset_stats(tokenized)
734827

828+
# Load eval dataset (if configured). We load once up-front so the
829+
# tokenization cost is amortized across all eval invocations.
830+
eval_tokenized = self.load_eval_dataset()
831+
if eval_tokenized is not None:
832+
logger.info(f"Eval dataset loaded: {len(eval_tokenized)} examples")
833+
834+
# Baseline eval before training begins (logged at step 0).
835+
# Wandb's step counter starts at 0; the training loop's first commit
836+
# advances it to >=1, so step=0 here does not conflict with later steps.
837+
if self.sft_cfg.eval_before_train and eval_tokenized is not None:
838+
eval_metrics, num_eval_batches = self.run_eval(eval_tokenized)
839+
self.tracker.log({f"eval/{k}": v for k, v in eval_metrics.items()}, step=0, commit=True)
840+
logger.info(
841+
f"Baseline eval before training: "
842+
f"eval_loss={eval_metrics.get('eval_loss', float('nan')):.4f} "
843+
f"over {num_eval_batches} batches"
844+
)
845+
735846
batch_size = self.sft_cfg.batch_size
736847

737848
# Resolve num_steps: explicit num_steps takes precedence; otherwise derive from num_epochs.
@@ -790,7 +901,7 @@ def train(self):
790901
batch_examples = tokenized[start_idx:] + tokenized[: end_idx - len(tokenized)]
791902
else:
792903
batch_examples = tokenized[start_idx:end_idx]
793-
batch = self.collate_batch(batch_examples)
904+
batch = self.collate_batch(batch_examples, batch_size=batch_size)
794905

795906
# Training step
796907
step_result = self.train_step(batch, self.global_step)
@@ -830,13 +941,39 @@ def train(self):
830941
self.save_hf_model()
831942
log_dict["timing/save_hf_model"] = all_timings["save_hf_model"]
832943

944+
eval_metrics = None
945+
num_eval_batches: int | None = None
946+
# Eval fires at step N where N % eval_interval == 0 and N > 0.
947+
# The first iteration of this loop runs as global_step=1 (the
948+
# initial increment happens before this block on resume), so a
949+
# baseline eval at step 0 is not currently produced by the
950+
# training loop. If a step-0 baseline is needed, it would have to
951+
# be evaluated before entering the training loop and logged
952+
# separately.
953+
if (
954+
eval_tokenized is not None
955+
and self.sft_cfg.eval_interval > 0
956+
and self.global_step % self.sft_cfg.eval_interval == 0
957+
):
958+
with Timer("eval", all_timings):
959+
eval_metrics, num_eval_batches = self.run_eval(eval_tokenized)
960+
if eval_metrics:
961+
log_dict.update({f"eval/{k}": v for k, v in eval_metrics.items()})
962+
log_dict["timing/eval"] = all_timings["eval"]
963+
833964
self.tracker.log(log_dict, step=self.global_step, commit=True)
834965

835966
if self.global_step % 5 == 0:
836967
logger.info(
837968
f"Step {self.global_step}: loss={step_result['loss']:.4f}, " f"grad_norm={step_result['grad_norm']}"
838969
)
839970

971+
if eval_metrics:
972+
logger.info(
973+
f"Step {self.global_step}: eval_loss={eval_metrics.get('eval_loss', float('nan')):.4f} "
974+
f"over {num_eval_batches} batches"
975+
)
976+
840977
# Check for epoch boundary and reshuffle
841978
epoch = (self.global_step * batch_size) // len(tokenized)
842979
if epoch > current_epoch:
@@ -866,6 +1003,33 @@ def train(self):
8661003
logger.info(f"Saving final HF model at step {final_step}")
8671004
self.save_hf_model()
8681005

1006+
# Final eval pass (skip if the last step already ran eval).
1007+
# NOTE: The last in-loop tracker.log(..., commit=True) at step=num_steps
1008+
# advanced wandb's internal step counter to num_steps+1. Logging the
1009+
# final eval at step=num_steps would be rejected by wandb with
1010+
# "step N < current step N+1". We log the final eval at num_steps+1
1011+
# (one past the last committed train step) in a single combined
1012+
# tracker.log() call, preserving wandb step ordering. We use a local
1013+
# ``final_eval_step`` rather than mutating ``self.global_step``: the
1014+
# bump is purely a wandb-step accounting concern, not real trainer
1015+
# state.
1016+
if eval_tokenized is not None:
1017+
already_ran = self.sft_cfg.eval_interval > 0 and num_steps % self.sft_cfg.eval_interval == 0
1018+
if not already_ran:
1019+
final_eval_step = num_steps + 1
1020+
eval_timings: dict[str, float] = {}
1021+
with Timer("eval", eval_timings):
1022+
eval_metrics, num_eval_batches = self.run_eval(eval_tokenized)
1023+
if eval_metrics:
1024+
eval_log = {f"eval/{k}": v for k, v in eval_metrics.items()}
1025+
eval_log["timing/eval"] = eval_timings["eval"]
1026+
self.tracker.log(eval_log, step=final_eval_step, commit=True)
1027+
logger.info(
1028+
f"Final eval at step {final_eval_step}: "
1029+
f"eval_loss={eval_metrics.get('eval_loss', float('nan')):.4f} "
1030+
f"over {num_eval_batches} batches"
1031+
)
1032+
8691033
logger.info("SFT training complete!")
8701034

8711035
def save_checkpoint(self):

tests/train/test_sft_tokenization.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -356,7 +356,7 @@ def test_loss_norm_sums_to_expected(tokenizer):
356356
trainer.sft_cfg = cfg
357357
trainer.tokenizer = tokenizer
358358

359-
batch = trainer.collate_batch(examples)
359+
batch = trainer.collate_batch(examples, batch_size=cfg.batch_size)
360360

361361
total_nonpad = 2 + 4 + 1 + 3 # = 10
362362
expected_scaling = cfg.batch_size / (cfg.micro_train_batch_size_per_gpu * total_nonpad)
@@ -388,7 +388,7 @@ def test_loss_norm_all_nonpad(tokenizer):
388388
trainer.sft_cfg = cfg
389389
trainer.tokenizer = tokenizer
390390

391-
batch = trainer.collate_batch(examples)
391+
batch = trainer.collate_batch(examples, batch_size=cfg.batch_size)
392392

393393
total_nonpad = 4 # 2 + 2
394394
expected_scaling = cfg.batch_size / (cfg.micro_train_batch_size_per_gpu * total_nonpad)

0 commit comments

Comments
 (0)