Skip to content

Commit ef757d1

Browse files
Fix write validation against empty transformations (#1118)
* Fix write validation not catching empty (rather than None) transformations After remove_transformation(element, remove_all=True) the transformations dict is set to {} rather than None, bypassing the is-None guard in all three IO writers. Changed the check to `not transformations` so both None and empty dicts are caught, and added a parametrized regression test covering images, multiscale images, labels, multiscale labels, shapes, and points. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com> * Move empty-transformation validation into model validate() methods Instead of guarding in each IO writer, call get_model(element) (which already dispatches to the right schema and runs validate()) at the start of _write_element() for all non-table spatial elements. Also fix the is-None guards in all three validate() methods to use `not transformations` / `not data.attrs.get(key)` so that an empty dict {} is caught in addition to None. The IO-level guards added in the previous commit are removed since they are now superseded by the model-level check; assert statements are kept to narrow the type for mypy. The regression test is updated to reflect the correct production scenario: element is already inside a SpatialData object when its transformations are removed in-place, so the error fires during write() not at construction. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com> * Move empty-transformation validation into model validate() methods Instead of guarding in each IO writer, call validate_element(element) (a new public helper in spatialdata.models that delegates to get_model) at the start of _write_element() for all non-table spatial elements. Validation changes in models.py: - RasterSchema._check_transforms_present: two explicit checks — one for None (key absent) and one for empty dict, with separate messages - ShapesModel.validate / PointsModel.validate: same split into two checks - asserts in IO files are kept solely for mypy type-narrowing, each annotated with a comment explaining that validate_element() guarantees the invariant at runtime New public API: - spatialdata.models.validate_element(e) raises ValueError if the element fails schema validation; documented in docs/api/models_utils.md Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com> --------- Co-authored-by: Claude Sonnet 4.6 <noreply@anthropic.com>
1 parent 5ead6cc commit ef757d1

8 files changed

Lines changed: 70 additions & 8 deletions

File tree

docs/api/models_utils.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
.. currentmodule:: spatialdata.models
55
66
.. autofunction:: get_model
7+
.. autofunction:: validate_element
78
.. autodata:: SpatialElement
89
.. autofunction:: get_axes_names
910
.. autofunction:: get_spatial_axes

src/spatialdata/_core/spatialdata.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1225,6 +1225,11 @@ def _write_element(
12251225
if parsed_formats is None:
12261226
parsed_formats = _parse_formats(formats=parsed_formats)
12271227

1228+
if element_type != "tables":
1229+
from spatialdata.models import validate_element
1230+
1231+
validate_element(element)
1232+
12281233
if element_type == "images":
12291234
write_image(
12301235
image=element,

src/spatialdata/_io/io_points.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,7 @@ def write_points(
6767
"""
6868
axes = get_axes_names(points)
6969
transformations = _get_transformations(points)
70+
assert transformations is not None # mypy: validate_element() in _write_element guarantees this
7071

7172
store_root = group.store_path.store.root
7273
path = store_root / group.path / "points.parquet"
@@ -95,6 +96,4 @@ def write_points(
9596
axes=list(axes),
9697
attrs=attrs,
9798
)
98-
if transformations is None:
99-
raise ValueError(f"No transformations specified for element '{group.basename}'. Cannot write.")
10099
overwrite_coordinate_transformations_non_raster(group=group, axes=axes, transformations=transformations)

src/spatialdata/_io/io_raster.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -369,8 +369,7 @@ def _write_raster_dataarray(
369369

370370
data = raster_data.data
371371
transformations = _get_transformations(raster_data)
372-
if transformations is None:
373-
raise ValueError(f"{element_name} does not have any transformations and can therefore not be written.")
372+
assert transformations is not None # mypy: validate_element() in _write_element guarantees this
374373
input_axes: tuple[str, ...] = tuple(raster_data.dims)
375374
parsed_axes = _get_valid_axes(axes=list(input_axes), fmt=raster_format)
376375
storage_options = _prepare_storage_options(storage_options)
@@ -439,8 +438,7 @@ def _write_raster_datatree(
439438
assert len(d) == 1
440439
xdata = d.values().__iter__().__next__()
441440
transformations = _get_transformations_xarray(xdata)
442-
if transformations is None:
443-
raise ValueError(f"{element_name} does not have any transformations and can therefore not be written.")
441+
assert transformations is not None # mypy: validate_element() in _write_element guarantees this
444442

445443
parsed_axes = _get_valid_axes(axes=list(input_axes), fmt=raster_format)
446444
storage_options = _prepare_storage_options(storage_options)

src/spatialdata/_io/io_shapes.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -100,8 +100,7 @@ def write_shapes(
100100

101101
axes = get_axes_names(shapes)
102102
transformations = _get_transformations(shapes)
103-
if transformations is None:
104-
raise ValueError(f"{group.basename} does not have any transformations and can therefore not be written.")
103+
assert transformations is not None # mypy: validate_element() in _write_element guarantees this
105104
if isinstance(element_format, ShapesFormatV01):
106105
attrs = _write_shapes_v01(shapes, group, element_format)
107106
elif isinstance(element_format, ShapesFormatV02 | ShapesFormatV03):

src/spatialdata/models/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
TableModel,
2828
get_model,
2929
get_table_keys,
30+
validate_element,
3031
)
3132

3233
__all__ = [
@@ -51,6 +52,7 @@
5152
"points_dask_dataframe_to_geopandas",
5253
"check_target_region_column_symmetry",
5354
"get_table_keys",
55+
"validate_element",
5456
"get_channel_names",
5557
"set_channel_names",
5658
"force_2d",

src/spatialdata/models/models.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -347,6 +347,11 @@ def _check_transforms_present(cls, data: DataArray | DataTree) -> None:
347347
f"No transformation found for `{data}`. At least one transformation is required for "
348348
f"raster elements, e.g. images, labels."
349349
)
350+
if len(parsed_transform) == 0:
351+
raise ValueError(
352+
f"The transformations dict for `{data}` is empty. At least one transformation is required for "
353+
f"raster elements, e.g. images, labels."
354+
)
350355

351356
@classmethod
352357
def _check_chunk_size_not_too_large(cls, data: DataArray | DataTree) -> None:
@@ -479,6 +484,11 @@ def validate(cls, data: GeoDataFrame) -> None:
479484
)
480485
if cls.TRANSFORM_KEY not in data.attrs:
481486
raise ValueError(f":class:`geopandas.GeoDataFrame` does not contain `{TRANSFORM_KEY}`." + SUGGESTION)
487+
if not data.attrs[cls.TRANSFORM_KEY]:
488+
raise ValueError(
489+
f":class:`geopandas.GeoDataFrame` has an empty `{TRANSFORM_KEY}` dict. "
490+
f"At least one transformation is required." + SUGGESTION
491+
)
482492
if len(data) > 0:
483493
n = data.geometry.iloc[0]._ndim
484494
if n != 2:
@@ -672,6 +682,11 @@ def validate(cls, data: DaskDataFrame) -> None:
672682
raise ValueError(
673683
f":attr:`dask.dataframe.core.DataFrame.attrs` does not contain `{cls.TRANSFORM_KEY}`." + SUGGESTION
674684
)
685+
if not data.attrs[cls.TRANSFORM_KEY]:
686+
raise ValueError(
687+
f":attr:`dask.dataframe.core.DataFrame.attrs` has an empty `{cls.TRANSFORM_KEY}` dict. "
688+
f"At least one transformation is required." + SUGGESTION
689+
)
675690
if ATTRS_KEY in data.attrs and "feature_key" in data.attrs[ATTRS_KEY]:
676691
feature_key = data.attrs[ATTRS_KEY][cls.FEATURE_KEY]
677692
if feature_key not in data.columns:
@@ -1293,6 +1308,23 @@ def _validate_and_return(
12931308
raise TypeError(f"Unsupported type {type(e)}")
12941309

12951310

1311+
def validate_element(e: SpatialElement) -> None:
1312+
"""
1313+
Validate a spatial element against its model schema.
1314+
1315+
Parameters
1316+
----------
1317+
e
1318+
The spatial element to validate.
1319+
1320+
Raises
1321+
------
1322+
ValueError
1323+
If the element is invalid (e.g. missing or empty transformations, wrong dtypes).
1324+
"""
1325+
get_model(e, validate=True)
1326+
1327+
12961328
def get_table_keys(table: AnnData) -> tuple[str | list[str], str, str]:
12971329
"""
12981330
Get the table keys giving information about what spatial element is annotated.

tests/core/operations/test_transform.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -796,3 +796,29 @@ def test_transform_until_0_0_15(points):
796796

797797
transform(points, transformation=t0, maintain_positioning=True)
798798
transform(points, to_coordinate_system="global", maintain_positioning=True)
799+
800+
801+
@pytest.mark.parametrize(
802+
"element_fixture,kwargs",
803+
[
804+
("image2d", {"images": {}}),
805+
("image2d_multiscale", {"images": {}}),
806+
("labels2d", {"labels": {}}),
807+
("labels2d_multiscale", {"labels": {}}),
808+
("circles", {"shapes": {}}),
809+
("points_0", {"points": {}}),
810+
],
811+
)
812+
def test_write_fails_after_removing_all_transformations(
813+
full_sdata: SpatialData, tmp_path: Path, element_fixture: str, kwargs: dict
814+
) -> None:
815+
"""Writing should fail when all transformations are removed from an element already in a SpatialData."""
816+
# Build a valid SpatialData first (passes __setitem__ validation)
817+
container_key = next(iter(kwargs))
818+
sdata = SpatialData(**{container_key: {element_fixture: full_sdata[element_fixture]}})
819+
820+
# Mutate in-place after construction, bypassing __setitem__ validation
821+
remove_transformation(sdata[element_fixture], remove_all=True)
822+
823+
with pytest.raises(ValueError, match="transform"):
824+
sdata.write(tmp_path / "sdata.zarr")

0 commit comments

Comments
 (0)