Skip to content

Commit e92ea2a

Browse files
zzstoatzzclaude
andcommitted
Make backend public and fix all type errors
- Change _backend to backend (make it public, not protected) - Worker legitimately needs backend access - tight coupling is OK - Fix Backend.lock() protocol signature (remove decorator, return AsyncContextManager) - Add type: ignore for uuid_extensions import (external lib without stubs) - Update all references from docket._backend to docket.backend - Update tests to use public backend All type errors in docket.py, worker.py, backends/base.py, and execution.py are now resolved (0 errors in modified files). 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <noreply@anthropic.com>
1 parent 13b5499 commit e92ea2a

File tree

5 files changed

+50
-52
lines changed

5 files changed

+50
-52
lines changed

src/docket/backends/base.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,9 @@
66

77
from __future__ import annotations
88

9-
from contextlib import asynccontextmanager
109
from datetime import datetime
1110
from typing import (
12-
AsyncGenerator,
11+
AsyncContextManager,
1312
Collection,
1413
Protocol,
1514
Sequence,
@@ -278,22 +277,21 @@ async def release_concurrency_slot(
278277
...
279278

280279
# Distributed lock operations
281-
@asynccontextmanager
282-
async def lock(
280+
def lock(
283281
self,
284282
key: str,
285283
timeout: float,
286284
blocking: bool = True,
287-
) -> AsyncGenerator[None, None]:
285+
) -> AsyncContextManager[None]:
288286
"""Acquire a distributed lock.
289287
290288
Args:
291289
key: Lock key
292290
timeout: Lock timeout in seconds
293291
blocking: Whether to block waiting for lock
294292
295-
Yields:
296-
None when lock is acquired
293+
Returns:
294+
Async context manager that yields when lock is acquired
297295
298296
Raises:
299297
LockError: If lock cannot be acquired (when blocking=False)

src/docket/docket.py

Lines changed: 24 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
)
2222

2323
from opentelemetry import trace
24-
from uuid_extensions import uuid7
24+
from uuid_extensions import uuid7 # type: ignore[import-untyped]
2525

2626
from .backends import Backend, RedisBackend
2727
from .execution import (
@@ -95,9 +95,9 @@ async def my_task(greeting: str, recipient: str) -> None:
9595

9696
tasks: dict[str, TaskFunction]
9797
strike_list: StrikeList
98+
backend: Backend
9899

99100
_monitor_strikes_task: asyncio.Task[None]
100-
_backend: Backend
101101

102102
def __init__(
103103
self,
@@ -124,7 +124,7 @@ def __init__(
124124
self.heartbeat_interval = heartbeat_interval
125125
self.missed_heartbeats = missed_heartbeats
126126
# Create backend based on URL scheme
127-
self._backend = RedisBackend(name, url)
127+
self.backend = RedisBackend(name, url)
128128

129129
@property
130130
def worker_group_name(self) -> str:
@@ -136,7 +136,7 @@ async def __aenter__(self) -> Self:
136136
self.tasks = {fn.__name__: fn for fn in standard_tasks}
137137
self.strike_list = StrikeList()
138138

139-
await self._backend.initialize()
139+
await self.backend.initialize()
140140
self._monitor_strikes_task = asyncio.create_task(self._monitor_strikes())
141141

142142
return self
@@ -159,7 +159,7 @@ async def __aexit__(
159159
except asyncio.CancelledError:
160160
pass
161161

162-
await self._backend.close()
162+
await self.backend.close()
163163

164164
def register(self, function: TaskFunction) -> None:
165165
"""Register a task with the Docket.
@@ -265,7 +265,7 @@ async def scheduler(*args: P.args, **kwargs: P.kwargs) -> Execution:
265265
)
266266
return execution
267267

268-
await self._backend.schedule_task(execution, replace=False)
268+
await self.backend.schedule_task(execution, replace=False)
269269

270270
TASKS_ADDED.add(1, {**self.labels(), **execution.general_labels()})
271271
TASKS_SCHEDULED.add(1, {**self.labels(), **execution.general_labels()})
@@ -340,7 +340,7 @@ async def scheduler(*args: P.args, **kwargs: P.kwargs) -> Execution:
340340
)
341341
return execution
342342

343-
await self._backend.schedule_task(execution, replace=True)
343+
await self.backend.schedule_task(execution, replace=True)
344344

345345
TASKS_REPLACED.add(1, {**self.labels(), **execution.general_labels()})
346346
TASKS_CANCELLED.add(1, {**self.labels(), **execution.general_labels()})
@@ -376,7 +376,7 @@ async def schedule(self, execution: Execution) -> None:
376376
"code.function.name": execution.function.__name__,
377377
},
378378
):
379-
await self._backend.schedule_task(execution, replace=False)
379+
await self.backend.schedule_task(execution, replace=False)
380380

