diff --git a/pyproject.toml b/pyproject.toml index 44b97956..fbc3c038 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -81,6 +81,10 @@ asyncio_mode = "auto" log_cli = true log_cli_level = "INFO" log_cli_format = "%(asctime)s [%(levelname)8s] %(message)s (%(filename)s:%(lineno)s)" +filterwarnings = [ + "ignore::DeprecationWarning:google\\..*", + "ignore::DeprecationWarning:importlib\\..*" +] [tool.isort] profile = "black" diff --git a/tests/update/serialized_handling_of_n_messages.py b/tests/update/serialized_handling_of_n_messages.py new file mode 100644 index 00000000..5c78af37 --- /dev/null +++ b/tests/update/serialized_handling_of_n_messages.py @@ -0,0 +1,95 @@ +import asyncio +import logging +import uuid +from dataclasses import dataclass +from unittest.mock import patch + +import temporalio.api.common.v1 +import temporalio.api.enums.v1 +import temporalio.api.update.v1 +import temporalio.api.workflowservice.v1 +from temporalio.client import Client, WorkflowHandle +from temporalio.worker import Worker +from temporalio.workflow import UpdateMethodMultiParam + +from update.serialized_handling_of_n_messages import ( + MessageProcessor, + Result, + get_current_time, +) + + +async def test_continue_as_new_doesnt_lose_updates(client: Client): + with patch( + "temporalio.workflow.Info.is_continue_as_new_suggested", return_value=True + ): + tq = str(uuid.uuid4()) + wf = await client.start_workflow( + MessageProcessor.run, id=str(uuid.uuid4()), task_queue=tq + ) + update_requests = [ + UpdateRequest(wf, MessageProcessor.process_message, i) for i in range(10) + ] + for req in update_requests: + await req.wait_until_admitted() + + async with Worker( + client, + task_queue=tq, + workflows=[MessageProcessor], + activities=[get_current_time], + ): + for req in update_requests: + update_result = await req.task + assert update_result.startswith(req.expected_result_prefix()) + + +@dataclass +class UpdateRequest: + wf_handle: WorkflowHandle + update: UpdateMethodMultiParam + sequence_number: int + + def __post_init__(self): + self.task = asyncio.Task[Result]( + self.wf_handle.execute_update(self.update, args=[self.arg], id=self.id) + ) + + async def wait_until_admitted(self): + while True: + try: + return await self._poll_update_non_blocking() + except Exception as err: + logging.warning(err) + + async def _poll_update_non_blocking(self): + req = temporalio.api.workflowservice.v1.PollWorkflowExecutionUpdateRequest( + namespace=self.wf_handle._client.namespace, + update_ref=temporalio.api.update.v1.UpdateRef( + workflow_execution=temporalio.api.common.v1.WorkflowExecution( + workflow_id=self.wf_handle.id, + run_id="", + ), + update_id=self.id, + ), + identity=self.wf_handle._client.identity, + ) + res = await self.wf_handle._client.workflow_service.poll_workflow_execution_update( + req + ) + # TODO: @cretz how do we work with these raw proto objects? + assert "stage: UPDATE_WORKFLOW_EXECUTION_LIFECYCLE_STAGE_ADMITTED" in str(res) + + @property + def arg(self) -> str: + return str(self.sequence_number) + + @property + def id(self) -> str: + return str(self.sequence_number) + + def expected_result_prefix(self) -> str: + # TODO: Currently the server does not send updates to the worker in order of admission When + # this is fixed (https://github.com/temporalio/temporal/pull/5831), we can make a stronger + # assertion about the activity numbers used to construct each result. + return f"{self.arg}-result" diff --git a/update/atomic_message_handlers_with_stateful_workflow.py b/update/atomic_message_handlers_with_stateful_workflow.py new file mode 100644 index 00000000..9fff00a9 --- /dev/null +++ b/update/atomic_message_handlers_with_stateful_workflow.py @@ -0,0 +1,222 @@ +import asyncio +from datetime import timedelta +import logging +from typing import Dict, List, Optional + +from temporalio import activity, common, workflow +from temporalio.client import Client, WorkflowHandle +from temporalio.worker import Worker + +@activity.defn +async def allocate_nodes_to_job(nodes: List[int], job_name: str): + print(f"Assigning nodes {nodes} to job {job_name}") + await asyncio.sleep(0.1) + +@activity.defn +async def deallocate_nodes_for_job(nodes: List[int], job_name: str): + print(f"Deallocating nodes {nodes} from job {job_name}") + await asyncio.sleep(0.1) + +@activity.defn +async def find_bad_nodes(nodes: List[int]) -> List[int]: + await asyncio.sleep(0.1) + bad_nodes = [n for n in nodes if n % 5 == 0] + print(f"Found bad nodes: {bad_nodes}") + return bad_nodes + +# This samples shows off +# - Making signal and update handlers only operate when the workflow is within a certain state +# (here between cluster_started and cluster_shutdown) +# - Using a lock to protect shared state shared by the workflow and its signal and update handlers +# interleaving writes +# - Running start_workflow with an initializer signal that you want to run before anything else. +@workflow.defn +class ClusterManager: + """ + A workflow to manage a cluster of compute nodes. + + The cluster is transitioned between operational and non-operational states by two signals: + `start_cluster` and `shutdown_cluster`. + + While it is active, the workflow maintains a mapping of nodes to assigned job, and exposes the + following API (implemented as updates): + + - allocate_n_nodes_to_job: attempt to find n free nodes, assign them to the job; return assigned node IDs + - delete_job: unassign any nodes assigned to job; return a success acknowledgement + - resize_job: assign or unassign nodes as needed; return assigned node IDs + + An API call made while the cluster is non-operational will block until the cluster is + operational. + + If an API call is made while another is in progress, it will block until all other thus-enqueued + requests are complete. + """ + + def __init__(self) -> None: + self.cluster_started = False + self.cluster_shutdown = False + self.nodes_lock = asyncio.Lock() + + @workflow.signal + async def start_cluster(self): + self.cluster_started = True + self.nodes : Dict[int, Optional[str]] = dict([(k, None) for k in range(25)]) + workflow.logger.info("Cluster started") + + @workflow.signal + async def shutdown_cluster(self): + await workflow.wait_condition(lambda: self.cluster_started) + self.cluster_shutdown = True + workflow.logger.info("Cluster shut down") + + @workflow.update + async def allocate_n_nodes_to_job(self, job_name: str, num_nodes: int, ) -> List[int]: + """ + Attempt to find n free nodes, assign them to the job, return assigned node IDs. + """ + await workflow.wait_condition(lambda: self.cluster_started) + assert not self.cluster_shutdown + + await self.nodes_lock.acquire() + try: + unassigned_nodes = [k for k, v in self.nodes.items() if v is None] + if len(unassigned_nodes) < num_nodes: + raise ValueError(f"Cannot allocate {num_nodes} nodes; have only {len(unassigned_nodes)} available") + assigned_nodes = unassigned_nodes[:num_nodes] + await self._allocate_nodes_to_job(assigned_nodes, job_name) + return assigned_nodes + finally: + self.nodes_lock.release() + + + async def _allocate_nodes_to_job(self, assigned_nodes: List[int], job_name: str): + await workflow.execute_activity( + allocate_nodes_to_job, args=[assigned_nodes, job_name], start_to_close_timeout=timedelta(seconds=10) + ) + for node in assigned_nodes: + self.nodes[node] = job_name + + + @workflow.update + async def delete_job(self, job_name: str) -> str: + """ + Unassign any nodes assigned to job; return a success acknowledgement. + """ + await workflow.wait_condition(lambda: self.cluster_started) + assert not self.cluster_shutdown + await self.nodes_lock.acquire() + try: + nodes_to_free = [k for k, v in self.nodes.items() if v == job_name] + await self._deallocate_nodes_for_job(nodes_to_free, job_name) + return "Done" + finally: + self.nodes_lock.release() + + async def _deallocate_nodes_for_job(self, nodes_to_free: List[int], job_name: str): + await workflow.execute_activity( + deallocate_nodes_for_job, args=[nodes_to_free, job_name], start_to_close_timeout=timedelta(seconds=10) + ) + for node in nodes_to_free: + self.nodes[node] = None + + + @workflow.update + async def resize_job(self, job_name: str, new_size: int) -> List[int]: + """ + Assign or unassign nodes as needed; return assigned node IDs. + """ + await workflow.wait_condition(lambda: self.cluster_started) + assert not self.cluster_shutdown + await self.nodes_lock.acquire() + try: + allocated_nodes = [k for k, v in self.nodes.items() if v == job_name] + delta = new_size - len(allocated_nodes) + if delta == 0: + return allocated_nodes + elif delta > 0: + unassigned_nodes = [k for k, v in self.nodes.items() if v is None] + if len(unassigned_nodes) < delta: + raise ValueError(f"Cannot allocate {delta} nodes; have only {len(unassigned_nodes)} available") + nodes_to_assign = unassigned_nodes[:delta] + await self._allocate_nodes_to_job(nodes_to_assign, job_name) + return allocated_nodes + nodes_to_assign + else: + nodes_to_deallocate = allocated_nodes[delta:] + await self._deallocate_nodes_for_job(nodes_to_deallocate, job_name) + return list(filter(lambda x: x not in nodes_to_deallocate, allocated_nodes)) + finally: + self.nodes_lock.release() + + async def perform_health_checks(self): + await self.nodes_lock.acquire() + try: + assigned_nodes = [k for k, v in self.nodes.items() if v is not None] + bad_nodes = await workflow.execute_activity(find_bad_nodes, assigned_nodes, start_to_close_timeout=timedelta(seconds=10)) + for node in bad_nodes: + self.nodes[node] = "BAD!" + finally: + self.nodes_lock.release() + + @workflow.run + async def run(self): + await workflow.wait_condition(lambda: self.cluster_started) + + while True: + try: + await workflow.wait_condition(lambda: self.cluster_shutdown, timeout=timedelta(seconds=1)) + except asyncio.TimeoutError: + pass + await self.perform_health_checks() + + # Now we can start allocating jobs to nodes + await workflow.wait_condition(lambda: self.cluster_shutdown) + + +async def do_cluster_lifecycle(wf: WorkflowHandle): + + allocation_updates = [] + for i in range(6): + + allocation_updates.append(wf.execute_update(ClusterManager.allocate_n_nodes_to_job, args=[f"job-{i}", 2])) + await asyncio.gather(*allocation_updates) + resize_updates = [] + for i in range(6): + resize_updates.append(wf.execute_update(ClusterManager.resize_job, args=[f"job-{i}", 4])) + await asyncio.gather(*resize_updates) + + deletion_updates = [] + for i in range(6): + deletion_updates.append(wf.execute_update(ClusterManager.delete_job, f"job-{i}")) + await asyncio.gather(*deletion_updates) + + await wf.signal(ClusterManager.shutdown_cluster) + print("Cluster shut down") + + + +async def main(): + client = await Client.connect("localhost:7233") + + async with Worker( + client, + task_queue="tq", + workflows=[ClusterManager], + activities=[allocate_nodes_to_job, deallocate_nodes_for_job, find_bad_nodes], + ): + wf = await client.start_workflow( + ClusterManager.run, + id="wid", + task_queue="tq", + id_reuse_policy=common.WorkflowIDReusePolicy.TERMINATE_IF_RUNNING, + start_signal='start_cluster', + + ) + await do_cluster_lifecycle(wf) + + +if __name__ == "__main__": + logging.basicConfig(level=logging.INFO) + asyncio.run(main()) + + + \ No newline at end of file diff --git a/update/job_runner_I1.py b/update/job_runner_I1.py new file mode 100644 index 00000000..05eb4646 --- /dev/null +++ b/update/job_runner_I1.py @@ -0,0 +1,183 @@ +import asyncio +from dataclasses import dataclass +from datetime import datetime, timedelta +import logging +from typing import Optional + +from temporalio import common, workflow, activity +from temporalio.client import Client, WorkflowHandle +from temporalio.worker import Worker + + +JobID = str + + +@dataclass +class Job: + id: JobID + depends_on: list[JobID] + after_time: Optional[int] + name: str + run: str + python_interpreter_version: Optional[str] + + +@dataclass +class JobOutput: + status: int + stdout: str + stderr: str + + +@workflow.defn +class JobRunner: + """ + Jobs must be executed in order dictated by job dependency graph (see `job.depends_on`) and + not before `job.after_time`. + """ + + def __init__(self) -> None: + self._pending_tasks = 0 + self.completed_tasks = set[JobID]() + self.handler_mutex = asyncio.Lock() + + def all_handlers_completed(self) -> bool: + # We are considering adding an API like `all_handlers_completed` to SDKs. We've added + # self._pending tasks to this workflow in lieu of it being built into the SDKs. + return not self._pending_tasks + + @workflow.run + async def run(self): + await workflow.wait_condition( + lambda: ( + workflow.info().is_continue_as_new_suggested() + and self.all_handlers_completed() + ) + ) + workflow.continue_as_new() + + def ready_to_execute(self, job: Job) -> bool: + if not set(job.depends_on) <= self.completed_tasks: + return False + if after_time := job.after_time: + if float(after_time) > workflow.now().timestamp(): + return False + return True + + @workflow.update + async def run_shell_script_job(self, job: Job) -> JobOutput: + self._pending_tasks += 1 + await workflow.wait_condition(lambda: self.ready_to_execute(job)) + await self.handler_mutex.acquire() + + try: + if security_errors := await workflow.execute_activity( + run_shell_script_security_linter, + args=[job.run], + start_to_close_timeout=timedelta(seconds=10), + ): + return JobOutput(status=1, stdout="", stderr=security_errors) + job_output = await workflow.execute_activity( + run_job, args=[job], start_to_close_timeout=timedelta(seconds=10) + ) + return job_output + finally: + # FIXME: unbounded memory usage + self.completed_tasks.add(job.id) + self.handler_mutex.release() + self._pending_tasks -= 1 + + @workflow.update + async def run_python_job(self, job: Job) -> JobOutput: + await workflow.wait_condition(lambda: self.ready_to_execute(job)) + await self.handler_mutex.acquire() + + try: + if not await workflow.execute_activity( + check_python_interpreter_version, + args=[job.python_interpreter_version], + start_to_close_timeout=timedelta(seconds=10), + ): + return JobOutput( + status=1, + stdout="", + stderr=f"Python interpreter version {job.python_interpreter_version} is not available", + ) + job_output = await workflow.execute_activity( + run_job, args=[job], start_to_close_timeout=timedelta(seconds=10) + ) + return job_output + finally: + self.handler_mutex.release() + + +@activity.defn +async def run_job(job: Job) -> JobOutput: + await asyncio.sleep(0.1) + stdout = f"Ran job {job.name} at {datetime.now()}" + print(stdout) + return JobOutput(status=0, stdout=stdout, stderr="") + + +@activity.defn +async def run_shell_script_security_linter(code: str) -> str: + # The user's organization requires that all shell scripts pass an in-house linter that checks + # for shell scripting constructions deemed insecure. + await asyncio.sleep(0.1) + return "" + + +@activity.defn +async def check_python_interpreter_version(version: str) -> bool: + await asyncio.sleep(0.1) + version_is_available = True + return version_is_available + + +async def app(wf: WorkflowHandle): + job_1 = Job( + id="1", + depends_on=[], + after_time=None, + name="should-run-first", + run="echo 'Hello world 1!'", + python_interpreter_version=None, + ) + job_2 = Job( + id="2", + depends_on=["1"], + after_time=None, + name="should-run-second", + run="print('Hello world 2!')", + python_interpreter_version=None, + ) + await asyncio.gather( + wf.execute_update(JobRunner.run_python_job, job_2), + wf.execute_update(JobRunner.run_shell_script_job, job_1), + ) + + +async def main(): + client = await Client.connect("localhost:7233") + async with Worker( + client, + task_queue="tq", + workflows=[JobRunner], + activities=[ + run_job, + run_shell_script_security_linter, + check_python_interpreter_version, + ], + ): + wf = await client.start_workflow( + JobRunner.run, + id="wid", + task_queue="tq", + id_reuse_policy=common.WorkflowIDReusePolicy.TERMINATE_IF_RUNNING, + ) + await app(wf) + + +if __name__ == "__main__": + logging.basicConfig(level=logging.INFO) + asyncio.run(main()) diff --git a/update/job_runner_I1_native.py b/update/job_runner_I1_native.py new file mode 100644 index 00000000..72807f24 --- /dev/null +++ b/update/job_runner_I1_native.py @@ -0,0 +1,252 @@ +import asyncio +from contextlib import asynccontextmanager +from dataclasses import dataclass +from datetime import datetime, timedelta +import inspect +import logging +from typing import Callable, Optional, Type + +from temporalio import common, workflow, activity +from temporalio.client import Client, WorkflowHandle +from temporalio.worker import Worker + +# This file contains a proposal for how the Python SDK could provide I1:WaitUntilReadyToExecute +# functionality to help users defer processing, control interleaving of handler coroutines, and +# ensure processing is complete before workflow completion. + + +JobID = str + + +@dataclass +class Job: + id: JobID + depends_on: list[JobID] + after_time: Optional[int] + name: str + run: str + python_interpreter_version: Optional[str] + + +@dataclass +class JobOutput: + status: int + stdout: str + stderr: str + + +## +## SDK internals toy prototype +## + +# TODO: use generics to satisfy serializable interface. Faking for now by using user-defined classes. +I = Job +O = JobOutput + +UpdateID = str +Workflow = Type + + +@dataclass +class Update: + id: UpdateID + arg: I + + +_sdk_internals_pending_tasks_count = 0 +_sdk_internals_handler_mutex = asyncio.Lock() + + +def _sdk_internals_all_handlers_completed(self) -> bool: + # We are considering adding an API like `all_handlers_completed` to SDKs. We've added + # self._pending tasks to this workflow in lieu of it being built into the SDKs. + return not _sdk_internals_pending_tasks_count + + +@asynccontextmanager +async def _sdk_internals__track_pending__wait_until_ready__synchronize( + execute_condition: Callable[[], bool] +): + global _sdk_internals_pending_tasks_count + _sdk_internals_pending_tasks_count += 1 + await workflow.wait_condition(execute_condition) + # TODO: honor max_concurrent, using a semaphore (if not None). + await _sdk_internals_handler_mutex.acquire() + try: + yield + finally: + _sdk_internals_handler_mutex.release() + _sdk_internals_pending_tasks_count -= 1 + + +class SDKInternals: + # Here, the SDK is wrapping the user's update handlers with the required wait-until-ready, + # pending tasks tracking, and synchronization functionality. This is a fake implementation: the + # real implementation will automatically inspect and wrap the user's declared update handlers. + + def ready_to_execute(self, update: Update) -> bool: + # Overridden by users who wish to control order of execution + return True + + @workflow.update + async def run_shell_script_job(self, arg: I) -> O: + handler = getattr(self, "_" + inspect.currentframe().f_code.co_name) + async with _sdk_internals__track_pending__wait_until_ready__synchronize( + lambda: self.ready_to_execute(Update(arg.id, arg)) + ): + return await handler(arg) + + @workflow.update + async def run_python_job(self, arg: I) -> O: + handler = getattr(self, "_" + inspect.currentframe().f_code.co_name) + async with _sdk_internals__track_pending__wait_until_ready__synchronize( + lambda: self.ready_to_execute(Update(arg.id, arg)) + ): + return await handler(arg) + + +# Monkey-patch proposed new public API +setattr(workflow, "all_handlers_completed", _sdk_internals_all_handlers_completed) +setattr(workflow, "Update", Update) +## +## END SDK internals prototype +## + + +@workflow.defn +class JobRunner(SDKInternals): + """ + Jobs must be executed in order dictated by job dependency graph (see `job.depends_on`) and + not before `job.after_time`. + """ + + def __init__(self) -> None: + self.completed_tasks = set[JobID]() + + @workflow.run + async def run(self): + await workflow.wait_condition( + lambda: ( + workflow.info().is_continue_as_new_suggested() + and workflow.all_handlers_completed() + ) + ) + workflow.continue_as_new() + + def ready_to_execute(self, update: workflow.Update) -> bool: + job = update.arg + if not set(job.depends_on) <= self.completed_tasks: + return False + if after_time := job.after_time: + if float(after_time) > workflow.now().timestamp(): + return False + return True + + # These are the real handler functions. When we implement SDK support, these will use the + # decorator form commented out below, and will not use an underscore prefix. + + # @workflow.update(max_concurrent=1, execute_condition=ready_to_execute) + async def _run_shell_script_job(self, job: Job) -> JobOutput: + if security_errors := await workflow.execute_activity( + run_shell_script_security_linter, + args=[job.run], + start_to_close_timeout=timedelta(seconds=10), + ): + return JobOutput(status=1, stdout="", stderr=security_errors) + job_output = await workflow.execute_activity( + run_job, args=[job], start_to_close_timeout=timedelta(seconds=10) + ) + # FIXME: unbounded memory usage + self.completed_tasks.add(job.id) + return job_output + + # @workflow.update(max_concurrent=1, execute_condition=ready_to_execute) + async def _run_python_job(self, job: Job) -> JobOutput: + if not await workflow.execute_activity( + check_python_interpreter_version, + args=[job.python_interpreter_version], + start_to_close_timeout=timedelta(seconds=10), + ): + return JobOutput( + status=1, + stdout="", + stderr=f"Python interpreter version {job.python_interpreter_version} is not available", + ) + job_output = await workflow.execute_activity( + run_job, args=[job], start_to_close_timeout=timedelta(seconds=10) + ) + # FIXME: unbounded memory usage + self.completed_tasks.add(job.id) + return job_output + + +@activity.defn +async def run_job(job: Job) -> JobOutput: + await asyncio.sleep(0.1) + stdout = f"Ran job {job.name} at {datetime.now()}" + print(stdout) + return JobOutput(status=0, stdout=stdout, stderr="") + + +@activity.defn +async def run_shell_script_security_linter(code: str) -> str: + # The user's organization requires that all shell scripts pass an in-house linter that checks + # for shell scripting constructions deemed insecure. + await asyncio.sleep(0.1) + return "" + + +@activity.defn +async def check_python_interpreter_version(version: str) -> bool: + await asyncio.sleep(0.1) + version_is_available = True + return version_is_available + + +async def app(wf: WorkflowHandle): + job_1 = Job( + id="1", + depends_on=[], + after_time=None, + name="should-run-first", + run="echo 'Hello world 1!'", + python_interpreter_version=None, + ) + job_2 = Job( + id="2", + depends_on=["1"], + after_time=None, + name="should-run-second", + run="print('Hello world 2!')", + python_interpreter_version=None, + ) + await asyncio.gather( + wf.execute_update(JobRunner.run_python_job, job_2), + wf.execute_update(JobRunner.run_shell_script_job, job_1), + ) + + +async def main(): + client = await Client.connect("localhost:7233") + async with Worker( + client, + task_queue="tq", + workflows=[JobRunner], + activities=[ + run_job, + run_shell_script_security_linter, + check_python_interpreter_version, + ], + ): + wf = await client.start_workflow( + JobRunner.run, + id="wid", + task_queue="tq", + id_reuse_policy=common.WorkflowIDReusePolicy.TERMINATE_IF_RUNNING, + ) + await app(wf) + + +if __name__ == "__main__": + logging.basicConfig(level=logging.INFO) + asyncio.run(main()) diff --git a/update/job_runner_I2.py b/update/job_runner_I2.py new file mode 100644 index 00000000..b33a68c8 --- /dev/null +++ b/update/job_runner_I2.py @@ -0,0 +1,232 @@ +import asyncio +from collections import OrderedDict +from dataclasses import dataclass +from datetime import datetime, timedelta +from enum import Enum +import logging +from typing import Awaitable, Callable, Optional + +from temporalio import common, workflow, activity +from temporalio.client import Client, WorkflowHandle +from temporalio.worker import Worker + + +JobID = str + + +class JobStatus(Enum): + BLOCKED = 1 + UNBLOCKED = 2 + + +@dataclass +class Job: + id: JobID + depends_on: list[JobID] + after_time: Optional[int] + name: str + run: str + python_interpreter_version: Optional[str] + # TODO: How to handle enums in dataclasses with Temporal's ser/de. + status_value: int = JobStatus.BLOCKED.value + + @property + def status(self): + return JobStatus(self.status_value) + + @status.setter + def status(self, status: JobStatus): + self.status_value = status.value + + +@dataclass +class JobOutput: + status: int + stdout: str + stderr: str + + +@dataclass +class Task: + input: Job + handler: Callable[["JobRunner", Job], Awaitable[JobOutput]] + output: Optional[JobOutput] = None + + +@workflow.defn +class JobRunner: + """ + Jobs must be executed in order dictated by job dependency graph (see `job.depends_on`) and + not before `job.after_time`. + """ + + def __init__(self) -> None: + self.task_queue = OrderedDict[JobID, Task]() + self.completed_tasks = set[JobID]() + + def all_handlers_completed(self): + # We are considering adding an API like `all_handlers_completed` to SDKs. In this particular + # case, the user doesn't actually need the new API, since they are forced to track pending + # tasks in their queue implementation. + return not self.task_queue + + # Note some undesirable things: + # 1. The update handler functions have become generic enqueuers; the "real" handler functions + # are some other methods that don't have the @workflow.update decorator. + # 2. The update handler functions have to store a reference to the real handler in the queue. + # 3. The workflow `run` method is *much* more complicated and bug-prone here, compared to + # I1:WaitUntilReadyToExecuteHandler + + @workflow.run + async def run(self): + """ + Process all tasks in the queue serially, in the main workflow coroutine. + """ + # Note: there are many mistakes a user will make while trying to implement this workflow. + while not ( + workflow.info().is_continue_as_new_suggested() + and self.all_handlers_completed() + ): + await workflow.wait_condition(lambda: bool(self.task_queue)) + for id, task in list(self.task_queue.items()): + job = task.input + if job.status == JobStatus.UNBLOCKED: + await task.handler(self, job) + del self.task_queue[id] + self.completed_tasks.add(id) + for id, task in self.task_queue.items(): + if job.status == JobStatus.BLOCKED and self.ready_to_execute(job): + job.status = JobStatus.UNBLOCKED + workflow.continue_as_new() + + def ready_to_execute(self, job: Job) -> bool: + if not set(job.depends_on) <= self.completed_tasks: + return False + if after_time := job.after_time: + if float(after_time) > workflow.now().timestamp(): + return False + return True + + async def _enqueue_job_and_wait_for_result( + self, job: Job, handler: Callable[["JobRunner", Job], Awaitable[JobOutput]] + ) -> JobOutput: + task = Task(job, handler) + self.task_queue[job.id] = task + await workflow.wait_condition(lambda: task.output is not None) + # Footgun: a user might well think that they can record task completion here, but in fact it + # deadlocks. + # self.completed_tasks.add(job.id) + assert task.output + return task.output + + @workflow.update + async def run_shell_script_job(self, job: Job) -> JobOutput: + return await self._enqueue_job_and_wait_for_result( + job, JobRunner._actually_run_shell_script_job + ) + + async def _actually_run_shell_script_job(self, job: Job) -> JobOutput: + if security_errors := await workflow.execute_activity( + run_shell_script_security_linter, + args=[job.run], + start_to_close_timeout=timedelta(seconds=10), + ): + return JobOutput(status=1, stdout="", stderr=security_errors) + job_output = await workflow.execute_activity( + run_job, args=[job], start_to_close_timeout=timedelta(seconds=10) + ) + return job_output + + @workflow.update + async def run_python_job(self, job: Job) -> JobOutput: + return await self._enqueue_job_and_wait_for_result( + job, JobRunner._actually_run_python_job + ) + + async def _actually_run_python_job(self, job: Job) -> JobOutput: + if not await workflow.execute_activity( + check_python_interpreter_version, + args=[job.python_interpreter_version], + start_to_close_timeout=timedelta(seconds=10), + ): + return JobOutput( + status=1, + stdout="", + stderr=f"Python interpreter version {job.python_interpreter_version} is not available", + ) + job_output = await workflow.execute_activity( + run_job, args=[job], start_to_close_timeout=timedelta(seconds=10) + ) + return job_output + + +@activity.defn +async def run_job(job: Job) -> JobOutput: + await asyncio.sleep(0.1) + stdout = f"Ran job {job.name} at {datetime.now()}" + print(stdout) + return JobOutput(status=0, stdout=stdout, stderr="") + + +@activity.defn +async def run_shell_script_security_linter(code: str) -> str: + # The user's organization requires that all shell scripts pass an in-house linter that checks + # for shell scripting constructions deemed insecure. + await asyncio.sleep(0.1) + return "" + + +@activity.defn +async def check_python_interpreter_version(version: str) -> bool: + await asyncio.sleep(0.1) + version_is_available = True + return version_is_available + + +async def app(wf: WorkflowHandle): + job_1 = Job( + id="1", + depends_on=[], + after_time=None, + name="should-run-first", + run="echo 'Hello world 1!'", + python_interpreter_version=None, + ) + job_2 = Job( + id="2", + depends_on=["1"], + after_time=None, + name="should-run-second", + run="print('Hello world 2!')", + python_interpreter_version=None, + ) + await asyncio.gather( + wf.execute_update(JobRunner.run_python_job, job_2), + wf.execute_update(JobRunner.run_shell_script_job, job_1), + ) + + +async def main(): + client = await Client.connect("localhost:7233") + async with Worker( + client, + task_queue="tq", + workflows=[JobRunner], + activities=[ + run_job, + run_shell_script_security_linter, + check_python_interpreter_version, + ], + ): + wf = await client.start_workflow( + JobRunner.run, + id="wid", + task_queue="tq", + id_reuse_policy=common.WorkflowIDReusePolicy.TERMINATE_IF_RUNNING, + ) + await app(wf) + + +if __name__ == "__main__": + logging.basicConfig(level=logging.INFO) + asyncio.run(main()) diff --git a/update/job_runner_I2_native.py b/update/job_runner_I2_native.py new file mode 100644 index 00000000..e7a439b4 --- /dev/null +++ b/update/job_runner_I2_native.py @@ -0,0 +1,294 @@ +import asyncio +from collections import OrderedDict +from dataclasses import dataclass +from datetime import datetime, timedelta +from enum import Enum +import inspect +import logging +from typing import Awaitable, Callable, Optional, Type + +from temporalio import common, workflow, activity +from temporalio.client import Client, WorkflowHandle +from temporalio.worker import Worker + +# This file contains a proposal for how the Python SDK could provide a native update queue +# (I2:PushToQueue) to help users defer processing, control interleaving of handler coroutines, and +# ensure processing is complete before workflow completion. + +## +## user code +## + + +JobID = str + + +class JobStatus(Enum): + BLOCKED = 1 + UNBLOCKED = 2 + + +@dataclass +class Job: + id: JobID + depends_on: list[JobID] + after_time: Optional[int] + name: str + run: str + python_interpreter_version: Optional[str] + # TODO: How to handle enums in dataclasses with Temporal's ser/de. + status_value: int = JobStatus.BLOCKED.value + + @property + def status(self): + return JobStatus(self.status_value) + + @status.setter + def status(self, status: JobStatus): + self.status_value = status.value + + +@dataclass +class JobOutput: + status: int + stdout: str + stderr: str + + +## +## SDK internals toy prototype +## + +# TODO: use generics to satisfy serializable interface. Faking for now by using user-defined classes. +I = Job +O = JobOutput + +UpdateID = str +Workflow = Type + + +@dataclass +class Update: + arg: I # real implementation will support multiple args + handler: Callable[[Workflow, I], Awaitable[O]] + output: Optional[O] = None + + @property + def id(self): + # In our real implementation the SDK will have native access to the update ID. Currently + # this example is assuming the user passes it in the update arg. + return self.arg.id + + async def handle(self, wf: Workflow) -> O: + # TODO: error-handling + # TODO: prevent handling an update twice + update_result = await self.handler(wf, self.arg) + del workflow.update_queue[self.id] + return update_result + + +async def _sdk_internals_enqueue_job_and_wait_for_result( + arg: I, handler: Callable[[Type, I], Awaitable[O]] +) -> O: + update = Update(arg, handler) + workflow.update_queue[update.id] = update + await workflow.wait_condition(lambda: update.output is not None) + assert update.output + return update.output + + +class SDKInternals: + # Here, the SDK is wrapping the user's update handlers with the required enqueue-and-wait + # functionality. This is a fake implementation: the real implementation will automatically + # inspect and wrap the user's declared update handlers. + + @workflow.update + async def run_shell_script_job(self, arg: I) -> O: + handler = getattr(self.__class__, "_" + inspect.currentframe().f_code.co_name) + return await _sdk_internals_enqueue_job_and_wait_for_result(arg, handler) + + @workflow.update + async def run_python_job(self, arg: I) -> O: + handler = getattr(self.__class__, "_" + inspect.currentframe().f_code.co_name) + return await _sdk_internals_enqueue_job_and_wait_for_result(arg, handler) + + +# Monkey-patch proposed new public API +setattr(workflow, "update_queue", OrderedDict[UpdateID, Update]()) +# The queue-processing style doesn't need an `all_handlers_completed` API: this condition is true +# iff workflow.update_queue is empty. + +## +## END SDK internals prototype +## + + +@workflow.defn +class JobRunner(SDKInternals): + """ + Jobs must be executed in order dictated by job dependency graph (see `job.depends_on`) and + not before `job.after_time`. + """ + + def __init__(self) -> None: + super().__init__() + self.completed_tasks = set[JobID]() + + # Note some desirable things: + # 1. The update handler functions are now "real" handler functions + + # Note some undesirable things: + # 1. The workflow `run` method is still *much* more complicated and bug-prone here, compared to + # I1:WaitUntilReadyToExecuteHandler + + @workflow.run + async def run(self): + """ + Process all tasks in the queue serially, in the main workflow coroutine. + """ + # Note: a user will make mistakes while trying to implement this workflow, due to the + # unblocking algorithm that this particular example seems to require when implemented via + # queue-processing in the main workflow coroutine (this example is simpler to implement by + # making each handler invocation wait until it should execute, and allowing the execution to + # take place in the handler coroutine, with a mutex held. See job_runner_I1.py and + # job_runner_I1_native.py) + while ( + workflow.update_queue or not workflow.info().is_continue_as_new_suggested() + ): + await workflow.wait_condition(lambda: bool(workflow.update_queue)) + for id, update in list(workflow.update_queue.items()): + job = update.arg + if job.status == JobStatus.UNBLOCKED: + # This is how a user manually handles an update. Note that it takes a reference + # to the workflow instance, since an update handler has access to the workflow + # instance. + await update.handle(self) + + # FIXME: unbounded memory usage; this example use-case needs to know which + # updates have completed. Perhaps the real problem here lies with the example, + # i.e. the example needs to be made more realistic. + self.completed_tasks.add(id) + for id, update in workflow.update_queue.items(): + job = update.arg + if job.status == JobStatus.BLOCKED and self.ready_to_execute(job): + job.status = JobStatus.UNBLOCKED + workflow.continue_as_new() + + def ready_to_execute(self, job: Job) -> bool: + if not set(job.depends_on) <= self.completed_tasks: + return False + if after_time := job.after_time: + if float(after_time) > workflow.now().timestamp(): + return False + return True + + # These are the real handler functions. When we implement SDK support, these will use the + # @workflow.update decorator and will not use an underscore prefix. + # TBD update decorator argument name: + # queue=True + # enqueue=True + # auto=False + # auto_handle=False + # manual=True + + # @workflow.update(queue=True) + async def _run_shell_script_job(self, job: Job) -> JobOutput: + if security_errors := await workflow.execute_activity( + run_shell_script_security_linter, + args=[job.run], + start_to_close_timeout=timedelta(seconds=10), + ): + return JobOutput(status=1, stdout="", stderr=security_errors) + job_output = await workflow.execute_activity( + run_job, args=[job], start_to_close_timeout=timedelta(seconds=10) + ) + return job_output + + # @workflow.update(queue=True) + async def _run_python_job(self, job: Job) -> JobOutput: + if not await workflow.execute_activity( + check_python_interpreter_version, + args=[job.python_interpreter_version], + start_to_close_timeout=timedelta(seconds=10), + ): + return JobOutput( + status=1, + stdout="", + stderr=f"Python interpreter version {job.python_interpreter_version} is not available", + ) + job_output = await workflow.execute_activity( + run_job, args=[job], start_to_close_timeout=timedelta(seconds=10) + ) + return job_output + + +@activity.defn +async def run_job(job: Job) -> JobOutput: + await asyncio.sleep(0.1) + stdout = f"Ran job {job.name} at {datetime.now()}" + print(stdout) + return JobOutput(status=0, stdout=stdout, stderr="") + + +@activity.defn +async def run_shell_script_security_linter(code: str) -> str: + # The user's organization requires that all shell scripts pass an in-house linter that checks + # for shell scripting constructions deemed insecure. + await asyncio.sleep(0.1) + return "" + + +@activity.defn +async def check_python_interpreter_version(version: str) -> bool: + await asyncio.sleep(0.1) + version_is_available = True + return version_is_available + + +async def app(wf: WorkflowHandle): + job_1 = Job( + id="1", + depends_on=[], + after_time=None, + name="should-run-first", + run="echo 'Hello world 1!'", + python_interpreter_version=None, + ) + job_2 = Job( + id="2", + depends_on=["1"], + after_time=None, + name="should-run-second", + run="print('Hello world 2!')", + python_interpreter_version=None, + ) + await asyncio.gather( + wf.execute_update(JobRunner.run_python_job, job_2), + wf.execute_update(JobRunner.run_shell_script_job, job_1), + ) + + +async def main(): + client = await Client.connect("localhost:7233") + async with Worker( + client, + task_queue="tq", + workflows=[JobRunner], + activities=[ + run_job, + run_shell_script_security_linter, + check_python_interpreter_version, + ], + ): + wf = await client.start_workflow( + JobRunner.run, + id="wid", + task_queue="tq", + id_reuse_policy=common.WorkflowIDReusePolicy.TERMINATE_IF_RUNNING, + ) + await app(wf) + + +if __name__ == "__main__": + logging.basicConfig(level=logging.INFO) + asyncio.run(main()) diff --git a/update/job_runner_base.py b/update/job_runner_base.py new file mode 100644 index 00000000..47f0a513 --- /dev/null +++ b/update/job_runner_base.py @@ -0,0 +1,146 @@ +import asyncio +from dataclasses import dataclass +from datetime import datetime, timedelta +import logging +from typing import Optional + +from temporalio import common, workflow, activity +from temporalio.client import Client, WorkflowHandle +from temporalio.worker import Worker + + +JobID = str + + +@dataclass +class Job: + id: JobID + depends_on: list[JobID] + after_time: Optional[int] + name: str + run: str + python_interpreter_version: Optional[str] + + +@dataclass +class JobOutput: + status: int + stdout: str + stderr: str + + +@workflow.defn +class JobRunner: + """ + Jobs must be executed in order dictated by job dependency graph (see `job.depends_on`) and + not before `job.after_time`. + """ + + @workflow.run + async def run(self): + await workflow.wait_condition( + lambda: workflow.info().is_continue_as_new_suggested() + ) + workflow.continue_as_new() + + @workflow.update + async def run_shell_script_job(self, job: Job) -> JobOutput: + if security_errors := await workflow.execute_activity( + run_shell_script_security_linter, + args=[job.run], + start_to_close_timeout=timedelta(seconds=10), + ): + return JobOutput(status=1, stdout="", stderr=security_errors) + job_output = await workflow.execute_activity( + run_job, args=[job], start_to_close_timeout=timedelta(seconds=10) + ) + return job_output + + @workflow.update + async def run_python_job(self, job: Job) -> JobOutput: + if not await workflow.execute_activity( + check_python_interpreter_version, + args=[job.python_interpreter_version], + start_to_close_timeout=timedelta(seconds=10), + ): + return JobOutput( + status=1, + stdout="", + stderr=f"Python interpreter version {job.python_interpreter_version} is not available", + ) + job_output = await workflow.execute_activity( + run_job, args=[job], start_to_close_timeout=timedelta(seconds=10) + ) + return job_output + + +@activity.defn +async def run_job(job: Job) -> JobOutput: + await asyncio.sleep(0.1) + stdout = f"Ran job {job.name} at {datetime.now()}" + print(stdout) + return JobOutput(status=0, stdout=stdout, stderr="") + + +@activity.defn +async def run_shell_script_security_linter(code: str) -> str: + # The user's organization requires that all shell scripts pass an in-house linter that checks + # for shell scripting constructions deemed insecure. + await asyncio.sleep(0.1) + return "" + + +@activity.defn +async def check_python_interpreter_version(version: str) -> bool: + await asyncio.sleep(0.1) + version_is_available = True + return version_is_available + + +async def app(wf: WorkflowHandle): + job_1 = Job( + id="1", + depends_on=[], + after_time=None, + name="should-run-first", + run="echo 'Hello world 1!'", + python_interpreter_version=None, + ) + job_2 = Job( + id="2", + depends_on=["1"], + after_time=None, + name="should-run-second", + run="print('Hello world 2!')", + python_interpreter_version=None, + ) + await asyncio.gather( + wf.execute_update(JobRunner.run_python_job, job_2), + wf.execute_update(JobRunner.run_shell_script_job, job_1), + ) + + +async def main(): + client = await Client.connect("localhost:7233") + async with Worker( + client, + task_queue="tq", + workflows=[JobRunner], + activities=[ + run_job, + run_shell_script_security_linter, + check_python_interpreter_version, + ], + ): + wf = await client.start_workflow( + JobRunner.run, + id="wid", + task_queue="tq", + id_reuse_policy=common.WorkflowIDReusePolicy.TERMINATE_IF_RUNNING, + ) + await app(wf) + + +if __name__ == "__main__": + logging.basicConfig(level=logging.INFO) + asyncio.run(main()) diff --git a/update/serialized_handling_of_n_messages.py b/update/serialized_handling_of_n_messages.py new file mode 100644 index 00000000..4c9a0d5c --- /dev/null +++ b/update/serialized_handling_of_n_messages.py @@ -0,0 +1,114 @@ +import asyncio +import logging +from asyncio import Future +from collections import deque +from datetime import timedelta + +from temporalio import activity, common, workflow +from temporalio.client import Client, WorkflowHandle +from temporalio.worker import Worker + +Arg = str +Result = str + +# Problem: +# ------- +# - Your workflow receives an unbounded number of updates. +# - Each update must be processed by calling two activities. +# - The next update may not start processing until the previous has completed. + +# Solution: +# -------- +# Enqueue updates, and process items from the queue in a single coroutine (the main workflow +# coroutine). + +# Discussion: +# ---------- +# The queue is used because Temporal's async update & signal handlers will interleave if they +# contain multiple yield points. An alternative would be to use standard async handler functions, +# with handling being done with an asyncio.Lock held. The queue approach would be necessary if we +# need to process in an order other than arrival. + + +@workflow.defn +class MessageProcessor: + + def __init__(self): + self.queue = deque[tuple[Arg, Future[Result]]]() + + @workflow.run + async def run(self): + while True: + await workflow.wait_condition(lambda: len(self.queue) > 0) + while self.queue: + arg, fut = self.queue.popleft() + fut.set_result(await self.execute_processing_task(arg)) + if workflow.info().is_continue_as_new_suggested(): + # Footgun: If we don't let the event loop tick, then CAN will end the workflow + # before the update handler is notified that the result future has completed. + # See https://github.com/temporalio/features/issues/481 + await asyncio.sleep(0) # Let update handler complete + print("CAN") + return workflow.continue_as_new() + + # Note: handler must be async if we are both enqueuing, and returning an update result + # => We could add SDK APIs to manually complete updates. + @workflow.update + async def process_message(self, arg: Arg) -> Result: + # Footgun: handler may need to wait for workflow initialization after CAN + # See https://github.com/temporalio/features/issues/400 + # await workflow.wait_condition(lambda: hasattr(self, "queue")) + fut = Future[Result]() + self.queue.append((arg, fut)) # Note: update validation gates enqueue + return await fut + + async def execute_processing_task(self, arg: Arg) -> Result: + # The purpose of the two activities and the result string format is to permit checks that + # the activities of different tasks do not interleave. + t1, t2 = [ + await workflow.execute_activity( + get_current_time, start_to_close_timeout=timedelta(seconds=10) + ) + for _ in range(2) + ] + return f"{arg}-result-{t1}-{t2}" + + +time = 0 + + +@activity.defn +async def get_current_time() -> int: + global time + time += 1 + return time + + +async def app(wf: WorkflowHandle): + for i in range(20): + print(f"app(): sending update {i}") + result = await wf.execute_update(MessageProcessor.process_message, f"arg {i}") + print(f"app(): {result}") + + +async def main(): + client = await Client.connect("localhost:7233") + + async with Worker( + client, + task_queue="tq", + workflows=[MessageProcessor], + activities=[get_current_time], + ): + wf = await client.start_workflow( + MessageProcessor.run, + id="wid", + task_queue="tq", + id_reuse_policy=common.WorkflowIDReusePolicy.TERMINATE_IF_RUNNING, + ) + await asyncio.gather(app(wf), wf.result()) + + +if __name__ == "__main__": + logging.basicConfig(level=logging.INFO) + asyncio.run(main())