Skip to content

feat: add builders for aggregate, window funcs and if-then #72

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Apr 16, 2025
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this file was deprecated in Python 3.3.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done. I still bump into issues occasionally when init files are missing, but we can always add them if there's a need for it later

Empty file.
374 changes: 374 additions & 0 deletions src/substrait/builders/extended_expression.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,374 @@
import itertools
import substrait.gen.proto.algebra_pb2 as stalg
import substrait.gen.proto.type_pb2 as stp
import substrait.gen.proto.extended_expression_pb2 as stee
import substrait.gen.proto.extensions.extensions_pb2 as ste
from substrait.extension_registry import ExtensionRegistry
from substrait.utils import type_num_names, merge_extension_uris, merge_extension_declarations
from substrait.type_inference import infer_extended_expression_schema
from typing import Callable, Any, Union, Iterable

UnboundExtendedExpression = Callable[[stp.NamedStruct, ExtensionRegistry], stee.ExtendedExpression]

def literal(value: Any, type: stp.Type, alias: str = None) -> UnboundExtendedExpression:
"""Builds a resolver for ExtendedExpression containing a literal expression"""
def resolve(base_schema: stp.NamedStruct, registry: ExtensionRegistry) -> stee.ExtendedExpression:
kind = type.WhichOneof('kind')

if kind == "bool":
literal = stalg.Expression.Literal(boolean=value, nullable=type.bool.nullability == stp.Type.NULLABILITY_NULLABLE)
elif kind == "i8":
literal = stalg.Expression.Literal(i8=value, nullable=type.i8.nullability == stp.Type.NULLABILITY_NULLABLE)
elif kind == "i16":
literal = stalg.Expression.Literal(i16=value, nullable=type.i16.nullability == stp.Type.NULLABILITY_NULLABLE)
elif kind == "i32":
literal = stalg.Expression.Literal(i32=value, nullable=type.i32.nullability == stp.Type.NULLABILITY_NULLABLE)
elif kind == "i64":
literal = stalg.Expression.Literal(i64=value, nullable=type.i64.nullability == stp.Type.NULLABILITY_NULLABLE)
elif kind == "fp32":
literal = stalg.Expression.Literal(fp32=value, nullable=type.fp32.nullability == stp.Type.NULLABILITY_NULLABLE)
elif kind == "fp64":
literal = stalg.Expression.Literal(fp64=value, nullable=type.fp64.nullability == stp.Type.NULLABILITY_NULLABLE)
elif kind == "string":
literal = stalg.Expression.Literal(string=value, nullable=type.string.nullability == stp.Type.NULLABILITY_NULLABLE)
else:
raise Exception(f"Unknown literal type - {type}")

return stee.ExtendedExpression(
referred_expr=[
stee.ExpressionReference(
expression=stalg.Expression(
literal=literal
),
output_names=[alias if alias else f'literal_{kind}'],
)
],
base_schema=base_schema,
)

return resolve

def column(field: Union[str, int], alias: str = None):
"""Builds a resolver for ExtendedExpression containing a FieldReference expression

Accepts either an index or a field name of a desired field.
"""

def resolve(
base_schema: stp.NamedStruct, registry: ExtensionRegistry
) -> stee.ExtendedExpression:
lengths = [type_num_names(t) for t in base_schema.struct.types]
flat_indices = [0] + list(itertools.accumulate(lengths))[:-1]

if isinstance(field, str):
column_index = list(base_schema.names).index(field)
field_index = flat_indices.index(column_index)
else:
field_index = field

names_start = flat_indices[field_index]
names_end = (
flat_indices[field_index + 1]
if len(flat_indices) > field_index + 1
else None
)

return stee.ExtendedExpression(
referred_expr=[
stee.ExpressionReference(
expression=stalg.Expression(
selection=stalg.Expression.FieldReference(
root_reference=stalg.Expression.FieldReference.RootReference(),
direct_reference=stalg.Expression.ReferenceSegment(
struct_field=stalg.Expression.ReferenceSegment.StructField(
field=field_index
)
),
)
),
output_names=list(base_schema.names)[names_start:names_end]
if not alias
else [alias],
)
],
base_schema=base_schema,
)

