Skip to content

Commit 1ef53b3

Browse files
Merge pull request #38 from hoechenberger/rng
ENH: Allow passing a random seed via stim_selection_options
2 parents 146e721 + 15e63bf commit 1ef53b3

File tree

6 files changed

+105
-9
lines changed

6 files changed

+105
-9
lines changed

.appveyor.yml

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ install:
2121
- python setup.py build
2222

2323
# Build & install sdist.
24-
- python setup.py sdist
24+
- python setup.py sdist --formats=zip
2525
# - pip install --no-deps dist/questplus-*.zip
2626
# - pip uninstall --yes questplus
2727

@@ -30,7 +30,8 @@ install:
3030
# - pip install --no-deps dist/questplus-*.whl
3131
# - pip uninstall --yes questplus
3232

33-
- pip install .
33+
- ps: Remove-Item –path dist, build –recurse
34+
- pip install --no-deps .
3435

3536
test_script:
3637
- py.test

.travis.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ install:
5050

5151
- rm -rf dist/ build/
5252

53-
- pip install .
53+
- pip install --no-deps .
5454

5555
script:
5656
- py.test

CHANGES.md

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,10 @@
1+
v2019.2
2+
-------
3+
* Allow passing a random seed via `stim_selection_options` keyword
4+
argument
5+
* Better handling of `stim_selection_options` defaults (now allows
6+
to supply only a subset of options)
7+
18
v2019.1
29
-------
310
* Allow to pass priors for only some parameters

questplus/_constants.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
DEFAULT_N = 4
2+
DEFAULT_MAX_CONSECUTIVE_REPS = 2
3+
DEFAULT_RANDOM_SEED = None

questplus/qp.py

Lines changed: 29 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -66,8 +66,9 @@ def __init__(self, *,
6666
method specified via `stim_selection_method`. Currently, this can
6767
be used to specify the number of `n` stimuli that will yield the
6868
`n` smallest entropies if `stim_selection_method=min_n_entropy`,
69-
and`max_consecutive_reps`, the number of times the same stimulus
70-
can be presented consecutively.
69+
and `max_consecutive_reps`, the number of times the same stimulus
70+
can be presented consecutively. A random number generator seed
71+
may be passed via `random_seed=12345`.
7172
7273
param_estimation_method
7374
The method to use when deriving the final parameter estimate.
@@ -88,11 +89,33 @@ def __init__(self, *,
8889

8990
self.stim_selection = stim_selection_method
9091

91-
if (self.stim_selection == 'min_n_entropy' and
92-
stim_selection_options is None):
93-
self.stim_selection_options = dict(n=4, max_consecutive_reps=2)
92+
if self.stim_selection == 'min_n_entropy':
93+
from ._constants import (DEFAULT_N, DEFAULT_RANDOM_SEED,
94+
DEFAULT_MAX_CONSECUTIVE_REPS)
95+
96+
if stim_selection_options is None:
97+
self.stim_selection_options = dict(
98+
n=DEFAULT_N,
99+
max_consecutive_reps=DEFAULT_MAX_CONSECUTIVE_REPS,
100+
random_seed=DEFAULT_RANDOM_SEED)
101+
else:
102+
self.stim_selection_options = stim_selection_options.copy()
103+
104+
if 'n' not in stim_selection_options:
105+
self.stim_selection_options['n'] = DEFAULT_N
106+
if 'max_consecutive_reps' not in stim_selection_options:
107+
self.stim_selection_options['max_consecutive_reps'] = DEFAULT_MAX_CONSECUTIVE_REPS
108+
if 'random_seed' not in stim_selection_options:
109+
self.stim_selection_options['random_seed'] = DEFAULT_RANDOM_SEED
110+
111+
del DEFAULT_N, DEFAULT_MAX_CONSECUTIVE_REPS, DEFAULT_RANDOM_SEED
112+
113+
seed = self.stim_selection_options['random_seed']
114+
self._rng = np.random.RandomState(seed=seed)
115+
del seed
94116
else:
95117
self.stim_selection_options = stim_selection_options
118+
self._rng = None
96119

97120
self.param_estimation_method = param_estimation_method
98121

@@ -271,7 +294,7 @@ def next_stim(self) -> dict:
271294
while True:
272295
# Randomly pick one index and retrieve its coordinates
273296
# (stimulus parameters).
274-
candidate_index = np.random.choice(indices)
297+
candidate_index = self._rng.choice(indices)
275298
coords = EH[candidate_index].coords
276299
stim = {stim_property: stim_val.item()
277300
for stim_property, stim_val in coords.items()}

questplus/tests/test_qp.py

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import scipy.stats
33
import numpy as np
44
from questplus.qp import QuestPlus, QuestPlusWeibull
5+
from questplus import _constants
56

67

78
def test_threshold():
@@ -597,6 +598,66 @@ def test_prior_for_parameter_subset():
597598
'lower_asymptote']).sum())
598599

