Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions doc/source/data/api/expressions.rst
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ instantiate them directly, but you may encounter them when working with expressi
UnaryExpr
UDFExpr
StarExpr
UnresolvedExpr

Expression namespaces
------------------------------------
Expand Down
9 changes: 9 additions & 0 deletions python/ray/data/_internal/datasource/iceberg_datasource.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
StarExpr,
UDFExpr,
UnaryExpr,
UnresolvedExpr,
)
from ray.util import log_once
from ray.util.annotations import DeveloperAPI
Expand Down Expand Up @@ -176,6 +177,14 @@ def visit_star(
"Star expressions cannot be converted to Iceberg filter expressions."
)

def visit_unresolved(
self, expr: "UnresolvedExpr"
) -> "BooleanExpression | UnboundTerm[Any] | Literal[Any]":
"""Unresolved expressions cannot be converted to Iceberg expressions."""
raise TypeError(
"Unresolved expressions cannot be converted to Iceberg filter expressions."
)


def _get_read_task(
tasks: Iterable["FileScanTask"],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
StarExpr,
UDFExpr,
UnaryExpr,
UnresolvedExpr,
_ExprVisitor,
col,
)
Expand Down Expand Up @@ -676,6 +677,23 @@ def visit_star(self, expr: StarExpr) -> Union[BlockColumn, ScalarType]:
"It should only be used in Project operations."
)

def visit_unresolved(self, expr: UnresolvedExpr) -> Union[BlockColumn, ScalarType]:
"""Visit an unresolved expression.

Args:
expr: The unresolved expression.

Returns:
This method does not return; it always raises TypeError.

Raises:
TypeError: UnresolvedExpr cannot be evaluated.
"""
raise TypeError(
"UnresolvedExpr cannot be evaluated. "
"Resolve it to a concrete expression before evaluation."
)

def visit_download(self, expr: DownloadExpr) -> Union[BlockColumn, ScalarType]:
"""Visit a download expression.

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
StarExpr,
UDFExpr,
UnaryExpr,
UnresolvedExpr,
_CallableClassUDF,
_ExprVisitor,
)
Expand Down Expand Up @@ -78,6 +79,10 @@ def visit_download(self, expr: "Expr") -> None:
"""Visit a download expression (no columns to collect)."""
pass

def visit_unresolved(self, expr: UnresolvedExpr) -> None:
"""Visit an unresolved expression (no columns to collect)."""
pass


class _ColumnReferenceCollector(_ExprVisitorBase):
"""Visitor that collects all column references from expression trees.
Expand Down Expand Up @@ -106,6 +111,10 @@ def visit_column(self, expr: ColumnExpr) -> None:
"""
self._col_refs[expr.name] = None

def visit_unresolved(self, expr: UnresolvedExpr) -> None:
"""Visit an unresolved expression and collect its name."""
self._col_refs[expr.name] = None

def visit_alias(self, expr: AliasExpr) -> None:
"""Visit an alias expression and collect from its inner expression.

Expand Down Expand Up @@ -197,6 +206,11 @@ def visit_literal(self, expr: LiteralExpr) -> Expr:
"""
return expr

def visit_unresolved(self, expr: UnresolvedExpr) -> Expr:
"""Visit an unresolved expression and substitute it if possible."""
substitution = self._col_ref_substitutions.get(expr.name)
return substitution if substitution is not None else expr

def visit_binary(self, expr: BinaryExpr) -> Expr:
"""Visit a binary expression and rewrite its operands.

Expand Down Expand Up @@ -413,6 +427,9 @@ def visit_download(self, expr: "DownloadExpr") -> str:
def visit_star(self, expr: "StarExpr") -> str:
return self._make_tree_lines("COL(*)", expr=expr)

def visit_unresolved(self, expr: "UnresolvedExpr") -> str:
return self._make_tree_lines(f"UNRESOLVED({expr.name!r})", expr=expr)


class _InlineExprReprVisitor(_ExprVisitor[str]):
"""Visitor that generates concise inline string representations of expressions.
Expand Down Expand Up @@ -502,6 +519,10 @@ def visit_star(self, expr: "StarExpr") -> str:
"""Visit a star expression and return its inline representation."""
return "col(*)"

def visit_unresolved(self, expr: "UnresolvedExpr") -> str:
"""Visit an unresolved expression and return its inline representation."""
return f"unresolved({expr.name!r})"


def get_column_references(expr: Expr) -> List[str]:
"""Extract all column references from an expression.
Expand Down
39 changes: 36 additions & 3 deletions python/ray/data/expressions.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,8 @@ def visit(self, expr: "Expr") -> T:
return self.visit_download(expr)
elif isinstance(expr, StarExpr):
return self.visit_star(expr)
elif isinstance(expr, UnresolvedExpr):
return self.visit_unresolved(expr)
else:
raise TypeError(f"Unsupported expression type for conversion: {type(expr)}")

Expand Down Expand Up @@ -142,6 +144,10 @@ def visit_star(self, expr: "StarExpr") -> T:
def visit_download(self, expr: "DownloadExpr") -> T:
pass

@abstractmethod
def visit_unresolved(self, expr: "UnresolvedExpr") -> T:
pass


class _PyArrowExpressionVisitor(_ExprVisitor["pyarrow.compute.Expression"]):
"""Visitor that converts Ray Data expressions to PyArrow compute expressions."""
Expand Down Expand Up @@ -206,6 +212,11 @@ def visit_download(self, expr: "DownloadExpr") -> "pyarrow.compute.Expression":
def visit_star(self, expr: "StarExpr") -> "pyarrow.compute.Expression":
raise TypeError("Star expressions cannot be converted to PyArrow expressions")

