3232from loguru import logger
3333from ray .util .placement_group import placement_group
3434
35- from skyrl .backends .skyrl_train .training_batch import TrainingInputBatch
35+ from skyrl .backends .skyrl_train .training_batch import (
36+ TrainingInputBatch ,
37+ pad_training_input_batch ,
38+ )
3639from skyrl .backends .skyrl_train .utils .io import io
3740from skyrl .backends .skyrl_train .workers .worker import PPORayActorGroup
3841from skyrl .backends .skyrl_train .workers .worker_dispatch import WorkerDispatch
@@ -467,6 +470,12 @@ def load_dataset(self) -> list:
467470 """Load and tokenize the training dataset."""
468471 return self ._load_and_tokenize (self .sft_cfg .dataset_name , self .sft_cfg .dataset_split )
469472
473+ def load_eval_dataset (self ) -> Optional [list ]:
474+ """Load and tokenize the eval dataset, or return ``None`` if not configured."""
475+ if not self .sft_cfg .eval_dataset_name :
476+ return None
477+ return self ._load_and_tokenize (self .sft_cfg .eval_dataset_name , self .sft_cfg .eval_dataset_split )
478+
470479 def _log_dataset_stats (self , tokenized : list ) -> None :
471480 """Log tokenized sequence length statistics over the training set.
472481
@@ -498,7 +507,7 @@ def pct(p: float) -> int:
498507 f"total={ sum (lengths )} , mean={ mean_len :.1f} , median={ q50 } , q25={ q25 } , q75={ q75 } , min={ min_len } , max={ max_len } "
499508 )
500509
501- def collate_batch (self , examples : list ) -> TrainingInputBatch :
510+ def collate_batch (self , examples : list , batch_size : int ) -> TrainingInputBatch :
502511 """Collate examples into a TrainingInputBatch with loss normalization.
503512
504513 Normalizes the loss_mask so that the sum-reduction in cross_entropy_loss
@@ -510,14 +519,20 @@ def collate_batch(self, examples: list) -> TrainingInputBatch:
510519 (FSDP) or ``1/num_microbatches`` (Megatron) applied during gradient
511520 accumulation so that the effective gradient equals
512521 ``d[sum(-log_probs_on_nonpad) / total_nonpad]``.
522+
523+ Args:
524+ examples: Tokenized examples to collate.
525+ batch_size: Global batch dimension used in the loss-mask scaling
526+ factor. Required; the train path passes ``sft_cfg.batch_size``
527+ and the eval path passes its per-dispatch chunk size.
513528 """
514529 batch = collate_sft_batch (examples , self .tokenizer )
515530 # Loss normalization: divide by non-pad token count (not padded seq length)
516531 # NOTE (sumanthrh): This specific scaling factor is because SkyRL's workers internally normalize
517532 # by number of micro batches, but aggregate otherwise
518533 micro_batch_size = self .sft_cfg .micro_train_batch_size_per_gpu
519534 total_nonpad = max (batch ["loss_mask" ].sum ().item (), 1 )
520- batch ["loss_mask" ] = batch ["loss_mask" ].float () * (self . sft_cfg . batch_size / (micro_batch_size * total_nonpad ))
535+ batch ["loss_mask" ] = batch ["loss_mask" ].float () * (batch_size / (micro_batch_size * total_nonpad ))
521536 return batch
522537
523538 # ------------------------------------------------------------------ #
@@ -603,6 +618,84 @@ def load_checkpoint(self) -> int:
603618 # Training
604619 # ------------------------------------------------------------------ #
605620
621+ def run_eval (self , eval_tokenized : list ) -> tuple [dict , int ]:
622+ """Compute eval loss over the full eval dataset.
623+
624+ Iterates the eval dataset in chunks of ``micro_train_batch_size_per_gpu * dp_size``
625+ (i.e. exactly one micro-batch per DP rank per dispatch call), calls
626+ :meth:`WorkerDispatch.forward` with ``loss_fn="cross_entropy"`` (which
627+ runs the model in ``eval()`` mode under ``no_grad``), and aggregates the
628+ per-batch losses into a token-weighted mean.
629+
630+ The aggregated loss is a token-weighted mean of the per-batch losses,
631+ which are themselves per-non-pad-token means within each batch. This
632+ yields the true per-non-pad-token mean across the eval dataset.
633+
634+ Args:
635+ eval_tokenized: Pre-tokenized eval dataset (output of
636+ :meth:`load_eval_dataset`).
637+
638+ Returns:
639+ ``(metrics, num_eval_batches)`` where ``metrics`` contains
640+ ``eval_loss`` and ``num_eval_batches`` is bookkeeping for
641+ stdout logging (not a wandb metric).
642+ """
643+ num_eval = len (eval_tokenized )
644+ if num_eval == 0 :
645+ raise ValueError (
646+ "Eval dataset is empty. Provide a non-empty eval split or disable eval "
647+ "by setting eval_dataset_name=None."
648+ )
649+
650+ # One micro-batch per DP rank per dispatch call — keeps memory usage bounded
651+ # and removes the need for a separate `eval_batch_size` knob.
652+ dp_size = self .dispatch .dp_size ("policy" )
653+ eval_chunk_size = self .sft_cfg .micro_train_batch_size_per_gpu * dp_size
654+
655+ # Pad a trailing partial batch up to ``eval_chunk_size`` via
656+ # ``pad_training_input_batch`` (which zeros ``loss_mask`` on padded rows).
657+ # Padded rows contribute 0 to the cross-entropy numerator, and the
658+ # pre-padding ``total_nonpad`` scaling in ``collate_batch`` excludes
659+ # them from the denominator, so the reported ``eval_loss`` is the
660+ # per-real-token mean over the full (non-padded) eval set.
661+ num_eval_batches = ceil (num_eval / eval_chunk_size )
662+
663+ total_loss_weighted = 0.0
664+ total_tokens = 0
665+ for batch_idx in range (num_eval_batches ):
666+ start = batch_idx * eval_chunk_size
667+ end = min (start + eval_chunk_size , num_eval )
668+ batch_examples = eval_tokenized [start :end ]
669+ batch = self .collate_batch (batch_examples , batch_size = eval_chunk_size )
670+ # Pad the last (possibly-short) chunk so every dispatch sees exactly
671+ # ``eval_chunk_size`` rows. ``pad_training_input_batch`` zeros the
672+ # ``loss_mask`` for padding rows; with ``pad_size=0`` it is a no-op.
673+ pad_rows = eval_chunk_size - len (batch_examples )
674+ if pad_rows > 0 :
675+ logger .info (
676+ f"Padding final eval batch by { pad_rows } rows "
677+ f"({ len (batch_examples )} real -> { eval_chunk_size } total); "
678+ f"padded rows are masked out of the loss."
679+ )
680+ batch = pad_training_input_batch (batch , pad_rows )
681+ # Count non-pad response tokens (from the unscaled mask, recovered from the batch)
682+ # We use the attention_mask response window via collate_sft_batch's loss_mask which
683+ # was 0/1 before scaling. Recover the count from the batch by counting positive entries.
684+ # Padded rows have loss_mask=0 so they are excluded here.
685+ nonpad_tokens = int ((batch ["loss_mask" ] > 0 ).sum ().item ())
686+ output = self .dispatch .forward (
687+ "policy" ,
688+ batch ,
689+ loss_fn = "cross_entropy" ,
690+ loss_fn_config = None ,
691+ )
692+ batch_loss = float (output .metrics .get ("loss" , float ("nan" )))
693+ total_loss_weighted += batch_loss * nonpad_tokens
694+ total_tokens += nonpad_tokens
695+
696+ eval_loss = total_loss_weighted / max (total_tokens , 1 )
697+ return {"eval_loss" : eval_loss }, num_eval_batches
698+
606699 def train_step (self , batch : TrainingInputBatch , step : int ) -> dict :
607700 """Execute a single training step: forward_backward + optim_step.
608701
@@ -732,6 +825,24 @@ def train(self):
732825 # Log tokenized sequence length statistics (once, before training loop)
733826 self ._log_dataset_stats (tokenized )
734827
828+ # Load eval dataset (if configured). We load once up-front so the
829+ # tokenization cost is amortized across all eval invocations.
830+ eval_tokenized = self .load_eval_dataset ()
831+ if eval_tokenized is not None :
832+ logger .info (f"Eval dataset loaded: { len (eval_tokenized )} examples" )
833+
834+ # Baseline eval before training begins (logged at step 0).
835+ # Wandb's step counter starts at 0; the training loop's first commit
836+ # advances it to >=1, so step=0 here does not conflict with later steps.
837+ if self .sft_cfg .eval_before_train and eval_tokenized is not None :
838+ eval_metrics , num_eval_batches = self .run_eval (eval_tokenized )
839+ self .tracker .log ({f"eval/{ k } " : v for k , v in eval_metrics .items ()}, step = 0 , commit = True )
840+ logger .info (
841+ f"Baseline eval before training: "
842+ f"eval_loss={ eval_metrics .get ('eval_loss' , float ('nan' )):.4f} "
843+ f"over { num_eval_batches } batches"
844+ )
845+
735846 batch_size = self .sft_cfg .batch_size
736847
737848 # Resolve num_steps: explicit num_steps takes precedence; otherwise derive from num_epochs.
@@ -790,7 +901,7 @@ def train(self):
790901 batch_examples = tokenized [start_idx :] + tokenized [: end_idx - len (tokenized )]
791902 else :
792903 batch_examples = tokenized [start_idx :end_idx ]
793- batch = self .collate_batch (batch_examples )
904+ batch = self .collate_batch (batch_examples , batch_size = batch_size )
794905
795906 # Training step
796907 step_result = self .train_step (batch , self .global_step )
@@ -830,13 +941,39 @@ def train(self):
830941 self .save_hf_model ()
831942 log_dict ["timing/save_hf_model" ] = all_timings ["save_hf_model" ]
832943
944+ eval_metrics = None
945+ num_eval_batches : int | None = None
946+ # Eval fires at step N where N % eval_interval == 0 and N > 0.
947+ # The first iteration of this loop runs as global_step=1 (the
948+ # initial increment happens before this block on resume), so a
949+ # baseline eval at step 0 is not currently produced by the
950+ # training loop. If a step-0 baseline is needed, it would have to
951+ # be evaluated before entering the training loop and logged
952+ # separately.
953+ if (
954+ eval_tokenized is not None
955+ and self .sft_cfg .eval_interval > 0
956+ and self .global_step % self .sft_cfg .eval_interval == 0
957+ ):
958+ with Timer ("eval" , all_timings ):
959+ eval_metrics , num_eval_batches = self .run_eval (eval_tokenized )
960+ if eval_metrics :
961+ log_dict .update ({f"eval/{ k } " : v for k , v in eval_metrics .items ()})
962+ log_dict ["timing/eval" ] = all_timings ["eval" ]
963+
833964 self .tracker .log (log_dict , step = self .global_step , commit = True )
834965
835966 if self .global_step % 5 == 0 :
836967 logger .info (
837968 f"Step { self .global_step } : loss={ step_result ['loss' ]:.4f} , " f"grad_norm={ step_result ['grad_norm' ]} "
838969 )
839970
971+ if eval_metrics :
972+ logger .info (
973+ f"Step { self .global_step } : eval_loss={ eval_metrics .get ('eval_loss' , float ('nan' )):.4f} "
974+ f"over { num_eval_batches } batches"
975+ )
976+
840977 # Check for epoch boundary and reshuffle
841978 epoch = (self .global_step * batch_size ) // len (tokenized )
842979 if epoch > current_epoch :
@@ -866,6 +1003,33 @@ def train(self):
8661003 logger .info (f"Saving final HF model at step { final_step } " )
8671004 self .save_hf_model ()
8681005
1006+ # Final eval pass (skip if the last step already ran eval).
1007+ # NOTE: The last in-loop tracker.log(..., commit=True) at step=num_steps
1008+ # advanced wandb's internal step counter to num_steps+1. Logging the
1009+ # final eval at step=num_steps would be rejected by wandb with
1010+ # "step N < current step N+1". We log the final eval at num_steps+1
1011+ # (one past the last committed train step) in a single combined
1012+ # tracker.log() call, preserving wandb step ordering. We use a local
1013+ # ``final_eval_step`` rather than mutating ``self.global_step``: the
1014+ # bump is purely a wandb-step accounting concern, not real trainer
1015+ # state.
1016+ if eval_tokenized is not None :
1017+ already_ran = self .sft_cfg .eval_interval > 0 and num_steps % self .sft_cfg .eval_interval == 0
1018+ if not already_ran :
1019+ final_eval_step = num_steps + 1
1020+ eval_timings : dict [str , float ] = {}
1021+ with Timer ("eval" , eval_timings ):
1022+ eval_metrics , num_eval_batches = self .run_eval (eval_tokenized )
1023+ if eval_metrics :
1024+ eval_log = {f"eval/{ k } " : v for k , v in eval_metrics .items ()}
1025+ eval_log ["timing/eval" ] = eval_timings ["eval" ]
1026+ self .tracker .log (eval_log , step = final_eval_step , commit = True )
1027+ logger .info (
1028+ f"Final eval at step { final_eval_step } : "
1029+ f"eval_loss={ eval_metrics .get ('eval_loss' , float ('nan' )):.4f} "
1030+ f"over { num_eval_batches } batches"
1031+ )
1032+
8691033 logger .info ("SFT training complete!" )
8701034
8711035 def save_checkpoint (self ):
0 commit comments