Skip to content

Commit 97f1564

Browse files
nil0x9HAOCHENYE
authored andcommitted
[Enhance] resolve conflictions
1 parent ca3371f commit 97f1564

2 files changed

Lines changed: 0 additions & 9 deletions

File tree

xtuner/v1/model/moe/moe.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -260,7 +260,6 @@ def forward(
260260
seq_ctx: list[SequenceContext] | SequenceContext,
261261
loss_ctx: list[CELossContext] | CELossContext | None,
262262
return_router_logits: bool = False,
263-
return_tokens_per_expert_global: bool = False,
264263
):
265264
# TODO: caoweihan: Recover this assertion after the refactor of LossContext
266265
if isinstance(seq_ctx, SequenceContext):
@@ -272,7 +271,6 @@ def forward(
272271
seq_ctx=seq_ctx,
273272
loss_ctx=loss_ctx, # type: ignore
274273
return_router_logits=return_router_logits,
275-
return_tokens_per_expert_global=return_tokens_per_expert_global,
276274
)
277275
else:
278276
assert isinstance(loss_ctx, list) and len(loss_ctx) == len(seq_ctx), (
@@ -285,15 +283,13 @@ def forward(
285283
seq_ctx_list=seq_ctx,
286284
loss_ctx_list=loss_ctx,
287285
return_router_logits=return_router_logits,
288-
return_tokens_per_expert_global=return_tokens_per_expert_global,
289286
)
290287

291288
def _micro_batch_forward(
292289
self,
293290
seq_ctx_list: list[SequenceContext],
294291
loss_ctx_list: list[CELossContext],
295292
return_router_logits: bool = False,
296-
return_tokens_per_expert_global: bool = False,
297293
) -> MoEModelOutputs:
298294
"""Micro-batch forward pass for MoE model.
299295
@@ -486,7 +482,6 @@ def _forward(
486482
seq_ctx: SequenceContext, # todo(@yehaochen): support intra layer micro-batch
487483
loss_ctx: CELossContext | None,
488484
return_router_logits: bool = False,
489-
return_tokens_per_expert_global: bool = False,
490485
) -> MoEModelOutputs:
491486
input_ids = seq_ctx.input_ids
492487
position_ids = seq_ctx.position_ids

xtuner/v1/train/trainer.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1198,10 +1198,6 @@ def _log_step(
11981198
f"est_global_batch_tokens: {est_global_batch_tokens} "
11991199
f"eta: {eta_hms} "
12001200
)
1201-
if internal_metrics:
1202-
internal_metrics_log_list = [f"{k}: {v:.8f}" for k, v in internal_metrics.items()]
1203-
internal_metrics_log_str = ", ".join(internal_metrics_log_list)
1204-
self.logger.info(f"Step {self.cur_step}/{self.total_step} internal_metrics: {internal_metrics_log_str}")
12051201

12061202
log_scalars = {
12071203
"lr": lr,

0 commit comments

Comments
 (0)