diff --git a/src/napari_spatialdata/_sdata_widgets.py b/src/napari_spatialdata/_sdata_widgets.py index 1c400136..549f43ca 100644 --- a/src/napari_spatialdata/_sdata_widgets.py +++ b/src/napari_spatialdata/_sdata_widgets.py @@ -15,11 +15,20 @@ from packaging.version import parse as parse_version from qtpy.QtCore import QThread, Signal from qtpy.QtGui import QIcon -from qtpy.QtWidgets import QLabel, QListWidget, QListWidgetItem, QProgressBar, QVBoxLayout, QWidget +from qtpy.QtWidgets import ( + QCheckBox, + QLabel, + QListWidget, + QListWidgetItem, + QProgressBar, + QVBoxLayout, + QWidget, +) from spatialdata import SpatialData from spatialdata.models._utils import DEFAULT_COORDINATE_SYSTEM from napari_spatialdata._viewer import SpatialDataViewer +from napari_spatialdata.constants import config from napari_spatialdata.constants.config import N_CIRCLES_WARNING_THRESHOLD, N_SHAPES_WARNING_THRESHOLD from napari_spatialdata.utils._utils import _get_sdata_key, get_duplicate_element_names, get_elements_meta_mapping @@ -174,11 +183,39 @@ def __init__(self, viewer: Viewer, sdata: EventedList): self.slider.setRange(0, 0) self.slider.setVisible(False) + self.discard_z_points = QCheckBox("Discard z for 3D points") + self.discard_z_points.setChecked(config.PROJECT_3D_POINTS_TO_2D) + self.discard_z_points.setToolTip( + "When checked, the z coordinate of new points layers is discarded so they are loaded in 2D. " + "Only applies to new layers; layers already displayed are not affected." + ) + self.discard_z_points.toggled.connect(self._on_discard_z_points_toggled) + + self.discard_z_shapes = QCheckBox("Discard z for 2.5D shapes") + self.discard_z_shapes.setChecked(config.PROJECT_2_5D_SHAPES_TO_2D) + self.discard_z_shapes.setToolTip( + "When checked, the z coordinate of new shapes layers is discarded so they are loaded in 2D. " + "Only applies to new layers; layers already displayed are not affected." + ) + self.discard_z_shapes.toggled.connect(self._on_discard_z_shapes_toggled) + + # The 3D toggles only matter when at least one element across the loaded + # SpatialData objects has a z axis. Otherwise we hide them to save screen + # real estate for users working with 2D-only data. + self._has_z_data = self._sdatas_have_z_axis(self._sdata) + self._three_d_settings_label = QLabel("3D Settings:") + self._three_d_settings_label.setVisible(self._has_z_data) + self.discard_z_points.setVisible(self._has_z_data) + self.discard_z_shapes.setVisible(self._has_z_data) + self.layout().addWidget(self.slider) self.layout().addWidget(QLabel("Coordinate System:")) self.layout().addWidget(self.coordinate_system_widget) self.layout().addWidget(QLabel("Elements:")) self.layout().addWidget(self.elements_widget) + self.layout().addWidget(self._three_d_settings_label) + self.layout().addWidget(self.discard_z_points) + self.layout().addWidget(self.discard_z_shapes) self.elements_widget.itemDoubleClicked.connect(self._on_click_item) self.coordinate_system_widget.currentItemChanged.connect( lambda item: self.elements_widget._onItemChange(item.text()) @@ -256,6 +293,24 @@ def _update_layers_visibility(self) -> None: layer.metadata["_active_in_cs"].add(coordinate_system) layer.metadata["_current_cs"] = coordinate_system + def _on_discard_z_points_toggled(self, checked: bool) -> None: + config.PROJECT_3D_POINTS_TO_2D = checked + + def _on_discard_z_shapes_toggled(self, checked: bool) -> None: + config.PROJECT_2_5D_SHAPES_TO_2D = checked + + @staticmethod + def _sdatas_have_z_axis(sdatas: EventedList) -> bool: + """Return ``True`` if any element across the given ``SpatialData`` objects has a z axis. + + Used to decide whether to expose the 3D / 2.5D projection toggles in the widget. + """ + for sdata in sdatas: + for _, _, element in sdata._gen_elements(): + if SpatialDataViewer._has_z_axis(element): + return True + return False + def _get_shapes(self, sdata: SpatialData, key: str, selected_cs: str, multi: bool) -> Shapes | Points: original_name = key[: key.rfind("_")] if multi else key diff --git a/src/napari_spatialdata/_viewer.py b/src/napari_spatialdata/_viewer.py index 8d5e3d71..cca44a92 100644 --- a/src/napari_spatialdata/_viewer.py +++ b/src/napari_spatialdata/_viewer.py @@ -189,9 +189,6 @@ def _save_points_to_sdata( raise ValueError("Cannot export a points element with no points") transformed_data = np.array([layer_to_save.data_to_world(xy) for xy in layer_to_save.data]) swap_data = np.fliplr(transformed_data) - # ignore z axis if present - if swap_data.shape[1] == 3: - swap_data = swap_data[:, :2] parsed = PointsModel.parse(swap_data, transformations=transformation) # saving to disk of points temporarily disabled until the interface update that will unify the view widget, @@ -261,14 +258,21 @@ def _save_shapes_to_sdata( for shape in layer_to_save._data_view.shapes ] - def _fix_coords(coords: ArrayLike) -> ArrayLike: - remove_z = coords.shape[1] == 3 - first_index = 1 if remove_z else 0 - coords = coords[:, first_index::] - return np.fliplr(coords) + has_z = coords[0].shape[1] == 3 - polygons: list[Polygon] = [Polygon(_fix_coords(p)) for p in coords] - gdf = GeoDataFrame({"geometry": polygons}) + def _fix_coords(coords: ArrayLike) -> tuple[ArrayLike, float | None]: + if coords.shape[1] == 3: + z_val = float(coords[0, 0]) + yx = coords[:, 1:] + return np.fliplr(yx), z_val + return np.fliplr(coords), None + + fixed = [_fix_coords(p) for p in coords] + polygons: list[Polygon] = [Polygon(xy) for xy, _ in fixed] + gdf_dict: dict[str, Any] = {"geometry": polygons} + if has_z: + gdf_dict["z"] = [z_val for _, z_val in fixed] + gdf = GeoDataFrame(gdf_dict) force_2d(gdf) parsed = ShapesModel.parse(gdf, transformations=transformation) @@ -514,11 +518,15 @@ def get_sdata_circles(self, sdata: SpatialData, key: str, selected_cs: str, mult original_name = original_name[: original_name.rfind("_")] df = sdata.shapes[original_name] - affine = _get_transform(sdata.shapes[original_name], selected_cs) + axes = get_axes_names(df) + include_z = "z" in axes and not config.PROJECT_2_5D_SHAPES_TO_2D + affine = _get_transform(sdata.shapes[original_name], selected_cs, include_z=include_z) - # 2.5D circles not supported yet xy = np.array([df.geometry.x, df.geometry.y]).T yx = np.fliplr(xy) + if include_z: + z_vals = df["z"].to_numpy() + yx = np.column_stack([z_vals, yx]) radii = df.radius.to_numpy() adata, table_name, table_names = self._get_table_data(sdata, original_name) @@ -561,7 +569,7 @@ def get_sdata_circles(self, sdata: SpatialData, key: str, selected_cs: str, mult else: kwargs |= {"border_color": "white"} # useful code to have readily available to debug the correct radius of circles when represented as points - ellipses = _get_ellipses_from_circles(yx=yx, radii=radii) + ellipses = _get_ellipses_from_circles(coords=yx, radii=radii) layer = Shapes( ellipses, shape_type="ellipse", @@ -804,8 +812,43 @@ def _affine_transform_layers(self, coordinate_system: str) -> None: sdata = metadata["sdata"] element_name = metadata["name"] element_data = sdata[element_name] - affine = _get_transform(element_data, coordinate_system) + include_z = self._should_include_z(element_data) + affine = _get_transform(element_data, coordinate_system, include_z=include_z) if affine is not None: layer.affine = affine if layer._type_string == "points": self._adjust_radii_of_points_layer(layer, affine) + + @staticmethod + def _has_z_axis(element: Any) -> bool: + """Return ``True`` if ``element`` exposes a ``z`` axis. + + For raster elements (images / labels) the ``z`` axis is reported by + :func:`spatialdata.models.get_axes_names`. For vector elements (points + as :class:`~dask.dataframe.DataFrame`, shapes as + :class:`~geopandas.GeoDataFrame`) the same helper is used. + """ + from xarray import DataArray, DataTree + + if not isinstance(element, DataArray | DataTree | DaskDataFrame | GeoDataFrame): + return False + return "z" in get_axes_names(element) + + @staticmethod + def _should_include_z(element: DaskDataFrame | GeoDataFrame) -> bool: + """Determine whether to include the z axis for a given spatial element. + + For raster data (images, labels) z is always included when present. + For vector data (points, shapes) z inclusion depends on the user-facing + projection config flags. + """ + from xarray import DataArray, DataTree + + if isinstance(element, DataArray | DataTree): + return True + axes = get_axes_names(element) + if "z" not in axes: + return False + if isinstance(element, DaskDataFrame): + return not config.PROJECT_3D_POINTS_TO_2D + return not config.PROJECT_2_5D_SHAPES_TO_2D diff --git a/src/napari_spatialdata/utils/_utils.py b/src/napari_spatialdata/utils/_utils.py index cfe37e23..f46d8cef 100644 --- a/src/napari_spatialdata/utils/_utils.py +++ b/src/napari_spatialdata/utils/_utils.py @@ -184,6 +184,32 @@ def _transform_coordinates(data: list[Any], f: Callable[..., Any]) -> list[Any]: def _get_transform( element: SpatialElement, coordinate_system_name: str | None = None, include_z: bool | None = None ) -> None | ArrayLike: + """Return the affine matrix for ``element`` in the given coordinate system. + + The z axis is included in the returned affine when **both**: + + * ``include_z`` is truthy, **and** + * the element (and therefore its underlying transformation) has a ``z`` axis, + as reported by :func:`spatialdata.models.get_axes_names`. + + If ``include_z`` is requested but the element / transformation does not expose a + ``z`` axis, the flag is silently ignored and a 2D ``(y, x)`` affine is returned. + + Parameters + ---------- + element + The :class:`spatialdata.models.SpatialElement` for which to compute the affine. + coordinate_system_name + Coordinate system to use. If ``None``, the first available is selected. + include_z + Whether to include the z axis in the affine. The z is only included when the + element / transformation also has a z axis; otherwise this flag is ignored. + + Returns + ------- + The affine matrix as an ``ArrayLike`` (``(3, 3)`` for 2D and ``(4, 4)`` for 2.5D/3D), + or ``None`` if no transformation is defined for the requested coordinate system. + """ if not isinstance(element, DataArray | DataTree | DaskDataFrame | GeoDataFrame): raise RuntimeError("Cannot get transform for {type(element)}") @@ -459,13 +485,17 @@ def generate_random_color_hex() -> str: return f"#{randint(0, 255):02x}{randint(0, 255):02x}{randint(0, 255):02x}ff" -def _get_ellipses_from_circles(yx: ArrayLike, radii: ArrayLike) -> ArrayLike: +def _get_ellipses_from_circles(coords: ArrayLike, radii: ArrayLike) -> ArrayLike: """Convert circles to ellipses. + Supports both 2D and 2.5D centroids. For 2.5D input the radius is applied only to + y and x while z is kept constant across the four corner vertices. + Parameters ---------- - yx - Centroids of the circles. + coords + Centroids of the circles with shape ``(N, 2)`` in ``(y, x)`` order or ``(N, 3)`` + in ``(z, y, x)`` order. radii Radii of the circles. @@ -474,14 +504,29 @@ def _get_ellipses_from_circles(yx: ArrayLike, radii: ArrayLike) -> ArrayLike: ArrayLike Ellipses. """ - ndim = yx.shape[1] - assert ndim == 2 - r = np.stack([radii] * ndim, axis=1) - lower_left = yx - r - upper_right = yx + r + ndim = coords.shape[1] + if ndim not in (2, 3): + raise ValueError(f"Expected centroids with 2 or 3 columns (yx or zyx), got shape {coords.shape}.") + + if ndim == 3: + z = coords[:, :1] + yx_2d = coords[:, 1:] + else: + yx_2d = coords + + r = np.stack([radii, radii], axis=1) + lower_left = yx_2d - r + upper_right = yx_2d + r r[:, 0] = -r[:, 0] - lower_right = yx - r - upper_left = yx + r + lower_right = yx_2d - r + upper_left = yx_2d + r + + if ndim == 3: + lower_left = np.column_stack([z, lower_left]) + lower_right = np.column_stack([z, lower_right]) + upper_right = np.column_stack([z, upper_right]) + upper_left = np.column_stack([z, upper_left]) + ellipses = np.stack([lower_left, lower_right, upper_right, upper_left], axis=1) assert isinstance(ellipses, np.ndarray) return ellipses diff --git a/tests/conftest.py b/tests/conftest.py index 5089971a..553f1cea 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -119,7 +119,7 @@ def _safe_get_max_texture_sizes(): # type: ignore[no-untyped-def] from spatialdata._types import ArrayLike from spatialdata.datasets import blobs from spatialdata.models import PointsModel, ShapesModel, TableModel -from spatialdata.transformations import Identity, set_transformation +from spatialdata.transformations import Affine, Identity, set_transformation from napari_spatialdata.utils._test_utils import export_figure, save_image @@ -415,3 +415,59 @@ def sdata_2_5d_shapes() -> SpatialData: shapes["shapes_2.5d"] = shape_element return SpatialData(shapes=shapes) + + +@pytest.fixture +def sdata_2_5d_circles() -> SpatialData: + """Create a SpatialData object with 2.5D circles (circles at different z levels).""" + n_circles = 10 + rng = np.random.default_rng(SEED) + gdf = gpd.GeoDataFrame( + { + "geometry": gpd.points_from_xy( + rng.uniform(0, 100, n_circles), + rng.uniform(0, 100, n_circles), + ), + "radius": rng.uniform(5, 15, n_circles), + "z": rng.uniform(0, 50, n_circles), + } + ) + circles = ShapesModel.parse(gdf) + set_transformation(circles, {"global": Identity()}, set_all=True) + + return SpatialData(shapes={"circles_2.5d": circles}) + + +@pytest.fixture +def sdata_3d_points_two_cs() -> SpatialData: + """Create a SpatialData with 3D points registered to two coordinate systems. + + The element lives in ``global`` (identity) and in ``scaled`` (2x scale + with a 10-unit z-translation). This is useful for testing that + ``_affine_transform_layers`` produces a correctly-sized affine matrix + when switching between coordinate systems. + """ + n_points = 5 + rng = np.random.default_rng(SEED) + df = pd.DataFrame( + { + "x": rng.uniform(0, 100, n_points), + "y": rng.uniform(0, 100, n_points), + "z": rng.uniform(0, 50, n_points), + } + ) + dask_df = from_pandas(df, npartitions=1) + points = PointsModel.parse(dask_df) + + affine_matrix = np.array( + [ + [2.0, 0.0, 0.0, 0.0], + [0.0, 2.0, 0.0, 0.0], + [0.0, 0.0, 2.0, 10.0], + [0.0, 0.0, 0.0, 1.0], + ] + ) + scaled_affine = Affine(affine_matrix, input_axes=("x", "y", "z"), output_axes=("x", "y", "z")) + set_transformation(points, {"global": Identity(), "scaled": scaled_affine}, set_all=True) + + return SpatialData(points={"points_3d": points}) diff --git a/tests/test_3d_visualization.py b/tests/test_3d_visualization.py index 41ee43f3..504f08dd 100644 --- a/tests/test_3d_visualization.py +++ b/tests/test_3d_visualization.py @@ -1,14 +1,17 @@ -"""Tests for 3D points and 2.5D shapes visualization. +"""Tests for 3D points, 2.5D shapes and 2.5D circles visualization. For debugging tips on how to visually inspect tests, see docs/contributing.md. """ +from pathlib import Path from typing import Any +import numpy as np import pytest from napari.layers import Points, Shapes from napari.utils.events import EventedList from spatialdata import SpatialData +from spatialdata.models import get_axes_names from napari_spatialdata._sdata_widgets import SdataWidget from napari_spatialdata.constants import config @@ -17,11 +20,22 @@ class Test3DPointsVisualization: """Test 3D points visualization in napari.""" - def test_3d_points_projected_to_2d(self, make_napari_viewer: Any, sdata_3d_points: SpatialData): - """Test that 3D points are projected to 2D when config flag is True.""" + @pytest.mark.parametrize( + ("project_to_2d", "expected_ndim"), + [(True, 2), (False, 3)], + ids=["projected_to_2d", "full_3d"], + ) + def test_3d_points_visualization( + self, + make_napari_viewer: Any, + sdata_3d_points: SpatialData, + project_to_2d: bool, + expected_ndim: int, + ): + """Points dimensionality follows the ``PROJECT_3D_POINTS_TO_2D`` config flag.""" original_value = config.PROJECT_3D_POINTS_TO_2D try: - config.PROJECT_3D_POINTS_TO_2D = True + config.PROJECT_3D_POINTS_TO_2D = project_to_2d viewer = make_napari_viewer() widget = SdataWidget(viewer, EventedList([sdata_3d_points])) @@ -33,84 +47,260 @@ def test_3d_points_projected_to_2d(self, make_napari_viewer: Any, sdata_3d_point assert len(viewer.layers) == 1 assert isinstance(viewer.layers[0], Points) - # 2D projection: points should have 2 coordinates - assert viewer.layers[0].data.shape[1] == 2 + assert viewer.layers[0].data.shape[1] == expected_ndim finally: config.PROJECT_3D_POINTS_TO_2D = original_value - def test_3d_points_full_3d(self, make_napari_viewer: Any, sdata_3d_points: SpatialData): - """Test that 3D points are visualized in 3D when config flag is False.""" + +class Test2_5DShapesVisualization: + """Test 2.5D shapes visualization in napari.""" + + @pytest.mark.parametrize( + ("project_to_2d", "expected_ndim"), + [(True, 2), (False, 3)], + ids=["projected_to_2d", "full_3d"], + ) + def test_2_5d_shapes_visualization( + self, + make_napari_viewer: Any, + sdata_2_5d_shapes: SpatialData, + project_to_2d: bool, + expected_ndim: int, + ): + """Shape vertex dimensionality follows the ``PROJECT_2_5D_SHAPES_TO_2D`` config flag.""" + original_value = config.PROJECT_2_5D_SHAPES_TO_2D + try: + config.PROJECT_2_5D_SHAPES_TO_2D = project_to_2d + + viewer = make_napari_viewer() + widget = SdataWidget(viewer, EventedList([sdata_2_5d_shapes])) + + widget.coordinate_system_widget._select_coord_sys("global") + widget.elements_widget._onItemChange("global") + widget._onClick("shapes_2.5d") + + assert len(viewer.layers) == 1 + assert isinstance(viewer.layers[0], Shapes) + for shape_data in viewer.layers[0].data: + assert shape_data.shape[1] == expected_ndim + finally: + config.PROJECT_2_5D_SHAPES_TO_2D = original_value + + +class Test2_5DCirclesVisualization: + """Test 2.5D circles visualization in napari.""" + + @pytest.mark.parametrize( + ("project_to_2d", "expected_ndim"), + [(True, 2), (False, 3)], + ids=["projected_to_2d", "full_3d"], + ) + def test_2_5d_circles_visualization( + self, + make_napari_viewer: Any, + sdata_2_5d_circles: SpatialData, + project_to_2d: bool, + expected_ndim: int, + ): + """Circles dimensionality follows the ``PROJECT_2_5D_SHAPES_TO_2D`` config flag.""" + original_value = config.PROJECT_2_5D_SHAPES_TO_2D + try: + config.PROJECT_2_5D_SHAPES_TO_2D = project_to_2d + + viewer = make_napari_viewer() + widget = SdataWidget(viewer, EventedList([sdata_2_5d_circles])) + + widget.coordinate_system_widget._select_coord_sys("global") + widget.elements_widget._onItemChange("global") + widget._onClick("circles_2.5d") + + assert len(viewer.layers) == 1 + assert viewer.layers[0].data.shape[1] == expected_ndim + finally: + config.PROJECT_2_5D_SHAPES_TO_2D = original_value + + +class TestAffineTransformLayers: + """Test that ``_affine_transform_layers`` propagates ``include_z`` correctly.""" + + @pytest.mark.parametrize( + ("project_to_2d", "expected_data_ndim", "expected_affine_shape"), + [(False, 3, (4, 4)), (True, 2, (3, 3))], + ids=["full_3d", "projected_to_2d"], + ) + def test_affine_transform_preserves_dimensionality( + self, + make_napari_viewer: Any, + sdata_3d_points_two_cs: SpatialData, + project_to_2d: bool, + expected_data_ndim: int, + expected_affine_shape: tuple[int, int], + ): + """Switching coordinate system preserves the affine matrix dimensionality.""" original_value = config.PROJECT_3D_POINTS_TO_2D try: - config.PROJECT_3D_POINTS_TO_2D = False + config.PROJECT_3D_POINTS_TO_2D = project_to_2d viewer = make_napari_viewer() - widget = SdataWidget(viewer, EventedList([sdata_3d_points])) + widget = SdataWidget(viewer, EventedList([sdata_3d_points_two_cs])) widget.coordinate_system_widget._select_coord_sys("global") widget.elements_widget._onItemChange("global") widget._onClick("points_3d") - viewer.dims.ndisplay = 3 assert len(viewer.layers) == 1 - assert isinstance(viewer.layers[0], Points) - # Full 3D: points should have 3 coordinates (z, y, x) - assert viewer.layers[0].data.shape[1] == 3 + layer = viewer.layers[0] + assert isinstance(layer, Points) + assert layer.data.shape[1] == expected_data_ndim + + # Identity in "global" -> affine should be the identity of the expected shape + np.testing.assert_array_almost_equal(layer.affine.affine_matrix, np.eye(expected_affine_shape[0])) + + widget.coordinate_system_widget._select_coord_sys("scaled") + widget.viewer_model._affine_transform_layers("scaled") + + # After switching the affine must keep its dimensionality + assert layer.affine.affine_matrix.shape == expected_affine_shape + if not project_to_2d: + assert not np.allclose(layer.affine.affine_matrix, np.eye(expected_affine_shape[0])) finally: config.PROJECT_3D_POINTS_TO_2D = original_value -class Test2_5DShapesVisualization: - """Test 2.5D shapes visualization in napari.""" +class TestSavePointsPreservesZ: + """Test that saving points correctly handles the z coordinate.""" - def test_2_5d_shapes_projected_to_2d(self, make_napari_viewer: Any, sdata_2_5d_shapes: SpatialData): - """Test that 2.5D shapes are projected to 2D when config flag is True.""" - original_value = config.PROJECT_2_5D_SHAPES_TO_2D + @pytest.mark.parametrize( + ("project_to_2d", "expected_data_ndim", "z_in_axes"), + [(False, 3, True), (True, 2, False)], + ids=["preserve_z", "drop_z"], + ) + def test_save_points_z_handling( + self, + tmp_path: Path, + make_napari_viewer: Any, + sdata_3d_points: SpatialData, + project_to_2d: bool, + expected_data_ndim: int, + z_in_axes: bool, + ): + """Saving a 3D points layer must retain or drop the z column based on the config flag.""" + original_value = config.PROJECT_3D_POINTS_TO_2D try: - config.PROJECT_2_5D_SHAPES_TO_2D = True + config.PROJECT_3D_POINTS_TO_2D = project_to_2d + + tmpdir = tmp_path / "sdata.zarr" + sdata_3d_points.write(tmpdir) viewer = make_napari_viewer() - widget = SdataWidget(viewer, EventedList([sdata_2_5d_shapes])) + widget = SdataWidget(viewer, EventedList([sdata_3d_points])) widget.coordinate_system_widget._select_coord_sys("global") widget.elements_widget._onItemChange("global") + widget._onClick("points_3d") - # Add 2.5D shapes - widget._onClick("shapes_2.5d") + layer = viewer.layers[0] + assert isinstance(layer, Points) + assert layer.data.shape[1] == expected_data_ndim - assert len(viewer.layers) == 1 - assert isinstance(viewer.layers[0], Shapes) - # 2D projection: shape coordinates should have 2 values per vertex (y, x) - for shape_data in viewer.layers[0].data: - assert shape_data.shape[1] == 2 + parsed, _ = widget.viewer_model._save_points_to_sdata(layer, "points_3d", overwrite=True) + + saved_axes = get_axes_names(parsed) + assert ("z" in saved_axes) is z_in_axes + if z_in_axes: + original_z = sdata_3d_points.points["points_3d"].compute()["z"].values + saved_z = parsed.compute()["z"].values + np.testing.assert_array_almost_equal(saved_z, original_z) finally: - config.PROJECT_2_5D_SHAPES_TO_2D = original_value + config.PROJECT_3D_POINTS_TO_2D = original_value + + +class TestSaveShapesPreservesZ: + """Test that saving shapes correctly handles the z coordinate.""" - def test_2_5d_shapes_full_3d(self, make_napari_viewer: Any, sdata_2_5d_shapes: SpatialData): - """Test that 2.5D shapes are visualized in 3D when config flag is False.""" + @pytest.mark.parametrize( + ("project_to_2d", "expected_vertex_ndim", "z_in_axes"), + [(False, 3, True), (True, 2, False)], + ids=["preserve_z", "drop_z"], + ) + def test_save_shapes_z_handling( + self, + tmp_path: Path, + make_napari_viewer: Any, + sdata_2_5d_shapes: SpatialData, + project_to_2d: bool, + expected_vertex_ndim: int, + z_in_axes: bool, + ): + """Saving a 2.5D shapes layer must retain or drop the z column based on the config flag.""" original_value = config.PROJECT_2_5D_SHAPES_TO_2D try: - config.PROJECT_2_5D_SHAPES_TO_2D = False + config.PROJECT_2_5D_SHAPES_TO_2D = project_to_2d + + tmpdir = tmp_path / "sdata.zarr" + sdata_2_5d_shapes.write(tmpdir) viewer = make_napari_viewer() widget = SdataWidget(viewer, EventedList([sdata_2_5d_shapes])) widget.coordinate_system_widget._select_coord_sys("global") widget.elements_widget._onItemChange("global") - - # Add 2.5D shapes widget._onClick("shapes_2.5d") - assert len(viewer.layers) == 1 - assert isinstance(viewer.layers[0], Shapes) - # Full 3D: shape coordinates should have 3 values per vertex (z, y, x) - for shape_data in viewer.layers[0].data: - assert shape_data.shape[1] == 3 + layer = viewer.layers[0] + assert isinstance(layer, Shapes) + for shape_data in layer.data: + assert shape_data.shape[1] == expected_vertex_ndim + + parsed, _ = widget.viewer_model._save_shapes_to_sdata(layer, "shapes_2.5d", overwrite=True) + + saved_axes = get_axes_names(parsed) + assert ("z" in saved_axes) is z_in_axes + + if z_in_axes: + saved_z = parsed["z"].values + original_unique_z = np.unique(sdata_2_5d_shapes.shapes["shapes_2.5d"]["z"].values) + np.testing.assert_array_almost_equal(np.unique(saved_z), original_unique_z) finally: config.PROJECT_2_5D_SHAPES_TO_2D = original_value +class TestUIToggle: + """Test the 3D settings checkboxes in SdataWidget.""" + + def test_toggle_affects_loaded_points( + self, + make_napari_viewer: Any, + sdata_3d_points: SpatialData, + ): + """Toggling the checkbox affects the dimensionality of newly loaded layers. + + This implicitly also tests that the checkbox state and the underlying + ``config.PROJECT_3D_POINTS_TO_2D`` flag stay in sync. + """ + original_value = config.PROJECT_3D_POINTS_TO_2D + try: + config.PROJECT_3D_POINTS_TO_2D = True + + viewer = make_napari_viewer() + widget = SdataWidget(viewer, EventedList([sdata_3d_points])) + + widget.coordinate_system_widget._select_coord_sys("global") + widget.elements_widget._onItemChange("global") + + widget._onClick("points_3d") + assert viewer.layers[0].data.shape[1] == 2 + + viewer.layers.clear() + + widget.discard_z_points.setChecked(False) + widget._onClick("points_3d") + assert viewer.layers[0].data.shape[1] == 3 + finally: + config.PROJECT_3D_POINTS_TO_2D = original_value + + class TestMixed2D3DVisualization: """Test mixed 2D and 3D visualization scenarios.""" @@ -137,7 +327,6 @@ def test_mixed_dimension_visualization( config.PROJECT_3D_POINTS_TO_2D = points_dim == 2 config.PROJECT_2_5D_SHAPES_TO_2D = shapes_dim == 2 - # Create a combined SpatialData combined_sdata = SpatialData( points={"points_3d": sdata_3d_points["points_3d"]}, shapes={"shapes_2.5d": sdata_2_5d_shapes["shapes_2.5d"]},