Skip to content

Refactor simulator RNG handling #6944

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 2 commits into
base: main
Choose a base branch
from
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions cirq-core/cirq/contrib/quimb/mps_simulator.py
Original file line number Diff line number Diff line change
@@ -379,7 +379,7 @@ def to_numpy(self) -> np.ndarray:
"""An alias for the state vector."""
return self.state_vector()

def apply_op(self, op: Any, axes: Sequence[int], prng: np.random.RandomState):
def apply_op(self, op: Any, axes: Sequence[int], prng: np.random.Generator):
"""Applies a unitary operation, mutating the object to represent the new state.

op:
@@ -481,7 +481,7 @@ def estimation_stats(self):
}

def _measure(
self, axes: Sequence[int], prng: np.random.RandomState, collapse_state_vector=True
self, axes: Sequence[int], prng: np.random.Generator, collapse_state_vector=True
) -> List[int]:
results: List[int] = []

@@ -565,7 +565,7 @@ def __init__(
self,
*,
qubits: Sequence['cirq.Qid'],
prng: np.random.RandomState,
prng: Union[np.random.Generator, np.random.RandomState],
simulation_options: MPSOptions = MPSOptions(),
grouping: Optional[Dict['cirq.Qid', int]] = None,
initial_state: int = 0,
Original file line number Diff line number Diff line change
@@ -184,7 +184,7 @@ def random_rotations_between_two_qubit_circuit(
q1: 'cirq.Qid',
depth: int,
two_qubit_op_factory: Callable[
['cirq.Qid', 'cirq.Qid', 'np.random.RandomState'], 'cirq.OP_TREE'
['cirq.Qid', 'cirq.Qid', 'np.random.Generator'], 'cirq.OP_TREE'
] = lambda a, b, _: ops.CZPowGate()(a, b),
single_qubit_gates: Sequence['cirq.Gate'] = (
ops.X**0.5,
@@ -354,7 +354,7 @@ def _get_random_combinations(

combinations_by_layer = []
for pairs, layer in pair_gen:
combinations = rs.randint(0, n_library_circuits, size=(n_combinations, len(pairs)))
combinations = rs.integers(0, n_library_circuits, size=(n_combinations, len(pairs)))
combinations_by_layer.append(
CircuitLibraryCombination(layer=layer, combinations=combinations, pairs=pairs)
)
@@ -553,7 +553,7 @@ def random_rotations_between_grid_interaction_layers_circuit(
depth: int,
*, # forces keyword arguments
two_qubit_op_factory: Callable[
['cirq.GridQubit', 'cirq.GridQubit', 'np.random.RandomState'], 'cirq.OP_TREE'
['cirq.GridQubit', 'cirq.GridQubit', 'np.random.Generator'], 'cirq.OP_TREE'
] = lambda a, b, _: ops.CZPowGate()(a, b),
pattern: Sequence[GridInteractionLayer] = GRID_STAGGERED_PATTERN,
single_qubit_gates: Sequence['cirq.Gate'] = (
@@ -641,7 +641,7 @@ def __init__(
self,
qubits: Sequence['cirq.Qid'],
single_qubit_gates: Sequence['cirq.Gate'],
prng: 'np.random.RandomState',
prng: 'np.random.Generator',
) -> None:
self.qubits = qubits
self.single_qubit_gates = single_qubit_gates
@@ -651,9 +651,9 @@ def new_layer(self, previous_single_qubit_layer: 'cirq.Moment') -> 'cirq.Moment'
def random_gate(qubit: 'cirq.Qid') -> 'cirq.Gate':
excluded_op = previous_single_qubit_layer.operation_at(qubit)
excluded_gate = excluded_op.gate if excluded_op is not None else None
g = self.single_qubit_gates[self.prng.randint(0, len(self.single_qubit_gates))]
g = self.single_qubit_gates[self.prng.integers(0, len(self.single_qubit_gates))]
while g is excluded_gate:
g = self.single_qubit_gates[self.prng.randint(0, len(self.single_qubit_gates))]
g = self.single_qubit_gates[self.prng.integers(0, len(self.single_qubit_gates))]
return g

return circuits.Moment(random_gate(q).on(q) for q in self.qubits)
@@ -673,7 +673,7 @@ def new_layer(self, previous_single_qubit_layer: 'cirq.Moment') -> 'cirq.Moment'
def _single_qubit_gates_arg_to_factory(
single_qubit_gates: Sequence['cirq.Gate'],
qubits: Sequence['cirq.Qid'],
prng: 'np.random.RandomState',
prng: 'np.random.Generator',
) -> _SingleQubitLayerFactory:
"""Parse the `single_qubit_gates` argument for circuit generation functions.

