diff --git a/mpl_interactions/_widget_backfill.py b/mpl_interactions/_widget_backfill.py new file mode 100644 index 0000000..5e0e5ff --- /dev/null +++ b/mpl_interactions/_widget_backfill.py @@ -0,0 +1,340 @@ +""" +Implementing matplotlib widgets for back compat +""" +from matplotlib.widgets import AxesWidget +from matplotlib import cbook, ticker +import numpy as np + +# slider widgets are taken almost verbatim from https://github.com/matplotlib/matplotlib/pull/18829/files +# which was written by me - but incorporates much of the existing matplotlib slider infrastructure +class SliderBase(AxesWidget): + def __init__( + self, ax, orientation, closedmin, closedmax, valmin, valmax, valfmt, dragging, valstep + ): + if ax.name == "3d": + raise ValueError("Sliders cannot be added to 3D Axes") + + super().__init__(ax) + + self.orientation = orientation + self.closedmin = closedmin + self.closedmax = closedmax + self.valmin = valmin + self.valmax = valmax + self.valstep = valstep + self.drag_active = False + self.valfmt = valfmt + + if orientation == "vertical": + ax.set_ylim((valmin, valmax)) + axis = ax.yaxis + else: + ax.set_xlim((valmin, valmax)) + axis = ax.xaxis + + self._fmt = axis.get_major_formatter() + if not isinstance(self._fmt, ticker.ScalarFormatter): + self._fmt = ticker.ScalarFormatter() + self._fmt.set_axis(axis) + self._fmt.set_useOffset(False) # No additive offset. + self._fmt.set_useMathText(True) # x sign before multiplicative offset. + + ax.set_xticks([]) + ax.set_yticks([]) + ax.set_navigate(False) + self.connect_event("button_press_event", self._update) + self.connect_event("button_release_event", self._update) + if dragging: + self.connect_event("motion_notify_event", self._update) + self._observers = cbook.CallbackRegistry() + + def _stepped_value(self, val): + if self.valstep: + return self.valmin + round((val - self.valmin) / self.valstep) * self.valstep + return val + + def disconnect(self, cid): + """ + Remove the observer with connection id *cid* + + Parameters + ---------- + cid : int + Connection id of the observer to be removed + """ + self._observers.disconnect(cid) + + def reset(self): + """Reset the slider to the initial value""" + if self.val != self.valinit: + self.set_val(self.valinit) + + +class RangeSlider(SliderBase): + """ + A slider representing a floating point range. + + Create a slider from *valmin* to *valmax* in axes *ax*. For the slider to + remain responsive you must maintain a reference to it. Call + :meth:`on_changed` to connect to the slider event. + + Attributes + ---------- + val : tuple of float + Slider value. + """ + + def __init__( + self, + ax, + label, + valmin, + valmax, + valinit=None, + valfmt=None, + closedmin=True, + closedmax=True, + dragging=True, + valstep=None, + orientation="horizontal", + **kwargs, + ): + """ + Parameters + ---------- + ax : Axes + The Axes to put the slider in. + label : str + Slider label. + valmin : float + The minimum value of the slider. + valmax : float + The maximum value of the slider. + valinit : tuple of float or None, default: None + The initial positions of the slider. If None the initial positions + will be at the 25th and 75th percentiles of the range. + valfmt : str, default: None + %-format string used to format the slider values. If None, a + `.ScalarFormatter` is used instead. + closedmin : bool, default: True + Whether the slider interval is closed on the bottom. + closedmax : bool, default: True + Whether the slider interval is closed on the top. + dragging : bool, default: True + If True the slider can be dragged by the mouse. + valstep : float, default: None + If given, the slider will snap to multiples of *valstep*. + orientation : {'horizontal', 'vertical'}, default: 'horizontal' + The orientation of the slider. + + Notes + ----- + Additional kwargs are passed on to ``self.poly`` which is the + `~matplotlib.patches.Rectangle` that draws the slider knob. See the + `.Rectangle` documentation for valid property names (``facecolor``, + ``edgecolor``, ``alpha``, etc.). + """ + super().__init__( + ax, orientation, closedmin, closedmax, valmin, valmax, valfmt, dragging, valstep + ) + + self.val = valinit + if valinit is None: + valinit = np.array([valmin + 0.25 * valmax, valmin + 0.75 * valmax]) + else: + valinit = self._value_in_bounds(valinit) + self.val = valinit + self.valinit = valinit + if orientation == "vertical": + self.poly = ax.axhspan(valinit[0], valinit[1], 0, 1, **kwargs) + else: + self.poly = ax.axvspan(valinit[0], valinit[1], 0, 1, **kwargs) + + if orientation == "vertical": + self.label = ax.text( + 0.5, + 1.02, + label, + transform=ax.transAxes, + verticalalignment="bottom", + horizontalalignment="center", + ) + + self.valtext = ax.text( + 0.5, + -0.02, + self._format(valinit), + transform=ax.transAxes, + verticalalignment="top", + horizontalalignment="center", + ) + else: + self.label = ax.text( + -0.02, + 0.5, + label, + transform=ax.transAxes, + verticalalignment="center", + horizontalalignment="right", + ) + + self.valtext = ax.text( + 1.02, + 0.5, + self._format(valinit), + transform=ax.transAxes, + verticalalignment="center", + horizontalalignment="left", + ) + + self.set_val(valinit) + + def _min_in_bounds(self, min): + """ + Ensure the new min value is between valmin and self.val[1] + """ + if min <= self.valmin: + if not self.closedmin: + return self.val[0] + min = self.valmin + + if min > self.val[1]: + min = self.val[1] + return self._stepped_value(min) + + def _max_in_bounds(self, max): + """ + Ensure the new max value is between valmax and self.val[0] + """ + if max >= self.valmax: + if not self.closedmax: + return self.val[1] + max = self.valmax + + if max <= self.val[0]: + max = self.val[0] + return self._stepped_value(max) + + def _value_in_bounds(self, val): + return (self._min_in_bounds(val[0]), self._max_in_bounds(val[1])) + + def _update_val_from_pos(self, pos): + """ + Given a position update the *val* + """ + idx = np.argmin(np.abs(self.val - pos)) + if idx == 0: + val = self._min_in_bounds(pos) + self.set_min(val) + else: + val = self._max_in_bounds(pos) + self.set_max(val) + + def _update(self, event): + """Update the slider position.""" + if self.ignore(event) or event.button != 1: + return + + if event.name == "button_press_event" and event.inaxes == self.ax: + self.drag_active = True + event.canvas.grab_mouse(self.ax) + + if not self.drag_active: + return + + elif (event.name == "button_release_event") or ( + event.name == "button_press_event" and event.inaxes != self.ax + ): + self.drag_active = False + event.canvas.release_mouse(self.ax) + return + if self.orientation == "vertical": + self._update_val_from_pos(event.ydata) + else: + self._update_val_from_pos(event.xdata) + + def _format(self, val): + """Pretty-print *val*.""" + if self.valfmt is not None: + return (self.valfmt % val[0], self.valfmt % val[1]) + else: + # fmt.get_offset is actually the multiplicative factor, if any. + _, s1, s2, _ = self._fmt.format_ticks([self.valmin, *val, self.valmax]) + # fmt.get_offset is actually the multiplicative factor, if any. + s1 += self._fmt.get_offset() + s2 += self._fmt.get_offset() + # use raw string to avoid issues with backslashes from + return rf"({s1}, {s2})" + + def set_min(self, min): + """ + Set the lower value of the slider to *min* + + Parameters + ---------- + min : float + """ + self.set_val((min, self.val[1])) + + def set_max(self, max): + """ + Set the lower value of the slider to *max* + + Parameters + ---------- + max : float + """ + self.set_val((self.val[0], max)) + + def set_val(self, val): + """ + Set slider value to *val* + + Parameters + ---------- + val : tuple or arraylike of float + """ + val = np.sort(np.asanyarray(val)) + if val.shape != (2,): + raise ValueError(f"val must have shape (2,) but has shape {val.shape}") + val[0] = self._min_in_bounds(val[0]) + val[1] = self._max_in_bounds(val[1]) + xy = self.poly.xy + if self.orientation == "vertical": + xy[0] = 0, val[0] + xy[1] = 0, val[1] + xy[2] = 1, val[1] + xy[3] = 1, val[0] + xy[4] = 0, val[0] + else: + xy[0] = val[0], 0 + xy[1] = val[0], 1 + xy[2] = val[1], 1 + xy[3] = val[1], 0 + xy[4] = val[0], 0 + self.poly.xy = xy + self.valtext.set_text(self._format(val)) + if self.drawon: + self.ax.figure.canvas.draw_idle() + self.val = val + if self.eventson: + self._observers.process("changed", val) + + def on_changed(self, func): + """ + When the slider value is changed call *func* with the new + slider value + + Parameters + ---------- + func : callable + Function to call when slider is changed. + The function must accept a numpy array with shape (2,) float + as its argument. + + Returns + ------- + int + Connection id (which can be used to disconnect *func*) + """ + return self._observers.connect("changed", func) diff --git a/mpl_interactions/controller.py b/mpl_interactions/controller.py index 11d41b2..7e2fbef 100644 --- a/mpl_interactions/controller.py +++ b/mpl_interactions/controller.py @@ -7,13 +7,14 @@ _not_ipython = True pass from collections import defaultdict +from mpl_interactions.widgets import IndexSlider, SliderWrapper + from .helpers import ( create_slider_format_dict, - kwarg_to_ipywidget, - kwarg_to_mpl_widget, - create_mpl_controls_fig, + maybe_create_mpl_controls_axes, + kwarg_to_widget, + maybe_get_widget_for_display, notebook_backend, - process_mpl_widget, ) from functools import partial from collections.abc import Iterable @@ -30,6 +31,7 @@ def __init__( play_button_pos="right", use_ipywidgets=None, use_cache=True, + index_kwargs=[], **kwargs ): # it might make sense to also accept kwargs as a straight up arg @@ -49,7 +51,8 @@ def __init__( self.vbox = widgets.VBox([]) else: self.control_figures = [] # storage for figures made of matplotlib sliders - + if widgets: + self.vbox = widgets.VBox([]) self.use_cache = use_cache self.kwargs = kwargs self.slider_format_strings = create_slider_format_dict(slider_formats) @@ -59,16 +62,30 @@ def __init__( self.indices = defaultdict(lambda: 0) self._update_funcs = defaultdict(list) self._user_callbacks = defaultdict(list) - self.add_kwargs(kwargs, slider_formats, play_buttons) + self.add_kwargs(kwargs, slider_formats, play_buttons, index_kwargs=index_kwargs) - def add_kwargs(self, kwargs, slider_formats=None, play_buttons=None, allow_duplicates=False): + def add_kwargs( + self, + kwargs, + slider_formats=None, + play_buttons=None, + allow_duplicates=False, + index_kwargs=None, + ): """ If you pass a redundant kwarg it will just be overwritten maybe should only raise a warning rather than an error? need to implement matplotlib widgets also a big question is how to dynamically update the display of matplotlib widgets. + + Parameters + ---------- + index_kwargs : list of str or None + A list of which sliders should use an index for their callbacks. """ + if not index_kwargs: + index_kwargs = [] if isinstance(play_buttons, bool) or isinstance(play_buttons, str) or play_buttons is None: _play_buttons = defaultdict(lambda: play_buttons) elif isinstance(play_buttons, defaultdict): @@ -85,76 +102,68 @@ def add_kwargs(self, kwargs, slider_formats=None, play_buttons=None, allow_dupli slider_formats = create_slider_format_dict(slider_formats) for k, v in slider_formats.items(): self.slider_format_strings[k] = v - if self.use_ipywidgets: - for k, v in kwargs.items(): - if k in self.params: - if allow_duplicates: - continue - else: - raise ValueError("can't overwrite an existing param in the controller") - if isinstance(v, AxesWidget): - self.params[k], self.controls[k], _ = process_mpl_widget( - v, partial(self.slider_updated, key=k) - ) - else: - self.params[k], control = kwarg_to_ipywidget( - k, - v, - partial(self.slider_updated, key=k), - self.slider_format_strings[k], - play_button=_play_buttons[k], - ) - if control: - self.controls[k] = control - self.vbox.children = list(self.vbox.children) + [control] - if k == "vmin_vmax": - self.params["vmin"] = self.params["vmin_vmax"][0] - self.params["vmax"] = self.params["vmin_vmax"][1] + + if not self.use_ipywidgets: + axes, fig = maybe_create_mpl_controls_axes(kwargs) + if fig is not None: + self.control_figures.append((fig)) else: - if len(kwargs) > 0: - mpl_layout = create_mpl_controls_fig(kwargs) - self.control_figures.append(mpl_layout[0]) - widget_y = 0.05 - for k, v in kwargs.items(): - if k in self.params: - if allow_duplicates: - continue - else: - raise ValueError("Can't overwrite an existing param in the controller") - self.params[k], control, cb, widget_y = kwarg_to_mpl_widget( - mpl_layout[0], - mpl_layout[1:], - widget_y, - k, - v, - partial(self.slider_updated, key=k), - self.slider_format_strings[k], - ) - if control: - self.controls[k] = control - if k == "vmin_vmax": - self.params["vmin"] = self.params["vmin_vmax"][0] - self.params["vmax"] = self.params["vmin_vmax"][1] - - def _slider_updated(self, change, key, values): + axes = [None] * len(kwargs) + + for k, v in kwargs.items(): + if k in self.params: + if allow_duplicates: + continue + else: + raise ValueError("can't overwrite an existing param in the controller") + # TODO: accept existing mpl widget + # if isinstance(v, AxesWidget): + # self.params[k], self.controls[k], _ = process_mpl_widget( + # v, partial(self.slider_updated, key=k) + # ) + # else: + ax = axes.pop() + control = kwarg_to_widget(k, v, ax, play_button=_play_buttons[k]) + # TODO: make the try except silliness less ugly + # the complexity of hiding away the val vs value vs whatever needs to + # be hidden away somewhere - but probably not here + if k in index_kwargs: + self.params[k] = control.index + try: + control.observe(partial(self._slider_updated, key=k), names="index") + except AttributeError: + self._setup_mpl_widget_callback(control, k) + else: + self.params[k] = control.value + try: + control.observe(partial(self._slider_updated, key=k), names="value") + except AttributeError: + self._setup_mpl_widget_callback(control, k) + + if control: + self.controls[k] = control + if ax is None: + disp = maybe_get_widget_for_display(control) + if disp is not None: + self.vbox.children = list(self.vbox.children) + [disp] + if k == "vmin_vmax": + self.params["vmin"] = self.params["vmin_vmax"][0] + self.params["vmax"] = self.params["vmin_vmax"][1] + + def _setup_mpl_widget_callback(self, widget, key): + def on_changed(val): + self._slider_updated({"new": val}, key=key) + + widget.on_changed(on_changed) + + def _slider_updated(self, change, key): """ gotta also give the indices in order to support hyperslicer without horrifying contortions """ - if values is None: - self.params[key] = change["new"] - else: - c = change["new"] - if isinstance(c, tuple): - # This is for range sliders which return 2 indices - self.params[key] = values[[*change["new"]]] - if key == "vmin_vmax": - self.params["vmin"] = self.params[key][0] - self.params["vmax"] = self.params[key][1] - else: - # int casting due to a bug in numpy < 1.19 - # see https://github.com/ianhi/mpl-interactions/pull/155 - self.params[key] = values[int(change["new"])] - self.indices[key] = change["new"] + self.params[key] = change["new"] + if key == "vmin_vmax": + self.params["vmin"] = self.params[key][0] + self.params["vmax"] = self.params[key][1] if self.use_cache: cache = {} else: @@ -162,14 +171,12 @@ def _slider_updated(self, change, key, values): for f, params in self._update_funcs[key]: ps = {} - idxs = {} for k in params: ps[k] = self.params[k] - idxs[k] = self.indices[k] - f(params=ps, indices=idxs, cache=cache) + f(params=ps, cache=cache) + # TODO: see if can combine these with update_funcs for only one loop for f, params in self._user_callbacks[key]: f(**{key: self.params[key] for key in params}) - for f in self.figs[key]: f.canvas.draw_idle() @@ -255,7 +262,7 @@ def save_animation( fig : figure param : str the name of the kwarg to use to animate - interval : int, default: 2o + interval : int, default: 20 interval between frames in ms func_anim_kwargs : dict kwargs to pass the creation of the underlying FuncAnimation @@ -272,40 +279,33 @@ def save_animation( anim : matplotlib.animation.FuncAniation """ slider = self.controls[param] - ipywidgets_slider = False - if "Box" in str(slider.__class__): - for obj in slider.children: - if "Slider" in str(obj.__class__): - slider = obj - - if isinstance(slider, mSlider): - min_ = slider.valmin - max_ = slider.valmax - if slider.valstep is None: + # at this point every slider should be wrapped by at least a .widgets.WidgetWrapper + if isinstance(slider, IndexSlider): + N = len(slider.values) + + def f(i): + slider.index = i + return [] + + elif isinstance(slider, SliderWrapper): + min = slider.min + max = slider.max + if slider.step is None: n_steps = N_frames if N_frames else 200 - step = (max_ - min_) / n_steps + step = (max - min) / n_steps else: step = slider.valstep - elif "Slider" in str(slider.__class__): - ipywidgets_slider = True - min_ = slider.min - max_ = slider.max - step = slider.step + N = int((max - min) / step) + + def f(i): + slider.value = min + step * i + return [] + else: raise NotImplementedError( - "Cannot save animation for slider of type %s".format(slider.__class__.__name__) + "Cannot save animation for param of type %s".format(type(slider)) ) - N = int((max_ - min_) / step) - - def f(i): - val = min_ + step * i - if ipywidgets_slider: - slider.value = val - else: - slider.set_val(val) - return [] - repeat = func_anim_kwargs.pop("repeat", False) anim = FuncAnimation(fig, f, frames=N, interval=interval, repeat=repeat, **func_anim_kwargs) # draw then stop necessary to prevent an extra loop after finished saving @@ -344,6 +344,7 @@ def gogogo_controls( play_buttons, extra_controls=None, allow_dupes=False, + index_kwargs=[], ): if controls or (extra_controls and not all([e is None for e in extra_controls])): if extra_controls is not None: @@ -358,7 +359,13 @@ def gogogo_controls( # it was indexed by the user when passed in extra_keys = controls[1] controls = controls[0] - controls.add_kwargs(kwargs, slider_formats, play_buttons, allow_duplicates=allow_dupes) + controls.add_kwargs( + kwargs, + slider_formats, + play_buttons, + allow_duplicates=allow_dupes, + index_kwargs=index_kwargs, + ) params = {k: controls.params[k] for k in list(kwargs.keys()) + list(extra_keys)} elif isinstance(controls, list): # collected from extra controls @@ -377,14 +384,31 @@ def gogogo_controls( raise ValueError("Only one controls object may be used per function") # now we are garunteed to only have a single entry in controls, so it's ok to pop controls = controls.pop() - controls.add_kwargs(kwargs, slider_formats, play_buttons, allow_duplicates=allow_dupes) + controls.add_kwargs( + kwargs, + slider_formats, + play_buttons, + allow_duplicates=allow_dupes, + index_kwargs=index_kwargs, + ) params = {k: controls.params[k] for k in list(kwargs.keys()) + list(extra_keys)} else: - controls.add_kwargs(kwargs, slider_formats, play_buttons, allow_duplicates=allow_dupes) + controls.add_kwargs( + kwargs, + slider_formats, + play_buttons, + allow_duplicates=allow_dupes, + index_kwargs=index_kwargs, + ) params = controls.params return controls, params else: - controls = Controls(slider_formats=slider_formats, play_buttons=play_buttons, **kwargs) + controls = Controls( + slider_formats=slider_formats, + play_buttons=play_buttons, + index_kwargs=index_kwargs, + **kwargs + ) params = controls.params if display_controls: controls.display() diff --git a/mpl_interactions/generic.py b/mpl_interactions/generic.py index af91cc8..0909386 100644 --- a/mpl_interactions/generic.py +++ b/mpl_interactions/generic.py @@ -688,6 +688,7 @@ def hyperslicer( play_buttons, extra_ctrls, allow_dupes=True, + index_kwargs=list(kwargs.keys()), ) if vmin_vmax is not None: params.pop("vmin_vmax") @@ -700,7 +701,8 @@ def vmin(**kwargs): def vmax(**kwargs): return kwargs["vmax"] - def update(params, indices, cache): + def update(params, cache): + indices = params if title is not None: ax.set_title(title.format(**params)) diff --git a/mpl_interactions/helpers.py b/mpl_interactions/helpers.py index 80b15ce..7defc13 100644 --- a/mpl_interactions/helpers.py +++ b/mpl_interactions/helpers.py @@ -1,6 +1,9 @@ from collections import defaultdict from collections.abc import Callable, Iterable from functools import partial + +from ipywidgets.widgets.widget_float import FloatLogSlider +from .widgets import CategoricalWrapper, IndexSlider, WidgetWrapper, scatter_selector from numbers import Number import matplotlib.widgets as mwidgets @@ -10,14 +13,29 @@ import ipywidgets as widgets from IPython.display import display as ipy_display except ImportError: - pass + widgets = None from matplotlib import get_backend -from matplotlib.pyplot import axes, gca, gcf, figure +from matplotlib.pyplot import gca, gcf, figure from numpy.distutils.misc_util import is_sequence -from .widgets import RangeSlider +try: + from matplotlib.widgets import RangeSlider, SliderBase +except ImportError: + from ._widget_backfill import RangeSlider, SliderBase +from .widgets import RangeSlider, fixed, SliderWrapper from .utils import ioff +if widgets: + _slider_types = ( + mwidgets.Slider, + widgets.IntSlider, + widgets.FloatSlider, + widgets.FloatLogSlider, + ) + # _categorical_types = (mwidgets.RadioButtons, widgets.RadioButtons, widgets.FloatSlider, widgets.FloatLogSlider) +else: + _slider_types = mwidgets.Slider + __all__ = [ "decompose_bbox", "update_datalim_from_xy", @@ -37,9 +55,11 @@ "create_slider_format_dict", "gogogo_figure", "gogogo_display", - "create_mpl_controls_fig", + "maybe_create_mpl_controls_axes", + "maybe_get_widget_for_display", "eval_xy", "choose_fmt_str", + "kwarg_to_widget", ] @@ -256,6 +276,130 @@ def eval_xy(x_, y_, params, cache=None): return np.asanyarray(x), np.asanyarray(y) +def kwarg_to_widget(key, val, mpl_widget_ax=None, play_button=False): + """ + Parameters + ---------- + key : str + val : slider value specification + The value to be interpreted and possibly transformed into an ipywidget + mpl_widget_ax : matplotlib axis, optional + If given then create a matplotlib widget instead of an ipywidget. + play_button : bool or "left" or "right" + Whether to create a play button and where to put it. + + Returns + ------- + widget : + A widget that can be `observed` and will have a `.value` attribute + and a `.index` attribute if applicable. + """ + init_val = 0 + control = None + if isinstance(val, set): + if len(val) == 1: + val = val.pop() + if isinstance(val, tuple): + # want the categories to be ordered + pass + else: + # fixed parameter + # TODO: for mpl as well + return fixed(val) + else: + val = list(val) + + return CategoricalWrapper(val, mpl_widget_ax) + # # TODO: categorical - Make wrappers here! + # if len(val) <= 3: + # selector = widgets.RadioButtons(options=val) + # else: + # selector = widgets.Select(options=val) + # selector.observe(partial(update, values=val), names="index") + # return val[0], selector + if isinstance(val, WidgetWrapper): + return val + elif isinstance(val, scatter_selector): + return val + elif isinstance(val, _slider_types): + return SliderWrapper(val) + # TODO: categorical types + # elif isinstance(val, _categorical_types): + # return CategoricalWrapper(val) + # TODO: add a _range_slider_types + elif widgets and isinstance(val, (widgets.Widget, widgets.fixed, fixed)): + if not hasattr(val, "value"): + raise TypeError( + "widgets passed as parameters must have the `value` trait." + "But the widget passed for {key} does not have a `.value` attribute" + ) + return val + # if isinstance(val, widgets.fixed): + # return val + # TODO: elif ( + # isinstance(val, widgets.Select) + # or isinstance(val, widgets.SelectionSlider) + # or isinstance(val, widgets.RadioButtons) + # ): + # # all the selection widget inherit a private _Selection :( + # # it looks unlikely to change but still would be nice to just check + # # if its a subclass + # return val + # # val.observe(partial(update, values=val.options), names="index") + # else: + # # set values to None and hope for the best + # val.observe(partial(update, values=None), names="value") + # return val.value, val + # # val.observe(partial(update, key=key, label=None), names=["value"]) + else: + # TODO: Range sliders + # if isinstance(val, tuple) and val[0] in ["r", "range", "rang", "rage"]: + # # also check for some reasonably easy mispellings + # if isinstance(val[1], (np.ndarray, list)): + # vals = val[1] + # else: + # vals = np.linspace(*val[1:]) + # label = widgets.Label(value=str(vals[0])) + # slider = widgets.IntRangeSlider( + # value=(0, vals.size - 1), min=0, max=vals.size - 1, readout=False, description=key + # ) + # widgets.dlink( + # (slider, "value"), + # (label, "value"), + # transform=lambda x: slider_format_string.format(vals[x[0]]) + # + " - " + # + slider_format_string.format(vals[x[1]]), + # ) + # slider.observe(partial(update, values=vals), names="value") + # controls = widgets.HBox([slider, label]) + # return vals[[0, -1]], controls + + if isinstance(val, tuple) and len(val) in [2, 3]: + # treat as an argument to linspace + # idk if it's acceptable to overwrite kwargs like this + # but I think at this point kwargs is just a dict like any other + val = np.linspace(*val) + val = np.atleast_1d(val) + if val.ndim > 1: + raise ValueError(f"{key} is {val.ndim}D but can only be 1D or a scalar") + if len(val) == 1: + return fixed(val) + else: + return IndexSlider(val, key, mpl_widget_ax, play_button=play_button) + + +def maybe_get_widget_for_display(w): + """ + Check if an object can be included in an ipywidgets HBox and if so return + the approriate object + """ + if isinstance(w, WidgetWrapper): + return w._get_widget_for_display() + elif widgets and isinstance(w, widgets.Widget): + return w + return None + + def kwarg_to_ipywidget(key, val, update, slider_format_string, play_button=None): """ Parameters @@ -360,24 +504,25 @@ def kwarg_to_ipywidget(key, val, update, slider_format_string, play_button=None) return val, None else: # params[key] = val[0] - label = widgets.Label(value=slider_format_string.format(val[0])) - slider = widgets.IntSlider(min=0, max=val.size - 1, readout=False, description=key) - widgets.dlink( - (slider, "value"), - (label, "value"), - transform=lambda x: slider_format_string.format(val[x]), - ) - slider.observe(partial(update, values=val), names="value") - if play_button is not None and play_button is not False: - play = widgets.Play(min=0, max=val.size - 1, step=1) - widgets.jslink((play, "value"), (slider, "value")) - if isinstance(play_button, str) and play_button.lower() == "right": - control = widgets.HBox([slider, label, play]) - else: - control = widgets.HBox([play, slider, label]) - else: - control = widgets.HBox([slider, label]) - return val[0], control + slider = IndexSlider(val, key) + # label = widgets.Label(value=slider_format_string.format(val[0])) + # slider = widgets.IntSlider(min=0, max=val.size - 1, readout=False, description=key) + # widgets.dlink( + # (slider, "value"), + # (label, "value"), + # transform=lambda x: slider_format_string.format(val[x]), + # ) + slider.observe(partial(update, values=val), names="index") + # if play_button is not None and play_button is not False: + # play = widgets.Play(min=0, max=val.size - 1, step=1) + # widgets.jslink((play, "value"), (slider, "value")) + # if isinstance(play_button, str) and play_button.lower() == "right": + # control = widgets.HBox([slider, label, play]) + # else: + # control = widgets.HBox([play, slider, label]) + # else: + # control = widgets.HBox([slider, label]) + return val[0], slider._get_widget_for_display() def extract_num_options(val): @@ -420,7 +565,7 @@ def changeify_radio(val, labels, update): update({"new": labels.index(value)}) -def create_mpl_controls_fig(kwargs): +def maybe_create_mpl_controls_axes(kwargs): """ Returns ------- @@ -441,16 +586,20 @@ def create_mpl_controls_fig(kwargs): I think maybe the correct approach is to use transforms and actually specify things in inches - Ian 2020-09-27 """ - init_fig = gcf() n_opts = 0 n_radio = 0 n_sliders = 0 + order = [] + radio_info = [] for key, val in kwargs.items(): if isinstance(val, set): new_opts = extract_num_options(val) if new_opts > 0: n_radio += 1 n_opts += new_opts + order.append("radio") + longest_len = max(list(map(lambda x: len(list(x)), map(str, val)))) + radio_info.append((new_opts, longest_len)) elif ( not isinstance(val, mwidgets.AxesWidget) and not "ipywidgets" in str(val.__class__) # do this to avoid depending on ipywidgets @@ -458,7 +607,16 @@ def create_mpl_controls_fig(kwargs): and len(val) > 1 ): n_sliders += 1 + order.append("slider") + else: + order.append(None) + + if n_sliders == 0 and n_radio == 0: + # do we need to make anything? + # if no just return None for all the axes + return order, None + init_fig = gcf() # These are roughly the sizes used in the matplotlib widget tutorial # https://matplotlib.org/3.2.2/gallery/widgets/slider_demo.html#sphx-glr-gallery-widgets-slider-demo-py slider_in = 0.15 @@ -486,6 +644,23 @@ def create_mpl_controls_fig(kwargs): # reset the active figure - necessary to make legends behave as expected # maybe this should really be handled via axes? idk figure(init_fig.number) + widget_y = 0.05 + axes = [] + for i, o in enumerate(order): + if o == "slider": + axes.append(fig.add_axes([0.2, 0.9 - widget_y - gap_height, 0.65, slider_height])) + widget_y += slider_height + gap_height + elif o == "radio": + n, longest_len = radio_info.pop() + width = max(0.15, 0.015 * longest_len) + axes.append( + fig.add_axes([0.2, 0.9 - widget_y - radio_height * n, width, radio_height * n]) + ) + widget_y += radio_height * n + gap_height + else: + axes.append(None) + return axes, fig + return fig, slider_height, radio_height, gap_height diff --git a/mpl_interactions/pyplot.py b/mpl_interactions/pyplot.py index c033c42..df91e70 100644 --- a/mpl_interactions/pyplot.py +++ b/mpl_interactions/pyplot.py @@ -156,7 +156,7 @@ def f(x, tau): kwargs, controls, display_controls, slider_formats, play_buttons ) - def update(params, indices, cache): + def update(params, cache): if x_and_y: x_, y_ = eval_xy(x, y, params, cache) # broadcast so that we can always index @@ -374,7 +374,7 @@ def f(loc, scale): pc = PatchCollection([]) ax.add_collection(pc, autolim=True) - def update(params, indices, cache): + def update(params, cache): arr_ = callable_else_value(arr, params, cache) new_x, new_y, new_patches = simple_hist(arr_, density=density, bins=bins, weights=weights) stretch(ax, new_x, new_y) @@ -501,7 +501,7 @@ def interactive_scatter( kwargs, controls, display_controls, slider_formats, play_buttons, extra_ctrls ) - def update(params, indices, cache): + def update(params, cache): if parametric: out = callable_else_value_no_cast(x, params) if not isinstance(out, tuple): @@ -702,7 +702,7 @@ def vmin(**kwargs): def vmax(**kwargs): return kwargs["vmax"] - def update(params, indices, cache): + def update(params, cache): if isinstance(X, Callable): # check this here to avoid setting the data if we don't need to # use the callable_else_value fxn to make use of easy caching @@ -822,7 +822,7 @@ def interactive_axhline( kwargs, controls, display_controls, slider_formats, play_buttons, extra_ctrls ) - def update(params, indices, cache): + def update(params, cache): y_ = callable_else_value(y, params, cache).item() line.set_ydata([y_, y_]) xmin_ = callable_else_value(xmin, params, cache).item() @@ -919,7 +919,7 @@ def interactive_axvline( kwargs, controls, display_controls, slider_formats, play_buttons, extra_ctrls ) - def update(params, indices, cache): + def update(params, cache): x_ = callable_else_value(x, params, cache).item() line.set_xdata([x_, x_]) ymin_ = callable_else_value(ymin, params, cache).item() @@ -1007,7 +1007,7 @@ def interactive_title( kwargs, controls, display_controls, slider_formats, play_buttons ) - def update(params, indices, cache): + def update(params, cache): ax.set_title( callable_else_value_no_cast(title, params, cache).format(**params), fontdict=fontdict, @@ -1094,7 +1094,7 @@ def interactive_xlabel( kwargs, controls, display_controls, slider_formats, play_buttons ) - def update(params, indices, cache): + def update(params, cache): ax.set_xlabel( callable_else_value_no_cast(xlabel, params, cache).format(**params), fontdict=fontdict, @@ -1179,7 +1179,7 @@ def interactive_ylabel( kwargs, controls, display_controls, slider_formats, play_buttons ) - def update(params, indices, cache): + def update(params, cache): ax.set_ylabel( callable_else_value_no_cast(ylabel, params, cache).format(**params), fontdict=fontdict, diff --git a/mpl_interactions/widgets.py b/mpl_interactions/widgets.py index 044a7ee..63e0316 100644 --- a/mpl_interactions/widgets.py +++ b/mpl_interactions/widgets.py @@ -1,13 +1,48 @@ +import numpy as np +from numbers import Number + +from traitlets import ( + HasTraits, + Int, + Float, + Union, + observe, + dlink, + link, + Tuple, + Unicode, + validate, + TraitError, + Any, +) +from traittypes import Array + +try: + import ipywidgets as widgets + from ipywidgets.widgets.widget_link import jslink + from IPython.display import display +except ImportError: + widgets = None +from matplotlib import widgets as mwidgets from matplotlib.cbook import CallbackRegistry from matplotlib.widgets import AxesWidget -from matplotlib import cbook, ticker -import numpy as np +from matplotlib.ticker import ScalarFormatter + +try: + from matplotlib.widgets import RangeSlider, SliderBase +except ImportError: + from ._widget_backfill import RangeSlider, SliderBase +import matplotlib.widgets as mwidgets __all__ = [ "scatter_selector", "scatter_selector_index", "scatter_selector_value", "RangeSlider", + "SliderWrapper", + "IntSlider", + "IndexSlider", + "CategoricalWrapper", ] @@ -49,6 +84,7 @@ def __init__(self, ax, x, y, pickradius=5, which_button=1, **kwargs): def _init_val(self): self.val = (0, (self._x[0], self._y[0])) + self.value = (0, (self._x[0], self._y[0])) def _on_pick(self, event): if event.mouseevent.button == self._button: @@ -57,7 +93,7 @@ def _on_pick(self, event): y = self._y[idx] self._process(idx, (x, y)) - def _process(idx, val): + def _process(self, idx, val): self._observers.process("picked", idx, val) def on_changed(self, func): @@ -86,6 +122,7 @@ class scatter_selector_index(scatter_selector): def _init_val(self): self.val = 0 + self.value = 0 def _process(self, idx, val): self._observers.process("picked", idx) @@ -117,6 +154,7 @@ class scatter_selector_value(scatter_selector): def _init_val(self): self.val = (self._x[0], self._y[0]) + self.value = (self._x[0], self._y[0]) def _process(self, idx, val): self._observers.process("picked", val) @@ -140,336 +178,266 @@ def on_changed(self, func): return self._observers.connect("picked", lambda val: func(val)) -# slider widgets are taken almost verbatim from https://github.com/matplotlib/matplotlib/pull/18829/files -# which was written by me - but incorporates much of the existing matplotlib slider infrastructure -class SliderBase(AxesWidget): - def __init__( - self, ax, orientation, closedmin, closedmax, valmin, valmax, valfmt, dragging, valstep - ): - if ax.name == "3d": - raise ValueError("Sliders cannot be added to 3D Axes") - - super().__init__(ax) - - self.orientation = orientation - self.closedmin = closedmin - self.closedmax = closedmax - self.valmin = valmin - self.valmax = valmax - self.valstep = valstep - self.drag_active = False - self.valfmt = valfmt - - if orientation == "vertical": - ax.set_ylim((valmin, valmax)) - axis = ax.yaxis - else: - ax.set_xlim((valmin, valmax)) - axis = ax.xaxis - - self._fmt = axis.get_major_formatter() - if not isinstance(self._fmt, ticker.ScalarFormatter): - self._fmt = ticker.ScalarFormatter() - self._fmt.set_axis(axis) - self._fmt.set_useOffset(False) # No additive offset. - self._fmt.set_useMathText(True) # x sign before multiplicative offset. - - ax.set_xticks([]) - ax.set_yticks([]) - ax.set_navigate(False) - self.connect_event("button_press_event", self._update) - self.connect_event("button_release_event", self._update) - if dragging: - self.connect_event("motion_notify_event", self._update) - self._observers = cbook.CallbackRegistry() - - def _stepped_value(self, val): - if self.valstep: - return self.valmin + round((val - self.valmin) / self.valstep) * self.valstep - return val - - def disconnect(self, cid): - """ - Remove the observer with connection id *cid* - - Parameters - ---------- - cid : int - Connection id of the observer to be removed - """ - self._observers.disconnect(cid) +_gross_traits = [ + "add_traits", + "class_own_trait_events", + "class_own_traits", + "class_trait_names", + "class_traits", + "cross_validation_lock", + "has_trait", + "hold_trait_notifications", + "notify_change", + "on_trait_change", + "set_trait", + "setup_instance", + "trait_defaults", + "trait_events", + "trait_has_value", + "trait_metadata", + "trait_names", + "trait_values", + "traits", +] - def reset(self): - """Reset the slider to the initial value""" - if self.val != self.valinit: - self.set_val(self.valinit) +class HasTraitsSmallShiftTab(HasTraits): + def __dir__(self): + # hide all the cruft from traitlets for shift+Tab + return [i for i in super().__dir__() if i not in _gross_traits] -class RangeSlider(SliderBase): - """ - A slider representing a floating point range. - Create a slider from *valmin* to *valmax* in axes *ax*. For the slider to - remain responsive you must maintain a reference to it. Call - :meth:`on_changed` to connect to the slider event. +class WidgetWrapper(HasTraitsSmallShiftTab): + value = Any() - Attributes - ---------- - val : tuple of float - Slider value. - """ + def __init__(self, mpl_widget, **kwargs) -> None: + super().__init__(self) + self._mpl = mpl_widget + self._callbacks = [] - def __init__( - self, - ax, - label, - valmin, - valmax, - valinit=None, - valfmt=None, - closedmin=True, - closedmax=True, - dragging=True, - valstep=None, - orientation="horizontal", - **kwargs, - ): - """ - Parameters - ---------- - ax : Axes - The Axes to put the slider in. - label : str - Slider label. - valmin : float - The minimum value of the slider. - valmax : float - The maximum value of the slider. - valinit : tuple of float or None, default: None - The initial positions of the slider. If None the initial positions - will be at the 25th and 75th percentiles of the range. - valfmt : str, default: None - %-format string used to format the slider values. If None, a - `.ScalarFormatter` is used instead. - closedmin : bool, default: True - Whether the slider interval is closed on the bottom. - closedmax : bool, default: True - Whether the slider interval is closed on the top. - dragging : bool, default: True - If True the slider can be dragged by the mouse. - valstep : float, default: None - If given, the slider will snap to multiples of *valstep*. - orientation : {'horizontal', 'vertical'}, default: 'horizontal' - The orientation of the slider. - - Notes - ----- - Additional kwargs are passed on to ``self.poly`` which is the - `~matplotlib.patches.Rectangle` that draws the slider knob. See the - `.Rectangle` documentation for valid property names (``facecolor``, - ``edgecolor``, ``alpha``, etc.). - """ - super().__init__( - ax, orientation, closedmin, closedmax, valmin, valmax, valfmt, dragging, valstep - ) + def on_changed(self, callback): + # callback registry? + self._callbacks.append(callback) - self.val = valinit - if valinit is None: - valinit = np.array([valmin + 0.25 * valmax, valmin + 0.75 * valmax]) - else: - valinit = self._value_in_bounds(valinit) - self.val = valinit - self.valinit = valinit - if orientation == "vertical": - self.poly = ax.axhspan(valinit[0], valinit[1], 0, 1, **kwargs) + def _get_widget_for_display(self): + if self._mpl: + return None else: - self.poly = ax.axvspan(valinit[0], valinit[1], 0, 1, **kwargs) - - if orientation == "vertical": - self.label = ax.text( - 0.5, - 1.02, - label, - transform=ax.transAxes, - verticalalignment="bottom", - horizontalalignment="center", - ) + return self._raw_widget - self.valtext = ax.text( - 0.5, - -0.02, - self._format(valinit), - transform=ax.transAxes, - verticalalignment="top", - horizontalalignment="center", - ) + def _ipython_display_(self): + if self._mpl: + pass else: - self.label = ax.text( - -0.02, - 0.5, - label, - transform=ax.transAxes, - verticalalignment="center", - horizontalalignment="right", - ) + display(self._get_widget_for_display()) - self.valtext = ax.text( - 1.02, - 0.5, - self._format(valinit), - transform=ax.transAxes, - verticalalignment="center", - horizontalalignment="left", - ) + @observe("value") + def _on_changed(self, change): + for c in self._callbacks: + c(change["new"]) - self.set_val(valinit) - def _min_in_bounds(self, min): - """ - Ensure the new min value is between valmin and self.val[1] - """ - if min <= self.valmin: - if not self.closedmin: - return self.val[0] - min = self.valmin +class SliderWrapper(WidgetWrapper): + """ + A warpper class that provides a consistent interface for both + ipywidgets and matplotlib sliders. + """ - if min > self.val[1]: - min = self.val[1] - return self._stepped_value(min) + min = Union([Int(), Float(), Tuple([Int(), Int()]), Tuple(Float(), Float())]) + max = Union([Int(), Float(), Tuple([Int(), Int()]), Tuple(Float(), Float())]) + value = Union([Float(), Int(), Tuple([Int(), Int()]), Tuple(Float(), Float())]) + step = Union([Int(), Float(allow_none=True)]) + label = Unicode() + + def __init__(self, slider, readout_format=None, setup_value_callbacks=True, **kwargs): + self._mpl = isinstance(slider, (mwidgets.Slider, SliderBase)) + super().__init__(self, **kwargs) + self._raw_widget = slider + + # eventually we can just rely on SliderBase here + # for now keep both for compatibility with mpl < 3.4 + if self._mpl: + self.observe(lambda change: setattr(self._raw_widget, "valmin", change["new"]), "min") + self.observe(lambda change: setattr(self._raw_widget, "valmax", change["new"]), "max") + self.observe(lambda change: self._raw_widget.label.set_text(change["new"]), "label") + if setup_value_callbacks: + self.observe(lambda change: self._raw_widget.set_val(change["new"]), "value") + self._raw_widget.on_changed(lambda val: setattr(self, "value", val)) + self.value = self._raw_widget.val + self.min = self._raw_widget.valmin + self.max = self._raw_widget.valmax + self.step = self._raw_widget.valstep + self.label = self._raw_widget.label.get_text() + else: + if setup_value_callbacks: + link((slider, "value"), (self, "value")) + link((slider, "min"), (self, "min")) + link((slider, "max"), (self, "max")) + link((slider, "step"), (self, "step")) + link((slider, "description"), (self, "label")) + + +class IntSlider(SliderWrapper): + min = Int() + max = Int() + value = Int() + + +class SelectionWrapper(WidgetWrapper): + index = Int() + values = Array() + max_index = Int() + + def __init__(self, values, mpl_ax=None, **kwargs) -> None: + super().__init__(mpl_ax is not None, **kwargs) + self.values = values + self.value = self.values[self.index] + + @validate("value") + def _validate_value(self, proposal): + if not proposal["value"] in self.values: + raise TraitError( + f"{proposal['value']} is not in the set of values for this index slider." + " To see or change the set of valid values use the `.values` attribute" + ) + # call `int` because traitlets can't handle np int64 + self.index = int(np.where(self.values == proposal["value"])[0][0]) - def _max_in_bounds(self, max): - """ - Ensure the new max value is between valmax and self.val[0] - """ - if max >= self.valmax: - if not self.closedmax: - return self.val[1] - max = self.valmax + return proposal["value"] - if max <= self.val[0]: - max = self.val[0] - return self._stepped_value(max) + @observe("index") + def _obs_index(self, change): + # call .item because traitlets is unhappy with numpy types + self.value = self.values[change["new"]].item() - def _value_in_bounds(self, val): - return (self._min_in_bounds(val[0]), self._max_in_bounds(val[1])) + @validate("values") + def _validate_values(self, proposal): + values = proposal["value"] + if values.ndim > 1: + raise TraitError("Expected 1d array but got an array with shape %s" % (values.shape)) + self.max_index = values.shape[0] + return values - def _update_val_from_pos(self, pos): - """ - Given a position update the *val* - """ - idx = np.argmin(np.abs(self.val - pos)) - if idx == 0: - val = self._min_in_bounds(pos) - self.set_min(val) - else: - val = self._max_in_bounds(pos) - self.set_max(val) - - def _update(self, event): - """Update the slider position.""" - if self.ignore(event) or event.button != 1: - return - - if event.name == "button_press_event" and event.inaxes == self.ax: - self.drag_active = True - event.canvas.grab_mouse(self.ax) - - if not self.drag_active: - return - - elif (event.name == "button_release_event") or ( - event.name == "button_press_event" and event.inaxes != self.ax - ): - self.drag_active = False - event.canvas.release_mouse(self.ax) - return - if self.orientation == "vertical": - self._update_val_from_pos(event.ydata) - else: - self._update_val_from_pos(event.xdata) - def _format(self, val): - """Pretty-print *val*.""" - if self.valfmt is not None: - return (self.valfmt % val[0], self.valfmt % val[1]) - else: - # fmt.get_offset is actually the multiplicative factor, if any. - _, s1, s2, _ = self._fmt.format_ticks([self.valmin, *val, self.valmax]) - # fmt.get_offset is actually the multiplicative factor, if any. - s1 += self._fmt.get_offset() - s2 += self._fmt.get_offset() - # use raw string to avoid issues with backslashes from - return rf"({s1}, {s2})" - - def set_min(self, min): - """ - Set the lower value of the slider to *min* +class IndexSlider(SelectionWrapper): + """ + A slider class to index through an array of values. + """ + def __init__( + self, values, label="", mpl_slider_ax=None, readout_format=None, play_button=False + ): + """ Parameters ---------- - min : float + values : 1D arraylike + The values to index over + label : str + The slider label + mpl_slider_ax : matplotlib.axes or None + If *None* an ipywidgets slider will be created """ - self.set_val((min, self.val[1])) + super().__init__(values, mpl_ax=mpl_slider_ax) + self.values = np.atleast_1d(values) + self.readout_format = readout_format + self._scalar_formatter = ScalarFormatter(useOffset=False) + self._scalar_formatter.create_dummy_axis() + if mpl_slider_ax is not None: + # make mpl_slider + if play_button: + raise ValueError( + "Play Buttons not yet available for matplotlib sliders " + "see https://github.com/ianhi/mpl-interactions/issues/144" + ) + slider = mwidgets.Slider( + mpl_slider_ax, + label=label, + valinit=0, + valmin=0, + valmax=self.values.shape[0] - 1, + valstep=1, + ) - def set_max(self, max): - """ - Set the lower value of the slider to *max* + def onchange(val): + self.index = int(val) + slider.valtext.set_text(self._format_value(self.values[int(val)])) - Parameters - ---------- - max : float + slider.on_changed(onchange) + elif widgets: + # i've basically recreated the ipywidgets.SelectionSlider here. + slider = widgets.IntSlider( + 0, 0, self.values.shape[0] - 1, step=1, readout=False, description=label + ) + self._readout = widgets.Label(value=str(self.values[0])) + widgets.dlink( + (slider, "value"), + (self._readout, "value"), + transform=lambda x: self._format_value(self.values[x]), + ) + self._play_button = None + if play_button: + self._play_button = widgets.Play(step=1) + self._play_button_on_left = not ( + isinstance(play_button, str) and play_button == "right" + ) + jslink((slider, "value"), (self._play_button, "value")) + jslink((slider, "min"), (self._play_button, "min")) + jslink((slider, "max"), (self._play_button, "max")) + link((slider, "value"), (self, "index")) + link((slider, "max"), (self, "max_index")) + else: + raise ValueError("mpl_slider_ax cannot be None if ipywidgets is not available") + self._raw_widget = slider + + def _format_value(self, value): + if self.readout_format is None: + if isinstance(value, Number): + return self._scalar_formatter.format_data_short(value) + else: + return str(value) + return self.readout_format.format(value) + + def _get_widget_for_display(self): + if self._mpl: + return None + if self._play_button: + if self._play_button_on_left: + return widgets.HBox([self._play_button, self._raw_widget, self._readout]) + else: + return widgets.HBox([self._raw_widget, self._readout, self._play_button]) + return widgets.HBox([self._raw_widget, self._readout]) + + +# A vendored version of ipywidgets.fixed - included so don't need to depend on ipywidgets +# https://github.com/jupyter-widgets/ipywidgets/blob/e0d41f6f02324596a282bc9e4650fd7ba63c0004/ipywidgets/widgets/interaction.py#L546 +class fixed(HasTraitsSmallShiftTab): + """A pseudo-widget whose value is fixed and never synced to the client.""" + + value = Any(help="Any Python object") + description = Unicode("", help="Any Python object") + + def __init__(self, value, **kwargs): + super().__init__(value=value, **kwargs) + + def get_interact_value(self): + """Return the value for this widget which should be passed to + interactive functions. Custom widgets can change this method + to process the raw value ``self.value``. """ - self.set_val((self.val[0], max)) + return self.value - def set_val(self, val): - """ - Set slider value to *val* - Parameters - ---------- - val : tuple or arraylike of float - """ - val = np.sort(np.asanyarray(val)) - if val.shape != (2,): - raise ValueError(f"val must have shape (2,) but has shape {val.shape}") - val[0] = self._min_in_bounds(val[0]) - val[1] = self._max_in_bounds(val[1]) - xy = self.poly.xy - if self.orientation == "vertical": - xy[0] = 0, val[0] - xy[1] = 0, val[1] - xy[2] = 1, val[1] - xy[3] = 1, val[0] - xy[4] = 0, val[0] - else: - xy[0] = val[0], 0 - xy[1] = val[0], 1 - xy[2] = val[1], 1 - xy[3] = val[1], 0 - xy[4] = val[0], 0 - self.poly.xy = xy - self.valtext.set_text(self._format(val)) - if self.drawon: - self.ax.figure.canvas.draw_idle() - self.val = val - if self.eventson: - self._observers.process("changed", val) +class CategoricalWrapper(SelectionWrapper): + def __init__(self, values, mpl_ax=None, **kwargs): + super().__init__(values, mpl_ax=mpl_ax, **kwargs) - def on_changed(self, func): - """ - When the slider value is changed call *func* with the new - slider value + if mpl_ax is not None: + self._raw_widget = mwidgets.RadioButtons(mpl_ax, values) - Parameters - ---------- - func : callable - Function to call when slider is changed. - The function must accept a numpy array with shape (2,) float - as its argument. + def on_changed(label): + self.index = self._raw_widget.active - Returns - ------- - int - Connection id (which can be used to disconnect *func*) - """ - return self._observers.connect("changed", func) + self._raw_widget.on_changed(on_changed) + else: + self._raw_widget = widgets.Select(options=values) + link((self._raw_widget, "index"), (self, "index")) diff --git a/setup.cfg b/setup.cfg index c86472b..477b5fd 100644 --- a/setup.cfg +++ b/setup.cfg @@ -35,6 +35,8 @@ platforms = Linux, Mac OS X, Windows python_requires = >=3.6, <3.10 install_requires = matplotlib >= 3.3 + traitlets + traittypes packages = find: [options.extras_require]