Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Always use barrier #2

Merged
merged 6 commits into from
Dec 20, 2024
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
Allow using custom barrier implementation
  • Loading branch information
pencil committed Dec 20, 2024
commit cd0c06ff28b59cb6cc338118a7240b7ec7918244
21 changes: 15 additions & 6 deletions dramatiq_workflow/_base.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import logging
import time
import typing
from uuid import uuid4

import dramatiq
@@ -8,7 +9,7 @@
from ._constants import CALLBACK_BARRIER_TTL, OPTION_KEY_CALLBACKS
from ._helpers import workflow_with_completion_callbacks
from ._middleware import WorkflowMiddleware, workflow_noop
from ._models import Barrier, Chain, CompletionCallbacks, Group, Message, WithDelay, WorkflowType
from ._models import Chain, CompletionCallbacks, Group, Message, WithDelay, WorkflowType
from ._serialize import serialize_callbacks, serialize_workflow

logger = logging.getLogger(__name__)
@@ -156,27 +157,35 @@ def __augment_message(self, message: Message, completion_callbacks: CompletionCa
)

@property
def __rate_limiter_backend(self):
if not hasattr(self, "__cached_rate_limiter_backend"):
def __middleware(self) -> WorkflowMiddleware:
if not hasattr(self, "__cached_middleware"):
for middleware in self.broker.middleware:
if isinstance(middleware, WorkflowMiddleware):
self.__cached_rate_limiter_backend = middleware.rate_limiter_backend
self.__cached_middleware = middleware
break
else:
raise RuntimeError(
"WorkflowMiddleware middleware not found! Did you forget "
"to set it up? It is required if you want to use "
"workflows."
)
return self.__cached_rate_limiter_backend
return self.__cached_middleware

@property
def __rate_limiter_backend(self) -> dramatiq.rate_limits.RateLimiterBackend:
return self.__middleware.rate_limiter_backend

@property
def __barrier(self) -> typing.Type[dramatiq.rate_limits.Barrier]:
return self.__middleware.barrier

def __create_barrier(self, count: int):
if count == 1:
# No need to create a distributed barrier if there is only one task
return None

completion_uuid = str(uuid4())
completion_barrier = Barrier(self.__rate_limiter_backend, completion_uuid, ttl=CALLBACK_BARRIER_TTL)
completion_barrier = self.__barrier(self.__rate_limiter_backend, completion_uuid, ttl=CALLBACK_BARRIER_TTL)
completion_barrier.create(count)
logger.debug("Barrier created: %s (%d tasks)", completion_uuid, count)
return completion_uuid
12 changes: 9 additions & 3 deletions dramatiq_workflow/_middleware.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,25 @@
import logging
import typing

import dramatiq
import dramatiq.rate_limits

from ._barrier import AtMostOnceBarrier
from ._constants import OPTION_KEY_CALLBACKS
from ._helpers import workflow_with_completion_callbacks
from ._models import Barrier
from ._serialize import unserialize_callbacks, unserialize_workflow

logger = logging.getLogger(__name__)


class WorkflowMiddleware(dramatiq.Middleware):
def __init__(self, rate_limiter_backend: dramatiq.rate_limits.RateLimiterBackend):
def __init__(
self,
rate_limiter_backend: dramatiq.rate_limits.RateLimiterBackend,
barrier: typing.Type[dramatiq.rate_limits.Barrier] = AtMostOnceBarrier,
):
self.rate_limiter_backend = rate_limiter_backend
self.barrier = barrier

def after_process_boot(self, broker: dramatiq.Broker):
broker.declare_actor(workflow_noop)
@@ -36,7 +42,7 @@ def after_process_message(
while len(completion_callbacks) > 0:
completion_id, remaining_workflow, propagate = completion_callbacks[-1]
if completion_id is not None:
barrier = Barrier(self.rate_limiter_backend, completion_id)
barrier = self.barrier(self.rate_limiter_backend, completion_id)
if not barrier.wait(block=False):
logger.debug("Barrier not completed: %s", completion_id)
break
3 changes: 0 additions & 3 deletions dramatiq_workflow/_models.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
import dramatiq
import dramatiq.rate_limits

from ._barrier import AtMostOnceBarrier


class Chain:
def __init__(self, *tasks: "WorkflowType"):
@@ -38,7 +36,6 @@ def __eq__(self, other):
return isinstance(other, WithDelay) and self.task == other.task and self.delay == other.delay


Barrier = AtMostOnceBarrier
Message = dramatiq.Message
WorkflowType = Message | Chain | Group | WithDelay

19 changes: 14 additions & 5 deletions dramatiq_workflow/tests/test_workflow.py
Original file line number Diff line number Diff line change
@@ -2,14 +2,24 @@
from unittest import mock

import dramatiq
import dramatiq.rate_limits

from .. import Chain, Group, WithDelay, Workflow, WorkflowMiddleware
from .._serialize import serialize_workflow, unserialize_workflow


class WorkflowTests(unittest.TestCase):
def setUp(self):
self.broker = mock.MagicMock(middleware=[WorkflowMiddleware(mock.MagicMock())])
self.rate_limiter_backend = mock.create_autospec(dramatiq.rate_limits.RateLimiterBackend, instance=True)
self.barrier = mock.create_autospec(dramatiq.rate_limits.Barrier)
self.broker = mock.MagicMock(
middleware=[
WorkflowMiddleware(
rate_limiter_backend=self.rate_limiter_backend,
barrier=self.barrier,
)
]
)
self.task = mock.MagicMock()
self.task.message.side_effect = lambda *args, **kwargs: self.__make_message(
self.__generate_id(), *args, **kwargs
@@ -230,10 +240,10 @@ def test_chain_with_delay(self, time_mock):
),
delay=10,
)
self.barrier.assert_called_once_with(self.rate_limiter_backend, mock.ANY, ttl=mock.ANY)

@mock.patch("dramatiq_workflow._base.time.time")
@mock.patch("dramatiq_workflow._base.Barrier")
def test_group_with_delay(self, barrier_mock, time_mock):
def test_group_with_delay(self, time_mock):
time_mock.return_value = 1717526000.12
updated_timestamp = time_mock.return_value * 1000
workflow = Workflow(
@@ -301,8 +311,7 @@ def test_serialize_unserialize(self):
self.assertEqual(workflow.workflow, unserialized)

@mock.patch("dramatiq_workflow._base.time.time")
@mock.patch("dramatiq_workflow._base.Barrier")
def test_additive_delays(self, barrier_mock, time_mock):
def test_additive_delays(self, time_mock):
time_mock.return_value = 1717526000.12
updated_timestamp = time_mock.return_value * 1000
workflow = Workflow(