Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
228 changes: 130 additions & 98 deletions examples/rms_norm/transform_aie2p.mlir
Original file line number Diff line number Diff line change
@@ -1,118 +1,150 @@
// RMS Norm transform for AIE2P.
// 2D kernel (BLOCK_M=2 x BLOCK_N=64).
//
// Strategy: bufferize FIRST (no L1 staging), then use linalg_promote
// on the linalg ops inside the forall to promote L2 subviews to L1 allocs.
// This creates memref.copy ops that par_to_herd + copy_to_dma convert to DMAs.
// RMS Norm transform for AIE2P, following mlir-air xrt 43_triton_layernorm/transform_aie2p.mlir.
// Chain (after fuse_elementwise + transpose_reduce): generic_sq -> reduce -> output_generic.

module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {

// PHASE 1: canonicalize + fold unit extent
%func0 = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
transform.apply_patterns to %func0 { transform.apply_patterns.canonicalization
transform.apply_patterns.linalg.fold_unit_extent_dims_via_reshapes } : !transform.any_op
transform.apply_patterns to %func0 {
transform.apply_patterns.linalg.tiling_canonicalization
transform.apply_patterns.scf.for_loop_canonicalization
transform.apply_patterns.canonicalization
transform.apply_patterns.linalg.fold_unit_extent_dims_via_reshapes
} : !transform.any_op
transform.apply_cse to %func0 : !transform.any_op

// PHASE 2: fuse elementwise + transpose reduce + canonicalize
%func1 = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
%fused_func = transform.air.fuse_elementwise_linalg %func1 : (!transform.any_op) -> !transform.any_op
%reduces = transform.structured.match ops{["linalg.reduce"]} in %arg1 : (!transform.any_op) -> !transform.any_op
%tr = transform.air.transpose_reduce %reduces : (!transform.any_op) -> !transform.any_op
%func1a = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
transform.apply_patterns to %func1a { transform.apply_patterns.canonicalization } : !transform.any_op
transform.apply_cse to %func1a : !transform.any_op
%func1 = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
%f = transform.air.fuse_elementwise_linalg %func1 : (!transform.any_op) -> !transform.any_op
%fa = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
transform.apply_patterns to %fa { transform.apply_patterns.canonicalization } : !transform.any_op
transform.apply_cse to %fa : !transform.any_op

%ag = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op
%sq, %out = transform.split_handle %ag : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
%reduce = transform.structured.match ops{["linalg.reduce"]} in %arg1 : (!transform.any_op) -> !transform.any_op

transform.apply_patterns to %fused_func {
transform.apply_patterns.linalg.tiling_canonicalization
transform.apply_patterns.scf.for_loop_canonicalization
transform.apply_patterns.canonicalization
} : !transform.any_op
transform.apply_cse to %fused_func : !transform.any_op

// Data-flow navigation. Chain: generic_sq -> reduce -> output_generic
%r = transform.structured.match ops{["linalg.reduce"]} in %arg1 : (!transform.any_op) -> !transform.any_op
%generic_sq = transform.get_producer_of_operand %r[0] : (!transform.any_op) -> !transform.any_op
%materialize = transform.structured.match ops{["bufferization.materialize_in_destination"]} in %arg1 : (!transform.any_op) -> !transform.any_op
%output_generic = transform.get_producer_of_operand %materialize[0] : (!transform.any_op) -> !transform.any_op
%fill = transform.structured.match ops{["linalg.fill"]} in %arg1 : (!transform.any_op) -> !transform.any_op

// L2 output alloc
%ob, %nb = transform.structured.bufferize_to_allocation %out {memory_space = 1, bufferize_destination_only, emit_dealloc} : !transform.any_op
// Tile at [1] on row dim
%t, %fl = transform.structured.tile_using_forall %out tile_sizes [1] : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
// Fuse all into forall
%f1, %fl1 = transform.structured.fuse_into_containing_op %reduce into %fl : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
%f2, %fl2 = transform.structured.fuse_into_containing_op %sq into %fl1 : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
%f3, %fl3 = transform.structured.fuse_into_containing_op %fill into %fl2 : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)

