@@ -260,7 +260,6 @@ def forward(
260260 seq_ctx : list [SequenceContext ] | SequenceContext ,
261261 loss_ctx : list [CELossContext ] | CELossContext | None ,
262262 return_router_logits : bool = False ,
263- return_tokens_per_expert_global : bool = False ,
264263 ):
265264 # TODO: caoweihan: Recover this assertion after the refactor of LossContext
266265 if isinstance (seq_ctx , SequenceContext ):
@@ -272,7 +271,6 @@ def forward(
272271 seq_ctx = seq_ctx ,
273272 loss_ctx = loss_ctx , # type: ignore
274273 return_router_logits = return_router_logits ,
275- return_tokens_per_expert_global = return_tokens_per_expert_global ,
276274 )
277275 else :
278276 assert isinstance (loss_ctx , list ) and len (loss_ctx ) == len (seq_ctx ), (
@@ -285,15 +283,13 @@ def forward(
285283 seq_ctx_list = seq_ctx ,
286284 loss_ctx_list = loss_ctx ,
287285 return_router_logits = return_router_logits ,
288- return_tokens_per_expert_global = return_tokens_per_expert_global ,
289286 )
290287
291288 def _micro_batch_forward (
292289 self ,
293290 seq_ctx_list : list [SequenceContext ],
294291 loss_ctx_list : list [CELossContext ],
295292 return_router_logits : bool = False ,
296- return_tokens_per_expert_global : bool = False ,
297293 ) -> MoEModelOutputs :
298294 """Micro-batch forward pass for MoE model.
299295
@@ -486,7 +482,6 @@ def _forward(
486482 seq_ctx : SequenceContext , # todo(@yehaochen): support intra layer micro-batch
487483 loss_ctx : CELossContext | None ,
488484 return_router_logits : bool = False ,
489- return_tokens_per_expert_global : bool = False ,
490485 ) -> MoEModelOutputs :
491486 input_ids = seq_ctx .input_ids
492487 position_ids = seq_ctx .position_ids
0 commit comments