Skip to content

Commit 5dd5ef4

Browse files
authored
feat: add builders for switch, cast expressions (#78)
- Builders for switch, cast, singular_or_list, multi_or_list - Builders accept either `ExtendedExtension` or `UnboundExtendedExtension` - Adds example of using pyarrow expressions to build substrait plans
1 parent c453d69 commit 5dd5ef4

11 files changed

+477
-39
lines changed

examples/pyarrow_example.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
# Install pyarrow before running this example
2+
# /// script
3+
# dependencies = [
4+
# "pyarrow==20.0.0",
5+
# "substrait[extensions] @ file:///${PROJECT_ROOT}/"
6+
# ]
7+
# ///
8+
import pyarrow as pa
9+
import pyarrow.compute as pc
10+
import pyarrow.substrait as pa_substrait
11+
import substrait
12+
from substrait.builders.plan import project, read_named_table
13+
14+
arrow_schema = pa.schema([
15+
pa.field("x", pa.int32()),
16+
pa.field("y", pa.int32())
17+
])
18+
19+
substrait_schema = pa_substrait.serialize_schema(arrow_schema).to_pysubstrait().base_schema
20+
21+
substrait_expr = pa_substrait.serialize_expressions(
22+
exprs=[pc.field("x") + pc.field("y")],
23+
names=["total"],
24+
schema=arrow_schema
25+
)
26+
27+
pysubstrait_expr = substrait.proto.ExtendedExpression.FromString(bytes(substrait_expr))
28+
29+
table = read_named_table("example", substrait_schema)
30+
table = project(table, expressions=[pysubstrait_expr])(None)
31+
print(table)

src/substrait/builders/extended_expression.py

Lines changed: 186 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from typing import Callable, Any, Union, Iterable
1111

1212
UnboundExtendedExpression = Callable[[stp.NamedStruct, ExtensionRegistry], stee.ExtendedExpression]
13+
ExtendedExpressionOrUnbound = Union[stee.ExtendedExpression, UnboundExtendedExpression]
1314

1415
def _alias_or_inferred(
1516
alias: Union[Iterable[str], str],
@@ -21,6 +22,13 @@ def _alias_or_inferred(
2122
else:
2223
return [f'{op}({",".join(args)})']
2324

25+
def resolve_expression(
26+
expression: ExtendedExpressionOrUnbound,
27+
base_schema: stp.NamedStruct,
28+
registry: ExtensionRegistry
29+
) -> stee.ExtendedExpression:
30+
return expression if isinstance(expression, stee.ExtendedExpression) else expression(base_schema, registry)
31+
2432
def literal(value: Any, type: stp.Type, alias: Union[Iterable[str], str] = None) -> UnboundExtendedExpression:
2533
"""Builds a resolver for ExtendedExpression containing a literal expression"""
2634
def resolve(base_schema: stp.NamedStruct, registry: ExtensionRegistry) -> stee.ExtendedExpression:
@@ -139,14 +147,14 @@ def resolve(
139147
return resolve
140148

141149
def scalar_function(
142-
uri: str, function: str, *expressions: UnboundExtendedExpression, alias: Union[Iterable[str], str] = None
150+
uri: str, function: str, expressions: Iterable[ExtendedExpressionOrUnbound], alias: Union[Iterable[str], str] = None
143151
):
144152
"""Builds a resolver for ExtendedExpression containing a ScalarFunction expression"""
145153
def resolve(
146154
base_schema: stp.NamedStruct, registry: ExtensionRegistry
147155
) -> stee.ExtendedExpression:
148-
bound_expressions: Iterable[stee.ExtendedExpression] = [
149-
e(base_schema, registry) for e in expressions
156+
bound_expressions = [
157+
resolve_expression(e, base_schema, registry) for e in expressions
150158
]
151159

152160
expression_schemas = [
@@ -210,14 +218,14 @@ def resolve(
210218
return resolve
211219

212220
def aggregate_function(
213-
uri: str, function: str, *expressions: UnboundExtendedExpression, alias: Union[Iterable[str], str] = None
221+
uri: str, function: str, expressions: Iterable[ExtendedExpressionOrUnbound], alias: Union[Iterable[str], str] = None
214222
):
215223
"""Builds a resolver for ExtendedExpression containing a AggregateFunction measure"""
216224
def resolve(
217225
base_schema: stp.NamedStruct, registry: ExtensionRegistry
218226
) -> stee.ExtendedExpression:
219227
bound_expressions: Iterable[stee.ExtendedExpression] = [
220-
e(base_schema, registry) for e in expressions
228+
resolve_expression(e, base_schema, registry) for e in expressions
221229
]
222230

223231
expression_schemas = [
@@ -281,19 +289,19 @@ def resolve(
281289
def window_function(
282290
uri: str,
283291
function: str,
284-
*expressions: UnboundExtendedExpression,
285-
partitions: Iterable[UnboundExtendedExpression] = [],
292+
expressions: Iterable[ExtendedExpressionOrUnbound],
293+
partitions: Iterable[ExtendedExpressionOrUnbound] = [],
286294
alias: Union[Iterable[str], str] = None
287295
):
288296
"""Builds a resolver for ExtendedExpression containing a WindowFunction expression"""
289297
def resolve(
290298
base_schema: stp.NamedStruct, registry: ExtensionRegistry
291299
) -> stee.ExtendedExpression:
292300
bound_expressions: Iterable[stee.ExtendedExpression] = [
293-
e(base_schema, registry) for e in expressions
301+
resolve_expression(e, base_schema, registry) for e in expressions
294302
]
295303

296-
bound_partitions = [e(base_schema, registry) for e in partitions]
304+
bound_partitions = [resolve_expression(e, base_schema, registry) for e in partitions]
297305

298306
expression_schemas = [
299307
infer_extended_expression_schema(b) for b in bound_expressions
@@ -363,17 +371,17 @@ def resolve(
363371
return resolve
364372

365373

366-
def if_then(ifs: Iterable[tuple[UnboundExtendedExpression, UnboundExtendedExpression]], _else: UnboundExtendedExpression, alias: Union[Iterable[str], str] = None):
374+
def if_then(ifs: Iterable[tuple[ExtendedExpressionOrUnbound, ExtendedExpressionOrUnbound]], _else: ExtendedExpressionOrUnbound, alias: Union[Iterable[str], str] = None):
367375
"""Builds a resolver for ExtendedExpression containing an IfThen expression"""
368376
def resolve(
369377
base_schema: stp.NamedStruct, registry: ExtensionRegistry
370378
) -> stee.ExtendedExpression:
371379
bound_ifs = [
372-
(if_clause[0](base_schema, registry), if_clause[1](base_schema, registry))
380+
(resolve_expression(if_clause[0], base_schema, registry), resolve_expression(if_clause[1], base_schema, registry))
373381
for if_clause in ifs
374382
]
375383

376-
bound_else = _else(base_schema, registry)
384+
bound_else = resolve_expression(_else, base_schema, registry)
377385

378386
extension_uris = merge_extension_uris(
379387
*[b[0].extension_uris for b in bound_ifs],
@@ -413,3 +421,169 @@ def resolve(
413421
)
414422

415423
return resolve
424+
425+
def switch(match: ExtendedExpressionOrUnbound,
426+
ifs: Iterable[tuple[ExtendedExpressionOrUnbound, ExtendedExpressionOrUnbound]],
427+
_else: ExtendedExpressionOrUnbound):
428+
"""Builds a resolver for ExtendedExpression containing a switch expression"""
429+
def resolve(
430+
base_schema: stp.NamedStruct, registry: ExtensionRegistry
431+
) -> stee.ExtendedExpression:
432+
bound_match = resolve_expression(match, base_schema, registry)
433+
bound_ifs = [
434+
(
435+
resolve_expression(a, base_schema, registry),
436+
resolve_expression(b, base_schema, registry)
437+
) for a, b in ifs]
438+
bound_else = resolve_expression(_else, base_schema, registry)
439+
440+
extension_uris = merge_extension_uris(
441+
bound_match.extension_uris,
442+
*[b.extension_uris for _, b in bound_ifs],
443+
bound_else.extension_uris
444+
)
445+
446+
extensions = merge_extension_declarations(
447+
bound_match.extensions,
448+
*[b.extensions for _, b in bound_ifs],
449+
bound_else.extensions
450+
)
451+
452+
return stee.ExtendedExpression(
453+
referred_expr=[
454+
stee.ExpressionReference(
455+
expression=stalg.Expression(
456+
switch_expression=stalg.Expression.SwitchExpression(
457+
match=bound_match.referred_expr[0].expression,
458+
ifs=[
459+
stalg.Expression.SwitchExpression.IfValue(**{
460+
'if': i.referred_expr[0].expression.literal,
461+
'then': t.referred_expr[0].expression
462+
})
463+
for i, t in bound_ifs
464+
],
465+
**{
466+
'else': bound_else.referred_expr[0].expression
467+
}
468+
)
469+
),
470+
output_names=['switch'] #TODO construct name from inputs
471+
)
472+
],
473+
base_schema=base_schema,
474+
extension_uris=extension_uris,
475+
extensions=extensions,
476+
)
477+
478+
return resolve
479+
480+
def singular_or_list(value: ExtendedExpressionOrUnbound, options: Iterable[ExtendedExpressionOrUnbound]):
481+
"""Builds a resolver for ExtendedExpression containing a SingularOrList expression"""
482+
def resolve(
483+
base_schema: stp.NamedStruct, registry: ExtensionRegistry
484+
) -> stee.ExtendedExpression:
485+
bound_value = resolve_expression(value, base_schema, registry)
486+
bound_options = [resolve_expression(o, base_schema, registry) for o in options]
487+
488+
extension_uris = merge_extension_uris(
489+
bound_value.extension_uris,
490+
*[b.extension_uris for b in bound_options]
491+
)
492+
493+
extensions = merge_extension_declarations(
494+
bound_value.extensions,
495+
*[b.extensions for b in bound_options]
496+
)
497+
498+
return stee.ExtendedExpression(
499+
referred_expr=[
500+
stee.ExpressionReference(
501+
expression=stalg.Expression(
502+
singular_or_list=stalg.Expression.SingularOrList(
503+
value=bound_value.referred_expr[0].expression,
504+
options=[
505+
o.referred_expr[0].expression
506+
for o in bound_options
507+
]
508+
)
509+
),
510+
output_names=['singular_or_list'] #TODO construct name from inputs
511+
)
512+
],
513+
base_schema=base_schema,
514+
extension_uris=extension_uris,
515+
extensions=extensions,
516+
)
517+
518+
return resolve
519+
520+
def multi_or_list(value: Iterable[ExtendedExpressionOrUnbound], options: Iterable[Iterable[ExtendedExpressionOrUnbound]]):
521+
"""Builds a resolver for ExtendedExpression containing a MultiOrList expression"""
522+
def resolve(
523+
base_schema: stp.NamedStruct, registry: ExtensionRegistry
524+
) -> stee.ExtendedExpression:
525+
bound_value = [resolve_expression(e, base_schema, registry) for e in value]
526+
bound_options = [
527+
[resolve_expression(e, base_schema, registry) for e in o] for o in options
528+
]
529+
530+
extension_uris = merge_extension_uris(
531+
*[b.extension_uris for b in bound_value],
532+
*[e.extension_uris for b in bound_options for e in b],
533+
)
534+
535+
extensions = merge_extension_uris(
536+
*[b.extensions for b in bound_value],
537+
*[e.extensions for b in bound_options for e in b],
538+
)
539+
540+
return stee.ExtendedExpression(
541+
referred_expr=[
542+
stee.ExpressionReference(
543+
expression=stalg.Expression(
544+
multi_or_list=stalg.Expression.MultiOrList(
545+
value=[e.referred_expr[0].expression for e in bound_value],
546+
options=[
547+
stalg.Expression.MultiOrList.Record(
548+
fields=[e.referred_expr[0].expression for e in option]
549+
)
550+
for option in bound_options
551+
]
552+
)
553+
),
554+
output_names=['multi_or_list'] #TODO construct name from inputs
555+
)
556+
],
557+
base_schema=base_schema,
558+
extension_uris=extension_uris,
559+
extensions=extensions,
560+
)
561+
562+
return resolve
563+
564+
def cast(input: ExtendedExpressionOrUnbound, type: stp.Type):
565+
"""Builds a resolver for ExtendedExpression containing a cast expression"""
566+
def resolve(
567+
base_schema: stp.NamedStruct, registry: ExtensionRegistry
568+
) -> stee.ExtendedExpression:
569+
bound_input = resolve_expression(input, base_schema, registry)
570+
571+
return stee.ExtendedExpression(
572+
referred_expr=[
573+
stee.ExpressionReference(
574+
expression=stalg.Expression(
575+
cast=stalg.Expression.Cast(
576+
input=bound_input.referred_expr[0].expression,
577+
type=type,
578+
failure_behavior=stalg.Expression.Cast.FAILURE_BEHAVIOR_RETURN_NULL
579+
)
580+
),
581+
output_names=['cast'] #TODO construct name from inputs
582+
)
583+
],
584+
base_schema=base_schema,
585+
extension_uris=bound_input.extension_uris,
586+
extensions=bound_input.extensions,
587+
)
588+
589+
return resolve

0 commit comments

Comments
 (0)