// Fuse sq into reduce
%reduce3 = transform.structured.match ops{["linalg.reduce"]} in %fl3 : (!transform.any_op) -> !transform.any_op
%sq3 = transform.structured.match ops{["linalg.generic"]} in %fl3 : (!transform.any_op) -> !transform.any_op
%sq_only, %out_only = transform.split_handle %sq3 {overflow_result = 1} : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
%fused_sr = transform.air.fuse_multi_op_linalg %sq_only, %reduce3 : (!transform.any_op, !transform.any_op) -> !transform.any_op

// L1 for fills only (destination-only)
%fills3 = transform.structured.match ops{["linalg.fill"]} in %fl3 : (!transform.any_op) -> !transform.any_op
%fill_buf, %fill_new = transform.structured.bufferize_to_allocation %fills3
// PHASE 3: L2 alloc for output, tile, fuse backward
%ob, %on = transform.structured.bufferize_to_allocation %output_generic
{memory_space = 1, bufferize_destination_only, emit_dealloc} : !transform.any_op
%tiled_output, %forall = transform.structured.tile_using_forall %output_generic tile_sizes [1]
: (!transform.any_op) -> (!transform.any_op, !transform.any_op)

%fr, %fl_r = transform.structured.fuse_into_containing_op %r into %forall : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
%fg, %fl_g = transform.structured.fuse_into_containing_op %generic_sq into %fl_r : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
%ff, %fl_f = transform.structured.fuse_into_containing_op %fill into %fl_g : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)

// PHASE 4: canonicalize
%func2 = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
transform.apply_patterns to %func2 {
transform.apply_patterns.linalg.tiling_canonicalization
transform.apply_patterns.scf.for_loop_canonicalization
transform.apply_patterns.canonicalization
} : !transform.any_op
transform.apply_cse to %func2 : !transform.any_op

// PHASE 5: L1 alloc for fills + intermediate ops
%fills_2 = transform.structured.match ops{["linalg.fill"]} in %arg1 : (!transform.any_op) -> !transform.any_op
%fb, %fn = transform.structured.bufferize_to_allocation %fills_2
{memory_space = 2, bufferize_destination_only, emit_dealloc} : !transform.any_op

// Re-match: 2 generics (sq, output) + 1 reduce after tiling.
%generics2 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op
%tiled_generic1, %tiled_generic2 = transform.split_handle %generics2 : (!transform.any_op<"linalg.generic">) -> (!transform.any_op<"linalg.generic">, !transform.any_op<"linalg.generic">)
%reduces2 = transform.structured.match ops{["linalg.reduce"]} in %arg1 : (!transform.any_op) -> !transform.any_op

// Promote input tensor to L1
%op0 = transform.get_operand %tiled_generic1[0] : (!transform.any_op) -> !transform.any_value
transform.structured.promote_tensor to 2 %op0 : !transform.any_value

// L1 alloc for intermediate outputs
%g1b, %g1n = transform.structured.bufferize_to_allocation %tiled_generic1
{memory_space = 2, bufferize_destination_only, emit_dealloc} : !transform.any_op
%rb, %rn = transform.structured.bufferize_to_allocation %reduces2
{memory_space = 2, bufferize_destination_only, emit_dealloc} : !transform.any_op
%g2b, %g2n = transform.structured.bufferize_to_allocation %tiled_generic2
{memory_space = 2, bufferize_destination_only, emit_dealloc} : !transform.any_op

// PHASE 6: canonicalize
%func5 = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
transform.apply_patterns to %func5 {
transform.apply_patterns.linalg.tiling_canonicalization
transform.apply_patterns.scf.for_loop_canonicalization
transform.apply_patterns.canonicalization
} : !transform.any_op
transform.apply_cse to %func5 : !transform.any_op

// Canonicalize + bufferize (no L1 staging for reduce/generic inputs)
%f2c = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
transform.apply_patterns to %f2c { transform.apply_patterns.canonicalization } : !transform.any_op
transform.apply_cse to %f2c : !transform.any_op
// PHASE 7: one_shot_bufferize
transform.include @one_shot_bufferize failures(propagate) (%arg1) : (!transform.any_op) -> ()
%f6 = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
transform.apply_patterns to %f6 { transform.apply_patterns.canonicalization } : !transform.any_op
transform.apply_cse to %f6 : !transform.any_op
%lc = transform.structured.match ops{["linalg.copy"]} in %arg1 : (!transform.any_op) -> !transform.any_op
%mc = transform.structured.linalg_copy_to_memref %lc : (!transform.any_op) -> !transform.any_op
%fu = transform.air.remove_uninitialized_copy %f6 : (!transform.any_op) -> (!transform.any_op)
%fu2 = transform.air.eliminate_cascade_memcpy %fu : (!transform.any_op) -> (!transform.any_op)

