|
23 | 23 | from spatialdata._core._elements import Images, Labels, Points, Shapes, Tables |
24 | 24 | from spatialdata._logging import logger |
25 | 25 | from spatialdata._types import ArrayLike, Raster_T |
26 | | -from spatialdata._utils import _deprecation_alias, _error_message_add_element |
| 26 | +from spatialdata._utils import ( |
| 27 | + _deprecation_alias, |
| 28 | + _error_message_add_element, |
| 29 | +) |
27 | 30 | from spatialdata.models import ( |
28 | 31 | Image2DModel, |
29 | 32 | Image3DModel, |
|
36 | 39 | get_model, |
37 | 40 | get_table_keys, |
38 | 41 | ) |
39 | | -from spatialdata.models._utils import SpatialElement, convert_region_column_to_categorical, get_axes_names |
| 42 | +from spatialdata.models._utils import ( |
| 43 | + SpatialElement, |
| 44 | + convert_region_column_to_categorical, |
| 45 | + get_axes_names, |
| 46 | + set_channel_names, |
| 47 | +) |
40 | 48 |
|
41 | 49 | if TYPE_CHECKING: |
42 | 50 | from spatialdata._core.query.spatial_query import BaseSpatialRequest |
@@ -315,6 +323,26 @@ def get_instance_key_column(table: AnnData) -> pd.Series: |
315 | 323 | return table.obs[instance_key] |
316 | 324 | raise KeyError(f"{instance_key} is set as instance key column. However the column is not found in table.obs.") |
317 | 325 |
|
| 326 | + def set_channel_names(self, element_name: str, channel_names: str | list[str], write: bool = False) -> None: |
| 327 | + """Set the channel names for a image `SpatialElement` in the `SpatialData` object. |
| 328 | +
|
| 329 | + This method assumes that the `SpatialData` object and the element are already stored on disk as it will |
| 330 | + also overwrite the channel names metadata on disk. In case either the `SpatialData` object or the |
| 331 | + element are not stored on disk, please use `SpatialData.set_image_channel_names` instead. |
| 332 | +
|
| 333 | + Parameters |
| 334 | + ---------- |
| 335 | + element_name |
| 336 | + Name of the image `SpatialElement`. |
| 337 | + channel_names |
| 338 | + The channel names to be assigned to the c dimension of the image `SpatialElement`. |
| 339 | + write |
| 340 | + Whether to overwrite the channel metadata on disk. |
| 341 | + """ |
| 342 | + self.images[element_name] = set_channel_names(self.images[element_name], channel_names) |
| 343 | + if write: |
| 344 | + self.write_channel_names(element_name) |
| 345 | + |
318 | 346 | @staticmethod |
319 | 347 | def _set_table_annotation_target( |
320 | 348 | table: AnnData, |
@@ -1441,6 +1469,45 @@ def _validate_can_write_metadata_on_element(self, element_name: str) -> tuple[st |
1441 | 1469 | ) |
1442 | 1470 | return element_type, element |
1443 | 1471 |
|
| 1472 | + def write_channel_names(self, element_name: str | None = None) -> None: |
| 1473 | + """ |
| 1474 | + Write channel names to disk for a single image element, or for all image elements, without rewriting the data. |
| 1475 | +
|
| 1476 | + Parameters |
| 1477 | + ---------- |
| 1478 | + element_name |
| 1479 | + The name of the element to write the channel names of. If None, write the channel names of all image |
| 1480 | + elements. |
| 1481 | + """ |
| 1482 | + from spatialdata._core._elements import Elements |
| 1483 | + |
| 1484 | + if element_name is not None: |
| 1485 | + Elements._check_valid_name(element_name) |
| 1486 | + |
| 1487 | + # recursively write the transformation for all the SpatialElement |
| 1488 | + if element_name is None: |
| 1489 | + for element_name in list(self.images.keys()): |
| 1490 | + self.write_channel_names(element_name) |
| 1491 | + return |
| 1492 | + |
| 1493 | + validation_result = self._validate_can_write_metadata_on_element(element_name) |
| 1494 | + if validation_result is None: |
| 1495 | + return |
| 1496 | + |
| 1497 | + element_type, element = validation_result |
| 1498 | + |
| 1499 | + # Mypy does not understand that path is not None so we have the check in the conditional |
| 1500 | + if element_type == "images" and self.path is not None: |
| 1501 | + _, _, element_group = self._get_groups_for_element( |
| 1502 | + zarr_path=Path(self.path), element_type=element_type, element_name=element_name |
| 1503 | + ) |
| 1504 | + |
| 1505 | + from spatialdata._io._utils import overwrite_channel_names |
| 1506 | + |
| 1507 | + overwrite_channel_names(element_group, element) |
| 1508 | + else: |
| 1509 | + raise ValueError(f"Can't set channel names for element of type '{element_type}'.") |
| 1510 | + |
1444 | 1511 | def write_transformations(self, element_name: str | None = None) -> None: |
1445 | 1512 | """ |
1446 | 1513 | Write transformations to disk for a single element, or for all elements, without rewriting the data. |
@@ -1471,6 +1538,7 @@ def write_transformations(self, element_name: str | None = None) -> None: |
1471 | 1538 | transformations = get_transformation(element, get_all=True) |
1472 | 1539 | assert isinstance(transformations, dict) |
1473 | 1540 |
|
| 1541 | + # Mypy does not understand that path is not None so we have a conditional |
1474 | 1542 | assert self.path is not None |
1475 | 1543 | _, _, element_group = self._get_groups_for_element( |
1476 | 1544 | zarr_path=Path(self.path), element_type=element_type, element_name=element_name |
@@ -1546,9 +1614,9 @@ def write_metadata(self, element_name: str | None = None, consolidate_metadata: |
1546 | 1614 | Elements._check_valid_name(element_name) |
1547 | 1615 |
|
1548 | 1616 | self.write_transformations(element_name) |
| 1617 | + self.write_channel_names(element_name) |
1549 | 1618 | # TODO: write .uns['spatialdata_attrs'] metadata for AnnData. |
1550 | 1619 | # TODO: write .attrs['spatialdata_attrs'] metadata for DaskDataFrame. |
1551 | | - # TODO: write omero metadata for the channel name of images. |
1552 | 1620 |
|
1553 | 1621 | if consolidate_metadata is None and self.has_consolidated_metadata(): |
1554 | 1622 | consolidate_metadata = True |
|
0 commit comments