Skip to content

Commit c9eedc2

Browse files
authored
Add chunk_prefill with cutlass backend (vllm-project#38)
* add chunk_prefill with cutlass backend Signed-off-by: Yizhou Wang <yizhou.wang@intel.com> * fix precommit error Signed-off-by: Yizhou Wang <yizhou.wang@intel.com> --------- Signed-off-by: Yizhou Wang <yizhou.wang@intel.com>
1 parent 2ad9a6e commit c9eedc2

19 files changed

Lines changed: 3004 additions & 9 deletions

CMakeLists.txt

Lines changed: 92 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,8 @@ set(SYCL_SUPPORTED_ARCHS "intel_gpu_pvc;intel_gpu_bmg_g21")
4747
#
4848
set(TORCH_SUPPORTED_VERSION_XPU "2.8.0")
4949

50+
set(FA2_ENABLED ON)
51+
5052
#
5153
# Try to find python package with an executable that exactly matches
5254
# `VLLM_PYTHON_EXECUTABLE` and is one of the supported versions.
@@ -155,12 +157,60 @@ if(VLLM_GPU_LANG STREQUAL "SYCL")
155157
"csrc/quantization/fp8/fp8_quant.cpp"
156158
)
157159
include_directories("/usr/include")
158-
set(CMPLR_ROOT $ENV{CMPLR_ROOT})
159-
set(CMAKE_CXX_COMPILER icpx)
160+
list(APPEND VLLM_INCLUDE_DIR ${CMPLR_ROOT}/include/)
161+
list(APPEND VLLM_INCLUDE_DIR ${CMPLR_ROOT}/include/sycl/)
162+
list(APPEND VLLM_INCLUDE_DIR ${CMPLR_ROOT}/include/syclcompat/)
163+
message(STATUS "VLLM_INCLUDE_DIR: ${VLLM_INCLUDE_DIR}")
160164
set(VLLM_EXTRA_INCLUDE_DIRECTORIES ${CMPLR_ROOT}/include/sycl)
161165
list(APPEND VLLM_GPU_FLAGS "-DVLLM_BUILD_XPU_OPS" )
162-
list(APPEND VLLM_GPU_LINK_FLAGS "-fsycl" "-fsycl-targets=spir64")
166+
list(APPEND VLLM_GPU_LINK_FLAGS "-fsycl" "-fsycl-targets=spir64" "-Xspirv-translator" "-spirv-ext=+SPV_INTEL_split_barrier")
163167
list(APPEND VLLM_LINK_LIBRARIES "sycl" "OpenCL" "pthread" "m" "dl" "torch" )
168+
169+
170+
# add cutlass dependency
171+
set(CUTLASS_ENABLE_HEADERS_ONLY "ON" CACHE BOOL "Enable only the header library")
172+
173+
# Set CUTLASS_REVISION. Used for FetchContent. Also fixes some bogus messages when building.
174+
set(CUTLASS_REVISION "main" CACHE STRING "CUTLASS revision to use")
175+
176+
# Use the specified CUTLASS source directory for compilation if VLLM_CUTLASS_SRC_DIR is provided
177+
FetchContent_Declare(
178+
cutlass-sycl
179+
GIT_REPOSITORY https://github.com/intel/cutlass-sycl
180+
# Please keep this in sync with CUTLASS_REVISION line above.
181+
GIT_TAG ${CUTLASS_REVISION}
182+
GIT_PROGRESS TRUE
183+
184+
# Speed up CUTLASS download by retrieving only the specified GIT_TAG instead of the history.
185+
# Important: If GIT_SHALLOW is enabled then GIT_TAG works only with branch names and tags.
186+
# So if the GIT_TAG above is updated to a commit hash, GIT_SHALLOW must be set to FALSE
187+
GIT_SHALLOW TRUE
188+
)
189+
190+
# cutlass compilation flags
191+
set(CUTLASS_ENABLE_SYCL "ON")
192+
# set(DPCPP_SYCL_TARGET "intel_gpu_pvc;intel_gpu_bmg_g21" CACHE STRING "DPC++ SYCL target architectures")
193+
set(CMAKE_EXPORT_COMPILE_COMMANDS "ON")
194+
set(CUTLASS_ENABLE_BENCHMARKS "OFF")
195+
# disable cuda
196+
set(CUTLASS_ENABLE_GDC_FOR_SM100_DEFAULT OFF CACHE BOOL "DISABLE CUDA")
197+
# list(APPEND CMAKE_CXX_FLAGS "-ftemplate-backtrace-limit=0 " )
198+
# list(APPEND CMAKE_CXX_FLAGS "-fdiagnostics-color=always " )
199+
200+
201+
FetchContent_MakeAvailable(cutlass-sycl)
202+
set(CUTLASS_INCLUDE_DIR ${cutlass-sycl_SOURCE_DIR}/include CACHE PATH "CUTLASS Header Library")
203+
set(CUTLASS_TOOLS_UTIL_INCLUDE_DIR ${cutlass-sycl_SOURCE_DIR}/tools/util/include CACHE INTERNAL "")
204+
set(CUTLASS_APP_INCLUDE_DIR ${cutlass-sycl_SOURCE_DIR}/applications CACHE INTERNAL "")
205+
message(STATUS "cutlass dir: ${CUTLASS_INCLUDE_DIR} and ${CUTLASS_TOOLS_UTIL_INCLUDE_DIR} and ${CUTLASS_APP_INCLUDE_DIR}")
206+
207+
# header only library
208+
list(APPEND VLLM_GPU_FLAGS "-DCUTLASS_ENABLE_SYCL")
209+
list(APPEND VLLM_GPU_FLAGS "-DSYCL_INTEL_TARGET")
210+
list(APPEND VLLM_GPU_FLAGS "-DCUTLASS_VERSIONS_GENERATED")
211+
list(APPEND VLLM_GPU_FLAGS "-ftemplate-backtrace-limit=0")
212+
list(APPEND VLLM_GPU_FLAGS "-fdiagnostics-color=always")
213+
164214
endif()
165215

