Skip to content

Commit af1a4cd

Browse files
Ignore predictions right after EOT.
1 parent d24b533 commit af1a4cd

File tree

2 files changed

+11
-0
lines changed

2 files changed

+11
-0
lines changed

open_lm/losses.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,4 +18,5 @@ def __init__(
1818
self.eps = eps
1919

2020
def forward(self, input: Tensor, target: Tensor) -> Tensor:
21+
# TODO: should ignore_index be taken into account in the regularization term as well?
2122
return super().forward(input, target) + self.eps * torch.square(torch.logsumexp(input, dim=-1)).mean()

open_lm/train.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -147,6 +147,11 @@ def train_one_epoch(model, data, loss, epoch, step, optimizer, scaler, scheduler
147147
with autocast():
148148
inputs, targets = sample_chunk(texts, args)
149149
document_seqlens = get_document_seqlens(inputs, args)
150+
if args.mask_across_documents:
151+
# Some input samples contain EOT as the final token. The prediction after that is meaningless, so it
152+
# should not contribute to the loss.
153+
ignore_indices = torch.nonzero(inputs == SpecialTokens.END_OF_TEXT, as_tuple=True)
154+
targets[ignore_indices] = loss.ignore_index
150155

151156
out, _, _ = model(inputs, document_seqlens=document_seqlens)
152157

@@ -168,6 +173,11 @@ def train_one_epoch(model, data, loss, epoch, step, optimizer, scaler, scheduler
168173
per_batch = args.per_gpu_batch_size // args.accum_freq
169174

170175
inputs, targets = sample_chunk(texts, args)
176+
if args.mask_across_documents:
177+
# Some input samples contain EOT as the final token. The prediction after that is meaningless, so it
178+
# should not contribute to the loss.
179+
ignore_indices = torch.nonzero(inputs == SpecialTokens.END_OF_TEXT, as_tuple=True)
180+
targets[ignore_indices] = loss.ignore_index
171181

172182
for ii in range(args.accum_freq):
173183
maybe_no_sync = nullcontext

0 commit comments

Comments
 (0)