Skip to content

Commit f8d0cae

Browse files
authored
[Bugfix] Tensor representation of controlled operations with qubit support order (#40)
* add single dispatch control * add tensor method * merge conflict * docstr tensor * values description * check tensor on a crx gate * add _ * make unitary private * change options selection mkdocs * fix docstr expand operatr ---------
1 parent c61e84d commit f8d0cae

File tree

7 files changed

+218
-16
lines changed

7 files changed

+218
-16
lines changed

horqrux/analog.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ class _HamiltonianEvolution(Primitive):
2121
target: QubitSupport
2222
control: QubitSupport
2323

24-
def unitary(self, values: dict[str, Array] = dict()) -> Array:
24+
def _unitary(self, values: dict[str, Array] = dict()) -> Array:
2525
return expm(values["hamiltonian"] * (-1j * values["time_evolution"]))
2626

2727

horqrux/parametric.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ def __iter__(self) -> Iterable:
6565
def tree_unflatten(cls, aux_data: Any, children: Any) -> Any:
6666
return cls(*children, *aux_data)
6767

68-
def unitary(self, values: dict[str, float] = dict()) -> Array:
68+
def _unitary(self, values: dict[str, float] = dict()) -> Array:
6969
return _unitary(OPERATIONS_DICT[self.generator_name], self.parse_values(values))
7070

7171
def jacobian(self, values: dict[str, float] = dict()) -> Array:
@@ -141,7 +141,7 @@ def RZ(
141141

142142

143143
class _PHASE(Parametric):
144-
def unitary(self, values: dict[str, float] = dict()) -> Array:
144+
def _unitary(self, values: dict[str, float] = dict()) -> Array:
145145
u = jnp.eye(2, 2, dtype=default_dtype)
146146
u = u.at[(1, 1)].set(jnp.exp(1.0j * self.parse_values(values)))
147147
return u

horqrux/primitive.py

+33-2
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
QubitSupport,
1515
TargetQubits,
1616
_dagger,
17+
controlled,
1718
is_controlled,
1819
none_like,
1920
)
@@ -60,11 +61,41 @@ def tree_flatten(self) -> tuple[tuple, tuple[str, TargetQubits, ControlQubits, N
6061
def tree_unflatten(cls, aux_data: Any, children: Any) -> Any:
6162
return cls(*children, *aux_data)
6263

63-
def unitary(self, values: dict[str, float] = dict()) -> Array:
64+
def _unitary(self, values: dict[str, float] = dict()) -> Array:
65+
"""Obtain the base unitary from `generator_name`.
66+
67+
Args:
68+
values (dict[str, float], optional): Parameter values. Defaults to dict().
69+
70+
Returns:
71+
Array: The base unitary from `generator_name`.
72+
"""
6473
return OPERATIONS_DICT[self.generator_name]
6574

6675
def dagger(self, values: dict[str, float] = dict()) -> Array:
67-
return _dagger(self.unitary(values))
76+
"""Obtain the dagger of the base unitary from `generator_name`.
77+
78+
Args:
79+
values (dict[str, float], optional): Parameter values. Defaults to dict().
80+
81+
Returns:
82+
Array: The base unitary daggered from `generator_name`.
83+
"""
84+
return _dagger(self._unitary(values))
85+
86+
def tensor(self, values: dict[str, float] = dict()) -> Array:
87+
"""Obtain the unitary taking into account the qubit support for controlled operations.
88+
89+
Args:
90+
values (dict[str, float], optional): Parameter values. Defaults to dict().
91+
92+
Returns:
93+
Array: Unitary representation taking into account the qubit support.
94+
"""
95+
base_unitary = self._unitary(values)
96+
if is_controlled(self.control):
97+
return controlled(base_unitary, self.target, self.control)
98+
return base_unitary
6899

69100
@property
70101
def name(self) -> str:

horqrux/shots.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ def to_matrix(
2929
observable.control == observable.parse_idx(none_like(observable.target)),
3030
"Controlled gates cannot be promoted from observables to operations on the whole state vector",
3131
)
32-
unitary = observable.unitary(values=values)
32+
unitary = observable._unitary(values=values)
3333
target = observable.target[0][0]
3434
identity = jnp.eye(2, dtype=unitary.dtype)
3535
ops = [identity for _ in range(n_qubits)]

horqrux/utils.py

+101-2
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,8 @@
33
from collections import Counter
44
from dataclasses import dataclass
55
from enum import Enum
6-
from functools import singledispatch
6+
from functools import reduce, singledispatch
7+
from math import log
78
from typing import Any, Iterable, Union
89

910
import jax
@@ -15,6 +16,7 @@
1516
from numpy import log2
1617

1718
from ._misc import default_complex_dtype
19+
from .matrices import _I
1820

1921
default_dtype = default_complex_dtype()
2022

@@ -97,7 +99,7 @@ def list(cls) -> list[str]:
9799

98100

99101
class OperationType(StrEnum):
100-
UNITARY = "unitary"
102+
UNITARY = "_unitary"
101103
DAGGER = "dagger"
102104
JACOBIAN = "jacobian"
103105

@@ -151,12 +153,109 @@ def _jacobian(generator: Array, theta: float) -> Array:
151153

152154

153155
def _controlled(operator: Array, n_control: int) -> Array:
156+
"""
157+
Create a controlled quantum operator with specified number of control qubits.
158+
159+
Args:
160+
operator (jnp.ndarray): The base quantum operator to be controlled.
161+
n_control (int): Number of control qubits.
162+
163+
Returns:
164+
jnp.ndarray: The controlled quantum operator matrix
165+
"""
154166
n_qubits = int(log2(operator.shape[0]))
155167
control = jnp.eye(2 ** (n_control + n_qubits), dtype=default_dtype)
156168
control = control.at[-(2**n_qubits) :, -(2**n_qubits) :].set(operator)
157169
return control
158170

159171

172+
def controlled(
173+
operator: jnp.ndarray,
174+
target_qubits: TargetQubits,
175+
control_qubits: ControlQubits,
176+
) -> jnp.ndarray:
177+
"""
178+
Create a controlled quantum operator with specified control and target qubit indices.
179+
180+
Args:
181+
operator (jnp.ndarray): The base quantum operator to be controlled.
182+
Note the operator is defined only on `target_qubits`.
183+
control_qubits (int or tuple of ints): Index or indices of control qubits
184+
target_qubits (int or tuple of ints): Index or indices of target qubits
185+
186+
Returns:
187+
jnp.ndarray: The controlled quantum operator matrix
188+
"""
189+
controls: tuple = tuple()
190+
targets: tuple = tuple()
191+
if isinstance(control_qubits[0], tuple):
192+
controls = control_qubits[0]
193+
if isinstance(target_qubits[0], tuple):
194+
targets = target_qubits[0]
195+
n_qop = int(log(operator.shape[0], 2))
196+
n_targets = len(targets)
197+
if n_qop != n_targets:
198+
raise ValueError("`target_qubits` length should match the shape of operator.")
199+
# Determine the total number of qubits and order of controls
200+
ntotal_qubits = len(controls) + n_targets
201+
qubit_support = sorted(controls + targets)
202+
control_ind_support = tuple(i for i, q in enumerate(qubit_support) if q in controls)
203+
204+
# Create the full Hilbert space dimension
205+
full_dim = 2**ntotal_qubits
206+
207+
# Initialize the controlled operator as an identity matrix
208+
controlled_op = jnp.eye(full_dim, dtype=operator.dtype)
209+
210+
# Compute the control mask using bit manipulation
211+
control_mask = jnp.sum(
212+
jnp.array(
213+
[1 << (ntotal_qubits - control_qubit - 1) for control_qubit in control_ind_support]
214+
)
215+
)
216+
217+
# Create indices for the controlled subspace
218+
indices = jnp.arange(full_dim)
219+
controlled_indices = indices[(indices & control_mask) == control_mask]
220+
221+
# Set the controlled subspace to the operator
222+
controlled_op = controlled_op.at[jnp.ix_(controlled_indices, controlled_indices)].set(operator)
223+
224+
return controlled_op
225+
226+
227+
def expand_operator(
228+
operator: Array, qubit_support: TargetQubits, full_support: TargetQubits
229+
) -> Array:
230+
"""
231+
Expands an operator acting on a given qubit_support to act on a larger full_support
232+
by explicitly filling in identity matrices on all remaining qubits.
233+
234+
Args:
235+
operator (Array): Operator to expand
236+
qubit_support (TargetQubits): Qubit support the operator is initially defined over.
237+
full_support (TargetQubits): Qubit support the operator will be defined over.
238+
239+
Raises:
240+
ValueError: When `full_support` larger than or equal to the `qubit_support`
241+
242+
Returns:
243+
Array: Expanded operator.
244+
"""
245+
full_support = tuple(sorted(full_support))
246+
qubit_support = tuple(sorted(qubit_support))
247+
if not set(qubit_support).issubset(set(full_support)):
248+
raise ValueError(
249+
"Expanding tensor operation requires a `full_support` argument "
250+
"larger than or equal to the `qubit_support`."
251+
)
252+
253+
kron_qubits = set(full_support) - set(qubit_support)
254+
kron_operator = reduce(jnp.kron, [operator] + [_I] * len(kron_qubits))
255+
# TODO: Add permute_basis
256+
return kron_operator
257+
258+
160259
def product_state(bitstring: str) -> Array:
161260
"""Generates a state of shape [2 for _ in range(len(bitstring))].
162261

mkdocs.yml

+1-1
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ plugins:
4949
default_handler: python
5050
handlers:
5151
python:
52-
selection:
52+
options:
5353
filters:
5454
- "!^_" # exlude all members starting with _
5555
- "^__init__$" # but always include __init__ modules and methods

tests/test_gates.py

+79-7
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from horqrux.apply import apply_gate, apply_operator
1111
from horqrux.parametric import PHASE, RX, RY, RZ
1212
from horqrux.primitive import NOT, SWAP, H, I, S, T, X, Y, Z
13-
from horqrux.utils import density_mat, equivalent_state, product_state, random_state
13+
from horqrux.utils import OperationType, density_mat, equivalent_state, product_state, random_state
1414

1515
MAX_QUBITS = 7
1616
PARAMETRIC_GATES = (RX, RY, RZ, PHASE)
@@ -31,7 +31,7 @@ def test_primitive(gate_fn: Callable) -> None:
3131
# test density matrix is similar to pure state
3232
dm = apply_operator(
3333
density_mat(orig_state),
34-
gate.unitary(),
34+
gate._unitary(),
3535
gate.target[0],
3636
gate.control[0],
3737
)
@@ -54,7 +54,7 @@ def test_controlled_primitive(gate_fn: Callable) -> None:
5454
# test density matrix is similar to pure state
5555
dm = apply_operator(
5656
density_mat(orig_state),
57-
gate.unitary(),
57+
gate._unitary(),
5858
gate.target[0],
5959
gate.control[0],
6060
)
@@ -75,7 +75,7 @@ def test_parametric(gate_fn: Callable) -> None:
7575
# test density matrix is similar to pure state
7676
dm = apply_operator(
7777
density_mat(orig_state),
78-
gate.unitary(values),
78+
gate._unitary(values),
7979
gate.target[0],
8080
gate.control[0],
8181
)
@@ -99,7 +99,7 @@ def test_controlled_parametric(gate_fn: Callable) -> None:
9999
# test density matrix is similar to pure state
100100
dm = apply_operator(
101101
density_mat(orig_state),
102-
gate.unitary(values),
102+
gate._unitary(values),
103103
gate.target[0],
104104
gate.control[0],
105105
)
@@ -149,9 +149,81 @@ def test_merge_gates() -> None:
149149
"c": np.random.uniform(0.1, 2 * np.pi),
150150
}
151151
state_grouped = apply_gate(
152-
product_state("0000"), gates, values, "unitary", group_gates=True, merge_ops=True
152+
product_state("0000"),
153+
gates,
154+
values,
155+
OperationType.UNITARY,
156+
group_gates=True,
157+
merge_ops=True,
153158
)
154159
state = apply_gate(
155-
product_state("0000"), gates, values, "unitary", group_gates=False, merge_ops=False
160+
product_state("0000"),
161+
gates,
162+
values,
163+
OperationType.UNITARY,
164+
group_gates=False,
165+
merge_ops=False,
156166
)
157167
assert jnp.allclose(state_grouped, state)
168+
169+
170+
def flip_bit_wrt_control(bitstring: str, control: int, target: int) -> str:
171+
# Convert bitstring to list for easier manipulation
172+
bits = list(bitstring)
173+
174+
# Flip the bit at the specified index
175+
if bits[control] == "1":
176+
bits[target] = "0" if bits[target] == "1" else "1"
177+
178+
# Convert back to string
179+
return "".join(bits)
180+
181+
182+
@pytest.mark.parametrize(
183+
"bitstring",
184+
[
185+
"00",
186+
"01",
187+
"11",
188+
"10",
189+
],
190+
)
191+
def test_cnot_product_state(bitstring: str):
192+
cnot0 = NOT(target=1, control=0)
193+
state = product_state(bitstring)
194+
state = apply_gate(state, cnot0)
195+
expected_state = product_state(flip_bit_wrt_control(bitstring, 0, 1))
196+
assert jnp.allclose(state, expected_state)
197+
198+
# reverse control and target
199+
cnot1 = NOT(target=0, control=1)
200+
state = product_state(bitstring)
201+
state = apply_gate(state, cnot1)
202+
expected_state = product_state(flip_bit_wrt_control(bitstring, 1, 0))
203+
assert jnp.allclose(state, expected_state)
204+
205+
206+
def test_cnot_tensor() -> None:
207+
cnot0 = NOT(target=1, control=0)
208+
cnot1 = NOT(target=0, control=1)
209+
assert jnp.allclose(
210+
cnot0.tensor(), jnp.array([[1, 0, 0, 0], [0, 1, 0, 0], [0, 0, 0, 1], [0, 0, 1, 0]])
211+
)
212+
assert jnp.allclose(
213+
cnot1.tensor(), jnp.array([[1, 0, 0, 0], [0, 0, 0, 1], [0, 0, 1, 0], [0, 1, 0, 0]])
214+
)
215+
216+
217+
def test_crx_tensor() -> None:
218+
crx0 = RX(0.2, target=1, control=0)
219+
crx1 = RX(0.2, target=0, control=1)
220+
assert jnp.allclose(
221+
crx0.tensor(),
222+
jnp.array([[1, 0, 0, 0], [0, 1, 0, 0], [0, 0, 0.9950, -0.0998j], [0, 0, -0.0998j, 0.9950]]),
223+
atol=1e-3,
224+
)
225+
assert jnp.allclose(
226+
crx1.tensor(),
227+
jnp.array([[1, 0, 0, 0], [0, 0.9950, 0, -0.0998j], [0, 0, 1, 0], [0, -0.0998j, 0, 0.9950]]),
228+
atol=1e-3,
229+
)

0 commit comments

Comments
 (0)