Skip to content

Migrate catalyst dialect to new one-shot bufferization #1708

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 37 commits into from
May 12, 2025
Merged
Show file tree
Hide file tree
Changes from 18 commits
Commits
Show all changes
37 commits
Select commit Hold shift + click to select a range
a128496
init; some boilerplate
paul0403 Apr 30, 2025
db5dccc
more boilerplate
paul0403 Apr 30, 2025
99b2487
boilerplate for quantum-opt
paul0403 Apr 30, 2025
a7ff16d
boilerplate...
paul0403 Apr 30, 2025
a94e33c
Merge remote-tracking branch 'origin/main' into paul0403/new_bufferiz…
paul0403 May 1, 2025
a30ab8a
(cherry pick) Add CustomCall bufferization
tzunghanjuang Aug 29, 2024
4aeb8df
changelog
paul0403 May 2, 2025
166314d
Merge remote-tracking branch 'origin/main' into paul0403/new_bufferiz…
paul0403 May 2, 2025
7508916
add remove hint in cmakelists
paul0403 May 2, 2025
dbf3ff7
remove pattern header from quantum dialect
paul0403 May 6, 2025
3f7c608
add callback and callbackcall op
paul0403 May 6, 2025
8f59346
update cpp pipeline
paul0403 May 6, 2025
e603605
missed include
paul0403 May 6, 2025
69fcac2
remove old catalyst dialect bufferization pass
paul0403 May 7, 2025
13037be
format
paul0403 May 7, 2025
66f4ab0
codefactor
paul0403 May 7, 2025
41a7e14
Merge remote-tracking branch 'origin/main' into paul0403/new_bufferiz…
paul0403 May 7, 2025
1a3ff7d
custom call also allocates
paul0403 May 7, 2025
045b15b
do not hint memory write for custom op when not in memref land
paul0403 May 7, 2025
ad942d9
Merge remote-tracking branch 'origin/main' into paul0403/new_bufferiz…
paul0403 May 8, 2025
fdb16e0
lapack kernels might write into source array
paul0403 May 8, 2025
4d5f983
Merge remote-tracking branch 'origin/main' into paul0403/new_bufferiz…
paul0403 May 8, 2025
38bf1c4
add bufferization interface doc banner
paul0403 May 8, 2025
fe4e944
add {} to a one-line if block
paul0403 May 8, 2025
3c6dfcd
remove aliasing operand method from callback op: it does not have ten…
paul0403 May 8, 2025
098d580
Merge remote-tracking branch 'origin/main' into paul0403/new_bufferiz…
paul0403 May 8, 2025
b39f28a
(prototype) make a white list of custom calls that won't copy
paul0403 May 9, 2025
4cd1843
Merge remote-tracking branch 'origin/main' into paul0403/new_bufferiz…
paul0403 May 9, 2025
3289d19
Set identity layout map option. This avoids the strides.
paul0403 May 9, 2025
4f48b1b
no copy: jax already does the copy around the lapack kernels
paul0403 May 9, 2025
7c6759f
name
paul0403 May 9, 2025
fb5c421
Merge remote-tracking branch 'origin/main' into paul0403/new_bufferiz…
paul0403 May 12, 2025
c949ce7
add comment about jax shim layer's copy
paul0403 May 12, 2025
0b70112
a bit more comment
paul0403 May 12, 2025
b06ec69
add back finalizing bufferize pass
paul0403 May 12, 2025
d5f4d3d
add back finalizing pass in cpp pipeline
paul0403 May 12, 2025
a752ead
CI
paul0403 May 12, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions doc/releases/changelog-dev.md
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,7 @@
The new mlir bufferization interface is required by jax 0.4.29 or higher.
[(#1027)](https://github.com/PennyLaneAI/catalyst/pull/1027)
[(#1686)](https://github.com/PennyLaneAI/catalyst/pull/1686)
[(#1708)](https://github.com/PennyLaneAI/catalyst/pull/1708)

<h3>Documentation 📝</h3>

Expand Down
8 changes: 4 additions & 4 deletions frontend/catalyst/pipelines.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ class CompileOptions:

def __post_init__(self):
# Check that async runs must not be seeded
if self.async_qnodes and self.seed != None:
if self.async_qnodes and self.seed is not None:
raise CompileError(
"""
Seeding has no effect on asynchronous QNodes,
Expand All @@ -107,7 +107,7 @@ def __post_init__(self):
)

# Check that seed is 32-bit unsigned int
if (self.seed != None) and (self.seed < 0 or self.seed > 2**32 - 1):
if (self.seed is not None) and (self.seed < 0 or self.seed > 2**32 - 1):
raise ValueError(
"""
Seed must be an unsigned 32-bit integer!
Expand Down Expand Up @@ -227,12 +227,12 @@ def get_bufferization_stage(_options: CompileOptions) -> List[str]:
"empty-tensor-to-alloc-tensor",
"func.func(bufferization-bufferize)",
"func.func(tensor-bufferize)",
"catalyst-bufferize", # Must be run before -- func.func(linalg-bufferize)
# Catalyst dialect's bufferization must be run before --func.func(linalg-bufferize)
"one-shot-bufferize{dialect-filter=catalyst}",
"func.func(linalg-bufferize)",
"func.func(tensor-bufferize)",
"one-shot-bufferize{dialect-filter=quantum}",
"func-bufferize",
"func.func(finalizing-bufferize)",
"canonicalize", # Remove dead memrefToTensorOp's
"gradient-postprocess",
# introduced during gradient-bufferize of callbacks
Expand Down
23 changes: 23 additions & 0 deletions mlir/include/Catalyst/Transforms/BufferizableOpInterfaceImpl.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
// Copyright 2024-2025 Xanadu Quantum Technologies Inc.

// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at

// http://www.apache.org/licenses/LICENSE-2.0

// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

using namespace mlir;

namespace catalyst {

void registerBufferizableOpInterfaceExternalModels(mlir::DialectRegistry &registry);

} // namespace catalyst
12 changes: 0 additions & 12 deletions mlir/include/Catalyst/Transforms/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -27,18 +27,6 @@ def DetensorizeSCFPass : Pass<"detensorize-scf"> {
let constructor = "catalyst::createDetensorizeSCFPass()";
}

def CatalystBufferizationPass : Pass<"catalyst-bufferize"> {
let summary = "Bufferize tensors in catalyst utility ops.";

let dependentDialects = [
"bufferization::BufferizationDialect",
"memref::MemRefDialect",
"index::IndexDialect"
];

let constructor = "catalyst::createCatalystBufferizationPass()";
}

def ArrayListToMemRefPass : Pass<"convert-arraylist-to-memref"> {
let summary = "Lower array list operations to memref operations.";
let description = [{
Expand Down
2 changes: 0 additions & 2 deletions mlir/include/Catalyst/Transforms/Patterns.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,6 @@

namespace catalyst {

void populateBufferizationPatterns(mlir::TypeConverter &, mlir::RewritePatternSet &);

void populateScatterPatterns(mlir::RewritePatternSet &);

void populateHloCustomCallPatterns(mlir::RewritePatternSet &);
Expand Down
2 changes: 0 additions & 2 deletions mlir/include/Quantum/Transforms/Patterns.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,6 @@
namespace catalyst {
namespace quantum {

void populateBufferizationLegality(mlir::TypeConverter &, mlir::ConversionTarget &);
void populateBufferizationPatterns(mlir::TypeConverter &, mlir::RewritePatternSet &);
void populateQIRConversionPatterns(mlir::TypeConverter &, mlir::RewritePatternSet &);
void populateAdjointPatterns(mlir::RewritePatternSet &);
void populateSelfInversePatterns(mlir::RewritePatternSet &);
Expand Down
4 changes: 4 additions & 0 deletions mlir/lib/Catalyst/IR/CatalystDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

#include "Catalyst/IR/CatalystDialect.h"
#include "Catalyst/IR/CatalystOps.h"
#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/DialectImplementation.h" // needed for generated type parser
#include "mlir/Interfaces/FunctionImplementation.h"
Expand All @@ -40,6 +41,9 @@ void CatalystDialect::initialize()
#define GET_OP_LIST
#include "Catalyst/IR/CatalystOps.cpp.inc"
>();

declarePromisedInterfaces<bufferization::BufferizableOpInterface, PrintOp, CustomCallOp,
CallbackCallOp, CallbackOp>();
}

//===----------------------------------------------------------------------===//
Expand Down
Loading