Skip to content

Commit f0c0cd3

Browse files
authored
feat: only disallow duplicate names when values don't match (#275)
1 parent 47c1b38 commit f0c0cd3

File tree

4 files changed

+67
-36
lines changed

4 files changed

+67
-36
lines changed

mpl_interactions/controller.py

Lines changed: 34 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -65,9 +65,10 @@ def __init__(
6565
self.indices = defaultdict(lambda: 0)
6666
self._update_funcs = defaultdict(list)
6767
self._user_callbacks = defaultdict(list)
68+
self._hashes = []
6869
self.add_kwargs(kwargs, slider_formats, play_buttons)
6970

70-
def add_kwargs(self, kwargs, slider_formats=None, play_buttons=None, allow_duplicates=False):
71+
def add_kwargs(self, kwargs, slider_formats=None, play_buttons=None):
7172
"""Add kwargs to the controller.
7273
7374
If you pass a redundant kwarg it will just be overwritten
@@ -94,23 +95,37 @@ def add_kwargs(self, kwargs, slider_formats=None, play_buttons=None, allow_dupli
9495
self.slider_format_strings[k] = v
9596
if self.use_ipywidgets:
9697
for k, v in kwargs.items():
97-
if k in self.params:
98-
if allow_duplicates:
99-
continue
100-
else:
101-
raise ValueError("can't overwrite an existing param in the controller")
10298
if isinstance(v, AxesWidget):
103-
self.params[k], self.controls[k], _ = process_mpl_widget(
99+
# TODO: HASHING behavior
100+
param, control, _, hash_ = process_mpl_widget(
104101
v, partial(self.slider_updated, key=k)
105102
)
103+
if k in self.params:
104+
if hash_ not in self._hashes:
105+
raise ValueError(
106+
f"kwarg {k} already exists and the new values are incompatible."
107+
)
108+
# don't need to add it because it already exists
109+
continue
110+
self.params[k], self.controls[k] = param, control
111+
self._hashes.append(hash)
106112
else:
107-
self.params[k], control = kwarg_to_ipywidget(
113+
param, control, hash_ = kwarg_to_ipywidget(
108114
k,
109115
v,
110116
partial(self.slider_updated, key=k),
111117
self.slider_format_strings[k],
112118
play_button=_play_buttons[k],
113119
)
120+
if k in self.params:
121+
if hash_ not in self._hashes:
122+
raise ValueError(
123+
f"kwarg {k} already exists and the new values are incompatible."
124+
)
125+
# don't need to add it because it already exists
126+
continue
127+
self.params[k] = param
128+
self._hashes.append(hash_)
114129
if control:
115130
self.controls[k] = control
116131
self.vbox.children = [*list(self.vbox.children), control]
@@ -123,12 +138,7 @@ def add_kwargs(self, kwargs, slider_formats=None, play_buttons=None, allow_dupli
123138
self.control_figures.append(mpl_layout[0])
124139
widget_y = 0.05
125140
for k, v in kwargs.items():
126-
if k in self.params:
127-
if allow_duplicates:
128-
continue
129-
else:
130-
raise ValueError("Can't overwrite an existing param in the controller")
131-
self.params[k], control, cb, widget_y = kwarg_to_mpl_widget(
141+
param, control, cb, widget_y, hash_ = kwarg_to_mpl_widget(
132142
mpl_layout[0],
133143
mpl_layout[1:],
134144
widget_y,
@@ -137,6 +147,15 @@ def add_kwargs(self, kwargs, slider_formats=None, play_buttons=None, allow_dupli
137147
partial(self.slider_updated, key=k),
138148
self.slider_format_strings[k],
139149
)
150+
if k in self.params:
151+
if hash_ not in self._hashes:
152+
raise ValueError(
153+
f"kwarg {k} already exists and the new values are incompatible."
154+
)
155+
# don't need to add it because it already exists
156+
continue
157+
self.params[k] = param
158+
self._hashes.append(hash_)
140159
if control:
141160
self.controls[k] = control
142161
if k == "vmin_vmax":
@@ -390,7 +409,6 @@ def gogogo_controls(
390409
slider_formats,
391410
play_buttons,
392411
extra_controls=None,
393-
allow_dupes=False,
394412
):
395413
"""
396414
Create a new controls object.
@@ -446,7 +464,7 @@ def gogogo_controls(
446464
controls.display()
447465
else:
448466
controls = ctrls.pop()
449-
controls.add_kwargs(kwargs, slider_formats, play_buttons, allow_duplicates=allow_dupes)
467+
controls.add_kwargs(kwargs, slider_formats, play_buttons)
450468
params = {k: controls.params[k] for k in keys}
451469
return controls, params
452470

mpl_interactions/generic.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -758,7 +758,6 @@ def hyperslicer(
758758
slider_format_strings,
759759
play_buttons,
760760
extra_ctrls,
761-
allow_dupes=True,
762761
)
763762
if vmin_vmax is not None:
764763
params.pop("vmin_vmax")

mpl_interactions/helpers.py

Lines changed: 25 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -204,6 +204,8 @@ def kwarg_to_ipywidget(key, val, update, slider_format_string, play_button=None)
204204
The generated widget. This may be the raw widget or a higher level container
205205
widget (e.g. HBox) depending on what widget was generated. If a fixed value is
206206
returned then control will be *None*
207+
param_hash :
208+
A hash of the possible values, to be used to check duplicates in the future.
207209
"""
208210
control = None
209211
if isinstance(val, set):
@@ -214,7 +216,7 @@ def kwarg_to_ipywidget(key, val, update, slider_format_string, play_button=None)
214216
pass
215217
else:
216218
# fixed parameter
217-
return val, None
219+
return val, None, hash(repr(val))
218220
else:
219221
val = list(val)
220222

@@ -224,15 +226,15 @@ def kwarg_to_ipywidget(key, val, update, slider_format_string, play_button=None)
224226
else:
225227
selector = widgets.Select(options=val)
226228
selector.observe(partial(update, values=val), names="index")
227-
return val[0], selector
229+
return val[0], selector, hash(repr(val))
228230
elif isinstance(val, widgets.Widget) or isinstance(val, widgets.fixed):
229231
if not hasattr(val, "value"):
230232
raise TypeError(
231233
"widgets passed as parameters must have the `value` trait."
232234
"But the widget passed for {key} does not have a `.value` attribute"
233235
)
234236
if isinstance(val, widgets.fixed):
235-
return val, None
237+
return val, None, hash(repr(val))
236238
elif (
237239
isinstance(val, widgets.Select)
238240
or isinstance(val, widgets.SelectionSlider)
@@ -242,10 +244,11 @@ def kwarg_to_ipywidget(key, val, update, slider_format_string, play_button=None)
242244
# it looks unlikely to change but still would be nice to just check
243245
# if its a subclass
244246
val.observe(partial(update, values=val.options), names="index")
247+
return val.value, val, hash(repr(val.options))
245248
else:
246249
# set values to None and hope for the best
247250
val.observe(partial(update, values=None), names="value")
248-
return val.value, val
251+
return val.value, val, hash(repr(val))
249252
# val.observe(partial(update, key=key, label=None), names=["value"])
250253
else:
251254
if isinstance(val, tuple) and val[0] in ["r", "range", "rang", "rage"]:
@@ -267,7 +270,7 @@ def kwarg_to_ipywidget(key, val, update, slider_format_string, play_button=None)
267270
)
268271
slider.observe(partial(update, values=vals), names="value")
269272
controls = widgets.HBox([slider, label])
270-
return vals[[0, -1]], controls
273+
return vals[[0, -1]], controls, hash("r" + repr(vals))
271274

272275
if isinstance(val, tuple) and len(val) in [2, 3]:
273276
# treat as an argument to linspace
@@ -279,7 +282,7 @@ def kwarg_to_ipywidget(key, val, update, slider_format_string, play_button=None)
279282
raise ValueError(f"{key} is {val.ndim}D but can only be 1D or a scalar")
280283
if len(val) == 1:
281284
# don't need to create a slider
282-
return val[0], None
285+
return val[0], None, hash(repr(val))
283286
else:
284287
# params[key] = val[0]
285288
label = widgets.Label(value=slider_format_string.format(val[0]))
@@ -299,7 +302,7 @@ def kwarg_to_ipywidget(key, val, update, slider_format_string, play_button=None)
299302
control = widgets.HBox([play, slider, label])
300303
else:
301304
control = widgets.HBox([slider, label])
302-
return val[0], control
305+
return val[0], control, hash(repr(val))
303306

304307

305308
def extract_num_options(val):
@@ -455,17 +458,19 @@ def process_mpl_widget(val, update):
455458
# oh boy do I ever not want to
456459
val.set_active(0)
457460
cb = val.on_clicked(partial(changeify_radio, labels=val.labels, update=update))
458-
return val.labels[0], val, cb
459-
elif isinstance(val, (mwidgets.Slider, RangeSlider)):
461+
return val.labels[0], val, cb, hash(repr(val.labels))
462+
elif isinstance(val, (mwidgets.Slider, mwidgets.RangeSlider, RangeSlider)):
463+
# TODO: proper inherit matplotlib rand
460464
# potential future improvement:
461465
# check if valstep has been set and then try to infer the values
462466
# but not now, I'm trying to avoid premature optimization lest this
463467
# drag on forever
464468
cb = val.on_changed(partial(changeify, update=partial(update, values=None)))
465-
return val.val, val, cb
469+
hash_ = hash(str(val.valmin) + str(val.valmax) + str(val.valstep))
470+
return val.val, val, cb, hash_
466471
else:
467472
cb = val.on_changed(partial(changeify, update=partial(update, values=None)))
468-
return val.val, val, cb
473+
return val.val, val, cb, hash(repr(val))
469474

470475

471476
def kwarg_to_mpl_widget(
@@ -512,6 +517,7 @@ def kwarg_to_mpl_widget(
512517
the callback id
513518
new_y
514519
The widget_y to use for the next pass.
520+
hash
515521
"""
516522
slider_height, radio_height, gap_height = heights
517523

@@ -525,7 +531,7 @@ def kwarg_to_mpl_widget(
525531
if isinstance(val, tuple):
526532
pass
527533
else:
528-
return val, None, None, widget_y
534+
return val, None, None, widget_y, hash(repr(val))
529535
else:
530536
val = list(val)
531537

@@ -537,10 +543,10 @@ def kwarg_to_mpl_widget(
537543
widget_y += radio_height * n + gap_height
538544
radio_buttons = mwidgets.RadioButtons(radio_ax, val, active=0)
539545
cb = radio_buttons.on_clicked(partial(changeify_radio, labels=val, update=update))
540-
return val[0], radio_buttons, cb, widget_y
546+
return val[0], radio_buttons, cb, widget_y, hash(repr(val))
541547
elif isinstance(val, mwidgets.AxesWidget):
542-
val, widget, cb = process_mpl_widget(val, update)
543-
return val, widget, cb, widget_y
548+
val, widget, cb, hash_ = process_mpl_widget(val, update)
549+
return val, widget, cb, widget_y, hash_
544550
else:
545551
slider = None
546552
if isinstance(val, tuple) and val[0] in ["r", "range", "rang", "rage"]:
@@ -552,7 +558,7 @@ def kwarg_to_mpl_widget(
552558
slider = create_mpl_range_selection_slider(slider_ax, key, vals, slider_format_string)
553559
cb = slider.on_changed(partial(changeify, update=partial(update, values=vals)))
554560
widget_y += slider_height + gap_height
555-
return vals[[0, -1]], slider, cb, widget_y
561+
return vals[[0, -1]], slider, cb, widget_y, hash(repr(vals))
556562

557563
if isinstance(val, tuple):
558564
if len(val) == 2:
@@ -569,7 +575,7 @@ def update_text(val):
569575
slider.on_changed(update_text)
570576
cb = slider.on_changed(partial(changeify, update=partial(update, values=None)))
571577
widget_y += slider_height + gap_height
572-
return min_, slider, cb, widget_y
578+
return min_, slider, cb, widget_y, hash(repr(val))
573579
elif len(val) == 3:
574580
# should warn that that doesn't make sense with matplotlib sliders
575581
min_ = val[0]
@@ -580,13 +586,13 @@ def update_text(val):
580586
raise ValueError(f"{key} is {val.ndim}D but can only be 1D or a scalar")
581587
if len(val) == 1:
582588
# don't need to create a slider
583-
return val[0], None, None, widget_y
589+
return val[0], None, None, widget_y, hash(repr(val))
584590
else:
585591
slider_ax = fig.add_axes([0.2, 0.9 - widget_y - gap_height, 0.65, slider_height])
586592
slider = create_mpl_selection_slider(slider_ax, key, val, slider_format_string)
587593
slider.on_changed(partial(changeify, update=partial(update, values=val)))
588594
widget_y += slider_height + gap_height
589-
return val[0], slider, None, widget_y
595+
return val[0], slider, None, widget_y, hash(repr(val))
590596

591597

592598
def create_slider_format_dict(slider_format_string):

tests/test_generic.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,3 +55,11 @@ def test_xr_hyperslicer_extents():
5555

5656
assert axs[1, 0].get_xlim() == axs[1, 1].get_xlim()
5757
assert axs[1, 0].get_ylim() == axs[1, 1].get_ylim()
58+
59+
60+
def test_duplicate_axis_names():
61+
plt.subplots()
62+
img_stack = np.random.rand(5, 512, 512)
63+
64+
with hyperslicer(img_stack):
65+
hyperslicer(img_stack)

0 commit comments

Comments
 (0)