Skip to content

Commit fc2065f

Browse files
authored
[CHUNK_PREFILL] new api refactor phase3 (vllm-project#90)
* add local attn(sliding window) * solve template issues * add asm barrier and fence --------- Signed-off-by: Yizhou Wang <yizhou.wang@intel.com>
1 parent 29c5ee0 commit fc2065f

5 files changed

Lines changed: 110 additions & 58 deletions

File tree

csrc/xpu/cutlass_kernels/chunk_prefill.hpp

Lines changed: 13 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -148,7 +148,9 @@ struct KernelLauncher {
148148
static_cast<int*>(args.block_table),
149149
args.block_size,
150150
args.max_blocks_per_seq,
151-
args.total_seqlen_k},
151+
args.total_seqlen_k,
152+
args.window_size_left,
153+
args.window_size_right},
152154
{},
153155
hw_info};
154156

@@ -231,14 +233,10 @@ struct FMHAConfig {
231233
decltype(cutlass::fmha::collective::get_sg_layout_pv(SubgroupLayoutQK{})),
232234
SubgroupLayoutPV_>;
233235

234-
template <
235-
class Scheduler,
236-
bool VarLen,
237-
bool Paged,
238-
bool Causal,
239-
bool Local,
240-
bool Sink>
236+
template <class Scheduler, bool Causal, bool Local, bool Sink>
241237
static void run(sycl::queue& queue, const chunk_prefill_args_t& args) {
238+
constexpr bool VarLen = true;
239+
constexpr bool Paged = true;
242240
cutlass::KernelHardwareInfo hw_info;
243241

244242
using ProblemShapeType = cutlass::fmha::kernel::FMHAProblemShape<VarLen>;
@@ -273,6 +271,7 @@ struct FMHAConfig {
273271
using CollectiveMainloop = cutlass::fmha::collective::FMHAFwdMainloop<
274272
MainloopDispatchPolicy,
275273
Causal,
274+
Local,
276275
Paged,
277276
TiledMMAQK,
278277
TiledMMAPV,
@@ -338,13 +337,7 @@ void policy_dispatch(
338337
half_t,
339338
half_t>::
340339
kernel_dispatch(
341-
queue,
342-
args,
343-
true, // args.is_varlen,
344-
true, // args.is_paged,
345-
args.is_causal,
346-
false, // args.is_local,
347-
args.is_sink);
340+
queue, args, args.is_causal, args.is_local, args.is_sink);
348341
} else {
349342
return FMHAConfig<
350343
typename chunk_policy::ShapeQK,
@@ -354,13 +347,7 @@ void policy_dispatch(
354347
void,
355348
PipelineStages>::
356349
kernel_dispatch(
357-
queue,
358-
args,
359-
true, // args.is_varlen,
360-
true, // args.is_paged,
361-
args.is_causal,
362-
false, // args.is_local,
363-
args.is_sink);
350+
queue, args, args.is_causal, args.is_local, args.is_sink);
364351
}
365352
}
366353

@@ -418,6 +405,10 @@ void cutlass_chunk_prefill_impl(
418405
window_size_left = window_size_left == -1 ? max_seqlen_k : window_size_left;
419406
window_size_right =
420407
window_size_right == -1 ? max_seqlen_k : window_size_right;
408+
if (is_causal) {
409+
window_size_right = 0;
410+
is_causal = false;
411+
}
421412
}
422413

423414
chunk_prefill_args_t args = {

csrc/xpu/cutlass_kernels/chunk_prefill_kernel.hpp

Lines changed: 21 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,7 @@ class XeFMHAFwdKernel {
115115
// Template Features
116116
static constexpr bool PagedKV = CollectiveMainloop::PagedKV;
117117
static constexpr bool CausalMask = CollectiveMainloop::CausalMask;
118+
static constexpr bool LocalMask = CollectiveMainloop::LocalMask;
118119
static constexpr bool Sink = CollectiveEpilogue::Sink;
119120
using ElementSink = typename CollectiveEpilogue::ElementSink;
120121

@@ -246,7 +247,7 @@ class XeFMHAFwdKernel {
246247
auto [seq_len_qo, seq_len_kv] = sequence_length_shape;
247248
if (blk_q * get<0>(TileShapeQK{}) >= seq_len_qo) continue;
248249

249-
auto offset = cute::min(seq_len_qo, seq_len_kv);
250+
auto offset = seq_len_qo;
250251
auto discard_seq_coord = seq_len_qo - offset;
251252
auto full_tile_offset = seq_len_kv - offset;
252253
int seq_coord =
@@ -256,13 +257,25 @@ class XeFMHAFwdKernel {
256257
// calc sg level seq_len_kv
257258
const int seq_len =
258259
CausalMask
259-
? full_tile_offset +
260-
cute::min(seq_len_kv, seq_coord - discard_seq_coord) +
261-
q_sg_tile
260+
? LocalMask
261+
? cute::min(
262+
seq_len_kv,
263+
full_tile_offset + seq_coord + q_sg_tile +
264+
params.mainloop.local_right)
265+
: cute::min(
266+
seq_len_kv, full_tile_offset + seq_coord + q_sg_tile)
262267
: seq_len_kv;
268+
const int k_block0 =
269+
LocalMask
270+
? cute::max(
271+
seq_coord + full_tile_offset - params.mainloop.local_left,
272+
0) /
273+
get<1>(TileShapeQK{})
274+
: 0;
263275
const int k_blocks = cute::ceil_div(seq_len, get<1>(TileShapeQK{}));
264-
const int k_causal_blocks =
265-
CausalMask ? (seq_len - q_sg_tile) / get<1>(TileShapeQK{}) : 0;
276+
const int k_blocks_causal =
277+
CausalMask ? (seq_coord + full_tile_offset) / get<1>(TileShapeQK{})
278+
: 0;
266279

267280
int offset_q = 0, offset_k = 0, offset_v = 0, offset_o = 0;
268281
if constexpr (is_var_len) {
@@ -330,9 +343,9 @@ class XeFMHAFwdKernel {
330343
tA_sum,
331344
blk_qv,
332345
idx_b,
333-
0,
346+
k_block0,
334347
k_blocks,
335-
k_causal_blocks,
348+
k_blocks_causal,
336349
thr_id,
337350
seq_len,
338351
full_tile_offset,

csrc/xpu/cutlass_kernels/collective/chunk_prefill_mainloop.hpp

Lines changed: 55 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -50,13 +50,25 @@ class XeDefault {}; // Default FMHA mainloop, P in registers.
5050

5151
namespace cutlass::fmha::collective {
5252

53+
static inline void sbarrier_wait() { asm volatile("sbarrier.wait\n"); }
54+
55+
static inline void sbarrier_signal() { asm volatile("sbarrier.signal\n"); }
56+
57+
static inline void gfence() { asm volatile("lsc_fence.ugm.none.group\n"); }
58+
59+
static inline void barrier() {
60+
asm volatile("lsc_fence.ugm.none.group\n");
61+
asm volatile("barrier\n");
62+
}
63+
5364
using namespace cute;
5465

5566
/////////////////////////////////////////////////////////////////////////////////////////////////
5667

5768
template <
5869
class DispatchPolicy_,
5970
bool CausalMask_,
71+
bool LocalMask_,
6072
bool PagedKV_,
6173
class TiledMMAQK_, // Tiling for Q*K GEMM
6274
class TiledMMAPV_, // Tiling for P*V GEMM
@@ -78,6 +90,7 @@ struct FMHAFwdMainloop {
7890
template <
7991
int Stages,
8092
bool CausalMask_,
93+
bool LocalMask_,
8194
bool PagedKV_,
8295
class TiledMMAQK_,
8396
class TiledMMAPV_,
@@ -91,6 +104,7 @@ template <
91104
struct FMHAFwdMainloop<
92105
XeDefault<Stages>,
93106
CausalMask_,
107+
LocalMask_,
94108
PagedKV_,
95109
TiledMMAQK_,
96110
TiledMMAPV_,
@@ -165,6 +179,7 @@ struct FMHAFwdMainloop<
165179
using ElementA = typename TiledMMAPV::ValTypeD;
166180

167181
static constexpr bool CausalMask = CausalMask_;
182+
static constexpr bool LocalMask = LocalMask_;
168183
static constexpr bool PagedKV = PagedKV_;
169184

170185
// User-facing arguments
@@ -176,6 +191,8 @@ struct FMHAFwdMainloop<
176191
int page_size;
177192
int max_pages_per_seq;
178193
int total_seqlen_kv;
194+
// Local Mask
195+
int local_left, local_right;
179196
};
180197

181198
// Kernel-facing parameters
@@ -201,7 +218,9 @@ struct FMHAFwdMainloop<
201218
args.ptr_page_table,
202219
args.page_size,
203220
args.max_pages_per_seq,
204-
args.total_seqlen_kv};
221+
args.total_seqlen_kv,
222+
args.local_left,
223+
args.local_right};
205224
}
206225

207226
CUTLASS_HOST_DEVICE static bool can_implement(Arguments const&) {
@@ -312,28 +331,29 @@ struct FMHAFwdMainloop<
312331

313332
// PagedKV
314333
int tiles_per_page = params.page_size / get<1>(TileShapeQK{});
315-
int page_idx, next_page_idx;
334+
int page_idx = blk_k0, next_page_idx;
316335
int b_offset = idx_b * params.max_pages_per_seq;
317336
if constexpr (PagedKV) {
318-
page_idx = params.ptr_page_table[b_offset] * tiles_per_page;
337+
int page_local_idx = page_idx * get<1>(TileShapeQK{}) / params.page_size;
338+
page_idx =
339+
params.ptr_page_table[b_offset + page_local_idx] * tiles_per_page +
340+
page_idx % tiles_per_page;
319341
}
320342

321343
/* Initialization steps for first block: Q/K prefetch, O init */
322344
/* TODO: limit D prefetch for large head size, and reorder K prefetches */
323-
if (blk_k0 == 0) {
324-
for (int D = 0; D < size<3>(pQgQ); D++) {
325-
prefetch(prefetch_q, pQgQ(_, _, _, D));
326-
}
327-
328-
for (int D = 0; D < size<4>(pKgK); D++) {
329-
prefetch(prefetch_k, pKgK(_, _, _, page_idx, D));
330-
}
345+
for (int D = 0; D < size<3>(pQgQ); D++) {
346+
prefetch(prefetch_q, pQgQ(_, _, _, D));
347+
}
331348

332-
clear(tArA);
333-
fill(tA_max, cutlass::platform::numeric_limits<ElementA>::lowest());
334-
clear(tA_sum);
349+
for (int D = 0; D < size<4>(pKgK); D++) {
350+
prefetch(prefetch_k, pKgK(_, _, _, page_idx, D));
335351
}
336352

353+
clear(tArA);
354+
fill(tA_max, cutlass::platform::numeric_limits<ElementA>::lowest());
355+
clear(tA_sum);
356+
337357
/* Check if */
338358
bool check_remainder_k = (seq_len % get<1>(TileShapeQK{}) != 0);
339359

@@ -379,6 +399,23 @@ struct FMHAFwdMainloop<
379399
}
380400
}
381401
}
402+
/* Local masking */
403+
if constexpr (LocalMask) {
404+
Tensor cPgP = make_identity_tensor(make_shape(seq_len, seq_len));
405+
Tensor gP = local_tile(
406+
cPgP, take<0, 2>(TileShapeQK{}), make_coord(get<0>(blk_qv), K));
407+
auto cS_thread = thr_mma_qk.partition_C(gP);
408+
CUTLASS_PRAGMA_UNROLL
409+
for (int i = 0; i < tSrS.size(); ++i) {
410+
int row_idx = get<0>(cS_thread(i)) - discard_seq_coord;
411+
int col_idx = get<1>(cS_thread(i)) - full_tile_offset;
412+
bool left_mask = col_idx < row_idx - params.local_left;
413+
bool right_mask = col_idx > row_idx + params.local_right;
414+
if (left_mask || right_mask) {
415+
tSrS(i) = ElementS(-INFINITY);
416+
}
417+
}
418+
}
382419
/* k masking for remainder tiles */
383420
if (check_remainder_k && K == blk_k1 - 1) {
384421
FragSCol k_rem_mask;
@@ -406,7 +443,8 @@ struct FMHAFwdMainloop<
406443
cute::gemm(mma_pv, tArP, tArV, tArA(_, _, _, VV));
407444
}
408445

409-
sycl::group_barrier(compat::get_nd_item<1>().get_group());
446+
// sycl::group_barrier(compat::get_nd_item<1>().get_group());
447+
barrier();
410448

411449
// next paged_idx
412450
next_page_idx = K + 1;
@@ -456,9 +494,10 @@ struct FMHAFwdMainloop<
456494

457495
/* Scale S and subtract maxima, then exponentiate */
458496
CUTLASS_PRAGMA_UNROLL
459-
for (int i = 0; i < tS.size(); i++)
497+
for (int i = 0; i < tS.size(); i++) {
460498
tS(i) = sycl::native::exp2(
461499
params.scale * tS(i) - broadcast<0>(tS_max, tS, i));
500+
}
462501

463502
/* Rescale existing S sums and O accumulator */
464503
if (!first_block) {

csrc/xpu/cutlass_kernels/fmha_utils.hpp

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,8 @@ inline CutlassType aten_to_Cutlass_dtype(const at::Tensor& input) {
2828

2929
using namespace cute;
3030
struct chunk_policy_head64 {
31-
using ShapeQK = Shape<_128, _64, _32>;
32-
using ShapePV = Shape<_128, _32, _64>;
31+
using ShapeQK = Shape<_128, _32, _32>;
32+
using ShapePV = Shape<_128, _32, _32>;
3333
using ShapeOut = Shape<_128, _64>;
3434
using SubgroupLayoutQK = Layout<Shape<_8, _1, _1>>;
3535
};
@@ -49,15 +49,15 @@ struct chunk_policy_head128 {
4949
};
5050

5151
struct chunk_policy_head192 {
52-
using ShapeQK = Shape<_256, _64, _32>;
53-
using ShapePV = Shape<_256, _32, _64>;
52+
using ShapeQK = Shape<_256, _32, _32>;
53+
using ShapePV = Shape<_256, _32, _32>;
5454
using ShapeOut = Shape<_256, _192>;
5555
using SubgroupLayoutQK = Layout<Shape<_32, _1, _1>>;
5656
};
5757

5858
struct chunk_policy_head256 {
59-
using ShapeQK = Shape<_256, _64, _32>;
60-
using ShapePV = Shape<_256, _32, _64>;
59+
using ShapeQK = Shape<_256, _32, _32>;
60+
using ShapePV = Shape<_256, _32, _32>;
6161
using ShapeOut = Shape<_256, _256>;
6262
using SubgroupLayoutQK = Layout<Shape<_32, _1, _1>>;
6363
};

tests/flash_attn/test_flash_attn_varlen_func.py

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
# one value small enough to test the schema op check
1919
NUM_BLOCKS = [32768, 2048]
2020
SOFT_CAPS = [None]
21-
SLIDING_WINDOWS = [(-1, 2), (2, -1), (11, 3), (-1, -1)]
21+
SLIDING_WINDOWS = [(-1, 127), (127, -1), (127, 127), (-1, -1)]
2222
SINK = [False, True]
2323
CASUAL = [False, True]
2424

@@ -56,8 +56,10 @@ def ref_paged_attn(query: torch.Tensor,
5656
v = v[:kv_len]
5757

5858
if q.shape[1] != k.shape[1]:
59-
k = torch.repeat_interleave(k, q.shape[1] // k.shape[1], dim=1)
60-
v = torch.repeat_interleave(v, q.shape[1] // v.shape[1], dim=1)
59+
k = torch.repeat_interleave(k, q.shape[1] // k.shape[1],
60+
dim=1).contiguous()
61+
v = torch.repeat_interleave(v, q.shape[1] // v.shape[1],
62+
dim=1).contiguous()
6163
attn = torch.einsum("qhd,khd->hqk", q, k).float()
6264
empty_mask = torch.ones(query_len, kv_len)
6365
mask = torch.triu(empty_mask, diagonal=kv_len - query_len + 1).bool()
@@ -111,7 +113,7 @@ def ref_paged_attn(query: torch.Tensor,
111113
@pytest.mark.parametrize("num_heads", NUM_HEADS)
112114
@pytest.mark.parametrize("head_size", HEAD_SIZES)
113115
@pytest.mark.parametrize("block_size", BLOCK_SIZES)
114-
@pytest.mark.parametrize("window_size", [(-1, -1)])
116+
@pytest.mark.parametrize("window_size", SLIDING_WINDOWS)
115117
@pytest.mark.parametrize("dtype", DTYPES)
116118
@pytest.mark.parametrize("soft_cap", SOFT_CAPS)
117119
@pytest.mark.parametrize("num_blocks", NUM_BLOCKS)
@@ -135,15 +137,20 @@ def test_varlen_with_paged_kv(
135137
is_casual: bool,
136138
) -> None:
137139
torch.set_default_device("xpu")
140+
torch.xpu.set_device("xpu:0")
138141
# # FIXME: remove skip
139142
if (is_casual and seq_lens[1][0]
140143
== 5) and (os.getenv("SKIP_HANG_KERNEL") is not None
141144
and os.getenv("SKIP_HANG_KERNEL") == "1"):
142145
pytest.skip("skip casual for seqlen0 to avoid runtime hang on CI.")
146+
if (window_size[0] != -1 or window_size[1]
147+
!= -1) and (os.getenv("SKIP_HANG_KERNEL") is not None
148+
and os.getenv("SKIP_HANG_KERNEL") == "1"):
149+
pytest.skip("skip local attn to avoid runtime hang on CI.")
143150
# if q_dtype is not None and (dtype != torch.bfloat16 or fa_version == 2):
144151
# pytest.skip("Flash attention with quantized inputs is only "
145152
# "supported on version 3 with bfloat16 base type")
146-
torch.manual_seed(0)
153+
torch.manual_seed(42)
147154
num_seqs = len(seq_lens)
148155
query_lens = [x[0] for x in seq_lens]
149156
kv_lens = [x[1] for x in seq_lens]
@@ -221,8 +228,10 @@ def test_varlen_with_paged_kv(
221228
sink=sink,
222229
window_size_left=window_size[0],
223230
window_size_right=window_size[1])
224-
atol, rtol = 1.5e-2, 1e-2
231+
atol, rtol = 1e-2, 1e-2
225232
if q_dtype is not None:
226233
atol, rtol = 1.5e-1, 1.5e-1
234+
if window_size[0] != -1 or window_size[1] != -1:
235+
atol, rtol = 1.5e-2, 1.5e-2
227236
torch.testing.assert_close(output, ref_output, atol=atol, rtol=rtol), \
228237
f"{torch.max(torch.abs(output - ref_output))}"

0 commit comments

Comments
 (0)