Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 25 additions & 5 deletions cadence/worker/_worker.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import asyncio
import logging
import uuid
from typing import Unpack, cast

Expand All @@ -8,6 +9,8 @@
from cadence.worker._decision import DecisionWorker
from cadence.worker._types import WorkerOptions, _DEFAULT_WORKER_OPTIONS

logger = logging.getLogger(__name__)


class Worker:
def __init__(
Expand All @@ -17,6 +20,7 @@ def __init__(
registry: Registry,
**kwargs: Unpack[WorkerOptions],
) -> None:
self._tasks: list[asyncio.Task[None]] = []
self._client = client
self._task_list = task_list

Expand All @@ -35,11 +39,27 @@ def task_list(self) -> str:
return self._task_list

async def run(self) -> None:
async with asyncio.TaskGroup() as tg:
if not self._options["disable_workflow_worker"]:
tg.create_task(self._decision_worker.run())
if not self._options["disable_activity_worker"]:
tg.create_task(self._activity_worker.run())
if not self._options["disable_workflow_worker"]:
self._tasks.append(asyncio.create_task(self._decision_worker.run()))
if not self._options["disable_activity_worker"]:
self._tasks.append(asyncio.create_task(self._activity_worker.run()))

async def close(self) -> None:
for task in self._tasks:
task.cancel()
results = await asyncio.gather(*self._tasks, return_exceptions=True)
for result in results:
if isinstance(result, BaseException) and not isinstance(
result, asyncio.CancelledError
):
logger.error("Worker task failed", exc_info=result)

async def __aenter__(self) -> "Worker":
await self.run()
return self

async def __aexit__(self, exc_type, exc_val, exc_tb) -> None:
await self.close()


def _validate_and_copy_defaults(
Expand Down
14 changes: 4 additions & 10 deletions tests/cadence/worker/test_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,22 +32,16 @@ async def poll(_, timeout=0.0):
type(client).domain = PropertyMock(return_value="domain")
type(client).identity = PropertyMock(return_value="identity")

worker = Worker(
async with Worker(
client,
"task_list",
Registry(),
activity_task_pollers=1,
decision_task_pollers=1,
identity="identity",
)

task = asyncio.create_task(worker.run())

# Wait until both polled
await both_waited.wait()
task.cancel()
with pytest.raises(asyncio.CancelledError):
await task
):
# Wait until both polled
await both_waited.wait()

worker_stub.PollForDecisionTask.assert_called_once_with(
PollForDecisionTaskRequest(
Expand Down
13 changes: 3 additions & 10 deletions tests/integration_tests/helper.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import asyncio
import pathlib
from contextlib import asynccontextmanager
import pathlib
from typing import AsyncGenerator, Unpack, cast

from google.protobuf.proto_json import serialize, parse
Expand All @@ -24,14 +23,8 @@ async def worker(
self, registry: Registry, **kwargs: Unpack[WorkerOptions]
) -> AsyncGenerator[Worker, None]:
async with self.client() as client:
worker = Worker(client, self.test_name, registry, **kwargs)
task = asyncio.create_task(worker.run())
yield worker
task.cancel()
try:
await task
except asyncio.CancelledError:
pass
async with Worker(client, self.test_name, registry, **kwargs) as w:
yield w

def load_history(self, path: str) -> history.History:
file = pathlib.Path(self.fspath).with_name(path)
Expand Down
Loading