Skip to content

Commit 3e0036b

Browse files
Formatting.
1 parent af1a4cd commit 3e0036b

File tree

4 files changed

+20
-19
lines changed

4 files changed

+20
-19
lines changed

open_lm/attention.py

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ def get_rectangular_mask(shape, q_seq_len, k_seq_len, device, dtype):
1616
)[:, :, :, :k_seq_len]
1717

1818

19-
def xformers_attn(queries, keys, values, is_causal, document_seqlens = None):
19+
def xformers_attn(queries, keys, values, is_causal, document_seqlens=None):
2020
# xformers assumes q, k, v are [batch, seq_len, heads, embed_dim]
2121
# We assume that queries match the last part of the key / value sequences
2222
# see (https://facebookresearch.github.io/xformers/components/ops.html#xformers.ops.fmha.attn_bias.LowerTriangularFromBottomRightMask)
@@ -47,18 +47,24 @@ def xformers_attn(queries, keys, values, is_causal, document_seqlens = None):
4747
device = queries.device
4848
for ds in document_seqlens:
4949
if is_causal and queries.shape[1] == keys.shape[1]:
50-
masks.append(xops.fmha.attn_bias.BlockDiagonalCausalMask.from_seqlens(ds).materialize(shape=(1, heads, q_seq_len, k_seq_len), device=device, dtype=dtype))
50+
masks.append(
51+
xops.fmha.attn_bias.BlockDiagonalCausalMask.from_seqlens(ds).materialize(
52+
shape=(1, heads, q_seq_len, k_seq_len), device=device, dtype=dtype
53+
)
54+
)
5155
elif is_causal and queries.shape[1] > 1:
52-
masks.append(xops.fmha.attn_bias.BlockDiagonalCausalFromBottomRightMask.from_seqlens(ds).materialize(shape=(1, heads, q_seq_len, k_seq_len), device=device, dtype=dtype))
56+
masks.append(
57+
xops.fmha.attn_bias.BlockDiagonalCausalFromBottomRightMask.from_seqlens(ds).materialize(
58+
shape=(1, heads, q_seq_len, k_seq_len), device=device, dtype=dtype
59+
)
60+
)
5361
mask = torch.cat(masks, dim=0)
5462

5563
return xops.memory_efficient_attention(queries, keys, values, attn_bias=mask)
5664

5765

58-
def torch_attn(queries, keys, values, is_causal, document_seqlens = None):
59-
66+
def torch_attn(queries, keys, values, is_causal, document_seqlens=None):
6067
if document_seqlens is None or len(document_seqlens) == 1:
61-
6268
# Need to call contiguous in torch >=2.1, otherwise later calls to .view() fail.
6369
# Possibly related: https://github.com/pytorch/pytorch/issues/110213 - behavior of scaled_dot_product_attention
6470
# changed between 2.0 and 2.1
@@ -89,7 +95,7 @@ def torch_attn(queries, keys, values, is_causal, document_seqlens = None):
8995
.transpose(1, 2)
9096
.contiguous()
9197
)
92-
98+
9399
else:
94100
raise NotImplementedError("Currently supporting --mask-across-documents only with xformers attention.")
95101

open_lm/model.py

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -175,13 +175,7 @@ def forward(self, x: torch.Tensor, is_causal=True, past_key_value=None, use_cach
175175
if use_cache:
176176
past_key_value = [keys, vals]
177177

178-
output = self.attn_fn(
179-
queries,
180-
keys,
181-
vals,
182-
is_causal=is_causal,
183-
document_seqlens=document_seqlens
184-
)
178+
output = self.attn_fn(queries, keys, vals, is_causal=is_causal, document_seqlens=document_seqlens)
185179

186180
output = output.view(batchsize, q_len, -1)
187181

@@ -258,7 +252,7 @@ def forward(self, x, past_key_value=None, use_cache=False, document_seqlens=None
258252
is_causal=True,
259253
past_key_value=past_key_value,
260254
use_cache=use_cache,
261-
document_seqlens=document_seqlens
255+
document_seqlens=document_seqlens,
262256
)
263257
h = x + h
264258
if self._ffn_type == "moe":
@@ -335,7 +329,9 @@ def forward(self, input, past_key_values=None, use_cache=False, document_seqlens
335329
if self.grad_checkpointing:
336330
x, past_key_values[i] = checkpoint(layer, x, past_key_values[i], use_cache, document_seqlens)
337331
else:
338-
x, past_key_values[i] = layer(x, past_key_values[i], use_cache=use_cache, document_seqlens=document_seqlens)
332+
x, past_key_values[i] = layer(
333+
x, past_key_values[i], use_cache=use_cache, document_seqlens=document_seqlens
334+
)
339335
if past_key_values[0] is None:
340336
past_key_values = None
341337
x = self.norm(x)

open_lm/params.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -745,7 +745,7 @@ def parse_args(args):
745745
parser.add_argument(
746746
"--mask-across-documents",
747747
action="store_true",
748-
help="If set, then tokens in the same sequence will be masked across EOT."
748+
help="If set, then tokens in the same sequence will be masked across EOT.",
749749
)
750750

751751
add_model_args(parser)

open_lm/train.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,6 @@
2828
from open_lm.meters import AverageMeter
2929

3030

31-
3231
def unwrap_model(model):
3332
if hasattr(model, "module"):
3433
return model.module
@@ -45,7 +44,7 @@ def backward(total_loss, scaler):
4544

4645
def get_document_seqlens(inputs, args):
4746
"""Get list of document sequence lengths.
48-
47+
4948
Return a list of lists. The length of the outer list is equal to the batch size, while the length of the inner list
5049
is equal to the the number of distinct documents (recognized by EOT tokens). Each element of the inner lists is the
5150
length of that corresponding document

0 commit comments

Comments
 (0)