|
1 | 1 | import numpy as np |
2 | | -from questplus.qp import QuestPlus |
| 2 | +from questplus.qp import QuestPlus, QuestPlusWeibull |
3 | 3 |
|
4 | 4 |
|
5 | 5 | def test_threshold(): |
@@ -370,9 +370,65 @@ def test_spatial_contrast_sensitivity(): |
370 | 370 | expected_mode_cf) |
371 | 371 |
|
372 | 372 |
|
| 373 | +def test_weibull(): |
| 374 | + threshold = np.arange(-40, 0 + 1) |
| 375 | + slope, guess, lapse = 3.5, 0.5, 0.02 |
| 376 | + contrasts = threshold.copy() |
| 377 | + |
| 378 | + expected_contrasts = [-18, -22, -25, -28, -30, -22, -13, -15, -16, -18, |
| 379 | + -19, -20, -21, -22, -23, -19, -20, -20, -18, -18, |
| 380 | + -19, -17, -17, -18, -18, -18, -19, -19, -19, -19, |
| 381 | + -19, -19] |
| 382 | + |
| 383 | + responses = ['Correct', 'Correct', 'Correct', 'Correct', 'Incorrect', |
| 384 | + 'Incorrect', 'Correct', 'Correct', 'Correct', 'Correct', |
| 385 | + 'Correct', 'Correct', 'Correct', 'Correct', 'Incorrect', |
| 386 | + 'Correct', 'Correct', 'Incorrect', 'Correct', 'Correct', |
| 387 | + 'Incorrect', 'Correct', 'Correct', 'Correct', 'Correct', |
| 388 | + 'Correct', 'Correct', 'Correct', 'Correct', 'Correct', |
| 389 | + 'Correct', 'Correct'] |
| 390 | + |
| 391 | + expected_mode_threshold = -20 |
| 392 | + |
| 393 | + stim_domain = dict(intensity=contrasts) |
| 394 | + param_domain = dict(threshold=threshold, slope=slope, |
| 395 | + lower_asymptote=guess, lapse_rate=lapse) |
| 396 | + outcome_domain = dict(response=['Correct', 'Incorrect']) |
| 397 | + |
| 398 | + f = 'weibull' |
| 399 | + scale = 'dB' |
| 400 | + stim_selection_method = 'min_entropy' |
| 401 | + param_estimation_method = 'mode' |
| 402 | + |
| 403 | + q = QuestPlus(stim_domain=stim_domain, param_domain=param_domain, |
| 404 | + outcome_domain=outcome_domain, func=f, stim_scale=scale, |
| 405 | + stim_selection_method=stim_selection_method, |
| 406 | + param_estimation_method=param_estimation_method) |
| 407 | + |
| 408 | + q_weibull = QuestPlusWeibull(intensities=stim_domain['intensity'], |
| 409 | + thresholds=param_domain['threshold'], |
| 410 | + slopes=param_domain['slope'], |
| 411 | + lower_asymptotes=param_domain['lower_asymptote'], |
| 412 | + lapse_rates=param_domain['lapse_rate'], |
| 413 | + responses=outcome_domain['response'], |
| 414 | + stim_scale=scale) |
| 415 | + |
| 416 | + for expected_contrast, response in zip(expected_contrasts, responses): |
| 417 | + assert q.next_stim['intensity'] == q_weibull.next_intensity |
| 418 | + assert q_weibull.next_intensity == expected_contrast |
| 419 | + q.update(stim=q.next_stim, |
| 420 | + outcome=dict(response=response)) |
| 421 | + q_weibull.update(intensity=q_weibull.next_intensity, |
| 422 | + response=response) |
| 423 | + |
| 424 | + assert np.allclose(q.param_estimate['threshold'], |
| 425 | + expected_mode_threshold) |
| 426 | + |
| 427 | + |
373 | 428 | if __name__ == '__main__': |
374 | | - # test_threshold() |
375 | | - # test_threshold_slope() |
376 | | - # test_threshold_slope_lapse() |
| 429 | + test_threshold() |
| 430 | + test_threshold_slope() |
| 431 | + test_threshold_slope_lapse() |
377 | 432 | test_mean_sd_lapse() |
378 | 433 | test_spatial_contrast_sensitivity() |
| 434 | + test_weibull() |
0 commit comments