Skip to content

Allow standard qml.cond usage of qml.cond(pred, qml.some_gate)(*args, **kwargs) in qjit #1232

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 22 commits into from
Oct 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions doc/releases/changelog-0.9.0.md
Original file line number Diff line number Diff line change
Expand Up @@ -363,6 +363,10 @@
lowering of the scatter operation.
[(#1214)](https://github.com/PennyLaneAI/catalyst/pull/1214)

* Fixes a bug where conditional-ed single gates cannot be used in qjit,
e.g. `qml.cond(x > 1, qml.Hadamard)(wires=0)`.
[(#1232)](https://github.com/PennyLaneAI/catalyst/pull/1232)

<h3>Internal changes</h3>

* Remove deprecated pennylane code across the frontend.
Expand Down
93 changes: 92 additions & 1 deletion frontend/catalyst/api_extensions/control_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@

import jax
import jax.numpy as jnp
import pennylane as qml
from jax._src.tree_util import PyTreeDef, tree_unflatten, treedef_is_leaf
from jax.core import AbstractValue
from pennylane import QueuingManager
Expand Down Expand Up @@ -237,8 +238,17 @@ def conditional_fn():
"""

def _decorator(true_fn: Callable):

if len(inspect.signature(true_fn).parameters):
raise TypeError("Conditional 'True' function is not allowed to have any arguments")
if isinstance(true_fn, type) and issubclass(true_fn, qml.operation.Operation):
# Special treatment if conditional function body is a single pennylane gate
# The qml.operation.Operation base class represents things that
# can reasonably be considered as a gate,
# e.g. qml.Hadamard, qml.RX, etc.
return CondCallableSingleGateHandler(pred, true_fn)
else:
raise TypeError("Conditional 'True' function is not allowed to have any arguments")

return CondCallable(pred, true_fn)

return _decorator
Expand Down Expand Up @@ -557,6 +567,15 @@ def __init__(self, pred, true_fn):
self._operation = None
self.expansion_strategy = cond_expansion_strategy()

def set_otherwise_fn(self, otherwise_fn): # pylint:disable=missing-function-docstring
self.otherwise_fn = otherwise_fn

def add_pred(self, _pred):
self.preds.append(self._convert_predicate_to_bool(_pred))

def add_branch_fn(self, _branch_fn):
self.branch_fns.append(_branch_fn)

@property
def operation(self):
"""
Expand Down Expand Up @@ -742,6 +761,78 @@ def __call__(self):
return self._call_during_interpretation()


class CondCallableSingleGateHandler(CondCallable):
"""
Special CondCallable when the conditional body function is a single pennylane gate.

A usual pennylane conditional call for a gate looks like
`qml.cond(x == 42, qml.RX)(theta, wires=0)`

Since gates are guaranteed to take in arguments (at the very least the wire argument),
the usual CondCallable class, which expects the conditional body function to have no arguments,
cannot be used.
This class inherits from base CondCallable, but wraps the gate in a function with no arguments,
and sends that function to CondCallable.
Comment on lines +774 to +775
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we not just use this strategy all the time, instead of only when the target is a gate class ?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I'm trying it right now. I feel like we can just disable those checks and it should work for any callable.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This allows us to perform the conditional branch gate function with arguments.
"""

def __init__(self, pred, true_fn): # pylint:disable=super-init-not-called
self.sgh_preds = [pred]
self.sgh_branch_fns = [true_fn]
self.sgh_otherwise_fn = None

def __call__(self, *args, **kwargs):
def argless_true_fn():
self.sgh_branch_fns[0](*args, **kwargs)

super().__init__(self.sgh_preds[0], argless_true_fn)

if self.sgh_otherwise_fn is not None:

def argless_otherwise_fn():
self.sgh_otherwise_fn(*args, **kwargs)

super().set_otherwise_fn(argless_otherwise_fn)

for i in range(1, len(self.sgh_branch_fns)):

def argless_elseif_fn(i=i): # i=i to work around late binding
self.sgh_branch_fns[i](*args, **kwargs)

super().add_pred(self.sgh_preds[i])
super().add_branch_fn(argless_elseif_fn)

return super().__call__()

def else_if(self, _pred):
"""
Override the "can't have arguments" check in the original CondCallable's `else_if`
"""

def decorator(branch_fn):
if isinstance(branch_fn, type) and issubclass(branch_fn, qml.operation.Operation):
self.sgh_preds.append(_pred)
self.sgh_branch_fns.append(branch_fn)
return self
else:
raise TypeError(
"Conditional 'else if' function can have arguments only if it is a PennyLane gate."
)

return decorator

def otherwise(self, otherwise_fn):
"""
Override the "can't have arguments" check in the original CondCallable's `otherwise`
"""
if isinstance(otherwise_fn, type) and issubclass(otherwise_fn, qml.operation.Operation):
self.sgh_otherwise_fn = otherwise_fn
else:
raise TypeError(
"Conditional 'False' function can have arguments only if it is a PennyLane gate."
)


class ForLoopCallable:
"""
Wrapping for_loop decorator into a class so that the actual "ForLoop" operation object, which
Expand Down
93 changes: 93 additions & 0 deletions frontend/test/lit/test_if_else.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,99 @@ def otherwise():
# -----


# CHECK-LABEL: public @jit_circuit_single_gate
@qjit(target="mlir")
@qml.qnode(qml.device("lightning.qubit", wires=1))
def circuit_single_gate(n: int):
# pylint: disable=line-too-long
# CHECK-DAG: [[c5:%[a-zA-Z0-9_]+]] = stablehlo.constant dense<5> : tensor<i64>
# CHECK-DAG: [[c6:%[a-zA-Z0-9_]+]] = stablehlo.constant dense<6> : tensor<i64>
# CHECK-DAG: [[c7:%[a-zA-Z0-9_]+]] = stablehlo.constant dense<7> : tensor<i64>
# CHECK-DAG: [[c8:%[a-zA-Z0-9_]+]] = stablehlo.constant dense<8> : tensor<i64>
# CHECK-DAG: [[c9:%[a-zA-Z0-9_]+]] = stablehlo.constant dense<9> : tensor<i64>
# CHECK-DAG: [[b_t5:%[a-zA-Z0-9_]+]] = stablehlo.compare LE, %arg0, [[c5]], SIGNED : (tensor<i64>, tensor<i64>) -> tensor<i1>
# CHECK-DAG: [[b_t6:%[a-zA-Z0-9_]+]] = stablehlo.compare LE, %arg0, [[c6]], SIGNED : (tensor<i64>, tensor<i64>) -> tensor<i1>
# CHECK-DAG: [[b_t7:%[a-zA-Z0-9_]+]] = stablehlo.compare LE, %arg0, [[c7]], SIGNED : (tensor<i64>, tensor<i64>) -> tensor<i1>
# CHECK-DAG: [[b_t8:%[a-zA-Z0-9_]+]] = stablehlo.compare LE, %arg0, [[c8]], SIGNED : (tensor<i64>, tensor<i64>) -> tensor<i1>
# CHECK-DAG: [[b_t9:%[a-zA-Z0-9_]+]] = stablehlo.compare LE, %arg0, [[c9]], SIGNED : (tensor<i64>, tensor<i64>) -> tensor<i1>
# CHECK-DAG: [[qreg_0:%[a-zA-Z0-9_]+]] = quantum.alloc

# CHECK: [[b5:%[a-zA-Z0-9_]+]] = tensor.extract [[b_t5]]
# CHECK: [[qreg_out:%.+]] = scf.if [[b5]]
# CHECK-DAG: [[q0:%[a-zA-Z0-9_]+]] = quantum.extract [[qreg_0]]
# CHECK-DAG: [[q1:%[a-zA-Z0-9_]+]] = quantum.custom "PauliX"() [[q0]]
# CHECK-DAG: [[qreg_1:%[a-zA-Z0-9_]+]] = quantum.insert [[qreg_0]][ {{[%a-zA-Z0-9_]+}}], [[q1]]
# CHECK: scf.yield [[qreg_1]]

# CHECK: else
# CHECK-DAG: [[q2:%[a-zA-Z0-9_]+]] = quantum.extract [[qreg_0]]
# CHECK-DAG: [[q3:%[a-zA-Z0-9_]+]] = quantum.custom "Hadamard"() [[q2]]
# CHECK-DAG: [[qreg_2:%[a-zA-Z0-9_]+]] = quantum.insert [[qreg_0]][ {{[%a-zA-Z0-9_]+}}], [[q3]]
# CHECK: scf.yield [[qreg_2]]
qml.cond(n <= 5, qml.PauliX, qml.Hadamard)(wires=0)

# CHECK: [[b6:%[a-zA-Z0-9_]+]] = tensor.extract [[b_t6]]
# CHECK: [[qreg_out1:%.+]] = scf.if [[b6]]
# CHECK-DAG: [[q4:%[a-zA-Z0-9_]+]] = quantum.extract [[qreg_out]]
# CHECK-DAG: [[q5:%[a-zA-Z0-9_]+]] = quantum.custom "RX"({{%.+}}) [[q4]]
# CHECK-DAG: [[qreg_3:%[a-zA-Z0-9_]+]] = quantum.insert [[qreg_out]][ {{[%a-zA-Z0-9_]+}}], [[q5]]
# CHECK: scf.yield [[qreg_3]]
# CHECK: else
# CHECK: scf.yield [[qreg_out]]

qml.cond(n <= 6, qml.RX)(3.14, wires=0)

# CHECK: [[b7:%[a-zA-Z0-9_]+]] = tensor.extract [[b_t7]]
# CHECK: [[qreg_out2:%.+]] = scf.if [[b7]]
# CHECK-DAG: [[q7:%[a-zA-Z0-9_]+]] = quantum.extract [[qreg_out1]]
# CHECK-DAG: [[q8:%[a-zA-Z0-9_]+]] = quantum.custom "Hadamard"() [[q7]]
# pylint: disable=line-too-long
# CHECK-DAG: [[qreg_4:%[a-zA-Z0-9_]+]] = quantum.insert [[qreg_out1]][ {{[%a-zA-Z0-9_]+}}], [[q8]]
# CHECK: scf.yield [[qreg_4]]
# CHECK: else {
# CHECK: [[b8:%[a-zA-Z0-9_]+]] = tensor.extract [[b_t8]]
# CHECK: [[qreg_out3:%.+]] = scf.if [[b8]]
# CHECK-DAG: [[q9:%[a-zA-Z0-9_]+]] = quantum.extract [[qreg_out1]]
# CHECK-DAG: [[q10:%[a-zA-Z0-9_]+]] = quantum.custom "PauliY"() [[q9]]
# CHECK-DAG: [[qreg_5:%[a-zA-Z0-9_]+]] = quantum.insert [[qreg_out1]][ {{[%a-zA-Z0-9_]+}}], [[q10]]
# CHECK: scf.yield [[qreg_5]]
# CHECK: else {
# CHECK: [[b9:%[a-zA-Z0-9_]+]] = tensor.extract [[b_t9]]
# CHECK: [[qreg_out4:%.+]] = scf.if [[b9]]
# CHECK-DAG: [[q11:%[a-zA-Z0-9_]+]] = quantum.extract [[qreg_out1]]
# CHECK-DAG: [[q12:%[a-zA-Z0-9_]+]] = quantum.custom "PauliZ"() [[q11]]
# CHECK-DAG: [[qreg_6:%[a-zA-Z0-9_]+]] = quantum.insert [[qreg_out1]][ {{[%a-zA-Z0-9_]+}}], [[q12]]
# CHECK: scf.yield [[qreg_6]]
# CHECK: else {
# CHECK-DAG: [[q13:%[a-zA-Z0-9_]+]] = quantum.extract [[qreg_out1]]
# CHECK-DAG: [[q14:%[a-zA-Z0-9_]+]] = quantum.custom "PauliX"() [[q13]]
# CHECK-DAG: [[qreg_7:%[a-zA-Z0-9_]+]] = quantum.insert [[qreg_out1]][ {{[%a-zA-Z0-9_]+}}], [[q14]]
# CHECK: scf.yield [[qreg_7]]
# CHECK: scf.yield [[qreg_out4]]
# CHECK: scf.yield [[qreg_out3]]
qml.cond(
n <= 7,
qml.Hadamard,
qml.PauliX,
(
(n <= 8, qml.PauliY),
(n <= 9, qml.PauliZ),
),
)(wires=0)

# CHECK: [[qreg_ret:%.+]] = quantum.extract [[qreg_out2]][ 0]
# CHECK: [[qobs:%.+]] = quantum.compbasis [[qreg_ret]] : !quantum.obs
# CHECK: [[ret:%.+]] = quantum.probs [[qobs]]
# CHECK: return [[ret]]
return qml.probs()


print(circuit_single_gate.mlir)


# -----


# CHECK-LABEL: test_convert_element_type
@qjit
def test_convert_element_type(i: int, f: float):
Expand Down
71 changes: 68 additions & 3 deletions frontend/test/pytest/test_conditionals.py
Original file line number Diff line number Diff line change
Expand Up @@ -418,16 +418,39 @@ def conditional_flip():
def test_argument_error_with_callables(self):
"""Test for the error when arguments are supplied and the target is not a function."""

@qml.qnode(qml.device("lightning.qubit", wires=1))
def f(x: int):

qml.cond(x < 5, qml.Hadamard)(0)
res = qml.cond(x < 5, lambda z: z + 1)(0)

return qml.probs()
return res

with pytest.raises(TypeError, match="not allowed to have any arguments"):
qjit(f)

def f(x: int):

res = qml.cond(x < 5, qml.Hadamard, lambda z: z + 1)(0)

return res

with pytest.raises(
TypeError,
match="Conditional 'False' function can have arguments only if it is a PennyLane gate.",
):
qjit(f)

def f(x: int):

res = qml.cond(x < 5, qml.Hadamard, qml.Hadamard, ((x < 6, lambda z: z + 1),))(0)

return res

with pytest.raises(
TypeError,
match="Conditional 'else if' function can have arguments only if it is a PennyLane gate.",
):
qjit(f)


class TestInterpretationConditional:
"""Test that the conditional operation's execution is semantically equivalent
Expand Down Expand Up @@ -676,6 +699,48 @@ def branch_f():
assert func(True) == 1
assert func(False) == 0

def test_cond_single_gate(self, backend):
"""
Test standard pennylane qml.cond usage on single quantum gates.
Fixes https://github.com/PennyLaneAI/catalyst/issues/449
"""

@qml.qnode(qml.device(backend, wires=2))
def func(x, y):
qml.cond(x == 42, qml.Hadamard, qml.PauliX)(wires=0)
qml.cond(x == 42, qml.RY, qml.RZ)(1.5, wires=0)
qml.cond(x == 42, qml.CNOT)(wires=[1, 0])
qml.cond(y == 37, qml.PauliX)(wires=1)
qml.cond(
y == 36,
qml.RZ,
qml.RY,
(
(x == 42, qml.RX),
(x == 41, qml.RZ),
),
)(5.1, wires=0)
qml.cond(y == 37, qml.Rot)(1.2, 3.4, 5.6, wires=1)

return qml.probs()

expected_0 = func(42, 37)
expected_1 = func(0, 37)
expected_2 = func(42, 0)
expected_3 = func(41, 0)

jitted_func = qjit(func)

observed_0 = jitted_func(42, 37)
observed_1 = jitted_func(0, 37)
observed_2 = jitted_func(42, 0)
observed_3 = jitted_func(41, 0)

assert np.allclose(expected_0, observed_0)
assert np.allclose(expected_1, observed_1)
assert np.allclose(expected_2, observed_2)
assert np.allclose(expected_3, observed_3)


class TestCondPredicateConversion:
"""Test suite for checking predicate conversion to bool."""
Expand Down
Loading