Skip to content

[MLIR][CAPI][python] expose the python binding for linalgOp.getIndexingMaps #136054

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Apr 17, 2025

Conversation

bangtianliu
Copy link
Contributor

This PR is mainly about exposing the python bindings for linalgOp.getIndexingMaps.

@llvmbot
Copy link
Member

llvmbot commented Apr 16, 2025

@llvm/pr-subscribers-mlir-linalg

@llvm/pr-subscribers-mlir

Author: Bangtian Liu (bangtianliu)

Changes

This PR is mainly about exposing the python bindings for linalgOp.getIndexingMaps.


Full diff: https://github.com/llvm/llvm-project/pull/136054.diff

4 Files Affected:

  • (modified) mlir/include/mlir-c/Dialect/Linalg.h (+3)
  • (modified) mlir/lib/Bindings/Python/DialectLinalg.cpp (+4)
  • (modified) mlir/lib/CAPI/Dialect/Linalg.cpp (+9)
  • (modified) mlir/test/python/dialects/linalg/utils.py (+47)
diff --git a/mlir/include/mlir-c/Dialect/Linalg.h b/mlir/include/mlir-c/Dialect/Linalg.h
index 4f2ee0d434222..339e63d667c5e 100644
--- a/mlir/include/mlir-c/Dialect/Linalg.h
+++ b/mlir/include/mlir-c/Dialect/Linalg.h
@@ -50,6 +50,9 @@ typedef struct MlirLinalgConvolutionDimensions {
 MLIR_CAPI_EXPORTED MlirLinalgConvolutionDimensions
 mlirLinalgInferConvolutionDimensions(MlirOperation op);
 
+MLIR_CAPI_EXPORTED MlirAttribute
+mlirLinalgGetIndexingMapsAttribute(MlirOperation op);
+
 MLIR_DECLARE_CAPI_DIALECT_REGISTRATION(Linalg, linalg);
 
 #ifdef __cplusplus
diff --git a/mlir/lib/Bindings/Python/DialectLinalg.cpp b/mlir/lib/Bindings/Python/DialectLinalg.cpp
index ce1102a3b3498..e18d218868441 100644
--- a/mlir/lib/Bindings/Python/DialectLinalg.cpp
+++ b/mlir/lib/Bindings/Python/DialectLinalg.cpp
@@ -120,6 +120,10 @@ static void populateDialectLinalgSubmodule(nb::module_ m) {
 
   m.def("infer_convolution_dimensions", &InferConvolutionDimensions,
         "Infers convolution dimensions", nb::arg("op"));
+  
+  m.def("get_indexing_maps_attr", &mlirLinalgGetIndexingMapsAttribute,
+          "Returns the indexing_maps or memoized_indexing_maps attribute for a Linalg op.",
+          nb::arg("op"));
 }
 
 NB_MODULE(_mlirDialectsLinalg, m) {
diff --git a/mlir/lib/CAPI/Dialect/Linalg.cpp b/mlir/lib/CAPI/Dialect/Linalg.cpp
index 7c456102a2c0c..f940ed857fc9e 100644
--- a/mlir/lib/CAPI/Dialect/Linalg.cpp
+++ b/mlir/lib/CAPI/Dialect/Linalg.cpp
@@ -120,4 +120,13 @@ mlirLinalgInferConvolutionDimensions(MlirOperation op) {
   return result;
 }
 
+MLIR_CAPI_EXPORTED MlirAttribute mlirLinalgGetIndexingMapsAttribute(MlirOperation op) {
+  auto linalgOp = mlir::dyn_cast<mlir::linalg::LinalgOp>(unwrap(op));
+  if (!linalgOp)
+    return MlirAttribute{nullptr};
+
+  ArrayAttr attr = linalgOp.getIndexingMaps();
+  return wrap(attr);
+}
+
 MLIR_DEFINE_CAPI_DIALECT_REGISTRATION(Linalg, linalg, LinalgDialect)
diff --git a/mlir/test/python/dialects/linalg/utils.py b/mlir/test/python/dialects/linalg/utils.py
index 98157b0e443cf..0bdb52c328778 100644
--- a/mlir/test/python/dialects/linalg/utils.py
+++ b/mlir/test/python/dialects/linalg/utils.py
@@ -159,3 +159,50 @@ def dyn_conv_fn(input, filter, output):
                 assert list(dims.depth) == []
                 assert list(dims.strides) == [1, 1]
                 assert list(dims.dilations) == [1, 1]
+
+
+@run
+def test_get_indexing_maps_attr():
+    with Context(), Location.unknown():
+        module = Module.create()
+        f32 = F32Type.get()
+        with InsertionPoint(module.body):
+            a_type = RankedTensorType.get((4, 8), f32)
+            b_type = RankedTensorType.get((8, 16), f32)
+            c_type = RankedTensorType.get((4, 16), f32)
+
+            dim_m = AffineDimExpr.get(0)
+            dim_n = AffineDimExpr.get(1)
+            dim_k = AffineDimExpr.get(2)
+
+            a_map = AffineMap.get(3, 0, [dim_m, dim_k])
+            b_map = AffineMap.get(3, 0, [dim_k, dim_n])
+            c_map = AffineMap.get(3, 0, [dim_m, dim_n])
+
+            @func.FuncOp.from_py_func(a_type, b_type, c_type)
+            def matmul_func(a, b, c):
+                zero = arith.ConstantOp(value=FloatAttr.get(f32, 0.0), result=f32)
+                init = linalg.fill(zero, outs=[c])
+                fill_op = init.owner
+
+                fill_maps = linalg.get_indexing_maps_attr(fill_op)
+                assert fill_maps is not None
+                assert len(fill_maps) == 2
+
+                # The fill op should have maps like (d0, d1) -> () and (d0, d1).
+                fill_input_map = fill_maps[0].value
+                print(type(fill_input_map))
+                fill_output_map = fill_maps[1].value
+                assert fill_input_map == AffineMap.get(2, 0, [])
+                assert fill_output_map == AffineMap.get(2, 0, [dim_m, dim_n])
+
+                result = linalg.matmul(a, b, outs=(init,))
+                matmul_op = result.owner
+                matmul_maps = linalg.get_indexing_maps_attr(matmul_op)
+                assert matmul_maps is not None
+                assert len(matmul_maps) == 3
+
+                maps = [map_attr.value for map_attr in matmul_maps]
+                assert maps[0] == a_map
+                assert maps[1] == b_map
+                assert maps[2] == c_map

Copy link

github-actions bot commented Apr 16, 2025

✅ With the latest revision this PR passed the C/C++ code formatter.

@bangtianliu bangtianliu force-pushed the python_binging_indexmap branch from 19403f3 to abd4d53 Compare April 16, 2025 23:37
@bangtianliu bangtianliu force-pushed the python_binging_indexmap branch from abd4d53 to d3686bd Compare April 17, 2025 00:43
Copy link
Contributor

@makslevental makslevental left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM!

Signed-off-by: Bangtian Liu <[email protected]>
Signed-off-by: Bangtian Liu <[email protected]>
@bangtianliu
Copy link
Contributor Author

cc @kuhar @Max191

Copy link
Contributor

@makslevental makslevental left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks for the nit change

@kuhar kuhar merged commit 7119b0c into llvm:main Apr 17, 2025
11 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants