Skip to content
Merged
4 changes: 4 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -49,3 +49,7 @@ run-server:
.PHONY: run-worker
run-worker:
celery -A chronos.worker worker --loglevel=info --autoscale 4,2 -E

.PHONY: run-dispatcher
run-dispatcher:
celery -A chronos.worker worker -Q dispatcher -c 1 --without-heartbeat --without-mingle --soft-time-limit=0 --time-limit=0
3 changes: 2 additions & 1 deletion Procfile
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
web: uvicorn chronos.main:app --host=0.0.0.0 --port=${PORT:-5000}
worker: celery -A chronos.worker worker --loglevel=error --concurrency 2 -E
worker: celery -A chronos.worker worker --loglevel=error --concurrency 2 -E
dispatcher: celery -A chronos.worker worker -Q dispatcher -c 1 --without-heartbeat --without-mingle --soft-time-limit=0 --time-limit=0
21 changes: 21 additions & 0 deletions chronos/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,5 +29,26 @@ class Settings(BaseSettings):
dft_timezone: str = 'Europe/London'
tc2_shared_key: str = 'test-key'

# Round-robin dispatcher feature flag
# So the dispatcher mode can be turned off gradually if we want to
# switch back to the previous mode of handling tasks
use_round_robin: bool = True

# Round-robin dispatcher tuning

# Backpressure threshold. If the broker queue (`celery`) has 100 or more pending tasks, dispatcher pauses dispatching temporarily.
dispatcher_max_celery_queue: int = 100
# Maximum number of jobs dispatched in one round-robin cycle (default cycle runs every 10 ms)
dispatcher_batch_limit: int = 100
# Sleep duration between normal dispatcher cycles while there is work.
dispatcher_cycle_delay_seconds: float = 0.01
# Sleep duration when no active branches have queued jobs.
dispatcher_idle_delay_seconds: float = 1.0

# Webhook HTTP client tuning

webhook_http_timeout_seconds: float = 8.0
webhook_http_max_connections: int = 250

# Read local env file for local variables
model_config = SettingsConfigDict(env_file='.env', extra='allow')
Empty file added chronos/tasks/__init__.py
Empty file.
122 changes: 122 additions & 0 deletions chronos/tasks/dispatcher.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
"""
Round-robin job dispatcher for fair processing across branches.

The dispatcher runs as a continuous while loop on a dedicated worker,
cycling through branches with pending jobs and dispatching them to
Celery workers with backpressure control.

#### IMPORTANT ######
The dispatcher worker must be run separately with time limits
disabled via CLI flags (--soft-time-limit=0 --time-limit=0). Setting
soft_time_limit=0 on the task decorator alone is NOT sufficient because
billiard's pool.apply_async uses `or` (not `is None`) to check overrides,
so 0 is treated as falsy and falls back to the global default.

COMMAND TO RUN:

celery -A chronos.worker worker -Q dispatcher -c 1 \
--without-heartbeat --without-mingle \
--soft-time-limit=0 --time-limit=0
"""

import json
import logging
from bisect import bisect_right

from pydantic import ValidationError

from chronos.utils import settings

dispatch_logger = logging.getLogger('chronos.dispatcher')


def dispatch_cycle(batch_limit: int = settings.dispatcher_batch_limit):
"""
Execute one round-robin dispatch cycle.

Cycles through all branches with pending jobs, starting from the cursor
position, and dispatches one job from each branch to Celery workers.

Uses bisect_right to find the correct start position even when the cursor
points to a branch that is no longer active. This ensures fairness: if
cursor=7 and active branches are [5, 10], we start from 10 (the next
branch after 7 in sorted order), not 5.
"""
# avoids circular import here
from chronos.worker import celery_app, job_queue

# get all the active branches here
active_branches = job_queue.get_active_branches()
if not active_branches:
dispatch_logger.info('No active branches found by dispatcher')
return 0

# then get the cursor
start_index = 0
cursor = job_queue.get_cursor()
if cursor is not None:
# we use bisect right to get the insertion point from the sorted active branches list
start_index = bisect_right(active_branches, cursor) % len(active_branches)

dispatched = 0

# this here does the re-ordering of the active branch list
# For example consider [1, 2 , 3 , 4 ,5] and start is at 4
# so we do [4, 5, 1, 2 , 3]
branches_to_process = active_branches[start_index:] + active_branches[:start_index]
for branch_id in branches_to_process:
if dispatched >= batch_limit:
# we will continue from here again in the next cycle if we have dispatched past
# the per cycle dispatch limit
break

# Phase 1 is peek and validate poison payloads are acked or discarded.
# Only catches deserialization errors such as corrupt json object which
# are permanent and can never succeed. Redis errors are not caught
# here as they propagate to the outer while-loop which sleeps
# and retries the whole cycle. This prevents a transient peek() failure from
# discarding a valid job.
try:
payload = job_queue.peek(branch_id)
if payload is None:
continue

task = celery_app.tasks.get(payload.task_name)
if task is None:
dispatch_logger.error('Unknown task %s for branch %d, skipping', payload.task_name, branch_id)
job_queue.ack(branch_id)
continue
except (json.JSONDecodeError, ValidationError):
dispatch_logger.exception('Poison payload for branch %d, discarding', branch_id)
try:
job_queue.ack(branch_id)
except Exception:
dispatch_logger.exception('Failed to ack poison job for branch %d', branch_id)
continue

# Phase 2 is dispatch where failures here DON'T ack.
# apply_async() can fail from broker errors or permanent
# serialization errors. We can't distinguish them, so we leave the job
# in the queue and skip this branch for now. Transient errors resolve
# next cycle. Permanent errors stall only this branch (other branches
# are unaffected) and produce repeated log errors for operator attention.
try:
task.apply_async(kwargs=payload.kwargs)
except Exception:
dispatch_logger.exception(
'Failed to dispatch %s for branch %d, will retry next cycle', payload.task_name, branch_id
)
# so the concept is that the transient error will hopefully resolve next cycle
# don't really think we'll get a serialisation error here? Because they should be discarded in phase 1
continue

# Phase 3 Post-dispatch ob was dispatched, ack it off the queue.
job_queue.ack(branch_id)
try:
job_queue.set_cursor(branch_id)
except Exception:
dispatch_logger.exception('Failed to update cursor after dispatching for branch %d', branch_id)
dispatched += 1
dispatch_logger.info('Dispatched %s for branch %d', payload.task_name, branch_id)

return dispatched
160 changes: 160 additions & 0 deletions chronos/tasks/queue.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,160 @@
import json
from datetime import UTC, datetime
from typing import Any, Optional

import redis
from pydantic import BaseModel

from chronos.utils import settings

BRANCH_KEY_TEMPLATE = 'jobs:branch:{}'
ACTIVE_BRANCHES_KEY = 'jobs:branches:active'
CURSOR_KEY = 'jobs:dispatcher:cursor'


class JobPayload(BaseModel):
"""
This is used to serialise the Job Payload stored in Redis
"""

task_name: str
branch_id: int
kwargs: dict[str, Any]
enqueued_at: datetime


class JobQueue:
"""
Redis LIST queue for per branch Job storage
"""

# this is a Lua script for atomic pop and remove
_ACK_SCRIPT = """
redis.call('LPOP', KEYS[1])
if redis.call('LLEN', KEYS[1]) == 0 then
redis.call('SREM', KEYS[2], ARGV[1])
end
return 1
"""

_ack_script_sha: Optional[str] = None

def __init__(self, redis_client: redis.Redis):
self.redis_client = redis_client

def _get_queue_key(self, branch_id: int) -> str:
"""
gives the key for the queue using the branch id
"""
return BRANCH_KEY_TEMPLATE.format(branch_id)

def enqueue(self, task_name: str, branch_id: int, **kwargs):
"""Add a job to a branch's queue.

Args:
task_name: Name of the Celery task to execute.
routing_branch_id: Branch ID for queue routing (not passed to the task).
**kwargs: Arguments to pass to the task.
"""
payload = JobPayload(
task_name=task_name,
branch_id=branch_id,
kwargs=kwargs,
enqueued_at=datetime.now(UTC),
)
queue_key = self._get_queue_key(branch_id)

