Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
133 changes: 67 additions & 66 deletions python/pyspark/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,21 @@
import inspect
import itertools
import json
from typing import Any, Callable, Iterable, Iterator, Optional, Tuple, TYPE_CHECKING, Union
from typing import (
Any,
Callable,
Iterable,
Iterator,
Optional,
Tuple,
Type,
TypeVar,
TYPE_CHECKING,
Union,
overload,
)

T = TypeVar("T")

if TYPE_CHECKING:
from pyspark.sql.pandas._typing import GroupedBatch
Expand Down Expand Up @@ -234,46 +248,56 @@ def chain(f, g):
return lambda *a: g(f(*a))


def verify_result(expected_type: type) -> Callable[[Any], Iterator]:
"""
Create a result verifier that checks both iterability and element types.
@overload
def verify_return_type(result: Any, expected_type: Type[T]) -> T: ...

Returns a function that takes a UDF result, verifies it is iterable,
and lazily type-checks each element via map.

Parameters
----------
expected_type : type
The expected Python/PyArrow type for each element
(e.g. pa.RecordBatch, pa.Array).
@overload
def verify_return_type(result: Any, expected_type: Any) -> Any: ...


def verify_return_type(result: Any, expected_type: Any) -> Any:
"""
Verify a UDF return value against an expected type.

package = getattr(inspect.getmodule(expected_type), "__package__", "")
label: str = f"{package}.{expected_type.__name__}"
Returns ``result`` unchanged if ``isinstance(result, expected_type)``.
For ``Iterator[T]``, returns a lazy iterator that checks each element
against ``T`` on consumption. Raises ``PySparkTypeError`` on mismatch.
"""
if getattr(expected_type, "_name", None) == "Iterator":
(element_type,) = expected_type.__args__
package = getattr(inspect.getmodule(element_type), "__package__", "")
label = f"iterator of {package}.{element_type.__name__}"

def check_element(element: Any) -> Any:
if not isinstance(element, expected_type):
if not isinstance(result, Iterator):
raise PySparkTypeError(
errorClass="UDF_RETURN_TYPE",
messageParameters={
"expected": f"iterator of {label}",
"actual": f"iterator of {type(element).__name__}",
},
messageParameters={"expected": label, "actual": type(result).__name__},
)
return element

def check(result: Any) -> Iterator:
if not isinstance(result, Iterator) and not hasattr(result, "__iter__"):
raise PySparkTypeError(
errorClass="UDF_RETURN_TYPE",
messageParameters={
"expected": f"iterator of {label}",
"actual": type(result).__name__,
},
)
def check_element(element: T) -> T:
if not isinstance(element, element_type):
raise PySparkTypeError(
errorClass="UDF_RETURN_TYPE",
messageParameters={
"expected": label,
"actual": f"iterator of {type(element).__name__}",
},
)
return element

return map(check_element, result)

return check
if not isinstance(result, expected_type):
package = getattr(inspect.getmodule(expected_type), "__package__", "")
raise PySparkTypeError(
errorClass="UDF_RETURN_TYPE",
messageParameters={
"expected": f"{package}.{expected_type.__name__}",
"actual": type(result).__name__,
},
)
return result


def verify_result_row_count(result_length: int, expected: int) -> None:
Expand Down Expand Up @@ -512,6 +536,8 @@ def verify_pandas_result(result, return_type, assign_cols_by_name, truncate_retu


def wrap_cogrouped_map_arrow_udf(f, return_type, argspec, runner_conf):
import pyarrow as pa

if runner_conf.assign_cols_by_name:
expected_cols_and_types = {
col.name: to_arrow_type(col.dataType, timezone="UTC") for col in return_type.fields
Expand All @@ -529,7 +555,8 @@ def wrapped(left_key_table, left_value_table, right_key_table, right_value_table
key = tuple(c[0] for c in key_table.columns)
result = f(key, left_value_table, right_value_table)

verify_arrow_table(result, runner_conf.assign_cols_by_name, expected_cols_and_types)
verify_return_type(result, pa.Table)
verify_arrow_result(result, runner_conf.assign_cols_by_name, expected_cols_and_types)

return result.to_batches()

Expand Down Expand Up @@ -622,36 +649,6 @@ def verify_arrow_result(result, assign_cols_by_name, expected_cols_and_types):
)


def verify_arrow_table(table, assign_cols_by_name, expected_cols_and_types):
import pyarrow as pa

if not isinstance(table, pa.Table):
raise PySparkTypeError(
errorClass="UDF_RETURN_TYPE",
messageParameters={
"expected": "pyarrow.Table",
"actual": type(table).__name__,
},
)

verify_arrow_result(table, assign_cols_by_name, expected_cols_and_types)


def verify_arrow_batch(batch, assign_cols_by_name, expected_cols_and_types):
import pyarrow as pa

if not isinstance(batch, pa.RecordBatch):
raise PySparkTypeError(
errorClass="UDF_RETURN_TYPE",
messageParameters={
"expected": "pyarrow.RecordBatch",
"actual": type(batch).__name__,
},
)

verify_arrow_result(batch, assign_cols_by_name, expected_cols_and_types)


def wrap_grouped_map_pandas_udf(f, return_type, argspec, runner_conf):
def wrapped(key_series, value_series):
import pandas as pd
Expand Down Expand Up @@ -2561,8 +2558,8 @@ def func(split_index: int, data: Iterator[pa.RecordBatch]) -> Iterator[pa.Record
output_batches = udf_func(input_batches)

# Post-processing
verified: Iterator[pa.RecordBatch] = verify_result(pa.RecordBatch)(output_batches)
yield from map(ArrowBatchTransformer.wrap_struct, verified)
verified_iter = verify_return_type(output_batches, Iterator[pa.RecordBatch])
yield from map(ArrowBatchTransformer.wrap_struct, verified_iter)

# profiling is not supported for UDF
return func, None, ser, ser
Expand Down Expand Up @@ -2626,7 +2623,7 @@ def extract_args(batch: pa.RecordBatch):
args_iter = map(extract_args, data)

# Call UDF and verify result type (iterator of pa.Array)
verified_iter = verify_result(pa.Array)(udf_func(args_iter))
verified_iter = verify_return_type(udf_func(args_iter), Iterator[pa.Array])

# Process results: enforce schema and assemble into RecordBatch
target_schema = pa.schema([pa.field("_0", arrow_return_type)])
Expand Down Expand Up @@ -2855,7 +2852,10 @@ def grouped_func(
key = tuple(c[0] for c in keys.columns)
result = grouped_udf(key, value_table)

verify_arrow_table(result, runner_conf.assign_cols_by_name, expected_cols_and_types)
verify_return_type(result, pa.Table)
verify_arrow_result(
result, runner_conf.assign_cols_by_name, expected_cols_and_types
)

# Reorder columns if needed and wrap into struct
for batch in result.to_batches():
Expand Down Expand Up @@ -2926,7 +2926,8 @@ def grouped_func(

# Verify, reorder, and wrap each output batch
for batch in result:
verify_arrow_batch(
verify_return_type(batch, pa.RecordBatch)
verify_arrow_result(
batch, runner_conf.assign_cols_by_name, expected_cols_and_types
)
if runner_conf.assign_cols_by_name:
Expand Down