Skip to content

Commit 6140f27

Browse files
committed
Refactor rms_norm transform script to follow mlir-air layernorm prototype
Replace post-bufferize linalg_promote (which leaks self-copies that crash transform.air.copy_to_dma) with pre-bufferize bufferize_to_allocation + promote_tensor for L1 staging, mirroring mlir-air xrt 43_triton_layernorm. Eliminates "expected to produce 1 results (actually produced 0)" stderr on aie2p reported in #64.
1 parent 93c8cc6 commit 6140f27

1 file changed

Lines changed: 130 additions & 98 deletions

File tree

Lines changed: 130 additions & 98 deletions
Original file line numberDiff line numberDiff line change
@@ -1,118 +1,150 @@
1-
// RMS Norm transform for AIE2P.
2-
// 2D kernel (BLOCK_M=2 x BLOCK_N=64).
3-
//
4-
// Strategy: bufferize FIRST (no L1 staging), then use linalg_promote
5-
// on the linalg ops inside the forall to promote L2 subviews to L1 allocs.
6-
// This creates memref.copy ops that par_to_herd + copy_to_dma convert to DMAs.
1+
// RMS Norm transform for AIE2P, following mlir-air xrt 43_triton_layernorm/transform_aie2p.mlir.
2+
// Chain (after fuse_elementwise + transpose_reduce): generic_sq -> reduce -> output_generic.
73

