5757
5858from .bypass_checkpoint_utils import find_latest_run_dir , load_local_state , save_bypass_checkpoint
5959from .bypass_utils import get_distributed_modules_ownership , set_experiment_dir , set_experiment_id
60- from .data_classes import GlobalRank , IterNum , IterStatistics , LocalTrainingStats , TimeToSaveSignal
60+ from .data_classes import GlobalRank , IterNum , IterStatistics , TimeToSaveSignal
6161from .stitched_model_factory import StitchedModuleDescriptor , StitchedModulesProcessOwnership
6262
6363os .environ ["TOKENIZERS_PARALLELISM" ] = "false"
@@ -118,6 +118,30 @@ def launch_bypass_distillation(hydra_cfg: DictConfig) -> None:
118118 mprint ("Bypass distillation sweep completed" )
119119
120120
121+ def _flush_loss_buffer (
122+ local_buffer : dict [int , dict [str , float ]],
123+ stitched_losses_history : Optional [dict [int , dict [str , float ]]],
124+ ) -> None :
125+ """All-gather buffered per-iter losses and merge into master's history.
126+
127+ Pickle-based ``all_gather_object`` was previously called on every micro-batch;
128+ batching to log-chunk boundaries reduces that cost ~``iters_per_log_chunk``×.
129+ All ranks must call this so the collective doesn't deadlock; only master
130+ actually accumulates into ``stitched_losses_history``.
131+ """
132+ if not local_buffer :
133+ return
134+ gathered : list [Optional [dict [int , dict [str , float ]]]] = [None ] * dist .size ()
135+ torch .distributed .all_gather_object (gathered , local_buffer )
136+ if dist .is_master ():
137+ assert stitched_losses_history is not None
138+ for rank_buf in gathered :
139+ if rank_buf is None :
140+ continue
141+ for it , losses in rank_buf .items ():
142+ stitched_losses_history .setdefault (it , {}).update (losses )
143+
144+
121145def train (
122146 cfg : DictConfig ,
123147 descriptor : ModelDescriptor ,
@@ -126,7 +150,7 @@ def train(
126150 teacher_stitched_model : StitchedModule ,
127151 stitched_module_descriptors : OrderedDict [str , StitchedModuleDescriptor ],
128152 stitched_modules_process_ownership : StitchedModulesProcessOwnership ,
129- train_dataloader : DataLoader ,
153+ train_dataloader : Optional [ DataLoader ] ,
130154 val_dataloader : Optional [DataLoader ],
131155 student_model_config : PretrainedConfig ,
132156 skip_first_batches : int = 0 ,
@@ -211,13 +235,18 @@ def train(
211235 f"Grad scaling status: { 'enabled' if cfg .bypass .training .use_grad_scaling else 'disabled' } "
212236 )
213237
214- train_iterator = iter (train_dataloader )
238+ # Only master consumes the dataloader — `next(train_iterator)` is gated by
239+ # `if dist.is_master()` further down. Building the iterator (or running
240+ # skip_first_batches against it) on non-master ranks wastes startup time
241+ # and memory proportional to the dataset, since each tokenizes the full
242+ # corpus only to throw it away.
243+ train_iterator = iter (train_dataloader ) if dist .is_master () else None
215244
216245 # Advance past the first `skip_first_batches` batches before the training loop
217246 # starts. Used either to skip a known-bad batch range during debugging, or to
218247 # roll the data iterator forward when resuming a run (model + optimizer state
219248 # are restored from the checkpoint, but the dataloader itself starts fresh).
220- if skip_first_batches > 0 :
249+ if dist . is_master () and skip_first_batches > 0 :
221250 mprint (f"Skipping first { skip_first_batches } batches before training" )
222251 for _ in range (skip_first_batches ):
223252 next (train_iterator )
@@ -233,8 +262,21 @@ def train(
233262 best_steps_by_name : dict [str , int ] = dict (cfg .bypass .get ("best_steps_by_name" , {}))
234263 # Anchor for the "Δ from initial" column: per-block loss from the first log chunk.
235264 initial_losses_by_name : dict [str , float ] = dict (cfg .bypass .get ("initial_losses_by_name" , {}))
236- # Buffer variables
237- input_ids = torch .zeros (1 , 1 , dtype = torch .int64 )
265+
266+ # log_interval is in optimizer-step units; multiply by grad_accum to land in
267+ # micro-batch units, which is what the per-iter loss collection counts.
268+ iters_per_log_chunk = (
269+ cfg .bypass .training .log_interval * cfg .bypass .training .grad_accumulation_steps
270+ )
271+ # Per-rank local buffer of {iter_num: {block_name: loss}}. We accumulate
272+ # losses locally on every rank and only collide them via all_gather_object
273+ # at log-chunk boundaries — the object collective is pickle-based and
274+ # was previously the per-iter sync cost. See `_flush_loss_buffer` below.
275+ local_losses_buffer : dict [int , dict [str , float ]] = {}
276+ # Buffer variables. Initialise on the active device so non-master ranks
277+ # never hand a CPU tensor to a downstream GPU op if the master-only-fetch
278+ # invariant is ever relaxed (today only master replaces this in the loop).
279+ input_ids = torch .zeros (1 , 1 , dtype = torch .int64 , device = device )
238280
239281 aprint (
240282 f"previous rank: { str (prev_rank ):<5} next rank: { str (next_rank ):<5} { owned_stitched_module_indices = } "
@@ -247,6 +289,11 @@ def train(
247289 # and incremented at the END of each iteration, so we must use `>` (not `>=`)
248290 # to ensure step `max_steps` itself runs before exiting.
249291 if cfg .bypass .step_num > cfg .bypass .training .max_steps :
292+ # Drain any residual buffered losses (< log-chunk boundary) so the
293+ # final partial chunk's stats reach master and can be logged before
294+ # the function returns. Must run on every rank — collective op.
295+ _flush_loss_buffer (local_losses_buffer , stitched_losses_history )
296+ local_losses_buffer .clear ()
250297 if (
251298 cfg .bypass .model .model_overrides .save_checkpoint_when_done
252299 and not cfg .bypass .disable_checkpoint_save
@@ -386,25 +433,17 @@ def train(
386433 else :
387434 iter_stitched_module_losses = {}
388435
389- # Collect losses from all ranks using all_gather_object
390- local_training_stats = LocalTrainingStats (
391- iter_num = cfg .bypass .iter_num ,
392- stitched_module_losses = iter_stitched_module_losses ,
393- )
394- all_training_stats = [None ] * dist .size ()
395- torch .distributed .all_gather_object (all_training_stats , local_training_stats )
396-
397- if dist .is_master ():
398- if cfg .bypass .iter_num == resumed_iter_num :
399- mprint (f"Starting from iter { cfg .bypass .iter_num } " )
436+ if dist .is_master () and cfg .bypass .iter_num == resumed_iter_num :
437+ mprint (f"Starting from iter { cfg .bypass .iter_num } " )
400438
401- # Merge all stats into the losses history
402- assert stitched_losses_history is not None
403- merged_losses : dict [str , float ] = {}
404- for stats in all_training_stats :
405- if stats is not None :
406- merged_losses .update (stats .stitched_module_losses )
407- stitched_losses_history [cfg .bypass .iter_num ] = merged_losses
439+ # Buffer this rank's per-block losses locally. The collide-across-ranks
440+ # gather happens only at log-chunk boundaries (`_flush_loss_buffer`),
441+ # which cuts the per-iter pickle-based all_gather_object cost down to
442+ # one gather per `iters_per_log_chunk` micro-batches.
443+ local_losses_buffer [cfg .bypass .iter_num ] = iter_stitched_module_losses
444+ if len (local_losses_buffer ) >= iters_per_log_chunk :
445+ _flush_loss_buffer (local_losses_buffer , stitched_losses_history )
446+ local_losses_buffer .clear ()
408447
409448 cfg .bypass .token_count += cfg .bypass .training .tokens_per_iter
410449 iter_t1 = time .time ()
@@ -441,11 +480,9 @@ def train(
441480 # Logging
442481 if dist .is_master ():
443482 assert stitched_losses_history is not None
444- # log_interval is in optimizer-step units; the underlying history is
445- # per-iter (micro-batch), so the chunk window is grad_accum × wider.
446- iters_per_log_chunk = (
447- cfg .bypass .training .log_interval * cfg .bypass .training .grad_accumulation_steps
448- )
483+ # `iters_per_log_chunk` is computed once before the loop (in
484+ # micro-batch units = log_interval × grad_accum) and reused for
485+ # both the gather-batching threshold and this log drain.
449486 while len (stitched_losses_history ) >= iters_per_log_chunk :
450487 lowest_iter = next (iter (stitched_losses_history .keys ()))
451488
@@ -830,23 +867,37 @@ def run_bypassed_training(cfg: DictConfig):
830867 load_streaming_fn if not cfg .bypass .data .load_from_disk else load_from_disk_fn
831868 )
832869
833- train_dataloader = create_train_dataloader (
834- seed = seed ,
835- tokenizer = tokenizer ,
836- block_size = cfg .bypass .data .block_size ,
837- dataset_path = cfg .dataset_path ,
838- content_field = cfg .bypass .data .data_column ,
839- fim_rate = cfg .bypass .data .fim_rate ,
840- fim_spm_rate = cfg .bypass .data .fim_spm_rate ,
841- micro_batch_size = cfg .bypass .training .micro_batch_size ,
842- load_dataset_fn = load_dataset_fn ,
843- keep_in_memory = cfg .bypass .data .keep_in_memory ,
844- source_datasets_to_discard = cfg .bypass .data .get ("source_datasets_to_discard" , tuple ()),
845- bos_rate = cfg .bypass .data .bos_rate ,
846- shuffle_seed = cfg .bypass .data .shuffle_train_data_seed ,
847- )
870+ # Only master ever fetches from the train dataloader (training_loop.train
871+ # gates `next(train_iterator)` on `dist.is_master()`), so skip the
872+ # potentially-large HF dataset load + tokenisation on non-master ranks.
873+ if dist .is_master ():
874+ train_dataloader = create_train_dataloader (
875+ seed = seed ,
876+ tokenizer = tokenizer ,
877+ block_size = cfg .bypass .data .block_size ,
878+ dataset_path = cfg .dataset_path ,
879+ content_field = cfg .bypass .data .data_column ,
880+ fim_rate = cfg .bypass .data .fim_rate ,
881+ fim_spm_rate = cfg .bypass .data .fim_spm_rate ,
882+ micro_batch_size = cfg .bypass .training .micro_batch_size ,
883+ load_dataset_fn = load_dataset_fn ,
884+ keep_in_memory = cfg .bypass .data .keep_in_memory ,
885+ source_datasets_to_discard = cfg .bypass .data .get (
886+ "source_datasets_to_discard" , tuple ()
887+ ),
888+ bos_rate = cfg .bypass .data .bos_rate ,
889+ shuffle_seed = cfg .bypass .data .shuffle_train_data_seed ,
890+ )
891+ else :
892+ train_dataloader = None
848893
849894 val_dataloader = None
895+ # Note: val_dataloader is kept constructed on every rank even though only
896+ # master reads from it inside calculate_losses_pipeline. The validation
897+ # block uses `val_dataloader is not None` as a "validation enabled" gate
898+ # that must agree across ranks — and calculate_losses_pipeline itself is
899+ # pipeline-parallel and requires every rank to enter it. Skipping
900+ # construction on non-master ranks would break those invariants.
850901 if not cfg .bypass .disable_validation :
851902 val_dataloader = create_validation_dataloader (
852903 accelerator = None ,
0 commit comments