22import scipy .stats
33import numpy as np
44from questplus .qp import QuestPlus , QuestPlusWeibull
5+ from questplus import _constants
56
67
78def test_threshold ():
@@ -597,6 +598,66 @@ def test_prior_for_parameter_subset():
597598 'lower_asymptote' ]).sum ())
598599
599600
601+ def test_stim_selection_options ():
602+ threshold = np .arange (- 40 , 0 + 1 )
603+ slope , guess , lapse = 3.5 , 0.5 , 0.02
604+ contrasts = threshold .copy ()
605+
606+ stim_domain = dict (intensity = contrasts )
607+ param_domain = dict (threshold = threshold , slope = slope ,
608+ lower_asymptote = guess , lapse_rate = lapse )
609+ outcome_domain = dict (response = ['Correct' , 'Incorrect' ])
610+
611+ f = 'weibull'
612+ scale = 'dB'
613+ stim_selection_method = 'min_n_entropy'
614+ param_estimation_method = 'mode'
615+
616+ common_params = dict (stim_domain = stim_domain , param_domain = param_domain ,
617+ outcome_domain = outcome_domain , func = f ,
618+ stim_scale = scale ,
619+ stim_selection_method = stim_selection_method ,
620+ param_estimation_method = param_estimation_method )
621+
622+ stim_selection_options = None
623+ q = QuestPlus (** common_params ,
624+ stim_selection_options = stim_selection_options )
625+ expected = dict (n = _constants .DEFAULT_N ,
626+ max_consecutive_reps = _constants .DEFAULT_MAX_CONSECUTIVE_REPS ,
627+ random_seed = _constants .DEFAULT_RANDOM_SEED )
628+ assert expected == q .stim_selection_options
629+
630+ stim_selection_options = dict (n = 5 )
631+ q = QuestPlus (** common_params ,
632+ stim_selection_options = stim_selection_options )
633+ expected = dict (n = 5 ,
634+ max_consecutive_reps = _constants .DEFAULT_MAX_CONSECUTIVE_REPS ,
635+ random_seed = _constants .DEFAULT_RANDOM_SEED )
636+ assert expected == q .stim_selection_options
637+
638+ stim_selection_options = dict (max_consecutive_reps = 4 )
639+ q = QuestPlus (** common_params ,
640+ stim_selection_options = stim_selection_options )
641+ expected = dict (n = _constants .DEFAULT_N ,
642+ max_consecutive_reps = 4 ,
643+ random_seed = _constants .DEFAULT_RANDOM_SEED )
644+ assert expected == q .stim_selection_options
645+
646+ stim_selection_options = dict (random_seed = 999 )
647+ q = QuestPlus (** common_params ,
648+ stim_selection_options = stim_selection_options )
649+ expected = dict (n = _constants .DEFAULT_N ,
650+ max_consecutive_reps = _constants .DEFAULT_MAX_CONSECUTIVE_REPS ,
651+ random_seed = 999 )
652+ assert expected == q .stim_selection_options
653+
654+ stim_selection_options = dict (n = 5 , max_consecutive_reps = 4 , random_seed = 999 )
655+ q = QuestPlus (** common_params ,
656+ stim_selection_options = stim_selection_options )
657+ expected = stim_selection_options .copy ()
658+ assert expected == q .stim_selection_options
659+
660+
600661if __name__ == '__main__' :
601662 test_threshold ()
602663 test_threshold_slope ()
@@ -609,3 +670,4 @@ def test_prior_for_parameter_subset():
609670 test_marginal_posterior ()
610671 test_prior_for_unknown_parameter ()
611672 test_prior_for_parameter_subset ()
673+ test_stim_selection_options ()
0 commit comments