@@ -230,7 +230,7 @@ def forward_only(
230230 config = get_model_config (model [0 ])
231231
232232 def forward_step (
233- data_iterator : DataIterator , model : GPTModel
233+ data_iterator : DataIterator , model : GPTModel , return_schedule_plan : bool = False
234234 ) -> tuple [torch .Tensor , Callable [[torch .Tensor ], dict [str , list [torch .Tensor ]]]]:
235235 """Forward step used by Megatron's pipeline engine.
236236
@@ -244,6 +244,8 @@ def forward_step(
244244 to be collected by the engine.
245245 """
246246
247+ assert not return_schedule_plan , "forward_only step should never return schedule plan"
248+
247249 # Get the batch.
248250 batch = get_batch (data_iterator , ["tokens" , "total_lengths" , "response_lengths" ])
249251 unconcat_tokens = batch ["unconcat_tokens" ]
@@ -364,7 +366,7 @@ def train_one_step(
364366 custom_before_train_step_hook = load_function (args .custom_megatron_before_train_step_hook_path )
365367 custom_before_train_step_hook (args , rollout_id , step_id , model , optimizer , opt_param_scheduler )
366368
367- def forward_step (data_iterator : DataIterator , model : GPTModel ) -> tuple [
369+ def forward_step (data_iterator : DataIterator , model : GPTModel , return_schedule_plan : bool = False ) -> tuple [
368370 torch .Tensor ,
369371 Callable [[torch .Tensor ], tuple [torch .Tensor , int , dict [str , torch .Tensor | list [str ]]]],
370372 ]:
@@ -402,13 +404,25 @@ def forward_step(data_iterator: DataIterator, model: GPTModel) -> tuple[
402404 old_stage = os .environ ["ROUTING_REPLAY_STAGE" ]
403405 os .environ ["ROUTING_REPLAY_STAGE" ] = "replay_forward"
404406
405- output_tensor = model (
406- input_ids = batch ["tokens" ],
407- position_ids = None ,
408- attention_mask = None ,
409- labels = None ,
410- packed_seq_params = batch ["packed_seq_params" ],
411- )
407+ if return_schedule_plan :
408+ assert (
409+ args .overlap_moe_expert_parallel_comm
410+ ), "overlap_moe_expert_parallel_comm must be enabled to return the schedule plan"
411+ output_tensor = model .build_schedule_plan (
412+ input_ids = batch ["tokens" ],
413+ position_ids = None ,
414+ attention_mask = None ,
415+ labels = None ,
416+ packed_seq_params = batch ["packed_seq_params" ],
417+ )
418+ else :
419+ output_tensor = model (
420+ input_ids = batch ["tokens" ],
421+ position_ids = None ,
422+ attention_mask = None ,
423+ labels = None ,
424+ packed_seq_params = batch ["packed_seq_params" ],
425+ )
412426
413427 if os .environ .get ("ENABLE_ROUTING_REPLAY" , "0" ) == "1" :
414428 os .environ ["ROUTING_REPLAY_STAGE" ] = old_stage
0 commit comments