diff --git a/examples/pyarrow_example.py b/examples/pyarrow_example.py new file mode 100644 index 0000000..d4c0d36 --- /dev/null +++ b/examples/pyarrow_example.py @@ -0,0 +1,31 @@ +# Install pyarrow before running this example +# /// script +# dependencies = [ +# "pyarrow==20.0.0", +# "substrait[extensions] @ file:///${PROJECT_ROOT}/" +# ] +# /// +import pyarrow as pa +import pyarrow.compute as pc +import pyarrow.substrait as pa_substrait +import substrait +from substrait.builders.plan import project, read_named_table + +arrow_schema = pa.schema([ + pa.field("x", pa.int32()), + pa.field("y", pa.int32()) +]) + +substrait_schema = pa_substrait.serialize_schema(arrow_schema).to_pysubstrait().base_schema + +substrait_expr = pa_substrait.serialize_expressions( + exprs=[pc.field("x") + pc.field("y")], + names=["total"], + schema=arrow_schema +) + +pysubstrait_expr = substrait.proto.ExtendedExpression.FromString(bytes(substrait_expr)) + +table = read_named_table("example", substrait_schema) +table = project(table, expressions=[pysubstrait_expr])(None) +print(table) \ No newline at end of file diff --git a/src/substrait/builders/extended_expression.py b/src/substrait/builders/extended_expression.py index fa4b444..39bcb66 100644 --- a/src/substrait/builders/extended_expression.py +++ b/src/substrait/builders/extended_expression.py @@ -10,6 +10,7 @@ from typing import Callable, Any, Union, Iterable UnboundExtendedExpression = Callable[[stp.NamedStruct, ExtensionRegistry], stee.ExtendedExpression] +ExtendedExpressionOrUnbound = Union[stee.ExtendedExpression, UnboundExtendedExpression] def _alias_or_inferred( alias: Union[Iterable[str], str], @@ -21,6 +22,13 @@ def _alias_or_inferred( else: return [f'{op}({",".join(args)})'] +def resolve_expression( + expression: ExtendedExpressionOrUnbound, + base_schema: stp.NamedStruct, + registry: ExtensionRegistry + ) -> stee.ExtendedExpression: + return expression if isinstance(expression, stee.ExtendedExpression) else expression(base_schema, registry) + def literal(value: Any, type: stp.Type, alias: Union[Iterable[str], str] = None) -> UnboundExtendedExpression: """Builds a resolver for ExtendedExpression containing a literal expression""" def resolve(base_schema: stp.NamedStruct, registry: ExtensionRegistry) -> stee.ExtendedExpression: @@ -139,14 +147,14 @@ def resolve( return resolve def scalar_function( - uri: str, function: str, *expressions: UnboundExtendedExpression, alias: Union[Iterable[str], str] = None + uri: str, function: str, expressions: Iterable[ExtendedExpressionOrUnbound], alias: Union[Iterable[str], str] = None ): """Builds a resolver for ExtendedExpression containing a ScalarFunction expression""" def resolve( base_schema: stp.NamedStruct, registry: ExtensionRegistry ) -> stee.ExtendedExpression: - bound_expressions: Iterable[stee.ExtendedExpression] = [ - e(base_schema, registry) for e in expressions + bound_expressions = [ + resolve_expression(e, base_schema, registry) for e in expressions ] expression_schemas = [ @@ -210,14 +218,14 @@ def resolve( return resolve def aggregate_function( - uri: str, function: str, *expressions: UnboundExtendedExpression, alias: Union[Iterable[str], str] = None + uri: str, function: str, expressions: Iterable[ExtendedExpressionOrUnbound], alias: Union[Iterable[str], str] = None ): """Builds a resolver for ExtendedExpression containing a AggregateFunction measure""" def resolve( base_schema: stp.NamedStruct, registry: ExtensionRegistry ) -> stee.ExtendedExpression: bound_expressions: Iterable[stee.ExtendedExpression] = [ - e(base_schema, registry) for e in expressions + resolve_expression(e, base_schema, registry) for e in expressions ] expression_schemas = [ @@ -281,8 +289,8 @@ def resolve( def window_function( uri: str, function: str, - *expressions: UnboundExtendedExpression, - partitions: Iterable[UnboundExtendedExpression] = [], + expressions: Iterable[ExtendedExpressionOrUnbound], + partitions: Iterable[ExtendedExpressionOrUnbound] = [], alias: Union[Iterable[str], str] = None ): """Builds a resolver for ExtendedExpression containing a WindowFunction expression""" @@ -290,10 +298,10 @@ def resolve( base_schema: stp.NamedStruct, registry: ExtensionRegistry ) -> stee.ExtendedExpression: bound_expressions: Iterable[stee.ExtendedExpression] = [ - e(base_schema, registry) for e in expressions + resolve_expression(e, base_schema, registry) for e in expressions ] - bound_partitions = [e(base_schema, registry) for e in partitions] + bound_partitions = [resolve_expression(e, base_schema, registry) for e in partitions] expression_schemas = [ infer_extended_expression_schema(b) for b in bound_expressions @@ -363,17 +371,17 @@ def resolve( return resolve -def if_then(ifs: Iterable[tuple[UnboundExtendedExpression, UnboundExtendedExpression]], _else: UnboundExtendedExpression, alias: Union[Iterable[str], str] = None): +def if_then(ifs: Iterable[tuple[ExtendedExpressionOrUnbound, ExtendedExpressionOrUnbound]], _else: ExtendedExpressionOrUnbound, alias: Union[Iterable[str], str] = None): """Builds a resolver for ExtendedExpression containing an IfThen expression""" def resolve( base_schema: stp.NamedStruct, registry: ExtensionRegistry ) -> stee.ExtendedExpression: bound_ifs = [ - (if_clause[0](base_schema, registry), if_clause[1](base_schema, registry)) + (resolve_expression(if_clause[0], base_schema, registry), resolve_expression(if_clause[1], base_schema, registry)) for if_clause in ifs ] - bound_else = _else(base_schema, registry) + bound_else = resolve_expression(_else, base_schema, registry) extension_uris = merge_extension_uris( *[b[0].extension_uris for b in bound_ifs], @@ -413,3 +421,169 @@ def resolve( ) return resolve + +def switch(match: ExtendedExpressionOrUnbound, + ifs: Iterable[tuple[ExtendedExpressionOrUnbound, ExtendedExpressionOrUnbound]], + _else: ExtendedExpressionOrUnbound): + """Builds a resolver for ExtendedExpression containing a switch expression""" + def resolve( + base_schema: stp.NamedStruct, registry: ExtensionRegistry + ) -> stee.ExtendedExpression: + bound_match = resolve_expression(match, base_schema, registry) + bound_ifs = [ + ( + resolve_expression(a, base_schema, registry), + resolve_expression(b, base_schema, registry) + ) for a, b in ifs] + bound_else = resolve_expression(_else, base_schema, registry) + + extension_uris = merge_extension_uris( + bound_match.extension_uris, + *[b.extension_uris for _, b in bound_ifs], + bound_else.extension_uris + ) + + extensions = merge_extension_declarations( + bound_match.extensions, + *[b.extensions for _, b in bound_ifs], + bound_else.extensions + ) + + return stee.ExtendedExpression( + referred_expr=[ + stee.ExpressionReference( + expression=stalg.Expression( + switch_expression=stalg.Expression.SwitchExpression( + match=bound_match.referred_expr[0].expression, + ifs=[ + stalg.Expression.SwitchExpression.IfValue(**{ + 'if': i.referred_expr[0].expression.literal, + 'then': t.referred_expr[0].expression + }) + for i, t in bound_ifs + ], + **{ + 'else': bound_else.referred_expr[0].expression + } + ) + ), + output_names=['switch'] #TODO construct name from inputs + ) + ], + base_schema=base_schema, + extension_uris=extension_uris, + extensions=extensions, + ) + + return resolve + +def singular_or_list(value: ExtendedExpressionOrUnbound, options: Iterable[ExtendedExpressionOrUnbound]): + """Builds a resolver for ExtendedExpression containing a SingularOrList expression""" + def resolve( + base_schema: stp.NamedStruct, registry: ExtensionRegistry + ) -> stee.ExtendedExpression: + bound_value = resolve_expression(value, base_schema, registry) + bound_options = [resolve_expression(o, base_schema, registry) for o in options] + + extension_uris = merge_extension_uris( + bound_value.extension_uris, + *[b.extension_uris for b in bound_options] + ) + + extensions = merge_extension_declarations( + bound_value.extensions, + *[b.extensions for b in bound_options] + ) + + return stee.ExtendedExpression( + referred_expr=[ + stee.ExpressionReference( + expression=stalg.Expression( + singular_or_list=stalg.Expression.SingularOrList( + value=bound_value.referred_expr[0].expression, + options=[ + o.referred_expr[0].expression + for o in bound_options + ] + ) + ), + output_names=['singular_or_list'] #TODO construct name from inputs + ) + ], + base_schema=base_schema, + extension_uris=extension_uris, + extensions=extensions, + ) + + return resolve + +def multi_or_list(value: Iterable[ExtendedExpressionOrUnbound], options: Iterable[Iterable[ExtendedExpressionOrUnbound]]): + """Builds a resolver for ExtendedExpression containing a MultiOrList expression""" + def resolve( + base_schema: stp.NamedStruct, registry: ExtensionRegistry + ) -> stee.ExtendedExpression: + bound_value = [resolve_expression(e, base_schema, registry) for e in value] + bound_options = [ + [resolve_expression(e, base_schema, registry) for e in o] for o in options + ] + + extension_uris = merge_extension_uris( + *[b.extension_uris for b in bound_value], + *[e.extension_uris for b in bound_options for e in b], + ) + + extensions = merge_extension_uris( + *[b.extensions for b in bound_value], + *[e.extensions for b in bound_options for e in b], + ) + + return stee.ExtendedExpression( + referred_expr=[ + stee.ExpressionReference( + expression=stalg.Expression( + multi_or_list=stalg.Expression.MultiOrList( + value=[e.referred_expr[0].expression for e in bound_value], + options=[ + stalg.Expression.MultiOrList.Record( + fields=[e.referred_expr[0].expression for e in option] + ) + for option in bound_options + ] + ) + ), + output_names=['multi_or_list'] #TODO construct name from inputs + ) + ], + base_schema=base_schema, + extension_uris=extension_uris, + extensions=extensions, + ) + + return resolve + +def cast(input: ExtendedExpressionOrUnbound, type: stp.Type): + """Builds a resolver for ExtendedExpression containing a cast expression""" + def resolve( + base_schema: stp.NamedStruct, registry: ExtensionRegistry + ) -> stee.ExtendedExpression: + bound_input = resolve_expression(input, base_schema, registry) + + return stee.ExtendedExpression( + referred_expr=[ + stee.ExpressionReference( + expression=stalg.Expression( + cast=stalg.Expression.Cast( + input=bound_input.referred_expr[0].expression, + type=type, + failure_behavior=stalg.Expression.Cast.FAILURE_BEHAVIOR_RETURN_NULL + ) + ), + output_names=['cast'] #TODO construct name from inputs + ) + ], + base_schema=base_schema, + extension_uris=bound_input.extension_uris, + extensions=bound_input.extensions, + ) + + return resolve diff --git a/src/substrait/builders/plan.py b/src/substrait/builders/plan.py index 3075c4b..29fc193 100644 --- a/src/substrait/builders/plan.py +++ b/src/substrait/builders/plan.py @@ -12,7 +12,7 @@ import substrait.gen.proto.type_pb2 as stt import substrait.gen.proto.extended_expression_pb2 as stee from substrait.extension_registry import ExtensionRegistry -from substrait.builders.extended_expression import UnboundExtendedExpression +from substrait.builders.extended_expression import ExtendedExpressionOrUnbound, resolve_expression from substrait.type_inference import infer_plan_schema from substrait.utils import merge_extension_declarations, merge_extension_uris @@ -48,13 +48,12 @@ def resolve(registry: ExtensionRegistry) -> stp.Plan: def project( - plan: PlanOrUnbound, expressions: Iterable[UnboundExtendedExpression] + plan: PlanOrUnbound, expressions: Iterable[ExtendedExpressionOrUnbound] ) -> UnboundPlan: def resolve(registry: ExtensionRegistry) -> stp.Plan: _plan = plan if isinstance(plan, stp.Plan) else plan(registry) ns = infer_plan_schema(_plan) - bound_expressions: Iterable[stee.ExtendedExpression] = [ - e(ns, registry) for e in expressions] + bound_expressions: Iterable[stee.ExtendedExpression] = [resolve_expression(e, ns, registry) for e in expressions] start_index = len(_plan.relations[-1].root.names) @@ -82,12 +81,12 @@ def resolve(registry: ExtensionRegistry) -> stp.Plan: def filter( - plan: PlanOrUnbound, expression: UnboundExtendedExpression + plan: PlanOrUnbound, expression: ExtendedExpressionOrUnbound ) -> UnboundPlan: def resolve(registry: ExtensionRegistry) -> stp.Plan: bound_plan = plan if isinstance(plan, stp.Plan) else plan(registry) ns = infer_plan_schema(bound_plan) - bound_expression: stee.ExtendedExpression = expression(ns, registry) + bound_expression: stee.ExtendedExpression = resolve_expression(expression, ns, registry) rel = stalg.Rel( filter=stalg.FilterRel( @@ -108,14 +107,14 @@ def resolve(registry: ExtensionRegistry) -> stp.Plan: def sort( plan: PlanOrUnbound, - expressions: Iterable[Union[UnboundExtendedExpression, tuple[UnboundExtendedExpression, stalg.SortField.SortDirection.ValueType]]] + expressions: Iterable[Union[ExtendedExpressionOrUnbound, tuple[ExtendedExpressionOrUnbound, stalg.SortField.SortDirection.ValueType]]] ) -> UnboundPlan: def resolve(registry: ExtensionRegistry) -> stp.Plan: bound_plan = plan if isinstance(plan, stp.Plan) else plan(registry) ns = infer_plan_schema(bound_plan) bound_expressions = [(e, stalg.SortField.SORT_DIRECTION_ASC_NULLS_LAST) if not isinstance(e, tuple) else e for e in expressions] - bound_expressions = [(e[0](ns, registry), e[1]) for e in bound_expressions] + bound_expressions = [(resolve_expression(e[0], ns, registry), e[1]) for e in bound_expressions] rel = stalg.Rel( sort=stalg.SortRel( @@ -159,14 +158,14 @@ def resolve(registry: ExtensionRegistry) -> stp.Plan: return resolve def fetch(plan: PlanOrUnbound, - offset: UnboundExtendedExpression, - count: UnboundExtendedExpression) -> UnboundPlan: + offset: ExtendedExpressionOrUnbound, + count: ExtendedExpressionOrUnbound) -> UnboundPlan: def resolve(registry: ExtensionRegistry) -> stp.Plan: bound_plan = plan if isinstance(plan, stp.Plan) else plan(registry) ns = infer_plan_schema(bound_plan) - bound_offset = offset(ns, registry) - bound_count = count(ns, registry) + bound_offset = resolve_expression(offset, ns, registry) + bound_count = resolve_expression(count, ns, registry) rel = stalg.Rel( fetch=stalg.FetchRel( @@ -191,7 +190,7 @@ def resolve(registry: ExtensionRegistry) -> stp.Plan: def join( left: PlanOrUnbound, right: PlanOrUnbound, - expression: UnboundExtendedExpression, + expression: ExtendedExpressionOrUnbound, type: stalg.JoinRel.JoinType, ) -> UnboundPlan: def resolve(registry: ExtensionRegistry) -> stp.Plan: @@ -207,7 +206,7 @@ def resolve(registry: ExtensionRegistry) -> stp.Plan: ), names=list(left_ns.names) + list(right_ns.names), ) - bound_expression: stee.ExtendedExpression = expression(ns, registry) + bound_expression: stee.ExtendedExpression = resolve_expression(expression, ns, registry) rel = stalg.Rel( join=stalg.JoinRel( @@ -260,15 +259,15 @@ def resolve(registry: ExtensionRegistry) -> stp.Plan: # TODO grouping sets def aggregate( input: PlanOrUnbound, - grouping_expressions: Iterable[UnboundExtendedExpression], - measures: Iterable[UnboundExtendedExpression], + grouping_expressions: Iterable[ExtendedExpressionOrUnbound], + measures: Iterable[ExtendedExpressionOrUnbound], ) -> UnboundPlan: def resolve(registry: ExtensionRegistry) -> stp.Plan: bound_input = input if isinstance(input, stp.Plan) else input(registry) ns = infer_plan_schema(bound_input) - bound_grouping_expressions = [e(ns, registry) for e in grouping_expressions] - bound_measures = [e(ns, registry) for e in measures] + bound_grouping_expressions = [resolve_expression(e, ns, registry) for e in grouping_expressions] + bound_measures = [resolve_expression(e, ns, registry) for e in measures] rel = stalg.Rel( aggregate=stalg.AggregateRel( diff --git a/tests/builders/extended_expression/test_aggregate_function.py b/tests/builders/extended_expression/test_aggregate_function.py index 50a2de9..304e1c7 100644 --- a/tests/builders/extended_expression/test_aggregate_function.py +++ b/tests/builders/extended_expression/test_aggregate_function.py @@ -39,8 +39,8 @@ registry.register_extension_dict(yaml.safe_load(content), uri="test_uri") def test_aggregate_count(): - e = aggregate_function('test_uri', 'count', - literal(10, type=stt.Type(i8=stt.Type.I8(nullability=stt.Type.NULLABILITY_REQUIRED))), + e = aggregate_function('test_uri', 'count', + expressions=[literal(10, type=stt.Type(i8=stt.Type.I8(nullability=stt.Type.NULLABILITY_REQUIRED)))], alias='count', )(named_struct, registry) diff --git a/tests/builders/extended_expression/test_cast.py b/tests/builders/extended_expression/test_cast.py new file mode 100644 index 0000000..2fcfca5 --- /dev/null +++ b/tests/builders/extended_expression/test_cast.py @@ -0,0 +1,49 @@ +import yaml + +import substrait.gen.proto.algebra_pb2 as stalg +import substrait.gen.proto.type_pb2 as stt +import substrait.gen.proto.extended_expression_pb2 as stee +from substrait.builders.extended_expression import cast, literal +from substrait.builders.type import i8, i16 +from substrait.extension_registry import ExtensionRegistry + +struct = stt.Type.Struct( + types=[ + stt.Type(i64=stt.Type.I64(nullability=stt.Type.NULLABILITY_REQUIRED)), + stt.Type(string=stt.Type.String(nullability=stt.Type.NULLABILITY_NULLABLE)), + stt.Type(fp32=stt.Type.FP32(nullability=stt.Type.NULLABILITY_NULLABLE)), + ] +) + +named_struct = stt.NamedStruct( + names=["order_id", "description", "order_total"], struct=struct +) + +registry = ExtensionRegistry(load_default_extensions=False) + +def test_cast(): + e = cast( + input=literal(3, i8()), + type=i16() + )(named_struct, registry) + + expected = stee.ExtendedExpression( + referred_expr=[ + stee.ExpressionReference( + expression=stalg.Expression( + cast=stalg.Expression.Cast( + type=stt.Type( + i16=stt.Type.I16(nullability=stt.Type.NULLABILITY_NULLABLE) + ), + input=stalg.Expression(literal=stalg.Expression.Literal(i8=3, nullable=True)), + failure_behavior=stalg.Expression.Cast.FAILURE_BEHAVIOR_RETURN_NULL + ) + ), + output_names=["cast"], + ) + ], + base_schema=named_struct, + ) + + assert e == expected + diff --git a/tests/builders/extended_expression/test_multi_or_list.py b/tests/builders/extended_expression/test_multi_or_list.py new file mode 100644 index 0000000..d2efa2d --- /dev/null +++ b/tests/builders/extended_expression/test_multi_or_list.py @@ -0,0 +1,65 @@ +import yaml + +import substrait.gen.proto.algebra_pb2 as stalg +import substrait.gen.proto.type_pb2 as stt +import substrait.gen.proto.extended_expression_pb2 as stee +from substrait.builders.extended_expression import multi_or_list, literal +from substrait.builders.type import i8 +from substrait.extension_registry import ExtensionRegistry + +struct = stt.Type.Struct( + types=[ + stt.Type(i64=stt.Type.I64(nullability=stt.Type.NULLABILITY_REQUIRED)), + stt.Type(string=stt.Type.String(nullability=stt.Type.NULLABILITY_NULLABLE)), + stt.Type(fp32=stt.Type.FP32(nullability=stt.Type.NULLABILITY_NULLABLE)), + ] +) + +named_struct = stt.NamedStruct( + names=["order_id", "description", "order_total"], struct=struct +) + +registry = ExtensionRegistry(load_default_extensions=False) + +def test_singular_or_list(): + e = multi_or_list( + value=[literal(1, i8()), literal(2, i8())], + options=[ + [literal(1, i8()), literal(2, i8())], + [literal(3, i8()), literal(4, i8())] + ] + )(named_struct, registry) + + expected = stee.ExtendedExpression( + referred_expr=[ + stee.ExpressionReference( + expression=stalg.Expression( + multi_or_list=stalg.Expression.MultiOrList( + value=[ + stalg.Expression(literal=stalg.Expression.Literal(i8=1, nullable=True)), + stalg.Expression(literal=stalg.Expression.Literal(i8=2, nullable=True)) + ], + options=[ + stalg.Expression.MultiOrList.Record( + fields=[ + stalg.Expression(literal=stalg.Expression.Literal(i8=1, nullable=True)), + stalg.Expression(literal=stalg.Expression.Literal(i8=2, nullable=True)) + ] + ), + stalg.Expression.MultiOrList.Record( + fields=[ + stalg.Expression(literal=stalg.Expression.Literal(i8=3, nullable=True)), + stalg.Expression(literal=stalg.Expression.Literal(i8=4, nullable=True)) + ] + ) + ] + ) + ), + output_names=["multi_or_list"], + ) + ], + base_schema=named_struct, + ) + + assert e == expected + diff --git a/tests/builders/extended_expression/test_scalar_function.py b/tests/builders/extended_expression/test_scalar_function.py index 26aba8e..6c7024b 100644 --- a/tests/builders/extended_expression/test_scalar_function.py +++ b/tests/builders/extended_expression/test_scalar_function.py @@ -44,9 +44,10 @@ def test_sclar_add(): e = scalar_function('test_uri', 'test_func', - literal(10, type=stt.Type(i8=stt.Type.I8(nullability=stt.Type.NULLABILITY_REQUIRED))), - literal(20, type=stt.Type(i8=stt.Type.I8(nullability=stt.Type.NULLABILITY_REQUIRED))) - )(named_struct, registry) + expressions=[ + literal(10, type=stt.Type(i8=stt.Type.I8(nullability=stt.Type.NULLABILITY_REQUIRED))), + literal(20, type=stt.Type(i8=stt.Type.I8(nullability=stt.Type.NULLABILITY_REQUIRED))) + ])(named_struct, registry) expected = stee.ExtendedExpression( extension_uris=[ @@ -87,9 +88,14 @@ def test_sclar_add(): def test_nested_scalar_calls(): e = scalar_function('test_uri', 'is_positive', - scalar_function('test_uri', 'test_func', + expressions=[ + scalar_function('test_uri', 'test_func', + expressions=[ literal(10, type=stt.Type(i8=stt.Type.I8(nullability=stt.Type.NULLABILITY_REQUIRED))), - literal(20, type=stt.Type(i8=stt.Type.I8(nullability=stt.Type.NULLABILITY_REQUIRED)))), + literal(20, type=stt.Type(i8=stt.Type.I8(nullability=stt.Type.NULLABILITY_REQUIRED))) + ] + ) + ], alias='positive' )(named_struct, registry) diff --git a/tests/builders/extended_expression/test_singular_or_list.py b/tests/builders/extended_expression/test_singular_or_list.py new file mode 100644 index 0000000..927a109 --- /dev/null +++ b/tests/builders/extended_expression/test_singular_or_list.py @@ -0,0 +1,52 @@ +import yaml + +import substrait.gen.proto.algebra_pb2 as stalg +import substrait.gen.proto.type_pb2 as stt +import substrait.gen.proto.extended_expression_pb2 as stee +from substrait.builders.extended_expression import singular_or_list, literal +from substrait.builders.type import i8 +from substrait.extension_registry import ExtensionRegistry + +struct = stt.Type.Struct( + types=[ + stt.Type(i64=stt.Type.I64(nullability=stt.Type.NULLABILITY_REQUIRED)), + stt.Type(string=stt.Type.String(nullability=stt.Type.NULLABILITY_NULLABLE)), + stt.Type(fp32=stt.Type.FP32(nullability=stt.Type.NULLABILITY_NULLABLE)), + ] +) + +named_struct = stt.NamedStruct( + names=["order_id", "description", "order_total"], struct=struct +) + +registry = ExtensionRegistry(load_default_extensions=False) + +def test_singular_or_list(): + e = singular_or_list( + value=literal(3, i8()), + options=[ + literal(1, i8()), + literal(2, i8()) + ] + )(named_struct, registry) + + expected = stee.ExtendedExpression( + referred_expr=[ + stee.ExpressionReference( + expression=stalg.Expression( + singular_or_list=stalg.Expression.SingularOrList( + value=stalg.Expression(literal=stalg.Expression.Literal(i8=3, nullable=True)), + options=[ + stalg.Expression(literal=stalg.Expression.Literal(i8=1, nullable=True)), + stalg.Expression(literal=stalg.Expression.Literal(i8=2, nullable=True)) + ] + ) + ), + output_names=["singular_or_list"], + ) + ], + base_schema=named_struct, + ) + + assert e == expected + diff --git a/tests/builders/extended_expression/test_switch.py b/tests/builders/extended_expression/test_switch.py new file mode 100644 index 0000000..35095f8 --- /dev/null +++ b/tests/builders/extended_expression/test_switch.py @@ -0,0 +1,62 @@ +import yaml + +import substrait.gen.proto.algebra_pb2 as stalg +import substrait.gen.proto.type_pb2 as stt +import substrait.gen.proto.extended_expression_pb2 as stee +from substrait.builders.extended_expression import switch, literal +from substrait.builders.type import i8 +from substrait.extension_registry import ExtensionRegistry + +struct = stt.Type.Struct( + types=[ + stt.Type(i64=stt.Type.I64(nullability=stt.Type.NULLABILITY_REQUIRED)), + stt.Type(string=stt.Type.String(nullability=stt.Type.NULLABILITY_NULLABLE)), + stt.Type(fp32=stt.Type.FP32(nullability=stt.Type.NULLABILITY_NULLABLE)), + ] +) + +named_struct = stt.NamedStruct( + names=["order_id", "description", "order_total"], struct=struct +) + +registry = ExtensionRegistry(load_default_extensions=False) + +def test_switch(): + e = switch( + match=literal(3, i8()), + ifs=[ + (literal(1, i8()), literal(1, i8())), + (literal(2, i8()), literal(4, i8())) + ], + _else=literal(9, i8()) + )(named_struct, registry) + + expected = stee.ExtendedExpression( + referred_expr=[ + stee.ExpressionReference( + expression=stalg.Expression( + switch_expression=stalg.Expression.SwitchExpression( + match=stalg.Expression(literal=stalg.Expression.Literal(i8=3, nullable=True)), + ifs=[ + stalg.Expression.SwitchExpression.IfValue(**{ + 'if': stalg.Expression.Literal(i8=1, nullable=True), + 'then': stalg.Expression(literal=stalg.Expression.Literal(i8=1, nullable=True)) + }), + stalg.Expression.SwitchExpression.IfValue(**{ + 'if': stalg.Expression.Literal(i8=2, nullable=True), + 'then': stalg.Expression(literal=stalg.Expression.Literal(i8=4, nullable=True)) + }), + ], + **{ + 'else': stalg.Expression(literal=stalg.Expression.Literal(i8=9, nullable=True)) + } + ) + ), + output_names=["switch"], + ) + ], + base_schema=named_struct, + ) + + assert e == expected + diff --git a/tests/builders/extended_expression/test_window_function.py b/tests/builders/extended_expression/test_window_function.py index 9e3bd00..f9dad38 100644 --- a/tests/builders/extended_expression/test_window_function.py +++ b/tests/builders/extended_expression/test_window_function.py @@ -45,7 +45,7 @@ registry.register_extension_dict(yaml.safe_load(content), uri="test_uri") def test_row_number(): - e = window_function('test_uri', 'row_number', alias='rn')(named_struct, registry) + e = window_function('test_uri', 'row_number', expressions=[], alias='rn')(named_struct, registry) expected = stee.ExtendedExpression( extension_uris=[ diff --git a/tests/builders/plan/test_aggregate.py b/tests/builders/plan/test_aggregate.py index 175519e..bb16af8 100644 --- a/tests/builders/plan/test_aggregate.py +++ b/tests/builders/plan/test_aggregate.py @@ -38,7 +38,7 @@ def test_aggregate(): table = read_named_table('table', named_struct) group_expr = column('id') - measure_expr = aggregate_function('test_uri', 'count', column('is_applicable'), alias=['count']) + measure_expr = aggregate_function('test_uri', 'count', expressions=[column('is_applicable')], alias=['count']) actual = aggregate(table, grouping_expressions=[group_expr],