@@ -43,6 +43,34 @@ def backward(total_loss, scaler):
4343 total_loss .backward ()
4444
4545
46+ def get_document_seqlens (inputs , args ):
47+ """Get list of document sequence lengths.
48+
49+ Return a list of lists. The length of the outer list is equal to the batch size, while the length of the inner list
50+ is equal to the the number of distinct documents (recognized by EOT tokens). Each element of the inner lists is the
51+ length of that corresponding document
52+ """
53+ if args .mask_across_documents :
54+ document_seqlens = []
55+ for idx in range (inputs .shape [0 ]):
56+ eot_idx = torch .nonzero (inputs [idx ] == SpecialTokens .END_OF_TEXT .value )
57+ if len (eot_idx .shape ) == 0 :
58+ # Fallback case - an eot token should appear at the end.
59+ document_seqlens .append ([args .seq_len + 1 ])
60+ else :
61+ start_idx = 0
62+ seqlens = []
63+ for k in range (eot_idx .shape [0 ]):
64+ seqlens .append (eot_idx [k ] - start_idx + 1 )
65+ start_idx = eot_idx [k ] + 1
66+ if start_idx < args .seq_len + 1 :
67+ seqlens .append (args .seq_len - start_idx )
68+ document_seqlens .append (seqlens )
69+ else :
70+ document_seqlens = None
71+ return document_seqlens
72+
73+
4674def train_one_epoch (model , data , loss , epoch , step , optimizer , scaler , scheduler , total_steps , args , tb_writer = None ):
4775 """Trains model for one epoch on the provided data.
4876
@@ -118,25 +146,7 @@ def train_one_epoch(model, data, loss, epoch, step, optimizer, scaler, scheduler
118146 if args .accum_freq == 1 :
119147 with autocast ():
120148 inputs , targets = sample_chunk (texts , args )
121-
122- if args .mask_across_documents :
123- document_seqlens = []
124- for idx in range (inputs .shape [0 ]):
125- eot_idx = torch .nonzero (inputs [idx ] == SpecialTokens .END_OF_TEXT .value )
126- if len (eot_idx .shape ) == 0 :
127- # Fallback case - an eot token should appear at the end.
128- document_seqlens .append ([args .seq_len + 1 ])
129- else :
130- start_idx = 0
131- seqlens = []
132- for eidx in eot_idx :
133- seqlens .append (eidx - start_idx + 1 )
134- start_idx = eidx + 1
135- if start_idx < args .seq_len + 1 :
136- seqlens .append (args .seq_len - start_idx )
137- document_seqlens .append (seqlens )
138- else :
139- document_seqlens = None
149+ document_seqlens = get_document_seqlens (inputs , args )
140150
141151 out , _ , _ = model (inputs , document_seqlens = document_seqlens )
142152
@@ -170,7 +180,8 @@ def train_one_epoch(model, data, loss, epoch, step, optimizer, scaler, scheduler
170180 if inputs_ii .shape [0 ] == 0 :
171181 break
172182 targets_ii = targets [ii * per_batch : (ii + 1 ) * per_batch ]
173- out , _ , _ = model (inputs_ii )
183+ document_seqlens = get_document_seqlens (inputs_ii , args )
184+ out , _ , _ = model (inputs_ii , document_seqlens = document_seqlens )
174185
175186 if args .log_logit_mean :
176187 logit_m .update (torch .mean (out ).item ())
0 commit comments