1212 "register_and_load" ,
1313]
1414
15+
1516def register_and_load (context = None ):
17+ """Register and load the TransformSMTDialectExtension and its operations."""
18+
1619 TransformSMTDialectExtension .load ()
1720
1821
1922class TransformSMTDialectExtension (ext .Dialect , name = "transform_smt_ext" ):
23+ """A Transform Dialect extension for SMT-related operations."""
24+
2025 @classmethod
2126 def load (cls , * args , ** kwargs ):
2227 super (TransformSMTDialectExtension , cls ).load (* args , ** kwargs )
@@ -29,6 +34,13 @@ def load(cls, *args, **kwargs):
2934class ConstrainParamsOp (
3035 TransformSMTDialectExtension .Operation , name = "constrain_params"
3136):
37+ """Constrain transform params by SMT ops while also producing new params.
38+
39+ In effect applies a predicate defined by the SMT ops in the body, which can
40+ reference the parameters as block arguments as !smt.int. The result params
41+ are defined by the !smt.int values yielded from the body.
42+ """
43+
3244 results_ : Sequence [ext .Result [transform .AnyParamType ]]
3345 params : Sequence [ext .Operand [transform .AnyParamType ]]
3446 body_ : ext .Region
@@ -49,24 +61,33 @@ def attach_interfaces(cls, ctx=None):
4961 setattr (cls , "_interfaces_attached" , True )
5062
5163 class ConstrainParamsTransformOpInterfaceModel (transform .TransformOpInterface ):
64+ """TransformOpInterface impl for evaluating the SMT constraints and producing new params."""
65+
5266 @staticmethod
5367 def apply (
5468 op : "ConstrainParamsOp" ,
5569 _rewriter : transform .TransformRewriter ,
5670 results : transform .TransformResults ,
5771 state : transform .TransformState ,
5872 ) -> transform .DiagnosedSilenceableFailure :
73+ # Set up the tracing environment by obtaining the transform params
74+ # and mapping them to Node constants, so that the traced Node
75+ # representation will refer to the params as just constants.
5976 env = dict ()
6077 for operand in op .params :
6178 params = state .get_params (operand )
6279 assert len (params ) == 1 and isinstance (params [0 ].value , int )
6380 env [operand ] = trace .Constant (params [0 ].value )
6481
82+ # Obtained traced representation of the body of the op.
6583 env = trace .trace_tune_and_smt_ops (op .operation , env )
6684
67- if not env [op ].evaluate (env ): # evaluate the conjoined predicate
85+ # Evaluate the predicate that represents the successful execution of the body.
86+ if not env [op ].evaluate (env ):
6887 return transform .DiagnosedSilenceableFailure .DefiniteFailure
6988
89+ # If the predicate is satisfied, we can extract the values of the result params
90+ # from the environment and set them as the results of the transformation.
7091 for result in op .results :
7192 res_value = env [result ].evaluate (env )
7293 i64 = ir .IntegerType .get_signless (64 )
@@ -88,6 +109,16 @@ def get_effects(op: "ConstrainParamsOp", effects):
88109
89110
90111class MixedResultConstrainParamsOp (ConstrainParamsOp ):
112+ """ConstrainParamsOp that supports both integer and SMTIntValues as results.
113+
114+ Used to support `constrain_params` as a decorator on functions that yield a
115+ mix of Python integers and `!smt.int`s (which are either arguments to the
116+ function/block or the result of operations in the body). Upon the body's function
117+ returning, the original ConstrainParamsOp is replaced with this version
118+ that has the same parameters but whose `.results` corresponds to the mix of
119+ integers and SMT values yielded from the body.
120+ """
121+
91122 def __init__ (
92123 self ,
93124 * args ,
@@ -109,109 +140,62 @@ def results(self) -> Sequence[int | ext.Result[transform.AnyParamType]]:
109140 return self ._results
110141
111142
112- # class ConstrainParamsOpDecorator(ConstrainParamsOp):
113- # def __init__(
114- # self,
115- # *params: transform.AnyParamType | int,
116- # results: Sequence[int | ext.Result[transform.AnyParamType]] | None = None,
117- # **kwargs,
118- # ):
119- # transform_params = [p for p in params if isinstance(p, ir.Value)]
120- # super().__init__([], transform_params, **kwargs)
121- # block_arg_types = [smt.IntType.get()] * len(transform_params)
122- # self.body_.blocks.append(*block_arg_types)
123- #
124- # self._arguments = []
125- # self._results = results
126- # smt_arguments = iter(self.body.arguments)
127- # for param in params:
128- # if isinstance(param, int):
129- # self._arguments.append(param)
130- # else:
131- # self._arguments.append(next(smt_arguments))
132- #
133- # @property
134- # def results(self) -> Sequence[ext.Result | int]:
135- # """Returns the yielded results of the decorated function, which are either
136- # integers or the transform parameters that correspond to the yielded SMT
137- # int values."""
138- # assert self._results is not None, (
139- # "Results are not available until the decorated function is called"
140- # )
141- # return self._results
142- #
143- # def __call__(self, func):
144- # with ir.InsertionPoint(self.body):
145- # yielded_results = func(*self._arguments)
146- #
147- # smt.yield_(res for res in yielded_results if isinstance(res, ir.Value))
148- #
149- # print(f"{yielded_results=}")
150- # if len(yielded_results) == 0:
151- # return self
152- #
153- # # In case of yielded results, we need to create a new ConstrainParamsOp with the same parameters and a body that contains the original body of the decorator, but with the yielded results as the results of the new op. We then replace the original op with the new one and return it.
154- # result_types = [transform.AnyParamType.get()] * sum(
155- # 1 for res in yielded_results if isinstance(res, ir.Value)
156- # )
157- # with ir.InsertionPoint(self):
158- # self_with_results = ConstrainParamsOp(
159- # result_types, self.params, loc=self.location
160- # )
161- # self.body_.blocks[0].append_to(self_with_results.body_)
162- # # new_block = self_with_results.body_.blocks.append(
163- # # *orig_block.arguments.types
164- # # )
165- # # arg_mapping = dict(zip(orig_block.arguments, new_block.arguments))
166- # # lh_utils_rewrite.move_block(orig_block, new_block, arg_mapping)
167- # # self.erase()
168- #
169- # results = []
170- # op_results = iter(self_with_results.results)
171- # for yielded_result in yielded_results:
172- # if isinstance(yielded_result, int):
173- # results.append(yielded_result)
174- # elif isinstance(yielded_result, ir.Value):
175- # results.append(next(op_results))
176- # else:
177- # assert False, "Unsupported yielded result type"
178- # setattr(self_with_results, "_results", results)
179- # return self_with_results
180-
181-
182143@overload
183144def constrain_params (
184145 * params : ir .Value | int , loc = None , ip = None
185- ) -> Callable [..., MixedResultConstrainParamsOp ]: ...
146+ ) -> Callable [..., MixedResultConstrainParamsOp ]:
147+ """Calls the decorated function with param args converted to !smt.int args.
148+
149+ The decorated function defines the body of the ConstrainParamsOp and handles
150+ args as `!smt.int` or Python integer. The function should yield a mix of
151+ Python integers and `!smt.int`s (the latter can be either block arguments or
152+ results of operations in the body). The original ConstrainParamsOp created
153+ by the decorator will be replaced with a MixedResultConstrainParamsOp that
154+ has the same parameters but whose results correspond to the mix of integers
155+ and SMT values yielded from the body.
156+ """
157+
158+ ...
186159
187160
188161@overload
189162def constrain_params (
190163 results : Sequence [ir .Type ],
191164 params : Sequence [transform .AnyParamType ],
192- arg_types : Sequence [ir .Type ],
193165 loc = None ,
194166 ip = None ,
195- ) -> ConstrainParamsOp : ...
167+ ) -> ConstrainParamsOp :
168+ """Creates a ConstrainParamsOp where the body is defined by the caller."""
169+
170+ ...
196171
197172
198173def constrain_params (
199174 * args , ** kwargs
200175) -> ConstrainParamsOp | Callable [..., MixedResultConstrainParamsOp ]:
176+ """Creates a ConstrainParamsOp or a decorator for a function that yields mixed results."""
177+
201178 # The second overload:
202- if len (args ) == 0 or isinstance (args [0 ], ir .Type ):
203- arg_types = kwargs .pop ("arg_types" )
179+ if len (args ) == 0 or not (
180+ isinstance (args [0 ], ir .Value ) or isinstance (args [0 ], int )
181+ ):
182+ params = kwargs .get ("params" ) or args [1 ]
183+ arg_types = [smt .IntType .get ()] * len (params )
204184 op = ConstrainParamsOp (* args , ** kwargs )
205185 op .body_ .blocks .append (* arg_types )
206186 return op
207187
208188 # The first overload:
209- # return ConstrainParamsOpDecorator(*args, **kwargs)
210189 def wrapper (func ):
190+ # Create a ConstrainParamsOp with just the transform parameters as block arguments.
211191 param_args = [p for p in args if isinstance (p , ir .Value )]
212192 constrain_params = ConstrainParamsOp ([], param_args , ** kwargs )
213193 constrain_params .body_ .blocks .append (* [smt .IntType .get ()] * len (param_args ))
214194
195+ # Call `func` with !smt.int block arguments for corresponding transform params,
196+ # and just normal ints for those passed via `args`. The body of `func` will be
197+ # the body of the op, and it can yield a mix of Python integers and `!smt.int`s.
198+ # A corresponding `smt.yield` will be generated as the terminator.
215199 block_args_iter = iter (constrain_params .body_ .blocks [0 ].arguments )
216200 with ir .InsertionPoint (constrain_params .body ):
217201 yielded_results = func (
@@ -224,21 +208,25 @@ def wrapper(func):
224208 yielded_results = [yielded_results ]
225209 smt .yield_ (res for res in yielded_results if isinstance (res , ir .Value ))
226210
227- if len (yielded_results ) == 0 :
228- return constrain_params
211+ # In case no results are returned, the current ConstrainParamsOp is sufficient.
212+ if len (yielded_results ) == 0 :
213+ return constrain_params
229214
230- result_values_or_types = [
231- transform .AnyParamType .get () if isinstance (res , ir .Value ) else res
232- for res in yielded_results
233- ]
215+ # Create a new version of the ConstrainParamsOp that has the same
216+ # parameters but whose results correspond to the mix of integers and
217+ # SMT values yielded from the body.
218+ result_values_or_types = [
219+ transform .AnyParamType .get () if isinstance (res , ir .Value ) else res
220+ for res in yielded_results
221+ ]
234222
235- mixed_result_op = MixedResultConstrainParamsOp (
236- params = param_args , result_values_or_types = result_values_or_types , ** kwargs
237- )
238- # Move the body of the original op to the version with (mixed) results.
239- constrain_params .body_ .blocks [0 ].append_to (mixed_result_op .body_ )
240- # Safe to remove as the op doesn't have results, so no users either.
241- constrain_params .erase ()
242- return mixed_result_op
223+ mixed_result_op = MixedResultConstrainParamsOp (
224+ params = param_args , result_values_or_types = result_values_or_types , ** kwargs
225+ )
226+ # Move the body of the original op to the version with (mixed) results.
227+ constrain_params .body_ .blocks [0 ].append_to (mixed_result_op .body_ )
228+ # Safe to remove as the op doesn't have results, so no users either.
229+ constrain_params .erase ()
230+ return mixed_result_op
243231
244232 return wrapper
0 commit comments