Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
115 changes: 81 additions & 34 deletions python/pyspark/sql/conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from typing import TYPE_CHECKING, Any, Callable, List, Optional, Sequence, Union, overload

import pyspark
from pyspark.errors import PySparkTypeError, PySparkValueError
from pyspark.errors import PySparkRuntimeError, PySparkValueError
from pyspark.sql.pandas.types import (
_dedup_names,
_deduplicate_field_names,
Expand Down Expand Up @@ -110,19 +110,20 @@ def wrap_struct(batch: "pa.RecordBatch") -> "pa.RecordBatch":
@classmethod
def enforce_schema(
cls,
batch: "pa.RecordBatch",
batch: Union["pa.RecordBatch", "pa.Table"],
arrow_schema: "pa.Schema",
*,
arrow_cast: bool = True,
safecheck: bool = True,
) -> "pa.RecordBatch":
reorder_by_name: bool = True,
) -> Union["pa.RecordBatch", "pa.Table"]:
"""
Enforce target schema on a RecordBatch by reordering columns and coercing types.
Enforce a target schema on an Arrow RecordBatch or Table.

Parameters
----------
batch : pa.RecordBatch
Input RecordBatch to transform.
batch : pa.RecordBatch or pa.Table
Input to transform. Output is of the same container type.
arrow_schema : pa.Schema
Target Arrow schema. Callers should pre-compute this once via
to_arrow_schema() to avoid repeated conversion.
Expand All @@ -131,11 +132,26 @@ def enforce_schema(
If False, raise an error on type mismatch instead of casting.
safecheck : bool, default True
If True, use safe casting (fails on overflow/truncation).
reorder_by_name : bool, default True
If True, match columns by name and reorder to the target order; any
missing or extra names raise ``RESULT_COLUMN_NAMES_MISMATCH``. Output
columns are renamed to target names.
If False, match columns by position (ignore names) and preserve the
original column names in the output.

Returns
-------
pa.RecordBatch
RecordBatch with columns reordered and types coerced to match target schema.
pa.RecordBatch or pa.Table
Same container type as ``batch``, with columns matched (and possibly
reordered/cast) per the target schema.

Raises
------
PySparkRuntimeError
``RESULT_COLUMN_NAMES_MISMATCH`` when ``reorder_by_name=True`` and the
batch has missing or extra column names.
``RESULT_COLUMN_TYPES_MISMATCH`` when any column's type does not match
the target (and either ``arrow_cast=False`` or the cast itself fails).
"""
import pyarrow as pa

