Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
60 commits
Select commit Hold shift + click to select a range
57a47b0
Let fmha_fwd_v3() compatible with fmha_fwd()
poyenc Nov 3, 2025
b93a0ad
Decouple get_fwd_blobs() and FmhaFwdKernel
poyenc Nov 3, 2025
6e46366
Decouple compatibility checks from get_fwd_blobs()
poyenc Nov 4, 2025
756a1b8
Extract product feature checks out from get_fwd_blobs()
poyenc Nov 4, 2025
4c5a68e
Remove duplicated code in factories and redundant checks
poyenc Nov 4, 2025
41cd25b
Remove FmhaFwdKernel<>::GetName()
poyenc Nov 5, 2025
3e0ad2c
Let FmhaFwdApiPool support pipelines with different mask_impl
poyenc Nov 5, 2025
4e6153b
Add tile setting for fmha fwd v3 pipeline
poyenc Nov 5, 2025
6eaa880
Add fwd v3 instances to tile_example_fmha_fwd manually
poyenc Nov 5, 2025
d6a99c2
Remove unused function import
poyenc Nov 7, 2025
76b2bc0
Undo irrelevant changes
poyenc Nov 7, 2025
260908a
Remove fwd v3 instances from tile_example_fmha_fwd
poyenc Nov 7, 2025
286a24b
Finish fmha fwd v3 kernel instance codegen
poyenc Nov 7, 2025
006692f
Fix formatting
poyenc Nov 10, 2025
051a6be
Remove unused F_idx attribute
poyenc Nov 10, 2025
0b15146
Add is_generic_attention_mask<> traits
poyenc Nov 10, 2025
a176996
Add constraints to the fmha fwd v3 pipeline
poyenc Nov 10, 2025
10ecccc
Unify traits & problem used for fmha fwd v3
poyenc Nov 10, 2025
16d4573
Unify kernel launch code for fmha fwd v2 & v3
poyenc Nov 10, 2025
1810d6f
Unify kernel template selection logic
poyenc Nov 11, 2025
05ffeac
Use same kernel codegen template for both v2 & v3
poyenc Nov 11, 2025
7b9b7ee
Rename api() property as render() method
poyenc Nov 11, 2025
923a97a
Allow specifying filter for fmha fwd api pool
poyenc Nov 11, 2025
be4d123
Allow specifying function name when rendering api pool items
poyenc Nov 11, 2025
b66d3f5
Separate fmha fwd v3 kernel dispatching logic from v2
poyenc Nov 11, 2025
48487b5
Remove lambda assignment
poyenc Nov 11, 2025
fd8312c
Add simple v2/v3 dispatch logic
poyenc Nov 11, 2025
0a3cfe1
Stop generating empty if-clauses
poyenc Nov 11, 2025
9da8cbb
Use "".join() to concatenate fmha fwd api string content
poyenc Nov 11, 2025
6793877
Add more feature checks for fmha fwd v3 pipeline
poyenc Nov 12, 2025
772c30f
Check features before dispatch to fmha_fwd_v3()
poyenc Nov 12, 2025
eebe510
Add more feature checks for fmha_fwd_v3()
poyenc Nov 12, 2025
1730875
Add missing filter call
poyenc Nov 12, 2025
a62afee
Use Tuple to reserve the dtype orders
poyenc Nov 12, 2025
9c89220
Fix wrong pipeline matching logic
poyenc Nov 12, 2025
23c0022
Add fmha fwd v3 group mode instances
poyenc Nov 13, 2025
6526b59
Add functor_transform<>
poyenc Nov 13, 2025
291cea6
Add type constraints to make_tile_window()
poyenc Nov 13, 2025
f4d92f1
Remove fmha fwd v3 example
poyenc Nov 13, 2025
2df5019
Fix wrong product(aiter mha_fwd()) config
poyenc Nov 13, 2025
66a874a
Merge branch 'develop' into poyenc/integrate-fmha-fwd-v2-v3-apis
poyenc Nov 14, 2025
1df098d
Fix wrong fmha fwd v2/v3 selection logic
poyenc Nov 16, 2025
0d0a25b
Merge branch 'poyenc/integrate-fmha-fwd-v2-v3-apis' of github.com:poy…
poyenc Nov 16, 2025
68fd415
Merge branch 'develop' into poyenc/integrate-fmha-fwd-v2-v3-apis
poyenc Nov 16, 2025
8e0d9dd
Fix formatting
poyenc Nov 17, 2025
51c30ba
Merge branch 'poyenc/integrate-fmha-fwd-v2-v3-apis' of github.com:poy…
poyenc Nov 17, 2025
d33691d
Merge branch 'develop' into poyenc/integrate-fmha-fwd-v2-v3-apis
poyenc Nov 18, 2025
72ad9d7
Merge branch 'develop' into poyenc/integrate-fmha-fwd-v2-v3-apis
poyenc Nov 19, 2025
d0730ba
Merge branch 'develop' into poyenc/integrate-fmha-fwd-v2-v3-apis
poyenc Nov 20, 2025
13aee99
Merge branch 'develop' into poyenc/integrate-fmha-fwd-v2-v3-apis
illsilin Nov 20, 2025
7ba44fd
Merge branch 'develop' into poyenc/integrate-fmha-fwd-v2-v3-apis
poyenc Nov 21, 2025
9c5364d
Merge branch 'develop' into poyenc/integrate-fmha-fwd-v2-v3-apis
poyenc Nov 21, 2025
615e4b8
Merge branch 'develop' into poyenc/integrate-fmha-fwd-v2-v3-apis
poyenc Nov 22, 2025
4464745
Merge branch 'develop' into poyenc/integrate-fmha-fwd-v2-v3-apis
poyenc Nov 23, 2025
f8ae943
Merge branch 'develop' into poyenc/integrate-fmha-fwd-v2-v3-apis
poyenc Dec 3, 2025
cf1f135
Add comment to warning v3 kernel users
poyenc Dec 3, 2025
608a253
Fix wrong codegen logics
poyenc Dec 3, 2025
02ed663
Remove unnecessary param
poyenc Dec 3, 2025
0e29033
Fix format
poyenc Dec 3, 2025
5e1f431
Merge branch 'develop' into poyenc/integrate-fmha-fwd-v2-v3-apis
poyenc Dec 4, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 0 additions & 34 deletions example/ck_tile/01_fmha/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -208,40 +208,6 @@ add_executable(${EXAMPLE_FMHA_BWD} EXCLUDE_FROM_ALL example_fmha_bwd.cpp)
target_link_libraries(${EXAMPLE_FMHA_BWD} ${FMHA_BWD_INSTANCES})
target_include_directories(${EXAMPLE_FMHA_BWD} PRIVATE ${CMAKE_CURRENT_LIST_DIR})

