@@ -284,6 +284,48 @@ util.func public @fuse_attention_with_broadcast(%arg0: tensor<4x8x128x?xf16>, %a
284284// CHECK-SAME: ins(%[[ARG1]], %[[ARG2]], %[[ARG0]], %[[ARG3]], %[[ARG4]] :
285285// CHECK: util.return %[[ATTENTION]]
286286
287+ // -----
288+
289+ util.func public @dont_fuse_attention_with_broadcasted_away_n_dim (
290+ %q: tensor <4 x32 x64 x16 xf16 >,
291+ %k: tensor <4 x32 x64 x16 xf16 >,
292+ %v: tensor <4 x32 x64 xf16 >,
293+ %scale: f16 ) -> tensor <4 x32 x64 x128 xf16 > {
294+ %empty_v = tensor.empty () : tensor <4 x32 x64 x128 xf16 >
295+ %v_broadcast = linalg.generic {
296+ indexing_maps = [affine_map <(d0 , d1 , d2 , d3 ) -> (d0 , d1 , d2 )>,
297+ affine_map <(d0 , d1 , d2 , d3 ) -> (d0 , d1 , d2 , d3 )>],
298+ iterator_types = [" parallel" , " parallel" , " parallel" , " parallel" ]}
299+ ins (%v : tensor <4 x32 x64 xf16 >) outs (%empty_v : tensor <4 x32 x64 x128 xf16 >) {
300+ ^bb0 (%in: f16 , %out: f16 ):
301+ linalg.yield %in : f16
302+ } -> tensor <4 x32 x64 x128 xf16 >
303+ %empty_out = tensor.empty () : tensor <4 x32 x64 x128 xf16 >
304+ %attention = iree_linalg_ext.attention {
305+ indexing_maps = [affine_map <(d0 , d1 , d2 , d3 , d4 , d5 ) -> (d0 , d1 , d2 , d3 )>,
306+ affine_map <(d0 , d1 , d2 , d3 , d4 , d5 ) -> (d0 , d1 , d4 , d3 )>,
307+ affine_map <(d0 , d1 , d2 , d3 , d4 , d5 ) -> (d0 , d1 , d4 , d5 )>,
308+ affine_map <(d0 , d1 , d2 , d3 , d4 , d5 ) -> ()>,
309+ affine_map <(d0 , d1 , d2 , d3 , d4 , d5 ) -> (d0 , d1 , d2 , d5 )>]}
310+ ins (%q , %k , %v_broadcast , %scale :
311+ tensor <4 x32 x64 x16 xf16 >, tensor <4 x32 x64 x16 xf16 >,
312+ tensor <4 x32 x64 x128 xf16 >, f16 )
313+ outs (%empty_out : tensor <4 x32 x64 x128 xf16 >) {
314+ ^bb0 (%score: f16 ):
315+ iree_linalg_ext.yield %score : f16
316+ } -> tensor <4 x32 x64 x128 xf16 >
317+ util.return %attention : tensor <4 x32 x64 x128 xf16 >
318+ }
319+ // CHECK-LABEL: func public @dont_fuse_attention_with_broadcasted_away_n_dim
320+ // CHECK-SAME: %[[Q:[a-zA-Z0-9]+]]:
321+ // CHECK-SAME: %[[K:[a-zA-Z0-9]+]]:
322+ // CHECK-SAME: %[[V:[a-zA-Z0-9]+]]:
323+ // CHECK-SAME: %[[SCALE:[a-zA-Z0-9]+]]:
324+ // CHECK: %[[V_BCAST:.+]] = linalg.generic
325+ // CHECK: %[[ATTENTION:.+]] = iree_linalg_ext.attention
326+ // CHECK-SAME: ins(%[[Q]], %[[K]], %[[V_BCAST]], %[[SCALE]] :
327+ // CHECK: util.return %[[ATTENTION]]
328+
287329
288330// -----
289331
0 commit comments