Skip to content

Commit 4614349

Browse files
authored
Merge pull request #38 from funkelab/v0.2-dev
V0.2 dev
2 parents 81925e1 + 60c94ce commit 4614349

File tree

12 files changed

+464
-372
lines changed

12 files changed

+464
-372
lines changed

finn/layers/tracks/_tests/test_tracks.py

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -9,23 +9,30 @@
99
validate_kwargs_sorted,
1010
)
1111

12-
# def test_empty_tracks():
13-
# """Test instantiating Tracks layer without data."""
14-
# pts = Tracks()
15-
# assert pts.data.shape == (0, 4)
16-
17-
1812
data_array_2dt = np.zeros((1, 4))
1913
data_list_2dt = list(data_array_2dt)
2014
dataframe_2dt = pd.DataFrame(data=data_array_2dt, columns=["track_id", "t", "y", "x"])
2115

2216

17+
@pytest.mark.parametrize("ndim", [3, 4])
18+
def test_empty_tracks(ndim):
19+
"""Test instantiating Tracks layer without data."""
20+
pts = Tracks(data=None, ndim=ndim)
21+
assert pts.data.shape == (0, ndim + 1)
22+
23+
2324
@pytest.mark.parametrize("data", [data_array_2dt, data_list_2dt, dataframe_2dt])
2425
def test_tracks_layer_2dt_ndim(data):
2526
"""Test instantiating Tracks layer, check 2D+t dimensionality."""
2627
layer = Tracks(data)
2728
assert layer.ndim == 3
2829

30+
layer = Tracks(data, ndim=3)
31+
assert layer.ndim == 3
32+
33+
with pytest.raises(ValueError, match="Provided ndim 4 and data ndim 3 do not match."):
34+
layer = Tracks(data, ndim=4)
35+
2936

3037
data_array_3dt = np.zeros((1, 5))
3138
data_list_3dt = list(data_array_3dt)
@@ -282,7 +289,8 @@ def test_track_connex_validity() -> None:
282289
# number of tracks
283290
n_tracks = 6
284291

285-
# the number of 'False' in the track_connex array should be equal to the number of tracks
292+
# the number of 'False' in the track_connex array should be equal to the number
293+
# of tracks
286294
assert np.sum(~layer._manager.track_connex) == n_tracks
287295

288296

finn/layers/tracks/_track_utils.py

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import numpy as np
22
import numpy.typing as npt
33
import pandas as pd
4-
from scipy.sparse import coo_matrix
4+
from scipy.sparse import coo_matrix, csr_matrix
55
from scipy.spatial import cKDTree
66

77
from finn.layers.utils.layer_utils import _FeatureTable
@@ -81,6 +81,8 @@ def __init__(self, data: np.ndarray) -> None:
8181
@staticmethod
8282
def _fast_points_lookup(sorted_time: np.ndarray) -> dict[int, slice]:
8383
"""Computes a fast lookup table from time to their respective points slicing."""
84+
if sorted_time.shape[0] == 0:
85+
return {}
8486

8587
# finds where t transitions to t + 1
8688
transitions = np.nonzero(sorted_time[:-1] - sorted_time[1:])[0] + 1
@@ -131,12 +133,16 @@ def data(self, data: list | np.ndarray) -> None:
131133

132134
# make a second lookup table using a sparse matrix to convert track id
133135
# to the vertex indices
134-
self._id2idxs = coo_matrix(
135-
(
136-
np.broadcast_to(1, self.track_ids.size), # just dummy ones
137-
(self.track_ids, np.arange(self.track_ids.size)),
138-
)
139-
).tocsr()
136+
if self.data.shape[0] == 0:
137+
# make an empty csr matrix
138+
self._id2idxs = csr_matrix(np.array([[]]))
139+
else:
140+
self._id2idxs = coo_matrix(
141+
(
142+
np.broadcast_to(1, self.track_ids.size), # just dummy ones
143+
(self.track_ids, np.arange(self.track_ids.size)),
144+
)
145+
).tocsr()
140146

141147
@property
142148
def features(self) -> pd.DataFrame:
@@ -279,7 +285,8 @@ def build_tracks(self) -> None:
279285
track_connex = np.ones(self.data.shape[0], dtype=bool)
280286
track_connex[indices_new_id] = False
281287
# Add 'False' for the last entry too (end of the last track)
282-
track_connex[-1] = False
288+
if self.data.shape[0] != 0:
289+
track_connex[-1] = False
283290

