From d3686bd8ac27debcb8defe11543ab80b1634c0bf Mon Sep 17 00:00:00 2001 From: Bangtian Liu Date: Wed, 16 Apr 2025 16:29:00 -0700 Subject: [PATCH 1/3] [MLIR][CAPI][python] expose the python binding for linalgOp.getIndexingMap Signed-off-by: Bangtian Liu --- mlir/include/mlir-c/Dialect/Linalg.h | 3 ++ mlir/lib/Bindings/Python/DialectLinalg.cpp | 5 +++ mlir/lib/CAPI/Dialect/Linalg.cpp | 10 +++++ mlir/test/python/dialects/linalg/utils.py | 47 ++++++++++++++++++++++ 4 files changed, 65 insertions(+) diff --git a/mlir/include/mlir-c/Dialect/Linalg.h b/mlir/include/mlir-c/Dialect/Linalg.h index 4f2ee0d434222..339e63d667c5e 100644 --- a/mlir/include/mlir-c/Dialect/Linalg.h +++ b/mlir/include/mlir-c/Dialect/Linalg.h @@ -50,6 +50,9 @@ typedef struct MlirLinalgConvolutionDimensions { MLIR_CAPI_EXPORTED MlirLinalgConvolutionDimensions mlirLinalgInferConvolutionDimensions(MlirOperation op); +MLIR_CAPI_EXPORTED MlirAttribute +mlirLinalgGetIndexingMapsAttribute(MlirOperation op); + MLIR_DECLARE_CAPI_DIALECT_REGISTRATION(Linalg, linalg); #ifdef __cplusplus diff --git a/mlir/lib/Bindings/Python/DialectLinalg.cpp b/mlir/lib/Bindings/Python/DialectLinalg.cpp index ce1102a3b3498..8f8519a8b9950 100644 --- a/mlir/lib/Bindings/Python/DialectLinalg.cpp +++ b/mlir/lib/Bindings/Python/DialectLinalg.cpp @@ -120,6 +120,11 @@ static void populateDialectLinalgSubmodule(nb::module_ m) { m.def("infer_convolution_dimensions", &InferConvolutionDimensions, "Infers convolution dimensions", nb::arg("op")); + + m.def("get_indexing_maps_attr", &mlirLinalgGetIndexingMapsAttribute, + "Returns the indexing_maps or memoized_indexing_maps attribute for a " + "Linalg op.", + nb::arg("op")); } NB_MODULE(_mlirDialectsLinalg, m) { diff --git a/mlir/lib/CAPI/Dialect/Linalg.cpp b/mlir/lib/CAPI/Dialect/Linalg.cpp index 7c456102a2c0c..0c4f6e88e7078 100644 --- a/mlir/lib/CAPI/Dialect/Linalg.cpp +++ b/mlir/lib/CAPI/Dialect/Linalg.cpp @@ -120,4 +120,14 @@ mlirLinalgInferConvolutionDimensions(MlirOperation op) { return result; } +MLIR_CAPI_EXPORTED MlirAttribute +mlirLinalgGetIndexingMapsAttribute(MlirOperation op) { + auto linalgOp = llvm::dyn_cast(unwrap(op)); + if (!linalgOp) + return MlirAttribute{nullptr}; + + ArrayAttr attr = linalgOp.getIndexingMaps(); + return wrap(attr); +} + MLIR_DEFINE_CAPI_DIALECT_REGISTRATION(Linalg, linalg, LinalgDialect) diff --git a/mlir/test/python/dialects/linalg/utils.py b/mlir/test/python/dialects/linalg/utils.py index 98157b0e443cf..0bdb52c328778 100644 --- a/mlir/test/python/dialects/linalg/utils.py +++ b/mlir/test/python/dialects/linalg/utils.py @@ -159,3 +159,50 @@ def dyn_conv_fn(input, filter, output): assert list(dims.depth) == [] assert list(dims.strides) == [1, 1] assert list(dims.dilations) == [1, 1] + + +@run +def test_get_indexing_maps_attr(): + with Context(), Location.unknown(): + module = Module.create() + f32 = F32Type.get() + with InsertionPoint(module.body): + a_type = RankedTensorType.get((4, 8), f32) + b_type = RankedTensorType.get((8, 16), f32) + c_type = RankedTensorType.get((4, 16), f32) + + dim_m = AffineDimExpr.get(0) + dim_n = AffineDimExpr.get(1) + dim_k = AffineDimExpr.get(2) + + a_map = AffineMap.get(3, 0, [dim_m, dim_k]) + b_map = AffineMap.get(3, 0, [dim_k, dim_n]) + c_map = AffineMap.get(3, 0, [dim_m, dim_n]) + + @func.FuncOp.from_py_func(a_type, b_type, c_type) + def matmul_func(a, b, c): + zero = arith.ConstantOp(value=FloatAttr.get(f32, 0.0), result=f32) + init = linalg.fill(zero, outs=[c]) + fill_op = init.owner + + fill_maps = linalg.get_indexing_maps_attr(fill_op) + assert fill_maps is not None + assert len(fill_maps) == 2 + + # The fill op should have maps like (d0, d1) -> () and (d0, d1). + fill_input_map = fill_maps[0].value + print(type(fill_input_map)) + fill_output_map = fill_maps[1].value + assert fill_input_map == AffineMap.get(2, 0, []) + assert fill_output_map == AffineMap.get(2, 0, [dim_m, dim_n]) + + result = linalg.matmul(a, b, outs=(init,)) + matmul_op = result.owner + matmul_maps = linalg.get_indexing_maps_attr(matmul_op) + assert matmul_maps is not None + assert len(matmul_maps) == 3 + + maps = [map_attr.value for map_attr in matmul_maps] + assert maps[0] == a_map + assert maps[1] == b_map + assert maps[2] == c_map From ff09577c6ab748234c9ebfbb4df5830e3d0aaca3 Mon Sep 17 00:00:00 2001 From: Bangtian Liu Date: Wed, 16 Apr 2025 22:20:30 -0700 Subject: [PATCH 2/3] add test for non-linalg op Signed-off-by: Bangtian Liu --- mlir/lib/Bindings/Python/DialectLinalg.cpp | 13 +++++++++---- mlir/test/python/dialects/linalg/utils.py | 6 ++++-- 2 files changed, 13 insertions(+), 6 deletions(-) diff --git a/mlir/lib/Bindings/Python/DialectLinalg.cpp b/mlir/lib/Bindings/Python/DialectLinalg.cpp index 8f8519a8b9950..b99deb05e96a8 100644 --- a/mlir/lib/Bindings/Python/DialectLinalg.cpp +++ b/mlir/lib/Bindings/Python/DialectLinalg.cpp @@ -121,10 +121,15 @@ static void populateDialectLinalgSubmodule(nb::module_ m) { m.def("infer_convolution_dimensions", &InferConvolutionDimensions, "Infers convolution dimensions", nb::arg("op")); - m.def("get_indexing_maps_attr", &mlirLinalgGetIndexingMapsAttribute, - "Returns the indexing_maps or memoized_indexing_maps attribute for a " - "Linalg op.", - nb::arg("op")); + m.def( + "get_indexing_maps_attr", + [](MlirOperation op) -> std::optional { + MlirAttribute attr = mlirLinalgGetIndexingMapsAttribute(op); + if (mlirAttributeIsNull(attr)) + return std::nullopt; + return attr; + }, + "Returns the indexing_maps attribute for a linalg op."); } NB_MODULE(_mlirDialectsLinalg, m) { diff --git a/mlir/test/python/dialects/linalg/utils.py b/mlir/test/python/dialects/linalg/utils.py index 0bdb52c328778..48cbef1cf1d85 100644 --- a/mlir/test/python/dialects/linalg/utils.py +++ b/mlir/test/python/dialects/linalg/utils.py @@ -182,16 +182,18 @@ def test_get_indexing_maps_attr(): @func.FuncOp.from_py_func(a_type, b_type, c_type) def matmul_func(a, b, c): zero = arith.ConstantOp(value=FloatAttr.get(f32, 0.0), result=f32) + assert not linalg.get_indexing_maps_attr( + zero.operation + ), "Expected no indexing_maps on non-linalg op" + init = linalg.fill(zero, outs=[c]) fill_op = init.owner - fill_maps = linalg.get_indexing_maps_attr(fill_op) assert fill_maps is not None assert len(fill_maps) == 2 # The fill op should have maps like (d0, d1) -> () and (d0, d1). fill_input_map = fill_maps[0].value - print(type(fill_input_map)) fill_output_map = fill_maps[1].value assert fill_input_map == AffineMap.get(2, 0, []) assert fill_output_map == AffineMap.get(2, 0, [dim_m, dim_n]) From 3abc7fdd5805d50af1a0cf882cf402f2efd508ba Mon Sep 17 00:00:00 2001 From: Bangtian Liu Date: Thu, 17 Apr 2025 08:55:55 -0700 Subject: [PATCH 3/3] address reviewer comments Signed-off-by: Bangtian Liu --- mlir/lib/Bindings/Python/DialectLinalg.cpp | 2 +- mlir/test/python/dialects/linalg/utils.py | 6 +++--- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/mlir/lib/Bindings/Python/DialectLinalg.cpp b/mlir/lib/Bindings/Python/DialectLinalg.cpp index b99deb05e96a8..015502371c65b 100644 --- a/mlir/lib/Bindings/Python/DialectLinalg.cpp +++ b/mlir/lib/Bindings/Python/DialectLinalg.cpp @@ -122,7 +122,7 @@ static void populateDialectLinalgSubmodule(nb::module_ m) { "Infers convolution dimensions", nb::arg("op")); m.def( - "get_indexing_maps_attr", + "get_indexing_maps", [](MlirOperation op) -> std::optional { MlirAttribute attr = mlirLinalgGetIndexingMapsAttribute(op); if (mlirAttributeIsNull(attr)) diff --git a/mlir/test/python/dialects/linalg/utils.py b/mlir/test/python/dialects/linalg/utils.py index 48cbef1cf1d85..5f7cb6a6c83cb 100644 --- a/mlir/test/python/dialects/linalg/utils.py +++ b/mlir/test/python/dialects/linalg/utils.py @@ -182,13 +182,13 @@ def test_get_indexing_maps_attr(): @func.FuncOp.from_py_func(a_type, b_type, c_type) def matmul_func(a, b, c): zero = arith.ConstantOp(value=FloatAttr.get(f32, 0.0), result=f32) - assert not linalg.get_indexing_maps_attr( + assert not linalg.get_indexing_maps( zero.operation ), "Expected no indexing_maps on non-linalg op" init = linalg.fill(zero, outs=[c]) fill_op = init.owner - fill_maps = linalg.get_indexing_maps_attr(fill_op) + fill_maps = linalg.get_indexing_maps(fill_op) assert fill_maps is not None assert len(fill_maps) == 2 @@ -200,7 +200,7 @@ def matmul_func(a, b, c): result = linalg.matmul(a, b, outs=(init,)) matmul_op = result.owner - matmul_maps = linalg.get_indexing_maps_attr(matmul_op) + matmul_maps = linalg.get_indexing_maps(matmul_op) assert matmul_maps is not None assert len(matmul_maps) == 3