@@ -524,3 +524,55 @@ func.func @attention_multi_m_dynamic(%arg0 : tensor<20x8x?x16x64xf16>, %arg1 : t
524524// CHECK-SAME: #iree_gpu.lowering_config
525525// CHECK-SAME: reduction = [0, 0, 0, 0, 0, 64, 0]
526526// CHECK-SAME: workgroup = [1, 4, 1, 16, 0, 0, 64]
527+
528+ // -----
529+
530+ // CHECK: #iree_codegen.translation_info<pipeline = #iree_gpu.pipeline<VectorDistribute>
531+
532+ // CHECK-LABEL: func.func @attention_f32_20x4096x128x4096x128()
533+
534+ #map = affine_map <(d0 , d1 , d2 , d3 , d4 ) -> (d0 , d1 , d2 )>
535+ #map1 = affine_map <(d0 , d1 , d2 , d3 , d4 ) -> (d0 , d3 , d2 )>
536+ #map2 = affine_map <(d0 , d1 , d2 , d3 , d4 ) -> (d0 , d3 , d4 )>
537+ #map3 = affine_map <(d0 , d1 , d2 , d3 , d4 ) -> ()>
538+ #map4 = affine_map <(d0 , d1 , d2 , d3 , d4 ) -> (d0 , d1 , d4 )>
539+ #map5 = affine_map <(d0 , d1 , d2 , d3 , d4 ) -> (d0 , d1 )>
540+ #pipeline_layout = #hal.pipeline.layout <bindings = [
541+ #hal.pipeline.binding <storage_buffer >,
542+ #hal.pipeline.binding <storage_buffer >,
543+ #hal.pipeline.binding <storage_buffer >,
544+ #hal.pipeline.binding <storage_buffer >
545+ ]>
546+ func.func @attention_f32_20x4096x128x4096x128 () {
547+ %cst = arith.constant 1.250000e-01 : f32
548+ %c0 = arith.constant 0 : index
549+ %0 = hal.interface.binding.subspan layout (#pipeline_layout ) binding (0 ) alignment (64 ) offset (%c0 ) flags (ReadOnly ) : !iree_tensor_ext.dispatch.tensor <readonly :tensor <20 x4096 x128 xf32 >>
550+ %1 = hal.interface.binding.subspan layout (#pipeline_layout ) binding (1 ) alignment (64 ) offset (%c0 ) flags (ReadOnly ) : !iree_tensor_ext.dispatch.tensor <readonly :tensor <20 x4096 x128 xf32 >>
551+ %2 = hal.interface.binding.subspan layout (#pipeline_layout ) binding (2 ) alignment (64 ) offset (%c0 ) flags (ReadOnly ) : !iree_tensor_ext.dispatch.tensor <readonly :tensor <20 x4096 x128 xf32 >>
552+ %3 = hal.interface.binding.subspan layout (#pipeline_layout ) binding (3 ) alignment (64 ) offset (%c0 ) : !iree_tensor_ext.dispatch.tensor <writeonly :tensor <20 x4096 x128 xf32 >>
553+ %4 = iree_tensor_ext.dispatch.tensor.load %0 , offsets = [0 , 0 , 0 ], sizes = [20 , 4096 , 128 ], strides = [1 , 1 , 1 ] : !iree_tensor_ext.dispatch.tensor <readonly :tensor <20 x4096 x128 xf32 >> -> tensor <20 x4096 x128 xf32 >
554+ %5 = iree_tensor_ext.dispatch.tensor.load %1 , offsets = [0 , 0 , 0 ], sizes = [20 , 4096 , 128 ], strides = [1 , 1 , 1 ] : !iree_tensor_ext.dispatch.tensor <readonly :tensor <20 x4096 x128 xf32 >> -> tensor <20 x4096 x128 xf32 >
555+ %6 = iree_tensor_ext.dispatch.tensor.load %2 , offsets = [0 , 0 , 0 ], sizes = [20 , 4096 , 128 ], strides = [1 , 1 , 1 ] : !iree_tensor_ext.dispatch.tensor <readonly :tensor <20 x4096 x128 xf32 >> -> tensor <20 x4096 x128 xf32 >
556+ %7 = tensor.empty () : tensor <20 x4096 x128 xf32 >
557+ %8 = tensor.empty () : tensor <20 x4096 xf32 >
558+ %cst_0 = arith.constant 0.000000e+00 : f32
559+ %cst_1 = arith.constant -3.40282347E+38 : f32
560+ %cst_2 = arith.constant 0.000000e+00 : f32
561+ %9 = linalg.fill ins (%cst_0 : f32 ) outs (%7 : tensor <20 x4096 x128 xf32 >) -> tensor <20 x4096 x128 xf32 >
562+ %10 = linalg.fill ins (%cst_1 : f32 ) outs (%8 : tensor <20 x4096 xf32 >) -> tensor <20 x4096 xf32 >
563+ %11 = linalg.fill ins (%cst_2 : f32 ) outs (%8 : tensor <20 x4096 xf32 >) -> tensor <20 x4096 xf32 >
564+ %12:3 = iree_linalg_ext.online_attention {index ing_maps = [#map , #map1 , #map2 , #map3 , #map4 , #map5 , #map5 ]} ins (%4 , %5 , %6 , %cst : tensor <20 x4096 x128 xf32 >, tensor <20 x4096 x128 xf32 >, tensor <20 x4096 x128 xf32 >, f32 ) outs (%9 , %10 , %11 : tensor <20 x4096 x128 xf32 >, tensor <20 x4096 xf32 >, tensor <20 x4096 xf32 >) {
565+ ^bb0 (%arg0: f32 ):
566+ iree_linalg_ext.yield %arg0 : f32
567+ } -> tensor <20 x4096 x128 xf32 >, tensor <20 x4096 xf32 >, tensor <20 x4096 xf32 >
568+ iree_tensor_ext.dispatch.tensor.store %12#0 , %3 , offsets = [0 , 0 , 0 ], sizes = [20 , 4096 , 128 ], strides = [1 , 1 , 1 ] : tensor <20 x4096 x128 xf32 > -> !iree_tensor_ext.dispatch.tensor <writeonly :tensor <20 x4096 x128 xf32 >>
569+ return
570+ }
571+
572+ // CHECK: #iree_gpu.lowering_config
573+ // CHECK-SAME: mma_kind = #iree_gpu.mma_layout<MFMA_F32_16x16x4_F32>
574+ // CHECK-SAME: #iree_gpu.lowering_config
575+ // CHECK-SAME: mma_kind = #iree_gpu.mma_layout<MFMA_F32_16x16x4_F32>
576+ // CHECK-SAME: #iree_gpu.lowering_config
577+ // CHECK-SAME: reduction = [0, 0, 0, 16, 0]
578+ // CHECK-SAME: workgroup = [1, 64, 0, 0, 64]
0 commit comments