4242class PythonSourceGeneratorTransformer (ast .NodeTransformer ):
4343 def __init__ (self ):
4444 self ._depth = None
45+ self ._tuple_depths = []
4546 self ._id_scopes = {}
4647 self ._projection_stack = []
4748
@@ -316,9 +317,15 @@ def visit_Select(self, node):
316317 raise TypeError ('Lambda function in Select() must have exactly one argument, found '
317318 + len (node .selector .args .args ))
318319 self .visit (node .source )
320+ if self ._depth in self ._tuple_depths :
321+ at_tuple = True
322+ original_source_rep = self .get_rep (node .source )
323+ node .source .rep = 'x'
324+ else :
325+ at_tuple = False
319326 self ._depth += 1
320327 self ._projection_stack .append (node .selector .args .args [0 ].arg )
321- if self ._depth > 2 :
328+ if self ._depth > 2 and not at_tuple :
322329 rep1 , rep2 = self ._projection_stack [- 2 ], self ._projection_stack [- 1 ]
323330 lambda_node = ast .Lambda (args = ast .arguments (args = [ast .arg (arg = rep1 ),
324331 ast .arg (arg = rep2 )]),
@@ -329,12 +336,17 @@ def visit_Select(self, node):
329336 else :
330337 call_node = ast .Call (func = node .selector , args = [node .source ])
331338 call_rep = self .get_rep (call_node )
332- node . rep = ('(lambda selection: ak.zip(selection,'
333- + ' depth_limit=(None if len(selection) == 1 else ' + repr (self ._depth ) + '))'
334- + ' if not isinstance(selection, ak.Array)'
335- + ' else selection)(' + call_rep + ')' )
339+ select_rep = ('(lambda selection: ak.zip(selection,'
340+ + ' depth_limit=(None if len(selection) == 1 else ' + repr (self ._depth )
341+ + '))' + ' if not isinstance(selection, ak.Array)'
342+ + ' else selection)(' + call_rep + ')' )
336343 self ._depth -= 1
337344 self ._projection_stack .pop ()
345+ if at_tuple :
346+ node .rep = ('ak.zip([' + select_rep + ' for x in ak.unzip(' + original_source_rep
347+ + ')])' )
348+ else :
349+ node .rep = select_rep
338350 return node
339351
340352 def visit_SelectMany (self , node ):
@@ -419,6 +431,7 @@ def visit_Choose(self, node):
419431 self .visit (node .source )
420432 node .rep = ('ak.combinations(' + self .get_rep (node .source ) + ', ' + self .get_rep (node .n )
421433 + ', axis=' + repr (self ._depth ) + ')' )
434+ self ._tuple_depths .append (self ._depth + 1 )
422435 return node
423436
424437 def visit_OrderBy (self , node , ascending = True ):
0 commit comments