-
Notifications
You must be signed in to change notification settings - Fork 13.3k
[MLIR][CAPI][python] expose the python binding for linalgOp.getIndexingMaps #136054
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@llvm/pr-subscribers-mlir-linalg @llvm/pr-subscribers-mlir Author: Bangtian Liu (bangtianliu) ChangesThis PR is mainly about exposing the python bindings for Full diff: https://github.com/llvm/llvm-project/pull/136054.diff 4 Files Affected:
diff --git a/mlir/include/mlir-c/Dialect/Linalg.h b/mlir/include/mlir-c/Dialect/Linalg.h
index 4f2ee0d434222..339e63d667c5e 100644
--- a/mlir/include/mlir-c/Dialect/Linalg.h
+++ b/mlir/include/mlir-c/Dialect/Linalg.h
@@ -50,6 +50,9 @@ typedef struct MlirLinalgConvolutionDimensions {
MLIR_CAPI_EXPORTED MlirLinalgConvolutionDimensions
mlirLinalgInferConvolutionDimensions(MlirOperation op);
+MLIR_CAPI_EXPORTED MlirAttribute
+mlirLinalgGetIndexingMapsAttribute(MlirOperation op);
+
MLIR_DECLARE_CAPI_DIALECT_REGISTRATION(Linalg, linalg);
#ifdef __cplusplus
diff --git a/mlir/lib/Bindings/Python/DialectLinalg.cpp b/mlir/lib/Bindings/Python/DialectLinalg.cpp
index ce1102a3b3498..e18d218868441 100644
--- a/mlir/lib/Bindings/Python/DialectLinalg.cpp
+++ b/mlir/lib/Bindings/Python/DialectLinalg.cpp
@@ -120,6 +120,10 @@ static void populateDialectLinalgSubmodule(nb::module_ m) {
m.def("infer_convolution_dimensions", &InferConvolutionDimensions,
"Infers convolution dimensions", nb::arg("op"));
+
+ m.def("get_indexing_maps_attr", &mlirLinalgGetIndexingMapsAttribute,
+ "Returns the indexing_maps or memoized_indexing_maps attribute for a Linalg op.",
+ nb::arg("op"));
}
NB_MODULE(_mlirDialectsLinalg, m) {
diff --git a/mlir/lib/CAPI/Dialect/Linalg.cpp b/mlir/lib/CAPI/Dialect/Linalg.cpp
index 7c456102a2c0c..f940ed857fc9e 100644
--- a/mlir/lib/CAPI/Dialect/Linalg.cpp
+++ b/mlir/lib/CAPI/Dialect/Linalg.cpp
@@ -120,4 +120,13 @@ mlirLinalgInferConvolutionDimensions(MlirOperation op) {
return result;
}
+MLIR_CAPI_EXPORTED MlirAttribute mlirLinalgGetIndexingMapsAttribute(MlirOperation op) {
+ auto linalgOp = mlir::dyn_cast<mlir::linalg::LinalgOp>(unwrap(op));
+ if (!linalgOp)
+ return MlirAttribute{nullptr};
+
+ ArrayAttr attr = linalgOp.getIndexingMaps();
+ return wrap(attr);
+}
+
MLIR_DEFINE_CAPI_DIALECT_REGISTRATION(Linalg, linalg, LinalgDialect)
diff --git a/mlir/test/python/dialects/linalg/utils.py b/mlir/test/python/dialects/linalg/utils.py
index 98157b0e443cf..0bdb52c328778 100644
--- a/mlir/test/python/dialects/linalg/utils.py
+++ b/mlir/test/python/dialects/linalg/utils.py
@@ -159,3 +159,50 @@ def dyn_conv_fn(input, filter, output):
assert list(dims.depth) == []
assert list(dims.strides) == [1, 1]
assert list(dims.dilations) == [1, 1]
+
+
+@run
+def test_get_indexing_maps_attr():
+ with Context(), Location.unknown():
+ module = Module.create()
+ f32 = F32Type.get()
+ with InsertionPoint(module.body):
+ a_type = RankedTensorType.get((4, 8), f32)
+ b_type = RankedTensorType.get((8, 16), f32)
+ c_type = RankedTensorType.get((4, 16), f32)
+
+ dim_m = AffineDimExpr.get(0)
+ dim_n = AffineDimExpr.get(1)
+ dim_k = AffineDimExpr.get(2)
+
+ a_map = AffineMap.get(3, 0, [dim_m, dim_k])
+ b_map = AffineMap.get(3, 0, [dim_k, dim_n])
+ c_map = AffineMap.get(3, 0, [dim_m, dim_n])
+
+ @func.FuncOp.from_py_func(a_type, b_type, c_type)
+ def matmul_func(a, b, c):
+ zero = arith.ConstantOp(value=FloatAttr.get(f32, 0.0), result=f32)
+ init = linalg.fill(zero, outs=[c])
+ fill_op = init.owner
+
+ fill_maps = linalg.get_indexing_maps_attr(fill_op)
+ assert fill_maps is not None
+ assert len(fill_maps) == 2
+
+ # The fill op should have maps like (d0, d1) -> () and (d0, d1).
+ fill_input_map = fill_maps[0].value
+ print(type(fill_input_map))
+ fill_output_map = fill_maps[1].value
+ assert fill_input_map == AffineMap.get(2, 0, [])
+ assert fill_output_map == AffineMap.get(2, 0, [dim_m, dim_n])
+
+ result = linalg.matmul(a, b, outs=(init,))
+ matmul_op = result.owner
+ matmul_maps = linalg.get_indexing_maps_attr(matmul_op)
+ assert matmul_maps is not None
+ assert len(matmul_maps) == 3
+
+ maps = [map_attr.value for map_attr in matmul_maps]
+ assert maps[0] == a_map
+ assert maps[1] == b_map
+ assert maps[2] == c_map
|
✅ With the latest revision this PR passed the C/C++ code formatter. |
19403f3
to
abd4d53
Compare
…ngMap Signed-off-by: Bangtian Liu <[email protected]>
abd4d53
to
d3686bd
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM!
Signed-off-by: Bangtian Liu <[email protected]>
Signed-off-by: Bangtian Liu <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thanks for the nit change
This PR is mainly about exposing the python bindings for
linalgOp.getIndexingMaps
.