84
module attributes {transform.with_named_sequence} {
95
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
6+
7+
// PHASE 1: canonicalize + fold unit extent
108
%func0 = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
11-
transform.apply_patterns to %func0 { transform.apply_patterns.canonicalization
12-
transform.apply_patterns.linalg.fold_unit_extent_dims_via_reshapes } : !transform.any_op
9+
transform.apply_patterns to %func0 {
10+
transform.apply_patterns.linalg.tiling_canonicalization
11+
transform.apply_patterns.scf.for_loop_canonicalization
12+
transform.apply_patterns.canonicalization
13+
transform.apply_patterns.linalg.fold_unit_extent_dims_via_reshapes
14+
} : !transform.any_op
1315
transform.apply_cse to %func0 : !transform.any_op
16+
17+
// PHASE 2: fuse elementwise + transpose reduce + canonicalize
18+
%func1 = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
19+
%fused_func = transform.air.fuse_elementwise_linalg %func1 : (!transform.any_op) -> !transform.any_op
1420
%reduces = transform.structured.match ops{["linalg.reduce"]} in %arg1 : (!transform.any_op) -> !transform.any_op
1521
%tr = transform.air.transpose_reduce %reduces : (!transform.any_op) -> !transform.any_op
16-
%func1a = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
17-
transform.apply_patterns to %func1a { transform.apply_patterns.canonicalization } : !transform.any_op
18-
transform.apply_cse to %func1a : !transform.any_op
19-
%func1 = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
20-
%f = transform.air.fuse_elementwise_linalg %func1 : (!transform.any_op) -> !transform.any_op
21-
%fa = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
22-
transform.apply_patterns to %fa { transform.apply_patterns.canonicalization } : !transform.any_op
23-
transform.apply_cse to %fa : !transform.any_op
24-
25-
%ag = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op
26-
%sq, %out = transform.split_handle %ag : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
27-
%reduce = transform.structured.match ops{["linalg.reduce"]} in %arg1 : (!transform.any_op) -> !transform.any_op
22+
23+
transform.apply_patterns to %fused_func {
24+
transform.apply_patterns.linalg.tiling_canonicalization
25+
transform.apply_patterns.scf.for_loop_canonicalization
26+
transform.apply_patterns.canonicalization
27+
} : !transform.any_op
28+
transform.apply_cse to %fused_func : !transform.any_op
29+
30+
// Data-flow navigation. Chain: generic_sq -> reduce -> output_generic
31+
%r = transform.structured.match ops{["linalg.reduce"]} in %arg1 : (!transform.any_op) -> !transform.any_op
32+
%generic_sq = transform.get_producer_of_operand %r[0] : (!transform.any_op) -> !transform.any_op
33+
%materialize = transform.structured.match ops{["bufferization.materialize_in_destination"]} in %arg1 : (!transform.any_op) -> !transform.any_op
34+
%output_generic = transform.get_producer_of_operand %materialize[0] : (!transform.any_op) -> !transform.any_op
2835
%fill = transform.structured.match ops{["linalg.fill"]} in %arg1 : (!transform.any_op) -> !transform.any_op
2936

30-
// L2 output alloc
31-
%ob, %nb = transform.structured.bufferize_to_allocation %out {memory_space = 1, bufferize_destination_only, emit_dealloc} : !transform.any_op
32-
// Tile at [1] on row dim
33-
%t, %fl = transform.structured.tile_using_forall %out tile_sizes [1] : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
34-
// Fuse all into forall
35-
%f1, %fl1 = transform.structured.fuse_into_containing_op %reduce into %fl : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
36-
%f2, %fl2 = transform.structured.fuse_into_containing_op %sq into %fl1 : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
37-
%f3, %fl3 = transform.structured.fuse_into_containing_op %fill into %fl2 : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
38-
39-
// Fuse sq into reduce
40-
%reduce3 = transform.structured.match ops{["linalg.reduce"]} in %fl3 : (!transform.any_op) -> !transform.any_op
41-
%sq3 = transform.structured.match ops{["linalg.generic"]} in %fl3 : (!transform.any_op) -> !transform.any_op
42-
%sq_only, %out_only = transform.split_handle %sq3 {overflow_result = 1} : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
43-
%fused_sr = transform.air.fuse_multi_op_linalg %sq_only, %reduce3 : (!transform.any_op, !transform.any_op) -> !transform.any_op
44-
45-
// L1 for fills only (destination-only)
46-
%fills3 = transform.structured.match ops{["linalg.fill"]} in %fl3 : (!transform.any_op) -> !transform.any_op
47-
%fill_buf, %fill_new = transform.structured.bufferize_to_allocation %fills3
37+
// PHASE 3: L2 alloc for output, tile, fuse backward
38+
%ob, %on = transform.structured.bufferize_to_allocation %output_generic
39+
{memory_space = 1, bufferize_destination_only, emit_dealloc} : !transform.any_op
40+
%tiled_output, %forall = transform.structured.tile_using_forall %output_generic tile_sizes [1]
41+
: (!transform.any_op) -> (!transform.any_op, !transform.any_op)
42+
43+
%fr, %fl_r = transform.structured.fuse_into_containing_op %r into %forall : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
44+
%fg, %fl_g = transform.structured.fuse_into_containing_op %generic_sq into %fl_r : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
45+
%ff, %fl_f = transform.structured.fuse_into_containing_op %fill into %fl_g : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
46+
47+
// PHASE 4: canonicalize
48+
%func2 = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
49+
transform.apply_patterns to %func2 {
50+
transform.apply_patterns.linalg.tiling_canonicalization
51+
transform.apply_patterns.scf.for_loop_canonicalization
52+
transform.apply_patterns.canonicalization
53+
} : !transform.any_op
54+
transform.apply_cse to %func2 : !transform.any_op
55+
56+
// PHASE 5: L1 alloc for fills + intermediate ops
57+
%fills_2 = transform.structured.match ops{["linalg.fill"]} in %arg1 : (!transform.any_op) -> !transform.any_op
58+
%fb, %fn = transform.structured.bufferize_to_allocation %fills_2
59+
{memory_space = 2, bufferize_destination_only, emit_dealloc} : !transform.any_op
60+
61+
// Re-match: 2 generics (sq, output) + 1 reduce after tiling.
62+
%generics2 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op
63+
%tiled_generic1, %tiled_generic2 = transform.split_handle %generics2 : (!transform.any_op<"linalg.generic">) -> (!transform.any_op<"linalg.generic">, !transform.any_op<"linalg.generic">)
64+
%reduces2 = transform.structured.match ops{["linalg.reduce"]} in %arg1 : (!transform.any_op) -> !transform.any_op
65+
66+
// Promote input tensor to L1
67+
%op0 = transform.get_operand %tiled_generic1[0] : (!transform.any_op) -> !transform.any_value
68+
transform.structured.promote_tensor to 2 %op0 : !transform.any_value
69+
70+
// L1 alloc for intermediate outputs
71+
%g1b, %g1n = transform.structured.bufferize_to_allocation %tiled_generic1
4872
{memory_space = 2, bufferize_destination_only, emit_dealloc} : !transform.any_op
73+
%rb, %rn = transform.structured.bufferize_to_allocation %reduces2
74+
{memory_space = 2, bufferize_destination_only, emit_dealloc} : !transform.any_op
75+
%g2b, %g2n = transform.structured.bufferize_to_allocation %tiled_generic2
76+
{memory_space = 2, bufferize_destination_only, emit_dealloc} : !transform.any_op
77+
78+
// PHASE 6: canonicalize
79+
%func5 = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
80+
transform.apply_patterns to %func5 {
81+
transform.apply_patterns.linalg.tiling_canonicalization
82+
transform.apply_patterns.scf.for_loop_canonicalization
83+
transform.apply_patterns.canonicalization
84+
} : !transform.any_op
85+
transform.apply_cse to %func5 : !transform.any_op
4986

50-
// Canonicalize + bufferize (no L1 staging for reduce/generic inputs)
51-
%f2c = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
52-
transform.apply_patterns to %f2c { transform.apply_patterns.canonicalization } : !transform.any_op
53-
transform.apply_cse to %f2c : !transform.any_op
87+
// PHASE 7: one_shot_bufferize
5488
transform.include @one_shot_bufferize failures(propagate) (%arg1) : (!transform.any_op) -> ()
55-
%f6 = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
56-
transform.apply_patterns to %f6 { transform.apply_patterns.canonicalization } : !transform.any_op
57-
transform.apply_cse to %f6 : !transform.any_op
58-
%lc = transform.structured.match ops{["linalg.copy"]} in %arg1 : (!transform.any_op) -> !transform.any_op
59-
%mc = transform.structured.linalg_copy_to_memref %lc : (!transform.any_op) -> !transform.any_op
60-
%fu = transform.air.remove_uninitialized_copy %f6 : (!transform.any_op) -> (!transform.any_op)
61-
%fu2 = transform.air.eliminate_cascade_memcpy %fu : (!transform.any_op) -> (!transform.any_op)
62-
63-
// NOW promote linalg ops inside forall to L1 (BEFORE herd creation)
64-
// This creates memref.copy from L2 subviews to L1 allocs
65-
%forall_op = transform.structured.match ops{["scf.forall"]} in %arg1 : (!transform.any_op) -> !transform.any_op
66-
%gens_f = transform.structured.match ops{["linalg.generic"]} in %forall_op : (!transform.any_op) -> !transform.any_op
67-
%reds_f = transform.structured.match ops{["linalg.reduce"]} in %forall_op : (!transform.any_op) -> !transform.any_op
68-
%all_linalg_f = transform.merge_handles %reds_f, %gens_f { deduplicate } : !transform.any_op
69-
%promoted = transform.air.linalg_promote %all_linalg_f {memory_space = "L1"} : (!transform.any_op) -> !transform.any_op
70-
71-
// Herd + DMA
72-
%fh = transform.structured.match ops{["scf.forall"]} in %arg1 : (!transform.any_op) -> !transform.any_op
73-
%pa = transform.loop.forall_to_parallel %fh : (!transform.any_op) -> !transform.any_op
74-
%h = transform.air.par_to_herd %pa : (!transform.any_op) -> !transform.any_op
75-
%lc2 = transform.structured.match ops{["linalg.copy"]} in %h : (!transform.any_op) -> !transform.any_op
76-
%mc2 = transform.structured.match ops{["memref.copy"]} in %h : (!transform.any_op) -> !transform.any_op
77-
%mc3 = transform.structured.linalg_copy_to_memref %lc2 : (!transform.any_op) -> !transform.any_op
78-
%ac = transform.merge_handles %mc2, %mc3 { deduplicate } : !transform.any_op
79-
%dm = transform.air.copy_to_dma %ac : (!transform.any_op) -> !transform.any_op
80-
81-
// Re-match the herd since handles may be stale after promote/dma
82-
%h2 = transform.structured.match ops{["air.herd"]} in %arg1 : (!transform.any_op) -> !transform.any_op
83-
// Inner vectorization tiling
84-
%gens_h = transform.structured.match ops{["linalg.generic"]} in %h2 : (!transform.any_op) -> !transform.any_op
85-
%inner_g, %inner_gl:1 = transform.structured.tile_using_for %gens_h tile_sizes [0, 16] : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
86-
%reds_h = transform.structured.match ops{["linalg.reduce"]} in %h2 : (!transform.any_op) -> !transform.any_op
87-
%inner_r, %inner_rl:1 = transform.structured.tile_using_for %reds_h tile_sizes [0, 16] : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
88-
%fills_h = transform.structured.match ops{["linalg.fill"]} in %h2 : (!transform.any_op) -> !transform.any_op
89-
%fill_scl = transform.structured.convert_to_loops %fills_h : (!transform.any_op) -> !transform.any_op
90-
%vh = transform.air.herd_vectorize %h2 : (!transform.any_op) -> !transform.any_op
91-
92-
// Lower vector reductions FIRST (creates arith.mulf/addf from vector.multi_reduction)
93-
%func_final = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
94-
transform.apply_patterns to %func_final {
95-
transform.apply_patterns.vector.reorder_multi_reduction_dims lowering_strategy = "innerreduction"
96-
transform.apply_patterns.vector.multi_reduction_flattening lowering_strategy = "innerreduction"
97-
transform.apply_patterns.vector.multi_reduction_unrolling lowering_strategy = "innerreduction"
98-
transform.apply_patterns.vector.lower_contraction
99-
transform.apply_patterns.vector.lower_transfer
89+
90+
// PHASE 8: canonicalize + remove uninitialized copy
91+
%func6 = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
92+
transform.apply_patterns to %func6 {
93+
transform.apply_patterns.linalg.tiling_canonicalization
94+
transform.apply_patterns.scf.for_loop_canonicalization
95+
transform.apply_patterns.canonicalization
96+
} : !transform.any_op
97+
transform.apply_cse to %func6 : !transform.any_op
98+
transform.apply_patterns to %func6 {
99+
transform.apply_patterns.canonicalization
100100
} : !transform.any_op
101-
transform.apply_cse to %func_final : !transform.any_op
101+
%func_op_updated = transform.air.remove_uninitialized_copy %func6 : (!transform.any_op) -> !transform.any_op
102102

103-
// AIE2P type casts AFTER lowering: mulf and addf are bf16-only, divf and rsqrt are f32-only
104-
%vh2 = transform.structured.match ops{["air.herd"]} in %arg1 : (!transform.any_op) -> !transform.any_op
105-
%vector_muls = transform.structured.match ops{["arith.mulf"]} in %vh2 : (!transform.any_op) -> !transform.any_op
106-
%mul_cast = transform.air.vector_type_cast %vector_muls {target_element_type = bf16} : (!transform.any_op) -> !transform.any_op
107-
%vector_adds = transform.structured.match ops{["arith.addf"]} in %vh2 : (!transform.any_op) -> !transform.any_op
108-
%add_cast = transform.air.vector_type_cast %vector_adds {target_element_type = bf16} : (!transform.any_op) -> !transform.any_op
109-
%func_s1 = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
110-
%func_s1_done = transform.air.convert_size1_vector_to_scalar %func_s1 : (!transform.any_op) -> !transform.any_op
111-
transform.apply_patterns to %func_s1_done {
103+
// PHASE 9: generalize remaining linalg.reduce, tile for vectorization, divf-sqrt -> rsqrt
104+
%remaining_reduces = transform.structured.match ops{["linalg.reduce"]} in %arg1 : (!transform.any_op) -> !transform.any_op
105+
%generalized = transform.structured.generalize %remaining_reduces : (!transform.any_op) -> !transform.any_op
106+
107+
%lg = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op
108+
%inner, %vl:1 = transform.structured.tile_using_for %lg tile_sizes [0, 16] : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
109+
110+
%fou1 = transform.air.convert_divf_sqrt_to_rsqrt %func_op_updated : (!transform.any_op) -> !transform.any_op
111+
112+
// PHASE 10: par_to_herd, copy_to_dma, herd_vectorize, casts
113+
%fa = transform.structured.match ops{["scf.forall"]} in %arg1 : (!transform.any_op) -> !transform.any_op
114+
%parallel = transform.loop.forall_to_parallel %fa : (!transform.any_op) -> !transform.any_op
115+
%herd = transform.air.par_to_herd %parallel : (!transform.any_op) -> !transform.any_op
116+
117+
%copies_in_herd = transform.structured.match ops{["memref.copy", "linalg.copy"]} in %herd : (!transform.any_op) -> !transform.any_op
118+
%dmas = transform.air.copy_to_dma %copies_in_herd : (!transform.any_op) -> !transform.any_op
119+
120+
%vh = transform.air.herd_vectorize %herd : (!transform.any_op) -> !transform.any_op
121+
122+
%func4 = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
123+
transform.apply_patterns to %func4 {
124+
transform.apply_patterns.canonicalization
112125
transform.apply_patterns.vector.cast_away_vector_leading_one_dim
126+
} : !transform.any_op
127+
128+
%vh2 = transform.air.broadcast_before_unary %func4 {op_name = "math.rsqrt"} : (!transform.any_op) -> !transform.any_op
129+
130+
%vector_reductions = transform.structured.match ops{["vector.multi_reduction"]} in %vh2 : (!transform.any_op) -> !transform.any_op
131+
%r1 = transform.air.vector_type_cast %vector_reductions {target_element_type = bf16} : (!transform.any_op) -> !transform.any_op
132+
133+
%vector_muls = transform.structured.match ops{["arith.mulf"]} in %vh2 : (!transform.any_op) -> !transform.any_op
134+
%r2 = transform.air.vector_type_cast %vector_muls {target_element_type = bf16} : (!transform.any_op) -> !transform.any_op
135+
136+
%func7 = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
137+
%func7t = transform.air.convert_size1_vector_to_scalar %func7 : (!transform.any_op) -> !transform.any_op
138+
transform.apply_patterns to %func7t {
139+
transform.apply_patterns.linalg.tiling_canonicalization
140+
transform.apply_patterns.scf.for_loop_canonicalization
113141
transform.apply_patterns.canonicalization
142+
transform.apply_patterns.vector.reorder_multi_reduction_dims lowering_strategy = "innerreduction"
143+
transform.apply_patterns.vector.multi_reduction_flattening lowering_strategy = "innerreduction"
144+
transform.apply_patterns.vector.multi_reduction_unrolling lowering_strategy = "innerreduction"
114145
} : !transform.any_op
115-
transform.apply_cse to %func_s1_done : !transform.any_op
146+
transform.apply_cse to %func7t : !transform.any_op
147+
116148
transform.yield
117149
}
118150
}

0 commit comments

Comments
 (0)