44# See LICENSE for license information.
55#################################################################################
66
7+ from functools import partial
8+
9+ import torch
10+ from megatron .core import mpu
11+ from megatron .core .models .gpt import GPTModel
12+ from megatron .core .rerun_state_machine import get_rerun_state_machine
13+ from megatron .core .utils import StragglerDetector
14+ from megatron .training import get_args , get_timers
15+ from megatron .training .utils import get_batch_on_this_cp_rank , get_batch_on_this_tp_rank
16+
17+ stimer = StragglerDetector ()
718
819from .trainer import MegatronTrainer
920
@@ -13,14 +24,106 @@ def __init__(self, *args, **kwargs):
1324 kwargs ["module_name" ] = "pre_trainer"
1425 super ().__init__ (* args , ** kwargs )
1526
16- def get_batch_func (self ):
17- raise NotImplementedError
27+ def get_batch (self , data_iterator ):
28+ """Generate a batch."""
29+
30+ # TODO: this is pretty hacky, find a better way
31+ if (not mpu .is_pipeline_first_stage ()) and (not mpu .is_pipeline_last_stage ()):
32+ return None , None , None , None , None
33+
34+ # get batches based on the TP rank you are on
35+ batch = get_batch_on_this_tp_rank (data_iterator )
36+
37+ # slice batch along sequence dimension for context parallelism
38+ batch = get_batch_on_this_cp_rank (batch )
39+
40+ return batch .values ()
41+
42+ def loss_func (self , loss_mask : torch .Tensor , output_tensor : torch .Tensor ):
43+ """Loss function.
44+
45+ Args:
46+ loss_mask (torch.Tensor): Used to mask out some portions of the loss
47+ output_tensor (torch.Tensor): The tensor with the losses
48+
49+ Returns:
50+ the loss scalar for this micro-batch
51+ the number of non-padded tokens in this microbatch
52+ a dict containing reporting metrics on the loss and number of tokens across
53+ the data parallel ranks
54+ """
55+ args = get_args ()
56+
57+ losses = output_tensor .float ()
58+ loss_mask = loss_mask .view (- 1 ).float ()
59+ total_tokens = loss_mask .sum ()
60+ loss = torch .cat ([torch .sum (losses .view (- 1 ) * loss_mask ).view (1 ), total_tokens .view (1 )])
61+
62+ if args .context_parallel_size > 1 :
63+ torch .distributed .all_reduce (loss , group = mpu .get_context_parallel_group ())
64+
65+ # Check individual rank losses are not NaN prior to DP all-reduce.
66+ rerun_state_machine = get_rerun_state_machine ()
67+ if args .check_for_nan_in_loss_and_grad :
68+ rerun_state_machine .validate_result (
69+ result = loss [0 ],
70+ rejection_func = torch .isnan ,
71+ message = "found NaN in local forward loss calculation" ,
72+ tolerance = 0.0 , # forward pass calculations are determinisic
73+ fatal = True ,
74+ )
75+ rerun_state_machine .validate_result (
76+ result = loss [0 ],
77+ rejection_func = torch .isinf ,
78+ message = "found Inf in local forward loss calculation" ,
79+ tolerance = 0.0 , # forward pass calculations are determinisic
80+ fatal = True ,
81+ )
82+ # Check for spiky loss
83+ if args .check_for_spiky_loss :
84+ rerun_state_machine .validate_result (
85+ result = loss [0 ],
86+ rejection_func = partial (
87+ rerun_state_machine .is_unexpectedly_large ,
88+ threshold = SPIKY_LOSS_FACTOR ,
89+ context = "loss" ,
90+ ),
91+ message = "Spiky loss" ,
92+ tolerance = 0.0 , # forward pass calculations are determinisic
93+ fatal = False ,
94+ )
95+ # Reduce loss for logging.
96+ reporting_loss = loss .clone ().detach ()
97+ torch .distributed .all_reduce (reporting_loss , group = mpu .get_data_parallel_group ())
98+
99+ # loss[0] is a view of loss, so it has ._base not None, which triggers assert error
100+ # in core/pipeline_parallel/schedule.py::deallocate_output_tensor, calling .clone()
101+ # on loss[0] fixes this
102+ local_num_tokens = loss [1 ].clone ().detach ().to (torch .int )
103+ return (
104+ loss [0 ].clone (),
105+ local_num_tokens ,
106+ {"lm loss" : (reporting_loss [0 ], reporting_loss [1 ])},
107+ )
108+
109+ def forward_step (self , data_iterator , model : GPTModel ):
110+ """Forward training step.
111+
112+ Args:
113+ data_iterator : Input data iterator
114+ model (GPTModel): The GPT Model
115+ """
116+ get_args ()
117+ timers = get_timers ()
18118
19- def get_loss_func (self ):
20- raise NotImplementedError
119+ # Get the batch.
120+ timers ("batch-generator" , log_level = 2 ).start ()
121+ global stimer
122+ with stimer (bdata = True ):
123+ tokens , labels , loss_mask , attention_mask , position_ids = self .get_batch (data_iterator )
124+ timers ("batch-generator" ).stop ()
21125
22- def build_dataset_and_tokenizer ( self ) :
23- raise NotImplementedError
126+ with stimer :
127+ output_tensor = model ( tokens , position_ids , attention_mask , labels = labels )
24128
25- def get_forward_step_func (self ):
26- raise NotImplementedError
129+ return output_tensor , partial (self .loss_func , loss_mask )
0 commit comments