|
19 | 19 | from dataclasses import dataclass, field |
20 | 20 | from enum import StrEnum, auto |
21 | 21 | from itertools import groupby, islice |
22 | | -from typing import Any |
| 22 | +from typing import Any, Protocol |
23 | 23 |
|
24 | 24 | import msgspec |
25 | 25 | from iris.env_resources import TaskResources as _TaskResources |
|
51 | 51 | logger = logging.getLogger(__name__) |
52 | 52 |
|
53 | 53 |
|
| 54 | +# --------------------------------------------------------------------------- |
| 55 | +# Shard protocol |
| 56 | +# --------------------------------------------------------------------------- |
| 57 | + |
| 58 | + |
| 59 | +class Shard(Protocol): |
| 60 | + """Protocol for a shard of data assigned to a single worker. |
| 61 | +
|
| 62 | + Implementations: |
| 63 | + - ListShard: backed by iterable references (source data, non-scatter) |
| 64 | + - ScatterShard: backed by scatter Parquet files with predicate pushdown |
| 65 | + """ |
| 66 | + |
| 67 | + def __iter__(self) -> Iterator: ... |
| 68 | + def get_iterators(self) -> Iterator[Iterator]: ... |
| 69 | + |
| 70 | + |
54 | 71 | @dataclass |
55 | 72 | class SourceItem: |
56 | 73 | """A source item with its shard assignment. |
@@ -568,7 +585,7 @@ def make_windows( |
568 | 585 |
|
569 | 586 |
|
570 | 587 | def _merge_sorted_chunks( |
571 | | - shard, key_fn: Callable, sort_fn: Callable | None = None, external_sort_dir: str | None = None |
| 588 | + shard: Shard, key_fn: Callable, sort_fn: Callable | None = None, external_sort_dir: str | None = None |
572 | 589 | ) -> Iterator[tuple[object, Iterator]]: |
573 | 590 | """Merge sorted chunks using k-way merge, yielding (key, items_iterator) groups. |
574 | 591 |
|
@@ -598,10 +615,11 @@ def merge_key(item): |
598 | 615 |
|
599 | 616 | # Check if external sort is needed BEFORE materializing all iterators. |
600 | 617 | # ScatterShard can decide using manifest stats (no file opens needed). |
| 618 | + from zephyr.shuffle import ScatterShard |
| 619 | + |
601 | 620 | use_external = ( |
602 | 621 | external_sort_dir is not None |
603 | | - and hasattr(shard, "needs_external_sort") |
604 | | - and hasattr(shard, "get_iterators") |
| 622 | + and isinstance(shard, ScatterShard) |
605 | 623 | and shard.needs_external_sort(_TaskResources.from_environment().memory_bytes) |
606 | 624 | ) |
607 | 625 |
|
|
0 commit comments