Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
202 changes: 102 additions & 100 deletions csrc/xpu/attn/xe_2/chunk_prefill.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -338,111 +338,113 @@ struct FMHAConfig {
}
};

template <
typename chunk_policy,
typename ElementQ,
typename ElementKV,
typename ElementO,
bool Paged,
bool Causal,
bool Local,
bool Sink>
void policy_dispatch_typed_impl(
sycl::queue& queue, const chunk_prefill_args_t& args) {
const int PipelineStages = 2;
return FMHAConfig<
typename chunk_policy::ShapeQK,
typename chunk_policy::ShapePV,
typename chunk_policy::ShapeOut,
typename chunk_policy::SubgroupLayoutQK,
void,
PipelineStages,
Paged,
Causal,
Local,
Sink,
ElementQ,
ElementKV,
ElementKV,
ElementO>::kernel_dispatch(queue, args);
}

template <typename chunk_policy, bool Paged, bool Causal, bool Local, bool Sink>
void policy_dispatch_impl(
sycl::queue& queue,
CutlassQKType& cuQKType,
CutlassQKOType& cuQKOType,
const chunk_prefill_args_t& args) {
const int PipelineStages = 2;
if (cuQKType.q_type == CutlassDType::half) {
if (cuQKType.k_type == CutlassDType::half) {
return FMHAConfig<
typename chunk_policy::ShapeQK,
typename chunk_policy::ShapePV,
typename chunk_policy::ShapeOut,
typename chunk_policy::SubgroupLayoutQK,
void,
PipelineStages,
Paged,
Causal,
Local,
Sink,
half_t,
half_t,
half_t,
half_t>::kernel_dispatch(queue, args);
} else if (cuQKType.k_type == CutlassDType::float8_e4m3) {
return FMHAConfig<
typename chunk_policy::ShapeQK,
typename chunk_policy::ShapePV,
typename chunk_policy::ShapeOut,
typename chunk_policy::SubgroupLayoutQK,
void,
PipelineStages,
Paged,
Causal,
Local,
Sink,
half_t,
float_e4m3_t,
float_e4m3_t,
half_t>::kernel_dispatch(queue, args);
} else if (cuQKType.k_type == CutlassDType::float8_e5m2) {
return FMHAConfig<
typename chunk_policy::ShapeQK,
typename chunk_policy::ShapePV,
typename chunk_policy::ShapeOut,
typename chunk_policy::SubgroupLayoutQK,
void,
PipelineStages,
Paged,
Causal,
Local,
Sink,
half_t,
float_e5m2_t,
float_e5m2_t,
half_t>::kernel_dispatch(queue, args);
if (cuQKOType.o_type == CutlassDType::half) {
if (cuQKOType.q_type == CutlassDType::half) {
if (cuQKOType.k_type == CutlassDType::half) {
Comment on lines +375 to +377
return policy_dispatch_typed_impl<
chunk_policy,
half_t,
half_t,
half_t,
Paged,
Causal,
Local,
Sink>(queue, args);
} else if (cuQKOType.k_type == CutlassDType::float8_e4m3) {
return policy_dispatch_typed_impl<
chunk_policy,
half_t,
float_e4m3_t,
half_t,
Paged,
Causal,
Local,
Sink>(queue, args);
} else if (cuQKOType.k_type == CutlassDType::float8_e5m2) {
return policy_dispatch_typed_impl<
chunk_policy,
half_t,
float_e5m2_t,
half_t,
Paged,
Causal,
Local,
Sink>(queue, args);
} else {
TORCH_CHECK(false, "Unsupported KV dtype for chunk prefill dispatch");
}
Comment on lines +408 to +409
}
} else {
if (cuQKType.k_type == CutlassDType::bfloat16) {
return FMHAConfig<
typename chunk_policy::ShapeQK,
typename chunk_policy::ShapePV,
typename chunk_policy::ShapeOut,
typename chunk_policy::SubgroupLayoutQK,
void,
PipelineStages,
Paged,
Causal,
Local,
Sink,
bfloat16_t,
bfloat16_t,
bfloat16_t,
bfloat16_t>::kernel_dispatch(queue, args);
} else if (cuQKType.k_type == CutlassDType::float8_e4m3) {
return FMHAConfig<
typename chunk_policy::ShapeQK,
typename chunk_policy::ShapePV,
typename chunk_policy::ShapeOut,
typename chunk_policy::SubgroupLayoutQK,
void,
PipelineStages,
Paged,
Causal,
Local,
Sink,
bfloat16_t,
float_e4m3_t,
float_e4m3_t,
bfloat16_t>::kernel_dispatch(queue, args);
} else if (cuQKType.k_type == CutlassDType::float8_e5m2) {
return FMHAConfig<
typename chunk_policy::ShapeQK,
typename chunk_policy::ShapePV,
typename chunk_policy::ShapeOut,
typename chunk_policy::SubgroupLayoutQK,
void,
PipelineStages,
Paged,
Causal,
Local,
Sink,
bfloat16_t,
float_e5m2_t,
float_e5m2_t,
bfloat16_t>::kernel_dispatch(queue, args);
} else if (cuQKOType.o_type == CutlassDType::bfloat16) {
if (cuQKOType.q_type == CutlassDType::bfloat16) {
if (cuQKOType.k_type == CutlassDType::bfloat16) {
return policy_dispatch_typed_impl<
chunk_policy,
bfloat16_t,
bfloat16_t,
bfloat16_t,
Paged,
Causal,
Local,
Sink>(queue, args);
} else if (cuQKOType.k_type == CutlassDType::float8_e4m3) {
return policy_dispatch_typed_impl<
chunk_policy,
bfloat16_t,
float_e4m3_t,
bfloat16_t,
Paged,
Causal,
Local,
Sink>(queue, args);
} else if (cuQKOType.k_type == CutlassDType::float8_e5m2) {
return policy_dispatch_typed_impl<
chunk_policy,
bfloat16_t,
float_e5m2_t,
bfloat16_t,
Paged,
Causal,
Local,
Sink>(queue, args);
} else {
TORCH_CHECK(false, "Unsupported KV dtype for chunk prefill dispatch");
}
}
} else {
TORCH_CHECK(false, "Unsupported output dtype for chunk prefill dispatch");
}
}
52 changes: 33 additions & 19 deletions csrc/xpu/attn/xe_2/chunk_prefill_configure.cmake
Original file line number Diff line number Diff line change
@@ -1,33 +1,47 @@
function(fmha_forward_configure FILENAME_SUFFIX)
set(GEN_KERNEL_SRCS) # output
set(L_TYPES "fp16" "bf16")
set(L_BOOLS "false" "true")
set(BOOL_FLAG_false "f")
set(BOOL_FLAG_true "t")
set(policy_list
"chunk_policy_head64" "chunk_policy_head96" "chunk_policy_head128"
"chunk_policy_head192" "chunk_policy_head256" "chunk_policy_head512")

set(IMPL_KV_T "fp16")
# Allowed dtype combinations must match runtime dispatch constraints. Format:
# Q_TYPE|KV_TYPE|O_TYPE|FILE_TAG
set(dtype_combo_list
"half_t|half_t|half_t|h_h_h"
"half_t|float_e4m3_t|half_t|h_e4_h"
"half_t|float_e5m2_t|half_t|h_e5_h"
"bfloat16_t|bfloat16_t|bfloat16_t|b_b_b"
"bfloat16_t|float_e4m3_t|bfloat16_t|b_e4_b"
"bfloat16_t|float_e5m2_t|bfloat16_t|b_e5_b")

foreach(IMPL_POLICY ${policy_list})
# foreach(IMPL_T ${L_TYPES})
foreach(IMPL_KISPAGED ${L_BOOLS})
foreach(IMPL_KISCAUSAL ${L_BOOLS})
foreach(IMPL_KISLOCAL ${L_BOOLS})
foreach(IMPL_KISSINK ${L_BOOLS})
set(FILE_SUFFIX "${IMPL_POLICY}_")
set(FILE_SUFFIX "${FILE_SUFFIX}${BOOL_FLAG_${IMPL_KISPAGED}}")
set(FILE_SUFFIX "${FILE_SUFFIX}${BOOL_FLAG_${IMPL_KISCAUSAL}}")
set(FILE_SUFFIX "${FILE_SUFFIX}${BOOL_FLAG_${IMPL_KISSINK}}")
set(FILE_SUFFIX "${FILE_SUFFIX}${BOOL_FLAG_${IMPL_KISLOCAL}}")
configure_file(${FILENAME_SUFFIX}.cpp.in
"${FILENAME_SUFFIX}_${FILE_SUFFIX}.cpp")
list(
APPEND
GEN_KERNEL_SRCS
"${CMAKE_CURRENT_BINARY_DIR}/${FILENAME_SUFFIX}_${FILE_SUFFIX}.cpp"
)
foreach(dtype_combo ${dtype_combo_list})
string(REPLACE "|" ";" dtype_parts "${dtype_combo}")
list(GET dtype_parts 0 IMPL_Q_T)
list(GET dtype_parts 1 IMPL_KV_T)
list(GET dtype_parts 2 IMPL_O_T)
list(GET dtype_parts 3 DTYPE_TAG)

foreach(IMPL_KISPAGED ${L_BOOLS})
foreach(IMPL_KISCAUSAL ${L_BOOLS})
foreach(IMPL_KISLOCAL ${L_BOOLS})
foreach(IMPL_KISSINK ${L_BOOLS})
set(FILE_SUFFIX "${IMPL_POLICY}_${DTYPE_TAG}_")
set(FILE_SUFFIX "${FILE_SUFFIX}${BOOL_FLAG_${IMPL_KISPAGED}}")
set(FILE_SUFFIX "${FILE_SUFFIX}${BOOL_FLAG_${IMPL_KISCAUSAL}}")
set(FILE_SUFFIX "${FILE_SUFFIX}${BOOL_FLAG_${IMPL_KISSINK}}")
set(FILE_SUFFIX "${FILE_SUFFIX}${BOOL_FLAG_${IMPL_KISLOCAL}}")
configure_file(${FILENAME_SUFFIX}.cpp.in
"${FILENAME_SUFFIX}_${FILE_SUFFIX}.cpp")
list(
APPEND
GEN_KERNEL_SRCS
"${CMAKE_CURRENT_BINARY_DIR}/${FILENAME_SUFFIX}_${FILE_SUFFIX}.cpp"
)
endforeach()
endforeach()
endforeach()
endforeach()
Expand Down
58 changes: 48 additions & 10 deletions csrc/xpu/attn/xe_2/chunk_prefill_extern.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,22 +22,59 @@
// that include chunk_prefill.hpp. Each specialization is explicitly
// instantiated in its own .cpp file generated by CMake.

// Helper macro to declare a single extern template instantiation
#define DECLARE_POLICY_DISPATCH_EXTERN(POLICY, PAGED, CAUSAL, LOCAL, SINK) \
extern template void \
policy_dispatch_impl<POLICY, PAGED, CAUSAL, LOCAL, SINK>( \
sycl::queue & queue, \
CutlassQKType & cuQKType, \
const chunk_prefill_args_t& args);
// Helper macro to declare a single extern template instantiation.
// Template order is: policy, ElementQ, ElementKV, ElementO, bools...
#define DECLARE_POLICY_DISPATCH_EXTERN( \
POLICY, Q_T, KV_T, O_T, PAGED, CAUSAL, LOCAL, SINK) \
extern template void policy_dispatch_typed_impl< \
POLICY, \
Q_T, \
KV_T, \
O_T, \
PAGED, \
CAUSAL, \
LOCAL, \
SINK>(sycl::queue & queue, const chunk_prefill_args_t& args);

// Allowed dtype combinations (must match runtime dispatch constraints):
// 1) Q=half -> O=half, KV in {half, fp8_e4m3, fp8_e5m2}
// 2) Q=bf16 -> O=bf16, KV in {bf16, fp8_e4m3, fp8_e5m2}
#define DECLARE_ALLOWED_DTYPES(POLICY, PAGED, CAUSAL, LOCAL, SINK) \
DECLARE_POLICY_DISPATCH_EXTERN( \
POLICY, half_t, half_t, half_t, PAGED, CAUSAL, LOCAL, SINK) \
DECLARE_POLICY_DISPATCH_EXTERN( \
POLICY, half_t, float_e4m3_t, half_t, PAGED, CAUSAL, LOCAL, SINK) \
DECLARE_POLICY_DISPATCH_EXTERN( \
POLICY, half_t, float_e5m2_t, half_t, PAGED, CAUSAL, LOCAL, SINK) \
DECLARE_POLICY_DISPATCH_EXTERN( \
POLICY, bfloat16_t, bfloat16_t, bfloat16_t, PAGED, CAUSAL, LOCAL, SINK) \
DECLARE_POLICY_DISPATCH_EXTERN( \
POLICY, \
bfloat16_t, \
float_e4m3_t, \
bfloat16_t, \
PAGED, \
CAUSAL, \
LOCAL, \
SINK) \
DECLARE_POLICY_DISPATCH_EXTERN( \
POLICY, \
bfloat16_t, \
float_e5m2_t, \
bfloat16_t, \
PAGED, \
CAUSAL, \
LOCAL, \
SINK)

// Generate all 16 bool combinations for a given policy using nested macros
// Pattern: Paged, Causal, Local, Sink (all permutations of 4 bools = 2^4 = 16)
// This hierarchical approach makes it easy to extend to more bool parameters

// Level 4: Iterate over Sink values (innermost)
#define DECLARE_FOR_SINK(POLICY, PAGED, CAUSAL, LOCAL) \
DECLARE_POLICY_DISPATCH_EXTERN(POLICY, PAGED, CAUSAL, LOCAL, false) \
DECLARE_POLICY_DISPATCH_EXTERN(POLICY, PAGED, CAUSAL, LOCAL, true)
#define DECLARE_FOR_SINK(POLICY, PAGED, CAUSAL, LOCAL) \
DECLARE_ALLOWED_DTYPES(POLICY, PAGED, CAUSAL, LOCAL, false) \
DECLARE_ALLOWED_DTYPES(POLICY, PAGED, CAUSAL, LOCAL, true)

// Level 3: Iterate over Local values
#define DECLARE_FOR_LOCAL(POLICY, PAGED, CAUSAL) \
Expand All @@ -62,4 +99,5 @@ CHUNK_POLICY_LIST(DECLARE_ALL_BOOL_COMBINATIONS)
#undef DECLARE_FOR_CAUSAL
#undef DECLARE_FOR_LOCAL
#undef DECLARE_FOR_SINK
#undef DECLARE_ALLOWED_DTYPES
#undef DECLARE_POLICY_DISPATCH_EXTERN
23 changes: 13 additions & 10 deletions csrc/xpu/attn/xe_2/chunk_prefill_kernel_template.cpp.in
Original file line number Diff line number Diff line change
Expand Up @@ -4,24 +4,27 @@ using namespace cute;

// clang-format off
// macros to be filled in CMake
#define IMPL_T ${IMPL_T}
#define IMPL_Q_T ${IMPL_Q_T}
#define IMPL_KV_T ${IMPL_KV_T}
#define IMPL_O_T ${IMPL_O_T}
#define IMPL_POLICY ${IMPL_POLICY}
#cmakedefine01 IMPL_KISPAGED
#cmakedefine01 IMPL_KISCAUSAL
#cmakedefine01 IMPL_KISSINK
#cmakedefine01 IMPL_KISLOCAL
// clang-format on

#define INSTANTIATE_KERNEL() \
template void policy_dispatch_impl< \
IMPL_POLICY, \
static_cast<bool>(IMPL_KISPAGED), \
static_cast<bool>(IMPL_KISCAUSAL), \
static_cast<bool>(IMPL_KISLOCAL), \
static_cast<bool>(IMPL_KISSINK)>( \
sycl::queue & queue, \
CutlassQKType& cuQKType, \
#define INSTANTIATE_KERNEL() \
template void policy_dispatch_typed_impl< \
IMPL_POLICY, \
IMPL_Q_T, \
IMPL_KV_T, \
IMPL_O_T, \
static_cast<bool>(IMPL_KISPAGED), \
static_cast<bool>(IMPL_KISCAUSAL), \
static_cast<bool>(IMPL_KISLOCAL), \
static_cast<bool>(IMPL_KISSINK)>( \
sycl::queue & queue, \
const chunk_prefill_args_t& args);

INSTANTIATE_KERNEL()
Loading
Loading