Skip to content

Commit 3cbf180

Browse files
committed
using pipeline helper
1 parent 7b68e85 commit 3cbf180

File tree

2 files changed

+12
-12
lines changed

2 files changed

+12
-12
lines changed

lighthouse/pipeline/helper.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,3 +50,9 @@ def match(*args, **kwargs):
5050
def canonicalize(op):
5151
with ir.InsertionPoint(transform.apply_patterns(op).patterns):
5252
transform.apply_patterns_canonicalization()
53+
54+
55+
def cleanup_func(target):
56+
func = structured.MatchOp.match_op_names(target, ["func.func"]).result
57+
transform.apply_cse(func)
58+
canonicalize(func)

lighthouse/schedule/x86/tile_and_vector_matmul.py

Lines changed: 6 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -5,13 +5,7 @@
55
from mlir.dialects.transform import tensor
66
from mlir.dialects.transform import vector
77
from mlir.dialects.transform import x86
8-
9-
10-
def _cleanup(target):
11-
func = structured.MatchOp.match_op_names(target, ["func.func"]).result
12-
transform.apply_cse(func)
13-
with ir.InsertionPoint(transform.ApplyPatternsOp(func).patterns):
14-
transform.apply_patterns_canonicalization()
8+
from lighthouse.pipeline.helper import cleanup_func
159

1610

1711
def create(tile_size=64) -> ir.Module:
@@ -57,7 +51,7 @@ def create(tile_size=64) -> ir.Module:
5751
):
5852
structured.apply_patterns_linalg_fold_unit_extent_dims_via_slices()
5953
structured.apply_patterns_linalg_fold_pack_unpack_into_empty()
60-
_cleanup(named_seq.bodyTarget)
54+
cleanup_func(named_seq.bodyTarget)
6155

6256
# Register tiling.
6357
reg_tile_m = 8
@@ -87,7 +81,7 @@ def create(tile_size=64) -> ir.Module:
8781
peel_front=False,
8882
fail_if_already_divisible=False,
8983
)
90-
_cleanup(named_seq.bodyTarget)
84+
cleanup_func(named_seq.bodyTarget)
9185

9286
# Register unroll.
9387
gemms = structured.MatchOp.match_op_names(named_seq.bodyTarget, [gemm_name])
@@ -100,7 +94,7 @@ def create(tile_size=64) -> ir.Module:
10094
loop.loop_unroll(loops[2], reg_tile_k)
10195
loop.loop_unroll(loops[0], reg_tile_m)
10296
transform.yield_()
103-
_cleanup(named_seq.bodyTarget)
97+
cleanup_func(named_seq.bodyTarget)
10498

10599
# Vectorize operations.
106100
gemms = structured.MatchOp.match_op_names(named_seq.bodyTarget, [gemm_name])
@@ -115,7 +109,7 @@ def create(tile_size=64) -> ir.Module:
115109
):
116110
vector.apply_patterns_vector_reduction_to_contract()
117111
vector.apply_patterns_vector_transfer_permutation_patterns()
118-
_cleanup(named_seq.bodyTarget)
112+
cleanup_func(named_seq.bodyTarget)
119113

120114
# Loop hoisting.
121115
all_loops = structured.MatchOp(
@@ -148,7 +142,7 @@ def create(tile_size=64) -> ir.Module:
148142
lowering_strategy=vector.VectorContractLowering.OuterProduct
149143
)
150144
vector.apply_patterns_vector_lower_outerproduct()
151-
_cleanup(named_seq.bodyTarget)
145+
cleanup_func(named_seq.bodyTarget)
152146

153147
transform.yield_()
154148
return schedule

0 commit comments

Comments
 (0)