Skip to content

Commit bcf0ace

Browse files
committed
ENH: Allow to pass a prior when instantiating QuestPlusWeibull
1 parent 1ef53b3 commit bcf0ace

File tree

3 files changed

+28
-0
lines changed

3 files changed

+28
-0
lines changed

CHANGES.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,7 @@
1+
v2019.3
2+
-------
3+
* Allow to pass a prior when instantiating `QuestPlusWeibull`
4+
15
v2019.2
26
-------
37
* Allow passing a random seed via `stim_selection_options` keyword

questplus/qp.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -470,6 +470,7 @@ def __init__(self, *,
470470
slopes: Sequence,
471471
lower_asymptotes: Sequence,
472472
lapse_rates: Sequence,
473+
prior: Optional[dict] = None,
473474
responses: Sequence = ('Yes', 'No'),
474475
stim_scale: str = 'log10',
475476
stim_selection_method: str = 'min_entropy',
@@ -481,6 +482,7 @@ def __init__(self, *,
481482
lower_asymptote=lower_asymptotes,
482483
lapse_rate=lapse_rates),
483484
outcome_domain=dict(response=responses),
485+
prior=prior,
484486
stim_scale=stim_scale,
485487
stim_selection_method=stim_selection_method,
486488
stim_selection_options=stim_selection_options,

questplus/tests/test_qp.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -658,6 +658,27 @@ def test_stim_selection_options():
658658
assert expected == q.stim_selection_options
659659

660660

661+
def test_weibull_prior():
662+
intensities = np.linspace(-10, 0)
663+
thresholds = intensities.copy()
664+
slopes = [3.5]
665+
lower_asymptotes = [0.01]
666+
lapse_rates = [0.01]
667+
668+
prior_val = scipy.stats.norm.pdf(intensities, loc=-5, scale=0.2)
669+
prior_val /= prior_val.sum()
670+
prior = dict(threshold=prior_val)
671+
672+
q = QuestPlusWeibull(intensities=intensities,
673+
thresholds=thresholds,
674+
slopes=slopes,
675+
lower_asymptotes=lower_asymptotes,
676+
lapse_rates=lapse_rates,
677+
prior=prior)
678+
679+
assert np.allclose(q.prior.squeeze().values, prior_val)
680+
681+
661682
if __name__ == '__main__':
662683
test_threshold()
663684
test_threshold_slope()
@@ -671,3 +692,4 @@ def test_stim_selection_options():
671692
test_prior_for_unknown_parameter()
672693
test_prior_for_parameter_subset()
673694
test_stim_selection_options()
695+
test_weibull_prior()

0 commit comments

Comments
 (0)