Skip to content

Commit 3764de7

Browse files
authored
enable sink and local attn (vllm-project#58)
Signed-off-by: Yizhou Wang <yizhou.wang@intel.com>
1 parent 38815f9 commit 3764de7

8 files changed

Lines changed: 160 additions & 82 deletions

File tree

csrc/flash_attn/flash_api.cpp

Lines changed: 12 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -8,27 +8,18 @@
88
namespace FLASH_NAMESPACE {
99

1010
std::vector<at::Tensor> mha_varlen_fwd(
11-
const at::Tensor&
12-
q, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i
13-
const at::Tensor& k, // total_k x num_heads_k x head_size, total_k :=
14-
// \sum_{i=0}^{b} s_i or num_blocks x page_block_size
15-
// x num_heads_k x head_size if there's a block_table.
16-
const at::Tensor& v, // total_k x num_heads_k x head_size, total_k :=
17-
// \sum_{i=0}^{b} s_i or num_blocks x page_block_size
18-
// x num_heads_k x head_size if there's a block_table.
19-
std::optional<at::Tensor>&
20-
out_, // total_q x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i
11+
const at::Tensor& q, const at::Tensor& k, const at::Tensor& v,
12+
std::optional<at::Tensor>& out_,
2113
const at::Tensor& cu_seqlens_q, // b+1
2214
const at::Tensor& cu_seqlens_k, // b+1
23-
std::optional<at::Tensor>&
24-
seqused_k, // b. If given, only this many elements of each batch
25-
// element's keys are used.
15+
std::optional<at::Tensor>& seqused_k,
2616
std::optional<const at::Tensor>& leftpad_k_, // batch_size
2717
at::Tensor& block_table_, // batch_size x max_num_blocks_per_seq
2818
std::optional<at::Tensor>& alibi_slopes_, // num_heads or b x num_heads
2919
int max_seqlen_q, int max_seqlen_k, float p_dropout, float softmax_scale,
30-
const bool zero_tensors, bool is_causal, int window_size_left,
31-
int window_size_right, const float softcap, const bool return_softmax,
20+
std::optional<const at::Tensor>& softmax_sink_, const bool zero_tensors,
21+
bool is_causal, int window_size_left, int window_size_right,
22+
const float softcap, const bool return_softmax,
3223
std::optional<at::Generator> gen_) {
3324
auto& queue = vllm::xpu::vllmGetQueue();
3425

@@ -39,9 +30,13 @@ std::vector<at::Tensor> mha_varlen_fwd(
3930
out = torch::empty_like(q);
4031
}
4132

33+
bool is_local = (window_size_left != -1) | (window_size_right != -1);
34+
bool is_sink = softmax_sink_.has_value();
35+
4236
cutlass_chunk_prefill_impl(queue, q, k, v, out, block_table_, cu_seqlens_q,
4337
cu_seqlens_k, max_seqlen_q, max_seqlen_k,
44-
softmax_scale, is_causal);
38+
softmax_scale, softmax_sink_, window_size_left,
39+
window_size_right, is_causal, is_local, is_sink);
4540

4641
if (return_softmax) {
4742
// FIXME: current do not support store softmax_lse out
@@ -61,7 +56,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
6156
"Tensor cu_seqlens_k, Tensor? seqused_k, Tensor? leftpad_k, Tensor "
6257
"block_table, Tensor? alibi_slopes, "
6358
"int max_seqlen_q, int max_seqlen_k, float p_dropout, float "
64-
"softmax_scale, bool zero_tensors, "
59+
"softmax_scale, Tensor? softmax_sink, bool zero_tensors, "
6560
"bool is_causal, int window_size_left, int window_size_right, float "
6661
"softcap, bool return_softmax, "
6762
"Generator? gen) -> Tensor[]");

csrc/xpu/cutlass_kernels/chunk_prefill.hpp

Lines changed: 55 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -36,16 +36,21 @@ struct chunk_prefill_args_t {
3636
int total_seqlen_q;
3737
int total_seqlen_k;
3838
float sm_scale;
39+
void* sm_sink;
3940
int batch_size;
4041
int num_heads_q;
4142
int num_heads_k;
4243
int head_size;
4344
int max_blocks_per_seq;
4445
int block_size;
45-
bool is_causal;
46+
int window_size_left = -1;
47+
int window_size_right = -1;
48+
bool is_causal = false;
49+
bool is_local = false;
50+
bool is_sink = false;
4651
};
4752

48-
template <class FMHAChunkPrefillKernel, bool isVarLen>
53+
template <class FMHAChunkPrefillKernel>
4954
struct KernelLauncher {
5055
using StrideQ = typename FMHAChunkPrefillKernel::StrideQ;
5156
using StrideK = typename FMHAChunkPrefillKernel::StrideK;
@@ -62,6 +67,7 @@ struct KernelLauncher {
6267
using ElementOutput = typename CollectiveEpilogue::ElementOutput;
6368
using ElementCompute = typename CollectiveEpilogue::ElementCompute;
6469
using ElementAccumulator = typename CollectiveEpilogue::ElementAccumulator;
70+
using ElementSink = typename CollectiveEpilogue::ElementSink;
6571

6672
using ProblemShapeType = typename FMHAChunkPrefillKernel::ProblemShape;
6773

@@ -120,9 +126,11 @@ struct KernelLauncher {
120126
reinterpret_cast<ElementK*>(args.key), stride_K_cache,
121127
reinterpret_cast<ElementV*>(args.value), stride_V_cache,
122128
static_cast<int*>(args.block_table), args.block_size,
123-
args.max_blocks_per_seq, args.total_seqlen_k, -1, -1},
129+
args.max_blocks_per_seq, args.total_seqlen_k, args.window_size_left,
130+
args.window_size_right},
124131
{args.sm_scale},
125-
{reinterpret_cast<ElementOutput*>(args.out), stride_O},
132+
{reinterpret_cast<ElementOutput*>(args.out), stride_O,
133+
reinterpret_cast<ElementSink*>(args.sm_sink)},
126134
hw_info};
127135

128136
// Define device-global scratch memory
@@ -186,28 +194,29 @@ template <typename TileShapeQK, typename TileShapePV, typename TileShapeOutput,
186194
typename ElementComputeEpilogue = float,
187195
typename GmemTiledCopyStore = XE_2D_U16x8x16_ST_N>
188196
struct FMHAKernel {
189-
template <bool isVarLen, bool Causal, bool PagedKV, bool Local,
190-
class Scheduler>
197+
template <class Scheduler, bool Causal, bool Local, bool Sink>
191198
static void run(sycl::queue& queue, const chunk_prefill_args_t& args) {
192199
cutlass::KernelHardwareInfo hw_info;
193200

201+
static constexpr bool PagedKV = true;
194202
using LayoutQ = cutlass::layout::RowMajor;
195203
using LayoutK = cutlass::layout::ColumnMajor;
196204
using LayoutV = cutlass::layout::RowMajor;
197205
using LayoutO = cutlass::layout::RowMajor;
198206

199207
using ElementInputKV = ElementInputQ;
200208
using ElementOutput = ElementInputQ;
209+
using ElementSink = ElementInputQ;
201210

202211
using GEMMDispatchPolicy =
203212
cutlass::gemm::MainloopIntelXeXMX16<PipelineStages>;
204213
using EpilogueDispatchPolicy = cutlass::epilogue::IntelXeXMX16;
205214
using CollectiveEpilogue =
206215
cutlass::flash_attention::collective::FlashChunkPrefillEpilogue<
207-
EpilogueDispatchPolicy, MMAOperation, TileShapeOutput,
216+
Sink, EpilogueDispatchPolicy, MMAOperation, TileShapeOutput,
208217
SubgroupLayout, ElementComputeEpilogue, ElementOutput,
209218
cutlass::gemm::TagToStrideC_t<LayoutO>, ElementOutput,
210-
GmemTiledCopyStore>;
219+
GmemTiledCopyStore, ElementSink>;
211220
using CollectiveSoftmaxEpilogue =
212221
cutlass::flash_attention::collective::FlashChunkPrefillSoftmaxEpilogue<
213222
Causal, Local, EpilogueDispatchPolicy, ElementAccumulator>;
@@ -216,8 +225,7 @@ struct FMHAKernel {
216225
using namespace cutlass::fmha::collective;
217226
using ProblemShapeVarlen =
218227
cute::tuple<int, int, int, VariableLength, VariableLength, int, int>;
219-
using ProblemShapeType =
220-
std::conditional_t<isVarLen, ProblemShapeVarlen, ProblemShapeRegular>;
228+
using ProblemShapeType = ProblemShapeVarlen;
221229

222230
// Mainloop
223231
using CollectiveMainloop =
@@ -237,18 +245,26 @@ struct FMHAKernel {
237245
ProblemShapeType, CollectiveMainloop, CollectiveSoftmaxEpilogue,
238246
CollectiveEpilogue, Scheduler>;
239247

240-
KernelLauncher<FMHAChunkPrefillKernel, isVarLen> launcher;
248+
KernelLauncher<FMHAChunkPrefillKernel> launcher;
241249

242250
launcher.run(queue, args, hw_info);
243251
}
244252

245-
static void dispatch(sycl::queue& queue, const chunk_prefill_args_t& args) {
246-
if (args.is_causal) {
247-
run<true, true, true, false,
248-
cutlass::flash_attention::IndividualScheduler>(queue, args);
253+
template <bool... Bs>
254+
static void kernel_dispatch(sycl::queue& queue,
255+
const chunk_prefill_args_t& args) {
256+
return run<cutlass::flash_attention::IndividualScheduler, Bs...>(queue,
257+
args);
258+
}
259+
260+
template <bool... Bs, typename... Ts>
261+
static void kernel_dispatch(sycl::queue& queue,
262+
const chunk_prefill_args_t& args, bool b,
263+
Ts... ts) {
264+
if (b) {
265+
kernel_dispatch<Bs..., true>(queue, args, ts...);
249266
} else {
250-
run<true, false, true, false,
251-
cutlass::flash_attention::IndividualScheduler>(queue, args);
267+
kernel_dispatch<Bs..., false>(queue, args, ts...);
252268
}
253269
}
254270
};
@@ -261,13 +277,17 @@ void policy_dispatch(sycl::queue& queue, CutlassType cuType,
261277
FMHAKernel<typename chunk_policy::ShapeQK, typename chunk_policy::ShapePV,
262278
typename chunk_policy::ShapeOutPut,
263279
typename chunk_policy::SubgroupLayout, PipelineStages,
264-
cutlass::half_t, XE_8x16x16_F32F16F16F32_TT>::dispatch(queue,
265-
args);
280+
cutlass::half_t,
281+
XE_8x16x16_F32F16F16F32_TT>::kernel_dispatch(queue, args,
282+
args.is_causal,
283+
args.is_local,
284+
args.is_sink);
266285
} else {
267286
FMHAKernel<typename chunk_policy::ShapeQK, typename chunk_policy::ShapePV,
268287
typename chunk_policy::ShapeOutPut,
269288
typename chunk_policy::SubgroupLayout,
270-
PipelineStages>::dispatch(queue, args);
289+
PipelineStages>::kernel_dispatch(queue, args, args.is_causal,
290+
args.is_local, args.is_sink);
271291
}
272292
}
273293

@@ -278,7 +298,9 @@ void cutlass_chunk_prefill_impl(
278298
const at::Tensor& value_cache, at::Tensor& out,
279299
const at::Tensor& block_table, const at::Tensor& cu_seqlens_q,
280300
const at::Tensor& cu_seqlens_k, int max_seqlen_q, int max_seqlen_k,
281-
double sm_scale, bool is_causal) {
301+
double sm_scale, std::optional<const at::Tensor>& sm_sink_,
302+
int window_size_left, int window_size_right, bool is_causal, bool is_local,
303+
bool is_sink) {
282304
int num_block = key_cache.size(0);
283305
int block_size = key_cache.size(1);
284306
int num_heads_q = query.size(1);
@@ -289,6 +311,12 @@ void cutlass_chunk_prefill_impl(
289311
int total_seqlen_q = query.size(0);
290312
int total_seqlen_k = num_block * block_size;
291313

314+
if (is_local) {
315+
window_size_left = window_size_left == -1 ? max_seqlen_k : window_size_left;
316+
window_size_right =
317+
window_size_right == -1 ? max_seqlen_k : window_size_right;
318+
}
319+
292320
chunk_prefill_args_t args = {query.data_ptr(),
293321
key_cache.data_ptr(),
294322
value_cache.data_ptr(),
@@ -301,13 +329,18 @@ void cutlass_chunk_prefill_impl(
301329
total_seqlen_q,
302330
total_seqlen_k,
303331
static_cast<float>(sm_scale),
332+
is_sink ? sm_sink_.value().data_ptr() : nullptr,
304333
batch_size,
305334
num_heads_q,
306335
num_heads_kv,
307336
head_size,
308337
max_blocks_per_seq,
309338
block_size,
310-
is_causal};
339+
window_size_left,
340+
window_size_right,
341+
is_causal,
342+
is_local,
343+
is_sink};
311344
CutlassType cuType = aten_to_Cutlass_dtype(query);
312345

313346
if (args.head_size == HEAD_SIZE_LIMIT_0) {

csrc/xpu/cutlass_kernels/chunk_prefill_kernel.hpp

Lines changed: 18 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,9 @@ class FMHAPrefillChunk {
100100
using EpilogueParams = typename CollectiveEpilogue::Params;
101101
using TileShapeOutput = typename CollectiveEpilogue::TileShapeOutput;
102102
using TiledMmaOutput = typename CollectiveEpilogue::TiledMmaOutput;
103+
// sink
104+
using ElementSink = typename CollectiveEpilogue::ElementSink;
105+
static constexpr bool Sink = CollectiveEpilogue::Sink;
103106
104107
static_assert(
105108
cute::is_same_v<ElementAccumulator,
@@ -111,7 +114,8 @@ class FMHAPrefillChunk {
111114
static constexpr bool CausalMask = CollectiveMainloop::CausalMask;
112115
static constexpr bool LocalMask = CollectiveMainloop::LocalMask;
113116
114-
static_assert(!(CausalMask && LocalMask), "Cannot be both causal and local");
117+
// static_assert(!(CausalMask && LocalMask), "Cannot be both causal and
118+
// local");
115119
static constexpr bool PagedKV = CollectiveMainloop::PagedKV;
116120

117121
static constexpr int SubgroupSize =
@@ -455,23 +459,23 @@ class FMHAPrefillChunk {
455459
if constexpr (LocalMask) {
456460
// mask the elements of each tile where j - left > i || j + right < i
457461
const int item_id = thread_idx % SubgroupSize;
458-
int col_idx = item_id;
459-
col_idx += split * cute::min(QK_BLK_N, seq_len_kv_cache);
462+
int col_idx = item_id + split * cute::min(QK_BLK_N, seq_len_kv_cache);
460463

461464
CUTLASS_PRAGMA_UNROLL
462465
for (int n = 0; n < FragsN;
463466
n++, col_idx += get<1>(MmaAtomShape())) { // 4
464467
CUTLASS_PRAGMA_UNROLL
465468
for (int m = 0; m < FragsM; m++) { // 2
466469
int row_idx = m * Vec + seq_coord;
470+
int col_ref = seq_len_kv_cache - seq_len_qo;
467471
CUTLASS_PRAGMA_UNROLL
468472
for (int row = 0; row < Vec; row++) { // 8
469473
bool left_mask =
470-
col_idx < cute::max(0, row + row_idx + seq_len_kv_cache -
474+
col_idx < cute::max(0, row + row_idx + col_ref -
471475
mainloop_params.window_left);
472476
bool right_mask =
473477
col_idx > cute::min(seq_len_kv_cache,
474-
row + row_idx + seq_len_kv_cache +
478+
row + row_idx + col_ref +
475479
mainloop_params.window_right);
476480
if (left_mask || right_mask) {
477481
tSr(row, m, n) = ElementAccumulator{-INFINITY};
@@ -544,8 +548,15 @@ class FMHAPrefillChunk {
544548
batch_coord, q_head_coord);
545549
CollectiveEpilogue epilogue{epilogue_params, shared_storage.epilogue};
546550
auto blk_coord_mnkl = make_coord(blk_m_coord, blk_n_coord, _, 0);
547-
epilogue(params.problem_shape, sequence_length_shape, blk_coord_mnkl,
548-
out_reg, max_reg, sum_reg);
551+
if constexpr (Sink) {
552+
ElementAccumulator max_scale{max_reg * params.softmax.scale};
553+
epilogue(params.problem_shape, sequence_length_shape, blk_coord_mnkl,
554+
out_reg, max_scale, sum_reg,
555+
params.epilogue.ptr_sink[q_head_coord]);
556+
} else {
557+
epilogue(params.problem_shape, sequence_length_shape, blk_coord_mnkl,
558+
out_reg, max_reg, sum_reg, 0);
559+
}
549560
}
550561
}
551562
};

0 commit comments

Comments
 (0)