Skip to content

Commit c057f1e

Browse files
omatthew98iamjustinhsualexeykudinkin
authored
[data] Cherry pick data fixes for 2.49.1 (#56058)
Cherry pick two fixes for ray data (from #55854 and #55926). --------- Signed-off-by: iamjustinhsu <[email protected]> Signed-off-by: Alexey Kudinkin <[email protected]> Signed-off-by: Matthew Owen <[email protected]> Co-authored-by: iamjustinhsu <[email protected]> Co-authored-by: Alexey Kudinkin <[email protected]>
1 parent 7cc0031 commit c057f1e

File tree

21 files changed

+269
-63
lines changed

21 files changed

+269
-63
lines changed

python/ray/air/util/object_extensions/arrow.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,9 @@ def __reduce__(self):
7171
self.__arrow_ext_serialize__(),
7272
)
7373

74+
def __hash__(self) -> int:
75+
return hash((type(self), self.storage_type.id, self.extension_name))
76+
7477

7578
@PublicAPI(stability="alpha")
7679
class ArrowPythonObjectScalar(pa.ExtensionScalar):

python/ray/air/util/tensor_extensions/arrow.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -574,6 +574,9 @@ def _need_variable_shaped_tensor_array(
574574
shape = arr_type.shape
575575
return False
576576

577+
def __hash__(self) -> int:
578+
return hash((type(self), self.extension_name, self.storage_type, self._shape))
579+
577580

578581
@PublicAPI(stability="beta")
579582
class ArrowTensorType(_BaseFixedShapeArrowTensorType):
@@ -584,6 +587,7 @@ class ArrowTensorType(_BaseFixedShapeArrowTensorType):
584587
"""
585588

586589
OFFSET_DTYPE = np.int32
590+
__hash__ = _BaseFixedShapeArrowTensorType.__hash__
587591

588592
def __init__(self, shape: Tuple[int, ...], dtype: pa.DataType):
589593
"""
@@ -614,6 +618,7 @@ class ArrowTensorTypeV2(_BaseFixedShapeArrowTensorType):
614618
"""Arrow ExtensionType (v2) for tensors (supporting tensors > 4Gb)."""
615619

616620
OFFSET_DTYPE = np.int64
621+
__hash__ = _BaseFixedShapeArrowTensorType.__hash__
617622

618623
def __init__(self, shape: Tuple[int, ...], dtype: pa.DataType):
619624
"""
@@ -1125,6 +1130,9 @@ def _extension_scalar_to_ndarray(self, scalar: "pa.ExtensionScalar") -> np.ndarr
11251130
data_buffer = raw_values.buffers()[1]
11261131
return _to_ndarray_helper(shape, value_type, offset, data_buffer)
11271132

1133+
def __hash__(self) -> int:
1134+
return hash((type(self), self.extension_name, self.storage_type, self._ndim))
1135+
11281136

11291137
# NOTE: We need to inherit from the mixin before pa.ExtensionArray to ensure that the
11301138
# mixin's overriding methods appear first in the MRO.

python/ray/data/BUILD

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1239,6 +1239,20 @@ py_test(
12391239
],
12401240
)
12411241

1242+
py_test(
1243+
name = "test_unify_schemas_performance",
1244+
size = "small",
1245+
srcs = ["tests/test_unify_schemas_performance.py"],
1246+
tags = [
1247+
"exclusive",
1248+
"team:data",
1249+
],
1250+
deps = [
1251+
":conftest",
1252+
"//:ray_lib",
1253+
],
1254+
)
1255+
12421256
py_test(
12431257
name = "test_util",
12441258
size = "small",

python/ray/data/_internal/arrow_ops/transform_pyarrow.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -172,6 +172,14 @@ def unify_schemas(
172172
ArrowVariableShapedTensorType,
173173
)
174174

175+
try:
176+
if len(set(schemas)) == 1:
177+
# Early exit because unifying can be expensive
178+
return schemas.pop()
179+
except Exception as e:
180+
# Unsure if there are cases where schemas are NOT hashable
181+
logger.warning(f"Failed to hash the schemas (for deduplication): {e}")
182+
175183
schemas_to_unify = []
176184
schema_field_overrides = {}
177185

python/ray/data/_internal/equalize.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,12 @@
22

33
from ray.data._internal.execution.interfaces import RefBundle
44
from ray.data._internal.split import _calculate_blocks_rows, _split_at_indices
5-
from ray.data._internal.util import unify_ref_bundles_schema
6-
from ray.data.block import Block, BlockMetadata, BlockPartition
5+
from ray.data.block import (
6+
Block,
7+
BlockMetadata,
8+
BlockPartition,
9+
_take_first_non_empty_schema,
10+
)
711
from ray.types import ObjectRef
812

913

@@ -41,7 +45,7 @@ def _equalize(
4145

4246
# phase 2: based on the num rows needed for each shaved split, split the leftovers
4347
# in the shape that exactly matches the rows needed.
44-
schema = unify_ref_bundles_schema(per_split_bundles)
48+
schema = _take_first_non_empty_schema(bundle.schema for bundle in per_split_bundles)
4549
leftover_bundle = RefBundle(leftovers, owns_blocks=owned_by_consumer, schema=schema)
4650
leftover_splits = _split_leftovers(leftover_bundle, per_split_needed_rows)
4751

python/ray/data/_internal/execution/interfaces/ref_bundle.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,13 @@ def __post_init__(self):
6363
"The size in bytes of the block must be known: {}".format(b)
6464
)
6565

66+
import pyarrow as pa
67+
68+
# The schema metadata might be unhashable.
69+
# We need schemas to be hashable for unification
70+
if isinstance(self.schema, pa.lib.Schema):
71+
self.schema = self.schema.remove_metadata()
72+
6673
def __setattr__(self, key, value):
6774
if hasattr(self, key) and key in ["blocks", "owns_blocks"]:
6875
raise ValueError(f"The `{key}` field of RefBundle cannot be updated.")

python/ray/data/_internal/execution/legacy_compat.py

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,11 @@
1616
from ray.data._internal.logical.util import record_operators_usage
1717
from ray.data._internal.plan import ExecutionPlan
1818
from ray.data._internal.stats import DatasetStats
19-
from ray.data._internal.util import (
20-
unify_schemas_with_validation,
19+
from ray.data.block import (
20+
BlockMetadata,
21+
BlockMetadataWithSchema,
22+
_take_first_non_empty_schema,
2123
)
22-
from ray.data.block import BlockMetadata, BlockMetadataWithSchema
2324

2425
# Warn about tasks larger than this.
2526
TASK_SIZE_WARN_THRESHOLD_BYTES = 100000
@@ -171,18 +172,18 @@ def _get_initial_stats_from_plan(plan: ExecutionPlan) -> DatasetStats:
171172
def _bundles_to_block_list(bundles: Iterator[RefBundle]) -> BlockList:
172173
blocks, metadata = [], []
173174
owns_blocks = True
174-
schemas = []
175+
bundle_list = list(bundles)
176+
schema = _take_first_non_empty_schema(
177+
ref_bundle.schema for ref_bundle in bundle_list
178+
)
175179

176-
for ref_bundle in bundles:
180+
for ref_bundle in bundle_list:
177181
if not ref_bundle.owns_blocks:
178182
owns_blocks = False
179183
blocks.extend(ref_bundle.block_refs)
180184
metadata.extend(ref_bundle.metadata)
181-
schemas.append(ref_bundle.schema)
182-
unified_schema = unify_schemas_with_validation(schemas)
183-
return BlockList(
184-
blocks, metadata, owned_by_consumer=owns_blocks, schema=unified_schema
185-
)
185+
186+
return BlockList(blocks, metadata, owned_by_consumer=owns_blocks, schema=schema)
186187

187188

188189
def _set_stats_uuid_recursive(stats: DatasetStats, dataset_uuid: str) -> None:

python/ray/data/_internal/execution/operators/map_operator.py

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -48,13 +48,14 @@
4848
)
4949
from ray.data._internal.execution.util import memory_string
5050
from ray.data._internal.stats import StatsDict
51-
from ray.data._internal.util import MemoryProfiler, unify_ref_bundles_schema
51+
from ray.data._internal.util import MemoryProfiler
5252
from ray.data.block import (
5353
Block,
5454
BlockAccessor,
5555
BlockExecStats,
5656
BlockMetadataWithSchema,
5757
BlockStats,
58+
_take_first_non_empty_schema,
5859
to_stats,
5960
)
6061
from ray.data.context import DataContext
@@ -541,8 +542,6 @@ def _map_task(
541542
A generator of blocks, followed by the list of BlockMetadata for the blocks
542543
as the last generator return.
543544
"""
544-
from ray.data.block import BlockMetadataWithSchema
545-
546545
logger.debug(
547546
"Executing map task of operator %s with task index %d",
548547
ctx.op_name,
@@ -662,14 +661,13 @@ def _get_bundle_size(bundle: RefBundle):
662661
def _merge_ref_bundles(*bundles: RefBundle) -> RefBundle:
663662
"""Merge N ref bundles into a single bundle of multiple blocks."""
664663
# Check that at least one bundle is non-null.
665-
assert any(bundle is not None for bundle in bundles)
664+
bundles = [bundle for bundle in bundles if bundle is not None]
665+
assert len(bundles) > 0
666666
blocks = list(
667-
itertools.chain(
668-
block for bundle in bundles if bundle is not None for block in bundle.blocks
669-
)
667+
itertools.chain(block for bundle in bundles for block in bundle.blocks)
670668
)
671-
owns_blocks = all(bundle.owns_blocks for bundle in bundles if bundle is not None)
672-
schema = unify_ref_bundles_schema(bundles)
669+
owns_blocks = all(bundle.owns_blocks for bundle in bundles)
670+
schema = _take_first_non_empty_schema(bundle.schema for bundle in bundles)
673671
return RefBundle(blocks, owns_blocks=owns_blocks, schema=schema)
674672

675673

python/ray/data/_internal/execution/streaming_executor_state.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -284,7 +284,10 @@ def add_output(self, ref: RefBundle) -> None:
284284
"""Move a bundle produced by the operator to its outqueue."""
285285

286286
ref, diverged = dedupe_schemas_with_validation(
287-
self._schema, ref, warn=not self._warned_on_schema_divergence
287+
self._schema,
288+
ref,
289+
warn=not self._warned_on_schema_divergence,
290+
enforce_schemas=self.op.data_context.enforce_schemas,
288291
)
289292
self._schema = ref.schema
290293
self._warned_on_schema_divergence |= diverged
@@ -756,7 +759,7 @@ def dedupe_schemas_with_validation(
756759
old_schema: Optional["Schema"],
757760
bundle: "RefBundle",
758761
warn: bool = True,
759-
allow_divergent: bool = False,
762+
enforce_schemas: bool = False,
760763
) -> Tuple["RefBundle", bool]:
761764
"""Unify/Dedupe two schemas, warning if warn=True
762765
@@ -765,7 +768,7 @@ def dedupe_schemas_with_validation(
765768
the new schema will be used as the old schema.
766769
bundle: The new `RefBundle` to unify with the old schema.
767770
warn: Raise a warning if the schemas diverge.
768-
allow_divergent: If `True`, allow the schemas to diverge and return unified schema.
771+
enforce_schemas: If `True`, allow the schemas to diverge and return unified schema.
769772
If `False`, but keep the old schema.
770773
771774
Returns:
@@ -792,7 +795,7 @@ def dedupe_schemas_with_validation(
792795
f"than the previous one. Previous schema: {old_schema}, "
793796
f"new schema: {bundle.schema}. This may lead to unexpected behavior."
794797
)
795-
if allow_divergent:
798+
if enforce_schemas:
796799
old_schema = unify_schemas_with_validation([old_schema, bundle.schema])
797800

798801
return (

python/ray/data/_internal/logical/operators/from_operators.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,12 @@
44

55
from ray.data._internal.execution.interfaces import RefBundle
66
from ray.data._internal.logical.interfaces import LogicalOperator, SourceOperator
7-
from ray.data._internal.util import unify_block_metadata_schema
8-
from ray.data.block import Block, BlockMetadata, BlockMetadataWithSchema
7+
from ray.data._internal.util import unify_ref_bundles_schema
8+
from ray.data.block import (
9+
Block,
10+
BlockMetadata,
11+
BlockMetadataWithSchema,
12+
)
913
from ray.types import ObjectRef
1014

1115
if TYPE_CHECKING:
@@ -28,12 +32,11 @@ def __init__(
2832
len(input_metadata),
2933
)
3034
# `owns_blocks` is False because this op may be shared by multiple Datasets.
31-
self._schema = unify_block_metadata_schema(input_metadata)
3235
self._input_data = [
3336
RefBundle(
3437
[(input_blocks[i], input_metadata[i])],
3538
owns_blocks=False,
36-
schema=self._schema,
39+
schema=input_metadata[i].schema,
3740
)
3841
for i in range(len(input_blocks))
3942
]
@@ -71,7 +74,7 @@ def infer_metadata(self) -> BlockMetadata:
7174
return self._cached_output_metadata
7275

7376
def infer_schema(self):
74-
return self._schema
77+
return unify_ref_bundles_schema(self._input_data)
7578

7679
def is_lineage_serializable(self) -> bool:
7780
# This operator isn't serializable because it contains ObjectRefs.

0 commit comments

Comments
 (0)