Skip to content

Commit 3f3f59e

Browse files
authored
Merge pull request #160 from djarecka/list_inp
list as a single element of an input
2 parents f9e72bf + a872608 commit 3f3f59e

File tree

3 files changed

+89
-2
lines changed

3 files changed

+89
-2
lines changed

pydra/engine/helpers.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -11,11 +11,13 @@
1111
from .specs import Runtime, File
1212

1313

14-
def ensure_list(obj):
14+
def ensure_list(obj, tuple2list=False):
1515
if obj is None:
1616
return []
1717
if isinstance(obj, list):
1818
return obj
19+
elif tuple2list and isinstance(obj, tuple):
20+
return list(obj)
1921
return [obj]
2022

2123

pydra/engine/task.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -240,7 +240,7 @@ def command_args(self):
240240
if value is not True:
241241
break
242242
else:
243-
cmd_add += ensure_list(value)
243+
cmd_add += ensure_list(value, tuple2list=True)
244244
if cmd_add is not None:
245245
pos_args.append((pos, cmd_add))
246246
# sorting all elements of the command

pydra/engine/tests/test_node_task.py

+85
Original file line numberDiff line numberDiff line change
@@ -642,6 +642,91 @@ def test_task_state_2(plugin):
642642
assert nn.output_dir == []
643643

644644

645+
@pytest.mark.parametrize("plugin", Plugins)
646+
def test_task_state_3(plugin):
647+
""" task with a tuple as an input, and a simple splitter """
648+
nn = moment(name="NA", n=3, lst=[(2, 3, 4), (1, 2, 3)]).split(splitter="lst")
649+
assert np.allclose(nn.inputs.n, 3)
650+
assert np.allclose(nn.inputs.lst, [[2, 3, 4], [1, 2, 3]])
651+
assert nn.state.splitter == "NA.lst"
652+
653+
with Submitter(plugin=plugin) as sub:
654+
sub(nn)
655+
656+
# checking the results
657+
results = nn.result()
658+
for i, expected in enumerate([33, 12]):
659+
assert results[i].output.out == expected
660+
# checking the output_dir
661+
assert nn.output_dir
662+
for odir in nn.output_dir:
663+
assert odir.exists()
664+
665+
666+
@pytest.mark.parametrize("plugin", Plugins)
667+
def test_task_state_4(plugin):
668+
""" task with a tuple as an input, and the variable is part of the scalar splitter"""
669+
nn = moment(name="NA", n=[1, 3], lst=[(2, 3, 4), (1, 2, 3)]).split(
670+
splitter=("n", "lst")
671+
)
672+
assert np.allclose(nn.inputs.n, [1, 3])
673+
assert np.allclose(nn.inputs.lst, [[2, 3, 4], [1, 2, 3]])
674+
assert nn.state.splitter == ("NA.n", "NA.lst")
675+
676+
with Submitter(plugin=plugin) as sub:
677+
sub(nn)
678+
679+
# checking the results
680+
results = nn.result()
681+
for i, expected in enumerate([3, 12]):
682+
assert results[i].output.out == expected
683+
# checking the output_dir
684+
assert nn.output_dir
685+
for odir in nn.output_dir:
686+
assert odir.exists()
687+
688+
689+
@pytest.mark.parametrize("plugin", Plugins)
690+
def test_task_state_4_exception(plugin):
691+
""" task with a tuple as an input, and the variable is part of the scalar splitter
692+
the shapes are not matching, so exception should be raised
693+
"""
694+
nn = moment(name="NA", n=[1, 3, 3], lst=[(2, 3, 4), (1, 2, 3)]).split(
695+
splitter=("n", "lst")
696+
)
697+
assert np.allclose(nn.inputs.n, [1, 3, 3])
698+
assert np.allclose(nn.inputs.lst, [[2, 3, 4], [1, 2, 3]])
699+
assert nn.state.splitter == ("NA.n", "NA.lst")
700+
701+
with pytest.raises(Exception) as excinfo:
702+
with Submitter(plugin=plugin) as sub:
703+
sub(nn)
704+
assert "shape" in str(excinfo.value)
705+
706+
707+
@pytest.mark.parametrize("plugin", Plugins)
708+
def test_task_state_5(plugin):
709+
""" ask with a tuple as an input, and the variable is part of the outer splitter """
710+
nn = moment(name="NA", n=[1, 3], lst=[(2, 3, 4), (1, 2, 3)]).split(
711+
splitter=["n", "lst"]
712+
)
713+
assert np.allclose(nn.inputs.n, [1, 3])
714+
assert np.allclose(nn.inputs.lst, [[2, 3, 4], [1, 2, 3]])
715+
assert nn.state.splitter == ["NA.n", "NA.lst"]
716+
717+
with Submitter(plugin=plugin) as sub:
718+
sub(nn)
719+
720+
# checking the results
721+
results = nn.result()
722+
for i, expected in enumerate([3, 2, 33, 12]):
723+
assert results[i].output.out == expected
724+
# checking the output_dir
725+
assert nn.output_dir
726+
for odir in nn.output_dir:
727+
assert odir.exists()
728+
729+
645730
@pytest.mark.parametrize("plugin", Plugins)
646731
def test_task_state_comb_1(plugin):
647732
""" task with the simplest splitter and combiner"""

0 commit comments

Comments
 (0)