diff --git a/csrc/xpu/attn/xe_2/chunk_prefill.hpp b/csrc/xpu/attn/xe_2/chunk_prefill.hpp index 3c212f9f7..464f17983 100644 --- a/csrc/xpu/attn/xe_2/chunk_prefill.hpp +++ b/csrc/xpu/attn/xe_2/chunk_prefill.hpp @@ -338,111 +338,113 @@ struct FMHAConfig { } }; +template < + typename chunk_policy, + typename ElementQ, + typename ElementKV, + typename ElementO, + bool Paged, + bool Causal, + bool Local, + bool Sink> +void policy_dispatch_typed_impl( + sycl::queue& queue, const chunk_prefill_args_t& args) { + const int PipelineStages = 2; + return FMHAConfig< + typename chunk_policy::ShapeQK, + typename chunk_policy::ShapePV, + typename chunk_policy::ShapeOut, + typename chunk_policy::SubgroupLayoutQK, + void, + PipelineStages, + Paged, + Causal, + Local, + Sink, + ElementQ, + ElementKV, + ElementKV, + ElementO>::kernel_dispatch(queue, args); +} + template void policy_dispatch_impl( sycl::queue& queue, - CutlassQKType& cuQKType, + CutlassQKOType& cuQKOType, const chunk_prefill_args_t& args) { - const int PipelineStages = 2; - if (cuQKType.q_type == CutlassDType::half) { - if (cuQKType.k_type == CutlassDType::half) { - return FMHAConfig< - typename chunk_policy::ShapeQK, - typename chunk_policy::ShapePV, - typename chunk_policy::ShapeOut, - typename chunk_policy::SubgroupLayoutQK, - void, - PipelineStages, - Paged, - Causal, - Local, - Sink, - half_t, - half_t, - half_t, - half_t>::kernel_dispatch(queue, args); - } else if (cuQKType.k_type == CutlassDType::float8_e4m3) { - return FMHAConfig< - typename chunk_policy::ShapeQK, - typename chunk_policy::ShapePV, - typename chunk_policy::ShapeOut, - typename chunk_policy::SubgroupLayoutQK, - void, - PipelineStages, - Paged, - Causal, - Local, - Sink, - half_t, - float_e4m3_t, - float_e4m3_t, - half_t>::kernel_dispatch(queue, args); - } else if (cuQKType.k_type == CutlassDType::float8_e5m2) { - return FMHAConfig< - typename chunk_policy::ShapeQK, - typename chunk_policy::ShapePV, - typename chunk_policy::ShapeOut, - typename chunk_policy::SubgroupLayoutQK, - void, - PipelineStages, - Paged, - Causal, - Local, - Sink, - half_t, - float_e5m2_t, - float_e5m2_t, - half_t>::kernel_dispatch(queue, args); + if (cuQKOType.o_type == CutlassDType::half) { + if (cuQKOType.q_type == CutlassDType::half) { + if (cuQKOType.k_type == CutlassDType::half) { + return policy_dispatch_typed_impl< + chunk_policy, + half_t, + half_t, + half_t, + Paged, + Causal, + Local, + Sink>(queue, args); + } else if (cuQKOType.k_type == CutlassDType::float8_e4m3) { + return policy_dispatch_typed_impl< + chunk_policy, + half_t, + float_e4m3_t, + half_t, + Paged, + Causal, + Local, + Sink>(queue, args); + } else if (cuQKOType.k_type == CutlassDType::float8_e5m2) { + return policy_dispatch_typed_impl< + chunk_policy, + half_t, + float_e5m2_t, + half_t, + Paged, + Causal, + Local, + Sink>(queue, args); + } else { + TORCH_CHECK(false, "Unsupported KV dtype for chunk prefill dispatch"); + } } - } else { - if (cuQKType.k_type == CutlassDType::bfloat16) { - return FMHAConfig< - typename chunk_policy::ShapeQK, - typename chunk_policy::ShapePV, - typename chunk_policy::ShapeOut, - typename chunk_policy::SubgroupLayoutQK, - void, - PipelineStages, - Paged, - Causal, - Local, - Sink, - bfloat16_t, - bfloat16_t, - bfloat16_t, - bfloat16_t>::kernel_dispatch(queue, args); - } else if (cuQKType.k_type == CutlassDType::float8_e4m3) { - return FMHAConfig< - typename chunk_policy::ShapeQK, - typename chunk_policy::ShapePV, - typename chunk_policy::ShapeOut, - typename chunk_policy::SubgroupLayoutQK, - void, - PipelineStages, - Paged, - Causal, - Local, - Sink, - bfloat16_t, - float_e4m3_t, - float_e4m3_t, - bfloat16_t>::kernel_dispatch(queue, args); - } else if (cuQKType.k_type == CutlassDType::float8_e5m2) { - return FMHAConfig< - typename chunk_policy::ShapeQK, - typename chunk_policy::ShapePV, - typename chunk_policy::ShapeOut, - typename chunk_policy::SubgroupLayoutQK, - void, - PipelineStages, - Paged, - Causal, - Local, - Sink, - bfloat16_t, - float_e5m2_t, - float_e5m2_t, - bfloat16_t>::kernel_dispatch(queue, args); + } else if (cuQKOType.o_type == CutlassDType::bfloat16) { + if (cuQKOType.q_type == CutlassDType::bfloat16) { + if (cuQKOType.k_type == CutlassDType::bfloat16) { + return policy_dispatch_typed_impl< + chunk_policy, + bfloat16_t, + bfloat16_t, + bfloat16_t, + Paged, + Causal, + Local, + Sink>(queue, args); + } else if (cuQKOType.k_type == CutlassDType::float8_e4m3) { + return policy_dispatch_typed_impl< + chunk_policy, + bfloat16_t, + float_e4m3_t, + bfloat16_t, + Paged, + Causal, + Local, + Sink>(queue, args); + } else if (cuQKOType.k_type == CutlassDType::float8_e5m2) { + return policy_dispatch_typed_impl< + chunk_policy, + bfloat16_t, + float_e5m2_t, + bfloat16_t, + Paged, + Causal, + Local, + Sink>(queue, args); + } else { + TORCH_CHECK(false, "Unsupported KV dtype for chunk prefill dispatch"); + } } + } else { + TORCH_CHECK(false, "Unsupported output dtype for chunk prefill dispatch"); } } diff --git a/csrc/xpu/attn/xe_2/chunk_prefill_configure.cmake b/csrc/xpu/attn/xe_2/chunk_prefill_configure.cmake index 76c977c45..30983b6c5 100644 --- a/csrc/xpu/attn/xe_2/chunk_prefill_configure.cmake +++ b/csrc/xpu/attn/xe_2/chunk_prefill_configure.cmake @@ -1,6 +1,5 @@ function(fmha_forward_configure FILENAME_SUFFIX) set(GEN_KERNEL_SRCS) # output - set(L_TYPES "fp16" "bf16") set(L_BOOLS "false" "true") set(BOOL_FLAG_false "f") set(BOOL_FLAG_true "t") @@ -8,26 +7,41 @@ function(fmha_forward_configure FILENAME_SUFFIX) "chunk_policy_head64" "chunk_policy_head96" "chunk_policy_head128" "chunk_policy_head192" "chunk_policy_head256" "chunk_policy_head512") - set(IMPL_KV_T "fp16") + # Allowed dtype combinations must match runtime dispatch constraints. Format: + # Q_TYPE|KV_TYPE|O_TYPE|FILE_TAG + set(dtype_combo_list + "half_t|half_t|half_t|h_h_h" + "half_t|float_e4m3_t|half_t|h_e4_h" + "half_t|float_e5m2_t|half_t|h_e5_h" + "bfloat16_t|bfloat16_t|bfloat16_t|b_b_b" + "bfloat16_t|float_e4m3_t|bfloat16_t|b_e4_b" + "bfloat16_t|float_e5m2_t|bfloat16_t|b_e5_b") foreach(IMPL_POLICY ${policy_list}) - # foreach(IMPL_T ${L_TYPES}) - foreach(IMPL_KISPAGED ${L_BOOLS}) - foreach(IMPL_KISCAUSAL ${L_BOOLS}) - foreach(IMPL_KISLOCAL ${L_BOOLS}) - foreach(IMPL_KISSINK ${L_BOOLS}) - set(FILE_SUFFIX "${IMPL_POLICY}_") - set(FILE_SUFFIX "${FILE_SUFFIX}${BOOL_FLAG_${IMPL_KISPAGED}}") - set(FILE_SUFFIX "${FILE_SUFFIX}${BOOL_FLAG_${IMPL_KISCAUSAL}}") - set(FILE_SUFFIX "${FILE_SUFFIX}${BOOL_FLAG_${IMPL_KISSINK}}") - set(FILE_SUFFIX "${FILE_SUFFIX}${BOOL_FLAG_${IMPL_KISLOCAL}}") - configure_file(${FILENAME_SUFFIX}.cpp.in - "${FILENAME_SUFFIX}_${FILE_SUFFIX}.cpp") - list( - APPEND - GEN_KERNEL_SRCS - "${CMAKE_CURRENT_BINARY_DIR}/${FILENAME_SUFFIX}_${FILE_SUFFIX}.cpp" - ) + foreach(dtype_combo ${dtype_combo_list}) + string(REPLACE "|" ";" dtype_parts "${dtype_combo}") + list(GET dtype_parts 0 IMPL_Q_T) + list(GET dtype_parts 1 IMPL_KV_T) + list(GET dtype_parts 2 IMPL_O_T) + list(GET dtype_parts 3 DTYPE_TAG) + + foreach(IMPL_KISPAGED ${L_BOOLS}) + foreach(IMPL_KISCAUSAL ${L_BOOLS}) + foreach(IMPL_KISLOCAL ${L_BOOLS}) + foreach(IMPL_KISSINK ${L_BOOLS}) + set(FILE_SUFFIX "${IMPL_POLICY}_${DTYPE_TAG}_") + set(FILE_SUFFIX "${FILE_SUFFIX}${BOOL_FLAG_${IMPL_KISPAGED}}") + set(FILE_SUFFIX "${FILE_SUFFIX}${BOOL_FLAG_${IMPL_KISCAUSAL}}") + set(FILE_SUFFIX "${FILE_SUFFIX}${BOOL_FLAG_${IMPL_KISSINK}}") + set(FILE_SUFFIX "${FILE_SUFFIX}${BOOL_FLAG_${IMPL_KISLOCAL}}") + configure_file(${FILENAME_SUFFIX}.cpp.in + "${FILENAME_SUFFIX}_${FILE_SUFFIX}.cpp") + list( + APPEND + GEN_KERNEL_SRCS + "${CMAKE_CURRENT_BINARY_DIR}/${FILENAME_SUFFIX}_${FILE_SUFFIX}.cpp" + ) + endforeach() endforeach() endforeach() endforeach() diff --git a/csrc/xpu/attn/xe_2/chunk_prefill_extern.hpp b/csrc/xpu/attn/xe_2/chunk_prefill_extern.hpp index 7e1876aea..fa9c96585 100644 --- a/csrc/xpu/attn/xe_2/chunk_prefill_extern.hpp +++ b/csrc/xpu/attn/xe_2/chunk_prefill_extern.hpp @@ -22,22 +22,59 @@ // that include chunk_prefill.hpp. Each specialization is explicitly // instantiated in its own .cpp file generated by CMake. -// Helper macro to declare a single extern template instantiation -#define DECLARE_POLICY_DISPATCH_EXTERN(POLICY, PAGED, CAUSAL, LOCAL, SINK) \ - extern template void \ - policy_dispatch_impl( \ - sycl::queue & queue, \ - CutlassQKType & cuQKType, \ - const chunk_prefill_args_t& args); +// Helper macro to declare a single extern template instantiation. +// Template order is: policy, ElementQ, ElementKV, ElementO, bools... +#define DECLARE_POLICY_DISPATCH_EXTERN( \ + POLICY, Q_T, KV_T, O_T, PAGED, CAUSAL, LOCAL, SINK) \ + extern template void policy_dispatch_typed_impl< \ + POLICY, \ + Q_T, \ + KV_T, \ + O_T, \ + PAGED, \ + CAUSAL, \ + LOCAL, \ + SINK>(sycl::queue & queue, const chunk_prefill_args_t& args); + +// Allowed dtype combinations (must match runtime dispatch constraints): +// 1) Q=half -> O=half, KV in {half, fp8_e4m3, fp8_e5m2} +// 2) Q=bf16 -> O=bf16, KV in {bf16, fp8_e4m3, fp8_e5m2} +#define DECLARE_ALLOWED_DTYPES(POLICY, PAGED, CAUSAL, LOCAL, SINK) \ + DECLARE_POLICY_DISPATCH_EXTERN( \ + POLICY, half_t, half_t, half_t, PAGED, CAUSAL, LOCAL, SINK) \ + DECLARE_POLICY_DISPATCH_EXTERN( \ + POLICY, half_t, float_e4m3_t, half_t, PAGED, CAUSAL, LOCAL, SINK) \ + DECLARE_POLICY_DISPATCH_EXTERN( \ + POLICY, half_t, float_e5m2_t, half_t, PAGED, CAUSAL, LOCAL, SINK) \ + DECLARE_POLICY_DISPATCH_EXTERN( \ + POLICY, bfloat16_t, bfloat16_t, bfloat16_t, PAGED, CAUSAL, LOCAL, SINK) \ + DECLARE_POLICY_DISPATCH_EXTERN( \ + POLICY, \ + bfloat16_t, \ + float_e4m3_t, \ + bfloat16_t, \ + PAGED, \ + CAUSAL, \ + LOCAL, \ + SINK) \ + DECLARE_POLICY_DISPATCH_EXTERN( \ + POLICY, \ + bfloat16_t, \ + float_e5m2_t, \ + bfloat16_t, \ + PAGED, \ + CAUSAL, \ + LOCAL, \ + SINK) // Generate all 16 bool combinations for a given policy using nested macros // Pattern: Paged, Causal, Local, Sink (all permutations of 4 bools = 2^4 = 16) // This hierarchical approach makes it easy to extend to more bool parameters // Level 4: Iterate over Sink values (innermost) -#define DECLARE_FOR_SINK(POLICY, PAGED, CAUSAL, LOCAL) \ - DECLARE_POLICY_DISPATCH_EXTERN(POLICY, PAGED, CAUSAL, LOCAL, false) \ - DECLARE_POLICY_DISPATCH_EXTERN(POLICY, PAGED, CAUSAL, LOCAL, true) +#define DECLARE_FOR_SINK(POLICY, PAGED, CAUSAL, LOCAL) \ + DECLARE_ALLOWED_DTYPES(POLICY, PAGED, CAUSAL, LOCAL, false) \ + DECLARE_ALLOWED_DTYPES(POLICY, PAGED, CAUSAL, LOCAL, true) // Level 3: Iterate over Local values #define DECLARE_FOR_LOCAL(POLICY, PAGED, CAUSAL) \ @@ -62,4 +99,5 @@ CHUNK_POLICY_LIST(DECLARE_ALL_BOOL_COMBINATIONS) #undef DECLARE_FOR_CAUSAL #undef DECLARE_FOR_LOCAL #undef DECLARE_FOR_SINK +#undef DECLARE_ALLOWED_DTYPES #undef DECLARE_POLICY_DISPATCH_EXTERN diff --git a/csrc/xpu/attn/xe_2/chunk_prefill_kernel_template.cpp.in b/csrc/xpu/attn/xe_2/chunk_prefill_kernel_template.cpp.in index a9a575c37..950a76fa0 100644 --- a/csrc/xpu/attn/xe_2/chunk_prefill_kernel_template.cpp.in +++ b/csrc/xpu/attn/xe_2/chunk_prefill_kernel_template.cpp.in @@ -4,8 +4,9 @@ using namespace cute; // clang-format off // macros to be filled in CMake -#define IMPL_T ${IMPL_T} +#define IMPL_Q_T ${IMPL_Q_T} #define IMPL_KV_T ${IMPL_KV_T} +#define IMPL_O_T ${IMPL_O_T} #define IMPL_POLICY ${IMPL_POLICY} #cmakedefine01 IMPL_KISPAGED #cmakedefine01 IMPL_KISCAUSAL @@ -13,15 +14,17 @@ using namespace cute; #cmakedefine01 IMPL_KISLOCAL // clang-format on -#define INSTANTIATE_KERNEL() \ - template void policy_dispatch_impl< \ - IMPL_POLICY, \ - static_cast(IMPL_KISPAGED), \ - static_cast(IMPL_KISCAUSAL), \ - static_cast(IMPL_KISLOCAL), \ - static_cast(IMPL_KISSINK)>( \ - sycl::queue & queue, \ - CutlassQKType& cuQKType, \ +#define INSTANTIATE_KERNEL() \ + template void policy_dispatch_typed_impl< \ + IMPL_POLICY, \ + IMPL_Q_T, \ + IMPL_KV_T, \ + IMPL_O_T, \ + static_cast(IMPL_KISPAGED), \ + static_cast(IMPL_KISCAUSAL), \ + static_cast(IMPL_KISLOCAL), \ + static_cast(IMPL_KISSINK)>( \ + sycl::queue & queue, \ const chunk_prefill_args_t& args); INSTANTIATE_KERNEL() diff --git a/csrc/xpu/attn/xe_2/chunk_prefill_utils.hpp b/csrc/xpu/attn/xe_2/chunk_prefill_utils.hpp index 2fb0dd193..d0601d987 100644 --- a/csrc/xpu/attn/xe_2/chunk_prefill_utils.hpp +++ b/csrc/xpu/attn/xe_2/chunk_prefill_utils.hpp @@ -5,24 +5,24 @@ using namespace cute; template void policy_dispatch_func( sycl::queue& queue, - CutlassQKType& cuQKType, + CutlassQKOType& cuQKOType, const chunk_prefill_args_t& args) { - policy_dispatch_impl(queue, cuQKType, args); + policy_dispatch_impl(queue, cuQKOType, args); } template void policy_dispatch_func( sycl::queue& queue, - CutlassQKType& cuQKType, + CutlassQKOType& cuQKOType, const chunk_prefill_args_t& args, bool b, Ts... ts) { if (b) { policy_dispatch_func( - queue, cuQKType, args, ts...); + queue, cuQKOType, args, ts...); } else { policy_dispatch_func( - queue, cuQKType, args, ts...); + queue, cuQKOType, args, ts...); } } diff --git a/csrc/xpu/attn/xe_2/fmha_utils.hpp b/csrc/xpu/attn/xe_2/fmha_utils.hpp index 1533001c9..678ac9f13 100644 --- a/csrc/xpu/attn/xe_2/fmha_utils.hpp +++ b/csrc/xpu/attn/xe_2/fmha_utils.hpp @@ -11,15 +11,16 @@ enum class CutlassDType { half, bfloat16, float8_e4m3, float8_e5m2 }; -// Struct to carry separate Q and K dtypes without breaking existing API -struct CutlassQKType { +// Struct to carry separate Q, K, and O dtypes without breaking existing API +struct CutlassQKOType { CutlassDType q_type; CutlassDType k_type; + CutlassDType o_type; // Convenience: construct with identical types - explicit CutlassQKType(CutlassDType t) : q_type(t), k_type(t) {} - CutlassQKType(CutlassDType q_t, CutlassDType k_t) - : q_type(q_t), k_type(k_t) {} + explicit CutlassQKOType(CutlassDType t) : q_type(t), k_type(t), o_type(t) {} + CutlassQKOType(CutlassDType q_t, CutlassDType k_t, CutlassDType o_t) + : q_type(q_t), k_type(k_t), o_type(o_t) {} }; inline CutlassDType aten_to_dtype(const at::ScalarType st) { @@ -35,17 +36,17 @@ inline CutlassDType aten_to_dtype(const at::ScalarType st) { TORCH_INTERNAL_ASSERT( false, "Unsupported dtype: only half/bfloat16/float8_e4m3/float8_e5m2 supported " - "for Q/K."); + "for Q/K/O."); } inline CutlassDType aten_to_dtype(const at::Tensor& t) { return aten_to_dtype(t.scalar_type()); } -// Helper to build Q/K dtype pair from tensors -inline CutlassQKType -aten_to_Cutlass_qk_dtype(const at::Tensor& q, const at::Tensor& k) { - return CutlassQKType(aten_to_dtype(q), aten_to_dtype(k)); +// Helper to build Q/K/O dtype triplet from tensors +inline CutlassQKOType aten_to_Cutlass_qko_dtype( + const at::Tensor& q, const at::Tensor& k, const at::Tensor& o) { + return CutlassQKOType(aten_to_dtype(q), aten_to_dtype(k), aten_to_dtype(o)); } using namespace cute; diff --git a/csrc/xpu/attn/xe_2/fmha_xe2.cpp b/csrc/xpu/attn/xe_2/fmha_xe2.cpp index b72ea0bbd..4f8e93301 100644 --- a/csrc/xpu/attn/xe_2/fmha_xe2.cpp +++ b/csrc/xpu/attn/xe_2/fmha_xe2.cpp @@ -1,6 +1,7 @@ #include "fmha_xe2.h" #include "chunk_prefill_utils.hpp" #include "chunk_prefill_extern.hpp" +#include "fmha_utils.hpp" using namespace cute; @@ -199,7 +200,7 @@ void cutlass_chunk_prefill_impl( } } - CutlassQKType cuQKType = aten_to_Cutlass_qk_dtype(query, key_cache); + CutlassQKOType cuQKOType = aten_to_Cutlass_qko_dtype(query, key_cache, out); static constexpr int max_head_size = 512; TORCH_CHECK( @@ -209,22 +210,22 @@ void cutlass_chunk_prefill_impl( if (args.head_size <= HEAD_SIZE_LIMIT_0) { policy_dispatch_func( - queue, cuQKType, args, is_paged, is_causal, is_local, is_sink); + queue, cuQKOType, args, is_paged, is_causal, is_local, is_sink); } else if (args.head_size <= HEAD_SIZE_LIMIT_1) { policy_dispatch_func( - queue, cuQKType, args, is_paged, is_causal, is_local, is_sink); + queue, cuQKOType, args, is_paged, is_causal, is_local, is_sink); } else if (args.head_size <= HEAD_SIZE_LIMIT_2) { policy_dispatch_func( - queue, cuQKType, args, is_paged, is_causal, is_local, is_sink); + queue, cuQKOType, args, is_paged, is_causal, is_local, is_sink); } else if (args.head_size <= HEAD_SIZE_LIMIT_3) { policy_dispatch_func( - queue, cuQKType, args, is_paged, is_causal, is_local, is_sink); + queue, cuQKOType, args, is_paged, is_causal, is_local, is_sink); } else if (args.head_size <= HEAD_SIZE_LIMIT_4) { policy_dispatch_func( - queue, cuQKType, args, is_paged, is_causal, is_local, is_sink); + queue, cuQKOType, args, is_paged, is_causal, is_local, is_sink); } else if (args.head_size <= HEAD_SIZE_LIMIT_5) { policy_dispatch_func( - queue, cuQKType, args, is_paged, is_causal, is_local, is_sink); + queue, cuQKOType, args, is_paged, is_causal, is_local, is_sink); } else { TORCH_CHECK(false, "Unsupported head size for fmha"); } diff --git a/csrc/xpu/attn/xe_2/paged_decode.hpp b/csrc/xpu/attn/xe_2/paged_decode.hpp index 219290d17..8734b0b27 100644 --- a/csrc/xpu/attn/xe_2/paged_decode.hpp +++ b/csrc/xpu/attn/xe_2/paged_decode.hpp @@ -490,111 +490,136 @@ struct PagedDecodeConfig { }; // Template function for explicit instantiation +template < + typename decode_policy, + typename ElementQ, + typename ElementKV, + typename ElementO, + bool Causal, + bool Local, + bool Sink> +void decode_policy_dispatch_typed_impl( + sycl::queue& queue, const paged_decode_args_t& args) { + const int PipelineStages = 1; + return PagedDecodeConfig< + typename decode_policy::ShapeQK, + typename decode_policy::ShapePV, + typename decode_policy::ShapeOut, + typename decode_policy::SubgroupLayoutQK, + void, + PipelineStages, + Causal, + Local, + Sink, + ElementQ, + ElementKV, + ElementKV, + ElementO>::kernel_dispatch(queue, args); +} + template void decode_policy_dispatch_impl( sycl::queue& queue, - CutlassQKType& cuQKType, + CutlassQKOType& cuQKOType, const paged_decode_args_t& args) { - const int PipelineStages = 1; - if (cuQKType.q_type == CutlassDType::half) { - if (cuQKType.k_type == CutlassDType::half) { - return PagedDecodeConfig< - typename decode_policy::ShapeQK, - typename decode_policy::ShapePV, - typename decode_policy::ShapeOut, - typename decode_policy::SubgroupLayoutQK, - void, - PipelineStages, - Causal, - Local, - Sink, - half_t, - half_t, - half_t, - half_t>::kernel_dispatch(queue, args); - } else if (cuQKType.k_type == CutlassDType::float8_e4m3) { - return PagedDecodeConfig< - typename decode_policy::ShapeQK, - typename decode_policy::ShapePV, - typename decode_policy::ShapeOut, - typename decode_policy::SubgroupLayoutQK, - void, - PipelineStages, - Causal, - Local, - Sink, - half_t, - float_e4m3_t, - float_e4m3_t, - half_t>::kernel_dispatch(queue, args); - } else if (cuQKType.k_type == CutlassDType::float8_e5m2) { - return PagedDecodeConfig< - typename decode_policy::ShapeQK, - typename decode_policy::ShapePV, - typename decode_policy::ShapeOut, - typename decode_policy::SubgroupLayoutQK, - void, - PipelineStages, - Causal, - Local, - Sink, - half_t, - float_e5m2_t, - float_e5m2_t, - half_t>::kernel_dispatch(queue, args); + if (cuQKOType.o_type == CutlassDType::half) { + if (cuQKOType.q_type == CutlassDType::half) { + if (cuQKOType.k_type == CutlassDType::half) { + return decode_policy_dispatch_typed_impl< + decode_policy, + half_t, + half_t, + half_t, + Causal, + Local, + Sink>(queue, args); + } else if (cuQKOType.k_type == CutlassDType::float8_e4m3) { + return decode_policy_dispatch_typed_impl< + decode_policy, + half_t, + float_e4m3_t, + half_t, + Causal, + Local, + Sink>(queue, args); + } else if (cuQKOType.k_type == CutlassDType::float8_e5m2) { + return decode_policy_dispatch_typed_impl< + decode_policy, + half_t, + float_e5m2_t, + half_t, + Causal, + Local, + Sink>(queue, args); + } else { + TORCH_CHECK( + false, + "Unsupported Q/KV dtype combination for paged_decode kernel: " + "q_type=", + static_cast(cuQKOType.q_type), + " k_type=", + static_cast(cuQKOType.k_type), + " o_type=", + static_cast(cuQKOType.o_type)); + } + } else { + TORCH_CHECK( + false, + "Unsupported Q/KV dtype combination for paged_decode kernel: q_type=", + static_cast(cuQKOType.q_type), + " k_type=", + static_cast(cuQKOType.k_type), + " o_type=", + static_cast(cuQKOType.o_type)); } - } else if (cuQKType.q_type == CutlassDType::bfloat16) { - if (cuQKType.k_type == CutlassDType::bfloat16) { - return PagedDecodeConfig< - typename decode_policy::ShapeQK, - typename decode_policy::ShapePV, - typename decode_policy::ShapeOut, - typename decode_policy::SubgroupLayoutQK, - void, - PipelineStages, - Causal, - Local, - Sink, - bfloat16_t, - bfloat16_t, - bfloat16_t, - bfloat16_t>::kernel_dispatch(queue, args); - } else if (cuQKType.k_type == CutlassDType::float8_e4m3) { - return PagedDecodeConfig< - typename decode_policy::ShapeQK, - typename decode_policy::ShapePV, - typename decode_policy::ShapeOut, - typename decode_policy::SubgroupLayoutQK, - void, - PipelineStages, - Causal, - Local, - Sink, - bfloat16_t, - float_e4m3_t, - float_e4m3_t, - bfloat16_t>::kernel_dispatch(queue, args); - } else if (cuQKType.k_type == CutlassDType::float8_e5m2) { - return PagedDecodeConfig< - typename decode_policy::ShapeQK, - typename decode_policy::ShapePV, - typename decode_policy::ShapeOut, - typename decode_policy::SubgroupLayoutQK, - void, - PipelineStages, - Causal, - Local, - Sink, - bfloat16_t, - float_e5m2_t, - float_e5m2_t, - bfloat16_t>::kernel_dispatch(queue, args); + } else if (cuQKOType.o_type == CutlassDType::bfloat16) { + if (cuQKOType.q_type == CutlassDType::bfloat16) { + if (cuQKOType.k_type == CutlassDType::bfloat16) { + return decode_policy_dispatch_typed_impl< + decode_policy, + bfloat16_t, + bfloat16_t, + bfloat16_t, + Causal, + Local, + Sink>(queue, args); + } else if (cuQKOType.k_type == CutlassDType::float8_e4m3) { + return decode_policy_dispatch_typed_impl< + decode_policy, + bfloat16_t, + float_e4m3_t, + bfloat16_t, + Causal, + Local, + Sink>(queue, args); + } else if (cuQKOType.k_type == CutlassDType::float8_e5m2) { + return decode_policy_dispatch_typed_impl< + decode_policy, + bfloat16_t, + float_e5m2_t, + bfloat16_t, + Causal, + Local, + Sink>(queue, args); + } + } else { + TORCH_CHECK( + false, + "Unsupported Q/KV dtype combination for paged_decode kernel: q_type=", + static_cast(cuQKOType.q_type), + " k_type=", + static_cast(cuQKOType.k_type), + " o_type=", + static_cast(cuQKOType.o_type)); } + } else { + TORCH_CHECK( + false, + "Unsupported Q/KV dtype combination for paged_decode kernel: q_type=", + static_cast(cuQKOType.q_type), + " k_type=", + static_cast(cuQKOType.k_type), + " o_type=", + static_cast(cuQKOType.o_type)); } - TORCH_CHECK( - false, - "Unsupported Q/KV dtype combination for paged_decode kernel: q_type=", - static_cast(cuQKType.q_type), - " k_type=", - static_cast(cuQKType.k_type)); } diff --git a/csrc/xpu/attn/xe_2/paged_decode_configure.cmake b/csrc/xpu/attn/xe_2/paged_decode_configure.cmake index eee965bff..acf3c9fcb 100644 --- a/csrc/xpu/attn/xe_2/paged_decode_configure.cmake +++ b/csrc/xpu/attn/xe_2/paged_decode_configure.cmake @@ -66,6 +66,16 @@ function(paged_decode_configure FILENAME_SUFFIX) set(headsize_list "64" "96" "128" "192" "256" "512") set(pagesize_list "64" "128") + # Allowed dtype combinations must match runtime dispatch constraints. Format: + # Q_TYPE|KV_TYPE|O_TYPE|FILE_TAG + set(dtype_combo_list + "half_t|half_t|half_t|h_h_h" + "half_t|float_e4m3_t|half_t|h_e4_h" + "half_t|float_e5m2_t|half_t|h_e5_h" + "bfloat16_t|bfloat16_t|bfloat16_t|b_b_b" + "bfloat16_t|float_e4m3_t|bfloat16_t|b_e4_b" + "bfloat16_t|float_e5m2_t|bfloat16_t|b_e5_b") + # ============================================================================= # Generate Kernel Sources # ============================================================================= @@ -78,26 +88,35 @@ function(paged_decode_configure FILENAME_SUFFIX) set(IMPL_POLICY ${policy_${IMPL_QGROUP}_${IMPL_HEADSIZE}_${IMPL_PAGESIZE}}) - foreach(IMPL_KISCAUSAL ${L_BOOLS}) - foreach(IMPL_KISLOCAL ${L_BOOLS}) - foreach(IMPL_KISSINK ${L_BOOLS}) - # Construct unique filename suffix: e.g., _q8_h64_fff - set(FILE_SUFFIX - "_q${IMPL_QGROUP}_h${IMPL_HEADSIZE}_p${IMPL_PAGESIZE}_") - set(FILE_SUFFIX "${FILE_SUFFIX}${BOOL_FLAG_${IMPL_KISCAUSAL}}") - set(FILE_SUFFIX "${FILE_SUFFIX}${BOOL_FLAG_${IMPL_KISLOCAL}}") - set(FILE_SUFFIX "${FILE_SUFFIX}${BOOL_FLAG_${IMPL_KISSINK}}") - - # Generate .cpp file from template - configure_file(${FILENAME_SUFFIX}.cpp.in - "${FILENAME_SUFFIX}${FILE_SUFFIX}.cpp") - - # Add to output list - list( - APPEND - GEN_KERNEL_SRCS - "${CMAKE_CURRENT_BINARY_DIR}/${FILENAME_SUFFIX}${FILE_SUFFIX}.cpp" - ) + foreach(dtype_combo ${dtype_combo_list}) + string(REPLACE "|" ";" dtype_parts "${dtype_combo}") + list(GET dtype_parts 0 IMPL_Q_T) + list(GET dtype_parts 1 IMPL_KV_T) + list(GET dtype_parts 2 IMPL_O_T) + list(GET dtype_parts 3 DTYPE_TAG) + + foreach(IMPL_KISCAUSAL ${L_BOOLS}) + foreach(IMPL_KISLOCAL ${L_BOOLS}) + foreach(IMPL_KISSINK ${L_BOOLS}) + # Construct unique filename suffix: e.g., _q8_h64_fff + set(FILE_SUFFIX + "_q${IMPL_QGROUP}_h${IMPL_HEADSIZE}_p${IMPL_PAGESIZE}_") + set(FILE_SUFFIX "${FILE_SUFFIX}${DTYPE_TAG}_") + set(FILE_SUFFIX "${FILE_SUFFIX}${BOOL_FLAG_${IMPL_KISCAUSAL}}") + set(FILE_SUFFIX "${FILE_SUFFIX}${BOOL_FLAG_${IMPL_KISLOCAL}}") + set(FILE_SUFFIX "${FILE_SUFFIX}${BOOL_FLAG_${IMPL_KISSINK}}") + + # Generate .cpp file from template + configure_file(${FILENAME_SUFFIX}.cpp.in + "${FILENAME_SUFFIX}${FILE_SUFFIX}.cpp") + + # Add to output list + list( + APPEND + GEN_KERNEL_SRCS + "${CMAKE_CURRENT_BINARY_DIR}/${FILENAME_SUFFIX}${FILE_SUFFIX}.cpp" + ) + endforeach() endforeach() endforeach() endforeach() diff --git a/csrc/xpu/attn/xe_2/paged_decode_extern.hpp b/csrc/xpu/attn/xe_2/paged_decode_extern.hpp index 8a9e1b7d2..31941a289 100644 --- a/csrc/xpu/attn/xe_2/paged_decode_extern.hpp +++ b/csrc/xpu/attn/xe_2/paged_decode_extern.hpp @@ -42,21 +42,42 @@ // instantiated in its own .cpp file generated by CMake. // Helper macro to declare a single extern template instantiation -#define DECLARE_DECODE_DISPATCH_EXTERN(POLICY, CAUSAL, LOCAL, SINK) \ - extern template void \ - decode_policy_dispatch_impl( \ - sycl::queue & queue, \ - CutlassQKType & cuQKType, \ - const paged_decode_args_t& args); +#define DECLARE_DECODE_DISPATCH_EXTERN( \ + POLICY, Q_T, KV_T, O_T, CAUSAL, LOCAL, SINK) \ + extern template void decode_policy_dispatch_typed_impl< \ + POLICY, \ + Q_T, \ + KV_T, \ + O_T, \ + CAUSAL, \ + LOCAL, \ + SINK>(sycl::queue & queue, const paged_decode_args_t& args); + +// Allowed dtype combinations (must match runtime dispatch constraints): +// 1) Q=half -> O=half, KV in {half, fp8_e4m3, fp8_e5m2} +// 2) Q=bf16 -> O=bf16, KV in {bf16, fp8_e4m3, fp8_e5m2} +#define DECLARE_ALLOWED_DTYPES(POLICY, CAUSAL, LOCAL, SINK) \ + DECLARE_DECODE_DISPATCH_EXTERN( \ + POLICY, half_t, half_t, half_t, CAUSAL, LOCAL, SINK) \ + DECLARE_DECODE_DISPATCH_EXTERN( \ + POLICY, half_t, float_e4m3_t, half_t, CAUSAL, LOCAL, SINK) \ + DECLARE_DECODE_DISPATCH_EXTERN( \ + POLICY, half_t, float_e5m2_t, half_t, CAUSAL, LOCAL, SINK) \ + DECLARE_DECODE_DISPATCH_EXTERN( \ + POLICY, bfloat16_t, bfloat16_t, bfloat16_t, CAUSAL, LOCAL, SINK) \ + DECLARE_DECODE_DISPATCH_EXTERN( \ + POLICY, bfloat16_t, float_e4m3_t, bfloat16_t, CAUSAL, LOCAL, SINK) \ + DECLARE_DECODE_DISPATCH_EXTERN( \ + POLICY, bfloat16_t, float_e5m2_t, bfloat16_t, CAUSAL, LOCAL, SINK) // Generate all 8 bool combinations for a given policy using nested macros // Pattern: Causal, Local, Sink (all permutations of 3 bools = 2^3 = 8) // This hierarchical approach makes it easy to extend to more bool parameters // Level 3: Iterate over Sink values (innermost) -#define DECLARE_FOR_SINK(POLICY, CAUSAL, LOCAL) \ - DECLARE_DECODE_DISPATCH_EXTERN(POLICY, CAUSAL, LOCAL, false) \ - DECLARE_DECODE_DISPATCH_EXTERN(POLICY, CAUSAL, LOCAL, true) +#define DECLARE_FOR_SINK(POLICY, CAUSAL, LOCAL) \ + DECLARE_ALLOWED_DTYPES(POLICY, CAUSAL, LOCAL, false) \ + DECLARE_ALLOWED_DTYPES(POLICY, CAUSAL, LOCAL, true) // Level 2: Iterate over Local values #define DECLARE_FOR_LOCAL(POLICY, CAUSAL) \ @@ -75,4 +96,5 @@ PAGED_DECODE_POLICY_LIST(DECLARE_ALL_BOOL_COMBINATIONS) #undef DECLARE_ALL_BOOL_COMBINATIONS #undef DECLARE_FOR_LOCAL #undef DECLARE_FOR_SINK +#undef DECLARE_ALLOWED_DTYPES #undef DECLARE_DECODE_DISPATCH_EXTERN diff --git a/csrc/xpu/attn/xe_2/paged_decode_kernel_template.cpp.in b/csrc/xpu/attn/xe_2/paged_decode_kernel_template.cpp.in index 1487fa976..b5e659fa0 100644 --- a/csrc/xpu/attn/xe_2/paged_decode_kernel_template.cpp.in +++ b/csrc/xpu/attn/xe_2/paged_decode_kernel_template.cpp.in @@ -9,6 +9,9 @@ using namespace cute; // to generate specific kernel instantiations for each policy combination. // clang-format off +#define IMPL_Q_T ${IMPL_Q_T} +#define IMPL_KV_T ${IMPL_KV_T} +#define IMPL_O_T ${IMPL_O_T} #define IMPL_POLICY ${IMPL_POLICY} #cmakedefine01 IMPL_KISCAUSAL #cmakedefine01 IMPL_KISLOCAL @@ -18,18 +21,20 @@ using namespace cute; // ============================================================================= // Explicit Template Instantiation // ============================================================================= -// Instantiate the decode_policy_dispatch_impl function template with the +// Instantiate the decode_policy_dispatch_typed_impl function template with the // specific policy and boolean flag combinations provided by CMake. This // produces one compiled kernel per source file. -#define INSTANTIATE_KERNEL() \ - template void decode_policy_dispatch_impl< \ - IMPL_POLICY, \ - static_cast(IMPL_KISCAUSAL), \ - static_cast(IMPL_KISLOCAL), \ - static_cast(IMPL_KISSINK)>( \ - sycl::queue & queue, \ - CutlassQKType & cuQKType, \ +#define INSTANTIATE_KERNEL() \ + template void decode_policy_dispatch_typed_impl< \ + IMPL_POLICY, \ + IMPL_Q_T, \ + IMPL_KV_T, \ + IMPL_O_T, \ + static_cast(IMPL_KISCAUSAL), \ + static_cast(IMPL_KISLOCAL), \ + static_cast(IMPL_KISSINK)>( \ + sycl::queue & queue, \ const paged_decode_args_t& args); INSTANTIATE_KERNEL() diff --git a/csrc/xpu/attn/xe_2/paged_decode_utils.hpp b/csrc/xpu/attn/xe_2/paged_decode_utils.hpp index 88cc8c930..dae7c9eeb 100644 --- a/csrc/xpu/attn/xe_2/paged_decode_utils.hpp +++ b/csrc/xpu/attn/xe_2/paged_decode_utils.hpp @@ -6,24 +6,24 @@ using namespace cute; template void decode_policy_dispatch_func( sycl::queue& queue, - CutlassQKType& cuQKType, + CutlassQKOType& cuQKOType, const paged_decode_args_t& args) { - decode_policy_dispatch_impl(queue, cuQKType, args); + decode_policy_dispatch_impl(queue, cuQKOType, args); } template void decode_policy_dispatch_func( sycl::queue& queue, - CutlassQKType& cuQKType, + CutlassQKOType& cuQKOType, const paged_decode_args_t& args, bool b, Ts... ts) { if (b) { decode_policy_dispatch_func( - queue, cuQKType, args, ts...); + queue, cuQKOType, args, ts...); } else { decode_policy_dispatch_func( - queue, cuQKType, args, ts...); + queue, cuQKOType, args, ts...); } } @@ -31,38 +31,38 @@ template inline void dispatch_by_head_size( const int head_case, sycl::queue& queue, - CutlassQKType& cuQKType, + CutlassQKOType& cuQKOType, const paged_decode_args_t& args) { switch (head_case) { case 0: decode_policy_dispatch_func< decode_policy_qpacked_head>( - queue, cuQKType, args, args.is_causal, args.is_local, args.is_sink); + queue, cuQKOType, args, args.is_causal, args.is_local, args.is_sink); break; case 1: decode_policy_dispatch_func< decode_policy_qpacked_head>( - queue, cuQKType, args, args.is_causal, args.is_local, args.is_sink); + queue, cuQKOType, args, args.is_causal, args.is_local, args.is_sink); break; case 2: decode_policy_dispatch_func< decode_policy_qpacked_head>( - queue, cuQKType, args, args.is_causal, args.is_local, args.is_sink); + queue, cuQKOType, args, args.is_causal, args.is_local, args.is_sink); break; case 3: decode_policy_dispatch_func< decode_policy_qpacked_head>( - queue, cuQKType, args, args.is_causal, args.is_local, args.is_sink); + queue, cuQKOType, args, args.is_causal, args.is_local, args.is_sink); break; case 4: decode_policy_dispatch_func< decode_policy_qpacked_head>( - queue, cuQKType, args, args.is_causal, args.is_local, args.is_sink); + queue, cuQKOType, args, args.is_causal, args.is_local, args.is_sink); break; case 5: decode_policy_dispatch_func< decode_policy_qpacked_head>( - queue, cuQKType, args, args.is_causal, args.is_local, args.is_sink); + queue, cuQKOType, args, args.is_causal, args.is_local, args.is_sink); break; default: TORCH_CHECK(false, "Unsupported head size for fmha"); @@ -74,14 +74,14 @@ inline void dispatch_by_page_size( const int page_size, const int head_case, sycl::queue& queue, - CutlassQKType& cuQKType, + CutlassQKOType& cuQKOType, const paged_decode_args_t& args) { switch (page_size) { case 64: - dispatch_by_head_size(head_case, queue, cuQKType, args); + dispatch_by_head_size(head_case, queue, cuQKOType, args); break; case 128: - dispatch_by_head_size(head_case, queue, cuQKType, args); + dispatch_by_head_size(head_case, queue, cuQKOType, args); break; default: TORCH_CHECK(false, "Unsupported page size for fmha"); diff --git a/csrc/xpu/attn/xe_2/paged_decode_xe2.cpp b/csrc/xpu/attn/xe_2/paged_decode_xe2.cpp index c3020e656..e53b4fb85 100644 --- a/csrc/xpu/attn/xe_2/paged_decode_xe2.cpp +++ b/csrc/xpu/attn/xe_2/paged_decode_xe2.cpp @@ -208,7 +208,7 @@ void cutlass_paged_decode_impl( "(head_dim), got stride=", value_cache.stride(-1)); - CutlassQKType cuQKType = aten_to_Cutlass_qk_dtype(query, key_cache); + CutlassQKOType cuQKOType = aten_to_Cutlass_qko_dtype(query, key_cache, out); static constexpr int max_head_size = 512; TORCH_CHECK( @@ -230,9 +230,9 @@ void cutlass_paged_decode_impl( int num_q_group_size = num_heads_q / num_heads_kv; if (num_q_group_size <= 8) { - dispatch_by_page_size<_8>(block_size, head_case, queue, cuQKType, args); + dispatch_by_page_size<_8>(block_size, head_case, queue, cuQKOType, args); } else if (num_q_group_size <= 16) { - dispatch_by_page_size<_16>(block_size, head_case, queue, cuQKType, args); + dispatch_by_page_size<_16>(block_size, head_case, queue, cuQKOType, args); } else { TORCH_CHECK(false, "Unsupported num_heads_q / num_heads_kv for fmha"); }