Skip to content

Commit 1100d7c

Browse files
authored
Merge pull request #507 from tclose/serial-worker-fix
Fixes SerialWorker Implementation
2 parents 4db192a + 8033ff7 commit 1100d7c

File tree

4 files changed

+20
-21
lines changed

4 files changed

+20
-21
lines changed

pydra/engine/specs.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -567,7 +567,7 @@ def _field_metadata(
567567
if "mandatory" in fld.metadata:
568568
if fld.metadata["mandatory"]:
569569
raise Exception(
570-
f"mandatory output for variable {fld.name} does not exit"
570+
f"mandatory output for variable {fld.name} does not exist"
571571
)
572572
return attr.NOTHING
573573
return val

pydra/engine/tests/test_shelltask.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -4453,7 +4453,7 @@ def test_shell_cmd_non_existing_outputs_4(tmpdir):
44534453
# An exception should be raised because the second mandatory output does not exist
44544454
with pytest.raises(Exception) as excinfo:
44554455
shelly()
4456-
assert "mandatory output for variable out_2 does not exit" == str(excinfo.value)
4456+
assert "mandatory output for variable out_2 does not exist" == str(excinfo.value)
44574457
# checking if the first output was created
44584458
assert (Path(shelly.output_dir) / Path("test_1.nii")).exists()
44594459

pydra/engine/tests/test_submitter.py

+7
Original file line numberDiff line numberDiff line change
@@ -175,6 +175,13 @@ def test_wf_with_state(plugin_dask_opt, tmpdir):
175175
assert res[2].output.out == 5
176176

177177

178+
def test_serial_wf():
179+
# Use serial plugin to execute workflow instead of CF
180+
wf = gen_basic_wf()
181+
res = wf(plugin="serial")
182+
assert res.output.out == 9
183+
184+
178185
@pytest.mark.skipif(not slurm_available, reason="slurm not installed")
179186
def test_slurm_wf(tmpdir):
180187
wf = gen_basic_wf()

pydra/engine/workers.py

+11-19
Original file line numberDiff line numberDiff line change
@@ -116,38 +116,30 @@ async def fetch_finished(self, futures):
116116
return pending.union(unqueued)
117117

118118

119-
class SerialPool:
120-
"""A simple class to imitate a pool executor of concurrent futures."""
121-
122-
def submit(self, interface, **kwargs):
123-
"""Send new task."""
124-
self.res = interface(**kwargs)
125-
126-
def result(self):
127-
"""Get the result of a task."""
128-
return self.res
129-
130-
def done(self):
131-
"""Return whether the task is finished."""
132-
return True
133-
134-
135119
class SerialWorker(Worker):
136120
"""A worker to execute linearly."""
137121

138122
def __init__(self):
139123
"""Initialize worker."""
140124
logger.debug("Initialize SerialWorker")
141-
self.pool = SerialPool()
142125

143126
def run_el(self, interface, rerun=False, **kwargs):
144127
"""Run a task."""
145-
self.pool.submit(interface=interface, rerun=rerun, **kwargs)
146-
return self.pool
128+
return self.exec_serial(interface, rerun=rerun)
147129

148130
def close(self):
149131
"""Return whether the task is finished."""
150132

133+
async def exec_serial(self, runnable, rerun=False):
134+
return runnable()
135+
136+
async def fetch_finished(self, futures):
137+
await asyncio.gather(*futures)
138+
return set([])
139+
140+
# async def fetch_finished(self, futures):
141+
# return await asyncio.wait(futures)
142+
151143

152144
class ConcurrentFuturesWorker(Worker):
153145
"""A worker to execute in parallel using Python's concurrent futures."""

0 commit comments

Comments
 (0)