Skip to content

Commit a8018dd

Browse files
authored
[Test] Fix prefetch sizes in SIMT flash attention. (#1127)
1 parent 708e131 commit a8018dd

File tree

1 file changed

+10
-10
lines changed

1 file changed

+10
-10
lines changed

test/Integration/Dialect/XeGPU/SG/flash_attention_fwd.mlir

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -76,18 +76,18 @@ module @flash_attention attributes {gpu.container_module} {
7676
%prefetch_offset_x = arith.addi %wg_q_x_offset, %prefetch_offset_x_t0 : index
7777
%prefetch_offset_y = arith.muli %sg_layout_y, %c32 : index
7878

79-
%k_prefetch_tile = xegpu.create_nd_tdesc %K , shape: [%size_x, %BLOCK_DMODEL], strides: [%BLOCK_DMODEL, %c1] : memref<?x?xf16> -> !xegpu.tensor_desc<8x16xf16, #xegpu.block_tdesc_attr<array_length = 2>>
80-
xegpu.prefetch_nd %k_prefetch_tile[%prefetch_offset_x, %prefetch_offset_y] {l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<cached>, l3_hint = #xegpu.cache_hint<cached>} : !xegpu.tensor_desc<8x16xf16, #xegpu.block_tdesc_attr<array_length = 2>>
79+
%k_prefetch_tile = xegpu.create_nd_tdesc %K , shape: [%size_x, %BLOCK_DMODEL], strides: [%BLOCK_DMODEL, %c1] : memref<?x?xf16> -> !xegpu.tensor_desc<16x16xf16, #xegpu.block_tdesc_attr<array_length = 2>>
80+
xegpu.prefetch_nd %k_prefetch_tile[%prefetch_offset_x, %prefetch_offset_y] {l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<cached>, l3_hint = #xegpu.cache_hint<cached>} : !xegpu.tensor_desc<16x16xf16, #xegpu.block_tdesc_attr<array_length = 2>>
8181
%prefetch_offset_x_plus_BLOCK_N = arith.addi %prefetch_offset_x, %BLOCK_N : index
82-
xegpu.prefetch_nd %k_prefetch_tile[%prefetch_offset_x_plus_BLOCK_N, %prefetch_offset_y] {l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<cached>, l3_hint = #xegpu.cache_hint<cached>} : !xegpu.tensor_desc<8x16xf16, #xegpu.block_tdesc_attr<array_length = 2>>
82+
xegpu.prefetch_nd %k_prefetch_tile[%prefetch_offset_x_plus_BLOCK_N, %prefetch_offset_y] {l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<cached>, l3_hint = #xegpu.cache_hint<cached>} : !xegpu.tensor_desc<16x16xf16, #xegpu.block_tdesc_attr<array_length = 2>>
8383
%prefetch_offset_x_plus_2_BLOCK_N = arith.addi %prefetch_offset_x_plus_BLOCK_N, %BLOCK_N : index
84-
xegpu.prefetch_nd %k_prefetch_tile[%prefetch_offset_x_plus_2_BLOCK_N, %prefetch_offset_y] {l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<cached>, l3_hint = #xegpu.cache_hint<cached>} : !xegpu.tensor_desc<8x16xf16, #xegpu.block_tdesc_attr<array_length = 2>>
84+
xegpu.prefetch_nd %k_prefetch_tile[%prefetch_offset_x_plus_2_BLOCK_N, %prefetch_offset_y] {l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<cached>, l3_hint = #xegpu.cache_hint<cached>} : !xegpu.tensor_desc<16x16xf16, #xegpu.block_tdesc_attr<array_length = 2>>
8585

8686
// V prefetch is similar to K
87-
%v_prefetch_tile = xegpu.create_nd_tdesc %V , shape: [%size_x, %BLOCK_DMODEL], strides: [%BLOCK_DMODEL, %c1] : memref<?x?xf16> -> !xegpu.tensor_desc<8x16xf16, #xegpu.block_tdesc_attr<array_length = 2>>
88-
xegpu.prefetch_nd %v_prefetch_tile[%prefetch_offset_x, %prefetch_offset_y] {l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<cached>, l3_hint = #xegpu.cache_hint<cached>} : !xegpu.tensor_desc<8x16xf16, #xegpu.block_tdesc_attr<array_length = 2>>
89-
xegpu.prefetch_nd %v_prefetch_tile[%prefetch_offset_x_plus_BLOCK_N, %prefetch_offset_y] {l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<cached>, l3_hint = #xegpu.cache_hint<cached>} : !xegpu.tensor_desc<8x16xf16, #xegpu.block_tdesc_attr<array_length = 2>>
90-
xegpu.prefetch_nd %v_prefetch_tile[%prefetch_offset_x_plus_2_BLOCK_N, %prefetch_offset_y] {l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<cached>, l3_hint = #xegpu.cache_hint<cached>} : !xegpu.tensor_desc<8x16xf16, #xegpu.block_tdesc_attr<array_length = 2>>
87+
%v_prefetch_tile = xegpu.create_nd_tdesc %V , shape: [%size_x, %BLOCK_DMODEL], strides: [%BLOCK_DMODEL, %c1] : memref<?x?xf16> -> !xegpu.tensor_desc<16x16xf16, #xegpu.block_tdesc_attr<array_length = 2>>
88+
xegpu.prefetch_nd %v_prefetch_tile[%prefetch_offset_x, %prefetch_offset_y] {l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<cached>, l3_hint = #xegpu.cache_hint<cached>} : !xegpu.tensor_desc<16x16xf16, #xegpu.block_tdesc_attr<array_length = 2>>
89+
xegpu.prefetch_nd %v_prefetch_tile[%prefetch_offset_x_plus_BLOCK_N, %prefetch_offset_y] {l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<cached>, l3_hint = #xegpu.cache_hint<cached>} : !xegpu.tensor_desc<16x16xf16, #xegpu.block_tdesc_attr<array_length = 2>>
90+
xegpu.prefetch_nd %v_prefetch_tile[%prefetch_offset_x_plus_2_BLOCK_N, %prefetch_offset_y] {l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<cached>, l3_hint = #xegpu.cache_hint<cached>} : !xegpu.tensor_desc<16x16xf16, #xegpu.block_tdesc_attr<array_length = 2>>
9191
%BLOCK_N_3_t = arith.addi %BLOCK_N, %BLOCK_N : index
9292
%BLOCK_N_3 = arith.addi %BLOCK_N_3_t, %BLOCK_N : index
9393

@@ -149,10 +149,10 @@ module @flash_attention attributes {gpu.container_module} {
149149
// K prefetch
150150
%prefetch_offset_x_running_t = arith.addi %BLOCK_N_3, %k : index
151151
%prefetch_offset_x_running = arith.addi %wg_q_x_offset, %prefetch_offset_x_running_t : index
152-
xegpu.prefetch_nd %k_prefetch_tile[%prefetch_offset_x_running, %prefetch_offset_y] : !xegpu.tensor_desc<8x16xf16, #xegpu.block_tdesc_attr<array_length = 2>>
152+
xegpu.prefetch_nd %k_prefetch_tile[%prefetch_offset_x_running, %prefetch_offset_y] : !xegpu.tensor_desc<16x16xf16, #xegpu.block_tdesc_attr<array_length = 2>>
153153

154154
// V prefetch
155-
xegpu.prefetch_nd %v_prefetch_tile[%prefetch_offset_x_running, %prefetch_offset_y] : !xegpu.tensor_desc<8x16xf16, #xegpu.block_tdesc_attr<array_length = 2>>
155+
xegpu.prefetch_nd %v_prefetch_tile[%prefetch_offset_x_running, %prefetch_offset_y] : !xegpu.tensor_desc<16x16xf16, #xegpu.block_tdesc_attr<array_length = 2>>
156156

157157
// Load first 16x64xf16 (i.e. 16x32xf32) K slice.
158158
%wg_x_offset_running = arith.addi %wg_x_offset, %k : index

0 commit comments

Comments
 (0)