Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Rename Slice5D to Interval5D and disallow undefined (slice(None)) axes #12

Draft
wants to merge 6 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions ndstructs/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from .array5D import Array5D
from .array5D import Array5D, All
from .array5D import Image, ScalarImage, LinearData, ScalarData, ScalarLine, StaticLine
from .point5D import Point5D, Shape5D, Slice5D, KeyMap
from .point5D import Point5D, Shape5D, Interval5D, SPAN, KeyMap

__version__ = "0.0.5dev0"
167 changes: 98 additions & 69 deletions ndstructs/array5D.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import itertools
from typing import Iterator, Tuple, Iterable, Optional, Union, TypeVar, Type, cast, Sequence
from typing import Iterator, Iterable, Optional, Union, TypeVar, Type, cast, Sequence
import numpy as np
from skimage import measure as skmeasure
import skimage.io
Expand All @@ -8,40 +7,35 @@
import uuid
from numbers import Number

from .point5D import Point5D, Slice5D, Shape5D, KeyMap
from .point5D import Point5D, Interval5D, Shape5D, KeyMap, SPAN
from ndstructs.utils import JsonSerializable

Arr = TypeVar("Arr", bound="Array5D")

DTYPE = Union[
Type[np.uint8],
Type[np.uint16],
Type[np.uint32],
Type[np.uint64],
Type[np.int8],
Type[np.int16],
Type[np.int32],
Type[np.int64],
Type[np.float16],
Type[np.float32],
Type[np.float64],
]

class All:
pass


SPAN_OVERRIDE = Union[SPAN, All]


class Array5D(JsonSerializable):
"""A wrapper around np.ndarray with labeled axes. Enforces 5D, even if some
dimensions are of size 1. Sliceable with Slice5D's"""
dimensions are of size 1. Sliceable with Interval5D's"""

LINEAR_RAW_AXISKEYS = "txyzc"

def __init__(self, arr: np.ndarray, axiskeys: str, location: Point5D = Point5D.zero()):
assert len(arr.shape) == len(axiskeys)
missing_keys = [key for key in Point5D.LABELS if key not in axiskeys]
self._axiskeys = "".join(missing_keys) + axiskeys
assert sorted(self._axiskeys) == sorted(Point5D.LABELS)
self.axiskeys = "".join(missing_keys) + axiskeys
assert sorted(self.axiskeys) == sorted(Point5D.LABELS)
slices = tuple([np.newaxis for key in missing_keys] + [...])
self._data = arr[slices]
self.location = location
self.shape = Shape5D(**{key: value for key, value in zip(self.axiskeys, self._data.shape)})
self.dtype = arr.dtype

