Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions doc/source/data/api/expressions.rst
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ instantiate them directly, but you may encounter them when working with expressi
UnaryExpr
UDFExpr
StarExpr
UnresolvedExpr

Expression namespaces
------------------------------------
Expand Down
9 changes: 9 additions & 0 deletions python/ray/data/_internal/datasource/iceberg_datasource.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
StarExpr,
UDFExpr,
UnaryExpr,
UnresolvedExpr,
)
from ray.util import log_once
from ray.util.annotations import DeveloperAPI
Expand Down Expand Up @@ -176,6 +177,14 @@ def visit_star(
"Star expressions cannot be converted to Iceberg filter expressions."
)

def visit_unresolved(
self, expr: "UnresolvedExpr"
) -> "BooleanExpression | UnboundTerm[Any] | Literal[Any]":
"""Unresolved expressions cannot be converted to Iceberg expressions."""
raise TypeError(
"Unresolved expressions cannot be converted to Iceberg filter expressions."
)


def _get_read_task(
tasks: Iterable["FileScanTask"],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
StarExpr,
UDFExpr,
UnaryExpr,
UnresolvedExpr,
_ExprVisitor,
col,
)
Expand Down Expand Up @@ -676,6 +677,20 @@ def visit_star(self, expr: StarExpr) -> Union[BlockColumn, ScalarType]:
"It should only be used in Project operations."
)

def visit_unresolved(self, expr: UnresolvedExpr) -> Union[BlockColumn, ScalarType]:
"""Visit an unresolved expression.

Args:
expr: The unresolved expression.

Returns:
TypeError: UnresolvedExpr cannot be evaluated.
"""
raise TypeError(
"UnresolvedExpr cannot be evaluated. "
"Resolve it to a concrete expression before evaluation."
)

def visit_download(self, expr: DownloadExpr) -> Union[BlockColumn, ScalarType]:
"""Visit a download expression.

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
StarExpr,
UDFExpr,
UnaryExpr,
UnresolvedExpr,
_CallableClassUDF,
_ExprVisitor,
)
Expand Down Expand Up @@ -78,6 +79,10 @@ def visit_download(self, expr: "Expr") -> None:
"""Visit a download expression (no columns to collect)."""
pass

def visit_unresolved(self, expr: UnresolvedExpr) -> None:
"""Visit an unresolved expression (no columns to collect)."""
pass


class _ColumnReferenceCollector(_ExprVisitorBase):
"""Visitor that collects all column references from expression trees.
Expand Down Expand Up @@ -106,6 +111,10 @@ def visit_column(self, expr: ColumnExpr) -> None:
"""
self._col_refs[expr.name] = None

def visit_unresolved(self, expr: UnresolvedExpr) -> None:
"""Visit an unresolved expression and collect its name."""
self._col_refs[expr.name] = None

def visit_alias(self, expr: AliasExpr) -> None:
"""Visit an alias expression and collect from its inner expression.

Expand Down Expand Up @@ -197,6 +206,11 @@ def visit_literal(self, expr: LiteralExpr) -> Expr:
"""
return expr

def visit_unresolved(self, expr: UnresolvedExpr) -> Expr:
"""Visit an unresolved expression and substitute it if possible."""
substitution = self._col_ref_substitutions.get(expr.name)
return substitution if substitution is not None else expr

def visit_binary(self, expr: BinaryExpr) -> Expr:
"""Visit a binary expression and rewrite its operands.

Expand Down Expand Up @@ -413,6 +427,9 @@ def visit_download(self, expr: "DownloadExpr") -> str:
def visit_star(self, expr: "StarExpr") -> str:
return self._make_tree_lines("COL(*)", expr=expr)

def visit_unresolved(self, expr: "UnresolvedExpr") -> str:
return self._make_tree_lines(f"UNRESOLVED({expr.name!r})", expr=expr)


class _InlineExprReprVisitor(_ExprVisitor[str]):
"""Visitor that generates concise inline string representations of expressions.
Expand Down Expand Up @@ -502,6 +519,10 @@ def visit_star(self, expr: "StarExpr") -> str:
"""Visit a star expression and return its inline representation."""
return "col(*)"

def visit_unresolved(self, expr: "UnresolvedExpr") -> str:
"""Visit an unresolved expression and return its inline representation."""
return f"unresolved({expr.name!r})"


def get_column_references(expr: Expr) -> List[str]:
"""Extract all column references from an expression.
Expand Down
39 changes: 36 additions & 3 deletions python/ray/data/expressions.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,8 @@ def visit(self, expr: "Expr") -> T:
return self.visit_download(expr)
elif isinstance(expr, StarExpr):
return self.visit_star(expr)
elif isinstance(expr, UnresolvedExpr):
return self.visit_unresolved(expr)
else:
raise TypeError(f"Unsupported expression type for conversion: {type(expr)}")

Expand Down Expand Up @@ -142,6 +144,10 @@ def visit_star(self, expr: "StarExpr") -> T:
def visit_download(self, expr: "DownloadExpr") -> T:
pass

@abstractmethod
def visit_unresolved(self, expr: "UnresolvedExpr") -> T:
pass


class _PyArrowExpressionVisitor(_ExprVisitor["pyarrow.compute.Expression"]):
"""Visitor that converts Ray Data expressions to PyArrow compute expressions."""
Expand Down Expand Up @@ -206,6 +212,11 @@ def visit_download(self, expr: "DownloadExpr") -> "pyarrow.compute.Expression":
def visit_star(self, expr: "StarExpr") -> "pyarrow.compute.Expression":
raise TypeError("Star expressions cannot be converted to PyArrow expressions")

