Skip to content

Commit 399296a

Browse files
desertaxleclaude
andcommitted
Address code review comments: validation, logging, and constants
This commit addresses three review feedback items: 1. **Input Validation for Progress Methods** - Added validation in Progress.set_total() - must be positive - Added validation in Progress.set() - must be non-negative and <= total - Raises ValueError with descriptive messages - Added 4 tests to verify validation behavior 2. **Improved Lua Script Error Handling** - Added logging when mark_task_completed() encounters missing keys - Lua script returns 0 when keys don't exist (unchanged) - Now logs WARNING when result is 0, with key details - Added DEBUG logging when script is evicted and reloaded - Updated test to verify warning is logged 3. **Constant for Default Progress Total** - Added DEFAULT_PROGRESS_TOTAL = 100 constant - Used in ProgressInfo dataclass default - Used in ProgressInfo.from_record() fallback - Used in Progress.__aenter__() initialization - Eliminates hardcoded 100 throughout codebase Changes: - src/docket/state.py: Added constant, logging, improved error handling - src/docket/dependencies.py: Added validation, used constant - tests/test_dependencies.py: Added 4 validation tests - tests/test_state.py: Updated test to verify logging Test results: 58 tests passed, 96% coverage for state.py 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <noreply@anthropic.com>
1 parent 86fe92e commit 399296a

File tree

4 files changed

+140
-10
lines changed

4 files changed

+140
-10
lines changed

src/docket/dependencies.py

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -680,12 +680,14 @@ def __init__(self) -> None:
680680
self._current: int = 0
681681

682682
async def __aenter__(self) -> "Progress":
683+
from docket.state import DEFAULT_PROGRESS_TOTAL
684+
683685
execution = self.execution.get()
684686
docket = self.docket.get()
685687

686688
self._key = execution.key
687689
self._docket = docket
688-
self._total = 100
690+
self._total = DEFAULT_PROGRESS_TOTAL
689691
self._current = 0
690692
self._store = TaskStateStore(docket, docket.record_ttl)
691693

