Skip to content

Commit efc5ac5

Browse files
authored
[Datasets] [Out-of-Band Serialization: 1/3] Refactor LazyBlockList. (ray-project#23821)
This PR refactors `LazyBlockList` in service of out-of-band serialization (see [mono-PR](ray-project#22616)) and is a precursor to an execution plan refactor (PR #2) and adding the actual out-of-band serialization APIs (PR #3). The following is included in this refactor: 1. `ReadTask`s are now a first-class concept, replacing calls; 2. read stage progress tracking is consolidated into `LazyBlockList._get_blocks_with_metadta()` and more of the read task complexity, e.g. the read remote function, was pushed into `LazyBlockList` to make `ray.data.read_datasource()` simpler; 3. we are a bit smarter with how we progressively launch tasks and fetch and cache metadata, including fetching the metadata for read tasks in `.iter_blocks_with_metadata()` instead of relying on the pre-read task metadata (which will be less accurate), and we also fix some small bugs in the lazy ramp-up around progressive metadata fetching. (1) is the most important item for supporting out-of-band serialization and fundamentally changes the `LazyBlockList` data model. This is required since we need to be able to reference the underlying read tasks when rewriting read stages during optimization and when serializing the lineage of the Dataset. See the [mono-PR](ray-project#22616) for more context. Other changes: 1. Changed stats actor to a global named actor singleton in order to obviate the need for serializing the actor handle with the Dataset stats; without this, we were encountering serialization failures.
1 parent d96ac25 commit efc5ac5

File tree

9 files changed

+454
-195
lines changed

9 files changed

+454
-195
lines changed

python/ray/data/dataset.py

Lines changed: 22 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@
5757
ParquetDatasource,
5858
BlockWritePathProvider,
5959
DefaultBlockWritePathProvider,
60+
ReadTask,
6061
WriteResult,
6162
)
6263
from ray.data.datasource.file_based_datasource import (
@@ -988,26 +989,28 @@ def union(self, *other: List["Dataset[T]"]) -> "Dataset[T]":
988989

989990
start_time = time.perf_counter()
990991
context = DatasetContext.get_current()
991-
calls: List[Callable[[], ObjectRef[BlockPartition]]] = []
992-
metadata: List[BlockPartitionMetadata] = []
993-
block_partitions: List[ObjectRef[BlockPartition]] = []
992+
tasks: List[ReadTask] = []
993+
block_partition_refs: List[ObjectRef[BlockPartition]] = []
994+
block_partition_meta_refs: List[ObjectRef[BlockPartitionMetadata]] = []
994995

995996
datasets = [self] + list(other)
996997
for ds in datasets:
997998
bl = ds._plan.execute()
998999
if isinstance(bl, LazyBlockList):
999-
calls.extend(bl._calls)
1000-
metadata.extend(bl._metadata)
1001-
block_partitions.extend(bl._block_partitions)
1000+
tasks.extend(bl._tasks)
1001+
block_partition_refs.extend(bl._block_partition_refs)
1002+
block_partition_meta_refs.extend(bl._block_partition_meta_refs)
10021003
else:
1003-
calls.extend([None] * bl.initial_num_blocks())
1004-
metadata.extend(bl._metadata)
1004+
tasks.extend([ReadTask(lambda: None, meta) for meta in bl._metadata])
10051005
if context.block_splitting_enabled:
1006-
block_partitions.extend(
1006+
block_partition_refs.extend(
10071007
[ray.put([(b, m)]) for b, m in bl.get_blocks_with_metadata()]
10081008
)
10091009
else:
1010-
block_partitions.extend(bl.get_blocks())
1010+
block_partition_refs.extend(bl.get_blocks())
1011+
block_partition_meta_refs.extend(
1012+
[ray.put(meta) for meta in bl._metadata]
1013+
)
10111014

10121015
epochs = [ds._get_epoch() for ds in datasets]
10131016
max_epoch = max(*epochs)
@@ -1028,7 +1031,8 @@ def union(self, *other: List["Dataset[T]"]) -> "Dataset[T]":
10281031
dataset_stats.time_total_s = time.perf_counter() - start_time
10291032
return Dataset(
10301033
ExecutionPlan(
1031-
LazyBlockList(calls, metadata, block_partitions), dataset_stats
1034+
LazyBlockList(tasks, block_partition_refs, block_partition_meta_refs),
1035+
dataset_stats,
10321036
),
10331037
max_epoch,
10341038
self._lazy,
@@ -2548,6 +2552,7 @@ def repeat(self, times: Optional[int] = None) -> "DatasetPipeline[T]":
25482552
# to enable fusion with downstream map stages.
25492553
ctx = DatasetContext.get_current()
25502554
if self._plan._is_read_stage() and ctx.optimize_fuse_read_stages:
2555+
self._plan._in_blocks.clear()
25512556
blocks, read_stage = self._plan._rewrite_read_stage()
25522557
outer_stats = DatasetStats(stages={}, parent=None)
25532558
else:
@@ -2666,6 +2671,7 @@ def window(
26662671
# to enable fusion with downstream map stages.
26672672
ctx = DatasetContext.get_current()
26682673
if self._plan._is_read_stage() and ctx.optimize_fuse_read_stages:
2674+
self._plan._in_blocks.clear()
26692675
blocks, read_stage = self._plan._rewrite_read_stage()
26702676
outer_stats = DatasetStats(stages={}, parent=None)
26712677
else:
@@ -2749,12 +2755,13 @@ def fully_executed(self) -> "Dataset[T]":
27492755
Returns:
27502756
A Dataset with all blocks fully materialized in memory.
27512757
"""
2752-
blocks = self.get_internal_block_refs()
2753-
bar = ProgressBar("Force reads", len(blocks))
2754-
bar.block_until_complete(blocks)
2758+
blocks, metadata = [], []
2759+
for b, m in self._plan.execute().get_blocks_with_metadata():
2760+
blocks.append(b)
2761+
metadata.append(m)
27552762
ds = Dataset(
27562763
ExecutionPlan(
2757-
BlockList(blocks, self._plan.execute().get_metadata()),
2764+
BlockList(blocks, metadata),
27582765
self._plan.stats(),
27592766
dataset_uuid=self._get_uuid(),
27602767
),

python/ray/data/impl/block_list.py

Lines changed: 3 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,11 @@
11
import math
2-
from typing import List, Iterator, Tuple, Any, Union, Optional, TYPE_CHECKING
3-
4-
if TYPE_CHECKING:
5-
import pyarrow
2+
from typing import List, Iterator, Tuple
63

74
import numpy as np
85

96
import ray
107
from ray.types import ObjectRef
11-
from ray.data.block import Block, BlockMetadata, BlockAccessor
12-
from ray.data.impl.remote_fn import cached_remote_fn
8+
from ray.data.block import Block, BlockMetadata
139

1410

1511
class BlockList:
@@ -26,11 +22,7 @@ def __init__(self, blocks: List[ObjectRef[Block]], metadata: List[BlockMetadata]
2622
self._num_blocks = len(self._blocks)
2723
self._metadata: List[BlockMetadata] = metadata
2824

29-
def set_metadata(self, i: int, metadata: BlockMetadata) -> None:
30-
"""Set the metadata for a given block."""
31-
self._metadata[i] = metadata
32-
33-
def get_metadata(self) -> List[BlockMetadata]:
25+
def get_metadata(self, fetch_if_missing: bool = False) -> List[BlockMetadata]:
3426
"""Get the metadata for all blocks."""
3527
return self._metadata.copy()
3628

@@ -182,23 +174,3 @@ def executed_num_blocks(self) -> int:
182174
doesn't know how many blocks will be produced until tasks finish.
183175
"""
184176
return len(self.get_blocks())
185-
186-
def ensure_schema_for_first_block(self) -> Optional[Union["pyarrow.Schema", type]]:
187-
"""Ensure that the schema is set for the first block.
188-
189-
Returns None if the block list is empty.
190-
"""
191-
get_schema = cached_remote_fn(_get_schema)
192-
try:
193-
block = next(self.iter_blocks())
194-
except (StopIteration, ValueError):
195-
# Dataset is empty (no blocks) or was manually cleared.
196-
return None
197-
schema = ray.get(get_schema.remote(block))
198-
# Set the schema.
199-
self._metadata[0].schema = schema
200-
return schema
201-
202-
203-
def _get_schema(block: Block) -> Any:
204-
return BlockAccessor.for_block(block).schema()

0 commit comments

Comments
 (0)