Skip to content

Commit ab94da1

Browse files
Merge pull request #17 from hoechenberger/json
NF: Add JSON dump and load, and support "equals" operator
2 parents 14527e6 + 4ee190b commit ab94da1

File tree

3 files changed

+133
-0
lines changed

3 files changed

+133
-0
lines changed

questplus/qp.py

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from typing import Optional, Sequence
22
import xarray as xr
33
import numpy as np
4+
import json_tricks
45
from copy import deepcopy
56

67
from questplus import psychometric_function
@@ -294,6 +295,69 @@ def param_estimate(self) -> dict:
294295

295296
return param_estimates
296297

298+
def to_json(self) -> str:
299+
self_copy = deepcopy(self)
300+
self_copy.prior = self_copy.prior.to_dict()
301+
self_copy.posterior = self_copy.posterior.to_dict()
302+
self_copy.likelihoods = self_copy.likelihoods.to_dict()
303+
return json_tricks.dumps(self_copy)
304+
305+
@staticmethod
306+
def from_json(data: str):
307+
loaded = json_tricks.loads(data)
308+
loaded.prior = xr.DataArray.from_dict(loaded.prior)
309+
loaded.posterior = xr.DataArray.from_dict(loaded.posterior)
310+
loaded.likelihoods = xr.DataArray.from_dict(loaded.likelihoods)
311+
return loaded
312+
313+
def __eq__(self, other):
314+
if not self.likelihoods.equals(other.likelihoods):
315+
return False
316+
317+
if not self.prior.equals(other.prior):
318+
return False
319+
320+
if not self.posterior.equals(other.posterior):
321+
return False
322+
323+
for param_name in self.param_domain.keys():
324+
if not np.array_equal(self.param_domain[param_name],
325+
other.param_domain[param_name]):
326+
return False
327+
328+
for stim_property in self.stim_domain.keys():
329+
if not np.array_equal(self.stim_domain[stim_property],
330+
other.stim_domain[stim_property]):
331+
return False
332+
333+
for outcome_name in self.outcome_domain.keys():
334+
if not np.array_equal(self.outcome_domain[outcome_name],
335+
other.outcome_domain[outcome_name]):
336+
return False
337+
338+
if self.stim_selection != other.stim_selection:
339+
return False
340+
341+
if self.stim_selection_options != other.stim_selection_options:
342+
return False
343+
344+
if self.stim_scale != other.stim_scale:
345+
return False
346+
347+
if self.stim_history != other.stim_history:
348+
return False
349+
350+
if self.resp_history != other.resp_history:
351+
return False
352+
353+
if self.param_estimation_method != other.param_estimation_method:
354+
return False
355+
356+
if self.func != other.func:
357+
return False
358+
359+
return True
360+
297361

298362
class QuestPlusWeibull(QuestPlus):
299363
def __init__(self, *,

questplus/tests/test_qp.py

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -425,10 +425,78 @@ def test_weibull():
425425
expected_mode_threshold)
426426

427427

428+
def test_eq():
429+
threshold = np.arange(-40, 0 + 1)
430+
slope, guess, lapse = 3.5, 0.5, 0.02
431+
contrasts = threshold.copy()
432+
433+
stim_domain = dict(intensity=contrasts)
434+
param_domain = dict(threshold=threshold, slope=slope,
435+
lower_asymptote=guess, lapse_rate=lapse)
436+
outcome_domain = dict(response=['Correct', 'Incorrect'])
437+
438+
f = 'weibull'
439+
scale = 'dB'
440+
stim_selection_method = 'min_entropy'
441+
param_estimation_method = 'mode'
442+
443+
q1 = QuestPlus(stim_domain=stim_domain, param_domain=param_domain,
444+
outcome_domain=outcome_domain, func=f, stim_scale=scale,
445+
stim_selection_method=stim_selection_method,
446+
param_estimation_method=param_estimation_method)
447+
448+
q2 = QuestPlus(stim_domain=stim_domain, param_domain=param_domain,
449+
outcome_domain=outcome_domain, func=f, stim_scale=scale,
450+
stim_selection_method=stim_selection_method,
451+
param_estimation_method=param_estimation_method)
452+
453+
# Add some random responses.
454+
q1.update(stim=q1.next_stim, outcome=dict(response='Correct'))
455+
q1.update(stim=q1.next_stim, outcome=dict(response='Incorrect'))
456+
q2.update(stim=q2.next_stim, outcome=dict(response='Correct'))
457+
q2.update(stim=q2.next_stim, outcome=dict(response='Incorrect'))
458+
459+
assert q1 == q2
460+
461+
462+
def test_json():
463+
threshold = np.arange(-40, 0 + 1)
464+
slope, guess, lapse = 3.5, 0.5, 0.02
465+
contrasts = threshold.copy()
466+
467+
stim_domain = dict(intensity=contrasts)
468+
param_domain = dict(threshold=threshold, slope=slope,
469+
lower_asymptote=guess, lapse_rate=lapse)
470+
outcome_domain = dict(response=['Correct', 'Incorrect'])
471+
472+
f = 'weibull'
473+
scale = 'dB'
474+
stim_selection_method = 'min_entropy'
475+
param_estimation_method = 'mode'
476+
477+
q = QuestPlus(stim_domain=stim_domain, param_domain=param_domain,
478+
outcome_domain=outcome_domain, func=f, stim_scale=scale,
479+
stim_selection_method=stim_selection_method,
480+
param_estimation_method=param_estimation_method)
481+
482+
# Add some random responses.
483+
q.update(stim=q.next_stim, outcome=dict(response='Correct'))
484+
q.update(stim=q.next_stim, outcome=dict(response='Incorrect'))
485+
486+
q_dumped = q.to_json()
487+
q_loaded = QuestPlus.from_json(q_dumped)
488+
489+
assert q_loaded == q
490+
491+
q_loaded.update(stim=q_loaded.next_stim, outcome=dict(response='Correct'))
492+
493+
428494
if __name__ == '__main__':
429495
test_threshold()
430496
test_threshold_slope()
431497
test_threshold_slope_lapse()
432498
test_mean_sd_lapse()
433499
test_spatial_contrast_sensitivity()
434500
test_weibull()
501+
test_eq()
502+
test_json()

setup.cfg

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ install_requires =
2424
numpy
2525
scipy
2626
xarray
27+
json_tricks
2728

2829
[bdist_wheel]
2930
universal = 1

0 commit comments

Comments
 (0)