diff --git a/doc/releases/changelog-dev.md b/doc/releases/changelog-dev.md
index e0b462de2c..198b1def28 100644
--- a/doc/releases/changelog-dev.md
+++ b/doc/releases/changelog-dev.md
@@ -166,6 +166,7 @@
The new mlir bufferization interface is required by jax 0.4.29 or higher.
[(#1027)](https://github.com/PennyLaneAI/catalyst/pull/1027)
[(#1686)](https://github.com/PennyLaneAI/catalyst/pull/1686)
+ [(#1708)](https://github.com/PennyLaneAI/catalyst/pull/1708)
Documentation 📝
diff --git a/frontend/catalyst/pipelines.py b/frontend/catalyst/pipelines.py
index 49ece83c56..996f6a2c28 100644
--- a/frontend/catalyst/pipelines.py
+++ b/frontend/catalyst/pipelines.py
@@ -97,7 +97,7 @@ class CompileOptions:
def __post_init__(self):
# Check that async runs must not be seeded
- if self.async_qnodes and self.seed != None:
+ if self.async_qnodes and self.seed is not None:
raise CompileError(
"""
Seeding has no effect on asynchronous QNodes,
@@ -107,7 +107,7 @@ def __post_init__(self):
)
# Check that seed is 32-bit unsigned int
- if (self.seed != None) and (self.seed < 0 or self.seed > 2**32 - 1):
+ if (self.seed is not None) and (self.seed < 0 or self.seed > 2**32 - 1):
raise ValueError(
"""
Seed must be an unsigned 32-bit integer!
@@ -227,7 +227,8 @@ def get_bufferization_stage(_options: CompileOptions) -> List[str]:
"empty-tensor-to-alloc-tensor",
"func.func(bufferization-bufferize)",
"func.func(tensor-bufferize)",
- "catalyst-bufferize", # Must be run before -- func.func(linalg-bufferize)
+ # Catalyst dialect's bufferization must be run before --func.func(linalg-bufferize)
+ "one-shot-bufferize{dialect-filter=catalyst unknown-type-conversion=identity-layout-map}",
"func.func(linalg-bufferize)",
"func.func(tensor-bufferize)",
"one-shot-bufferize{dialect-filter=quantum}",
diff --git a/mlir/include/Catalyst/Transforms/BufferizableOpInterfaceImpl.h b/mlir/include/Catalyst/Transforms/BufferizableOpInterfaceImpl.h
new file mode 100644
index 0000000000..89eb431f8f
--- /dev/null
+++ b/mlir/include/Catalyst/Transforms/BufferizableOpInterfaceImpl.h
@@ -0,0 +1,23 @@
+// Copyright 2024-2025 Xanadu Quantum Technologies Inc.
+
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+
+// http://www.apache.org/licenses/LICENSE-2.0
+
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#pragma once
+
+using namespace mlir;
+
+namespace catalyst {
+
+void registerBufferizableOpInterfaceExternalModels(mlir::DialectRegistry ®istry);
+
+} // namespace catalyst
diff --git a/mlir/include/Catalyst/Transforms/Passes.td b/mlir/include/Catalyst/Transforms/Passes.td
index 0f50286dac..d22246e3c8 100644
--- a/mlir/include/Catalyst/Transforms/Passes.td
+++ b/mlir/include/Catalyst/Transforms/Passes.td
@@ -27,18 +27,6 @@ def DetensorizeSCFPass : Pass<"detensorize-scf"> {
let constructor = "catalyst::createDetensorizeSCFPass()";
}
-def CatalystBufferizationPass : Pass<"catalyst-bufferize"> {
- let summary = "Bufferize tensors in catalyst utility ops.";
-
- let dependentDialects = [
- "bufferization::BufferizationDialect",
- "memref::MemRefDialect",
- "index::IndexDialect"
- ];
-
- let constructor = "catalyst::createCatalystBufferizationPass()";
-}
-
def ArrayListToMemRefPass : Pass<"convert-arraylist-to-memref"> {
let summary = "Lower array list operations to memref operations.";
let description = [{
diff --git a/mlir/include/Catalyst/Transforms/Patterns.h b/mlir/include/Catalyst/Transforms/Patterns.h
index cdc5157806..6bbf3150ff 100644
--- a/mlir/include/Catalyst/Transforms/Patterns.h
+++ b/mlir/include/Catalyst/Transforms/Patterns.h
@@ -21,8 +21,6 @@
namespace catalyst {
-void populateBufferizationPatterns(mlir::TypeConverter &, mlir::RewritePatternSet &);
-
void populateScatterPatterns(mlir::RewritePatternSet &);
void populateHloCustomCallPatterns(mlir::RewritePatternSet &);
diff --git a/mlir/include/Quantum/Transforms/Patterns.h b/mlir/include/Quantum/Transforms/Patterns.h
index d278f809b7..8edaf0ffe5 100644
--- a/mlir/include/Quantum/Transforms/Patterns.h
+++ b/mlir/include/Quantum/Transforms/Patterns.h
@@ -21,8 +21,6 @@
namespace catalyst {
namespace quantum {
-void populateBufferizationLegality(mlir::TypeConverter &, mlir::ConversionTarget &);
-void populateBufferizationPatterns(mlir::TypeConverter &, mlir::RewritePatternSet &);
void populateQIRConversionPatterns(mlir::TypeConverter &, mlir::RewritePatternSet &);
void populateAdjointPatterns(mlir::RewritePatternSet &);
void populateSelfInversePatterns(mlir::RewritePatternSet &);
diff --git a/mlir/lib/Catalyst/IR/CatalystDialect.cpp b/mlir/lib/Catalyst/IR/CatalystDialect.cpp
index 1dce30c4cc..aae67c64ca 100644
--- a/mlir/lib/Catalyst/IR/CatalystDialect.cpp
+++ b/mlir/lib/Catalyst/IR/CatalystDialect.cpp
@@ -14,6 +14,7 @@
#include "Catalyst/IR/CatalystDialect.h"
#include "Catalyst/IR/CatalystOps.h"
+#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/DialectImplementation.h" // needed for generated type parser
#include "mlir/Interfaces/FunctionImplementation.h"
@@ -40,6 +41,9 @@ void CatalystDialect::initialize()
#define GET_OP_LIST
#include "Catalyst/IR/CatalystOps.cpp.inc"
>();
+
+ declarePromisedInterfaces();
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Catalyst/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Catalyst/Transforms/BufferizableOpInterfaceImpl.cpp
new file mode 100644
index 0000000000..e5fb6b9bca
--- /dev/null
+++ b/mlir/lib/Catalyst/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -0,0 +1,317 @@
+// Copyright 2024-2025 Xanadu Quantum Technologies Inc.
+
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+
+// http://www.apache.org/licenses/LICENSE-2.0
+
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
+#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
+#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Transforms/DialectConversion.h"
+
+#include "Catalyst/IR/CatalystOps.h"
+#include "Catalyst/Transforms/BufferizableOpInterfaceImpl.h"
+
+using namespace mlir;
+using namespace mlir::bufferization;
+using namespace catalyst;
+
+/**
+ * Implementation of the BufferizableOpInterface for use with one-shot bufferization.
+ * For more information on the interface, refer to the documentation below:
+ * https://mlir.llvm.org/docs/Bufferization/#extending-one-shot-bufferize
+ * https://github.com/llvm/llvm-project/blob/main/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td#L14
+ */
+
+namespace {
+
+/// Bufferization of catalyst.print. Get memref of printOp.val.
+struct PrintOpInterface
+ : public bufferization::BufferizableOpInterface::ExternalModel {
+ bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
+ const bufferization::AnalysisState &state) const
+ {
+ return true;
+ }
+
+ bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
+ const bufferization::AnalysisState &state) const
+ {
+ return false;
+ }
+
+ bufferization::AliasingValueList
+ getAliasingValues(Operation *op, OpOperand &opOperand,
+ const bufferization::AnalysisState &state) const
+ {
+ return {};
+ }
+
+ LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
+ const bufferization::BufferizationOptions &options) const
+ {
+ auto printOp = cast(op);
+ if (printOp.getVal()) {
+ FailureOr source = getBuffer(rewriter, printOp.getVal(), options);
+ if (failed(source)) {
+ return failure();
+ }
+ bufferization::replaceOpWithNewBufferizedOp(
+ rewriter, op, *source, printOp.getConstValAttr(), printOp.getPrintDescriptorAttr());
+ }
+ return success();
+ }
+};
+
+/// Bufferization of catalyst.custom_call. Mainly get buffers for arguments.
+struct CustomCallOpInterface
+ : public bufferization::BufferizableOpInterface::ExternalModel {
+ bool bufferizesToAllocation(Operation *op, Value value) const { return true; }
+
+ bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
+ const bufferization::AnalysisState &state) const
+ {
+ // Custom Call Op always reads the operand memory no matter what.
+ return true;
+ }
+
+ bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
+ const bufferization::AnalysisState &state) const
+ {
+ // We only use custom call for the jax lapack kernels.
+ // This is actually hard-guarded: in the lowering pattern for custom call
+ // we check that the name of the callee is a jax symbol for a lapack kernel.
+ //
+ // The lapack kernels themselves might overwrite some of the input arrays.
+ // However, in jax's shim wrapper layer, a memcpy is already performed.
+ // See
+ // https://github.com/PennyLaneAI/catalyst/blob/main/frontend/catalyst/utils/jax_cpu_lapack_kernels/lapack_kernels.cpp
+ //
+ // The arguments to the underlying lapack kernel are denoted by the jax wrapper
+ // function as `data`. The `data` args already contain the output array that
+ // the lapack kernel is supposed to write into. The other input arrays are all marked const.
+ // Jax then purifies the function by adding a new argument `out` to hold the
+ // output array.
+ //
+ // In other words, the jax wrappers we call here with custom call op
+ // are already pure, and we won't have side effects on the input tensors.
+
+ return false;
+ }
+
+ bufferization::AliasingValueList
+ getAliasingValues(Operation *op, OpOperand &opOperand,
+ const bufferization::AnalysisState &state) const
+ {
+ return {};
+ }
+
+ LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
+ const bufferization::BufferizationOptions &options) const
+ {
+ auto customCallOp = cast(op);
+
+ // Add bufferized arguments
+ SmallVector bufferArgs;
+ ValueRange operands = customCallOp.getOperands();
+ for (Value operand : operands) {
+ FailureOr opBuffer = getBuffer(rewriter, operand, options);
+ if (failed(opBuffer)) {
+ return failure();
+ }
+ bufferArgs.push_back(*opBuffer);
+ }
+
+ // Add bufferized return values to the arguments
+ ValueRange results = customCallOp.getResults();
+ for (Value result : results) {
+ Type resultType = result.getType();
+ RankedTensorType tensorType = dyn_cast(resultType);
+ if (!tensorType) {
+ return failure();
+ }
+ auto options = bufferization::BufferizationOptions();
+ FailureOr tensorAlloc = bufferization::allocateTensorForShapedValue(
+ rewriter, op->getLoc(), result, options, false);
+ MemRefType memrefType =
+ MemRefType::get(tensorType.getShape(), tensorType.getElementType());
+ auto newBuffer =
+ rewriter.create(op->getLoc(), memrefType, *tensorAlloc);
+ bufferArgs.push_back(newBuffer);
+ }
+
+ // Add the initial number of arguments
+ int32_t numArguments = static_cast(customCallOp.getNumOperands());
+ DenseI32ArrayAttr numArgumentsDenseAttr = rewriter.getDenseI32ArrayAttr({numArguments});
+
+ // Create an updated custom call operation
+ rewriter.create(op->getLoc(), TypeRange{}, bufferArgs,
+ customCallOp.getCallTargetName(), numArgumentsDenseAttr);
+ size_t startIndex = bufferArgs.size() - customCallOp.getNumResults();
+ SmallVector bufferResults(bufferArgs.begin() + startIndex, bufferArgs.end());
+ bufferization::replaceOpWithBufferizedValues(rewriter, op, bufferResults);
+
+ return success();
+ }
+};
+
+struct CallbackOpInterface
+ : public bufferization::BufferizableOpInterface::ExternalModel {
+ bool hasTensorSemantics(Operation *op) const
+ {
+ auto isaTensor = llvm::IsaPred;
+
+ // A function has tensor semantics if it has tensor arguments/results.
+ auto callbackOp = cast(op);
+ bool hasTensorArg = any_of(callbackOp.getArgumentTypes(), isaTensor);
+ bool hasTensorResult = any_of(callbackOp.getResultTypes(), isaTensor);
+ if (hasTensorArg || hasTensorResult) {
+ return true;
+ }
+
+ return false;
+ }
+
+ LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
+ const bufferization::BufferizationOptions &options) const
+ {
+ auto callbackOp = cast(op);
+
+ auto argTys = callbackOp.getArgumentTypes();
+ auto retTys = callbackOp.getResultTypes();
+ SmallVector emptyRets;
+ SmallVector args(argTys.begin(), argTys.end());
+ args.insert(args.end(), retTys.begin(), retTys.end());
+ SmallVector bufferArgs;
+ for (Type ty : args) {
+ auto tensorType = dyn_cast(ty);
+ if (!tensorType) {
+ bufferArgs.push_back(ty);
+ }
+ else {
+ bufferArgs.push_back(
+ MemRefType::get(tensorType.getShape(), tensorType.getElementType()));
+ }
+ }
+ auto callbackTy = rewriter.getFunctionType(bufferArgs, emptyRets);
+ rewriter.modifyOpInPlace(op, [&] { callbackOp.setFunctionType(callbackTy); });
+
+ return success();
+ }
+};
+
+void convertTypes(SmallVector inTypes, SmallVector &convertedResults)
+{
+ // See https://github.com/llvm/llvm-project/pull/114155/files
+ for (Type inType : inTypes) {
+ if (isa(inType)) {
+ convertedResults.push_back(
+ bufferization::getMemRefTypeWithStaticIdentityLayout(cast(inType)));
+ }
+ else {
+ convertedResults.push_back(inType);
+ }
+ }
+}
+
+struct CallbackCallOpInterface
+ : public bufferization::BufferizableOpInterface::ExternalModel {
+ bool bufferizesToAllocation(Operation *op, Value value) const { return true; }
+
+ bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
+ const bufferization::AnalysisState &state) const
+ {
+ return true;
+ }
+
+ bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
+ const bufferization::AnalysisState &state) const
+ {
+ // We can safely say false because CallbackCallOp's memrefs
+ // will be put in a JAX array and JAX arrays are immutable.
+ //
+ // Unlike NumPy arrays, JAX arrays are always immutable.
+ //
+ // https://jax.readthedocs.io/en/latest/notebooks/thinking_in_jax.html
+ return false;
+ }
+
+ bufferization::AliasingValueList
+ getAliasingValues(Operation *op, OpOperand &opOperand,
+ const bufferization::AnalysisState &state) const
+ {
+ return {};
+ }
+
+ LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
+ const bufferization::BufferizationOptions &options) const
+ {
+ auto callOp = cast(op);
+
+ SmallVector convertedResults;
+ convertTypes(SmallVector(callOp.getResultTypes()), convertedResults);
+ if (callOp->getNumResults() != convertedResults.size()) {
+ return failure();
+ }
+
+ SmallVector newInputs;
+ auto operands = callOp.getOperands();
+ for (Value operand : operands) {
+ FailureOr opBuffer = getBuffer(rewriter, operand, options);
+ if (failed(opBuffer)) {
+ return failure();
+ }
+ newInputs.push_back(*opBuffer);
+ }
+
+ auto results = callOp.getResults();
+ auto loc = callOp->getLoc();
+ SmallVector outmemrefs;
+ for (auto result : results) {
+ FailureOr tensorAlloc =
+ bufferization::allocateTensorForShapedValue(rewriter, loc, result, options, false);
+ if (failed(tensorAlloc)) {
+ return failure();
+ }
+
+ auto tensor = *tensorAlloc;
+ RankedTensorType tensorTy = cast(tensor.getType());
+ auto shape = tensorTy.getShape();
+ auto elementTy = tensorTy.getElementType();
+ auto memrefType = MemRefType::get(shape, elementTy);
+ auto toMemrefOp = rewriter.create(loc, memrefType, tensor);
+ auto memref = toMemrefOp.getResult();
+ outmemrefs.push_back(memref);
+ newInputs.push_back(memref);
+ }
+
+ SmallVector emptyRets;
+ rewriter.create(loc, emptyRets, callOp.getCallee(), newInputs);
+ bufferization::replaceOpWithBufferizedValues(rewriter, op, outmemrefs);
+ return success();
+ }
+};
+
+} // namespace
+
+void catalyst::registerBufferizableOpInterfaceExternalModels(DialectRegistry ®istry)
+{
+ registry.addExtension(+[](MLIRContext *ctx, CatalystDialect *dialect) {
+ CustomCallOp::attachInterface(*ctx);
+ PrintOp::attachInterface(*ctx);
+ CallbackOp::attachInterface(*ctx);
+ CallbackCallOp::attachInterface(*ctx);
+ });
+}
diff --git a/mlir/lib/Catalyst/Transforms/BufferizationPatterns.cpp b/mlir/lib/Catalyst/Transforms/BufferizationPatterns.cpp
deleted file mode 100644
index 1fd2a436ca..0000000000
--- a/mlir/lib/Catalyst/Transforms/BufferizationPatterns.cpp
+++ /dev/null
@@ -1,173 +0,0 @@
-// Copyright 2023 Xanadu Quantum Technologies Inc.
-
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-
-// http://www.apache.org/licenses/LICENSE-2.0
-
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
-#include "mlir/Dialect/MemRef/IR/MemRef.h"
-#include "mlir/Transforms/DialectConversion.h"
-
-#include "Catalyst/IR/CatalystDialect.h"
-#include "Catalyst/IR/CatalystOps.h"
-
-using namespace mlir;
-using namespace catalyst;
-
-namespace {
-
-struct BufferizePrintOp : public OpConversionPattern {
- using OpConversionPattern::OpConversionPattern;
-
- LogicalResult matchAndRewrite(PrintOp op, OpAdaptor adaptor,
- ConversionPatternRewriter &rewriter) const override
- {
- if (op.getVal()) {
- rewriter.replaceOpWithNewOp(op, adaptor.getVal(), adaptor.getConstValAttr(),
- adaptor.getPrintDescriptorAttr());
- }
- return success();
- }
-};
-
-struct BufferizeCustomCallOp : public OpConversionPattern {
- using OpConversionPattern::OpConversionPattern;
-
- LogicalResult matchAndRewrite(CustomCallOp op, OpAdaptor adaptor,
- ConversionPatternRewriter &rewriter) const override
- {
- // Add bufferized arguments
- SmallVector bufferArgs;
- ValueRange operands = adaptor.getOperands();
- for (Value operand : operands) {
- bufferArgs.push_back(operand);
- }
-
- // Add bufferized return values to the arguments
- ValueRange results = op.getResults();
- for (Value result : results) {
- Type resultType = result.getType();
- RankedTensorType tensorType = dyn_cast(resultType);
- if (!tensorType) {
- return failure();
- }
- auto options = bufferization::BufferizationOptions();
- FailureOr tensorAlloc = bufferization::allocateTensorForShapedValue(
- rewriter, op->getLoc(), result, options, false);
- MemRefType memrefType =
- MemRefType::get(tensorType.getShape(), tensorType.getElementType());
- auto newBuffer =
- rewriter.create(op->getLoc(), memrefType, *tensorAlloc);
- bufferArgs.push_back(newBuffer);
- }
- // Add the initial number of arguments
- int32_t numArguments = static_cast(op.getNumOperands());
- DenseI32ArrayAttr numArgumentsDenseAttr = rewriter.getDenseI32ArrayAttr({numArguments});
-
- // Create an updated custom call operation
- rewriter.create(op->getLoc(), TypeRange{}, bufferArgs, op.getCallTargetName(),
- numArgumentsDenseAttr);
- size_t startIndex = bufferArgs.size() - op.getNumResults();
- SmallVector bufferResults(bufferArgs.begin() + startIndex, bufferArgs.end());
- rewriter.replaceOp(op, bufferResults);
- return success();
- }
-};
-
-struct BufferizeCallbackOp : public OpConversionPattern {
- using OpConversionPattern::OpConversionPattern;
-
- LogicalResult match(CallbackOp op) const override
- {
- // Only match here if we have all memref arguments and return values.
- if (llvm::any_of(op.getArgumentTypes(),
- [](Type argType) { return !isa(argType); })) {
- return failure();
- }
- if (llvm::any_of(op.getResultTypes(),
- [](Type argType) { return !isa(argType); })) {
- return failure();
- }
-
- // Only match if we have result types.
- return op.getResultTypes().empty() ? failure() : success();
- }
-
- void rewrite(CallbackOp op, OpAdaptor adaptor,
- ConversionPatternRewriter &rewriter) const override
- {
- auto argTys = op.getArgumentTypes();
- auto retTys = op.getResultTypes();
- SmallVector emptyRets;
- SmallVector args(argTys.begin(), argTys.end());
- args.insert(args.end(), retTys.begin(), retTys.end());
- auto callbackTy = rewriter.getFunctionType(args, emptyRets);
- rewriter.modifyOpInPlace(op, [&] { op.setFunctionType(callbackTy); });
- }
-};
-
-struct BufferizeCallbackCallOp : public OpConversionPattern {
- using OpConversionPattern::OpConversionPattern;
-
- LogicalResult matchAndRewrite(CallbackCallOp callOp, OpAdaptor adaptor,
- ConversionPatternRewriter &rewriter) const override
- {
- SmallVector convertedResults;
- if (failed(typeConverter->convertTypes(callOp.getResultTypes(), convertedResults)))
- return failure();
-
- if (callOp->getNumResults() != convertedResults.size())
- return failure();
-
- auto operands = adaptor.getOperands();
- SmallVector newInputs(operands.begin(), operands.end());
- auto results = callOp.getResults();
-
- auto loc = callOp->getLoc();
- auto options = bufferization::BufferizationOptions();
- SmallVector outmemrefs;
- for (auto result : results) {
- FailureOr tensorAlloc =
- bufferization::allocateTensorForShapedValue(rewriter, loc, result, options, false);
- if (failed(tensorAlloc))
- return failure();
-
- auto tensor = *tensorAlloc;
- RankedTensorType tensorTy = cast(tensor.getType());
- auto shape = tensorTy.getShape();
- auto elementTy = tensorTy.getElementType();
- auto memrefType = MemRefType::get(shape, elementTy);
- auto toMemrefOp = rewriter.create(loc, memrefType, tensor);
- auto memref = toMemrefOp.getResult();
- outmemrefs.push_back(memref);
- newInputs.push_back(memref);
- }
-
- SmallVector emptyRets;
- rewriter.create(loc, emptyRets, callOp.getCallee(), newInputs);
- rewriter.replaceOp(callOp, outmemrefs);
- return success();
- }
-};
-
-} // namespace
-
-namespace catalyst {
-
-void populateBufferizationPatterns(TypeConverter &typeConverter, RewritePatternSet &patterns)
-{
- patterns.add(typeConverter, patterns.getContext());
- patterns.add(typeConverter, patterns.getContext());
- patterns.add(typeConverter, patterns.getContext());
- patterns.add(typeConverter, patterns.getContext());
-}
-
-} // namespace catalyst
diff --git a/mlir/lib/Catalyst/Transforms/CMakeLists.txt b/mlir/lib/Catalyst/Transforms/CMakeLists.txt
index 59e5619eca..48d4aa362b 100644
--- a/mlir/lib/Catalyst/Transforms/CMakeLists.txt
+++ b/mlir/lib/Catalyst/Transforms/CMakeLists.txt
@@ -4,8 +4,7 @@ file(GLOB SRC
ApplyTransformSequencePass.cpp
ArrayListToMemRefPass.cpp
AsyncUtils.cpp
- BufferizationPatterns.cpp
- catalyst_bufferize.cpp
+ BufferizableOpInterfaceImpl.cpp
catalyst_to_llvm.cpp
DetectQNodes.cpp
DetensorizeSCFPass.cpp
diff --git a/mlir/lib/Catalyst/Transforms/RegisterAllPasses.cpp b/mlir/lib/Catalyst/Transforms/RegisterAllPasses.cpp
index cc1379a6d1..3ca100ceee 100644
--- a/mlir/lib/Catalyst/Transforms/RegisterAllPasses.cpp
+++ b/mlir/lib/Catalyst/Transforms/RegisterAllPasses.cpp
@@ -29,7 +29,6 @@ void catalyst::registerAllCatalystPasses()
mlir::registerPass(catalyst::createAnnotateFunctionPass);
mlir::registerPass(catalyst::createApplyTransformSequencePass);
mlir::registerPass(catalyst::createArrayListToMemRefPass);
- mlir::registerPass(catalyst::createCatalystBufferizationPass);
mlir::registerPass(catalyst::createCatalystConversionPass);
mlir::registerPass(catalyst::createCopyGlobalMemRefPass);
mlir::registerPass(catalyst::createDetensorizeSCFPass);
diff --git a/mlir/lib/Catalyst/Transforms/catalyst_bufferize.cpp b/mlir/lib/Catalyst/Transforms/catalyst_bufferize.cpp
deleted file mode 100644
index 8363f4d39b..0000000000
--- a/mlir/lib/Catalyst/Transforms/catalyst_bufferize.cpp
+++ /dev/null
@@ -1,72 +0,0 @@
-// Copyright 2023 Xanadu Quantum Technologies Inc.
-
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-
-// http://www.apache.org/licenses/LICENSE-2.0
-
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
-#include "mlir/Dialect/Bufferization/Transforms/Bufferize.h"
-#include "mlir/Dialect/Index/IR/IndexDialect.h"
-#include "mlir/Dialect/MemRef/IR/MemRef.h"
-#include "mlir/IR/BuiltinOps.h"
-#include "mlir/Pass/Pass.h"
-#include "mlir/Transforms/DialectConversion.h"
-
-#include "Catalyst/IR/CatalystOps.h"
-#include "Catalyst/Transforms/Passes.h"
-#include "Catalyst/Transforms/Patterns.h"
-
-using namespace mlir;
-using namespace catalyst;
-
-namespace catalyst {
-
-#define GEN_PASS_DEF_CATALYSTBUFFERIZATIONPASS
-#include "Catalyst/Transforms/Passes.h.inc"
-
-struct CatalystBufferizationPass : impl::CatalystBufferizationPassBase {
- using CatalystBufferizationPassBase::CatalystBufferizationPassBase;
-
- void runOnOperation() final
- {
- MLIRContext *context = &getContext();
- bufferization::BufferizeTypeConverter typeConverter;
-
- RewritePatternSet patterns(context);
- populateBufferizationPatterns(typeConverter, patterns);
- populateFunctionOpInterfaceTypeConversionPattern(patterns, typeConverter);
-
- ConversionTarget target(*context);
- bufferization::populateBufferizeMaterializationLegality(target);
- target.markUnknownOpDynamicallyLegal([](Operation *) { return true; });
- target.addDynamicallyLegalOp(
- [&](PrintOp op) { return typeConverter.isLegal(op); });
- target.addDynamicallyLegalOp(
- [&](CustomCallOp op) { return typeConverter.isLegal(op); });
- target.addDynamicallyLegalOp([&](CallbackOp op) {
- return typeConverter.isSignatureLegal(op.getFunctionType()) &&
- typeConverter.isLegal(&op.getBody()) && op.getResultTypes().empty();
- });
- target.addDynamicallyLegalOp(
- [&](CallbackCallOp op) { return typeConverter.isLegal(op); });
-
- if (failed(applyPartialConversion(getOperation(), target, std::move(patterns)))) {
- signalPassFailure();
- }
- }
-};
-
-std::unique_ptr createCatalystBufferizationPass()
-{
- return std::make_unique();
-}
-
-} // namespace catalyst
diff --git a/mlir/lib/Driver/CompilerDriver.cpp b/mlir/lib/Driver/CompilerDriver.cpp
index 2875ad79a7..1b3256ead7 100644
--- a/mlir/lib/Driver/CompilerDriver.cpp
+++ b/mlir/lib/Driver/CompilerDriver.cpp
@@ -56,6 +56,7 @@
#include "llvm/Transforms/IPO/GlobalDCE.h"
#include "Catalyst/IR/CatalystDialect.h"
+#include "Catalyst/Transforms/BufferizableOpInterfaceImpl.h"
#include "Catalyst/Transforms/Passes.h"
#include "Driver/CatalystLLVMTarget.h"
#include "Driver/CompilerDriver.h"
@@ -964,6 +965,7 @@ int QuantumDriverMainFromCL(int argc, char **argv)
registerLLVMTranslations(registry);
// Register bufferization interfaces
+ catalyst::registerBufferizableOpInterfaceExternalModels(registry);
catalyst::quantum::registerBufferizableOpInterfaceExternalModels(registry);
// Register and parse command line options.
diff --git a/mlir/lib/Driver/Pipelines.cpp b/mlir/lib/Driver/Pipelines.cpp
index 46a3079c69..e3d7f3ad55 100644
--- a/mlir/lib/Driver/Pipelines.cpp
+++ b/mlir/lib/Driver/Pipelines.cpp
@@ -13,12 +13,14 @@
// limitations under the License.
#include "Driver/Pipelines.h"
+#include "Catalyst/IR/CatalystDialect.h"
#include "Catalyst/Transforms/Passes.h"
#include "Gradient/Transforms/Passes.h"
#include "Mitigation/Transforms/Passes.h"
#include "Quantum/IR/QuantumDialect.h"
#include "Quantum/Transforms/Passes.h"
#include "mhlo/transforms/passes.h"
+#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
#include "mlir/InitAllDialects.h"
#include "mlir/InitAllPasses.h"
#include "mlir/Pass/PassManager.h"
@@ -77,7 +79,15 @@ void createBufferizationPipeline(OpPassManager &pm)
pm.addPass(mlir::bufferization::createEmptyTensorToAllocTensorPass());
pm.addNestedPass(mlir::bufferization::createBufferizationBufferizePass());
pm.addNestedPass(mlir::tensor::createTensorBufferizePass());
- pm.addPass(catalyst::createCatalystBufferizationPass());
+ mlir::bufferization::OneShotBufferizationOptions catalyst_buffer_options;
+ catalyst_buffer_options.opFilter.allowDialect();
+ catalyst_buffer_options.unknownTypeConverterFn =
+ [=](Value value, Attribute memorySpace,
+ const mlir::bufferization::BufferizationOptions &options) {
+ auto tensorType = cast(value.getType());
+ return bufferization::getMemRefTypeWithStaticIdentityLayout(tensorType, memorySpace);
+ };
+ pm.addPass(mlir::bufferization::createOneShotBufferizePass(catalyst_buffer_options));
pm.addNestedPass(mlir::createLinalgBufferizePass());
pm.addNestedPass(mlir::tensor::createTensorBufferizePass());
mlir::bufferization::OneShotBufferizationOptions quantum_buffer_options;
diff --git a/mlir/test/Catalyst/BufferizationTest.mlir b/mlir/test/Catalyst/BufferizationTest.mlir
index bcff9bff98..50c691fef9 100644
--- a/mlir/test/Catalyst/BufferizationTest.mlir
+++ b/mlir/test/Catalyst/BufferizationTest.mlir
@@ -12,7 +12,10 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-// RUN: quantum-opt --catalyst-bufferize --split-input-file %s | FileCheck %s
+// RUN: quantum-opt --split-input-file \
+// RUN: --pass-pipeline="builtin.module( \
+// RUN: one-shot-bufferize{unknown-type-conversion=identity-layout-map} \
+// RUN: )" %s | FileCheck %s
//////////////////////
// Catalyst PrintOp //
@@ -20,6 +23,7 @@
func.func @dbprint_val(%arg0: tensor) {
+ // CHECK: %0 = bufferization.to_memref %arg0
// CHECK: "catalyst.print"(%0) : (memref) -> ()
"catalyst.print"(%arg0) : (tensor) -> ()
@@ -30,6 +34,7 @@ func.func @dbprint_val(%arg0: tensor) {
func.func @dbprint_memref(%arg0: tensor) {
+ // CHECK: %0 = bufferization.to_memref %arg0
// CHECK: "catalyst.print"(%0) <{print_descriptor}> : (memref) -> ()
"catalyst.print"(%arg0) {print_descriptor} : (tensor) -> ()
@@ -49,11 +54,11 @@ func.func @dbprint_str() {
// -----
func.func @custom_call(%arg0: tensor<3x3xf64>) -> tensor<3x3xf64> {
- // CHECK: [[memrefArg:%.+]] = bufferization.to_memref %arg0 : memref<3x3xf64>
- // CHECK: [[alloc:%.+]] = bufferization.alloc_tensor() {{.*}}: tensor<3x3xf64>
- // CHECK: [[allocmemref:%.+]] = bufferization.to_memref [[alloc]] : memref<3x3xf64>
- // CHECK: catalyst.custom_call fn("lapack_dgesdd") ([[memrefArg]], [[allocmemref]]) {number_original_arg = array} : (memref<3x3xf64>, memref<3x3xf64>) -> ()
- // CHECK: [[res:%.+]] = bufferization.to_tensor [[allocmemref]] : memref<3x3xf64>
+ // CHECK: [[sourceAlloc:%.+]] = bufferization.to_memref %arg0
+ // CHECK: [[destAlloc:%.+]] = memref.alloc() {{.*}}: memref<3x3xf64>
+ // CHECK: catalyst.custom_call fn("lapack_dgesdd") ([[sourceAlloc]], [[destAlloc]]) {number_original_arg = array} :
+ // CHECK-SAME: (memref<3x3xf64>, memref<3x3xf64>) -> ()
+ // CHECK: [[res:%.+]] = bufferization.to_tensor [[destAlloc]] : memref<3x3xf64>
// CHECK: return [[res]] : tensor<3x3xf64>
%0 = catalyst.custom_call fn("lapack_dgesdd") (%arg0) : (tensor<3x3xf64>) -> (tensor<3x3xf64>)
@@ -77,12 +82,11 @@ module @test1 {
// CHECK-LABEL: @foo(
// CHECK-SAME: [[arg0:%.+]]: tensor)
func.func private @foo(%arg0: tensor) -> tensor {
- // CHECK-DAG: [[memref0:%.+]] = bufferization.to_memref [[arg0]]
- // CHECK-DAG: [[tensor1:%.+]] = bufferization.alloc_tensor
- // CHECK: [[memref1:%.+]] = bufferization.to_memref [[tensor1]]
- // CHECK: catalyst.callback_call @callback_1([[memref0]], [[memref1]])
+ // CHECK-DAG: [[memref0:%.+]] = bufferization.to_memref [[arg0]] : memref
+ // CHECK-DAG: [[resAlloc:%.+]] = memref.alloc() {{.*}}: memref
+ // CHECK: catalyst.callback_call @callback_1([[memref0]], [[resAlloc]]) : (memref, memref) -> ()
%1 = catalyst.callback_call @callback_1(%arg0) : (tensor) -> (tensor)
- // CHECK: [[retval:%.+]] = bufferization.to_tensor [[memref1]]
+ // CHECK: [[retval:%.+]] = bufferization.to_tensor [[resAlloc]]
// CHECK: return [[retval]]
return %1 : tensor
}
diff --git a/mlir/tools/quantum-opt/quantum-opt.cpp b/mlir/tools/quantum-opt/quantum-opt.cpp
index eda34657df..26c08f78bd 100644
--- a/mlir/tools/quantum-opt/quantum-opt.cpp
+++ b/mlir/tools/quantum-opt/quantum-opt.cpp
@@ -25,6 +25,7 @@
#include "mhlo/IR/hlo_ops.h"
#include "Catalyst/IR/CatalystDialect.h"
+#include "Catalyst/Transforms/BufferizableOpInterfaceImpl.h"
#include "Catalyst/Transforms/Passes.h"
#include "Gradient/IR/GradientDialect.h"
#include "Gradient/Transforms/Passes.h"
@@ -62,6 +63,7 @@ int main(int argc, char **argv)
registry.insert();
registry.insert();
+ catalyst::registerBufferizableOpInterfaceExternalModels(registry);
catalyst::quantum::registerBufferizableOpInterfaceExternalModels(registry);
return mlir::asMainReturnCode(