Skip to content

feat: extended expression builders #71

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Mar 30, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
156 changes: 156 additions & 0 deletions src/substrait/extended_expression.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,156 @@
import itertools
import substrait.gen.proto.algebra_pb2 as stalg
import substrait.gen.proto.type_pb2 as stp
import substrait.gen.proto.extended_expression_pb2 as stee
import substrait.gen.proto.extensions.extensions_pb2 as ste
from substrait.function_registry import FunctionRegistry
from substrait.utils import type_num_names, merge_extension_uris, merge_extension_declarations
from substrait.type_inference import infer_extended_expression_schema
from typing import Callable, Any, Union

UnboundExpression = Callable[[stp.NamedStruct, FunctionRegistry], stee.ExtendedExpression]

def literal(value: Any, type: stp.Type, alias: str = None) -> UnboundExpression:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a way to create a literal builder that works for both expressions and extended expressions (avoiding code duplication)? Perhaps we need to add more support to the other builder parts to accept either.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think a simple Expression builder makes a lot of sense, especially for literals. I'm simply trying to keep things uniform for starters as I don't have a clear idea how to tackle Expression builders for non-literal types.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please add comments for these methods (for instance, note how names are used).

"""Builds a resolver for ExtendedExpression containing a literal expression"""
def resolve(base_schema: stp.NamedStruct, registry: FunctionRegistry) -> stee.ExtendedExpression:
kind = type.WhichOneof('kind')

if kind == "bool":
literal = stalg.Expression.Literal(boolean=value, nullable=type.bool.nullability == stp.Type.NULLABILITY_NULLABLE)
elif kind == "i8":
literal = stalg.Expression.Literal(i8=value, nullable=type.i8.nullability == stp.Type.NULLABILITY_NULLABLE)
elif kind == "i16":
literal = stalg.Expression.Literal(i16=value, nullable=type.i16.nullability == stp.Type.NULLABILITY_NULLABLE)
elif kind == "i32":
literal = stalg.Expression.Literal(i32=value, nullable=type.i32.nullability == stp.Type.NULLABILITY_NULLABLE)
elif kind == "i64":
literal = stalg.Expression.Literal(i64=value, nullable=type.i64.nullability == stp.Type.NULLABILITY_NULLABLE)
elif kind == "fp32":
literal = stalg.Expression.Literal(fp32=value, nullable=type.fp32.nullability == stp.Type.NULLABILITY_NULLABLE)
elif kind == "fp64":
literal = stalg.Expression.Literal(fp64=value, nullable=type.fp64.nullability == stp.Type.NULLABILITY_NULLABLE)
elif kind == "string":
literal = stalg.Expression.Literal(string=value, nullable=type.string.nullability == stp.Type.NULLABILITY_NULLABLE)
else:
raise Exception(f"Unknown literal type - {type}")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are also some unimplemented types here (such as dates, times, binary, uuid, etc.). Is it worth providing guidance here as to how to request/implement support?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

that's a good point. I'll follow up with those implementations shortly and let's discuss this in that PR, if that's okay. Honestly I'm not sure how a literal builder for a "complex" type (something like IntervalCompound for example) should look like.


return stee.ExtendedExpression(
referred_expr=[
stee.ExpressionReference(
expression=stalg.Expression(
literal=literal
),
output_names=[alias if alias else f'literal_{kind}'],
)
],
base_schema=base_schema,
)

return resolve

def column(field: Union[str, int]):
"""Builds a resolver for ExtendedExpression containing a FieldReference expression

Accepts either an index or a field name of a desired field.
"""
def resolve(base_schema: stp.NamedStruct, registry: FunctionRegistry) -> stee.ExtendedExpression:
if isinstance(field, str):
column_index = list(base_schema.names).index(field)
lengths = [type_num_names(t) for t in base_schema.struct.types]
flat_indices = [0] + list(itertools.accumulate(lengths))[:-1]
field_index = flat_indices.index(column_index)
else:
field_index = field

names_start = flat_indices[field_index]
names_end = (
flat_indices[field_index + 1]
if len(flat_indices) > field_index + 1
else None
)

return stee.ExtendedExpression(
referred_expr=[
stee.ExpressionReference(
expression=stalg.Expression(
selection=stalg.Expression.FieldReference(
root_reference=stalg.Expression.FieldReference.RootReference(),
direct_reference=stalg.Expression.ReferenceSegment(
struct_field=stalg.Expression.ReferenceSegment.StructField(
field=field_index
)
),
)
),
output_names=list(base_schema.names)[names_start:names_end],
)
],
base_schema=base_schema,
)

return resolve