599600

601+
def test_stim_selection_options():
602+
threshold = np.arange(-40, 0 + 1)
603+
slope, guess, lapse = 3.5, 0.5, 0.02
604+
contrasts = threshold.copy()
605+
606+
stim_domain = dict(intensity=contrasts)
607+
param_domain = dict(threshold=threshold, slope=slope,
608+
lower_asymptote=guess, lapse_rate=lapse)
609+
outcome_domain = dict(response=['Correct', 'Incorrect'])
610+
611+
f = 'weibull'
612+
scale = 'dB'
613+
stim_selection_method = 'min_n_entropy'
614+
param_estimation_method = 'mode'
615+
616+
common_params = dict(stim_domain=stim_domain, param_domain=param_domain,
617+
outcome_domain=outcome_domain, func=f,
618+
stim_scale=scale,
619+
stim_selection_method=stim_selection_method,
620+
param_estimation_method=param_estimation_method)
621+
622+
stim_selection_options = None
623+
q = QuestPlus(**common_params,
624+
stim_selection_options=stim_selection_options)
625+
expected = dict(n=_constants.DEFAULT_N,
626+
max_consecutive_reps=_constants.DEFAULT_MAX_CONSECUTIVE_REPS,
627+
random_seed=_constants.DEFAULT_RANDOM_SEED)
628+
assert expected == q.stim_selection_options
629+
630+
stim_selection_options = dict(n=5)
631+
q = QuestPlus(**common_params,
632+
stim_selection_options=stim_selection_options)
633+
expected = dict(n=5,
634+
max_consecutive_reps=_constants.DEFAULT_MAX_CONSECUTIVE_REPS,
635+
random_seed=_constants.DEFAULT_RANDOM_SEED)
636+
assert expected == q.stim_selection_options
637+
638+
stim_selection_options = dict(max_consecutive_reps=4)
639+
q = QuestPlus(**common_params,
640+
stim_selection_options=stim_selection_options)
641+
expected = dict(n=_constants.DEFAULT_N,
642+
max_consecutive_reps=4,
643+
random_seed=_constants.DEFAULT_RANDOM_SEED)
644+
assert expected == q.stim_selection_options
645+
646+
stim_selection_options = dict(random_seed=999)
647+
q = QuestPlus(**common_params,
648+
stim_selection_options=stim_selection_options)
649+
expected = dict(n=_constants.DEFAULT_N,
650+
max_consecutive_reps=_constants.DEFAULT_MAX_CONSECUTIVE_REPS,
651+
random_seed=999)
652+
assert expected == q.stim_selection_options
653+
654+
stim_selection_options = dict(n=5, max_consecutive_reps=4, random_seed=999)
655+
q = QuestPlus(**common_params,
656+
stim_selection_options=stim_selection_options)
657+
expected = stim_selection_options.copy()
658+
assert expected == q.stim_selection_options
659+
660+
600661
if __name__ == '__main__':
601662
test_threshold()
602663
test_threshold_slope()
@@ -609,3 +670,4 @@ def test_prior_for_parameter_subset():
609670
test_marginal_posterior()
610671
test_prior_for_unknown_parameter()
611672
test_prior_for_parameter_subset()
673+
test_stim_selection_options()

0 commit comments

Comments
 (0)