Skip to content

Commit 14527e6

Browse files
Merge pull request #16 from hoechenberger/subclasses
NF: Add QuestPlusWeibull convenience class
2 parents 73f20db + ce6486f commit 14527e6

File tree

3 files changed

+121
-6
lines changed

3 files changed

+121
-6
lines changed

questplus/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from .qp import QuestPlus
1+
from .qp import QuestPlus, QuestPlusWeibull
22

33
from ._version import get_versions
44
__version__ = get_versions()['version']

questplus/qp.py

Lines changed: 60 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Optional
1+
from typing import Optional, Sequence
22
import xarray as xr
33
import numpy as np
44
from copy import deepcopy
@@ -293,3 +293,62 @@ def param_estimate(self) -> dict:
293293
raise ValueError('Unknown method parameter.')
294294

295295
return param_estimates
296+
297+
298+
class QuestPlusWeibull(QuestPlus):
299+
def __init__(self, *,
300+
intensities: Sequence,
301+
thresholds: Sequence,
302+
slopes: Sequence,
303+
lower_asymptotes: Sequence,
304+
lapse_rates: Sequence,
305+
responses: Sequence = ('Yes', 'No'),
306+
stim_scale: str = 'log10',
307+
stim_selection_method: str = 'min_entropy',
308+
stim_selection_options: Optional[dict] = None,
309+
param_estimation_method: str = 'mean'):
310+
super().__init__(stim_domain=dict(intensity=intensities),
311+
param_domain=dict(threshold=thresholds,
312+
slope=slopes,
313+
lower_asymptote=lower_asymptotes,
314+
lapse_rate=lapse_rates),
315+
outcome_domain=dict(response=responses),
316+
stim_scale=stim_scale,
317+
stim_selection_method=stim_selection_method,
318+
stim_selection_options=stim_selection_options,
319+
param_estimation_method=param_estimation_method,
320+
func='weibull')
321+
322+
@property
323+
def intensities(self):
324+
return self.stim_domain['intensity']
325+
326+
@property
327+
def thresholds(self):
328+
return self.param_domain['threshold']
329+
330+
@property
331+
def slopes(self):
332+
return self.param_domain['slope']
333+
334+
@property
335+
def lower_asymptotes(self):
336+
return self.param_domain['lower_asymptote']
337+
338+
@property
339+
def lapse_rates(self):
340+
return self.param_domain['lapse_rate']
341+
342+
@property
343+
def responses(self):
344+
return self.outcome_domain['response']
345+
346+
@property
347+
def next_intensity(self):
348+
return super().next_stim['intensity']
349+
350+
def update(self, *,
351+
intensity: float,
352+
response: str) -> None:
353+
super().update(stim=dict(intensity=intensity),
354+
outcome=dict(response=response))

questplus/tests/test_qp.py

Lines changed: 60 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import numpy as np
2-
from questplus.qp import QuestPlus
2+
from questplus.qp import QuestPlus, QuestPlusWeibull
33

44

55
def test_threshold():
@@ -370,9 +370,65 @@ def test_spatial_contrast_sensitivity():
370370
expected_mode_cf)
371371

372372

373+
def test_weibull():
374+
threshold = np.arange(-40, 0 + 1)
375+
slope, guess, lapse = 3.5, 0.5, 0.02
376+
contrasts = threshold.copy()
377+
378+
expected_contrasts = [-18, -22, -25, -28, -30, -22, -13, -15, -16, -18,
379+
-19, -20, -21, -22, -23, -19, -20, -20, -18, -18,
380+
-19, -17, -17, -18, -18, -18, -19, -19, -19, -19,
381+
-19, -19]
382+
383+
responses = ['Correct', 'Correct', 'Correct', 'Correct', 'Incorrect',
384+
'Incorrect', 'Correct', 'Correct', 'Correct', 'Correct',
385+
'Correct', 'Correct', 'Correct', 'Correct', 'Incorrect',
386+
'Correct', 'Correct', 'Incorrect', 'Correct', 'Correct',
387+
'Incorrect', 'Correct', 'Correct', 'Correct', 'Correct',
388+
'Correct', 'Correct', 'Correct', 'Correct', 'Correct',
389+
'Correct', 'Correct']
390+
391+
expected_mode_threshold = -20
392+
393+
stim_domain = dict(intensity=contrasts)
394+
param_domain = dict(threshold=threshold, slope=slope,
395+
lower_asymptote=guess, lapse_rate=lapse)
396+
outcome_domain = dict(response=['Correct', 'Incorrect'])
397+
398+
f = 'weibull'
399+
scale = 'dB'
400+
stim_selection_method = 'min_entropy'
401+
param_estimation_method = 'mode'
402+
403+
q = QuestPlus(stim_domain=stim_domain, param_domain=param_domain,
404+
outcome_domain=outcome_domain, func=f, stim_scale=scale,
405+
stim_selection_method=stim_selection_method,
406+
param_estimation_method=param_estimation_method)
407+
408+
q_weibull = QuestPlusWeibull(intensities=stim_domain['intensity'],
409+
thresholds=param_domain['threshold'],
410+
slopes=param_domain['slope'],
411+
lower_asymptotes=param_domain['lower_asymptote'],
412+
lapse_rates=param_domain['lapse_rate'],
413+
responses=outcome_domain['response'],
414+
stim_scale=scale)
415+
416+
for expected_contrast, response in zip(expected_contrasts, responses):
417+
assert q.next_stim['intensity'] == q_weibull.next_intensity
418+
assert q_weibull.next_intensity == expected_contrast
419+
q.update(stim=q.next_stim,
420+
outcome=dict(response=response))
421+
q_weibull.update(intensity=q_weibull.next_intensity,
422+
response=response)
423+
424+
assert np.allclose(q.param_estimate['threshold'],
425+
expected_mode_threshold)
426+
427+
373428
if __name__ == '__main__':
374-
# test_threshold()
375-
# test_threshold_slope()
376-
# test_threshold_slope_lapse()
429+
test_threshold()
430+
test_threshold_slope()
431+
test_threshold_slope_lapse()
377432
test_mean_sd_lapse()
378433
test_spatial_contrast_sensitivity()
434+
test_weibull()

0 commit comments

Comments
 (0)