Skip to content

Commit 94f0a31

Browse files
authored
change channel names (#786)
* change channel names * write channel meta * add tests * add to docstring * adjust to multiscale_spatial_image * refactor * adjust test * correct refactor and test * add get_channel_names * local import get_model * import set_channel_names * update log and api.md * local import
1 parent 62e4699 commit 94f0a31

File tree

15 files changed

+260
-26
lines changed

15 files changed

+260
-26
lines changed

CHANGELOG.md

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,18 @@ and this project adheres to [Semantic Versioning][].
1010

1111
## [0.2.6] - TBD
1212

13+
### Added
14+
15+
- Added `set_channel_names` method to `SpatialData` to change the channel names of an
16+
image element in `SpatialData`
17+
- Added `write_channel_names` method to `SpatialData` to overwrite channel metadata on disk
18+
without overwriting the image array itself.
19+
20+
### Changed
21+
22+
- `get_channels` is marked for deprecation in `SpatialData` v0.3.0. Function is replaced
23+
by `get_channel_names`
24+
1325
### Fixed
1426

1527
- Updated deprecated default stages of `pre-commit` #771

docs/api.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,8 @@ The elements (building-blocks) that constitute `SpatialData`.
8989
points_geopandas_to_dask_dataframe
9090
points_dask_dataframe_to_geopandas
9191
get_channels
92+
get_channel_names
93+
set_channel_names
9294
force_2d
9395
```
9496

src/spatialdata/_core/operations/map.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from dask.array.overlap import coerce_depth
99
from xarray import DataArray, DataTree
1010

11-
from spatialdata.models._utils import get_axes_names, get_channels, get_raster_model_from_data_dims
11+
from spatialdata.models._utils import get_axes_names, get_channel_names, get_raster_model_from_data_dims
1212
from spatialdata.transformations import get_transformation
1313

1414
__all__ = ["map_raster"]
@@ -121,7 +121,7 @@ def map_raster(
121121

122122
if "c" in dims:
123123
if c_coords is None:
124-
c_coords = range(arr.shape[0]) if arr.shape[0] != len(get_channels(data)) else get_channels(data)
124+
c_coords = range(arr.shape[0]) if arr.shape[0] != len(get_channel_names(data)) else get_channel_names(data)
125125
else:
126126
c_coords = None
127127
if transformations is None:

src/spatialdata/_core/operations/transform.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
from spatialdata._core.spatialdata import SpatialData
1818
from spatialdata._types import ArrayLike
1919
from spatialdata.models import SpatialElement, get_axes_names, get_model
20-
from spatialdata.models._utils import DEFAULT_COORDINATE_SYSTEM, get_channels
20+
from spatialdata.models._utils import DEFAULT_COORDINATE_SYSTEM, get_channel_names
2121
from spatialdata.transformations._utils import _get_scale, compute_coordinates, scale_radii
2222

2323
if TYPE_CHECKING:
@@ -367,7 +367,7 @@ def _(
367367
channel_names = None
368368
elif schema in (Image2DModel, Image3DModel):
369369
kwargs = {}
370-
channel_names = get_channels(data)
370+
channel_names = get_channel_names(data)
371371
else:
372372
raise ValueError(f"DataTree with schema {schema} not supported")
373373

src/spatialdata/_core/spatialdata.py

Lines changed: 71 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,10 @@
2323
from spatialdata._core._elements import Images, Labels, Points, Shapes, Tables
2424
from spatialdata._logging import logger
2525
from spatialdata._types import ArrayLike, Raster_T
26-
from spatialdata._utils import _deprecation_alias, _error_message_add_element
26+
from spatialdata._utils import (
27+
_deprecation_alias,
28+
_error_message_add_element,
29+
)
2730
from spatialdata.models import (
2831
Image2DModel,
2932
Image3DModel,
@@ -36,7 +39,12 @@
3639
get_model,
3740
get_table_keys,
3841
)
39-
from spatialdata.models._utils import SpatialElement, convert_region_column_to_categorical, get_axes_names
42+
from spatialdata.models._utils import (
43+
SpatialElement,
44+
convert_region_column_to_categorical,
45+
get_axes_names,
46+
set_channel_names,
47+
)
4048

4149
if TYPE_CHECKING:
4250
from spatialdata._core.query.spatial_query import BaseSpatialRequest
@@ -315,6 +323,26 @@ def get_instance_key_column(table: AnnData) -> pd.Series:
315323
return table.obs[instance_key]
316324
raise KeyError(f"{instance_key} is set as instance key column. However the column is not found in table.obs.")
317325

326+
def set_channel_names(self, element_name: str, channel_names: str | list[str], write: bool = False) -> None:
327+
"""Set the channel names for a image `SpatialElement` in the `SpatialData` object.
328+
329+
This method assumes that the `SpatialData` object and the element are already stored on disk as it will
330+
also overwrite the channel names metadata on disk. In case either the `SpatialData` object or the
331+
element are not stored on disk, please use `SpatialData.set_image_channel_names` instead.
332+
333+
Parameters
334+
----------
335+
element_name
336+
Name of the image `SpatialElement`.
337+
channel_names
338+
The channel names to be assigned to the c dimension of the image `SpatialElement`.
339+
write
340+
Whether to overwrite the channel metadata on disk.
341+
"""
342+
self.images[element_name] = set_channel_names(self.images[element_name], channel_names)
343+
if write:
344+
self.write_channel_names(element_name)
345+
318346
@staticmethod
319347
def _set_table_annotation_target(
320348
table: AnnData,
@@ -1441,6 +1469,45 @@ def _validate_can_write_metadata_on_element(self, element_name: str) -> tuple[st
14411469
)
14421470
return element_type, element
14431471

1472+
def write_channel_names(self, element_name: str | None = None) -> None:
1473+
"""
1474+
Write channel names to disk for a single image element, or for all image elements, without rewriting the data.
1475+
1476+
Parameters
1477+
----------
1478+
element_name
1479+
The name of the element to write the channel names of. If None, write the channel names of all image
1480+
elements.
1481+
"""
1482+
from spatialdata._core._elements import Elements
1483+
1484+
if element_name is not None:
1485+
Elements._check_valid_name(element_name)
1486+
1487+
# recursively write the transformation for all the SpatialElement
1488+
if element_name is None:
1489+
for element_name in list(self.images.keys()):
1490+
self.write_channel_names(element_name)
1491+
return
1492+
1493+
validation_result = self._validate_can_write_metadata_on_element(element_name)
1494+
if validation_result is None:
1495+
return
1496+
1497+
element_type, element = validation_result
1498+
1499+
# Mypy does not understand that path is not None so we have the check in the conditional
1500+
if element_type == "images" and self.path is not None:
1501+
_, _, element_group = self._get_groups_for_element(
1502+
zarr_path=Path(self.path), element_type=element_type, element_name=element_name
1503+
)
1504+
1505+
from spatialdata._io._utils import overwrite_channel_names
1506+
1507+
overwrite_channel_names(element_group, element)
1508+
else:
1509+
raise ValueError(f"Can't set channel names for element of type '{element_type}'.")
1510+
14441511
def write_transformations(self, element_name: str | None = None) -> None:
14451512
"""
14461513
Write transformations to disk for a single element, or for all elements, without rewriting the data.
@@ -1471,6 +1538,7 @@ def write_transformations(self, element_name: str | None = None) -> None:
14711538
transformations = get_transformation(element, get_all=True)
14721539
assert isinstance(transformations, dict)
14731540

1541+
# Mypy does not understand that path is not None so we have a conditional
14741542
assert self.path is not None
14751543
_, _, element_group = self._get_groups_for_element(
14761544
zarr_path=Path(self.path), element_type=element_type, element_name=element_name
@@ -1546,9 +1614,9 @@ def write_metadata(self, element_name: str | None = None, consolidate_metadata:
15461614
Elements._check_valid_name(element_name)
15471615

15481616
self.write_transformations(element_name)
1617+
self.write_channel_names(element_name)
15491618
# TODO: write .uns['spatialdata_attrs'] metadata for AnnData.
15501619
# TODO: write .attrs['spatialdata_attrs'] metadata for DaskDataFrame.
1551-
# TODO: write omero metadata for the channel name of images.
15521620

15531621
if consolidate_metadata is None and self.has_consolidated_metadata():
15541622
consolidate_metadata = True

src/spatialdata/_io/_utils.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,27 @@ def overwrite_coordinate_transformations_raster(
104104
group.attrs["multiscales"] = multiscales
105105

106106

107+
def overwrite_channel_names(group: zarr.Group, element: DataArray | DataTree) -> None:
108+
"""Write channel metadata to a group."""
109+
if isinstance(element, DataArray):
110+
channel_names = element.coords["c"].data.tolist()
111+
else:
112+
channel_names = element["scale0"]["image"].coords["c"].data.tolist()
113+
114+
channel_metadata = [{"label": name} for name in channel_names]
115+
omero_meta = group.attrs["omero"]
116+
omero_meta["channels"] = channel_metadata
117+
group.attrs["omero"] = omero_meta
118+
multiscales_meta = group.attrs["multiscales"]
119+
if len(multiscales_meta) != 1:
120+
raise ValueError(
121+
f"Multiscale metadata must be of length one but got length {len(multiscales_meta)}. Data might"
122+
f"be corrupted."
123+
)
124+
multiscales_meta[0]["metadata"]["omero"]["channels"] = channel_metadata
125+
group.attrs["multiscales"] = multiscales_meta
126+
127+
107128
def _write_metadata(
108129
group: zarr.Group,
109130
group_type: str,

src/spatialdata/_io/io_raster.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
_parse_version,
2727
)
2828
from spatialdata._utils import get_pyramid_levels
29-
from spatialdata.models._utils import get_channels
29+
from spatialdata.models._utils import get_channel_names
3030
from spatialdata.models.models import ATTRS_KEY
3131
from spatialdata.transformations._utils import (
3232
_get_transformations,
@@ -151,7 +151,7 @@ def _get_group_for_writing_transformations() -> zarr.Group:
151151
# convert channel names to channel metadata in omero
152152
if raster_type == "image":
153153
metadata["metadata"] = {"omero": {"channels": []}}
154-
channels = get_channels(raster_data)
154+
channels = get_channel_names(raster_data)
155155
for c in channels:
156156
metadata["metadata"]["omero"]["channels"].append({"label": c}) # type: ignore[union-attr, index, call-overload]
157157

src/spatialdata/_utils.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
import pandas as pd
1212
from anndata import AnnData
1313
from dask import array as da
14+
from dask.array import Array as DaskArray
1415
from xarray import DataArray, Dataset, DataTree
1516

1617
from spatialdata._types import ArrayLike
@@ -311,3 +312,37 @@ def _error_message_add_element() -> None:
311312
"write_labels(), write_points(), write_shapes() and write_table(). We are going to make these calls more "
312313
"ergonomic in a follow up PR."
313314
)
315+
316+
317+
def _check_match_length_channels_c_dim(
318+
data: DaskArray | DataArray | DataTree, c_coords: str | list[str], dims: tuple[str]
319+
) -> list[str]:
320+
"""
321+
Check whether channel names `c_coords` are of equal length to the `c` dimension of the data.
322+
323+
Parameters
324+
----------
325+
data
326+
The image array
327+
c_coords
328+
The channel names
329+
dims
330+
The axes names in the order that is the same as the `ImageModel` from which it is derived.
331+
332+
Returns
333+
-------
334+
c_coords
335+
The channel names as list
336+
"""
337+
c_index = dims.index("c")
338+
c_length = (
339+
data.shape[c_index] if isinstance(data, DataArray | DaskArray) else data["scale0"]["image"].shape[c_index]
340+
)
341+
if isinstance(c_coords, str):
342+
c_coords = [c_coords]
343+
if c_coords is not None and len(c_coords) != c_length:
344+
raise ValueError(
345+
f"The number of channel names `{len(c_coords)}` does not match the length of dimension 'c'"
346+
f" with length {c_length}."
347+
)
348+
return c_coords

src/spatialdata/models/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,12 @@
66
Z,
77
force_2d,
88
get_axes_names,
9+
get_channel_names,
910
get_channels,
1011
get_spatial_axes,
1112
points_dask_dataframe_to_geopandas,
1213
points_geopandas_to_dask_dataframe,
14+
set_channel_names,
1315
validate_axes,
1416
validate_axis_name,
1517
)
@@ -49,6 +51,8 @@
4951
"check_target_region_column_symmetry",
5052
"get_table_keys",
5153
"get_channels",
54+
"get_channel_names",
55+
"set_channel_names",
5256
"force_2d",
5357
"RasterSchema",
5458
]

src/spatialdata/models/_utils.py

Lines changed: 65 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from xarray import DataArray, DataTree
1414

1515
from spatialdata._logging import logger
16+
from spatialdata._utils import _check_match_length_channels_c_dim
1617
from spatialdata.transformations.transformations import BaseTransformation
1718

1819
SpatialElement: TypeAlias = DataArray | DataTree | GeoDataFrame | DaskDataFrame
@@ -268,7 +269,7 @@ def points_geopandas_to_dask_dataframe(gdf: GeoDataFrame, suppress_z_warning: bo
268269

269270

270271
@singledispatch
271-
def get_channels(data: Any) -> list[Any]:
272+
def get_channel_names(data: Any) -> list[Any]:
272273
"""Get channels from data for an image element (both single and multiscale).
273274
274275
Parameters
@@ -287,12 +288,40 @@ def get_channels(data: Any) -> list[Any]:
287288
raise ValueError(f"Cannot get channels from {type(data)}")
288289

289290

290-
@get_channels.register
291+
def get_channels(data: Any) -> list[Any]:
292+
"""Get channels from data for an image element (both single and multiscale).
293+
294+
[Deprecation] This function will be deprecated in version 0.3.0. Please use
295+
`get_channel_names`.
296+
297+
Parameters
298+
----------
299+
data
300+
data to get channels from
301+
302+
Returns
303+
-------
304+
List of channels
305+
306+
Notes
307+
-----
308+
For multiscale images, the channels are validated to be consistent across scales.
309+
"""
310+
warnings.warn(
311+
"The function 'get_channels' is deprecated and will be removed in version 0.3.0. "
312+
"Please use 'get_channel_names' instead.",
313+
DeprecationWarning,
314+
stacklevel=2, # Adjust the stack level to point to the caller
315+
)
316+
return get_channel_names(data)
317+
318+
319+
@get_channel_names.register
291320
def _(data: DataArray) -> list[Any]:
292321
return data.coords["c"].values.tolist() # type: ignore[no-any-return]
293322

294323

295-
@get_channels.register
324+
@get_channel_names.register
296325
def _(data: DataTree) -> list[Any]:
297326
name = list({list(data[i].data_vars.keys())[0] for i in data})[0]
298327
channels = {tuple(data[i][name].coords["c"].values) for i in data}
@@ -374,3 +403,36 @@ def convert_region_column_to_categorical(table: AnnData) -> AnnData:
374403
)
375404
table.obs[region_key] = pd.Categorical(table.obs[region_key])
376405
return table
406+
407+
408+
def set_channel_names(element: DataArray | DataTree, channel_names: str | list[str]) -> DataArray | DataTree:
409+
"""Set the channel names for a image `SpatialElement` in the `SpatialData` object.
410+
411+
Parameters
412+
----------
413+
element
414+
The image `SpatialElement` or parsed `ImageModel`.
415+
channel_names
416+
The channel names to be assigned to the c dimension of the image `SpatialElement`.
417+
418+
Returns
419+
-------
420+
element
421+
The image `SpatialElement` or parsed `ImageModel` with the channel names set to the `c` dimension.
422+
"""
423+
from spatialdata.models import Image2DModel, Image3DModel, get_model
424+
425+
channel_names = channel_names if isinstance(channel_names, list) else [channel_names]
426+
model = get_model(element)
427+
428+
# get_model cannot be used due to circular import so get_axes_names is used instead
429+
if model in [Image2DModel, Image3DModel]:
430+
channel_names = _check_match_length_channels_c_dim(element, channel_names, model.dims.dims) # type: ignore[union-attr]
431+
if isinstance(element, DataArray):
432+
element = element.assign_coords(c=channel_names)
433+
else:
434+
element = element.msi.assign_coords({"c": channel_names})
435+
else:
436+
raise TypeError("Element model does not support setting channel names, no `c` dimension found.")
437+
438+
return element

0 commit comments

Comments
 (0)