166216
message(STATUS "Enabling C extension.")
@@ -174,9 +224,48 @@ define_gpu_extension_target(
174224
ARCHITECTURES ${VLLM_GPU_ARCHES}
175225
INCLUDE_DIRECTORIES ${CUTLASS_INCLUDE_DIR}
176226
INCLUDE_DIRECTORIES ${CUTLASS_TOOLS_UTIL_INCLUDE_DIR}
227+
INCLUDE_DIRECTORIES ${CUTLASS_APP_INCLUDE_DIR}
228+
INCLUDE_DIRECTORIES ${VLLM_INCLUDE_DIR}
177229
USE_SABI 3
178230
WITH_SOABI)
179231

232+
#
233+
# flash attention _C extension
234+
#
235+
236+
if (FA2_ENABLED)
237+
message(STATUS "Enabling fa2 extension.")
238+
file(GLOB FA2_GEN_SRCS "csrc/flash_attn/*.cpp")
239+
240+
set(CUTLASS_GPU_FLAGS ${VLLM_GPU_FLAGS})
241+
set(CUTLASS_LINK_FLAGS ${VLLM_GPU_LINK_FLAGS})
242+
243+
# XPU FLAGS
244+
list(APPEND CUTLASS_GPU_FLAGS "-O3" "-DNDEBUG")
245+
list(APPEND CUTLASS_GPU_FLAGS "-gline-tables-only")
246+
list(APPEND CUTLASS_GPU_FLAGS "-fsycl" "-fsycl-targets=spir64_gen" "-ftemplate-backtrace-limit=10")
247+
248+
list(APPEND CUTLASS_LINK_FLAGS "-fsycl" "-fsycl-targets=spir64_gen")
249+
list(APPEND CUTLASS_LINK_FLAGS -Xsycl-target-backend=spir64_gen "-device bmg-g21-a0 -internal_options -cl-intel-256-GRF-per-thread")
250+
251+
define_gpu_extension_target(
252+
_vllm_fa2_C
253+
DESTINATION vllm_xpu_kernels
254+
LANGUAGE ${VLLM_GPU_LANG}
255+
SOURCES
256+
csrc/flash_attn/flash_api.cpp
257+
${FA2_GEN_SRCS}
258+
COMPILE_FLAGS ${CUTLASS_GPU_FLAGS}
259+
LINK_FLAGS ${CUTLASS_LINK_FLAGS}
260+
ARCHITECTURES ${VLLM_GPU_ARCHES}
261+
INCLUDE_DIRECTORIES ${CUTLASS_INCLUDE_DIR}
262+
INCLUDE_DIRECTORIES ${CUTLASS_TOOLS_UTIL_INCLUDE_DIR}
263+
INCLUDE_DIRECTORIES ${CUTLASS_APP_INCLUDE_DIR}
264+
INCLUDE_DIRECTORIES ${VLLM_INCLUDE_DIR}
265+
USE_SABI 3
266+
WITH_SOABI)
267+
endif ()
268+
180269
#
181270
# xpu only ops/kernels, implemented with cutlass/onednn/sycl.
182271
#

cmake/toolchain.cmake

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
# use this file to set the compiler and flags for SYCL
2+
3+
set(CMPLR_ROOT $ENV{CMPLR_ROOT})
4+
message(STATUS "CMPLR_ROOT: ${CMPLR_ROOT}")
5+
set(CMAKE_CXX_COMPILER ${CMPLR_ROOT}/bin/icpx)
6+
set(CMAKE_C_COMPILER ${CMPLR_ROOT}/bin/icx)

