Skip to content

Commit 2790975

Browse files
authored
fix ring attn w/ native backend in torch 2.10 (#750)
1 parent ffd5914 commit 2790975

1 file changed

Lines changed: 6 additions & 2 deletions

File tree

src/cache_dit/parallelism/attention/_attention_dispatch.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -196,8 +196,12 @@ def _native_attention_forward_op(
196196
is_causal=is_causal,
197197
scale=scale,
198198
)[:2]
199-
out = out.transpose(1, 2)
200-
lse = lse.transpose(1, 2)
199+
# [B, H, N, D] -> [B, N, H, D]
200+
out = out.transpose(1, 2) # type: torch.Tensor
201+
lse = lse.transpose(1, 2) # type: torch.Tensor
202+
if lse.dim() == 3:
203+
# [B, N, H] -> [B, N, H, 1]
204+
lse = lse.unsqueeze(-1)
201205
return out, lse
202206

203207
query, key, value = (x.permute(0, 2, 1, 3) for x in (query, key, value))

0 commit comments

Comments
 (0)