Skip to content

Commit 9abd28b

Browse files
Merge pull request #40 from hoechenberger/rng
ENH: Allow JSON serialization of RNG
2 parents a0fb0e6 + 02a51e9 commit 9abd28b

File tree

3 files changed

+45
-0
lines changed

3 files changed

+45
-0
lines changed

CHANGES.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,7 @@
1+
v2019.4
2+
-------
3+
* Allow JSON serialization of random number generator
4+
15
v2019.3
26
-------
37
* Allow to pass a prior when instantiating `QuestPlusWeibull`

questplus/qp.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -386,6 +386,10 @@ def to_json(self) -> str:
386386
self_copy.prior = self_copy.prior.to_dict()
387387
self_copy.posterior = self_copy.posterior.to_dict()
388388
self_copy.likelihoods = self_copy.likelihoods.to_dict()
389+
390+
if self_copy._rng is not None: # NumPy RandomState cannot be serialized.
391+
self_copy._rng = self_copy._rng.get_state()
392+
389393
return json_tricks.dumps(self_copy, allow_nan=True)
390394

391395
@staticmethod
@@ -412,6 +416,12 @@ def from_json(data: str):
412416
loaded.prior = xr.DataArray.from_dict(loaded.prior)
413417
loaded.posterior = xr.DataArray.from_dict(loaded.posterior)
414418
loaded.likelihoods = xr.DataArray.from_dict(loaded.likelihoods)
419+
420+
if loaded._rng is not None:
421+
state = deepcopy(loaded._rng)
422+
loaded._rng = np.random.RandomState()
423+
loaded._rng.set_state(state)
424+
415425
return loaded
416426

417427
def __eq__(self, other):

questplus/tests/test_qp.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -494,6 +494,36 @@ def test_json():
494494
q_loaded.update(stim=q_loaded.next_stim, outcome=dict(response='Correct'))
495495

496496

497+
def test_json_rng():
498+
threshold = np.arange(-40, 0 + 1)
499+
slope, guess, lapse = 3.5, 0.5, 0.02
500+
contrasts = threshold.copy()
501+
502+
stim_domain = dict(intensity=contrasts)
503+
param_domain = dict(threshold=threshold, slope=slope,
504+
lower_asymptote=guess, lapse_rate=lapse)
505+
outcome_domain = dict(response=['Correct', 'Incorrect'])
506+
f = 'weibull'
507+
scale = 'dB'
508+
stim_selection_method = 'min_n_entropy'
509+
param_estimation_method = 'mode'
510+
random_seed = 5
511+
stim_selection_options = dict(n=3, random_seed=random_seed)
512+
513+
q = QuestPlus(stim_domain=stim_domain, param_domain=param_domain,
514+
outcome_domain=outcome_domain, func=f, stim_scale=scale,
515+
stim_selection_method=stim_selection_method,
516+
param_estimation_method=param_estimation_method,
517+
stim_selection_options=stim_selection_options)
518+
519+
q2 = QuestPlus.from_json(q.to_json())
520+
521+
rand = q._rng.random_sample(10)
522+
rand2 = q2._rng.random_sample(10)
523+
524+
assert np.allclose(rand, rand2)
525+
526+
497527
def test_marginal_posterior():
498528
contrasts = np.arange(-40, 0 + 1)
499529
slope = np.arange(2, 5 + 1)
@@ -688,6 +718,7 @@ def test_weibull_prior():
688718
test_weibull()
689719
test_eq()
690720
test_json()
721+
test_json_rng()
691722
test_marginal_posterior()
692723
test_prior_for_unknown_parameter()
693724
test_prior_for_parameter_subset()

0 commit comments

Comments
 (0)