csrc/core/registration.h

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -19,9 +19,16 @@
1919

2020
// REGISTER_EXTENSION allows the shared library to be loaded and initialized
2121
// via python's import statement.
22-
#define REGISTER_EXTENSION(NAME) \
23-
PyMODINIT_FUNC CONCAT(PyInit_, NAME)() { \
24-
static struct PyModuleDef module = {PyModuleDef_HEAD_INIT, \
25-
STRINGIFY(NAME), nullptr, 0, nullptr}; \
26-
return PyModule_Create(&module); \
22+
#define REGISTER_EXTENSION(NAME) \
23+
PyMODINIT_FUNC CONCAT(PyInit_, NAME)() { \
24+
static struct PyModuleDef module = {PyModuleDef_HEAD_INIT, \
25+
STRINGIFY(NAME), \
26+
nullptr, \
27+
0, \
28+
nullptr, \
29+
nullptr, \
30+
nullptr, \
31+
nullptr, \
32+
nullptr}; \
33+
return PyModule_Create(&module); \
2734
}

csrc/flash_attn/flash_api.cpp

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
#include "pytorch_shim.h"
2+
3+
#include "core/registration.h"
4+
#include "xpu/cutlass_kernels/chunk_prefill.hpp"
5+
#include "utils.h"
6+
#include <torch/all.h>
7+
8+
namespace FLASH_NAMESPACE {
9+
10+
std::vector<at::Tensor> mha_varlen_fwd(
11+
const at::Tensor&
12+
q, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i
13+
const at::Tensor& k, // total_k x num_heads_k x head_size, total_k :=
14+
// \sum_{i=0}^{b} s_i or num_blocks x page_block_size
15+
// x num_heads_k x head_size if there's a block_table.
16+
const at::Tensor& v, // total_k x num_heads_k x head_size, total_k :=
17+
// \sum_{i=0}^{b} s_i or num_blocks x page_block_size
18+
// x num_heads_k x head_size if there's a block_table.
19+
std::optional<at::Tensor>&
20+
out_, // total_q x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i
21+
const at::Tensor& cu_seqlens_q, // b+1
22+
const at::Tensor& cu_seqlens_k, // b+1
23+
std::optional<at::Tensor>&
24+
seqused_k, // b. If given, only this many elements of each batch
25+
// element's keys are used.
26+
std::optional<const at::Tensor>& leftpad_k_, // batch_size
27+
at::Tensor& block_table_, // batch_size x max_num_blocks_per_seq
28+
std::optional<at::Tensor>& alibi_slopes_, // num_heads or b x num_heads
29+
int max_seqlen_q, int max_seqlen_k, float p_dropout, float softmax_scale,
30+
const bool zero_tensors, bool is_causal, int window_size_left,
31+
int window_size_right, const float softcap, const bool return_softmax,
32+
std::optional<at::Generator> gen_) {
33+
auto& queue = vllm::xpu::vllmGetQueue();
34+
35+
at::Tensor out;
36+
if (out_.has_value()) {
37+
out = *out_;
38+
} else {
39+
out = torch::zeros_like(q);
40+
}
41+
42+
cutlass_chunk_prefill_impl(queue, q, k, v, out, block_table_, cu_seqlens_q,
43+
cu_seqlens_k, max_seqlen_q, max_seqlen_k,
44+
softmax_scale, is_causal);
45+
46+
if (return_softmax) {
47+
// FIXME: current do not support store softmax_lse out
48+
auto softmax_lse = torch::empty_like(out);
49+
return {out, softmax_lse};
50+
} else {
51+
at::Tensor softmax_lse;
52+
return {out, softmax_lse};
53+
}
54+
}
55+
} // namespace FLASH_NAMESPACE
56+
57+
TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
58+
ops.def(
59+
"varlen_fwd(Tensor q, Tensor k, Tensor v, Tensor!? out, Tensor "
60+
"cu_seqlens_q, "
61+
"Tensor cu_seqlens_k, Tensor? seqused_k, Tensor? leftpad_k, Tensor "
62+
"block_table, Tensor? alibi_slopes, "
63+
"int max_seqlen_q, int max_seqlen_k, float p_dropout, float "
64+
"softmax_scale, bool zero_tensors, "
65+
"bool is_causal, int window_size_left, int window_size_right, float "
66+
"softcap, bool return_softmax, "
67+
"Generator? gen) -> Tensor[]");
68+
ops.impl("varlen_fwd", torch::kXPU,
69+
make_pytorch_shim(&FLASH_NAMESPACE::mha_varlen_fwd));
70+
}
71+
72+
REGISTER_EXTENSION(TORCH_EXTENSION_NAME)

