@@ -351,3 +351,28 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.targ
351351 tt.return
352352 }
353353}
354+
355+ // -----
356+
357+ #blocked = #ttg.blocked <{sizePerThread = [2 , 2 ], threadsPerWarp = [2 , 16 ], warpsPerCTA = [4 , 1 ], order = [1 , 0 ] }>
358+ #op0 = #ttg.dot_op <{opIdx = 0 , parent = #blocked }>
359+ #op1 = #ttg.dot_op <{opIdx = 1 , parent = #blocked }>
360+
361+ // CHECK{LITERAL}: #mma = #ttg.amd_wmma<{version = 3, isTranspose = true, ctaLayout = {warp = [[0, 1], [1, 0]]}, instrShape = [16, 16, 32]}>
362+ module attributes {" ttg.num-ctas" = 1 : i32 , " ttg.num-warps" = 4 : i32 , ttg.target = " hip:gfx1250" , " ttg.threads-per-warp" = 32 : i32 } {
363+ tt.func public @wmma_dot_f16_f32_smallk (
364+ %arg0: tensor <32 x8 x!tt.ptr <f16 >, #op0 >,
365+ %arg1: tensor <8 x32 x!tt.ptr <f16 >, #op1 >,
366+ %arg2: tensor <32 x32 x!tt.ptr <f32 >, #blocked >
367+ ) {
368+ %a = tt.load %arg0 : tensor <32 x8 x!tt.ptr <f16 >, #op0 >
369+ %b = tt.load %arg1 : tensor <8 x32 x!tt.ptr <f16 >, #op1 >
370+ %c = arith.constant dense <0.000000e+00 > : tensor <32 x32 xf32 , #blocked >
371+ // CHECK: %[[OPND0:.*]] = ttg.convert_layout {{.*}} : tensor<32x8xf16, #ttg.dot_op<{opIdx = 0, parent = #blocked}>> -> tensor<32x8xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>>
372+ // CHECK: %[[OPND1:.*]] = ttg.convert_layout {{.*}} : tensor<8x32xf16, #ttg.dot_op<{opIdx = 1, parent = #blocked}>> -> tensor<8x32xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 4}>>
373+ // CHECK: tt.dot %[[OPND0]], %[[OPND1]], %{{.*}} : tensor<32x8xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>> * tensor<8x32xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 4}>> -> tensor<32x32xf32, #mma>
374+ %res = tt.dot %a , %b , %c : tensor <32 x8 xf16 , #op0 > * tensor <8 x32 xf16 , #op1 > -> tensor <32 x32 xf32 , #blocked >
375+ tt.store %arg2 , %res : tensor <32 x32 x!tt.ptr <f32 >, #blocked >
376+ tt.return
377+ }
378+ }
0 commit comments