Skip to content

Commit 09d24cd

Browse files
authored
[eudsl-python-extras] fix memref.reinterpret_cast to handle mixed static/dynamic offsets/sizes/strides (#355)
Add tests which weren't added when first implemented
1 parent da33110 commit 09d24cd

File tree

2 files changed

+148
-15
lines changed

2 files changed

+148
-15
lines changed

projects/eudsl-python-extras/mlir/extras/dialects/memref.py

Lines changed: 22 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -636,27 +636,35 @@ def reinterpret_cast(
636636
sizes_, _packed_sizes, static_sizes = _dispatch_mixed_values(sizes)
637637
strides_, _packed_strides, static_strides = _dispatch_mixed_values(strides)
638638

639-
if offsets_ or sizes_ or strides_:
640-
raise NotImplementedError("only static offsets and sizes and strides supported")
639+
sizes_list = list(static_sizes)
640+
strides_list = list(static_strides)
641+
offsets_list = list(static_offsets)
641642

643+
# Compute default (row-major) strides when all sizes and strides are static
642644
default_strides = None
643-
if not static_strides and all(_is_static_int_like(s) for s in static_sizes):
644-
default_strides = list(accumulate(list(static_sizes)[1:][::-1], operator.mul))[
645-
::-1
646-
] + [1]
647-
static_strides = default_strides
645+
if not sizes_ and not strides_ and sizes_list and all(s != S for s in sizes_list):
646+
default_strides = list(accumulate(sizes_list[1:][::-1], operator.mul))[::-1] + [
647+
1
648+
]
649+
650+
if not strides_list and default_strides is not None:
651+
strides_list = default_strides
648652

649-
target_offset = 0
650-
for offset, target_stride in zip(static_offsets, static_strides):
651-
target_offset += offset * target_stride
653+
# The layout offset is the flat element offset (single value in offsets)
654+
target_offset = offsets_list[0] if offsets_list else 0
652655

653-
if static_strides == default_strides and target_offset == 0:
656+
# Omit layout when strides are default row-major and offset is 0
657+
if (
658+
default_strides is not None
659+
and strides_list == default_strides
660+
and target_offset == 0
661+
):
654662
layout = None
655663
else:
656-
layout = StridedLayoutAttr.get(target_offset, static_strides)
664+
layout = StridedLayoutAttr.get(target_offset, strides_list)
657665

658666
result = MemRefType.get(
659-
static_sizes, source.type.element_type, layout, source.type.memory_space
667+
sizes_list, source.type.element_type, layout, source.type.memory_space
660668
)
661669
return ReinterpretCastOp(
662670
result=result,
@@ -666,7 +674,7 @@ def reinterpret_cast(
666674
strides=strides_,
667675
static_offsets=static_offsets,
668676
static_sizes=static_sizes,
669-
static_strides=static_strides,
677+
static_strides=strides_list,
670678
loc=loc,
671679
ip=ip,
672680
).result

projects/eudsl-python-extras/tests/dialect/test_memref.py

Lines changed: 126 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,12 @@
99
import numpy as np
1010
import pytest
1111
from mlir.dialects.memref import subview
12-
from mlir.ir import MLIRError, Type, Value
12+
from mlir.ir import (
13+
MLIRError,
14+
Type,
15+
UnrankedMemRefType,
16+
Value,
17+
)
1318

1419
from mlir.extras.ast.canonicalize import canonicalize
1520
from mlir.extras.dialects import memref, arith
@@ -21,6 +26,7 @@
2126
alloca_scope_return,
2227
global_,
2328
rank_reduce,
29+
reinterpret_cast,
2430
S,
2531
)
2632
from mlir.extras.dialects.scf import (
@@ -721,3 +727,122 @@ def test_dim(ctx: MLIRContext):
721727

722728
dims = mem_dynamic.dims()
723729
assert isinstance(dims[1], Value) and isinstance(dims[1].owner.opview, memref.DimOp)
730+
731+
732+
def test_cast_ranked_memref_to_static_shape(ctx: MLIRContext):
733+
input = alloc((2, 3), T.f32())
734+
reinterpret_cast(input, offsets=[0], sizes=[6, 1], strides=[1, 1])
735+
736+
# CHECK: %[[ALLOC:.*]] = memref.alloc() : memref<2x3xf32>
737+
# CHECK: %[[OUT:.*]] = memref.reinterpret_cast %[[ALLOC]] to offset: [0], sizes: [6, 1], strides: [1, 1] : memref<2x3xf32> to memref<6x1xf32>
738+
739+
filecheck_with_comments(ctx.module)
740+
741+
742+
def test_cast_ranked_memref_to_dynamic_shape(ctx: MLIRContext):
743+
input = alloc((2, 3), T.f32())
744+
c0 = constant(0, index=True)
745+
c1 = constant(1, index=True)
746+
c6 = constant(6, index=True)
747+
reinterpret_cast(input, offsets=[c0], sizes=[c1, c6], strides=[c6, c1])
748+
749+
# CHECK: %[[ALLOC:.*]] = memref.alloc() : memref<2x3xf32>
750+
# CHECK: %[[C0:.*]] = arith.constant 0 : index
751+
# CHECK: %[[C1:.*]] = arith.constant 1 : index
752+
# CHECK: %[[C6:.*]] = arith.constant 6 : index
753+
# CHECK: %[[OUT:.*]] = memref.reinterpret_cast %[[ALLOC]] to offset: [%[[C0]]], sizes: [%[[C1]], %[[C6]]], strides: [%[[C6]], %[[C1]]] : memref<2x3xf32> to memref<?x?xf32, strided<[?, ?], offset: ?>>
754+
755+
filecheck_with_comments(ctx.module)
756+
757+
758+
def test_cast_unranked_memref_to_static_shape(ctx: MLIRContext):
759+
f32 = T.f32()
760+
input = alloc((2, 3), f32)
761+
unranked = memref.CastOp(UnrankedMemRefType.get(f32, None), input).result
762+
reinterpret_cast(unranked, offsets=[0], sizes=[6, 1], strides=[1, 1])
763+
764+
# CHECK: %[[ALLOC:.*]] = memref.alloc() : memref<2x3xf32>
765+
# CHECK: %[[CAST:.*]] = memref.cast %[[ALLOC]] : memref<2x3xf32> to memref<*xf32>
766+
# CHECK: %[[OUT:.*]] = memref.reinterpret_cast %[[CAST]] to offset: [0], sizes: [6, 1], strides: [1, 1] : memref<*xf32> to memref<6x1xf32>
767+
768+
filecheck_with_comments(ctx.module)
769+
770+
771+
def test_cast_unranked_memref_to_dynamic_shape(ctx: MLIRContext):
772+
f32 = T.f32()
773+
input = alloc((2, 3), f32)
774+
unranked = memref.CastOp(UnrankedMemRefType.get(f32, None), input).result
775+
c0 = constant(0, index=True)
776+
c1 = constant(1, index=True)
777+
c6 = constant(6, index=True)
778+
reinterpret_cast(unranked, offsets=[c0], sizes=[c1, c6], strides=[c6, c1])
779+
780+
# CHECK: %[[ALLOC:.*]] = memref.alloc() : memref<2x3xf32>
781+
# CHECK: %[[CAST:.*]] = memref.cast %[[ALLOC]] : memref<2x3xf32> to memref<*xf32>
782+
# CHECK: %[[C0:.*]] = arith.constant 0 : index
783+
# CHECK: %[[C1:.*]] = arith.constant 1 : index
784+
# CHECK: %[[C6:.*]] = arith.constant 6 : index
785+
# CHECK: %[[OUT:.*]] = memref.reinterpret_cast %[[CAST]] to offset: [%[[C0]]], sizes: [%[[C1]], %[[C6]]], strides: [%[[C6]], %[[C1]]] : memref<*xf32> to memref<?x?xf32, strided<[?, ?], offset: ?>>
786+
787+
filecheck_with_comments(ctx.module)
788+
789+
790+
def test_reinterpret_cast_mixed_sizes(ctx: MLIRContext):
791+
# Static first dim, dynamic second dim; static offset and strides.
792+
input = alloc((2, 3), T.f32())
793+
c1 = constant(1, index=True)
794+
reinterpret_cast(input, offsets=[0], sizes=[6, c1], strides=[1, 1])
795+
796+
# CHECK: %[[ALLOC:.*]] = memref.alloc() : memref<2x3xf32>
797+
# CHECK: %[[C1:.*]] = arith.constant 1 : index
798+
# CHECK: %[[OUT:.*]] = memref.reinterpret_cast %[[ALLOC]] to offset: [0], sizes: [6, %[[C1]]], strides: [1, 1] : memref<2x3xf32> to memref<6x?xf32, strided<[1, 1]>>
799+
800+
filecheck_with_comments(ctx.module)
801+
802+
803+
def test_reinterpret_cast_mixed_strides(ctx: MLIRContext):
804+
# Static sizes and offset; dynamic first stride, static second stride.
805+
input = alloc((2, 3), T.f32())
806+
c6 = constant(6, index=True)
807+
reinterpret_cast(input, offsets=[0], sizes=[6, 1], strides=[c6, 1])
808+
809+
# CHECK: %[[ALLOC:.*]] = memref.alloc() : memref<2x3xf32>
810+
# CHECK: %[[C6:.*]] = arith.constant 6 : index
811+
# CHECK: %[[OUT:.*]] = memref.reinterpret_cast %[[ALLOC]] to offset: [0], sizes: [6, 1], strides: [%[[C6]], 1] : memref<2x3xf32> to memref<6x1xf32, strided<[?, 1]>>
812+
813+
filecheck_with_comments(ctx.module)
814+
815+
816+
def test_reinterpret_cast_mixed_offset(ctx: MLIRContext):
817+
# Dynamic offset; static sizes and strides.
818+
input = alloc((2, 3), T.f32())
819+
c0 = constant(0, index=True)
820+
reinterpret_cast(input, offsets=[c0], sizes=[6, 1], strides=[1, 1])
821+
822+
# CHECK: %[[ALLOC:.*]] = memref.alloc() : memref<2x3xf32>
823+
# CHECK: %[[C0:.*]] = arith.constant 0 : index
824+
# CHECK: %[[OUT:.*]] = memref.reinterpret_cast %[[ALLOC]] to offset: [%[[C0]]], sizes: [6, 1], strides: [1, 1] : memref<2x3xf32> to memref<6x1xf32, strided<[1, 1], offset: ?>>
825+
826+
filecheck_with_comments(ctx.module)
827+
828+
829+
def test_reinterpret_cast_nonzero_static_offset(ctx: MLIRContext):
830+
input = alloc((2, 3), T.f32())
831+
reinterpret_cast(input, offsets=[3], sizes=[6, 1], strides=[1, 1])
832+
833+
# CHECK: %[[ALLOC:.*]] = memref.alloc() : memref<2x3xf32>
834+
# CHECK: %[[OUT:.*]] = memref.reinterpret_cast %[[ALLOC]] to offset: [3], sizes: [6, 1], strides: [1, 1] : memref<2x3xf32> to memref<6x1xf32, strided<[1, 1], offset: 3>>
835+
836+
filecheck_with_comments(ctx.module)
837+
838+
839+
def test_reinterpret_cast_nonzero_dynamic_offset(ctx: MLIRContext):
840+
input = alloc((2, 3), T.f32())
841+
c3 = constant(3, index=True)
842+
reinterpret_cast(input, offsets=[c3], sizes=[6, 1], strides=[1, 1])
843+
844+
# CHECK: %[[ALLOC:.*]] = memref.alloc() : memref<2x3xf32>
845+
# CHECK: %[[C3:.*]] = arith.constant 3 : index
846+
# CHECK: %[[OUT:.*]] = memref.reinterpret_cast %[[ALLOC]] to offset: [%[[C3]]], sizes: [6, 1], strides: [1, 1] : memref<2x3xf32> to memref<6x1xf32, strided<[1, 1], offset: ?>>
847+
848+
filecheck_with_comments(ctx.module)

0 commit comments

Comments
 (0)