Skip to content

Commit 47c1b38

Browse files
authored
fix: ensure vmin_vmax doesn't get passed to user fxns + add vmin_vmax to scatter (#274)
1 parent b1bbdf1 commit 47c1b38

File tree

2 files changed

+36
-24
lines changed

2 files changed

+36
-24
lines changed

mpl_interactions/controller.py

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -532,11 +532,7 @@ def prep_scalars(kwarg_dict, **kwargs):
532532
kwargs[name] = _gen_f(arg.keys[0])
533533
extra_ctrls.append(arg)
534534

535-
if len(added_kwargs) == 0:
536-
# shortcircuit options
537-
def param_excluder(params, except_=None):
538-
return params
539-
540-
else:
541-
param_excluder = _gen_param_excluder(added_kwargs)
535+
# always exclude all these - this will always be matplotlib
536+
# arugments and so should never be passed to user supplied functions.
537+
param_excluder = _gen_param_excluder(list(kwargs.keys()))
542538
return kwargs, extra_ctrls, param_excluder

mpl_interactions/pyplot.py

Lines changed: 33 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -419,6 +419,7 @@ def interactive_scatter(
419419
c=None,
420420
vmin=None,
421421
vmax=None,
422+
vmin_vmax=None,
422423
alpha=None,
423424
marker=None,
424425
edgecolors=None,
@@ -452,6 +453,9 @@ def interactive_scatter(
452453
or any slider shorthand to control with a slider, or an indexed controls
453454
object to use an existing slider, or an arbitrary function of the other
454455
parameters.
456+
vmin_vmax : tuple of float
457+
Used to generate a range slider for vmin and vmax. Should be given in range slider
458+
notation: `("r", 0, 1)`.
455459
alpha : float or Callable, optional
456460
Affects all scatter points. This will compound with any alpha introduced by
457461
the ``c`` argument
@@ -516,21 +520,38 @@ def interactive_scatter(
516520
facecolors = kwargs.pop("facecolor", facecolors)
517521
edgecolors = kwargs.pop("edgecolor", edgecolors)
518522

519-
kwargs, collection_kwargs = kwarg_popper(kwargs, collection_kwargs_list)
520-
521523
ipympl = notebook_backend() or force_ipywidgets
522524
fig, ax = gogogo_figure(ipympl, ax)
523525
slider_formats = create_slider_format_dict(slider_formats)
524526

525-
extra_ctrls = []
526-
funcs, extra_ctrls, param_excluder = prep_scalars(kwargs, s=s, alpha=alpha, marker=marker)
527+
kwargs, collection_kwargs = kwarg_popper(kwargs, collection_kwargs_list)
528+
funcs, extra_ctrls, param_excluder = prep_scalars(
529+
kwargs, s=s, alpha=alpha, marker=marker, vmin=vmin, vmax=vmax
530+
)
527531
s = funcs["s"]
532+
vmin = funcs["vmin"]
533+
vmax = funcs["vmax"]
528534
alpha = funcs["alpha"]
529535
marker = funcs["marker"]
530536

537+
if vmin_vmax is not None:
538+
if isinstance(vmin_vmax, tuple) and not isinstance(vmin_vmax[0], str):
539+
vmin_vmax = ("r", *vmin_vmax)
540+
kwargs["vmin_vmax"] = vmin_vmax
541+
531542
controls, params = gogogo_controls(
532543
kwargs, controls, display_controls, slider_formats, play_buttons, extra_ctrls
533544
)
545+
if vmin_vmax is not None:
546+
params.pop("vmin_vmax")
547+
params["vmin"] = controls.params["vmin"]
548+
params["vmax"] = controls.params["vmax"]
549+
550+
def vmin(**kwargs):
551+
return kwargs["vmin"]
552+
553+
def vmax(**kwargs):
554+
return kwargs["vmax"]
534555

535556
def update(params, indices, cache):
536557
if parametric:
@@ -545,7 +566,7 @@ def update(params, indices, cache):
545566
s_ = check_callable_xy(s, x_, y_, param_excluder(params, "s"), cache)
546567
ec_ = check_callable_xy(edgecolors, x_, y_, param_excluder(params), cache)
547568
fc_ = check_callable_xy(facecolors, x_, y_, param_excluder(params), cache)
548-
a_ = check_callable_alpha(alpha, param_excluder(params, "alpha"), cache)
569+
a_ = (callable_else_value_no_cast(alpha, param_excluder(params, "alpha"), cache),)
549570
marker_ = callable_else_value_no_cast(marker, param_excluder(params), cache)
550571

551572
if marker_ is not None:
@@ -576,6 +597,10 @@ def update(params, indices, cache):
576597
scatter.set_sizes(s_)
577598
if a_ is not None:
578599
scatter.set_alpha(a_)
600+
if isinstance(vmin, Callable):
601+
scatter.norm.vmin = callable_else_value(vmin, param_excluder(params, "vmin"), cache)
602+
if isinstance(vmax, Callable):
603+
scatter.norm.vmax = callable_else_value(vmax, param_excluder(params, "vmax"), cache)
579604

580605
update_datalim_from_bbox(
581606
ax, scatter.get_datalim(ax.transData), stretch_x=stretch_x, stretch_y=stretch_y
@@ -592,14 +617,6 @@ def check_callable_xy(arg, x, y, params, cache):
592617
else:
593618
return arg
594619

595-
def check_callable_alpha(alpha_, params, cache):
596-
if isinstance(alpha_, Callable):
597-
if alpha_ not in cache:
598-
cache[alpha_] = alpha_(**param_excluder(params, "alpha"))
599-
return cache[alpha_]
600-
else:
601-
return alpha_
602-
603620
p = param_excluder(params)
604621
if parametric:
605622
out = callable_else_value_no_cast(x, p)
@@ -612,17 +629,16 @@ def check_callable_alpha(alpha_, params, cache):
612629
s_ = check_callable_xy(s, x_, y_, param_excluder(params, "s"), {})
613630
ec_ = check_callable_xy(edgecolors, x_, y_, p, {})
614631
fc_ = check_callable_xy(facecolors, x_, y_, p, {})
615-
a_ = check_callable_alpha(alpha, params, {})
616632
marker_ = callable_else_value_no_cast(marker, p, {})
617633
scatter = ax.scatter(
618634
x_,
619635
y_,
620636
c=c_,
621637
s=s_,
622-
vmin=vmin,
623-
vmax=vmax,
638+
alpha=callable_else_value_no_cast(alpha, param_excluder(params, "alpha")),
639+
vmin=callable_else_value_no_cast(vmin, param_excluder(params, "vmin")),
640+
vmax=callable_else_value_no_cast(vmax, param_excluder(params, "vmax")),
624641
marker=marker_,
625-
alpha=a_,
626642
edgecolors=ec_,
627643
facecolors=fc_,
628644
label=label,

0 commit comments

Comments
 (0)