csrc/flash_attn/pytorch_shim.h

Lines changed: 110 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,110 @@
1+
#pragma once
2+
3+
#include <torch/library.h>
4+
5+
/**
6+
* Unfortunately, the type signatures of the flash_attn ops are not compatible
7+
* with the PyTorch library bindings. To get around that we use
8+
* `make_pytorch_shim` which creates a lambda that exponses the API using
9+
* PyTorch compatible types to the types, then converts them to the types
10+
* expected by the flash_attn ops. This shims allows us to make minimal changes
11+
* to `flash_api.cpp` making it easier to synchronize with upstream changes.
12+
*
13+
* The `pytorch_library_compatible_type` struct is used to map from the
14+
* flash_attn ops types to a PyTorch library compatible one. The main issues is
15+
* that the following types are not support by PyTorch library bindings:
16+
* - `int`
17+
* - `float`
18+
* - `std::optional<T> &`
19+
* - `std::optional<const at::Tensor> &`
20+
* So we convert them to (respectively):
21+
* - `int64_t`
22+
* - `double`
23+
* - `const std::optional<T>&`
24+
* - `const std::optional<at::Tensor>&`
25+
*/
26+
27+
template <typename T>
28+
struct pytorch_library_compatible_type {
29+
using type = T;
30+
static T convert_from_type(T arg) { return arg; }
31+
};
32+
33+
template <typename T>
34+
using pytorch_library_compatible_type_t =
35+
typename pytorch_library_compatible_type<T>::type;
36+
37+
template <typename T>
38+
T convert_from_pytorch_compatible_type(
39+
pytorch_library_compatible_type_t<T> arg) {
40+
return pytorch_library_compatible_type<T>::convert_from_type(arg);
41+
}
42+
43+
// Map `std::optional<T> &` -> `const std::optional<T>&`
44+
// (NOTE: this is bit unsafe but non of the ops in flash_attn mutate
45+
// the optional container)
46+
template <typename T>
47+
struct pytorch_library_compatible_type<std::optional<T>&> {
48+
using type = const std::optional<T>&;
49+
static std::optional<T>& convert_from_type(const std::optional<T>& arg) {
50+
return const_cast<std::optional<T>&>(arg);
51+
}
52+
};
53+
54+
// Map `std::optional<T>` ->
55+
// `std::optional<pytorch_library_compatible_type_t<T>>`
56+
// (NOTE: tested for `std::optional<int>` -> `std::optional<int64_t>`)
57+
template <typename T>
58+
struct pytorch_library_compatible_type<std::optional<T>> {
59+
using type = std::optional<pytorch_library_compatible_type_t<T>>;
60+
static std::optional<pytorch_library_compatible_type_t<T>> convert_from_type(
61+
std::optional<T> arg) {
62+
return arg;
63+
}
64+
};
65+
66+
// Map `std::optional<const at::Tensor>&` -> `const std::optional<at::Tensor>&`
67+
template <>
68+
struct pytorch_library_compatible_type<std::optional<const at::Tensor>&> {
69+
using type = const std::optional<at::Tensor>&;
70+
static std::optional<const at::Tensor>& convert_from_type(
71+
const std::optional<at::Tensor>& arg) {
72+
return const_cast<std::optional<const at::Tensor>&>(
73+
reinterpret_cast<const std::optional<const at::Tensor>&>(arg));
74+
}
75+
};
76+
77+
// Map `int` -> `int64_t`
78+
template <>
79+
struct pytorch_library_compatible_type<int> {
80+
using type = int64_t;
81+
static int convert_from_type(int64_t arg) {
82+
TORCH_CHECK(arg <= std::numeric_limits<int>::max(),
83+
"int64_t value is too large to be converted to int");
84+
TORCH_CHECK(arg >= std::numeric_limits<int>::min(),
85+
"int64_t value is too small to be converted to int");
86+
return arg;
87+
}
88+
};
89+
90+
// Map `float` -> `double`
91+
template <>
92+
struct pytorch_library_compatible_type<float> {
93+
using type = double;
94+
static float convert_from_type(double arg) {
95+
TORCH_CHECK(std::abs(arg) <= std::numeric_limits<float>::max(),
96+
"double value is too large to be converted to float");
97+
return arg;
98+
}
99+
};
100+
101+
//
102+
// Shim Utils
103+
//
104+
105+
template <typename Ret, typename... Args>
106+
auto make_pytorch_shim(Ret (*fun)(Args... args)) {
107+
return [fun](pytorch_library_compatible_type_t<Args>... args) {
108+
return fun(convert_from_pytorch_compatible_type<Args>(args)...);
109+
};
110+
}

0 commit comments

Comments
 (0)