Skip to content

Commit 7bbec3b

Browse files
anurudhpmpharrigan
andauthored
Tensor from classical action (#1514)
* WIP tensor from classical action * fix errors * fix syntax error? * move test, add timeout * rename * notebooks * cleanup * `my_tensors` from classical * rename and docs * more tests * imports + doc * nits * fix docstring * make files private (matching others) --------- Co-authored-by: Matthew Harrigan <[email protected]>
1 parent d92bb1e commit 7bbec3b

File tree

7 files changed

+212
-3
lines changed

7 files changed

+212
-3
lines changed

qualtran/bloqs/arithmetic/sorting.ipynb

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,7 @@
8686
},
8787
"outputs": [],
8888
"source": [
89-
"comparator = Comparator(7)"
89+
"comparator = Comparator(3)"
9090
]
9191
},
9292
{

qualtran/bloqs/arithmetic/sorting.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,7 @@ def build_composite_bloq(
8585

8686
@bloq_example
8787
def _comparator() -> Comparator:
88-
comparator = Comparator(7)
88+
comparator = Comparator(3)
8989
return comparator
9090

9191

qualtran/bloqs/basic_gates/swap.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -299,7 +299,7 @@ def _swap_small() -> Swap:
299299
return swap_small
300300

301301

302-
@bloq_example
302+
@bloq_example(generalizer=ignore_split_join)
303303
def _swap_large() -> Swap:
304304
swap_large = Swap(bitsize=64)
305305
return swap_large

qualtran/bloqs/basic_gates/swap_test.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
_cswap_large,
3636
_cswap_small,
3737
_swap,
38+
_swap_large,
3839
_swap_matrix,
3940
_swap_small,
4041
Swap,
@@ -224,6 +225,10 @@ def test_swap_small(bloq_autotester):
224225
bloq_autotester(_swap_small)
225226

226227

228+
def test_swap_large(bloq_autotester):
229+
bloq_autotester(_swap_large)
230+
231+
227232
def test_swap_symb(bloq_autotester):
228233
if bloq_autotester.check_name == 'serialize':
229234
pytest.skip("Sympy equality with assumptions.")

qualtran/simulation/tensor/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,3 +24,7 @@
2424
tensor_out_inp_shape_from_signature,
2525
tensor_shape_from_signature,
2626
)
27+
from ._tensor_from_classical import (
28+
bloq_to_dense_via_classical_action,
29+
my_tensors_from_classical_action,
30+
)
Lines changed: 139 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,139 @@
1+
# Copyright 2024 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
import itertools
15+
from typing import Iterable, TYPE_CHECKING
16+
17+
import numpy as np
18+
from numpy.typing import NDArray
19+
20+
if TYPE_CHECKING:
21+
import quimb.tensor as qtn
22+
23+
from qualtran import Bloq, ConnectionT, Register
24+
from qualtran.simulation.classical_sim import ClassicalValT
25+
26+
27+
def _bits_to_classical_reg_data(reg: 'Register', bits: NDArray[np.uint8]) -> 'ClassicalValT':
28+
if reg.shape == ():
29+
return reg.dtype.from_bits([*bits.flat])
30+
return reg.dtype.from_bits_array(np.reshape(bits, reg.shape + (reg.dtype.num_qubits,)))
31+
32+
33+
def _bloq_to_dense_via_classical_action(bloq: 'Bloq') -> NDArray:
34+
"""Internal method to compute the tensor of a bloq using its classical action.
35+
36+
Args:
37+
bloq: the Bloq
38+
39+
Returns:
40+
an NDArray of shape (2, 2, ...) indexed by the output bits followed by input bits.
41+
"""
42+
left_qubit_counts = tuple(reg.total_bits() for reg in bloq.signature.lefts())
43+
left_qubit_splits = np.cumsum(left_qubit_counts)
44+
45+
n_qubits_left = sum(left_qubit_counts)
46+
n_qubits_right = sum(reg.total_bits() for reg in bloq.signature.rights())
47+
48+
if n_qubits_left + n_qubits_right > 40:
49+
raise ValueError(f"tensor is too large: {n_qubits_left + n_qubits_right} total qubits")
50+
51+
matrix = np.zeros((2,) * (n_qubits_right + n_qubits_left))
52+
53+
for input_t in itertools.product((0, 1), repeat=n_qubits_left):
54+
*inputs_t, last = np.split(input_t, left_qubit_splits)
55+
assert np.size(last) == 0
56+
57+
input_kwargs = {
58+
reg.name: _bits_to_classical_reg_data(reg, bits)
59+
for reg, bits in zip(bloq.signature.lefts(), inputs_t)
60+
}
61+
output_args = bloq.call_classically(**input_kwargs)
62+
63+
if output_args:
64+
output_t = np.concatenate(
65+
[
66+
reg.dtype.to_bits_array(np.asarray(vals)).flat
67+
for reg, vals in zip(bloq.signature.rights(), output_args)
68+
]
69+
)
70+
else:
71+
output_t = np.array([])
72+
73+
matrix[tuple([*np.atleast_1d(output_t), *np.atleast_1d(input_t)])] = 1
74+
75+
return matrix
76+
77+
78+
def bloq_to_dense_via_classical_action(bloq: 'Bloq') -> NDArray:
79+
"""Return a contracted, dense ndarray representing the bloq, using its classical action.
80+
81+
Args:
82+
bloq: The bloq
83+
84+
Raises:
85+
ValueError: if the bloq does not have a classical action.
86+
"""
87+
try:
88+
matrix = _bloq_to_dense_via_classical_action(bloq)
89+
except ValueError as e:
90+
raise ValueError(f"cannot compute tensor for {bloq}: {str(e)}") from e
91+
92+
n_qubits_left = sum(reg.total_bits() for reg in bloq.signature.lefts())
93+
n_qubits_right = sum(reg.total_bits() for reg in bloq.signature.rights())
94+
95+
shape: tuple[int, ...]
96+
if n_qubits_left == 0 and n_qubits_right == 0:
97+
shape = ()
98+
elif n_qubits_left == 0 or n_qubits_right == 0:
99+
shape = (2 ** max(n_qubits_left, n_qubits_right),)
100+
else:
101+
shape = (2**n_qubits_right, 2**n_qubits_left)
102+
103+
return matrix.reshape(shape)
104+
105+
106+
def my_tensors_from_classical_action(
107+
bloq: 'Bloq', incoming: dict[str, 'ConnectionT'], outgoing: dict[str, 'ConnectionT']
108+
) -> list['qtn.Tensor']:
109+
"""Returns the quimb tensors for the bloq derived from its `on_classical_vals` method.
110+
111+
This function has the same signature as `bloq.my_tensors`, and can be used as a
112+
replacement for it when the bloq has a known classical action.
113+
For example:
114+
115+
```py
116+
class ClassicalBloq(Bloq):
117+
...
118+
119+
def on_classical_vals(...):
120+
...
121+
122+
def my_tensors(self, incoming, outgoing):
123+
return my_tensors_from_classical_action(self, incoming, outgoing)
124+
```
125+
"""
126+
import quimb.tensor as qtn
127+
128+
def _signature_to_inds(registers: Iterable['Register'], cxns: dict[str, 'ConnectionT']):
129+
for reg in registers:
130+
for cxn in np.asarray(cxns[reg.name]).flat:
131+
for j in range(reg.dtype.num_qubits):
132+
yield cxn, j
133+
134+
data = _bloq_to_dense_via_classical_action(bloq)
135+
incoming_inds = _signature_to_inds(bloq.signature.lefts(), incoming)
136+
outgoing_inds = _signature_to_inds(bloq.signature.rights(), outgoing)
137+
inds = [*outgoing_inds, *incoming_inds]
138+
139+
return [qtn.Tensor(data=data, inds=inds)]
Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
# Copyright 2024 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
import numpy as np
15+
import pytest
16+
import quimb.tensor as qtn
17+
18+
from qualtran import Bloq, ConnectionT, QAny, QUInt, Signature
19+
from qualtran.bloqs.arithmetic import Add, Xor
20+
from qualtran.bloqs.basic_gates import Toffoli, TwoBitCSwap, XGate
21+
from qualtran.simulation.classical_sim import ClassicalValT
22+
from qualtran.simulation.tensor._tensor_from_classical import (
23+
bloq_to_dense_via_classical_action,
24+
my_tensors_from_classical_action,
25+
)
26+
27+
28+
@pytest.mark.parametrize(
29+
"bloq", [XGate(), TwoBitCSwap(), Toffoli(), Add(QUInt(3)), Xor(QAny(3))], ids=str
30+
)
31+
def test_tensor_consistent_with_classical(bloq: Bloq):
32+
from_classical = bloq_to_dense_via_classical_action(bloq)
33+
from_tensor = bloq.tensor_contract()
34+
35+
np.testing.assert_allclose(from_classical, from_tensor)
36+
37+
38+
class TestClassicalBloq(Bloq):
39+
@property
40+
def signature(self) -> 'Signature':
41+
return Signature.build(a=1, b=1, c=1)
42+
43+
def on_classical_vals(
44+
self, a: 'ClassicalValT', b: 'ClassicalValT', c: 'ClassicalValT'
45+
) -> dict[str, 'ClassicalValT']:
46+
if a == 1 and b == 1:
47+
c = c ^ 1
48+
return {'a': a, 'b': b, 'c': c}
49+
50+
def my_tensors(
51+
self, incoming: dict[str, 'ConnectionT'], outgoing: dict[str, 'ConnectionT']
52+
) -> list['qtn.Tensor']:
53+
return my_tensors_from_classical_action(self, incoming, outgoing)
54+
55+
56+
def test_my_tensors_from_classical_action():
57+
bloq = TestClassicalBloq()
58+
59+
expected_tensor = Toffoli().tensor_contract()
60+
actual_tensor = bloq.tensor_contract()
61+
np.testing.assert_allclose(actual_tensor, expected_tensor)

0 commit comments

Comments
 (0)