@@ -248,26 +248,27 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} {
248248
249249// -----
250250
251+ #blockedsrc = #ttg.blocked <{sizePerThread = [1 , 4 ], threadsPerWarp = [1 , 32 ], warpsPerCTA = [4 , 1 ], order = [0 , 1 ]}>
251252#blocked = #ttg.blocked <{sizePerThread = [2 , 2 ], threadsPerWarp = [32 , 1 ], warpsPerCTA = [4 , 1 ], order = [0 , 1 ]}>
252253#blockedtrans = #ttg.blocked <{sizePerThread = [2 , 2 ], threadsPerWarp = [1 , 32 ], warpsPerCTA = [1 , 4 ], order = [1 , 0 ]}>
253- #blocked1 = #ttg.slice <{dim =0 , parent =#blocked }>
254- #blocked2 = #ttg.blocked <{ sizePerThread = [ 2 ], threadsPerWarp = [ 32 ], warpsPerCTA = [ 4 ], order = [ 0 ] }>
254+ #blocked1 = #ttg.slice <{dim =0 , parent =#blockedsrc }>
255+ #blocked2 = #ttg.slice <{ dim = 0 , parent = #blockedtrans }>
255256module attributes {" ttg.num-ctas" = 1 : i32 , " ttg.num-warps" = 4 : i32 } {
256257 // COMMON-LABEL: unary_triton_ops_transitive_nonneg
257258 tt.func @unary_triton_ops_transitive_nonneg (%arg0: !tt.ptr <bf16 > {tt.divisibility = 16 : i32 , tt.pointer_range = 32 : i32 }, %arg1: !tt.ptr <bf16 > {tt.divisibility = 16 : i32 , tt.pointer_range = 32 : i32 }) {
258259 %c10_i32 = arith.constant 5 : i32
259260 %0 = tt.make_range {end = 16 : i32 , start = 0 : i32 } : tensor <16 xi32 , #blocked1 >
260- %1 = tt.expand_dims %0 {axis = 0 : i32 } : tensor <16 xi32 , #blocked1 > -> tensor <1 x16 xi32 , #blocked >
261- %2 = tt.reshape %1 allow_reorder : tensor <1 x16 xi32 , #blocked > -> tensor <8 x2 xi32 , #blocked >
262- %3 = tt.reshape %1 allow_reorder : tensor <1 x16 xi32 , #blocked > -> tensor <2 x8 xi32 , #blocked >
263- %4 = tt.trans %3 {order = array<i32 : 1 , 0 >} : tensor <2 x8 xi32 , #blocked > -> tensor <8 x2 xi32 , #blockedtrans >
264- %5 = ttg.convert_layout %4 : tensor <8 x2 xi32 , #blockedtrans > -> tensor <8 x2 xi32 , #blocked >
261+ %1 = tt.expand_dims %0 {axis = 0 : i32 } : tensor <16 xi32 , #blocked1 > -> tensor <1 x16 xi32 , #blockedsrc >
262+ %2 = tt.reshape %1 allow_reorder : tensor <1 x16 xi32 , #blockedsrc > -> tensor <8 x2 xi32 , #blocked >
263+ %3 = tt.reshape %1 allow_reorder : tensor <1 x16 xi32 , #blockedsrc > -> tensor <2 x8 xi32 , #blockedtrans >
264+ %4 = tt.trans %3 {order = array<i32 : 1 , 0 >} : tensor <2 x8 xi32 , #blockedtrans > -> tensor <8 x2 xi32 , #blocked >
265+ %5 = ttg.convert_layout %4 : tensor <8 x2 xi32 , #blocked > -> tensor <8 x2 xi32 , #blocked >
265266 %6 = arith.addi %5 , %2 : tensor <8 x2 xi32 , #blocked >
266267 %7 = tt.make_range {end = 10 : i32 , start = 2 : i32 } : tensor <8 xi32 , #blocked2 >
267- %8 = ttg.convert_layout %7 : tensor <8 xi32 , #blocked2 > -> tensor <8 xi32 , #blocked1 >
268- %9 = tt.expand_dims %8 {axis = 0 : i32 } : tensor <8 xi32 , #blocked1 > -> tensor <1 x8 xi32 , #blocked >
269- %10 = tt.broadcast %9 : tensor <1 x8 xi32 , #blocked > -> tensor <2 x8 xi32 , #blocked >
270- %11 = tt.reshape %10 allow_reorder : tensor <2 x8 xi32 , #blocked > -> tensor <8 x2 xi32 , #blocked >
268+ %8 = ttg.convert_layout %7 : tensor <8 xi32 , #blocked2 > -> tensor <8 xi32 , #blocked2 >
269+ %9 = tt.expand_dims %8 {axis = 0 : i32 } : tensor <8 xi32 , #blocked2 > -> tensor <1 x8 xi32 , #blockedtrans >
270+ %10 = tt.broadcast %9 : tensor <1 x8 xi32 , #blockedtrans > -> tensor <2 x8 xi32 , #blockedtrans >
271+ %11 = tt.reshape %10 allow_reorder : tensor <2 x8 xi32 , #blockedtrans > -> tensor <8 x2 xi32 , #blocked >
271272 %12 = tt.splat %c10_i32 : i32 -> tensor <8 x2 xi32 , #blocked >
272273 %13 = arith.addi %11 , %12 : tensor <8 x2 xi32 , #blocked >
273274 %14 = arith.minsi %13 , %5 : tensor <8 x2 xi32 , #blocked >
@@ -293,7 +294,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} {
293294// -----
294295
295296
296- #blocked = #ttg.blocked <{sizePerThread = [2 , 2 ], threadsPerWarp = [32 , 1 ], warpsPerCTA = [4 , 1 ], order = [1 , 0 ]}>
297+ #blocked = #ttg.blocked <{sizePerThread = [2 , 1 ], threadsPerWarp = [32 , 1 ], warpsPerCTA = [4 , 1 ], order = [1 , 0 ]}>
297298#blocked1 = #ttg.blocked <{sizePerThread = [2 ], threadsPerWarp = [32 ], warpsPerCTA = [4 ], order = [0 ]}>
298299module attributes {" ttg.num-ctas" = 1 : i32 , " ttg.num-warps" = 4 : i32 } {
299300 // COMMON-LABEL: join_cat_transitive_nonneg
0 commit comments