diff --git a/setup.cfg b/setup.cfg index 9097cf86..fea284df 100644 --- a/setup.cfg +++ b/setup.cfg @@ -37,7 +37,7 @@ install_requires = anndata click cycler - dask>=2024.4.1,<=2024.11.2 + dask>=2025.2.0 geopandas loguru matplotlib diff --git a/src/napari_spatialdata/_viewer.py b/src/napari_spatialdata/_viewer.py index c6af7139..8d5e3d71 100644 --- a/src/napari_spatialdata/_viewer.py +++ b/src/napari_spatialdata/_viewer.py @@ -18,12 +18,11 @@ from spatialdata import get_element_annotators, get_element_instances from spatialdata._core.query.relational_query import _left_join_spatialelement_table from spatialdata._types import ArrayLike -from spatialdata.models import PointsModel, ShapesModel, TableModel, force_2d, get_channel_names +from spatialdata.models import PointsModel, ShapesModel, TableModel, force_2d, get_axes_names, get_channel_names from spatialdata.transformations import Affine, Identity from napari_spatialdata._model import DataModel from napari_spatialdata.constants import config -from napari_spatialdata.constants.config import CIRCLES_AS_POINTS from napari_spatialdata.utils._utils import ( _adjust_channels_order, _get_ellipses_from_circles, @@ -470,7 +469,7 @@ def get_sdata_image(self, sdata: SpatialData, key: str, selected_cs: str, multi: if multi: original_name = original_name[: original_name.rfind("_")] - affine = _get_transform(sdata.images[original_name], selected_cs) + affine = _get_transform(sdata.images[original_name], selected_cs, include_z=True) rgb_image, rgb = _adjust_channels_order(element=sdata.images[original_name]) channels = ("RGB(A)",) if rgb else get_channel_names(sdata.images[original_name]) @@ -517,6 +516,7 @@ def get_sdata_circles(self, sdata: SpatialData, key: str, selected_cs: str, mult df = sdata.shapes[original_name] affine = _get_transform(sdata.shapes[original_name], selected_cs) + # 2.5D circles not supported yet xy = np.array([df.geometry.x, df.geometry.y]).T yx = np.fliplr(xy) radii = df.radius.to_numpy() @@ -541,10 +541,10 @@ def get_sdata_circles(self, sdata: SpatialData, key: str, selected_cs: str, mult version = get_napari_version() kwargs: dict[str, Any] = ( {"edge_width": 0.0} - if version <= packaging.version.parse("0.4.20") or not CIRCLES_AS_POINTS + if version <= packaging.version.parse("0.4.20") or not config.CIRCLES_AS_POINTS else {"border_width": 0.0} ) - if CIRCLES_AS_POINTS: + if config.CIRCLES_AS_POINTS: layer = Points( yx, name=key, @@ -556,7 +556,7 @@ def get_sdata_circles(self, sdata: SpatialData, key: str, selected_cs: str, mult assert affine is not None self._adjust_radii_of_points_layer(layer=layer, affine=affine) else: - if version <= packaging.version.parse("0.4.20") or not CIRCLES_AS_POINTS: + if version <= packaging.version.parse("0.4.20") or not config.CIRCLES_AS_POINTS: kwargs |= {"edge_color": "white"} else: kwargs |= {"border_color": "white"} @@ -597,7 +597,8 @@ def get_sdata_shapes(self, sdata: SpatialData, key: str, selected_cs: str, multi original_name = original_name[: original_name.rfind("_")] df = sdata.shapes[original_name] - affine = _get_transform(sdata.shapes[original_name], selected_cs) + include_z = not config.PROJECT_2_5D_SHAPES_TO_2D + affine = _get_transform(sdata.shapes[original_name], selected_cs, include_z=include_z) # when mulitpolygons are present, we select the largest ones if "MultiPolygon" in np.unique(df.geometry.type): @@ -609,7 +610,7 @@ def get_sdata_shapes(self, sdata: SpatialData, key: str, selected_cs: str, multi df = df.sort_index() # reset the index to the first order simplify = len(df) > config.POLYGON_THRESHOLD - polygons, indices = _get_polygons_properties(df, simplify) + polygons, indices = _get_polygons_properties(df, simplify, include_z=include_z) # this will only work for polygons and not for multipolygons polygons = _transform_coordinates(polygons, f=lambda x: x[::-1]) @@ -662,7 +663,7 @@ def get_sdata_labels(self, sdata: SpatialData, key: str, selected_cs: str, multi original_name = original_name[: original_name.rfind("_")] indices = get_element_instances(sdata.labels[original_name]) - affine = _get_transform(sdata.labels[original_name], selected_cs) + affine = _get_transform(sdata.labels[original_name], selected_cs, include_z=True) rgb_labels, _ = _adjust_channels_order(element=sdata.labels[original_name]) adata, table_name, table_names = self._get_table_data(sdata, original_name) @@ -706,8 +707,10 @@ def get_sdata_points(self, sdata: SpatialData, key: str, selected_cs: str, multi if multi: original_name = original_name[: original_name.rfind("_")] + axes = get_axes_names(sdata.points[original_name]) points = sdata.points[original_name].compute() - affine = _get_transform(sdata.points[original_name], selected_cs) + include_z = "z" in axes and not config.PROJECT_3D_POINTS_TO_2D + affine = _get_transform(sdata.points[original_name], selected_cs, include_z=include_z) adata, table_name, table_names = self._get_table_data(sdata, original_name) if len(points) < config.POINT_THRESHOLD: @@ -727,14 +730,16 @@ def get_sdata_points(self, sdata: SpatialData, key: str, selected_cs: str, multi _, adata = _left_join_spatialelement_table( {"points": {original_name: subsample_points}}, sdata[table_name], match_rows="left" ) - xy = subsample_points[["y", "x"]].values - np.fliplr(xy) + axes = sorted(axes, reverse=True) + if not include_z and "z" in axes: + axes.remove("z") + coords = subsample_points[axes].values # radii_size = _calc_default_radii(self.viewer, sdata, selected_cs) radii_size = 3 version = get_napari_version() kwargs = {"edge_width": 0.0} if version <= packaging.version.parse("0.4.20") else {"border_width": 0.0} layer = Points( - xy, + coords, name=key, size=radii_size * 2, affine=affine, diff --git a/src/napari_spatialdata/constants/config.py b/src/napari_spatialdata/constants/config.py index 14054767..ddcebc23 100644 --- a/src/napari_spatialdata/constants/config.py +++ b/src/napari_spatialdata/constants/config.py @@ -4,3 +4,5 @@ N_SHAPES_WARNING_THRESHOLD = 10000 POINT_SIZE_SCATTERPLOT_WIDGET = 6 CIRCLES_AS_POINTS = True +PROJECT_3D_POINTS_TO_2D = True +PROJECT_2_5D_SHAPES_TO_2D = True diff --git a/src/napari_spatialdata/utils/_utils.py b/src/napari_spatialdata/utils/_utils.py index 03399977..cfe37e23 100644 --- a/src/napari_spatialdata/utils/_utils.py +++ b/src/napari_spatialdata/utils/_utils.py @@ -181,7 +181,9 @@ def _transform_coordinates(data: list[Any], f: Callable[..., Any]) -> list[Any]: return [[f(xy) for xy in sublist] for sublist in data] -def _get_transform(element: SpatialElement, coordinate_system_name: str | None = None) -> None | ArrayLike: +def _get_transform( + element: SpatialElement, coordinate_system_name: str | None = None, include_z: bool | None = None +) -> None | ArrayLike: if not isinstance(element, DataArray | DataTree | DaskDataFrame | GeoDataFrame): raise RuntimeError("Cannot get transform for {type(element)}") @@ -189,7 +191,10 @@ def _get_transform(element: SpatialElement, coordinate_system_name: str | None = cs = transformations.keys().__iter__().__next__() if coordinate_system_name is None else coordinate_system_name ct = transformations.get(cs) if ct: - return ct.to_affine_matrix(input_axes=("y", "x"), output_axes=("y", "x")) + axes_element = get_axes_names(element) + include_z = include_z and "z" in axes_element + axes_transformation = ("z", "y", "x") if include_z else ("y", "x") + return ct.to_affine_matrix(input_axes=axes_transformation, output_axes=axes_transformation) return None diff --git a/src/napari_spatialdata/utils/_viewer_utils.py b/src/napari_spatialdata/utils/_viewer_utils.py index 17a724f8..3b186228 100644 --- a/src/napari_spatialdata/utils/_viewer_utils.py +++ b/src/napari_spatialdata/utils/_viewer_utils.py @@ -1,18 +1,59 @@ +from typing import cast + from geopandas import GeoDataFrame +from spatialdata.models import get_axes_names + + +def add_z_to_list_of_xy_tuples(xy: list[tuple[float, float]], z: float) -> list[tuple[float, float, float]]: + """ + Add z coordinates to a list of (x, y) tuples. + + Parameters + ---------- + xy + List of (x, y) tuples. + z + z coordinate to add to each tuple. + + Returns + ------- + list[tuple[float, float, float]] + List of (x, y, z) tuples. + """ + return [(x, y, z) for x, y in xy] + + +# type aliases, only used in this module +Coord2D = tuple[float, float] +Coord3D = tuple[float, float, float] +Polygon2D = list[Coord2D] +Polygon3D = list[Coord3D] +Polygon = Polygon2D | Polygon3D + + +def _get_polygons_properties(df: GeoDataFrame, simplify: bool, include_z: bool) -> tuple[list[Polygon], list[int]]: + # assumes no "Polygon Z": z is in separate column if present + indices: list[int] = [] + polygons: list[Polygon] = [] + + axes = get_axes_names(df) + add_z = include_z and "z" in axes + + for i in range(len(df)): + indices.append(int(df.index[i])) + if simplify: + xy = cast(list[Coord2D], list(df.geometry.iloc[i].exterior.simplify(tolerance=2).coords)) + else: + xy = cast(list[Coord2D], list(df.geometry.iloc[i].exterior.coords)) -def _get_polygons_properties(df: GeoDataFrame, simplify: bool) -> tuple[list[list[tuple[float, float]]], list[int]]: - indices = [] - polygons = [] + coords: Polygon2D | Polygon3D + if add_z: + z_val = float(df.iloc[i].z.item() if hasattr(df.iloc[i].z, "item") else df.iloc[i].z) + coords = add_z_to_list_of_xy_tuples(xy=xy, z=z_val) + else: + coords = xy - if simplify: - for i in range(0, len(df)): - indices.append(df.iloc[i].name) - # This can be removed once napari is sped up in the plotting. It changes the shapes only very slightly - polygons.append(list(df.geometry.iloc[i].exterior.simplify(tolerance=2).coords)) - else: - for i in range(0, len(df)): - indices.append(df.iloc[i].name) - polygons.append(list(df.geometry.iloc[i].exterior.coords)) + polygons.append(coords) return polygons, indices