Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
84 changes: 33 additions & 51 deletions lib/zephyr/src/zephyr/readers.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import msgspec
import numpy as np
import pyarrow as pa
import pyarrow.parquet as pq

from zephyr import counters
from zephyr.expr import Expr
Expand Down Expand Up @@ -163,8 +164,6 @@ def load_parquet(source: str | InputFileSpec) -> Iterator[dict]:
... )
>>> output_files = list(ctx.execute(ds))
"""
import pyarrow.dataset as pads

spec = _as_spec(source)
logger.info("Loading: %s", spec.path)
columns = spec.columns
Expand All @@ -175,59 +174,42 @@ def load_parquet(source: str | InputFileSpec) -> Iterator[dict]:

pa_filter = to_pyarrow_expr(spec.filter_expr)

dataset = pads.dataset(spec.path, format="parquet")
# Read row-group-by-row-group via ParquetFile to avoid the
# pyarrow.dataset API which loads the entire file into Arrow's memory
# pool upfront, causing multi-GB RSS bloat on large files.
# See: https://github.com/apache/arrow/issues/39808
pf = pq.ParquetFile(spec.path)
has_row_range = spec.row_start is not None and spec.row_end is not None
cumulative_rows = 0

for i in range(pf.metadata.num_row_groups):
rg_num_rows = pf.metadata.row_group(i).num_rows
rg_start = cumulative_rows
rg_end = cumulative_rows + rg_num_rows
cumulative_rows = rg_end

if has_row_range:
assert spec.row_start is not None and spec.row_end is not None
if rg_end <= spec.row_start:
continue
if rg_start >= spec.row_end:
return

# Handle empty parquet files (no data columns in schema)
schema_names = dataset.schema.names
if not schema_names:
return
table = pf.read_row_group(i, columns=columns)
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Apply filter before projecting row-group columns

load_parquet now reads each row group with columns=columns and applies table.filter(pa_filter) afterward. In pushed-down pipelines that combine filter + select (e.g. .filter(col("score") > 70).select("id")), _compute_file_pushdown can pass columns=["id"] while the predicate still references score, so this path drops the predicate column before filtering and causes table.filter to fail (or mis-evaluate). The prior dataset.to_table(columns=..., filter=...) flow did not have this ordering problem because the scanner could read predicate columns without projecting them to output.

Useful? React with 👍 / 👎.


if has_row_range:
assert spec.row_start is not None and spec.row_end is not None
is_interior = rg_start >= spec.row_start and rg_end <= spec.row_end
if not is_interior:
local_start = max(0, spec.row_start - rg_start)
local_end = min(rg_num_rows, spec.row_end - rg_start)
table = table.slice(local_start, local_end - local_start)

if pa_filter is not None:
table = table.filter(pa_filter)

if spec.row_start is not None and spec.row_end is not None:
# Row range first: select rows by position, then apply filter
cumulative_rows = 0
for fragment in dataset.get_fragments():
for rg_fragment in fragment.split_by_row_group():
# Get row group size from RowGroupInfo (no data read)
assert len(rg_fragment.row_groups) == 1
rg_info = rg_fragment.row_groups[0]
rg_num_rows = rg_info.num_rows
rg_start = cumulative_rows
rg_end = cumulative_rows + rg_num_rows

if rg_end > spec.row_start and rg_start < spec.row_end:
is_interior = rg_start >= spec.row_start and rg_end <= spec.row_end

if is_interior:
# Entirely within range: push filter down, yield all
table = rg_fragment.to_table(columns=columns, filter=pa_filter)
counters.increment("zephyr/records_in", len(table))
yield from table.to_pylist()
else:
# Boundary row group: slice first, then filter
table = rg_fragment.to_table(columns=columns)
local_start = max(0, spec.row_start - rg_start)
local_end = min(rg_num_rows, spec.row_end - rg_start)
sliced = table.slice(local_start, local_end - local_start)

if pa_filter is not None:
filtered = sliced.filter(pa_filter)
counters.increment("zephyr/records_in", len(filtered))
yield from filtered.to_pylist()
else:
counters.increment("zephyr/records_in", len(sliced))
yield from sliced.to_pylist()

cumulative_rows = rg_end
if cumulative_rows >= spec.row_end:
return
elif pa_filter is not None:
table = dataset.to_table(columns=columns, filter=pa_filter)
counters.increment("zephyr/records_in", len(table))
yield from table.to_pylist()
else:
for batch in dataset.to_batches(columns=columns):
counters.increment("zephyr/records_in", len(batch))
yield from batch.to_pylist()


def load_vortex(source: str | InputFileSpec) -> Iterator[dict]:
Expand Down
99 changes: 99 additions & 0 deletions lib/zephyr/tests/test_readers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
# Copyright The Marin Authors
# SPDX-License-Identifier: Apache-2.0

"""Tests for parquet reader (load_parquet)."""


import pyarrow as pa
import pyarrow.parquet as pq

from zephyr.expr import ColumnExpr, CompareExpr, LiteralExpr
from zephyr.readers import InputFileSpec, load_parquet


def _write_test_parquet(path: str, records: list[dict], row_group_size: int = 2) -> None:
"""Write a parquet file with small row groups for testing."""
table = pa.Table.from_pylist(records)
pq.write_table(table, path, row_group_size=row_group_size)


RECORDS = [{"id": i, "name": f"row{i}", "score": float(i * 10)} for i in range(10)]


def test_load_parquet_plain(tmp_path):
path = str(tmp_path / "data.parquet")
_write_test_parquet(path, RECORDS)

result = list(load_parquet(path))
assert result == RECORDS


def test_load_parquet_columns(tmp_path):
path = str(tmp_path / "data.parquet")
_write_test_parquet(path, RECORDS)

spec = InputFileSpec(path=path, columns=["id", "name"])
result = list(load_parquet(spec))
assert result == [{"id": r["id"], "name": r["name"]} for r in RECORDS]


def test_load_parquet_row_range(tmp_path):
path = str(tmp_path / "data.parquet")
_write_test_parquet(path, RECORDS, row_group_size=3)

spec = InputFileSpec(path=path, row_start=2, row_end=7)
result = list(load_parquet(spec))
assert [r["id"] for r in result] == [2, 3, 4, 5, 6]


def test_load_parquet_filter(tmp_path):
path = str(tmp_path / "data.parquet")
_write_test_parquet(path, RECORDS)

spec = InputFileSpec(
path=path,
filter_expr=CompareExpr(op="ge", left=ColumnExpr(name="score"), right=LiteralExpr(value=50.0)),
)
result = list(load_parquet(spec))
assert all(r["score"] >= 50.0 for r in result)
assert [r["id"] for r in result] == [5, 6, 7, 8, 9]


def test_load_parquet_filter_and_row_range(tmp_path):
path = str(tmp_path / "data.parquet")
_write_test_parquet(path, RECORDS, row_group_size=3)

spec = InputFileSpec(
path=path,
row_start=1,
row_end=8,
filter_expr=CompareExpr(op="ge", left=ColumnExpr(name="score"), right=LiteralExpr(value=50.0)),
)
result = list(load_parquet(spec))
# rows 1-7, then filtered to score >= 50 → ids 5, 6, 7
assert [r["id"] for r in result] == [5, 6, 7]


def test_load_parquet_empty(tmp_path):
path = str(tmp_path / "empty.parquet")
table = pa.Table.from_pylist([], schema=pa.schema([("id", pa.int64())]))
pq.write_table(table, path)

result = list(load_parquet(path))
assert result == []


def test_load_parquet_no_dataset_api(tmp_path, monkeypatch):
"""Verify that load_parquet does NOT import pyarrow.dataset."""
import sys

path = str(tmp_path / "data.parquet")
_write_test_parquet(path, RECORDS)

# Remove pyarrow.dataset from sys.modules and block re-import
sys.modules.pop("pyarrow.dataset", None)
monkeypatch.setitem(sys.modules, "pyarrow.dataset", None)

# Should succeed without pyarrow.dataset
result = list(load_parquet(path))
assert len(result) == len(RECORDS)
Loading