From 7050c1fc1cc671fc028648af1c9b8812427761e7 Mon Sep 17 00:00:00 2001
From: Haochen Wang
Date: Fri, 25 Oct 2024 12:01:58 -0400
Subject: [PATCH 01/22] save
---
.../catalyst/api_extensions/control_flow.py | 36 ++++++++++++++++---
1 file changed, 32 insertions(+), 4 deletions(-)
diff --git a/frontend/catalyst/api_extensions/control_flow.py b/frontend/catalyst/api_extensions/control_flow.py
index b590c8c0e2..3e1ce83076 100644
--- a/frontend/catalyst/api_extensions/control_flow.py
+++ b/frontend/catalyst/api_extensions/control_flow.py
@@ -65,7 +65,7 @@
JaxTracingContext,
)
-
+import pennylane as qml
## API ##
def cond(pred: DynamicJaxprTracer):
"""A :func:`~.qjit` compatible decorator for if-else conditionals in PennyLane/Catalyst.
@@ -237,9 +237,23 @@ def conditional_fn():
"""
def _decorator(true_fn: Callable):
+ #breakpoint()
+
if len(inspect.signature(true_fn).parameters):
- raise TypeError("Conditional 'True' function is not allowed to have any arguments")
- return CondCallable(pred, true_fn)
+ #breakpoint()
+ current_frame = inspect.currentframe()
+ call_frame = current_frame.f_back
+ args_info = inspect.getargvalues(call_frame)
+ #breakpoint()
+
+ #def _true_fn():
+ # true_fn(wires=0)
+ #return CondCallable(pred, _true_fn)
+
+ #raise TypeError("Conditional 'True' function is not allowed to have any arguments")
+
+ #return CondCallable(pred, true_fn)
+ return CondCallableSingleGateHandler(pred, true_fn)
return _decorator
@@ -498,6 +512,19 @@ def _decorator(body_fn):
## IMPL ##
+class CondCallableSingleGateHandler:
+ def __init__(self, pred, true_fn):
+ self.pred = pred
+ self.true_fn = true_fn
+
+
+ def __call__(self, *my_args, **my_kwargs):
+ def new_true_fn():
+ self.true_fn(*my_args, **my_kwargs)
+ #breakpoint()
+ return CondCallable(self.pred, new_true_fn)
+
+
class CondCallable:
"""User-facing wrapper provoding "else_if" and "otherwise" public methods.
Some code in this class has been adapted from the cond implementation in the JAX project at
@@ -614,7 +641,7 @@ def _convert_predicate_to_bool(self, pred):
if isinstance(pred, jax.Array) and pred.shape not in ((), (1,)):
raise TypeError("Array with multiple elements is not a valid predicate")
-
+ #breakpoint()
if not self._is_any_boolean(pred):
try:
pred = jnp.astype(pred, bool, copy=False)
@@ -623,6 +650,7 @@ def _convert_predicate_to_bool(self, pred):
"Conditional predicates are required to be of bool, integer or float type"
) from e
+ #breakpoint()
return pred
def _is_any_boolean(self, pred):
From fd31386cff242a0e7bb389b44937a3298a7ca142 Mon Sep 17 00:00:00 2001
From: Haochen Wang
Date: Fri, 25 Oct 2024 13:26:31 -0400
Subject: [PATCH 02/22] init commit
---
.../catalyst/api_extensions/control_flow.py | 57 +++++++++----------
frontend/test/pytest/test_conditionals.py | 14 +++++
2 files changed, 41 insertions(+), 30 deletions(-)
diff --git a/frontend/catalyst/api_extensions/control_flow.py b/frontend/catalyst/api_extensions/control_flow.py
index 3e1ce83076..3e0b0cbde0 100644
--- a/frontend/catalyst/api_extensions/control_flow.py
+++ b/frontend/catalyst/api_extensions/control_flow.py
@@ -24,6 +24,7 @@
import jax
import jax.numpy as jnp
+import pennylane as qml
from jax._src.tree_util import PyTreeDef, tree_unflatten, treedef_is_leaf
from jax.core import AbstractValue
from pennylane import QueuingManager
@@ -65,7 +66,12 @@
JaxTracingContext,
)
-import pennylane as qml
+
+def is_pennylane_gate(callable):
+ # TODO
+ return True
+
+
## API ##
def cond(pred: DynamicJaxprTracer):
"""A :func:`~.qjit` compatible decorator for if-else conditionals in PennyLane/Catalyst.
@@ -237,22 +243,12 @@ def conditional_fn():
"""
def _decorator(true_fn: Callable):
- #breakpoint()
-
- if len(inspect.signature(true_fn).parameters):
- #breakpoint()
- current_frame = inspect.currentframe()
- call_frame = current_frame.f_back
- args_info = inspect.getargvalues(call_frame)
- #breakpoint()
-
- #def _true_fn():
- # true_fn(wires=0)
- #return CondCallable(pred, _true_fn)
- #raise TypeError("Conditional 'True' function is not allowed to have any arguments")
+ # if len(inspect.signature(true_fn).parameters):
+ # raise TypeError("Conditional 'True' function is not allowed to have any arguments")
- #return CondCallable(pred, true_fn)
+ # return CondCallable(pred, true_fn)
+ # if true_fn is a plain gate:
return CondCallableSingleGateHandler(pred, true_fn)
return _decorator
@@ -512,19 +508,6 @@ def _decorator(body_fn):
## IMPL ##
-class CondCallableSingleGateHandler:
- def __init__(self, pred, true_fn):
- self.pred = pred
- self.true_fn = true_fn
-
-
- def __call__(self, *my_args, **my_kwargs):
- def new_true_fn():
- self.true_fn(*my_args, **my_kwargs)
- #breakpoint()
- return CondCallable(self.pred, new_true_fn)
-
-
class CondCallable:
"""User-facing wrapper provoding "else_if" and "otherwise" public methods.
Some code in this class has been adapted from the cond implementation in the JAX project at
@@ -641,7 +624,7 @@ def _convert_predicate_to_bool(self, pred):
if isinstance(pred, jax.Array) and pred.shape not in ((), (1,)):
raise TypeError("Array with multiple elements is not a valid predicate")
- #breakpoint()
+ # breakpoint()
if not self._is_any_boolean(pred):
try:
pred = jnp.astype(pred, bool, copy=False)
@@ -650,7 +633,7 @@ def _convert_predicate_to_bool(self, pred):
"Conditional predicates are required to be of bool, integer or float type"
) from e
- #breakpoint()
+ # breakpoint()
return pred
def _is_any_boolean(self, pred):
@@ -770,6 +753,20 @@ def __call__(self):
return self._call_during_interpretation()
+class CondCallableSingleGateHandler(CondCallable):
+
+ def __init__(self, pred, true_fn):
+ self.pred = pred
+ self.true_fn = true_fn
+
+ def __call__(self, *my_args, **my_kwargs):
+ def new_true_fn():
+ self.true_fn(*my_args, **my_kwargs)
+
+ super().__init__(self.pred, new_true_fn)
+ return super().__call__()
+
+
class ForLoopCallable:
"""
Wrapping for_loop decorator into a class so that the actual "ForLoop" operation object, which
diff --git a/frontend/test/pytest/test_conditionals.py b/frontend/test/pytest/test_conditionals.py
index 4a0edbb624..ec43f26ffe 100644
--- a/frontend/test/pytest/test_conditionals.py
+++ b/frontend/test/pytest/test_conditionals.py
@@ -676,6 +676,20 @@ def branch_f():
assert func(True) == 1
assert func(False) == 0
+ def test_cond_single_gate(self, backend):
+ """Test standard pennylane qml.cond usage on single quantum gates."""
+ """Fixes https://github.com/PennyLaneAI/catalyst/issues/449"""
+
+ @qjit
+ @qml.qnode(qml.device(backend, wires=1))
+ def func(x, y):
+ qml.cond(x == 42, qml.Hadamard)(wires=0)
+ qml.cond(y == 37, qml.Hadamard)(wires=0)
+ return qml.probs()
+
+ assert np.allclose(func(42, 37), [1, 0])
+ assert np.allclose(func(0, 37), [0.5, 0.5])
+
class TestCondPredicateConversion:
"""Test suite for checking predicate conversion to bool."""
From dcb05ef19bb6d19b18e24301367b5547561bcbcc Mon Sep 17 00:00:00 2001
From: Haochen Wang
Date: Fri, 25 Oct 2024 13:54:46 -0400
Subject: [PATCH 03/22] add checks for whether the callable is a pennylane gate
---
frontend/catalyst/api_extensions/control_flow.py | 16 ++++++----------
frontend/test/pytest/test_conditionals.py | 5 ++---
2 files changed, 8 insertions(+), 13 deletions(-)
diff --git a/frontend/catalyst/api_extensions/control_flow.py b/frontend/catalyst/api_extensions/control_flow.py
index 3e0b0cbde0..9ffc2e5e5f 100644
--- a/frontend/catalyst/api_extensions/control_flow.py
+++ b/frontend/catalyst/api_extensions/control_flow.py
@@ -67,11 +67,6 @@
)
-def is_pennylane_gate(callable):
- # TODO
- return True
-
-
## API ##
def cond(pred: DynamicJaxprTracer):
"""A :func:`~.qjit` compatible decorator for if-else conditionals in PennyLane/Catalyst.
@@ -244,12 +239,13 @@ def conditional_fn():
def _decorator(true_fn: Callable):
- # if len(inspect.signature(true_fn).parameters):
- # raise TypeError("Conditional 'True' function is not allowed to have any arguments")
+ if len(inspect.signature(true_fn).parameters):
+ if isinstance(true_fn, type) and issubclass(true_fn, qml.operation.Operation):
+ return CondCallableSingleGateHandler(pred, true_fn)
+ else:
+ raise TypeError("Conditional 'True' function is not allowed to have any arguments")
- # return CondCallable(pred, true_fn)
- # if true_fn is a plain gate:
- return CondCallableSingleGateHandler(pred, true_fn)
+ return CondCallable(pred, true_fn)
return _decorator
diff --git a/frontend/test/pytest/test_conditionals.py b/frontend/test/pytest/test_conditionals.py
index ec43f26ffe..48492473ee 100644
--- a/frontend/test/pytest/test_conditionals.py
+++ b/frontend/test/pytest/test_conditionals.py
@@ -418,12 +418,11 @@ def conditional_flip():
def test_argument_error_with_callables(self):
"""Test for the error when arguments are supplied and the target is not a function."""
- @qml.qnode(qml.device("lightning.qubit", wires=1))
def f(x: int):
- qml.cond(x < 5, qml.Hadamard)(0)
+ res = qml.cond(x < 5, lambda z: z + 1)(0)
- return qml.probs()
+ return res
with pytest.raises(TypeError, match="not allowed to have any arguments"):
qjit(f)
From fb1d3a789c9b7c36bb88945139eb295d5201db6a Mon Sep 17 00:00:00 2001
From: Haochen Wang
Date: Fri, 25 Oct 2024 13:56:25 -0400
Subject: [PATCH 04/22] remove debugs
---
frontend/catalyst/api_extensions/control_flow.py | 3 +--
1 file changed, 1 insertion(+), 2 deletions(-)
diff --git a/frontend/catalyst/api_extensions/control_flow.py b/frontend/catalyst/api_extensions/control_flow.py
index 9ffc2e5e5f..b5b91cc5a6 100644
--- a/frontend/catalyst/api_extensions/control_flow.py
+++ b/frontend/catalyst/api_extensions/control_flow.py
@@ -620,7 +620,7 @@ def _convert_predicate_to_bool(self, pred):
if isinstance(pred, jax.Array) and pred.shape not in ((), (1,)):
raise TypeError("Array with multiple elements is not a valid predicate")
- # breakpoint()
+
if not self._is_any_boolean(pred):
try:
pred = jnp.astype(pred, bool, copy=False)
@@ -629,7 +629,6 @@ def _convert_predicate_to_bool(self, pred):
"Conditional predicates are required to be of bool, integer or float type"
) from e
- # breakpoint()
return pred
def _is_any_boolean(self, pred):
From 45a82400a82ae03521f0e29bdb14796920abe663 Mon Sep 17 00:00:00 2001
From: Haochen Wang
Date: Fri, 25 Oct 2024 14:46:03 -0400
Subject: [PATCH 05/22] add more gates in test
---
frontend/test/pytest/test_conditionals.py | 25 ++++++++++++++++++-----
1 file changed, 20 insertions(+), 5 deletions(-)
diff --git a/frontend/test/pytest/test_conditionals.py b/frontend/test/pytest/test_conditionals.py
index 48492473ee..a7d4eb4727 100644
--- a/frontend/test/pytest/test_conditionals.py
+++ b/frontend/test/pytest/test_conditionals.py
@@ -679,15 +679,30 @@ def test_cond_single_gate(self, backend):
"""Test standard pennylane qml.cond usage on single quantum gates."""
"""Fixes https://github.com/PennyLaneAI/catalyst/issues/449"""
- @qjit
- @qml.qnode(qml.device(backend, wires=1))
+ @qml.qnode(qml.device(backend, wires=2))
def func(x, y):
qml.cond(x == 42, qml.Hadamard)(wires=0)
- qml.cond(y == 37, qml.Hadamard)(wires=0)
+ qml.cond(x == 42, qml.RY)(1.5, wires=0)
+ qml.cond(x == 42, qml.CNOT)(wires=[1, 0])
+ qml.cond(y == 37, qml.PauliX)(wires=1)
+ qml.cond(y == 37, qml.RZ)(5.1, wires=0)
+ qml.cond(y == 37, qml.Rot)(1.2, 3.4, 5.6, wires=1)
+
return qml.probs()
- assert np.allclose(func(42, 37), [1, 0])
- assert np.allclose(func(0, 37), [0.5, 0.5])
+ expected_0 = func(42, 37)
+ expected_1 = func(0, 37)
+ expected_2 = func(42, 0)
+
+ jitted_func = qjit(func)
+
+ observed_0 = jitted_func(42, 37)
+ observed_1 = jitted_func(0, 37)
+ observed_2 = jitted_func(42, 0)
+
+ assert np.allclose(expected_0, observed_0)
+ assert np.allclose(expected_1, observed_1)
+ assert np.allclose(expected_2, observed_2)
class TestCondPredicateConversion:
From dbfb6898feb4d341e2e28ee2a2621e18e5ae2946 Mon Sep 17 00:00:00 2001
From: Haochen Wang
Date: Fri, 25 Oct 2024 14:47:44 -0400
Subject: [PATCH 06/22] codefactor pointless string
---
frontend/test/pytest/test_conditionals.py | 6 ++++--
1 file changed, 4 insertions(+), 2 deletions(-)
diff --git a/frontend/test/pytest/test_conditionals.py b/frontend/test/pytest/test_conditionals.py
index a7d4eb4727..366d0d4fbf 100644
--- a/frontend/test/pytest/test_conditionals.py
+++ b/frontend/test/pytest/test_conditionals.py
@@ -676,8 +676,10 @@ def branch_f():
assert func(False) == 0
def test_cond_single_gate(self, backend):
- """Test standard pennylane qml.cond usage on single quantum gates."""
- """Fixes https://github.com/PennyLaneAI/catalyst/issues/449"""
+ """
+ Test standard pennylane qml.cond usage on single quantum gates.
+ Fixes https://github.com/PennyLaneAI/catalyst/issues/449
+ """
@qml.qnode(qml.device(backend, wires=2))
def func(x, y):
From 6e5dba116de767f50ebb5e73eb7a6c085e337cbd Mon Sep 17 00:00:00 2001
From: Haochen Wang
Date: Fri, 25 Oct 2024 14:56:29 -0400
Subject: [PATCH 07/22] code factor new class docstring
---
.../catalyst/api_extensions/control_flow.py | 17 +++++++++++++++++
1 file changed, 17 insertions(+)
diff --git a/frontend/catalyst/api_extensions/control_flow.py b/frontend/catalyst/api_extensions/control_flow.py
index b5b91cc5a6..90395ab3ac 100644
--- a/frontend/catalyst/api_extensions/control_flow.py
+++ b/frontend/catalyst/api_extensions/control_flow.py
@@ -241,6 +241,10 @@ def _decorator(true_fn: Callable):
if len(inspect.signature(true_fn).parameters):
if isinstance(true_fn, type) and issubclass(true_fn, qml.operation.Operation):
+ # Special treatment if conditional function body is a single pennylane gate
+ # The qml.operation.Operation base class represents things that
+ # can reasonably be considered as a gate,
+ # e.g. qml.Hadamard, qml.RX, etc.
return CondCallableSingleGateHandler(pred, true_fn)
else:
raise TypeError("Conditional 'True' function is not allowed to have any arguments")
@@ -749,6 +753,19 @@ def __call__(self):
class CondCallableSingleGateHandler(CondCallable):
+ """
+ Special CondCallable when the conditional body function is a single pennylane gate.
+
+ A usual pennylane conditional call for a gate looks like
+ `qml.cond(x == 42, qml.RX)(theta, wires=0)`
+
+ Since gates are guaranteed to take in arguments (at the very least the wire argument),
+ the usual CondCallable class, which expects the conditional body function to have no arguments,
+ cannot be used.
+ This class inherits from base CondCallable, but wraps the gate in a function with no arguments,
+ and send that function to CondCallable.
+ This allows us to perform the conditional branch gate function with arguments.
+ """
def __init__(self, pred, true_fn):
self.pred = pred
From 61611367d9320be977be7ed889c7b764f33d18ae Mon Sep 17 00:00:00 2001
From: Haochen Wang
Date: Fri, 25 Oct 2024 14:58:29 -0400
Subject: [PATCH 08/22] codefactor super init not called
---
frontend/catalyst/api_extensions/control_flow.py | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/frontend/catalyst/api_extensions/control_flow.py b/frontend/catalyst/api_extensions/control_flow.py
index 90395ab3ac..3fde7ba960 100644
--- a/frontend/catalyst/api_extensions/control_flow.py
+++ b/frontend/catalyst/api_extensions/control_flow.py
@@ -767,7 +767,7 @@ class CondCallableSingleGateHandler(CondCallable):
This allows us to perform the conditional branch gate function with arguments.
"""
- def __init__(self, pred, true_fn):
+ def __init__(self, pred, true_fn): # pylint:disable=super-init-not-called
self.pred = pred
self.true_fn = true_fn
From 36279f80b919d6fab05c190d26ccc911129b8962 Mon Sep 17 00:00:00 2001
From: Haochen Wang
Date: Fri, 25 Oct 2024 15:05:28 -0400
Subject: [PATCH 09/22] changelog
---
doc/releases/changelog-0.9.0.md | 4 ++++
1 file changed, 4 insertions(+)
diff --git a/doc/releases/changelog-0.9.0.md b/doc/releases/changelog-0.9.0.md
index da459438dc..9030125816 100644
--- a/doc/releases/changelog-0.9.0.md
+++ b/doc/releases/changelog-0.9.0.md
@@ -363,6 +363,10 @@
lowering of the scatter operation.
[(#1214)](https://github.com/PennyLaneAI/catalyst/pull/1214)
+* Fixes a bug where conditional-ed single gates cannot be used in qjit,
+ i.e. `qml.cond(x > 1, qml.Hadamard)(wires=0)`.
+ [(#1213)](https://github.com/PennyLaneAI/catalyst/pull/1232)
+
Internal changes
* Remove deprecated pennylane code across the frontend.
From fd291db3cf0f311cb82e6ff205241b0f8383b04c Mon Sep 17 00:00:00 2001
From: Haochen Wang
Date: Fri, 25 Oct 2024 15:06:30 -0400
Subject: [PATCH 10/22] changelog typo
---
doc/releases/changelog-0.9.0.md | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/doc/releases/changelog-0.9.0.md b/doc/releases/changelog-0.9.0.md
index 9030125816..b8e62e629b 100644
--- a/doc/releases/changelog-0.9.0.md
+++ b/doc/releases/changelog-0.9.0.md
@@ -365,7 +365,7 @@
* Fixes a bug where conditional-ed single gates cannot be used in qjit,
i.e. `qml.cond(x > 1, qml.Hadamard)(wires=0)`.
- [(#1213)](https://github.com/PennyLaneAI/catalyst/pull/1232)
+ [(#1232)](https://github.com/PennyLaneAI/catalyst/pull/1232)
Internal changes
From a3688e194a972634165167ba009925a38a525498 Mon Sep 17 00:00:00 2001
From: Haochen Wang
Date: Mon, 28 Oct 2024 09:38:10 -0400
Subject: [PATCH 11/22] fixing my latin
---
doc/releases/changelog-0.9.0.md | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/doc/releases/changelog-0.9.0.md b/doc/releases/changelog-0.9.0.md
index b8e62e629b..5ff1a0286d 100644
--- a/doc/releases/changelog-0.9.0.md
+++ b/doc/releases/changelog-0.9.0.md
@@ -364,7 +364,7 @@
[(#1214)](https://github.com/PennyLaneAI/catalyst/pull/1214)
* Fixes a bug where conditional-ed single gates cannot be used in qjit,
- i.e. `qml.cond(x > 1, qml.Hadamard)(wires=0)`.
+ e.g. `qml.cond(x > 1, qml.Hadamard)(wires=0)`.
[(#1232)](https://github.com/PennyLaneAI/catalyst/pull/1232)
Internal changes
From 9a78577da9754e1476544c2cc004c44a174539b9 Mon Sep 17 00:00:00 2001
From: Haochen Wang
Date: Mon, 28 Oct 2024 09:38:59 -0400
Subject: [PATCH 12/22] fixing my grammar
---
frontend/catalyst/api_extensions/control_flow.py | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/frontend/catalyst/api_extensions/control_flow.py b/frontend/catalyst/api_extensions/control_flow.py
index 3fde7ba960..d61ec22d6e 100644
--- a/frontend/catalyst/api_extensions/control_flow.py
+++ b/frontend/catalyst/api_extensions/control_flow.py
@@ -763,7 +763,7 @@ class CondCallableSingleGateHandler(CondCallable):
the usual CondCallable class, which expects the conditional body function to have no arguments,
cannot be used.
This class inherits from base CondCallable, but wraps the gate in a function with no arguments,
- and send that function to CondCallable.
+ and sends that function to CondCallable.
This allows us to perform the conditional branch gate function with arguments.
"""
From cef02969a334d40514231303a53544722a9fe6f1 Mon Sep 17 00:00:00 2001
From: Haochen Wang
Date: Mon, 28 Oct 2024 11:46:45 -0400
Subject: [PATCH 13/22] add otherwise
---
.../catalyst/api_extensions/control_flow.py | 24 +++++++++++++++----
frontend/test/pytest/test_conditionals.py | 4 ++--
2 files changed, 22 insertions(+), 6 deletions(-)
diff --git a/frontend/catalyst/api_extensions/control_flow.py b/frontend/catalyst/api_extensions/control_flow.py
index d61ec22d6e..535d1ca6f1 100644
--- a/frontend/catalyst/api_extensions/control_flow.py
+++ b/frontend/catalyst/api_extensions/control_flow.py
@@ -567,6 +567,9 @@ def __init__(self, pred, true_fn):
self._operation = None
self.expansion_strategy = cond_expansion_strategy()
+ def set_otherwise_fn(self, otherwise_fn):
+ self.otherwise_fn = otherwise_fn
+
@property
def operation(self):
"""
@@ -768,16 +771,29 @@ class CondCallableSingleGateHandler(CondCallable):
"""
def __init__(self, pred, true_fn): # pylint:disable=super-init-not-called
- self.pred = pred
- self.true_fn = true_fn
+ self._pred = pred
+ self._true_fn = true_fn
+ self._otherwise_fn = None
def __call__(self, *my_args, **my_kwargs):
def new_true_fn():
- self.true_fn(*my_args, **my_kwargs)
+ self._true_fn(*my_args, **my_kwargs)
+
+ super().__init__(self._pred, new_true_fn)
+
+ if self._otherwise_fn is not None:
+
+ def new_otherwise_fn():
+ self._otherwise_fn(*my_args, **my_kwargs)
+
+ super().set_otherwise_fn(new_otherwise_fn)
- super().__init__(self.pred, new_true_fn)
return super().__call__()
+ def otherwise(self, otherwise_fn):
+ # Override the "can't have arguments" check in the original CondCallable's `otherwise`
+ self._otherwise_fn = otherwise_fn
+
class ForLoopCallable:
"""
diff --git a/frontend/test/pytest/test_conditionals.py b/frontend/test/pytest/test_conditionals.py
index 366d0d4fbf..c187ce3db0 100644
--- a/frontend/test/pytest/test_conditionals.py
+++ b/frontend/test/pytest/test_conditionals.py
@@ -683,8 +683,8 @@ def test_cond_single_gate(self, backend):
@qml.qnode(qml.device(backend, wires=2))
def func(x, y):
- qml.cond(x == 42, qml.Hadamard)(wires=0)
- qml.cond(x == 42, qml.RY)(1.5, wires=0)
+ qml.cond(x == 42, qml.Hadamard, qml.PauliX)(wires=0)
+ qml.cond(x == 42, qml.RY, qml.RZ)(1.5, wires=0)
qml.cond(x == 42, qml.CNOT)(wires=[1, 0])
qml.cond(y == 37, qml.PauliX)(wires=1)
qml.cond(y == 37, qml.RZ)(5.1, wires=0)
From 25ab7287d80ee5f5e344e2950d3ba47165fdcc8c Mon Sep 17 00:00:00 2001
From: Haochen Wang
Date: Mon, 28 Oct 2024 11:50:11 -0400
Subject: [PATCH 14/22] rename my_args -> args
---
frontend/catalyst/api_extensions/control_flow.py | 6 +++---
1 file changed, 3 insertions(+), 3 deletions(-)
diff --git a/frontend/catalyst/api_extensions/control_flow.py b/frontend/catalyst/api_extensions/control_flow.py
index 535d1ca6f1..4248e1a175 100644
--- a/frontend/catalyst/api_extensions/control_flow.py
+++ b/frontend/catalyst/api_extensions/control_flow.py
@@ -775,16 +775,16 @@ def __init__(self, pred, true_fn): # pylint:disable=super-init-not-called
self._true_fn = true_fn
self._otherwise_fn = None
- def __call__(self, *my_args, **my_kwargs):
+ def __call__(self, *args, **kwargs):
def new_true_fn():
- self._true_fn(*my_args, **my_kwargs)
+ self._true_fn(*args, **kwargs)
super().__init__(self._pred, new_true_fn)
if self._otherwise_fn is not None:
def new_otherwise_fn():
- self._otherwise_fn(*my_args, **my_kwargs)
+ self._otherwise_fn(*args, **kwargs)
super().set_otherwise_fn(new_otherwise_fn)
From 22e7c52b186612a5ee42b1c4056449c1bb6fa8de Mon Sep 17 00:00:00 2001
From: Haochen Wang
Date: Mon, 28 Oct 2024 11:52:19 -0400
Subject: [PATCH 15/22] more renames
---
frontend/catalyst/api_extensions/control_flow.py | 16 ++++++++--------
1 file changed, 8 insertions(+), 8 deletions(-)
diff --git a/frontend/catalyst/api_extensions/control_flow.py b/frontend/catalyst/api_extensions/control_flow.py
index 4248e1a175..d8fbfac1b7 100644
--- a/frontend/catalyst/api_extensions/control_flow.py
+++ b/frontend/catalyst/api_extensions/control_flow.py
@@ -771,20 +771,20 @@ class CondCallableSingleGateHandler(CondCallable):
"""
def __init__(self, pred, true_fn): # pylint:disable=super-init-not-called
- self._pred = pred
- self._true_fn = true_fn
- self._otherwise_fn = None
+ self.sgh_pred = pred
+ self.sgh_true_fn = true_fn
+ self.sgh_otherwise_fn = None
def __call__(self, *args, **kwargs):
def new_true_fn():
- self._true_fn(*args, **kwargs)
+ self.sgh_true_fn(*args, **kwargs)
- super().__init__(self._pred, new_true_fn)
+ super().__init__(self.sgh_pred, new_true_fn)
- if self._otherwise_fn is not None:
+ if self.sgh_otherwise_fn is not None:
def new_otherwise_fn():
- self._otherwise_fn(*args, **kwargs)
+ self.sgh_otherwise_fn(*args, **kwargs)
super().set_otherwise_fn(new_otherwise_fn)
@@ -792,7 +792,7 @@ def new_otherwise_fn():
def otherwise(self, otherwise_fn):
# Override the "can't have arguments" check in the original CondCallable's `otherwise`
- self._otherwise_fn = otherwise_fn
+ self.sgh_otherwise_fn = otherwise_fn
class ForLoopCallable:
From 5fae6b3f0534ba4c734264a2108953c62121ce05 Mon Sep 17 00:00:00 2001
From: Haochen Wang
Date: Mon, 28 Oct 2024 11:53:32 -0400
Subject: [PATCH 16/22] more renames
---
frontend/catalyst/api_extensions/control_flow.py | 8 ++++----
1 file changed, 4 insertions(+), 4 deletions(-)
diff --git a/frontend/catalyst/api_extensions/control_flow.py b/frontend/catalyst/api_extensions/control_flow.py
index d8fbfac1b7..f77e408fc1 100644
--- a/frontend/catalyst/api_extensions/control_flow.py
+++ b/frontend/catalyst/api_extensions/control_flow.py
@@ -776,17 +776,17 @@ def __init__(self, pred, true_fn): # pylint:disable=super-init-not-called
self.sgh_otherwise_fn = None
def __call__(self, *args, **kwargs):
- def new_true_fn():
+ def argless_true_fn():
self.sgh_true_fn(*args, **kwargs)
- super().__init__(self.sgh_pred, new_true_fn)
+ super().__init__(self.sgh_pred, argless_true_fn)
if self.sgh_otherwise_fn is not None:
- def new_otherwise_fn():
+ def argless_otherwise_fn():
self.sgh_otherwise_fn(*args, **kwargs)
- super().set_otherwise_fn(new_otherwise_fn)
+ super().set_otherwise_fn(argless_otherwise_fn)
return super().__call__()
From 90edf445107e80d6c1c9cd7d8cfcc968e964e2dc Mon Sep 17 00:00:00 2001
From: Haochen Wang
Date: Mon, 28 Oct 2024 12:01:28 -0400
Subject: [PATCH 17/22] add gate checks for otherwise func
---
frontend/catalyst/api_extensions/control_flow.py | 7 ++++++-
frontend/test/pytest/test_conditionals.py | 12 ++++++++++++
2 files changed, 18 insertions(+), 1 deletion(-)
diff --git a/frontend/catalyst/api_extensions/control_flow.py b/frontend/catalyst/api_extensions/control_flow.py
index f77e408fc1..816380bb3a 100644
--- a/frontend/catalyst/api_extensions/control_flow.py
+++ b/frontend/catalyst/api_extensions/control_flow.py
@@ -792,7 +792,12 @@ def argless_otherwise_fn():
def otherwise(self, otherwise_fn):
# Override the "can't have arguments" check in the original CondCallable's `otherwise`
- self.sgh_otherwise_fn = otherwise_fn
+ if isinstance(otherwise_fn, type) and issubclass(otherwise_fn, qml.operation.Operation):
+ self.sgh_otherwise_fn = otherwise_fn
+ else:
+ raise TypeError(
+ "Conditional 'False' function is allowed to have arguments only if it is a PennyLane gate."
+ )
class ForLoopCallable:
diff --git a/frontend/test/pytest/test_conditionals.py b/frontend/test/pytest/test_conditionals.py
index c187ce3db0..0c36e3df16 100644
--- a/frontend/test/pytest/test_conditionals.py
+++ b/frontend/test/pytest/test_conditionals.py
@@ -427,6 +427,18 @@ def f(x: int):
with pytest.raises(TypeError, match="not allowed to have any arguments"):
qjit(f)
+ def f(x: int):
+
+ res = qml.cond(x < 5, qml.Hadamard, lambda z: z + 1)(0)
+
+ return res
+
+ with pytest.raises(
+ TypeError,
+ match="Conditional 'False' function is allowed to have arguments only if it is a PennyLane gate.",
+ ):
+ qjit(f)
+
class TestInterpretationConditional:
"""Test that the conditional operation's execution is semantically equivalent
From 070885bcd95076aaf23dc89f4a9ecb5e627f42cc Mon Sep 17 00:00:00 2001
From: Haochen Wang
Date: Mon, 28 Oct 2024 13:11:24 -0400
Subject: [PATCH 18/22] codefactor
---
frontend/catalyst/api_extensions/control_flow.py | 4 ++--
frontend/test/pytest/test_conditionals.py | 2 +-
2 files changed, 3 insertions(+), 3 deletions(-)
diff --git a/frontend/catalyst/api_extensions/control_flow.py b/frontend/catalyst/api_extensions/control_flow.py
index 816380bb3a..67ab1ac57b 100644
--- a/frontend/catalyst/api_extensions/control_flow.py
+++ b/frontend/catalyst/api_extensions/control_flow.py
@@ -567,7 +567,7 @@ def __init__(self, pred, true_fn):
self._operation = None
self.expansion_strategy = cond_expansion_strategy()
- def set_otherwise_fn(self, otherwise_fn):
+ def set_otherwise_fn(self, otherwise_fn): # pylint:disable=missing-function-docstring
self.otherwise_fn = otherwise_fn
@property
@@ -796,7 +796,7 @@ def otherwise(self, otherwise_fn):
self.sgh_otherwise_fn = otherwise_fn
else:
raise TypeError(
- "Conditional 'False' function is allowed to have arguments only if it is a PennyLane gate."
+ "Conditional 'False' function can have arguments only if it is a PennyLane gate."
)
diff --git a/frontend/test/pytest/test_conditionals.py b/frontend/test/pytest/test_conditionals.py
index 0c36e3df16..7bc04e0514 100644
--- a/frontend/test/pytest/test_conditionals.py
+++ b/frontend/test/pytest/test_conditionals.py
@@ -435,7 +435,7 @@ def f(x: int):
with pytest.raises(
TypeError,
- match="Conditional 'False' function is allowed to have arguments only if it is a PennyLane gate.",
+ match="Conditional 'False' function can have arguments only if it is a PennyLane gate.",
):
qjit(f)
From 162422197de6eccd81906e1d2aac0b552d841096 Mon Sep 17 00:00:00 2001
From: Haochen Wang
Date: Mon, 28 Oct 2024 13:28:24 -0400
Subject: [PATCH 19/22] add lit test
---
frontend/test/lit/test_if_else.py | 37 +++++++++++++++++++++++++++++++
1 file changed, 37 insertions(+)
diff --git a/frontend/test/lit/test_if_else.py b/frontend/test/lit/test_if_else.py
index 5b37804b7f..53e472e083 100644
--- a/frontend/test/lit/test_if_else.py
+++ b/frontend/test/lit/test_if_else.py
@@ -58,6 +58,43 @@ def otherwise():
# -----
+# CHECK-LABEL: public @jit_circuit_single_gate
+@qjit(target="mlir")
+@qml.qnode(qml.device("lightning.qubit", wires=1))
+def circuit_single_gate(n: int):
+ # CHECK-DAG: [[c5:%[a-zA-Z0-9_]+]] = stablehlo.constant dense<5> : tensor
+ # CHECK: [[b_t:%[a-zA-Z0-9_]+]] = stablehlo.compare LE, %arg0, [[c5]], SIGNED : (tensor, tensor) -> tensor
+ # CHECK-DAG: [[qreg_0:%[a-zA-Z0-9_]+]] = quantum.alloc
+ # CHECK: [[b:%[a-zA-Z0-9_]+]] = tensor.extract [[b_t]]
+
+ # CHECK: [[qreg_out:%.+]] = scf.if [[b]]
+ # CHECK-DAG: [[q0:%[a-zA-Z0-9_]+]] = quantum.extract
+ # CHECK-DAG: [[q1:%[a-zA-Z0-9_]+]] = quantum.custom "PauliX"() [[q0]]
+ # pylint: disable=line-too-long
+ # CHECK-DAG: [[qreg_1:%[a-zA-Z0-9_]+]] = quantum.insert [[qreg_0]][ {{[%a-zA-Z0-9_]+}}], [[q1]]
+ # CHECK: scf.yield [[qreg_1]]
+
+ # CHECK: else
+ # CHECK-DAG: [[q2:%[a-zA-Z0-9_]+]] = quantum.extract
+ # CHECK-DAG: [[q3:%[a-zA-Z0-9_]+]] = quantum.custom "Hadamard"() [[q2]]
+ # pylint: disable=line-too-long
+ # CHECK-DAG: [[qreg_2:%[a-zA-Z0-9_]+]] = quantum.insert [[qreg_0]][ {{[%a-zA-Z0-9_]+}}], [[q3]]
+ # CHECK: scf.yield [[qreg_2]]
+ qml.cond(n <= 5, qml.PauliX, qml.Hadamard)(wires=0)
+
+ # CHECK: [[qreg_3:%.+]] = quantum.extract [[qreg_out]][ 0]
+ # CHECK: [[qobs:%.+]] = quantum.compbasis [[qreg_3]] : !quantum.obs
+ # CHECK: [[ret:%.+]] = quantum.probs [[qobs]]
+ # CHECK: return [[ret]]
+ return qml.probs()
+
+
+print(circuit_single_gate.mlir)
+
+
+# -----
+
+
# CHECK-LABEL: test_convert_element_type
@qjit
def test_convert_element_type(i: int, f: float):
From ce24fcf4457ebe679a35b624ef696bd8519ebd07 Mon Sep 17 00:00:00 2001
From: Haochen Wang
Date: Mon, 28 Oct 2024 13:41:24 -0400
Subject: [PATCH 20/22] more lit tests
---
frontend/test/lit/test_if_else.py | 26 ++++++++++++++++++++------
1 file changed, 20 insertions(+), 6 deletions(-)
diff --git a/frontend/test/lit/test_if_else.py b/frontend/test/lit/test_if_else.py
index 53e472e083..9097a2af5e 100644
--- a/frontend/test/lit/test_if_else.py
+++ b/frontend/test/lit/test_if_else.py
@@ -63,26 +63,40 @@ def otherwise():
@qml.qnode(qml.device("lightning.qubit", wires=1))
def circuit_single_gate(n: int):
# CHECK-DAG: [[c5:%[a-zA-Z0-9_]+]] = stablehlo.constant dense<5> : tensor
- # CHECK: [[b_t:%[a-zA-Z0-9_]+]] = stablehlo.compare LE, %arg0, [[c5]], SIGNED : (tensor, tensor) -> tensor
+ # CHECK-DAG: [[c6:%[a-zA-Z0-9_]+]] = stablehlo.constant dense<6> : tensor
+ # CHECK-DAG: [[b_t5:%[a-zA-Z0-9_]+]] = stablehlo.compare LE, %arg0, [[c5]], SIGNED : (tensor, tensor) -> tensor
+ # CHECK-DAG: [[b_t6:%[a-zA-Z0-9_]+]] = stablehlo.compare LE, %arg0, [[c6]], SIGNED : (tensor, tensor) -> tensor
# CHECK-DAG: [[qreg_0:%[a-zA-Z0-9_]+]] = quantum.alloc
- # CHECK: [[b:%[a-zA-Z0-9_]+]] = tensor.extract [[b_t]]
+ # CHECK: [[b5:%[a-zA-Z0-9_]+]] = tensor.extract [[b_t5]]
- # CHECK: [[qreg_out:%.+]] = scf.if [[b]]
- # CHECK-DAG: [[q0:%[a-zA-Z0-9_]+]] = quantum.extract
+ # CHECK: [[qreg_out:%.+]] = scf.if [[b5]]
+ # CHECK-DAG: [[q0:%[a-zA-Z0-9_]+]] = quantum.extract [[qreg_0]]
# CHECK-DAG: [[q1:%[a-zA-Z0-9_]+]] = quantum.custom "PauliX"() [[q0]]
# pylint: disable=line-too-long
# CHECK-DAG: [[qreg_1:%[a-zA-Z0-9_]+]] = quantum.insert [[qreg_0]][ {{[%a-zA-Z0-9_]+}}], [[q1]]
# CHECK: scf.yield [[qreg_1]]
# CHECK: else
- # CHECK-DAG: [[q2:%[a-zA-Z0-9_]+]] = quantum.extract
+ # CHECK-DAG: [[q2:%[a-zA-Z0-9_]+]] = quantum.extract [[qreg_0]]
# CHECK-DAG: [[q3:%[a-zA-Z0-9_]+]] = quantum.custom "Hadamard"() [[q2]]
# pylint: disable=line-too-long
# CHECK-DAG: [[qreg_2:%[a-zA-Z0-9_]+]] = quantum.insert [[qreg_0]][ {{[%a-zA-Z0-9_]+}}], [[q3]]
# CHECK: scf.yield [[qreg_2]]
qml.cond(n <= 5, qml.PauliX, qml.Hadamard)(wires=0)
- # CHECK: [[qreg_3:%.+]] = quantum.extract [[qreg_out]][ 0]
+ # CHECK: [[b6:%[a-zA-Z0-9_]+]] = tensor.extract [[b_t6]]
+ # CHECK: [[qreg_out1:%.+]] = scf.if [[b6]]
+ # CHECK-DAG: [[q4:%[a-zA-Z0-9_]+]] = quantum.extract [[qreg_out]]
+ # CHECK-DAG: [[q5:%[a-zA-Z0-9_]+]] = quantum.custom "RX"({{%.+}}) [[q4]]
+ # pylint: disable=line-too-long
+ # CHECK-DAG: [[qreg_3:%[a-zA-Z0-9_]+]] = quantum.insert [[qreg_out]][ {{[%a-zA-Z0-9_]+}}], [[q1]]
+ # CHECK: scf.yield [[qreg_3]]
+ # CHECK: else
+ # CHECK: scf.yield [[qreg_out]]
+
+ qml.cond(n <= 6, qml.RX)(3.14, wires=0)
+
+ # CHECK: [[qreg_3:%.+]] = quantum.extract [[qreg_out1]][ 0]
# CHECK: [[qobs:%.+]] = quantum.compbasis [[qreg_3]] : !quantum.obs
# CHECK: [[ret:%.+]] = quantum.probs [[qobs]]
# CHECK: return [[ret]]
From 3415dfc20d2c63a93d96571b3c902a11baf2fb55 Mon Sep 17 00:00:00 2001
From: Haochen Wang
Date: Mon, 28 Oct 2024 15:20:03 -0400
Subject: [PATCH 21/22] elseif
---
.../catalyst/api_extensions/control_flow.py | 43 ++++++++++++++++---
frontend/test/pytest/test_conditionals.py | 25 ++++++++++-
2 files changed, 62 insertions(+), 6 deletions(-)
diff --git a/frontend/catalyst/api_extensions/control_flow.py b/frontend/catalyst/api_extensions/control_flow.py
index 67ab1ac57b..1346761528 100644
--- a/frontend/catalyst/api_extensions/control_flow.py
+++ b/frontend/catalyst/api_extensions/control_flow.py
@@ -570,6 +570,12 @@ def __init__(self, pred, true_fn):
def set_otherwise_fn(self, otherwise_fn): # pylint:disable=missing-function-docstring
self.otherwise_fn = otherwise_fn
+ def add_pred(self, _pred):
+ self.preds.append(self._convert_predicate_to_bool(_pred))
+
+ def add_branch_fn(self, _branch_fn):
+ self.branch_fns.append(_branch_fn)
+
@property
def operation(self):
"""
@@ -771,15 +777,15 @@ class CondCallableSingleGateHandler(CondCallable):
"""
def __init__(self, pred, true_fn): # pylint:disable=super-init-not-called
- self.sgh_pred = pred
- self.sgh_true_fn = true_fn
+ self.sgh_preds = [pred]
+ self.sgh_branch_fns = [true_fn]
self.sgh_otherwise_fn = None
def __call__(self, *args, **kwargs):
def argless_true_fn():
- self.sgh_true_fn(*args, **kwargs)
+ self.sgh_branch_fns[0](*args, **kwargs)
- super().__init__(self.sgh_pred, argless_true_fn)
+ super().__init__(self.sgh_preds[0], argless_true_fn)
if self.sgh_otherwise_fn is not None:
@@ -788,10 +794,37 @@ def argless_otherwise_fn():
super().set_otherwise_fn(argless_otherwise_fn)
+ for i in range(1, len(self.sgh_branch_fns)):
+
+ def argless_elseif_fn(i=i): # i=i to work around late binding
+ self.sgh_branch_fns[i](*args, **kwargs)
+
+ super().add_pred(self.sgh_preds[i])
+ super().add_branch_fn(argless_elseif_fn)
+
return super().__call__()
+ def else_if(self, _pred):
+ """
+ Override the "can't have arguments" check in the original CondCallable's `else_if`
+ """
+
+ def decorator(branch_fn):
+ if isinstance(branch_fn, type) and issubclass(branch_fn, qml.operation.Operation):
+ self.sgh_preds.append(_pred)
+ self.sgh_branch_fns.append(branch_fn)
+ return self
+ else:
+ raise TypeError(
+ "Conditional 'else if' function can have arguments only if it is a PennyLane gate."
+ )
+
+ return decorator
+
def otherwise(self, otherwise_fn):
- # Override the "can't have arguments" check in the original CondCallable's `otherwise`
+ """
+ Override the "can't have arguments" check in the original CondCallable's `otherwise`
+ """
if isinstance(otherwise_fn, type) and issubclass(otherwise_fn, qml.operation.Operation):
self.sgh_otherwise_fn = otherwise_fn
else:
diff --git a/frontend/test/pytest/test_conditionals.py b/frontend/test/pytest/test_conditionals.py
index 7bc04e0514..19b5a0eb22 100644
--- a/frontend/test/pytest/test_conditionals.py
+++ b/frontend/test/pytest/test_conditionals.py
@@ -439,6 +439,18 @@ def f(x: int):
):
qjit(f)
+ def f(x: int):
+
+ res = qml.cond(x < 5, qml.Hadamard, qml.Hadamard, ((x < 6, lambda z: z + 1),))(0)
+
+ return res
+
+ with pytest.raises(
+ TypeError,
+ match="Conditional 'else if' function can have arguments only if it is a PennyLane gate.",
+ ):
+ qjit(f)
+
class TestInterpretationConditional:
"""Test that the conditional operation's execution is semantically equivalent
@@ -699,7 +711,15 @@ def func(x, y):
qml.cond(x == 42, qml.RY, qml.RZ)(1.5, wires=0)
qml.cond(x == 42, qml.CNOT)(wires=[1, 0])
qml.cond(y == 37, qml.PauliX)(wires=1)
- qml.cond(y == 37, qml.RZ)(5.1, wires=0)
+ qml.cond(
+ y == 36,
+ qml.RZ,
+ qml.RY,
+ (
+ (x == 42, qml.RX),
+ (x == 41, qml.RZ),
+ ),
+ )(5.1, wires=0)
qml.cond(y == 37, qml.Rot)(1.2, 3.4, 5.6, wires=1)
return qml.probs()
@@ -707,16 +727,19 @@ def func(x, y):
expected_0 = func(42, 37)
expected_1 = func(0, 37)
expected_2 = func(42, 0)
+ expected_3 = func(41, 0)
jitted_func = qjit(func)
observed_0 = jitted_func(42, 37)
observed_1 = jitted_func(0, 37)
observed_2 = jitted_func(42, 0)
+ observed_3 = jitted_func(41, 0)
assert np.allclose(expected_0, observed_0)
assert np.allclose(expected_1, observed_1)
assert np.allclose(expected_2, observed_2)
+ assert np.allclose(expected_3, observed_3)
class TestCondPredicateConversion:
From e339f972d4c0a73f66e49dc9975dec15dd5033d3 Mon Sep 17 00:00:00 2001
From: Haochen Wang
Date: Mon, 28 Oct 2024 16:38:00 -0400
Subject: [PATCH 22/22] elseif lit test
---
frontend/test/lit/test_if_else.py | 56 +++++++++++++++++++++++++++----
1 file changed, 49 insertions(+), 7 deletions(-)
diff --git a/frontend/test/lit/test_if_else.py b/frontend/test/lit/test_if_else.py
index 9097a2af5e..25bd3ebbad 100644
--- a/frontend/test/lit/test_if_else.py
+++ b/frontend/test/lit/test_if_else.py
@@ -62,24 +62,29 @@ def otherwise():
@qjit(target="mlir")
@qml.qnode(qml.device("lightning.qubit", wires=1))
def circuit_single_gate(n: int):
+ # pylint: disable=line-too-long
# CHECK-DAG: [[c5:%[a-zA-Z0-9_]+]] = stablehlo.constant dense<5> : tensor
# CHECK-DAG: [[c6:%[a-zA-Z0-9_]+]] = stablehlo.constant dense<6> : tensor
+ # CHECK-DAG: [[c7:%[a-zA-Z0-9_]+]] = stablehlo.constant dense<7> : tensor
+ # CHECK-DAG: [[c8:%[a-zA-Z0-9_]+]] = stablehlo.constant dense<8> : tensor
+ # CHECK-DAG: [[c9:%[a-zA-Z0-9_]+]] = stablehlo.constant dense<9> : tensor
# CHECK-DAG: [[b_t5:%[a-zA-Z0-9_]+]] = stablehlo.compare LE, %arg0, [[c5]], SIGNED : (tensor, tensor) -> tensor
# CHECK-DAG: [[b_t6:%[a-zA-Z0-9_]+]] = stablehlo.compare LE, %arg0, [[c6]], SIGNED : (tensor, tensor) -> tensor
+ # CHECK-DAG: [[b_t7:%[a-zA-Z0-9_]+]] = stablehlo.compare LE, %arg0, [[c7]], SIGNED : (tensor, tensor) -> tensor
+ # CHECK-DAG: [[b_t8:%[a-zA-Z0-9_]+]] = stablehlo.compare LE, %arg0, [[c8]], SIGNED : (tensor, tensor) -> tensor
+ # CHECK-DAG: [[b_t9:%[a-zA-Z0-9_]+]] = stablehlo.compare LE, %arg0, [[c9]], SIGNED : (tensor, tensor) -> tensor
# CHECK-DAG: [[qreg_0:%[a-zA-Z0-9_]+]] = quantum.alloc
- # CHECK: [[b5:%[a-zA-Z0-9_]+]] = tensor.extract [[b_t5]]
+ # CHECK: [[b5:%[a-zA-Z0-9_]+]] = tensor.extract [[b_t5]]
# CHECK: [[qreg_out:%.+]] = scf.if [[b5]]
# CHECK-DAG: [[q0:%[a-zA-Z0-9_]+]] = quantum.extract [[qreg_0]]
# CHECK-DAG: [[q1:%[a-zA-Z0-9_]+]] = quantum.custom "PauliX"() [[q0]]
- # pylint: disable=line-too-long
# CHECK-DAG: [[qreg_1:%[a-zA-Z0-9_]+]] = quantum.insert [[qreg_0]][ {{[%a-zA-Z0-9_]+}}], [[q1]]
# CHECK: scf.yield [[qreg_1]]
# CHECK: else
# CHECK-DAG: [[q2:%[a-zA-Z0-9_]+]] = quantum.extract [[qreg_0]]
# CHECK-DAG: [[q3:%[a-zA-Z0-9_]+]] = quantum.custom "Hadamard"() [[q2]]
- # pylint: disable=line-too-long
# CHECK-DAG: [[qreg_2:%[a-zA-Z0-9_]+]] = quantum.insert [[qreg_0]][ {{[%a-zA-Z0-9_]+}}], [[q3]]
# CHECK: scf.yield [[qreg_2]]
qml.cond(n <= 5, qml.PauliX, qml.Hadamard)(wires=0)
@@ -88,16 +93,53 @@ def circuit_single_gate(n: int):
# CHECK: [[qreg_out1:%.+]] = scf.if [[b6]]
# CHECK-DAG: [[q4:%[a-zA-Z0-9_]+]] = quantum.extract [[qreg_out]]
# CHECK-DAG: [[q5:%[a-zA-Z0-9_]+]] = quantum.custom "RX"({{%.+}}) [[q4]]
- # pylint: disable=line-too-long
- # CHECK-DAG: [[qreg_3:%[a-zA-Z0-9_]+]] = quantum.insert [[qreg_out]][ {{[%a-zA-Z0-9_]+}}], [[q1]]
+ # CHECK-DAG: [[qreg_3:%[a-zA-Z0-9_]+]] = quantum.insert [[qreg_out]][ {{[%a-zA-Z0-9_]+}}], [[q5]]
# CHECK: scf.yield [[qreg_3]]
# CHECK: else
# CHECK: scf.yield [[qreg_out]]
qml.cond(n <= 6, qml.RX)(3.14, wires=0)
- # CHECK: [[qreg_3:%.+]] = quantum.extract [[qreg_out1]][ 0]
- # CHECK: [[qobs:%.+]] = quantum.compbasis [[qreg_3]] : !quantum.obs
+ # CHECK: [[b7:%[a-zA-Z0-9_]+]] = tensor.extract [[b_t7]]
+ # CHECK: [[qreg_out2:%.+]] = scf.if [[b7]]
+ # CHECK-DAG: [[q7:%[a-zA-Z0-9_]+]] = quantum.extract [[qreg_out1]]
+ # CHECK-DAG: [[q8:%[a-zA-Z0-9_]+]] = quantum.custom "Hadamard"() [[q7]]
+ # pylint: disable=line-too-long
+ # CHECK-DAG: [[qreg_4:%[a-zA-Z0-9_]+]] = quantum.insert [[qreg_out1]][ {{[%a-zA-Z0-9_]+}}], [[q8]]
+ # CHECK: scf.yield [[qreg_4]]
+ # CHECK: else {
+ # CHECK: [[b8:%[a-zA-Z0-9_]+]] = tensor.extract [[b_t8]]
+ # CHECK: [[qreg_out3:%.+]] = scf.if [[b8]]
+ # CHECK-DAG: [[q9:%[a-zA-Z0-9_]+]] = quantum.extract [[qreg_out1]]
+ # CHECK-DAG: [[q10:%[a-zA-Z0-9_]+]] = quantum.custom "PauliY"() [[q9]]
+ # CHECK-DAG: [[qreg_5:%[a-zA-Z0-9_]+]] = quantum.insert [[qreg_out1]][ {{[%a-zA-Z0-9_]+}}], [[q10]]
+ # CHECK: scf.yield [[qreg_5]]
+ # CHECK: else {
+ # CHECK: [[b9:%[a-zA-Z0-9_]+]] = tensor.extract [[b_t9]]
+ # CHECK: [[qreg_out4:%.+]] = scf.if [[b9]]
+ # CHECK-DAG: [[q11:%[a-zA-Z0-9_]+]] = quantum.extract [[qreg_out1]]
+ # CHECK-DAG: [[q12:%[a-zA-Z0-9_]+]] = quantum.custom "PauliZ"() [[q11]]
+ # CHECK-DAG: [[qreg_6:%[a-zA-Z0-9_]+]] = quantum.insert [[qreg_out1]][ {{[%a-zA-Z0-9_]+}}], [[q12]]
+ # CHECK: scf.yield [[qreg_6]]
+ # CHECK: else {
+ # CHECK-DAG: [[q13:%[a-zA-Z0-9_]+]] = quantum.extract [[qreg_out1]]
+ # CHECK-DAG: [[q14:%[a-zA-Z0-9_]+]] = quantum.custom "PauliX"() [[q13]]
+ # CHECK-DAG: [[qreg_7:%[a-zA-Z0-9_]+]] = quantum.insert [[qreg_out1]][ {{[%a-zA-Z0-9_]+}}], [[q14]]
+ # CHECK: scf.yield [[qreg_7]]
+ # CHECK: scf.yield [[qreg_out4]]
+ # CHECK: scf.yield [[qreg_out3]]
+ qml.cond(
+ n <= 7,
+ qml.Hadamard,
+ qml.PauliX,
+ (
+ (n <= 8, qml.PauliY),
+ (n <= 9, qml.PauliZ),
+ ),
+ )(wires=0)
+
+ # CHECK: [[qreg_ret:%.+]] = quantum.extract [[qreg_out2]][ 0]
+ # CHECK: [[qobs:%.+]] = quantum.compbasis [[qreg_ret]] : !quantum.obs
# CHECK: [[ret:%.+]] = quantum.probs [[qobs]]
# CHECK: return [[ret]]
return qml.probs()