Skip to content

Commit 3e88907

Browse files
Add attention calls training.
1 parent e24738c commit 3e88907

File tree

1 file changed

+24
-1
lines changed

1 file changed

+24
-1
lines changed

open_lm/train.py

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,11 +22,13 @@
2222
wandb = None
2323

2424
from open_lm.data import sample_chunk
25+
from open_lm.datapreprocess.ray.tokenize_shuffle import SpecialTokens
2526
from open_lm.distributed import is_master
2627
from open_lm.precision import get_autocast
2728
from open_lm.meters import AverageMeter
2829

2930

31+
3032
def unwrap_model(model):
3133
if hasattr(model, "module"):
3234
return model.module
@@ -109,13 +111,34 @@ def train_one_epoch(model, data, loss, epoch, step, optimizer, scaler, scheduler
109111

110112
(texts,) = batch
111113
texts = torch.LongTensor(texts).to(device)
114+
112115
data_time_m.update(time.time() - end)
113116
optimizer.zero_grad()
114117

115118
if args.accum_freq == 1:
116119
with autocast():
117120
inputs, targets = sample_chunk(texts, args)
118-
out, _, _ = model(inputs)
121+
122+
if args.mask_across_documents:
123+
document_seqlens = []
124+
for idx in range(inputs.shape[0]):
125+
eot_idx = torch.nonzero(inputs[idx] == SpecialTokens.END_OF_TEXT.value)
126+
if len(eot_idx.shape) == 0:
127+
# Fallback case - an eot token should appear at the end.
128+
document_seqlens.append([args.seq_len + 1])
129+
else:
130+
start_idx = 0
131+
seqlens = []
132+
for eidx in eot_idx:
133+
seqlens.append(eidx - start_idx + 1)
134+
start_idx = eidx + 1
135+
if start_idx < args.seq_len + 1:
136+
seqlens.append(args.seq_len - start_idx)
137+
document_seqlens.append(seqlens)
138+
else:
139+
document_seqlens = None
140+
141+
out, _, _ = model(inputs, document_seqlens=document_seqlens)
119142

120143
if args.log_logit_mean:
121144
logit_m.update(torch.mean(out).item())

0 commit comments

Comments
 (0)