diff --git a/src/substrait/builders/extended_expression.py b/src/substrait/builders/extended_expression.py new file mode 100644 index 0000000..f757980 --- /dev/null +++ b/src/substrait/builders/extended_expression.py @@ -0,0 +1,384 @@ +import itertools +import substrait.gen.proto.algebra_pb2 as stalg +import substrait.gen.proto.type_pb2 as stp +import substrait.gen.proto.extended_expression_pb2 as stee +import substrait.gen.proto.extensions.extensions_pb2 as ste +from substrait.extension_registry import ExtensionRegistry +from substrait.utils import type_num_names, merge_extension_uris, merge_extension_declarations +from substrait.type_inference import infer_extended_expression_schema +from typing import Callable, Any, Union, Iterable + +UnboundExtendedExpression = Callable[[stp.NamedStruct, ExtensionRegistry], stee.ExtendedExpression] + +def _alias_or_inferred( + alias: Union[Iterable[str], str], + op: str, + args: Iterable[str], + ): + if alias: + return [alias] if isinstance(alias, str) else alias + else: + return [f'{op}({",".join(args)})'] + +def literal(value: Any, type: stp.Type, alias: Union[Iterable[str], str] = None) -> UnboundExtendedExpression: + """Builds a resolver for ExtendedExpression containing a literal expression""" + def resolve(base_schema: stp.NamedStruct, registry: ExtensionRegistry) -> stee.ExtendedExpression: + kind = type.WhichOneof('kind') + + if kind == "bool": + literal = stalg.Expression.Literal(boolean=value, nullable=type.bool.nullability == stp.Type.NULLABILITY_NULLABLE) + elif kind == "i8": + literal = stalg.Expression.Literal(i8=value, nullable=type.i8.nullability == stp.Type.NULLABILITY_NULLABLE) + elif kind == "i16": + literal = stalg.Expression.Literal(i16=value, nullable=type.i16.nullability == stp.Type.NULLABILITY_NULLABLE) + elif kind == "i32": + literal = stalg.Expression.Literal(i32=value, nullable=type.i32.nullability == stp.Type.NULLABILITY_NULLABLE) + elif kind == "i64": + literal = stalg.Expression.Literal(i64=value, nullable=type.i64.nullability == stp.Type.NULLABILITY_NULLABLE) + elif kind == "fp32": + literal = stalg.Expression.Literal(fp32=value, nullable=type.fp32.nullability == stp.Type.NULLABILITY_NULLABLE) + elif kind == "fp64": + literal = stalg.Expression.Literal(fp64=value, nullable=type.fp64.nullability == stp.Type.NULLABILITY_NULLABLE) + elif kind == "string": + literal = stalg.Expression.Literal(string=value, nullable=type.string.nullability == stp.Type.NULLABILITY_NULLABLE) + else: + raise Exception(f"Unknown literal type - {type}") + + return stee.ExtendedExpression( + referred_expr=[ + stee.ExpressionReference( + expression=stalg.Expression( + literal=literal + ), + output_names=_alias_or_inferred(alias, 'Literal', [str(value)]) + ) + ], + base_schema=base_schema, + ) + + return resolve + +def column(field: Union[str, int], alias: Union[Iterable[str], str] = None): + """Builds a resolver for ExtendedExpression containing a FieldReference expression + + Accepts either an index or a field name of a desired field. + """ + alias = [alias] if alias and isinstance(alias, str) else alias + + def resolve( + base_schema: stp.NamedStruct, registry: ExtensionRegistry + ) -> stee.ExtendedExpression: + lengths = [type_num_names(t) for t in base_schema.struct.types] + flat_indices = [0] + list(itertools.accumulate(lengths))[:-1] + + if isinstance(field, str): + column_index = list(base_schema.names).index(field) + field_index = flat_indices.index(column_index) + else: + field_index = field + + names_start = flat_indices[field_index] + names_end = ( + flat_indices[field_index + 1] + if len(flat_indices) > field_index + 1 + else None + ) + + return stee.ExtendedExpression( + referred_expr=[ + stee.ExpressionReference( + expression=stalg.Expression( + selection=stalg.Expression.FieldReference( + root_reference=stalg.Expression.FieldReference.RootReference(), + direct_reference=stalg.Expression.ReferenceSegment( + struct_field=stalg.Expression.ReferenceSegment.StructField( + field=field_index + ) + ), + ) + ), + output_names=list(base_schema.names)[names_start:names_end] + if not alias + else alias, + ) + ], + base_schema=base_schema, + ) + + return resolve + +def scalar_function( + uri: str, function: str, *expressions: UnboundExtendedExpression, alias: Union[Iterable[str], str] = None +): + """Builds a resolver for ExtendedExpression containing a ScalarFunction expression""" + def resolve( + base_schema: stp.NamedStruct, registry: ExtensionRegistry + ) -> stee.ExtendedExpression: + bound_expressions: Iterable[stee.ExtendedExpression] = [ + e(base_schema, registry) for e in expressions + ] + + expression_schemas = [ + infer_extended_expression_schema(b) for b in bound_expressions + ] + + signature = [typ for es in expression_schemas for typ in es.types] + + func = registry.lookup_function(uri, function, signature) + + if not func: + raise Exception(f"Unknown function {function} for {signature}") + + func_extension_uris = [ + ste.SimpleExtensionURI( + extension_uri_anchor=registry.lookup_uri(uri), uri=uri + ) + ] + + func_extensions = [ + ste.SimpleExtensionDeclaration( + extension_function=ste.SimpleExtensionDeclaration.ExtensionFunction( + extension_uri_reference=registry.lookup_uri(uri), + function_anchor=func[0].anchor, + name=function, + ) + ) + ] + + extension_uris = merge_extension_uris( + func_extension_uris, *[b.extension_uris for b in bound_expressions] + ) + + extensions = merge_extension_declarations( + func_extensions, *[b.extensions for b in bound_expressions] + ) + + return stee.ExtendedExpression( + referred_expr=[ + stee.ExpressionReference( + expression=stalg.Expression( + scalar_function=stalg.Expression.ScalarFunction( + function_reference=func[0].anchor, + arguments=[ + stalg.FunctionArgument( + value=e.referred_expr[0].expression + ) + for e in bound_expressions + ], + output_type=func[1], + ) + ), + output_names=_alias_or_inferred(alias, function, [e.referred_expr[0].output_names[0] for e in bound_expressions]), + ) + ], + base_schema=base_schema, + extension_uris=extension_uris, + extensions=extensions, + ) + + return resolve + +def aggregate_function( + uri: str, function: str, *expressions: UnboundExtendedExpression, alias: Union[Iterable[str], str] = None +): + """Builds a resolver for ExtendedExpression containing a AggregateFunction measure""" + def resolve( + base_schema: stp.NamedStruct, registry: ExtensionRegistry + ) -> stee.ExtendedExpression: + bound_expressions: Iterable[stee.ExtendedExpression] = [ + e(base_schema, registry) for e in expressions + ] + + expression_schemas = [ + infer_extended_expression_schema(b) for b in bound_expressions + ] + + signature = [typ for es in expression_schemas for typ in es.types] + + func = registry.lookup_function(uri, function, signature) + + if not func: + raise Exception(f"Unknown function {function} for {signature}") + + func_extension_uris = [ + ste.SimpleExtensionURI( + extension_uri_anchor=registry.lookup_uri(uri), uri=uri + ) + ] + + func_extensions = [ + ste.SimpleExtensionDeclaration( + extension_function=ste.SimpleExtensionDeclaration.ExtensionFunction( + extension_uri_reference=registry.lookup_uri(uri), + function_anchor=func[0].anchor, + name=function, + ) + ) + ] + + extension_uris = merge_extension_uris( + func_extension_uris, *[b.extension_uris for b in bound_expressions] + ) + + extensions = merge_extension_declarations( + func_extensions, *[b.extensions for b in bound_expressions] + ) + + return stee.ExtendedExpression( + referred_expr=[ + stee.ExpressionReference( + measure=stalg.AggregateFunction( + function_reference=func[0].anchor, + arguments=[ + stalg.FunctionArgument(value=e.referred_expr[0].expression) + for e in bound_expressions + ], + output_type=func[1], + ), + output_names=_alias_or_inferred(alias, 'IfThen', [e.referred_expr[0].output_names[0] for e in bound_expressions]), + ) + ], + base_schema=base_schema, + extension_uris=extension_uris, + extensions=extensions, + ) + + return resolve + + +# TODO bounds, sorts +def window_function( + uri: str, + function: str, + *expressions: UnboundExtendedExpression, + partitions: Iterable[UnboundExtendedExpression] = [], + alias: Union[Iterable[str], str] = None +): + """Builds a resolver for ExtendedExpression containing a WindowFunction expression""" + def resolve( + base_schema: stp.NamedStruct, registry: ExtensionRegistry + ) -> stee.ExtendedExpression: + bound_expressions: Iterable[stee.ExtendedExpression] = [ + e(base_schema, registry) for e in expressions + ] + + bound_partitions = [e(base_schema, registry) for e in partitions] + + expression_schemas = [ + infer_extended_expression_schema(b) for b in bound_expressions + ] + + signature = [typ for es in expression_schemas for typ in es.types] + + func = registry.lookup_function(uri, function, signature) + + if not func: + raise Exception(f"Unknown function {function} for {signature}") + + func_extension_uris = [ + ste.SimpleExtensionURI( + extension_uri_anchor=registry.lookup_uri(uri), uri=uri + ) + ] + + func_extensions = [ + ste.SimpleExtensionDeclaration( + extension_function=ste.SimpleExtensionDeclaration.ExtensionFunction( + extension_uri_reference=registry.lookup_uri(uri), + function_anchor=func[0].anchor, + name=function, + ) + ) + ] + + extension_uris = merge_extension_uris( + func_extension_uris, + *[b.extension_uris for b in bound_expressions], + *[b.extension_uris for b in bound_partitions], + ) + + extensions = merge_extension_declarations( + func_extensions, + *[b.extensions for b in bound_expressions], + *[b.extensions for b in bound_partitions], + ) + + return stee.ExtendedExpression( + referred_expr=[ + stee.ExpressionReference( + expression=stalg.Expression( + window_function=stalg.Expression.WindowFunction( + function_reference=func[0].anchor, + arguments=[ + stalg.FunctionArgument( + value=e.referred_expr[0].expression + ) + for e in bound_expressions + ], + output_type=func[1], + partitions=[ + e.referred_expr[0].expression for e in bound_partitions + ], + ) + ), + output_names=_alias_or_inferred(alias, function, [e.referred_expr[0].output_names[0] for e in bound_expressions]), + ) + ], + base_schema=base_schema, + extension_uris=extension_uris, + extensions=extensions, + ) + + return resolve + + +def if_then(ifs: Iterable[tuple[UnboundExtendedExpression, UnboundExtendedExpression]], _else: UnboundExtendedExpression, alias: Union[Iterable[str], str] = None): + """Builds a resolver for ExtendedExpression containing an IfThen expression""" + def resolve( + base_schema: stp.NamedStruct, registry: ExtensionRegistry + ) -> stee.ExtendedExpression: + bound_ifs = [ + (if_clause[0](base_schema, registry), if_clause[1](base_schema, registry)) + for if_clause in ifs + ] + + bound_else = _else(base_schema, registry) + + extension_uris = merge_extension_uris( + *[b[0].extension_uris for b in bound_ifs], + *[b[1].extension_uris for b in bound_ifs], + bound_else.extension_uris + ) + + extensions = merge_extension_declarations( + *[b[0].extensions for b in bound_ifs], + *[b[1].extensions for b in bound_ifs], + bound_else.extensions + ) + + return stee.ExtendedExpression( + referred_expr=[ + stee.ExpressionReference( + expression=stalg.Expression( + if_then=stalg.Expression.IfThen(**{ + 'ifs': [ + stalg.Expression.IfThen.IfClause(**{ + 'if': if_clause[0].referred_expr[0].expression, + 'then': if_clause[1].referred_expr[0].expression, + }) + for if_clause in bound_ifs + ], + 'else': bound_else.referred_expr[0].expression + }) + ), + output_names=_alias_or_inferred(alias, 'IfThen', [a for e in bound_ifs for a in [e[0].referred_expr[0].output_names[0], e[1].referred_expr[0].output_names[0]]] + + [bound_else.referred_expr[0].output_names[0]] + ), + ) + ], + base_schema=base_schema, + extension_uris=extension_uris, + extensions=extensions, + ) + + return resolve diff --git a/src/substrait/extended_expression.py b/src/substrait/extended_expression.py deleted file mode 100644 index f71762e..0000000 --- a/src/substrait/extended_expression.py +++ /dev/null @@ -1,156 +0,0 @@ -import itertools -import substrait.gen.proto.algebra_pb2 as stalg -import substrait.gen.proto.type_pb2 as stp -import substrait.gen.proto.extended_expression_pb2 as stee -import substrait.gen.proto.extensions.extensions_pb2 as ste -from substrait.function_registry import FunctionRegistry -from substrait.utils import type_num_names, merge_extension_uris, merge_extension_declarations -from substrait.type_inference import infer_extended_expression_schema -from typing import Callable, Any, Union - -UnboundExpression = Callable[[stp.NamedStruct, FunctionRegistry], stee.ExtendedExpression] - -def literal(value: Any, type: stp.Type, alias: str = None) -> UnboundExpression: - """Builds a resolver for ExtendedExpression containing a literal expression""" - def resolve(base_schema: stp.NamedStruct, registry: FunctionRegistry) -> stee.ExtendedExpression: - kind = type.WhichOneof('kind') - - if kind == "bool": - literal = stalg.Expression.Literal(boolean=value, nullable=type.bool.nullability == stp.Type.NULLABILITY_NULLABLE) - elif kind == "i8": - literal = stalg.Expression.Literal(i8=value, nullable=type.i8.nullability == stp.Type.NULLABILITY_NULLABLE) - elif kind == "i16": - literal = stalg.Expression.Literal(i16=value, nullable=type.i16.nullability == stp.Type.NULLABILITY_NULLABLE) - elif kind == "i32": - literal = stalg.Expression.Literal(i32=value, nullable=type.i32.nullability == stp.Type.NULLABILITY_NULLABLE) - elif kind == "i64": - literal = stalg.Expression.Literal(i64=value, nullable=type.i64.nullability == stp.Type.NULLABILITY_NULLABLE) - elif kind == "fp32": - literal = stalg.Expression.Literal(fp32=value, nullable=type.fp32.nullability == stp.Type.NULLABILITY_NULLABLE) - elif kind == "fp64": - literal = stalg.Expression.Literal(fp64=value, nullable=type.fp64.nullability == stp.Type.NULLABILITY_NULLABLE) - elif kind == "string": - literal = stalg.Expression.Literal(string=value, nullable=type.string.nullability == stp.Type.NULLABILITY_NULLABLE) - else: - raise Exception(f"Unknown literal type - {type}") - - return stee.ExtendedExpression( - referred_expr=[ - stee.ExpressionReference( - expression=stalg.Expression( - literal=literal - ), - output_names=[alias if alias else f'literal_{kind}'], - ) - ], - base_schema=base_schema, - ) - - return resolve - -def column(field: Union[str, int]): - """Builds a resolver for ExtendedExpression containing a FieldReference expression - - Accepts either an index or a field name of a desired field. - """ - def resolve(base_schema: stp.NamedStruct, registry: FunctionRegistry) -> stee.ExtendedExpression: - if isinstance(field, str): - column_index = list(base_schema.names).index(field) - lengths = [type_num_names(t) for t in base_schema.struct.types] - flat_indices = [0] + list(itertools.accumulate(lengths))[:-1] - field_index = flat_indices.index(column_index) - else: - field_index = field - - names_start = flat_indices[field_index] - names_end = ( - flat_indices[field_index + 1] - if len(flat_indices) > field_index + 1 - else None - ) - - return stee.ExtendedExpression( - referred_expr=[ - stee.ExpressionReference( - expression=stalg.Expression( - selection=stalg.Expression.FieldReference( - root_reference=stalg.Expression.FieldReference.RootReference(), - direct_reference=stalg.Expression.ReferenceSegment( - struct_field=stalg.Expression.ReferenceSegment.StructField( - field=field_index - ) - ), - ) - ), - output_names=list(base_schema.names)[names_start:names_end], - ) - ], - base_schema=base_schema, - ) - - return resolve - -def scalar_function(uri: str, function: str, *expressions: UnboundExpression, alias: str = None): - """Builds a resolver for ExtendedExpression containing a ScalarFunction expression""" - def resolve(base_schema: stp.NamedStruct, registry: FunctionRegistry) -> stee.ExtendedExpression: - bound_expressions: list[stee.ExtendedExpression] = [e(base_schema, registry) for e in expressions] - - expression_schemas = [infer_extended_expression_schema(b) for b in bound_expressions] - - signature = [typ for es in expression_schemas for typ in es.types] - - func = registry.lookup_function(uri, function, signature) - - if not func: - raise Exception('') - - func_extension_uris = [ - ste.SimpleExtensionURI( - extension_uri_anchor=registry.lookup_uri(uri), - uri=uri - ) - ] - - func_extensions = [ - ste.SimpleExtensionDeclaration( - extension_function=ste.SimpleExtensionDeclaration.ExtensionFunction( - extension_uri_reference=registry.lookup_uri(uri), - function_anchor=func[0].anchor, - name=function - ) - ) - ] - - extension_uris = merge_extension_uris( - func_extension_uris, - *[b.extension_uris for b in bound_expressions] - ) - - extensions = merge_extension_declarations( - func_extensions, - *[b.extensions for b in bound_expressions] - ) - - return stee.ExtendedExpression( - referred_expr=[ - stee.ExpressionReference( - expression=stalg.Expression( - scalar_function=stalg.Expression.ScalarFunction( - function_reference=func[0].anchor, - arguments=[ - stalg.FunctionArgument( - value=e.referred_expr[0].expression - ) for e in bound_expressions - ], - output_type=func[1] - ) - ), - output_names=[alias if alias else 'scalar_function'], - ) - ], - base_schema=base_schema, - extension_uris=extension_uris, - extensions=extensions - ) - - return resolve diff --git a/src/substrait/function_registry.py b/src/substrait/extension_registry.py similarity index 99% rename from src/substrait/function_registry.py rename to src/substrait/extension_registry.py index 7130fce..2774ae3 100644 --- a/src/substrait/function_registry.py +++ b/src/substrait/extension_registry.py @@ -224,7 +224,7 @@ def satisfies_signature(self, signature: tuple) -> Optional[str]: return output_type -class FunctionRegistry: +class ExtensionRegistry: def __init__(self, load_default_extensions=True) -> None: self._uri_mapping: dict = defaultdict(dict) self._uri_id_generator = itertools.count(1) diff --git a/tests/builders/extended_expression/test_aggregate_function.py b/tests/builders/extended_expression/test_aggregate_function.py new file mode 100644 index 0000000..50a2de9 --- /dev/null +++ b/tests/builders/extended_expression/test_aggregate_function.py @@ -0,0 +1,78 @@ +import yaml + +import substrait.gen.proto.algebra_pb2 as stalg +import substrait.gen.proto.type_pb2 as stt +import substrait.gen.proto.extended_expression_pb2 as stee +import substrait.gen.proto.extensions.extensions_pb2 as ste +from substrait.builders.extended_expression import aggregate_function, literal +from substrait.extension_registry import ExtensionRegistry + +struct = stt.Type.Struct( + types=[ + stt.Type(i64=stt.Type.I64(nullability=stt.Type.NULLABILITY_REQUIRED)), + stt.Type(string=stt.Type.String(nullability=stt.Type.NULLABILITY_NULLABLE)), + stt.Type(fp32=stt.Type.FP32(nullability=stt.Type.NULLABILITY_NULLABLE)), + ] +) + +named_struct = stt.NamedStruct( + names=["order_id", "description", "order_total"], struct=struct +) + +content = """%YAML 1.2 +--- +aggregate_functions: + - name: "count" + description: Count a set of values + impls: + - args: + - name: x + value: any + nullability: DECLARED_OUTPUT + decomposable: MANY + intermediate: i64 + return: i64 +""" + + +registry = ExtensionRegistry(load_default_extensions=False) +registry.register_extension_dict(yaml.safe_load(content), uri="test_uri") + +def test_aggregate_count(): + e = aggregate_function('test_uri', 'count', + literal(10, type=stt.Type(i8=stt.Type.I8(nullability=stt.Type.NULLABILITY_REQUIRED))), + alias='count', + )(named_struct, registry) + + expected = stee.ExtendedExpression( + extension_uris=[ + ste.SimpleExtensionURI( + extension_uri_anchor=1, + uri='test_uri' + ) + ], + extensions=[ + ste.SimpleExtensionDeclaration( + extension_function=ste.SimpleExtensionDeclaration.ExtensionFunction( + extension_uri_reference=1, + function_anchor=1, + name='count' + ) + ) + ], + referred_expr=[ + stee.ExpressionReference( + measure=stalg.AggregateFunction( + function_reference=1, + arguments=[ + stalg.FunctionArgument(value=stalg.Expression(literal=stalg.Expression.Literal(i8=10, nullable=False))), + ], + output_type=stt.Type(i64=stt.Type.I64(nullability=stt.Type.NULLABILITY_REQUIRED)) + ), + output_names=["count"], + ) + ], + base_schema=named_struct, + ) + + assert e == expected diff --git a/tests/extended_expression/test_column.py b/tests/builders/extended_expression/test_column.py similarity index 98% rename from tests/extended_expression/test_column.py rename to tests/builders/extended_expression/test_column.py index 92cf568..306fe70 100644 --- a/tests/extended_expression/test_column.py +++ b/tests/builders/extended_expression/test_column.py @@ -1,7 +1,7 @@ import substrait.gen.proto.algebra_pb2 as stalg import substrait.gen.proto.type_pb2 as stt import substrait.gen.proto.extended_expression_pb2 as stee -from substrait.extended_expression import column +from substrait.builders.extended_expression import column struct = stt.Type.Struct( diff --git a/tests/builders/extended_expression/test_if_then.py b/tests/builders/extended_expression/test_if_then.py new file mode 100644 index 0000000..1adeb6d --- /dev/null +++ b/tests/builders/extended_expression/test_if_then.py @@ -0,0 +1,50 @@ +import substrait.gen.proto.algebra_pb2 as stalg +import substrait.gen.proto.type_pb2 as stt +import substrait.gen.proto.extended_expression_pb2 as stee +from substrait.builders.extended_expression import if_then, literal + + +struct = stt.Type.Struct( + types=[ + stt.Type(i64=stt.Type.I64(nullability=stt.Type.NULLABILITY_REQUIRED)), + stt.Type(string=stt.Type.String(nullability=stt.Type.NULLABILITY_NULLABLE)), + stt.Type(fp32=stt.Type.FP32(nullability=stt.Type.NULLABILITY_NULLABLE)), + ] +) + +named_struct = stt.NamedStruct( + names=["order_id", "description", "order_total"], struct=struct +) + +def test_if_else(): + actual = if_then( + ifs=[ + ( + literal(True, type=stt.Type(bool=stt.Type.Boolean(nullability=stt.Type.NULLABILITY_REQUIRED))), + literal(10, type=stt.Type(i8=stt.Type.I8(nullability=stt.Type.NULLABILITY_REQUIRED))) + ) + ], + _else=literal(20, type=stt.Type(i8=stt.Type.I8(nullability=stt.Type.NULLABILITY_REQUIRED))) + )(named_struct, None) + + expected = stee.ExtendedExpression( + referred_expr=[ + stee.ExpressionReference( + expression=stalg.Expression( + if_then=stalg.Expression.IfThen(**{ + 'ifs': [ + stalg.Expression.IfThen.IfClause(**{ + 'if': stalg.Expression(literal=stalg.Expression.Literal(boolean=True, nullable=False)), + 'then': stalg.Expression(literal=stalg.Expression.Literal(i8=10, nullable=False)) + }) + ], + 'else': stalg.Expression(literal=stalg.Expression.Literal(i8=20, nullable=False)) + }) + ), + output_names=["IfThen(Literal(True),Literal(10),Literal(20))"], + ) + ], + base_schema=named_struct, + ) + + assert actual == expected diff --git a/tests/extended_expression/test_scalar_function.py b/tests/builders/extended_expression/test_scalar_function.py similarity index 94% rename from tests/extended_expression/test_scalar_function.py rename to tests/builders/extended_expression/test_scalar_function.py index 0f3fb8c..26aba8e 100644 --- a/tests/extended_expression/test_scalar_function.py +++ b/tests/builders/extended_expression/test_scalar_function.py @@ -4,8 +4,8 @@ import substrait.gen.proto.type_pb2 as stt import substrait.gen.proto.extended_expression_pb2 as stee import substrait.gen.proto.extensions.extensions_pb2 as ste -from substrait.extended_expression import scalar_function, literal -from substrait.function_registry import FunctionRegistry +from substrait.builders.extended_expression import scalar_function, literal +from substrait.extension_registry import ExtensionRegistry struct = stt.Type.Struct( types=[ @@ -39,14 +39,13 @@ """ -registry = FunctionRegistry(load_default_extensions=False) +registry = ExtensionRegistry(load_default_extensions=False) registry.register_extension_dict(yaml.safe_load(content), uri="test_uri") def test_sclar_add(): e = scalar_function('test_uri', 'test_func', literal(10, type=stt.Type(i8=stt.Type.I8(nullability=stt.Type.NULLABILITY_REQUIRED))), - literal(20, type=stt.Type(i8=stt.Type.I8(nullability=stt.Type.NULLABILITY_REQUIRED))), - alias='sum', + literal(20, type=stt.Type(i8=stt.Type.I8(nullability=stt.Type.NULLABILITY_REQUIRED))) )(named_struct, registry) expected = stee.ExtendedExpression( @@ -77,7 +76,7 @@ def test_sclar_add(): output_type=stt.Type(i8=stt.Type.I8(nullability=stt.Type.NULLABILITY_REQUIRED)) ) ), - output_names=["sum"], + output_names=["test_func(Literal(10),Literal(20))"], ) ], base_schema=named_struct, diff --git a/tests/builders/extended_expression/test_window_function.py b/tests/builders/extended_expression/test_window_function.py new file mode 100644 index 0000000..9e3bd00 --- /dev/null +++ b/tests/builders/extended_expression/test_window_function.py @@ -0,0 +1,80 @@ +import yaml + +import substrait.gen.proto.algebra_pb2 as stalg +import substrait.gen.proto.type_pb2 as stt +import substrait.gen.proto.extended_expression_pb2 as stee +import substrait.gen.proto.extensions.extensions_pb2 as ste +from substrait.builders.extended_expression import window_function, literal +from substrait.extension_registry import ExtensionRegistry + +struct = stt.Type.Struct( + types=[ + stt.Type(i64=stt.Type.I64(nullability=stt.Type.NULLABILITY_REQUIRED)), + stt.Type(string=stt.Type.String(nullability=stt.Type.NULLABILITY_NULLABLE)), + stt.Type(fp32=stt.Type.FP32(nullability=stt.Type.NULLABILITY_NULLABLE)), + ] +) + +named_struct = stt.NamedStruct( + names=["order_id", "description", "order_total"], struct=struct +) + +content = """%YAML 1.2 +--- +window_functions: + - name: "row_number" + description: "the number of the current row within its partition, starting at 1" + impls: + - args: [] + nullability: DECLARED_OUTPUT + decomposable: NONE + return: i64? + window_type: PARTITION + - name: "rank" + description: "the rank of the current row, with gaps." + impls: + - args: [] + nullability: DECLARED_OUTPUT + decomposable: NONE + return: i64? + window_type: PARTITION +""" + + +registry = ExtensionRegistry(load_default_extensions=False) +registry.register_extension_dict(yaml.safe_load(content), uri="test_uri") + +def test_row_number(): + e = window_function('test_uri', 'row_number', alias='rn')(named_struct, registry) + + expected = stee.ExtendedExpression( + extension_uris=[ + ste.SimpleExtensionURI( + extension_uri_anchor=1, + uri='test_uri' + ) + ], + extensions=[ + ste.SimpleExtensionDeclaration( + extension_function=ste.SimpleExtensionDeclaration.ExtensionFunction( + extension_uri_reference=1, + function_anchor=1, + name='row_number' + ) + ) + ], + referred_expr=[ + stee.ExpressionReference( + expression=stalg.Expression( + window_function=stalg.Expression.WindowFunction( + function_reference=1, + output_type=stt.Type(i64=stt.Type.I64(nullability=stt.Type.NULLABILITY_NULLABLE)) + ) + ), + output_names=["rn"], + ) + ], + base_schema=named_struct, + ) + + assert e == expected diff --git a/tests/test_function_registry.py b/tests/test_function_registry.py index ef7387e..14a227e 100644 --- a/tests/test_function_registry.py +++ b/tests/test_function_registry.py @@ -1,7 +1,7 @@ import yaml from substrait.gen.proto.type_pb2 import Type -from substrait.function_registry import FunctionRegistry, covers +from substrait.extension_registry import ExtensionRegistry, covers from substrait.derivation_expression import _parse content = """%YAML 1.2 @@ -105,7 +105,7 @@ """ -registry = FunctionRegistry() +registry = ExtensionRegistry() registry.register_extension_dict(yaml.safe_load(content), uri="test")