Skip to content

Commit e338190

Browse files
committed
mypy fixes
1 parent db766d0 commit e338190

File tree

3 files changed

+62
-48
lines changed

3 files changed

+62
-48
lines changed

napari_ome_zarr/_reader.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,14 @@
55

66

77
import warnings
8+
from typing import Callable
89

910
import zarr
11+
1012
from .ome_zarr_reader import read_ome_zarr
1113

1214

13-
def napari_get_reader(path):
15+
def napari_get_reader(path: str | list) -> Callable | None:
1416
"""Returns a reader for supported paths that include IDR ID.
1517
1618
- URL of the form: https://uk1s3.embassy.ebi.ac.uk/idr/zarr/v0.1/ID.zarr/

napari_ome_zarr/ome_zarr_reader.py

Lines changed: 49 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
# zarr v3
22

3-
from typing import Any, Dict, List, Tuple, Union
3+
from abc import ABC
4+
from typing import Any, Callable, Dict, Iterable, List, Tuple
45
from xml.etree import ElementTree as ET
56

67
import dask.array as da
@@ -12,40 +13,42 @@
1213

1314
from .plate import get_first_field_path, get_first_well, get_pyramid_lazy
1415

15-
LayerData = Union[Tuple[Any], Tuple[Any, Dict], Tuple[Any, Dict, str]]
16+
# StrDict = Dict[str, Any]
17+
# LayerData = Union[Tuple[Any], Tuple[Any, StrDict], Tuple[Any, StrDict, str]]
18+
LayerData = Tuple[List[da.core.Array], Dict[str, Any], str]
1619

1720

18-
class Spec:
19-
def __init__(self, group: Group):
21+
class Spec(ABC):
22+
def __init__(self, group: Group) -> None:
2023
self.group = group
2124

2225
@staticmethod
2326
def matches(group: Group) -> bool:
2427
return False
2528

26-
def data(self) -> List[da.core.Array] | None:
27-
return None
29+
def data(self) -> List[da.core.Array]:
30+
return []
2831

29-
def metadata(self) -> Dict[str, Any] | None:
32+
def metadata(self) -> Dict[str, Any]:
3033
# napari layer metadata
3134
return {}
3235

33-
def children(self):
36+
def children(self) -> list["Spec"]:
3437
return []
3538

36-
def iter_nodes(self):
39+
def iter_nodes(self) -> Iterable["Spec"]:
3740
yield self
3841
for child in self.children():
3942
yield from child.iter_nodes()
4043

41-
def iter_data(self):
44+
def iter_data(self) -> Iterable[da.core.Array]:
4245
for node in self.iter_nodes():
4346
data = node.data()
4447
if data:
4548
yield data
4649

4750
@staticmethod
48-
def get_attrs(group: Group):
51+
def get_attrs(group: Group) -> dict:
4952
if "ome" in group.attrs:
5053
return group.attrs["ome"]
5154
return group.attrs
@@ -56,8 +59,8 @@ class Multiscales(Spec):
5659
def matches(group: Group) -> bool:
5760
return "multiscales" in Spec.get_attrs(group)
5861

59-
def children(self):
60-
ch = []
62+
def children(self) -> list[Spec]:
63+
ch: list[Spec] = []
6164
# test for child "labels"
6265
try:
6366
grp = self.group["labels"]
@@ -71,13 +74,13 @@ def children(self):
7174
pass
7275
return ch
7376

74-
def data(self):
77+
def data(self) -> list[da.core.Array]:
7578
attrs = Spec.get_attrs(self.group)
7679
paths = [ds["path"] for ds in attrs["multiscales"][0]["datasets"]]
7780
return [da.from_zarr(self.group[path]) for path in paths]
7881

79-
def metadata(self):
80-
rsp = {}
82+
def metadata(self) -> Dict[str, Any]:
83+
rsp: dict = {}
8184
attrs = Spec.get_attrs(self.group)
8285
axes = attrs["multiscales"][0]["axes"]
8386
atypes = [axis["type"] for axis in axes]
@@ -88,36 +91,37 @@ def metadata(self):
8891
colormaps = []
8992
ch_names = []
9093
visibles = []
91-
contrast_limits = []
94+
contrast_limits: list[tuple[int, int]] = []
9295

9396
for index, ch in enumerate(attrs["omero"]["channels"]):
9497
color = ch.get("color", None)
9598
if color is not None:
9699
rgb = [(int(color[i : i + 2], 16) / 255) for i in range(0, 6, 2)]
97100
# colormap is range: black -> rgb color
98101
colormaps.append(Colormap([[0, 0, 0], rgb]))
99-
ch_names.append(ch.get("label", f'channel_{index}'))
102+
ch_names.append(ch.get("label", f"channel_{index}"))
100103
visibles.append(ch.get("active", True))
101104

102105
window = ch.get("window", None)
103106
if window is not None:
104107
start = window.get("start", None)
105108
end = window.get("end", None)
106-
if start is None or end is None:
107-
# Disable contrast limits settings if one is missing
108-
contrast_limits = None
109-
elif contrast_limits is not None:
110-
contrast_limits.append([start, end])
109+
if start is not None and end is not None:
110+
# skip if None. Otherwise check no previous skip
111+
if len(contrast_limits) == index:
112+
contrast_limits.append((start, end))
111113

112114
if rsp.get("channel_axis") is not None:
113115
rsp["colormap"] = colormaps
114116
rsp["name"] = ch_names
115-
rsp["contrast_limits"] = contrast_limits
117+
if len(contrast_limits) > 0:
118+
rsp["contrast_limits"] = contrast_limits
116119
rsp["visible"] = visibles
117120
else:
118121
rsp["colormap"] = colormaps[0]
119122
rsp["name"] = ch_names[0]
120-
rsp["contrast_limits"] = contrast_limits[0]
123+
if len(contrast_limits) > 0:
124+
rsp["contrast_limits"] = contrast_limits[0]
121125
rsp["visible"] = visibles[0]
122126

123127
return rsp
@@ -130,7 +134,7 @@ def matches(group: Group) -> bool:
130134
# Don't consider "plate" as a Bioformats2raw layout
131135
return "bioformats2raw.layout" in attrs and "plate" not in attrs
132136

133-
def children(self):
137+
def children(self) -> list[Spec]:
134138
# lookup children from series of OME/METADATA.xml
135139
xml_data = SyncMixin()._sync(
136140
self.group.store.get(
@@ -139,7 +143,7 @@ def children(self):
139143
)
140144
# print("xml_data", xml_data.to_bytes())
141145
root = ET.fromstring(xml_data.to_bytes())
142-
rv = []
146+
rv: list[Spec] = []
143147
for child in root:
144148
# {http://www.openmicroscopy.org/Schemas/OME/2016-06}Image
145149
print(child.tag)
@@ -153,7 +157,7 @@ def children(self):
153157
return rv
154158

155159
# override to NOT yield self since node has no data
156-
def iter_nodes(self):
160+
def iter_nodes(self) -> Iterable[Spec]:
157161
for child in self.children():
158162
yield from child.iter_nodes()
159163

@@ -163,17 +167,17 @@ class Plate(Spec):
163167
def matches(group: Group) -> bool:
164168
return "plate" in Spec.get_attrs(group)
165169

166-
def data(self):
170+
def data(self) -> list[da.core.Array]:
167171
# we want to return a dask pyramid...
168172
return get_pyramid_lazy(self.group)
169173

170-
def metadata(self):
174+
def metadata(self) -> dict:
171175
well_group = get_first_well(self.group)
172176
first_field_path = get_first_field_path(well_group)
173177
image_group = well_group[first_field_path]
174178
return Multiscales(image_group).metadata()
175179

176-
def children(self):
180+
def children(self) -> list[Spec]:
177181
# Plate has children If it has labels - check one Well...
178182
# Child is PlateLabels
179183
well_group = get_first_well(self.group)
@@ -183,27 +187,28 @@ def children(self):
183187
if labels_group is not None:
184188
labels_attrs = Spec.get_attrs(labels_group)
185189
if "labels" in labels_attrs:
186-
ch = []
190+
ch: list[Spec] = []
187191
for labels_path in labels_attrs["labels"]:
188192
print("labels_path", labels_path)
189193
ch.append(PlateLabels(self.group, labels_path=labels_path))
190194
return ch
195+
return []
191196

192197

193198
class PlateLabels(Plate):
194199
def __init__(self, group: Group, labels_path: str):
195200
super().__init__(group)
196201
self.labels_path = labels_path
197202

198-
def data(self):
203+
def data(self) -> list[da.core.Array]:
199204
# return a dask pyramid...
200205
return get_pyramid_lazy(self.group, self.labels_path)
201206

202-
def children(self):
207+
def children(self) -> list[Spec]:
203208
# Need to override Plate.children()
204209
return []
205210

206-
def metadata(self) -> Dict[str, Any] | None:
211+
def metadata(self) -> dict:
207212
# override Plate metadata (no channel-axis etc)
208213
# TODO: read image-label metadata, colors etc
209214
return {
@@ -217,7 +222,7 @@ def matches(group: Group) -> bool:
217222
return "labels" in Spec.get_attrs(group)
218223

219224
# override to NOT yield self since node has no data
220-
def iter_nodes(self):
225+
def iter_nodes(self) -> Iterable[Spec]:
221226
attrs = Spec.get_attrs(self.group)
222227
for name in attrs["labels"]:
223228
g = self.group[name]
@@ -233,7 +238,7 @@ def matches(group: Group) -> bool:
233238
return False
234239
return "image-label" in Spec.get_attrs(group)
235240

236-
def metadata(self) -> Dict[str, Any] | None:
241+
def metadata(self) -> Dict[str, Any]:
237242
# override Multiscales metadata
238243
# call super
239244
ms_data = super().metadata()
@@ -242,12 +247,12 @@ def metadata(self) -> Dict[str, Any] | None:
242247
ms_data = {}
243248
return {
244249
"name": f"labels{self.group.name}",
245-
"visible": False, # labels not visible initially
250+
"visible": False, # labels not visible initially
246251
**ms_data,
247252
}
248253

249254

250-
def read_ome_zarr(root_group):
255+
def read_ome_zarr(root_group: Group) -> Callable:
251256
def f(*args: Any, **kwargs: Any) -> List[LayerData]:
252257
results: List[LayerData] = list()
253258

@@ -256,6 +261,8 @@ def f(*args: Any, **kwargs: Any) -> List[LayerData]:
256261

257262
print("Root group", root_group.attrs.asdict())
258263

264+
spec: Spec | None = None
265+
259266
if Labels.matches(root_group):
260267
# Try starting at parent Image
261268
parent_path = root_group.store.root.parent
@@ -290,11 +297,11 @@ def f(*args: Any, **kwargs: Any) -> List[LayerData]:
290297
for node in nodes:
291298
node_data = node.data()
292299
metadata = node.metadata()
300+
layer_type = "image"
293301
# print(Spec.get_attrs(node.group))
294302
if Label.matches(node.group) or isinstance(node, PlateLabels):
295-
rv: LayerData = (node_data, metadata, "labels")
296-
else:
297-
rv: LayerData = (node_data, metadata, "image")
303+
layer_type = "labels"
304+
rv: LayerData = (node_data, metadata, layer_type)
298305
results.append(rv)
299306

300307
return results

napari_ome_zarr/plate.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,16 @@
11
import dask.array as da
22
import numpy as np
3+
from numpy._typing import DTypeLike
34
from zarr import Group
45

56

6-
def get_attrs(group: Group):
7+
def get_attrs(group: Group) -> dict:
78
if "ome" in group.attrs:
89
return group.attrs["ome"]
910
return group.attrs
1011

1112

12-
def get_pyramid_lazy(plate_group, labels_path=None) -> None:
13+
def get_pyramid_lazy(plate_group: Group, labels_path: str | None = None) -> list:
1314
"""
1415
Return a pyramid of dask data, where the highest resolution is the
1516
stitched full-resolution images.
@@ -51,7 +52,11 @@ def get_pyramid_lazy(plate_group, labels_path=None) -> None:
5152

5253

5354
def get_stitched_grid(
54-
plate_group, level: int, tile_shape: tuple, numpy_type, first_field_path
55+
plate_group: Group,
56+
level: int,
57+
tile_shape: tuple,
58+
numpy_type: DTypeLike,
59+
first_field_path: str,
5560
) -> da.core.Array:
5661
plate_data = get_attrs(plate_group)["plate"]
5762
rows = plate_data.get("rows")
@@ -91,7 +96,7 @@ def get_tile(row: int, col: int) -> da.core.Array:
9196
return da.concatenate(lazy_rows, axis=len(lazy_rows[0].shape) - 2)
9297

9398

94-
def get_first_well(plate_group):
99+
def get_first_well(plate_group: Group) -> Group:
95100
plate_data = get_attrs(plate_group)["plate"]
96101
well_paths = [well["path"] for well in plate_data.get("wells")]
97102
well_paths.sort()
@@ -103,7 +108,7 @@ def get_first_well(plate_group):
103108
return well_group
104109

105110

106-
def get_first_field_path(well_group):
111+
def get_first_field_path(well_group: Group) -> str:
107112
well_data = get_attrs(well_group)["well"]
108113
if well_data is None:
109114
raise Exception("Could not find well data")

0 commit comments

Comments
 (0)