def relabeled(self: Arr, keymap: KeyMap) -> Arr:
new_location = self.location.relabeled(keymap)
Expand Down Expand Up @@ -76,45 +70,33 @@ def from_file(cls: Type[Arr], filelike: io.IOBase, location: Point5D = Point5D.z
return cls(data, "yxc"[: len(data.shape)], location=location)

def __repr__(self) -> str:
return f"<{self.__class__.__name__} {self.to_slice_5d()}>"
return f"<{self.__class__.__name__} {self.interval}>"

@classmethod
def allocate(
cls: Type[Arr], slc: Union[Slice5D, Shape5D], dtype: DTYPE, axiskeys: str = Point5D.LABELS, value: int = None
cls: Type[Arr],
interval: Union[Interval5D, Shape5D],
dtype: np.dtype,
axiskeys: str = Point5D.LABELS,
value: int = None,
) -> Arr:
slc = slc.to_slice_5d() if isinstance(slc, Shape5D) else slc
interval = interval.to_interval5d() if isinstance(interval, Shape5D) else interval
assert sorted(axiskeys) == sorted(Point5D.LABELS)
assert slc.is_defined() # FIXME: Create DefinedSlice class?
arr = np.empty(slc.shape.to_tuple(axiskeys), dtype=dtype)
arr = cls(arr, axiskeys, location=slc.start)
assert interval.shape.hypervolume != float("inf")
arr = np.empty(interval.shape.to_tuple(axiskeys), dtype=dtype)
arr = cls(arr, axiskeys, location=interval.start)
if value is not None:
arr._data[...] = value
return arr

@classmethod
def allocate_like(
cls: Type[Arr], arr: "Array5D", dtype: Optional[DTYPE], axiskeys: str = "", value: int = None
cls: Type[Arr], arr: "Array5D", dtype: Optional[np.dtype], axiskeys: str = "", value: int = None
) -> Arr:
return cls.allocate(arr.roi, dtype=dtype or arr.dtype, axiskeys=axiskeys or arr.axiskeys, value=value)

@property
def dtype(self) -> Type:
return self._data.dtype

@property
def axiskeys(self) -> str:
return self._axiskeys

@property
def _shape(self) -> Tuple:
return self._data.shape

@property
def shape(self) -> Shape5D:
return Shape5D(**{key: value for key, value in zip(self.axiskeys, self._shape)})
return cls.allocate(arr.interval, dtype=dtype or arr.dtype, axiskeys=axiskeys or arr.axiskeys, value=value)

def split(self: Arr, shape: Shape5D) -> Iterator[Arr]:
for slc in self.roi.split(shape):
for slc in self.interval.split(shape):
yield self.cut(slc)

def as_mask(self) -> "Array5D":
Expand All @@ -127,7 +109,7 @@ def sample_channels(self, mask: "ScalarData") -> "LinearData":
(N, c)
where N is the number of True-valued elements in 'mask', and c is the number
of channels in self."""
assert self.shape.with_coord(c=1) == mask.shape
assert self.shape.updated(c=1) == mask.shape
assert mask.dtype == bool # FIXME: create "Mask" type?

# mask has singleton channel axis, so 'c' must be in the end to index self.raw
Expand Down Expand Up @@ -159,7 +141,7 @@ def setflags(self, *, write: bool) -> None:
self._data.setflags(write=write)

def normalized(self: Arr, step: Optional[Shape5D] = None) -> Arr:
step = step if step is not None else self.roi.with_coord(c=1, t=1).defined_with(self.shape).shape
step = step if step is not None else self.interval.updated(c=1, t=1).clamped(self.shape).shape
normalized = self.allocate(self.shape, self.dtype, self.axiskeys)
for source, dest in zip(normalized.split(step), self.split(step)):
source_raw = source.raw(self.axiskeys)
Expand All @@ -176,7 +158,7 @@ def rebuild(self: Arr, arr: np.ndarray, *, axiskeys: str, location: Point5D = No
return self.__class__(arr, axiskeys, location)

def translated(self: Arr, offset: Point5D) -> Arr:
return self.rebuild(self._data, axiskeys=self._axiskeys, location=self.location + offset)
return self.rebuild(self._data, axiskeys=self.axiskeys, location=self.location + offset)

def raw(self, axiskeys: str) -> np.ndarray:
"""Returns a raw view of the underlying np.ndarray, containing only the axes
Expand Down Expand Up @@ -214,36 +196,82 @@ def reordered(self: Arr, axiskeys: str) -> Arr:

return self.rebuild(moved_arr, axiskeys=new_axes)

def local_cut(self: Arr, roi: Slice5D, *, copy: bool = False) -> Arr:
defined_roi = roi.defined_with(self.shape)
slices = defined_roi.to_slices(self.axiskeys)
def local_cut(
self: Arr,
interval: Interval5D = None,
*,
x: Optional[SPAN_OVERRIDE] = None,
y: Optional[SPAN_OVERRIDE] = None,
z: Optional[SPAN_OVERRIDE] = None,
t: Optional[SPAN_OVERRIDE] = None,
c: Optional[SPAN_OVERRIDE] = None,
copy: bool = False,
) -> Arr:
local_interval = self.shape.to_interval5d()
interval = (interval or local_interval).updated(
x=local_interval.x if isinstance(x, All) else x,
y=local_interval.y if isinstance(y, All) else y,
z=local_interval.z if isinstance(z, All) else z,
t=local_interval.t if isinstance(t, All) else t,
c=local_interval.c if isinstance(c, All) else c,
)
slices = interval.to_slices(self.axiskeys)
if any(slc.start < 0 for slc in slices):
raise ValueError(f"Cant't cut locally with negative indices: {interval}")
if copy:
cut_data = np.copy(self._data[slices])
else:
cut_data = self._data[slices]
return self.rebuild(cut_data, axiskeys=self.axiskeys, location=self.location + defined_roi.start)

def cut(self: Arr, roi: Slice5D, *, copy: bool = False) -> Arr:
return self.local_cut(roi.translated(-self.location), copy=copy) # TODO: define before translate?
return self.rebuild(cut_data, axiskeys=self.axiskeys, location=self.location + interval.start)

def cut(
self: Arr,
interval: Interval5D = None,
*,
x: Optional[SPAN_OVERRIDE] = None,
y: Optional[SPAN_OVERRIDE] = None,
z: Optional[SPAN_OVERRIDE] = None,
t: Optional[SPAN_OVERRIDE] = None,
c: Optional[SPAN_OVERRIDE] = None,
copy: bool = False,
) -> Arr:
interval = (
(interval or self.interval)
.updated(
x=self.interval.x if isinstance(x, All) else x,
y=self.interval.y if isinstance(y, All) else y,
z=self.interval.z if isinstance(z, All) else z,
t=self.interval.t if isinstance(t, All) else t,
c=self.interval.c if isinstance(c, All) else c,
)
.translated(-self.location)
)
return self.local_cut(interval, copy=copy)

def duplicate(self: Arr) -> Arr:
return self.cut(self.roi, copy=True)

def clamped(self: Arr, roi: Slice5D) -> Arr:
return self.cut(self.roi.clamped(roi))

def to_slice_5d(self) -> Slice5D:
return self.shape.to_slice_5d().translated(self.location)
return self.cut(self.interval, copy=True)

def clamped(
self: Arr,
limits: Union[Shape5D, Interval5D, None] = None,
*,
x: Optional[SPAN] = None,
y: Optional[SPAN] = None,
z: Optional[SPAN] = None,
t: Optional[SPAN] = None,
c: Optional[SPAN] = None,
) -> Arr:
return self.cut(self.interval.clamped(limits, x=x, y=y, z=z, t=t, c=c))

@property
def roi(self) -> Slice5D:
return self.to_slice_5d()
def interval(self) -> Interval5D:
return self.shape.to_interval5d().translated(self.location)

def set(self, value: "Array5D", autocrop: bool = False, mask_value: Optional[Number] = None) -> None:
if autocrop:
value_slc = value.roi.clamped(self.roi)
value_slc = value.interval.clamped(self.interval)
value = value.cut(value_slc)
self.cut(value.roi).localSet(value.translated(-self.location), mask_value=mask_value)
self.cut(value.interval).localSet(value.translated(-self.location), mask_value=mask_value)

def localSet(self, value: "Array5D", mask_value: Optional[Number] = None) -> None:
self_raw = self.raw(Point5D.LABELS)
Expand All @@ -264,7 +292,7 @@ def as_uint8(self, normalized: bool = True) -> "Array5D":
return Array5D((self._data * multi).astype(np.uint8), axiskeys=self.axiskeys)

def get_borders(self: Arr, thickness: Shape5D) -> Iterable[Arr]:
for border_slc in self.roi.get_borders(thickness):
for border_slc in self.interval.get_borders(thickness):
yield self.cut(border_slc)

def unique_border_colors(self, border_thickness: Optional[Shape5D] = None) -> "StaticLine":
Expand All @@ -284,7 +312,7 @@ def threshold(self: Arr, threshold: float) -> Arr:
return out

def connected_components(self: Arr, background: int = 0, connectivity: str = "xyz") -> Arr:
piece_shape = self.shape.with_coord(**{axis: 1 for axis in set("xyztc").difference(connectivity)})
piece_shape = self.shape.updated(**{axis: 1 for axis in set("xyztc").difference(connectivity)})
output = Array5D.allocate_like(self, dtype=np.int64)
for piece in self.split(piece_shape):
raw = piece.raw(connectivity)
Expand All @@ -299,8 +327,9 @@ def paint_point(self, point: Point5D, value: Number, local: bool = False):
self._data[np_selection] = value

def combine(self: Arr, others: Sequence[Arr]) -> Arr:
out_roi = Slice5D.enclosing([self.roi] + [o.roi for o in others])
out = self.allocate(slc=out_roi, dtype=self.dtype, axiskeys=self.axiskeys, value=0)
"""Pastes self and others together into a single Array5D"""
out_roi = Interval5D.enclosing([self.interval] + [o.interval for o in others])
out = self.allocate(interval=out_roi, dtype=self.dtype, axiskeys=self.axiskeys, value=0)
out.set(self)
for other in others:
out.set(other)
Expand Down
13 changes: 6 additions & 7 deletions ndstructs/datasink/DataSink.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,20 @@
from abc import abstractmethod
from typing import Optional

from ndstructs import Point5D, Shape5D, Slice5D, Array5D
from ndstructs import Point5D, Shape5D, Interval5D, Array5D
from ndstructs.datasource import UnsupportedUrlException
from ndstructs.datasource.DataSource import DataSource, AddressMode
from ndstructs.datasource.DataSourceSlice import DataSourceSlice
from ndstructs.datasource.DataRoi import DataRoi


class DataSink:
def __init__(self, *, data_slice: DataSourceSlice, tile_shape: Optional[Shape5D] = None):
def __init__(self, *, data_slice: DataRoi, tile_shape: Optional[Shape5D] = None):
self.data_slice = data_slice
self.tile_shape = tile_shape or data_slice.tile_shape

def process(self, roi: Slice5D = Slice5D.all(), address_mode: AddressMode = AddressMode.BLACK) -> None:
defined_roi = roi.defined_with(self.data_slice)
assert self.data_slice.contains(defined_roi)
for piece in defined_roi.split(self.tile_shape):
def process(self, roi: Interval5D, address_mode: AddressMode = AddressMode.BLACK) -> None:
assert self.data_slice.contains(roi)
for piece in roi.split(self.tile_shape):
source_data = self.data_slice.datasource.retrieve(piece, address_mode=address_mode)
self._process_tile(source_data)

Expand Down
12 changes: 6 additions & 6 deletions ndstructs/datasink/N5DataSink.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,11 @@
from fs.base import FS
from fs.osfs import OSFS

from ndstructs.point5D import Point5D, Slice5D, Shape5D
from ndstructs.point5D import Point5D, Interval5D, Shape5D
from ndstructs.array5D import Array5D
from ndstructs.datasource.DataSource import DataSource, UnsupportedUrlException
from ndstructs.datasource.N5DataSource import N5Block
from ndstructs.datasource.DataSourceSlice import DataSourceSlice
from ndstructs.datasource.DataRoi import DataRoi
from ndstructs.datasink.DataSink import DataSink


Expand All @@ -29,7 +29,7 @@ def __init__(
self,
*,
path: Path, # dataset path, e.g. "mydata.n5/mydataset"
data_slice: DataSourceSlice,
data_slice: DataRoi,
axiskeys: str = "tzyxc",
compression_type: str = "raw",
tile_shape: Optional[Shape5D] = None,
Expand Down Expand Up @@ -77,17 +77,17 @@ def __init__(
self.filesystem.makedirs(dir_path)
created_dirs.add(dir_path)

def get_tile_dataset_path(self, global_roi: Slice5D) -> str:
def get_tile_dataset_path(self, global_roi: Interval5D) -> str:
"Gets the relative path into the n5 dataset where 'tile' should be stored"
local_roi = global_roi.translated(-self.data_slice.start)
slice_address_components = (local_roi.start // self.tile_shape).to_np(self.axiskeys[::-1]).astype(np.uint32)
return "/".join(map(str, slice_address_components))

def get_tile_dir_dataset_path(self, global_roi: Slice5D) -> str:
def get_tile_dir_dataset_path(self, global_roi: Interval5D) -> str:
return "/".join(self.get_tile_dataset_path(global_roi).split("/")[:-1])

def _process_tile(self, tile: Array5D) -> None:
tile = N5Block.fromArray5D(tile)
tile_path = self.get_tile_dataset_path(global_roi=tile.roi)
tile_path = self.get_tile_dataset_path(global_roi=tile.interval)
with self.filesystem.openbin(tile_path, "w") as f:
f.write(tile.to_n5_bytes(axiskeys=self.axiskeys, compression_type=self.compression_type))
Loading