Skip to content

Commit 2d03d25

Browse files
authored
[CHUNK_PREFILL] add policy 192 and check conditions (vllm-project#68)
* add policy 192 and check conditions Signed-off-by: Yizhou Wang <yizhou.wang@intel.com> * pre-commit Signed-off-by: Yizhou Wang <yizhou.wang@intel.com> * solve comments Signed-off-by: Yizhou Wang <yizhou.wang@intel.com> --------- Signed-off-by: Yizhou Wang <yizhou.wang@intel.com>
1 parent cc369d2 commit 2d03d25

5 files changed

Lines changed: 62 additions & 3 deletions

File tree

csrc/flash_attn/flash_api.cpp

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,44 @@ std::vector<at::Tensor> mha_varlen_fwd(
2121
bool is_causal, int window_size_left, int window_size_right,
2222
const float softcap, const bool return_softmax,
2323
std::optional<at::Generator> gen_) {
24+
auto q_type = q.scalar_type();
25+
TORCH_CHECK(
26+
q_type == at::ScalarType::Half || q_type == at::ScalarType::BFloat16,
27+
"VLLM Kernel XPU only supports fp16 and bf16 type");
28+
29+
TORCH_CHECK(k.scalar_type() == q_type,
30+
"query and key must have the same dtype");
31+
TORCH_CHECK(v.scalar_type() == q_type,
32+
"query and value must have the same dtype");
33+
34+
CHECK_DEVICE(q);
35+
CHECK_DEVICE(k);
36+
CHECK_DEVICE(v);
37+
38+
TORCH_CHECK(q.stride(-1) == 1,
39+
"Input tensor must have contiguous last dimension");
40+
TORCH_CHECK(k.stride(-1) == 1,
41+
"Input tensor must have contiguous last dimension");
42+
TORCH_CHECK(v.stride(-1) == 1,
43+
"Input tensor must have contiguous last dimension");
44+
TORCH_CHECK(q.dim() == 3, "query must be in ragged format");
45+
46+
CHECK_DEVICE(block_table_);
47+
TORCH_CHECK(block_table_.dtype() == torch::kInt32,
48+
"page_table must have dtype torch.int32");
49+
TORCH_CHECK(block_table_.stride(-1) == 1,
50+
"page_table must have contiguous last dimension");
51+
52+
CHECK_DEVICE(cu_seqlens_q);
53+
CHECK_CONTIGUOUS(cu_seqlens_q);
54+
TORCH_CHECK(cu_seqlens_q.dtype() == torch::kInt32,
55+
"cu_seqlens_q must have dtype torch.int32");
56+
57+
CHECK_DEVICE(cu_seqlens_k);
58+
CHECK_CONTIGUOUS(cu_seqlens_k);
59+
TORCH_CHECK(cu_seqlens_k.dtype() == torch::kInt32,
60+
"cu_seqlens_k must have dtype torch.int32");
61+
2462
auto& queue = vllm::xpu::vllmGetQueue(q.device().index());
2563

2664
at::Tensor out;

csrc/utils.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,10 @@
44
#include <c10/xpu/XPUStream.h>
55
#include <sycl/sycl.hpp>
66

7+
#define CHECK_DEVICE(x) TORCH_CHECK(x.is_xpu(), #x " must be on XPU")
8+
#define CHECK_CONTIGUOUS(x) \
9+
TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
10+
711
namespace vllm {
812
namespace xpu {
913

csrc/xpu/cutlass_kernels/chunk_prefill.hpp

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -343,11 +343,20 @@ void cutlass_chunk_prefill_impl(
343343
is_sink};
344344
CutlassType cuType = aten_to_Cutlass_dtype(query);
345345

346+
static constexpr int max_head_size = 256;
347+
TORCH_CHECK(head_size <= max_head_size,
348+
"FMHA forward only supports head dimension at most " +
349+
std::to_string(max_head_size));
350+
346351
if (args.head_size == HEAD_SIZE_LIMIT_0) {
347352
policy_dispatch<chunk_policy_head64>(queue, cuType, args);
348353
} else if (args.head_size == HEAD_SIZE_LIMIT_1) {
349354
policy_dispatch<chunk_policy_head128>(queue, cuType, args);
350355
} else if (args.head_size == HEAD_SIZE_LIMIT_2) {
356+
policy_dispatch<chunk_policy_head192>(queue, cuType, args);
357+
} else if (args.head_size == HEAD_SIZE_LIMIT_3) {
351358
policy_dispatch<chunk_policy_head256>(queue, cuType, args);
359+
} else {
360+
TORCH_CHECK(false, "Unsupported head size for fmha");
352361
}
353362
}

csrc/xpu/cutlass_kernels/fmha_utils.hpp

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,9 @@
44

55
#define HEAD_SIZE_LIMIT_0 64
66
#define HEAD_SIZE_LIMIT_1 128
7-
#define HEAD_SIZE_LIMIT_2 256
8-
#define HEAD_SIZE_LIMIT_3 512
7+
#define HEAD_SIZE_LIMIT_2 192
8+
#define HEAD_SIZE_LIMIT_3 256
9+
#define HEAD_SIZE_LIMIT_4 512
910

1011
enum class CutlassType {
1112
half,
@@ -40,6 +41,13 @@ struct chunk_policy_head128 {
4041
using SubgroupLayout = Layout<Shape<_16, _1, _1>, Stride<_1, _1, _1>>;
4142
};
4243

44+
struct chunk_policy_head192 {
45+
using ShapeQK = Shape<_256, _64, _64>;
46+
using ShapePV = Shape<_256, _32, _64>;
47+
using ShapeOutPut = Shape<_256, _192, _64>;
48+
using SubgroupLayout = Layout<Shape<_32, _1, _1>, Stride<_1, _1, _1>>;
49+
};
50+
4351
struct chunk_policy_head256 {
4452
using ShapeQK = Shape<_256, _64, _64>;
4553
using ShapePV = Shape<_256, _32, _64>;

tests/flash_attn/test_flash_attn_varlen_func.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from vllm_xpu_kernels.flash_attn_interface import flash_attn_varlen_func
1010

1111
NUM_HEADS = [(4, 4), (8, 2)]
12-
HEAD_SIZES = [64, 128, 256]
12+
HEAD_SIZES = [64, 128, 192, 256]
1313
BLOCK_SIZES = [64]
1414
DTYPES = [torch.bfloat16, torch.half]
1515
QDTYPES = [None]

0 commit comments

Comments
 (0)