Skip to content

Commit abd4d53

Browse files
committed
[MLIR][CAPI][python] expose the python binding for linalgOp.getIndexingMap
Signed-off-by: Bangtian Liu <[email protected]>
1 parent 4863d1f commit abd4d53

File tree

4 files changed

+65
-0
lines changed

4 files changed

+65
-0
lines changed

Diff for: mlir/include/mlir-c/Dialect/Linalg.h

+3
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,9 @@ typedef struct MlirLinalgConvolutionDimensions {
5050
MLIR_CAPI_EXPORTED MlirLinalgConvolutionDimensions
5151
mlirLinalgInferConvolutionDimensions(MlirOperation op);
5252

53+
MLIR_CAPI_EXPORTED MlirAttribute
54+
mlirLinalgGetIndexingMapsAttribute(MlirOperation op);
55+
5356
MLIR_DECLARE_CAPI_DIALECT_REGISTRATION(Linalg, linalg);
5457

5558
#ifdef __cplusplus

Diff for: mlir/lib/Bindings/Python/DialectLinalg.cpp

+5
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,11 @@ static void populateDialectLinalgSubmodule(nb::module_ m) {
120120

121121
m.def("infer_convolution_dimensions", &InferConvolutionDimensions,
122122
"Infers convolution dimensions", nb::arg("op"));
123+
124+
m.def("get_indexing_maps_attr", &mlirLinalgGetIndexingMapsAttribute,
125+
"Returns the indexing_maps or memoized_indexing_maps attribute for a "
126+
"Linalg op.",
127+
nb::arg("op"));
123128
}
124129

125130
NB_MODULE(_mlirDialectsLinalg, m) {

Diff for: mlir/lib/CAPI/Dialect/Linalg.cpp

+10
Original file line numberDiff line numberDiff line change
@@ -120,4 +120,14 @@ mlirLinalgInferConvolutionDimensions(MlirOperation op) {
120120
return result;
121121
}
122122

123+
MLIR_CAPI_EXPORTED MlirAttribute
124+
mlirLinalgGetIndexingMapsAttribute(MlirOperation op) {
125+
auto linalgOp = mlir::dyn_cast<mlir::linalg::LinalgOp>(unwrap(op));
126+
if (!linalgOp)
127+
return MlirAttribute{nullptr};
128+
129+
ArrayAttr attr = linalgOp.getIndexingMaps();
130+
return wrap(attr);
131+
}
132+
123133
MLIR_DEFINE_CAPI_DIALECT_REGISTRATION(Linalg, linalg, LinalgDialect)

Diff for: mlir/test/python/dialects/linalg/utils.py

+47
Original file line numberDiff line numberDiff line change
@@ -159,3 +159,50 @@ def dyn_conv_fn(input, filter, output):
159159
assert list(dims.depth) == []
160160
assert list(dims.strides) == [1, 1]
161161
assert list(dims.dilations) == [1, 1]
162+
163+
164+
@run
165+
def test_get_indexing_maps_attr():
166+
with Context(), Location.unknown():
167+
module = Module.create()
168+
f32 = F32Type.get()
169+
with InsertionPoint(module.body):
170+
a_type = RankedTensorType.get((4, 8), f32)
171+
b_type = RankedTensorType.get((8, 16), f32)
172+
c_type = RankedTensorType.get((4, 16), f32)
173+
174+
dim_m = AffineDimExpr.get(0)
175+
dim_n = AffineDimExpr.get(1)
176+
dim_k = AffineDimExpr.get(2)
177+
178+
a_map = AffineMap.get(3, 0, [dim_m, dim_k])
179+
b_map = AffineMap.get(3, 0, [dim_k, dim_n])
180+
c_map = AffineMap.get(3, 0, [dim_m, dim_n])
181+
182+
@func.FuncOp.from_py_func(a_type, b_type, c_type)
183+
def matmul_func(a, b, c):
184+
zero = arith.ConstantOp(value=FloatAttr.get(f32, 0.0), result=f32)
185+
init = linalg.fill(zero, outs=[c])
186+
fill_op = init.owner
187+
188+
fill_maps = linalg.get_indexing_maps_attr(fill_op)
189+
assert fill_maps is not None
190+
assert len(fill_maps) == 2
191+
192+
# The fill op should have maps like (d0, d1) -> () and (d0, d1).
193+
fill_input_map = fill_maps[0].value
194+
print(type(fill_input_map))
195+
fill_output_map = fill_maps[1].value
196+
assert fill_input_map == AffineMap.get(2, 0, [])
197+
assert fill_output_map == AffineMap.get(2, 0, [dim_m, dim_n])
198+
199+
result = linalg.matmul(a, b, outs=(init,))
200+
matmul_op = result.owner
201+
matmul_maps = linalg.get_indexing_maps_attr(matmul_op)
202+
assert matmul_maps is not None
203+
assert len(matmul_maps) == 3
204+
205+
maps = [map_attr.value for map_attr in matmul_maps]
206+
assert maps[0] == a_map
207+
assert maps[1] == b_map
208+
assert maps[2] == c_map

0 commit comments

Comments
 (0)