diff --git a/src/substrait/extended_expression.py b/src/substrait/extended_expression.py new file mode 100644 index 0000000..f71762e --- /dev/null +++ b/src/substrait/extended_expression.py @@ -0,0 +1,156 @@ +import itertools +import substrait.gen.proto.algebra_pb2 as stalg +import substrait.gen.proto.type_pb2 as stp +import substrait.gen.proto.extended_expression_pb2 as stee +import substrait.gen.proto.extensions.extensions_pb2 as ste +from substrait.function_registry import FunctionRegistry +from substrait.utils import type_num_names, merge_extension_uris, merge_extension_declarations +from substrait.type_inference import infer_extended_expression_schema +from typing import Callable, Any, Union + +UnboundExpression = Callable[[stp.NamedStruct, FunctionRegistry], stee.ExtendedExpression] + +def literal(value: Any, type: stp.Type, alias: str = None) -> UnboundExpression: + """Builds a resolver for ExtendedExpression containing a literal expression""" + def resolve(base_schema: stp.NamedStruct, registry: FunctionRegistry) -> stee.ExtendedExpression: + kind = type.WhichOneof('kind') + + if kind == "bool": + literal = stalg.Expression.Literal(boolean=value, nullable=type.bool.nullability == stp.Type.NULLABILITY_NULLABLE) + elif kind == "i8": + literal = stalg.Expression.Literal(i8=value, nullable=type.i8.nullability == stp.Type.NULLABILITY_NULLABLE) + elif kind == "i16": + literal = stalg.Expression.Literal(i16=value, nullable=type.i16.nullability == stp.Type.NULLABILITY_NULLABLE) + elif kind == "i32": + literal = stalg.Expression.Literal(i32=value, nullable=type.i32.nullability == stp.Type.NULLABILITY_NULLABLE) + elif kind == "i64": + literal = stalg.Expression.Literal(i64=value, nullable=type.i64.nullability == stp.Type.NULLABILITY_NULLABLE) + elif kind == "fp32": + literal = stalg.Expression.Literal(fp32=value, nullable=type.fp32.nullability == stp.Type.NULLABILITY_NULLABLE) + elif kind == "fp64": + literal = stalg.Expression.Literal(fp64=value, nullable=type.fp64.nullability == stp.Type.NULLABILITY_NULLABLE) + elif kind == "string": + literal = stalg.Expression.Literal(string=value, nullable=type.string.nullability == stp.Type.NULLABILITY_NULLABLE) + else: + raise Exception(f"Unknown literal type - {type}") + + return stee.ExtendedExpression( + referred_expr=[ + stee.ExpressionReference( + expression=stalg.Expression( + literal=literal + ), + output_names=[alias if alias else f'literal_{kind}'], + ) + ], + base_schema=base_schema, + ) + + return resolve + +def column(field: Union[str, int]): + """Builds a resolver for ExtendedExpression containing a FieldReference expression + + Accepts either an index or a field name of a desired field. + """ + def resolve(base_schema: stp.NamedStruct, registry: FunctionRegistry) -> stee.ExtendedExpression: + if isinstance(field, str): + column_index = list(base_schema.names).index(field) + lengths = [type_num_names(t) for t in base_schema.struct.types] + flat_indices = [0] + list(itertools.accumulate(lengths))[:-1] + field_index = flat_indices.index(column_index) + else: + field_index = field + + names_start = flat_indices[field_index] + names_end = ( + flat_indices[field_index + 1] + if len(flat_indices) > field_index + 1 + else None + ) + + return stee.ExtendedExpression( + referred_expr=[ + stee.ExpressionReference( + expression=stalg.Expression( + selection=stalg.Expression.FieldReference( + root_reference=stalg.Expression.FieldReference.RootReference(), + direct_reference=stalg.Expression.ReferenceSegment( + struct_field=stalg.Expression.ReferenceSegment.StructField( + field=field_index + ) + ), + ) + ), + output_names=list(base_schema.names)[names_start:names_end], + ) + ], + base_schema=base_schema, + ) + + return resolve + +def scalar_function(uri: str, function: str, *expressions: UnboundExpression, alias: str = None): + """Builds a resolver for ExtendedExpression containing a ScalarFunction expression""" + def resolve(base_schema: stp.NamedStruct, registry: FunctionRegistry) -> stee.ExtendedExpression: + bound_expressions: list[stee.ExtendedExpression] = [e(base_schema, registry) for e in expressions] + + expression_schemas = [infer_extended_expression_schema(b) for b in bound_expressions] + + signature = [typ for es in expression_schemas for typ in es.types] + + func = registry.lookup_function(uri, function, signature) + + if not func: + raise Exception('') + + func_extension_uris = [ + ste.SimpleExtensionURI( + extension_uri_anchor=registry.lookup_uri(uri), + uri=uri + ) + ] + + func_extensions = [ + ste.SimpleExtensionDeclaration( + extension_function=ste.SimpleExtensionDeclaration.ExtensionFunction( + extension_uri_reference=registry.lookup_uri(uri), + function_anchor=func[0].anchor, + name=function + ) + ) + ] + + extension_uris = merge_extension_uris( + func_extension_uris, + *[b.extension_uris for b in bound_expressions] + ) + + extensions = merge_extension_declarations( + func_extensions, + *[b.extensions for b in bound_expressions] + ) + + return stee.ExtendedExpression( + referred_expr=[ + stee.ExpressionReference( + expression=stalg.Expression( + scalar_function=stalg.Expression.ScalarFunction( + function_reference=func[0].anchor, + arguments=[ + stalg.FunctionArgument( + value=e.referred_expr[0].expression + ) for e in bound_expressions + ], + output_type=func[1] + ) + ), + output_names=[alias if alias else 'scalar_function'], + ) + ], + base_schema=base_schema, + extension_uris=extension_uris, + extensions=extensions + ) + + return resolve diff --git a/src/substrait/function_registry.py b/src/substrait/function_registry.py index 101f2d7..7130fce 100644 --- a/src/substrait/function_registry.py +++ b/src/substrait/function_registry.py @@ -1,4 +1,3 @@ -from substrait.gen.proto.parameterized_types_pb2 import ParameterizedType from substrait.gen.proto.type_pb2 import Type from importlib.resources import files as importlib_files import itertools @@ -227,6 +226,9 @@ def satisfies_signature(self, signature: tuple) -> Optional[str]: class FunctionRegistry: def __init__(self, load_default_extensions=True) -> None: + self._uri_mapping: dict = defaultdict(dict) + self._uri_id_generator = itertools.count(1) + self._function_mapping: dict = defaultdict(dict) self._id_generator = itertools.count(1) @@ -252,6 +254,8 @@ def register_extension_yaml( self.register_extension_dict(extension_definitions, uri) def register_extension_dict(self, definitions: dict, uri: str) -> None: + self._uri_mapping[uri] = next(self._uri_id_generator) + for named_functions in definitions.values(): for function in named_functions: for impl in function.get("impls", []): @@ -285,3 +289,7 @@ def lookup_function( return (f, rtn) return None + + def lookup_uri(self, uri: str) -> Optional[int]: + uri = self._uri_aliases.get(uri, uri) + return self._uri_mapping.get(uri, None) diff --git a/src/substrait/type_inference.py b/src/substrait/type_inference.py index 080af9a..082da29 100644 --- a/src/substrait/type_inference.py +++ b/src/substrait/type_inference.py @@ -1,4 +1,5 @@ import substrait.gen.proto.algebra_pb2 as stalg +import substrait.gen.proto.extended_expression_pb2 as stee import substrait.gen.proto.type_pb2 as stt @@ -220,6 +221,17 @@ def infer_expression_type( raise Exception(f"Unknown rex_type {rex_type}") +def infer_extended_expression_schema(ee: stee.ExtendedExpression) -> stt.Type.Struct: + exprs = [e for e in ee.referred_expr] + + types = [infer_expression_type(e.expression, ee.base_schema.struct) for e in exprs] + + return stt.Type.Struct( + types=types, + nullability=stt.Type.NULLABILITY_REQUIRED, + ) + + def infer_rel_schema(rel: stalg.Rel) -> stt.Type.Struct: rel_type = rel.WhichOneof("rel_type") diff --git a/src/substrait/utils.py b/src/substrait/utils.py new file mode 100644 index 0000000..0f8764b --- /dev/null +++ b/src/substrait/utils.py @@ -0,0 +1,53 @@ +import substrait.gen.proto.type_pb2 as stp +import substrait.gen.proto.extensions.extensions_pb2 as ste +from typing import Iterable + +def type_num_names(typ: stp.Type): + kind = typ.WhichOneof("kind") + if kind == "struct": + lengths = [type_num_names(t) for t in typ.struct.types] + return sum(lengths) + 1 + elif kind == "list": + return type_num_names(typ.list.type) + elif kind == "map": + return type_num_names(typ.map.key) + type_num_names(typ.map.value) + else: + return 1 + +def merge_extension_uris(*extension_uris: Iterable[ste.SimpleExtensionURI]): + """Merges multiple sets of SimpleExtensionURI objects into a single set. + The order of extensions is kept intact, while duplicates are discarded. + Assumes that there are no collisions (different extensions having identical anchors). + """ + seen_uris = set() + ret = [] + + for uris in extension_uris: + for uri in uris: + if uri.uri not in seen_uris: + seen_uris.add(uri.uri) + ret.append(uri) + + return ret + +def merge_extension_declarations(*extension_declarations: Iterable[ste.SimpleExtensionDeclaration]): + """Merges multiple sets of SimpleExtensionDeclaration objects into a single set. + The order of extension declarations is kept intact, while duplicates are discarded. + Assumes that there are no collisions (different extension declarations having identical anchors). + """ + + seen_extension_functions = set() + ret = [] + + for declarations in extension_declarations: + for declaration in declarations: + if declaration.WhichOneof('mapping_type') == 'extension_function': + ident = (declaration.extension_function.extension_uri_reference, declaration.extension_function.name) + if ident not in seen_extension_functions: + seen_extension_functions.add(ident) + ret.append(declaration) + else: + raise Exception('') #TODO handle extension types + + return ret + \ No newline at end of file diff --git a/tests/extended_expression/test_column.py b/tests/extended_expression/test_column.py new file mode 100644 index 0000000..92cf568 --- /dev/null +++ b/tests/extended_expression/test_column.py @@ -0,0 +1,105 @@ +import substrait.gen.proto.algebra_pb2 as stalg +import substrait.gen.proto.type_pb2 as stt +import substrait.gen.proto.extended_expression_pb2 as stee +from substrait.extended_expression import column + + +struct = stt.Type.Struct( + types=[ + stt.Type(i64=stt.Type.I64(nullability=stt.Type.NULLABILITY_REQUIRED)), + stt.Type(string=stt.Type.String(nullability=stt.Type.NULLABILITY_NULLABLE)), + stt.Type(fp32=stt.Type.FP32(nullability=stt.Type.NULLABILITY_NULLABLE)), + ] +) + +named_struct = stt.NamedStruct( + names=["order_id", "description", "order_total"], struct=struct +) + +nested_struct = stt.Type.Struct( + types=[ + stt.Type(i64=stt.Type.I64(nullability=stt.Type.NULLABILITY_REQUIRED)), + stt.Type( + struct=stt.Type.Struct( + types=[ + stt.Type( + i64=stt.Type.I64(nullability=stt.Type.NULLABILITY_REQUIRED) + ), + stt.Type( + fp32=stt.Type.FP32(nullability=stt.Type.NULLABILITY_NULLABLE) + ), + ], + nullability=stt.Type.NULLABILITY_NULLABLE, + ) + ), + stt.Type(fp32=stt.Type.FP32(nullability=stt.Type.NULLABILITY_NULLABLE)), + ] +) + +nested_named_struct = stt.NamedStruct( + names=["order_id", "shop_details", "shop_id", "shop_total", "order_total"], + struct=nested_struct, +) + + +def test_column_no_nesting(): + assert column("description")(named_struct, None) == stee.ExtendedExpression( + referred_expr=[ + stee.ExpressionReference( + expression=stalg.Expression( + selection=stalg.Expression.FieldReference( + root_reference=stalg.Expression.FieldReference.RootReference(), + direct_reference=stalg.Expression.ReferenceSegment( + struct_field=stalg.Expression.ReferenceSegment.StructField( + field=1 + ) + ), + ) + ), + output_names=["description"], + ) + ], + base_schema=named_struct, + ) + + +def test_column_nesting(): + assert column("order_total")(nested_named_struct, None) == stee.ExtendedExpression( + referred_expr=[ + stee.ExpressionReference( + expression=stalg.Expression( + selection=stalg.Expression.FieldReference( + root_reference=stalg.Expression.FieldReference.RootReference(), + direct_reference=stalg.Expression.ReferenceSegment( + struct_field=stalg.Expression.ReferenceSegment.StructField( + field=2 + ) + ), + ) + ), + output_names=["order_total"], + ) + ], + base_schema=nested_named_struct, + ) + + +def test_column_nested_struct(): + assert column("shop_details")(nested_named_struct, None) == stee.ExtendedExpression( + referred_expr=[ + stee.ExpressionReference( + expression=stalg.Expression( + selection=stalg.Expression.FieldReference( + root_reference=stalg.Expression.FieldReference.RootReference(), + direct_reference=stalg.Expression.ReferenceSegment( + struct_field=stalg.Expression.ReferenceSegment.StructField( + field=1 + ) + ), + ) + ), + output_names=["shop_details", "shop_id", "shop_total"], + ) + ], + base_schema=nested_named_struct, + ) diff --git a/tests/extended_expression/test_scalar_function.py b/tests/extended_expression/test_scalar_function.py new file mode 100644 index 0000000..0f3fb8c --- /dev/null +++ b/tests/extended_expression/test_scalar_function.py @@ -0,0 +1,148 @@ +import yaml + +import substrait.gen.proto.algebra_pb2 as stalg +import substrait.gen.proto.type_pb2 as stt +import substrait.gen.proto.extended_expression_pb2 as stee +import substrait.gen.proto.extensions.extensions_pb2 as ste +from substrait.extended_expression import scalar_function, literal +from substrait.function_registry import FunctionRegistry + +struct = stt.Type.Struct( + types=[ + stt.Type(i64=stt.Type.I64(nullability=stt.Type.NULLABILITY_REQUIRED)), + stt.Type(string=stt.Type.String(nullability=stt.Type.NULLABILITY_NULLABLE)), + stt.Type(fp32=stt.Type.FP32(nullability=stt.Type.NULLABILITY_NULLABLE)), + ] +) + +named_struct = stt.NamedStruct( + names=["order_id", "description", "order_total"], struct=struct +) + +content = """%YAML 1.2 +--- +scalar_functions: + - name: "test_func" + description: "" + impls: + - args: + - value: i8 + variadic: + min: 2 + return: i8 + - name: "is_positive" + description: "" + impls: + - args: + - value: i8 + return: boolean +""" + + +registry = FunctionRegistry(load_default_extensions=False) +registry.register_extension_dict(yaml.safe_load(content), uri="test_uri") + +def test_sclar_add(): + e = scalar_function('test_uri', 'test_func', + literal(10, type=stt.Type(i8=stt.Type.I8(nullability=stt.Type.NULLABILITY_REQUIRED))), + literal(20, type=stt.Type(i8=stt.Type.I8(nullability=stt.Type.NULLABILITY_REQUIRED))), + alias='sum', + )(named_struct, registry) + + expected = stee.ExtendedExpression( + extension_uris=[ + ste.SimpleExtensionURI( + extension_uri_anchor=1, + uri='test_uri' + ) + ], + extensions=[ + ste.SimpleExtensionDeclaration( + extension_function=ste.SimpleExtensionDeclaration.ExtensionFunction( + extension_uri_reference=1, + function_anchor=1, + name='test_func' + ) + ) + ], + referred_expr=[ + stee.ExpressionReference( + expression=stalg.Expression( + scalar_function=stalg.Expression.ScalarFunction( + function_reference=1, + arguments=[ + stalg.FunctionArgument(value=stalg.Expression(literal=stalg.Expression.Literal(i8=10, nullable=False))), + stalg.FunctionArgument(value=stalg.Expression(literal=stalg.Expression.Literal(i8=20, nullable=False))) + ], + output_type=stt.Type(i8=stt.Type.I8(nullability=stt.Type.NULLABILITY_REQUIRED)) + ) + ), + output_names=["sum"], + ) + ], + base_schema=named_struct, + ) + + assert e == expected + + +def test_nested_scalar_calls(): + e = scalar_function('test_uri', 'is_positive', + scalar_function('test_uri', 'test_func', + literal(10, type=stt.Type(i8=stt.Type.I8(nullability=stt.Type.NULLABILITY_REQUIRED))), + literal(20, type=stt.Type(i8=stt.Type.I8(nullability=stt.Type.NULLABILITY_REQUIRED)))), + alias='positive' + )(named_struct, registry) + + expected = stee.ExtendedExpression( + extension_uris=[ + ste.SimpleExtensionURI( + extension_uri_anchor=1, + uri='test_uri' + ) + ], + extensions=[ + ste.SimpleExtensionDeclaration( + extension_function=ste.SimpleExtensionDeclaration.ExtensionFunction( + extension_uri_reference=1, + function_anchor=2, + name='is_positive' + ) + ), + ste.SimpleExtensionDeclaration( + extension_function=ste.SimpleExtensionDeclaration.ExtensionFunction( + extension_uri_reference=1, + function_anchor=1, + name='test_func' + ) + ) + ], + referred_expr=[ + stee.ExpressionReference( + expression=stalg.Expression( + scalar_function=stalg.Expression.ScalarFunction( + function_reference=2, + arguments=[ + stalg.FunctionArgument( + value=stalg.Expression( + scalar_function=stalg.Expression.ScalarFunction( + function_reference=1, + arguments=[ + stalg.FunctionArgument(value=stalg.Expression(literal=stalg.Expression.Literal(i8=10, nullable=False))), + stalg.FunctionArgument(value=stalg.Expression(literal=stalg.Expression.Literal(i8=20, nullable=False))) + ], + output_type=stt.Type(i8=stt.Type.I8(nullability=stt.Type.NULLABILITY_REQUIRED)) + ) + ) + ) + ], + output_type=stt.Type(bool=stt.Type.Boolean(nullability=stt.Type.NULLABILITY_REQUIRED)) + ) + ), + output_names=["positive"], + ) + ], + base_schema=named_struct, + ) + + assert e == expected diff --git a/tests/test_utils.py b/tests/test_utils.py new file mode 100644 index 0000000..6043364 --- /dev/null +++ b/tests/test_utils.py @@ -0,0 +1,88 @@ +import substrait.gen.proto.type_pb2 as stt +from substrait.utils import type_num_names + + +def test_type_num_names_flat_struct(): + assert ( + type_num_names( + stt.Type( + struct=stt.Type.Struct( + types=[ + stt.Type(i64=stt.Type.I64()), + stt.Type(string=stt.Type.String()), + stt.Type(fp32=stt.Type.FP32()), + ] + ) + ) + ) + == 4 + ) + + +def test_type_num_names_nested_struct(): + assert ( + type_num_names( + stt.Type( + struct=stt.Type.Struct( + types=[ + stt.Type(i64=stt.Type.I64()), + stt.Type( + struct=stt.Type.Struct( + types=[ + stt.Type(i64=stt.Type.I64()), + stt.Type(fp32=stt.Type.FP32()), + ] + ) + ), + stt.Type(fp32=stt.Type.FP32()), + ] + ) + ) + ) + == 6 + ) + + +def test_type_num_names_flat_list(): + assert ( + type_num_names( + stt.Type( + struct=stt.Type.Struct( + types=[ + stt.Type(i64=stt.Type.I64()), + stt.Type(list=stt.Type.List(type=stt.Type(i64=stt.Type.I64()))), + stt.Type(fp32=stt.Type.FP32()), + ] + ) + ) + ) + == 4 + ) + + +def test_type_num_names_nested_list(): + assert ( + type_num_names( + stt.Type( + struct=stt.Type.Struct( + types=[ + stt.Type(i64=stt.Type.I64()), + stt.Type( + list=stt.Type.List( + type=stt.Type( + struct=stt.Type.Struct( + types=[ + stt.Type(i64=stt.Type.I64()), + stt.Type(fp32=stt.Type.FP32()), + ] + ) + ) + ) + ), + stt.Type(fp32=stt.Type.FP32()), + ] + ) + ) + ) + == 6 + )