Skip to content

Commit 5e3b982

Browse files
Fix points repr not showing length (#1128)
* fix _search_for_backing_files_recursively() to support multifile parquet * Fix repr showing <Delayed> instead of row count for backed points (#1084) Replace broken dask graph introspection (which only worked for single-task graphs with a HighLevelGraph layer API that no longer exists) with get_dask_backing_files() + pyarrow footer metadata reads. This handles all graph shapes including the list-of-piece-dicts case produced by aggregate_files=True. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com> --------- Co-authored-by: Claude Sonnet 4.6 <noreply@anthropic.com>
1 parent 05d258e commit 5e3b982

2 files changed

Lines changed: 34 additions & 18 deletions

File tree

src/spatialdata/_core/spatialdata.py

Lines changed: 10 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from anndata import AnnData
1515
from annsel.core.typing import Predicates
1616
from dask.dataframe import DataFrame as DaskDataFrame
17-
from dask.dataframe import Scalar, read_parquet
17+
from dask.dataframe import Scalar
1818
from geopandas import GeoDataFrame
1919
from shapely import MultiPolygon, Polygon
2020
from upath import UPath
@@ -1979,21 +1979,14 @@ def h(s: str) -> str:
19791979
if attr == "shapes":
19801980
descr += f"{h(attr + 'level1.1')}{k!r}: {descr_class} shape: {v.shape} (2D shapes)"
19811981
elif attr == "points":
1982+
import pyarrow.parquet as pq
1983+
1984+
from spatialdata._io._utils import get_dask_backing_files
1985+
19821986
length: int | None = None
1983-
if len(v.dask) == 1:
1984-
name, layer = v.dask.items().__iter__().__next__()
1985-
if "read-parquet" in name:
1986-
t = layer.creation_info["args"]
1987-
assert isinstance(t, tuple)
1988-
assert len(t) == 1
1989-
parquet_file = t[0]
1990-
table = read_parquet(parquet_file)
1991-
length = len(table)
1992-
else:
1993-
# length = len(v)
1994-
length = None
1995-
else:
1996-
length = None
1987+
backing_files = get_dask_backing_files(v)
1988+
if backing_files:
1989+
length = sum(pq.read_metadata(f).num_rows for f in backing_files)
19971990

19981991
n = len(get_axes_names(v))
19991992
dim_string = f"({n}D points)"
@@ -2084,8 +2077,8 @@ def _element_path_to_element_name_with_type(element_path: str) -> str:
20842077
description = self.elements_are_self_contained()
20852078
for _, element_name, element in self.gen_elements():
20862079
if not description[element_name]:
2087-
backing_files = ", ".join(get_dask_backing_files(element))
2088-
descr += f"\n{element_name}: {backing_files}"
2080+
backing_files_str = ", ".join(get_dask_backing_files(element))
2081+
descr += f"\n{element_name}: {backing_files_str}"
20892082

20902083
if self.path is not None:
20912084
elements_only_in_sdata, elements_only_in_zarr = self._symmetric_difference_with_zarr_store()

tests/io/test_utils.py

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,14 @@
55
from contextlib import nullcontext
66

77
import dask.dataframe as dd
8+
import numpy as np
9+
import pandas as pd
810
import pytest
911
from upath import UPath
1012

11-
from spatialdata import read_zarr
13+
from spatialdata import SpatialData, read_zarr
1214
from spatialdata._io._utils import get_dask_backing_files, handle_read_errors
15+
from spatialdata.models import PointsModel
1316

1417

1518
def test_backing_files_points(points):
@@ -141,3 +144,23 @@ def test_handle_read_errors(on_bad_files: str, actual_error: Exception, expectat
141144
with handle_read_errors(on_bad_files=on_bad_files, location="location", exc_types=KeyError):
142145
if actual_error is not None:
143146
raise actual_error
147+
148+
149+
def test_repr_points_shows_row_count():
150+
"""repr() must show the concrete row count, not <Delayed>, for backed points."""
151+
with tempfile.TemporaryDirectory() as tmp:
152+
parquet_path = os.path.join(tmp, "points.parquet")
153+
n_rows = 400
154+
rng = np.random.default_rng(0)
155+
df = pd.DataFrame({"x": rng.random(n_rows), "y": rng.random(n_rows)})
156+
# aggregate_files=True produces a list-of-piece-dicts graph, the case reported in #1084
157+
dd.from_pandas(df, npartitions=4).to_parquet(parquet_path, write_index=False)
158+
ddf = dd.read_parquet(parquet_path, aggregate_files=True)
159+
160+
points = PointsModel.parse(ddf)
161+
sdata = SpatialData(points={"pts": points})
162+
sdata.write(os.path.join(tmp, "example.zarr"))
163+
164+
r = repr(sdata)
165+
assert f"({n_rows}," in r, f"expected row count {n_rows} in repr, got: {r}"
166+
assert "<Delayed>" not in r, f"repr still contains <Delayed>: {r}"

0 commit comments

Comments
 (0)