diff --git a/doc/releases/changelog-dev.md b/doc/releases/changelog-dev.md index e0b462de2c..198b1def28 100644 --- a/doc/releases/changelog-dev.md +++ b/doc/releases/changelog-dev.md @@ -166,6 +166,7 @@ The new mlir bufferization interface is required by jax 0.4.29 or higher. [(#1027)](https://github.com/PennyLaneAI/catalyst/pull/1027) [(#1686)](https://github.com/PennyLaneAI/catalyst/pull/1686) + [(#1708)](https://github.com/PennyLaneAI/catalyst/pull/1708)

Documentation 📝

diff --git a/frontend/catalyst/pipelines.py b/frontend/catalyst/pipelines.py index 49ece83c56..996f6a2c28 100644 --- a/frontend/catalyst/pipelines.py +++ b/frontend/catalyst/pipelines.py @@ -97,7 +97,7 @@ class CompileOptions: def __post_init__(self): # Check that async runs must not be seeded - if self.async_qnodes and self.seed != None: + if self.async_qnodes and self.seed is not None: raise CompileError( """ Seeding has no effect on asynchronous QNodes, @@ -107,7 +107,7 @@ def __post_init__(self): ) # Check that seed is 32-bit unsigned int - if (self.seed != None) and (self.seed < 0 or self.seed > 2**32 - 1): + if (self.seed is not None) and (self.seed < 0 or self.seed > 2**32 - 1): raise ValueError( """ Seed must be an unsigned 32-bit integer! @@ -227,7 +227,8 @@ def get_bufferization_stage(_options: CompileOptions) -> List[str]: "empty-tensor-to-alloc-tensor", "func.func(bufferization-bufferize)", "func.func(tensor-bufferize)", - "catalyst-bufferize", # Must be run before -- func.func(linalg-bufferize) + # Catalyst dialect's bufferization must be run before --func.func(linalg-bufferize) + "one-shot-bufferize{dialect-filter=catalyst unknown-type-conversion=identity-layout-map}", "func.func(linalg-bufferize)", "func.func(tensor-bufferize)", "one-shot-bufferize{dialect-filter=quantum}", diff --git a/mlir/include/Catalyst/Transforms/BufferizableOpInterfaceImpl.h b/mlir/include/Catalyst/Transforms/BufferizableOpInterfaceImpl.h new file mode 100644 index 0000000000..89eb431f8f --- /dev/null +++ b/mlir/include/Catalyst/Transforms/BufferizableOpInterfaceImpl.h @@ -0,0 +1,23 @@ +// Copyright 2024-2025 Xanadu Quantum Technologies Inc. + +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at + +// http://www.apache.org/licenses/LICENSE-2.0 + +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +using namespace mlir; + +namespace catalyst { + +void registerBufferizableOpInterfaceExternalModels(mlir::DialectRegistry ®istry); + +} // namespace catalyst diff --git a/mlir/include/Catalyst/Transforms/Passes.td b/mlir/include/Catalyst/Transforms/Passes.td index 0f50286dac..d22246e3c8 100644 --- a/mlir/include/Catalyst/Transforms/Passes.td +++ b/mlir/include/Catalyst/Transforms/Passes.td @@ -27,18 +27,6 @@ def DetensorizeSCFPass : Pass<"detensorize-scf"> { let constructor = "catalyst::createDetensorizeSCFPass()"; } -def CatalystBufferizationPass : Pass<"catalyst-bufferize"> { - let summary = "Bufferize tensors in catalyst utility ops."; - - let dependentDialects = [ - "bufferization::BufferizationDialect", - "memref::MemRefDialect", - "index::IndexDialect" - ]; - - let constructor = "catalyst::createCatalystBufferizationPass()"; -} - def ArrayListToMemRefPass : Pass<"convert-arraylist-to-memref"> { let summary = "Lower array list operations to memref operations."; let description = [{ diff --git a/mlir/include/Catalyst/Transforms/Patterns.h b/mlir/include/Catalyst/Transforms/Patterns.h index cdc5157806..6bbf3150ff 100644 --- a/mlir/include/Catalyst/Transforms/Patterns.h +++ b/mlir/include/Catalyst/Transforms/Patterns.h @@ -21,8 +21,6 @@ namespace catalyst { -void populateBufferizationPatterns(mlir::TypeConverter &, mlir::RewritePatternSet &); - void populateScatterPatterns(mlir::RewritePatternSet &); void populateHloCustomCallPatterns(mlir::RewritePatternSet &); diff --git a/mlir/include/Quantum/Transforms/Patterns.h b/mlir/include/Quantum/Transforms/Patterns.h index d278f809b7..8edaf0ffe5 100644 --- a/mlir/include/Quantum/Transforms/Patterns.h +++ b/mlir/include/Quantum/Transforms/Patterns.h @@ -21,8 +21,6 @@ namespace catalyst { namespace quantum { -void populateBufferizationLegality(mlir::TypeConverter &, mlir::ConversionTarget &); -void populateBufferizationPatterns(mlir::TypeConverter &, mlir::RewritePatternSet &); void populateQIRConversionPatterns(mlir::TypeConverter &, mlir::RewritePatternSet &); void populateAdjointPatterns(mlir::RewritePatternSet &); void populateSelfInversePatterns(mlir::RewritePatternSet &); diff --git a/mlir/lib/Catalyst/IR/CatalystDialect.cpp b/mlir/lib/Catalyst/IR/CatalystDialect.cpp index 1dce30c4cc..aae67c64ca 100644 --- a/mlir/lib/Catalyst/IR/CatalystDialect.cpp +++ b/mlir/lib/Catalyst/IR/CatalystDialect.cpp @@ -14,6 +14,7 @@ #include "Catalyst/IR/CatalystDialect.h" #include "Catalyst/IR/CatalystOps.h" +#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" #include "mlir/IR/Builders.h" #include "mlir/IR/DialectImplementation.h" // needed for generated type parser #include "mlir/Interfaces/FunctionImplementation.h" @@ -40,6 +41,9 @@ void CatalystDialect::initialize() #define GET_OP_LIST #include "Catalyst/IR/CatalystOps.cpp.inc" >(); + + declarePromisedInterfaces(); } //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Catalyst/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Catalyst/Transforms/BufferizableOpInterfaceImpl.cpp new file mode 100644 index 0000000000..e5fb6b9bca --- /dev/null +++ b/mlir/lib/Catalyst/Transforms/BufferizableOpInterfaceImpl.cpp @@ -0,0 +1,317 @@ +// Copyright 2024-2025 Xanadu Quantum Technologies Inc. + +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at + +// http://www.apache.org/licenses/LICENSE-2.0 + +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" +#include "mlir/Dialect/Bufferization/IR/Bufferization.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Transforms/DialectConversion.h" + +#include "Catalyst/IR/CatalystOps.h" +#include "Catalyst/Transforms/BufferizableOpInterfaceImpl.h" + +using namespace mlir; +using namespace mlir::bufferization; +using namespace catalyst; + +/** + * Implementation of the BufferizableOpInterface for use with one-shot bufferization. + * For more information on the interface, refer to the documentation below: + * https://mlir.llvm.org/docs/Bufferization/#extending-one-shot-bufferize + * https://github.com/llvm/llvm-project/blob/main/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td#L14 + */ + +namespace { + +/// Bufferization of catalyst.print. Get memref of printOp.val. +struct PrintOpInterface + : public bufferization::BufferizableOpInterface::ExternalModel { + bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, + const bufferization::AnalysisState &state) const + { + return true; + } + + bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, + const bufferization::AnalysisState &state) const + { + return false; + } + + bufferization::AliasingValueList + getAliasingValues(Operation *op, OpOperand &opOperand, + const bufferization::AnalysisState &state) const + { + return {}; + } + + LogicalResult bufferize(Operation *op, RewriterBase &rewriter, + const bufferization::BufferizationOptions &options) const + { + auto printOp = cast(op); + if (printOp.getVal()) { + FailureOr source = getBuffer(rewriter, printOp.getVal(), options); + if (failed(source)) { + return failure(); + } + bufferization::replaceOpWithNewBufferizedOp( + rewriter, op, *source, printOp.getConstValAttr(), printOp.getPrintDescriptorAttr()); + } + return success(); + } +}; + +/// Bufferization of catalyst.custom_call. Mainly get buffers for arguments. +struct CustomCallOpInterface + : public bufferization::BufferizableOpInterface::ExternalModel { + bool bufferizesToAllocation(Operation *op, Value value) const { return true; } + + bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, + const bufferization::AnalysisState &state) const + { + // Custom Call Op always reads the operand memory no matter what. + return true; + } + + bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, + const bufferization::AnalysisState &state) const + { + // We only use custom call for the jax lapack kernels. + // This is actually hard-guarded: in the lowering pattern for custom call + // we check that the name of the callee is a jax symbol for a lapack kernel. + // + // The lapack kernels themselves might overwrite some of the input arrays. + // However, in jax's shim wrapper layer, a memcpy is already performed. + // See + // https://github.com/PennyLaneAI/catalyst/blob/main/frontend/catalyst/utils/jax_cpu_lapack_kernels/lapack_kernels.cpp + // + // The arguments to the underlying lapack kernel are denoted by the jax wrapper + // function as `data`. The `data` args already contain the output array that + // the lapack kernel is supposed to write into. The other input arrays are all marked const. + // Jax then purifies the function by adding a new argument `out` to hold the + // output array. + // + // In other words, the jax wrappers we call here with custom call op + // are already pure, and we won't have side effects on the input tensors. + + return false; + } + + bufferization::AliasingValueList + getAliasingValues(Operation *op, OpOperand &opOperand, + const bufferization::AnalysisState &state) const + { + return {}; + } + + LogicalResult bufferize(Operation *op, RewriterBase &rewriter, + const bufferization::BufferizationOptions &options) const + { + auto customCallOp = cast(op); + + // Add bufferized arguments + SmallVector bufferArgs; + ValueRange operands = customCallOp.getOperands(); + for (Value operand : operands) { + FailureOr opBuffer = getBuffer(rewriter, operand, options); + if (failed(opBuffer)) { + return failure(); + } + bufferArgs.push_back(*opBuffer); + } + + // Add bufferized return values to the arguments + ValueRange results = customCallOp.getResults(); + for (Value result : results) { + Type resultType = result.getType(); + RankedTensorType tensorType = dyn_cast(resultType); + if (!tensorType) { + return failure(); + } + auto options = bufferization::BufferizationOptions(); + FailureOr tensorAlloc = bufferization::allocateTensorForShapedValue( + rewriter, op->getLoc(), result, options, false); + MemRefType memrefType = + MemRefType::get(tensorType.getShape(), tensorType.getElementType()); + auto newBuffer = + rewriter.create(op->getLoc(), memrefType, *tensorAlloc); + bufferArgs.push_back(newBuffer); + } + + // Add the initial number of arguments + int32_t numArguments = static_cast(customCallOp.getNumOperands()); + DenseI32ArrayAttr numArgumentsDenseAttr = rewriter.getDenseI32ArrayAttr({numArguments}); + + // Create an updated custom call operation + rewriter.create(op->getLoc(), TypeRange{}, bufferArgs, + customCallOp.getCallTargetName(), numArgumentsDenseAttr); + size_t startIndex = bufferArgs.size() - customCallOp.getNumResults(); + SmallVector bufferResults(bufferArgs.begin() + startIndex, bufferArgs.end()); + bufferization::replaceOpWithBufferizedValues(rewriter, op, bufferResults); + + return success(); + } +}; + +struct CallbackOpInterface + : public bufferization::BufferizableOpInterface::ExternalModel { + bool hasTensorSemantics(Operation *op) const + { + auto isaTensor = llvm::IsaPred; + + // A function has tensor semantics if it has tensor arguments/results. + auto callbackOp = cast(op); + bool hasTensorArg = any_of(callbackOp.getArgumentTypes(), isaTensor); + bool hasTensorResult = any_of(callbackOp.getResultTypes(), isaTensor); + if (hasTensorArg || hasTensorResult) { + return true; + } + + return false; + } + + LogicalResult bufferize(Operation *op, RewriterBase &rewriter, + const bufferization::BufferizationOptions &options) const + { + auto callbackOp = cast(op); + + auto argTys = callbackOp.getArgumentTypes(); + auto retTys = callbackOp.getResultTypes(); + SmallVector emptyRets; + SmallVector args(argTys.begin(), argTys.end()); + args.insert(args.end(), retTys.begin(), retTys.end()); + SmallVector bufferArgs; + for (Type ty : args) { + auto tensorType = dyn_cast(ty); + if (!tensorType) { + bufferArgs.push_back(ty); + } + else { + bufferArgs.push_back( + MemRefType::get(tensorType.getShape(), tensorType.getElementType())); + } + } + auto callbackTy = rewriter.getFunctionType(bufferArgs, emptyRets); + rewriter.modifyOpInPlace(op, [&] { callbackOp.setFunctionType(callbackTy); }); + + return success(); + } +}; + +void convertTypes(SmallVector inTypes, SmallVector &convertedResults) +{ + // See https://github.com/llvm/llvm-project/pull/114155/files + for (Type inType : inTypes) { + if (isa(inType)) { + convertedResults.push_back( + bufferization::getMemRefTypeWithStaticIdentityLayout(cast(inType))); + } + else { + convertedResults.push_back(inType); + } + } +} + +struct CallbackCallOpInterface + : public bufferization::BufferizableOpInterface::ExternalModel { + bool bufferizesToAllocation(Operation *op, Value value) const { return true; } + + bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, + const bufferization::AnalysisState &state) const + { + return true; + } + + bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, + const bufferization::AnalysisState &state) const + { + // We can safely say false because CallbackCallOp's memrefs + // will be put in a JAX array and JAX arrays are immutable. + // + // Unlike NumPy arrays, JAX arrays are always immutable. + // + // https://jax.readthedocs.io/en/latest/notebooks/thinking_in_jax.html + return false; + } + + bufferization::AliasingValueList + getAliasingValues(Operation *op, OpOperand &opOperand, + const bufferization::AnalysisState &state) const + { + return {}; + } + + LogicalResult bufferize(Operation *op, RewriterBase &rewriter, + const bufferization::BufferizationOptions &options) const + { + auto callOp = cast(op); + + SmallVector convertedResults; + convertTypes(SmallVector(callOp.getResultTypes()), convertedResults); + if (callOp->getNumResults() != convertedResults.size()) { + return failure(); + } + + SmallVector newInputs; + auto operands = callOp.getOperands(); + for (Value operand : operands) { + FailureOr opBuffer = getBuffer(rewriter, operand, options); + if (failed(opBuffer)) { + return failure(); + } + newInputs.push_back(*opBuffer); + } + + auto results = callOp.getResults(); + auto loc = callOp->getLoc(); + SmallVector outmemrefs; + for (auto result : results) { + FailureOr tensorAlloc = + bufferization::allocateTensorForShapedValue(rewriter, loc, result, options, false); + if (failed(tensorAlloc)) { + return failure(); + } + + auto tensor = *tensorAlloc; + RankedTensorType tensorTy = cast(tensor.getType()); + auto shape = tensorTy.getShape(); + auto elementTy = tensorTy.getElementType(); + auto memrefType = MemRefType::get(shape, elementTy); + auto toMemrefOp = rewriter.create(loc, memrefType, tensor); + auto memref = toMemrefOp.getResult(); + outmemrefs.push_back(memref); + newInputs.push_back(memref); + } + + SmallVector emptyRets; + rewriter.create(loc, emptyRets, callOp.getCallee(), newInputs); + bufferization::replaceOpWithBufferizedValues(rewriter, op, outmemrefs); + return success(); + } +}; + +} // namespace + +void catalyst::registerBufferizableOpInterfaceExternalModels(DialectRegistry ®istry) +{ + registry.addExtension(+[](MLIRContext *ctx, CatalystDialect *dialect) { + CustomCallOp::attachInterface(*ctx); + PrintOp::attachInterface(*ctx); + CallbackOp::attachInterface(*ctx); + CallbackCallOp::attachInterface(*ctx); + }); +} diff --git a/mlir/lib/Catalyst/Transforms/BufferizationPatterns.cpp b/mlir/lib/Catalyst/Transforms/BufferizationPatterns.cpp deleted file mode 100644 index 1fd2a436ca..0000000000 --- a/mlir/lib/Catalyst/Transforms/BufferizationPatterns.cpp +++ /dev/null @@ -1,173 +0,0 @@ -// Copyright 2023 Xanadu Quantum Technologies Inc. - -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at - -// http://www.apache.org/licenses/LICENSE-2.0 - -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "mlir/Dialect/Bufferization/IR/Bufferization.h" -#include "mlir/Dialect/MemRef/IR/MemRef.h" -#include "mlir/Transforms/DialectConversion.h" - -#include "Catalyst/IR/CatalystDialect.h" -#include "Catalyst/IR/CatalystOps.h" - -using namespace mlir; -using namespace catalyst; - -namespace { - -struct BufferizePrintOp : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(PrintOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override - { - if (op.getVal()) { - rewriter.replaceOpWithNewOp(op, adaptor.getVal(), adaptor.getConstValAttr(), - adaptor.getPrintDescriptorAttr()); - } - return success(); - } -}; - -struct BufferizeCustomCallOp : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(CustomCallOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override - { - // Add bufferized arguments - SmallVector bufferArgs; - ValueRange operands = adaptor.getOperands(); - for (Value operand : operands) { - bufferArgs.push_back(operand); - } - - // Add bufferized return values to the arguments - ValueRange results = op.getResults(); - for (Value result : results) { - Type resultType = result.getType(); - RankedTensorType tensorType = dyn_cast(resultType); - if (!tensorType) { - return failure(); - } - auto options = bufferization::BufferizationOptions(); - FailureOr tensorAlloc = bufferization::allocateTensorForShapedValue( - rewriter, op->getLoc(), result, options, false); - MemRefType memrefType = - MemRefType::get(tensorType.getShape(), tensorType.getElementType()); - auto newBuffer = - rewriter.create(op->getLoc(), memrefType, *tensorAlloc); - bufferArgs.push_back(newBuffer); - } - // Add the initial number of arguments - int32_t numArguments = static_cast(op.getNumOperands()); - DenseI32ArrayAttr numArgumentsDenseAttr = rewriter.getDenseI32ArrayAttr({numArguments}); - - // Create an updated custom call operation - rewriter.create(op->getLoc(), TypeRange{}, bufferArgs, op.getCallTargetName(), - numArgumentsDenseAttr); - size_t startIndex = bufferArgs.size() - op.getNumResults(); - SmallVector bufferResults(bufferArgs.begin() + startIndex, bufferArgs.end()); - rewriter.replaceOp(op, bufferResults); - return success(); - } -}; - -struct BufferizeCallbackOp : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult match(CallbackOp op) const override - { - // Only match here if we have all memref arguments and return values. - if (llvm::any_of(op.getArgumentTypes(), - [](Type argType) { return !isa(argType); })) { - return failure(); - } - if (llvm::any_of(op.getResultTypes(), - [](Type argType) { return !isa(argType); })) { - return failure(); - } - - // Only match if we have result types. - return op.getResultTypes().empty() ? failure() : success(); - } - - void rewrite(CallbackOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override - { - auto argTys = op.getArgumentTypes(); - auto retTys = op.getResultTypes(); - SmallVector emptyRets; - SmallVector args(argTys.begin(), argTys.end()); - args.insert(args.end(), retTys.begin(), retTys.end()); - auto callbackTy = rewriter.getFunctionType(args, emptyRets); - rewriter.modifyOpInPlace(op, [&] { op.setFunctionType(callbackTy); }); - } -}; - -struct BufferizeCallbackCallOp : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite(CallbackCallOp callOp, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override - { - SmallVector convertedResults; - if (failed(typeConverter->convertTypes(callOp.getResultTypes(), convertedResults))) - return failure(); - - if (callOp->getNumResults() != convertedResults.size()) - return failure(); - - auto operands = adaptor.getOperands(); - SmallVector newInputs(operands.begin(), operands.end()); - auto results = callOp.getResults(); - - auto loc = callOp->getLoc(); - auto options = bufferization::BufferizationOptions(); - SmallVector outmemrefs; - for (auto result : results) { - FailureOr tensorAlloc = - bufferization::allocateTensorForShapedValue(rewriter, loc, result, options, false); - if (failed(tensorAlloc)) - return failure(); - - auto tensor = *tensorAlloc; - RankedTensorType tensorTy = cast(tensor.getType()); - auto shape = tensorTy.getShape(); - auto elementTy = tensorTy.getElementType(); - auto memrefType = MemRefType::get(shape, elementTy); - auto toMemrefOp = rewriter.create(loc, memrefType, tensor); - auto memref = toMemrefOp.getResult(); - outmemrefs.push_back(memref); - newInputs.push_back(memref); - } - - SmallVector emptyRets; - rewriter.create(loc, emptyRets, callOp.getCallee(), newInputs); - rewriter.replaceOp(callOp, outmemrefs); - return success(); - } -}; - -} // namespace - -namespace catalyst { - -void populateBufferizationPatterns(TypeConverter &typeConverter, RewritePatternSet &patterns) -{ - patterns.add(typeConverter, patterns.getContext()); - patterns.add(typeConverter, patterns.getContext()); - patterns.add(typeConverter, patterns.getContext()); - patterns.add(typeConverter, patterns.getContext()); -} - -} // namespace catalyst diff --git a/mlir/lib/Catalyst/Transforms/CMakeLists.txt b/mlir/lib/Catalyst/Transforms/CMakeLists.txt index 59e5619eca..48d4aa362b 100644 --- a/mlir/lib/Catalyst/Transforms/CMakeLists.txt +++ b/mlir/lib/Catalyst/Transforms/CMakeLists.txt @@ -4,8 +4,7 @@ file(GLOB SRC ApplyTransformSequencePass.cpp ArrayListToMemRefPass.cpp AsyncUtils.cpp - BufferizationPatterns.cpp - catalyst_bufferize.cpp + BufferizableOpInterfaceImpl.cpp catalyst_to_llvm.cpp DetectQNodes.cpp DetensorizeSCFPass.cpp diff --git a/mlir/lib/Catalyst/Transforms/RegisterAllPasses.cpp b/mlir/lib/Catalyst/Transforms/RegisterAllPasses.cpp index cc1379a6d1..3ca100ceee 100644 --- a/mlir/lib/Catalyst/Transforms/RegisterAllPasses.cpp +++ b/mlir/lib/Catalyst/Transforms/RegisterAllPasses.cpp @@ -29,7 +29,6 @@ void catalyst::registerAllCatalystPasses() mlir::registerPass(catalyst::createAnnotateFunctionPass); mlir::registerPass(catalyst::createApplyTransformSequencePass); mlir::registerPass(catalyst::createArrayListToMemRefPass); - mlir::registerPass(catalyst::createCatalystBufferizationPass); mlir::registerPass(catalyst::createCatalystConversionPass); mlir::registerPass(catalyst::createCopyGlobalMemRefPass); mlir::registerPass(catalyst::createDetensorizeSCFPass); diff --git a/mlir/lib/Catalyst/Transforms/catalyst_bufferize.cpp b/mlir/lib/Catalyst/Transforms/catalyst_bufferize.cpp deleted file mode 100644 index 8363f4d39b..0000000000 --- a/mlir/lib/Catalyst/Transforms/catalyst_bufferize.cpp +++ /dev/null @@ -1,72 +0,0 @@ -// Copyright 2023 Xanadu Quantum Technologies Inc. - -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at - -// http://www.apache.org/licenses/LICENSE-2.0 - -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "mlir/Dialect/Bufferization/IR/Bufferization.h" -#include "mlir/Dialect/Bufferization/Transforms/Bufferize.h" -#include "mlir/Dialect/Index/IR/IndexDialect.h" -#include "mlir/Dialect/MemRef/IR/MemRef.h" -#include "mlir/IR/BuiltinOps.h" -#include "mlir/Pass/Pass.h" -#include "mlir/Transforms/DialectConversion.h" - -#include "Catalyst/IR/CatalystOps.h" -#include "Catalyst/Transforms/Passes.h" -#include "Catalyst/Transforms/Patterns.h" - -using namespace mlir; -using namespace catalyst; - -namespace catalyst { - -#define GEN_PASS_DEF_CATALYSTBUFFERIZATIONPASS -#include "Catalyst/Transforms/Passes.h.inc" - -struct CatalystBufferizationPass : impl::CatalystBufferizationPassBase { - using CatalystBufferizationPassBase::CatalystBufferizationPassBase; - - void runOnOperation() final - { - MLIRContext *context = &getContext(); - bufferization::BufferizeTypeConverter typeConverter; - - RewritePatternSet patterns(context); - populateBufferizationPatterns(typeConverter, patterns); - populateFunctionOpInterfaceTypeConversionPattern(patterns, typeConverter); - - ConversionTarget target(*context); - bufferization::populateBufferizeMaterializationLegality(target); - target.markUnknownOpDynamicallyLegal([](Operation *) { return true; }); - target.addDynamicallyLegalOp( - [&](PrintOp op) { return typeConverter.isLegal(op); }); - target.addDynamicallyLegalOp( - [&](CustomCallOp op) { return typeConverter.isLegal(op); }); - target.addDynamicallyLegalOp([&](CallbackOp op) { - return typeConverter.isSignatureLegal(op.getFunctionType()) && - typeConverter.isLegal(&op.getBody()) && op.getResultTypes().empty(); - }); - target.addDynamicallyLegalOp( - [&](CallbackCallOp op) { return typeConverter.isLegal(op); }); - - if (failed(applyPartialConversion(getOperation(), target, std::move(patterns)))) { - signalPassFailure(); - } - } -}; - -std::unique_ptr createCatalystBufferizationPass() -{ - return std::make_unique(); -} - -} // namespace catalyst diff --git a/mlir/lib/Driver/CompilerDriver.cpp b/mlir/lib/Driver/CompilerDriver.cpp index 2875ad79a7..1b3256ead7 100644 --- a/mlir/lib/Driver/CompilerDriver.cpp +++ b/mlir/lib/Driver/CompilerDriver.cpp @@ -56,6 +56,7 @@ #include "llvm/Transforms/IPO/GlobalDCE.h" #include "Catalyst/IR/CatalystDialect.h" +#include "Catalyst/Transforms/BufferizableOpInterfaceImpl.h" #include "Catalyst/Transforms/Passes.h" #include "Driver/CatalystLLVMTarget.h" #include "Driver/CompilerDriver.h" @@ -964,6 +965,7 @@ int QuantumDriverMainFromCL(int argc, char **argv) registerLLVMTranslations(registry); // Register bufferization interfaces + catalyst::registerBufferizableOpInterfaceExternalModels(registry); catalyst::quantum::registerBufferizableOpInterfaceExternalModels(registry); // Register and parse command line options. diff --git a/mlir/lib/Driver/Pipelines.cpp b/mlir/lib/Driver/Pipelines.cpp index 46a3079c69..e3d7f3ad55 100644 --- a/mlir/lib/Driver/Pipelines.cpp +++ b/mlir/lib/Driver/Pipelines.cpp @@ -13,12 +13,14 @@ // limitations under the License. #include "Driver/Pipelines.h" +#include "Catalyst/IR/CatalystDialect.h" #include "Catalyst/Transforms/Passes.h" #include "Gradient/Transforms/Passes.h" #include "Mitigation/Transforms/Passes.h" #include "Quantum/IR/QuantumDialect.h" #include "Quantum/Transforms/Passes.h" #include "mhlo/transforms/passes.h" +#include "mlir/Dialect/Bufferization/IR/Bufferization.h" #include "mlir/InitAllDialects.h" #include "mlir/InitAllPasses.h" #include "mlir/Pass/PassManager.h" @@ -77,7 +79,15 @@ void createBufferizationPipeline(OpPassManager &pm) pm.addPass(mlir::bufferization::createEmptyTensorToAllocTensorPass()); pm.addNestedPass(mlir::bufferization::createBufferizationBufferizePass()); pm.addNestedPass(mlir::tensor::createTensorBufferizePass()); - pm.addPass(catalyst::createCatalystBufferizationPass()); + mlir::bufferization::OneShotBufferizationOptions catalyst_buffer_options; + catalyst_buffer_options.opFilter.allowDialect(); + catalyst_buffer_options.unknownTypeConverterFn = + [=](Value value, Attribute memorySpace, + const mlir::bufferization::BufferizationOptions &options) { + auto tensorType = cast(value.getType()); + return bufferization::getMemRefTypeWithStaticIdentityLayout(tensorType, memorySpace); + }; + pm.addPass(mlir::bufferization::createOneShotBufferizePass(catalyst_buffer_options)); pm.addNestedPass(mlir::createLinalgBufferizePass()); pm.addNestedPass(mlir::tensor::createTensorBufferizePass()); mlir::bufferization::OneShotBufferizationOptions quantum_buffer_options; diff --git a/mlir/test/Catalyst/BufferizationTest.mlir b/mlir/test/Catalyst/BufferizationTest.mlir index bcff9bff98..50c691fef9 100644 --- a/mlir/test/Catalyst/BufferizationTest.mlir +++ b/mlir/test/Catalyst/BufferizationTest.mlir @@ -12,7 +12,10 @@ // See the License for the specific language governing permissions and // limitations under the License. -// RUN: quantum-opt --catalyst-bufferize --split-input-file %s | FileCheck %s +// RUN: quantum-opt --split-input-file \ +// RUN: --pass-pipeline="builtin.module( \ +// RUN: one-shot-bufferize{unknown-type-conversion=identity-layout-map} \ +// RUN: )" %s | FileCheck %s ////////////////////// // Catalyst PrintOp // @@ -20,6 +23,7 @@ func.func @dbprint_val(%arg0: tensor) { + // CHECK: %0 = bufferization.to_memref %arg0 // CHECK: "catalyst.print"(%0) : (memref) -> () "catalyst.print"(%arg0) : (tensor) -> () @@ -30,6 +34,7 @@ func.func @dbprint_val(%arg0: tensor) { func.func @dbprint_memref(%arg0: tensor) { + // CHECK: %0 = bufferization.to_memref %arg0 // CHECK: "catalyst.print"(%0) <{print_descriptor}> : (memref) -> () "catalyst.print"(%arg0) {print_descriptor} : (tensor) -> () @@ -49,11 +54,11 @@ func.func @dbprint_str() { // ----- func.func @custom_call(%arg0: tensor<3x3xf64>) -> tensor<3x3xf64> { - // CHECK: [[memrefArg:%.+]] = bufferization.to_memref %arg0 : memref<3x3xf64> - // CHECK: [[alloc:%.+]] = bufferization.alloc_tensor() {{.*}}: tensor<3x3xf64> - // CHECK: [[allocmemref:%.+]] = bufferization.to_memref [[alloc]] : memref<3x3xf64> - // CHECK: catalyst.custom_call fn("lapack_dgesdd") ([[memrefArg]], [[allocmemref]]) {number_original_arg = array} : (memref<3x3xf64>, memref<3x3xf64>) -> () - // CHECK: [[res:%.+]] = bufferization.to_tensor [[allocmemref]] : memref<3x3xf64> + // CHECK: [[sourceAlloc:%.+]] = bufferization.to_memref %arg0 + // CHECK: [[destAlloc:%.+]] = memref.alloc() {{.*}}: memref<3x3xf64> + // CHECK: catalyst.custom_call fn("lapack_dgesdd") ([[sourceAlloc]], [[destAlloc]]) {number_original_arg = array} : + // CHECK-SAME: (memref<3x3xf64>, memref<3x3xf64>) -> () + // CHECK: [[res:%.+]] = bufferization.to_tensor [[destAlloc]] : memref<3x3xf64> // CHECK: return [[res]] : tensor<3x3xf64> %0 = catalyst.custom_call fn("lapack_dgesdd") (%arg0) : (tensor<3x3xf64>) -> (tensor<3x3xf64>) @@ -77,12 +82,11 @@ module @test1 { // CHECK-LABEL: @foo( // CHECK-SAME: [[arg0:%.+]]: tensor) func.func private @foo(%arg0: tensor) -> tensor { - // CHECK-DAG: [[memref0:%.+]] = bufferization.to_memref [[arg0]] - // CHECK-DAG: [[tensor1:%.+]] = bufferization.alloc_tensor - // CHECK: [[memref1:%.+]] = bufferization.to_memref [[tensor1]] - // CHECK: catalyst.callback_call @callback_1([[memref0]], [[memref1]]) + // CHECK-DAG: [[memref0:%.+]] = bufferization.to_memref [[arg0]] : memref + // CHECK-DAG: [[resAlloc:%.+]] = memref.alloc() {{.*}}: memref + // CHECK: catalyst.callback_call @callback_1([[memref0]], [[resAlloc]]) : (memref, memref) -> () %1 = catalyst.callback_call @callback_1(%arg0) : (tensor) -> (tensor) - // CHECK: [[retval:%.+]] = bufferization.to_tensor [[memref1]] + // CHECK: [[retval:%.+]] = bufferization.to_tensor [[resAlloc]] // CHECK: return [[retval]] return %1 : tensor } diff --git a/mlir/tools/quantum-opt/quantum-opt.cpp b/mlir/tools/quantum-opt/quantum-opt.cpp index eda34657df..26c08f78bd 100644 --- a/mlir/tools/quantum-opt/quantum-opt.cpp +++ b/mlir/tools/quantum-opt/quantum-opt.cpp @@ -25,6 +25,7 @@ #include "mhlo/IR/hlo_ops.h" #include "Catalyst/IR/CatalystDialect.h" +#include "Catalyst/Transforms/BufferizableOpInterfaceImpl.h" #include "Catalyst/Transforms/Passes.h" #include "Gradient/IR/GradientDialect.h" #include "Gradient/Transforms/Passes.h" @@ -62,6 +63,7 @@ int main(int argc, char **argv) registry.insert(); registry.insert(); + catalyst::registerBufferizableOpInterfaceExternalModels(registry); catalyst::quantum::registerBufferizableOpInterfaceExternalModels(registry); return mlir::asMainReturnCode(