Skip to content

Commit ee77b93

Browse files
committed
Swap seqlen_q and nheads for MQA to speed it up (h/t Daniel Haziza)
1 parent 0700580 commit ee77b93

5 files changed

Lines changed: 29 additions & 14 deletions

File tree

csrc/flash_attn/flash_api.cpp

Lines changed: 23 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -992,15 +992,15 @@ mha_varlen_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size
992992
}
993993

994994
std::vector<at::Tensor>
995-
mha_fwd_kvcache(const at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size
995+
mha_fwd_kvcache(at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size
996996
const at::Tensor &kcache, // batch_size x seqlen_k x num_heads_k x head_size
997997
const at::Tensor &vcache, // batch_size x seqlen_k x num_heads_k x head_size
998-
c10::optional<const at::Tensor> &k_, // batch_size x seqlen_q x num_heads_k x head_size
999-
c10::optional<const at::Tensor> &v_, // batch_size x seqlen_q x num_heads_k x head_size
998+
c10::optional<const at::Tensor> &k_, // batch_size x seqlen_knew x num_heads_k x head_size
999+
c10::optional<const at::Tensor> &v_, // batch_size x seqlen_knew x num_heads_k x head_size
10001000
c10::optional<const at::Tensor> &seqlens_k_, // batch_size
10011001
c10::optional<at::Tensor> &out_, // batch_size x seqlen_q x num_heads x head_size
10021002
const float softmax_scale,
1003-
const bool is_causal,
1003+
bool is_causal,
10041004
int num_splits
10051005
) {
10061006

@@ -1032,15 +1032,24 @@ mha_fwd_kvcache(const at::Tensor &q, // batch_size x seqlen_q x
10321032
const auto sizes = q.sizes();
10331033

10341034
const int batch_size = sizes[0];
1035-
const int seqlen_q = sizes[1];
1036-
const int num_heads = sizes[2];
1035+
int seqlen_q = sizes[1];
1036+
int num_heads = sizes[2];
10371037
const int head_size_og = sizes[3];
10381038
const int seqlen_k = kcache.size(1);
10391039
const int num_heads_k = kcache.size(2);
10401040
TORCH_CHECK(batch_size > 0, "batch size must be postive");
10411041
TORCH_CHECK(head_size_og <= 256, "FlashAttention forward only supports head dimension at most 256");
10421042
TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query");
10431043

1044+
if (seqlen_q == 1) { is_causal = false; } // causal=true is the same as causal=false in this case
1045+
1046+
// Faster to transpose q from (b, 1, h, d) to (b, h, 1, d) in this case
1047+
const int seqlenq_nheads_swapped = seqlen_q == 1 && num_heads_k == 1 && num_heads > 1;
1048+
if (seqlenq_nheads_swapped) {
1049+
q = q.transpose(1, 2);
1050+
std::swap(seqlen_q, num_heads);
1051+
}
1052+
10441053
CHECK_SHAPE(q, batch_size, seqlen_q, num_heads, head_size_og);
10451054
CHECK_SHAPE(kcache, batch_size, seqlen_k, num_heads_k, head_size_og);
10461055
CHECK_SHAPE(vcache, batch_size, seqlen_k, num_heads_k, head_size_og);
@@ -1111,15 +1120,17 @@ mha_fwd_kvcache(const at::Tensor &q, // batch_size x seqlen_q x
11111120
TORCH_CHECK(v.is_cuda(), "Value tensor must be on CUDA device");
11121121
TORCH_CHECK(k.stride(-1) == 1, "Key tensor must have contiguous last dimension");
11131122
TORCH_CHECK(v.stride(-1) == 1, "Value tensor must have contiguous last dimension");
1114-
CHECK_SHAPE(k, batch_size, seqlen_q, num_heads_k, head_size_og);
1115-
CHECK_SHAPE(v, batch_size, seqlen_q, num_heads_k, head_size_og);
1123+
int seqlen_knew = k.size(1);
1124+
CHECK_SHAPE(k, batch_size, seqlen_knew, num_heads_k, head_size_og);
1125+
CHECK_SHAPE(v, batch_size, seqlen_knew, num_heads_k, head_size_og);
11161126
if (head_size_og % 8 != 0) {
11171127
k_padded = torch::nn::functional::pad(k, torch::nn::functional::PadFuncOptions({0, 8 - head_size_og % 8}));
11181128
v_padded = torch::nn::functional::pad(v, torch::nn::functional::PadFuncOptions({0, 8 - head_size_og % 8}));
11191129
} else {
11201130
k_padded = k;
11211131
v_padded = v;
11221132
}
1133+
params.seqlen_knew = seqlen_knew;
11231134
params.knew_ptr = k_padded.data_ptr();
11241135
params.vnew_ptr = v_padded.data_ptr();
11251136
// All stride are in elements, not bytes.
@@ -1175,6 +1186,10 @@ mha_fwd_kvcache(const at::Tensor &q, // batch_size x seqlen_q x
11751186
}
11761187
}
11771188

1189+
if (seqlenq_nheads_swapped) {
1190+
out = out.transpose(1, 2);
1191+
softmax_lse = softmax_lse.transpose(1, 2);
1192+
}
11781193
return {out, softmax_lse};
11791194
}
11801195

csrc/flash_attn/src/block_info.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ struct BlockInfo {
1919
// If is_seqlens_k_cumulative, then seqlen_k is cu_seqlens_k[bidb + 1] - cu_seqlens_k[bidb].
2020
// Otherwise it's cu_seqlens_k[bidb], i.e., we use cu_seqlens_k to store the sequence lengths of K.
2121
, seqlen_k_cache(!Varlen || params.cu_seqlens_k == nullptr ? params.seqlen_k : (params.is_seqlens_k_cumulative ? params.cu_seqlens_k[bidb + 1] - sum_s_k : params.cu_seqlens_k[bidb]))
22-
, actual_seqlen_k(seqlen_k_cache + (params.knew_ptr == nullptr ? 0 : params.seqlen_q))
22+
, actual_seqlen_k(seqlen_k_cache + (params.knew_ptr == nullptr ? 0 : params.seqlen_knew))
2323
{
2424
}
2525

csrc/flash_attn/src/flash.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ struct Flash_fwd_params : public Qkv_params {
6868
void * __restrict__ softmax_lseaccum_ptr;
6969

7070
// The dimensions.
71-
int b, seqlen_q, seqlen_k, d, seqlen_q_rounded, seqlen_k_rounded, d_rounded;
71+
int b, seqlen_q, seqlen_k, seqlen_knew, d, seqlen_q_rounded, seqlen_k_rounded, d_rounded;
7272

7373
// The scaling factors for the kernel.
7474
float scale_softmax;

csrc/flash_attn/src/flash_fwd_kernel.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -644,7 +644,7 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params &params, cons
644644

645645
const BlockInfo</*Varlen=*/!Is_even_MN> binfo(params, bidb);
646646
// if (threadIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0) { printf("Is_even_MN = %d, is_cumulativ = %d, seqlen_k_cache = %d, actual_seqlen_k = %d\n", Is_even_MN, params.is_seqlens_k_cumulative, binfo.seqlen_k_cache, binfo.actual_seqlen_k); }
647-
// if (threadIdx.x == 0 && blockIdx.y == 1 && blockIdx.z == 0) { printf("params.knew_ptr = %p, seqlen_k_cache + seqlen_q = %d\n", params.knew_ptr, binfo.seqlen_k_cache + (params.knew_ptr == nullptr ? 0 : params.seqlen_q)); }
647+
// if (threadIdx.x == 0 && blockIdx.y == 1 && blockIdx.z == 0) { printf("params.knew_ptr = %p, seqlen_k_cache + seqlen_knew = %d\n", params.knew_ptr, binfo.seqlen_k_cache + (params.knew_ptr == nullptr ? 0 : params.seqlen_knew)); }
648648
if (m_block * kBlockM >= binfo.actual_seqlen_q) return;
649649

650650
const int n_blocks_per_split = ((params.seqlen_k + kBlockN - 1) / kBlockN + num_n_splits - 1) / num_n_splits;

flash_attn/flash_attn_interface.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -838,9 +838,9 @@ def flash_attn_with_kvcache(
838838
q: (batch_size, seqlen, nheads, headdim)
839839
k_cache: (batch_size, seqlen_cache, nheads_k, headdim)
840840
v_cache: (batch_size, seqlen_cache, nheads_k, headdim)
841-
k [optional]: (batch_size, seqlen, nheads_k, headdim). If not None, we concatenate k with
842-
k_cache, starting at the indices specified by cache_seqlens.
843-
v [optional]: (batch_size, seqlen, nheads_k, headdim). Similar to k.
841+
k [optional]: (batch_size, seqlen_new, nheads_k, headdim). If not None, we concatenate
842+
k with k_cache, starting at the indices specified by cache_seqlens.
843+
v [optional]: (batch_size, seqlen_new, nheads_k, headdim). Similar to k.
844844
cache_seqlens: int, or (batch_size,), dtype torch.int32. The sequence lengths of the
845845
KV cache.
846846
softmax_scale: float. The scaling of QK^T before applying softmax.

0 commit comments

Comments
 (0)