Skip to content

Commit 6cf0168

Browse files
committed
Add CSWAP support to ClassicalStateSimulator
1 parent f066e63 commit 6cf0168

File tree

3 files changed

+40
-6
lines changed

3 files changed

+40
-6
lines changed

cirq-core/cirq/sim/classical_simulator.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -154,6 +154,13 @@ def _act_on_fallback_(self, action, qubits: Sequence[cirq.Qid], allow_decompose:
154154
elif gate == ops.SWAP:
155155
a, b = mapped_qubits
156156
self._state.basis[a], self._state.basis[b] = self._state.basis[b], self._state.basis[a]
157+
elif gate == ops.CSWAP:
158+
c, a, b = mapped_qubits
159+
if self._state.basis[c]:
160+
self._state.basis[a], self._state.basis[b] = (
161+
self._state.basis[b],
162+
self._state.basis[a],
163+
)
157164
elif gate == ops.TOFFOLI:
158165
c1, c2, q = mapped_qubits
159166
self._state.basis[q] ^= self._state.basis[c1] & self._state.basis[c2]

cirq-core/cirq/sim/classical_simulator_test.py

Lines changed: 32 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424

2525

2626
def test_x_gate() -> None:
27+
"""Tests the X gate."""
2728
q0, q1 = cirq.LineQubit.range(2)
2829
circuit = cirq.Circuit()
2930
circuit.append(cirq.X(q0))
@@ -37,7 +38,8 @@ def test_x_gate() -> None:
3738
np.testing.assert_equal(results, expected_results)
3839

3940

40-
def test_CNOT() -> None:
41+
def test_cnot() -> None:
42+
"""Tests the CNOT gate."""
4143
q0, q1 = cirq.LineQubit.range(2)
4244
circuit = cirq.Circuit()
4345
circuit.append(cirq.X(q0))
@@ -50,7 +52,8 @@ def test_CNOT() -> None:
5052
np.testing.assert_equal(results, expected_results)
5153

5254

53-
def test_Swap() -> None:
55+
def test_swap() -> None:
56+
"""Tests the SWAP gate."""
5457
q0, q1 = cirq.LineQubit.range(2)
5558
circuit = cirq.Circuit()
5659
circuit.append(cirq.X(q0))
@@ -71,6 +74,7 @@ def test_Swap() -> None:
7174
],
7275
)
7376
def test_qubit_permutation_gate(n, perm, state) -> None:
77+
"""Tests the QubitPermutationGate."""
7478
qubits = cirq.LineQubit.range(n)
7579
perm_gate = cirq.QubitPermutationGate(perm)
7680
circuit = cirq.Circuit(perm_gate(*qubits), cirq.measure(*qubits, key='key'))
@@ -83,7 +87,8 @@ def test_qubit_permutation_gate(n, perm, state) -> None:
8387
np.testing.assert_equal(result.measurements['key'], expected)
8488

8589

86-
def test_CCNOT() -> None:
90+
def test_ccnot() -> None:
91+
"""Tests the CCNOT gate."""
8792
q0, q1, q2 = cirq.LineQubit.range(3)
8893
circuit = cirq.Circuit()
8994
circuit.append(cirq.CCNOT(q0, q1, q2))
@@ -108,7 +113,8 @@ def test_CCNOT() -> None:
108113

109114

110115
@pytest.mark.parametrize(['initial_state'], [(list(x),) for x in product([0, 1], repeat=4)])
111-
def test_CCCX(initial_state) -> None:
116+
def test_cccx(initial_state) -> None:
117+
"""Tests the CCCX gate."""
112118
CCCX = cirq.CCNOT.controlled()
113119
qubits = cirq.LineQubit.range(4)
114120

@@ -126,7 +132,8 @@ def test_CCCX(initial_state) -> None:
126132

127133

128134
@pytest.mark.parametrize(['initial_state'], [(list(x),) for x in product([0, 1], repeat=3)])
129-
def test_CSWAP(initial_state) -> None:
135+
def test_controlled_swap(initial_state) -> None:
136+
"""Tests the controlled SWAP gate."""
130137
CSWAP = cirq.SWAP.controlled()
131138
qubits = cirq.LineQubit.range(3)
132139
circuit = cirq.Circuit()
@@ -146,6 +153,25 @@ def test_CSWAP(initial_state) -> None:
146153
np.testing.assert_equal(results, final_state)
147154

148155

156+
def test_cswap() -> None:
157+
"""Tests the CSWAP gate."""
158+
# Specifically test named CSWAP gate, not just controlled(SWAP)
159+
q0, q1, q2 = cirq.LineQubit.range(3)
160+
circuit = cirq.Circuit()
161+
# Control q0=1, so swap q1, q2
162+
circuit.append(cirq.X(q0))
163+
circuit.append(cirq.X(q1))
164+
circuit.append(cirq.CSWAP(q0, q1, q2)) # q0=1 -> swap q1=1, q2=0 -> q1=0, q2=1
165+
circuit.append(cirq.measure((q0, q1, q2), key='key'))
166+
167+
sim: cirq.ClassicalStateSimulator
168+
sim = cirq.ClassicalStateSimulator()
169+
result = sim.run(circuit, repetitions=1)
170+
# Expected: 1, 0, 1
171+
expected = np.array([[[1, 0, 1]]], dtype=np.uint8)
172+
np.testing.assert_equal(result.records['key'], expected)
173+
174+
149175
def test_measurement_gate() -> None:
150176
q0, q1 = cirq.LineQubit.range(2)
151177
circuit = cirq.Circuit()
@@ -324,7 +350,7 @@ def test_create_partial_simulation_state_from_int_with_no_qubits() -> None:
324350
with pytest.raises(ValueError):
325351
sim._create_partial_simulation_state(
326352
initial_state=initial_state,
327-
qubits=qs, # type: ignore[arg-type]
353+
qubits=qs,
328354
classical_data=classical_data,
329355
)
330356

dev_tools/snippets_test.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@
6262
import cirq
6363

6464
DOCS_FOLDER = pathlib.Path(__file__).parent.parent / 'docs'
65+
6566
DEFAULT_STATE: dict[str, Any] = {}
6667

6768

0 commit comments

Comments
 (0)