|
6 | 6 | from functools import singledispatch |
7 | 7 | from typing import TYPE_CHECKING, Any, cast |
8 | 8 |
|
| 9 | +import dask |
9 | 10 | import dask.array as da |
| 11 | +import dask.dataframe as dd |
10 | 12 | import dask_image.ndinterp |
11 | 13 | import numpy as np |
12 | | -import pandas as pd |
13 | 14 | from dask.array.core import Array as DaskArray |
14 | 15 | from dask.dataframe import DataFrame as DaskDataFrame |
15 | 16 | from geopandas import GeoDataFrame |
@@ -442,29 +443,51 @@ def _( |
442 | 443 | axes = get_axes_names(data) |
443 | 444 | arrays = [] |
444 | 445 |
|
445 | | - # Workaround to prevent partition collaps and missing dependency problem for now. |
| 446 | + # Dask's expression optimizer can collapse partitions at compute time, making the partition |
| 447 | + # structure inside vs. outside a disable_dask_tune_optimization() context inconsistent. To avoid |
| 448 | + # index-alignment failures (e.g. "cannot reindex on an axis with duplicate labels" from parquet |
| 449 | + # files that each start their index at 0) and length-mismatch errors, we materialise the non-axis |
| 450 | + # columns and compute the axis arrays inside a single context where the partition structure is |
| 451 | + # stable, then do plain pandas operations and re-wrap with dd.from_delayed (not dd.from_pandas, |
| 452 | + # which sorts by index and would scramble rows for non-monotonic or duplicate indices). |
446 | 453 | with disable_dask_tune_optimization() if data.npartitions > 1 else contextlib.nullcontext(): |
| 454 | + lengths = [len(part) for part in data.partitions] |
447 | 455 | for ax in axes: |
448 | 456 | # TODO We have to pass on the lengths explicitly as automatic determination with dask graph optimization |
449 | | - # leads to collaps of the partitions. However this causes a missing dependency problem, which for now is |
| 457 | + # leads to collapse of the partitions. However this causes a missing dependency problem, which for now is |
450 | 458 | # prevented by setting the optimization to False when performing this operation. |
451 | | - arrays.append(data[ax].to_dask_array(lengths=[len(part) for part in data.partitions]).reshape(-1, 1)) |
| 459 | + arrays.append(data[ax].to_dask_array(lengths=lengths).reshape(-1, 1)) |
452 | 460 |
|
453 | | - xdata = DataArray(da.concatenate(arrays, axis=1), coords={"points": range(len(data)), "dim": list(axes)}) |
454 | | - xtransformed = transformation._transform_coordinates(xdata) |
455 | | - transformed = data.drop(columns=list(axes)).copy() |
456 | | - # dummy transformation that will be replaced by _adjust_transformation() |
457 | | - default_cs = {DEFAULT_COORDINATE_SYSTEM: Identity()} |
458 | | - transformed.attrs[TRANSFORM_KEY] = default_cs |
| 461 | + xdata = DataArray(da.concatenate(arrays, axis=1), coords={"points": range(sum(lengths)), "dim": list(axes)}) |
| 462 | + xtransformed = transformation._transform_coordinates(xdata) |
| 463 | + |
| 464 | + # Compute non-axis columns while the partition structure is still stable; preserves original index. |
| 465 | + transformed_pd = data.drop(columns=list(axes)).compute() |
459 | 466 |
|
460 | 467 | for ax in axes: |
461 | 468 | indices = xtransformed["dim"] == ax |
462 | 469 | new_ax = xtransformed[:, indices] |
463 | 470 | # TODO: discuss with dask team |
464 | 471 | # This is not nice, but otherwise there is a problem with the joint graph of new_ax and transformed, causing |
465 | 472 | # a getattr missing dependency of dependent from_dask_array. |
466 | | - new_col = pd.Series(new_ax.data.flatten().compute(), index=transformed.index) |
467 | | - transformed[ax] = new_col |
| 473 | + # Assigning a numpy array is positional (no index alignment), so the original index is preserved. |
| 474 | + transformed_pd[ax] = new_ax.data.flatten().compute() |
| 475 | + |
| 476 | + # Reconstruct as a dask DataFrame via delayed partitions so that: |
| 477 | + # (a) row order matches the original (dd.from_pandas sorts by index, which scrambles rows for |
| 478 | + # non-monotonic or duplicate indices such as those produced by multi-file parquet reads), and |
| 479 | + # (b) the original index is preserved exactly. |
| 480 | + offsets = np.cumsum([0] + lengths) |
| 481 | + delayed_parts = [dask.delayed(transformed_pd.iloc[offsets[i] : offsets[i + 1]]) for i in range(len(lengths))] |
| 482 | + transformed = dd.from_delayed(delayed_parts, meta=transformed_pd.iloc[:0]) |
| 483 | + # Preserve spatialdata_attrs (feature_key, instance_key, …) from the original element; |
| 484 | + # dd.from_delayed starts with empty attrs so we must copy them explicitly. |
| 485 | + for k, v in data.attrs.items(): |
| 486 | + if k != TRANSFORM_KEY: |
| 487 | + transformed.attrs[k] = v |
| 488 | + # dummy transformation that will be replaced by _adjust_transformation() |
| 489 | + default_cs = {DEFAULT_COORDINATE_SYSTEM: Identity()} |
| 490 | + transformed.attrs[TRANSFORM_KEY] = default_cs |
468 | 491 |
|
469 | 492 | old_transformations = cast(dict[str, Any], get_transformation(data, get_all=True)) |
470 | 493 |
|
|
0 commit comments