Skip to content

Commit d24b533

Browse files
Running version.
1 parent 3e88907 commit d24b533

File tree

3 files changed

+39
-23
lines changed

3 files changed

+39
-23
lines changed

open_lm/attention.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ def xformers_attn(queries, keys, values, is_causal, document_seqlens = None):
2323
# we would like to replace the mask generation with: mask = xops.fmha.attn_bias.LowerTriangularFromBottomRightMask()
2424
# sadly we cannot us this because it needs xformers>=0.0.23 and this is not compatible with torch<2.1.1 while llm-foundry requires torch<2.1.1
2525

26-
if document_seqlens is None or all(len(d) == 1 for ds in document_seqlens):
26+
if document_seqlens is None or all(len(ds) == 1 for ds in document_seqlens):
2727
# In this case, all the tokens inside the sequence (are considered to) come from the same document.
2828
# The attention mask is constructed as a simple causal mask
2929

@@ -41,11 +41,15 @@ def xformers_attn(queries, keys, values, is_causal, document_seqlens = None):
4141

4242
else:
4343
masks = []
44+
batch, q_seq_len, heads, _ = queries.shape
45+
k_seq_len = keys.shape[1]
46+
dtype = queries.dtype
47+
device = queries.device
4448
for ds in document_seqlens:
4549
if is_causal and queries.shape[1] == keys.shape[1]:
46-
masks.append(xops.fmha.attn_bias.BlockDiagonalCausalMask.from_seqlens(document_seqlens).materialize(shape=(1, queries.shape[1], queries.shape[1])))
50+
masks.append(xops.fmha.attn_bias.BlockDiagonalCausalMask.from_seqlens(ds).materialize(shape=(1, heads, q_seq_len, k_seq_len), device=device, dtype=dtype))
4751
elif is_causal and queries.shape[1] > 1:
48-
masks.append(xops.fmha.attn_bias.BlockDiagonalCausalFromBottomRightMask.from_seqlens(document_seqlens).materialize(shape=(1, queries.shape[1], keys.shape[1])))
52+
masks.append(xops.fmha.attn_bias.BlockDiagonalCausalFromBottomRightMask.from_seqlens(ds).materialize(shape=(1, heads, q_seq_len, k_seq_len), device=device, dtype=dtype))
4953
mask = torch.cat(masks, dim=0)
5054

5155
return xops.memory_efficient_attention(queries, keys, values, attn_bias=mask)

open_lm/model.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -180,6 +180,7 @@ def forward(self, x: torch.Tensor, is_causal=True, past_key_value=None, use_cach
180180
keys,
181181
vals,
182182
is_causal=is_causal,
183+
document_seqlens=document_seqlens
183184
)
184185

185186
output = output.view(batchsize, q_len, -1)

open_lm/train.py

Lines changed: 31 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,34 @@ def backward(total_loss, scaler):
4343
total_loss.backward()
4444

4545

46+
def get_document_seqlens(inputs, args):
47+
"""Get list of document sequence lengths.
48+
49+
Return a list of lists. The length of the outer list is equal to the batch size, while the length of the inner list
50+
is equal to the the number of distinct documents (recognized by EOT tokens). Each element of the inner lists is the
51+
length of that corresponding document
52+
"""
53+
if args.mask_across_documents:
54+
document_seqlens = []
55+
for idx in range(inputs.shape[0]):
56+
eot_idx = torch.nonzero(inputs[idx] == SpecialTokens.END_OF_TEXT.value)
57+
if len(eot_idx.shape) == 0:
58+
# Fallback case - an eot token should appear at the end.
59+
document_seqlens.append([args.seq_len + 1])
60+
else:
61+
start_idx = 0
62+
seqlens = []
63+
for k in range(eot_idx.shape[0]):
64+
seqlens.append(eot_idx[k] - start_idx + 1)
65+
start_idx = eot_idx[k] + 1
66+
if start_idx < args.seq_len + 1:
67+
seqlens.append(args.seq_len - start_idx)
68+
document_seqlens.append(seqlens)
69+
else:
70+
document_seqlens = None
71+
return document_seqlens
72+
73+
4674
def train_one_epoch(model, data, loss, epoch, step, optimizer, scaler, scheduler, total_steps, args, tb_writer=None):
4775
"""Trains model for one epoch on the provided data.
4876
@@ -118,25 +146,7 @@ def train_one_epoch(model, data, loss, epoch, step, optimizer, scaler, scheduler
118146
if args.accum_freq == 1:
119147
with autocast():
120148
inputs, targets = sample_chunk(texts, args)
121-
122-
if args.mask_across_documents:
123-
document_seqlens = []
124-
for idx in range(inputs.shape[0]):
125-
eot_idx = torch.nonzero(inputs[idx] == SpecialTokens.END_OF_TEXT.value)
126-
if len(eot_idx.shape) == 0:
127-
# Fallback case - an eot token should appear at the end.
128-
document_seqlens.append([args.seq_len + 1])
129-
else:
130-
start_idx = 0
131-
seqlens = []
132-
for eidx in eot_idx:
133-
seqlens.append(eidx - start_idx + 1)
134-
start_idx = eidx + 1
135-
if start_idx < args.seq_len + 1:
136-
seqlens.append(args.seq_len - start_idx)
137-
document_seqlens.append(seqlens)
138-
else:
139-
document_seqlens = None
149+
document_seqlens = get_document_seqlens(inputs, args)
140150

141151
out, _, _ = model(inputs, document_seqlens=document_seqlens)
142152

@@ -170,7 +180,8 @@ def train_one_epoch(model, data, loss, epoch, step, optimizer, scaler, scheduler
170180
if inputs_ii.shape[0] == 0:
171181
break
172182
targets_ii = targets[ii * per_batch : (ii + 1) * per_batch]
173-
out, _, _ = model(inputs_ii)
183+
document_seqlens = get_document_seqlens(inputs_ii, args)
184+
out, _, _ = model(inputs_ii, document_seqlens=document_seqlens)
174185

175186
if args.log_logit_mean:
176187
logit_m.update(torch.mean(out).item())

0 commit comments

Comments
 (0)