Skip to content

Commit c99351f

Browse files
committed
adding benchmark function through pattern rewrite schedule
1 parent 3cbf180 commit c99351f

File tree

10 files changed

+186
-84
lines changed

10 files changed

+186
-84
lines changed

examples/feed-forward-mpi/feed-forward-mpi.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,8 @@
1313
import ctypes
1414
from contextlib import contextmanager
1515
from typing import Optional
16-
1716
import numpy as np
17+
1818
from mlir import ir
1919
from mlir.dialects import transform
2020
from mlir.dialects.transform.bufferization import OneShotBufferizeOp
@@ -25,12 +25,13 @@
2525
make_nd_memref_descriptor,
2626
as_ctype,
2727
)
28+
2829
from lighthouse.utils.memref import (
2930
to_ctype as memref_to_ctype,
3031
deallocate_memrefs_on_exit,
3132
)
3233
from lighthouse.pipeline.helper import apply_registered_pass, match
33-
from lighthouse.workload import Workload, benchmark
34+
from lighthouse.workload import Workload, benchmark, get_bench_wrapper_schedule
3435
from lighthouse.schedule.x86 import tile_and_vector_matmul
3536
from ff_weight_stationary import generate_ff_payload
3637

@@ -103,8 +104,6 @@ class DistFF(Workload):
103104
where A, B, C are (M,K), (K,N), (N,K) matrices respectively.
104105
"""
105106

106-
payload_function_name = "payload"
107-
108107
def __init__(self, args, P: int, R: int):
109108
self.M = args.sizes[0]
110109
self.N = args.sizes[1]
@@ -308,8 +307,9 @@ def schedule_modules(
308307
transform.YieldOp()
309308
func = None
310309

310+
bench_schedule = get_bench_wrapper_schedule(self)
311+
311312
tile_schedule = tile_and_vector_matmul.create()
312-
tile_schedule.body.operations[0].verify()
313313

314314
main_schedule = ir.Module.create()
315315
main_schedule.operation.attributes["transform.with_named_sequence"] = (
@@ -377,7 +377,7 @@ def schedule_modules(
377377
transform.PrintOp(target=mod)
378378
transform.YieldOp()
379379

380-
return [pre_schedule, tile_schedule, main_schedule]
380+
return [pre_schedule, bench_schedule, tile_schedule, main_schedule]
381381

382382

383383
if __name__ == "__main__":
@@ -394,7 +394,7 @@ def schedule_modules(
394394
# execute(wload, verbose=args.verbose)
395395
rprint(" Benchmark".center(60, "-"))
396396
times = benchmark(
397-
wload, nruns=10, nwarmup=1, check_correctness=True, verbose=args.verbose
397+
wload, nruns=100, nwarmup=10, check_correctness=True, verbose=args.verbose
398398
)
399399
# compute statistics
400400
times *= 1e6

examples/workload/example.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
from mlir.execution_engine import ExecutionEngine
2020

2121
from lighthouse.pipeline.helper import apply_registered_pass, canonicalize, match
22-
from lighthouse.workload import Workload, execute, benchmark
22+
from lighthouse.workload import Workload, execute, benchmark, get_bench_wrapper_schedule
2323

2424

2525
class ElementwiseSum(Workload):
@@ -158,7 +158,7 @@ def schedule_modules(
158158
mod = apply_registered_pass(mod, "reconcile-unrealized-casts")
159159
transform.YieldOp()
160160

161-
return [schedule_module]
161+
return [get_bench_wrapper_schedule(self), schedule_module]
162162

163163

164164
if __name__ == "__main__":

examples/xegpu/matmul.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
from mlir import ir
1818
from mlir.execution_engine import ExecutionEngine
1919

20-
from lighthouse.workload import benchmark
20+
from lighthouse.workload import benchmark, get_bench_wrapper_schedule
2121
from lighthouse.utils.memref import to_ctype as memref_to_ctype
2222
from lighthouse.utils.numpy import numpy_to_ctype
2323
from lighthouse.schedule.xegpu.mlp_schedule import get_schedule_module
@@ -195,14 +195,15 @@ def schedule_modules(
195195
self, stop_at_stage: Optional[str] = None, parameters: Optional[dict] = None
196196
) -> list[ir.Module]:
197197
return [
198+
get_bench_wrapper_schedule(self),
198199
get_schedule_module(
199200
has_bias=self.has_bias,
200201
has_relu=self.has_relu,
201202
has_convert_c=False,
202203
stop_at_stage=stop_at_stage,
203204
nlayers=1,
204205
params={"layer_0": parameters},
205-
)
206+
),
206207
]
207208

208209
def shared_libs(self) -> list[str]:

examples/xegpu/mlp.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
from mlir import ir
2222
from mlir.execution_engine import ExecutionEngine
2323

24-
from lighthouse.workload import benchmark
24+
from lighthouse.workload import benchmark, get_bench_wrapper_schedule
2525
from lighthouse.utils.memref import to_ctype as memref_to_ctype
2626
from lighthouse.utils.numpy import numpy_to_ctype
2727
from lighthouse.schedule.xegpu.mlp_schedule import get_schedule_module
@@ -256,14 +256,15 @@ def schedule_modules(
256256
self, stop_at_stage: Optional[str] = None, parameters: Optional[dict] = None
257257
) -> list[ir.Module]:
258258
return [
259+
get_bench_wrapper_schedule(self),
259260
get_schedule_module(
260261
has_bias=self.has_bias,
261262
has_relu=self.has_relu,
262263
skip_final_layer_relu=True,
263264
stop_at_stage=stop_at_stage,
264265
nlayers=self.nlayers,
265266
params=parameters,
266-
)
267+
),
267268
]
268269

269270
def shared_libs(self) -> list[str]:
Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
from contextlib import contextmanager
2+
from mlir import rewrite, ir
3+
from mlir.dialects import ext, transform
4+
from mlir.dialects.transform import AnyOpType
5+
6+
7+
@ext.register_dialect
8+
class PatternDialect(ext.Dialect, name="lighthouse"):
9+
pass
10+
11+
12+
def rewrite_pattern(patterns: dict, pname: str):
13+
"""Return a rewrite pattern class that can be registered with MLIR.
14+
The patterns dict should map op names to their corresponding match and rewrite functions."""
15+
16+
@ext.register_operation(PatternDialect, replace=True)
17+
class RewritePattern(PatternDialect.Operation, name=pname):
18+
@classmethod
19+
def attach_interface_impls(cls, ctx=None):
20+
cls.PatternDescriptorOpInterfaceFallbackModel.attach(
21+
cls.OPERATION_NAME, context=ctx
22+
)
23+
24+
class PatternDescriptorOpInterfaceFallbackModel(
25+
transform.PatternDescriptorOpInterface
26+
):
27+
@staticmethod
28+
def populate_patterns(
29+
op: "RewritePattern",
30+
patternset: rewrite.RewritePatternSet,
31+
) -> None:
32+
for op_name, match_and_rewrite in patterns.items():
33+
patternset.add(op_name, match_and_rewrite, benefit=1)
34+
35+
return RewritePattern
36+
37+
38+
@contextmanager
39+
def schedule_boilerplate():
40+
schedule = ir.Module.create()
41+
schedule.operation.attributes["transform.with_named_sequence"] = ir.UnitAttr.get()
42+
with ir.InsertionPoint(schedule.body):
43+
named_sequence = transform.NamedSequenceOp(
44+
"__transform_main",
45+
[AnyOpType.get()],
46+
[AnyOpType.get()],
47+
arg_attrs=[{"transform.consumed": ir.UnitAttr.get()}],
48+
)
49+
with ir.InsertionPoint(named_sequence.body):
50+
yield schedule, named_sequence
51+
52+
53+
def pattern_rewrite_schedule(patterns: dict, pname: str = "rewrite_pattern"):
54+
"""Return a transform module that applies the given rewrite patterns.
55+
patterns: dict mapping op names to match-and-rewrite functions.
56+
pname: name for the generated rewrite pattern operation."""
57+
58+
rw_pattern = rewrite_pattern(patterns, pname)
59+
PatternDialect.load(register=False, reload=False)
60+
rw_pattern.attach_interface_impls()
61+
62+
with schedule_boilerplate() as (schedule, named_seq):
63+
apply_patterns_op = transform.ApplyPatternsOp(named_seq.bodyTarget)
64+
with ir.InsertionPoint(apply_patterns_op.patterns):
65+
rw_pattern()
66+
transform.yield_([named_seq.bodyTarget])
67+
named_seq.verify()
68+
69+
schedule.body.operations[0].verify()
70+
return schedule

lighthouse/schedule/x86/tile_and_vector_matmul.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -145,4 +145,6 @@ def create(tile_size=64) -> ir.Module:
145145
cleanup_func(named_seq.bodyTarget)
146146

147147
transform.yield_()
148+
149+
schedule.body.operations[0].verify()
148150
return schedule

lighthouse/utils/memref.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -50,12 +50,17 @@ def get_packed_arg(
5050
return packed_args
5151

5252

53-
def to_packed_args(memref_descs) -> ctypes.Array[ctypes.c_void_p]:
53+
def to_packed_args(args) -> ctypes.Array[ctypes.c_void_p]:
5454
"""
55-
Convert a list of memref descriptors into packed ctype arguments.
55+
Convert a list of memref descriptors and/or integers into packed ctype arguments.
5656
5757
Args:
58-
memref_descs: A list of memref descriptors.
58+
args: A list of memref descriptors or integers.
5959
"""
60-
ctype_args = [to_ctype(memref) for memref in memref_descs]
60+
ctype_args = []
61+
for arg in args:
62+
if isinstance(arg, int):
63+
ctype_args.append(ctypes.pointer(ctypes.c_int64(arg)))
64+
else:
65+
ctype_args.append(to_ctype(arg))
6166
return get_packed_arg(ctype_args)

lighthouse/workload/__init__.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,15 @@
11
from .workload import Workload
2-
from .runner import execute, benchmark
2+
from .runner import (
3+
execute,
4+
benchmark,
5+
bench_wrapper_pattern,
6+
get_bench_wrapper_schedule,
7+
)
38

4-
__all__ = ["Workload", "benchmark", "execute"]
9+
__all__ = [
10+
"Workload",
11+
"bench_wrapper_pattern",
12+
"benchmark",
13+
"execute",
14+
"get_bench_wrapper_schedule",
15+
]

0 commit comments

Comments
 (0)