Skip to content

Commit c20e99e

Browse files
committed
selective build sycltla kernel
Signed-off-by: Xinyu Chen <xinyu1.chen@intel.com>
1 parent adfc54b commit c20e99e

File tree

5 files changed

+59
-18
lines changed

5 files changed

+59
-18
lines changed

CMakeLists.txt

Lines changed: 37 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,9 @@ option(VLLM_XPU_ENABLE_XE_DEFAULT "Enable XE Default architecture kernels" ON)
5959
option(BASIC_KERNELS_ENABLED "Build basic kernels (_C extension)" ON)
6060
option(FA2_KERNELS_ENABLED
6161
"Build Flash Attention 2 kernels (_vllm_fa2_C extension)" ON)
62-
option(MOE_KERNELS_ENABLED "Build MoE kernels (_moe_C extension)" ON)
62+
option(MOE_KERNELS_ENABLED
63+
"Build MoE kernels (_moe_C extension + grouped_gemm TLA)" ON)
64+
option(GDN_KERNELS_ENABLED "Build GDN attention kernels (gdn_attn TLA)" ON)
6365
option(XPU_SPECIFIC_KERNELS_ENABLED
6466
"Build XPU-specific kernels (_xpu_C extension)" ON)
6567
option(XPUMEM_ALLOCATOR_ENABLED "Build xpumem_allocator extension" ON)
@@ -72,6 +74,7 @@ message(STATUS " VLLM_XPU_ENABLE_XE_DEFAULT = ${VLLM_XPU_ENABLE_XE_DEFAULT}")
7274
message(STATUS " BASIC_KERNELS_ENABLED = ${BASIC_KERNELS_ENABLED}")
7375
message(STATUS " FA2_KERNELS_ENABLED = ${FA2_KERNELS_ENABLED}")
7476
message(STATUS " MOE_KERNELS_ENABLED = ${MOE_KERNELS_ENABLED}")
77+
message(STATUS " GDN_KERNELS_ENABLED = ${GDN_KERNELS_ENABLED}")
7578
message(
7679
STATUS " XPU_SPECIFIC_KERNELS_ENABLED = ${XPU_SPECIFIC_KERNELS_ENABLED}")
7780
message(STATUS " XPUMEM_ALLOCATOR_ENABLED = ${XPUMEM_ALLOCATOR_ENABLED}")
@@ -331,23 +334,40 @@ if(BUILD_SYCL_TLA_KERNELS)
331334
# extensions shared library
332335
set(SYCL_TLA_COMPILE_OPTIONS "")
333336
if(VLLM_XPU_ENABLE_XE_DEFAULT)
334-
add_subdirectory(csrc/xpu/grouped_gemm/xe_default)
335-
list(APPEND GROUPED_GEMM_LIB_NAME "grouped_gemm_xe_default")
337+
if(MOE_KERNELS_ENABLED)
338+
add_subdirectory(csrc/xpu/grouped_gemm/xe_default)
339+
list(APPEND GROUPED_GEMM_LIB_NAME "grouped_gemm_xe_default")
340+
endif()
336341
list(APPEND SYCL_TLA_COMPILE_OPTIONS -DVLLM_XPU_ENABLE_XE_DEFAULT)
337342
endif()
338343
if(VLLM_XPU_ENABLE_XE2)
339-
add_subdirectory(csrc/xpu/grouped_gemm/xe_2)
340-
add_subdirectory(csrc/xpu/attn/xe_2)
341-
add_subdirectory(csrc/xpu/gdn_attn/xe_2)
342-
list(APPEND GROUPED_GEMM_LIB_NAME "grouped_gemm_xe_2")
343-
list(APPEND ATTN_KERNEL_LIB_NAME "attn_kernels_xe_2")
344-
list(APPEND GDN_ATTN_LIB_NAME "gdn_attn_kernels_xe_2")
344+
if(MOE_KERNELS_ENABLED)
345+
add_subdirectory(csrc/xpu/grouped_gemm/xe_2)
346+
list(APPEND GROUPED_GEMM_LIB_NAME "grouped_gemm_xe_2")
347+
endif()
348+
if(FA2_KERNELS_ENABLED)
349+
add_subdirectory(csrc/xpu/attn/xe_2)
350+
list(APPEND ATTN_KERNEL_LIB_NAME "attn_kernels_xe_2")
351+
endif()
352+
if(GDN_KERNELS_ENABLED)
353+
add_subdirectory(csrc/xpu/gdn_attn/xe_2)
354+
list(APPEND GDN_ATTN_LIB_NAME "gdn_attn_kernels_xe_2")
355+
endif()
345356
list(APPEND SYCL_TLA_COMPILE_OPTIONS -DVLLM_XPU_ENABLE_XE2)
346357
endif()
347358
list(APPEND VLLM_GPU_COMPILE_FLAGS ${SYCL_TLA_COMPILE_OPTIONS})
348359

