Skip to content

Commit 7119b0c

Browse files
authored
[MLIR][CAPI][python] expose the python binding for linalgOp.getIndexingMaps (#136054)
This PR is mainly about exposing the python bindings for `linalgOp.getIndexingMaps`. --------- Signed-off-by: Bangtian Liu <[email protected]>
1 parent b3a53cc commit 7119b0c

File tree

4 files changed

+72
-0
lines changed

4 files changed

+72
-0
lines changed

Diff for: mlir/include/mlir-c/Dialect/Linalg.h

+3
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,9 @@ typedef struct MlirLinalgConvolutionDimensions {
5050
MLIR_CAPI_EXPORTED MlirLinalgConvolutionDimensions
5151
mlirLinalgInferConvolutionDimensions(MlirOperation op);
5252

53+
MLIR_CAPI_EXPORTED MlirAttribute
54+
mlirLinalgGetIndexingMapsAttribute(MlirOperation op);
55+
5356
MLIR_DECLARE_CAPI_DIALECT_REGISTRATION(Linalg, linalg);
5457

5558
#ifdef __cplusplus

Diff for: mlir/lib/Bindings/Python/DialectLinalg.cpp

+10
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,16 @@ static void populateDialectLinalgSubmodule(nb::module_ m) {
120120

121121
m.def("infer_convolution_dimensions", &InferConvolutionDimensions,
122122
"Infers convolution dimensions", nb::arg("op"));
123+
124+
m.def(
125+
"get_indexing_maps",
126+
[](MlirOperation op) -> std::optional<MlirAttribute> {
127+
MlirAttribute attr = mlirLinalgGetIndexingMapsAttribute(op);
128+
if (mlirAttributeIsNull(attr))
129+
return std::nullopt;
130+
return attr;
131+
},
132+
"Returns the indexing_maps attribute for a linalg op.");
123133
}
124134

125135
NB_MODULE(_mlirDialectsLinalg, m) {

Diff for: mlir/lib/CAPI/Dialect/Linalg.cpp

+10
Original file line numberDiff line numberDiff line change
@@ -120,4 +120,14 @@ mlirLinalgInferConvolutionDimensions(MlirOperation op) {
120120
return result;
121121
}
122122

123+
MLIR_CAPI_EXPORTED MlirAttribute
124+
mlirLinalgGetIndexingMapsAttribute(MlirOperation op) {
125+
auto linalgOp = llvm::dyn_cast<mlir::linalg::LinalgOp>(unwrap(op));
126+
if (!linalgOp)
127+
return MlirAttribute{nullptr};
128+
129+
ArrayAttr attr = linalgOp.getIndexingMaps();
130+
return wrap(attr);
131+
}
132+
123133
MLIR_DEFINE_CAPI_DIALECT_REGISTRATION(Linalg, linalg, LinalgDialect)

Diff for: mlir/test/python/dialects/linalg/utils.py

+49
Original file line numberDiff line numberDiff line change
@@ -159,3 +159,52 @@ def dyn_conv_fn(input, filter, output):
159159
assert list(dims.depth) == []
160160
assert list(dims.strides) == [1, 1]
161161
assert list(dims.dilations) == [1, 1]
162+
163+
164+
@run
165+
def test_get_indexing_maps_attr():
166+
with Context(), Location.unknown():
167+
module = Module.create()
168+
f32 = F32Type.get()
169+
with InsertionPoint(module.body):
170+
a_type = RankedTensorType.get((4, 8), f32)
171+
b_type = RankedTensorType.get((8, 16), f32)
172+
c_type = RankedTensorType.get((4, 16), f32)
173+
174+
dim_m = AffineDimExpr.get(0)
175+
dim_n = AffineDimExpr.get(1)
176+
dim_k = AffineDimExpr.get(2)
177+
178+
a_map = AffineMap.get(3, 0, [dim_m, dim_k])
179+
b_map = AffineMap.get(3, 0, [dim_k, dim_n])
180+
c_map = AffineMap.get(3, 0, [dim_m, dim_n])
181+
182+
@func.FuncOp.from_py_func(a_type, b_type, c_type)
183+
def matmul_func(a, b, c):
184+
zero = arith.ConstantOp(value=FloatAttr.get(f32, 0.0), result=f32)
185+
assert not linalg.get_indexing_maps(
186+
zero.operation
187+
), "Expected no indexing_maps on non-linalg op"
188+
189+
init = linalg.fill(zero, outs=[c])
190+
fill_op = init.owner
191+
fill_maps = linalg.get_indexing_maps(fill_op)
192+
assert fill_maps is not None
193+
assert len(fill_maps) == 2
194+
195+
# The fill op should have maps like (d0, d1) -> () and (d0, d1).
196+
fill_input_map = fill_maps[0].value
197+
fill_output_map = fill_maps[1].value
198+
assert fill_input_map == AffineMap.get(2, 0, [])
199+
assert fill_output_map == AffineMap.get(2, 0, [dim_m, dim_n])
200+
201+
result = linalg.matmul(a, b, outs=(init,))
202+
matmul_op = result.owner
203+
matmul_maps = linalg.get_indexing_maps(matmul_op)
204+
assert matmul_maps is not None
205+
assert len(matmul_maps) == 3
206+
207+
maps = [map_attr.value for map_attr in matmul_maps]
208+
assert maps[0] == a_map
209+
assert maps[1] == b_map
210+
assert maps[2] == c_map

0 commit comments

Comments
 (0)