Skip to content

Commit 7234b31

Browse files
Trying to debug.
1 parent e4a5bac commit 7234b31

File tree

3 files changed

+9
-8
lines changed

3 files changed

+9
-8
lines changed

open_lm/attention.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ def xformers_attn(queries, keys, values, is_causal, document_seqlens=None):
2222
# see (https://facebookresearch.github.io/xformers/components/ops.html#xformers.ops.fmha.attn_bias.LowerTriangularFromBottomRightMask)
2323
# we would like to replace the mask generation with: mask = xops.fmha.attn_bias.LowerTriangularFromBottomRightMask()
2424
# sadly we cannot us this because it needs xformers>=0.0.23 and this is not compatible with torch<2.1.1 while llm-foundry requires torch<2.1.1
25-
25+
print("attention called")
2626
if document_seqlens is None or all(len(ds) == 1 for ds in document_seqlens):
2727
# In this case, all the tokens inside the sequence (are considered to) come from the same document.
2828
# The attention mask is constructed as a simple causal mask

open_lm/model.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -155,6 +155,7 @@ def reset_parameters(self):
155155
torch.nn.init.trunc_normal_(self.out_proj.weight, std=std, a=-3 * std, b=3 * std)
156156

157157
def forward(self, x: torch.Tensor, is_causal=True, past_key_value=None, use_cache=False, document_seqlens=None):
158+
print("attention called")
158159
batchsize, q_len, _ = x.shape
159160
queries, keys, vals = self.in_proj(x).chunk(3, dim=-1)
160161

@@ -247,6 +248,7 @@ def reset_parameters(self):
247248
torch.nn.init.trunc_normal_(self._ff_w2.weight, std=std, a=-3 * std, b=3 * std)
248249

249250
def forward(self, x, past_key_value=None, use_cache=False, document_seqlens=None):
251+
print("block called")
250252
h, past_key_value = self.attention(
251253
self.attention_norm(x),
252254
is_causal=True,
@@ -320,7 +322,6 @@ def set_grad_checkpointing(self, enable=True):
320322
def forward(self, input, past_key_values=None, use_cache=False, document_seqlens=None):
321323
x = self.tok_embeddings(input)
322324
x = self.post_embed_norm(x)
323-
324325
if past_key_values is None:
325326
past_key_values = [None] * self.n_layers
326327
elif isinstance(past_key_values, tuple):

open_lm/train.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -53,16 +53,16 @@ def get_document_seqlens(inputs, args):
5353
document_seqlens = []
5454
for idx in range(inputs.shape[0]):
5555
eot_idx = torch.nonzero(inputs[idx] == SpecialTokens.END_OF_TEXT.value)
56-
if len(eot_idx.shape) == 0:
57-
# Fallback case - an eot token should appear at the end.
58-
document_seqlens.append([args.seq_len + 1])
56+
if eot_idx.shape[0] == 0:
57+
# All tokens come from the same document.
58+
document_seqlens.append([args.seq_len])
5959
else:
6060
start_idx = 0
6161
seqlens = []
6262
for k in range(eot_idx.shape[0]):
63-
seqlens.append(eot_idx[k] - start_idx + 1)
64-
start_idx = eot_idx[k] + 1
65-
if start_idx < args.seq_len + 1:
63+
seqlens.append(eot_idx[k].item() - start_idx + 1)
64+
start_idx = eot_idx[k].item() + 1
65+
if start_idx < args.seq_len:
6666
seqlens.append(args.seq_len - start_idx)
6767
document_seqlens.append(seqlens)
6868
else:

0 commit comments

Comments
 (0)