Skip to content

Commit a6ec178

Browse files
committed
Bump to v0.2.6
1 parent 63670fd commit a6ec178

3 files changed

Lines changed: 3 additions & 3 deletions

File tree

flash_attn/modules/mha.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -436,7 +436,7 @@ def forward(self, x, x_kv=None, key_padding_mask=None, cu_seqlens=None, max_seql
436436
kv = self._update_kv_cache(qkv[:, :, 1:], inference_params)
437437
# If we're processing the prompt, causal=None (use self.causal).
438438
# If we're decoding, then causal=False.
439-
causal = False if inference_params.sequence_len_offset == 0 else None
439+
causal = None if inference_params.sequence_len_offset == 0 else False
440440
context = self.inner_cross_attn(q, kv, causal=causal)
441441
else:
442442
if not self.return_residual:

flash_attn/utils/generation.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ def greedy_decode(input_ids, model, max_length):
4040
inference_params.sequence_len_offset = seqlen_og
4141
while True:
4242
position_ids = torch.full((batch_size, 1), inference_params.sequence_len_offset,
43-
dtype=torch.long, device=input_ids.device)
43+
dtype=torch.long, device=input_ids.device)
4444
logits = model(rearrange(next_token, 'b -> b 1'), position_ids=position_ids,
4545
inference_params=inference_params).logits[:, -1]
4646
scores.append(logits)

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -156,7 +156,7 @@ def append_nvcc_threads(nvcc_extra_args):
156156

157157
setup(
158158
name="flash_attn",
159-
version="0.2.5",
159+
version="0.2.6-1",
160160
packages=find_packages(
161161
exclude=("build", "csrc", "include", "tests", "dist", "docs", "benchmarks", "flash_attn.egg-info",)
162162
),

0 commit comments

Comments
 (0)