return resolve

def scalar_function(
uri: str, function: str, *expressions: UnboundExtendedExpression, alias: str = None
):
"""Builds a resolver for ExtendedExpression containing a ScalarFunction expression"""

def resolve(
base_schema: stp.NamedStruct, registry: ExtensionRegistry
) -> stee.ExtendedExpression:
bound_expressions: Iterable[stee.ExtendedExpression] = [
e(base_schema, registry) for e in expressions
]

expression_schemas = [
infer_extended_expression_schema(b) for b in bound_expressions
]

signature = [typ for es in expression_schemas for typ in es.types]

func = registry.lookup_function(uri, function, signature)

if not func:
raise Exception(f"Unknown function {function} for {signature}")

func_extension_uris = [
ste.SimpleExtensionURI(
extension_uri_anchor=registry.lookup_uri(uri), uri=uri
)
]

func_extensions = [
ste.SimpleExtensionDeclaration(
extension_function=ste.SimpleExtensionDeclaration.ExtensionFunction(
extension_uri_reference=registry.lookup_uri(uri),
function_anchor=func[0].anchor,
name=function,
)
)
]

extension_uris = merge_extension_uris(
func_extension_uris, *[b.extension_uris for b in bound_expressions]
)

extensions = merge_extension_declarations(
func_extensions, *[b.extensions for b in bound_expressions]
)

return stee.ExtendedExpression(
referred_expr=[
stee.ExpressionReference(
expression=stalg.Expression(
scalar_function=stalg.Expression.ScalarFunction(
function_reference=func[0].anchor,
arguments=[
stalg.FunctionArgument(
value=e.referred_expr[0].expression
)
for e in bound_expressions
],
output_type=func[1],
)
),
output_names=[alias if alias else "scalar_function"],
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this be a unique name? Perhaps keep track of generated names and prepend a count for each reuse (and elsewhere).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I switched to a rudimentary way of generating names from argument names. (ibis does something like this) Pretty sure it's gonna fail if the output type of an expression contains a struct, but it's a start. I'm not sure if it's worth it to infer an output type just to get the expected cardinality of the output_names array.

)
],
base_schema=base_schema,
extension_uris=extension_uris,
extensions=extensions,
)

return resolve

def aggregate_function(
uri: str, function: str, *expressions: UnboundExtendedExpression, alias: str = None
):
"""Builds a resolver for ExtendedExpression containing a AggregateFunction measure"""

def resolve(
base_schema: stp.NamedStruct, registry: ExtensionRegistry
) -> stee.ExtendedExpression:
bound_expressions: Iterable[stee.ExtendedExpression] = [
e(base_schema, registry) for e in expressions
]

expression_schemas = [
infer_extended_expression_schema(b) for b in bound_expressions
]

signature = [typ for es in expression_schemas for typ in es.types]

func = registry.lookup_function(uri, function, signature)

if not func:
raise Exception(f"Unknown function {function} for {signature}")

func_extension_uris = [
ste.SimpleExtensionURI(
extension_uri_anchor=registry.lookup_uri(uri), uri=uri
)
]

func_extensions = [
ste.SimpleExtensionDeclaration(
extension_function=ste.SimpleExtensionDeclaration.ExtensionFunction(
extension_uri_reference=registry.lookup_uri(uri),
function_anchor=func[0].anchor,
name=function,
)
)
]

extension_uris = merge_extension_uris(
func_extension_uris, *[b.extension_uris for b in bound_expressions]
)

extensions = merge_extension_declarations(
func_extensions, *[b.extensions for b in bound_expressions]
)

