Skip to content

feat: add builders for switch, cast expressions #78

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
May 26, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 31 additions & 0 deletions examples/pyarrow_example.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
# Install pyarrow before running this example
# /// script
# dependencies = [
# "pyarrow==20.0.0",
# "substrait[extensions] @ file:///${PROJECT_ROOT}/"
# ]
# ///
import pyarrow as pa
import pyarrow.compute as pc
import pyarrow.substrait as pa_substrait
import substrait
from substrait.builders.plan import project, read_named_table

arrow_schema = pa.schema([
pa.field("x", pa.int32()),
pa.field("y", pa.int32())
])

substrait_schema = pa_substrait.serialize_schema(arrow_schema).to_pysubstrait().base_schema

substrait_expr = pa_substrait.serialize_expressions(
exprs=[pc.field("x") + pc.field("y")],
names=["total"],
schema=arrow_schema
)

pysubstrait_expr = substrait.proto.ExtendedExpression.FromString(bytes(substrait_expr))

table = read_named_table("example", substrait_schema)
table = project(table, expressions=[pysubstrait_expr])(None)
print(table)
198 changes: 186 additions & 12 deletions src/substrait/builders/extended_expression.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from typing import Callable, Any, Union, Iterable

UnboundExtendedExpression = Callable[[stp.NamedStruct, ExtensionRegistry], stee.ExtendedExpression]
ExtendedExpressionOrUnbound = Union[stee.ExtendedExpression, UnboundExtendedExpression]

