Skip to content
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
167 changes: 113 additions & 54 deletions lib/zephyr/src/zephyr/readers.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,118 @@
import msgspec
import numpy as np
import pyarrow as pa
import pyarrow.compute as pc
import pyarrow.parquet as pq

from zephyr import counters
from zephyr.expr import Expr

logger = logging.getLogger(__name__)


# ---------------------------------------------------------------------------
# Shared Parquet row-group reader
# ---------------------------------------------------------------------------


def _check_row_group_statistics(
rg_meta: pq.RowGroupMetaData,
equality_predicates: dict[str, object],
) -> bool:
"""Return False if row group min/max statistics prove no rows can match."""
for col_idx in range(rg_meta.num_columns):
col_meta = rg_meta.column(col_idx)
name = col_meta.path_in_schema
if name not in equality_predicates:
continue
stats = col_meta.statistics
if stats is None or not stats.has_min_max:
continue # no stats — assume it could match
value = equality_predicates[name]
if value < stats.min or value > stats.max:
return False
return True


def iter_parquet_row_groups(
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the duplicate eq_predictes and row_filter feels easy to accidentally screw up. can we extract the row equality predictes?

  def _extract_equality_predicates(
      expr: pc.Expression, schema: pa.Schema
  ) -> dict[str, object]:
      """Extract field==scalar pairs from a pc.Expression via Substrait."""
      msg = ee_pb2.ExtendedExpression()
      msg.ParseFromString(bytes(expr.to_substrait(schema, allow_arrow_extensions=True)))

      func_names = {
          ext.extension_function.function_anchor: ext.extension_function.name
          for ext in msg.extensions
          if ext.HasField("extension_function")
      }
      field_names = list(msg.base_schema.names)
      result: dict[str, object] = {}

      def _walk(node):
          if not node.HasField("scalar_function"):
              return
          sf = node.scalar_function
          fname = func_names.get(sf.function_reference, "")
          if fname == "and":
              for arg in sf.arguments:
                  if arg.HasField("value"):
                      _walk(arg.value)
          elif fname == "equal" and len(sf.arguments) == 2:
              field_arg = lit_arg = None
              for a in (a.value for a in sf.arguments):
                  if a.HasField("selection"):
                      field_arg = a
                  elif a.HasField("literal"):
                      lit_arg = a
              if field_arg and lit_arg:
                  idx = field_arg.selection.direct_reference.struct_field.field
                  lit = lit_arg.literal
                  for typ in ("i64", "i32", "i16", "i8", "fp64", "fp32", "string", "boolean"):
                      if lit.HasField(typ):
                          result[field_names[idx]] = getattr(lit, typ)
                          break

      _walk(msg.referred_expr[0].expression)
      return result

alternatively, i'd suggest getting rid of the row filter for now -- just pick one?

source: str | pq.ParquetFile,
*,
columns: list[str] | None = None,
row_filter: pc.Expression | None = None,
row_start: int | None = None,
row_end: int | None = None,
equality_predicates: dict[str, object] | None = None,
) -> Iterator[pa.Table]:
"""Yield one ``pa.Table`` per qualifying row group with O(row_group) memory.

Uses ``pq.ParquetFile`` instead of ``pyarrow.dataset`` to avoid the
upstream memory leak (https://github.com/apache/arrow/issues/39808).

Args:
source: Path to parquet file or an already-open ``pq.ParquetFile``.
columns: Columns to read (``None`` for all).
row_filter: PyArrow compute expression applied after reading.
When ``columns`` is specified, any additional columns needed by the
filter are read automatically and stripped after filtering.
row_start: First row to include (inclusive, before filtering).
row_end: Last row to include (exclusive, before filtering).
equality_predicates: Column-value pairs for statistics-based row group
skipping. Row groups whose min/max statistics exclude the target
value are not read at all.
"""
pf = pq.ParquetFile(source) if isinstance(source, str) else source
has_row_range = row_start is not None and row_end is not None