@@ -690,10 +690,10 @@ def _single_qubit_gates_arg_to_factory(
def _two_qubit_layer(
coupled_qubit_pairs: List[GridQubitPairT],
two_qubit_op_factory: Callable[
['cirq.GridQubit', 'cirq.GridQubit', 'np.random.RandomState'], 'cirq.OP_TREE'
['cirq.GridQubit', 'cirq.GridQubit', 'np.random.Generator'], 'cirq.OP_TREE'
],
layer: GridInteractionLayer,
prng: 'np.random.RandomState',
prng: 'np.random.Generator',
) -> Iterator['cirq.OP_TREE']:
for a, b in coupled_qubit_pairs:
if (a, b) in layer or (b, a) in layer:
Original file line number Diff line number Diff line change
@@ -52,15 +52,15 @@ def test_random_rotation_between_two_qubit_circuit():
"""\
0 1
│ │
Y^0.5 X^0.5
PhX(0.25)^0.5 Y^0.5
│ │
@─────────────@
│ │
PhX(0.25)^0.5 Y^0.5
X^0.5 PhX(0.25)^0.5
│ │
@─────────────@
│ │
Y^0.5 X^0.5
Y^0.5 Y^0.5
│ │
@─────────────@
│ │
@@ -361,7 +361,7 @@ def test_random_rotations_between_grid_interaction_layers(
qubits: Iterable[cirq.GridQubit],
depth: int,
two_qubit_op_factory: Callable[
[cirq.GridQubit, cirq.GridQubit, np.random.RandomState], cirq.OP_TREE
[cirq.GridQubit, cirq.GridQubit, np.random.Generator], cirq.OP_TREE
],
pattern: Sequence[GridInteractionLayer],
single_qubit_gates: Sequence[cirq.Gate],
2 changes: 1 addition & 1 deletion cirq-core/cirq/linalg/decompositions_test.py
Original file line number Diff line number Diff line change
@@ -597,7 +597,7 @@ def _random_two_qubit_unitaries(num_samples: int, random_state: 'cirq.RANDOM_STA

prng = value.parse_random_state(random_state)
# Generate the non-local part by explict matrix exponentiation.
kak_vecs = prng.rand(num_samples, 3) * np.pi
kak_vecs = prng.random((num_samples, 3)) * np.pi
gens = np.einsum('...a,abc->...bc', kak_vecs, _kak_gens)
evals, evecs = np.linalg.eigh(gens)
A = np.einsum('...ab,...b,...cb', evecs, np.exp(1j * evals), evecs.conj())
2 changes: 1 addition & 1 deletion cirq-core/cirq/protocols/act_on_protocol_test.py
Original file line number Diff line number Diff line change
@@ -30,7 +30,7 @@ def measure(self, axes, seed=None):

class ExampleSimulationState(cirq.SimulationState):
def __init__(self, fallback_result: Any = NotImplemented):
super().__init__(prng=np.random.RandomState(), state=ExampleQuantumState())
super().__init__(prng=np.random.default_rng(), state=ExampleQuantumState())
self.fallback_result = fallback_result

def _act_on_fallback_(
4 changes: 2 additions & 2 deletions cirq-core/cirq/qis/clifford_tableau.py
Original file line number Diff line number Diff line change
@@ -509,7 +509,7 @@ def destabilizers(self) -> List['cirq.DensePauliString']:
generators above generate the full Pauli group on n qubits."""
return [self._row_to_dense_pauli(i) for i in range(self.n)]

def _measure(self, q, prng: np.random.RandomState) -> int:
def _measure(self, q, prng: np.random.Generator) -> int:
"""Performs a projective measurement on the q'th qubit.

Returns: the result (0 or 1) of the measurement.
@@ -544,7 +544,7 @@ def _measure(self, q, prng: np.random.RandomState) -> int:

self.zs[p, q] = True

self.rs[p] = bool(prng.randint(2))
self.rs[p] = bool(prng.integers(2))

return int(self.rs[p])

4 changes: 2 additions & 2 deletions cirq-core/cirq/sim/clifford/clifford_simulator.py
Original file line number Diff line number Diff line change
@@ -242,7 +242,7 @@ def state_vector(self):

def apply_unitary(self, op: 'cirq.Operation'):
ch_form_args = clifford.StabilizerChFormSimulationState(
prng=np.random.RandomState(), qubits=self.qubit_map.keys(), initial_state=self.ch_form
prng=np.random.default_rng(), qubits=self.qubit_map.keys(), initial_state=self.ch_form
)
try:
act_on(op, ch_form_args)
@@ -254,7 +254,7 @@ def apply_measurement(
self,
op: 'cirq.Operation',
measurements: Dict[str, List[int]],
prng: np.random.RandomState,
prng: Union[np.random.Generator, np.random.RandomState],
collapse_state_vector=True,
):
if not isinstance(op.gate, cirq.MeasurementGate):
Original file line number Diff line number Diff line change
@@ -14,7 +14,7 @@
"""A protocol for implementing high performance clifford tableau evolutions
for Clifford Simulator."""

from typing import Optional, Sequence, TYPE_CHECKING
from typing import Optional, Sequence, TYPE_CHECKING, Union

import numpy as np

@@ -31,7 +31,7 @@ class CliffordTableauSimulationState(StabilizerSimulationState[clifford_tableau.
def __init__(
self,
tableau: 'cirq.CliffordTableau',
prng: Optional[np.random.RandomState] = None,
prng: Optional[Union[np.random.Generator, np.random.RandomState]] = None,
qubits: Optional[Sequence['cirq.Qid']] = None,
classical_data: Optional['cirq.ClassicalDataStore'] = None,
):
Original file line number Diff line number Diff line change
@@ -32,7 +32,7 @@ class StabilizerChFormSimulationState(
def __init__(
self,
*,
prng: Optional[np.random.RandomState] = None,
prng: Optional[Union[np.random.Generator, np.random.RandomState]] = None,
qubits: Optional[Sequence['cirq.Qid']] = None,
initial_state: Union[int, 'cirq.StabilizerStateChForm'] = 0,
classical_data: Optional['cirq.ClassicalDataStore'] = None,
Original file line number Diff line number Diff line change
@@ -41,7 +41,7 @@ def __init__(
self,
*,
state: TStabilizerState,
prng: Optional[np.random.RandomState] = None,
prng: Optional[Union[np.random.Generator, np.random.RandomState]] = None,
qubits: Optional[Sequence['cirq.Qid']] = None,
classical_data: Optional['cirq.ClassicalDataStore'] = None,
):
4 changes: 2 additions & 2 deletions cirq-core/cirq/sim/clifford/stabilizer_state_ch_form.py
Original file line number Diff line number Diff line change
@@ -236,7 +236,7 @@ def to_state_vector(self) -> np.ndarray:

return arr

def _measure(self, q, prng: np.random.RandomState) -> int:
def _measure(self, q, prng: np.random.Generator) -> int:
"""Measures the q'th qubit.

Reference: Section 4.1 "Simulating measurements"
@@ -246,7 +246,7 @@ def _measure(self, q, prng: np.random.RandomState) -> int:
w = self.s.copy()
for i, v_i in enumerate(self.v):
if v_i == 1:
w[i] = bool(prng.randint(2))
w[i] = bool(prng.integers(2))
x_i = sum(w & self.G[q, :]) % 2
# Project the state to the above measurement outcome.
self.project_Z(q, x_i)
2 changes: 1 addition & 1 deletion cirq-core/cirq/sim/density_matrix_simulation_state.py
Original file line number Diff line number Diff line change
@@ -247,7 +247,7 @@ def __init__(
self,
*,
available_buffer: Optional[List[np.ndarray]] = None,
prng: Optional[np.random.RandomState] = None,
prng: Optional[Union[np.random.Generator, np.random.RandomState]] = None,
qubits: Optional[Sequence['cirq.Qid']] = None,
initial_state: Union[np.ndarray, 'cirq.STATE_VECTOR_LIKE'] = 0,
dtype: Type[np.complexfloating] = np.complex64,
9 changes: 6 additions & 3 deletions cirq-core/cirq/sim/simulation_state.py
Original file line number Diff line number Diff line change
@@ -27,6 +27,7 @@
TypeVar,
TYPE_CHECKING,
Tuple,
Union,
)
from typing_extensions import Self

@@ -49,7 +50,7 @@ def __init__(
self,
*,
state: TState,
prng: Optional[np.random.RandomState] = None,
prng: Optional[Union[np.random.Generator, np.random.RandomState]] = None,
qubits: Optional[Sequence['cirq.Qid']] = None,
classical_data: Optional['cirq.ClassicalDataStore'] = None,
):
@@ -70,12 +71,14 @@ def __init__(
classical_data = classical_data or value.ClassicalDataDictionaryStore()
super().__init__(qubits=qubits, classical_data=classical_data)
if prng is None:
prng = cast(np.random.RandomState, np.random)
prng = np.random.default_rng()
elif isinstance(prng, np.random.RandomState):
prng = np.random.default_rng(prng._bit_generator)
self._prng = prng
self._state = state

@property
def prng(self) -> np.random.RandomState:
def prng(self) -> np.random.Generator:
return self._prng

def measure(
29 changes: 23 additions & 6 deletions cirq-core/cirq/sim/simulator_base.py
Original file line number Diff line number Diff line change
@@ -29,6 +29,7 @@
Type,
TypeVar,
TYPE_CHECKING,
Union,
)

import numpy as np
@@ -93,21 +94,27 @@ def __init__(
*,
dtype: Type[np.complexfloating] = np.complex64,
noise: 'cirq.NOISE_MODEL_LIKE' = None,
seed: 'cirq.RANDOM_STATE_OR_SEED_LIKE' = None,
seed: Optional[Union[int, np.random.Generator, np.random.RandomState]] = None,
split_untangled_states: bool = False,
):
"""Initializes the simulator.

Args:
dtype: The `numpy.dtype` used by the simulation.
noise: A noise model to apply while simulating.
seed: The random seed to use for this simulator.
seed: The random seed or generator to use for this simulator.
split_untangled_states: If True, optimizes simulation by running
unentangled qubit sets independently and merging those states
at the end.
"""
self._dtype = dtype
self._prng = value.parse_random_state(seed)
if isinstance(seed, np.random.RandomState):
# Convert RandomState to Generator for backward compatibility
self._prng = np.random.Generator(seed._bit_generator)
elif isinstance(seed, np.random.Generator):
self._prng = seed
else:
self._prng = np.random.default_rng(seed)
self._noise = devices.NoiseModel.from_noise_model_like(noise)
self._split_untangled_states = split_untangled_states

@@ -228,6 +235,7 @@ def _run(
circuit: 'cirq.AbstractCircuit',
param_resolver: 'cirq.ParamResolver',
repetitions: int,
rng: Optional[np.random.Generator] = None,
) -> Dict[str, np.ndarray]:
"""See definition in `cirq.SimulatesSamples`."""
param_resolver = param_resolver or study.ParamResolver({})
@@ -254,7 +262,10 @@ def _run(
assert step_result is not None
measurement_ops = [cast(ops.GateOperation, op) for op in general_ops]
return step_result.sample_measurement_ops(
measurement_ops, repetitions, seed=self._prng, _allow_repeated=True
measurement_ops,
repetitions,
seed=rng if rng is not None else self._prng,
_allow_repeated=True,
)

records: Dict['cirq.MeasurementKey', List[Sequence[Sequence[int]]]] = {}
@@ -395,9 +406,15 @@ def sample(
self,
qubits: List['cirq.Qid'],
repetitions: int = 1,
seed: 'cirq.RANDOM_STATE_OR_SEED_LIKE' = None,
seed: Optional[Union[int, np.random.Generator, np.random.RandomState]] = None,
) -> np.ndarray:
return self._sim_state.sample(qubits, repetitions, seed)
if isinstance(seed, np.random.RandomState):
rng = np.random.Generator(seed._bit_generator)
elif isinstance(seed, np.random.Generator):
rng = seed
else:
rng = np.random.default_rng(seed)
return self._sim_state.sample(qubits, repetitions, rng)


class SimulationTrialResultBase(
26 changes: 26 additions & 0 deletions cirq-core/cirq/sim/simulator_base_test.py
Original file line number Diff line number Diff line change
@@ -434,3 +434,29 @@ def test_inhomogeneous_measurement_count_padding():
results = sim.run(c, repetitions=10)
for i in range(10):
assert np.sum(results.records['m'][i, :, :]) == 1


def test_run_with_custom_rng():
sim = cirq.Simulator()
circuit = cirq.Circuit(cirq.H(q0), cirq.measure(q0))
rng1 = np.random.default_rng(seed=1234)
rng2 = np.random.default_rng(seed=1234)

result1 = sim._run(circuit, param_resolver=cirq.ParamResolver({}), repetitions=10, rng=rng1)
result2 = sim._run(circuit, param_resolver=cirq.ParamResolver({}), repetitions=10, rng=rng2)
assert np.array_equal(result1['q(0)'], result2['q(0)'])

rng3 = np.random.default_rng(seed=5678)
result3 = sim._run(circuit, param_resolver=cirq.ParamResolver({}), repetitions=10, rng=rng3)
assert not np.array_equal(result1['q(0)'], result3['q(0)'])


def test_run_with_explicit_rng_override():
sim1 = cirq.Simulator(seed=1234)
sim2 = cirq.Simulator(seed=5678)
circuit = cirq.Circuit(cirq.H(q0), cirq.measure(q0))
rng = np.random.default_rng(1234)

result1 = sim1._run(circuit, cirq.ParamResolver({}), repetitions=10)
result2 = sim2._run(circuit, cirq.ParamResolver({}), repetitions=10, rng=rng)
assert np.array_equal(result1['q(0)'], result2['q(0)'])
2 changes: 1 addition & 1 deletion cirq-core/cirq/sim/state_vector_simulation_state.py
Original file line number Diff line number Diff line change
@@ -321,7 +321,7 @@ def __init__(
self,
*,
available_buffer: Optional[np.ndarray] = None,
prng: Optional[np.random.RandomState] = None,
prng: Optional[Union[np.random.Generator, np.random.RandomState]] = None,
qubits: Optional[Sequence['cirq.Qid']] = None,
initial_state: Union[np.ndarray, 'cirq.STATE_VECTOR_LIKE'] = 0,
dtype: Type[np.complexfloating] = np.complex64,
6 changes: 3 additions & 3 deletions cirq-core/cirq/testing/consistent_act_on.py
Original file line number Diff line number Diff line change
@@ -52,7 +52,7 @@ def state_vector_has_stabilizer(state_vector: np.ndarray, stabilizer: DensePauli
args = state_vector_simulation_state.StateVectorSimulationState(
available_buffer=np.empty_like(state_vector),
qubits=qubits,
prng=np.random.RandomState(),
prng=np.random.default_rng(),
initial_state=state_vector.copy(),
dtype=complex_dtype,
)
@@ -163,7 +163,7 @@ def _final_clifford_tableau(

tableau = clifford_tableau.CliffordTableau(len(qubit_map))
args = clifford_tableau_simulation_state.CliffordTableauSimulationState(
tableau=tableau, qubits=list(qubit_map.keys()), prng=np.random.RandomState()
tableau=tableau, qubits=list(qubit_map.keys()), prng=np.random.default_rng()
)
for op in circuit.all_operations():
try:
@@ -192,7 +192,7 @@ def _final_stabilizer_state_ch_form(
stabilizer_ch_form = stabilizer_state_ch_form.StabilizerStateChForm(len(qubit_map))
args = stabilizer_ch_form_simulation_state.StabilizerChFormSimulationState(
qubits=list(qubit_map.keys()),
prng=np.random.RandomState(),
prng=np.random.default_rng(),
initial_state=stabilizer_ch_form,
)
for op in circuit.all_operations():
14 changes: 7 additions & 7 deletions cirq-core/cirq/testing/lin_alg_utils.py
Original file line number Diff line number Diff line change
@@ -13,7 +13,7 @@
# limitations under the License.
"""A testing class with utilities for checking linear algebra."""

from typing import Optional, TYPE_CHECKING
from typing import Optional, TYPE_CHECKING, Union

import numpy as np

@@ -39,8 +39,8 @@ def random_superposition(
"""
random_state = value.parse_random_state(random_state)

state_vector = random_state.randn(dim).astype(complex)
state_vector += 1j * random_state.randn(dim)
state_vector = random_state.random(dim).astype(complex)
state_vector += 1j * random_state.random(dim)
state_vector /= np.linalg.norm(state_vector)
return state_vector

@@ -63,7 +63,7 @@ def random_density_matrix(
"""
random_state = value.parse_random_state(random_state)

mat = random_state.randn(dim, dim) + 1j * random_state.randn(dim, dim)
mat = random_state.random((dim, dim)) + 1j * random_state.random((dim, dim))
mat = mat @ mat.T.conj()
return mat / np.trace(mat)

@@ -86,7 +86,7 @@ def random_unitary(
"""
random_state = value.parse_random_state(random_state)

z = random_state.randn(dim, dim) + 1j * random_state.randn(dim, dim)
z = random_state.random((dim, dim)) + 1j * random_state.random((dim, dim))
q, r = np.linalg.qr(z)
d = np.diag(r)
return q * (d / abs(d))
@@ -112,14 +112,14 @@ def random_orthogonal(
"""
random_state = value.parse_random_state(random_state)

m = random_state.randn(dim, dim)
m = random_state.random((dim, dim))
q, r = np.linalg.qr(m)
d = np.diag(r)
return q * (d / abs(d))


def random_special_unitary(
dim: int, *, random_state: Optional[np.random.RandomState] = None
dim: int, *, random_state: Optional[Union[np.random.Generator, np.random.RandomState]] = None
) -> np.ndarray:
"""Returns a random special unitary distributed with Haar measure.
6 changes: 3 additions & 3 deletions cirq-core/cirq/testing/random_circuit.py
Original file line number Diff line number Diff line change
@@ -113,10 +113,10 @@ def random_circuit(
operations = []
free_qubits = set(qubits)
while len(free_qubits) >= max_arity:
gate, arity = gate_arity_pairs[prng.randint(num_gates)]
gate, arity = gate_arity_pairs[prng.integers(num_gates)]
op_qubits = prng.choice(sorted(free_qubits), size=arity, replace=False)
free_qubits.difference_update(op_qubits)
if prng.rand() <= op_density:
if prng.random() <= op_density:
operations.append(gate(*op_qubits))
moments.append(circuits.Moment(operations))

@@ -147,7 +147,7 @@ def random_two_qubit_circuit_with_czs(
q1 = ops.NamedQubit('q1') if q1 is None else q1

def random_one_qubit_gate():
return ops.PhasedXPowGate(phase_exponent=prng.rand(), exponent=prng.rand())
return ops.PhasedXPowGate(phase_exponent=prng.random(), exponent=prng.random())

def one_cz():
return [ops.CZ.on(q0, q1), random_one_qubit_gate().on(q0), random_one_qubit_gate().on(q1)]
Original file line number Diff line number Diff line change
@@ -115,7 +115,7 @@ def decompose_clifford_tableau_to_operations(
t: qis.CliffordTableau = clifford_tableau.copy()
operations: List[ops.Operation] = []
args = sim.CliffordTableauSimulationState(
tableau=t, qubits=qubits, prng=np.random.RandomState()
tableau=t, qubits=qubits, prng=np.random.default_rng()
)

_X_with_ops = functools.partial(_X, args=args, operations=operations, qubits=qubits)
Original file line number Diff line number Diff line change
@@ -44,7 +44,7 @@ def _single_qubit_unitary(
def random_qubit_unitary(
shape: Sequence[int] = (),
randomize_global_phase: bool = False,
rng: Optional[np.random.RandomState] = None,
rng: Optional[Union[np.random.Generator, np.random.RandomState]] = None,
) -> np.ndarray:
"""Random qubit unitary distributed over the Haar measure.
@@ -61,15 +61,15 @@ def random_qubit_unitary(
"""
real_rng = random_state.parse_random_state(rng)

theta = np.arcsin(np.sqrt(real_rng.rand(*shape)))
phi_d = real_rng.rand(*shape) * np.pi * 2
phi_o = real_rng.rand(*shape) * np.pi * 2
theta = np.arcsin(np.sqrt(real_rng.random(*shape)))
phi_d = real_rng.random(*shape) * np.pi * 2
phi_o = real_rng.random(*shape) * np.pi * 2

out = _single_qubit_unitary(theta, phi_d, phi_o)

if randomize_global_phase:
out = np.moveaxis(out, (-2, -1), (0, 1))
out *= np.exp(1j * np.pi * 2 * real_rng.rand(*shape))
out *= np.exp(1j * np.pi * 2 * real_rng.random(*shape))
out = np.moveaxis(out, (0, 1), (-2, -1))
return out

12 changes: 8 additions & 4 deletions cirq-core/cirq/value/random_state.py
Original file line number Diff line number Diff line change
@@ -37,7 +37,7 @@
)


def parse_random_state(random_state: RANDOM_STATE_OR_SEED_LIKE) -> np.random.RandomState:
def parse_random_state(random_state: RANDOM_STATE_OR_SEED_LIKE) -> np.random.Generator:
"""Interpret an object as a pseudorandom number generator.
If `random_state` is None, returns the module `np.random`.
@@ -53,8 +53,12 @@ def parse_random_state(random_state: RANDOM_STATE_OR_SEED_LIKE) -> np.random.Ran
The pseudorandom number generator object.
"""
if random_state is None:
return cast(np.random.RandomState, np.random)
return np.random.default_rng()
elif isinstance(random_state, int):
return np.random.RandomState(random_state)
return np.random.default_rng(random_state)
elif isinstance(random_state, np.random.RandomState):
return np.random.default_rng(random_state.get_state()[1][0]) # type: ignore[index]
elif isinstance(random_state, np.random.Generator):
return random_state
else:
return cast(np.random.RandomState, random_state)
return np.random.default_rng()
19 changes: 2 additions & 17 deletions cirq-core/cirq/value/random_state_test.py
Original file line number Diff line number Diff line change
@@ -18,27 +18,12 @@


def test_parse_random_state():
global_state = np.random.get_state()

def rand(prng):
np.random.set_state(global_state)
return prng.rand()

prngs = [
np.random,
cirq.value.parse_random_state(np.random),
cirq.value.parse_random_state(None),
]
vals = [rand(prng) for prng in prngs]
eq = cirq.testing.EqualsTester()
eq.add_equality_group(*vals)

seed = np.random.randint(2**31)
prngs = [
np.random.RandomState(seed),
np.random.default_rng(seed),
cirq.value.parse_random_state(np.random.RandomState(seed)),
cirq.value.parse_random_state(seed),
]
vals = [prng.rand() for prng in prngs]
vals = [prng.random() for prng in prngs]
eq = cirq.testing.EqualsTester()
eq.add_equality_group(*vals)