# add fmha_fwd_v3 example
set(EXAMPLE_FMHA_FWD_V3 "tile_example_fmha_fwd_v3")
message(DEBUG "adding example ${EXAMPLE_FMHA_FWD_V3}")

add_executable(${EXAMPLE_FMHA_FWD_V3} EXCLUDE_FROM_ALL example_fmha_fwd_v3.cpp)
target_include_directories(${EXAMPLE_FMHA_FWD_V3} PRIVATE ${CMAKE_CURRENT_LIST_DIR})
file(GLOB FMHA_FWD_V3_INSTANCES CONFIGURE_DEPENDS
"${CMAKE_CURRENT_LIST_DIR}/instances/*.cpp"
)
target_sources(${EXAMPLE_FMHA_FWD_V3} PRIVATE
fmha_fwd_v3.cpp
${FMHA_FWD_V3_INSTANCES}
)

set(EXAMPLE_FMHA_FWD_V3_COMPILE_OPTIONS)
list(APPEND EXAMPLE_FMHA_FWD_V3_COMPILE_OPTIONS
-fgpu-flush-denormals-to-zero
-Wno-undefined-func-template
--save-temps
)
set(EXAMPLE_FMHA_FWD_V3_COMPILE_DEFINITIONS)

check_cxx_compiler_flag("-mllvm --amdgpu-disable-packed-fp32=1" HAS_DISABLE_PACKED_FP32)
if(HAS_DISABLE_PACKED_FP32)
list(APPEND EXAMPLE_FMHA_FWD_V3_COMPILE_OPTIONS
-mllvm --amdgpu-disable-packed-fp32=1
)
list(APPEND EXAMPLE_FMHA_FWD_V3_COMPILE_DEFINITIONS
-DCK_TILE_DISABLE_PACKED_FP32=1
)
endif()

target_compile_options(${EXAMPLE_FMHA_FWD_V3} PRIVATE ${EXAMPLE_FMHA_FWD_V3_COMPILE_OPTIONS})
target_compile_definitions(${EXAMPLE_FMHA_FWD_V3} PRIVATE ${EXAMPLE_FMHA_FWD_V3_COMPILE_DEFINITIONS})
# TODO: we have to turn off this global prop, otherwise the progress bar generated
# by cmake will print too many files, execvp: /bin/sh: Argument list too long
# however, this property may affect global
Expand Down
20 changes: 17 additions & 3 deletions example/ck_tile/01_fmha/codegen/cpp_symbol_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,16 +30,24 @@
}


def get_mask_map(mask: str):
if mask == "generic":
def get_mask_map(mask_impl: str):
if mask_impl == "generic":
return _MASK_MAP
elif mask == "simplified":
elif mask_impl == "simplified":
return _MASK_SIMPLIFIED_MAP
else:
assert False
return None


def get_mask_impl(mask: str) -> str:
return "simplified" if mask.startswith("s_") else "generic"


def get_mask_cpp_type(mask: str) -> str:
return get_mask_map(get_mask_impl(mask))[mask]


_MASK_CHECK_MAP = {
"no": "t.mask_type == mask_enum::no_mask",
"causal": "t.mask_type == mask_enum::mask_top_left || t.mask_type == mask_enum::mask_bottom_right",
Expand All @@ -62,6 +70,10 @@ def get_mask_check_map(mask: str):
return None


def get_mask_cpp_check_expr(mask: str) -> str:
return get_mask_check_map(get_mask_impl(mask))[mask]


QSCALE_MAP = {
"no": "ck_tile::BlockAttentionQuantScaleEnum::NO_SCALE",
"pertensor": "ck_tile::BlockAttentionQuantScaleEnum::PERTENSOR",
Expand Down Expand Up @@ -122,6 +134,7 @@ def get_mask_check_map(mask: str):
"qr_async": "ck_tile::BlockFmhaPipelineQRKSVSAsync",
"qs": "ck_tile::BlockFmhaPipelineQSKSVS",
"qr_async_trload": "ck_tile::BlockFmhaPipelineQRKSVSAsyncTrload",
"qr_async_trload_v3": "ck_tile::BlockFmhaFwdV3Pipeline",
}

PIPELINE_ENUM_MAP = {
Expand All @@ -131,6 +144,7 @@ def get_mask_check_map(mask: str):
"qs": "ck_tile::BlockFmhaPipelineEnum::QSKSVS",
"qr_pagedkv": "ck_tile::BlockFmhaPipelineEnum::QRKSVS",
"qr_async_trload": "ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC_TRLOAD",
"qr_async_trload_v3": "ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC_TRLOAD_V3",
}

BOOL_MAP = {
Expand Down
Loading