return stee.ExtendedExpression(
referred_expr=[
stee.ExpressionReference(
measure=stalg.AggregateFunction(
function_reference=func[0].anchor,
arguments=[
stalg.FunctionArgument(value=e.referred_expr[0].expression)
for e in bound_expressions
],
output_type=func[1],
),
output_names=[alias if alias else "aggregate_function"],
)
],
base_schema=base_schema,
extension_uris=extension_uris,
extensions=extensions,
)

return resolve


# TODO bounds, sorts
def window_function(
uri: str,
function: str,
*expressions: UnboundExtendedExpression,
partitions: Iterable[UnboundExtendedExpression] = [],
alias: str = None,
):
"""Builds a resolver for ExtendedExpression containing a WindowFunction expression"""

def resolve(
base_schema: stp.NamedStruct, registry: ExtensionRegistry
) -> stee.ExtendedExpression:
bound_expressions: Iterable[stee.ExtendedExpression] = [
e(base_schema, registry) for e in expressions
]

bound_partitions = [e(base_schema, registry) for e in partitions]

expression_schemas = [
infer_extended_expression_schema(b) for b in bound_expressions
]

signature = [typ for es in expression_schemas for typ in es.types]

func = registry.lookup_function(uri, function, signature)

if not func:
raise Exception(f"Unknown function {function} for {signature}")

func_extension_uris = [
ste.SimpleExtensionURI(
extension_uri_anchor=registry.lookup_uri(uri), uri=uri
)
]

func_extensions = [
ste.SimpleExtensionDeclaration(
extension_function=ste.SimpleExtensionDeclaration.ExtensionFunction(
extension_uri_reference=registry.lookup_uri(uri),
function_anchor=func[0].anchor,
name=function,
)
)
]

extension_uris = merge_extension_uris(
func_extension_uris,
*[b.extension_uris for b in bound_expressions],
*[b.extension_uris for b in bound_partitions],
)

extensions = merge_extension_declarations(
func_extensions,
*[b.extensions for b in bound_expressions],
*[b.extensions for b in bound_partitions],
)

return stee.ExtendedExpression(
referred_expr=[
stee.ExpressionReference(
expression=stalg.Expression(
window_function=stalg.Expression.WindowFunction(
function_reference=func[0].anchor,
arguments=[
stalg.FunctionArgument(
value=e.referred_expr[0].expression
)
for e in bound_expressions
],
output_type=func[1],
partitions=[
e.referred_expr[0].expression for e in bound_partitions
],
)
),
output_names=[alias if alias else "window_function"],
)
],
base_schema=base_schema,
extension_uris=extension_uris,
extensions=extensions,
)

return resolve


def if_then(ifs: Iterable[tuple[UnboundExtendedExpression, UnboundExtendedExpression]], _else: UnboundExtendedExpression, alias: str = None):
"""Builds a resolver for ExtendedExpression containing a IfThen expression"""
def resolve(
base_schema: stp.NamedStruct, registry: ExtensionRegistry
) -> stee.ExtendedExpression:
bound_ifs = [
(if_clause[0](base_schema, registry), if_clause[1](base_schema, registry))
for if_clause in ifs
]

bound_else = _else(base_schema, registry)

extension_uris = merge_extension_uris(
*[b[0].extension_uris for b in bound_ifs],
*[b[1].extension_uris for b in bound_ifs],
bound_else.extension_uris
)

extensions = merge_extension_declarations(
*[b[0].extensions for b in bound_ifs],
*[b[1].extensions for b in bound_ifs],
bound_else.extensions
)

return stee.ExtendedExpression(
referred_expr=[
stee.ExpressionReference(
expression=stalg.Expression(
if_then=stalg.Expression.IfThen(**{
'ifs': [
stalg.Expression.IfThen.IfClause(**{
'if': if_clause[0].referred_expr[0].expression,
'then': if_clause[1].referred_expr[0].expression,
})
for if_clause in bound_ifs
],
'else': bound_else.referred_expr[0].expression
})
),
output_names=[alias if alias else "if_then"],
)
],
base_schema=base_schema,
extension_uris=extension_uris,
extensions=extensions,
)

return resolve
Loading