# When both columns and row_filter are set, the filter may reference
# columns not in the projection. Read the union and project after.
read_columns = columns
need_project = False
if columns is not None and row_filter is not None:
all_schema_names = {pf.schema_arrow.field(i).name for i in range(len(pf.schema_arrow))}
filter_cols = {name for name in all_schema_names if name in str(row_filter)} - set(columns)
if filter_cols:
read_columns = list(columns) + sorted(filter_cols)
need_project = True

cumulative_rows = 0

for i in range(pf.metadata.num_row_groups):
rg_meta = pf.metadata.row_group(i)
rg_num_rows = rg_meta.num_rows
rg_start = cumulative_rows
rg_end = cumulative_rows + rg_num_rows
cumulative_rows = rg_end

if equality_predicates and not _check_row_group_statistics(rg_meta, equality_predicates):
continue

if has_row_range:
assert row_start is not None and row_end is not None
if rg_end <= row_start:
continue
if rg_start >= row_end:
return

table = pf.read_row_group(i, columns=read_columns)

if has_row_range:
assert row_start is not None and row_end is not None
is_interior = rg_start >= row_start and rg_end <= row_end
if not is_interior:
local_start = max(0, row_start - rg_start)
local_end = min(rg_num_rows, row_end - rg_start)
table = table.slice(local_start, local_end - local_start)

if row_filter is not None:
table = table.filter(row_filter)

if need_project:
table = table.select(columns)

if len(table) > 0:
yield table


# 16 MB read blocks with background prefetch for S3/remote reads.
_READ_BLOCK_SIZE = 16_000_000
_READ_CACHE_TYPE = "background"
Expand Down Expand Up @@ -163,71 +269,24 @@ def load_parquet(source: str | InputFileSpec) -> Iterator[dict]:
... )
>>> output_files = list(ctx.execute(ds))
"""
import pyarrow.dataset as pads

spec = _as_spec(source)
logger.info("Loading: %s", spec.path)
columns = spec.columns

pa_filter = None
if spec.filter_expr is not None:
from zephyr.expr import to_pyarrow_expr

pa_filter = to_pyarrow_expr(spec.filter_expr)

dataset = pads.dataset(spec.path, format="parquet")

# Handle empty parquet files (no data columns in schema)
schema_names = dataset.schema.names
if not schema_names:
return

if spec.row_start is not None and spec.row_end is not None:
# Row range first: select rows by position, then apply filter
cumulative_rows = 0
for fragment in dataset.get_fragments():
for rg_fragment in fragment.split_by_row_group():
# Get row group size from RowGroupInfo (no data read)
assert len(rg_fragment.row_groups) == 1
rg_info = rg_fragment.row_groups[0]
rg_num_rows = rg_info.num_rows
rg_start = cumulative_rows
rg_end = cumulative_rows + rg_num_rows

if rg_end > spec.row_start and rg_start < spec.row_end:
is_interior = rg_start >= spec.row_start and rg_end <= spec.row_end

if is_interior:
# Entirely within range: push filter down, yield all
table = rg_fragment.to_table(columns=columns, filter=pa_filter)
counters.increment("zephyr/records_in", len(table))
yield from table.to_pylist()
else:
# Boundary row group: slice first, then filter
table = rg_fragment.to_table(columns=columns)
local_start = max(0, spec.row_start - rg_start)
local_end = min(rg_num_rows, spec.row_end - rg_start)
sliced = table.slice(local_start, local_end - local_start)

if pa_filter is not None:
filtered = sliced.filter(pa_filter)
counters.increment("zephyr/records_in", len(filtered))
yield from filtered.to_pylist()
else:
counters.increment("zephyr/records_in", len(sliced))
yield from sliced.to_pylist()

cumulative_rows = rg_end
if cumulative_rows >= spec.row_end:
return
elif pa_filter is not None:
table = dataset.to_table(columns=columns, filter=pa_filter)
for table in iter_parquet_row_groups(
spec.path,
columns=spec.columns,
row_filter=pa_filter,
row_start=spec.row_start,
row_end=spec.row_end,
):
counters.increment("zephyr/records_in", len(table))
yield from table.to_pylist()
else:
for batch in dataset.to_batches(columns=columns):
counters.increment("zephyr/records_in", len(batch))
yield from batch.to_pylist()


def load_vortex(source: str | InputFileSpec) -> Iterator[dict]:
Expand Down
46 changes: 25 additions & 21 deletions lib/zephyr/src/zephyr/shuffle.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,6 @@
import cloudpickle
import fsspec
import pyarrow as pa
import pyarrow.compute as pc
import pyarrow.dataset as pad
import pyarrow.parquet as pq
from iris.env_resources import TaskResources as _TaskResources
from rigging.filesystem import open_url, url_to_fs
Expand Down Expand Up @@ -139,8 +137,9 @@ def _get_scatter_read_fs(num_files: int, sample_path: str, memory_fraction: floa
class ScatterParquetIterator:
"""Reference to sorted chunks for one target shard in one Parquet file.

