Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 25 additions & 10 deletions examples/schedule/transform_a_payload_according_to_a_schedule.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,15 @@

# Simply demonstrates applying a schedule to a payload.
# To do so generates a basic payload and a basic schedule, purely as an example.
# The following demonstrates doing so from the cmdline:

# RUN: mkdir -p %t
# RUN: %PYTHON %S/transform_a_payload_according_to_a_schedule.py payload > %t/payload.mlir
# RUN: %PYTHON %S/transform_a_payload_according_to_a_schedule.py schedule > %t/schedule.mlir
# RUN: lh-transform %t/schedule.mlir %t/payload.mlir | FileCheck %S/transform_a_payload_according_to_a_schedule.py


import sys

from mlir.ir import Context, Location, InsertionPoint, Operation, Module
from mlir.ir import RankedTensorType, F32Type, UnitAttr
Expand All @@ -21,13 +30,11 @@ def example_payload() -> Module:
Res = matmul(..., C=X)
"""

print("NOTE: example payload module:")
payload = Module.create()
with InsertionPoint(payload.body):
matrixType = RankedTensorType.get([16, 16], F32Type.get())

# NB: Do the CHECKing on the transformed output:
# CHECK-LABEL: result of applying schedule to payload
# CHECK: func.func @fold_add_on_two_matmuls
# CHECK-SAME: (%[[MATRIX_A:.*]]: {{.*}}, %[[MATRIX_B:.*]]: {{.*}}, %[[WEIGHTS:.*]]: {{.*}})
@func.func(matrixType, matrixType, matrixType)
Expand All @@ -45,14 +52,12 @@ def fold_add_on_two_matmuls(matrixA, matrixB, weights):
# CHECK: return %[[RES]]
return added

print(payload)
return payload


def example_schedule() -> Module:
"""Basic schedule wrapping a single rewrite pattern."""

print("NOTE: example schedule module:")
schedule_module = Module.create()
schedule_module.operation.attributes["transform.with_named_sequence"] = (
UnitAttr.get()
Expand All @@ -75,17 +80,27 @@ def example_schedule() -> Module:
) # TODO: expose dedicated builder upstream
transform.yield_([])

print(schedule_module)
return schedule_module


with Context(), Location.unknown():
payload = example_payload()
schedule_module = example_schedule()
# Actual schedule is defined by the contained transform.named_sequence:
schedule: transform.NamedSequenceOp = schedule_module.body.operations[0]

schedule.apply(payload) # The actual transformation happens here.
if len(sys.argv) > 1 and sys.argv[1] == "schedule":
print(schedule_module)
elif len(sys.argv) > 1 and sys.argv[1] == "payload":
print(payload)
else:
print("// NOTE: example payload module:")
print(payload)
print("// NOTE: example schedule module:")
print(schedule_module)

# Actual schedule is defined by the contained transform.named_sequence:
schedule: transform.NamedSequenceOp = schedule_module.body.operations[0]

schedule.apply(payload) # The actual transformation happens here.

print("NOTE: result of applying schedule to payload:")
print(payload)
print("// NOTE: result of applying schedule to payload:")
print(payload)
5 changes: 5 additions & 0 deletions lit.cfg.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,3 +51,8 @@ def find_filecheck() -> str:
torch_kernels_dir = project_root + "/third_party/KernelBench/KernelBench"
if os.path.isdir(torch_kernels_dir):
config.available_features.add("kernel_bench")

for tool in os.listdir(project_root + "/tools"):
tool_path = os.path.join(project_root, "tools", tool)
if os.access(tool_path, os.X_OK):
config.substitutions.append((tool, tool_path))
39 changes: 39 additions & 0 deletions tools/lh-transform
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
#!/usr/bin/env python

# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: header is skipped everywhere else

# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

import argparse
import sys

from mlir import ir
from mlir.dialects import transform


if __name__ == "__main__":
arg_parser = argparse.ArgumentParser(
prog="lh-transform", description="Apply a schedule to a payload module"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You could also add epilog with extra usage examples and documentation like what's the expected format of the files passed.

)
arg_parser.add_argument(
"schedule", help="path to file containing MLIR schedule module"
)
arg_parser.add_argument(
"payload", help="path to file containing MLIR payload module"
)
args = arg_parser.parse_args(sys.argv[1:])

with ir.Context(), ir.Location.unknown():
with open(args.schedule) as sched_file, open(args.payload) as payload_file:
schedule_module = ir.Module.parse(sched_file.read())
payload_module = ir.Module.parse(payload_file.read())

schedule = schedule_module.body.operations[0]
if not isinstance(schedule, transform.NamedSequenceOp):
sys.exit(
Comment on lines +31 to +33
Copy link

Copilot AI Mar 9, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

schedule_module.body.operations[0] will raise an IndexError if the schedule module parses successfully but contains no top-level operations (e.g., empty file). Consider checking for an empty body and exiting with a clear error, and/or searching the module body for a transform.named_sequence op instead of assuming it is the first op.

Copilot uses AI. Check for mistakes.
"The following op was expected to be a `transform.named_sequence`, instead got:\n"
+ str(schedule)
)
schedule.apply(payload_module)

print(payload_module)