-
Notifications
You must be signed in to change notification settings - Fork 47
Migrate catalyst
dialect to new one-shot bufferization
#1708
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
…e_catalyst_dialect
…e_catalyst_dialect
Codecov ReportAll modified and coverable lines are covered by tests ✅
Additional details and impacted files@@ Coverage Diff @@
## main #1708 +/- ##
=======================================
Coverage 96.51% 96.51%
=======================================
Files 82 82
Lines 9029 9029
Branches 861 861
=======================================
Hits 8714 8714
Misses 258 258
Partials 57 57 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks very good! Just a couple of questions for clarity.
"func.func(linalg-bufferize)", | ||
"func.func(tensor-bufferize)", | ||
"one-shot-bufferize{dialect-filter=quantum}", | ||
"func-bufferize", | ||
"func.func(finalizing-bufferize)", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🥳
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why is this removed now?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So this is a bit complicated and is the result of a bunch of things interplaying with each other, but the TLDR is I think this is nothing to worry about (and this removal needs to happen eventually anyway).
Ok, here goes my essay.
Consider this program:
@qjit
def my_print(x):
catalyst.debug.print(x)
my_print(42)
Right before bufferization (ie at 3_quantumcompilation.mlir
), it looks like
func.func public @jit_my_print(%arg0: tensor<i64>) attributes {llvm.emit_c_interface} {
catalyst.callback_call @callback_closure_124385755099728(%arg0) : (tensor<i64>) -> ()
return
}
catalyst.callback @callback_closure_124385755099728(tensor<i64>) attributes {argc = 1 : i64, id = 124385755099728 : i64, resc = 0 : i64}
still in abstract tensor land, nothing super fancy.
The part of the pipeline that concerns us is
"one-shot-bufferize{dialect-filter=catalyst}" OR "catalyst-bufferize",
"func-bufferize",
#"func.func(finalizing-bufferize)", # This is the one you were wondering about
"canonicalize",
Now, in the old pipeline, with --catalyst-bufferize
, this becomes
func.func public @jit_my_print(%arg0: tensor<i64>) attributes {llvm.emit_c_interface} {
%0 = bufferization.to_memref %arg0 : memref<i64>
catalyst.callback_call @callback_closure_139668113764752(%0) : (memref<i64>) -> ()
return
}
catalyst.callback @callback_closure_139668113764752(memref<i64>) attributes {argc = 1 : i64, id = 139668113764752 : i64, resc = 0 : i64}
makes sense: we just add a tensor2memref op.
(Forgive me if the ids look different, they were in multiple debug sessions. Just imagine they are the same)
The next pass is --func-bufferize
. This converts the %arg0
argument from a tensor to a memref. However, because the user of %arg0
(the to_memref op) still wants a tensor, this will insert an extra memref2tensor op from the now-memref argument:
func.func public @jit_my_print(%arg0: memref<i64>) attributes {llvm.emit_c_interface} {
%0 = bufferization.to_tensor %arg0 : memref<i64>
%1 = bufferization.to_memref %0 : memref<i64>
catalyst.callback_call @callback_closure_139668113764752(%1) : (memref<i64>) -> ()
return
}
catalyst.callback @callback_closure_139668113764752(memref<i64>) attributes {argc = 1 : i64, id = 139668113764752 : i64, resc = 0 : i64}
In the old pipeline, we now hit the --finalizing-bufferize
. This is one of the old bufferization passes that will be removed. What this pass does is just canceling the inverse mem2tensor/tensor2mem pairs:
func.func public @jit_my_print(%arg0: memref<i64>) attributes {llvm.emit_c_interface} {
catalyst.callback_call @callback_closure_139668113764752(%arg0) : (memref<i64>) -> ()
return
}
catalyst.callback @callback_closure_139668113764752(memref<i64>) attributes {argc = 1 : i64, id = 139668113764752 : i64, resc = 0 : i64}
Note that the passes fails if they don't come in pairs:
The removal of those operations is only possible if the operations only
exist in pairs, i.e., all uses ofbufferization.to_tensor
operations are
bufferization.to_buffer
operations.This pass will fail if not all operations can be removed
What about the new pipeline? We still start in tensor land (3_quantumcompilationpass.mlir
), but this time we go through --one-shot-bufferization
:
func.func public @jit_my_print(%arg0: tensor<i64>) attributes {llvm.emit_c_interface} {
%0 = bufferization.to_memref %arg0 : memref<i64, strided<[], offset: ?>>
catalyst.callback_call @callback_closure_139668113764752(%0) : (memref<i64, strided<[], offset: ?>>) -> ()
return
}
catalyst.callback @callback_closure_139668113764752(memref<i64>) attributes {argc = 1 : i64, id = 139668113764752 : i64, resc = 0 : i64}
It looks almost the same as the old --catalyst-bufferize
, but with a crucial difference: instead of a simple pattern rewrite, we are now using the new getBuffer
and replaceOpWithNewBufferizedOp
methods. Here's the punchline: the new one-shot bufferization inserts dynamic memory layout by default!
By default, One-Shot Bufferize choose the most dynamic memref type wrt. layout maps.
When bufferizing the above IR, One-Shot Bufferize inserts a to_memref ops with dynamic offset and strides
This means the result of the 2memref op is not just a simple memref<i64>
anymore, but a fully strided memref<i64, strided<[], offset: ?>>
!
Now, this itself is not a problem. But as we go through --func-bufferize
to insert the memref2tensor for the arg:
func.func public @jit_my_print(%arg0: memref<i64>) attributes {llvm.emit_c_interface} {
%0 = bufferization.to_tensor %arg0 : memref<i64>
%1 = bufferization.to_memref %0 : memref<i64, strided<[], offset: ?>>
catalyst.callback_call @callback_closure_139668113764752(%1) : (memref<i64, strided<[], offset: ?>>) -> ()
return
}
catalyst.callback @callback_closure_139668113764752(memref<i64>) attributes {argc = 1 : i64, id = 139668113764752 : i64, resc = 0 : i64}
and attempt to run --finalizing-bufferize
:
./my_print/3_QuantumCompilationPass.mlir:4:34: error: failed to materialize conversion for result #0 of operation 'bufferization.to_memref' that remained live after conversion
func.func public @jit_my_print(%arg0: tensor<i64>) attributes {llvm.emit_c_interface} {
^
./my_print/3_QuantumCompilationPass.mlir:4:34: note: see current operation: %1 = "bufferization.to_memref"(%0) : (tensor<i64>) -> memref<i64, strided<[], offset: ?>>
./my_print/3_QuantumCompilationPass.mlir:5:5: note: see existing live user here: catalyst.callback_call @callback_closure_139668113764752(%1) : (memref<i64, strided<[], offset: ?>>) -> ()
catalyst.callback_call @callback_closure_139668113764752(%arg0) : (tensor<i64>) -> ()
The finalizing pass cannot remove it, causing an error! And this is because, now the types on the 2tensor and 2memref ops are different (memref<i64>
vs memref<i64, strided<[], offset: ?>>
), the finalization pass no longer identifies them as a possible "cancel inverse" pair!
If we skip the finalization pass and just canonicalize:
func.func public @jit_my_print(%arg0: memref<i64>) attributes {llvm.emit_c_interface} {
%cast = memref.cast %arg0 : memref<i64> to memref<i64, strided<[], offset: ?>>
catalyst.callback_call @callback_closure_139668113764752(%cast) : (memref<i64, strided<[], offset: ?>>) -> ()
return
}
catalyst.callback @callback_closure_139668113764752(memref<i64>) attributes {argc = 1 : i64, id = 139668113764752 : i64, resc = 0 : i64}
We get the correct cast between the unstrided and the strided memrefs, and all is good again.
I think having the new strided types is ok, if they are returned by the getBuffer
methods, which we are supposed to go through in the new one shot bufferization. Also, the finalizing pass is supposed to be removed anyway. If we can decouple from it here already at this somewhat early stage, I am more than happy to switch over now.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This means the result of the 2memref op is not just a simple memref anymore, but a fully strided memref<i64, strided<[], offset: ?>> !
🤔 this may be a problem later
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
From what I remember, when bufferizing functions we will need to use an option call identity-layout-map
. Would it be possible to try this option with the current level of progress? Or would that not be possible? @paul0403
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
From what I remember, when bufferizing functions we will need to use an option call
identity-layout-map
. Would it be possible to try this option with the current level of progress? Or would that not be possible? @paul0403
This worked, thank you! I can now produce buffers without the unnecessary strides.
I still think we should remove the finalizing bufferize pass since it is removed in new llvm versions.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good! I think this answers everything on my side. I'll let David close this.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thus memory write must be true for custom call
…e_catalyst_dialect
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks @paul0403, this is good but you are right we should be extra careful with this one given the complexities of passing memory to external functions!
if (isa<MemRefType>(opOperand.get().getType())) { | ||
return true; | ||
} | ||
return false; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks like there were a few different threads discussing whether the custom calls write to memory or not. If they don't, it would be great to keep this as false, although if we are uncertain, the safe choice is to set this to true.
The weird part is that the results are already stored in fresh new buffers, so writing to the inputs and producing the output in new buffers is kind of a waste 🤔
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, there were quite a few discussions on this.
We eventually decided to set this to true. The lapack kernels do not guarantee that the source arrays will be untouched https://www.netlib.org/lapack/lug/node112.html
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The weird part is that the results are already stored in fresh new buffers, so writing to the inputs and producing the output in new buffers is kind of a waste 🤔
I agree, but I think there might be kernels in lapack that both compute a new matrix and perform some action on the input matrix. I don't know exactly which but if their documentation says so then I think we should follow that, even if it unfortunately means an extra copy.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is it correct that so far we didn't make an additional copy?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
By "so far" do you mean the lapack kernels we already support?
I am not a linalg expert, so I can't say for sure. I know that, the lapack pytests we have all pass, regardless of whether this memwrite is true or false. But I would still argue for the side of safety, because:
- The tests might be passing simply because there happens to be no users of the source tensor operand after the lapack kernel is called, so even if the source memrefs are wrongly changed, we did not detect them
- We might want to add new kernels in the future, depending on whichever new jax linalg function a user might want. We should not restrict to only support kernels who do not write into source operands.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I just thought of another strategy: we can keep a whitelist of lapack kernels to keep track of the ones that do not write to the source operand, and only create a copy for ops not in the white list!
// We can safely say false because CallbackCallOp's memrefs | ||
// will be put in a JAX array and JAX arrays are immutable. | ||
// | ||
// Unlike NumPy arrays, JAX arrays are always immutable. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If this is true, why are we allocating result buffers and placing them into the argument list, wouldn't those also be converted to immutable JAX arrays instead of being written to?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
At the moment we do not have destination passing style, so the new result buffers do not enter in the discussion of "whether the op read/write into operand" at all: they don't exist yet when the bufferization pipeline is run. All operands are purely inputs.
I am open to changing things to DPS for optimization, but that is a lot of work and for now I want to get the basic migration done to unblock upgrading llvm.
As for your question itself, the callback op runtime does distinguish between the source arguments and the newly allocated destination buffer arguments: https://github.com/PennyLaneAI/catalyst/blob/main/runtime/lib/registry/Registry.cpp#L133
The callback runtime already performs a copy to prevent aliasing. I don't remember how this ties into the jax array, @erick-xanadu I can't seem to find it in frontend anymore 😆
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If this is true, why are we allocating result buffers and placing them into the argument list, wouldn't those also be converted to immutable JAX arrays instead of being written to?
I'm calling in-argument those arguments which are operands to the operation before bufferization and out-arguments to the result of the operation after bufferization. In-arguments and out-arguments will both be operands after bufferization. In-arguments are always just read. The out-arguments are always the contents of JAX-arrays copied over to this runtime allocated memory.
…e_catalyst_dialect
…e_catalyst_dialect
@@ -227,12 +227,12 @@ def get_bufferization_stage(_options: CompileOptions) -> List[str]: | |||
"empty-tensor-to-alloc-tensor", | |||
"func.func(bufferization-bufferize)", | |||
"func.func(tensor-bufferize)", | |||
"catalyst-bufferize", # Must be run before -- func.func(linalg-bufferize) | |||
# Catalyst dialect's bufferization must be run before --func.func(linalg-bufferize) | |||
"one-shot-bufferize{dialect-filter=catalyst unknown-type-conversion=identity-layout-map}", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
restore finalizing bufferize passs
Context:
This work is based on #1027 .
As part of the mlir update, the bufferization of the custom catalyst dialects need to migrate to the new one-shot bufferization interface, as opposed to the old pattern-rewrite style bufferization passes.
See more context in #1027.
The
Quantum
dialect was migrated in #1686 .Description of the Change:
MIgrate
Catalyst
dialect to one-shot bufferization.Benefits:
Align with mlir practices; one step closer to updating mlir.
[sc-71487]