@@ -56,8 +56,13 @@ def plot_psychometric_function(result: Result, # noqa: C901, this function is t
5656 x = np .linspace (x_data .min (), x_data .max (), num = 1000 )
5757 x_low = np .linspace (x [0 ] - extrapolate_stimulus * (x [- 1 ] - x [0 ]), x [0 ], num = 100 )
5858 x_high = np .linspace (x [- 1 ], x [- 1 ] + extrapolate_stimulus * (x [- 1 ] - x [0 ]), num = 100 )
59- y = sigmoid (np .r_ [x_low , x , x_high ], params ['threshold' ], params ['width' ])
60- y = (1 - params ['gamma' ] - params ['lambda' ]) * y + params ['gamma' ]
59+ y = sigmoid (
60+ np .r_ [x_low , x , x_high ],
61+ threshold = params ['threshold' ],
62+ width = params ['width' ],
63+ gamma = params ['gamma' ],
64+ lambd = params ['lambda' ],
65+ )
6166 ax .plot (x , y [len (x_low ):- len (x_high )], c = line_color , lw = line_width , clip_on = False )
6267 ax .plot (x_low , y [:len (x_low )], '--' , c = line_color , lw = line_width , clip_on = False )
6368 ax .plot (x_high , y [- len (x_high ):], '--' , c = line_color , lw = line_width , clip_on = False )
@@ -110,8 +115,13 @@ def _plot_residuals(x_values: np.ndarray,
110115 data = result .data
111116 sigmoid = result .configuration .make_sigmoid ()
112117
113- std_model = params ['gamma' ] + (1 - params ['lambda' ] - params ['gamma' ]) * sigmoid (
114- data [:, 0 ], params ['threshold' ], params ['width' ])
118+ std_model = sigmoid (
119+ data [:, 0 ],
120+ threshold = params ['threshold' ],
121+ width = params ['width' ],
122+ gamma = params ['gamma' ],
123+ lambd = params ['lambda' ],
124+ )
115125 deviance = data [:, 1 ] / data [:, 2 ] - std_model
116126 std_model = np .sqrt (std_model * (1 - std_model ))
117127 deviance = deviance / std_model
@@ -329,8 +339,9 @@ def plot_prior(result: Result,
329339 prior_cdf = np .cumsum (prior_val * prior_w )
330340 q25_index = np .argmax (prior_cdf > 0.25 )
331341 q75_index = np .argmax (prior_cdf > 0.75 )
342+ prior_mean = np .sum (prior_x * prior_val )/ np .sum (prior_val )
332343
333- x_percentiles = [estimate [ param ] ,
344+ x_percentiles = [prior_mean ,
334345 min (prior_x ),
335346 prior_x [q25_index ],
336347 prior_x [q75_index ],
@@ -348,8 +359,13 @@ def plot_prior(result: Result,
348359 for param_value , color in zip (x_percentiles , colors ):
349360 this_sigmoid_params = dict (sigmoid_params )
350361 this_sigmoid_params [param ] = param_value
351- y = sigmoid (sigmoid_x , this_sigmoid_params ['threshold' ], this_sigmoid_params ['width' ])
352- y = (1 - estimate ['gamma' ] - this_sigmoid_params ['lambda' ]) * y + estimate ['gamma' ]
362+ y = sigmoid (
363+ sigmoid_x ,
364+ threshold = this_sigmoid_params ['threshold' ],
365+ width = this_sigmoid_params ['width' ],
366+ gamma = estimate ['gamma' ],
367+ lambd = this_sigmoid_params ['lambda' ],
368+ )
353369 plt .plot (sigmoid_x , y , linewidth = line_width , color = color )
354370
355371 plt .scatter (data [:, 0 ], np .zeros (data [:, 0 ].shape ), s = marker_size * .75 , c = 'k' , clip_on = False )
@@ -376,13 +392,41 @@ def plot_2D_margin(result: Result,
376392 other_param_ix = tuple (i for param , i in parameter_indices .items ()
377393 if param != first_param and param != second_param )
378394 marginal_2d = np .sum (result .debug ['posteriors' ], axis = other_param_ix )
395+ extent = [result .parameter_values [second_param ][0 ], result .parameter_values [second_param ][- 1 ],
396+ result .parameter_values [first_param ][- 1 ], result .parameter_values [first_param ][0 ]]
397+
379398 if len (np .squeeze (marginal_2d ).shape ) != 2 or np .any (np .array (marginal_2d .shape ) == 1 ):
380- raise ValueError ('The marginal is not two-dimensional. Were the parameters fixed during optimization? If so, then change the arguments to parametes that were unfixed, or use plot_marginal() to obtain a 1D marginal for a parameter.' )
399+ len_first = len (result .parameter_values [first_param ])
400+ len_second = len (result .parameter_values [second_param ])
401+
402+ # if first_param is singleton, we copy the marginal into a matrix
403+ if len_first == 1 and len_second != 1 :
404+ marginal_2d = np .broadcast_to (marginal_2d ,
405+ (len (result .parameter_values [second_param ]),
406+ len (result .parameter_values [second_param ]))
407+ )
408+ extent [2 ] = 1 # replace range for a mockup range between 0 and 1
409+ extent [3 ] = 0
410+
411+ # if second param is singleton
412+ elif len_first != 1 and len_second == 1 :
413+ marginal_2d = np .broadcast_to (marginal_2d ,
414+ (len (result .parameter_values [first_param ]),
415+ len (result .parameter_values [first_param ]))
416+ )
417+ extent [0 ] = 0
418+ extent [1 ] = 1
419+
420+ # if both params are singletons, we return a matrix full of ones
421+ elif len_first == 1 and len_second == 1 :
422+ marginal_2d = np .ones ((len (result .parameter_values [first_param ]),
423+ len (result .parameter_values [second_param ]))
424+ )
425+ extent = [0 , 1 , 1 , 0 ]
381426
382427 if parameter_indices [first_param ] > parameter_indices [second_param ]:
383428 marginal_2d = np .transpose (marginal_2d )
384- extent = [result .parameter_values [second_param ][0 ], result .parameter_values [second_param ][- 1 ],
385- result .parameter_values [first_param ][- 1 ], result .parameter_values [first_param ][0 ]]
429+
386430 ax .imshow (marginal_2d , extent = extent , cmap = 'Reds_r' , aspect = 'auto' )
387431 ax .set_xlabel (_parameter_label (second_param ))
388432 ax .set_ylabel (_parameter_label (first_param ))
@@ -428,7 +472,7 @@ def plot_bias_analysis(data: np.ndarray, compare_data: np.ndarray,
428472
429473 fig = plt .figure (constrained_layout = True , figsize = (5 , 15 ))
430474 gs = fig .add_gridspec (6 , 1 )
431-
475+
432476 ax1 = fig .add_subplot (gs [0 :2 , 0 ])
433477 plot_psychometric_function (result_combined , ax = ax1 , estimate_type = estimate_type )
434478 plot_psychometric_function (result_data , ax = ax1 , line_color = [1 , 0 , 0 ], data_color = [1 , 0 , 0 ],
@@ -440,25 +484,25 @@ def plot_bias_analysis(data: np.ndarray, compare_data: np.ndarray,
440484 ax3 = fig .add_subplot (gs [3 , 0 ])
441485 ax4 = fig .add_subplot (gs [4 , 0 ])
442486 ax5 = fig .add_subplot (gs [5 , 0 ])
443-
487+
444488 axesmarginals = [ax2 , ax3 , ax4 , ax5 ]
445-
489+
446490 for param , ax in zip (['threshold' , 'width' , 'lambda' , 'gamma' ], axesmarginals ):
447491
448- plot_marginal (result_combined , param , ax = ax , plot_prior = False ,
492+ plot_marginal (result_combined , param , ax = ax , plot_prior = False ,
449493 line_color = [0 , 0 , 0 ], estimate_type = estimate_type ,
450494 plot_ci = False )
451-
495+
452496 plot_marginal (result_data , param , ax = ax , plot_prior = False ,
453- line_color = [1 , 0 , 0 ], estimate_type = estimate_type ,
497+ line_color = [1 , 0 , 0 ], estimate_type = estimate_type ,
454498 plot_ci = False )
455-
456-
499+
500+
457501 plot_marginal (result_compare_data , param , ax = ax , plot_prior = False ,
458502 line_color = [0 , 0 , 1 ], estimate_type = estimate_type ,
459503 plot_ci = False )
460-
504+
461505 for ax in axesmarginals :
462506 ax .autoscale ()
463-
507+
464508
0 commit comments