@@ -992,15 +992,15 @@ mha_varlen_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size
992992}
993993
994994std::vector<at::Tensor>
995- mha_fwd_kvcache (const at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size
995+ mha_fwd_kvcache (at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size
996996 const at::Tensor &kcache, // batch_size x seqlen_k x num_heads_k x head_size
997997 const at::Tensor &vcache, // batch_size x seqlen_k x num_heads_k x head_size
998- c10::optional<const at::Tensor> &k_, // batch_size x seqlen_q x num_heads_k x head_size
999- c10::optional<const at::Tensor> &v_, // batch_size x seqlen_q x num_heads_k x head_size
998+ c10::optional<const at::Tensor> &k_, // batch_size x seqlen_knew x num_heads_k x head_size
999+ c10::optional<const at::Tensor> &v_, // batch_size x seqlen_knew x num_heads_k x head_size
10001000 c10::optional<const at::Tensor> &seqlens_k_, // batch_size
10011001 c10::optional<at::Tensor> &out_, // batch_size x seqlen_q x num_heads x head_size
10021002 const float softmax_scale,
1003- const bool is_causal,
1003+ bool is_causal,
10041004 int num_splits
10051005 ) {
10061006
@@ -1032,15 +1032,24 @@ mha_fwd_kvcache(const at::Tensor &q, // batch_size x seqlen_q x
10321032 const auto sizes = q.sizes ();
10331033
10341034 const int batch_size = sizes[0 ];
1035- const int seqlen_q = sizes[1 ];
1036- const int num_heads = sizes[2 ];
1035+ int seqlen_q = sizes[1 ];
1036+ int num_heads = sizes[2 ];
10371037 const int head_size_og = sizes[3 ];
10381038 const int seqlen_k = kcache.size (1 );
10391039 const int num_heads_k = kcache.size (2 );
10401040 TORCH_CHECK (batch_size > 0 , " batch size must be postive" );
10411041 TORCH_CHECK (head_size_og <= 256 , " FlashAttention forward only supports head dimension at most 256" );
10421042 TORCH_CHECK (num_heads % num_heads_k == 0 , " Number of heads in key/value must divide number of heads in query" );
10431043
1044+ if (seqlen_q == 1 ) { is_causal = false ; } // causal=true is the same as causal=false in this case
1045+
1046+ // Faster to transpose q from (b, 1, h, d) to (b, h, 1, d) in this case
1047+ const int seqlenq_nheads_swapped = seqlen_q == 1 && num_heads_k == 1 && num_heads > 1 ;
1048+ if (seqlenq_nheads_swapped) {
1049+ q = q.transpose (1 , 2 );
1050+ std::swap (seqlen_q, num_heads);
1051+ }
1052+
10441053 CHECK_SHAPE (q, batch_size, seqlen_q, num_heads, head_size_og);
10451054 CHECK_SHAPE (kcache, batch_size, seqlen_k, num_heads_k, head_size_og);
10461055 CHECK_SHAPE (vcache, batch_size, seqlen_k, num_heads_k, head_size_og);
@@ -1111,15 +1120,17 @@ mha_fwd_kvcache(const at::Tensor &q, // batch_size x seqlen_q x
11111120 TORCH_CHECK (v.is_cuda (), " Value tensor must be on CUDA device" );
11121121 TORCH_CHECK (k.stride (-1 ) == 1 , " Key tensor must have contiguous last dimension" );
11131122 TORCH_CHECK (v.stride (-1 ) == 1 , " Value tensor must have contiguous last dimension" );
1114- CHECK_SHAPE (k, batch_size, seqlen_q, num_heads_k, head_size_og);
1115- CHECK_SHAPE (v, batch_size, seqlen_q, num_heads_k, head_size_og);
1123+ int seqlen_knew = k.size (1 );
1124+ CHECK_SHAPE (k, batch_size, seqlen_knew, num_heads_k, head_size_og);
1125+ CHECK_SHAPE (v, batch_size, seqlen_knew, num_heads_k, head_size_og);
11161126 if (head_size_og % 8 != 0 ) {
11171127 k_padded = torch::nn::functional::pad (k, torch::nn::functional::PadFuncOptions ({0 , 8 - head_size_og % 8 }));
11181128 v_padded = torch::nn::functional::pad (v, torch::nn::functional::PadFuncOptions ({0 , 8 - head_size_og % 8 }));
11191129 } else {
11201130 k_padded = k;
11211131 v_padded = v;
11221132 }
1133+ params.seqlen_knew = seqlen_knew;
11231134 params.knew_ptr = k_padded.data_ptr ();
11241135 params.vnew_ptr = v_padded.data_ptr ();
11251136 // All stride are in elements, not bytes.
@@ -1175,6 +1186,10 @@ mha_fwd_kvcache(const at::Tensor &q, // batch_size x seqlen_q x
11751186 }
11761187 }
11771188
1189+ if (seqlenq_nheads_swapped) {
1190+ out = out.transpose (1 , 2 );
1191+ softmax_lse = softmax_lse.transpose (1 , 2 );
1192+ }
11781193 return {out, softmax_lse};
11791194}
11801195
0 commit comments