@@ -21,6 +21,44 @@ std::vector<at::Tensor> mha_varlen_fwd(
2121 bool is_causal, int window_size_left, int window_size_right,
2222 const float softcap, const bool return_softmax,
2323 std::optional<at::Generator> gen_) {
24+ auto q_type = q.scalar_type ();
25+ TORCH_CHECK (
26+ q_type == at::ScalarType::Half || q_type == at::ScalarType::BFloat16,
27+ " VLLM Kernel XPU only supports fp16 and bf16 type" );
28+
29+ TORCH_CHECK (k.scalar_type () == q_type,
30+ " query and key must have the same dtype" );
31+ TORCH_CHECK (v.scalar_type () == q_type,
32+ " query and value must have the same dtype" );
33+
34+ CHECK_DEVICE (q);
35+ CHECK_DEVICE (k);
36+ CHECK_DEVICE (v);
37+
38+ TORCH_CHECK (q.stride (-1 ) == 1 ,
39+ " Input tensor must have contiguous last dimension" );
40+ TORCH_CHECK (k.stride (-1 ) == 1 ,
41+ " Input tensor must have contiguous last dimension" );
42+ TORCH_CHECK (v.stride (-1 ) == 1 ,
43+ " Input tensor must have contiguous last dimension" );
44+ TORCH_CHECK (q.dim () == 3 , " query must be in ragged format" );
45+
46+ CHECK_DEVICE (block_table_);
47+ TORCH_CHECK (block_table_.dtype () == torch::kInt32 ,
48+ " page_table must have dtype torch.int32" );
49+ TORCH_CHECK (block_table_.stride (-1 ) == 1 ,
50+ " page_table must have contiguous last dimension" );
51+
52+ CHECK_DEVICE (cu_seqlens_q);
53+ CHECK_CONTIGUOUS (cu_seqlens_q);
54+ TORCH_CHECK (cu_seqlens_q.dtype () == torch::kInt32 ,
55+ " cu_seqlens_q must have dtype torch.int32" );
56+
57+ CHECK_DEVICE (cu_seqlens_k);
58+ CHECK_CONTIGUOUS (cu_seqlens_k);
59+ TORCH_CHECK (cu_seqlens_k.dtype () == torch::kInt32 ,
60+ " cu_seqlens_k must have dtype torch.int32" );
61+
2462 auto & queue = vllm::xpu::vllmGetQueue (q.device ().index ());
2563
2664 at::Tensor out;
0 commit comments