-
Notifications
You must be signed in to change notification settings - Fork 6
Open
Description
Hello,
Thank you very much for open-sourcing your work!
I tried to install flash-moba in an environment with CUDA 12.6, and it seems to run fine using this small test :
import torch
import flash_moba
B, L, H, D = 2, 512, 16, 128
q = torch.randn(B*L, H, D).cuda()
k = torch.randn(B*L, H, D).cuda()
v = torch.randn(B*L, H, D).cuda()
# cu_seqlens, max_seqlen
seqlens = torch.randint(L // 2, L + 1, (B,), device="cuda", dtype=torch.int32)
cu_seqlens = torch.zeros(B + 1, device="cuda", dtype=torch.int32)
cu_seqlens[1:] = torch.cumsum(seqlens, dim=0)
max_seqlen = int(seqlens.max())
o = flash_moba.flash_moba_varlen_func(
q.bfloat16(),
k.bfloat16(),
v.bfloat16(),
cu_seqlens_q=cu_seqlens,
cu_seqlens_k=cu_seqlens,
max_seqlen_q=max_seqlen,
max_seqlen_k=max_seqlen,
moba_chunk_size=128,
moba_topk=2,
)
The README states that CUDA 12.9 is required. Do you think that it's necessary ?
Thank you
(for the context, I'm on an HPC so upgrading CUDA version is a bit of a mess, that's why I ask)
Metadata
Metadata
Assignees
Labels
No labels