Skip to content

Commit 82bb2fa

Browse files
ravwojdylaclaude
andcommitted
zephyr: widen inferred parquet schema via pa.unify_schemas
``_accumulate_tables`` infers its schema from the first micro-batch (``_MICRO_BATCH_SIZE=8``). If those first records happen to have ``None`` for a field — or to lack a field that appears later — downstream batches that would legitimately widen the schema either crashed with ``ArrowInvalid: Invalid null value`` or (in the new-field case) were silently truncated by ``pa.Table.from_pylist``. Unify-widen the inferred schema on mismatch and reconcile chunks on yield via ``concat_tables(promote_options="permissive")``. Surface genuine incompatibilities (e.g. int vs string) as errors with both schemas and the inference origin shown, so operators can diagnose without extra instrumentation. An explicit caller-provided schema is treated as a contract: mismatches raise without silent widening. Tests cover: null→concrete widening, new-field-appears-later (previously silently dropped), and int-vs-string conflict surfacing. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
1 parent ef51f83 commit 82bb2fa

2 files changed

Lines changed: 119 additions & 6 deletions

File tree

lib/zephyr/src/zephyr/writers.py

Lines changed: 70 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -171,28 +171,94 @@ def _accumulate_tables(
171171
Converts records to PyArrow in micro-batches of ``_MICRO_BATCH_SIZE``,
172172
tracks byte size incrementally, and yields a single ``concat_tables``
173173
result each time the threshold is reached.
174+
175+
When the caller did not pass an explicit schema, the schema is inferred
176+
from the first micro-batch. If a later micro-batch doesn't fit that
177+
schema — e.g. early rows pinned a column as ``null`` and a later row
178+
supplies a concrete value, or a new top-level column appears — the
179+
schemas are unified via :func:`pa.unify_schemas` and the batch is
180+
rebuilt against the widened schema. On yield, prior chunks whose
181+
schemas differ are reconciled via ``concat_tables(promote_options=
182+
"permissive")``. Genuinely incompatible schemas (e.g. ``int`` vs
183+
``string`` for the same field) still raise, with both schemas shown.
184+
185+
An explicit caller-provided schema is treated as a contract: mismatches
186+
raise without silent widening.
174187
"""
175188
chunks: list[pa.Table] = []
176189
bytesize = 0
177190
convert: Callable | None = None
191+
schema_inferred = schema is None
192+
193+
def _raise_schema_mismatch(e: Exception, dicts: list[dict[str, Any]]) -> None:
194+
actual_schema = pa.Table.from_pylist(dicts).schema
195+
origin = (
196+
f"inferred from first {_MICRO_BATCH_SIZE} records (no explicit schema passed)"
197+
if schema_inferred
198+
else "explicitly provided by caller"
199+
)
200+
raise pa.ArrowInvalid(
201+
f"Schema mismatch converting batch to Arrow: {e}\n"
202+
f"Expected schema ({origin}):\n{schema}\n"
203+
f"Got schema:\n{actual_schema}"
204+
) from e
205+
206+
def _build_table(dicts: list[dict[str, Any]], schema: pa.Schema) -> tuple[pa.Table, pa.Schema]:
207+
"""Convert *dicts* to a table under *schema*, widening via ``pa.unify_schemas`` when needed.
208+
209+
Returns ``(table, schema)`` where ``schema`` may be wider than the
210+
input. Handles two kinds of divergence: (1) ``from_pylist`` raises
211+
because a field's type doesn't fit, (2) ``from_pylist`` would
212+
silently drop extra top-level keys (new fields appearing only in
213+
later batches). Raises (via :func:`_raise_schema_mismatch`) when
214+
*schema* was explicitly provided by the caller, or when the
215+
divergence isn't representable as a widening (e.g. ``int`` vs
216+
``string``).
217+
"""
218+
mismatch_error: Exception | None = None
219+
try:
220+
table = pa.Table.from_pylist(dicts, schema=schema)
221+
except (pa.ArrowInvalid, pa.ArrowTypeError, pa.ArrowNotImplementedError) as e:
222+
mismatch_error = e
223+
224+
if mismatch_error is None:
225+
extra_keys = {k for d in dicts for k in d.keys()} - set(schema.names)
226+
if not extra_keys:
227+
return table, schema
228+
mismatch_error = pa.ArrowInvalid(f"extra top-level keys not in schema: {sorted(extra_keys)}")
229+
230+
if not schema_inferred:
231+
_raise_schema_mismatch(mismatch_error, dicts)
232+
new_schema = pa.Table.from_pylist(dicts).schema
233+
try:
234+
widened = pa.unify_schemas([schema, new_schema])
235+
except (pa.ArrowInvalid, pa.ArrowTypeError, pa.ArrowNotImplementedError):
236+
_raise_schema_mismatch(mismatch_error, dicts)
237+
return pa.Table.from_pylist(dicts, schema=widened), widened
178238

179239
for micro_batch in batchify(records, n=_MICRO_BATCH_SIZE):
180240
if convert is None:
181241
convert = asdict if is_dataclass(micro_batch[0]) else (lambda x: x)
182242
dicts = [convert(r) for r in micro_batch]
183243
if schema is None:
184-
# NOTE: the _MICRO_BATCH_SIZE is fairly small, here we hope it's enough to infer "real" schema
244+
# NOTE: _MICRO_BATCH_SIZE is small; if the initial schema turns
245+
# out to be narrower than the stream's true schema, we widen
246+
# below on the first mismatching batch.
185247
schema = infer_arrow_schema(dicts)
186-
table = pa.Table.from_pylist(dicts, schema=schema)
248+
249+
table, schema = _build_table(dicts, schema)
187250
chunks.append(table)
188251
bytesize += table.nbytes
189252
if bytesize >= target_bytes:
190-
yield pa.concat_tables(chunks)
253+
# ``promote_options="permissive"`` reconciles chunks whose schemas
254+
# widened mid-stream (e.g. a later chunk introduced a new column
255+
# or widened ``null`` → concrete type).
256+
yield pa.concat_tables(chunks, promote_options="permissive")
191257
chunks = []
192258
bytesize = 0
193259

194260
if chunks:
195-
yield pa.concat_tables(chunks)
261+
yield pa.concat_tables(chunks, promote_options="permissive")
196262

197263

198264
def write_parquet_file(

lib/zephyr/tests/test_writers.py

Lines changed: 49 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,12 +7,11 @@
77
import tempfile
88
from pathlib import Path
99

10+
import pyarrow as pa
1011
import pyarrow.parquet as pq
1112
import pytest
1213
import vortex
1314

14-
import pyarrow as pa
15-
1615
from zephyr.writers import (
1716
atomic_rename,
1817
infer_arrow_schema,
@@ -151,6 +150,54 @@ def test_write_parquet_file_basic():
151150
assert len(table) == 2
152151

153152

153+
def test_write_parquet_file_widens_null_to_concrete_type():
154+
"""First batch pins a field as null; a later batch with a concrete type widens cleanly.
155+
156+
This is the stackv2 failure mode: the first ``_MICRO_BATCH_SIZE`` (=8)
157+
records all had ``None`` for a field, pinning it to ``pa.null()`` —
158+
later records with real values would fail without schema widening.
159+
Behavior must: (a) succeed, (b) land the widened schema on disk, (c)
160+
preserve all values from both batches.
161+
"""
162+
records = [{"x": None}] * 8 + [{"x": "hello"}]
163+
with tempfile.TemporaryDirectory() as tmpdir:
164+
output_path = str(Path(tmpdir) / "test.parquet")
165+
result = write_parquet_file(records, output_path)
166+
assert result["count"] == 9
167+
168+
table = pq.read_table(output_path)
169+
assert len(table) == 9
170+
assert pa.types.is_string(table.schema.field("x").type)
171+
xs = table.column("x").to_pylist()
172+
assert xs[:8] == [None] * 8
173+
assert xs[8] == "hello"
174+
175+
176+
def test_write_parquet_file_captures_fields_appearing_in_later_batches():
177+
"""A field absent from the first batch but present later must not be silently dropped."""
178+
records = [{"x": "a"}] * 8 + [{"x": "b", "z": 42}]
179+
with tempfile.TemporaryDirectory() as tmpdir:
180+
output_path = str(Path(tmpdir) / "test.parquet")
181+
result = write_parquet_file(records, output_path)
182+
assert result["count"] == 9
183+
184+
table = pq.read_table(output_path)
185+
assert "z" in table.schema.names, "field `z` must survive to disk, not be dropped"
186+
assert table.column("z").to_pylist() == [None] * 8 + [42]
187+
188+
189+
def test_write_parquet_file_raises_on_incompatible_type_conflict():
190+
"""Genuine type conflicts (e.g. int vs string) must still raise a clear error."""
191+
records = [{"x": i} for i in range(8)] + [{"x": "stringy"}]
192+
with tempfile.TemporaryDirectory() as tmpdir:
193+
output_path = str(Path(tmpdir) / "test.parquet")
194+
with pytest.raises((pa.ArrowInvalid, pa.ArrowTypeError)) as excinfo:
195+
write_parquet_file(records, output_path)
196+
msg = str(excinfo.value)
197+
assert "int" in msg.lower() or "int64" in msg.lower()
198+
assert "string" in msg.lower()
199+
200+
154201
def test_write_parquet_file_empty():
155202
"""Test writing an empty parquet file."""
156203
with tempfile.TemporaryDirectory() as tmpdir:

0 commit comments

Comments
 (0)