29
29
Type ,
30
30
TypeVar ,
31
31
TYPE_CHECKING ,
32
+ Union ,
32
33
)
33
34
34
35
import numpy as np
@@ -93,21 +94,27 @@ def __init__(
93
94
* ,
94
95
dtype : Type [np .complexfloating ] = np .complex64 ,
95
96
noise : 'cirq.NOISE_MODEL_LIKE' = None ,
96
- seed : 'cirq.RANDOM_STATE_OR_SEED_LIKE' = None ,
97
+ seed : Optional [ Union [ int , np . random . Generator , np . random . RandomState ]] = None ,
97
98
split_untangled_states : bool = False ,
98
99
):
99
100
"""Initializes the simulator.
100
101
101
102
Args:
102
103
dtype: The `numpy.dtype` used by the simulation.
103
104
noise: A noise model to apply while simulating.
104
- seed: The random seed to use for this simulator.
105
+ seed: The random seed or generator to use for this simulator.
105
106
split_untangled_states: If True, optimizes simulation by running
106
107
unentangled qubit sets independently and merging those states
107
108
at the end.
108
109
"""
109
110
self ._dtype = dtype
110
- self ._prng = value .parse_random_state (seed )
111
+ if isinstance (seed , np .random .RandomState ):
112
+ # Convert RandomState to Generator for backward compatibility
113
+ self ._prng = np .random .default_rng (seed .get_state ()[1 ][0 ])
114
+ elif isinstance (seed , np .random .Generator ):
115
+ self ._prng = seed
116
+ else :
117
+ self ._prng = np .random .default_rng (seed )
111
118
self ._noise = devices .NoiseModel .from_noise_model_like (noise )
112
119
self ._split_untangled_states = split_untangled_states
113
120
@@ -228,6 +235,7 @@ def _run(
228
235
circuit : 'cirq.AbstractCircuit' ,
229
236
param_resolver : 'cirq.ParamResolver' ,
230
237
repetitions : int ,
238
+ rng : Optional [np .random .Generator ] = None ,
231
239
) -> Dict [str , np .ndarray ]:
232
240
"""See definition in `cirq.SimulatesSamples`."""
233
241
param_resolver = param_resolver or study .ParamResolver ({})
@@ -254,7 +262,10 @@ def _run(
254
262
assert step_result is not None
255
263
measurement_ops = [cast (ops .GateOperation , op ) for op in general_ops ]
256
264
return step_result .sample_measurement_ops (
257
- measurement_ops , repetitions , seed = self ._prng , _allow_repeated = True
265
+ measurement_ops ,
266
+ repetitions ,
267
+ seed = rng if rng is not None else self ._prng ,
268
+ _allow_repeated = True ,
258
269
)
259
270
260
271
records : Dict ['cirq.MeasurementKey' , List [Sequence [Sequence [int ]]]] = {}
@@ -395,9 +406,15 @@ def sample(
395
406
self ,
396
407
qubits : List ['cirq.Qid' ],
397
408
repetitions : int = 1 ,
398
- seed : 'cirq.RANDOM_STATE_OR_SEED_LIKE' = None ,
409
+ seed : Optional [ Union [ int , np . random . Generator , np . random . RandomState ]] = None ,
399
410
) -> np .ndarray :
400
- return self ._sim_state .sample (qubits , repetitions , seed )
411
+ if isinstance (seed , np .random .RandomState ):
412
+ rng = np .random .default_rng (seed .get_state ()[1 ][0 ])
413
+ elif isinstance (seed , np .random .Generator ):
414
+ rng = seed
415
+ else :
416
+ rng = np .random .default_rng (seed )
417
+ return self ._sim_state .sample (qubits , repetitions , rng )
401
418
402
419
403
420
class SimulationTrialResultBase (
0 commit comments