Skip to content

Commit 68e314e

Browse files
committed
fix transform points with duplicate indices
1 parent d8bf265 commit 68e314e

2 files changed

Lines changed: 78 additions & 12 deletions

File tree

src/spatialdata/_core/operations/transform.py

Lines changed: 30 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,11 @@
66
from functools import singledispatch
77
from typing import TYPE_CHECKING, Any, cast
88

9+
import dask
910
import dask.array as da
11+
import dask.dataframe as dd
1012
import dask_image.ndinterp
1113
import numpy as np
12-
import pandas as pd
1314
from dask.array.core import Array as DaskArray
1415
from dask.dataframe import DataFrame as DaskDataFrame
1516
from geopandas import GeoDataFrame
@@ -442,29 +443,46 @@ def _(
442443
axes = get_axes_names(data)
443444
arrays = []
444445

445-
# Workaround to prevent partition collaps and missing dependency problem for now.
446+
# Dask's expression optimizer can collapse partitions at compute time, making the partition
447+
# structure inside vs. outside a disable_dask_tune_optimization() context inconsistent. To avoid
448+
# index-alignment failures (e.g. "cannot reindex on an axis with duplicate labels" from parquet
449+
# files that each start their index at 0) and length-mismatch errors, we materialise the non-axis
450+
# columns and compute the axis arrays inside a single context where the partition structure is
451+
# stable, then do plain pandas operations and re-wrap with dd.from_delayed (not dd.from_pandas,
452+
# which sorts by index and would scramble rows for non-monotonic or duplicate indices).
446453
with disable_dask_tune_optimization() if data.npartitions > 1 else contextlib.nullcontext():
454+
lengths = [len(part) for part in data.partitions]
447455
for ax in axes:
448456
# TODO We have to pass on the lengths explicitly as automatic determination with dask graph optimization
449-
# leads to collaps of the partitions. However this causes a missing dependency problem, which for now is
457+
# leads to collapse of the partitions. However this causes a missing dependency problem, which for now is
450458
# prevented by setting the optimization to False when performing this operation.
451-
arrays.append(data[ax].to_dask_array(lengths=[len(part) for part in data.partitions]).reshape(-1, 1))
459+
arrays.append(data[ax].to_dask_array(lengths=lengths).reshape(-1, 1))
452460

453-
xdata = DataArray(da.concatenate(arrays, axis=1), coords={"points": range(len(data)), "dim": list(axes)})
454-
xtransformed = transformation._transform_coordinates(xdata)
455-
transformed = data.drop(columns=list(axes)).copy()
456-
# dummy transformation that will be replaced by _adjust_transformation()
457-
default_cs = {DEFAULT_COORDINATE_SYSTEM: Identity()}
458-
transformed.attrs[TRANSFORM_KEY] = default_cs
461+
xdata = DataArray(da.concatenate(arrays, axis=1), coords={"points": range(sum(lengths)), "dim": list(axes)})
462+
xtransformed = transformation._transform_coordinates(xdata)
463+
464+
# Compute non-axis columns while the partition structure is still stable; preserves original index.
465+
transformed_pd = data.drop(columns=list(axes)).compute()
459466

460467
for ax in axes:
461468
indices = xtransformed["dim"] == ax
462469
new_ax = xtransformed[:, indices]
463470
# TODO: discuss with dask team
464471
# This is not nice, but otherwise there is a problem with the joint graph of new_ax and transformed, causing
465472
# a getattr missing dependency of dependent from_dask_array.
466-
new_col = pd.Series(new_ax.data.flatten().compute(), index=transformed.index)
467-
transformed[ax] = new_col
473+
# Assigning a numpy array is positional (no index alignment), so the original index is preserved.
474+
transformed_pd[ax] = new_ax.data.flatten().compute()
475+
476+
# Reconstruct as a dask DataFrame via delayed partitions so that:
477+
# (a) row order matches the original (dd.from_pandas sorts by index, which scrambles rows for
478+
# non-monotonic or duplicate indices such as those produced by multi-file parquet reads), and
479+
# (b) the original index is preserved exactly.
480+
offsets = np.cumsum([0] + lengths)
481+
delayed_parts = [dask.delayed(transformed_pd.iloc[offsets[i] : offsets[i + 1]]) for i in range(len(lengths))]
482+
transformed = dd.from_delayed(delayed_parts, meta=transformed_pd.iloc[:0])
483+
# dummy transformation that will be replaced by _adjust_transformation()
484+
default_cs = {DEFAULT_COORDINATE_SYSTEM: Identity()}
485+
transformed.attrs[TRANSFORM_KEY] = default_cs
468486

469487
old_transformations = cast(dict[str, Any], get_transformation(data, get_all=True))
470488

tests/core/operations/test_transform.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from pathlib import Path
77

88
import numpy as np
9+
import pandas as pd
910
import pytest
1011
from dask import config
1112
from geopandas.testing import geom_almost_equals
@@ -590,6 +591,53 @@ def test_transform_elements_and_entire_spatial_data_object(full_sdata: SpatialDa
590591
_ = full_sdata.transform_to_coordinate_system("my_space", maintain_positioning=maintain_positioning)
591592

592593

594+
def test_transform_points_duplicate_index_gh1105(tmp_path: str):
595+
"""Regression test for https://github.com/scverse/spatialdata/issues/1105.
596+
597+
Points loaded from multiple parquet files (e.g. Xenium transcripts) have a per-file 0-based
598+
index, so the global dask DataFrame index has duplicate labels. The old implementation passed
599+
``index=transformed.index`` to ``pd.Series``, which materialised the duplicate dask Index and
600+
caused ``ValueError: cannot reindex on an axis with duplicate labels`` when assigning back.
601+
"""
602+
import dask.dataframe as dd
603+
604+
n_per_partition = 50
605+
n_partitions = 4
606+
rng = np.random.default_rng(0)
607+
608+
# Simulate multi-file parquet: each partition's index starts at 0
609+
parts = [
610+
pd.DataFrame(
611+
{
612+
"x": rng.random(n_per_partition).astype("float32"),
613+
"y": rng.random(n_per_partition).astype("float32"),
614+
"gene": [f"gene_{j}" for j in range(n_per_partition)],
615+
}
616+
)
617+
for _ in range(n_partitions)
618+
]
619+
# test also the case of non-contiguous indices
620+
for part in parts:
621+
part.index = part.index.to_list()[:-1] + [100]
622+
ddf = dd.from_map(lambda df: df, parts)
623+
assert not ddf.index.compute().is_unique, "test setup: index must have duplicates"
624+
625+
scale_factor = 4
626+
points = PointsModel.parse(ddf)
627+
set_transformation(points, Scale([scale_factor, scale_factor], axes=("x", "y")), to_coordinate_system="global")
628+
629+
result = transform(points, to_coordinate_system="global")
630+
result_pd = result.compute()
631+
632+
# Index must be preserved as-is (duplicate [0..49] × 4)
633+
assert list(result_pd.index) == list(ddf.compute().index)
634+
# Non-axis column must survive unchanged
635+
assert list(result_pd["gene"]) == list(ddf.compute()["gene"])
636+
# Axis values must be correctly scaled
637+
expected_x = ddf.compute()["x"].values * scale_factor
638+
np.testing.assert_allclose(result_pd["x"].values, expected_x, rtol=1e-5)
639+
640+
593641
def test_transform_points_with_multiple_partitions(full_sdata: SpatialData, tmp_path: str):
594642
tmpdir = Path(tmp_path) / "tmp.zarr"
595643
points_memory = full_sdata["points_0"].compute()

0 commit comments

Comments
 (0)