Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ install_requires =
anndata
click
cycler
dask>=2024.4.1,<=2024.11.2
dask>=2025.2.0
geopandas
loguru
matplotlib
Expand Down
31 changes: 18 additions & 13 deletions src/napari_spatialdata/_viewer.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,11 @@
from spatialdata import get_element_annotators, get_element_instances
from spatialdata._core.query.relational_query import _left_join_spatialelement_table
from spatialdata._types import ArrayLike
from spatialdata.models import PointsModel, ShapesModel, TableModel, force_2d, get_channel_names
from spatialdata.models import PointsModel, ShapesModel, TableModel, force_2d, get_axes_names, get_channel_names
from spatialdata.transformations import Affine, Identity

from napari_spatialdata._model import DataModel
from napari_spatialdata.constants import config
from napari_spatialdata.constants.config import CIRCLES_AS_POINTS
from napari_spatialdata.utils._utils import (
_adjust_channels_order,
_get_ellipses_from_circles,
Expand Down Expand Up @@ -470,7 +469,7 @@ def get_sdata_image(self, sdata: SpatialData, key: str, selected_cs: str, multi:
if multi:
original_name = original_name[: original_name.rfind("_")]

affine = _get_transform(sdata.images[original_name], selected_cs)
affine = _get_transform(sdata.images[original_name], selected_cs, include_z=True)
rgb_image, rgb = _adjust_channels_order(element=sdata.images[original_name])

channels = ("RGB(A)",) if rgb else get_channel_names(sdata.images[original_name])
Expand Down Expand Up @@ -517,6 +516,7 @@ def get_sdata_circles(self, sdata: SpatialData, key: str, selected_cs: str, mult
df = sdata.shapes[original_name]
affine = _get_transform(sdata.shapes[original_name], selected_cs)

# 2.5D circles not supported yet
xy = np.array([df.geometry.x, df.geometry.y]).T
yx = np.fliplr(xy)
radii = df.radius.to_numpy()
Expand All @@ -541,10 +541,10 @@ def get_sdata_circles(self, sdata: SpatialData, key: str, selected_cs: str, mult
version = get_napari_version()
kwargs: dict[str, Any] = (
{"edge_width": 0.0}
if version <= packaging.version.parse("0.4.20") or not CIRCLES_AS_POINTS
if version <= packaging.version.parse("0.4.20") or not config.CIRCLES_AS_POINTS
else {"border_width": 0.0}
)
if CIRCLES_AS_POINTS:
if config.CIRCLES_AS_POINTS:
layer = Points(
yx,
name=key,
Expand All @@ -556,7 +556,7 @@ def get_sdata_circles(self, sdata: SpatialData, key: str, selected_cs: str, mult
assert affine is not None
self._adjust_radii_of_points_layer(layer=layer, affine=affine)
else:
if version <= packaging.version.parse("0.4.20") or not CIRCLES_AS_POINTS:
if version <= packaging.version.parse("0.4.20") or not config.CIRCLES_AS_POINTS:
kwargs |= {"edge_color": "white"}
else:
kwargs |= {"border_color": "white"}
Expand Down Expand Up @@ -597,7 +597,8 @@ def get_sdata_shapes(self, sdata: SpatialData, key: str, selected_cs: str, multi
original_name = original_name[: original_name.rfind("_")]

df = sdata.shapes[original_name]
affine = _get_transform(sdata.shapes[original_name], selected_cs)
include_z = not config.PROJECT_2_5D_SHAPES_TO_2D
affine = _get_transform(sdata.shapes[original_name], selected_cs, include_z=include_z)

# when mulitpolygons are present, we select the largest ones
if "MultiPolygon" in np.unique(df.geometry.type):
Expand All @@ -609,7 +610,7 @@ def get_sdata_shapes(self, sdata: SpatialData, key: str, selected_cs: str, multi
df = df.sort_index() # reset the index to the first order

simplify = len(df) > config.POLYGON_THRESHOLD
polygons, indices = _get_polygons_properties(df, simplify)
polygons, indices = _get_polygons_properties(df, simplify, include_z=include_z)

# this will only work for polygons and not for multipolygons
polygons = _transform_coordinates(polygons, f=lambda x: x[::-1])
Expand Down Expand Up @@ -662,7 +663,7 @@ def get_sdata_labels(self, sdata: SpatialData, key: str, selected_cs: str, multi
original_name = original_name[: original_name.rfind("_")]

indices = get_element_instances(sdata.labels[original_name])
affine = _get_transform(sdata.labels[original_name], selected_cs)
affine = _get_transform(sdata.labels[original_name], selected_cs, include_z=True)
rgb_labels, _ = _adjust_channels_order(element=sdata.labels[original_name])

adata, table_name, table_names = self._get_table_data(sdata, original_name)
Expand Down Expand Up @@ -706,8 +707,10 @@ def get_sdata_points(self, sdata: SpatialData, key: str, selected_cs: str, multi
if multi:
original_name = original_name[: original_name.rfind("_")]

axes = get_axes_names(sdata.points[original_name])
points = sdata.points[original_name].compute()
affine = _get_transform(sdata.points[original_name], selected_cs)
include_z = "z" in axes and not config.PROJECT_3D_POINTS_TO_2D
affine = _get_transform(sdata.points[original_name], selected_cs, include_z=include_z)
adata, table_name, table_names = self._get_table_data(sdata, original_name)

if len(points) < config.POINT_THRESHOLD:
Expand All @@ -727,14 +730,16 @@ def get_sdata_points(self, sdata: SpatialData, key: str, selected_cs: str, multi
_, adata = _left_join_spatialelement_table(
{"points": {original_name: subsample_points}}, sdata[table_name], match_rows="left"
)
xy = subsample_points[["y", "x"]].values
np.fliplr(xy)
axes = sorted(axes, reverse=True)
if not include_z and "z" in axes:
axes.remove("z")
coords = subsample_points[axes].values
# radii_size = _calc_default_radii(self.viewer, sdata, selected_cs)
radii_size = 3
version = get_napari_version()
kwargs = {"edge_width": 0.0} if version <= packaging.version.parse("0.4.20") else {"border_width": 0.0}
layer = Points(
xy,
coords,
name=key,
size=radii_size * 2,
affine=affine,
Expand Down
2 changes: 2 additions & 0 deletions src/napari_spatialdata/constants/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,5 @@
N_SHAPES_WARNING_THRESHOLD = 10000
POINT_SIZE_SCATTERPLOT_WIDGET = 6
CIRCLES_AS_POINTS = True
PROJECT_3D_POINTS_TO_2D = True
PROJECT_2_5D_SHAPES_TO_2D = True
9 changes: 7 additions & 2 deletions src/napari_spatialdata/utils/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,15 +181,20 @@ def _transform_coordinates(data: list[Any], f: Callable[..., Any]) -> list[Any]:
return [[f(xy) for xy in sublist] for sublist in data]


def _get_transform(element: SpatialElement, coordinate_system_name: str | None = None) -> None | ArrayLike:
def _get_transform(
element: SpatialElement, coordinate_system_name: str | None = None, include_z: bool | None = None
) -> None | ArrayLike:
if not isinstance(element, DataArray | DataTree | DaskDataFrame | GeoDataFrame):
raise RuntimeError("Cannot get transform for {type(element)}")

transformations = get_transformation(element, get_all=True)
cs = transformations.keys().__iter__().__next__() if coordinate_system_name is None else coordinate_system_name
ct = transformations.get(cs)
if ct:
return ct.to_affine_matrix(input_axes=("y", "x"), output_axes=("y", "x"))
axes_element = get_axes_names(element)
include_z = include_z and "z" in axes_element
axes_transformation = ("z", "y", "x") if include_z else ("y", "x")
return ct.to_affine_matrix(input_axes=axes_transformation, output_axes=axes_transformation)
return None


Expand Down
65 changes: 53 additions & 12 deletions src/napari_spatialdata/utils/_viewer_utils.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,59 @@
from typing import cast

from geopandas import GeoDataFrame
from spatialdata.models import get_axes_names


def add_z_to_list_of_xy_tuples(xy: list[tuple[float, float]], z: float) -> list[tuple[float, float, float]]:
"""
Add z coordinates to a list of (x, y) tuples.

Parameters
----------
xy
List of (x, y) tuples.
z
z coordinate to add to each tuple.

Returns
-------
list[tuple[float, float, float]]
List of (x, y, z) tuples.
"""
return [(x, y, z) for x, y in xy]


# type aliases, only used in this module
Coord2D = tuple[float, float]
Coord3D = tuple[float, float, float]
Polygon2D = list[Coord2D]
Polygon3D = list[Coord3D]
Polygon = Polygon2D | Polygon3D


def _get_polygons_properties(df: GeoDataFrame, simplify: bool, include_z: bool) -> tuple[list[Polygon], list[int]]:
# assumes no "Polygon Z": z is in separate column if present
indices: list[int] = []
polygons: list[Polygon] = []

axes = get_axes_names(df)
add_z = include_z and "z" in axes

for i in range(len(df)):
indices.append(int(df.index[i]))

if simplify:
xy = cast(list[Coord2D], list(df.geometry.iloc[i].exterior.simplify(tolerance=2).coords))
else:
xy = cast(list[Coord2D], list(df.geometry.iloc[i].exterior.coords))

def _get_polygons_properties(df: GeoDataFrame, simplify: bool) -> tuple[list[list[tuple[float, float]]], list[int]]:
indices = []
polygons = []
coords: Polygon2D | Polygon3D
if add_z:
z_val = float(df.iloc[i].z.item() if hasattr(df.iloc[i].z, "item") else df.iloc[i].z)
coords = add_z_to_list_of_xy_tuples(xy=xy, z=z_val)
else:
coords = xy

if simplify:
for i in range(0, len(df)):
indices.append(df.iloc[i].name)
# This can be removed once napari is sped up in the plotting. It changes the shapes only very slightly
polygons.append(list(df.geometry.iloc[i].exterior.simplify(tolerance=2).coords))
else:
for i in range(0, len(df)):
indices.append(df.iloc[i].name)
polygons.append(list(df.geometry.iloc[i].exterior.coords))
polygons.append(coords)

return polygons, indices
Loading