Skip to content

Commit 937d9e1

Browse files
Merge pull request #84 from iris-hep/fix_83_fields-of-choose
Fix fields when used with `Choose`
2 parents ac39e6d + 3ae99d0 commit 937d9e1

File tree

2 files changed

+37
-5
lines changed

2 files changed

+37
-5
lines changed

func_adl_uproot/transformer.py

Lines changed: 18 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@
4242
class PythonSourceGeneratorTransformer(ast.NodeTransformer):
4343
def __init__(self):
4444
self._depth = None
45+
self._tuple_depths = []
4546
self._id_scopes = {}
4647
self._projection_stack = []
4748

@@ -316,9 +317,15 @@ def visit_Select(self, node):
316317
raise TypeError('Lambda function in Select() must have exactly one argument, found '
317318
+ len(node.selector.args.args))
318319
self.visit(node.source)
320+
if self._depth in self._tuple_depths:
321+
at_tuple = True
322+
original_source_rep = self.get_rep(node.source)
323+
node.source.rep = 'x'
324+
else:
325+
at_tuple = False
319326
self._depth += 1
320327
self._projection_stack.append(node.selector.args.args[0].arg)
321-
if self._depth > 2:
328+
if self._depth > 2 and not at_tuple:
322329
rep1, rep2 = self._projection_stack[-2], self._projection_stack[-1]
323330
lambda_node = ast.Lambda(args=ast.arguments(args=[ast.arg(arg=rep1),
324331
ast.arg(arg=rep2)]),
@@ -329,12 +336,17 @@ def visit_Select(self, node):
329336
else:
330337
call_node = ast.Call(func=node.selector, args=[node.source])
331338
call_rep = self.get_rep(call_node)
332-
node.rep = ('(lambda selection: ak.zip(selection,'
333-
+ ' depth_limit=(None if len(selection) == 1 else ' + repr(self._depth) + '))'
334-
+ ' if not isinstance(selection, ak.Array)'
335-
+ ' else selection)(' + call_rep + ')')
339+
select_rep = ('(lambda selection: ak.zip(selection,'
340+
+ ' depth_limit=(None if len(selection) == 1 else ' + repr(self._depth)
341+
+ '))' + ' if not isinstance(selection, ak.Array)'
342+
+ ' else selection)(' + call_rep + ')')
336343
self._depth -= 1
337344
self._projection_stack.pop()
345+
if at_tuple:
346+
node.rep = ('ak.zip([' + select_rep + ' for x in ak.unzip(' + original_source_rep
347+
+ ')])')
348+
else:
349+
node.rep = select_rep
338350
return node
339351

340352
def visit_SelectMany(self, node):
@@ -419,6 +431,7 @@ def visit_Choose(self, node):
419431
self.visit(node.source)
420432
node.rep = ('ak.combinations(' + self.get_rep(node.source) + ', ' + self.get_rep(node.n)
421433
+ ', axis=' + repr(self._depth) + ')')
434+
self._tuple_depths.append(self._depth + 1)
422435
return node
423436

424437
def visit_OrderBy(self, node, ascending=True):

tests/test_executor.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -326,6 +326,25 @@ def test_ast_executor_select_of_choose():
326326
assert ast_executor(python_ast).tolist() == [[], [1, 2, 5], []]
327327

328328

329+
def test_ast_executor_choose_zipped_dict():
330+
python_source = ("Select(EventDataset('tests/vectors_tree_file.root', 'tree'),"
331+
+ "lambda row: {'ints': row.int_vector_branch}.Zip().Choose(2))")
332+
python_ast = ast.parse(python_source)
333+
assert ast_executor(python_ast).tolist() == [[],
334+
[({'ints': -1}, {'ints': 2}),
335+
({'ints': -1}, {'ints': 3}),
336+
({'ints': 2}, {'ints': 3})],
337+
[]]
338+
339+
340+
def test_ast_executor_field_of_choose():
341+
python_source = ("Select(EventDataset('tests/vectors_tree_file.root', 'tree'),"
342+
+ "lambda row: {'int': row.int_vector_branch}.Zip().Choose(2)"
343+
+ '.Select(lambda pair: pair.Select(lambda record: record.int)))')
344+
python_ast = ast.parse(python_source)
345+
assert ast_executor(python_ast).tolist() == [[], [(-1, 2), (-1, 3), (2, 3)], []]
346+
347+
329348
def test_ast_executor_tofourvector():
330349
python_source = ("Select(EventDataset('tests/four-vector_tree_file.root', 'tree'),"
331350
+ "lambda row: Zip({'pt': row.pt_vector_branch, 'eta': row.eta_vector_branch,"

0 commit comments

Comments
 (0)