Skip to content

Commit c321005

Browse files
jpjustinpan0
jp
authored andcommitted
Refactor simulator RNG handling
1 parent c1465c1 commit c321005

File tree

2 files changed

+49
-6
lines changed

2 files changed

+49
-6
lines changed

cirq-core/cirq/sim/simulator_base.py

+23-6
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
Type,
3030
TypeVar,
3131
TYPE_CHECKING,
32+
Union,
3233
)
3334

3435
import numpy as np
@@ -93,21 +94,27 @@ def __init__(
9394
*,
9495
dtype: Type[np.complexfloating] = np.complex64,
9596
noise: 'cirq.NOISE_MODEL_LIKE' = None,
96-
seed: 'cirq.RANDOM_STATE_OR_SEED_LIKE' = None,
97+
seed: Optional[Union[int, np.random.Generator, np.random.RandomState]] = None,
9798
split_untangled_states: bool = False,
9899
):
99100
"""Initializes the simulator.
100101
101102
Args:
102103
dtype: The `numpy.dtype` used by the simulation.
103104
noise: A noise model to apply while simulating.
104-
seed: The random seed to use for this simulator.
105+
seed: The random seed or generator to use for this simulator.
105106
split_untangled_states: If True, optimizes simulation by running
106107
unentangled qubit sets independently and merging those states
107108
at the end.
108109
"""
109110
self._dtype = dtype
110-
self._prng = value.parse_random_state(seed)
111+
if isinstance(seed, np.random.RandomState):
112+
# Convert RandomState to Generator for backward compatibility
113+
self._prng = np.random.default_rng(seed.get_state()[1][0])
114+
elif isinstance(seed, np.random.Generator):
115+
self._prng = seed
116+
else:
117+
self._prng = np.random.default_rng(seed)
111118
self._noise = devices.NoiseModel.from_noise_model_like(noise)
112119
self._split_untangled_states = split_untangled_states
113120

@@ -228,6 +235,7 @@ def _run(
228235
circuit: 'cirq.AbstractCircuit',
229236
param_resolver: 'cirq.ParamResolver',
230237
repetitions: int,
238+
rng: Optional[np.random.Generator] = None,
231239
) -> Dict[str, np.ndarray]:
232240
"""See definition in `cirq.SimulatesSamples`."""
233241
param_resolver = param_resolver or study.ParamResolver({})
@@ -254,7 +262,10 @@ def _run(
254262
assert step_result is not None
255263
measurement_ops = [cast(ops.GateOperation, op) for op in general_ops]
256264
return step_result.sample_measurement_ops(
257-
measurement_ops, repetitions, seed=self._prng, _allow_repeated=True
265+
measurement_ops,
266+
repetitions,
267+
seed=rng if rng is not None else self._prng,
268+
_allow_repeated=True,
258269
)
259270

260271
records: Dict['cirq.MeasurementKey', List[Sequence[Sequence[int]]]] = {}
@@ -395,9 +406,15 @@ def sample(
395406
self,
396407
qubits: List['cirq.Qid'],
397408
repetitions: int = 1,
398-
seed: 'cirq.RANDOM_STATE_OR_SEED_LIKE' = None,
409+
seed: Optional[Union[int, np.random.Generator, np.random.RandomState]] = None,
399410
) -> np.ndarray:
400-
return self._sim_state.sample(qubits, repetitions, seed)
411+
if isinstance(seed, np.random.RandomState):
412+
rng = np.random.default_rng(seed.get_state()[1][0])
413+
elif isinstance(seed, np.random.Generator):
414+
rng = seed
415+
else:
416+
rng = np.random.default_rng(seed)
417+
return self._sim_state.sample(qubits, repetitions, rng)
401418

402419

403420
class SimulationTrialResultBase(

cirq-core/cirq/sim/simulator_base_test.py

+26
Original file line numberDiff line numberDiff line change
@@ -434,3 +434,29 @@ def test_inhomogeneous_measurement_count_padding():
434434
results = sim.run(c, repetitions=10)
435435
for i in range(10):
436436
assert np.sum(results.records['m'][i, :, :]) == 1
437+
438+
439+
def test_run_with_custom_rng():
440+
sim = cirq.Simulator()
441+
circuit = cirq.Circuit(cirq.H(q0), cirq.measure(q0))
442+
rng1 = np.random.default_rng(seed=1234)
443+
rng2 = np.random.default_rng(seed=1234)
444+
445+
result1 = sim._run(circuit, param_resolver=cirq.ParamResolver({}), repetitions=10, rng=rng1)
446+
result2 = sim._run(circuit, param_resolver=cirq.ParamResolver({}), repetitions=10, rng=rng2)
447+
assert np.array_equal(result1['q(0)'], result2['q(0)'])
448+
449+
rng3 = np.random.default_rng(seed=5678)
450+
result3 = sim._run(circuit, param_resolver=cirq.ParamResolver({}), repetitions=10, rng=rng3)
451+
assert not np.array_equal(result1['q(0)'], result3['q(0)'])
452+
453+
454+
def test_run_with_explicit_rng_override():
455+
sim1 = cirq.Simulator(seed=1234)
456+
sim2 = cirq.Simulator(seed=5678)
457+
circuit = cirq.Circuit(cirq.H(q0), cirq.measure(q0))
458+
rng = np.random.default_rng(1234)
459+
460+
result1 = sim1._run(circuit, cirq.ParamResolver({}), repetitions=10)
461+
result2 = sim2._run(circuit, cirq.ParamResolver({}), repetitions=10, rng=rng)
462+
assert np.array_equal(result1['q(0)'], result2['q(0)'])

0 commit comments

Comments
 (0)