def scalar_function(uri: str, function: str, *expressions: UnboundExpression, alias: str = None):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this accept regular expressions and extended expressions as well?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think there are 3 possible types to be used here: Expression, ExtendedExpression and UnboundExpression. My thinking about pros/cons of each option is:

  • The downside of using a regular proto Expression type is that Expressions can't carry additional information regarding resources that are used in the expression tree, meaning that if you use an extension function or a cte (once this substrait PR is merged) that context is thrown away with a regular Expression type.

    The same applies to a more Rel-level builders as well, for example a possible join builder can have one of these following signatures: join(left: Rel, right: Rel, join_type: str, condition: Expression) or join(left: Plan, right: Plan, join_type: str, condition: ExtendedExpression). I think we should prefer the second one because it allows for each object to independently hold full context of the extenstions/ctes used while building each object and allows them to be merged by the builder function later on.

  • ExtendedExpression is perfectly okay to use here with the only downside being that in order to build an ExtendedExpression the user needs to be aware of precisely in what context the expression is supposed to be used, in other words once ExtendedExpression is built, you've already tied it to a specific Rel schema where it can be used.

  • UnboundExpression is the currying equivalent of ExtendedExpression which allows the user to defer resolving precise Rel schema until later when it's actually used in the context of a specific Rel.

I guess we can probably introduce a type like Union[Expression, ExtendedExpression, UnboundExpression] and use it in similar builder signatures, if we want to cover as many use cases as possible. We will have to then do some light type matching on the builder side.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The use case I'd like to enable is modifying an existing plan. I suppose you could convert an existing plan into a series of extended expressions. We don't need to do this right away though. A set of builders that work from scratch is a reasonable place to start and we can look at the other use cases later.

"""Builds a resolver for ExtendedExpression containing a ScalarFunction expression"""
def resolve(base_schema: stp.NamedStruct, registry: FunctionRegistry) -> stee.ExtendedExpression:
bound_expressions: list[stee.ExtendedExpression] = [e(base_schema, registry) for e in expressions]

expression_schemas = [infer_extended_expression_schema(b) for b in bound_expressions]

signature = [typ for es in expression_schemas for typ in es.types]

func = registry.lookup_function(uri, function, signature)

if not func:
raise Exception('')

func_extension_uris = [
ste.SimpleExtensionURI(
extension_uri_anchor=registry.lookup_uri(uri),
uri=uri
)
]

func_extensions = [
ste.SimpleExtensionDeclaration(
extension_function=ste.SimpleExtensionDeclaration.ExtensionFunction(
extension_uri_reference=registry.lookup_uri(uri),
function_anchor=func[0].anchor,
name=function
)
)
]

extension_uris = merge_extension_uris(
func_extension_uris,
*[b.extension_uris for b in bound_expressions]
)

extensions = merge_extension_declarations(
func_extensions,
*[b.extensions for b in bound_expressions]
)

return stee.ExtendedExpression(
referred_expr=[
stee.ExpressionReference(
expression=stalg.Expression(
scalar_function=stalg.Expression.ScalarFunction(
function_reference=func[0].anchor,
arguments=[
stalg.FunctionArgument(
value=e.referred_expr[0].expression
) for e in bound_expressions
],
output_type=func[1]
)
),
output_names=[alias if alias else 'scalar_function'],
)
],
base_schema=base_schema,
extension_uris=extension_uris,
extensions=extensions
)

return resolve
10 changes: 9 additions & 1 deletion src/substrait/function_registry.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
from substrait.gen.proto.parameterized_types_pb2 import ParameterizedType
from substrait.gen.proto.type_pb2 import Type
from importlib.resources import files as importlib_files
import itertools
Expand Down Expand Up @@ -227,6 +226,9 @@ def satisfies_signature(self, signature: tuple) -> Optional[str]:

class FunctionRegistry:
def __init__(self, load_default_extensions=True) -> None:
self._uri_mapping: dict = defaultdict(dict)
self._uri_id_generator = itertools.count(1)

self._function_mapping: dict = defaultdict(dict)
self._id_generator = itertools.count(1)

