Skip to content

Commit 42b5801

Browse files
committed
Documentation
1 parent 7b912bd commit 42b5801

File tree

4 files changed

+124
-95
lines changed

4 files changed

+124
-95
lines changed

lighthouse/dialects/smt_ext.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,14 @@
77

88

99
def register_and_load(context=None):
10+
"""Register and load the SMTIntValue caster."""
11+
1012
SMTIntValue.register_value_caster()
1113

1214

1315
def assert_(predicate: ir.Value[smt.BoolType] | bool):
1416
"""Assert normally if a bool else produce an SMT assertion op."""
17+
1518
if isinstance(predicate, bool):
1619
assert predicate
1720
else:
@@ -32,6 +35,8 @@ def swapped(
3235

3336

3437
class SMTIntValue(ir.Value[smt.IntType]):
38+
"""A Value caster for `!smt.int` that supports Pythonic arithmetic and comparison operations."""
39+
3540
def __init__(self, v):
3641
super().__init__(v)
3742

lighthouse/dialects/transform_smt_ext.py

Lines changed: 79 additions & 91 deletions
Original file line numberDiff line numberDiff line change
@@ -12,11 +12,16 @@
1212
"register_and_load",
1313
]
1414

15+
1516
def register_and_load(context=None):
17+
"""Register and load the TransformSMTDialectExtension and its operations."""
18+
1619
TransformSMTDialectExtension.load()
1720

1821

1922
class TransformSMTDialectExtension(ext.Dialect, name="transform_smt_ext"):
23+
"""A Transform Dialect extension for SMT-related operations."""
24+
2025
@classmethod
2126
def load(cls, *args, **kwargs):
2227
super(TransformSMTDialectExtension, cls).load(*args, **kwargs)
@@ -29,6 +34,13 @@ def load(cls, *args, **kwargs):
2934
class ConstrainParamsOp(
3035
TransformSMTDialectExtension.Operation, name="constrain_params"
3136
):
37+
"""Constrain transform params by SMT ops while also producing new params.
38+
39+
In effect applies a predicate defined by the SMT ops in the body, which can
40+
reference the parameters as block arguments as !smt.int. The result params
41+
are defined by the !smt.int values yielded from the body.
42+
"""
43+
3244
results_: Sequence[ext.Result[transform.AnyParamType]]
3345
params: Sequence[ext.Operand[transform.AnyParamType]]
3446
body_: ext.Region
@@ -49,24 +61,33 @@ def attach_interfaces(cls, ctx=None):
4961
setattr(cls, "_interfaces_attached", True)
5062

5163
class ConstrainParamsTransformOpInterfaceModel(transform.TransformOpInterface):
64+
"""TransformOpInterface impl for evaluating the SMT constraints and producing new params."""
65+
5266
@staticmethod
5367
def apply(
5468
op: "ConstrainParamsOp",
5569
_rewriter: transform.TransformRewriter,
5670
results: transform.TransformResults,
5771
state: transform.TransformState,
5872
) -> transform.DiagnosedSilenceableFailure:
73+
# Set up the tracing environment by obtaining the transform params
74+
# and mapping them to Node constants, so that the traced Node
75+
# representation will refer to the params as just constants.
5976
env = dict()
6077
for operand in op.params:
6178
params = state.get_params(operand)
6279
assert len(params) == 1 and isinstance(params[0].value, int)
6380
env[operand] = trace.Constant(params[0].value)
6481

82+
# Obtained traced representation of the body of the op.
6583
env = trace.trace_tune_and_smt_ops(op.operation, env)
6684

67-
if not env[op].evaluate(env): # evaluate the conjoined predicate
85+
# Evaluate the predicate that represents the successful execution of the body.
86+
if not env[op].evaluate(env):
6887
return transform.DiagnosedSilenceableFailure.DefiniteFailure
6988

89+
# If the predicate is satisfied, we can extract the values of the result params
90+
# from the environment and set them as the results of the transformation.
7091
for result in op.results:
7192
res_value = env[result].evaluate(env)
7293
i64 = ir.IntegerType.get_signless(64)
@@ -88,6 +109,16 @@ def get_effects(op: "ConstrainParamsOp", effects):
88109

89110