@@ -708,8 +710,13 @@ async def set_total(self, total: int) -> None:
708710
"""Set the total expected progress value.
709711
710712
Args:
711-
total: Total expected progress value
713+
total: Total expected progress value (must be positive)
714+
715+
Raises:
716+
ValueError: If total is not positive
712717
"""
718+
if total <= 0:
719+
raise ValueError(f"Progress total must be positive, got {total}")
713720
self._total = total
714721
await self._store.set_task_progress(
715722
self._key, ProgressInfo(current=self._current, total=self._total)
@@ -727,8 +734,17 @@ async def set(self, current: int) -> None:
727734
"""Set the current progress value directly.
728735
729736
Args:
730-
current: Current progress value
737+
current: Current progress value (must be non-negative and <= total)
738+
739+
Raises:
740+
ValueError: If current is negative or exceeds total
731741
"""
742+
if current < 0:
743+
raise ValueError(f"Progress current must be non-negative, got {current}")
744+
if current > self._total:
745+
raise ValueError(
746+
f"Progress current ({current}) cannot exceed total ({self._total})"
747+
)
732748
self._current = current
733749
await self._store.set_task_progress(
734750
self._key, ProgressInfo(current=self._current, total=self._total)

src/docket/state.py

Lines changed: 23 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import logging
12
from dataclasses import dataclass, field
23
from datetime import datetime, timezone
34
from typing import TYPE_CHECKING, Any, cast
@@ -7,6 +8,12 @@
78
if TYPE_CHECKING:
89
from docket import Docket
910

11+
logger: logging.Logger = logging.getLogger(__name__)
12+
13+
14+
# Default total value for progress tracking
15+
DEFAULT_PROGRESS_TOTAL = 100
16+
1017

1118
@dataclass
1219
class ProgressInfo:
@@ -18,7 +25,7 @@ class ProgressInfo:
1825
"""
1926

2027
current: int = field(default=0)
21-
total: int = field(default=100)
28+
total: int = field(default=DEFAULT_PROGRESS_TOTAL)
2229

2330
@property
2431
def percentage(self) -> float:
@@ -34,7 +41,7 @@ def to_record(self) -> dict[str, int]:
3441
def from_record(cls, record: dict[str, int]) -> "ProgressInfo":
3542
return cls(
3643
current=record.get("current", 0),
37-
total=record.get("total", 100),
44+
total=record.get("total", DEFAULT_PROGRESS_TOTAL),
3845
)
3946

4047

@@ -286,7 +293,7 @@ async def mark_task_completed(self, key: str) -> None:
286293

287294
try:
288295
# Execute using cached SHA
289-
await redis.evalsha( # pyright: ignore[reportUnknownMemberType,reportGeneralTypeIssues]
296+
result = await redis.evalsha( # pyright: ignore[reportUnknownMemberType,reportGeneralTypeIssues]
290297
TaskStateStore._completion_script_sha,
291298
2, # number of keys
292299
progress_key,
@@ -295,15 +302,27 @@ async def mark_task_completed(self, key: str) -> None:
295302
self.record_ttl,
296303
)
297304
except NoScriptError:
305+
# Script was evicted from Redis, reload and retry
306+
logger.debug("Lua script evicted from Redis, reloading for key %s", key)
298307
TaskStateStore._completion_script_sha = cast(
299308
str,
300309
await redis.script_load(self._COMPLETION_SCRIPT), # pyright: ignore[reportUnknownMemberType,reportGeneralTypeIssues]
301310
)
302-
await redis.evalsha( # pyright: ignore[reportUnknownMemberType,reportGeneralTypeIssues]
311+
result = await redis.evalsha( # pyright: ignore[reportUnknownMemberType,reportGeneralTypeIssues]
303312
TaskStateStore._completion_script_sha,
304313
2, # number of keys
305314
progress_key,
306315
state_key,
307316
now,
308317
self.record_ttl,
309318
)
319+
320+
# Log if task state didn't exist (script returns 0)
321+
if result == 0:
322+
logger.warning(
323+
"Task state not found when marking completed: %s "
324+
"(progress key: %s, state key: %s)",
325+
key,
326+
progress_key,
327+
state_key,
328+
)

tests/test_dependencies.py

Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -515,3 +515,88 @@ async def task_with_progress_context(progress: Progress = Progress()):
515515

516516
assert entered
517517
assert exited
518+
519+
520+
async def test_progress_set_total_validation(docket: Docket, worker: Worker):
521+
"""Progress.set_total() should validate input."""
522+
from docket.dependencies import Progress
523+
524+
validation_error = None
525+
526+
async def task_with_invalid_total(progress: Progress = Progress()):
527+
nonlocal validation_error
528+
try:
529+
await progress.set_total(-10)
530+
except ValueError as e:
531+
validation_error = e
532+
533+
docket.register(task_with_invalid_total)
534+
await docket.add(task_with_invalid_total, key="validation-test")()
535+
await worker.run_until_finished()
536+
537+
assert validation_error is not None
538+
assert "must be positive" in str(validation_error)
539+
540+
541+
async def test_progress_set_total_zero_validation(docket: Docket, worker: Worker):
542+
"""Progress.set_total() should reject zero."""
543+
from docket.dependencies import Progress
544+
545+
validation_error = None
546+
547+
async def task_with_zero_total(progress: Progress = Progress()):
548+
nonlocal validation_error
549+
try:
550+
await progress.set_total(0)
551+
except ValueError as e:
552+
validation_error = e
553+
554+
docket.register(task_with_zero_total)
555+
await docket.add(task_with_zero_total, key="zero-validation-test")()
556+
await worker.run_until_finished()
557+
558+
assert validation_error is not None
559+
assert "must be positive" in str(validation_error)
560+
561+
562+
async def test_progress_set_negative_validation(docket: Docket, worker: Worker):
563+
"""Progress.set() should validate negative values."""
564+
from docket.dependencies import Progress
565+
566+
validation_error = None
567+
568+
async def task_with_negative_current(progress: Progress = Progress()):
569+
nonlocal validation_error
570+
try:
571+
await progress.set(-5)
572+
except ValueError as e:
573+
validation_error = e
574+
575+
docket.register(task_with_negative_current)
576+
await docket.add(task_with_negative_current, key="negative-validation-test")()
577+
await worker.run_until_finished()
578+
579+
assert validation_error is not None
580+
assert "must be non-negative" in str(validation_error)
581+
582+
583+
async def test_progress_set_exceeds_total_validation(docket: Docket, worker: Worker):
584+
"""Progress.set() should validate current doesn't exceed total."""
585+
from docket.dependencies import Progress
586+
587+
validation_error = None
588+
589+
async def task_with_exceeding_current(progress: Progress = Progress()):
590+
nonlocal validation_error
591+
await progress.set_total(100)
592+
try:
593+
await progress.set(150)
594+
except ValueError as e:
595+
validation_error = e
596+
597+
docket.register(task_with_exceeding_current)
598+
await docket.add(task_with_exceeding_current, key="exceeds-validation-test")()
599+
await worker.run_until_finished()
600+
601+
assert validation_error is not None
602+
assert "cannot exceed total" in str(validation_error)

tests/test_state.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
from datetime import datetime, timezone
44

5+
import pytest
56

67
from docket import Docket, Worker
78
from docket.state import ProgressInfo, TaskState, TaskStateStore
@@ -290,12 +291,21 @@ async def test_get_task_state_missing_progress_key(self, docket: Docket):
290291
state = await store.get_task_state("test-task-key")
291292
assert state is None
292293

293-
async def test_mark_task_completed_nonexistent(self, docket: Docket):
294-
"""Marking nonexistent task as completed should not error."""
294+
async def test_mark_task_completed_nonexistent(
295+
self, docket: Docket, caplog: pytest.LogCaptureFixture
296+
):
297+
"""Marking nonexistent task as completed should not error but should log warning."""
298+
import logging
299+
295300
store = TaskStateStore(docket, record_ttl=3600)
296301

297302
# Should not raise an exception
298-
await store.mark_task_completed("nonexistent-key")
303+
with caplog.at_level(logging.WARNING):
304+
await store.mark_task_completed("nonexistent-key")
305+
306+
# Should log a warning about missing task state
307+
assert "Task state not found when marking completed" in caplog.text
308+
assert "nonexistent-key" in caplog.text
299309

300310
async def test_mark_task_completed_missing_total(self, docket: Docket):
301311
"""Marking task completed with missing total field should not error."""

0 commit comments

Comments
 (0)