Skip to content

Commit c8a9a41

Browse files
authored
[MLIR] [python] A few improvements to the Python bindings (llvm#131686)
* `PyRegionList` is now sliceable. The dialect bindings generator seems to assume it is sliceable already (!), yet accessing e.g. `cases` on `scf.IndexedSwitchOp` raises a `TypeError` at runtime. * `PyBlockList` and `PyOperationList` support negative indexing. It is common for containers to do that in Python, and most container in the MLIR Python bindings already allow the index to be negative.
1 parent 4d5a963 commit c8a9a41

File tree

3 files changed

+55
-19
lines changed

3 files changed

+55
-19
lines changed

mlir/lib/Bindings/Python/IRCore.cpp

+33-16
Original file line numberDiff line numberDiff line change
@@ -361,37 +361,45 @@ class PyRegionIterator {
361361

362362
/// Regions of an op are fixed length and indexed numerically so are represented
363363
/// with a sequence-like container.
364-
class PyRegionList {
364+
class PyRegionList : public Sliceable<PyRegionList, PyRegion> {
365365
public:
366-
PyRegionList(PyOperationRef operation) : operation(std::move(operation)) {}
366+
static constexpr const char *pyClassName = "RegionSequence";
367+
368+
PyRegionList(PyOperationRef operation, intptr_t startIndex = 0,
369+
intptr_t length = -1, intptr_t step = 1)
370+
: Sliceable(startIndex,
371+
length == -1 ? mlirOperationGetNumRegions(operation->get())
372+
: length,
373+
step),
374+
operation(std::move(operation)) {}
367375

368376
PyRegionIterator dunderIter() {
369377
operation->checkValid();
370378
return PyRegionIterator(operation);
371379
}
372380

373-
intptr_t dunderLen() {
381+
static void bindDerived(ClassTy &c) {
382+
c.def("__iter__", &PyRegionList::dunderIter);
383+
}
384+
385+
private:
386+
/// Give the parent CRTP class access to hook implementations below.
387+
friend class Sliceable<PyRegionList, PyRegion>;
388+
389+
intptr_t getRawNumElements() {
374390
operation->checkValid();
375391
return mlirOperationGetNumRegions(operation->get());
376392
}
377393

378-
PyRegion dunderGetItem(intptr_t index) {
379-
// dunderLen checks validity.
380-
if (index < 0 || index >= dunderLen()) {
381-
throw nb::index_error("attempt to access out of bounds region");
382-
}
383-
MlirRegion region = mlirOperationGetRegion(operation->get(), index);
384-
return PyRegion(operation, region);
394+
PyRegion getRawElement(intptr_t pos) {
395+
operation->checkValid();
396+
return PyRegion(operation, mlirOperationGetRegion(operation->get(), pos));
385397
}
386398

387-
static void bind(nb::module_ &m) {
388-
nb::class_<PyRegionList>(m, "RegionSequence")
389-
.def("__len__", &PyRegionList::dunderLen)
390-
.def("__iter__", &PyRegionList::dunderIter)
391-
.def("__getitem__", &PyRegionList::dunderGetItem);
399+
PyRegionList slice(intptr_t startIndex, intptr_t length, intptr_t step) {
400+
return PyRegionList(operation, startIndex, length, step);
392401
}
393402

394-
private:
395403
PyOperationRef operation;
396404
};
397405

@@ -450,6 +458,9 @@ class PyBlockList {
450458

451459
PyBlock dunderGetItem(intptr_t index) {
452460
operation->checkValid();
461+
if (index < 0) {
462+
index += dunderLen();
463+
}
453464
if (index < 0) {
454465
throw nb::index_error("attempt to access out of bounds block");
455466
}
@@ -546,6 +557,9 @@ class PyOperationList {
546557

547558
nb::object dunderGetItem(intptr_t index) {
548559
parentOperation->checkValid();
560+
if (index < 0) {
561+
index += dunderLen();
562+
}
549563
if (index < 0) {
550564
throw nb::index_error("attempt to access out of bounds operation");
551565
}
@@ -2629,6 +2643,9 @@ class PyOpAttributeMap {
26292643
}
26302644

26312645
PyNamedAttribute dunderGetItemIndexed(intptr_t index) {
2646+
if (index < 0) {
2647+
index += dunderLen();
2648+
}
26322649
if (index < 0 || index >= dunderLen()) {
26332650
throw nb::index_error("attempt to access out of bounds attribute");
26342651
}

mlir/python/mlir/_mlir_libs/_mlir/ir.pyi

+3
Original file line numberDiff line numberDiff line change
@@ -2466,7 +2466,10 @@ class RegionIterator:
24662466
def __next__(self) -> Region: ...
24672467

24682468
class RegionSequence:
2469+
@overload
24692470
def __getitem__(self, arg0: int) -> Region: ...
2471+
@overload
2472+
def __getitem__(self, arg0: slice) -> Sequence[Region]: ...
24702473
def __iter__(self) -> RegionIterator: ...
24712474
def __len__(self) -> int: ...
24722475

mlir/test/python/ir/operation.py

+19-3
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ def testTraverseOpRegionBlockIterators():
4444
op = module.operation
4545
assert op.context is ctx
4646
# Get the block using iterators off of the named collections.
47-
regions = list(op.regions)
47+
regions = list(op.regions[:])
4848
blocks = list(regions[0].blocks)
4949
# CHECK: MODULE REGIONS=1 BLOCKS=1
5050
print(f"MODULE REGIONS={len(regions)} BLOCKS={len(blocks)}")
@@ -86,8 +86,24 @@ def walk_operations(indent, op):
8686
# CHECK: Block iter: <mlir.{{.+}}.BlockIterator
8787
# CHECK: Operation iter: <mlir.{{.+}}.OperationIterator
8888
print(" Region iter:", iter(op.regions))
89-
print(" Block iter:", iter(op.regions[0]))
90-
print("Operation iter:", iter(op.regions[0].blocks[0]))
89+
print(" Block iter:", iter(op.regions[-1]))
90+
print("Operation iter:", iter(op.regions[-1].blocks[-1]))
91+
92+
try:
93+
op.regions[-42]
94+
except IndexError as e:
95+
# CHECK: Region OOB: index out of range
96+
print("Region OOB:", e)
97+
try:
98+
op.regions[0].blocks[-42]
99+
except IndexError as e:
100+
# CHECK: attempt to access out of bounds block
101+
print(e)
102+
try:
103+
op.regions[0].blocks[0].operations[-42]
104+
except IndexError as e:
105+
# CHECK: attempt to access out of bounds operation
106+
print(e)
91107

92108

93109
# Verify index based traversal of the op/region/block hierarchy.

0 commit comments

Comments
 (0)