10
10
from typing import Callable , Any , Union , Iterable
11
11
12
12
UnboundExtendedExpression = Callable [[stp .NamedStruct , ExtensionRegistry ], stee .ExtendedExpression ]
13
+ ExtendedExpressionOrUnbound = Union [stee .ExtendedExpression , UnboundExtendedExpression ]
13
14
14
15
def _alias_or_inferred (
15
16
alias : Union [Iterable [str ], str ],
@@ -21,6 +22,13 @@ def _alias_or_inferred(
21
22
else :
22
23
return [f'{ op } ({ "," .join (args )} )' ]
23
24
25
+ def resolve_expression (
26
+ expression : ExtendedExpressionOrUnbound ,
27
+ base_schema : stp .NamedStruct ,
28
+ registry : ExtensionRegistry
29
+ ) -> stee .ExtendedExpression :
30
+ return expression if isinstance (expression , stee .ExtendedExpression ) else expression (base_schema , registry )
31
+
24
32
def literal (value : Any , type : stp .Type , alias : Union [Iterable [str ], str ] = None ) -> UnboundExtendedExpression :
25
33
"""Builds a resolver for ExtendedExpression containing a literal expression"""
26
34
def resolve (base_schema : stp .NamedStruct , registry : ExtensionRegistry ) -> stee .ExtendedExpression :
@@ -139,14 +147,14 @@ def resolve(
139
147
return resolve
140
148
141
149
def scalar_function (
142
- uri : str , function : str , * expressions : UnboundExtendedExpression , alias : Union [Iterable [str ], str ] = None
150
+ uri : str , function : str , expressions : Iterable [ ExtendedExpressionOrUnbound ] , alias : Union [Iterable [str ], str ] = None
143
151
):
144
152
"""Builds a resolver for ExtendedExpression containing a ScalarFunction expression"""
145
153
def resolve (
146
154
base_schema : stp .NamedStruct , registry : ExtensionRegistry
147
155
) -> stee .ExtendedExpression :
148
- bound_expressions : Iterable [ stee . ExtendedExpression ] = [
149
- e ( base_schema , registry ) for e in expressions
156
+ bound_expressions = [
157
+ resolve_expression ( e , base_schema , registry ) for e in expressions
150
158
]
151
159
152
160
expression_schemas = [
@@ -210,14 +218,14 @@ def resolve(
210
218
return resolve
211
219
212
220
def aggregate_function (
213
- uri : str , function : str , * expressions : UnboundExtendedExpression , alias : Union [Iterable [str ], str ] = None
221
+ uri : str , function : str , expressions : Iterable [ ExtendedExpressionOrUnbound ] , alias : Union [Iterable [str ], str ] = None
214
222
):
215
223
"""Builds a resolver for ExtendedExpression containing a AggregateFunction measure"""
216
224
def resolve (
217
225
base_schema : stp .NamedStruct , registry : ExtensionRegistry
218
226
) -> stee .ExtendedExpression :
219
227
bound_expressions : Iterable [stee .ExtendedExpression ] = [
220
- e ( base_schema , registry ) for e in expressions
228
+ resolve_expression ( e , base_schema , registry ) for e in expressions
221
229
]
222
230
223
231
expression_schemas = [
@@ -281,19 +289,19 @@ def resolve(
281
289
def window_function (
282
290
uri : str ,
283
291
function : str ,
284
- * expressions : UnboundExtendedExpression ,
285
- partitions : Iterable [UnboundExtendedExpression ] = [],
292
+ expressions : Iterable [ ExtendedExpressionOrUnbound ] ,
293
+ partitions : Iterable [ExtendedExpressionOrUnbound ] = [],
286
294
alias : Union [Iterable [str ], str ] = None
287
295
):
288
296
"""Builds a resolver for ExtendedExpression containing a WindowFunction expression"""
289
297
def resolve (
290
298
base_schema : stp .NamedStruct , registry : ExtensionRegistry
291
299
) -> stee .ExtendedExpression :
292
300
bound_expressions : Iterable [stee .ExtendedExpression ] = [
293
- e ( base_schema , registry ) for e in expressions
301
+ resolve_expression ( e , base_schema , registry ) for e in expressions
294
302
]
295
303
296
- bound_partitions = [e ( base_schema , registry ) for e in partitions ]
304
+ bound_partitions = [resolve_expression ( e , base_schema , registry ) for e in partitions ]
297
305
298
306
expression_schemas = [
299
307
infer_extended_expression_schema (b ) for b in bound_expressions
@@ -363,17 +371,17 @@ def resolve(
363
371
return resolve
364
372
365
373
366
- def if_then (ifs : Iterable [tuple [UnboundExtendedExpression , UnboundExtendedExpression ]], _else : UnboundExtendedExpression , alias : Union [Iterable [str ], str ] = None ):
374
+ def if_then (ifs : Iterable [tuple [ExtendedExpressionOrUnbound , ExtendedExpressionOrUnbound ]], _else : ExtendedExpressionOrUnbound , alias : Union [Iterable [str ], str ] = None ):
367
375
"""Builds a resolver for ExtendedExpression containing an IfThen expression"""
368
376
def resolve (
369
377
base_schema : stp .NamedStruct , registry : ExtensionRegistry
370
378
) -> stee .ExtendedExpression :
371
379
bound_ifs = [
372
- (if_clause [0 ]( base_schema , registry ), if_clause [1 ]( base_schema , registry ))
380
+ (resolve_expression ( if_clause [0 ], base_schema , registry ), resolve_expression ( if_clause [1 ], base_schema , registry ))
373
381
for if_clause in ifs
374
382
]
375
383
376
- bound_else = _else ( base_schema , registry )
384
+ bound_else = resolve_expression ( _else , base_schema , registry )
377
385
378
386
extension_uris = merge_extension_uris (
379
387
* [b [0 ].extension_uris for b in bound_ifs ],
@@ -413,3 +421,169 @@ def resolve(
413
421
)
414
422
415
423
return resolve
424
+
425
+ def switch (match : ExtendedExpressionOrUnbound ,
426
+ ifs : Iterable [tuple [ExtendedExpressionOrUnbound , ExtendedExpressionOrUnbound ]],
427
+ _else : ExtendedExpressionOrUnbound ):
428
+ """Builds a resolver for ExtendedExpression containing a switch expression"""
429
+ def resolve (
430
+ base_schema : stp .NamedStruct , registry : ExtensionRegistry
431
+ ) -> stee .ExtendedExpression :
432
+ bound_match = resolve_expression (match , base_schema , registry )
433
+ bound_ifs = [
434
+ (
435
+ resolve_expression (a , base_schema , registry ),
436
+ resolve_expression (b , base_schema , registry )
437
+ ) for a , b in ifs ]
438
+ bound_else = resolve_expression (_else , base_schema , registry )
439
+
440
+ extension_uris = merge_extension_uris (
441
+ bound_match .extension_uris ,
442
+ * [b .extension_uris for _ , b in bound_ifs ],
443
+ bound_else .extension_uris
444
+ )
445
+
446
+ extensions = merge_extension_declarations (
447
+ bound_match .extensions ,
448
+ * [b .extensions for _ , b in bound_ifs ],
449
+ bound_else .extensions
450
+ )
451
+
452
+ return stee .ExtendedExpression (
453
+ referred_expr = [
454
+ stee .ExpressionReference (
455
+ expression = stalg .Expression (
456
+ switch_expression = stalg .Expression .SwitchExpression (
457
+ match = bound_match .referred_expr [0 ].expression ,
458
+ ifs = [
459
+ stalg .Expression .SwitchExpression .IfValue (** {
460
+ 'if' : i .referred_expr [0 ].expression .literal ,
461
+ 'then' : t .referred_expr [0 ].expression
462
+ })
463
+ for i , t in bound_ifs
464
+ ],
465
+ ** {
466
+ 'else' : bound_else .referred_expr [0 ].expression
467
+ }
468
+ )
469
+ ),
470
+ output_names = ['switch' ] #TODO construct name from inputs
471
+ )
472
+ ],
473
+ base_schema = base_schema ,
474
+ extension_uris = extension_uris ,
475
+ extensions = extensions ,
476
+ )
477
+
478
+ return resolve
479
+
480
+ def singular_or_list (value : ExtendedExpressionOrUnbound , options : Iterable [ExtendedExpressionOrUnbound ]):
481
+ """Builds a resolver for ExtendedExpression containing a SingularOrList expression"""
482
+ def resolve (
483
+ base_schema : stp .NamedStruct , registry : ExtensionRegistry
484
+ ) -> stee .ExtendedExpression :
485
+ bound_value = resolve_expression (value , base_schema , registry )
486
+ bound_options = [resolve_expression (o , base_schema , registry ) for o in options ]
487
+
488
+ extension_uris = merge_extension_uris (
489
+ bound_value .extension_uris ,
490
+ * [b .extension_uris for b in bound_options ]
491
+ )
492
+
493
+ extensions = merge_extension_declarations (
494
+ bound_value .extensions ,
495
+ * [b .extensions for b in bound_options ]
496
+ )
497
+
498
+ return stee .ExtendedExpression (
499
+ referred_expr = [
500
+ stee .ExpressionReference (
501
+ expression = stalg .Expression (
502
+ singular_or_list = stalg .Expression .SingularOrList (
503
+ value = bound_value .referred_expr [0 ].expression ,
504
+ options = [
505
+ o .referred_expr [0 ].expression
506
+ for o in bound_options
507
+ ]
508
+ )
509
+ ),
510
+ output_names = ['singular_or_list' ] #TODO construct name from inputs
511
+ )
512
+ ],
513
+ base_schema = base_schema ,
514
+ extension_uris = extension_uris ,
515
+ extensions = extensions ,
516
+ )
517
+
518
+ return resolve
519
+
520
+ def multi_or_list (value : Iterable [ExtendedExpressionOrUnbound ], options : Iterable [Iterable [ExtendedExpressionOrUnbound ]]):
521
+ """Builds a resolver for ExtendedExpression containing a MultiOrList expression"""
522
+ def resolve (
523
+ base_schema : stp .NamedStruct , registry : ExtensionRegistry
524
+ ) -> stee .ExtendedExpression :
525
+ bound_value = [resolve_expression (e , base_schema , registry ) for e in value ]
526
+ bound_options = [
527
+ [resolve_expression (e , base_schema , registry ) for e in o ] for o in options
528
+ ]
529
+
530
+ extension_uris = merge_extension_uris (
531
+ * [b .extension_uris for b in bound_value ],
532
+ * [e .extension_uris for b in bound_options for e in b ],
533
+ )
534
+
535
+ extensions = merge_extension_uris (
536
+ * [b .extensions for b in bound_value ],
537
+ * [e .extensions for b in bound_options for e in b ],
538
+ )
539
+
540
+ return stee .ExtendedExpression (
541
+ referred_expr = [
542
+ stee .ExpressionReference (
543
+ expression = stalg .Expression (
544
+ multi_or_list = stalg .Expression .MultiOrList (
545
+ value = [e .referred_expr [0 ].expression for e in bound_value ],
546
+ options = [
547
+ stalg .Expression .MultiOrList .Record (
548
+ fields = [e .referred_expr [0 ].expression for e in option ]
549
+ )
550
+ for option in bound_options
551
+ ]
552
+ )
553
+ ),
554
+ output_names = ['multi_or_list' ] #TODO construct name from inputs
555
+ )
556
+ ],
557
+ base_schema = base_schema ,
558
+ extension_uris = extension_uris ,
559
+ extensions = extensions ,
560
+ )
561
+
562
+ return resolve
563
+
564
+ def cast (input : ExtendedExpressionOrUnbound , type : stp .Type ):
565
+ """Builds a resolver for ExtendedExpression containing a cast expression"""
566
+ def resolve (
567
+ base_schema : stp .NamedStruct , registry : ExtensionRegistry
568
+ ) -> stee .ExtendedExpression :
569
+ bound_input = resolve_expression (input , base_schema , registry )
570
+
571
+ return stee .ExtendedExpression (
572
+ referred_expr = [
573
+ stee .ExpressionReference (
574
+ expression = stalg .Expression (
575
+ cast = stalg .Expression .Cast (
576
+ input = bound_input .referred_expr [0 ].expression ,
577
+ type = type ,
578
+ failure_behavior = stalg .Expression .Cast .FAILURE_BEHAVIOR_RETURN_NULL
579
+ )
580
+ ),
581
+ output_names = ['cast' ] #TODO construct name from inputs
582
+ )
583
+ ],
584
+ base_schema = base_schema ,
585
+ extension_uris = bound_input .extension_uris ,
586
+ extensions = bound_input .extensions ,
587
+ )
588
+
589
+ return resolve
0 commit comments