// NOW promote linalg ops inside forall to L1 (BEFORE herd creation)
// This creates memref.copy from L2 subviews to L1 allocs
%forall_op = transform.structured.match ops{["scf.forall"]} in %arg1 : (!transform.any_op) -> !transform.any_op
%gens_f = transform.structured.match ops{["linalg.generic"]} in %forall_op : (!transform.any_op) -> !transform.any_op
%reds_f = transform.structured.match ops{["linalg.reduce"]} in %forall_op : (!transform.any_op) -> !transform.any_op
%all_linalg_f = transform.merge_handles %reds_f, %gens_f { deduplicate } : !transform.any_op
%promoted = transform.air.linalg_promote %all_linalg_f {memory_space = "L1"} : (!transform.any_op) -> !transform.any_op

// Herd + DMA
%fh = transform.structured.match ops{["scf.forall"]} in %arg1 : (!transform.any_op) -> !transform.any_op
%pa = transform.loop.forall_to_parallel %fh : (!transform.any_op) -> !transform.any_op
%h = transform.air.par_to_herd %pa : (!transform.any_op) -> !transform.any_op
%lc2 = transform.structured.match ops{["linalg.copy"]} in %h : (!transform.any_op) -> !transform.any_op
%mc2 = transform.structured.match ops{["memref.copy"]} in %h : (!transform.any_op) -> !transform.any_op
%mc3 = transform.structured.linalg_copy_to_memref %lc2 : (!transform.any_op) -> !transform.any_op
%ac = transform.merge_handles %mc2, %mc3 { deduplicate } : !transform.any_op
%dm = transform.air.copy_to_dma %ac : (!transform.any_op) -> !transform.any_op

// Re-match the herd since handles may be stale after promote/dma
%h2 = transform.structured.match ops{["air.herd"]} in %arg1 : (!transform.any_op) -> !transform.any_op
// Inner vectorization tiling
%gens_h = transform.structured.match ops{["linalg.generic"]} in %h2 : (!transform.any_op) -> !transform.any_op
%inner_g, %inner_gl:1 = transform.structured.tile_using_for %gens_h tile_sizes [0, 16] : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
%reds_h = transform.structured.match ops{["linalg.reduce"]} in %h2 : (!transform.any_op) -> !transform.any_op
%inner_r, %inner_rl:1 = transform.structured.tile_using_for %reds_h tile_sizes [0, 16] : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
%fills_h = transform.structured.match ops{["linalg.fill"]} in %h2 : (!transform.any_op) -> !transform.any_op
%fill_scl = transform.structured.convert_to_loops %fills_h : (!transform.any_op) -> !transform.any_op
%vh = transform.air.herd_vectorize %h2 : (!transform.any_op) -> !transform.any_op