Expand All @@ -146,37 +162,68 @@ def enforce_schema(
if batch.schema.equals(arrow_schema, check_metadata=False):
return batch

# Check if columns are in the same order (by name) as the target schema.
# If so, use index-based access (faster than name lookup).
batch_names = [batch.schema.field(i).name for i in range(batch.num_columns)]
target_names = [field.name for field in arrow_schema]
use_index = batch_names == target_names

coerced_arrays = []
for i, field in enumerate(arrow_schema):
try:
arr = batch.column(i) if use_index else batch.column(field.name)
except KeyError:
raise PySparkTypeError(
f"Result column '{field.name}' does not exist in the output. "
f"Expected schema: {arrow_schema}, got: {batch.schema}."
# Step 1: pick source columns from batch to align with target schema
if reorder_by_name:
batch_names = [batch.schema.field(i).name for i in range(batch.num_columns)]
missing = sorted(set(target_names) - set(batch_names))
extra = sorted(set(batch_names) - set(target_names))
if missing or extra:
raise PySparkRuntimeError(
errorClass="RESULT_COLUMN_NAMES_MISMATCH",
messageParameters={
"missing": f" Missing: {', '.join(missing)}." if missing else "",
"extra": f" Unexpected: {', '.join(extra)}." if extra else "",
},
)
if arr.type != field.type:
if not arrow_cast:
raise PySparkTypeError(
f"Result type of column '{field.name}' does not match "
f"the expected type. Expected: {field.type}, got: {arr.type}."
)
source_columns = [batch.column(name) for name in target_names]
output_names = target_names
else:
# Positional: require exact column-count match, then take columns by
# index, preserving the batch's original column names.
if batch.num_columns != len(arrow_schema):
raise PySparkRuntimeError(
errorClass="RESULT_COLUMN_SCHEMA_MISMATCH",
messageParameters={
"expected": str(len(arrow_schema)),
"actual": str(batch.num_columns),
},
)
source_columns = [batch.column(i) for i in range(len(arrow_schema))]
output_names = [batch.schema.field(i).name for i in range(len(arrow_schema))]

# Step 2: check types / cast, collect all mismatches
type_mismatches = []
coerced_arrays = []
for field, arr in zip(arrow_schema, source_columns):
if arr.type == field.type:
coerced_arrays.append(arr)
elif not arrow_cast:
type_mismatches.append((field.name, field.type, arr.type))
coerced_arrays.append(arr)
else:
try:
arr = arr.cast(target_type=field.type, safe=safecheck)
except (pa.ArrowInvalid, pa.ArrowTypeError) as e:
raise PySparkTypeError(
f"Result type of column '{field.name}' does not match "
f"the expected type. Expected: {field.type}, got: {arr.type}."
) from e
coerced_arrays.append(arr)
coerced_arrays.append(arr.cast(target_type=field.type, safe=safecheck))
except (pa.ArrowInvalid, pa.ArrowTypeError):
type_mismatches.append((field.name, field.type, arr.type))
coerced_arrays.append(arr)

if type_mismatches:
raise PySparkRuntimeError(
errorClass="RESULT_COLUMN_TYPES_MISMATCH",
messageParameters={
"mismatch": ", ".join(
f"column '{name}' (expected {expected}, actual {actual})"
for name, expected, actual in type_mismatches
)
},
)

return pa.RecordBatch.from_arrays(coerced_arrays, names=target_names)
# Preserve input container type (Table vs RecordBatch)
if isinstance(batch, pa.Table):
return pa.Table.from_arrays(coerced_arrays, names=output_names)
return pa.RecordBatch.from_arrays(coerced_arrays, names=output_names)

@classmethod
def to_pandas(
Expand Down
9 changes: 4 additions & 5 deletions python/pyspark/sql/tests/arrow/test_arrow_udtf.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,9 +211,8 @@ def eval(self) -> Iterator["pa.Table"]:

with self.assertRaisesRegex(
PythonException,
r"(?s)Result column 'x' does not exist in the output\. "
r"Expected schema: x: int32\ny: string, "
r"got: wrong_col: int32\nanother_wrong_col: double\.",
r"(?s)\[RESULT_COLUMN_NAMES_MISMATCH\].*"
r"Missing: x, y\..*Unexpected: another_wrong_col, wrong_col\.",
):
result_df = MismatchedSchemaUDTF()
result_df.collect()
Expand Down Expand Up @@ -375,8 +374,8 @@ def eval(self) -> Iterator["pa.Table"]:
# Should fail with Arrow cast exception since string cannot be cast to int
with self.assertRaisesRegex(
PythonException,
"Result type of column 'id' does not match "
"the expected type. Expected: int32, got: string.",
r"(?s)\[RESULT_COLUMN_TYPES_MISMATCH\].*"
r"column 'id' \(expected int32, actual string\)",
):
result_df = StringToIntUDTF()
result_df.collect()
Expand Down
66 changes: 61 additions & 5 deletions python/pyspark/sql/tests/test_conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import unittest
from zoneinfo import ZoneInfo

from pyspark.errors import PySparkTypeError, PySparkValueError
from pyspark.errors import PySparkRuntimeError, PySparkTypeError, PySparkValueError
from pyspark.sql.conversion import (
ArrowArrayToPandasConversion,
ArrowTableToRowsConversion,
Expand Down Expand Up @@ -185,27 +185,83 @@ def test_enforce_schema_arrow_cast_false(self):

batch = pa.RecordBatch.from_arrays([pa.array([1], type=pa.int32())], names=["x"])
target = pa.schema([("x", pa.int64())])
with self.assertRaises(PySparkTypeError):
with self.assertRaises(PySparkRuntimeError) as cm:
ArrowBatchTransformer.enforce_schema(batch, target, arrow_cast=False)
self.assertEqual(cm.exception.getCondition(), "RESULT_COLUMN_TYPES_MISMATCH")

def test_enforce_schema_safecheck(self):
"""safecheck=True rejects overflow; safecheck=False allows it."""
import pyarrow as pa

batch = pa.RecordBatch.from_arrays([pa.array([999], type=pa.int64())], names=["x"])
target = pa.schema([("x", pa.int8())])
with self.assertRaises(PySparkTypeError):
with self.assertRaises(PySparkRuntimeError) as cm:
ArrowBatchTransformer.enforce_schema(batch, target, safecheck=True)
self.assertEqual(cm.exception.getCondition(), "RESULT_COLUMN_TYPES_MISMATCH")
result = ArrowBatchTransformer.enforce_schema(batch, target, safecheck=False)
self.assertEqual(result.schema, target)

def test_enforce_schema_missing_column(self):
"""Missing column raises PySparkTypeError."""
"""Missing column raises RESULT_COLUMN_NAMES_MISMATCH."""
import pyarrow as pa

batch = pa.RecordBatch.from_arrays([pa.array([1])], names=["a"])
with self.assertRaises(PySparkTypeError):
with self.assertRaises(PySparkRuntimeError) as cm:
ArrowBatchTransformer.enforce_schema(batch, pa.schema([("missing", pa.int64())]))
self.assertEqual(cm.exception.getCondition(), "RESULT_COLUMN_NAMES_MISMATCH")

def test_enforce_schema_extra_column(self):
"""Extra column raises RESULT_COLUMN_NAMES_MISMATCH with the extra name listed."""
import pyarrow as pa

batch = pa.RecordBatch.from_arrays([pa.array([1]), pa.array([2])], names=["a", "b"])
with self.assertRaises(PySparkRuntimeError) as cm:
ArrowBatchTransformer.enforce_schema(batch, pa.schema([("a", pa.int64())]))
self.assertEqual(cm.exception.getCondition(), "RESULT_COLUMN_NAMES_MISMATCH")
self.assertIn("b", str(cm.exception))

def test_enforce_schema_reorder_by_name(self):
"""reorder_by_name=True reorders input columns to match target schema order."""
import pyarrow as pa

batch = pa.RecordBatch.from_arrays([pa.array(["x"]), pa.array([1])], names=["b", "a"])
target = pa.schema([("a", pa.int64()), ("b", pa.string())])
result = ArrowBatchTransformer.enforce_schema(batch, target)
self.assertEqual(result.schema.names, ["a", "b"])
self.assertEqual(result.column(0).to_pylist(), [1])
self.assertEqual(result.column(1).to_pylist(), ["x"])

def test_enforce_schema_positional(self):
"""reorder_by_name=False matches columns by index, preserving input names."""
import pyarrow as pa

batch = pa.RecordBatch.from_arrays([pa.array([1]), pa.array(["x"])], names=["foo", "bar"])
target = pa.schema([("a", pa.int64()), ("b", pa.string())])
result = ArrowBatchTransformer.enforce_schema(batch, target, reorder_by_name=False)
# Input column names are preserved
self.assertEqual(result.schema.names, ["foo", "bar"])
self.assertEqual(result.column(0).to_pylist(), [1])
self.assertEqual(result.column(1).to_pylist(), ["x"])

def test_enforce_schema_positional_count_mismatch(self):
"""reorder_by_name=False with wrong column count raises RESULT_COLUMN_SCHEMA_MISMATCH."""
import pyarrow as pa

batch = pa.RecordBatch.from_arrays([pa.array([1])], names=["a"])
target = pa.schema([("x", pa.int64()), ("y", pa.int64())])
with self.assertRaises(PySparkRuntimeError) as cm:
ArrowBatchTransformer.enforce_schema(batch, target, reorder_by_name=False)
self.assertEqual(cm.exception.getCondition(), "RESULT_COLUMN_SCHEMA_MISMATCH")

def test_enforce_schema_table_input(self):
"""enforce_schema accepts pa.Table and returns pa.Table."""
import pyarrow as pa

table = pa.table({"x": pa.array([1], type=pa.int32())})
target = pa.schema([("x", pa.int64())])
result = ArrowBatchTransformer.enforce_schema(table, target)
self.assertIsInstance(result, pa.Table)
self.assertEqual(result.schema, target)


@unittest.skipIf(not have_pyarrow, pyarrow_requirement_message)
Expand Down
Loading