-
Notifications
You must be signed in to change notification settings - Fork 66
Expand file tree
/
Copy pathaxes.py
More file actions
117 lines (99 loc) · 4.06 KB
/
axes.py
File metadata and controls
117 lines (99 loc) · 4.06 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
"""Axes class for validating and transforming axes"""
from typing import Any
from .format import CurrentFormat, Format
KNOWN_AXES = {"x": "space", "y": "space", "z": "space", "c": "channel", "t": "time"}
class Axes:
def __init__(
self,
axes: list[str] | list[dict[str, str]],
fmt: Format = CurrentFormat(),
) -> None:
"""
Constructor, transforms axes and validates
Raises ValueError if not valid
"""
if axes is not None:
self.axes = self._axes_to_dicts(axes)
elif fmt.version in ("0.1", "0.2"):
# strictly 5D
self.axes = self._axes_to_dicts(["t", "c", "z", "y", "x"])
self.fmt = fmt
self.validate()
def validate(self) -> None:
"""Raises ValueError if not valid"""
if self.fmt.version in ("0.1", "0.2"):
return
# check names (only enforced for version 0.3)
if self.fmt.version == "0.3":
self._validate_03()
return
self._validate_axes_types()
def to_list(
self, fmt: Format = CurrentFormat()
) -> list[str] | list[dict[str, str]]:
if fmt.version == "0.3":
return self._get_names()
return self.axes
@staticmethod
def _axes_to_dicts(axes: list[str] | list[dict[str, str]]) -> list[dict[str, str]]:
"""Returns a list of axis dicts with name and type"""
axes_dicts = []
for axis in axes:
if isinstance(axis, str):
axis_dict = {"name": axis}
if axis in KNOWN_AXES:
axis_dict["type"] = KNOWN_AXES[axis]
axes_dicts.append(axis_dict)
else:
axes_dicts.append(axis)
return axes_dicts
def _validate_axes_types(self) -> None:
"""
Validate the axes types according to the spec, version 0.4+
"""
axes_types = [axis.get("type") for axis in self.axes]
known_types = list(KNOWN_AXES.values())
unknown_types = [atype for atype in axes_types if atype not in known_types]
if len(unknown_types) > 1:
raise ValueError(
f"Too many unknown axes types. 1 allowed, found: {unknown_types}"
)
def _last_index(item: str, item_list: list[Any]) -> int:
return max(loc for loc, val in enumerate(item_list) if val == item)
if "time" in axes_types and _last_index("time", axes_types) > 0:
raise ValueError("'time' axis must be first dimension only")
if axes_types.count("channel") > 1:
raise ValueError("Only 1 axis can be type 'channel'")
if "channel" in axes_types and _last_index(
"channel", axes_types
) > axes_types.index("space"):
raise ValueError("'space' axes must come after 'channel'")
def _get_names(self) -> list[str]:
"""Returns a list of axis names"""
axes_names = []
for axis in self.axes:
if "name" not in axis:
raise ValueError(f"Axis Dict {axis} has no 'name'")
axes_names.append(axis["name"])
return axes_names
def _validate_03(self) -> None:
val_axes = tuple(self._get_names())
if len(val_axes) == 2:
if val_axes != ("y", "x"):
raise ValueError(f"2D data must have axes ('y', 'x') {val_axes}")
elif len(val_axes) == 3:
if val_axes not in [("z", "y", "x"), ("c", "y", "x"), ("t", "y", "x")]:
raise ValueError(
"3D data must have axes ('z', 'y', 'x') or ('c', 'y', 'x')"
f" or ('t', 'y', 'x'), not {val_axes}"
)
elif len(val_axes) == 4:
if val_axes not in [
("t", "z", "y", "x"),
("c", "z", "y", "x"),
("t", "c", "y", "x"),
]:
raise ValueError("4D data must have axes tzyx or czyx or tcyx")
else:
if val_axes != ("t", "c", "z", "y", "x"):
raise ValueError("5D data must have axes ('t', 'c', 'z', 'y', 'x')")