284291
self._points_id = points_id
285292
self._track_vertices = track_vertices

finn/layers/tracks/tracks.py

Lines changed: 29 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ class Tracks(Layer):
5050
Optional dictionary mapping each property to a colormap for that
5151
property. This allows each property to be assigned a specific colormap,
5252
rather than having a global colormap for everything.
53-
experimental_clipping_planes : list of dicts, list of ClippingPlane, or ClippingPlaneList
53+
experimental_clipping_planes : list of dicts, list of ClippingPlane, ClippingPlaneList
5454
Each dict defines a clipping plane in 3D in data coordinates.
5555
Valid dictionary keys are {'position', 'normal', and 'enabled'}.
5656
Values on the negative side of the normal are discarded if the plane is enabled.
@@ -70,6 +70,9 @@ class Tracks(Layer):
7070
Layer metadata.
7171
name : str
7272
Name of the layer.
73+
ndim : int, optional
74+
The number of spatio-temporal dimensions of the tracks. Necessary if data is None
75+
to initialize an empty tracks layer.
7376
opacity : float
7477
Opacity of the layer visual, between 0.0 and 1.0.
7578
projection_mode : str
@@ -123,6 +126,7 @@ def __init__(
123126
head_length: int = 0,
124127
metadata=None,
125128
name=None,
129+
ndim=None,
126130
opacity=1.0,
127131
projection_mode="none",
128132
properties=None,
@@ -135,12 +139,20 @@ def __init__(
135139
units=None,
136140
visible=True,
137141
) -> None:
142+
if ndim is None and data is None:
143+
raise ValueError("Must provide ndim or data to tracks layer")
144+
138145
# if not provided with any data, set up an empty layer in 2D+t
139146
# otherwise convert the data to an np.ndarray
140-
data = np.empty((0, 4)) if data is None else np.asarray(data)
147+
data = np.empty((0, ndim + 1)) if data is None else np.asarray(data)
141148

142149
# set the track data dimensions (remove ID from data)
143-
ndim = data.shape[1] - 1
150+
if ndim is None:
151+
ndim = data.shape[1] - 1
152+
elif ndim != data.shape[1] - 1:
153+
raise ValueError(
154+
f"Provided ndim {ndim} and data ndim {data.shape[1] - 1} do not match."
155+
)
144156

145157
super().__init__(
146158
data,
@@ -225,10 +237,10 @@ def _extent_data(self) -> np.ndarray:
225237
"""
226238
if len(self.data) == 0:
227239
extrema = np.full((2, self.ndim), np.nan)
228-
else:
229-
maxs = np.max(self.data, axis=0)
230-
mins = np.min(self.data, axis=0)
231-
extrema = np.vstack([mins, maxs])
240+
return extrema
241+
maxs = np.max(self.data, axis=0)
242+
mins = np.min(self.data, axis=0)
243+
extrema = np.vstack([mins, maxs])
232244
return extrema[:, 1:]
233245

234246
def _get_ndim(self) -> int:
@@ -247,6 +259,7 @@ def _get_state(self) -> dict[str, Any]:
247259
state.update(
248260
{
249261
"data": self.data,
262+
"ndim": self.ndim,
250263
"properties": self.properties,
251264
"graph": self.graph,
252265
"color_by": self.color_by,
@@ -299,7 +312,11 @@ def _update_thumbnail(self) -> None:
299312
colormapped = np.zeros(self._thumbnail_shape)
300313
colormapped[..., 3] = 1
301314

302-
if self._view_data is not None and self.track_colors is not None:
315+
if (
316+
self._view_data is not None
317+
and self.track_colors is not None
318+
and self.data.shape[0] > 0
319+
):
303320
de = self._extent_data
304321
min_vals = [de[0, i] for i in self._slice_input.displayed]
305322
shape = np.ceil(
@@ -599,7 +616,8 @@ def _norm(p):
599616
else:
600617
# if we don't have a colormap, get one and scale the properties
601618
colormap = AVAILABLE_COLORMAPS[self.colormap]
602-
vertex_properties = _norm(vertex_properties)
619+
if vertex_properties.size > 0:
620+
vertex_properties = _norm(vertex_properties)
603621

604622
# actually set the vertex colors
605623
self._track_colors = colormap.map(vertex_properties)
@@ -648,7 +666,8 @@ def _check_color_by_in_features(self) -> None:
648666
warn(
649667
(
650668
trans._(
651-
"Previous color_by key {key!r} not present in features. Falling back to track_id",
669+
"Previous color_by key {key!r} not present in features. "
670+
"Falling back to track_id",
652671
deferred=True,
653672
key=self._color_by,
654673
)

finn/track_application_menus/main_app.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,12 +14,12 @@ class MainApp(QWidget):
1414
def __init__(self, viewer: finn.Viewer):
1515
super().__init__()
1616

17-
menu_widget = MenuWidget(viewer)
17+
self.menu_widget = MenuWidget(viewer)
1818
tree_widget = TreeWidget(viewer)
1919

2020
viewer.window.add_dock_widget(tree_widget, area="bottom", name="Tree View")
2121

2222
layout = QVBoxLayout()
23-
layout.addWidget(menu_widget)
23+
layout.addWidget(self.menu_widget)
2424

2525
self.setLayout(layout)

finn/track_application_menus/menu_widget.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -18,16 +18,16 @@ def __init__(self, viewer: finn.Viewer):
1818
# motile_widget = MotileWidget(viewer)
1919
editing_widget = EditingMenu(viewer)
2020

21-
tabwidget = QTabWidget()
21+
self.tabwidget = QTabWidget()
2222

2323
# tabwidget.addTab(motile_widget, "Track with Motile")
24-
tabwidget.addTab(tracks_viewer.tracks_list, "Tracks List")
25-
tabwidget.addTab(editing_widget, "Edit Tracks")
24+
self.tabwidget.addTab(tracks_viewer.tracks_list, "Tracks List")
25+
self.tabwidget.addTab(editing_widget, "Edit Tracks")
2626

2727
layout = QVBoxLayout()
28-
layout.addWidget(tabwidget)
28+
layout.addWidget(self.tabwidget)
2929

30-
self.setWidget(tabwidget)
30+
self.setWidget(self.tabwidget)
3131
self.setWidgetResizable(True)
3232

3333
self.setLayout(layout)

finn_builtins/_ndims_balls.py

Lines changed: 0 additions & 29 deletions
This file was deleted.

finn_builtins/_skimage_data.py

Lines changed: 0 additions & 58 deletions
This file was deleted.
Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
import numpy as np
2+
import pytest
3+
4+
from finn_builtins.example_data import (
5+
Fluo_N2DL_HeLa,
6+
Fluo_N2DL_HeLa_crop,
7+
Mouse_Embryo_Membrane,
8+
delete_all,
9+
)
10+
11+
12+
@pytest.mark.slow
13+
def test_sample_data():
14+
delete_all()
15+
raw_layer_data, seg_layer_data = Mouse_Embryo_Membrane()
16+
raw_data = raw_layer_data[0]
17+
seg_data = seg_layer_data[0]
18+
shape = (117, 123, 127, 127)
19+
assert raw_data.shape == shape
20+
assert seg_data.shape == shape
21+
assert raw_data.dtype == np.uint16
22+
assert seg_data.dtype == np.uint16
23+
24+
25+
@pytest.mark.slow
26+
@pytest.mark.parametrize(
27+
("ds_function", "img_shape", "point_shape"),
28+
[
29+
(Fluo_N2DL_HeLa, (92, 700, 1100), (8602, 3)),
30+
(Fluo_N2DL_HeLa_crop, (92, 210, 340), (1266, 3)),
31+
],
32+
)
33+
def test_Fluo_N2DL_Hela(ds_function, img_shape, point_shape):
34+
delete_all()
35+
raw_layer_data, seg_layer_data, points_layer_data = ds_function()
36+
raw_data = raw_layer_data[0]
37+
seg_data = seg_layer_data[0]
38+
points = points_layer_data[0]
39+
assert raw_data.shape == img_shape
40+
assert seg_data.shape == img_shape
41+
assert raw_data.dtype == np.uint16
42+
assert seg_data.dtype == np.uint16
43+
assert points.shape == point_shape

0 commit comments

Comments
 (0)