Skip to content

Commit 4802d98

Browse files
committed
proof of concept 2
1 parent 1749d2b commit 4802d98

File tree

4 files changed

+31
-1
lines changed

4 files changed

+31
-1
lines changed

frontend/catalyst/from_plxpr.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -183,6 +183,7 @@ def handle_qnode(
183183
non_const_args = args[n_consts:]
184184

185185
f = partial(QFuncPlxprInterpreter(device, shots).eval, qfunc_jaxpr, consts)
186+
f = jax.jit(f)
186187

187188
return quantum_kernel_p.bind(
188189
wrap_init(f, debug_info=qfunc_jaxpr.debug_info),

frontend/catalyst/pipelines.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -173,6 +173,8 @@ def get_enforce_runtime_invariants_stage(_options: CompileOptions) -> List[str]:
173173
# keep inlining modules targeting the Catalyst runtime.
174174
# But qnodes targeting other backends may choose to lower
175175
# this into something else.
176+
"builtin.module(inline)",
177+
"split-multiple-tapes",
176178
"inline-nested-module",
177179
]
178180
return enforce_runtime_invariants
@@ -217,7 +219,6 @@ def get_bufferization_stage(_options: CompileOptions) -> List[str]:
217219
"""Returns the list of passes that performs bufferization"""
218220
bufferization = [
219221
"one-shot-bufferize{dialect-filter=memref}",
220-
"inline",
221222
"gradient-preprocess",
222223
"gradient-bufferize",
223224
"scf-bufferize",

mlir/lib/Catalyst/IR/CatalystDialect.cpp

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,17 @@ using namespace catalyst;
2828
//===----------------------------------------------------------------------===//
2929
// Catalyst dialect.
3030
//===----------------------------------------------------------------------===//
31+
namespace {
32+
struct CatalystInlinerInterface : public DialectInlinerInterface {
33+
using DialectInlinerInterface::DialectInlinerInterface;
34+
35+
/// Operations in Gradient dialect are always legal to inline.
36+
bool isLegalToInline(Operation *op, Region *, bool, IRMapping &valueMapping) const final
37+
{
38+
return isa<CallbackCallOp>(op);
39+
}
40+
};
41+
}
3142

3243
void CatalystDialect::initialize()
3344
{
@@ -40,6 +51,7 @@ void CatalystDialect::initialize()
4051
#define GET_OP_LIST
4152
#include "Catalyst/IR/CatalystOps.cpp.inc"
4253
>();
54+
addInterface<CatalystInlinerInterface>();
4355
}
4456

4557
//===----------------------------------------------------------------------===//

mlir/lib/Quantum/IR/QuantumDialect.cpp

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
1616
#include "mlir/IR/DialectImplementation.h" // needed for generated type parser
17+
#include "mlir/Transforms/InliningUtils.h"
1718
#include "llvm/ADT/TypeSwitch.h" // needed for generated type parser
1819

1920
#include "Quantum/IR/QuantumDialect.h"
@@ -28,6 +29,20 @@ using namespace catalyst::quantum;
2829

2930
#include "Quantum/IR/QuantumOpsDialect.cpp.inc"
3031

32+
namespace {
33+
struct QuantumInlinerInterface : public DialectInlinerInterface {
34+
using DialectInlinerInterface::DialectInlinerInterface;
35+
36+
/// Operations in Gradient dialect are always legal to inline.
37+
bool isLegalToInline(Operation *, Region *, bool, IRMapping &valueMapping) const final
38+
{
39+
return true;
40+
}
41+
};
42+
} // namespace
43+
44+
45+
3146
void QuantumDialect::initialize()
3247
{
3348
addTypes<
@@ -48,6 +63,7 @@ void QuantumDialect::initialize()
4863
declarePromisedInterfaces<bufferization::BufferizableOpInterface, QubitUnitaryOp, HermitianOp,
4964
HamiltonianOp, SampleOp, CountsOp, ProbsOp, StateOp, SetStateOp,
5065
SetBasisStateOp>();
66+
addInterface<QuantumInlinerInterface>();
5167
}
5268

5369
//===----------------------------------------------------------------------===//

0 commit comments

Comments
 (0)