We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
1 parent 4d837ff commit 763f18dCopy full SHA for 763f18d
slime/backends/megatron_utils/loss.py
@@ -702,9 +702,13 @@ def loss_function(
702
raise ValueError(f"Unknown loss type: {args.loss_type}")
703
704
# Here we need to divide by cp_size because to cancel the multiply in Megatron.
705
- loss = (
706
- loss * num_microbatches / args.global_batch_size * mpu.get_data_parallel_world_size(with_context_parallel=True)
707
- )
+ if not args.calculate_per_token_loss:
+ loss = (
+ loss
708
+ * num_microbatches
709
+ / args.global_batch_size
710
+ * mpu.get_data_parallel_world_size(with_context_parallel=True)
711
+ )
712
713
return (
714
loss,
0 commit comments