-
Notifications
You must be signed in to change notification settings - Fork 6
Description
Hi,
Congratulations on the great work. I’m very interested in Flash MoBA and have been running some tests on it. My main goal was to evaluate the impact of moba_chunk_size and moba_topk*on GPU memory usage and computation time, and to compare the results against PyTorch’s scaled_dot_product_attention.
I focused on the non-causal (causal=False) case and ran experiments on an H100 with:
batch_size = 1nheads = 4headdim = 128causal = False
The test settings were:
seqlens = [1024, 2048, 4096, 8192, 16384, 32768, 65536, 131072, 262144, 524288, 1_000_000]
chunk_sizes = [64, 128, 256, 512, 1024]
topks = [2, 4, 8, 16, 32]The results are summarized as follows:
- For a fixed sequence length,
moba_chunk_sizeandmoba_topkappear to have no impact on peak memory usage. - For a fixed sequence length, larger
moba_chunk_sizeandmoba_topklead to longer computation time. - Flash MoBA is faster than SDPA.
- However, Flash MoBA consumes more memory than SDPA.
The test code is as follows:
In addition, could you please help clarify the following questions:
-
In the original [MoBA](https://github.com/MoonshotAI/MoBA/blob/master/moba/moba_efficient.py) implementation, the final output is obtained by combining sparse attention and self-attention via online softmax — where self-attention performs local attention within each chunk (each token attends to previous tokens inside its own chunk), and sparse attention computes top-k cross-chunk attention (selected tokens attend to the top-k most relevant chunks).
For Flash MoBA, does it only include the top-k cross-chunk attention, without the local self-attention component? -
When
seqlen_kis not divisible bymoba_chunk_size, how is the tail chunk handled?
Thank you very much!