5
5
import math
6
6
from collections .abc import Callable , Iterable , Mapping
7
7
from functools import reduce
8
+ from typing import Any
8
9
9
10
import ipywidgets as ipw
10
11
import numpy as np
@@ -24,7 +25,11 @@ class FlatVoxelViewer(ipw.VBox):
24
25
"""
25
26
26
27
def __init__ (
27
- self , data : Mapping [str , sc .DataArray ], * , rasterized : bool = True
28
+ self ,
29
+ data : Mapping [str , sc .DataArray ],
30
+ * ,
31
+ rasterized : bool = True ,
32
+ ** kwargs : Any ,
28
33
) -> None :
29
34
"""Create a new viewer.
30
35
@@ -35,14 +40,16 @@ def __init__(
35
40
rasterized:
36
41
If ``True``, the figure is rasterized which improves rendering
37
42
speed but reduces resolution.
43
+ **kwargs:
44
+ Additional arguments passed to the plotting function.
38
45
"""
39
46
self ._data = self ._prepare_data (data )
40
47
self ._bank_selector = _make_bank_selector (data .keys ())
41
48
self ._bank = self ._data [self ._bank_selector .value ]
42
49
43
50
self ._dim_selector = _DimensionSelector (self ._bank .dims , self ._update_view )
44
51
45
- self ._fig_kwargs = {'rasterized' : rasterized }
52
+ self ._fig_kwargs = {'rasterized' : rasterized } | kwargs
46
53
self ._figure_box = ipw .HBox ([self ._make_figure ()])
47
54
self ._bank_selector .observe (self ._select_bank , names = 'value' )
48
55
@@ -54,12 +61,12 @@ def __init__(
54
61
]
55
62
)
56
63
57
- def _select_bank (self , * _args : object , ** _kwargs : object ) -> None :
64
+ def _select_bank (self , * _args : Any , ** _kwargs : Any ) -> None :
58
65
self ._bank = self ._data [self ._bank_selector .value ]
59
66
self ._dim_selector .set_dims (self ._bank .dims )
60
67
self ._update_view ()
61
68
62
- def _update_view (self , * _args : object , ** _kwargs : object ) -> None :
69
+ def _update_view (self , * _args : Any , ** _kwargs : Any ) -> None :
63
70
self ._figure_box .children = [self ._make_figure ()]
64
71
65
72
def _make_figure (self ) -> FigureLike :
@@ -149,8 +156,7 @@ def _flat_voxel_figure(
149
156
data : sc .DataArray ,
150
157
horizontal_dim : str ,
151
158
vertical_dim : str ,
152
- * ,
153
- rasterized : bool = True ,
159
+ ** kwargs : Any ,
154
160
) -> FigureLike :
155
161
kept_dims = {horizontal_dim , vertical_dim }
156
162
@@ -182,7 +188,7 @@ def _flat_voxel_figure(
182
188
h_labels = [str (value ) for value in h_coord .values ]
183
189
v_labels = [str (value ) for value in v_coord .values ]
184
190
185
- fig = flat .plot (rasterized = rasterized , cbar = True )
191
+ fig = flat .plot (** kwargs )
186
192
187
193
fig .ax .xaxis .set_ticks (ticks = h_ticks , labels = h_labels )
188
194
fig .ax .yaxis .set_ticks (ticks = v_ticks , labels = v_labels )
0 commit comments