def visit_unresolved(self, expr: "UnresolvedExpr") -> "pyarrow.compute.Expression":
raise TypeError(
"Unresolved expressions cannot be converted to PyArrow expressions"
)


@DeveloperAPI(stability="alpha")
@dataclass(frozen=True)
Expand All @@ -232,7 +243,7 @@ class Expr(ABC):
subclasses like ColumnExpr, LiteralExpr, etc.
"""

data_type: DataType
data_type: DataType | None

@property
def name(self) -> str | None:
Expand Down Expand Up @@ -1360,6 +1371,28 @@ def structurally_equals(self, other: Any) -> bool:
)


@DeveloperAPI(stability="alpha")
@dataclass(frozen=True, eq=False, repr=False)
class UnresolvedExpr(Expr):
"""Expression that represents an unresolved column reference.

This expression is a placeholder used when a column reference has not yet
been resolved against a concrete schema. It must be resolved before
evaluation or conversion to another expression system.
"""

_name: str
data_type: DataType | None = field(default=None, init=False)

@property
def name(self) -> str:
"""Get the unresolved column name."""
return self._name

def structurally_equals(self, other: Any) -> bool:
return isinstance(other, UnresolvedExpr) and self.name == other.name


@DeveloperAPI(stability="alpha")
@dataclass(frozen=True, eq=False, repr=False)
class StarExpr(Expr):
Expand All @@ -1377,8 +1410,7 @@ class StarExpr(Expr):
This means: keep all existing columns, then add/overwrite "new_col"
"""

# TODO: Add UnresolvedExpr. Both StarExpr and UnresolvedExpr won't have a defined data_type.
data_type: DataType = field(default_factory=lambda: DataType(object), init=False)
data_type: DataType | None = field(default=None, init=False)

def structurally_equals(self, other: Any) -> bool:
return isinstance(other, StarExpr)
Expand Down Expand Up @@ -1513,6 +1545,7 @@ def download(
"UDFExpr",
"DownloadExpr",
"AliasExpr",
"UnresolvedExpr",
"StarExpr",
"pyarrow_udf",
"udf",
Expand Down
16 changes: 16 additions & 0 deletions python/ray/data/tests/unit/expressions/test_conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
BinaryExpr,
Operation,
UDFExpr,
UnresolvedExpr,
col,
download,
lit,
Expand Down Expand Up @@ -303,6 +304,13 @@ def test_star_expression_raises(self):
with pytest.raises(TypeError, match="Star expressions cannot be converted"):
star().to_pyarrow()

def test_unresolved_expression_raises(self):
"""Test that unresolved expressions raise TypeError."""
with pytest.raises(
TypeError, match="Unresolved expressions cannot be converted"
):
UnresolvedExpr("pending").to_pyarrow()


# ──────────────────────────────────────
# Iceberg Conversion Tests
Expand Down Expand Up @@ -538,6 +546,14 @@ def test_star_expression_raises(self):
):
visitor.visit(star())

def test_unresolved_expression_raises(self):
"""Test that unresolved expressions raise TypeError."""
visitor = _IcebergExpressionVisitor()
with pytest.raises(
TypeError, match="Unresolved expressions cannot be converted to Iceberg"
):
visitor.visit(UnresolvedExpr("pending"))

def test_is_in_requires_literal_list(self):
"""Test that IN/NOT_IN operations require literal lists."""
visitor = _IcebergExpressionVisitor()
Expand Down
36 changes: 35 additions & 1 deletion python/ray/data/tests/unit/expressions/test_core.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""Tests for core expression types and basic functionality.

This module tests:
- ColumnExpr, LiteralExpr, BinaryExpr, UnaryExpr, AliasExpr, StarExpr
- ColumnExpr, LiteralExpr, BinaryExpr, UnaryExpr, AliasExpr, StarExpr, UnresolvedExpr
- Structural equality for all expression types
- Expression tree repr (string representation)
- UDFExpr structural equality
Expand All @@ -24,6 +24,7 @@
StarExpr,
UDFExpr,
UnaryExpr,
UnresolvedExpr,
col,
download,
lit,
Expand Down Expand Up @@ -296,6 +297,37 @@ def test_star_structural_equality(self):
assert not star().structurally_equals(col("a"))


# ──────────────────────────────────────
# Unresolved Expression Tests
# ──────────────────────────────────────


class TestUnresolvedExpr:
"""Tests for UnresolvedExpr functionality."""

def test_unresolved_creation(self):
"""Test that UnresolvedExpr creates with correct name."""
expr = UnresolvedExpr("pending")
assert isinstance(expr, UnresolvedExpr)
assert expr.name == "pending"

@pytest.mark.parametrize(
"name1,name2,expected",
[
("a", "a", True),
("a", "b", False),
("column_name", "column_name", True),
("COL", "col", False),
],
ids=["same_name", "different_name", "long_name", "case_sensitive"],
)
def test_unresolved_structural_equality(self, name1, name2, expected):
"""Test structural equality for unresolved expressions."""
assert (
UnresolvedExpr(name1).structurally_equals(UnresolvedExpr(name2)) is expected
)


# ──────────────────────────────────────
# UDF Expression Tests
# ──────────────────────────────────────
Expand Down Expand Up @@ -393,13 +425,15 @@ class TestCrossTypeEquality:
(lit(1), lit(1) + 0),
(col("a"), col("a").alias("a")),
(col("a"), star()),
(col("a"), UnresolvedExpr("a")),
],
ids=[
"col_vs_lit",
"col_vs_binary",
"lit_vs_binary",
"col_vs_alias",
"col_vs_star",
"col_vs_unresolved",
],
)
def test_different_types_not_equal(self, expr1, expr2):
Expand Down