90111
class MixedResultConstrainParamsOp(ConstrainParamsOp):
112+
"""ConstrainParamsOp that supports both integer and SMTIntValues as results.
113+
114+
Used to support `constrain_params` as a decorator on functions that yield a
115+
mix of Python integers and `!smt.int`s (which are either arguments to the
116+
function/block or the result of operations in the body). Upon the body's function
117+
returning, the original ConstrainParamsOp is replaced with this version
118+
that has the same parameters but whose `.results` corresponds to the mix of
119+
integers and SMT values yielded from the body.
120+
"""
121+
91122
def __init__(
92123
self,
93124
*args,
@@ -109,109 +140,62 @@ def results(self) -> Sequence[int | ext.Result[transform.AnyParamType]]:
109140
return self._results
110141

111142

112-
# class ConstrainParamsOpDecorator(ConstrainParamsOp):
113-
# def __init__(
114-
# self,
115-
# *params: transform.AnyParamType | int,
116-
# results: Sequence[int | ext.Result[transform.AnyParamType]] | None = None,
117-
# **kwargs,
118-
# ):
119-
# transform_params = [p for p in params if isinstance(p, ir.Value)]
120-
# super().__init__([], transform_params, **kwargs)
121-
# block_arg_types = [smt.IntType.get()] * len(transform_params)
122-
# self.body_.blocks.append(*block_arg_types)
123-
#
124-
# self._arguments = []
125-
# self._results = results
126-
# smt_arguments = iter(self.body.arguments)
127-
# for param in params:
128-
# if isinstance(param, int):
129-
# self._arguments.append(param)
130-
# else:
131-
# self._arguments.append(next(smt_arguments))
132-
#
133-
# @property
134-
# def results(self) -> Sequence[ext.Result | int]:
135-
# """Returns the yielded results of the decorated function, which are either
136-
# integers or the transform parameters that correspond to the yielded SMT
137-
# int values."""
138-
# assert self._results is not None, (
139-
# "Results are not available until the decorated function is called"
140-
# )
141-
# return self._results
142-
#
143-
# def __call__(self, func):
144-
# with ir.InsertionPoint(self.body):
145-
# yielded_results = func(*self._arguments)
146-
#
147-
# smt.yield_(res for res in yielded_results if isinstance(res, ir.Value))
148-
#
149-
# print(f"{yielded_results=}")
150-
# if len(yielded_results) == 0:
151-
# return self
152-
#
153-
# # In case of yielded results, we need to create a new ConstrainParamsOp with the same parameters and a body that contains the original body of the decorator, but with the yielded results as the results of the new op. We then replace the original op with the new one and return it.
154-
# result_types = [transform.AnyParamType.get()] * sum(
155-
# 1 for res in yielded_results if isinstance(res, ir.Value)
156-
# )
157-
# with ir.InsertionPoint(self):
158-
# self_with_results = ConstrainParamsOp(
159-
# result_types, self.params, loc=self.location
160-
# )
161-
# self.body_.blocks[0].append_to(self_with_results.body_)
162-
# # new_block = self_with_results.body_.blocks.append(
163-
# # *orig_block.arguments.types
164-
# # )
165-
# # arg_mapping = dict(zip(orig_block.arguments, new_block.arguments))
166-
# # lh_utils_rewrite.move_block(orig_block, new_block, arg_mapping)
167-
# # self.erase()
168-
#
169-
# results = []
170-
# op_results = iter(self_with_results.results)
171-
# for yielded_result in yielded_results:
172-
# if isinstance(yielded_result, int):
173-
# results.append(yielded_result)
174-
# elif isinstance(yielded_result, ir.Value):
175-
# results.append(next(op_results))
176-
# else:
177-
# assert False, "Unsupported yielded result type"
178-
# setattr(self_with_results, "_results", results)
179-
# return self_with_results
180-
181-
182143
@overload
183144
def constrain_params(
184145
*params: ir.Value | int, loc=None, ip=None
185-
) -> Callable[..., MixedResultConstrainParamsOp]: ...
146+
) -> Callable[..., MixedResultConstrainParamsOp]:
147+
"""Calls the decorated function with param args converted to !smt.int args.
148+
149+
The decorated function defines the body of the ConstrainParamsOp and handles
150+
args as `!smt.int` or Python integer. The function should yield a mix of
151+
Python integers and `!smt.int`s (the latter can be either block arguments or
152+
results of operations in the body). The original ConstrainParamsOp created
153+
by the decorator will be replaced with a MixedResultConstrainParamsOp that
154+
has the same parameters but whose results correspond to the mix of integers
155+
and SMT values yielded from the body.
156+
"""
157+
158+
...
186159

