Skip to content

Commit 77e57d8

Browse files
gtopperGal Topper
andauthored
ParallelExecution step (#543)
* `ParallelExecution` step [ML-7689](https://iguazio.atlassian.net/browse/ML-7689) * Add runnable names, change output format * Various improvements and additions * Add error on duplicate runnable selection, similar to #545 * Make ParallelExecution friendly to mlrun serialization * Process and thread limits, always spawn, docs * Runtime, output format, docs * Export ParallelExecution * Improve error, add test * Type hints * Copy event body to protect against mutation, pass path * select_runnables can return None for all runnables * Add run_async, expose supported mechanisms * Fix test following async change * Fix result gathering on selection * Rename parameters, add comment, add kwargs * Add more docstrings, rename parameter * Add type annotations * Move mechanism list out of class * Add explicit max processes and threads defaults * Remove redundant ifs * Remove print, improve docstring --------- Co-authored-by: Gal Topper <galt@iguazio.com>
1 parent c1710a2 commit 77e57d8

File tree

3 files changed

+361
-1
lines changed

3 files changed

+361
-1
lines changed

storey/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,8 @@
5050
from .flow import Map # noqa: F401
5151
from .flow import MapClass # noqa: F401
5252
from .flow import MapWithState # noqa: F401
53+
from .flow import ParallelExecution # noqa: F401
54+
from .flow import ParallelExecutionRunnable # noqa: F401
5355
from .flow import Recover # noqa: F401
5456
from .flow import Reduce # noqa: F401
5557
from .flow import Rename # noqa: F401

storey/flow.py

Lines changed: 182 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@
1616
import copy
1717
import datetime
1818
import inspect
19+
import multiprocessing
20+
import os
1921
import pickle
2022
import time
2123
import traceback
@@ -1435,3 +1437,183 @@ def get_table(self, key):
14351437

14361438
def set_table(self, key, table):
14371439
self._tables[key] = table
1440+
1441+
1442+
class _ParallelExecutionRunnableResult:
1443+
def __init__(self, runnable_name: str, data: Any, runtime: float):
1444+
self.runnable_name = runnable_name
1445+
self.data = data
1446+
self.runtime = runtime
1447+
1448+
1449+
parallel_execution_mechanisms = ("multiprocessing", "threading", "asyncio", "naive")
1450+
1451+
1452+
class ParallelExecutionRunnable:
1453+
"""
1454+
Runnable to be run by a ParallelExecution step. Subclasses must assign execution_mechanism with one of:
1455+
* "multiprocessing" – To run in a separate process. This is appropriate for CPU or GPU intensive tasks as they
1456+
would otherwise block the main process by holding Python's Global Interpreter Lock (GIL).
1457+
* "threading" – To run in a separate thread. This is appropriate for blocking I/O tasks, as they would otherwise
1458+
block the main event loop thread.
1459+
* "asyncio" – To run in an asyncio task. This is appropriate for I/O tasks that use asyncio, allowing the event
1460+
loop to continue running while waiting for a response.
1461+
* "naive" – To run in the main event loop. This is appropriate only for trivial computation and/or file I/O. It
1462+
means that the runnable will not actually be run in parallel to anything else.
1463+
1464+
Subclasses must also override the run() method, or run_async() when execution_mechanism="asyncio", with user code
1465+
that handles the event and returns a result.
1466+
1467+
Subclasses may optionally override the init() method if the user's implementation of run() requires prior
1468+
initialization.
1469+
1470+
:param name: Runnable name
1471+
"""
1472+
1473+
execution_mechanism: Optional[str] = None
1474+
1475+
# ignore unused keyword arguments such as context which may be passed in by mlrun
1476+
def __init__(self, name: str, **kwargs):
1477+
if self.execution_mechanism not in parallel_execution_mechanisms:
1478+
raise ValueError(
1479+
"ParallelExecutionRunnable's execution_mechanism attribute must be overridden with one of: "
1480+
'"multiprocessing", "threading", "asyncio", "naive"'
1481+
)
1482+
self.name = name
1483+
1484+
def init(self) -> None:
1485+
"""Override this method to add initialization logic."""
1486+
pass
1487+
1488+
def run(self, body: Any, path: str) -> Any:
1489+
"""
1490+
Override this method with the code this runnable should run. If execution_mechanism is "asyncio", override
1491+
run_async() instead.
1492+
1493+
:param body: Event body
1494+
:param path: Event path
1495+
"""
1496+
return body
1497+
1498+
async def run_async(self, body: Any, path: str) -> Any:
1499+
"""
1500+
If execution_mechanism is "asyncio", override this method with the code this runnable should run. Otherwise,
1501+
override run() instead.
1502+
1503+
:param body: Event body
1504+
:param path: Event path
1505+
"""
1506+
return body
1507+
1508+
def _run(self, body: Any, path: str) -> Any:
1509+
start = time.monotonic()
1510+
body = self.run(body, path)
1511+
end = time.monotonic()
1512+
return _ParallelExecutionRunnableResult(self.name, body, end - start)
1513+
1514+
async def _async_run(self, body: Any, path: str) -> Any:
1515+
start = time.monotonic()
1516+
body = await self.run_async(body, path)
1517+
end = time.monotonic()
1518+
return _ParallelExecutionRunnableResult(self.name, body, end - start)
1519+
1520+
1521+
class ParallelExecution(Flow):
1522+
"""
1523+
Runs multiple jobs in parallel for each event.
1524+
1525+
:param runnables: A list of ParallelExecutionRunnable instances.
1526+
:param max_processes: Maximum number of processes to spawn. Defaults to the number of available CPUs, or 16 if
1527+
number of CPUs can't be determined.
1528+
:param max_threads: Maximum number of threads to start. Defaults to 32.
1529+
"""
1530+
1531+
def __init__(
1532+
self,
1533+
runnables: list[ParallelExecutionRunnable],
1534+
max_processes: Optional[int] = None,
1535+
max_threads: Optional[int] = None,
1536+
**kwargs,
1537+
):
1538+
super().__init__(**kwargs)
1539+
1540+
if not runnables:
1541+
raise ValueError("ParallelExecution cannot be instantiated without at least one runnable")
1542+
1543+
self.runnables = runnables
1544+
self._runnable_by_name = {}
1545+
1546+
self.max_processes = max_processes or os.cpu_count() or 16
1547+
self.max_threads = max_threads or 32
1548+
1549+
def select_runnables(self, event) -> Optional[Union[list[str], list[ParallelExecutionRunnable]]]:
1550+
"""
1551+
Given an event, returns a list of runnables (or a list of runnable names) to execute on it. It can also return
1552+
None, in which case all runnables are executed on the event, which is also the default.
1553+
1554+
:param event: Event object
1555+
"""
1556+
pass
1557+
1558+
def _init(self):
1559+
super()._init()
1560+
num_processes = 0
1561+
num_threads = 0
1562+
for runnable in self.runnables:
1563+
if runnable.name in self._runnable_by_name:
1564+
raise ValueError(f"ParallelExecutionRunnable name '{runnable.name}' is not unique")
1565+
self._runnable_by_name[runnable.name] = runnable
1566+
runnable.init()
1567+
if runnable.execution_mechanism == "multiprocessing":
1568+
num_processes += 1
1569+
elif runnable.execution_mechanism == "threading":
1570+
num_threads += 1
1571+
elif runnable.execution_mechanism not in ("asyncio", "naive"):
1572+
raise ValueError(f"Unsupported execution mechanism: {runnable.execution_mechanism}")
1573+
1574+
# enforce max
1575+
num_processes = min(num_processes, self.max_processes)
1576+
num_threads = min(num_threads, self.max_threads)
1577+
1578+
self._executors = {}
1579+
if num_processes:
1580+
mp_context = multiprocessing.get_context("spawn")
1581+
self._executors["multiprocessing"] = ProcessPoolExecutor(max_workers=num_processes, mp_context=mp_context)
1582+
if num_threads:
1583+
self._executors["threading"] = ThreadPoolExecutor(max_workers=num_threads)
1584+
1585+
async def _do(self, event):
1586+
if event is _termination_obj:
1587+
return await self._do_downstream(_termination_obj)
1588+
else:
1589+
runnables = self.select_runnables(event)
1590+
if runnables is None:
1591+
runnables = self.runnables
1592+
futures = []
1593+
runnables_encountered = set()
1594+
for runnable in runnables:
1595+
if isinstance(runnable, str):
1596+
runnable = self._runnable_by_name[runnable]
1597+
if id(runnable) in runnables_encountered:
1598+
raise ValueError(f"select_runnables() returned more than one outlet named '{runnable.name}'")
1599+
input = event.body if runnable.execution_mechanism == "multiprocessing" else copy.deepcopy(event.body)
1600+
runnables_encountered.add(id(runnable))
1601+
if runnable.execution_mechanism == "asyncio":
1602+
future = asyncio.get_running_loop().create_task(runnable._async_run(input, event.path))
1603+
elif runnable.execution_mechanism == "naive":
1604+
future = asyncio.get_running_loop().create_future()
1605+
future.set_result(runnable._run(input, event.path))
1606+
else:
1607+
executor = self._executors[runnable.execution_mechanism]
1608+
future = asyncio.get_running_loop().run_in_executor(
1609+
executor,
1610+
runnable._run,
1611+
input,
1612+
event.path,
1613+
)
1614+
futures.append(future)
1615+
results: list[_ParallelExecutionRunnableResult] = await asyncio.gather(*futures)
1616+
event.body = {"input": event.body, "results": {}}
1617+
for result in results:
1618+
event.body["results"][result.runnable_name] = {"runtime": result.runtime, "output": result.data}
1619+
return await self._do_downstream(event)

tests/test_flow.py

Lines changed: 177 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,14 @@
7070
V3ioDriver,
7171
build_flow,
7272
)
73-
from storey.flow import Context, ReifyMetadata, Rename, _ConcurrentJobExecution
73+
from storey.flow import (
74+
Context,
75+
ParallelExecution,
76+
ParallelExecutionRunnable,
77+
ReifyMetadata,
78+
Rename,
79+
_ConcurrentJobExecution,
80+
)
7481

7582

7683
class ATestException(Exception):
@@ -4686,3 +4693,172 @@ def test_filters_type():
46864693
additional_filters=[[("city", "=", "Tel Aviv")], [("age", ">=", "40")]],
46874694
filter_column="start_time",
46884695
)
4696+
4697+
4698+
class RunnableBusyWait(ParallelExecutionRunnable):
4699+
execution_mechanism = "multiprocessing"
4700+
_result = 0
4701+
4702+
def init(self):
4703+
self._result = 1
4704+
4705+
def run(self, data, path):
4706+
start = time.monotonic()
4707+
while time.monotonic() - start < 1:
4708+
pass
4709+
return self._result
4710+
4711+
4712+
class RunnableSleep(ParallelExecutionRunnable):
4713+
execution_mechanism = "threading"
4714+
_result = 0
4715+
4716+
def init(self):
4717+
self._result = 1
4718+
4719+
def run(self, data, path):
4720+
time.sleep(1)
4721+
return self._result
4722+
4723+
4724+
class RunnableAsyncSleep(ParallelExecutionRunnable):
4725+
execution_mechanism = "asyncio"
4726+
_result = 0
4727+
4728+
def init(self):
4729+
self._result = 1
4730+
4731+
async def run_async(self, data, path):
4732+
await asyncio.sleep(1)
4733+
print(f"{self.name} returning {self._result}")
4734+
return self._result
4735+
4736+
4737+
class RunnableNaiveNoOp(ParallelExecutionRunnable):
4738+
execution_mechanism = "naive"
4739+
_result = 0
4740+
4741+
def init(self):
4742+
self._result = 1
4743+
4744+
def run(self, data, path):
4745+
return self._result
4746+
4747+
4748+
class RunnableWithError(ParallelExecutionRunnable):
4749+
execution_mechanism = "naive"
4750+
4751+
def run(self, data, path):
4752+
raise Exception("This shouldn't run!")
4753+
4754+
4755+
def test_parallel_execution_runnable_uniqueness():
4756+
runnables = [
4757+
RunnableBusyWait("x"),
4758+
RunnableBusyWait("x"),
4759+
]
4760+
parallel_execution = ParallelExecution(runnables)
4761+
with pytest.raises(ValueError, match="ParallelExecutionRunnable name 'x' is not unique"):
4762+
parallel_execution._init()
4763+
4764+
4765+
def test_select_runnable_uniqueness():
4766+
runnables = [
4767+
RunnableNaiveNoOp("x"),
4768+
RunnableNaiveNoOp("y"),
4769+
]
4770+
4771+
class MyParallelExecution(ParallelExecution):
4772+
def select_runnables(self, event):
4773+
return ["x", "x"]
4774+
4775+
parallel_execution = MyParallelExecution(runnables)
4776+
4777+
source = SyncEmitSource()
4778+
source.to(parallel_execution)
4779+
4780+
controller = source.run()
4781+
controller.emit(0)
4782+
controller.terminate()
4783+
with pytest.raises(ValueError, match=r"select_runnables\(\) returned more than one outlet named 'x'"):
4784+
controller.await_termination()
4785+
4786+
4787+
def test_parallel_execution():
4788+
runnables = [
4789+
RunnableWithError("error"),
4790+
RunnableBusyWait("busy1"),
4791+
RunnableBusyWait("busy2"),
4792+
RunnableSleep("sleep1"),
4793+
RunnableSleep("sleep2"),
4794+
RunnableAsyncSleep("asleep1"),
4795+
RunnableAsyncSleep("asleep2"),
4796+
RunnableAsyncSleep("naive"),
4797+
]
4798+
4799+
class MyParallelExecution(ParallelExecution):
4800+
def select_runnables(self, event):
4801+
return [runnable.name for runnable in runnables if runnable.name != "error"]
4802+
4803+
parallel_execution = MyParallelExecution(runnables)
4804+
reduce = Reduce([], lambda acc, x: acc + [x])
4805+
4806+
source = SyncEmitSource()
4807+
source.to(parallel_execution).to(reduce)
4808+
4809+
start = time.monotonic()
4810+
controller = source.run()
4811+
controller.emit(0)
4812+
controller.terminate()
4813+
termination_result = controller.await_termination()
4814+
end = time.monotonic()
4815+
4816+
assert end - start < 6
4817+
termination_result = termination_result[0]
4818+
assert termination_result.keys() == {"input", "results"}
4819+
assert termination_result["input"] == 0
4820+
results = termination_result["results"]
4821+
assert results.keys() == {"busy1", "busy2", "sleep1", "sleep2", "asleep1", "asleep2", "naive"}
4822+
for result in results.values():
4823+
assert result["output"] == 1
4824+
assert 1 < result["runtime"] < 2
4825+
4826+
4827+
def test_invalid_runnable():
4828+
with pytest.raises(
4829+
ValueError,
4830+
match="ParallelExecutionRunnable's execution_mechanism attribute must be overridden with one of: "
4831+
'"multiprocessing", "threading", "asyncio", "naive"',
4832+
):
4833+
ParallelExecutionRunnable("my_runnable")
4834+
4835+
4836+
class RunnableNaiveWithMutation(ParallelExecutionRunnable):
4837+
execution_mechanism = "naive"
4838+
4839+
def run(self, data, path):
4840+
data["n"] += 1
4841+
return data
4842+
4843+
4844+
def test_event_input_preservation():
4845+
runnables = [
4846+
RunnableNaiveWithMutation("x"),
4847+
]
4848+
reduce = Reduce([], lambda acc, x: acc + [x])
4849+
4850+
source = SyncEmitSource()
4851+
source.to(ParallelExecution(runnables)).to(reduce)
4852+
4853+
controller = source.run()
4854+
controller.emit({"n": 1})
4855+
controller.terminate()
4856+
termination_result = controller.await_termination()
4857+
termination_result = termination_result[0]
4858+
assert termination_result.keys() == {"input", "results"}
4859+
assert termination_result["input"] == {"n": 1}
4860+
results = termination_result["results"]
4861+
assert results.keys() == {"x"}
4862+
result = results["x"]
4863+
assert result.keys() == {"runtime", "output"}
4864+
assert result["output"] == {"n": 2}

0 commit comments

Comments
 (0)