// Lower vector reductions FIRST (creates arith.mulf/addf from vector.multi_reduction)
%func_final = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
transform.apply_patterns to %func_final {
transform.apply_patterns.vector.reorder_multi_reduction_dims lowering_strategy = "innerreduction"
transform.apply_patterns.vector.multi_reduction_flattening lowering_strategy = "innerreduction"
transform.apply_patterns.vector.multi_reduction_unrolling lowering_strategy = "innerreduction"
transform.apply_patterns.vector.lower_contraction
transform.apply_patterns.vector.lower_transfer

// PHASE 8: canonicalize + remove uninitialized copy
%func6 = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
transform.apply_patterns to %func6 {
transform.apply_patterns.linalg.tiling_canonicalization
transform.apply_patterns.scf.for_loop_canonicalization
transform.apply_patterns.canonicalization
} : !transform.any_op
transform.apply_cse to %func6 : !transform.any_op
transform.apply_patterns to %func6 {
transform.apply_patterns.canonicalization
} : !transform.any_op
transform.apply_cse to %func_final : !transform.any_op
%func_op_updated = transform.air.remove_uninitialized_copy %func6 : (!transform.any_op) -> !transform.any_op

// AIE2P type casts AFTER lowering: mulf and addf are bf16-only, divf and rsqrt are f32-only
%vh2 = transform.structured.match ops{["air.herd"]} in %arg1 : (!transform.any_op) -> !transform.any_op
%vector_muls = transform.structured.match ops{["arith.mulf"]} in %vh2 : (!transform.any_op) -> !transform.any_op
%mul_cast = transform.air.vector_type_cast %vector_muls {target_element_type = bf16} : (!transform.any_op) -> !transform.any_op
%vector_adds = transform.structured.match ops{["arith.addf"]} in %vh2 : (!transform.any_op) -> !transform.any_op
%add_cast = transform.air.vector_type_cast %vector_adds {target_element_type = bf16} : (!transform.any_op) -> !transform.any_op
%func_s1 = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
%func_s1_done = transform.air.convert_size1_vector_to_scalar %func_s1 : (!transform.any_op) -> !transform.any_op
transform.apply_patterns to %func_s1_done {
// PHASE 9: generalize remaining linalg.reduce, tile for vectorization, divf-sqrt -> rsqrt
%remaining_reduces = transform.structured.match ops{["linalg.reduce"]} in %arg1 : (!transform.any_op) -> !transform.any_op
%generalized = transform.structured.generalize %remaining_reduces : (!transform.any_op) -> !transform.any_op

%lg = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op
%inner, %vl:1 = transform.structured.tile_using_for %lg tile_sizes [0, 16] : (!transform.any_op) -> (!transform.any_op, !transform.any_op)

%fou1 = transform.air.convert_divf_sqrt_to_rsqrt %func_op_updated : (!transform.any_op) -> !transform.any_op

// PHASE 10: par_to_herd, copy_to_dma, herd_vectorize, casts
%fa = transform.structured.match ops{["scf.forall"]} in %arg1 : (!transform.any_op) -> !transform.any_op
%parallel = transform.loop.forall_to_parallel %fa : (!transform.any_op) -> !transform.any_op
%herd = transform.air.par_to_herd %parallel : (!transform.any_op) -> !transform.any_op

%copies_in_herd = transform.structured.match ops{["memref.copy", "linalg.copy"]} in %herd : (!transform.any_op) -> !transform.any_op
%dmas = transform.air.copy_to_dma %copies_in_herd : (!transform.any_op) -> !transform.any_op

%vh = transform.air.herd_vectorize %herd : (!transform.any_op) -> !transform.any_op

%func4 = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
transform.apply_patterns to %func4 {
transform.apply_patterns.canonicalization
transform.apply_patterns.vector.cast_away_vector_leading_one_dim
} : !transform.any_op

%vh2 = transform.air.broadcast_before_unary %func4 {op_name = "math.rsqrt"} : (!transform.any_op) -> !transform.any_op

%vector_reductions = transform.structured.match ops{["vector.multi_reduction"]} in %vh2 : (!transform.any_op) -> !transform.any_op
%r1 = transform.air.vector_type_cast %vector_reductions {target_element_type = bf16} : (!transform.any_op) -> !transform.any_op

%vector_muls = transform.structured.match ops{["arith.mulf"]} in %vh2 : (!transform.any_op) -> !transform.any_op
%r2 = transform.air.vector_type_cast %vector_muls {target_element_type = bf16} : (!transform.any_op) -> !transform.any_op

%func7 = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
%func7t = transform.air.convert_size1_vector_to_scalar %func7 : (!transform.any_op) -> !transform.any_op
transform.apply_patterns to %func7t {
transform.apply_patterns.linalg.tiling_canonicalization
transform.apply_patterns.scf.for_loop_canonicalization
transform.apply_patterns.canonicalization
transform.apply_patterns.vector.reorder_multi_reduction_dims lowering_strategy = "innerreduction"
transform.apply_patterns.vector.multi_reduction_flattening lowering_strategy = "innerreduction"
transform.apply_patterns.vector.multi_reduction_unrolling lowering_strategy = "innerreduction"
} : !transform.any_op
transform.apply_cse to %func_s1_done : !transform.any_op
transform.apply_cse to %func7t : !transform.any_op

transform.yield
}
}
Loading