187160

188161
@overload
189162
def constrain_params(
190163
results: Sequence[ir.Type],
191164
params: Sequence[transform.AnyParamType],
192-
arg_types: Sequence[ir.Type],
193165
loc=None,
194166
ip=None,
195-
) -> ConstrainParamsOp: ...
167+
) -> ConstrainParamsOp:
168+
"""Creates a ConstrainParamsOp where the body is defined by the caller."""
169+
170+
...
196171

197172

198173
def constrain_params(
199174
*args, **kwargs
200175
) -> ConstrainParamsOp | Callable[..., MixedResultConstrainParamsOp]:
176+
"""Creates a ConstrainParamsOp or a decorator for a function that yields mixed results."""
177+
201178
# The second overload:
202-
if len(args) == 0 or isinstance(args[0], ir.Type):
203-
arg_types = kwargs.pop("arg_types")
179+
if len(args) == 0 or not (
180+
isinstance(args[0], ir.Value) or isinstance(args[0], int)
181+
):
182+
params = kwargs.get("params") or args[1]
183+
arg_types = [smt.IntType.get()] * len(params)
204184
op = ConstrainParamsOp(*args, **kwargs)
205185
op.body_.blocks.append(*arg_types)
206186
return op
207187

208188
# The first overload:
209-
# return ConstrainParamsOpDecorator(*args, **kwargs)
210189
def wrapper(func):
190+
# Create a ConstrainParamsOp with just the transform parameters as block arguments.
211191
param_args = [p for p in args if isinstance(p, ir.Value)]
212192
constrain_params = ConstrainParamsOp([], param_args, **kwargs)
213193
constrain_params.body_.blocks.append(*[smt.IntType.get()] * len(param_args))
214194

195+
# Call `func` with !smt.int block arguments for corresponding transform params,
196+
# and just normal ints for those passed via `args`. The body of `func` will be
197+
# the body of the op, and it can yield a mix of Python integers and `!smt.int`s.
198+
# A corresponding `smt.yield` will be generated as the terminator.
215199
block_args_iter = iter(constrain_params.body_.blocks[0].arguments)
216200
with ir.InsertionPoint(constrain_params.body):
217201
yielded_results = func(
@@ -224,21 +208,25 @@ def wrapper(func):
224208
yielded_results = [yielded_results]
225209
smt.yield_(res for res in yielded_results if isinstance(res, ir.Value))
226210

227-
if len(yielded_results) == 0:
228-
return constrain_params
211+
# In case no results are returned, the current ConstrainParamsOp is sufficient.
212+
if len(yielded_results) == 0:
213+
return constrain_params
229214

230-
result_values_or_types = [
231-
transform.AnyParamType.get() if isinstance(res, ir.Value) else res
232-
for res in yielded_results
233-
]
215+
# Create a new version of the ConstrainParamsOp that has the same
216+
# parameters but whose results correspond to the mix of integers and
217+
# SMT values yielded from the body.
218+
result_values_or_types = [
219+
transform.AnyParamType.get() if isinstance(res, ir.Value) else res
220+
for res in yielded_results
221+
]
234222

235-
mixed_result_op = MixedResultConstrainParamsOp(
236-
params=param_args, result_values_or_types=result_values_or_types, **kwargs
237-
)
238-
# Move the body of the original op to the version with (mixed) results.
239-
constrain_params.body_.blocks[0].append_to(mixed_result_op.body_)
240-
# Safe to remove as the op doesn't have results, so no users either.
241-
constrain_params.erase()
242-
return mixed_result_op
223+
mixed_result_op = MixedResultConstrainParamsOp(
224+
params=param_args, result_values_or_types=result_values_or_types, **kwargs
225+
)
226+
# Move the body of the original op to the version with (mixed) results.
227+
constrain_params.body_.blocks[0].append_to(mixed_result_op.body_)
228+
# Safe to remove as the op doesn't have results, so no users either.
229+
constrain_params.erase()
230+
return mixed_result_op
243231

244232
return wrapper

0 commit comments

Comments
 (0)