Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions python/ray/data/_internal/datasource/iceberg_datasource.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
DownloadExpr,
LiteralExpr,
Operation,
RenameExpr,
StarExpr,
UDFExpr,
UnaryExpr,
Expand Down Expand Up @@ -151,6 +152,12 @@ def visit_alias(
"""Convert an aliased expression (just unwrap the alias)."""
return self.visit(expr.expr)

def visit_rename(
self, expr: "RenameExpr"
) -> "BooleanExpression | UnboundTerm[Any] | Literal[Any]":
"""Convert a rename expression (just unwrap the rename)."""
return self.visit(expr.expr)

def visit_udf(
self, expr: "UDFExpr"
) -> "BooleanExpression | UnboundTerm[Any] | Literal[Any]":
Expand Down
18 changes: 7 additions & 11 deletions python/ray/data/_internal/logical/rules/predicate_pushdown.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,13 +84,10 @@ def _can_push_filter_through_projection(
- Rename chains with name reuse: rename({'a': 'b', 'b': 'c'}).filter(col('b'))
(where 'b' is valid output created by a->b)
"""
from ray.data._internal.logical.rules.projection_pushdown import (
_is_renaming_expr,
)
from ray.data._internal.planner.plan_expression.expression_visitors import (
_ColumnReferenceCollector,
)
from ray.data.expressions import AliasExpr
from ray.data.expressions import AliasExpr, RenameExpr

collector = _ColumnReferenceCollector()
collector.visit(filter_op.predicate_expr)
Expand All @@ -105,18 +102,17 @@ def _can_push_filter_through_projection(
# Collect output column names
output_columns.add(expr.name)

# Process AliasExpr (computed columns or renames)
if isinstance(expr, AliasExpr):
# Process RenameExpr and AliasExpr (computed columns)
if isinstance(expr, RenameExpr):
new_names.add(expr.name)
original_columns_being_renamed.add(expr.expr.name)
elif isinstance(expr, AliasExpr):
new_names.add(expr.name)

# Check computed column: with_column('d', 4) creates AliasExpr(lit(4), 'd')
if expr.name in predicate_columns and not _is_renaming_expr(expr):
if expr.name in predicate_columns:
return False # Computed column

# Track old names being renamed for later check
if _is_renaming_expr(expr):
original_columns_being_renamed.add(expr.expr.name)

# Check if filter references columns removed by explicit select
# Valid if: projection includes all columns (star) OR predicate columns exist in output
has_required_columns = (
Expand Down
17 changes: 8 additions & 9 deletions python/ray/data/_internal/logical/rules/projection_pushdown.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,11 @@
from ray.data._internal.planner.plan_expression.expression_visitors import (
_ColumnReferenceCollector,
_ColumnSubstitutionVisitor,
_is_col_expr,
)
from ray.data.expressions import (
AliasExpr,
ColumnExpr,
Expr,
RenameExpr,
StarExpr,
)

Expand Down Expand Up @@ -52,7 +51,7 @@ def _analyze_upstream_project(
"""
Analyze what the upstream project produces and identifies removed columns.

Example: Upstream exprs [col("x").alias("y")] → removed_by_renames = {"x"} if "x" not in output
Example: Upstream exprs [col("x")._rename("y")] → removed_by_renames = {"x"} if "x" not in output
"""
output_column_names = {
expr.name for expr in upstream_project.exprs if not isinstance(expr, StarExpr)
Expand Down Expand Up @@ -338,7 +337,8 @@ def _push_projection_into_read_op(cls, op: LogicalOperator) -> LogicalOperator:
# Check if it's a simple projection that could be pushed into
# read as a whole
is_projection = all(
_is_col_expr(expr) for expr in _filter_out_star(current_project.exprs)
isinstance(expr, (ColumnExpr, RenameExpr))
for expr in _filter_out_star(current_project.exprs)
)

if is_projection:
Expand Down Expand Up @@ -432,16 +432,15 @@ def _extract_input_columns_renaming_mapping(
def _get_renaming_mapping(expr: Expr) -> Tuple[str, str]:
assert _is_renaming_expr(expr)

alias: AliasExpr = expr

return alias.expr.name, alias.name
rename: RenameExpr = expr
return rename.expr.name, rename.name


def _is_renaming_expr(expr: Expr) -> bool:
is_renaming = isinstance(expr, AliasExpr) and expr._is_rename
is_renaming = isinstance(expr, RenameExpr)

assert not is_renaming or isinstance(
expr.expr, ColumnExpr
), f"Renaming expression expected to be of the shape alias(col('source'), 'target') (got {expr})"
), f"Renaming expression expected to be of the shape rename(col('source'), 'target') (got {expr})"

return is_renaming
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
Expr,
LiteralExpr,
Operation,
RenameExpr,
StarExpr,
UDFExpr,
UnaryExpr,
Expand Down Expand Up @@ -661,6 +662,17 @@ def visit_alias(self, expr: AliasExpr) -> Union[BlockColumn, ScalarType]:
# Evaluate the inner expression
return self.visit(expr.expr)

def visit_rename(self, expr: RenameExpr) -> Union[BlockColumn, ScalarType]:
"""Visit a rename expression and return the renamed result.

Args:
expr: The rename expression.

Returns:
A Block with the data from the inner expression.
"""
return self.visit(expr.expr)

def visit_star(self, expr: StarExpr) -> Union[BlockColumn, ScalarType]:
"""Visit a star expression.

Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
from dataclasses import replace
from typing import Dict, List, TypeVar

from ray.data.expressions import (
Expand All @@ -9,6 +8,7 @@
Expr,
LiteralExpr,
Operation,
RenameExpr,
StarExpr,
UDFExpr,
UnaryExpr,
Expand Down Expand Up @@ -59,6 +59,10 @@ def visit_alias(self, expr: "AliasExpr") -> None:
"""Default implementation: recursively visit the inner expression."""
super().visit(expr.expr)

def visit_rename(self, expr: "RenameExpr") -> None:
"""Default implementation: recursively visit the inner expression."""
super().visit(expr.expr)

def visit_udf(self, expr: "UDFExpr") -> None:
"""Default implementation: recursively visit all arguments."""
for arg in expr.args:
Expand Down Expand Up @@ -117,6 +121,17 @@ def visit_alias(self, expr: AliasExpr) -> None:
"""
self.visit(expr.expr)

def visit_rename(self, expr: RenameExpr) -> None:
"""Visit a rename expression and collect from its inner expression.

Args:
expr: The rename expression.

Returns:
None (only collects columns as a side effect).
"""
self.visit(expr.expr)


class _CallableClassUDFCollector(_ExprVisitorBase):
"""Visitor that collects all callable class UDFs from expression trees.
Expand Down Expand Up @@ -157,6 +172,10 @@ def visit_udf(self, expr: UDFExpr) -> None:
# Continue visiting child expressions
super().visit_udf(expr)

def visit_rename(self, expr: RenameExpr) -> None:
"""Visit a rename expression (no UDFs to collect)."""
super().visit_rename(expr)


class _ColumnSubstitutionVisitor(_ExprVisitor[Expr]):
"""Visitor rebinding column references in ``Expression``s.
Expand Down Expand Up @@ -252,17 +271,24 @@ def visit_alias(self, expr: AliasExpr) -> Expr:
"""
# We unalias returned expression to avoid nested aliasing
visited = self.visit(expr.expr)._unalias()
# NOTE: We're carrying over all of the other aspects of the alias
# only replacing inner expre
return replace(
expr,
expr=visited,
# Alias expression will remain a renaming one (ie replacing source column)
# so long as it's referencing another column (and not otherwise)
#
# TODO replace w/ standalone rename expr
_is_rename=expr._is_rename and _is_col_expr(visited),
)
return AliasExpr(data_type=visited.data_type, expr=visited, _name=expr.name)

def visit_rename(self, expr: RenameExpr) -> Expr:
"""Visit a rename expression and rewrite its inner expression.

Args:
expr: The rename expression.

Returns:
A rename expression when it still targets a column, otherwise an alias.
"""
# Unalias to avoid nested aliasing and preserve rename semantics when possible.
visited = self.visit(expr.expr)._unalias()
if isinstance(visited, ColumnExpr):
return RenameExpr(
data_type=visited.data_type, expr=visited, _name=expr.name
)
return AliasExpr(data_type=visited.data_type, expr=visited, _name=expr.name)

def visit_download(self, expr: "Expr") -> Expr:
"""Visit a download expression (no rewriting needed).
Expand All @@ -287,12 +313,6 @@ def visit_star(self, expr: StarExpr) -> Expr:
return expr


def _is_col_expr(expr: Expr) -> bool:
return isinstance(expr, ColumnExpr) or (
isinstance(expr, AliasExpr) and isinstance(expr.expr, ColumnExpr)
)


class _TreeReprVisitor(_ExprVisitor[str]):
"""Visitor that generates a readable tree representation of expressions. Returns in pre-order traversal."""

Expand Down Expand Up @@ -381,9 +401,15 @@ def visit_unary(self, expr: "UnaryExpr") -> str:
)

def visit_alias(self, expr: "AliasExpr") -> str:
rename_marker = " [rename]" if expr._is_rename else ""
return self._make_tree_lines(
f"ALIAS({expr.name!r}){rename_marker}",
f"ALIAS({expr.name!r})",
children=[("", expr.expr)],
expr=expr,
)

def visit_rename(self, expr: "RenameExpr") -> str:
return self._make_tree_lines(
f"RENAME({expr.name!r})",
children=[("", expr.expr)],
expr=expr,
)
Expand Down Expand Up @@ -478,6 +504,11 @@ def visit_alias(self, expr: "AliasExpr") -> str:
inner_str = self.visit(expr.expr)
return f"{inner_str}.alias({expr.name!r})"

def visit_rename(self, expr: "RenameExpr") -> str:
"""Visit a rename expression and return its inline representation."""
inner_str = self.visit(expr.expr)
return f"{inner_str}.rename({expr.name!r})"

def visit_udf(self, expr: "UDFExpr") -> str:
"""Visit a UDF expression and return its inline representation."""
# Get function name for better readability
Expand Down
48 changes: 38 additions & 10 deletions python/ray/data/expressions.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,8 @@ def visit(self, expr: "Expr") -> T:
return self.visit_binary(expr)
elif isinstance(expr, UnaryExpr):
return self.visit_unary(expr)
elif isinstance(expr, RenameExpr):
return self.visit_rename(expr)
elif isinstance(expr, AliasExpr):
return self.visit_alias(expr)
elif isinstance(expr, UDFExpr):
Expand Down Expand Up @@ -126,6 +128,10 @@ def visit_binary(self, expr: "BinaryExpr") -> T:
def visit_unary(self, expr: "UnaryExpr") -> T:
pass

@abstractmethod
def visit_rename(self, expr: "RenameExpr") -> T:
pass

@abstractmethod
def visit_alias(self, expr: "AliasExpr") -> T:
pass
Expand Down Expand Up @@ -192,6 +198,9 @@ def visit_unary(self, expr: "UnaryExpr") -> "pyarrow.compute.Expression":
return _ARROW_EXPR_OPS_MAP[expr.op](operand)
raise ValueError(f"Unsupported unary operation for PyArrow: {expr.op}")

def visit_rename(self, expr: "RenameExpr") -> "pyarrow.compute.Expression":
return self.visit(expr.expr)

def visit_alias(self, expr: "AliasExpr") -> "pyarrow.compute.Expression":
return self.visit(expr.expr)

Expand Down Expand Up @@ -239,7 +248,7 @@ def name(self) -> str | None:
"""Get the name associated with this expression.

Returns:
The name for expressions that have one (ColumnExpr, AliasExpr),
The name for expressions that have one (ColumnExpr, AliasExpr, RenameExpr),
None otherwise.
"""
return None
Expand Down Expand Up @@ -431,9 +440,7 @@ def alias(self, name: str) -> "Expr":
>>> expr = (col("price") * col("quantity")).alias("total")
>>> # Can be used with Dataset operations that support named expressions
"""
return AliasExpr(
data_type=self.data_type, expr=self, _name=name, _is_rename=False
)
return AliasExpr(data_type=self.data_type, expr=self, _name=name)

# rounding helpers
def ceil(self) -> "UDFExpr":
Expand Down Expand Up @@ -693,7 +700,7 @@ def name(self) -> str:
return self._name

def _rename(self, name: str):
return AliasExpr(self.data_type, self, name, _is_rename=True)
return RenameExpr(self.data_type, self, name)

def structurally_equals(self, other: Any) -> bool:
return isinstance(other, ColumnExpr) and self.name == other.name
Expand Down Expand Up @@ -1335,7 +1342,6 @@ class AliasExpr(Expr):

expr: Expr
_name: str
_is_rename: bool

@property
def name(self) -> str:
Expand All @@ -1344,9 +1350,7 @@ def name(self) -> str:

def alias(self, name: str) -> "Expr":
# Always unalias before creating new one
return AliasExpr(
self.expr.data_type, self.expr, _name=name, _is_rename=self._is_rename
)
return AliasExpr(self.expr.data_type, self.expr, _name=name)

def _unalias(self) -> "Expr":
return self.expr
Expand All @@ -1356,7 +1360,30 @@ def structurally_equals(self, other: Any) -> bool:
isinstance(other, AliasExpr)
and self.expr.structurally_equals(other.expr)
and self.name == other.name
and self._is_rename == other._is_rename
)


@DeveloperAPI(stability="alpha")
@dataclass(frozen=True, eq=False, repr=False)
class RenameExpr(Expr):
"""Expression that represents renaming a column."""

expr: Expr
_name: str

@property
def name(self) -> str:
"""Get the renamed column name."""
return self._name

def _unalias(self) -> "Expr":
return self.expr

def structurally_equals(self, other: Any) -> bool:
return (
isinstance(other, RenameExpr)
and self.expr.structurally_equals(other.expr)
and self.name == other.name
)


Expand Down Expand Up @@ -1513,6 +1540,7 @@ def download(
"UDFExpr",
"DownloadExpr",
"AliasExpr",
"RenameExpr",
"StarExpr",
"pyarrow_udf",
"udf",
Expand Down
Loading