-
Notifications
You must be signed in to change notification settings - Fork 47
Allow standard qml.cond
usage of qml.cond(pred, qml.some_gate)(*args, **kwargs)
in qjit
#1232
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
7050c1f
fd31386
dcb05ef
fb1d3a7
45a8240
dbfb689
6e5dba1
6161136
36279f8
fd291db
a3688e1
9a78577
cef0296
25ab728
22e7c52
5fae6b3
90edf44
070885b
1624221
ce24fcf
3415dfc
e339f97
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -24,6 +24,7 @@ | |
|
||
import jax | ||
import jax.numpy as jnp | ||
import pennylane as qml | ||
from jax._src.tree_util import PyTreeDef, tree_unflatten, treedef_is_leaf | ||
from jax.core import AbstractValue | ||
from pennylane import QueuingManager | ||
|
@@ -237,8 +238,17 @@ def conditional_fn(): | |
""" | ||
|
||
def _decorator(true_fn: Callable): | ||
|
||
if len(inspect.signature(true_fn).parameters): | ||
raise TypeError("Conditional 'True' function is not allowed to have any arguments") | ||
if isinstance(true_fn, type) and issubclass(true_fn, qml.operation.Operation): | ||
# Special treatment if conditional function body is a single pennylane gate | ||
# The qml.operation.Operation base class represents things that | ||
# can reasonably be considered as a gate, | ||
# e.g. qml.Hadamard, qml.RX, etc. | ||
return CondCallableSingleGateHandler(pred, true_fn) | ||
else: | ||
raise TypeError("Conditional 'True' function is not allowed to have any arguments") | ||
|
||
return CondCallable(pred, true_fn) | ||
|
||
return _decorator | ||
|
@@ -557,6 +567,15 @@ def __init__(self, pred, true_fn): | |
self._operation = None | ||
self.expansion_strategy = cond_expansion_strategy() | ||
|
||
def set_otherwise_fn(self, otherwise_fn): # pylint:disable=missing-function-docstring | ||
self.otherwise_fn = otherwise_fn | ||
|
||
def add_pred(self, _pred): | ||
self.preds.append(self._convert_predicate_to_bool(_pred)) | ||
|
||
def add_branch_fn(self, _branch_fn): | ||
self.branch_fns.append(_branch_fn) | ||
|
||
@property | ||
def operation(self): | ||
""" | ||
|
@@ -742,6 +761,78 @@ def __call__(self): | |
return self._call_during_interpretation() | ||
|
||
|
||
class CondCallableSingleGateHandler(CondCallable): | ||
""" | ||
Special CondCallable when the conditional body function is a single pennylane gate. | ||
|
||
A usual pennylane conditional call for a gate looks like | ||
`qml.cond(x == 42, qml.RX)(theta, wires=0)` | ||
|
||
Since gates are guaranteed to take in arguments (at the very least the wire argument), | ||
the usual CondCallable class, which expects the conditional body function to have no arguments, | ||
cannot be used. | ||
This class inherits from base CondCallable, but wraps the gate in a function with no arguments, | ||
and sends that function to CondCallable. | ||
Comment on lines
+774
to
+775
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Could we not just use this strategy all the time, instead of only when the target is a gate class ? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, I'm trying it right now. I feel like we can just disable those checks and it should work for any callable. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
This allows us to perform the conditional branch gate function with arguments. | ||
""" | ||
|
||
def __init__(self, pred, true_fn): # pylint:disable=super-init-not-called | ||
self.sgh_preds = [pred] | ||
self.sgh_branch_fns = [true_fn] | ||
self.sgh_otherwise_fn = None | ||
paul0403 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
def __call__(self, *args, **kwargs): | ||
def argless_true_fn(): | ||
self.sgh_branch_fns[0](*args, **kwargs) | ||
|
||
super().__init__(self.sgh_preds[0], argless_true_fn) | ||
|
||
if self.sgh_otherwise_fn is not None: | ||
|
||
def argless_otherwise_fn(): | ||
self.sgh_otherwise_fn(*args, **kwargs) | ||
|
||
super().set_otherwise_fn(argless_otherwise_fn) | ||
|
||
for i in range(1, len(self.sgh_branch_fns)): | ||
|
||
def argless_elseif_fn(i=i): # i=i to work around late binding | ||
self.sgh_branch_fns[i](*args, **kwargs) | ||
|
||
super().add_pred(self.sgh_preds[i]) | ||
super().add_branch_fn(argless_elseif_fn) | ||
|
||
return super().__call__() | ||
|
||
def else_if(self, _pred): | ||
""" | ||
Override the "can't have arguments" check in the original CondCallable's `else_if` | ||
""" | ||
|
||
def decorator(branch_fn): | ||
if isinstance(branch_fn, type) and issubclass(branch_fn, qml.operation.Operation): | ||
self.sgh_preds.append(_pred) | ||
self.sgh_branch_fns.append(branch_fn) | ||
return self | ||
else: | ||
raise TypeError( | ||
"Conditional 'else if' function can have arguments only if it is a PennyLane gate." | ||
) | ||
|
||
return decorator | ||
|
||
def otherwise(self, otherwise_fn): | ||
""" | ||
Override the "can't have arguments" check in the original CondCallable's `otherwise` | ||
""" | ||
if isinstance(otherwise_fn, type) and issubclass(otherwise_fn, qml.operation.Operation): | ||
self.sgh_otherwise_fn = otherwise_fn | ||
else: | ||
raise TypeError( | ||
"Conditional 'False' function can have arguments only if it is a PennyLane gate." | ||
) | ||
|
||
|
||
class ForLoopCallable: | ||
""" | ||
Wrapping for_loop decorator into a class so that the actual "ForLoop" operation object, which | ||
|
Uh oh!
There was an error while loading. Please reload this page.