381381
TASKS_SCHEDULED.add(1, {**self.labels(), **execution.general_labels()})
382382

@@ -390,30 +390,30 @@ async def cancel(self, key: str) -> None:
390390
"docket.cancel",
391391
attributes={**self.labels(), "docket.key": key},
392392
):
393-
await self._backend.cancel_task(key)
393+
await self.backend.cancel_task(key)
394394

395395
TASKS_CANCELLED.add(1, self.labels())
396396

397397
@property
398398
def queue_key(self) -> str:
399-
return self._backend.queue_key
399+
return self.backend.queue_key
400400

401401
@property
402402
def stream_key(self) -> str:
403-
return self._backend.stream_key
403+
return self.backend.stream_key
404404

405405
def known_task_key(self, key: str) -> str:
406-
return self._backend.known_task_key(key)
406+
return self.backend.known_task_key(key)
407407

408408
def parked_task_key(self, key: str) -> str:
409-
return self._backend.parked_task_key(key)
409+
return self.backend.parked_task_key(key)
410410

411411
def stream_id_key(self, key: str) -> str:
412-
return self._backend.stream_id_key(key)
412+
return self.backend.stream_id_key(key)
413413

414414
@property
415415
def strike_key(self) -> str:
416-
return self._backend.strike_key
416+
return self.backend.strike_key
417417

418418
async def strike(
419419
self,
@@ -469,14 +469,14 @@ async def _send_strike_instruction(self, instruction: StrikeInstruction) -> None
469469
**instruction.labels(),
470470
},
471471
):
472-
await self._backend.send_strike_instruction(instruction)
472+
await self.backend.send_strike_instruction(instruction)
473473
self.strike_list.update(instruction)
474474

475475
async def _monitor_strikes(self) -> NoReturn:
476476
last_id = "0-0"
477477
while True:
478478
try:
479-
instructions = await self._backend.receive_strike_instructions(
479+
instructions = await self.backend.receive_strike_instructions(
480480
last_id, timeout_ms=60_000
481481
)
482482
for message_id, instruction in instructions:
@@ -518,7 +518,7 @@ async def snapshot(self) -> DocketSnapshot:
518518
total_tasks,
519519
future_executions,
520520
running_executions,
521-
) = await self._backend.snapshot(self.tasks)
521+
) = await self.backend.snapshot(self.tasks)
522522
workers = await self.workers()
523523