def _alias_or_inferred(
alias: Union[Iterable[str], str],
Expand All @@ -21,6 +22,13 @@ def _alias_or_inferred(
else:
return [f'{op}({",".join(args)})']

def resolve_expression(
expression: ExtendedExpressionOrUnbound,
base_schema: stp.NamedStruct,
registry: ExtensionRegistry
) -> stee.ExtendedExpression:
return expression if isinstance(expression, stee.ExtendedExpression) else expression(base_schema, registry)

def literal(value: Any, type: stp.Type, alias: Union[Iterable[str], str] = None) -> UnboundExtendedExpression:
"""Builds a resolver for ExtendedExpression containing a literal expression"""
def resolve(base_schema: stp.NamedStruct, registry: ExtensionRegistry) -> stee.ExtendedExpression:
Expand Down Expand Up @@ -139,14 +147,14 @@ def resolve(
return resolve

def scalar_function(
uri: str, function: str, *expressions: UnboundExtendedExpression, alias: Union[Iterable[str], str] = None
uri: str, function: str, expressions: Iterable[ExtendedExpressionOrUnbound], alias: Union[Iterable[str], str] = None
):
"""Builds a resolver for ExtendedExpression containing a ScalarFunction expression"""
def resolve(
base_schema: stp.NamedStruct, registry: ExtensionRegistry
) -> stee.ExtendedExpression:
bound_expressions: Iterable[stee.ExtendedExpression] = [
e(base_schema, registry) for e in expressions
bound_expressions = [
resolve_expression(e, base_schema, registry) for e in expressions
]

expression_schemas = [
Expand Down Expand Up @@ -210,14 +218,14 @@ def resolve(
return resolve

def aggregate_function(
uri: str, function: str, *expressions: UnboundExtendedExpression, alias: Union[Iterable[str], str] = None
uri: str, function: str, expressions: Iterable[ExtendedExpressionOrUnbound], alias: Union[Iterable[str], str] = None
):
"""Builds a resolver for ExtendedExpression containing a AggregateFunction measure"""
def resolve(
base_schema: stp.NamedStruct, registry: ExtensionRegistry
) -> stee.ExtendedExpression:
bound_expressions: Iterable[stee.ExtendedExpression] = [
e(base_schema, registry) for e in expressions
resolve_expression(e, base_schema, registry) for e in expressions
]

expression_schemas = [
Expand Down Expand Up @@ -281,19 +289,19 @@ def resolve(
def window_function(
uri: str,
function: str,
*expressions: UnboundExtendedExpression,
partitions: Iterable[UnboundExtendedExpression] = [],
expressions: Iterable[ExtendedExpressionOrUnbound],
partitions: Iterable[ExtendedExpressionOrUnbound] = [],
alias: Union[Iterable[str], str] = None
):
"""Builds a resolver for ExtendedExpression containing a WindowFunction expression"""
def resolve(
base_schema: stp.NamedStruct, registry: ExtensionRegistry
) -> stee.ExtendedExpression:
bound_expressions: Iterable[stee.ExtendedExpression] = [
e(base_schema, registry) for e in expressions
resolve_expression(e, base_schema, registry) for e in expressions
]

bound_partitions = [e(base_schema, registry) for e in partitions]
bound_partitions = [resolve_expression(e, base_schema, registry) for e in partitions]

expression_schemas = [
infer_extended_expression_schema(b) for b in bound_expressions
Expand Down Expand Up @@ -363,17 +371,17 @@ def resolve(
return resolve


def if_then(ifs: Iterable[tuple[UnboundExtendedExpression, UnboundExtendedExpression]], _else: UnboundExtendedExpression, alias: Union[Iterable[str], str] = None):
def if_then(ifs: Iterable[tuple[ExtendedExpressionOrUnbound, ExtendedExpressionOrUnbound]], _else: ExtendedExpressionOrUnbound, alias: Union[Iterable[str], str] = None):
"""Builds a resolver for ExtendedExpression containing an IfThen expression"""
def resolve(
base_schema: stp.NamedStruct, registry: ExtensionRegistry
) -> stee.ExtendedExpression:
bound_ifs = [
(if_clause[0](base_schema, registry), if_clause[1](base_schema, registry))
(resolve_expression(if_clause[0], base_schema, registry), resolve_expression(if_clause[1], base_schema, registry))
for if_clause in ifs
]

bound_else = _else(base_schema, registry)
bound_else = resolve_expression(_else, base_schema, registry)

extension_uris = merge_extension_uris(
*[b[0].extension_uris for b in bound_ifs],
Expand Down Expand Up @@ -413,3 +421,169 @@ def resolve(
)

return resolve

def switch(match: ExtendedExpressionOrUnbound,
ifs: Iterable[tuple[ExtendedExpressionOrUnbound, ExtendedExpressionOrUnbound]],
_else: ExtendedExpressionOrUnbound):
"""Builds a resolver for ExtendedExpression containing a switch expression"""
def resolve(
base_schema: stp.NamedStruct, registry: ExtensionRegistry
) -> stee.ExtendedExpression:
bound_match = resolve_expression(match, base_schema, registry)
bound_ifs = [
(
resolve_expression(a, base_schema, registry),
resolve_expression(b, base_schema, registry)
) for a, b in ifs]
bound_else = resolve_expression(_else, base_schema, registry)

extension_uris = merge_extension_uris(
bound_match.extension_uris,
*[b.extension_uris for _, b in bound_ifs],
bound_else.extension_uris
)

extensions = merge_extension_declarations(
bound_match.extensions,
*[b.extensions for _, b in bound_ifs],
bound_else.extensions
)

return stee.ExtendedExpression(
referred_expr=[
stee.ExpressionReference(
expression=stalg.Expression(
switch_expression=stalg.Expression.SwitchExpression(
match=bound_match.referred_expr[0].expression,
ifs=[
stalg.Expression.SwitchExpression.IfValue(**{
'if': i.referred_expr[0].expression.literal,
'then': t.referred_expr[0].expression
})
for i, t in bound_ifs
],
**{
'else': bound_else.referred_expr[0].expression
}
)
),
output_names=['switch'] #TODO construct name from inputs
)
],
base_schema=base_schema,
extension_uris=extension_uris,
extensions=extensions,
)

return resolve

def singular_or_list(value: ExtendedExpressionOrUnbound, options: Iterable[ExtendedExpressionOrUnbound]):
"""Builds a resolver for ExtendedExpression containing a SingularOrList expression"""
def resolve(
base_schema: stp.NamedStruct, registry: ExtensionRegistry
) -> stee.ExtendedExpression:
bound_value = resolve_expression(value, base_schema, registry)
bound_options = [resolve_expression(o, base_schema, registry) for o in options]

extension_uris = merge_extension_uris(
bound_value.extension_uris,
*[b.extension_uris for b in bound_options]
)

extensions = merge_extension_declarations(
bound_value.extensions,
*[b.extensions for b in bound_options]
)

return stee.ExtendedExpression(
referred_expr=[
stee.ExpressionReference(
expression=stalg.Expression(
singular_or_list=stalg.Expression.SingularOrList(
value=bound_value.referred_expr[0].expression,
options=[
o.referred_expr[0].expression
for o in bound_options
]
)
),
output_names=['singular_or_list'] #TODO construct name from inputs
)
],
base_schema=base_schema,
extension_uris=extension_uris,
extensions=extensions,
)

return resolve

def multi_or_list(value: Iterable[ExtendedExpressionOrUnbound], options: Iterable[Iterable[ExtendedExpressionOrUnbound]]):
"""Builds a resolver for ExtendedExpression containing a MultiOrList expression"""
def resolve(
base_schema: stp.NamedStruct, registry: ExtensionRegistry
) -> stee.ExtendedExpression:
bound_value = [resolve_expression(e, base_schema, registry) for e in value]
bound_options = [
[resolve_expression(e, base_schema, registry) for e in o] for o in options
]

extension_uris = merge_extension_uris(
*[b.extension_uris for b in bound_value],
*[e.extension_uris for b in bound_options for e in b],
)

extensions = merge_extension_uris(
*[b.extensions for b in bound_value],
*[e.extensions for b in bound_options for e in b],
)

return stee.ExtendedExpression(
referred_expr=[
stee.ExpressionReference(
expression=stalg.Expression(
multi_or_list=stalg.Expression.MultiOrList(
value=[e.referred_expr[0].expression for e in bound_value],
options=[
stalg.Expression.MultiOrList.Record(
fields=[e.referred_expr[0].expression for e in option]
)
for option in bound_options
]
)
),
output_names=['multi_or_list'] #TODO construct name from inputs
)
],
base_schema=base_schema,
extension_uris=extension_uris,
extensions=extensions,
)

return resolve

def cast(input: ExtendedExpressionOrUnbound, type: stp.Type):
"""Builds a resolver for ExtendedExpression containing a cast expression"""
def resolve(
base_schema: stp.NamedStruct, registry: ExtensionRegistry
) -> stee.ExtendedExpression:
bound_input = resolve_expression(input, base_schema, registry)

return stee.ExtendedExpression(
referred_expr=[
stee.ExpressionReference(
expression=stalg.Expression(
cast=stalg.Expression.Cast(
input=bound_input.referred_expr[0].expression,
type=type,
failure_behavior=stalg.Expression.Cast.FAILURE_BEHAVIOR_RETURN_NULL
)
),
output_names=['cast'] #TODO construct name from inputs
)
],
base_schema=base_schema,
extension_uris=bound_input.extension_uris,
extensions=bound_input.extensions,
)

return resolve
Loading