|
16 | 16 | import copy |
17 | 17 | import datetime |
18 | 18 | import inspect |
| 19 | +import multiprocessing |
| 20 | +import os |
19 | 21 | import pickle |
20 | 22 | import time |
21 | 23 | import traceback |
@@ -1435,3 +1437,183 @@ def get_table(self, key): |
1435 | 1437 |
|
1436 | 1438 | def set_table(self, key, table): |
1437 | 1439 | self._tables[key] = table |
| 1440 | + |
| 1441 | + |
| 1442 | +class _ParallelExecutionRunnableResult: |
| 1443 | + def __init__(self, runnable_name: str, data: Any, runtime: float): |
| 1444 | + self.runnable_name = runnable_name |
| 1445 | + self.data = data |
| 1446 | + self.runtime = runtime |
| 1447 | + |
| 1448 | + |
| 1449 | +parallel_execution_mechanisms = ("multiprocessing", "threading", "asyncio", "naive") |
| 1450 | + |
| 1451 | + |
| 1452 | +class ParallelExecutionRunnable: |
| 1453 | + """ |
| 1454 | + Runnable to be run by a ParallelExecution step. Subclasses must assign execution_mechanism with one of: |
| 1455 | + * "multiprocessing" – To run in a separate process. This is appropriate for CPU or GPU intensive tasks as they |
| 1456 | + would otherwise block the main process by holding Python's Global Interpreter Lock (GIL). |
| 1457 | + * "threading" – To run in a separate thread. This is appropriate for blocking I/O tasks, as they would otherwise |
| 1458 | + block the main event loop thread. |
| 1459 | + * "asyncio" – To run in an asyncio task. This is appropriate for I/O tasks that use asyncio, allowing the event |
| 1460 | + loop to continue running while waiting for a response. |
| 1461 | + * "naive" – To run in the main event loop. This is appropriate only for trivial computation and/or file I/O. It |
| 1462 | + means that the runnable will not actually be run in parallel to anything else. |
| 1463 | +
|
| 1464 | + Subclasses must also override the run() method, or run_async() when execution_mechanism="asyncio", with user code |
| 1465 | + that handles the event and returns a result. |
| 1466 | +
|
| 1467 | + Subclasses may optionally override the init() method if the user's implementation of run() requires prior |
| 1468 | + initialization. |
| 1469 | +
|
| 1470 | + :param name: Runnable name |
| 1471 | + """ |
| 1472 | + |
| 1473 | + execution_mechanism: Optional[str] = None |
| 1474 | + |
| 1475 | + # ignore unused keyword arguments such as context which may be passed in by mlrun |
| 1476 | + def __init__(self, name: str, **kwargs): |
| 1477 | + if self.execution_mechanism not in parallel_execution_mechanisms: |
| 1478 | + raise ValueError( |
| 1479 | + "ParallelExecutionRunnable's execution_mechanism attribute must be overridden with one of: " |
| 1480 | + '"multiprocessing", "threading", "asyncio", "naive"' |
| 1481 | + ) |
| 1482 | + self.name = name |
| 1483 | + |
| 1484 | + def init(self) -> None: |
| 1485 | + """Override this method to add initialization logic.""" |
| 1486 | + pass |
| 1487 | + |
| 1488 | + def run(self, body: Any, path: str) -> Any: |
| 1489 | + """ |
| 1490 | + Override this method with the code this runnable should run. If execution_mechanism is "asyncio", override |
| 1491 | + run_async() instead. |
| 1492 | +
|
| 1493 | + :param body: Event body |
| 1494 | + :param path: Event path |
| 1495 | + """ |
| 1496 | + return body |
| 1497 | + |
| 1498 | + async def run_async(self, body: Any, path: str) -> Any: |
| 1499 | + """ |
| 1500 | + If execution_mechanism is "asyncio", override this method with the code this runnable should run. Otherwise, |
| 1501 | + override run() instead. |
| 1502 | +
|
| 1503 | + :param body: Event body |
| 1504 | + :param path: Event path |
| 1505 | + """ |
| 1506 | + return body |
| 1507 | + |
| 1508 | + def _run(self, body: Any, path: str) -> Any: |
| 1509 | + start = time.monotonic() |
| 1510 | + body = self.run(body, path) |
| 1511 | + end = time.monotonic() |
| 1512 | + return _ParallelExecutionRunnableResult(self.name, body, end - start) |
| 1513 | + |
| 1514 | + async def _async_run(self, body: Any, path: str) -> Any: |
| 1515 | + start = time.monotonic() |
| 1516 | + body = await self.run_async(body, path) |
| 1517 | + end = time.monotonic() |
| 1518 | + return _ParallelExecutionRunnableResult(self.name, body, end - start) |
| 1519 | + |
| 1520 | + |
| 1521 | +class ParallelExecution(Flow): |
| 1522 | + """ |
| 1523 | + Runs multiple jobs in parallel for each event. |
| 1524 | +
|
| 1525 | + :param runnables: A list of ParallelExecutionRunnable instances. |
| 1526 | + :param max_processes: Maximum number of processes to spawn. Defaults to the number of available CPUs, or 16 if |
| 1527 | + number of CPUs can't be determined. |
| 1528 | + :param max_threads: Maximum number of threads to start. Defaults to 32. |
| 1529 | + """ |
| 1530 | + |
| 1531 | + def __init__( |
| 1532 | + self, |
| 1533 | + runnables: list[ParallelExecutionRunnable], |
| 1534 | + max_processes: Optional[int] = None, |
| 1535 | + max_threads: Optional[int] = None, |
| 1536 | + **kwargs, |
| 1537 | + ): |
| 1538 | + super().__init__(**kwargs) |
| 1539 | + |
| 1540 | + if not runnables: |
| 1541 | + raise ValueError("ParallelExecution cannot be instantiated without at least one runnable") |
| 1542 | + |
| 1543 | + self.runnables = runnables |
| 1544 | + self._runnable_by_name = {} |
| 1545 | + |
| 1546 | + self.max_processes = max_processes or os.cpu_count() or 16 |
| 1547 | + self.max_threads = max_threads or 32 |
| 1548 | + |
| 1549 | + def select_runnables(self, event) -> Optional[Union[list[str], list[ParallelExecutionRunnable]]]: |
| 1550 | + """ |
| 1551 | + Given an event, returns a list of runnables (or a list of runnable names) to execute on it. It can also return |
| 1552 | + None, in which case all runnables are executed on the event, which is also the default. |
| 1553 | +
|
| 1554 | + :param event: Event object |
| 1555 | + """ |
| 1556 | + pass |
| 1557 | + |
| 1558 | + def _init(self): |
| 1559 | + super()._init() |
| 1560 | + num_processes = 0 |
| 1561 | + num_threads = 0 |
| 1562 | + for runnable in self.runnables: |
| 1563 | + if runnable.name in self._runnable_by_name: |
| 1564 | + raise ValueError(f"ParallelExecutionRunnable name '{runnable.name}' is not unique") |
| 1565 | + self._runnable_by_name[runnable.name] = runnable |
| 1566 | + runnable.init() |
| 1567 | + if runnable.execution_mechanism == "multiprocessing": |
| 1568 | + num_processes += 1 |
| 1569 | + elif runnable.execution_mechanism == "threading": |
| 1570 | + num_threads += 1 |
| 1571 | + elif runnable.execution_mechanism not in ("asyncio", "naive"): |
| 1572 | + raise ValueError(f"Unsupported execution mechanism: {runnable.execution_mechanism}") |
| 1573 | + |
| 1574 | + # enforce max |
| 1575 | + num_processes = min(num_processes, self.max_processes) |
| 1576 | + num_threads = min(num_threads, self.max_threads) |
| 1577 | + |
| 1578 | + self._executors = {} |
| 1579 | + if num_processes: |
| 1580 | + mp_context = multiprocessing.get_context("spawn") |
| 1581 | + self._executors["multiprocessing"] = ProcessPoolExecutor(max_workers=num_processes, mp_context=mp_context) |
| 1582 | + if num_threads: |
| 1583 | + self._executors["threading"] = ThreadPoolExecutor(max_workers=num_threads) |
| 1584 | + |
| 1585 | + async def _do(self, event): |
| 1586 | + if event is _termination_obj: |
| 1587 | + return await self._do_downstream(_termination_obj) |
| 1588 | + else: |
| 1589 | + runnables = self.select_runnables(event) |
| 1590 | + if runnables is None: |
| 1591 | + runnables = self.runnables |
| 1592 | + futures = [] |
| 1593 | + runnables_encountered = set() |
| 1594 | + for runnable in runnables: |
| 1595 | + if isinstance(runnable, str): |
| 1596 | + runnable = self._runnable_by_name[runnable] |
| 1597 | + if id(runnable) in runnables_encountered: |
| 1598 | + raise ValueError(f"select_runnables() returned more than one outlet named '{runnable.name}'") |
| 1599 | + input = event.body if runnable.execution_mechanism == "multiprocessing" else copy.deepcopy(event.body) |
| 1600 | + runnables_encountered.add(id(runnable)) |
| 1601 | + if runnable.execution_mechanism == "asyncio": |
| 1602 | + future = asyncio.get_running_loop().create_task(runnable._async_run(input, event.path)) |
| 1603 | + elif runnable.execution_mechanism == "naive": |
| 1604 | + future = asyncio.get_running_loop().create_future() |
| 1605 | + future.set_result(runnable._run(input, event.path)) |
| 1606 | + else: |
| 1607 | + executor = self._executors[runnable.execution_mechanism] |
| 1608 | + future = asyncio.get_running_loop().run_in_executor( |
| 1609 | + executor, |
| 1610 | + runnable._run, |
| 1611 | + input, |
| 1612 | + event.path, |
| 1613 | + ) |
| 1614 | + futures.append(future) |
| 1615 | + results: list[_ParallelExecutionRunnableResult] = await asyncio.gather(*futures) |
| 1616 | + event.body = {"input": event.body, "results": {}} |
| 1617 | + for result in results: |
| 1618 | + event.body["results"][result.runnable_name] = {"runtime": result.runtime, "output": result.data} |
| 1619 | + return await self._do_downstream(event) |
0 commit comments