Skip to content

Commit 0ecd1b5

Browse files
authored
[examples][xegpu-matmul] Use XeGPU anchor layouts (#54)
Migrate lowering schedule to use XeGPU anchor layouts for load/store/dpas ops.
1 parent 0dadb37 commit 0ecd1b5

File tree

3 files changed

+34
-51
lines changed

3 files changed

+34
-51
lines changed

examples/xegpu_matmul/README.md

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ Set `LLVM_INSTALL_DIR` and use the below script to checkout and compile LLVM loc
2020

2121
```bash
2222
export LLVM_INSTALL_DIR=<...>
23-
export LLVM_VERSION=83765f435d1c
23+
export LLVM_VERSION=45bee6efe9d6
2424

2525
git clone https://github.com/llvm/llvm-project.git
2626
cd llvm-project
@@ -34,7 +34,6 @@ cmake ../llvm -G Ninja \
3434
-DLLVM_BUILD_EXAMPLES=OFF \
3535
-DLLVM_TARGETS_TO_BUILD="host" \
3636
-DLLVM_ENABLE_ASSERTIONS=ON \
37-
-DLLVM_ENABLE_RTTI=ON \
3837
-DLLVM_EXPERIMENTAL_TARGETS_TO_BUILD="SPIRV" \
3938
-DLLVM_INSTALL_GTEST=ON \
4039
-DMLIR_ENABLE_LEVELZERO_RUNNER=1 \

examples/xegpu_matmul/schedule.py

Lines changed: 32 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -271,22 +271,27 @@ def convert_layout(value, input, target):
271271
tile_a,
272272
nb_prefetch=nb_prefetch,
273273
)
274-
xegpu.set_desc_layout(
275-
desc_prefetch_a,
276-
sg_layout=prefetch_layout_a,
277-
sg_data=prefetch_tile_a,
278-
inst_data=prefetch_inst_data,
279-
)
274+
layout_prefetch_a = {
275+
"sg_layout": prefetch_layout_a,
276+
"sg_data": prefetch_tile_a,
277+
"inst_data": prefetch_inst_data,
278+
}
279+
pf_ops = transform.get_consumers_of_result(anytype, desc_prefetch_a, 0)
280+
for pf in transform.split_handle((anytype,) * (nb_prefetch + 1), pf_ops):
281+
xegpu.set_op_layout_attr(pf, **layout_prefetch_a)
282+
280283
desc_prefetch_b = xegpu.insert_prefetch(
281284
tile_b,
282285
nb_prefetch=nb_prefetch,
283286
)
284-
xegpu.set_desc_layout(
285-
desc_prefetch_b,
286-
sg_layout=prefetch_layout_b,
287-
sg_data=prefetch_tile_b,
288-
inst_data=prefetch_inst_data,
289-
)
287+
layout_prefetch_b = {
288+
"sg_layout": prefetch_layout_b,
289+
"sg_data": prefetch_tile_b,
290+
"inst_data": prefetch_inst_data,
291+
}
292+
pf_ops = transform.get_consumers_of_result(anytype, desc_prefetch_b, 0)
293+
for pf in transform.split_handle((anytype,) * (nb_prefetch + 1), pf_ops):
294+
xegpu.set_op_layout_attr(pf, **layout_prefetch_b)
290295

291296
# A tile load layout
292297
layout_load_a = {
@@ -295,10 +300,9 @@ def convert_layout(value, input, target):
295300
"inst_data": load_tile_a,
296301
}
297302
desc_op_a = xegpu.get_desc_op(tile_a)
298-
desc_op_a = xegpu.set_desc_layout(
299-
target=desc_op_a,
300-
**layout_load_a,
301-
)
303+
# A tile load op anchor layout
304+
load_op_a = transform.get_consumers_of_result(anytype, desc_op_a, 0)
305+
xegpu.set_op_layout_attr(load_op_a, **layout_load_a)
302306
# A tile dpas layout
303307
layout_dpas_a = layout_load_a.copy()
304308
layout_dpas_a["inst_data"] = dpas_shape_a
@@ -311,10 +315,9 @@ def convert_layout(value, input, target):
311315
"inst_data": load_tile_b,
312316
}
313317
desc_op_b = xegpu.get_desc_op(tile_b)
314-
desc_op_b = xegpu.set_desc_layout(
315-
target=desc_op_b,
316-
**layout_load_b,
317-
)
318+
# B tile load op anchor layout
319+
load_op_b = transform.get_consumers_of_result(anytype, desc_op_b, 0)
320+
xegpu.set_op_layout_attr(load_op_b, **layout_load_b)
318321
# B tile dpas layout
319322
layout_dpas_b = layout_load_b.copy()
320323
layout_dpas_b["inst_data"] = dpas_shape_b
@@ -327,42 +330,23 @@ def convert_layout(value, input, target):
327330
"inst_data": dpas_shape_c,
328331
}
329332
desc_op_c = xegpu.get_desc_op(tile_c)
330-
desc_op_c = xegpu.set_desc_layout(desc_op_c, **output_layout)
331-
# C tile dpas layout
332-
xegpu.set_op_layout_attr(dpas_op, result=True, index=0, **output_layout)
333+
# C tile load/store op anchor layout
334+
desc_c_users = transform.get_consumers_of_result(anytype, desc_op_c, 0)
335+
load_op_c, store_op_c = transform.split_handle((anytype, anytype), desc_c_users)
336+
xegpu.set_op_layout_attr(load_op_c, **output_layout)
337+
# C tile dpas anchor layout
338+
xegpu.set_op_layout_attr(dpas_op, index=0, **layout_dpas_a)
339+
xegpu.set_op_layout_attr(dpas_op, index=1, **layout_dpas_b)
340+
xegpu.set_op_layout_attr(dpas_op, index=2, **output_layout)
333341

334-
if has_relu:
335-
# for post ops we need to add C layout manually
336-
max_op = match(gpu_func, ops={"arith.maximumf"})
337-
xegpu.set_op_layout_attr(max_op, result=True, index=0, **output_layout)
338-
# find zero constant buffer and annotate it
339-
const_buffer = transform.get_producer_of_operand(anytype, max_op, 1)
340-
xegpu.set_op_layout_attr(const_buffer, result=True, index=0, **output_layout)
341342
if has_bias:
342-
# for post ops we need to add C layout manually
343+
# annotate the 1d load of the broadcast op with a slice layout
343344
add_op = match(gpu_func, ops={"arith.addf"})
344-
xegpu.set_op_layout_attr(add_op, result=True, index=0, **output_layout)
345-
346-
# annotate broadcast op operands
347345
bcast_op = transform.get_producer_of_operand(anytype, add_op, 0)
348-
xegpu.set_op_layout_attr(bcast_op, result=True, index=0, **output_layout)
349346
bcast_load = transform.get_producer_of_operand(anytype, bcast_op, 0)
350347
xegpu.set_op_layout_attr(
351348
bcast_load, result=True, index=0, **output_layout, slice_dims=[0]
352349
)
353-
output_layout_dim1 = {
354-
"sg_layout": [sg_layout[1]],
355-
"sg_data": [sg_tile[1]],
356-
"inst_data": [dpas_shape_c[1]],
357-
}
358-
offset = transform.get_producer_of_operand(anytype, bcast_load, 1)
359-
xegpu.set_op_layout_attr(offset, result=True, index=0, **output_layout_dim1)
360-
aux1 = transform.get_producer_of_operand(anytype, offset, 0)
361-
xegpu.set_op_layout_attr(aux1, result=True, index=0, **output_layout_dim1)
362-
aux2 = transform.get_producer_of_operand(anytype, offset, 1)
363-
xegpu.set_op_layout_attr(aux2, result=True, index=0, **output_layout_dim1)
364-
mask = transform.get_producer_of_operand(anytype, bcast_load, 2)
365-
xegpu.set_op_layout_attr(mask, result=True, index=0, **output_layout_dim1)
366350
raise NotImplementedError("Bias layout propagation is not supported.")
367351
transform.apply_cse(gpu_func)
368352
canonicalize(gpu_func)

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ name = "lighthouse"
33
dynamic = ["version"]
44
requires-python = ">=3.10,<3.13" # Bounds are due to torch-mlir's packaging
55
dependencies = [
6-
"mlir-python-bindings==20260211+f932646bf"
6+
"mlir-python-bindings==20260215+45bee6efe"
77
]
88

99
[dependency-groups]

0 commit comments

Comments
 (0)