@@ -1080,3 +1080,57 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.targ
10801080 tt.return %2 : tensor <64 x32 xf32 , #mma >
10811081 }
10821082}
1083+
1084+ // -----
1085+
1086+ #mma = #ttg.nvidia_mma <{versionMajor = 3 , versionMinor = 0 , warpsPerCTA = [4 , 1 ], instrShape = [16 , 32 , 16 ]}>
1087+ #shared = #ttg.nvmma_shared <{swizzlingByteWidth = 64 , transposed = false , elementBitWidth = 16 }>
1088+ #shared1 = #ttg.nvmma_shared <{swizzlingByteWidth = 64 , transposed = true , elementBitWidth = 16 }>
1089+ #smem = #ttg.shared_memory
1090+ module attributes {" ttg.num-ctas" = 1 : i32 , " ttg.num-warps" = 4 : i32 , ttg.target = " cuda:90" , " ttg.threads-per-warp" = 32 : i32 } {
1091+ // CHECK-LABEL: dot_outer_loop_arg
1092+ // CHECK: scf.for
1093+ // CHECK-NEXT: scf.for
1094+ // CHECK-NEXT: ttng.warp_group_dot
1095+ // CHECK-NEXT: ttng.warp_group_dot_wait {{.*}} {pendings = 1 : i32}
1096+ // CHECK-NEXT: scf.yield
1097+ // CHECK: ttng.warp_group_dot_wait {{.*}} {pendings = 0 : i32}
1098+ tt.func public @dot_outer_loop_arg (%arg0: i32 , %arg2: !ttg.memdesc <64 x32 xbf16 , #shared , #smem , mutable >, %arg3: !ttg.memdesc <32 x32 xbf16 , #shared1 , #smem , mutable >) -> tensor <64 x32 xf32 , #mma > {
1099+ %c0_i32 = arith.constant 0 : i32
1100+ %c32_i32 = arith.constant 32 : i32
1101+ %cst_0 = arith.constant dense <0.000000e+00 > : tensor <64 x32 xf32 , #mma >
1102+ %outer:2 = scf.for %arg4 = %c0_i32 to %arg0 step %c32_i32 iter_args (%arg5 = %arg3 , %arg8 = %cst_0 ) -> (!ttg.memdesc <32 x32 xbf16 , #shared1 , #smem , mutable >, tensor <64 x32 xf32 , #mma >) : i32 {
1103+ %0 = scf.for %arg6 = %c0_i32 to %arg0 step %c32_i32 iter_args (%arg7 = %arg8 ) -> (tensor <64 x32 xf32 , #mma >) : i32 {
1104+ %1 = ttng.warp_group_dot %arg2 , %arg5 , %arg7 {inputPrecision = 0 : i32 } : !ttg.memdesc <64 x32 xbf16 , #shared , #smem , mutable > * !ttg.memdesc <32 x32 xbf16 , #shared1 , #smem , mutable > -> tensor <64 x32 xf32 , #mma >
1105+ scf.yield %1 : tensor <64 x32 xf32 , #mma >
1106+ }
1107+ scf.yield %arg5 , %0 : !ttg.memdesc <32 x32 xbf16 , #shared1 , #smem , mutable >, tensor <64 x32 xf32 , #mma >
1108+ }
1109+ tt.return %outer#1 : tensor <64 x32 xf32 , #mma >
1110+ }
1111+ }
1112+
1113+ // -----
1114+
1115+ #mma = #ttg.nvidia_mma <{versionMajor = 3 , versionMinor = 0 , warpsPerCTA = [4 , 1 ], instrShape = [16 , 32 , 16 ]}>
1116+ #shared = #ttg.nvmma_shared <{swizzlingByteWidth = 64 , transposed = false , elementBitWidth = 16 }>
1117+ #shared1 = #ttg.nvmma_shared <{swizzlingByteWidth = 64 , transposed = true , elementBitWidth = 16 }>
1118+ #smem = #ttg.shared_memory
1119+ module attributes {" ttg.num-ctas" = 1 : i32 , " ttg.num-warps" = 4 : i32 , ttg.target = " cuda:90" , " ttg.threads-per-warp" = 32 : i32 } {
1120+ // CHECK-LABEL: loop_arg_cycle
1121+ // CHECK: scf.for
1122+ // CHECK-NEXT: ttng.warp_group_dot
1123+ // CHECK-NEXT: ttng.warp_group_dot_wait {{.*}} {pendings = 1 : i32}
1124+ // CHECK-NEXT: scf.yield
1125+ // CHECK: ttng.warp_group_dot_wait {{.*}} {pendings = 0 : i32}
1126+ tt.func public @loop_arg_cycle (%arg0: i32 , %arg2: !ttg.memdesc <64 x32 xbf16 , #shared , #smem , mutable >, %arg3: !ttg.memdesc <32 x32 xbf16 , #shared1 , #smem , mutable >) -> tensor <64 x32 xf32 , #mma > {
1127+ %c0_i32 = arith.constant 0 : i32
1128+ %c32_i32 = arith.constant 32 : i32
1129+ %cst_0 = arith.constant dense <0.000000e+00 > : tensor <64 x32 xf32 , #mma >
1130+ %0:2 = scf.for %arg4 = %c0_i32 to %arg0 step %c32_i32 iter_args (%arg5 = %arg3 , %arg7 = %cst_0 ) -> (!ttg.memdesc <32 x32 xbf16 , #shared1 , #smem , mutable >, tensor <64 x32 xf32 , #mma >) : i32 {
1131+ %1 = ttng.warp_group_dot %arg2 , %arg5 , %arg7 {inputPrecision = 0 : i32 } : !ttg.memdesc <64 x32 xbf16 , #shared , #smem , mutable > * !ttg.memdesc <32 x32 xbf16 , #shared1 , #smem , mutable > -> tensor <64 x32 xf32 , #mma >
1132+ scf.yield %arg5 , %1 : !ttg.memdesc <32 x32 xbf16 , #shared1 , #smem , mutable >, tensor <64 x32 xf32 , #mma >
1133+ }
1134+ tt.return %0#1 : tensor <64 x32 xf32 , #mma >
1135+ }
1136+ }
0 commit comments