Skip to content

Commit 3dfb2c0

Browse files
asarigunpre-commit-ci[bot]LucaMarconato
authored
Support 3D visualization for points and 2.5D shapes (#393)
* fix the Issue #31 * fix Issue #31 * fix Issue #31 * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * address review feedback PR#393 --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Luca Marconato <m.lucalmer@gmail.com> Co-authored-by: LucaMarconato <2664412+LucaMarconato@users.noreply.github.com>
1 parent 88edca3 commit 3dfb2c0

5 files changed

Lines changed: 454 additions & 66 deletions

File tree

src/napari_spatialdata/_sdata_widgets.py

Lines changed: 56 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,11 +15,20 @@
1515
from packaging.version import parse as parse_version
1616
from qtpy.QtCore import QThread, Signal
1717
from qtpy.QtGui import QIcon
18-
from qtpy.QtWidgets import QLabel, QListWidget, QListWidgetItem, QProgressBar, QVBoxLayout, QWidget
18+
from qtpy.QtWidgets import (
19+
QCheckBox,
20+
QLabel,
21+
QListWidget,
22+
QListWidgetItem,
23+
QProgressBar,
24+
QVBoxLayout,
25+
QWidget,
26+
)
1927
from spatialdata import SpatialData
2028
from spatialdata.models._utils import DEFAULT_COORDINATE_SYSTEM
2129

2230
from napari_spatialdata._viewer import SpatialDataViewer
31+
from napari_spatialdata.constants import config
2332
from napari_spatialdata.constants.config import N_CIRCLES_WARNING_THRESHOLD, N_SHAPES_WARNING_THRESHOLD
2433
from napari_spatialdata.utils._utils import _get_sdata_key, get_duplicate_element_names, get_elements_meta_mapping
2534

@@ -174,11 +183,39 @@ def __init__(self, viewer: Viewer, sdata: EventedList):
174183
self.slider.setRange(0, 0)
175184
self.slider.setVisible(False)
176185

186+
self.discard_z_points = QCheckBox("Discard z for 3D points")
187+
self.discard_z_points.setChecked(config.PROJECT_3D_POINTS_TO_2D)
188+
self.discard_z_points.setToolTip(
189+
"When checked, the z coordinate of new points layers is discarded so they are loaded in 2D. "
190+
"Only applies to new layers; layers already displayed are not affected."
191+
)
192+
self.discard_z_points.toggled.connect(self._on_discard_z_points_toggled)
193+
194+
self.discard_z_shapes = QCheckBox("Discard z for 2.5D shapes")
195+
self.discard_z_shapes.setChecked(config.PROJECT_2_5D_SHAPES_TO_2D)
196+
self.discard_z_shapes.setToolTip(
197+
"When checked, the z coordinate of new shapes layers is discarded so they are loaded in 2D. "
198+
"Only applies to new layers; layers already displayed are not affected."
199+
)
200+
self.discard_z_shapes.toggled.connect(self._on_discard_z_shapes_toggled)
201+
202+
# The 3D toggles only matter when at least one element across the loaded
203+
# SpatialData objects has a z axis. Otherwise we hide them to save screen
204+
# real estate for users working with 2D-only data.
205+
self._has_z_data = self._sdatas_have_z_axis(self._sdata)
206+
self._three_d_settings_label = QLabel("3D Settings:")
207+
self._three_d_settings_label.setVisible(self._has_z_data)
208+
self.discard_z_points.setVisible(self._has_z_data)
209+
self.discard_z_shapes.setVisible(self._has_z_data)
210+
177211
self.layout().addWidget(self.slider)
178212
self.layout().addWidget(QLabel("Coordinate System:"))
179213
self.layout().addWidget(self.coordinate_system_widget)
180214
self.layout().addWidget(QLabel("Elements:"))
181215
self.layout().addWidget(self.elements_widget)
216+
self.layout().addWidget(self._three_d_settings_label)
217+
self.layout().addWidget(self.discard_z_points)
218+
self.layout().addWidget(self.discard_z_shapes)
182219
self.elements_widget.itemDoubleClicked.connect(self._on_click_item)
183220
self.coordinate_system_widget.currentItemChanged.connect(
184221
lambda item: self.elements_widget._onItemChange(item.text())
@@ -256,6 +293,24 @@ def _update_layers_visibility(self) -> None:
256293
layer.metadata["_active_in_cs"].add(coordinate_system)
257294
layer.metadata["_current_cs"] = coordinate_system
258295

296+
def _on_discard_z_points_toggled(self, checked: bool) -> None:
297+
config.PROJECT_3D_POINTS_TO_2D = checked
298+
299+
def _on_discard_z_shapes_toggled(self, checked: bool) -> None:
300+
config.PROJECT_2_5D_SHAPES_TO_2D = checked
301+
302+
@staticmethod
303+
def _sdatas_have_z_axis(sdatas: EventedList) -> bool:
304+
"""Return ``True`` if any element across the given ``SpatialData`` objects has a z axis.
305+
306+
Used to decide whether to expose the 3D / 2.5D projection toggles in the widget.
307+
"""
308+
for sdata in sdatas:
309+
for _, _, element in sdata._gen_elements():
310+
if SpatialDataViewer._has_z_axis(element):
311+
return True
312+
return False
313+
259314
def _get_shapes(self, sdata: SpatialData, key: str, selected_cs: str, multi: bool) -> Shapes | Points:
260315
original_name = key[: key.rfind("_")] if multi else key
261316

src/napari_spatialdata/_viewer.py

Lines changed: 57 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -189,9 +189,6 @@ def _save_points_to_sdata(
189189
raise ValueError("Cannot export a points element with no points")
190190
transformed_data = np.array([layer_to_save.data_to_world(xy) for xy in layer_to_save.data])
191191
swap_data = np.fliplr(transformed_data)
192-
# ignore z axis if present
193-
if swap_data.shape[1] == 3:
194-
swap_data = swap_data[:, :2]
195192
parsed = PointsModel.parse(swap_data, transformations=transformation)
196193

197194
# saving to disk of points temporarily disabled until the interface update that will unify the view widget,
@@ -261,14 +258,21 @@ def _save_shapes_to_sdata(
261258
for shape in layer_to_save._data_view.shapes
262259
]
263260

264-
def _fix_coords(coords: ArrayLike) -> ArrayLike:
265-
remove_z = coords.shape[1] == 3
266-
first_index = 1 if remove_z else 0
267-
coords = coords[:, first_index::]
268-
return np.fliplr(coords)
261+
has_z = coords[0].shape[1] == 3
269262

270-
polygons: list[Polygon] = [Polygon(_fix_coords(p)) for p in coords]
271-
gdf = GeoDataFrame({"geometry": polygons})
263+
def _fix_coords(coords: ArrayLike) -> tuple[ArrayLike, float | None]:
264+
if coords.shape[1] == 3:
265+
z_val = float(coords[0, 0])
266+
yx = coords[:, 1:]
267+
return np.fliplr(yx), z_val
268+
return np.fliplr(coords), None
269+
270+
fixed = [_fix_coords(p) for p in coords]
271+
polygons: list[Polygon] = [Polygon(xy) for xy, _ in fixed]
272+
gdf_dict: dict[str, Any] = {"geometry": polygons}
273+
if has_z:
274+
gdf_dict["z"] = [z_val for _, z_val in fixed]
275+
gdf = GeoDataFrame(gdf_dict)
272276

273277
force_2d(gdf)
274278
parsed = ShapesModel.parse(gdf, transformations=transformation)
@@ -514,11 +518,15 @@ def get_sdata_circles(self, sdata: SpatialData, key: str, selected_cs: str, mult
514518
original_name = original_name[: original_name.rfind("_")]
515519

516520
df = sdata.shapes[original_name]
517-
affine = _get_transform(sdata.shapes[original_name], selected_cs)
521+
axes = get_axes_names(df)
522+
include_z = "z" in axes and not config.PROJECT_2_5D_SHAPES_TO_2D
523+
affine = _get_transform(sdata.shapes[original_name], selected_cs, include_z=include_z)
518524

519-
# 2.5D circles not supported yet
520525
xy = np.array([df.geometry.x, df.geometry.y]).T
521526
yx = np.fliplr(xy)
527+
if include_z:
528+
z_vals = df["z"].to_numpy()
529+
yx = np.column_stack([z_vals, yx])
522530
radii = df.radius.to_numpy()
523531

524532
adata, table_name, table_names = self._get_table_data(sdata, original_name)
@@ -561,7 +569,7 @@ def get_sdata_circles(self, sdata: SpatialData, key: str, selected_cs: str, mult
561569
else:
562570
kwargs |= {"border_color": "white"}
563571
# useful code to have readily available to debug the correct radius of circles when represented as points
564-
ellipses = _get_ellipses_from_circles(yx=yx, radii=radii)
572+
ellipses = _get_ellipses_from_circles(coords=yx, radii=radii)
565573
layer = Shapes(
566574
ellipses,
567575
shape_type="ellipse",
@@ -804,8 +812,43 @@ def _affine_transform_layers(self, coordinate_system: str) -> None:
804812
sdata = metadata["sdata"]
805813
element_name = metadata["name"]
806814
element_data = sdata[element_name]
807-
affine = _get_transform(element_data, coordinate_system)
815+
include_z = self._should_include_z(element_data)
816+
affine = _get_transform(element_data, coordinate_system, include_z=include_z)
808817
if affine is not None:
809818
layer.affine = affine
810819
if layer._type_string == "points":
811820
self._adjust_radii_of_points_layer(layer, affine)
821+
822+
@staticmethod
823+
def _has_z_axis(element: Any) -> bool:
824+
"""Return ``True`` if ``element`` exposes a ``z`` axis.
825+
826+
For raster elements (images / labels) the ``z`` axis is reported by
827+
:func:`spatialdata.models.get_axes_names`. For vector elements (points
828+
as :class:`~dask.dataframe.DataFrame`, shapes as
829+
:class:`~geopandas.GeoDataFrame`) the same helper is used.
830+
"""
831+
from xarray import DataArray, DataTree
832+
833+
if not isinstance(element, DataArray | DataTree | DaskDataFrame | GeoDataFrame):
834+
return False
835+
return "z" in get_axes_names(element)
836+
837+
@staticmethod
838+
def _should_include_z(element: DaskDataFrame | GeoDataFrame) -> bool:
839+
"""Determine whether to include the z axis for a given spatial element.
840+
841+
For raster data (images, labels) z is always included when present.
842+
For vector data (points, shapes) z inclusion depends on the user-facing
843+
projection config flags.
844+
"""
845+
from xarray import DataArray, DataTree
846+
847+
if isinstance(element, DataArray | DataTree):
848+
return True
849+
axes = get_axes_names(element)
850+
if "z" not in axes:
851+
return False
852+
if isinstance(element, DaskDataFrame):
853+
return not config.PROJECT_3D_POINTS_TO_2D
854+
return not config.PROJECT_2_5D_SHAPES_TO_2D

src/napari_spatialdata/utils/_utils.py

Lines changed: 55 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -184,6 +184,32 @@ def _transform_coordinates(data: list[Any], f: Callable[..., Any]) -> list[Any]:
184184
def _get_transform(
185185
element: SpatialElement, coordinate_system_name: str | None = None, include_z: bool | None = None
186186
) -> None | ArrayLike:
187+
"""Return the affine matrix for ``element`` in the given coordinate system.
188+
189+
The z axis is included in the returned affine when **both**:
190+
191+
* ``include_z`` is truthy, **and**
192+
* the element (and therefore its underlying transformation) has a ``z`` axis,
193+
as reported by :func:`spatialdata.models.get_axes_names`.
194+
195+
If ``include_z`` is requested but the element / transformation does not expose a
196+
``z`` axis, the flag is silently ignored and a 2D ``(y, x)`` affine is returned.
197+
198+
Parameters
199+
----------
200+
element
201+
The :class:`spatialdata.models.SpatialElement` for which to compute the affine.
202+
coordinate_system_name
203+
Coordinate system to use. If ``None``, the first available is selected.
204+
include_z
205+
Whether to include the z axis in the affine. The z is only included when the
206+
element / transformation also has a z axis; otherwise this flag is ignored.
207+
208+
Returns
209+
-------
210+
The affine matrix as an ``ArrayLike`` (``(3, 3)`` for 2D and ``(4, 4)`` for 2.5D/3D),
211+
or ``None`` if no transformation is defined for the requested coordinate system.
212+
"""
187213
if not isinstance(element, DataArray | DataTree | DaskDataFrame | GeoDataFrame):
188214
raise RuntimeError("Cannot get transform for {type(element)}")
189215

@@ -459,13 +485,17 @@ def generate_random_color_hex() -> str:
459485
return f"#{randint(0, 255):02x}{randint(0, 255):02x}{randint(0, 255):02x}ff"
460486

461487

462-
def _get_ellipses_from_circles(yx: ArrayLike, radii: ArrayLike) -> ArrayLike:
488+
def _get_ellipses_from_circles(coords: ArrayLike, radii: ArrayLike) -> ArrayLike:
463489
"""Convert circles to ellipses.
464490
491+
Supports both 2D and 2.5D centroids. For 2.5D input the radius is applied only to
492+
y and x while z is kept constant across the four corner vertices.
493+
465494
Parameters
466495
----------
467-
yx
468-
Centroids of the circles.
496+
coords
497+
Centroids of the circles with shape ``(N, 2)`` in ``(y, x)`` order or ``(N, 3)``
498+
in ``(z, y, x)`` order.
469499
radii
470500
Radii of the circles.
471501
@@ -474,14 +504,29 @@ def _get_ellipses_from_circles(yx: ArrayLike, radii: ArrayLike) -> ArrayLike:
474504
ArrayLike
475505
Ellipses.
476506
"""
477-
ndim = yx.shape[1]
478-
assert ndim == 2
479-
r = np.stack([radii] * ndim, axis=1)
480-
lower_left = yx - r
481-
upper_right = yx + r
507+
ndim = coords.shape[1]
508+
if ndim not in (2, 3):
509+
raise ValueError(f"Expected centroids with 2 or 3 columns (yx or zyx), got shape {coords.shape}.")
510+
511+
if ndim == 3:
512+
z = coords[:, :1]
513+
yx_2d = coords[:, 1:]
514+
else:
515+
yx_2d = coords
516+
517+
r = np.stack([radii, radii], axis=1)
518+
lower_left = yx_2d - r
519+
upper_right = yx_2d + r
482520
r[:, 0] = -r[:, 0]
483-
lower_right = yx - r
484-
upper_left = yx + r
521+
lower_right = yx_2d - r
522+
upper_left = yx_2d + r
523+
524+
if ndim == 3:
525+
lower_left = np.column_stack([z, lower_left])
526+
lower_right = np.column_stack([z, lower_right])
527+
upper_right = np.column_stack([z, upper_right])
528+
upper_left = np.column_stack([z, upper_left])
529+
485530
ellipses = np.stack([lower_left, lower_right, upper_right, upper_left], axis=1)
486531
assert isinstance(ellipses, np.ndarray)
487532
return ellipses

tests/conftest.py

Lines changed: 57 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,7 @@ def _safe_get_max_texture_sizes(): # type: ignore[no-untyped-def]
119119
from spatialdata._types import ArrayLike
120120
from spatialdata.datasets import blobs
121121
from spatialdata.models import PointsModel, ShapesModel, TableModel
122-
from spatialdata.transformations import Identity, set_transformation
122+
from spatialdata.transformations import Affine, Identity, set_transformation
123123

124124
from napari_spatialdata.utils._test_utils import export_figure, save_image
125125

@@ -415,3 +415,59 @@ def sdata_2_5d_shapes() -> SpatialData:
415415
shapes["shapes_2.5d"] = shape_element
416416

417417
return SpatialData(shapes=shapes)
418+
419+
420+
@pytest.fixture
421+
def sdata_2_5d_circles() -> SpatialData:
422+
"""Create a SpatialData object with 2.5D circles (circles at different z levels)."""
423+
n_circles = 10
424+
rng = np.random.default_rng(SEED)
425+
gdf = gpd.GeoDataFrame(
426+
{
427+
"geometry": gpd.points_from_xy(
428+
rng.uniform(0, 100, n_circles),
429+
rng.uniform(0, 100, n_circles),
430+
),
431+
"radius": rng.uniform(5, 15, n_circles),
432+
"z": rng.uniform(0, 50, n_circles),
433+
}
434+
)
435+
circles = ShapesModel.parse(gdf)
436+
set_transformation(circles, {"global": Identity()}, set_all=True)
437+
438+
return SpatialData(shapes={"circles_2.5d": circles})
439+
440+
441+
@pytest.fixture
442+
def sdata_3d_points_two_cs() -> SpatialData:
443+
"""Create a SpatialData with 3D points registered to two coordinate systems.
444+
445+
The element lives in ``global`` (identity) and in ``scaled`` (2x scale
446+
with a 10-unit z-translation). This is useful for testing that
447+
``_affine_transform_layers`` produces a correctly-sized affine matrix
448+
when switching between coordinate systems.
449+
"""
450+
n_points = 5
451+
rng = np.random.default_rng(SEED)
452+
df = pd.DataFrame(
453+
{
454+
"x": rng.uniform(0, 100, n_points),
455+
"y": rng.uniform(0, 100, n_points),
456+
"z": rng.uniform(0, 50, n_points),
457+
}
458+
)
459+
dask_df = from_pandas(df, npartitions=1)
460+
points = PointsModel.parse(dask_df)
461+
462+
affine_matrix = np.array(
463+
[
464+
[2.0, 0.0, 0.0, 0.0],
465+
[0.0, 2.0, 0.0, 0.0],
466+
[0.0, 0.0, 2.0, 10.0],
467+
[0.0, 0.0, 0.0, 1.0],
468+
]
469+
)
470+
scaled_affine = Affine(affine_matrix, input_axes=("x", "y", "z"), output_axes=("x", "y", "z"))
471+
set_transformation(points, {"global": Identity(), "scaled": scaled_affine}, set_all=True)
472+
473+
return SpatialData(points={"points_3d": points})

0 commit comments

Comments
 (0)