@@ -271,22 +271,27 @@ def convert_layout(value, input, target):
271271 tile_a ,
272272 nb_prefetch = nb_prefetch ,
273273 )
274- xegpu .set_desc_layout (
275- desc_prefetch_a ,
276- sg_layout = prefetch_layout_a ,
277- sg_data = prefetch_tile_a ,
278- inst_data = prefetch_inst_data ,
279- )
274+ layout_prefetch_a = {
275+ "sg_layout" : prefetch_layout_a ,
276+ "sg_data" : prefetch_tile_a ,
277+ "inst_data" : prefetch_inst_data ,
278+ }
279+ pf_ops = transform .get_consumers_of_result (anytype , desc_prefetch_a , 0 )
280+ for pf in transform .split_handle ((anytype ,) * (nb_prefetch + 1 ), pf_ops ):
281+ xegpu .set_op_layout_attr (pf , ** layout_prefetch_a )
282+
280283 desc_prefetch_b = xegpu .insert_prefetch (
281284 tile_b ,
282285 nb_prefetch = nb_prefetch ,
283286 )
284- xegpu .set_desc_layout (
285- desc_prefetch_b ,
286- sg_layout = prefetch_layout_b ,
287- sg_data = prefetch_tile_b ,
288- inst_data = prefetch_inst_data ,
289- )
287+ layout_prefetch_b = {
288+ "sg_layout" : prefetch_layout_b ,
289+ "sg_data" : prefetch_tile_b ,
290+ "inst_data" : prefetch_inst_data ,
291+ }
292+ pf_ops = transform .get_consumers_of_result (anytype , desc_prefetch_b , 0 )
293+ for pf in transform .split_handle ((anytype ,) * (nb_prefetch + 1 ), pf_ops ):
294+ xegpu .set_op_layout_attr (pf , ** layout_prefetch_b )
290295
291296 # A tile load layout
292297 layout_load_a = {
@@ -295,10 +300,9 @@ def convert_layout(value, input, target):
295300 "inst_data" : load_tile_a ,
296301 }
297302 desc_op_a = xegpu .get_desc_op (tile_a )
298- desc_op_a = xegpu .set_desc_layout (
299- target = desc_op_a ,
300- ** layout_load_a ,
301- )
303+ # A tile load op anchor layout
304+ load_op_a = transform .get_consumers_of_result (anytype , desc_op_a , 0 )
305+ xegpu .set_op_layout_attr (load_op_a , ** layout_load_a )
302306 # A tile dpas layout
303307 layout_dpas_a = layout_load_a .copy ()
304308 layout_dpas_a ["inst_data" ] = dpas_shape_a
@@ -311,10 +315,9 @@ def convert_layout(value, input, target):
311315 "inst_data" : load_tile_b ,
312316 }
313317 desc_op_b = xegpu .get_desc_op (tile_b )
314- desc_op_b = xegpu .set_desc_layout (
315- target = desc_op_b ,
316- ** layout_load_b ,
317- )
318+ # B tile load op anchor layout
319+ load_op_b = transform .get_consumers_of_result (anytype , desc_op_b , 0 )
320+ xegpu .set_op_layout_attr (load_op_b , ** layout_load_b )
318321 # B tile dpas layout
319322 layout_dpas_b = layout_load_b .copy ()
320323 layout_dpas_b ["inst_data" ] = dpas_shape_b
@@ -327,42 +330,23 @@ def convert_layout(value, input, target):
327330 "inst_data" : dpas_shape_c ,
328331 }
329332 desc_op_c = xegpu .get_desc_op (tile_c )
330- desc_op_c = xegpu .set_desc_layout (desc_op_c , ** output_layout )
331- # C tile dpas layout
332- xegpu .set_op_layout_attr (dpas_op , result = True , index = 0 , ** output_layout )
333+ # C tile load/store op anchor layout
334+ desc_c_users = transform .get_consumers_of_result (anytype , desc_op_c , 0 )
335+ load_op_c , store_op_c = transform .split_handle ((anytype , anytype ), desc_c_users )
336+ xegpu .set_op_layout_attr (load_op_c , ** output_layout )
337+ # C tile dpas anchor layout
338+ xegpu .set_op_layout_attr (dpas_op , index = 0 , ** layout_dpas_a )
339+ xegpu .set_op_layout_attr (dpas_op , index = 1 , ** layout_dpas_b )
340+ xegpu .set_op_layout_attr (dpas_op , index = 2 , ** output_layout )
333341
334- if has_relu :
335- # for post ops we need to add C layout manually
336- max_op = match (gpu_func , ops = {"arith.maximumf" })
337- xegpu .set_op_layout_attr (max_op , result = True , index = 0 , ** output_layout )
338- # find zero constant buffer and annotate it
339- const_buffer = transform .get_producer_of_operand (anytype , max_op , 1 )
340- xegpu .set_op_layout_attr (const_buffer , result = True , index = 0 , ** output_layout )
341342 if has_bias :
342- # for post ops we need to add C layout manually
343+ # annotate the 1d load of the broadcast op with a slice layout
343344 add_op = match (gpu_func , ops = {"arith.addf" })
344- xegpu .set_op_layout_attr (add_op , result = True , index = 0 , ** output_layout )
345-
346- # annotate broadcast op operands
347345 bcast_op = transform .get_producer_of_operand (anytype , add_op , 0 )
348- xegpu .set_op_layout_attr (bcast_op , result = True , index = 0 , ** output_layout )
349346 bcast_load = transform .get_producer_of_operand (anytype , bcast_op , 0 )
350347 xegpu .set_op_layout_attr (
351348 bcast_load , result = True , index = 0 , ** output_layout , slice_dims = [0 ]
352349 )
353- output_layout_dim1 = {
354- "sg_layout" : [sg_layout [1 ]],
355- "sg_data" : [sg_tile [1 ]],
356- "inst_data" : [dpas_shape_c [1 ]],
357- }
358- offset = transform .get_producer_of_operand (anytype , bcast_load , 1 )
359- xegpu .set_op_layout_attr (offset , result = True , index = 0 , ** output_layout_dim1 )
360- aux1 = transform .get_producer_of_operand (anytype , offset , 0 )
361- xegpu .set_op_layout_attr (aux1 , result = True , index = 0 , ** output_layout_dim1 )
362- aux2 = transform .get_producer_of_operand (anytype , offset , 1 )
363- xegpu .set_op_layout_attr (aux2 , result = True , index = 0 , ** output_layout_dim1 )
364- mask = transform .get_producer_of_operand (anytype , bcast_load , 2 )
365- xegpu .set_op_layout_attr (mask , result = True , index = 0 , ** output_layout_dim1 )
366350 raise NotImplementedError ("Bias layout propagation is not supported." )
367351 transform .apply_cse (gpu_func )
368352 canonicalize (gpu_func )
0 commit comments