524524
return DocketSnapshot(
@@ -527,13 +527,13 @@ async def snapshot(self) -> DocketSnapshot:
527527

528528
@property
529529
def workers_set(self) -> str:
530-
return self._backend.workers_set
530+
return self.backend.workers_set
531531

532532
def worker_tasks_set(self, worker_name: str) -> str:
533-
return self._backend.worker_tasks_set(worker_name)
533+
return self.backend.worker_tasks_set(worker_name)
534534

535535
def task_workers_set(self, task_name: str) -> str:
536-
return self._backend.task_workers_set(task_name)
536+
return self.backend.task_workers_set(task_name)
537537

538538
async def workers(self) -> Collection[WorkerInfo]:
539539
"""Get a list of all workers that have sent heartbeats to the Docket.
@@ -544,7 +544,7 @@ async def workers(self) -> Collection[WorkerInfo]:
544544
heartbeat_timeout = (
545545
self.heartbeat_interval.total_seconds() * self.missed_heartbeats
546546
)
547-
return await self._backend.get_workers(heartbeat_timeout)
547+
return await self.backend.get_workers(heartbeat_timeout)
548548

549549
async def task_workers(self, task_name: str) -> Collection[WorkerInfo]:
550550
"""Get a list of all workers that are able to execute a given task.
@@ -558,7 +558,7 @@ async def task_workers(self, task_name: str) -> Collection[WorkerInfo]:
558558
heartbeat_timeout = (
559559
self.heartbeat_interval.total_seconds() * self.missed_heartbeats
560560
)
561-
return await self._backend.get_task_workers(task_name, heartbeat_timeout)
561+
return await self.backend.get_task_workers(task_name, heartbeat_timeout)
562562

563563
async def clear(self) -> int:
564564
"""Clear all pending and scheduled tasks from the docket.
@@ -574,4 +574,4 @@ async def clear(self) -> int:
574574
"docket.clear",
575575
attributes=self.labels(),
576576
):
577-
return await self._backend.clear()
577+
return await self.backend.clear()

src/docket/worker.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -245,11 +245,11 @@ async def _worker_loop(self, forever: bool = False):
245245

246246
async def check_for_work() -> bool:
247247
logger.debug("Checking for work", extra=log_context)
248-
return await self.docket._backend.check_for_work()
248+
return await self.docket.backend.check_for_work()
249249

250250
async def get_redeliveries() -> RedisReadGroupResponse:
251251
logger.debug("Getting redeliveries", extra=log_context)
252-
redeliveries = await self.docket._backend.reclaim_stale_tasks(
252+
redeliveries = await self.docket.backend.reclaim_stale_tasks(
253253
worker_name=self.name,
254254
max_count=available_slots,
255255
stale_timeout_ms=int(self.redelivery_timeout.total_seconds() * 1000),
@@ -258,7 +258,7 @@ async def get_redeliveries() -> RedisReadGroupResponse:
258258

259259
async def get_new_deliveries() -> RedisReadGroupResponse:
260260
logger.debug("Getting new deliveries", extra=log_context)
261-
messages = await self.docket._backend.claim_tasks(
261+
messages = await self.docket.backend.claim_tasks(
262262
worker_name=self.name,
263263
max_count=available_slots,
264264
timeout_ms=int(self.minimum_check_interval.total_seconds() * 1000),
@@ -301,7 +301,7 @@ async def process_completed_tasks() -> None:
301301

302302
async def ack_message(message_id: RedisMessageID) -> None:
303303
logger.debug("Acknowledging message", extra=log_context)
304-
await self.docket._backend.acknowledge_task(message_id)
304+
await self.docket.backend.acknowledge_task(message_id)
305305

306306
has_work: bool = True
307307

@@ -364,7 +364,7 @@ async def _scheduler_loop(
364364
try:
365365
logger.debug("Scheduling due tasks", extra=log_context)
366366
now = datetime.now(timezone.utc)
367-
total_work, due_work = await self.docket._backend.move_due_tasks(now)
367+
total_work, due_work = await self.docket.backend.move_due_tasks(now)
368368

369369
if due_work > 0:
370370
logger.debug(
@@ -388,7 +388,7 @@ async def _scheduler_loop(
388388

389389
async def _schedule_all_automatic_perpetual_tasks(self) -> None:
390390
try:
391-
async with self.docket._backend.lock(
391+
async with self.docket.backend.lock(
392392
f"{self.docket.name}:perpetual:lock", timeout=10, blocking=False
393393
):
394394
for task_function in self.docket.tasks.values():
@@ -419,7 +419,7 @@ async def _delete_known_task(
419419
return
420420

421421
logger.debug("Deleting known task", extra=self._log_context())
422-
await self.docket._backend.delete_known_task(key)
422+
await self.docket.backend.delete_known_task(key)
423423

424424
async def _execute(self, execution: Execution) -> None:
425425
log_context = {**self._log_context(), **execution.specific_labels()}
@@ -486,7 +486,7 @@ async def _execute(self, execution: Execution) -> None:
486486
# Try to acquire a concurrency slot
487487
expiry_seconds = self.redelivery_timeout.total_seconds()
488488
acquired = (
489-
await self.docket._backend.acquire_concurrency_slot(
489+
await self.docket.backend.acquire_concurrency_slot(
490490
key=concurrency_key,
491491
max_concurrent=concurrency_limit.max_concurrent,
492492
worker_name=self.name,
@@ -615,7 +615,7 @@ async def _execute(self, execution: Execution) -> None:
615615
scope = concurrency_limit.scope or self.docket.name
616616
concurrency_key = f"{scope}:concurrency:{concurrency_limit.argument_name}:{argument_value}"
617617

618-
await self.docket._backend.release_concurrency_slot(
618+
await self.docket.backend.release_concurrency_slot(
619619
key=concurrency_key,
620620
worker_name=self.name,
621621
task_key=execution.key,
@@ -740,7 +740,7 @@ async def _heartbeat(self) -> None:
740740
stream_depth,
741741
overdue_depth,
742742
schedule_depth,
743-
) = await self.docket._backend.record_heartbeat(
743+
) = await self.docket.backend.record_heartbeat(
744744
worker_name=self.name,
745745
task_names=task_names,
746746
timestamp=now,

tests/test_docket.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -150,14 +150,14 @@ async def test_clear_no_redis_key_leaks(docket: Docket, the_task: AsyncMock):
150150
await docket.add(the_task, when=future)("scheduled1")
151151
await docket.add(the_task, when=future + timedelta(seconds=1))("scheduled2")
152152

153-
async with docket._backend.redis() as r:
153+
async with docket.backend.redis() as r:
154154
keys_before = cast(list[str], await r.keys("*")) # type: ignore
155155
keys_before_count = len(keys_before)
156156

157157
result = await docket.clear()
158158
assert result == 5
159159

160-
async with docket._backend.redis() as r:
160+
async with docket.backend.redis() as r:
161161
keys_after = cast(list[str], await r.keys("*")) # type: ignore
162162
keys_after_count = len(keys_after)
163163

0 commit comments

Comments
 (0)