349360
endif()
350361

362+
# Feature compile defines — these guard op registrations and interface code so
363+
# that disabled features don't pull in unbuilt TLA library symbols.
364+
if(MOE_KERNELS_ENABLED)
365+
list(APPEND VLLM_GPU_COMPILE_FLAGS -DVLLM_MOE_ENABLED)
366+
endif()
367+
if(GDN_KERNELS_ENABLED)
368+
list(APPEND VLLM_GPU_COMPILE_FLAGS -DVLLM_GDN_ENABLED)
369+
endif()
370+
351371
# define vLLM XPU cmake variables
352372

353373
set(VLLM_XPU_INCLUDE_DIR "")
@@ -505,9 +525,14 @@ if(XPU_SPECIFIC_KERNELS_ENABLED)
505525
"csrc/xpu/sampler/topk_topp_sampler.cpp"
506526
"csrc/xpu/sycl/deepseek_scaling_rope.cpp"
507527
"csrc/xpu/rand/exponential.cpp"
508-
"csrc/xpu/grouped_gemm/grouped_gemm_interface.cpp"
509-
"csrc/xpu/utils.cpp"
510-
"csrc/xpu/gdn_attn/gdn_attn_interface.cpp")
528+
"csrc/xpu/utils.cpp")
529+
if(MOE_KERNELS_ENABLED)
530+
list(APPEND VLLM_EXT_XPU_SRC
531+
"csrc/xpu/grouped_gemm/grouped_gemm_interface.cpp")
532+
endif()
533+
if(GDN_KERNELS_ENABLED)
534+
list(APPEND VLLM_EXT_XPU_SRC "csrc/xpu/gdn_attn/gdn_attn_interface.cpp")
535+
endif()
511536
include_directories("/usr/include")
512537
# TODO: check if we need this flags list(APPEND VLLM_GPU_FLAGS
513538
# "-gline-tables-only")

csrc/xpu/ops.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@ torch::Tensor int4_gemm_w4a8(
5252
const std::optional<torch::Tensor>& g_idx,
5353
const std::optional<torch::Tensor>& bias);
5454

