@@ -76,18 +76,18 @@ module @flash_attention attributes {gpu.container_module} {
7676 %prefetch_offset_x = arith.addi %wg_q_x_offset , %prefetch_offset_x_t0 : index
7777 %prefetch_offset_y = arith.muli %sg_layout_y , %c32 : index
7878
79- %k_prefetch_tile = xegpu.create_nd_tdesc %K , shape : [%size_x , %BLOCK_DMODEL ], strides : [%BLOCK_DMODEL , %c1 ] : memref <?x?xf16 > -> !xegpu.tensor_desc <8 x 16 x f16 , #xegpu.block_tdesc_attr <array_length = 2 >>
80- xegpu.prefetch_nd %k_prefetch_tile [%prefetch_offset_x , %prefetch_offset_y ] {l1_hint = #xegpu.cache_hint <cached >, l2_hint = #xegpu.cache_hint <cached >, l3_hint = #xegpu.cache_hint <cached >} : !xegpu.tensor_desc <8 x 16 x f16 , #xegpu.block_tdesc_attr <array_length = 2 >>
79+ %k_prefetch_tile = xegpu.create_nd_tdesc %K , shape : [%size_x , %BLOCK_DMODEL ], strides : [%BLOCK_DMODEL , %c1 ] : memref <?x?xf16 > -> !xegpu.tensor_desc <16 x 16 x f16 , #xegpu.block_tdesc_attr <array_length = 2 >>
80+ xegpu.prefetch_nd %k_prefetch_tile [%prefetch_offset_x , %prefetch_offset_y ] {l1_hint = #xegpu.cache_hint <cached >, l2_hint = #xegpu.cache_hint <cached >, l3_hint = #xegpu.cache_hint <cached >} : !xegpu.tensor_desc <16 x 16 x f16 , #xegpu.block_tdesc_attr <array_length = 2 >>
8181 %prefetch_offset_x_plus_BLOCK_N = arith.addi %prefetch_offset_x , %BLOCK_N : index
82- xegpu.prefetch_nd %k_prefetch_tile [%prefetch_offset_x_plus_BLOCK_N , %prefetch_offset_y ] {l1_hint = #xegpu.cache_hint <cached >, l2_hint = #xegpu.cache_hint <cached >, l3_hint = #xegpu.cache_hint <cached >} : !xegpu.tensor_desc <8 x 16 x f16 , #xegpu.block_tdesc_attr <array_length = 2 >>
82+ xegpu.prefetch_nd %k_prefetch_tile [%prefetch_offset_x_plus_BLOCK_N , %prefetch_offset_y ] {l1_hint = #xegpu.cache_hint <cached >, l2_hint = #xegpu.cache_hint <cached >, l3_hint = #xegpu.cache_hint <cached >} : !xegpu.tensor_desc <16 x 16 x f16 , #xegpu.block_tdesc_attr <array_length = 2 >>
8383 %prefetch_offset_x_plus_2_BLOCK_N = arith.addi %prefetch_offset_x_plus_BLOCK_N , %BLOCK_N : index
84- xegpu.prefetch_nd %k_prefetch_tile [%prefetch_offset_x_plus_2_BLOCK_N , %prefetch_offset_y ] {l1_hint = #xegpu.cache_hint <cached >, l2_hint = #xegpu.cache_hint <cached >, l3_hint = #xegpu.cache_hint <cached >} : !xegpu.tensor_desc <8 x 16 x f16 , #xegpu.block_tdesc_attr <array_length = 2 >>
84+ xegpu.prefetch_nd %k_prefetch_tile [%prefetch_offset_x_plus_2_BLOCK_N , %prefetch_offset_y ] {l1_hint = #xegpu.cache_hint <cached >, l2_hint = #xegpu.cache_hint <cached >, l3_hint = #xegpu.cache_hint <cached >} : !xegpu.tensor_desc <16 x 16 x f16 , #xegpu.block_tdesc_attr <array_length = 2 >>
8585
8686 // V prefetch is similar to K
87- %v_prefetch_tile = xegpu.create_nd_tdesc %V , shape : [%size_x , %BLOCK_DMODEL ], strides : [%BLOCK_DMODEL , %c1 ] : memref <?x?xf16 > -> !xegpu.tensor_desc <8 x 16 x f16 , #xegpu.block_tdesc_attr <array_length = 2 >>
88- xegpu.prefetch_nd %v_prefetch_tile [%prefetch_offset_x , %prefetch_offset_y ] {l1_hint = #xegpu.cache_hint <cached >, l2_hint = #xegpu.cache_hint <cached >, l3_hint = #xegpu.cache_hint <cached >} : !xegpu.tensor_desc <8 x 16 x f16 , #xegpu.block_tdesc_attr <array_length = 2 >>
89- xegpu.prefetch_nd %v_prefetch_tile [%prefetch_offset_x_plus_BLOCK_N , %prefetch_offset_y ] {l1_hint = #xegpu.cache_hint <cached >, l2_hint = #xegpu.cache_hint <cached >, l3_hint = #xegpu.cache_hint <cached >} : !xegpu.tensor_desc <8 x 16 x f16 , #xegpu.block_tdesc_attr <array_length = 2 >>
90- xegpu.prefetch_nd %v_prefetch_tile [%prefetch_offset_x_plus_2_BLOCK_N , %prefetch_offset_y ] {l1_hint = #xegpu.cache_hint <cached >, l2_hint = #xegpu.cache_hint <cached >, l3_hint = #xegpu.cache_hint <cached >} : !xegpu.tensor_desc <8 x 16 x f16 , #xegpu.block_tdesc_attr <array_length = 2 >>
87+ %v_prefetch_tile = xegpu.create_nd_tdesc %V , shape : [%size_x , %BLOCK_DMODEL ], strides : [%BLOCK_DMODEL , %c1 ] : memref <?x?xf16 > -> !xegpu.tensor_desc <16 x 16 x f16 , #xegpu.block_tdesc_attr <array_length = 2 >>
88+ xegpu.prefetch_nd %v_prefetch_tile [%prefetch_offset_x , %prefetch_offset_y ] {l1_hint = #xegpu.cache_hint <cached >, l2_hint = #xegpu.cache_hint <cached >, l3_hint = #xegpu.cache_hint <cached >} : !xegpu.tensor_desc <16 x 16 x f16 , #xegpu.block_tdesc_attr <array_length = 2 >>
89+ xegpu.prefetch_nd %v_prefetch_tile [%prefetch_offset_x_plus_BLOCK_N , %prefetch_offset_y ] {l1_hint = #xegpu.cache_hint <cached >, l2_hint = #xegpu.cache_hint <cached >, l3_hint = #xegpu.cache_hint <cached >} : !xegpu.tensor_desc <16 x 16 x f16 , #xegpu.block_tdesc_attr <array_length = 2 >>
90+ xegpu.prefetch_nd %v_prefetch_tile [%prefetch_offset_x_plus_2_BLOCK_N , %prefetch_offset_y ] {l1_hint = #xegpu.cache_hint <cached >, l2_hint = #xegpu.cache_hint <cached >, l3_hint = #xegpu.cache_hint <cached >} : !xegpu.tensor_desc <16 x 16 x f16 , #xegpu.block_tdesc_attr <array_length = 2 >>
9191 %BLOCK_N_3_t = arith.addi %BLOCK_N , %BLOCK_N : index
9292 %BLOCK_N_3 = arith.addi %BLOCK_N_3_t , %BLOCK_N : index
9393
@@ -149,10 +149,10 @@ module @flash_attention attributes {gpu.container_module} {
149149 // K prefetch
150150 %prefetch_offset_x_running_t = arith.addi %BLOCK_N_3 , %k : index
151151 %prefetch_offset_x_running = arith.addi %wg_q_x_offset , %prefetch_offset_x_running_t : index
152- xegpu.prefetch_nd %k_prefetch_tile [%prefetch_offset_x_running , %prefetch_offset_y ] : !xegpu.tensor_desc <8 x 16 x f16 , #xegpu.block_tdesc_attr <array_length = 2 >>
152+ xegpu.prefetch_nd %k_prefetch_tile [%prefetch_offset_x_running , %prefetch_offset_y ] : !xegpu.tensor_desc <16 x 16 x f16 , #xegpu.block_tdesc_attr <array_length = 2 >>
153153
154154 // V prefetch
155- xegpu.prefetch_nd %v_prefetch_tile [%prefetch_offset_x_running , %prefetch_offset_y ] : !xegpu.tensor_desc <8 x 16 x f16 , #xegpu.block_tdesc_attr <array_length = 2 >>
155+ xegpu.prefetch_nd %v_prefetch_tile [%prefetch_offset_x_running , %prefetch_offset_y ] : !xegpu.tensor_desc <16 x 16 x f16 , #xegpu.block_tdesc_attr <array_length = 2 >>
156156
157157 // Load first 16x64xf16 (i.e. 16x32xf32) K slice.
158158 %wg_x_offset_running = arith.addi %wg_x_offset , %k : index
0 commit comments