Skip to content

Commit 763f18d

Browse files
authored
fix scaling of per token loss (#987)
1 parent 4d837ff commit 763f18d

File tree

1 file changed

+7
-3
lines changed
  • slime/backends/megatron_utils

1 file changed

+7
-3
lines changed

slime/backends/megatron_utils/loss.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -702,9 +702,13 @@ def loss_function(
702702
raise ValueError(f"Unknown loss type: {args.loss_type}")
703703

704704
# Here we need to divide by cp_size because to cancel the multiply in Megatron.
705-
loss = (
706-
loss * num_microbatches / args.global_batch_size * mpu.get_data_parallel_world_size(with_context_parallel=True)
707-
)
705+
if not args.calculate_per_token_loss:
706+
loss = (
707+
loss
708+
* num_microbatches
709+
/ args.global_batch_size
710+
* mpu.get_data_parallel_world_size(with_context_parallel=True)
711+
)
708712

709713
return (
710714
loss,

0 commit comments

Comments
 (0)