Skip to content

Commit 4aae404

Browse files
authored
feat: add builders for aggregate, window funcs and if-then (#72)
- moves builders to a new directory - renames FunctionRegistry to ExtensionRegistry - adds new extended expression builders
1 parent 1e44dd8 commit 4aae404

9 files changed

+601
-166
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,384 @@
1+
import itertools
2+
import substrait.gen.proto.algebra_pb2 as stalg
3+
import substrait.gen.proto.type_pb2 as stp
4+
import substrait.gen.proto.extended_expression_pb2 as stee
5+
import substrait.gen.proto.extensions.extensions_pb2 as ste
6+
from substrait.extension_registry import ExtensionRegistry
7+
from substrait.utils import type_num_names, merge_extension_uris, merge_extension_declarations
8+
from substrait.type_inference import infer_extended_expression_schema
9+
from typing import Callable, Any, Union, Iterable
10+
11+
UnboundExtendedExpression = Callable[[stp.NamedStruct, ExtensionRegistry], stee.ExtendedExpression]
12+
13+
def _alias_or_inferred(
14+
alias: Union[Iterable[str], str],
15+
op: str,
16+
args: Iterable[str],
17+
):
18+
if alias:
19+
return [alias] if isinstance(alias, str) else alias
20+
else:
21+
return [f'{op}({",".join(args)})']
22+
23+
def literal(value: Any, type: stp.Type, alias: Union[Iterable[str], str] = None) -> UnboundExtendedExpression:
24+
"""Builds a resolver for ExtendedExpression containing a literal expression"""
25+
def resolve(base_schema: stp.NamedStruct, registry: ExtensionRegistry) -> stee.ExtendedExpression:
26+
kind = type.WhichOneof('kind')
27+
28+
if kind == "bool":
29+
literal = stalg.Expression.Literal(boolean=value, nullable=type.bool.nullability == stp.Type.NULLABILITY_NULLABLE)
30+
elif kind == "i8":
31+
literal = stalg.Expression.Literal(i8=value, nullable=type.i8.nullability == stp.Type.NULLABILITY_NULLABLE)
32+
elif kind == "i16":
33+
literal = stalg.Expression.Literal(i16=value, nullable=type.i16.nullability == stp.Type.NULLABILITY_NULLABLE)
34+
elif kind == "i32":
35+
literal = stalg.Expression.Literal(i32=value, nullable=type.i32.nullability == stp.Type.NULLABILITY_NULLABLE)
36+
elif kind == "i64":
37+
literal = stalg.Expression.Literal(i64=value, nullable=type.i64.nullability == stp.Type.NULLABILITY_NULLABLE)
38+
elif kind == "fp32":
39+
literal = stalg.Expression.Literal(fp32=value, nullable=type.fp32.nullability == stp.Type.NULLABILITY_NULLABLE)
40+
elif kind == "fp64":
41+
literal = stalg.Expression.Literal(fp64=value, nullable=type.fp64.nullability == stp.Type.NULLABILITY_NULLABLE)
42+
elif kind == "string":
43+
literal = stalg.Expression.Literal(string=value, nullable=type.string.nullability == stp.Type.NULLABILITY_NULLABLE)
44+
else:
45+
raise Exception(f"Unknown literal type - {type}")
46+
47+
return stee.ExtendedExpression(
48+
referred_expr=[
49+
stee.ExpressionReference(
50+
expression=stalg.Expression(
51+
literal=literal
52+
),
53+
output_names=_alias_or_inferred(alias, 'Literal', [str(value)])
54+
)
55+
],
56+
base_schema=base_schema,
57+
)
58+
59+
return resolve
60+
61+
def column(field: Union[str, int], alias: Union[Iterable[str], str] = None):
62+
"""Builds a resolver for ExtendedExpression containing a FieldReference expression
63+
64+
Accepts either an index or a field name of a desired field.
65+
"""
66+
alias = [alias] if alias and isinstance(alias, str) else alias
67+
68+
def resolve(
69+
base_schema: stp.NamedStruct, registry: ExtensionRegistry
70+
) -> stee.ExtendedExpression:
71+
lengths = [type_num_names(t) for t in base_schema.struct.types]
72+
flat_indices = [0] + list(itertools.accumulate(lengths))[:-1]
73+
74+
if isinstance(field, str):
75+
column_index = list(base_schema.names).index(field)
76+
field_index = flat_indices.index(column_index)
77+
else:
78+
field_index = field
79+
80+
names_start = flat_indices[field_index]
81+
names_end = (
82+
flat_indices[field_index + 1]
83+
if len(flat_indices) > field_index + 1
84+
else None
85+
)
86+
87+
return stee.ExtendedExpression(
88+
referred_expr=[
89+
stee.ExpressionReference(
90+
expression=stalg.Expression(
91+
selection=stalg.Expression.FieldReference(
92+
root_reference=stalg.Expression.FieldReference.RootReference(),
93+
direct_reference=stalg.Expression.ReferenceSegment(
94+
struct_field=stalg.Expression.ReferenceSegment.StructField(
95+
field=field_index
96+
)
97+
),
98+
)
99+
),
100+
output_names=list(base_schema.names)[names_start:names_end]
101+
if not alias
102+
else alias,
103+
)
104+
],
105+
base_schema=base_schema,
106+
)
107+
108+
return resolve
109+
110+
def scalar_function(
111+
uri: str, function: str, *expressions: UnboundExtendedExpression, alias: Union[Iterable[str], str] = None
112+
):
113+
"""Builds a resolver for ExtendedExpression containing a ScalarFunction expression"""
114+
def resolve(
115+
base_schema: stp.NamedStruct, registry: ExtensionRegistry
116+
) -> stee.ExtendedExpression:
117+
bound_expressions: Iterable[stee.ExtendedExpression] = [
118+
e(base_schema, registry) for e in expressions
119+
]
120+
121+
expression_schemas = [
122+
infer_extended_expression_schema(b) for b in bound_expressions
123+
]
124+
125+
signature = [typ for es in expression_schemas for typ in es.types]
126+
127+
func = registry.lookup_function(uri, function, signature)
128+
129+
if not func:
130+
raise Exception(f"Unknown function {function} for {signature}")
131+
132+
func_extension_uris = [
133+
ste.SimpleExtensionURI(
134+
extension_uri_anchor=registry.lookup_uri(uri), uri=uri
135+
)
136+
]
137+
138+
func_extensions = [
139+
ste.SimpleExtensionDeclaration(
140+
extension_function=ste.SimpleExtensionDeclaration.ExtensionFunction(
141+
extension_uri_reference=registry.lookup_uri(uri),
142+
function_anchor=func[0].anchor,
143+
name=function,
144+
)
145+
)
146+
]
147+
148+
extension_uris = merge_extension_uris(
149+
func_extension_uris, *[b.extension_uris for b in bound_expressions]
150+
)
151+
152+
extensions = merge_extension_declarations(
153+
func_extensions, *[b.extensions for b in bound_expressions]
154+
)
155+
156+
return stee.ExtendedExpression(
157+
referred_expr=[
158+
stee.ExpressionReference(
159+
expression=stalg.Expression(
160+
scalar_function=stalg.Expression.ScalarFunction(
161+
function_reference=func[0].anchor,
162+
arguments=[
163+
stalg.FunctionArgument(
164+
value=e.referred_expr[0].expression
165+
)
166+
for e in bound_expressions
167+
],
168+
output_type=func[1],
169+
)
170+
),
171+
output_names=_alias_or_inferred(alias, function, [e.referred_expr[0].output_names[0] for e in bound_expressions]),
172+
)
173+
],
174+
base_schema=base_schema,
175+
extension_uris=extension_uris,
176+
extensions=extensions,
177+
)
178+
179+
return resolve
180+
181+
def aggregate_function(
182+
uri: str, function: str, *expressions: UnboundExtendedExpression, alias: Union[Iterable[str], str] = None
183+
):
184+
"""Builds a resolver for ExtendedExpression containing a AggregateFunction measure"""
185+
def resolve(
186+
base_schema: stp.NamedStruct, registry: ExtensionRegistry
187+
) -> stee.ExtendedExpression:
188+
bound_expressions: Iterable[stee.ExtendedExpression] = [
189+
e(base_schema, registry) for e in expressions
190+
]
191+
192+
expression_schemas = [
193+
infer_extended_expression_schema(b) for b in bound_expressions
194+
]
195+
196+
signature = [typ for es in expression_schemas for typ in es.types]
197+
198+
func = registry.lookup_function(uri, function, signature)
199+
200+
if not func:
201+
raise Exception(f"Unknown function {function} for {signature}")
202+
203+
func_extension_uris = [
204+
ste.SimpleExtensionURI(
205+
extension_uri_anchor=registry.lookup_uri(uri), uri=uri
206+
)
207+
]
208+
209+
func_extensions = [
210+
ste.SimpleExtensionDeclaration(
211+
extension_function=ste.SimpleExtensionDeclaration.ExtensionFunction(
212+
extension_uri_reference=registry.lookup_uri(uri),
213+
function_anchor=func[0].anchor,
214+
name=function,
215+
)
216+
)
217+
]
218+
219+
extension_uris = merge_extension_uris(
220+
func_extension_uris, *[b.extension_uris for b in bound_expressions]
221+
)
222+
223+
extensions = merge_extension_declarations(
224+
func_extensions, *[b.extensions for b in bound_expressions]
225+
)
226+
227+
return stee.ExtendedExpression(
228+
referred_expr=[
229+
stee.ExpressionReference(
230+
measure=stalg.AggregateFunction(
231+
function_reference=func[0].anchor,
232+
arguments=[
233+
stalg.FunctionArgument(value=e.referred_expr[0].expression)
234+
for e in bound_expressions
235+
],
236+
output_type=func[1],
237+
),
238+
output_names=_alias_or_inferred(alias, 'IfThen', [e.referred_expr[0].output_names[0] for e in bound_expressions]),
239+
)
240+
],
241+
base_schema=base_schema,
242+
extension_uris=extension_uris,
243+
extensions=extensions,
244+
)
245+
246+
return resolve
247+
248+
249+
# TODO bounds, sorts
250+
def window_function(
251+
uri: str,
252+
function: str,
253+
*expressions: UnboundExtendedExpression,
254+
partitions: Iterable[UnboundExtendedExpression] = [],
255+
alias: Union[Iterable[str], str] = None
256+
):
257+
"""Builds a resolver for ExtendedExpression containing a WindowFunction expression"""
258+
def resolve(
259+
base_schema: stp.NamedStruct, registry: ExtensionRegistry
260+
) -> stee.ExtendedExpression:
261+
bound_expressions: Iterable[stee.ExtendedExpression] = [
262+
e(base_schema, registry) for e in expressions
263+
]
264+
265+
bound_partitions = [e(base_schema, registry) for e in partitions]
266+
267+
expression_schemas = [
268+
infer_extended_expression_schema(b) for b in bound_expressions
269+
]
270+
271+
signature = [typ for es in expression_schemas for typ in es.types]
272+
273+
func = registry.lookup_function(uri, function, signature)
274+
275+
if not func:
276+
raise Exception(f"Unknown function {function} for {signature}")
277+
278+
func_extension_uris = [
279+
ste.SimpleExtensionURI(
280+
extension_uri_anchor=registry.lookup_uri(uri), uri=uri
281+
)
282+
]
283+
284+
func_extensions = [
285+
ste.SimpleExtensionDeclaration(
286+
extension_function=ste.SimpleExtensionDeclaration.ExtensionFunction(
287+
extension_uri_reference=registry.lookup_uri(uri),
288+
function_anchor=func[0].anchor,
289+
name=function,
290+
)
291+
)
292+
]
293+
294+
extension_uris = merge_extension_uris(
295+
func_extension_uris,
296+
*[b.extension_uris for b in bound_expressions],
297+
*[b.extension_uris for b in bound_partitions],
298+
)
299+
300+
extensions = merge_extension_declarations(
301+
func_extensions,
302+
*[b.extensions for b in bound_expressions],
303+
*[b.extensions for b in bound_partitions],
304+
)
305+
306+
return stee.ExtendedExpression(
307+
referred_expr=[
308+
stee.ExpressionReference(
309+
expression=stalg.Expression(
310+
window_function=stalg.Expression.WindowFunction(
311+
function_reference=func[0].anchor,
312+
arguments=[
313+
stalg.FunctionArgument(
314+
value=e.referred_expr[0].expression
315+
)
316+
for e in bound_expressions
317+
],
318+
output_type=func[1],
319+
partitions=[
320+
e.referred_expr[0].expression for e in bound_partitions
321+
],
322+
)
323+
),
324+
output_names=_alias_or_inferred(alias, function, [e.referred_expr[0].output_names[0] for e in bound_expressions]),
325+
)
326+
],
327+
base_schema=base_schema,
328+
extension_uris=extension_uris,
329+
extensions=extensions,
330+
)
331+
332+
return resolve
333+
334+
335+
def if_then(ifs: Iterable[tuple[UnboundExtendedExpression, UnboundExtendedExpression]], _else: UnboundExtendedExpression, alias: Union[Iterable[str], str] = None):
336+
"""Builds a resolver for ExtendedExpression containing an IfThen expression"""
337+
def resolve(
338+
base_schema: stp.NamedStruct, registry: ExtensionRegistry
339+
) -> stee.ExtendedExpression:
340+
bound_ifs = [
341+
(if_clause[0](base_schema, registry), if_clause[1](base_schema, registry))
342+
for if_clause in ifs
343+
]
344+
345+
bound_else = _else(base_schema, registry)
346+
347+
extension_uris = merge_extension_uris(
348+
*[b[0].extension_uris for b in bound_ifs],
349+
*[b[1].extension_uris for b in bound_ifs],
350+
bound_else.extension_uris
351+
)
352+
353+
extensions = merge_extension_declarations(
354+
*[b[0].extensions for b in bound_ifs],
355+
*[b[1].extensions for b in bound_ifs],
356+
bound_else.extensions
357+
)
358+
359+
return stee.ExtendedExpression(
360+
referred_expr=[
361+
stee.ExpressionReference(
362+
expression=stalg.Expression(
363+
if_then=stalg.Expression.IfThen(**{
364+
'ifs': [
365+
stalg.Expression.IfThen.IfClause(**{
366+
'if': if_clause[0].referred_expr[0].expression,
367+
'then': if_clause[1].referred_expr[0].expression,
368+
})
369+
for if_clause in bound_ifs
370+
],
371+
'else': bound_else.referred_expr[0].expression
372+
})
373+
),
374+
output_names=_alias_or_inferred(alias, 'IfThen', [a for e in bound_ifs for a in [e[0].referred_expr[0].output_names[0], e[1].referred_expr[0].output_names[0]]]
375+
+ [bound_else.referred_expr[0].output_names[0]]
376+
),
377+
)
378+
],
379+
base_schema=base_schema,
380+
extension_uris=extension_uris,
381+
extensions=extensions,
382+
)
383+
384+
return resolve

0 commit comments

Comments
 (0)