-
Notifications
You must be signed in to change notification settings - Fork 9
feat: extended expression builders #71
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
0ab87e9
444da04
7ea05b8
e5168b2
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,156 @@ | ||
import itertools | ||
import substrait.gen.proto.algebra_pb2 as stalg | ||
import substrait.gen.proto.type_pb2 as stp | ||
import substrait.gen.proto.extended_expression_pb2 as stee | ||
import substrait.gen.proto.extensions.extensions_pb2 as ste | ||
from substrait.function_registry import FunctionRegistry | ||
from substrait.utils import type_num_names, merge_extension_uris, merge_extension_declarations | ||
from substrait.type_inference import infer_extended_expression_schema | ||
from typing import Callable, Any, Union | ||
|
||
UnboundExpression = Callable[[stp.NamedStruct, FunctionRegistry], stee.ExtendedExpression] | ||
|
||
def literal(value: Any, type: stp.Type, alias: str = None) -> UnboundExpression: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Please add comments for these methods (for instance, note how names are used). |
||
"""Builds a resolver for ExtendedExpression containing a literal expression""" | ||
def resolve(base_schema: stp.NamedStruct, registry: FunctionRegistry) -> stee.ExtendedExpression: | ||
kind = type.WhichOneof('kind') | ||
|
||
if kind == "bool": | ||
literal = stalg.Expression.Literal(boolean=value, nullable=type.bool.nullability == stp.Type.NULLABILITY_NULLABLE) | ||
elif kind == "i8": | ||
literal = stalg.Expression.Literal(i8=value, nullable=type.i8.nullability == stp.Type.NULLABILITY_NULLABLE) | ||
elif kind == "i16": | ||
literal = stalg.Expression.Literal(i16=value, nullable=type.i16.nullability == stp.Type.NULLABILITY_NULLABLE) | ||
elif kind == "i32": | ||
literal = stalg.Expression.Literal(i32=value, nullable=type.i32.nullability == stp.Type.NULLABILITY_NULLABLE) | ||
elif kind == "i64": | ||
literal = stalg.Expression.Literal(i64=value, nullable=type.i64.nullability == stp.Type.NULLABILITY_NULLABLE) | ||
elif kind == "fp32": | ||
literal = stalg.Expression.Literal(fp32=value, nullable=type.fp32.nullability == stp.Type.NULLABILITY_NULLABLE) | ||
elif kind == "fp64": | ||
literal = stalg.Expression.Literal(fp64=value, nullable=type.fp64.nullability == stp.Type.NULLABILITY_NULLABLE) | ||
elif kind == "string": | ||
literal = stalg.Expression.Literal(string=value, nullable=type.string.nullability == stp.Type.NULLABILITY_NULLABLE) | ||
else: | ||
raise Exception(f"Unknown literal type - {type}") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There are also some unimplemented types here (such as dates, times, binary, uuid, etc.). Is it worth providing guidance here as to how to request/implement support? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. that's a good point. I'll follow up with those implementations shortly and let's discuss this in that PR, if that's okay. Honestly I'm not sure how a literal builder for a "complex" type (something like IntervalCompound for example) should look like. |
||
|
||
return stee.ExtendedExpression( | ||
referred_expr=[ | ||
stee.ExpressionReference( | ||
expression=stalg.Expression( | ||
literal=literal | ||
), | ||
output_names=[alias if alias else f'literal_{kind}'], | ||
) | ||
], | ||
base_schema=base_schema, | ||
) | ||
|
||
return resolve | ||
|
||
def column(field: Union[str, int]): | ||
"""Builds a resolver for ExtendedExpression containing a FieldReference expression | ||
|
||
Accepts either an index or a field name of a desired field. | ||
""" | ||
def resolve(base_schema: stp.NamedStruct, registry: FunctionRegistry) -> stee.ExtendedExpression: | ||
if isinstance(field, str): | ||
column_index = list(base_schema.names).index(field) | ||
lengths = [type_num_names(t) for t in base_schema.struct.types] | ||
flat_indices = [0] + list(itertools.accumulate(lengths))[:-1] | ||
field_index = flat_indices.index(column_index) | ||
else: | ||
field_index = field | ||
|
||
names_start = flat_indices[field_index] | ||
names_end = ( | ||
flat_indices[field_index + 1] | ||
if len(flat_indices) > field_index + 1 | ||
else None | ||
) | ||
|
||
return stee.ExtendedExpression( | ||
referred_expr=[ | ||
stee.ExpressionReference( | ||
expression=stalg.Expression( | ||
selection=stalg.Expression.FieldReference( | ||
root_reference=stalg.Expression.FieldReference.RootReference(), | ||
direct_reference=stalg.Expression.ReferenceSegment( | ||
struct_field=stalg.Expression.ReferenceSegment.StructField( | ||
field=field_index | ||
) | ||
), | ||
) | ||
), | ||
output_names=list(base_schema.names)[names_start:names_end], | ||
) | ||
], | ||
base_schema=base_schema, | ||
) | ||
|
||
return resolve | ||
|
||
def scalar_function(uri: str, function: str, *expressions: UnboundExpression, alias: str = None): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should this accept regular expressions and extended expressions as well? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think there are 3 possible types to be used here:
I guess we can probably introduce a type like There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The use case I'd like to enable is modifying an existing plan. I suppose you could convert an existing plan into a series of extended expressions. We don't need to do this right away though. A set of builders that work from scratch is a reasonable place to start and we can look at the other use cases later. |
||
"""Builds a resolver for ExtendedExpression containing a ScalarFunction expression""" | ||
def resolve(base_schema: stp.NamedStruct, registry: FunctionRegistry) -> stee.ExtendedExpression: | ||
bound_expressions: list[stee.ExtendedExpression] = [e(base_schema, registry) for e in expressions] | ||
|
||
expression_schemas = [infer_extended_expression_schema(b) for b in bound_expressions] | ||
|
||
signature = [typ for es in expression_schemas for typ in es.types] | ||
|
||
func = registry.lookup_function(uri, function, signature) | ||
|
||
if not func: | ||
raise Exception('') | ||
|
||
func_extension_uris = [ | ||
ste.SimpleExtensionURI( | ||
extension_uri_anchor=registry.lookup_uri(uri), | ||
uri=uri | ||
) | ||
] | ||
|
||
func_extensions = [ | ||
ste.SimpleExtensionDeclaration( | ||
extension_function=ste.SimpleExtensionDeclaration.ExtensionFunction( | ||
extension_uri_reference=registry.lookup_uri(uri), | ||
function_anchor=func[0].anchor, | ||
name=function | ||
) | ||
) | ||
] | ||
|
||
extension_uris = merge_extension_uris( | ||
func_extension_uris, | ||
*[b.extension_uris for b in bound_expressions] | ||
) | ||
|
||
extensions = merge_extension_declarations( | ||
func_extensions, | ||
*[b.extensions for b in bound_expressions] | ||
) | ||
|
||
return stee.ExtendedExpression( | ||
referred_expr=[ | ||
stee.ExpressionReference( | ||
expression=stalg.Expression( | ||
scalar_function=stalg.Expression.ScalarFunction( | ||
function_reference=func[0].anchor, | ||
arguments=[ | ||
stalg.FunctionArgument( | ||
value=e.referred_expr[0].expression | ||
) for e in bound_expressions | ||
], | ||
output_type=func[1] | ||
) | ||
), | ||
output_names=[alias if alias else 'scalar_function'], | ||
) | ||
], | ||
base_schema=base_schema, | ||
extension_uris=extension_uris, | ||
extensions=extensions | ||
) | ||
|
||
return resolve |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,53 @@ | ||
import substrait.gen.proto.type_pb2 as stp | ||
import substrait.gen.proto.extensions.extensions_pb2 as ste | ||
from typing import Iterable | ||
|
||
def type_num_names(typ: stp.Type): | ||
kind = typ.WhichOneof("kind") | ||
if kind == "struct": | ||
lengths = [type_num_names(t) for t in typ.struct.types] | ||
return sum(lengths) + 1 | ||
elif kind == "list": | ||
return type_num_names(typ.list.type) | ||
elif kind == "map": | ||
return type_num_names(typ.map.key) + type_num_names(typ.map.value) | ||
else: | ||
return 1 | ||
|
||
def merge_extension_uris(*extension_uris: Iterable[ste.SimpleExtensionURI]): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. A comment explaining what merge does (maintains order but removes duplicates) would be useful here. |
||
"""Merges multiple sets of SimpleExtensionURI objects into a single set. | ||
The order of extensions is kept intact, while duplicates are discarded. | ||
Assumes that there are no collisions (different extensions having identical anchors). | ||
""" | ||
seen_uris = set() | ||
ret = [] | ||
|
||
for uris in extension_uris: | ||
for uri in uris: | ||
if uri.uri not in seen_uris: | ||
seen_uris.add(uri.uri) | ||
ret.append(uri) | ||
|
||
return ret | ||
|
||
def merge_extension_declarations(*extension_declarations: Iterable[ste.SimpleExtensionDeclaration]): | ||
"""Merges multiple sets of SimpleExtensionDeclaration objects into a single set. | ||
The order of extension declarations is kept intact, while duplicates are discarded. | ||
Assumes that there are no collisions (different extension declarations having identical anchors). | ||
""" | ||
|
||
seen_extension_functions = set() | ||
ret = [] | ||
|
||
for declarations in extension_declarations: | ||
for declaration in declarations: | ||
if declaration.WhichOneof('mapping_type') == 'extension_function': | ||
ident = (declaration.extension_function.extension_uri_reference, declaration.extension_function.name) | ||
if ident not in seen_extension_functions: | ||
seen_extension_functions.add(ident) | ||
ret.append(declaration) | ||
else: | ||
raise Exception('') #TODO handle extension types | ||
|
||
return ret | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,105 @@ | ||
import substrait.gen.proto.algebra_pb2 as stalg | ||
import substrait.gen.proto.type_pb2 as stt | ||
import substrait.gen.proto.extended_expression_pb2 as stee | ||
from substrait.extended_expression import column | ||
|
||
|
||
struct = stt.Type.Struct( | ||
types=[ | ||
stt.Type(i64=stt.Type.I64(nullability=stt.Type.NULLABILITY_REQUIRED)), | ||
stt.Type(string=stt.Type.String(nullability=stt.Type.NULLABILITY_NULLABLE)), | ||
stt.Type(fp32=stt.Type.FP32(nullability=stt.Type.NULLABILITY_NULLABLE)), | ||
] | ||
) | ||
|
||
named_struct = stt.NamedStruct( | ||
names=["order_id", "description", "order_total"], struct=struct | ||
) | ||
|
||
nested_struct = stt.Type.Struct( | ||
types=[ | ||
stt.Type(i64=stt.Type.I64(nullability=stt.Type.NULLABILITY_REQUIRED)), | ||
stt.Type( | ||
struct=stt.Type.Struct( | ||
types=[ | ||
stt.Type( | ||
i64=stt.Type.I64(nullability=stt.Type.NULLABILITY_REQUIRED) | ||
), | ||
stt.Type( | ||
fp32=stt.Type.FP32(nullability=stt.Type.NULLABILITY_NULLABLE) | ||
), | ||
], | ||
nullability=stt.Type.NULLABILITY_NULLABLE, | ||
) | ||
), | ||
stt.Type(fp32=stt.Type.FP32(nullability=stt.Type.NULLABILITY_NULLABLE)), | ||
] | ||
) | ||
|
||
nested_named_struct = stt.NamedStruct( | ||
names=["order_id", "shop_details", "shop_id", "shop_total", "order_total"], | ||
struct=nested_struct, | ||
) | ||
|
||
|
||
def test_column_no_nesting(): | ||
assert column("description")(named_struct, None) == stee.ExtendedExpression( | ||
referred_expr=[ | ||
stee.ExpressionReference( | ||
expression=stalg.Expression( | ||
selection=stalg.Expression.FieldReference( | ||
root_reference=stalg.Expression.FieldReference.RootReference(), | ||
direct_reference=stalg.Expression.ReferenceSegment( | ||
struct_field=stalg.Expression.ReferenceSegment.StructField( | ||
field=1 | ||
) | ||
), | ||
) | ||
), | ||
output_names=["description"], | ||
) | ||
], | ||
base_schema=named_struct, | ||
) | ||
|
||
|
||
def test_column_nesting(): | ||
assert column("order_total")(nested_named_struct, None) == stee.ExtendedExpression( | ||
referred_expr=[ | ||
stee.ExpressionReference( | ||
expression=stalg.Expression( | ||
selection=stalg.Expression.FieldReference( | ||
root_reference=stalg.Expression.FieldReference.RootReference(), | ||
direct_reference=stalg.Expression.ReferenceSegment( | ||
struct_field=stalg.Expression.ReferenceSegment.StructField( | ||
field=2 | ||
) | ||
), | ||
) | ||
), | ||
output_names=["order_total"], | ||
) | ||
], | ||
base_schema=nested_named_struct, | ||
) | ||
|
||
|
||
def test_column_nested_struct(): | ||
assert column("shop_details")(nested_named_struct, None) == stee.ExtendedExpression( | ||
referred_expr=[ | ||
stee.ExpressionReference( | ||
expression=stalg.Expression( | ||
selection=stalg.Expression.FieldReference( | ||
root_reference=stalg.Expression.FieldReference.RootReference(), | ||
direct_reference=stalg.Expression.ReferenceSegment( | ||
struct_field=stalg.Expression.ReferenceSegment.StructField( | ||
field=1 | ||
) | ||
), | ||
) | ||
), | ||
output_names=["shop_details", "shop_id", "shop_total"], | ||
) | ||
], | ||
base_schema=nested_named_struct, | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is there a way to create a literal builder that works for both expressions and extended expressions (avoiding code duplication)? Perhaps we need to add more support to the other builder parts to accept either.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think a simple Expression builder makes a lot of sense, especially for literals. I'm simply trying to keep things uniform for starters as I don't have a clear idea how to tackle Expression builders for non-literal types.