Expand All @@ -252,6 +254,8 @@ def register_extension_yaml(
self.register_extension_dict(extension_definitions, uri)

def register_extension_dict(self, definitions: dict, uri: str) -> None:
self._uri_mapping[uri] = next(self._uri_id_generator)

for named_functions in definitions.values():
for function in named_functions:
for impl in function.get("impls", []):
Expand Down Expand Up @@ -285,3 +289,7 @@ def lookup_function(
return (f, rtn)

return None

def lookup_uri(self, uri: str) -> Optional[int]:
uri = self._uri_aliases.get(uri, uri)
return self._uri_mapping.get(uri, None)
12 changes: 12 additions & 0 deletions src/substrait/type_inference.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import substrait.gen.proto.algebra_pb2 as stalg
import substrait.gen.proto.extended_expression_pb2 as stee
import substrait.gen.proto.type_pb2 as stt


Expand Down Expand Up @@ -220,6 +221,17 @@ def infer_expression_type(
raise Exception(f"Unknown rex_type {rex_type}")


def infer_extended_expression_schema(ee: stee.ExtendedExpression) -> stt.Type.Struct:
exprs = [e for e in ee.referred_expr]

types = [infer_expression_type(e.expression, ee.base_schema.struct) for e in exprs]

return stt.Type.Struct(
types=types,
nullability=stt.Type.NULLABILITY_REQUIRED,
)


def infer_rel_schema(rel: stalg.Rel) -> stt.Type.Struct:
rel_type = rel.WhichOneof("rel_type")

Expand Down
53 changes: 53 additions & 0 deletions src/substrait/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
import substrait.gen.proto.type_pb2 as stp
import substrait.gen.proto.extensions.extensions_pb2 as ste
from typing import Iterable

def type_num_names(typ: stp.Type):
kind = typ.WhichOneof("kind")
if kind == "struct":
lengths = [type_num_names(t) for t in typ.struct.types]
return sum(lengths) + 1
elif kind == "list":
return type_num_names(typ.list.type)
elif kind == "map":
return type_num_names(typ.map.key) + type_num_names(typ.map.value)
else:
return 1

def merge_extension_uris(*extension_uris: Iterable[ste.SimpleExtensionURI]):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A comment explaining what merge does (maintains order but removes duplicates) would be useful here.

"""Merges multiple sets of SimpleExtensionURI objects into a single set.
The order of extensions is kept intact, while duplicates are discarded.
Assumes that there are no collisions (different extensions having identical anchors).
"""
seen_uris = set()
ret = []

for uris in extension_uris:
for uri in uris:
if uri.uri not in seen_uris:
seen_uris.add(uri.uri)
ret.append(uri)

return ret

def merge_extension_declarations(*extension_declarations: Iterable[ste.SimpleExtensionDeclaration]):
"""Merges multiple sets of SimpleExtensionDeclaration objects into a single set.
The order of extension declarations is kept intact, while duplicates are discarded.
Assumes that there are no collisions (different extension declarations having identical anchors).
"""

seen_extension_functions = set()
ret = []

for declarations in extension_declarations:
for declaration in declarations:
if declaration.WhichOneof('mapping_type') == 'extension_function':
ident = (declaration.extension_function.extension_uri_reference, declaration.extension_function.name)
if ident not in seen_extension_functions:
seen_extension_functions.add(ident)
ret.append(declaration)
else:
raise Exception('') #TODO handle extension types

return ret

105 changes: 105 additions & 0 deletions tests/extended_expression/test_column.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
import substrait.gen.proto.algebra_pb2 as stalg
import substrait.gen.proto.type_pb2 as stt
import substrait.gen.proto.extended_expression_pb2 as stee
from substrait.extended_expression import column


struct = stt.Type.Struct(
types=[
stt.Type(i64=stt.Type.I64(nullability=stt.Type.NULLABILITY_REQUIRED)),
stt.Type(string=stt.Type.String(nullability=stt.Type.NULLABILITY_NULLABLE)),
stt.Type(fp32=stt.Type.FP32(nullability=stt.Type.NULLABILITY_NULLABLE)),
]
)

named_struct = stt.NamedStruct(
names=["order_id", "description", "order_total"], struct=struct
)

nested_struct = stt.Type.Struct(
types=[
stt.Type(i64=stt.Type.I64(nullability=stt.Type.NULLABILITY_REQUIRED)),
stt.Type(
struct=stt.Type.Struct(
types=[
stt.Type(
i64=stt.Type.I64(nullability=stt.Type.NULLABILITY_REQUIRED)
),
stt.Type(
fp32=stt.Type.FP32(nullability=stt.Type.NULLABILITY_NULLABLE)
),
],
nullability=stt.Type.NULLABILITY_NULLABLE,
)
),
stt.Type(fp32=stt.Type.FP32(nullability=stt.Type.NULLABILITY_NULLABLE)),
]
)

nested_named_struct = stt.NamedStruct(
names=["order_id", "shop_details", "shop_id", "shop_total", "order_total"],
struct=nested_struct,
)


def test_column_no_nesting():
assert column("description")(named_struct, None) == stee.ExtendedExpression(
referred_expr=[
stee.ExpressionReference(
expression=stalg.Expression(
selection=stalg.Expression.FieldReference(
root_reference=stalg.Expression.FieldReference.RootReference(),
direct_reference=stalg.Expression.ReferenceSegment(
struct_field=stalg.Expression.ReferenceSegment.StructField(
field=1
)
),
)
),
output_names=["description"],
)
],
base_schema=named_struct,
)


def test_column_nesting():
assert column("order_total")(nested_named_struct, None) == stee.ExtendedExpression(
referred_expr=[
stee.ExpressionReference(
expression=stalg.Expression(
selection=stalg.Expression.FieldReference(
root_reference=stalg.Expression.FieldReference.RootReference(),
direct_reference=stalg.Expression.ReferenceSegment(
struct_field=stalg.Expression.ReferenceSegment.StructField(
field=2
)
),
)
),
output_names=["order_total"],
)
],
base_schema=nested_named_struct,
)


def test_column_nested_struct():
assert column("shop_details")(nested_named_struct, None) == stee.ExtendedExpression(
referred_expr=[
stee.ExpressionReference(
expression=stalg.Expression(
selection=stalg.Expression.FieldReference(
root_reference=stalg.Expression.FieldReference.RootReference(),
direct_reference=stalg.Expression.ReferenceSegment(
struct_field=stalg.Expression.ReferenceSegment.StructField(
field=1
)
),
)
),
output_names=["shop_details", "shop_id", "shop_total"],
)
],
base_schema=nested_named_struct,
)
Loading