Skip to content

Commit 22cd4eb

Browse files
authored
doc attention eot enum value
1 parent 9c7fbe1 commit 22cd4eb

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

open_lm/train.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -149,7 +149,7 @@ def train_one_epoch(model, data, loss, epoch, step, optimizer, scaler, scheduler
149149
if args.mask_across_documents:
150150
# Some input samples contain EOT as the final token. The prediction after that is meaningless, so it
151151
# should not contribute to the loss.
152-
ignore_indices = torch.nonzero(inputs == SpecialTokens.END_OF_TEXT, as_tuple=True)
152+
ignore_indices = torch.nonzero(inputs == SpecialTokens.END_OF_TEXT.value, as_tuple=True)
153153
targets[ignore_indices] = loss.ignore_index
154154

155155
out, _, _ = model(inputs, document_seqlens=document_seqlens)
@@ -175,7 +175,7 @@ def train_one_epoch(model, data, loss, epoch, step, optimizer, scaler, scheduler
175175
if args.mask_across_documents:
176176
# Some input samples contain EOT as the final token. The prediction after that is meaningless, so it
177177
# should not contribute to the loss.
178-
ignore_indices = torch.nonzero(inputs == SpecialTokens.END_OF_TEXT, as_tuple=True)
178+
ignore_indices = torch.nonzero(inputs == SpecialTokens.END_OF_TEXT.value, as_tuple=True)
179179
targets[ignore_indices] = loss.ignore_index
180180

181181
for ii in range(args.accum_freq):

0 commit comments

Comments
 (0)