Skip to content

Commit 806d6e6

Browse files
Migrate catalyst dialect to new one-shot bufferization (#1708)
**Context:** This work is based on #1027 . As part of the mlir update, the bufferization of the custom catalyst dialects need to migrate to the new one-shot bufferization interface, as opposed to the old pattern-rewrite style bufferization passes. See more context in #1027. The `Quantum` dialect was migrated in #1686 . **Description of the Change:** MIgrate `Catalyst` dialect to one-shot bufferization. **Benefits:** Align with mlir practices; one step closer to updating mlir. [sc-71487] --------- Co-authored-by: Tzung-Han Juang <[email protected]>
1 parent 1749d2b commit 806d6e6

File tree

16 files changed

+380
-279
lines changed

16 files changed

+380
-279
lines changed

doc/releases/changelog-dev.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -166,6 +166,7 @@
166166
The new mlir bufferization interface is required by jax 0.4.29 or higher.
167167
[(#1027)](https://github.com/PennyLaneAI/catalyst/pull/1027)
168168
[(#1686)](https://github.com/PennyLaneAI/catalyst/pull/1686)
169+
[(#1708)](https://github.com/PennyLaneAI/catalyst/pull/1708)
169170

170171
<h3>Documentation 📝</h3>
171172

frontend/catalyst/pipelines.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,7 @@ class CompileOptions:
9797

9898
def __post_init__(self):
9999
# Check that async runs must not be seeded
100-
if self.async_qnodes and self.seed != None:
100+
if self.async_qnodes and self.seed is not None:
101101
raise CompileError(
102102
"""
103103
Seeding has no effect on asynchronous QNodes,
@@ -107,7 +107,7 @@ def __post_init__(self):
107107
)
108108

109109
# Check that seed is 32-bit unsigned int
110-
if (self.seed != None) and (self.seed < 0 or self.seed > 2**32 - 1):
110+
if (self.seed is not None) and (self.seed < 0 or self.seed > 2**32 - 1):
111111
raise ValueError(
112112
"""
113113
Seed must be an unsigned 32-bit integer!
@@ -227,7 +227,8 @@ def get_bufferization_stage(_options: CompileOptions) -> List[str]:
227227
"empty-tensor-to-alloc-tensor",
228228
"func.func(bufferization-bufferize)",
229229
"func.func(tensor-bufferize)",
230-
"catalyst-bufferize", # Must be run before -- func.func(linalg-bufferize)
230+
# Catalyst dialect's bufferization must be run before --func.func(linalg-bufferize)
231+
"one-shot-bufferize{dialect-filter=catalyst unknown-type-conversion=identity-layout-map}",
231232
"func.func(linalg-bufferize)",
232233
"func.func(tensor-bufferize)",
233234
"one-shot-bufferize{dialect-filter=quantum}",
Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
// Copyright 2024-2025 Xanadu Quantum Technologies Inc.
2+
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#pragma once
16+
17+
using namespace mlir;
18+
19+
namespace catalyst {
20+
21+
void registerBufferizableOpInterfaceExternalModels(mlir::DialectRegistry &registry);
22+
23+
} // namespace catalyst

mlir/include/Catalyst/Transforms/Passes.td

Lines changed: 0 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -27,18 +27,6 @@ def DetensorizeSCFPass : Pass<"detensorize-scf"> {
2727
let constructor = "catalyst::createDetensorizeSCFPass()";
2828
}
2929

30-
def CatalystBufferizationPass : Pass<"catalyst-bufferize"> {
31-
let summary = "Bufferize tensors in catalyst utility ops.";
32-
33-
let dependentDialects = [
34-
"bufferization::BufferizationDialect",
35-
"memref::MemRefDialect",
36-
"index::IndexDialect"
37-
];
38-
39-
let constructor = "catalyst::createCatalystBufferizationPass()";
40-
}
41-
4230
def ArrayListToMemRefPass : Pass<"convert-arraylist-to-memref"> {
4331
let summary = "Lower array list operations to memref operations.";
4432
let description = [{

mlir/include/Catalyst/Transforms/Patterns.h

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,6 @@
2121

2222
namespace catalyst {
2323

24-
void populateBufferizationPatterns(mlir::TypeConverter &, mlir::RewritePatternSet &);
25-
2624
void populateScatterPatterns(mlir::RewritePatternSet &);
2725

2826
void populateHloCustomCallPatterns(mlir::RewritePatternSet &);

mlir/include/Quantum/Transforms/Patterns.h

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,6 @@
2121
namespace catalyst {
2222
namespace quantum {
2323

24-
void populateBufferizationLegality(mlir::TypeConverter &, mlir::ConversionTarget &);
25-
void populateBufferizationPatterns(mlir::TypeConverter &, mlir::RewritePatternSet &);
2624
void populateQIRConversionPatterns(mlir::TypeConverter &, mlir::RewritePatternSet &);
2725
void populateAdjointPatterns(mlir::RewritePatternSet &);
2826
void populateSelfInversePatterns(mlir::RewritePatternSet &);

mlir/lib/Catalyst/IR/CatalystDialect.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
#include "Catalyst/IR/CatalystDialect.h"
1616
#include "Catalyst/IR/CatalystOps.h"
17+
#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
1718
#include "mlir/IR/Builders.h"
1819
#include "mlir/IR/DialectImplementation.h" // needed for generated type parser
1920
#include "mlir/Interfaces/FunctionImplementation.h"
@@ -40,6 +41,9 @@ void CatalystDialect::initialize()
4041
#define GET_OP_LIST
4142
#include "Catalyst/IR/CatalystOps.cpp.inc"
4243
>();
44+
45+
declarePromisedInterfaces<bufferization::BufferizableOpInterface, PrintOp, CustomCallOp,
46+
CallbackCallOp, CallbackOp>();
4347
}
4448

4549
//===----------------------------------------------------------------------===//

0 commit comments

Comments
 (0)