@@ -71,7 +71,6 @@ class InternalMetricsRecorder:
7171 def __init__ (self , internal_metrics_cfg : InternalMetricsConfig , engine : TrainEngine ):
7272 self .internal_metrics_cfg = internal_metrics_cfg
7373 self .model = engine .model
74- self .intra_layer_micro_batch = engine .intra_layer_micro_batch
7574 self .hooks : list [RemovableHandle ] = []
7675 self ._attn_monitor_type : str | None = None
7776 self .attn_max_lse : dict [str , torch .Tensor ] = {}
@@ -169,25 +168,10 @@ def pop_metrics(self, data_batches: list[ModelItem]):
169168
170169 # do dummy forward to get metrics
171170 if self .need_dummy_forward :
172- for i in range (0 , len (data_batches ), self .intra_layer_micro_batch ):
173- data_batch = data_batches [i : i + self .intra_layer_micro_batch ]
174- seq_ctx_list = []
175- loss_ctx_list = []
176- for data in data_batch :
177- seq_ctx = data ["seq_ctx" ]
178- loss_ctx = data ["loss_ctx" ]
179- seq_ctx_list .append (seq_ctx )
180- loss_ctx_list .append (loss_ctx )
181- if self .intra_layer_micro_batch == 1 :
182- output = self .model (seq_ctx = seq_ctx_list [0 ], loss_ctx = loss_ctx_list [0 ], ** additional_kwargs )
183- else :
184- # although we dont need loss at this point, we still need loss_ctx for micro-batch forward
185- output = self .model (
186- seq_ctx = seq_ctx_list ,
187- loss_ctx = loss_ctx_list ,
188- ** additional_kwargs ,
189- )
190-
171+ for i in range (0 , len (data_batches )):
172+ data_batch = data_batches [i ]
173+ seq_ctx = data_batch ["seq_ctx" ]
174+ output = self .model (seq_ctx = seq_ctx , loss_ctx = None , ** additional_kwargs )
191175 if (
192176 self .internal_metrics_cfg .monitor_moe_load_balance_stats
193177 and (cur_tokens_per_expert := output .get ("tokens_per_expert_global" )) is not None
0 commit comments