# we create a pipeline to execute the enqueue related commands
pipeline = self.redis_client.pipeline()
pipeline.rpush(queue_key, payload.model_dump_json())
pipeline.sadd(ACTIVE_BRANCHES_KEY, str(branch_id))
pipeline.execute()

def peek(self, branch_id: int) -> JobPayload | None:
"""
Uses a non-destructive peek to return the JobPayload ob given a branch id
"""
queue_key = self._get_queue_key(branch_id)
data = self.redis_client.lindex(queue_key, 0)
if data is None:
return None
return JobPayload(**json.loads(data))

def _run_ack_script(self, queue_key: str, branch_id: int) -> None:
"""
Runs the atomic Lua ack script
"""
try:
self.redis_client.evalsha(
JobQueue._ack_script_sha,
2,
queue_key,
ACTIVE_BRANCHES_KEY,
str(branch_id),
)
except redis.exceptions.NoScriptError:
JobQueue._ack_script_sha = self.redis_client.script_load(JobQueue._ACK_SCRIPT)
self.redis_client.evalsha(
JobQueue._ack_script_sha,
2,
queue_key,
ACTIVE_BRANCHES_KEY,
str(branch_id),
)

def ack(self, branch_id: int) -> None:
"""
Acknowledge and remove the oldest job from a branch's queue. Also removes the branch from
active set if it has no jobs in the queue
"""
queue_key = self._get_queue_key(branch_id)
if JobQueue._ack_script_sha is None:
JobQueue._ack_script_sha = self.redis_client.script_load(JobQueue._ACK_SCRIPT)
self._run_ack_script(queue_key, branch_id)

def get_active_branches(self) -> list[int]:
"""
Get sorted list of branch IDs with pending jobs.
"""
branch_ids = self.redis_client.smembers(ACTIVE_BRANCHES_KEY)
return sorted(int(bid) for bid in branch_ids)

def has_active_jobs(self) -> bool:
# checks against the cardinality or length of the set
return self.redis_client.scard(ACTIVE_BRANCHES_KEY) > 0

def get_cursor(self) -> Optional[int]:
cursor = self.redis_client.get(CURSOR_KEY)
return int(cursor) if cursor else None

def set_cursor(self, branch_id: int) -> None:
self.redis_client.set(CURSOR_KEY, str(branch_id))

def get_queue_length(self, branch_id: int) -> int:
"""
Get the length of a branch's queue. O(1) Redis operation.
"""
return self.redis_client.llen(self._get_queue_key(branch_id))

def get_celery_queue_length(self) -> int:
"""Get pending tasks in the Celery default queue. O(1) Redis operation.

IMPORTANT: This measures only the broker queue (tasks waiting to be picked
up by a worker). It does NOT count tasks currently being executed by workers.
With 2 workers busy and LLEN=0, real system load is 2, not 0. See Edge Case 26
for analysis of why this is acceptable and Appendix B for future improvements.
"""
return self.redis_client.llen('celery')

def clear_all(self) -> None:
"""
Clear all job queue data. Testing/dev only.
"""
assert settings.testing or settings.dev_mode
branch_ids = self.get_active_branches()
pipe = self.redis_client.pipeline()
for branch_id in branch_ids:
pipe.delete(self._get_queue_key(branch_id))
pipe.delete(ACTIVE_BRANCHES_KEY)
pipe.delete(CURSOR_KEY)
pipe.execute()
24 changes: 24 additions & 0 deletions chronos/tasks/worker_startup.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
"""
Worker startup hook for the job dispatcher.

Automatically starts the continuous dispatcher when a Celery worker
consuming from the 'dispatcher' queue becomes ready.
"""

import logging

from celery.signals import worker_ready

logger = logging.getLogger('chronos.dispatcher')


@worker_ready.connect
def start_dispatcher_on_worker_ready(sender, **kwargs):
queues = [q.name if hasattr(q, 'name') else q for q in sender.app.amqp.queues.consume_from]
if 'dispatcher' in queues:
logger.info('Dispatcher worker ready, starting job dispatcher task')
# Local import to avoid circular dependency: worker_startup.py is imported
# at the bottom of worker.py, so worker.py isn't fully loaded yet.
from chronos.worker import job_dispatcher_task

job_dispatcher_task.delay()
Loading