Skip to content

Commit 6cbd477

Browse files
committed
Infrastructure for auto-tuning
1 parent 8922c63 commit 6cbd477

File tree

10 files changed

+732
-56
lines changed

10 files changed

+732
-56
lines changed

examples/xegpu_matmul/matmul.py

Lines changed: 31 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# RUN: %PYTHON %s --dump-kernel=xegpu-wg | FileCheck %s
1+
# RUN: %PYTHON %s --dump-payload=xegpu-wg | FileCheck %s
22
# CHECK: module attributes {gpu.container_module} {
33

44
"""
@@ -315,7 +315,7 @@ def parse_cli():
315315
help="Check the result of the matrix multiplication.",
316316
)
317317
parser.add_argument(
318-
"--dump-kernel",
318+
"--dump-payload",
319319
type=str,
320320
choices=[
321321
"initial",
@@ -328,13 +328,18 @@ def parse_cli():
328328
"xegpu-inst",
329329
"final",
330330
],
331-
help="Dump kernel IR at different stages of lowering.",
331+
help="Dump payload IR at different stages of lowering.",
332332
)
333333
parser.add_argument(
334334
"--dump-schedule",
335335
action="store_true",
336336
help="Dump transform schedule.",
337337
)
338+
parser.add_argument(
339+
"--non-det",
340+
action="store_true",
341+
help="Generate schedule with knob values left non-determined.",
342+
)
338343
args = parser.parse_args()
339344

340345
return args
@@ -344,21 +349,23 @@ def parse_cli():
344349
args = parse_cli()
345350

346351
params = {
347-
"auto_wg_d0": args.wg_tile[0],
348-
"auto_wg_d1": args.wg_tile[1],
349-
"auto_sg_d0": args.sg_tile[0],
350-
"auto_sg_d1": args.sg_tile[1],
351-
"auto_k": args.k_tile,
352-
"auto_load_a_d0": args.load_tile_a[0],
353-
"auto_load_a_d1": args.load_tile_a[1],
354-
"auto_load_b_d0": args.load_tile_b[0],
355-
"auto_load_b_d1": args.load_tile_b[1],
356-
"auto_prefetch_a_d0": args.prefetch_tile_a[0],
357-
"auto_prefetch_a_d1": args.prefetch_tile_a[1],
358-
"auto_prefetch_b_d0": args.prefetch_tile_b[0],
359-
"auto_prefetch_b_d1": args.prefetch_tile_b[1],
360-
"auto_nb_prefetch": args.nb_prefetch,
352+
"wg_d0": args.wg_tile[0],
353+
"wg_d1": args.wg_tile[1],
354+
"sg_d0": args.sg_tile[0],
355+
"sg_d1": args.sg_tile[1],
356+
"k_tile": args.k_tile,
357+
"load_a_d0": args.load_tile_a[0],
358+
"load_a_d1": args.load_tile_a[1],
359+
"load_b_d0": args.load_tile_b[0],
360+
"load_b_d1": args.load_tile_b[1],
361+
"prefetch_a_d0": args.prefetch_tile_a[0],
362+
"prefetch_a_d1": args.prefetch_tile_a[1],
363+
"prefetch_b_d0": args.prefetch_tile_b[0],
364+
"prefetch_b_d1": args.prefetch_tile_b[1],
365+
"nb_prefetch": args.nb_prefetch,
361366
}
367+
if args.non_det:
368+
params = {}
362369

363370
M, N, K = args.sizes
364371
ab_type = "f16"
@@ -375,9 +382,14 @@ def parse_cli():
375382
has_relu=args.relu,
376383
)
377384

378-
if args.dump_kernel or args.dump_schedule:
385+
if args.dump_schedule:
386+
schedule_module = wload.schedule_module(
387+
stop_at_stage=args.dump_payload, parameters=params
388+
)
389+
print(schedule_module)
390+
elif args.dump_kernel:
379391
wload.lower_payload(
380-
dump_payload=args.dump_kernel,
392+
dump_payload=args.dump_payload,
381393
dump_schedule=args.dump_schedule,
382394
schedule_parameters=params,
383395
)

examples/xegpu_matmul/schedule.py

Lines changed: 165 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,23 @@
1+
import inspect
2+
from typing import Optional, Annotated
3+
14
from mlir import ir
25
from mlir.dialects.transform import loop
36
from mlir.dialects.transform import bufferization
47
from mlir.dialects.transform import xegpu
58
from mlir.dialects.bufferization import LayoutMapOption
6-
from mlir.dialects import transform
7-
from mlir.dialects.transform import structured
8-
from lighthouse.utils.mlir import (
9-
apply_registered_pass,
10-
canonicalize,
11-
match,
9+
from mlir.dialects import transform, smt
10+
from mlir.dialects.transform import (
11+
structured,
12+
tune as transform_tune,
13+
smt as transform_smt,
14+
)
15+
from lighthouse.utils.mlir import apply_registered_pass, canonicalize, match
16+
from lighthouse.tune.annotate import (
17+
check_annotated_constraints,
18+
NonDet,
19+
ConstraintCollector,
1220
)
13-
from typing import Optional
1421

1522

1623
class PipelineInterrupt(Exception):
@@ -76,7 +83,7 @@ def xegpu_matmul_transform_schedule(
7683
has_bias=has_bias,
7784
has_relu=has_relu,
7885
stop_at_stage=stop_at_stage,
79-
params=params,
86+
**params,
8087
)
8188

8289
mod = bundle_xegpu_to_binary(
@@ -89,45 +96,166 @@ def xegpu_matmul_transform_schedule(
8996
transform.yield_()
9097

9198

99+
@check_annotated_constraints
92100
def bundle_xepu_matmul_schedule(
93101
mod,
94102
has_bias: bool = False,
95103
has_relu: bool = False,
96104
stop_at_stage: str = "",
97-
params: Optional[dict] = None,
105+
*,
106+
wg_d0: Annotated[int, lambda _: 128 <= _ <= 256 and _ % 32 == 0] = NonDet,
107+
wg_d1: Annotated[int, lambda _: 128 <= _ <= 256 and _ % 32 == 0] = NonDet,
108+
sg_d0: Annotated[int, lambda _: 16 <= _ <= 32 and _ % 8 == 0] = NonDet,
109+
sg_d1: Annotated[int, lambda _: 16 <= _ <= 32 and _ % 8 == 0] = NonDet,
110+
k_tile: Annotated[int, lambda _: 8 <= _ <= 32 and _ % 8 == 0] = NonDet,
111+
load_a_d0: Annotated[int, lambda _: 8 <= _ <= 32 and _ % 8 == 0] = NonDet,
112+
load_a_d1: Annotated[int, lambda _: 8 <= _ <= 32 and _ % 8 == 0] = NonDet,
113+
load_b_d0: Annotated[int, lambda _: 8 <= _ <= 32 and _ % 8 == 0] = NonDet,
114+
load_b_d1: Annotated[int, lambda _: 8 <= _ <= 32 and _ % 8 == 0] = NonDet,
115+
prefetch_a_d0: Annotated[int, lambda _: 4 <= _ <= 8] = NonDet,
116+
prefetch_a_d1: Annotated[int, lambda _: 16 <= _ <= 32] = NonDet,
117+
prefetch_b_d0: Annotated[int, lambda _: 4 <= _ <= 8] = NonDet,
118+
prefetch_b_d1: Annotated[int, lambda _: 8 <= _ <= 16] = NonDet,
119+
nb_prefetch: Annotated[int, lambda _: 1 <= _ <= 32] = NonDet,
120+
**_kwargs: Optional[dict],
98121
) -> ir.Module:
99122
"""Schedule for lowering matmul-like payload to xegpu wg level."""
100-
if params is None:
101-
raise ValueError("Schedule parameters must be provided.")
102-
103-
# tunable parameters
104-
wg_tile = [params["auto_wg_d0"], params["auto_wg_d1"]]
105-
sg_tile = [params["auto_sg_d0"], params["auto_sg_d1"]]
106-
k_tile = params["auto_k"]
107-
108-
load_tile_a = [params["auto_load_a_d0"], params["auto_load_a_d1"]]
109-
load_tile_b = [params["auto_load_b_d0"], params["auto_load_b_d1"]]
110-
111-
prefetch_tile_a = [params["auto_prefetch_a_d0"], params["auto_prefetch_a_d1"]]
112-
prefetch_tile_b = [params["auto_prefetch_b_d0"], params["auto_prefetch_b_d1"]]
113-
nb_prefetch = params["auto_nb_prefetch"]
114-
115-
# derived parameters
116-
sg_layout = [wg_tile[0] // sg_tile[0], wg_tile[1] // sg_tile[1]]
117-
# number of threads collapsed to 1d layout
118-
nb_threads = sg_layout[0] * sg_layout[1] * nb_workitems
119-
prefetch_layout_a = [
120-
wg_tile[0] // prefetch_tile_a[0],
121-
k_tile // prefetch_tile_a[1],
122-
]
123-
prefetch_layout_b = [
124-
k_tile // prefetch_tile_b[0],
125-
wg_tile[1] // prefetch_tile_b[1],
123+
124+
sig = inspect.signature(bundle_xepu_matmul_schedule)
125+
126+
any_param = transform.AnyParamType.get()
127+
128+
use_knobs = NonDet in [
129+
wg_d0,
130+
wg_d1,
131+
prefetch_a_d0,
132+
prefetch_a_d1,
133+
prefetch_b_d0,
134+
prefetch_b_d1,
135+
k_tile,
136+
load_a_d0,
137+
load_a_d1,
138+
load_b_d0,
139+
load_b_d1,
140+
prefetch_a_d0,
141+
prefetch_a_d1,
142+
prefetch_b_d0,
143+
prefetch_b_d1,
144+
nb_prefetch,
126145
]
127146

147+
def as_const_or_as_knob(value, knob_name):
148+
collector = ConstraintCollector()
149+
sig.parameters[knob_name].annotation.__metadata__[0](collector)
150+
if use_knobs:
151+
return transform_tune.knob(
152+
any_param,
153+
name=knob_name,
154+
options=collector.to_mlir(),
155+
selected=value if value is not NonDet else None,
156+
)
157+
return value
158+
159+
wg_d0 = as_const_or_as_knob(wg_d0, "wg_d0")
160+
wg_d1 = as_const_or_as_knob(wg_d1, "wg_d1")
161+
wg_tile = [wg_d0, wg_d1]
162+
sg_d0 = as_const_or_as_knob(sg_d0, "sg_d0")
163+
sg_d1 = as_const_or_as_knob(sg_d1, "sg_d1")
164+
sg_tile = [sg_d0, sg_d1]
165+
166+
smt_int = smt.IntType.get()
167+
c0 = ir.IntegerAttr.get(ir.IntegerType.get_signless(64), 0)
168+
c_nb_workitems = ir.IntegerAttr.get(ir.IntegerType.get_signless(64), nb_workitems)
169+
170+
if use_knobs:
171+
constraint1 = transform_smt.constrain_params(
172+
(any_param, any_param, any_param),
173+
(
174+
wg_d0,
175+
wg_d1,
176+
sg_d0,
177+
sg_d1,
178+
),
179+
[smt_int] * 4,
180+
)
181+
with ir.InsertionPoint(constraint1.body):
182+
WGd0, WGd1, SGd0, SGd1 = constraint1.body.arguments
183+
C0 = smt.int_constant(c0)
184+
smt.assert_(smt.eq((smt.int_mod(WGd0, SGd0), C0)))
185+
smt.assert_(smt.eq((smt.int_mod(WGd1, SGd1), C0)))
186+
d0_step_smt = smt.int_div(WGd0, SGd0)
187+
d1_step_smt = smt.int_div(WGd1, SGd1)
188+
nb_threads_smt = smt.int_mul(
189+
(d0_step_smt, d1_step_smt, smt.int_constant(c_nb_workitems))
190+
)
191+
smt.yield_((d0_step_smt, d1_step_smt, nb_threads_smt))
192+
d0_step, d1_step, nb_threads = constraint1.results
193+
sg_layout = [d0_step, d1_step]
194+
else:
195+
# derived parameters
196+
sg_layout = [wg_d0 // sg_d0, wg_d1 // sg_d1]
197+
# number of threads collapsed to 1d layout
198+
nb_threads = sg_layout[0] * sg_layout[1] * nb_workitems
199+
200+
prefetch_a_d0 = as_const_or_as_knob(prefetch_a_d0, "prefetch_a_d0")
201+
prefetch_a_d1 = as_const_or_as_knob(prefetch_a_d1, "prefetch_a_d1")
202+
prefetch_tile_a = [prefetch_a_d0, prefetch_a_d1]
203+
prefetch_b_d0 = as_const_or_as_knob(prefetch_b_d0, "prefetch_b_d0")
204+
prefetch_b_d1 = as_const_or_as_knob(prefetch_b_d1, "prefetch_b_d1")
205+
prefetch_tile_b = [prefetch_b_d0, prefetch_b_d1]
206+
k_tile = as_const_or_as_knob(k_tile, "k_tile")
207+
208+
if use_knobs:
209+
constraint2 = transform_smt.constrain_params(
210+
(any_param, any_param, any_param, any_param),
211+
(
212+
wg_d0,
213+
wg_d1,
214+
k_tile,
215+
prefetch_a_d0,
216+
prefetch_a_d1,
217+
prefetch_b_d0,
218+
prefetch_b_d1,
219+
),
220+
[smt_int] * 7,
221+
)
222+
with ir.InsertionPoint(constraint2.body):
223+
WGd0, WGd1, K, PFAd0, PFAd1, PFBd0, PFBd1 = constraint2.body.arguments
224+
C0 = smt.int_constant(c0)
225+
smt.assert_(smt.eq((smt.int_mod(WGd0, PFAd0), C0)))
226+
smt.assert_(smt.eq((smt.int_mod(K, PFAd1), C0)))
227+
PFAd0_step = smt.int_div(WGd0, PFAd0)
228+
PFAd1_step = smt.int_div(K, PFAd1)
229+
230+
smt.assert_(smt.eq((smt.int_mod(K, PFBd0), C0)))
231+
smt.assert_(smt.eq((smt.int_mod(WGd1, PFBd1), C0)))
232+
PFBd0_step = smt.int_div(K, PFBd0)
233+
PFBd1_step = smt.int_div(WGd1, PFBd1)
234+
235+
smt.yield_((PFAd0_step, PFAd1_step, PFBd0_step, PFBd1_step))
236+
prefetch_layout_a = constraint2.results[0:2]
237+
prefetch_layout_b = constraint2.results[2:4]
238+
else:
239+
prefetch_layout_a = [
240+
wg_d0 // prefetch_a_d0,
241+
k_tile // prefetch_a_d1,
242+
]
243+
prefetch_layout_b = [
244+
k_tile // prefetch_b_d0,
245+
wg_d1 // prefetch_b_d1,
246+
]
247+
128248
# matmul matrix shapes
129-
sg_tile_a = [sg_tile[0], k_tile]
130-
sg_tile_b = [k_tile, sg_tile[1]]
249+
sg_tile_a = [sg_d0, k_tile]
250+
sg_tile_b = [k_tile, sg_d1]
251+
252+
load_a_d0 = as_const_or_as_knob(load_a_d0, "load_a_d0")
253+
load_a_d1 = as_const_or_as_knob(load_a_d1, "load_a_d1")
254+
load_b_d0 = as_const_or_as_knob(load_b_d0, "load_b_d0")
255+
load_b_d1 = as_const_or_as_knob(load_b_d1, "load_b_d1")
256+
257+
load_tile_a = [load_a_d0, load_a_d1]
258+
load_tile_b = [load_b_d0, load_b_d1]
131259

132260
if stop_at_stage == "initial":
133261
raise PipelineInterrupt()

lighthouse/tune/__init__.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
__all__ = ["smt", "rewrite"]
2+
3+
import sys
4+
import importlib
5+
6+
7+
def __getattr__(name):
8+
"""Enable lazy loading of submodules.
9+
10+
Enables `import lighthouse.tune as lh_tune; lh_tune.<submodule>` with
11+
loading of (the submodule's heavy) depenendencies only upon being needed.
12+
"""
13+
14+
if name in __all__:
15+
# Import the submodule and cache it on the current module. That is,
16+
# upon the next access __getattr__ will not be called.
17+
submodule = importlib.import_module("lighthouse.tune." + name)
18+
lighthouse_tune_mod = sys.modules[__name__]
19+
setattr(lighthouse_tune_mod, name, submodule)
20+
return submodule
21+
raise AttributeError(f"module '{__name__}' has no attribute '{name}'")

0 commit comments

Comments
 (0)