Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 23 additions & 14 deletions cirq-core/cirq/sim/classical_simulator.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,20 +77,24 @@ class ClassicalBasisSimState(SimulationState[ClassicalBasisState]):

def __init__(
self,
initial_state: int | list[int] = 0,
initial_state: int | list[int] | tuple[int, ...] = 0,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
initial_state: int | list[int] | tuple[int, ...] = 0,
initial_state: int | Sequence[int] = 0,

qubits: Sequence[cirq.Qid] | None = None,
classical_data: cirq.ClassicalDataStore | None = None,
):
"""Initializes the ClassicalBasisSimState object.

Args:
qubits: The qubits to simulate.
initial_state: The initial state for the simulation.
initial_state: The initial state for the simulation. Accepts int, list[int],
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
initial_state: The initial state for the simulation. Accepts int, list[int],
initial_state: The initial state for the simulation. Accepts int or a sequence of int.

tuple[int], or np.ndarray.
classical_data: The classical data container for the simulation.

Raises:
ValueError: If qubits not provided and initial_state is int.
If initial_state is not an int, list[int], or np.ndarray.
If initial_state is not an int, list[int], tuple[int], or np.ndarray.
If initial_state is a np.ndarray and its shape is not 1-dimensional.
If gate is not one of X, SWAP, QubitPermutationGate, a controlled version
of X or SWAP, or a measurement.

An initial_state value of type integer is parsed in big endian order.
"""
Expand All @@ -100,10 +104,16 @@ def __init__(
state = ClassicalBasisState(
big_endian_int_to_bits(initial_state, bit_count=len(qubits))
)
elif isinstance(initial_state, (list, np.ndarray)):
state = ClassicalBasisState(initial_state)
elif isinstance(initial_state, np.ndarray):
if initial_state.ndim != 1:
raise ValueError(
f'initial_state must be 1-dimensional, got shape {initial_state.shape}'
)
state = ClassicalBasisState(list(initial_state))
elif isinstance(initial_state, (list, tuple)):
state = ClassicalBasisState(list(initial_state))
else:
raise ValueError('initial_state must be an int or list[int] or np.ndarray')
raise ValueError('initial_state must be an int, list[int], tuple[int], or np.ndarray')
super().__init__(state=state, qubits=qubits, classical_data=classical_data)

def _act_on_fallback_(self, action, qubits: Sequence[cirq.Qid], allow_decompose: bool = True):
Expand All @@ -116,14 +126,7 @@ def _act_on_fallback_(self, action, qubits: Sequence[cirq.Qid], allow_decompose:

Returns:
True if the operation was applied successfully.

Raises:
ValueError: If initial_state shape for type np.ndarray is not equal to 1.
If gate is not one of X, SWAP, a controlled version of X or SWAP,
or a measurement.
"""
if isinstance(self._state.basis, np.ndarray) and len(self._state.basis.shape) != 1:
raise ValueError('initial_state shape for type np.ndarray is not equal to 1')
gate = action.gate if isinstance(action, ops.Operation) else action
mapped_qubits = [self.qubit_map[i] for i in qubits]

Expand Down Expand Up @@ -152,9 +155,15 @@ def _act_on_fallback_(self, action, qubits: Sequence[cirq.Qid], allow_decompose:
elif gate == ops.TOFFOLI:
c1, c2, q = mapped_qubits
self._state.basis[q] ^= self._state.basis[c1] & self._state.basis[c2]
elif isinstance(gate, ops.QubitPermutationGate):
perm = gate.permutation
basis = self._state.basis
original_values = [basis[mapped_qubits[i]] for i in range(len(mapped_qubits))]
for i, q in enumerate(mapped_qubits):
basis[q] = original_values[perm[i]]
else:
raise ValueError(
f'{gate} is not one of X, SWAP; a controlled version '
f'{gate} is not one of X, SWAP, QubitPermutationGate; a controlled version '
'of X or SWAP; or a measurement'
)
return True
Expand Down
25 changes: 21 additions & 4 deletions cirq-core/cirq/sim/classical_simulator_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,15 @@ def test_Swap():
np.testing.assert_equal(results, expected_results)


def test_qubit_permutation_gate():
q0, q1, q2 = cirq.LineQubit.range(3)
perm_gate = cirq.QubitPermutationGate([2, 0, 1])
circuit = cirq.Circuit(perm_gate(q0, q1, q2), cirq.measure(q0, q1, q2, key='key'))
sim = cirq.ClassicalStateSimulator()
result = sim.simulate(circuit, initial_state=[1, 0, 1])
np.testing.assert_equal(result.measurements['key'], [1, 1, 0])


def test_CCNOT():
q0, q1, q2 = cirq.LineQubit.range(3)
circuit = cirq.Circuit()
Expand Down Expand Up @@ -209,6 +218,14 @@ def test_multiple_gates_order():
np.testing.assert_equal(results, expected_results)


def test_tuple_initial_state():
q0, q1, q2 = cirq.LineQubit.range(3)
circuit = cirq.Circuit(cirq.X(q0), cirq.measure(q0, q1, q2, key='key'))
sim = cirq.ClassicalStateSimulator()
result = sim.simulate(circuit, initial_state=(0, 1, 0))
np.testing.assert_equal(result.measurements['key'], [1, 1, 0])


def test_param_resolver():
gate = cirq.CNOT ** sympy.Symbol('t')
q0, q1 = cirq.LineQubit.range(2)
Expand Down Expand Up @@ -333,11 +350,11 @@ def test_create_invalid_partial_simulation_state_from_np():
qs = cirq.LineQubit.range(2)
classical_data = cirq.value.ClassicalDataDictionaryStore()
sim = cirq.ClassicalStateSimulator()
sim_state = sim._create_partial_simulation_state(
initial_state=initial_state, qubits=qs, classical_data=classical_data
)

with pytest.raises(ValueError):
sim_state._act_on_fallback_(action=cirq.CX, qubits=qs)
sim._create_partial_simulation_state(
initial_state=initial_state, qubits=qs, classical_data=classical_data
)


def test_noise_model():
Expand Down