@@ -147,6 +147,11 @@ def train_one_epoch(model, data, loss, epoch, step, optimizer, scaler, scheduler
147
147
with autocast ():
148
148
inputs , targets = sample_chunk (texts , args )
149
149
document_seqlens = get_document_seqlens (inputs , args )
150
+ if args .mask_across_documents :
151
+ # Some input samples contain EOT as the final token. The prediction after that is meaningless, so it
152
+ # should not contribute to the loss.
153
+ ignore_indices = torch .nonzero (inputs == SpecialTokens .END_OF_TEXT , as_tuple = True )
154
+ targets [ignore_indices ] = loss .ignore_index
150
155
151
156
out , _ , _ = model (inputs , document_seqlens = document_seqlens )
152
157
@@ -168,6 +173,11 @@ def train_one_epoch(model, data, loss, epoch, step, optimizer, scaler, scheduler
168
173
per_batch = args .per_gpu_batch_size // args .accum_freq
169
174
170
175
inputs , targets = sample_chunk (texts , args )
176
+ if args .mask_across_documents :
177
+ # Some input samples contain EOT as the final token. The prediction after that is meaningless, so it
178
+ # should not contribute to the loss.
179
+ ignore_indices = torch .nonzero (inputs == SpecialTokens .END_OF_TEXT , as_tuple = True )
180
+ targets [ignore_indices ] = loss .ignore_index
171
181
172
182
for ii in range (args .accum_freq ):
173
183
maybe_no_sync = nullcontext
0 commit comments