Creates a ``pyarrow.dataset`` once (caching file metadata) and yields
lazy per-chunk iterators via Scanner with predicate pushdown.
Opens the file via ``pq.ParquetFile`` and uses Parquet row-group
statistics on ``(shard_idx, chunk_idx)`` for predicate pushdown,
avoiding the ``pyarrow.dataset`` memory leak (apache/arrow#39808).
"""

path: str
Expand All @@ -156,28 +155,33 @@ def __iter__(self) -> Iterator:
def get_chunk_iterators(self, batch_size: int = 1024) -> Iterator[Iterator]:
"""Yield one lazy iterator per sorted chunk.

Opens the file once via ``pyarrow.dataset`` and creates a Scanner
per chunk with predicate pushdown on ``(shard_idx, chunk_idx)``.
Opens the file once via ``pq.ParquetFile`` and uses row-group
statistics to skip non-matching row groups (equivalent to dataset
Scanner predicate pushdown for the scatter envelope columns).
"""

_, fs_path = url_to_fs(self.path)
dataset: pad.FileSystemDataset = pad.dataset(fs_path, format="parquet", filesystem=self.filesystem)
pf = pq.ParquetFile(self.filesystem.open_input_file(fs_path))
col = _ZEPHYR_SHUFFLE_PICKLED_COL if self.is_pickled else _ZEPHYR_SHUFFLE_ITEM_COL

for chunk_idx in range(self.chunk_count):
scanner = dataset.scanner(
columns=[col],
filter=(
(pc.field(_ZEPHYR_SHUFFLE_SHARD_IDX_COL) == self.shard_idx)
& (pc.field(_ZEPHYR_SHUFFLE_CHUNK_IDX_COL) == chunk_idx)
),
batch_size=batch_size,
use_threads=False,
)
yield self._iter_scanner(scanner, col)

def _iter_scanner(self, scanner: pad.Scanner, col: str) -> Iterator:
for batch in scanner.to_batches():
items = batch.column(col).to_pylist()
yield self._iter_chunk(pf, col, chunk_idx)

def _iter_chunk(self, pf: pq.ParquetFile, col: str, chunk_idx: int) -> Iterator:
from zephyr.readers import iter_parquet_row_groups

# The scatter writer writes one (shard_idx, chunk_idx) per row group,
# so equality_predicates on min/max statistics skip non-matching row
# groups without reading data — equivalent to dataset predicate pushdown.
for table in iter_parquet_row_groups(
pf,
columns=[col],
equality_predicates={
_ZEPHYR_SHUFFLE_SHARD_IDX_COL: self.shard_idx,
_ZEPHYR_SHUFFLE_CHUNK_IDX_COL: chunk_idx,
},
):
items = table.column(col).to_pylist()
if self.is_pickled:
yield from (pickle.loads(b) for b in items)
else:
Expand Down
Loading
Loading