Skip to content

Commit 74ff15b

Browse files
authored
Merge pull request #139 from scipp/voxelviewer-kwargs
Support figure kwargs in FlatVoxelViewer
2 parents 6f35582 + db36bb5 commit 74ff15b

File tree

1 file changed

+13
-7
lines changed

1 file changed

+13
-7
lines changed

src/ess/dream/diagnostics.py

+13-7
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import math
66
from collections.abc import Callable, Iterable, Mapping
77
from functools import reduce
8+
from typing import Any
89

910
import ipywidgets as ipw
1011
import numpy as np
@@ -24,7 +25,11 @@ class FlatVoxelViewer(ipw.VBox):
2425
"""
2526

2627
def __init__(
27-
self, data: Mapping[str, sc.DataArray], *, rasterized: bool = True
28+
self,
29+
data: Mapping[str, sc.DataArray],
30+
*,
31+
rasterized: bool = True,
32+
**kwargs: Any,
2833
) -> None:
2934
"""Create a new viewer.
3035
@@ -35,14 +40,16 @@ def __init__(
3540
rasterized:
3641
If ``True``, the figure is rasterized which improves rendering
3742
speed but reduces resolution.
43+
**kwargs:
44+
Additional arguments passed to the plotting function.
3845
"""
3946
self._data = self._prepare_data(data)
4047
self._bank_selector = _make_bank_selector(data.keys())
4148
self._bank = self._data[self._bank_selector.value]
4249

4350
self._dim_selector = _DimensionSelector(self._bank.dims, self._update_view)
4451

45-
self._fig_kwargs = {'rasterized': rasterized}
52+
self._fig_kwargs = {'rasterized': rasterized} | kwargs
4653
self._figure_box = ipw.HBox([self._make_figure()])
4754
self._bank_selector.observe(self._select_bank, names='value')
4855

@@ -54,12 +61,12 @@ def __init__(
5461
]
5562
)
5663

57-
def _select_bank(self, *_args: object, **_kwargs: object) -> None:
64+
def _select_bank(self, *_args: Any, **_kwargs: Any) -> None:
5865
self._bank = self._data[self._bank_selector.value]
5966
self._dim_selector.set_dims(self._bank.dims)
6067
self._update_view()
6168

62-
def _update_view(self, *_args: object, **_kwargs: object) -> None:
69+
def _update_view(self, *_args: Any, **_kwargs: Any) -> None:
6370
self._figure_box.children = [self._make_figure()]
6471

6572
def _make_figure(self) -> FigureLike:
@@ -149,8 +156,7 @@ def _flat_voxel_figure(
149156
data: sc.DataArray,
150157
horizontal_dim: str,
151158
vertical_dim: str,
152-
*,
153-
rasterized: bool = True,
159+
**kwargs: Any,
154160
) -> FigureLike:
155161
kept_dims = {horizontal_dim, vertical_dim}
156162

@@ -182,7 +188,7 @@ def _flat_voxel_figure(
182188
h_labels = [str(value) for value in h_coord.values]
183189
v_labels = [str(value) for value in v_coord.values]
184190

185-
fig = flat.plot(rasterized=rasterized, cbar=True)
191+
fig = flat.plot(**kwargs)
186192

187193
fig.ax.xaxis.set_ticks(ticks=h_ticks, labels=h_labels)
188194
fig.ax.yaxis.set_ticks(ticks=v_ticks, labels=v_labels)

0 commit comments

Comments
 (0)