Skip to content

Commit c163350

Browse files
committed
Add cancel() method to QueueJobStore
Allow transitioning queued or in_progress jobs to the cancelled state, matching the JobStatus enum value that was previously unreachable. Includes tests for both valid transitions and the rejection of cancelling already-completed jobs.
1 parent 3e72389 commit c163350

File tree

2 files changed

+64
-0
lines changed

2 files changed

+64
-0
lines changed

src/docverse/storage/queue_job_store.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -201,6 +201,29 @@ async def fail(
201201
await self._session.refresh(row)
202202
return QueueJob.model_validate(row, from_attributes=True)
203203

204+
async def cancel(self, job_id: int) -> QueueJob:
205+
"""Mark job cancelled, set date_completed=now().
206+
207+
Raises
208+
------
209+
InvalidJobStateError
210+
If the job is not in queued or in_progress status.
211+
"""
212+
row = await self._get_row(job_id)
213+
allowed = {JobStatus.queued.value, JobStatus.in_progress.value}
214+
if row.status not in allowed:
215+
msg = (
216+
f"Cannot cancel job {job_id}: "
217+
f"expected 'queued'/'in_progress', "
218+
f"got '{row.status}'"
219+
)
220+
raise InvalidJobStateError(msg)
221+
row.status = JobStatus.cancelled.value
222+
row.date_completed = datetime.now(tz=UTC)
223+
await self._session.flush()
224+
await self._session.refresh(row)
225+
return QueueJob.model_validate(row, from_attributes=True)
226+
204227
async def _get_row(self, job_id: int) -> SqlQueueJob:
205228
"""Fetch a SqlQueueJob row by id, raising if not found."""
206229
result = await self._session.execute(

tests/storage/queue_job_store_test.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -152,6 +152,47 @@ async def test_fail_job(
152152
assert failed.errors == {"message": "something went wrong"}
153153

154154

155+
@pytest.mark.asyncio
156+
async def test_cancel_queued_job(
157+
db_session: async_scoped_session[AsyncSession],
158+
store: QueueJobStore,
159+
) -> None:
160+
async with db_session.begin():
161+
job = await store.create(kind=JobKind.build_processing, org_id=1)
162+
cancelled = await store.cancel(job.id)
163+
await db_session.commit()
164+
assert cancelled.status == JobStatus.cancelled
165+
assert cancelled.date_completed is not None
166+
167+
168+
@pytest.mark.asyncio
169+
async def test_cancel_in_progress_job(
170+
db_session: async_scoped_session[AsyncSession],
171+
store: QueueJobStore,
172+
) -> None:
173+
async with db_session.begin():
174+
job = await store.create(kind=JobKind.build_processing, org_id=1)
175+
await store.start(job.id)
176+
cancelled = await store.cancel(job.id)
177+
await db_session.commit()
178+
assert cancelled.status == JobStatus.cancelled
179+
assert cancelled.date_completed is not None
180+
181+
182+
@pytest.mark.asyncio
183+
async def test_cancel_completed_job_raises(
184+
db_session: async_scoped_session[AsyncSession],
185+
store: QueueJobStore,
186+
) -> None:
187+
async with db_session.begin():
188+
job = await store.create(kind=JobKind.build_processing, org_id=1)
189+
await store.start(job.id)
190+
await store.complete(job.id)
191+
with pytest.raises(InvalidJobStateError):
192+
await store.cancel(job.id)
193+
await db_session.commit()
194+
195+
155196
@pytest.mark.asyncio
156197
async def test_get_by_public_id(
157198
db_session: async_scoped_session[AsyncSession],

0 commit comments

Comments
 (0)