1+ import inspect
2+ from typing import Optional , Annotated
3+
14from mlir import ir
25from mlir .dialects .transform import loop
36from mlir .dialects .transform import bufferization
47from mlir .dialects .transform import xegpu
58from mlir .dialects .bufferization import LayoutMapOption
6- from mlir .dialects import transform
7- from mlir .dialects .transform import structured
8- from lighthouse .utils .mlir import (
9- apply_registered_pass ,
10- canonicalize ,
11- match ,
9+ from mlir .dialects import transform , smt
10+ from mlir .dialects .transform import (
11+ structured ,
12+ tune as transform_tune ,
13+ smt as transform_smt ,
14+ )
15+ from lighthouse .utils .mlir import apply_registered_pass , canonicalize , match
16+ from lighthouse .tune .annotate import (
17+ check_annotated_constraints ,
18+ NonDet ,
19+ ConstraintCollector ,
1220)
13- from typing import Optional
1421
1522
1623class PipelineInterrupt (Exception ):
@@ -76,7 +83,7 @@ def xegpu_matmul_transform_schedule(
7683 has_bias = has_bias ,
7784 has_relu = has_relu ,
7885 stop_at_stage = stop_at_stage ,
79- params = params ,
86+ ** params ,
8087 )
8188
8289 mod = bundle_xegpu_to_binary (
@@ -89,45 +96,166 @@ def xegpu_matmul_transform_schedule(
8996 transform .yield_ ()
9097
9198
99+ @check_annotated_constraints
92100def bundle_xepu_matmul_schedule (
93101 mod ,
94102 has_bias : bool = False ,
95103 has_relu : bool = False ,
96104 stop_at_stage : str = "" ,
97- params : Optional [dict ] = None ,
105+ * ,
106+ wg_d0 : Annotated [int , lambda _ : 128 <= _ <= 256 and _ % 32 == 0 ] = NonDet ,
107+ wg_d1 : Annotated [int , lambda _ : 128 <= _ <= 256 and _ % 32 == 0 ] = NonDet ,
108+ sg_d0 : Annotated [int , lambda _ : 16 <= _ <= 32 and _ % 8 == 0 ] = NonDet ,
109+ sg_d1 : Annotated [int , lambda _ : 16 <= _ <= 32 and _ % 8 == 0 ] = NonDet ,
110+ k_tile : Annotated [int , lambda _ : 8 <= _ <= 32 and _ % 8 == 0 ] = NonDet ,
111+ load_a_d0 : Annotated [int , lambda _ : 8 <= _ <= 32 and _ % 8 == 0 ] = NonDet ,
112+ load_a_d1 : Annotated [int , lambda _ : 8 <= _ <= 32 and _ % 8 == 0 ] = NonDet ,
113+ load_b_d0 : Annotated [int , lambda _ : 8 <= _ <= 32 and _ % 8 == 0 ] = NonDet ,
114+ load_b_d1 : Annotated [int , lambda _ : 8 <= _ <= 32 and _ % 8 == 0 ] = NonDet ,
115+ prefetch_a_d0 : Annotated [int , lambda _ : 4 <= _ <= 8 ] = NonDet ,
116+ prefetch_a_d1 : Annotated [int , lambda _ : 16 <= _ <= 32 ] = NonDet ,
117+ prefetch_b_d0 : Annotated [int , lambda _ : 4 <= _ <= 8 ] = NonDet ,
118+ prefetch_b_d1 : Annotated [int , lambda _ : 8 <= _ <= 16 ] = NonDet ,
119+ nb_prefetch : Annotated [int , lambda _ : 1 <= _ <= 32 ] = NonDet ,
120+ ** _kwargs : Optional [dict ],
98121) -> ir .Module :
99122 """Schedule for lowering matmul-like payload to xegpu wg level."""
100- if params is None :
101- raise ValueError ("Schedule parameters must be provided." )
102-
103- # tunable parameters
104- wg_tile = [params ["auto_wg_d0" ], params ["auto_wg_d1" ]]
105- sg_tile = [params ["auto_sg_d0" ], params ["auto_sg_d1" ]]
106- k_tile = params ["auto_k" ]
107-
108- load_tile_a = [params ["auto_load_a_d0" ], params ["auto_load_a_d1" ]]
109- load_tile_b = [params ["auto_load_b_d0" ], params ["auto_load_b_d1" ]]
110-
111- prefetch_tile_a = [params ["auto_prefetch_a_d0" ], params ["auto_prefetch_a_d1" ]]
112- prefetch_tile_b = [params ["auto_prefetch_b_d0" ], params ["auto_prefetch_b_d1" ]]
113- nb_prefetch = params ["auto_nb_prefetch" ]
114-
115- # derived parameters
116- sg_layout = [wg_tile [0 ] // sg_tile [0 ], wg_tile [1 ] // sg_tile [1 ]]
117- # number of threads collapsed to 1d layout
118- nb_threads = sg_layout [0 ] * sg_layout [1 ] * nb_workitems
119- prefetch_layout_a = [
120- wg_tile [0 ] // prefetch_tile_a [0 ],
121- k_tile // prefetch_tile_a [1 ],
122- ]
123- prefetch_layout_b = [
124- k_tile // prefetch_tile_b [0 ],
125- wg_tile [1 ] // prefetch_tile_b [1 ],
123+
124+ sig = inspect .signature (bundle_xepu_matmul_schedule )
125+
126+ any_param = transform .AnyParamType .get ()
127+
128+ use_knobs = NonDet in [
129+ wg_d0 ,
130+ wg_d1 ,
131+ prefetch_a_d0 ,
132+ prefetch_a_d1 ,
133+ prefetch_b_d0 ,
134+ prefetch_b_d1 ,
135+ k_tile ,
136+ load_a_d0 ,
137+ load_a_d1 ,
138+ load_b_d0 ,
139+ load_b_d1 ,
140+ prefetch_a_d0 ,
141+ prefetch_a_d1 ,
142+ prefetch_b_d0 ,
143+ prefetch_b_d1 ,
144+ nb_prefetch ,
126145 ]
127146
147+ def as_const_or_as_knob (value , knob_name ):
148+ collector = ConstraintCollector ()
149+ sig .parameters [knob_name ].annotation .__metadata__ [0 ](collector )
150+ if use_knobs :
151+ return transform_tune .knob (
152+ any_param ,
153+ name = knob_name ,
154+ options = collector .to_mlir (),
155+ selected = value if value is not NonDet else None ,
156+ )
157+ return value
158+
159+ wg_d0 = as_const_or_as_knob (wg_d0 , "wg_d0" )
160+ wg_d1 = as_const_or_as_knob (wg_d1 , "wg_d1" )
161+ wg_tile = [wg_d0 , wg_d1 ]
162+ sg_d0 = as_const_or_as_knob (sg_d0 , "sg_d0" )
163+ sg_d1 = as_const_or_as_knob (sg_d1 , "sg_d1" )
164+ sg_tile = [sg_d0 , sg_d1 ]
165+
166+ smt_int = smt .IntType .get ()
167+ c0 = ir .IntegerAttr .get (ir .IntegerType .get_signless (64 ), 0 )
168+ c_nb_workitems = ir .IntegerAttr .get (ir .IntegerType .get_signless (64 ), nb_workitems )
169+
170+ if use_knobs :
171+ constraint1 = transform_smt .constrain_params (
172+ (any_param , any_param , any_param ),
173+ (
174+ wg_d0 ,
175+ wg_d1 ,
176+ sg_d0 ,
177+ sg_d1 ,
178+ ),
179+ [smt_int ] * 4 ,
180+ )
181+ with ir .InsertionPoint (constraint1 .body ):
182+ WGd0 , WGd1 , SGd0 , SGd1 = constraint1 .body .arguments
183+ C0 = smt .int_constant (c0 )
184+ smt .assert_ (smt .eq ((smt .int_mod (WGd0 , SGd0 ), C0 )))
185+ smt .assert_ (smt .eq ((smt .int_mod (WGd1 , SGd1 ), C0 )))
186+ d0_step_smt = smt .int_div (WGd0 , SGd0 )
187+ d1_step_smt = smt .int_div (WGd1 , SGd1 )
188+ nb_threads_smt = smt .int_mul (
189+ (d0_step_smt , d1_step_smt , smt .int_constant (c_nb_workitems ))
190+ )
191+ smt .yield_ ((d0_step_smt , d1_step_smt , nb_threads_smt ))
192+ d0_step , d1_step , nb_threads = constraint1 .results
193+ sg_layout = [d0_step , d1_step ]
194+ else :
195+ # derived parameters
196+ sg_layout = [wg_d0 // sg_d0 , wg_d1 // sg_d1 ]
197+ # number of threads collapsed to 1d layout
198+ nb_threads = sg_layout [0 ] * sg_layout [1 ] * nb_workitems
199+
200+ prefetch_a_d0 = as_const_or_as_knob (prefetch_a_d0 , "prefetch_a_d0" )
201+ prefetch_a_d1 = as_const_or_as_knob (prefetch_a_d1 , "prefetch_a_d1" )
202+ prefetch_tile_a = [prefetch_a_d0 , prefetch_a_d1 ]
203+ prefetch_b_d0 = as_const_or_as_knob (prefetch_b_d0 , "prefetch_b_d0" )
204+ prefetch_b_d1 = as_const_or_as_knob (prefetch_b_d1 , "prefetch_b_d1" )
205+ prefetch_tile_b = [prefetch_b_d0 , prefetch_b_d1 ]
206+ k_tile = as_const_or_as_knob (k_tile , "k_tile" )
207+
208+ if use_knobs :
209+ constraint2 = transform_smt .constrain_params (
210+ (any_param , any_param , any_param , any_param ),
211+ (
212+ wg_d0 ,
213+ wg_d1 ,
214+ k_tile ,
215+ prefetch_a_d0 ,
216+ prefetch_a_d1 ,
217+ prefetch_b_d0 ,
218+ prefetch_b_d1 ,
219+ ),
220+ [smt_int ] * 7 ,
221+ )
222+ with ir .InsertionPoint (constraint2 .body ):
223+ WGd0 , WGd1 , K , PFAd0 , PFAd1 , PFBd0 , PFBd1 = constraint2 .body .arguments
224+ C0 = smt .int_constant (c0 )
225+ smt .assert_ (smt .eq ((smt .int_mod (WGd0 , PFAd0 ), C0 )))
226+ smt .assert_ (smt .eq ((smt .int_mod (K , PFAd1 ), C0 )))
227+ PFAd0_step = smt .int_div (WGd0 , PFAd0 )
228+ PFAd1_step = smt .int_div (K , PFAd1 )
229+
230+ smt .assert_ (smt .eq ((smt .int_mod (K , PFBd0 ), C0 )))
231+ smt .assert_ (smt .eq ((smt .int_mod (WGd1 , PFBd1 ), C0 )))
232+ PFBd0_step = smt .int_div (K , PFBd0 )
233+ PFBd1_step = smt .int_div (WGd1 , PFBd1 )
234+
235+ smt .yield_ ((PFAd0_step , PFAd1_step , PFBd0_step , PFBd1_step ))
236+ prefetch_layout_a = constraint2 .results [0 :2 ]
237+ prefetch_layout_b = constraint2 .results [2 :4 ]
238+ else :
239+ prefetch_layout_a = [
240+ wg_d0 // prefetch_a_d0 ,
241+ k_tile // prefetch_a_d1 ,
242+ ]
243+ prefetch_layout_b = [
244+ k_tile // prefetch_b_d0 ,
245+ wg_d1 // prefetch_b_d1 ,
246+ ]
247+
128248 # matmul matrix shapes
129- sg_tile_a = [sg_tile [0 ], k_tile ]
130- sg_tile_b = [k_tile , sg_tile [1 ]]
249+ sg_tile_a = [sg_d0 , k_tile ]
250+ sg_tile_b = [k_tile , sg_d1 ]
251+
252+ load_a_d0 = as_const_or_as_knob (load_a_d0 , "load_a_d0" )
253+ load_a_d1 = as_const_or_as_knob (load_a_d1 , "load_a_d1" )
254+ load_b_d0 = as_const_or_as_knob (load_b_d0 , "load_b_d0" )
255+ load_b_d1 = as_const_or_as_knob (load_b_d1 , "load_b_d1" )
256+
257+ load_tile_a = [load_a_d0 , load_a_d1 ]
258+ load_tile_b = [load_b_d0 , load_b_d1 ]
131259
132260 if stop_at_stage == "initial" :
133261 raise PipelineInterrupt ()
0 commit comments