55from mlir .dialects .transform import tensor
66from mlir .dialects .transform import vector
77from mlir .dialects .transform import x86
8-
9-
10- def _cleanup (target ):
11- func = structured .MatchOp .match_op_names (target , ["func.func" ]).result
12- transform .apply_cse (func )
13- with ir .InsertionPoint (transform .ApplyPatternsOp (func ).patterns ):
14- transform .apply_patterns_canonicalization ()
8+ from lighthouse .pipeline .helper import cleanup_func
159
1610
1711def create (tile_size = 64 ) -> ir .Module :
@@ -57,7 +51,7 @@ def create(tile_size=64) -> ir.Module:
5751 ):
5852 structured .apply_patterns_linalg_fold_unit_extent_dims_via_slices ()
5953 structured .apply_patterns_linalg_fold_pack_unpack_into_empty ()
60- _cleanup (named_seq .bodyTarget )
54+ cleanup_func (named_seq .bodyTarget )
6155
6256 # Register tiling.
6357 reg_tile_m = 8
@@ -87,7 +81,7 @@ def create(tile_size=64) -> ir.Module:
8781 peel_front = False ,
8882 fail_if_already_divisible = False ,
8983 )
90- _cleanup (named_seq .bodyTarget )
84+ cleanup_func (named_seq .bodyTarget )
9185
9286 # Register unroll.
9387 gemms = structured .MatchOp .match_op_names (named_seq .bodyTarget , [gemm_name ])
@@ -100,7 +94,7 @@ def create(tile_size=64) -> ir.Module:
10094 loop .loop_unroll (loops [2 ], reg_tile_k )
10195 loop .loop_unroll (loops [0 ], reg_tile_m )
10296 transform .yield_ ()
103- _cleanup (named_seq .bodyTarget )
97+ cleanup_func (named_seq .bodyTarget )
10498
10599 # Vectorize operations.
106100 gemms = structured .MatchOp .match_op_names (named_seq .bodyTarget , [gemm_name ])
@@ -115,7 +109,7 @@ def create(tile_size=64) -> ir.Module:
115109 ):
116110 vector .apply_patterns_vector_reduction_to_contract ()
117111 vector .apply_patterns_vector_transfer_permutation_patterns ()
118- _cleanup (named_seq .bodyTarget )
112+ cleanup_func (named_seq .bodyTarget )
119113
120114 # Loop hoisting.
121115 all_loops = structured .MatchOp (
@@ -148,7 +142,7 @@ def create(tile_size=64) -> ir.Module:
148142 lowering_strategy = vector .VectorContractLowering .OuterProduct
149143 )
150144 vector .apply_patterns_vector_lower_outerproduct ()
151- _cleanup (named_seq .bodyTarget )
145+ cleanup_func (named_seq .bodyTarget )
152146
153147 transform .yield_ ()
154148 return schedule
0 commit comments