Skip to content

Commit 3dca5b1

Browse files
committed
split attention template via data types
Signed-off-by: Xinyu Chen <xinyu1.chen@intel.com>
1 parent 7711ef0 commit 3dca5b1

13 files changed

Lines changed: 447 additions & 317 deletions

csrc/xpu/attn/xe_2/chunk_prefill.hpp

Lines changed: 102 additions & 100 deletions
Original file line numberDiff line numberDiff line change
@@ -338,111 +338,113 @@ struct FMHAConfig {
338338
}
339339
};
340340

341+
template <
342+
typename chunk_policy,
343+
typename ElementQ,
344+
typename ElementKV,
345+
typename ElementO,
346+
bool Paged,
347+
bool Causal,
348+
bool Local,
349+
bool Sink>
350+
void policy_dispatch_typed_impl(
351+
sycl::queue& queue, const chunk_prefill_args_t& args) {
352+
const int PipelineStages = 2;
353+
return FMHAConfig<
354+
typename chunk_policy::ShapeQK,
355+
typename chunk_policy::ShapePV,
356+
typename chunk_policy::ShapeOut,
357+
typename chunk_policy::SubgroupLayoutQK,
358+
void,
359+
PipelineStages,
360+
Paged,
361+
Causal,
362+
Local,
363+
Sink,
364+
ElementQ,
365+
ElementKV,
366+
ElementKV,
367+
ElementO>::kernel_dispatch(queue, args);
368+
}
369+
341370
template <typename chunk_policy, bool Paged, bool Causal, bool Local, bool Sink>
342371
void policy_dispatch_impl(
343372
sycl::queue& queue,
344-
CutlassQKType& cuQKType,
373+
CutlassQKOType& cuQKOType,
345374
const chunk_prefill_args_t& args) {
346-
const int PipelineStages = 2;
347-
if (cuQKType.q_type == CutlassDType::half) {
348-
if (cuQKType.k_type == CutlassDType::half) {
349-
return FMHAConfig<
350-
typename chunk_policy::ShapeQK,
351-
typename chunk_policy::ShapePV,
352-
typename chunk_policy::ShapeOut,
353-
typename chunk_policy::SubgroupLayoutQK,
354-
void,
355-
PipelineStages,
356-
Paged,
357-
Causal,
358-
Local,
359-
Sink,
360-
half_t,
361-
half_t,
362-
half_t,
363-
half_t>::kernel_dispatch(queue, args);
364-
} else if (cuQKType.k_type == CutlassDType::float8_e4m3) {
365-
return FMHAConfig<
366-
typename chunk_policy::ShapeQK,
367-
typename chunk_policy::ShapePV,
368-
typename chunk_policy::ShapeOut,
369-
typename chunk_policy::SubgroupLayoutQK,
370-
void,
371-
PipelineStages,
372-
Paged,
373-
Causal,
374-
Local,
375-
Sink,
376-
half_t,
377-
float_e4m3_t,
378-
float_e4m3_t,
379-
half_t>::kernel_dispatch(queue, args);
380-
} else if (cuQKType.k_type == CutlassDType::float8_e5m2) {
381-
return FMHAConfig<
382-
typename chunk_policy::ShapeQK,
383-
typename chunk_policy::ShapePV,
384-
typename chunk_policy::ShapeOut,
385-
typename chunk_policy::SubgroupLayoutQK,
386-
void,
387-
PipelineStages,
388-
Paged,
389-
Causal,
390-
Local,
391-
Sink,
392-
half_t,
393-
float_e5m2_t,
394-
float_e5m2_t,
395-
half_t>::kernel_dispatch(queue, args);
375+
if (cuQKOType.o_type == CutlassDType::half) {
376+
if (cuQKOType.q_type == CutlassDType::half) {
377+
if (cuQKOType.k_type == CutlassDType::half) {
378+
return policy_dispatch_typed_impl<
379+
chunk_policy,
380+
half_t,
381+
half_t,
382+
half_t,
383+
Paged,
384+
Causal,
385+
Local,
386+
Sink>(queue, args);
387+
} else if (cuQKOType.k_type == CutlassDType::float8_e4m3) {
388+
return policy_dispatch_typed_impl<
389+
chunk_policy,
390+
half_t,
391+
float_e4m3_t,
392+
half_t,
393+
Paged,
394+
Causal,
395+
Local,
396+
Sink>(queue, args);
397+
} else if (cuQKOType.k_type == CutlassDType::float8_e5m2) {
398+
return policy_dispatch_typed_impl<
399+
chunk_policy,
400+
half_t,
401+
float_e5m2_t,
402+
half_t,
403+
Paged,
404+
Causal,
405+
Local,
406+
Sink>(queue, args);
407+
} else {
408+
TORCH_CHECK(false, "Unsupported KV dtype for chunk prefill dispatch");
409+
}
396410
}
397-
} else {
398-
if (cuQKType.k_type == CutlassDType::bfloat16) {
399-
return FMHAConfig<
400-
typename chunk_policy::ShapeQK,
401-
typename chunk_policy::ShapePV,
402-
typename chunk_policy::ShapeOut,
403-
typename chunk_policy::SubgroupLayoutQK,
404-
void,
405-
PipelineStages,
406-
Paged,
407-
Causal,
408-
Local,
409-
Sink,
410-
bfloat16_t,
411-
bfloat16_t,
412-
bfloat16_t,
413-
bfloat16_t>::kernel_dispatch(queue, args);
414-
} else if (cuQKType.k_type == CutlassDType::float8_e4m3) {
415-
return FMHAConfig<
416-
typename chunk_policy::ShapeQK,
417-
typename chunk_policy::ShapePV,
418-
typename chunk_policy::ShapeOut,
419-
typename chunk_policy::SubgroupLayoutQK,
420-
void,
421-
PipelineStages,
422-
Paged,
423-
Causal,
424-
Local,
425-
Sink,
426-
bfloat16_t,
427-
float_e4m3_t,
428-
float_e4m3_t,
429-
bfloat16_t>::kernel_dispatch(queue, args);
430-
} else if (cuQKType.k_type == CutlassDType::float8_e5m2) {
431-
return FMHAConfig<
432-
typename chunk_policy::ShapeQK,
433-
typename chunk_policy::ShapePV,
434-
typename chunk_policy::ShapeOut,
435-
typename chunk_policy::SubgroupLayoutQK,
436-
void,
437-
PipelineStages,
438-
Paged,
439-
Causal,
440-
Local,
441-
Sink,
442-
bfloat16_t,
443-
float_e5m2_t,
444-
float_e5m2_t,
445-
bfloat16_t>::kernel_dispatch(queue, args);
411+
} else if (cuQKOType.o_type == CutlassDType::bfloat16) {
412+
if (cuQKOType.q_type == CutlassDType::bfloat16) {
413+
if (cuQKOType.k_type == CutlassDType::bfloat16) {
414+
return policy_dispatch_typed_impl<
415+
chunk_policy,
416+
bfloat16_t,
417+
bfloat16_t,
418+
bfloat16_t,
419+
Paged,
420+
Causal,
421+
Local,
422+
Sink>(queue, args);
423+
} else if (cuQKOType.k_type == CutlassDType::float8_e4m3) {
424+
return policy_dispatch_typed_impl<
425+
chunk_policy,
426+
bfloat16_t,
427+
float_e4m3_t,
428+
bfloat16_t,
429+
Paged,
430+
Causal,
431+
Local,
432+
Sink>(queue, args);
433+
} else if (cuQKOType.k_type == CutlassDType::float8_e5m2) {
434+
return policy_dispatch_typed_impl<
435+
chunk_policy,
436+
bfloat16_t,
437+
float_e5m2_t,
438+
bfloat16_t,
439+
Paged,
440+
Causal,
441+
Local,
442+
Sink>(queue, args);
443+
} else {
444+
TORCH_CHECK(false, "Unsupported KV dtype for chunk prefill dispatch");
445+
}
446446
}
447+
} else {
448+
TORCH_CHECK(false, "Unsupported output dtype for chunk prefill dispatch");
447449
}
448450
}

csrc/xpu/attn/xe_2/chunk_prefill_configure.cmake

Lines changed: 33 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,33 +1,47 @@
11
function(fmha_forward_configure FILENAME_SUFFIX)
22
set(GEN_KERNEL_SRCS) # output
3-
set(L_TYPES "fp16" "bf16")
43
set(L_BOOLS "false" "true")
54
set(BOOL_FLAG_false "f")
65
set(BOOL_FLAG_true "t")
76
set(policy_list
87
"chunk_policy_head64" "chunk_policy_head96" "chunk_policy_head128"
98
"chunk_policy_head192" "chunk_policy_head256" "chunk_policy_head512")
109

11-
set(IMPL_KV_T "fp16")
10+
# Allowed dtype combinations must match runtime dispatch constraints. Format:
11+
# Q_TYPE|KV_TYPE|O_TYPE|FILE_TAG
12+
set(dtype_combo_list
13+
"half_t|half_t|half_t|h_h_h"
14+
"half_t|float_e4m3_t|half_t|h_e4_h"
15+
"half_t|float_e5m2_t|half_t|h_e5_h"
16+
"bfloat16_t|bfloat16_t|bfloat16_t|b_b_b"
17+
"bfloat16_t|float_e4m3_t|bfloat16_t|b_e4_b"
18+
"bfloat16_t|float_e5m2_t|bfloat16_t|b_e5_b")
1219

1320
foreach(IMPL_POLICY ${policy_list})
14-
# foreach(IMPL_T ${L_TYPES})
15-
foreach(IMPL_KISPAGED ${L_BOOLS})
16-
foreach(IMPL_KISCAUSAL ${L_BOOLS})
17-
foreach(IMPL_KISLOCAL ${L_BOOLS})
18-
foreach(IMPL_KISSINK ${L_BOOLS})
19-
set(FILE_SUFFIX "${IMPL_POLICY}_")
20-
set(FILE_SUFFIX "${FILE_SUFFIX}${BOOL_FLAG_${IMPL_KISPAGED}}")
21-
set(FILE_SUFFIX "${FILE_SUFFIX}${BOOL_FLAG_${IMPL_KISCAUSAL}}")
22-
set(FILE_SUFFIX "${FILE_SUFFIX}${BOOL_FLAG_${IMPL_KISSINK}}")
23-
set(FILE_SUFFIX "${FILE_SUFFIX}${BOOL_FLAG_${IMPL_KISLOCAL}}")
24-
configure_file(${FILENAME_SUFFIX}.cpp.in
25-
"${FILENAME_SUFFIX}_${FILE_SUFFIX}.cpp")
26-
list(
27-
APPEND
28-
GEN_KERNEL_SRCS
29-
"${CMAKE_CURRENT_BINARY_DIR}/${FILENAME_SUFFIX}_${FILE_SUFFIX}.cpp"
30-
)
21+
foreach(dtype_combo ${dtype_combo_list})
22+
string(REPLACE "|" ";" dtype_parts "${dtype_combo}")
23+
list(GET dtype_parts 0 IMPL_Q_T)
24+
list(GET dtype_parts 1 IMPL_KV_T)
25+
list(GET dtype_parts 2 IMPL_O_T)
26+
list(GET dtype_parts 3 DTYPE_TAG)
27+
28+
foreach(IMPL_KISPAGED ${L_BOOLS})
29+
foreach(IMPL_KISCAUSAL ${L_BOOLS})
30+
foreach(IMPL_KISLOCAL ${L_BOOLS})
31+
foreach(IMPL_KISSINK ${L_BOOLS})
32+
set(FILE_SUFFIX "${IMPL_POLICY}_${DTYPE_TAG}_")
33+
set(FILE_SUFFIX "${FILE_SUFFIX}${BOOL_FLAG_${IMPL_KISPAGED}}")
34+
set(FILE_SUFFIX "${FILE_SUFFIX}${BOOL_FLAG_${IMPL_KISCAUSAL}}")
35+
set(FILE_SUFFIX "${FILE_SUFFIX}${BOOL_FLAG_${IMPL_KISSINK}}")
36+
set(FILE_SUFFIX "${FILE_SUFFIX}${BOOL_FLAG_${IMPL_KISLOCAL}}")
37+
configure_file(${FILENAME_SUFFIX}.cpp.in
38+
"${FILENAME_SUFFIX}_${FILE_SUFFIX}.cpp")
39+
list(
40+
APPEND
41+
GEN_KERNEL_SRCS
42+
"${CMAKE_CURRENT_BINARY_DIR}/${FILENAME_SUFFIX}_${FILE_SUFFIX}.cpp"
43+
)
44+
endforeach()
3145
endforeach()
3246
endforeach()
3347
endforeach()

csrc/xpu/attn/xe_2/chunk_prefill_extern.hpp

Lines changed: 48 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -22,22 +22,59 @@
2222
// that include chunk_prefill.hpp. Each specialization is explicitly
2323
// instantiated in its own .cpp file generated by CMake.
2424

25-
// Helper macro to declare a single extern template instantiation
26-
#define DECLARE_POLICY_DISPATCH_EXTERN(POLICY, PAGED, CAUSAL, LOCAL, SINK) \
27-
extern template void \
28-
policy_dispatch_impl<POLICY, PAGED, CAUSAL, LOCAL, SINK>( \
29-
sycl::queue & queue, \
30-
CutlassQKType & cuQKType, \
31-
const chunk_prefill_args_t& args);
25+
// Helper macro to declare a single extern template instantiation.
26+
// Template order is: policy, ElementQ, ElementKV, ElementO, bools...
27+
#define DECLARE_POLICY_DISPATCH_EXTERN( \
28+
POLICY, Q_T, KV_T, O_T, PAGED, CAUSAL, LOCAL, SINK) \
29+
extern template void policy_dispatch_typed_impl< \
30+
POLICY, \
31+
Q_T, \
32+
KV_T, \
33+
O_T, \
34+
PAGED, \
35+
CAUSAL, \
36+
LOCAL, \
37+
SINK>(sycl::queue & queue, const chunk_prefill_args_t& args);
38+
39+
// Allowed dtype combinations (must match runtime dispatch constraints):
40+
// 1) Q=half -> O=half, KV in {half, fp8_e4m3, fp8_e5m2}
41+
// 2) Q=bf16 -> O=bf16, KV in {bf16, fp8_e4m3, fp8_e5m2}
42+
#define DECLARE_ALLOWED_DTYPES(POLICY, PAGED, CAUSAL, LOCAL, SINK) \
43+
DECLARE_POLICY_DISPATCH_EXTERN( \
44+
POLICY, half_t, half_t, half_t, PAGED, CAUSAL, LOCAL, SINK) \
45+
DECLARE_POLICY_DISPATCH_EXTERN( \
46+
POLICY, half_t, float_e4m3_t, half_t, PAGED, CAUSAL, LOCAL, SINK) \
47+
DECLARE_POLICY_DISPATCH_EXTERN( \
48+
POLICY, half_t, float_e5m2_t, half_t, PAGED, CAUSAL, LOCAL, SINK) \
49+
DECLARE_POLICY_DISPATCH_EXTERN( \
50+
POLICY, bfloat16_t, bfloat16_t, bfloat16_t, PAGED, CAUSAL, LOCAL, SINK) \
51+
DECLARE_POLICY_DISPATCH_EXTERN( \
52+
POLICY, \
53+
bfloat16_t, \
54+
float_e4m3_t, \
55+
bfloat16_t, \
56+
PAGED, \
57+
CAUSAL, \
58+
LOCAL, \
59+
SINK) \
60+
DECLARE_POLICY_DISPATCH_EXTERN( \
61+
POLICY, \
62+
bfloat16_t, \
63+
float_e5m2_t, \
64+
bfloat16_t, \
65+
PAGED, \
66+
CAUSAL, \
67+
LOCAL, \
68+
SINK)
3269

3370
// Generate all 16 bool combinations for a given policy using nested macros
3471
// Pattern: Paged, Causal, Local, Sink (all permutations of 4 bools = 2^4 = 16)
3572
// This hierarchical approach makes it easy to extend to more bool parameters
3673

3774
// Level 4: Iterate over Sink values (innermost)
38-
#define DECLARE_FOR_SINK(POLICY, PAGED, CAUSAL, LOCAL) \
39-
DECLARE_POLICY_DISPATCH_EXTERN(POLICY, PAGED, CAUSAL, LOCAL, false) \
40-
DECLARE_POLICY_DISPATCH_EXTERN(POLICY, PAGED, CAUSAL, LOCAL, true)
75+
#define DECLARE_FOR_SINK(POLICY, PAGED, CAUSAL, LOCAL) \
76+
DECLARE_ALLOWED_DTYPES(POLICY, PAGED, CAUSAL, LOCAL, false) \
77+
DECLARE_ALLOWED_DTYPES(POLICY, PAGED, CAUSAL, LOCAL, true)
4178

4279
// Level 3: Iterate over Local values
4380
#define DECLARE_FOR_LOCAL(POLICY, PAGED, CAUSAL) \
@@ -62,4 +99,5 @@ CHUNK_POLICY_LIST(DECLARE_ALL_BOOL_COMBINATIONS)
6299
#undef DECLARE_FOR_CAUSAL
63100
#undef DECLARE_FOR_LOCAL
64101
#undef DECLARE_FOR_SINK
102+
#undef DECLARE_ALLOWED_DTYPES
65103
#undef DECLARE_POLICY_DISPATCH_EXTERN

csrc/xpu/attn/xe_2/chunk_prefill_kernel_template.cpp.in

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -4,24 +4,27 @@ using namespace cute;
44

55
// clang-format off
66
// macros to be filled in CMake
7-
#define IMPL_T ${IMPL_T}
7+
#define IMPL_Q_T ${IMPL_Q_T}
88
#define IMPL_KV_T ${IMPL_KV_T}
9+
#define IMPL_O_T ${IMPL_O_T}
910
#define IMPL_POLICY ${IMPL_POLICY}
1011
#cmakedefine01 IMPL_KISPAGED
1112
#cmakedefine01 IMPL_KISCAUSAL
1213
#cmakedefine01 IMPL_KISSINK
1314
#cmakedefine01 IMPL_KISLOCAL
1415
// clang-format on
1516

16-
#define INSTANTIATE_KERNEL() \
17-
template void policy_dispatch_impl< \
18-
IMPL_POLICY, \
19-
static_cast<bool>(IMPL_KISPAGED), \
20-
static_cast<bool>(IMPL_KISCAUSAL), \
21-
static_cast<bool>(IMPL_KISLOCAL), \
22-
static_cast<bool>(IMPL_KISSINK)>( \
23-
sycl::queue & queue, \
24-
CutlassQKType& cuQKType, \
17+
#define INSTANTIATE_KERNEL() \
18+
template void policy_dispatch_typed_impl< \
19+
IMPL_POLICY, \
20+
IMPL_Q_T, \
21+
IMPL_KV_T, \
22+
IMPL_O_T, \
23+
static_cast<bool>(IMPL_KISPAGED), \
24+
static_cast<bool>(IMPL_KISCAUSAL), \
25+
static_cast<bool>(IMPL_KISLOCAL), \
26+
static_cast<bool>(IMPL_KISSINK)>( \
27+
sycl::queue & queue, \
2528
const chunk_prefill_args_t& args);
2629

2730
INSTANTIATE_KERNEL()

0 commit comments

Comments
 (0)