def visit_unresolved(self, expr: "UnresolvedExpr") -> "pyarrow.compute.Expression":
raise TypeError(
"Unresolved expressions cannot be converted to PyArrow expressions"
)


@DeveloperAPI(stability="alpha")
@dataclass(frozen=True)
Expand All @@ -232,7 +243,7 @@ class Expr(ABC):
subclasses like ColumnExpr, LiteralExpr, etc.
"""

data_type: DataType
data_type: DataType | None

@property
def name(self) -> str | None:
Expand Down Expand Up @@ -1360,6 +1371,28 @@ def structurally_equals(self, other: Any) -> bool:
)


@DeveloperAPI(stability="alpha")
@dataclass(frozen=True, eq=False, repr=False)
class UnresolvedExpr(Expr):
"""Expression that represents an unresolved column reference.

This expression is a placeholder used when a column reference has not yet
been resolved against a concrete schema. It must be resolved before
evaluation or conversion to another expression system.
"""

_name: str
data_type: DataType | None = field(default=None, init=False)

@property
def name(self) -> str:
"""Get the unresolved column name."""
return self._name

def structurally_equals(self, other: Any) -> bool:
return isinstance(other, UnresolvedExpr) and self.name == other.name


@DeveloperAPI(stability="alpha")
@dataclass(frozen=True, eq=False, repr=False)
class StarExpr(Expr):
Expand All @@ -1377,8 +1410,7 @@ class StarExpr(Expr):
This means: keep all existing columns, then add/overwrite "new_col"
"""

# TODO: Add UnresolvedExpr. Both StarExpr and UnresolvedExpr won't have a defined data_type.
data_type: DataType = field(default_factory=lambda: DataType(object), init=False)
data_type: DataType | None = field(default=None, init=False)

def structurally_equals(self, other: Any) -> bool:
return isinstance(other, StarExpr)
Expand Down Expand Up @@ -1513,6 +1545,7 @@ def download(
"UDFExpr",
"DownloadExpr",
"AliasExpr",
"UnresolvedExpr",
"StarExpr",
"pyarrow_udf",
"udf",
Expand Down
16 changes: 16 additions & 0 deletions python/ray/data/tests/unit/expressions/test_conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
BinaryExpr,
Operation,
UDFExpr,
UnresolvedExpr,
col,
download,
lit,
Expand Down Expand Up @@ -303,6 +304,13 @@ def test_star_expression_raises(self):
with pytest.raises(TypeError, match="Star expressions cannot be converted"):
star().to_pyarrow()

def test_unresolved_expression_raises(self):
"""Test that unresolved expressions raise TypeError."""
with pytest.raises(
TypeError, match="Unresolved expressions cannot be converted"
):
UnresolvedExpr("pending").to_pyarrow()


# ──────────────────────────────────────
# Iceberg Conversion Tests
Expand Down Expand Up @@ -538,6 +546,14 @@ def test_star_expression_raises(self):
):
visitor.visit(star())

def test_unresolved_expression_raises(self):
"""Test that unresolved expressions raise TypeError."""
visitor = _IcebergExpressionVisitor()
with pytest.raises(
TypeError, match="Unresolved expressions cannot be converted to Iceberg"
):
visitor.visit(UnresolvedExpr("pending"))

def test_is_in_requires_literal_list(self):
"""Test that IN/NOT_IN operations require literal lists."""
visitor = _IcebergExpressionVisitor()
Expand Down
36 changes: 35 additions & 1 deletion python/ray/data/tests/unit/expressions/test_core.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""Tests for core expression types and basic functionality.

This module tests:
- ColumnExpr, LiteralExpr, BinaryExpr, UnaryExpr, AliasExpr, StarExpr
- ColumnExpr, LiteralExpr, BinaryExpr, UnaryExpr, AliasExpr, StarExpr, UnresolvedExpr
- Structural equality for all expression types
- Expression tree repr (string representation)
- UDFExpr structural equality
Expand All @@ -24,6 +24,7 @@
StarExpr,
UDFExpr,
UnaryExpr,
UnresolvedExpr,
col,
download,
lit,
Expand Down Expand Up @@ -296,6 +297,37 @@ def test_star_structural_equality(self):
assert not star().structurally_equals(col("a"))


# ──────────────────────────────────────
# Unresolved Expression Tests
# ──────────────────────────────────────


class TestUnresolvedExpr:
"""Tests for UnresolvedExpr functionality."""

def test_unresolved_creation(self):
"""Test that UnresolvedExpr creates with correct name."""
expr = UnresolvedExpr("pending")
assert isinstance(expr, UnresolvedExpr)
assert expr.name == "pending"

@pytest.mark.parametrize(
"name1,name2,expected",
[
("a", "a", True),
("a", "b", False),
("column_name", "column_name", True),
("COL", "col", False),
],
ids=["same_name", "different_name", "long_name", "case_sensitive"],
)
def test_unresolved_structural_equality(self, name1, name2, expected):
"""Test structural equality for unresolved expressions."""
assert (
UnresolvedExpr(name1).structurally_equals(UnresolvedExpr(name2)) is expected
)


# ──────────────────────────────────────
# UDF Expression Tests
# ──────────────────────────────────────
Expand Down Expand Up @@ -393,13 +425,15 @@ class TestCrossTypeEquality:
(lit(1), lit(1) + 0),
(col("a"), col("a").alias("a")),
(col("a"), star()),
(col("a"), UnresolvedExpr("a")),
],
ids=[
"col_vs_lit",
"col_vs_binary",
"lit_vs_binary",
"col_vs_alias",
"col_vs_star",
"col_vs_unresolved",
],
)
def test_different_types_not_equal(self, expr1, expr2):
Expand Down