Skip to content

Commit 0497c64

Browse files
ravwojdyla-agentravwojdylaclaude
authored
zephyr: widen inferred parquet schema via pa.unify_schemas (#5142)
* writers' ``_accumulate_tables`` infers schema from the first ``_MICRO_BATCH_SIZE=8`` records — so if those records have ``None`` for an optional field, the field gets pinned to ``pa.null()`` and later records with real values crash with ``ArrowInvalid: Invalid null value`` * real-world case: ``common-pile/stackv2``'s nested ``metadata.gha_language`` (959 null / 1041 str across ~2000 records) was deterministically failing * separately, ``pa.Table.from_pylist`` **silently drops** top-level keys missing from the pinned schema — any new column appearing in a later batch was being truncated without a signal [^1] * on mismatch, unify via ``pa.unify_schemas`` and rebuild the batch against the widened schema; reconcile prior chunks on yield via ``concat_tables(promote_options="permissive")`` * genuine type conflicts (e.g. ``int`` vs ``string`` for the same field) still raise with both schemas + inference origin shown, so operators can diagnose without extra instrumentation * explicit caller-provided schemas are a contract — mismatches raise without silent widening ## Test plan - [x] `test_write_parquet_file_widens_null_to_concrete_type` — null→string widening succeeds and lands the widened schema on disk - [x] `test_write_parquet_file_captures_fields_appearing_in_later_batches` — new field survives to disk instead of being silently dropped - [x] `test_write_parquet_file_raises_on_incompatible_type_conflict` — int vs string still surfaces as a clear error [^1]: this silent-drop behavior was a latent data-loss bug; the new extra-keys detection catches it and routes through the same widen path. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-authored-by: Rafal Wojdyla <ravwojdyla@gmail.com> Co-authored-by: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
1 parent 91972ca commit 0497c64

2 files changed

Lines changed: 110 additions & 27 deletions

File tree

lib/zephyr/src/zephyr/writers.py

Lines changed: 69 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -171,42 +171,94 @@ def _accumulate_tables(
171171
Converts records to PyArrow in micro-batches of ``_MICRO_BATCH_SIZE``,
172172
tracks byte size incrementally, and yields a single ``concat_tables``
173173
result each time the threshold is reached.
174+
175+
When the caller did not pass an explicit schema, the schema is inferred
176+
from the first micro-batch. If a later micro-batch doesn't fit that
177+
schema — e.g. early rows pinned a column as ``null`` and a later row
178+
supplies a concrete value, or a new top-level column appears — the
179+
schemas are unified via :func:`pa.unify_schemas` and the batch is
180+
rebuilt against the widened schema. On yield, prior chunks whose
181+
schemas differ are reconciled via ``concat_tables(promote_options=
182+
"permissive")``. Genuinely incompatible schemas (e.g. ``int`` vs
183+
``string`` for the same field) still raise, with both schemas shown.
184+
185+
An explicit caller-provided schema is treated as a contract: mismatches
186+
raise without silent widening.
174187
"""
175188
chunks: list[pa.Table] = []
176189
bytesize = 0
177190
convert: Callable | None = None
178191
schema_inferred = schema is None
179192

193+
def _raise_schema_mismatch(e: Exception, dicts: list[dict[str, Any]]) -> None:
194+
actual_schema = pa.Table.from_pylist(dicts).schema
195+
origin = (
196+
f"inferred from first {_MICRO_BATCH_SIZE} records (no explicit schema passed)"
197+
if schema_inferred
198+
else "explicitly provided by caller"
199+
)
200+
raise pa.ArrowInvalid(
201+
f"Schema mismatch converting batch to Arrow: {e}\n"
202+
f"Expected schema ({origin}):\n{schema}\n"
203+
f"Got schema:\n{actual_schema}"
204+
) from e
205+
206+
def _build_table(dicts: list[dict[str, Any]], schema: pa.Schema) -> tuple[pa.Table, pa.Schema]:
207+
"""Convert *dicts* to a table under *schema*, widening via ``pa.unify_schemas`` when needed.
208+
209+
Returns ``(table, schema)`` where ``schema`` may be wider than the
210+
input. Handles two kinds of divergence: (1) ``from_pylist`` raises
211+
because a field's type doesn't fit, (2) ``from_pylist`` would
212+
silently drop extra top-level keys (new fields appearing only in
213+
later batches). Raises (via :func:`_raise_schema_mismatch`) when
214+
*schema* was explicitly provided by the caller, or when the
215+
divergence isn't representable as a widening (e.g. ``int`` vs
216+
``string``).
217+
"""
218+
mismatch_error: Exception | None = None
219+
try:
220+
table = pa.Table.from_pylist(dicts, schema=schema)
221+
except (pa.ArrowInvalid, pa.ArrowTypeError, pa.ArrowNotImplementedError) as e:
222+
mismatch_error = e
223+
224+
if mismatch_error is None:
225+
extra_keys = {k for d in dicts for k in d.keys()} - set(schema.names)
226+
if not extra_keys:
227+
return table, schema
228+
mismatch_error = pa.ArrowInvalid(f"extra top-level keys not in schema: {sorted(extra_keys)}")
229+
230+
if not schema_inferred:
231+
_raise_schema_mismatch(mismatch_error, dicts)
232+
new_schema = pa.Table.from_pylist(dicts).schema
233+
try:
234+
widened = pa.unify_schemas([schema, new_schema])
235+
except (pa.ArrowInvalid, pa.ArrowTypeError, pa.ArrowNotImplementedError):
236+
_raise_schema_mismatch(mismatch_error, dicts)
237+
return pa.Table.from_pylist(dicts, schema=widened), widened
238+
180239
for micro_batch in batchify(records, n=_MICRO_BATCH_SIZE):
181240
if convert is None:
182241
convert = asdict if is_dataclass(micro_batch[0]) else (lambda x: x)
183242
dicts = [convert(r) for r in micro_batch]
184243
if schema is None:
185-
# NOTE: the _MICRO_BATCH_SIZE is fairly small, here we hope it's enough to infer "real" schema
244+
# NOTE: _MICRO_BATCH_SIZE is small; if the initial schema turns
245+
# out to be narrower than the stream's true schema, we widen
246+
# below on the first mismatching batch.
186247
schema = infer_arrow_schema(dicts)
187-
try:
188-
table = pa.Table.from_pylist(dicts, schema=schema)
189-
except (pa.ArrowInvalid, pa.ArrowTypeError, pa.ArrowNotImplementedError) as e:
190-
actual_schema = pa.Table.from_pylist(dicts).schema
191-
origin = (
192-
f"inferred from first {_MICRO_BATCH_SIZE} records (no explicit schema passed)"
193-
if schema_inferred
194-
else "explicitly provided by caller"
195-
)
196-
raise pa.ArrowInvalid(
197-
f"Schema mismatch converting batch to Arrow: {e}\n"
198-
f"Expected schema ({origin}):\n{schema}\n"
199-
f"Got schema:\n{actual_schema}"
200-
) from e
248+
249+
table, schema = _build_table(dicts, schema)
201250
chunks.append(table)
202251
bytesize += table.nbytes
203252
if bytesize >= target_bytes:
204-
yield pa.concat_tables(chunks)
253+
# ``promote_options="permissive"`` reconciles chunks whose schemas
254+
# widened mid-stream (e.g. a later chunk introduced a new column
255+
# or widened ``null`` → concrete type).
256+
yield pa.concat_tables(chunks, promote_options="permissive")
205257
chunks = []
206258
bytesize = 0
207259

208260
if chunks:
209-
yield pa.concat_tables(chunks)
261+
yield pa.concat_tables(chunks, promote_options="permissive")
210262

211263

212264
def write_parquet_file(

lib/zephyr/tests/test_writers.py

Lines changed: 41 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -151,21 +151,52 @@ def test_write_parquet_file_basic():
151151
assert len(table) == 2
152152

153153

154-
def test_write_parquet_file_schema_mismatch_surfaces_both_schemas():
155-
"""On schema divergence, the raised error includes expected + actual schemas and inference origin."""
156-
# First micro-batch (_MICRO_BATCH_SIZE=8) has `x` all None → inferred as pa.null().
157-
# A later record with a real value for `x` then fails to fit that schema.
154+
def test_write_parquet_file_widens_null_to_concrete_type():
155+
"""First batch pins a field as null; a later batch with a concrete type widens cleanly.
156+
157+
This is the stackv2 failure mode: the first ``_MICRO_BATCH_SIZE`` (=8)
158+
records all had ``None`` for a field, pinning it to ``pa.null()`` —
159+
later records with real values would fail without schema widening.
160+
Behavior must: (a) succeed, (b) land the widened schema on disk, (c)
161+
preserve all values from both batches.
162+
"""
158163
records = [{"x": None}] * 8 + [{"x": "hello"}]
159164
with tempfile.TemporaryDirectory() as tmpdir:
160165
output_path = str(Path(tmpdir) / "test.parquet")
161-
with pytest.raises(pa.ArrowInvalid) as excinfo:
166+
result = write_parquet_file(records, output_path)
167+
assert result["count"] == 9
168+
169+
table = pq.read_table(output_path)
170+
assert len(table) == 9
171+
assert pa.types.is_string(table.schema.field("x").type)
172+
xs = table.column("x").to_pylist()
173+
assert xs[:8] == [None] * 8
174+
assert xs[8] == "hello"
175+
176+
177+
def test_write_parquet_file_captures_fields_appearing_in_later_batches():
178+
"""A field absent from the first batch but present later must not be silently dropped."""
179+
records = [{"x": "a"}] * 8 + [{"x": "b", "z": 42}]
180+
with tempfile.TemporaryDirectory() as tmpdir:
181+
output_path = str(Path(tmpdir) / "test.parquet")
182+
result = write_parquet_file(records, output_path)
183+
assert result["count"] == 9
184+
185+
table = pq.read_table(output_path)
186+
assert "z" in table.schema.names, "field `z` must survive to disk, not be dropped"
187+
assert table.column("z").to_pylist() == [None] * 8 + [42]
188+
189+
190+
def test_write_parquet_file_raises_on_incompatible_type_conflict():
191+
"""Genuine type conflicts (e.g. int vs string) must still raise a clear error."""
192+
records = [{"x": i} for i in range(8)] + [{"x": "stringy"}]
193+
with tempfile.TemporaryDirectory() as tmpdir:
194+
output_path = str(Path(tmpdir) / "test.parquet")
195+
with pytest.raises((pa.ArrowInvalid, pa.ArrowTypeError)) as excinfo:
162196
write_parquet_file(records, output_path)
163197
msg = str(excinfo.value)
164-
assert "Expected schema" in msg
165-
assert "Got schema" in msg
166-
assert "inferred from first" in msg
167-
assert "x: null" in msg
168-
assert "x: string" in msg
198+
assert "int" in msg.lower() or "int64" in msg.lower()
199+
assert "string" in msg.lower()
169200

170201

171202
def test_write_parquet_file_empty():

0 commit comments

Comments
 (0)