Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions python/ray/data/_internal/datasource/iceberg_datasource.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
DownloadExpr,
LiteralExpr,
Operation,
RenameExpr,
StarExpr,
UDFExpr,
UnaryExpr,
Expand Down Expand Up @@ -151,6 +152,12 @@ def visit_alias(
"""Convert an aliased expression (just unwrap the alias)."""
return self.visit(expr.expr)

def visit_rename(
self, expr: "RenameExpr"
) -> "BooleanExpression | UnboundTerm[Any] | Literal[Any]":
"""Convert a rename expression (just unwrap the rename)."""
return self.visit(expr.expr)

def visit_udf(
self, expr: "UDFExpr"
) -> "BooleanExpression | UnboundTerm[Any] | Literal[Any]":
Expand Down
18 changes: 7 additions & 11 deletions python/ray/data/_internal/logical/rules/predicate_pushdown.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,13 +84,10 @@ def _can_push_filter_through_projection(
- Rename chains with name reuse: rename({'a': 'b', 'b': 'c'}).filter(col('b'))
(where 'b' is valid output created by a->b)
"""
from ray.data._internal.logical.rules.projection_pushdown import (
_is_renaming_expr,
)
from ray.data._internal.planner.plan_expression.expression_visitors import (
_ColumnReferenceCollector,
)
from ray.data.expressions import AliasExpr
from ray.data.expressions import AliasExpr, RenameExpr

collector = _ColumnReferenceCollector()
collector.visit(filter_op.predicate_expr)
Expand All @@ -105,18 +102,17 @@ def _can_push_filter_through_projection(
# Collect output column names
output_columns.add(expr.name)

# Process AliasExpr (computed columns or renames)
if isinstance(expr, AliasExpr):
# Process RenameExpr and AliasExpr (computed columns)
if isinstance(expr, RenameExpr):
new_names.add(expr.name)
original_columns_being_renamed.add(expr.expr.name)
elif isinstance(expr, AliasExpr):
new_names.add(expr.name)

# Check computed column: with_column('d', 4) creates AliasExpr(lit(4), 'd')
if expr.name in predicate_columns and not _is_renaming_expr(expr):
if expr.name in predicate_columns:
return False # Computed column

# Track old names being renamed for later check
if _is_renaming_expr(expr):
original_columns_being_renamed.add(expr.expr.name)

# Check if filter references columns removed by explicit select
# Valid if: projection includes all columns (star) OR predicate columns exist in output
has_required_columns = (
Expand Down
17 changes: 8 additions & 9 deletions python/ray/data/_internal/logical/rules/projection_pushdown.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,11 @@
from ray.data._internal.planner.plan_expression.expression_visitors import (
_ColumnReferenceCollector,
_ColumnSubstitutionVisitor,
_is_col_expr,
)
from ray.data.expressions import (
AliasExpr,
ColumnExpr,
Expr,
RenameExpr,
StarExpr,
)

Expand Down Expand Up @@ -52,7 +51,7 @@ def _analyze_upstream_project(
"""
Analyze what the upstream project produces and identifies removed columns.

Example: Upstream exprs [col("x").alias("y")] → removed_by_renames = {"x"} if "x" not in output
Example: Upstream exprs [col("x")._rename("y")] → removed_by_renames = {"x"} if "x" not in output
"""
output_column_names = {
expr.name for expr in upstream_project.exprs if not isinstance(expr, StarExpr)
Expand Down Expand Up @@ -338,7 +337,8 @@ def _push_projection_into_read_op(cls, op: LogicalOperator) -> LogicalOperator:
# Check if it's a simple projection that could be pushed into
# read as a whole
is_projection = all(
_is_col_expr(expr) for expr in _filter_out_star(current_project.exprs)
isinstance(expr, (ColumnExpr, RenameExpr))
for expr in _filter_out_star(current_project.exprs)
)

if is_projection:
Expand Down Expand Up @@ -432,16 +432,15 @@ def _extract_input_columns_renaming_mapping(
def _get_renaming_mapping(expr: Expr) -> Tuple[str, str]:
assert _is_renaming_expr(expr)

alias: AliasExpr = expr

return alias.expr.name, alias.name
rename: RenameExpr = expr
return rename.expr.name, rename.name


def _is_renaming_expr(expr: Expr) -> bool:
is_renaming = isinstance(expr, AliasExpr) and expr._is_rename
is_renaming = isinstance(expr, RenameExpr)

assert not is_renaming or isinstance(
expr.expr, ColumnExpr
), f"Renaming expression expected to be of the shape alias(col('source'), 'target') (got {expr})"
), f"Renaming expression expected to be of the shape rename(col('source'), 'target') (got {expr})"

return is_renaming
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
Expr,
LiteralExpr,
Operation,
RenameExpr,
StarExpr,
UDFExpr,
UnaryExpr,
Expand Down Expand Up @@ -661,6 +662,17 @@ def visit_alias(self, expr: AliasExpr) -> Union[BlockColumn, ScalarType]:
# Evaluate the inner expression
return self.visit(expr.expr)

def visit_rename(self, expr: RenameExpr) -> Union[BlockColumn, ScalarType]:
"""Visit a rename expression and return the renamed result.

Args:
expr: The rename expression.

Returns:
A Block with the data from the inner expression.
"""
return self.visit(expr.expr)

def visit_star(self, expr: StarExpr) -> Union[BlockColumn, ScalarType]:
"""Visit a star expression.

Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
from dataclasses import replace
from typing import Dict, List, TypeVar

from ray.data.expressions import (
Expand All @@ -9,6 +8,7 @@
Expr,
LiteralExpr,
Operation,
RenameExpr,
StarExpr,
UDFExpr,
UnaryExpr,
Expand Down Expand Up @@ -59,6 +59,10 @@ def visit_alias(self, expr: "AliasExpr") -> None:
"""Default implementation: recursively visit the inner expression."""
super().visit(expr.expr)

def visit_rename(self, expr: "RenameExpr") -> None:
"""Default implementation: recursively visit the inner expression."""
super().visit(expr.expr)

def visit_udf(self, expr: "UDFExpr") -> None:
"""Default implementation: recursively visit all arguments."""
for arg in expr.args:
Expand Down Expand Up @@ -117,6 +121,17 @@ def visit_alias(self, expr: AliasExpr) -> None:
"""
self.visit(expr.expr)

def visit_rename(self, expr: RenameExpr) -> None:
"""Visit a rename expression and collect from its inner expression.

Args:
expr: The rename expression.

Returns:
None (only collects columns as a side effect).
"""
self.visit(expr.expr)


class _CallableClassUDFCollector(_ExprVisitorBase):
"""Visitor that collects all callable class UDFs from expression trees.
Expand Down Expand Up @@ -157,6 +172,10 @@ def visit_udf(self, expr: UDFExpr) -> None:
# Continue visiting child expressions
super().visit_udf(expr)

def visit_rename(self, expr: RenameExpr) -> None:
"""Visit a rename expression (no UDFs to collect)."""
super().visit_rename(expr)


class _ColumnSubstitutionVisitor(_ExprVisitor[Expr]):
"""Visitor rebinding column references in ``Expression``s.
Expand Down Expand Up @@ -252,17 +271,23 @@ def visit_alias(self, expr: AliasExpr) -> Expr:
"""
# We unalias returned expression to avoid nested aliasing
visited = self.visit(expr.expr)._unalias()
# NOTE: We're carrying over all of the other aspects of the alias
# only replacing inner expre
return replace(
expr,
expr=visited,
# Alias expression will remain a renaming one (ie replacing source column)
# so long as it's referencing another column (and not otherwise)
#
# TODO replace w/ standalone rename expr
_is_rename=expr._is_rename and _is_col_expr(visited),
)
return AliasExpr(data_type=visited.data_type, expr=visited, _name=expr.name)

def visit_rename(self, expr: RenameExpr) -> Expr:
"""Visit a rename expression and rewrite its inner expression.

Args:
expr: The rename expression.

Returns:
A rename expression when it still targets a column, otherwise an alias.
"""
visited = self.visit(expr.expr)
if isinstance(visited, ColumnExpr):
return RenameExpr(
data_type=visited.data_type, expr=visited, _name=expr.name
)
return AliasExpr(data_type=visited.data_type, expr=visited, _name=expr.name)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

For consistency with visit_alias, consider calling _unalias() on the visited expression. This would prevent creating nested aliases if the substitution itself results in an AliasExpr and would correctly preserve the RenameExpr type when possible.

For example, if col('a')._rename('b') is visited with a substitution {'a': col('c').alias('d')}:

  • The current implementation produces a nested alias: (col('c').alias('d')).alias('b').
  • With _unalias(), it would become col('c')._rename('b'), which seems more correct as the intermediate alias d is being renamed to b anyway.
Suggested change
visited = self.visit(expr.expr)
if isinstance(visited, ColumnExpr):
return RenameExpr(
data_type=visited.data_type, expr=visited, _name=expr.name
)
return AliasExpr(data_type=visited.data_type, expr=visited, _name=expr.name)
visited = self.visit(expr.expr)._unalias()
if isinstance(visited, ColumnExpr):
return RenameExpr(
data_type=visited.data_type, expr=visited, _name=expr.name
)
return AliasExpr(data_type=visited.data_type, expr=visited, _name=expr.name)


def visit_download(self, expr: "Expr") -> Expr:
"""Visit a download expression (no rewriting needed).
Expand All @@ -287,12 +312,6 @@ def visit_star(self, expr: StarExpr) -> Expr:
return expr


def _is_col_expr(expr: Expr) -> bool:
return isinstance(expr, ColumnExpr) or (
isinstance(expr, AliasExpr) and isinstance(expr.expr, ColumnExpr)
)


class _TreeReprVisitor(_ExprVisitor[str]):
"""Visitor that generates a readable tree representation of expressions. Returns in pre-order traversal."""

Expand Down Expand Up @@ -381,9 +400,15 @@ def visit_unary(self, expr: "UnaryExpr") -> str:
)

def visit_alias(self, expr: "AliasExpr") -> str:
rename_marker = " [rename]" if expr._is_rename else ""
return self._make_tree_lines(
f"ALIAS({expr.name!r}){rename_marker}",
f"ALIAS({expr.name!r})",
children=[("", expr.expr)],
expr=expr,
)

def visit_rename(self, expr: "RenameExpr") -> str:
return self._make_tree_lines(
f"RENAME({expr.name!r})",
children=[("", expr.expr)],
expr=expr,
)
Expand Down Expand Up @@ -478,6 +503,11 @@ def visit_alias(self, expr: "AliasExpr") -> str:
inner_str = self.visit(expr.expr)
return f"{inner_str}.alias({expr.name!r})"

def visit_rename(self, expr: "RenameExpr") -> str:
"""Visit a rename expression and return its inline representation."""
inner_str = self.visit(expr.expr)
return f"{inner_str}.rename({expr.name!r})"

def visit_udf(self, expr: "UDFExpr") -> str:
"""Visit a UDF expression and return its inline representation."""
# Get function name for better readability
Expand Down
45 changes: 35 additions & 10 deletions python/ray/data/expressions.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,8 @@ def visit(self, expr: "Expr") -> T:
return self.visit_binary(expr)
elif isinstance(expr, UnaryExpr):
return self.visit_unary(expr)
elif isinstance(expr, RenameExpr):
return self.visit_rename(expr)
elif isinstance(expr, AliasExpr):
return self.visit_alias(expr)
elif isinstance(expr, UDFExpr):
Expand Down Expand Up @@ -126,6 +128,10 @@ def visit_binary(self, expr: "BinaryExpr") -> T:
def visit_unary(self, expr: "UnaryExpr") -> T:
pass

@abstractmethod
def visit_rename(self, expr: "RenameExpr") -> T:
pass

@abstractmethod
def visit_alias(self, expr: "AliasExpr") -> T:
pass
Expand Down Expand Up @@ -192,6 +198,9 @@ def visit_unary(self, expr: "UnaryExpr") -> "pyarrow.compute.Expression":
return _ARROW_EXPR_OPS_MAP[expr.op](operand)
raise ValueError(f"Unsupported unary operation for PyArrow: {expr.op}")

def visit_rename(self, expr: "RenameExpr") -> "pyarrow.compute.Expression":
return self.visit(expr.expr)

def visit_alias(self, expr: "AliasExpr") -> "pyarrow.compute.Expression":
return self.visit(expr.expr)

Expand Down Expand Up @@ -239,7 +248,7 @@ def name(self) -> str | None:
"""Get the name associated with this expression.

Returns:
The name for expressions that have one (ColumnExpr, AliasExpr),
The name for expressions that have one (ColumnExpr, AliasExpr, RenameExpr),
None otherwise.
"""
return None
Expand Down Expand Up @@ -431,9 +440,7 @@ def alias(self, name: str) -> "Expr":
>>> expr = (col("price") * col("quantity")).alias("total")
>>> # Can be used with Dataset operations that support named expressions
"""
return AliasExpr(
data_type=self.data_type, expr=self, _name=name, _is_rename=False
)
return AliasExpr(data_type=self.data_type, expr=self, _name=name)

# rounding helpers
def ceil(self) -> "UDFExpr":
Expand Down Expand Up @@ -693,7 +700,7 @@ def name(self) -> str:
return self._name

def _rename(self, name: str):
return AliasExpr(self.data_type, self, name, _is_rename=True)
return RenameExpr(self.data_type, self, name)

def structurally_equals(self, other: Any) -> bool:
return isinstance(other, ColumnExpr) and self.name == other.name
Expand Down Expand Up @@ -1335,7 +1342,6 @@ class AliasExpr(Expr):

expr: Expr
_name: str
_is_rename: bool

@property
def name(self) -> str:
Expand All @@ -1344,9 +1350,7 @@ def name(self) -> str:

def alias(self, name: str) -> "Expr":
# Always unalias before creating new one
return AliasExpr(
self.expr.data_type, self.expr, _name=name, _is_rename=self._is_rename
)
return AliasExpr(self.expr.data_type, self.expr, _name=name)

def _unalias(self) -> "Expr":
return self.expr
Expand All @@ -1356,7 +1360,27 @@ def structurally_equals(self, other: Any) -> bool:
isinstance(other, AliasExpr)
and self.expr.structurally_equals(other.expr)
and self.name == other.name
and self._is_rename == other._is_rename
)


@DeveloperAPI(stability="alpha")
@dataclass(frozen=True, eq=False, repr=False)
class RenameExpr(Expr):
"""Expression that represents renaming a column."""

expr: Expr
_name: str

@property
def name(self) -> str:
"""Get the renamed column name."""
return self._name

def structurally_equals(self, other: Any) -> bool:
return (
isinstance(other, RenameExpr)
and self.expr.structurally_equals(other.expr)
and self.name == other.name
)


Expand Down Expand Up @@ -1513,6 +1537,7 @@ def download(
"UDFExpr",
"DownloadExpr",
"AliasExpr",
"RenameExpr",
"StarExpr",
"pyarrow_udf",
"udf",
Expand Down
Loading