Skip to content

Commit 81af3b8

Browse files
authored
Merge pull request #211 from djarecka/master
small changes in template_update (closes #205)
2 parents 73ac982 + 199a3f1 commit 81af3b8

File tree

5 files changed

+50
-22
lines changed

5 files changed

+50
-22
lines changed

pydra/engine/core.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -381,7 +381,7 @@ def _run(self, rerun=False, **kwargs):
381381
orig_inputs = attr.asdict(self.inputs)
382382
map_copyfiles = copyfile_input(self.inputs, self.output_dir)
383383
modified_inputs = template_update(self.inputs, map_copyfiles)
384-
if modified_inputs is not None:
384+
if modified_inputs:
385385
self.inputs = attr.evolve(self.inputs, **modified_inputs)
386386
self.audit.start_audit(odir)
387387
result = Result(output=None, runtime=None, errored=False)

pydra/engine/helpers_file.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -505,7 +505,7 @@ def template_update(inputs, map_copyfiles=None):
505505
f"output_file_template metadata for "
506506
"{fld.name} should be a string"
507507
)
508-
return {k: v for k, v in dict_.items() if getattr(inputs, k) != v}
508+
return {k: v for k, v in dict_.items() if getattr(inputs, k) is not v}
509509

510510

511511
def is_local_file(f):

pydra/engine/task.py

+14-9
Original file line numberDiff line numberDiff line change
@@ -176,15 +176,20 @@ def __init__(
176176
)
177177
else:
178178
if not isinstance(return_info, tuple):
179-
return_info = (return_info,)
180-
output_spec = SpecInfo(
181-
name="Output",
182-
fields=[
183-
("out{}".format(n + 1), t)
184-
for n, t in enumerate(return_info)
185-
],
186-
bases=(BaseSpec,),
187-
)
179+
output_spec = SpecInfo(
180+
name="Output",
181+
fields=[("out", return_info)],
182+
bases=(BaseSpec,),
183+
)
184+
else:
185+
output_spec = SpecInfo(
186+
name="Output",
187+
fields=[
188+
("out{}".format(n + 1), t)
189+
for n, t in enumerate(return_info)
190+
],
191+
bases=(BaseSpec,),
192+
)
188193
elif "return" in func.__annotations__:
189194
raise NotImplementedError("Branch not implemented")
190195
self.output_spec = output_spec

pydra/engine/tests/test_task.py

+24-11
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,17 @@ def test_output():
2727
assert res.output.out == 5
2828

2929

30+
def test_numpy():
31+
""" checking if mark.task works for numpy functions"""
32+
np = pytest.importorskip("numpy")
33+
fft = mark.annotate({"a": np.ndarray, "return": float})(np.fft.fft)
34+
fft = mark.task(fft)()
35+
arr = np.array([[1, 10], [2, 20]])
36+
fft.inputs.a = arr
37+
res = fft()
38+
assert np.allclose(np.fft.fft(arr), res.output.out)
39+
40+
3041
@pytest.mark.xfail(reason="cp.dumps(func) depends on the system/setup, TODO!!")
3142
def test_checksum():
3243
nn = funaddtwo(a=3)
@@ -38,7 +49,9 @@ def test_checksum():
3849

3950
def test_annotated_func():
4051
@mark.task
41-
def testfunc(a: int, b: float = 0.1) -> ty.NamedTuple("Output", [("out1", float)]):
52+
def testfunc(
53+
a: int, b: float = 0.1
54+
) -> ty.NamedTuple("Output", [("out_out", float)]):
4255
return a + b
4356

4457
funky = testfunc(a=1)
@@ -48,14 +61,14 @@ def testfunc(a: int, b: float = 0.1) -> ty.NamedTuple("Output", [("out1", float)
4861
assert getattr(funky.inputs, "a") == 1
4962
assert getattr(funky.inputs, "b") == 0.1
5063
assert getattr(funky.inputs, "_func") is not None
51-
assert set(funky.output_names) == set(["out1"])
64+
assert set(funky.output_names) == set(["out_out"])
5265
# assert funky.inputs.hash == '17772c3aec9540a8dd3e187eecd2301a09c9a25c6e371ddd86e31e3a1ecfeefa'
5366
assert funky.__class__.__name__ + "_" + funky.inputs.hash == funky.checksum
5467

5568
result = funky()
5669
assert hasattr(result, "output")
57-
assert hasattr(result.output, "out1")
58-
assert result.output.out1 == 1.1
70+
assert hasattr(result.output, "out_out")
71+
assert result.output.out_out == 1.1
5972

6073
assert os.path.exists(funky.cache_dir / funky.checksum / "_result.pklz")
6174
funky.result() # should not recompute
@@ -64,7 +77,7 @@ def testfunc(a: int, b: float = 0.1) -> ty.NamedTuple("Output", [("out1", float)
6477
assert funky.result() is None
6578
funky()
6679
result = funky.result()
67-
assert result.output.out1 == 2.1
80+
assert result.output.out_out == 2.1
6881

6982
help = funky.help(returnhelp=True)
7083
assert help == [
@@ -74,7 +87,7 @@ def testfunc(a: int, b: float = 0.1) -> ty.NamedTuple("Output", [("out1", float)
7487
"- b: float (default: 0.1)",
7588
"- _func: str",
7689
"Output Parameters:",
77-
"- out1: float",
90+
"- out_out: float",
7891
]
7992

8093

@@ -150,13 +163,13 @@ def testfunc(a, b) -> int:
150163
assert getattr(funky.inputs, "a") == 10
151164
assert getattr(funky.inputs, "b") == 20
152165
assert getattr(funky.inputs, "_func") is not None
153-
assert set(funky.output_names) == set(["out1"])
166+
assert set(funky.output_names) == set(["out"])
154167
assert funky.__class__.__name__ + "_" + funky.inputs.hash == funky.checksum
155168

156169
result = funky()
157170
assert hasattr(result, "output")
158-
assert hasattr(result.output, "out1")
159-
assert result.output.out1 == 30
171+
assert hasattr(result.output, "out")
172+
assert result.output.out == 30
160173

161174
assert os.path.exists(funky.cache_dir / funky.checksum / "_result.pklz")
162175

@@ -165,7 +178,7 @@ def testfunc(a, b) -> int:
165178
assert funky.result() is None
166179
funky()
167180
result = funky.result()
168-
assert result.output.out1 == 31
181+
assert result.output.out == 31
169182
help = funky.help(returnhelp=True)
170183

171184
assert help == [
@@ -175,7 +188,7 @@ def testfunc(a, b) -> int:
175188
"- b: _empty",
176189
"- _func: str",
177190
"Output Parameters:",
178-
"- out1: int",
191+
"- out: int",
179192
]
180193

181194

pydra/mark/tests/test_functions.py

+10
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,16 @@ def square(in_val):
8787
assert res.output.squared == 4.0
8888

8989

90+
def test_return_halfannotated_annotated_task():
91+
@task
92+
@annotate({"in_val": float, "return": float})
93+
def square(in_val):
94+
return in_val ** 2
95+
96+
res = square(in_val=2.0)()
97+
assert res.output.out == 4.0
98+
99+
90100
def test_return_annotated_task_multiple_output():
91101
@task
92102
@annotate({"in_val": float, "return": {"squared": float, "cubed": float}})

0 commit comments

Comments
 (0)