Skip to content

Commit f19556c

Browse files
committed
Change default for loss normalization
1 parent a0f7534 commit f19556c

4 files changed

Lines changed: 6 additions & 5 deletions

File tree

keys_values/finetune/args.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -429,6 +429,7 @@ class TrainArgs:
429429
batch. Otherwise (`False`, the default), we average the sum of loss
430430
values per data case (by the number of non-masked target tokens),
431431
then use the uniform average over the batch.
432+
Defaults to `True`.
432433
"""
433434

434435
save_interval: Optional[int] = 1000
@@ -459,7 +460,7 @@ class TrainArgs:
459460
intermed_save_interval: Optional[int] = None
460461
intermed_save_num: Optional[int] = None
461462
max_grad_norm: Optional[float] = 1.0
462-
average_loss_per_batch: Optional[bool] = False
463+
average_loss_per_batch: Optional[bool] = True
463464

464465
def __post_init__(self) -> None:
465466
if self.lr_warmup_fraction and self.lr_warmup_steps:

keys_values/finetune/longcontext_full.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -156,7 +156,7 @@ def setup(
156156
intermed_save_interval=None,
157157
intermed_save_num=None,
158158
max_grad_norm=1.0,
159-
average_loss_per_batch=False,
159+
average_loss_per_batch=True,
160160
),
161161
eval: EvalArgs = EvalArgs(
162162
interval=600,

keys_values/kvcache/gradient/main.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -240,7 +240,7 @@ def __init__(
240240
offload_device: Optional[torch.device] = None,
241241
offload_grad_accum: Optional[CPUOffloadAccumulateGradients] = None,
242242
track_unmatched_annotations: Optional[Callable[[int, int], bool]] = None,
243-
average_loss_per_batch: bool = False,
243+
average_loss_per_batch: bool = True,
244244
debug_gpt_model: Optional[GPT] = None,
245245
debug_intermediates: Optional[DebugIntermediates] = None,
246246
debug_profile_forward: bool = False,
@@ -318,7 +318,7 @@ def __init__(
318318
`track_unmatched_annotations(layer_idx, chunk_idx)` is `True`,
319319
where `chunk_idx` is the first chunk in the cell.
320320
average_loss_per_batch: See :meth:`LongContextInferenceModel.forward`.
321-
Defaults to `False`.
321+
Defaults to `True`.
322322
323323
"""
324324
if head_model is None:

keys_values/long_context.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -593,7 +593,7 @@ def forward(
593593
as the model is to be used for token generations.
594594
595595
Some luss functions are defined over target tokens. For these:
596-
If `average_loss_per_batch == False` (default), each loss value
596+
If `average_loss_per_batch == False`, each loss value
597597
`l[b]` is normalized by the number `nz[b]` of (not ignored)
598598
target tokens: `l[b] = s[b] / nz[b]`, if `s[b]` is the sum of loss
599599
values over target tokens.

0 commit comments

Comments
 (0)