Skip to content

Commit 5bd66cd

Browse files
authored
Add Python bindings for OM unknown value (#9868)
1 parent 37cdf9c commit 5bd66cd

File tree

5 files changed

+126
-2
lines changed

5 files changed

+126
-2
lines changed

include/circt-c/Dialect/OM.h

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -206,6 +206,18 @@ omEvaluatorValueIsAReference(OMEvaluatorValue evaluatorValue);
206206
MLIR_CAPI_EXPORTED OMEvaluatorValue
207207
omEvaluatorValueGetReferenceValue(OMEvaluatorValue evaluatorValue);
208208

209+
/// Query if the EvaluatorValue is Unknown.
210+
MLIR_CAPI_EXPORTED bool
211+
omEvaluatorValueIsUnknown(OMEvaluatorValue evaluatorValue);
212+
213+
/// Create an Unknown EvaluatorValue with the given type.
214+
MLIR_CAPI_EXPORTED OMEvaluatorValue omEvaluatorUnknownGet(MlirContext context,
215+
MlirType type);
216+
217+
/// Get the type of an EvaluatorValue.
218+
MLIR_CAPI_EXPORTED MlirType
219+
omEvaluatorValueGetType(OMEvaluatorValue evaluatorValue);
220+
209221
//===----------------------------------------------------------------------===//
210222
// ReferenceAttr API
211223
//===----------------------------------------------------------------------===//

integration_test/Bindings/Python/dialects/om.py

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -251,3 +251,62 @@
251251
list_type = Type.parse("!om.list<!om.any>")
252252
assert isinstance(list_type, om.ListType)
253253
assert isinstance(list_type.element_type, om.AnyType)
254+
255+
# Test partial evaluation with multiple inputs and outputs
256+
module = Module.parse("""
257+
module {
258+
om.class @PartialEval(
259+
%in0: !om.integer,
260+
%in1: !om.integer,
261+
%in2: !om.integer,
262+
%in3: i1
263+
) -> (
264+
out0: !om.integer,
265+
out1: !om.integer,
266+
out2: !om.integer,
267+
out3: !om.integer,
268+
out4: i1,
269+
out5: i1
270+
) {
271+
// out0: known constant
272+
%c42 = om.constant #om.integer<42 : i32> : !om.integer
273+
274+
// out1: known input (in0)
275+
276+
// out2: unknown input (in1) propagated through arithmetic
277+
%sum = om.integer.add %in1, %c42 : !om.integer
278+
279+
// out3: depends on unknown input (in2) and known input (in0)
280+
%prod = om.integer.mul %in0, %in2 : !om.integer
281+
282+
// out4: known boolean constant
283+
%true = om.constant true
284+
285+
// out5: unknown boolean input (in3)
286+
287+
om.class.fields %c42, %in0, %sum, %prod, %true, %in3 :
288+
!om.integer, !om.integer, !om.integer, !om.integer, i1, i1
289+
}
290+
}
291+
""")
292+
293+
evaluator = om.Evaluator(module)
294+
om_integer_type = Type.parse("!om.integer")
295+
i1_type = Type.parse("i1")
296+
297+
# Instantiate with mix of known and unknown inputs
298+
obj = evaluator.instantiate("PartialEval", 100, om.Unknown(om_integer_type),
299+
om.Unknown(om_integer_type), om.Unknown(i1_type))
300+
301+
# CHECK: out0 (constant): 42
302+
print(f"out0 (constant): {obj.out0}")
303+
# CHECK: out1 (known input): 100
304+
print(f"out1 (known input): {obj.out1}")
305+
# CHECK: out2 (unknown input): Unknown(!om.integer)
306+
print(f"out2 (unknown input): {obj.out2}")
307+
# CHECK: out3 (depends on unknown): Unknown(!om.integer)
308+
print(f"out3 (depends on unknown): {obj.out3}")
309+
# CHECK: out4 (constant bool): True
310+
print(f"out4 (constant bool): {obj.out4}")
311+
# CHECK: out5 (unknown bool): Unknown(i1)
312+
print(f"out5 (unknown bool): {obj.out5}")

lib/Bindings/Python/OMModule.cpp

Lines changed: 31 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
#include "mlir-c/BuiltinTypes.h"
1414
#include "mlir-c/IR.h"
1515
#include "mlir/Bindings/Python/NanobindAdaptors.h"
16+
#include "mlir/Bindings/Python/NanobindUtils.h"
1617
#include <nanobind/nanobind.h>
1718
#include <nanobind/stl/variant.h>
1819
#include <nanobind/stl/vector.h>
@@ -29,6 +30,16 @@ struct Object;
2930
struct BasePath;
3031
struct Path;
3132

33+
/// Represents a value that is not known because it is an unsupplied input, or
34+
/// derived from unsupplied inputs.
35+
struct Unknown {
36+
Unknown(MlirType type) : type(type) {}
37+
MlirType getType() const { return type; }
38+
39+
private:
40+
MlirType type;
41+
};
42+
3243
/// These are the Python types that are represented by the different primitive
3344
/// OMEvaluatorValues as Attributes.
3445
using PythonPrimitive = std::variant<nb::int_, nb::float_, nb::str, nb::bool_,
@@ -41,7 +52,7 @@ using PythonPrimitive = std::variant<nb::int_, nb::float_, nb::str, nb::bool_,
4152
/// is tried first, then we can hit an assert inside the MLIR codebase.
4253
struct None {};
4354
using PythonValue =
44-
std::variant<None, Object, List, BasePath, Path, PythonPrimitive>;
55+
std::variant<None, Object, List, BasePath, Path, Unknown, PythonPrimitive>;
4556

4657
/// Map an opaque OMEvaluatorValue into a python value.
4758
PythonValue omEvaluatorValueToPythonValue(OMEvaluatorValue result);
@@ -362,6 +373,9 @@ PythonValue omEvaluatorValueToPythonValue(OMEvaluatorValue result) {
362373
if (omEvaluatorValueIsAPath(result))
363374
return Path(result);
364375

376+
if (omEvaluatorValueIsUnknown(result))
377+
return Unknown(omEvaluatorValueGetType(result));
378+
365379
if (omEvaluatorValueIsAReference(result))
366380
return omEvaluatorValueToPythonValue(
367381
omEvaluatorValueGetReferenceValue(result));
@@ -385,6 +399,9 @@ OMEvaluatorValue pythonValueToOMEvaluatorValue(PythonValue result,
385399
if (auto *object = std::get_if<Object>(&result))
386400
return object->getValue();
387401

402+
if (auto *unknown = std::get_if<Unknown>(&result))
403+
return omEvaluatorUnknownGet(ctx, unknown->getType());
404+
388405
auto primitive = std::get<PythonPrimitive>(result);
389406
return omEvaluatorValueFromPrimitive(
390407
omPythonValueToPrimitive(primitive, ctx));
@@ -421,6 +438,19 @@ void circt::python::populateDialectOMSubmodule(nb::module_ &m) {
421438
.def(nb::init<Path>(), nb::arg("path"))
422439
.def("__str__", &Path::dunderStr);
423440

441+
// Add the Unknown sentinel class definition.
442+
nb::class_<Unknown>(m, "Unknown")
443+
.def(nb::init<MlirType>(), nb::arg("type"))
444+
.def_prop_ro("type", &Unknown::getType)
445+
.def("__repr__", [](const Unknown &u) {
446+
PyPrintAccumulator printAccum;
447+
printAccum.parts.append("Unknown(");
448+
mlirTypePrint(u.getType(), printAccum.getCallback(),
449+
printAccum.getUserData());
450+
printAccum.parts.append(")");
451+
return printAccum.join();
452+
});
453+
424454
// Add the Object class definition.
425455
nb::class_<Object>(m, "Object")
426456
.def(nb::init<Object>(), nb::arg("object"))

lib/Bindings/Python/dialects/om.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from __future__ import annotations
66

77
from ._om_ops_gen import *
8-
from .._mlir_libs._circt._om import AnyType, Evaluator as BaseEvaluator, Object as BaseObject, List as BaseList, BasePath as BaseBasePath, BasePathType, Path, PathType, ClassType, ReferenceAttr, ListAttr, ListType, OMIntegerAttr
8+
from .._mlir_libs._circt._om import AnyType, Evaluator as BaseEvaluator, Object as BaseObject, List as BaseList, BasePath as BaseBasePath, BasePathType, Path, PathType, ClassType, ReferenceAttr, ListAttr, ListType, OMIntegerAttr, Unknown
99

1010
from ..ir import Attribute, Diagnostic, DiagnosticSeverity, Module, StringAttr, IntegerAttr, IntegerType
1111
from ..support import attribute_to_var, var_to_attribute
@@ -25,6 +25,9 @@ def wrap_mlir_object(value):
2525
if isinstance(value, (int, float, str, bool, tuple, list, dict)):
2626
return value
2727

28+
if isinstance(value, Unknown):
29+
return value
30+
2831
if isinstance(value, BaseList):
2932
return List(value)
3033

lib/CAPI/Dialect/OM.cpp

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -321,6 +321,26 @@ omEvaluatorValueGetReferenceValue(OMEvaluatorValue evaluatorValue) {
321321
return wrap(result.value());
322322
}
323323

324+
/// Query if the EvaluatorValue is an Unknown value.
325+
bool omEvaluatorValueIsUnknown(OMEvaluatorValue evaluatorValue) {
326+
return unwrap(evaluatorValue)->isUnknown();
327+
}
328+
329+
/// Create an Unknown EvaluatorValue.
330+
OMEvaluatorValue omEvaluatorUnknownGet(MlirContext context, MlirType type) {
331+
auto *ctx = unwrap(context);
332+
auto loc = UnknownLoc::get(ctx);
333+
auto unknownValue =
334+
evaluator::AttributeValue::get(unwrap(type), LocationAttr(loc));
335+
unknownValue->markUnknown();
336+
return wrap(unknownValue);
337+
}
338+
339+
/// Get the type of an EvaluatorValue.
340+
MlirType omEvaluatorValueGetType(OMEvaluatorValue evaluatorValue) {
341+
return wrap(unwrap(evaluatorValue)->getType());
342+
}
343+
324344
//===----------------------------------------------------------------------===//
325345
// ReferenceAttr API.
326346
//===----------------------------------------------------------------------===//

0 commit comments

Comments
 (0)