1616import os
1717import time
1818from functools import wraps
19- from pprint import pprint
2019from typing import Any , Callable , List , Optional , Protocol , Union
2120
2221import torch
4039logger = logging .getLogger (__name__ )
4140
4241
43- # =============================================================================
44- # Utility Functions
45- # =============================================================================
46-
4742def get_last_pp_rank ():
4843 """Check if current rank is the last pipeline parallel rank."""
4944 is_last_pp = mpu .is_pipeline_last_stage (ignore_virtual = True )
@@ -69,22 +64,9 @@ def broadcast_loss(loss_reduced):
6964 return loss_synced .item ()
7065
7166
72- def get_rank ():
73- """Get current process rank for warmup callback."""
74- return int (os .getenv ("SLURM_PROCID" , 0 ))
75-
76-
77- # =============================================================================
78- # MLPerf Logger
79- # =============================================================================
80-
8167mllogger = MLLoggerWrapper (PyTCommunicationHandler ())
8268
8369
84- # =============================================================================
85- # Timer Utility
86- # =============================================================================
87-
8870class DeltaTimer :
8971 """Timer for measuring time deltas."""
9072
@@ -378,10 +360,6 @@ def install_callbacks() -> None:
378360 pretrain_module .train = train_module .train
379361
380362
381- # =============================================================================
382- # MLPerf Logging Callback
383- # =============================================================================
384-
385363class MLPerfLoggingCallback :
386364 """MLPerf logging callback for compliance logging."""
387365
@@ -575,10 +553,6 @@ def _get_samples_count(self, global_state):
575553 return self ._get_step (global_state ) * self .global_batch_size
576554
577555
578- # =============================================================================
579- # Delta Timing Callback
580- # =============================================================================
581-
582556class DeltaTimingCallback :
583557 """Callback for tracking training step timing."""
584558
@@ -633,157 +607,3 @@ def on_validation_end(
633607 ):
634608 """Reset timer after validation to avoid including validation time in first train step."""
635609 self .t0 = time .time ()
636-
637-
638- # =============================================================================
639- # Warmup Callback
640- # =============================================================================
641-
642- def get_mock_data (config ):
643- """Get mock data configuration for warmup."""
644- from megatron .bridge .training .config import MockGPTDatasetConfig
645-
646- return MockGPTDatasetConfig (
647- sequence_length = config .model .encoder_seq_length ,
648- random_seed = config .model .seed ,
649- dataloader_type = "single" ,
650- num_workers = config .model .data .num_workers ,
651- reset_position_ids = False ,
652- reset_attention_mask = False ,
653- eod_mask_loss = False ,
654- path_to_cache = None ,
655- split = "900,50,50" ,
656- )
657-
658-
659- class WarmupCallback :
660- """Callback for performing training and validation warmup."""
661-
662- def __init__ (self , cfg ):
663- self .cfg = cfg
664- self .train_steps = cfg .model .custom .warmup_train_steps
665- self .val_steps = cfg .model .custom .warmup_validation_steps
666-
667- def on_train_start (
668- self ,
669- global_state ,
670- forward_step_func ,
671- model ,
672- optimizer ,
673- scheduler ,
674- ):
675- if get_rank () == 0 :
676- print ("\n MCore config:" , flush = True )
677- pprint (model [0 ].config )
678-
679- torch .distributed .barrier ()
680- enable_cuda_graph = int (os .getenv ("MCORE_CUDA_GRAPH" , "0" )) == 1
681- cuda_graph_scope = self .cfg .model .overwritten_attributes .cuda_graph_scope
682-
683- forward_backward_func = get_forward_backward_func ()
684- if enable_cuda_graph and cuda_graph_scope == "full_iteration" :
685- forward_backward_func = FullCudaGraphWrapper (
686- forward_backward_func ,
687- cuda_graph_warmup_steps = 1 ,
688- )
689-
690- skip_val_warmup = enable_cuda_graph and cuda_graph_scope != "full_iteration"
691-
692- warmup_mock_cfg = get_mock_data (self .cfg )
693-
694- train_dataloader , val_dataloader , _ = warmup_mock_cfg .build_dataloaders ()
695- data_iterator = iter (train_dataloader )
696- eval_data_iterator = iter (val_dataloader )
697-
698- pp_group = mpu .get_pipeline_model_parallel_group ()
699- torch .distributed .barrier (pp_group )
700-
701- for group in optimizer .param_groups :
702- group ["betas_" ] = group ["betas" ]
703- group ["bias_correction_" ] = group ["bias_correction" ]
704- group ["betas" ] = [1.0 , 1.0 ]
705- group ["bias_correction" ] = False
706-
707- if torch .distributed .get_rank () == 0 :
708- logger .info ("Starting training warmup" )
709- start = time .time ()
710- for step_idx in range (self .train_steps ):
711- if torch .distributed .get_rank () == 0 :
712- logger .info (f" Starting warmup step { step_idx } " )
713- step_timer = time .time ()
714- torch .cuda .synchronize ()
715- torch .distributed .barrier ()
716- forward_backward_func (
717- forward_step_func = forward_step_func ,
718- data_iterator = data_iterator ,
719- model = model ,
720- num_microbatches = get_num_microbatches (),
721- seq_length = self .cfg .model .encoder_seq_length ,
722- micro_batch_size = self .cfg .model .micro_batch_size ,
723- decoder_seq_length = self .cfg .model .encoder_seq_length ,
724- forward_only = False ,
725- )
726- optimizer .zero_grad ()
727- optimizer .step ()
728- optimizer .zero_grad ()
729- torch .cuda .synchronize ()
730-
731- for chunk in model :
732- chunk .module .zero_grad_buffer ()
733- chunk .module .zero_grad ()
734-
735- if torch .distributed .get_rank () == 0 :
736- logger .info (f" Finished warmup step { step_idx } , takes { time .time () - step_timer } s" )
737-
738- torch .cuda .synchronize ()
739- torch .distributed .barrier ()
740- if torch .distributed .get_rank () == 0 :
741- logger .info (f"Finished training warmup: { time .time () - start } s. " )
742-
743- for group in optimizer .param_groups :
744- group ["betas" ] = group ["betas_" ]
745- group ["bias_correction" ] = group ["bias_correction_" ]
746- del group ["betas_" ]
747- del group ["bias_correction_" ]
748- if "step" in group :
749- if isinstance (group ["step" ], torch .Tensor ):
750- group ["step" ].fill_ (1 )
751- else :
752- group ["step" ] = 1
753-
754- if not skip_val_warmup :
755- start = time .time ()
756- if torch .distributed .get_rank () == 0 :
757- logger .info ("Starting validation warmups" )
758- for model_module in model :
759- model_module .eval ()
760- with torch .no_grad ():
761- for _ in range (self .val_steps ):
762- torch .cuda .synchronize ()
763- torch .distributed .barrier ()
764-
765- forward_backward_func (
766- forward_step_func = forward_step_func ,
767- data_iterator = eval_data_iterator ,
768- model = model ,
769- num_microbatches = get_num_microbatches (),
770- seq_length = self .cfg .model .encoder_seq_length ,
771- micro_batch_size = self .cfg .model .micro_batch_size ,
772- forward_only = True ,
773- )
774- torch .cuda .synchronize ()
775-
776- torch .distributed .barrier ()
777- if torch .distributed .get_rank () == 0 :
778- logger .info (f"Finished validation warmup: { time .time () - start } s. " )
779-
780- for chunk in model :
781- chunk .module .zero_grad_buffer ()
782- chunk .module .zero_grad ()
783- if torch .distributed .get_rank () == 0 :
784- logger .info (f"Finished training warmup: { time .time () - start } s. " )
785-
786- torch .cuda .synchronize ()
787- torch .distributed .barrier ()
788- if torch .distributed .get_rank () == 0 :
789- logger .info (f"Time spent in run_training_warmup: { time .time () - start } s" )
0 commit comments