-
Notifications
You must be signed in to change notification settings - Fork 9
feat: add builders for aggregate, window funcs and if-then #72
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 1 commit
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,374 @@ | ||
import itertools | ||
import substrait.gen.proto.algebra_pb2 as stalg | ||
import substrait.gen.proto.type_pb2 as stp | ||
import substrait.gen.proto.extended_expression_pb2 as stee | ||
import substrait.gen.proto.extensions.extensions_pb2 as ste | ||
from substrait.extension_registry import ExtensionRegistry | ||
from substrait.utils import type_num_names, merge_extension_uris, merge_extension_declarations | ||
from substrait.type_inference import infer_extended_expression_schema | ||
from typing import Callable, Any, Union, Iterable | ||
|
||
UnboundExtendedExpression = Callable[[stp.NamedStruct, ExtensionRegistry], stee.ExtendedExpression] | ||
|
||
def literal(value: Any, type: stp.Type, alias: str = None) -> UnboundExtendedExpression: | ||
"""Builds a resolver for ExtendedExpression containing a literal expression""" | ||
def resolve(base_schema: stp.NamedStruct, registry: ExtensionRegistry) -> stee.ExtendedExpression: | ||
kind = type.WhichOneof('kind') | ||
|
||
if kind == "bool": | ||
literal = stalg.Expression.Literal(boolean=value, nullable=type.bool.nullability == stp.Type.NULLABILITY_NULLABLE) | ||
elif kind == "i8": | ||
literal = stalg.Expression.Literal(i8=value, nullable=type.i8.nullability == stp.Type.NULLABILITY_NULLABLE) | ||
elif kind == "i16": | ||
literal = stalg.Expression.Literal(i16=value, nullable=type.i16.nullability == stp.Type.NULLABILITY_NULLABLE) | ||
elif kind == "i32": | ||
literal = stalg.Expression.Literal(i32=value, nullable=type.i32.nullability == stp.Type.NULLABILITY_NULLABLE) | ||
elif kind == "i64": | ||
literal = stalg.Expression.Literal(i64=value, nullable=type.i64.nullability == stp.Type.NULLABILITY_NULLABLE) | ||
elif kind == "fp32": | ||
literal = stalg.Expression.Literal(fp32=value, nullable=type.fp32.nullability == stp.Type.NULLABILITY_NULLABLE) | ||
elif kind == "fp64": | ||
literal = stalg.Expression.Literal(fp64=value, nullable=type.fp64.nullability == stp.Type.NULLABILITY_NULLABLE) | ||
elif kind == "string": | ||
literal = stalg.Expression.Literal(string=value, nullable=type.string.nullability == stp.Type.NULLABILITY_NULLABLE) | ||
else: | ||
raise Exception(f"Unknown literal type - {type}") | ||
|
||
return stee.ExtendedExpression( | ||
referred_expr=[ | ||
stee.ExpressionReference( | ||
expression=stalg.Expression( | ||
literal=literal | ||
), | ||
output_names=[alias if alias else f'literal_{kind}'], | ||
) | ||
], | ||
base_schema=base_schema, | ||
) | ||
|
||
return resolve | ||
|
||
def column(field: Union[str, int], alias: str = None): | ||
"""Builds a resolver for ExtendedExpression containing a FieldReference expression | ||
|
||
Accepts either an index or a field name of a desired field. | ||
""" | ||
|
||
def resolve( | ||
base_schema: stp.NamedStruct, registry: ExtensionRegistry | ||
) -> stee.ExtendedExpression: | ||
lengths = [type_num_names(t) for t in base_schema.struct.types] | ||
flat_indices = [0] + list(itertools.accumulate(lengths))[:-1] | ||
|
||
if isinstance(field, str): | ||
column_index = list(base_schema.names).index(field) | ||
field_index = flat_indices.index(column_index) | ||
else: | ||
field_index = field | ||
|
||
names_start = flat_indices[field_index] | ||
names_end = ( | ||
flat_indices[field_index + 1] | ||
if len(flat_indices) > field_index + 1 | ||
else None | ||
) | ||
|
||
return stee.ExtendedExpression( | ||
referred_expr=[ | ||
stee.ExpressionReference( | ||
expression=stalg.Expression( | ||
selection=stalg.Expression.FieldReference( | ||
root_reference=stalg.Expression.FieldReference.RootReference(), | ||
direct_reference=stalg.Expression.ReferenceSegment( | ||
struct_field=stalg.Expression.ReferenceSegment.StructField( | ||
field=field_index | ||
) | ||
), | ||
) | ||
), | ||
output_names=list(base_schema.names)[names_start:names_end] | ||
if not alias | ||
else [alias], | ||
) | ||
], | ||
base_schema=base_schema, | ||
) | ||
|
||
return resolve | ||
|
||
def scalar_function( | ||
uri: str, function: str, *expressions: UnboundExtendedExpression, alias: str = None | ||
): | ||
"""Builds a resolver for ExtendedExpression containing a ScalarFunction expression""" | ||
|
||
def resolve( | ||
base_schema: stp.NamedStruct, registry: ExtensionRegistry | ||
) -> stee.ExtendedExpression: | ||
bound_expressions: Iterable[stee.ExtendedExpression] = [ | ||
e(base_schema, registry) for e in expressions | ||
] | ||
|
||
expression_schemas = [ | ||
infer_extended_expression_schema(b) for b in bound_expressions | ||
] | ||
|
||
signature = [typ for es in expression_schemas for typ in es.types] | ||
|
||
func = registry.lookup_function(uri, function, signature) | ||
|
||
if not func: | ||
raise Exception(f"Unknown function {function} for {signature}") | ||
|
||
func_extension_uris = [ | ||
ste.SimpleExtensionURI( | ||
extension_uri_anchor=registry.lookup_uri(uri), uri=uri | ||
) | ||
] | ||
|
||
func_extensions = [ | ||
ste.SimpleExtensionDeclaration( | ||
extension_function=ste.SimpleExtensionDeclaration.ExtensionFunction( | ||
extension_uri_reference=registry.lookup_uri(uri), | ||
function_anchor=func[0].anchor, | ||
name=function, | ||
) | ||
) | ||
] | ||
|
||
extension_uris = merge_extension_uris( | ||
func_extension_uris, *[b.extension_uris for b in bound_expressions] | ||
) | ||
|
||
extensions = merge_extension_declarations( | ||
func_extensions, *[b.extensions for b in bound_expressions] | ||
) | ||
|
||
return stee.ExtendedExpression( | ||
referred_expr=[ | ||
stee.ExpressionReference( | ||
expression=stalg.Expression( | ||
scalar_function=stalg.Expression.ScalarFunction( | ||
function_reference=func[0].anchor, | ||
arguments=[ | ||
stalg.FunctionArgument( | ||
value=e.referred_expr[0].expression | ||
) | ||
for e in bound_expressions | ||
], | ||
output_type=func[1], | ||
) | ||
), | ||
output_names=[alias if alias else "scalar_function"], | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should this be a unique name? Perhaps keep track of generated names and prepend a count for each reuse (and elsewhere). There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I switched to a rudimentary way of generating names from argument names. (ibis does something like this) Pretty sure it's gonna fail if the output type of an expression contains a struct, but it's a start. I'm not sure if it's worth it to infer an output type just to get the expected cardinality of the |
||
) | ||
], | ||
base_schema=base_schema, | ||
extension_uris=extension_uris, | ||
extensions=extensions, | ||
) | ||
|
||
return resolve | ||
|
||
def aggregate_function( | ||
uri: str, function: str, *expressions: UnboundExtendedExpression, alias: str = None | ||
): | ||
"""Builds a resolver for ExtendedExpression containing a AggregateFunction measure""" | ||
|
||
def resolve( | ||
base_schema: stp.NamedStruct, registry: ExtensionRegistry | ||
) -> stee.ExtendedExpression: | ||
bound_expressions: Iterable[stee.ExtendedExpression] = [ | ||
e(base_schema, registry) for e in expressions | ||
] | ||
|
||
expression_schemas = [ | ||
infer_extended_expression_schema(b) for b in bound_expressions | ||
] | ||
|
||
signature = [typ for es in expression_schemas for typ in es.types] | ||
|
||
func = registry.lookup_function(uri, function, signature) | ||
|
||
if not func: | ||
raise Exception(f"Unknown function {function} for {signature}") | ||
|
||
func_extension_uris = [ | ||
ste.SimpleExtensionURI( | ||
extension_uri_anchor=registry.lookup_uri(uri), uri=uri | ||
) | ||
] | ||
|
||
func_extensions = [ | ||
ste.SimpleExtensionDeclaration( | ||
extension_function=ste.SimpleExtensionDeclaration.ExtensionFunction( | ||
extension_uri_reference=registry.lookup_uri(uri), | ||
function_anchor=func[0].anchor, | ||
name=function, | ||
) | ||
) | ||
] | ||
|
||
extension_uris = merge_extension_uris( | ||
func_extension_uris, *[b.extension_uris for b in bound_expressions] | ||
) | ||
|
||
extensions = merge_extension_declarations( | ||
func_extensions, *[b.extensions for b in bound_expressions] | ||
) | ||
|
||
return stee.ExtendedExpression( | ||
referred_expr=[ | ||
stee.ExpressionReference( | ||
measure=stalg.AggregateFunction( | ||
function_reference=func[0].anchor, | ||
arguments=[ | ||
stalg.FunctionArgument(value=e.referred_expr[0].expression) | ||
for e in bound_expressions | ||
], | ||
output_type=func[1], | ||
), | ||
output_names=[alias if alias else "aggregate_function"], | ||
) | ||
], | ||
base_schema=base_schema, | ||
extension_uris=extension_uris, | ||
extensions=extensions, | ||
) | ||
|
||
return resolve | ||
|
||
|
||
# TODO bounds, sorts | ||
def window_function( | ||
uri: str, | ||
function: str, | ||
*expressions: UnboundExtendedExpression, | ||
partitions: Iterable[UnboundExtendedExpression] = [], | ||
alias: str = None, | ||
): | ||
"""Builds a resolver for ExtendedExpression containing a WindowFunction expression""" | ||
|
||
def resolve( | ||
base_schema: stp.NamedStruct, registry: ExtensionRegistry | ||
) -> stee.ExtendedExpression: | ||
bound_expressions: Iterable[stee.ExtendedExpression] = [ | ||
e(base_schema, registry) for e in expressions | ||
] | ||
|
||
bound_partitions = [e(base_schema, registry) for e in partitions] | ||
|
||
expression_schemas = [ | ||
infer_extended_expression_schema(b) for b in bound_expressions | ||
] | ||
|
||
signature = [typ for es in expression_schemas for typ in es.types] | ||
|
||
func = registry.lookup_function(uri, function, signature) | ||
|
||
if not func: | ||
raise Exception(f"Unknown function {function} for {signature}") | ||
|
||
func_extension_uris = [ | ||
ste.SimpleExtensionURI( | ||
extension_uri_anchor=registry.lookup_uri(uri), uri=uri | ||
) | ||
] | ||
|
||
func_extensions = [ | ||
ste.SimpleExtensionDeclaration( | ||
extension_function=ste.SimpleExtensionDeclaration.ExtensionFunction( | ||
extension_uri_reference=registry.lookup_uri(uri), | ||
function_anchor=func[0].anchor, | ||
name=function, | ||
) | ||
) | ||
] | ||
|
||
extension_uris = merge_extension_uris( | ||
func_extension_uris, | ||
*[b.extension_uris for b in bound_expressions], | ||
*[b.extension_uris for b in bound_partitions], | ||
) | ||
|
||
extensions = merge_extension_declarations( | ||
func_extensions, | ||
*[b.extensions for b in bound_expressions], | ||
*[b.extensions for b in bound_partitions], | ||
) | ||
|
||
return stee.ExtendedExpression( | ||
referred_expr=[ | ||
stee.ExpressionReference( | ||
expression=stalg.Expression( | ||
window_function=stalg.Expression.WindowFunction( | ||
function_reference=func[0].anchor, | ||
arguments=[ | ||
stalg.FunctionArgument( | ||
value=e.referred_expr[0].expression | ||
) | ||
for e in bound_expressions | ||
], | ||
output_type=func[1], | ||
partitions=[ | ||
e.referred_expr[0].expression for e in bound_partitions | ||
], | ||
) | ||
), | ||
output_names=[alias if alias else "window_function"], | ||
) | ||
], | ||
base_schema=base_schema, | ||
extension_uris=extension_uris, | ||
extensions=extensions, | ||
) | ||
|
||
return resolve | ||
|
||
|
||
def if_then(ifs: Iterable[tuple[UnboundExtendedExpression, UnboundExtendedExpression]], _else: UnboundExtendedExpression, alias: str = None): | ||
"""Builds a resolver for ExtendedExpression containing a IfThen expression""" | ||
def resolve( | ||
base_schema: stp.NamedStruct, registry: ExtensionRegistry | ||
) -> stee.ExtendedExpression: | ||
bound_ifs = [ | ||
(if_clause[0](base_schema, registry), if_clause[1](base_schema, registry)) | ||
for if_clause in ifs | ||
] | ||
|
||
bound_else = _else(base_schema, registry) | ||
|
||
extension_uris = merge_extension_uris( | ||
*[b[0].extension_uris for b in bound_ifs], | ||
*[b[1].extension_uris for b in bound_ifs], | ||
bound_else.extension_uris | ||
) | ||
|
||
extensions = merge_extension_declarations( | ||
*[b[0].extensions for b in bound_ifs], | ||
*[b[1].extensions for b in bound_ifs], | ||
bound_else.extensions | ||
) | ||
|
||
return stee.ExtendedExpression( | ||
referred_expr=[ | ||
stee.ExpressionReference( | ||
expression=stalg.Expression( | ||
if_then=stalg.Expression.IfThen(**{ | ||
'ifs': [ | ||
stalg.Expression.IfThen.IfClause(**{ | ||
'if': if_clause[0].referred_expr[0].expression, | ||
'then': if_clause[1].referred_expr[0].expression, | ||
}) | ||
for if_clause in bound_ifs | ||
], | ||
'else': bound_else.referred_expr[0].expression | ||
}) | ||
), | ||
output_names=[alias if alias else "if_then"], | ||
) | ||
], | ||
base_schema=base_schema, | ||
extension_uris=extension_uris, | ||
extensions=extensions, | ||
) | ||
|
||
return resolve |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this file was deprecated in Python 3.3.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done. I still bump into issues occasionally when init files are missing, but we can always add them if there's a need for it later