|
22 | 22 | wandb = None |
23 | 23 |
|
24 | 24 | from open_lm.data import sample_chunk |
| 25 | +from open_lm.datapreprocess.ray.tokenize_shuffle import SpecialTokens |
25 | 26 | from open_lm.distributed import is_master |
26 | 27 | from open_lm.precision import get_autocast |
27 | 28 | from open_lm.meters import AverageMeter |
28 | 29 |
|
29 | 30 |
|
| 31 | + |
30 | 32 | def unwrap_model(model): |
31 | 33 | if hasattr(model, "module"): |
32 | 34 | return model.module |
@@ -109,13 +111,34 @@ def train_one_epoch(model, data, loss, epoch, step, optimizer, scaler, scheduler |
109 | 111 |
|
110 | 112 | (texts,) = batch |
111 | 113 | texts = torch.LongTensor(texts).to(device) |
| 114 | + |
112 | 115 | data_time_m.update(time.time() - end) |
113 | 116 | optimizer.zero_grad() |
114 | 117 |
|
115 | 118 | if args.accum_freq == 1: |
116 | 119 | with autocast(): |
117 | 120 | inputs, targets = sample_chunk(texts, args) |
118 | | - out, _, _ = model(inputs) |
| 121 | + |
| 122 | + if args.mask_across_documents: |
| 123 | + document_seqlens = [] |
| 124 | + for idx in range(inputs.shape[0]): |
| 125 | + eot_idx = torch.nonzero(inputs[idx] == SpecialTokens.END_OF_TEXT.value) |
| 126 | + if len(eot_idx.shape) == 0: |
| 127 | + # Fallback case - an eot token should appear at the end. |
| 128 | + document_seqlens.append([args.seq_len + 1]) |
| 129 | + else: |
| 130 | + start_idx = 0 |
| 131 | + seqlens = [] |
| 132 | + for eidx in eot_idx: |
| 133 | + seqlens.append(eidx - start_idx + 1) |
| 134 | + start_idx = eidx + 1 |
| 135 | + if start_idx < args.seq_len + 1: |
| 136 | + seqlens.append(args.seq_len - start_idx) |
| 137 | + document_seqlens.append(seqlens) |
| 138 | + else: |
| 139 | + document_seqlens = None |
| 140 | + |
| 141 | + out, _, _ = model(inputs, document_seqlens=document_seqlens) |
119 | 142 |
|
120 | 143 | if args.log_logit_mean: |
121 | 144 | logit_m.update(torch.mean(out).item()) |
|
0 commit comments