From 7050c1fc1cc671fc028648af1c9b8812427761e7 Mon Sep 17 00:00:00 2001 From: Haochen Wang Date: Fri, 25 Oct 2024 12:01:58 -0400 Subject: [PATCH 01/22] save --- .../catalyst/api_extensions/control_flow.py | 36 ++++++++++++++++--- 1 file changed, 32 insertions(+), 4 deletions(-) diff --git a/frontend/catalyst/api_extensions/control_flow.py b/frontend/catalyst/api_extensions/control_flow.py index b590c8c0e2..3e1ce83076 100644 --- a/frontend/catalyst/api_extensions/control_flow.py +++ b/frontend/catalyst/api_extensions/control_flow.py @@ -65,7 +65,7 @@ JaxTracingContext, ) - +import pennylane as qml ## API ## def cond(pred: DynamicJaxprTracer): """A :func:`~.qjit` compatible decorator for if-else conditionals in PennyLane/Catalyst. @@ -237,9 +237,23 @@ def conditional_fn(): """ def _decorator(true_fn: Callable): + #breakpoint() + if len(inspect.signature(true_fn).parameters): - raise TypeError("Conditional 'True' function is not allowed to have any arguments") - return CondCallable(pred, true_fn) + #breakpoint() + current_frame = inspect.currentframe() + call_frame = current_frame.f_back + args_info = inspect.getargvalues(call_frame) + #breakpoint() + + #def _true_fn(): + # true_fn(wires=0) + #return CondCallable(pred, _true_fn) + + #raise TypeError("Conditional 'True' function is not allowed to have any arguments") + + #return CondCallable(pred, true_fn) + return CondCallableSingleGateHandler(pred, true_fn) return _decorator @@ -498,6 +512,19 @@ def _decorator(body_fn): ## IMPL ## +class CondCallableSingleGateHandler: + def __init__(self, pred, true_fn): + self.pred = pred + self.true_fn = true_fn + + + def __call__(self, *my_args, **my_kwargs): + def new_true_fn(): + self.true_fn(*my_args, **my_kwargs) + #breakpoint() + return CondCallable(self.pred, new_true_fn) + + class CondCallable: """User-facing wrapper provoding "else_if" and "otherwise" public methods. Some code in this class has been adapted from the cond implementation in the JAX project at @@ -614,7 +641,7 @@ def _convert_predicate_to_bool(self, pred): if isinstance(pred, jax.Array) and pred.shape not in ((), (1,)): raise TypeError("Array with multiple elements is not a valid predicate") - + #breakpoint() if not self._is_any_boolean(pred): try: pred = jnp.astype(pred, bool, copy=False) @@ -623,6 +650,7 @@ def _convert_predicate_to_bool(self, pred): "Conditional predicates are required to be of bool, integer or float type" ) from e + #breakpoint() return pred def _is_any_boolean(self, pred): From fd31386cff242a0e7bb389b44937a3298a7ca142 Mon Sep 17 00:00:00 2001 From: Haochen Wang Date: Fri, 25 Oct 2024 13:26:31 -0400 Subject: [PATCH 02/22] init commit --- .../catalyst/api_extensions/control_flow.py | 57 +++++++++---------- frontend/test/pytest/test_conditionals.py | 14 +++++ 2 files changed, 41 insertions(+), 30 deletions(-) diff --git a/frontend/catalyst/api_extensions/control_flow.py b/frontend/catalyst/api_extensions/control_flow.py index 3e1ce83076..3e0b0cbde0 100644 --- a/frontend/catalyst/api_extensions/control_flow.py +++ b/frontend/catalyst/api_extensions/control_flow.py @@ -24,6 +24,7 @@ import jax import jax.numpy as jnp +import pennylane as qml from jax._src.tree_util import PyTreeDef, tree_unflatten, treedef_is_leaf from jax.core import AbstractValue from pennylane import QueuingManager @@ -65,7 +66,12 @@ JaxTracingContext, ) -import pennylane as qml + +def is_pennylane_gate(callable): + # TODO + return True + + ## API ## def cond(pred: DynamicJaxprTracer): """A :func:`~.qjit` compatible decorator for if-else conditionals in PennyLane/Catalyst. @@ -237,22 +243,12 @@ def conditional_fn(): """ def _decorator(true_fn: Callable): - #breakpoint() - - if len(inspect.signature(true_fn).parameters): - #breakpoint() - current_frame = inspect.currentframe() - call_frame = current_frame.f_back - args_info = inspect.getargvalues(call_frame) - #breakpoint() - - #def _true_fn(): - # true_fn(wires=0) - #return CondCallable(pred, _true_fn) - #raise TypeError("Conditional 'True' function is not allowed to have any arguments") + # if len(inspect.signature(true_fn).parameters): + # raise TypeError("Conditional 'True' function is not allowed to have any arguments") - #return CondCallable(pred, true_fn) + # return CondCallable(pred, true_fn) + # if true_fn is a plain gate: return CondCallableSingleGateHandler(pred, true_fn) return _decorator @@ -512,19 +508,6 @@ def _decorator(body_fn): ## IMPL ## -class CondCallableSingleGateHandler: - def __init__(self, pred, true_fn): - self.pred = pred - self.true_fn = true_fn - - - def __call__(self, *my_args, **my_kwargs): - def new_true_fn(): - self.true_fn(*my_args, **my_kwargs) - #breakpoint() - return CondCallable(self.pred, new_true_fn) - - class CondCallable: """User-facing wrapper provoding "else_if" and "otherwise" public methods. Some code in this class has been adapted from the cond implementation in the JAX project at @@ -641,7 +624,7 @@ def _convert_predicate_to_bool(self, pred): if isinstance(pred, jax.Array) and pred.shape not in ((), (1,)): raise TypeError("Array with multiple elements is not a valid predicate") - #breakpoint() + # breakpoint() if not self._is_any_boolean(pred): try: pred = jnp.astype(pred, bool, copy=False) @@ -650,7 +633,7 @@ def _convert_predicate_to_bool(self, pred): "Conditional predicates are required to be of bool, integer or float type" ) from e - #breakpoint() + # breakpoint() return pred def _is_any_boolean(self, pred): @@ -770,6 +753,20 @@ def __call__(self): return self._call_during_interpretation() +class CondCallableSingleGateHandler(CondCallable): + + def __init__(self, pred, true_fn): + self.pred = pred + self.true_fn = true_fn + + def __call__(self, *my_args, **my_kwargs): + def new_true_fn(): + self.true_fn(*my_args, **my_kwargs) + + super().__init__(self.pred, new_true_fn) + return super().__call__() + + class ForLoopCallable: """ Wrapping for_loop decorator into a class so that the actual "ForLoop" operation object, which diff --git a/frontend/test/pytest/test_conditionals.py b/frontend/test/pytest/test_conditionals.py index 4a0edbb624..ec43f26ffe 100644 --- a/frontend/test/pytest/test_conditionals.py +++ b/frontend/test/pytest/test_conditionals.py @@ -676,6 +676,20 @@ def branch_f(): assert func(True) == 1 assert func(False) == 0 + def test_cond_single_gate(self, backend): + """Test standard pennylane qml.cond usage on single quantum gates.""" + """Fixes https://github.com/PennyLaneAI/catalyst/issues/449""" + + @qjit + @qml.qnode(qml.device(backend, wires=1)) + def func(x, y): + qml.cond(x == 42, qml.Hadamard)(wires=0) + qml.cond(y == 37, qml.Hadamard)(wires=0) + return qml.probs() + + assert np.allclose(func(42, 37), [1, 0]) + assert np.allclose(func(0, 37), [0.5, 0.5]) + class TestCondPredicateConversion: """Test suite for checking predicate conversion to bool.""" From dcb05ef19bb6d19b18e24301367b5547561bcbcc Mon Sep 17 00:00:00 2001 From: Haochen Wang Date: Fri, 25 Oct 2024 13:54:46 -0400 Subject: [PATCH 03/22] add checks for whether the callable is a pennylane gate --- frontend/catalyst/api_extensions/control_flow.py | 16 ++++++---------- frontend/test/pytest/test_conditionals.py | 5 ++--- 2 files changed, 8 insertions(+), 13 deletions(-) diff --git a/frontend/catalyst/api_extensions/control_flow.py b/frontend/catalyst/api_extensions/control_flow.py index 3e0b0cbde0..9ffc2e5e5f 100644 --- a/frontend/catalyst/api_extensions/control_flow.py +++ b/frontend/catalyst/api_extensions/control_flow.py @@ -67,11 +67,6 @@ ) -def is_pennylane_gate(callable): - # TODO - return True - - ## API ## def cond(pred: DynamicJaxprTracer): """A :func:`~.qjit` compatible decorator for if-else conditionals in PennyLane/Catalyst. @@ -244,12 +239,13 @@ def conditional_fn(): def _decorator(true_fn: Callable): - # if len(inspect.signature(true_fn).parameters): - # raise TypeError("Conditional 'True' function is not allowed to have any arguments") + if len(inspect.signature(true_fn).parameters): + if isinstance(true_fn, type) and issubclass(true_fn, qml.operation.Operation): + return CondCallableSingleGateHandler(pred, true_fn) + else: + raise TypeError("Conditional 'True' function is not allowed to have any arguments") - # return CondCallable(pred, true_fn) - # if true_fn is a plain gate: - return CondCallableSingleGateHandler(pred, true_fn) + return CondCallable(pred, true_fn) return _decorator diff --git a/frontend/test/pytest/test_conditionals.py b/frontend/test/pytest/test_conditionals.py index ec43f26ffe..48492473ee 100644 --- a/frontend/test/pytest/test_conditionals.py +++ b/frontend/test/pytest/test_conditionals.py @@ -418,12 +418,11 @@ def conditional_flip(): def test_argument_error_with_callables(self): """Test for the error when arguments are supplied and the target is not a function.""" - @qml.qnode(qml.device("lightning.qubit", wires=1)) def f(x: int): - qml.cond(x < 5, qml.Hadamard)(0) + res = qml.cond(x < 5, lambda z: z + 1)(0) - return qml.probs() + return res with pytest.raises(TypeError, match="not allowed to have any arguments"): qjit(f) From fb1d3a789c9b7c36bb88945139eb295d5201db6a Mon Sep 17 00:00:00 2001 From: Haochen Wang Date: Fri, 25 Oct 2024 13:56:25 -0400 Subject: [PATCH 04/22] remove debugs --- frontend/catalyst/api_extensions/control_flow.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/frontend/catalyst/api_extensions/control_flow.py b/frontend/catalyst/api_extensions/control_flow.py index 9ffc2e5e5f..b5b91cc5a6 100644 --- a/frontend/catalyst/api_extensions/control_flow.py +++ b/frontend/catalyst/api_extensions/control_flow.py @@ -620,7 +620,7 @@ def _convert_predicate_to_bool(self, pred): if isinstance(pred, jax.Array) and pred.shape not in ((), (1,)): raise TypeError("Array with multiple elements is not a valid predicate") - # breakpoint() + if not self._is_any_boolean(pred): try: pred = jnp.astype(pred, bool, copy=False) @@ -629,7 +629,6 @@ def _convert_predicate_to_bool(self, pred): "Conditional predicates are required to be of bool, integer or float type" ) from e - # breakpoint() return pred def _is_any_boolean(self, pred): From 45a82400a82ae03521f0e29bdb14796920abe663 Mon Sep 17 00:00:00 2001 From: Haochen Wang Date: Fri, 25 Oct 2024 14:46:03 -0400 Subject: [PATCH 05/22] add more gates in test --- frontend/test/pytest/test_conditionals.py | 25 ++++++++++++++++++----- 1 file changed, 20 insertions(+), 5 deletions(-) diff --git a/frontend/test/pytest/test_conditionals.py b/frontend/test/pytest/test_conditionals.py index 48492473ee..a7d4eb4727 100644 --- a/frontend/test/pytest/test_conditionals.py +++ b/frontend/test/pytest/test_conditionals.py @@ -679,15 +679,30 @@ def test_cond_single_gate(self, backend): """Test standard pennylane qml.cond usage on single quantum gates.""" """Fixes https://github.com/PennyLaneAI/catalyst/issues/449""" - @qjit - @qml.qnode(qml.device(backend, wires=1)) + @qml.qnode(qml.device(backend, wires=2)) def func(x, y): qml.cond(x == 42, qml.Hadamard)(wires=0) - qml.cond(y == 37, qml.Hadamard)(wires=0) + qml.cond(x == 42, qml.RY)(1.5, wires=0) + qml.cond(x == 42, qml.CNOT)(wires=[1, 0]) + qml.cond(y == 37, qml.PauliX)(wires=1) + qml.cond(y == 37, qml.RZ)(5.1, wires=0) + qml.cond(y == 37, qml.Rot)(1.2, 3.4, 5.6, wires=1) + return qml.probs() - assert np.allclose(func(42, 37), [1, 0]) - assert np.allclose(func(0, 37), [0.5, 0.5]) + expected_0 = func(42, 37) + expected_1 = func(0, 37) + expected_2 = func(42, 0) + + jitted_func = qjit(func) + + observed_0 = jitted_func(42, 37) + observed_1 = jitted_func(0, 37) + observed_2 = jitted_func(42, 0) + + assert np.allclose(expected_0, observed_0) + assert np.allclose(expected_1, observed_1) + assert np.allclose(expected_2, observed_2) class TestCondPredicateConversion: From dbfb6898feb4d341e2e28ee2a2621e18e5ae2946 Mon Sep 17 00:00:00 2001 From: Haochen Wang Date: Fri, 25 Oct 2024 14:47:44 -0400 Subject: [PATCH 06/22] codefactor pointless string --- frontend/test/pytest/test_conditionals.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/frontend/test/pytest/test_conditionals.py b/frontend/test/pytest/test_conditionals.py index a7d4eb4727..366d0d4fbf 100644 --- a/frontend/test/pytest/test_conditionals.py +++ b/frontend/test/pytest/test_conditionals.py @@ -676,8 +676,10 @@ def branch_f(): assert func(False) == 0 def test_cond_single_gate(self, backend): - """Test standard pennylane qml.cond usage on single quantum gates.""" - """Fixes https://github.com/PennyLaneAI/catalyst/issues/449""" + """ + Test standard pennylane qml.cond usage on single quantum gates. + Fixes https://github.com/PennyLaneAI/catalyst/issues/449 + """ @qml.qnode(qml.device(backend, wires=2)) def func(x, y): From 6e5dba116de767f50ebb5e73eb7a6c085e337cbd Mon Sep 17 00:00:00 2001 From: Haochen Wang Date: Fri, 25 Oct 2024 14:56:29 -0400 Subject: [PATCH 07/22] code factor new class docstring --- .../catalyst/api_extensions/control_flow.py | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/frontend/catalyst/api_extensions/control_flow.py b/frontend/catalyst/api_extensions/control_flow.py index b5b91cc5a6..90395ab3ac 100644 --- a/frontend/catalyst/api_extensions/control_flow.py +++ b/frontend/catalyst/api_extensions/control_flow.py @@ -241,6 +241,10 @@ def _decorator(true_fn: Callable): if len(inspect.signature(true_fn).parameters): if isinstance(true_fn, type) and issubclass(true_fn, qml.operation.Operation): + # Special treatment if conditional function body is a single pennylane gate + # The qml.operation.Operation base class represents things that + # can reasonably be considered as a gate, + # e.g. qml.Hadamard, qml.RX, etc. return CondCallableSingleGateHandler(pred, true_fn) else: raise TypeError("Conditional 'True' function is not allowed to have any arguments") @@ -749,6 +753,19 @@ def __call__(self): class CondCallableSingleGateHandler(CondCallable): + """ + Special CondCallable when the conditional body function is a single pennylane gate. + + A usual pennylane conditional call for a gate looks like + `qml.cond(x == 42, qml.RX)(theta, wires=0)` + + Since gates are guaranteed to take in arguments (at the very least the wire argument), + the usual CondCallable class, which expects the conditional body function to have no arguments, + cannot be used. + This class inherits from base CondCallable, but wraps the gate in a function with no arguments, + and send that function to CondCallable. + This allows us to perform the conditional branch gate function with arguments. + """ def __init__(self, pred, true_fn): self.pred = pred From 61611367d9320be977be7ed889c7b764f33d18ae Mon Sep 17 00:00:00 2001 From: Haochen Wang Date: Fri, 25 Oct 2024 14:58:29 -0400 Subject: [PATCH 08/22] codefactor super init not called --- frontend/catalyst/api_extensions/control_flow.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/frontend/catalyst/api_extensions/control_flow.py b/frontend/catalyst/api_extensions/control_flow.py index 90395ab3ac..3fde7ba960 100644 --- a/frontend/catalyst/api_extensions/control_flow.py +++ b/frontend/catalyst/api_extensions/control_flow.py @@ -767,7 +767,7 @@ class CondCallableSingleGateHandler(CondCallable): This allows us to perform the conditional branch gate function with arguments. """ - def __init__(self, pred, true_fn): + def __init__(self, pred, true_fn): # pylint:disable=super-init-not-called self.pred = pred self.true_fn = true_fn From 36279f80b919d6fab05c190d26ccc911129b8962 Mon Sep 17 00:00:00 2001 From: Haochen Wang Date: Fri, 25 Oct 2024 15:05:28 -0400 Subject: [PATCH 09/22] changelog --- doc/releases/changelog-0.9.0.md | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/doc/releases/changelog-0.9.0.md b/doc/releases/changelog-0.9.0.md index da459438dc..9030125816 100644 --- a/doc/releases/changelog-0.9.0.md +++ b/doc/releases/changelog-0.9.0.md @@ -363,6 +363,10 @@ lowering of the scatter operation. [(#1214)](https://github.com/PennyLaneAI/catalyst/pull/1214) +* Fixes a bug where conditional-ed single gates cannot be used in qjit, + i.e. `qml.cond(x > 1, qml.Hadamard)(wires=0)`. + [(#1213)](https://github.com/PennyLaneAI/catalyst/pull/1232) +