55+
#ifdef VLLM_MOE_ENABLED
5556
torch::Tensor cutlass_grouped_gemm_interface(
5657
torch::Tensor ptr_A,
5758
torch::Tensor ptr_B,
@@ -64,6 +65,7 @@ torch::Tensor cutlass_grouped_gemm_interface(
6465
int64_t num_experts,
6566
bool is_B_int4,
6667
bool is_B_mxfp4);
68+
#endif
6769

6870
std::tuple<at::Tensor, at::Tensor> deepseek_scaling_rope(
6971
const at::Tensor& positions,
@@ -74,6 +76,7 @@ std::tuple<at::Tensor, at::Tensor> deepseek_scaling_rope(
7476
int64_t rotary_dim,
7577
bool is_neox);
7678

79+
#ifdef VLLM_GDN_ENABLED
7780
void gdn_attention(
7881
torch::Tensor& core_attn_out,
7982
torch::Tensor& z,
@@ -98,6 +101,7 @@ void gdn_attention(
98101
const int64_t num_actual_tokens,
99102
const int64_t tp_size,
100103
const bool reorder_input);
104+
#endif
101105

102106
bool is_bmg(int64_t device_index);
103107

csrc/xpu/torch_bindings.cpp

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
#include "core/registration.h"
22
#include "xpu/ops.h"
3-
#include "xpu/grouped_gemm/grouped_gemm_interface.h"
3+
#ifdef VLLM_MOE_ENABLED
4+
#include "xpu/grouped_gemm/grouped_gemm_interface.h"
5+
#endif
46
#include "xpu/lora/lora_ops.h"
57

68
#include <torch/library.h>
@@ -35,6 +37,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, xpu_ops) {
3537
"bias) -> Tensor");
3638
xpu_ops.impl("int4_gemm_w4a8", torch::kXPU, &int4_gemm_w4a8);
3739

40+
#ifdef VLLM_MOE_ENABLED
3841
xpu_ops.def(
3942
"cutlass_grouped_gemm_interface(Tensor ptr_A, Tensor ptr_B, Tensor? "
4043
"ptr_scales, "
@@ -48,6 +51,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, xpu_ops) {
4851
"cutlass_grouped_gemm_interface",
4952
torch::kXPU,
5053
&cutlass_grouped_gemm_interface);
54+
#endif
5155

5256
xpu_ops.def(
5357
"deepseek_scaling_rope(Tensor! positions, Tensor! query, Tensor! key, "
@@ -72,6 +76,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, xpu_ops) {
7276
"-> ()");
7377
xpu_ops.impl("bgmv_expand_slice", torch::kXPU, &bgmv_expand_slice);
7478

79+
#ifdef VLLM_GDN_ENABLED
7580
xpu_ops.def(
7681
"gdn_attention(Tensor! core_attn_out, Tensor! z, Tensor "
7782
"projected_states_qkvz, Tensor projected_states_ba,"
@@ -83,6 +88,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, xpu_ops) {
8388
"Tensor non_spec_state_indices_tensor, int num_actual_tokens, int "
8489
"tp_size, bool reorder_input) -> ()");
8590
xpu_ops.impl("gdn_attention", torch::kXPU, &gdn_attention);
91+
#endif
8692

8793
// for empty tensor functions, we don't need dispatch key like torch::kXPU
8894
xpu_ops.def("is_bmg(int device_index) -> bool");

setup.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -191,6 +191,7 @@ def configure(self, ext: CMakeExtension) -> None:
191191
"BASIC_KERNELS_ENABLED",
192192
"FA2_KERNELS_ENABLED",
193193
"MOE_KERNELS_ENABLED",
194+
"GDN_KERNELS_ENABLED",
194195
"XPU_SPECIFIC_KERNELS_ENABLED",
195196
"XPUMEM_ALLOCATOR_ENABLED",
196197
]
@@ -528,11 +529,14 @@ def build_extensions(self) -> None:
528529
if _is_enabled("VLLM_XPU_ENABLE_XE2"):
529530
if _is_enabled("FA2_KERNELS_ENABLED"):
530531
additional_libraries["attn_kernels_xe_2"] = "/csrc/xpu/attn/xe_2"
531-
additional_libraries["gdn_attn_kernels_xe_2"] = (
532-
"/csrc/xpu/gdn_attn/xe_2")
533-
additional_libraries["grouped_gemm_xe_2"] = (
534-
"/csrc/xpu/grouped_gemm/xe_2")
535-
if _is_enabled("VLLM_XPU_ENABLE_XE_DEFAULT"):
532+
if _is_enabled("GDN_KERNELS_ENABLED"):
533+
additional_libraries["gdn_attn_kernels_xe_2"] = (
534+
"/csrc/xpu/gdn_attn/xe_2")
535+
if _is_enabled("MOE_KERNELS_ENABLED"):
536+
additional_libraries["grouped_gemm_xe_2"] = (
537+
"/csrc/xpu/grouped_gemm/xe_2")
538+
if _is_enabled("VLLM_XPU_ENABLE_XE_DEFAULT") and _is_enabled(
539+
"MOE_KERNELS_ENABLED"):
536540
additional_libraries["grouped_gemm_xe_default"] = (
537541
"/csrc/xpu/grouped_gemm/xe_default")
538542

tools/envs.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -121,6 +121,8 @@ def get_vllm_port() -> Optional[int]:
121121
lambda: os.getenv("FA2_KERNELS_ENABLED", "ON"),
122122
"MOE_KERNELS_ENABLED":
123123
lambda: os.getenv("MOE_KERNELS_ENABLED", "ON"),
124+
"GDN_KERNELS_ENABLED":
125+
lambda: os.getenv("GDN_KERNELS_ENABLED", "ON"),
124126
"XPU_SPECIFIC_KERNELS_ENABLED":
125127
lambda: os.getenv("XPU_SPECIFIC_KERNELS_ENABLED", "ON"),
126128
"XPUMEM_ALLOCATOR_ENABLED":

0 commit comments

Comments
 (0)