Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Always use barrier #2

Merged
merged 6 commits into from
Dec 20, 2024
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 32 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -39,10 +39,10 @@ pip install dramatiq-workflow
Then, add the `dramatiq-workflow` middleware to your dramatiq broker:

```python
from dramatiq.rate_limits.backends import RedisBackend
from dramatiq.rate_limits.backends import RedisBackend # or MemcachedBackend
from dramatiq_workflow import WorkflowMiddleware

backend = RedisBackend()
backend = RedisBackend() # or MemcachedBackend()
broker.add_middleware(WorkflowMiddleware(backend))
```

Expand Down Expand Up @@ -153,6 +153,36 @@ workflow = Workflow(
In this example, Task 2 will run roughly 1 second after Task 1 finishes, and
Task 3 and will run 2 seconds after Task 2 finishes.

### Barrier

`dramatiq-workflow` uses a barrier mechanism to keep track of the current state
of a workflow. For example, every time a task in a `Group` is completed, the
barrier is decreased by one. When the barrier reaches zero, the next task in
the outer `Chain` is scheduled to run.

By default, `dramatiq-workflow` uses a custom `AtMostOnceBarrier` that ensures
the barrier is never released more than once. When the barrier reaches zero, an
additional key is set in the backend to prevent releasing the barrier again. In
almost all cases, this is the desired behavior since releasing a barrier more
than once could lead to duplicate tasks being scheduled - which would have
severe compounding effects in a workflow with many `Group` tasks.

However, there is a small chance that the barrier is never released. This can
happen when the configured `rate_limiter_backend` loses its state or when the
worker unexpectedly crashes before scheduling the next task in the workflow.

To configure a different barrier implementation such as dramatiq's default
`Barrier`, you can pass it to the `WorkflowMiddleware`:

```python
from dramatiq.rate_limits import Barrier
from dramatiq.rate_limits.backends import RedisBackend
from dramatiq_workflow import WorkflowMiddleware

backend = RedisBackend()
broker.add_middleware(WorkflowMiddleware(backend, barrier=Barrier))
```

## License

This project is licensed under the MIT License. See the [LICENSE](LICENSE) file
Expand Down
24 changes: 14 additions & 10 deletions dramatiq_workflow/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from ._constants import CALLBACK_BARRIER_TTL, OPTION_KEY_CALLBACKS
from ._helpers import workflow_with_completion_callbacks
from ._middleware import WorkflowMiddleware, workflow_noop
from ._models import Barrier, Chain, CompletionCallbacks, Group, Message, WithDelay, WorkflowType
from ._models import Chain, CompletionCallbacks, Group, Message, WithDelay, WorkflowType
from ._serialize import serialize_callbacks, serialize_workflow

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -156,27 +156,31 @@ def __augment_message(self, message: Message, completion_callbacks: CompletionCa
)

@property
def __rate_limiter_backend(self):
if not hasattr(self, "__cached_rate_limiter_backend"):
def __middleware(self) -> WorkflowMiddleware:
if not hasattr(self, "__cached_middleware"):
for middleware in self.broker.middleware:
if isinstance(middleware, WorkflowMiddleware):
self.__cached_rate_limiter_backend = middleware.rate_limiter_backend
self.__cached_middleware = middleware
break
else:
raise RuntimeError(
"WorkflowMiddleware middleware not found! Did you forget "
"to set it up? It is required if you want to use "
"workflows."
)
return self.__cached_rate_limiter_backend
return self.__cached_middleware

def __create_barrier(self, count: int):
if count == 1:
# No need to create a distributed barrier if there is only one task
return None
@property
def __rate_limiter_backend(self) -> dramatiq.rate_limits.RateLimiterBackend:
return self.__middleware.rate_limiter_backend

@property
def __barrier(self) -> type[dramatiq.rate_limits.Barrier]:
return self.__middleware.barrier

def __create_barrier(self, count: int):
completion_uuid = str(uuid4())
completion_barrier = Barrier(self.__rate_limiter_backend, completion_uuid, ttl=CALLBACK_BARRIER_TTL)
completion_barrier = self.__barrier(self.__rate_limiter_backend, completion_uuid, ttl=CALLBACK_BARRIER_TTL)
completion_barrier.create(count)
logger.debug("Barrier created: %s (%d tasks)", completion_uuid, count)
return completion_uuid
Expand Down
11 changes: 8 additions & 3 deletions dramatiq_workflow/_middleware.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,17 +3,22 @@
import dramatiq
import dramatiq.rate_limits

from ._barrier import AtMostOnceBarrier
from ._constants import OPTION_KEY_CALLBACKS
from ._helpers import workflow_with_completion_callbacks
from ._models import Barrier
from ._serialize import unserialize_callbacks, unserialize_workflow

logger = logging.getLogger(__name__)


class WorkflowMiddleware(dramatiq.Middleware):
def __init__(self, rate_limiter_backend: dramatiq.rate_limits.RateLimiterBackend):
def __init__(
self,
rate_limiter_backend: dramatiq.rate_limits.RateLimiterBackend,
barrier: type[dramatiq.rate_limits.Barrier] = AtMostOnceBarrier,
):
self.rate_limiter_backend = rate_limiter_backend
self.barrier = barrier

def after_process_boot(self, broker: dramatiq.Broker):
broker.declare_actor(workflow_noop)
Expand All @@ -36,7 +41,7 @@ def after_process_message(
while len(completion_callbacks) > 0:
completion_id, remaining_workflow, propagate = completion_callbacks[-1]
if completion_id is not None:
barrier = Barrier(self.rate_limiter_backend, completion_id)
barrier = self.barrier(self.rate_limiter_backend, completion_id)
if not barrier.wait(block=False):
logger.debug("Barrier not completed: %s", completion_id)
break
Expand Down
3 changes: 0 additions & 3 deletions dramatiq_workflow/_models.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
import dramatiq
import dramatiq.rate_limits

from ._barrier import AtMostOnceBarrier


class Chain:
def __init__(self, *tasks: "WorkflowType"):
Expand Down Expand Up @@ -38,7 +36,6 @@ def __eq__(self, other):
return isinstance(other, WithDelay) and self.task == other.task and self.delay == other.delay


Barrier = AtMostOnceBarrier
Message = dramatiq.Message
WorkflowType = Message | Chain | Group | WithDelay

Expand Down
23 changes: 16 additions & 7 deletions dramatiq_workflow/tests/test_workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,24 @@
from unittest import mock

import dramatiq
import dramatiq.rate_limits

from .. import Chain, Group, WithDelay, Workflow, WorkflowMiddleware
from .._serialize import serialize_workflow, unserialize_workflow


class WorkflowTests(unittest.TestCase):
def setUp(self):
self.broker = mock.MagicMock(middleware=[WorkflowMiddleware(mock.MagicMock())])
self.rate_limiter_backend = mock.create_autospec(dramatiq.rate_limits.RateLimiterBackend, instance=True)
self.barrier = mock.create_autospec(dramatiq.rate_limits.Barrier)
self.broker = mock.MagicMock(
middleware=[
WorkflowMiddleware(
rate_limiter_backend=self.rate_limiter_backend,
barrier=self.barrier,
)
]
)
self.task = mock.MagicMock()
self.task.message.side_effect = lambda *args, **kwargs: self.__make_message(
self.__generate_id(), *args, **kwargs
Expand Down Expand Up @@ -70,7 +80,7 @@ def test_simple_workflow(self, time_mock):
message_options={
"workflow_completion_callbacks": [
(
None,
mock.ANY, # Accept any callback ID
{
"__type__": "chain",
"children": [
Expand Down Expand Up @@ -207,7 +217,7 @@ def test_chain_with_delay(self, time_mock):
message_options={
"workflow_completion_callbacks": [
(
None,
mock.ANY, # Accept any callback ID
{
"__type__": "chain",
"children": [
Expand All @@ -230,10 +240,10 @@ def test_chain_with_delay(self, time_mock):
),
delay=10,
)
self.barrier.assert_called_once_with(self.rate_limiter_backend, mock.ANY, ttl=mock.ANY)

@mock.patch("dramatiq_workflow._base.time.time")
@mock.patch("dramatiq_workflow._base.Barrier")
def test_group_with_delay(self, barrier_mock, time_mock):
def test_group_with_delay(self, time_mock):
time_mock.return_value = 1717526000.12
updated_timestamp = time_mock.return_value * 1000
workflow = Workflow(
Expand Down Expand Up @@ -301,8 +311,7 @@ def test_serialize_unserialize(self):
self.assertEqual(workflow.workflow, unserialized)

@mock.patch("dramatiq_workflow._base.time.time")
@mock.patch("dramatiq_workflow._base.Barrier")
def test_additive_delays(self, barrier_mock, time_mock):
def test_additive_delays(self, time_mock):
time_mock.return_value = 1717526000.12
updated_timestamp = time_mock.return_value * 1000
workflow = Workflow(
Expand Down
Loading