Internal changes

* Remove deprecated pennylane code across the frontend. From fd291db3cf0f311cb82e6ff205241b0f8383b04c Mon Sep 17 00:00:00 2001 From: Haochen Wang Date: Fri, 25 Oct 2024 15:06:30 -0400 Subject: [PATCH 10/22] changelog typo --- doc/releases/changelog-0.9.0.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/doc/releases/changelog-0.9.0.md b/doc/releases/changelog-0.9.0.md index 9030125816..b8e62e629b 100644 --- a/doc/releases/changelog-0.9.0.md +++ b/doc/releases/changelog-0.9.0.md @@ -365,7 +365,7 @@ * Fixes a bug where conditional-ed single gates cannot be used in qjit, i.e. `qml.cond(x > 1, qml.Hadamard)(wires=0)`. - [(#1213)](https://github.com/PennyLaneAI/catalyst/pull/1232) + [(#1232)](https://github.com/PennyLaneAI/catalyst/pull/1232)

Internal changes

From a3688e194a972634165167ba009925a38a525498 Mon Sep 17 00:00:00 2001 From: Haochen Wang Date: Mon, 28 Oct 2024 09:38:10 -0400 Subject: [PATCH 11/22] fixing my latin --- doc/releases/changelog-0.9.0.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/doc/releases/changelog-0.9.0.md b/doc/releases/changelog-0.9.0.md index b8e62e629b..5ff1a0286d 100644 --- a/doc/releases/changelog-0.9.0.md +++ b/doc/releases/changelog-0.9.0.md @@ -364,7 +364,7 @@ [(#1214)](https://github.com/PennyLaneAI/catalyst/pull/1214) * Fixes a bug where conditional-ed single gates cannot be used in qjit, - i.e. `qml.cond(x > 1, qml.Hadamard)(wires=0)`. + e.g. `qml.cond(x > 1, qml.Hadamard)(wires=0)`. [(#1232)](https://github.com/PennyLaneAI/catalyst/pull/1232)

Internal changes

From 9a78577da9754e1476544c2cc004c44a174539b9 Mon Sep 17 00:00:00 2001 From: Haochen Wang Date: Mon, 28 Oct 2024 09:38:59 -0400 Subject: [PATCH 12/22] fixing my grammar --- frontend/catalyst/api_extensions/control_flow.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/frontend/catalyst/api_extensions/control_flow.py b/frontend/catalyst/api_extensions/control_flow.py index 3fde7ba960..d61ec22d6e 100644 --- a/frontend/catalyst/api_extensions/control_flow.py +++ b/frontend/catalyst/api_extensions/control_flow.py @@ -763,7 +763,7 @@ class CondCallableSingleGateHandler(CondCallable): the usual CondCallable class, which expects the conditional body function to have no arguments, cannot be used. This class inherits from base CondCallable, but wraps the gate in a function with no arguments, - and send that function to CondCallable. + and sends that function to CondCallable. This allows us to perform the conditional branch gate function with arguments. """ From cef02969a334d40514231303a53544722a9fe6f1 Mon Sep 17 00:00:00 2001 From: Haochen Wang Date: Mon, 28 Oct 2024 11:46:45 -0400 Subject: [PATCH 13/22] add otherwise --- .../catalyst/api_extensions/control_flow.py | 24 +++++++++++++++---- frontend/test/pytest/test_conditionals.py | 4 ++-- 2 files changed, 22 insertions(+), 6 deletions(-) diff --git a/frontend/catalyst/api_extensions/control_flow.py b/frontend/catalyst/api_extensions/control_flow.py index d61ec22d6e..535d1ca6f1 100644 --- a/frontend/catalyst/api_extensions/control_flow.py +++ b/frontend/catalyst/api_extensions/control_flow.py @@ -567,6 +567,9 @@ def __init__(self, pred, true_fn): self._operation = None self.expansion_strategy = cond_expansion_strategy() + def set_otherwise_fn(self, otherwise_fn): + self.otherwise_fn = otherwise_fn + @property def operation(self): """ @@ -768,16 +771,29 @@ class CondCallableSingleGateHandler(CondCallable): """ def __init__(self, pred, true_fn): # pylint:disable=super-init-not-called - self.pred = pred - self.true_fn = true_fn + self._pred = pred + self._true_fn = true_fn + self._otherwise_fn = None def __call__(self, *my_args, **my_kwargs): def new_true_fn(): - self.true_fn(*my_args, **my_kwargs) + self._true_fn(*my_args, **my_kwargs) + + super().__init__(self._pred, new_true_fn) + + if self._otherwise_fn is not None: + + def new_otherwise_fn(): + self._otherwise_fn(*my_args, **my_kwargs) + + super().set_otherwise_fn(new_otherwise_fn) - super().__init__(self.pred, new_true_fn) return super().__call__() + def otherwise(self, otherwise_fn): + # Override the "can't have arguments" check in the original CondCallable's `otherwise` + self._otherwise_fn = otherwise_fn + class ForLoopCallable: """ diff --git a/frontend/test/pytest/test_conditionals.py b/frontend/test/pytest/test_conditionals.py index 366d0d4fbf..c187ce3db0 100644 --- a/frontend/test/pytest/test_conditionals.py +++ b/frontend/test/pytest/test_conditionals.py @@ -683,8 +683,8 @@ def test_cond_single_gate(self, backend): @qml.qnode(qml.device(backend, wires=2)) def func(x, y): - qml.cond(x == 42, qml.Hadamard)(wires=0) - qml.cond(x == 42, qml.RY)(1.5, wires=0) + qml.cond(x == 42, qml.Hadamard, qml.PauliX)(wires=0) + qml.cond(x == 42, qml.RY, qml.RZ)(1.5, wires=0) qml.cond(x == 42, qml.CNOT)(wires=[1, 0]) qml.cond(y == 37, qml.PauliX)(wires=1) qml.cond(y == 37, qml.RZ)(5.1, wires=0) From 25ab7287d80ee5f5e344e2950d3ba47165fdcc8c Mon Sep 17 00:00:00 2001 From: Haochen Wang Date: Mon, 28 Oct 2024 11:50:11 -0400 Subject: [PATCH 14/22] rename my_args -> args --- frontend/catalyst/api_extensions/control_flow.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/frontend/catalyst/api_extensions/control_flow.py b/frontend/catalyst/api_extensions/control_flow.py index 535d1ca6f1..4248e1a175 100644 --- a/frontend/catalyst/api_extensions/control_flow.py +++ b/frontend/catalyst/api_extensions/control_flow.py @@ -775,16 +775,16 @@ def __init__(self, pred, true_fn): # pylint:disable=super-init-not-called self._true_fn = true_fn self._otherwise_fn = None - def __call__(self, *my_args, **my_kwargs): + def __call__(self, *args, **kwargs): def new_true_fn(): - self._true_fn(*my_args, **my_kwargs) + self._true_fn(*args, **kwargs) super().__init__(self._pred, new_true_fn) if self._otherwise_fn is not None: def new_otherwise_fn(): - self._otherwise_fn(*my_args, **my_kwargs) + self._otherwise_fn(*args, **kwargs) super().set_otherwise_fn(new_otherwise_fn) From 22e7c52b186612a5ee42b1c4056449c1bb6fa8de Mon Sep 17 00:00:00 2001 From: Haochen Wang Date: Mon, 28 Oct 2024 11:52:19 -0400 Subject: [PATCH 15/22] more renames --- frontend/catalyst/api_extensions/control_flow.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/frontend/catalyst/api_extensions/control_flow.py b/frontend/catalyst/api_extensions/control_flow.py index 4248e1a175..d8fbfac1b7 100644 --- a/frontend/catalyst/api_extensions/control_flow.py +++ b/frontend/catalyst/api_extensions/control_flow.py @@ -771,20 +771,20 @@ class CondCallableSingleGateHandler(CondCallable): """ def __init__(self, pred, true_fn): # pylint:disable=super-init-not-called - self._pred = pred - self._true_fn = true_fn - self._otherwise_fn = None + self.sgh_pred = pred + self.sgh_true_fn = true_fn + self.sgh_otherwise_fn = None def __call__(self, *args, **kwargs): def new_true_fn(): - self._true_fn(*args, **kwargs) + self.sgh_true_fn(*args, **kwargs) - super().__init__(self._pred, new_true_fn) + super().__init__(self.sgh_pred, new_true_fn) - if self._otherwise_fn is not None: + if self.sgh_otherwise_fn is not None: def new_otherwise_fn(): - self._otherwise_fn(*args, **kwargs) + self.sgh_otherwise_fn(*args, **kwargs) super().set_otherwise_fn(new_otherwise_fn) @@ -792,7 +792,7 @@ def new_otherwise_fn(): def otherwise(self, otherwise_fn): # Override the "can't have arguments" check in the original CondCallable's `otherwise` - self._otherwise_fn = otherwise_fn + self.sgh_otherwise_fn = otherwise_fn class ForLoopCallable: From 5fae6b3f0534ba4c734264a2108953c62121ce05 Mon Sep 17 00:00:00 2001 From: Haochen Wang Date: Mon, 28 Oct 2024 11:53:32 -0400 Subject: [PATCH 16/22] more renames --- frontend/catalyst/api_extensions/control_flow.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/frontend/catalyst/api_extensions/control_flow.py b/frontend/catalyst/api_extensions/control_flow.py index d8fbfac1b7..f77e408fc1 100644 --- a/frontend/catalyst/api_extensions/control_flow.py +++ b/frontend/catalyst/api_extensions/control_flow.py @@ -776,17 +776,17 @@ def __init__(self, pred, true_fn): # pylint:disable=super-init-not-called self.sgh_otherwise_fn = None def __call__(self, *args, **kwargs): - def new_true_fn(): + def argless_true_fn(): self.sgh_true_fn(*args, **kwargs) - super().__init__(self.sgh_pred, new_true_fn) + super().__init__(self.sgh_pred, argless_true_fn) if self.sgh_otherwise_fn is not None: - def new_otherwise_fn(): + def argless_otherwise_fn(): self.sgh_otherwise_fn(*args, **kwargs) - super().set_otherwise_fn(new_otherwise_fn) + super().set_otherwise_fn(argless_otherwise_fn) return super().__call__() From 90edf445107e80d6c1c9cd7d8cfcc968e964e2dc Mon Sep 17 00:00:00 2001 From: Haochen Wang Date: Mon, 28 Oct 2024 12:01:28 -0400 Subject: [PATCH 17/22] add gate checks for otherwise func --- frontend/catalyst/api_extensions/control_flow.py | 7 ++++++- frontend/test/pytest/test_conditionals.py | 12 ++++++++++++ 2 files changed, 18 insertions(+), 1 deletion(-) diff --git a/frontend/catalyst/api_extensions/control_flow.py b/frontend/catalyst/api_extensions/control_flow.py index f77e408fc1..816380bb3a 100644 --- a/frontend/catalyst/api_extensions/control_flow.py +++ b/frontend/catalyst/api_extensions/control_flow.py @@ -792,7 +792,12 @@ def argless_otherwise_fn(): def otherwise(self, otherwise_fn): # Override the "can't have arguments" check in the original CondCallable's `otherwise` - self.sgh_otherwise_fn = otherwise_fn + if isinstance(otherwise_fn, type) and issubclass(otherwise_fn, qml.operation.Operation): + self.sgh_otherwise_fn = otherwise_fn + else: + raise TypeError( + "Conditional 'False' function is allowed to have arguments only if it is a PennyLane gate." + ) class ForLoopCallable: diff --git a/frontend/test/pytest/test_conditionals.py b/frontend/test/pytest/test_conditionals.py index c187ce3db0..0c36e3df16 100644 --- a/frontend/test/pytest/test_conditionals.py +++ b/frontend/test/pytest/test_conditionals.py @@ -427,6 +427,18 @@ def f(x: int): with pytest.raises(TypeError, match="not allowed to have any arguments"): qjit(f) + def f(x: int): + + res = qml.cond(x < 5, qml.Hadamard, lambda z: z + 1)(0) + + return res + + with pytest.raises( + TypeError, + match="Conditional 'False' function is allowed to have arguments only if it is a PennyLane gate.", + ): + qjit(f) + class TestInterpretationConditional: """Test that the conditional operation's execution is semantically equivalent From 070885bcd95076aaf23dc89f4a9ecb5e627f42cc Mon Sep 17 00:00:00 2001 From: Haochen Wang Date: Mon, 28 Oct 2024 13:11:24 -0400 Subject: [PATCH 18/22] codefactor --- frontend/catalyst/api_extensions/control_flow.py | 4 ++-- frontend/test/pytest/test_conditionals.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/frontend/catalyst/api_extensions/control_flow.py b/frontend/catalyst/api_extensions/control_flow.py index 816380bb3a..67ab1ac57b 100644 --- a/frontend/catalyst/api_extensions/control_flow.py +++ b/frontend/catalyst/api_extensions/control_flow.py @@ -567,7 +567,7 @@ def __init__(self, pred, true_fn): self._operation = None self.expansion_strategy = cond_expansion_strategy() - def set_otherwise_fn(self, otherwise_fn): + def set_otherwise_fn(self, otherwise_fn): # pylint:disable=missing-function-docstring self.otherwise_fn = otherwise_fn @property @@ -796,7 +796,7 @@ def otherwise(self, otherwise_fn): self.sgh_otherwise_fn = otherwise_fn else: raise TypeError( - "Conditional 'False' function is allowed to have arguments only if it is a PennyLane gate." + "Conditional 'False' function can have arguments only if it is a PennyLane gate." ) diff --git a/frontend/test/pytest/test_conditionals.py b/frontend/test/pytest/test_conditionals.py index 0c36e3df16..7bc04e0514 100644 --- a/frontend/test/pytest/test_conditionals.py +++ b/frontend/test/pytest/test_conditionals.py @@ -435,7 +435,7 @@ def f(x: int): with pytest.raises( TypeError, - match="Conditional 'False' function is allowed to have arguments only if it is a PennyLane gate.", + match="Conditional 'False' function can have arguments only if it is a PennyLane gate.", ): qjit(f) From 162422197de6eccd81906e1d2aac0b552d841096 Mon Sep 17 00:00:00 2001 From: Haochen Wang Date: Mon, 28 Oct 2024 13:28:24 -0400 Subject: [PATCH 19/22] add lit test --- frontend/test/lit/test_if_else.py | 37 +++++++++++++++++++++++++++++++ 1 file changed, 37 insertions(+) diff --git a/frontend/test/lit/test_if_else.py b/frontend/test/lit/test_if_else.py index 5b37804b7f..53e472e083 100644 --- a/frontend/test/lit/test_if_else.py +++ b/frontend/test/lit/test_if_else.py @@ -58,6 +58,43 @@ def otherwise(): # ----- +# CHECK-LABEL: public @jit_circuit_single_gate +@qjit(target="mlir") +@qml.qnode(qml.device("lightning.qubit", wires=1)) +def circuit_single_gate(n: int): + # CHECK-DAG: [[c5:%[a-zA-Z0-9_]+]] = stablehlo.constant dense<5> : tensor + # CHECK: [[b_t:%[a-zA-Z0-9_]+]] = stablehlo.compare LE, %arg0, [[c5]], SIGNED : (tensor, tensor) -> tensor + # CHECK-DAG: [[qreg_0:%[a-zA-Z0-9_]+]] = quantum.alloc + # CHECK: [[b:%[a-zA-Z0-9_]+]] = tensor.extract [[b_t]] + + # CHECK: [[qreg_out:%.+]] = scf.if [[b]] + # CHECK-DAG: [[q0:%[a-zA-Z0-9_]+]] = quantum.extract + # CHECK-DAG: [[q1:%[a-zA-Z0-9_]+]] = quantum.custom "PauliX"() [[q0]] + # pylint: disable=line-too-long + # CHECK-DAG: [[qreg_1:%[a-zA-Z0-9_]+]] = quantum.insert [[qreg_0]][ {{[%a-zA-Z0-9_]+}}], [[q1]] + # CHECK: scf.yield [[qreg_1]] + + # CHECK: else + # CHECK-DAG: [[q2:%[a-zA-Z0-9_]+]] = quantum.extract + # CHECK-DAG: [[q3:%[a-zA-Z0-9_]+]] = quantum.custom "Hadamard"() [[q2]] + # pylint: disable=line-too-long + # CHECK-DAG: [[qreg_2:%[a-zA-Z0-9_]+]] = quantum.insert [[qreg_0]][ {{[%a-zA-Z0-9_]+}}], [[q3]] + # CHECK: scf.yield [[qreg_2]] + qml.cond(n <= 5, qml.PauliX, qml.Hadamard)(wires=0) + + # CHECK: [[qreg_3:%.+]] = quantum.extract [[qreg_out]][ 0] + # CHECK: [[qobs:%.+]] = quantum.compbasis [[qreg_3]] : !quantum.obs + # CHECK: [[ret:%.+]] = quantum.probs [[qobs]] + # CHECK: return [[ret]] + return qml.probs() + + +print(circuit_single_gate.mlir) + + +# ----- + + # CHECK-LABEL: test_convert_element_type @qjit def test_convert_element_type(i: int, f: float): From ce24fcf4457ebe679a35b624ef696bd8519ebd07 Mon Sep 17 00:00:00 2001 From: Haochen Wang Date: Mon, 28 Oct 2024 13:41:24 -0400 Subject: [PATCH 20/22] more lit tests --- frontend/test/lit/test_if_else.py | 26 ++++++++++++++++++++------ 1 file changed, 20 insertions(+), 6 deletions(-) diff --git a/frontend/test/lit/test_if_else.py b/frontend/test/lit/test_if_else.py index 53e472e083..9097a2af5e 100644 --- a/frontend/test/lit/test_if_else.py +++ b/frontend/test/lit/test_if_else.py @@ -63,26 +63,40 @@ def otherwise(): @qml.qnode(qml.device("lightning.qubit", wires=1)) def circuit_single_gate(n: int): # CHECK-DAG: [[c5:%[a-zA-Z0-9_]+]] = stablehlo.constant dense<5> : tensor - # CHECK: [[b_t:%[a-zA-Z0-9_]+]] = stablehlo.compare LE, %arg0, [[c5]], SIGNED : (tensor, tensor) -> tensor + # CHECK-DAG: [[c6:%[a-zA-Z0-9_]+]] = stablehlo.constant dense<6> : tensor + # CHECK-DAG: [[b_t5:%[a-zA-Z0-9_]+]] = stablehlo.compare LE, %arg0, [[c5]], SIGNED : (tensor, tensor) -> tensor + # CHECK-DAG: [[b_t6:%[a-zA-Z0-9_]+]] = stablehlo.compare LE, %arg0, [[c6]], SIGNED : (tensor, tensor) -> tensor # CHECK-DAG: [[qreg_0:%[a-zA-Z0-9_]+]] = quantum.alloc - # CHECK: [[b:%[a-zA-Z0-9_]+]] = tensor.extract [[b_t]] + # CHECK: [[b5:%[a-zA-Z0-9_]+]] = tensor.extract [[b_t5]] - # CHECK: [[qreg_out:%.+]] = scf.if [[b]] - # CHECK-DAG: [[q0:%[a-zA-Z0-9_]+]] = quantum.extract + # CHECK: [[qreg_out:%.+]] = scf.if [[b5]] + # CHECK-DAG: [[q0:%[a-zA-Z0-9_]+]] = quantum.extract [[qreg_0]] # CHECK-DAG: [[q1:%[a-zA-Z0-9_]+]] = quantum.custom "PauliX"() [[q0]] # pylint: disable=line-too-long # CHECK-DAG: [[qreg_1:%[a-zA-Z0-9_]+]] = quantum.insert [[qreg_0]][ {{[%a-zA-Z0-9_]+}}], [[q1]] # CHECK: scf.yield [[qreg_1]] # CHECK: else - # CHECK-DAG: [[q2:%[a-zA-Z0-9_]+]] = quantum.extract + # CHECK-DAG: [[q2:%[a-zA-Z0-9_]+]] = quantum.extract [[qreg_0]] # CHECK-DAG: [[q3:%[a-zA-Z0-9_]+]] = quantum.custom "Hadamard"() [[q2]] # pylint: disable=line-too-long # CHECK-DAG: [[qreg_2:%[a-zA-Z0-9_]+]] = quantum.insert [[qreg_0]][ {{[%a-zA-Z0-9_]+}}], [[q3]] # CHECK: scf.yield [[qreg_2]] qml.cond(n <= 5, qml.PauliX, qml.Hadamard)(wires=0) - # CHECK: [[qreg_3:%.+]] = quantum.extract [[qreg_out]][ 0] + # CHECK: [[b6:%[a-zA-Z0-9_]+]] = tensor.extract [[b_t6]] + # CHECK: [[qreg_out1:%.+]] = scf.if [[b6]] + # CHECK-DAG: [[q4:%[a-zA-Z0-9_]+]] = quantum.extract [[qreg_out]] + # CHECK-DAG: [[q5:%[a-zA-Z0-9_]+]] = quantum.custom "RX"({{%.+}}) [[q4]] + # pylint: disable=line-too-long + # CHECK-DAG: [[qreg_3:%[a-zA-Z0-9_]+]] = quantum.insert [[qreg_out]][ {{[%a-zA-Z0-9_]+}}], [[q1]] + # CHECK: scf.yield [[qreg_3]] + # CHECK: else + # CHECK: scf.yield [[qreg_out]] + + qml.cond(n <= 6, qml.RX)(3.14, wires=0) + + # CHECK: [[qreg_3:%.+]] = quantum.extract [[qreg_out1]][ 0] # CHECK: [[qobs:%.+]] = quantum.compbasis [[qreg_3]] : !quantum.obs # CHECK: [[ret:%.+]] = quantum.probs [[qobs]] # CHECK: return [[ret]] From 3415dfc20d2c63a93d96571b3c902a11baf2fb55 Mon Sep 17 00:00:00 2001 From: Haochen Wang Date: Mon, 28 Oct 2024 15:20:03 -0400 Subject: [PATCH 21/22] elseif --- .../catalyst/api_extensions/control_flow.py | 43 ++++++++++++++++--- frontend/test/pytest/test_conditionals.py | 25 ++++++++++- 2 files changed, 62 insertions(+), 6 deletions(-) diff --git a/frontend/catalyst/api_extensions/control_flow.py b/frontend/catalyst/api_extensions/control_flow.py index 67ab1ac57b..1346761528 100644 --- a/frontend/catalyst/api_extensions/control_flow.py +++ b/frontend/catalyst/api_extensions/control_flow.py @@ -570,6 +570,12 @@ def __init__(self, pred, true_fn): def set_otherwise_fn(self, otherwise_fn): # pylint:disable=missing-function-docstring self.otherwise_fn = otherwise_fn + def add_pred(self, _pred): + self.preds.append(self._convert_predicate_to_bool(_pred)) + + def add_branch_fn(self, _branch_fn): + self.branch_fns.append(_branch_fn) + @property def operation(self): """ @@ -771,15 +777,15 @@ class CondCallableSingleGateHandler(CondCallable): """ def __init__(self, pred, true_fn): # pylint:disable=super-init-not-called - self.sgh_pred = pred - self.sgh_true_fn = true_fn + self.sgh_preds = [pred] + self.sgh_branch_fns = [true_fn] self.sgh_otherwise_fn = None def __call__(self, *args, **kwargs): def argless_true_fn(): - self.sgh_true_fn(*args, **kwargs) + self.sgh_branch_fns[0](*args, **kwargs) - super().__init__(self.sgh_pred, argless_true_fn) + super().__init__(self.sgh_preds[0], argless_true_fn) if self.sgh_otherwise_fn is not None: @@ -788,10 +794,37 @@ def argless_otherwise_fn(): super().set_otherwise_fn(argless_otherwise_fn) + for i in range(1, len(self.sgh_branch_fns)): + + def argless_elseif_fn(i=i): # i=i to work around late binding + self.sgh_branch_fns[i](*args, **kwargs) + + super().add_pred(self.sgh_preds[i]) + super().add_branch_fn(argless_elseif_fn) + return super().__call__() + def else_if(self, _pred): + """ + Override the "can't have arguments" check in the original CondCallable's `else_if` + """ + + def decorator(branch_fn): + if isinstance(branch_fn, type) and issubclass(branch_fn, qml.operation.Operation): + self.sgh_preds.append(_pred) + self.sgh_branch_fns.append(branch_fn) + return self + else: + raise TypeError( + "Conditional 'else if' function can have arguments only if it is a PennyLane gate." + ) + + return decorator + def otherwise(self, otherwise_fn): - # Override the "can't have arguments" check in the original CondCallable's `otherwise` + """ + Override the "can't have arguments" check in the original CondCallable's `otherwise` + """ if isinstance(otherwise_fn, type) and issubclass(otherwise_fn, qml.operation.Operation): self.sgh_otherwise_fn = otherwise_fn else: diff --git a/frontend/test/pytest/test_conditionals.py b/frontend/test/pytest/test_conditionals.py index 7bc04e0514..19b5a0eb22 100644 --- a/frontend/test/pytest/test_conditionals.py +++ b/frontend/test/pytest/test_conditionals.py @@ -439,6 +439,18 @@ def f(x: int): ): qjit(f) + def f(x: int): + + res = qml.cond(x < 5, qml.Hadamard, qml.Hadamard, ((x < 6, lambda z: z + 1),))(0) + + return res + + with pytest.raises( + TypeError, + match="Conditional 'else if' function can have arguments only if it is a PennyLane gate.", + ): + qjit(f) + class TestInterpretationConditional: """Test that the conditional operation's execution is semantically equivalent @@ -699,7 +711,15 @@ def func(x, y): qml.cond(x == 42, qml.RY, qml.RZ)(1.5, wires=0) qml.cond(x == 42, qml.CNOT)(wires=[1, 0]) qml.cond(y == 37, qml.PauliX)(wires=1) - qml.cond(y == 37, qml.RZ)(5.1, wires=0) + qml.cond( + y == 36, + qml.RZ, + qml.RY, + ( + (x == 42, qml.RX), + (x == 41, qml.RZ), + ), + )(5.1, wires=0) qml.cond(y == 37, qml.Rot)(1.2, 3.4, 5.6, wires=1) return qml.probs() @@ -707,16 +727,19 @@ def func(x, y): expected_0 = func(42, 37) expected_1 = func(0, 37) expected_2 = func(42, 0) + expected_3 = func(41, 0) jitted_func = qjit(func) observed_0 = jitted_func(42, 37) observed_1 = jitted_func(0, 37) observed_2 = jitted_func(42, 0) + observed_3 = jitted_func(41, 0) assert np.allclose(expected_0, observed_0) assert np.allclose(expected_1, observed_1) assert np.allclose(expected_2, observed_2) + assert np.allclose(expected_3, observed_3) class TestCondPredicateConversion: From e339f972d4c0a73f66e49dc9975dec15dd5033d3 Mon Sep 17 00:00:00 2001 From: Haochen Wang Date: Mon, 28 Oct 2024 16:38:00 -0400 Subject: [PATCH 22/22] elseif lit test --- frontend/test/lit/test_if_else.py | 56 +++++++++++++++++++++++++++---- 1 file changed, 49 insertions(+), 7 deletions(-) diff --git a/frontend/test/lit/test_if_else.py b/frontend/test/lit/test_if_else.py index 9097a2af5e..25bd3ebbad 100644 --- a/frontend/test/lit/test_if_else.py +++ b/frontend/test/lit/test_if_else.py @@ -62,24 +62,29 @@ def otherwise(): @qjit(target="mlir") @qml.qnode(qml.device("lightning.qubit", wires=1)) def circuit_single_gate(n: int): + # pylint: disable=line-too-long # CHECK-DAG: [[c5:%[a-zA-Z0-9_]+]] = stablehlo.constant dense<5> : tensor # CHECK-DAG: [[c6:%[a-zA-Z0-9_]+]] = stablehlo.constant dense<6> : tensor + # CHECK-DAG: [[c7:%[a-zA-Z0-9_]+]] = stablehlo.constant dense<7> : tensor + # CHECK-DAG: [[c8:%[a-zA-Z0-9_]+]] = stablehlo.constant dense<8> : tensor + # CHECK-DAG: [[c9:%[a-zA-Z0-9_]+]] = stablehlo.constant dense<9> : tensor # CHECK-DAG: [[b_t5:%[a-zA-Z0-9_]+]] = stablehlo.compare LE, %arg0, [[c5]], SIGNED : (tensor, tensor) -> tensor # CHECK-DAG: [[b_t6:%[a-zA-Z0-9_]+]] = stablehlo.compare LE, %arg0, [[c6]], SIGNED : (tensor, tensor) -> tensor + # CHECK-DAG: [[b_t7:%[a-zA-Z0-9_]+]] = stablehlo.compare LE, %arg0, [[c7]], SIGNED : (tensor, tensor) -> tensor + # CHECK-DAG: [[b_t8:%[a-zA-Z0-9_]+]] = stablehlo.compare LE, %arg0, [[c8]], SIGNED : (tensor, tensor) -> tensor + # CHECK-DAG: [[b_t9:%[a-zA-Z0-9_]+]] = stablehlo.compare LE, %arg0, [[c9]], SIGNED : (tensor, tensor) -> tensor # CHECK-DAG: [[qreg_0:%[a-zA-Z0-9_]+]] = quantum.alloc - # CHECK: [[b5:%[a-zA-Z0-9_]+]] = tensor.extract [[b_t5]] + # CHECK: [[b5:%[a-zA-Z0-9_]+]] = tensor.extract [[b_t5]] # CHECK: [[qreg_out:%.+]] = scf.if [[b5]] # CHECK-DAG: [[q0:%[a-zA-Z0-9_]+]] = quantum.extract [[qreg_0]] # CHECK-DAG: [[q1:%[a-zA-Z0-9_]+]] = quantum.custom "PauliX"() [[q0]] - # pylint: disable=line-too-long # CHECK-DAG: [[qreg_1:%[a-zA-Z0-9_]+]] = quantum.insert [[qreg_0]][ {{[%a-zA-Z0-9_]+}}], [[q1]] # CHECK: scf.yield [[qreg_1]] # CHECK: else # CHECK-DAG: [[q2:%[a-zA-Z0-9_]+]] = quantum.extract [[qreg_0]] # CHECK-DAG: [[q3:%[a-zA-Z0-9_]+]] = quantum.custom "Hadamard"() [[q2]] - # pylint: disable=line-too-long # CHECK-DAG: [[qreg_2:%[a-zA-Z0-9_]+]] = quantum.insert [[qreg_0]][ {{[%a-zA-Z0-9_]+}}], [[q3]] # CHECK: scf.yield [[qreg_2]] qml.cond(n <= 5, qml.PauliX, qml.Hadamard)(wires=0) @@ -88,16 +93,53 @@ def circuit_single_gate(n: int): # CHECK: [[qreg_out1:%.+]] = scf.if [[b6]] # CHECK-DAG: [[q4:%[a-zA-Z0-9_]+]] = quantum.extract [[qreg_out]] # CHECK-DAG: [[q5:%[a-zA-Z0-9_]+]] = quantum.custom "RX"({{%.+}}) [[q4]] - # pylint: disable=line-too-long - # CHECK-DAG: [[qreg_3:%[a-zA-Z0-9_]+]] = quantum.insert [[qreg_out]][ {{[%a-zA-Z0-9_]+}}], [[q1]] + # CHECK-DAG: [[qreg_3:%[a-zA-Z0-9_]+]] = quantum.insert [[qreg_out]][ {{[%a-zA-Z0-9_]+}}], [[q5]] # CHECK: scf.yield [[qreg_3]] # CHECK: else # CHECK: scf.yield [[qreg_out]] qml.cond(n <= 6, qml.RX)(3.14, wires=0) - # CHECK: [[qreg_3:%.+]] = quantum.extract [[qreg_out1]][ 0] - # CHECK: [[qobs:%.+]] = quantum.compbasis [[qreg_3]] : !quantum.obs + # CHECK: [[b7:%[a-zA-Z0-9_]+]] = tensor.extract [[b_t7]] + # CHECK: [[qreg_out2:%.+]] = scf.if [[b7]] + # CHECK-DAG: [[q7:%[a-zA-Z0-9_]+]] = quantum.extract [[qreg_out1]] + # CHECK-DAG: [[q8:%[a-zA-Z0-9_]+]] = quantum.custom "Hadamard"() [[q7]] + # pylint: disable=line-too-long + # CHECK-DAG: [[qreg_4:%[a-zA-Z0-9_]+]] = quantum.insert [[qreg_out1]][ {{[%a-zA-Z0-9_]+}}], [[q8]] + # CHECK: scf.yield [[qreg_4]] + # CHECK: else { + # CHECK: [[b8:%[a-zA-Z0-9_]+]] = tensor.extract [[b_t8]] + # CHECK: [[qreg_out3:%.+]] = scf.if [[b8]] + # CHECK-DAG: [[q9:%[a-zA-Z0-9_]+]] = quantum.extract [[qreg_out1]] + # CHECK-DAG: [[q10:%[a-zA-Z0-9_]+]] = quantum.custom "PauliY"() [[q9]] + # CHECK-DAG: [[qreg_5:%[a-zA-Z0-9_]+]] = quantum.insert [[qreg_out1]][ {{[%a-zA-Z0-9_]+}}], [[q10]] + # CHECK: scf.yield [[qreg_5]] + # CHECK: else { + # CHECK: [[b9:%[a-zA-Z0-9_]+]] = tensor.extract [[b_t9]] + # CHECK: [[qreg_out4:%.+]] = scf.if [[b9]] + # CHECK-DAG: [[q11:%[a-zA-Z0-9_]+]] = quantum.extract [[qreg_out1]] + # CHECK-DAG: [[q12:%[a-zA-Z0-9_]+]] = quantum.custom "PauliZ"() [[q11]] + # CHECK-DAG: [[qreg_6:%[a-zA-Z0-9_]+]] = quantum.insert [[qreg_out1]][ {{[%a-zA-Z0-9_]+}}], [[q12]] + # CHECK: scf.yield [[qreg_6]] + # CHECK: else { + # CHECK-DAG: [[q13:%[a-zA-Z0-9_]+]] = quantum.extract [[qreg_out1]] + # CHECK-DAG: [[q14:%[a-zA-Z0-9_]+]] = quantum.custom "PauliX"() [[q13]] + # CHECK-DAG: [[qreg_7:%[a-zA-Z0-9_]+]] = quantum.insert [[qreg_out1]][ {{[%a-zA-Z0-9_]+}}], [[q14]] + # CHECK: scf.yield [[qreg_7]] + # CHECK: scf.yield [[qreg_out4]] + # CHECK: scf.yield [[qreg_out3]] + qml.cond( + n <= 7, + qml.Hadamard, + qml.PauliX, + ( + (n <= 8, qml.PauliY), + (n <= 9, qml.PauliZ), + ), + )(wires=0) + + # CHECK: [[qreg_ret:%.+]] = quantum.extract [[qreg_out2]][ 0] + # CHECK: [[qobs:%.+]] = quantum.compbasis [[qreg_ret]] : !quantum.obs # CHECK: [[ret:%.+]] = quantum.probs [[qobs]] # CHECK: return [[ret]] return qml.probs()