@@ -36,16 +36,21 @@ struct chunk_prefill_args_t {
3636 int total_seqlen_q;
3737 int total_seqlen_k;
3838 float sm_scale;
39+ void * sm_sink;
3940 int batch_size;
4041 int num_heads_q;
4142 int num_heads_k;
4243 int head_size;
4344 int max_blocks_per_seq;
4445 int block_size;
45- bool is_causal;
46+ int window_size_left = -1 ;
47+ int window_size_right = -1 ;
48+ bool is_causal = false ;
49+ bool is_local = false ;
50+ bool is_sink = false ;
4651};
4752
48- template <class FMHAChunkPrefillKernel , bool isVarLen >
53+ template <class FMHAChunkPrefillKernel >
4954struct KernelLauncher {
5055 using StrideQ = typename FMHAChunkPrefillKernel::StrideQ;
5156 using StrideK = typename FMHAChunkPrefillKernel::StrideK;
@@ -62,6 +67,7 @@ struct KernelLauncher {
6267 using ElementOutput = typename CollectiveEpilogue::ElementOutput;
6368 using ElementCompute = typename CollectiveEpilogue::ElementCompute;
6469 using ElementAccumulator = typename CollectiveEpilogue::ElementAccumulator;
70+ using ElementSink = typename CollectiveEpilogue::ElementSink;
6571
6672 using ProblemShapeType = typename FMHAChunkPrefillKernel::ProblemShape;
6773
@@ -120,9 +126,11 @@ struct KernelLauncher {
120126 reinterpret_cast <ElementK*>(args.key ), stride_K_cache,
121127 reinterpret_cast <ElementV*>(args.value ), stride_V_cache,
122128 static_cast <int *>(args.block_table ), args.block_size ,
123- args.max_blocks_per_seq , args.total_seqlen_k , -1 , -1 },
129+ args.max_blocks_per_seq , args.total_seqlen_k , args.window_size_left ,
130+ args.window_size_right },
124131 {args.sm_scale },
125- {reinterpret_cast <ElementOutput*>(args.out ), stride_O},
132+ {reinterpret_cast <ElementOutput*>(args.out ), stride_O,
133+ reinterpret_cast <ElementSink*>(args.sm_sink )},
126134 hw_info};
127135
128136 // Define device-global scratch memory
@@ -186,28 +194,29 @@ template <typename TileShapeQK, typename TileShapePV, typename TileShapeOutput,
186194 typename ElementComputeEpilogue = float ,
187195 typename GmemTiledCopyStore = XE_2D_U16x8x16_ST_N>
188196struct FMHAKernel {
189- template <bool isVarLen, bool Causal, bool PagedKV, bool Local,
190- class Scheduler >
197+ template <class Scheduler , bool Causal, bool Local, bool Sink>
191198 static void run (sycl::queue& queue, const chunk_prefill_args_t & args) {
192199 cutlass::KernelHardwareInfo hw_info;
193200
201+ static constexpr bool PagedKV = true ;
194202 using LayoutQ = cutlass::layout::RowMajor;
195203 using LayoutK = cutlass::layout::ColumnMajor;
196204 using LayoutV = cutlass::layout::RowMajor;
197205 using LayoutO = cutlass::layout::RowMajor;
198206
199207 using ElementInputKV = ElementInputQ;
200208 using ElementOutput = ElementInputQ;
209+ using ElementSink = ElementInputQ;
201210
202211 using GEMMDispatchPolicy =
203212 cutlass::gemm::MainloopIntelXeXMX16<PipelineStages>;
204213 using EpilogueDispatchPolicy = cutlass::epilogue::IntelXeXMX16;
205214 using CollectiveEpilogue =
206215 cutlass::flash_attention::collective::FlashChunkPrefillEpilogue<
207- EpilogueDispatchPolicy, MMAOperation, TileShapeOutput,
216+ Sink, EpilogueDispatchPolicy, MMAOperation, TileShapeOutput,
208217 SubgroupLayout, ElementComputeEpilogue, ElementOutput,
209218 cutlass::gemm::TagToStrideC_t<LayoutO>, ElementOutput,
210- GmemTiledCopyStore>;
219+ GmemTiledCopyStore, ElementSink >;
211220 using CollectiveSoftmaxEpilogue =
212221 cutlass::flash_attention::collective::FlashChunkPrefillSoftmaxEpilogue<
213222 Causal, Local, EpilogueDispatchPolicy, ElementAccumulator>;
@@ -216,8 +225,7 @@ struct FMHAKernel {
216225 using namespace cutlass ::fmha::collective;
217226 using ProblemShapeVarlen =
218227 cute::tuple<int , int , int , VariableLength, VariableLength, int , int >;
219- using ProblemShapeType =
220- std::conditional_t <isVarLen, ProblemShapeVarlen, ProblemShapeRegular>;
228+ using ProblemShapeType = ProblemShapeVarlen;
221229
222230 // Mainloop
223231 using CollectiveMainloop =
@@ -237,18 +245,26 @@ struct FMHAKernel {
237245 ProblemShapeType, CollectiveMainloop, CollectiveSoftmaxEpilogue,
238246 CollectiveEpilogue, Scheduler>;
239247
240- KernelLauncher<FMHAChunkPrefillKernel, isVarLen > launcher;
248+ KernelLauncher<FMHAChunkPrefillKernel> launcher;
241249
242250 launcher.run (queue, args, hw_info);
243251 }
244252
245- static void dispatch (sycl::queue& queue, const chunk_prefill_args_t & args) {
246- if (args.is_causal ) {
247- run<true , true , true , false ,
248- cutlass::flash_attention::IndividualScheduler>(queue, args);
253+ template <bool ... Bs>
254+ static void kernel_dispatch (sycl::queue& queue,
255+ const chunk_prefill_args_t & args) {
256+ return run<cutlass::flash_attention::IndividualScheduler, Bs...>(queue,
257+ args);
258+ }
259+
260+ template <bool ... Bs, typename ... Ts>
261+ static void kernel_dispatch (sycl::queue& queue,
262+ const chunk_prefill_args_t & args, bool b,
263+ Ts... ts) {
264+ if (b) {
265+ kernel_dispatch<Bs..., true >(queue, args, ts...);
249266 } else {
250- run<true , false , true , false ,
251- cutlass::flash_attention::IndividualScheduler>(queue, args);
267+ kernel_dispatch<Bs..., false >(queue, args, ts...);
252268 }
253269 }
254270};
@@ -261,13 +277,17 @@ void policy_dispatch(sycl::queue& queue, CutlassType cuType,
261277 FMHAKernel<typename chunk_policy::ShapeQK, typename chunk_policy::ShapePV,
262278 typename chunk_policy::ShapeOutPut,
263279 typename chunk_policy::SubgroupLayout, PipelineStages,
264- cutlass::half_t , XE_8x16x16_F32F16F16F32_TT>::dispatch (queue,
265- args);
280+ cutlass::half_t ,
281+ XE_8x16x16_F32F16F16F32_TT>::kernel_dispatch (queue, args,
282+ args.is_causal ,
283+ args.is_local ,
284+ args.is_sink );
266285 } else {
267286 FMHAKernel<typename chunk_policy::ShapeQK, typename chunk_policy::ShapePV,
268287 typename chunk_policy::ShapeOutPut,
269288 typename chunk_policy::SubgroupLayout,
270- PipelineStages>::dispatch (queue, args);
289+ PipelineStages>::kernel_dispatch (queue, args, args.is_causal ,
290+ args.is_local , args.is_sink );
271291 }
272292}
273293
@@ -278,7 +298,9 @@ void cutlass_chunk_prefill_impl(
278298 const at::Tensor& value_cache, at::Tensor& out,
279299 const at::Tensor& block_table, const at::Tensor& cu_seqlens_q,
280300 const at::Tensor& cu_seqlens_k, int max_seqlen_q, int max_seqlen_k,
281- double sm_scale, bool is_causal) {
301+ double sm_scale, std::optional<const at::Tensor>& sm_sink_,
302+ int window_size_left, int window_size_right, bool is_causal, bool is_local,
303+ bool is_sink) {
282304 int num_block = key_cache.size (0 );
283305 int block_size = key_cache.size (1 );
284306 int num_heads_q = query.size (1 );
@@ -289,6 +311,12 @@ void cutlass_chunk_prefill_impl(
289311 int total_seqlen_q = query.size (0 );
290312 int total_seqlen_k = num_block * block_size;
291313
314+ if (is_local) {
315+ window_size_left = window_size_left == -1 ? max_seqlen_k : window_size_left;
316+ window_size_right =
317+ window_size_right == -1 ? max_seqlen_k : window_size_right;
318+ }
319+
292320 chunk_prefill_args_t args = {query.data_ptr (),
293321 key_cache.data_ptr (),
294322 value_cache.data_ptr (),
@@ -301,13 +329,18 @@ void cutlass_chunk_prefill_impl(
301329 total_seqlen_q,
302330 total_seqlen_k,
303331 static_cast <float >(sm_scale),
332+ is_sink ? sm_sink_.value ().data_ptr () : nullptr ,
304333 batch_size,
305334 num_heads_q,
306335 num_heads_kv,
307336 head_size,
308337 max_blocks_per_seq,
309338 block_size,
310- is_causal};
339+ window_size_left,
340+ window_size_right,
341+ is_causal,
342+ is_local,
343+ is_sink};
311344 CutlassType cuType = aten_to_Cutlass_dtype (query);
312345
313346 if (args.head_size == HEAD_SIZE_LIMIT_0) {
0 commit comments