Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: retry queue items #7649

Merged
merged 4 commits into from
Feb 17, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions invokeai/app/api/routers/session_queue.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
ClearResult,
EnqueueBatchResult,
PruneResult,
RetryItemsResult,
SessionQueueCountsByDestination,
SessionQueueItem,
SessionQueueItemDTO,
Expand Down Expand Up @@ -135,6 +136,19 @@ async def cancel_by_destination(
)


@session_queue_router.put(
"/{queue_id}/retry_items_by_id",
operation_id="retry_items_by_id",
responses={200: {"model": RetryItemsResult}},
)
async def retry_items_by_id(
queue_id: str = Path(description="The queue id to perform this operation on"),
item_ids: list[int] = Body(description="The queue item ids to retry"),
) -> RetryItemsResult:
"""Immediately cancels all queue items with the given origin"""
return ApiDependencies.invoker.services.session_queue.retry_items_by_id(queue_id=queue_id, item_ids=item_ids)


@session_queue_router.put(
"/{queue_id}/clear",
operation_id="clear",
Expand Down
6 changes: 6 additions & 0 deletions invokeai/app/services/events/events_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
ModelLoadCompleteEvent,
ModelLoadStartedEvent,
QueueClearedEvent,
QueueItemsRetriedEvent,
QueueItemStatusChangedEvent,
)

Expand All @@ -39,6 +40,7 @@
from invokeai.app.services.session_queue.session_queue_common import (
BatchStatus,
EnqueueBatchResult,
RetryItemsResult,
SessionQueueItem,
SessionQueueStatus,
)
Expand Down Expand Up @@ -99,6 +101,10 @@ def emit_batch_enqueued(self, enqueue_result: "EnqueueBatchResult") -> None:
"""Emitted when a batch is enqueued"""
self.dispatch(BatchEnqueuedEvent.build(enqueue_result))

def emit_queue_items_retried(self, retry_result: "RetryItemsResult") -> None:
"""Emitted when a list of queue items are retried"""
self.dispatch(QueueItemsRetriedEvent.build(retry_result))

def emit_queue_cleared(self, queue_id: str) -> None:
"""Emitted when a queue is cleared"""
self.dispatch(QueueClearedEvent.build(queue_id))
Expand Down
17 changes: 17 additions & 0 deletions invokeai/app/services/events/events_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
QUEUE_ITEM_STATUS,
BatchStatus,
EnqueueBatchResult,
RetryItemsResult,
SessionQueueItem,
SessionQueueStatus,
)
Expand Down Expand Up @@ -290,6 +291,22 @@ def build(cls, enqueue_result: EnqueueBatchResult) -> "BatchEnqueuedEvent":
)


@payload_schema.register
class QueueItemsRetriedEvent(QueueEventBase):
"""Event model for queue_items_retried"""

__event_name__ = "queue_items_retried"

retried_item_ids: list[int] = Field(description="The IDs of the queue items that were retried")

@classmethod
def build(cls, retry_result: RetryItemsResult) -> "QueueItemsRetriedEvent":
return cls(
queue_id=retry_result.queue_id,
retried_item_ids=retry_result.retried_item_ids,
)


@payload_schema.register
class QueueClearedEvent(QueueEventBase):
"""Event model for queue_cleared"""
Expand Down
6 changes: 6 additions & 0 deletions invokeai/app/services/session_queue/session_queue_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
IsEmptyResult,
IsFullResult,
PruneResult,
RetryItemsResult,
SessionQueueCountsByDestination,
SessionQueueItem,
SessionQueueItemDTO,
Expand Down Expand Up @@ -139,3 +140,8 @@ def get_queue_item(self, item_id: int) -> SessionQueueItem:
def set_queue_item_session(self, item_id: int, session: GraphExecutionState) -> SessionQueueItem:
"""Sets the session for a session queue item. Use this to update the session state."""
pass

@abstractmethod
def retry_items_by_id(self, queue_id: str, item_ids: list[int]) -> RetryItemsResult:
"""Retries the given queue items"""
pass
27 changes: 18 additions & 9 deletions invokeai/app/services/session_queue/session_queue_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,6 +234,9 @@ class SessionQueueItemWithoutGraph(BaseModel):
field_values: Optional[list[NodeFieldValue]] = Field(
default=None, description="The field values that were used for this queue item"
)
retried_from_item_id: Optional[int] = Field(
default=None, description="The item_id of the queue item that this item was retried from"
)

@classmethod
def queue_item_dto_from_dict(cls, queue_item_dict: dict) -> "SessionQueueItemDTO":
Expand Down Expand Up @@ -344,6 +347,11 @@ class EnqueueBatchResult(BaseModel):
priority: int = Field(description="The priority of the enqueued batch")


class RetryItemsResult(BaseModel):
queue_id: str = Field(description="The ID of the queue")
retried_item_ids: list[int] = Field(description="The IDs of the queue items that were retried")


class ClearResult(BaseModel):
"""Result of clearing the session queue"""

Expand Down Expand Up @@ -481,6 +489,7 @@ class SessionQueueValueToInsert(NamedTuple):
workflow: Optional[str] # workflow json
origin: str | None
destination: str | None
retried_from_item_id: int | None = None


ValuesToInsert: TypeAlias = list[SessionQueueValueToInsert]
Expand All @@ -493,16 +502,16 @@ def prepare_values_to_insert(queue_id: str, batch: Batch, priority: int, max_new
session.id = uuid_string()
values_to_insert.append(
SessionQueueValueToInsert(
queue_id, # queue_id
session.model_dump_json(warnings=False, exclude_none=True), # session (json)
session.id, # session_id
batch.batch_id, # batch_id
queue_id=queue_id,
session=session.model_dump_json(warnings=False, exclude_none=True), # as json
session_id=session.id,
batch_id=batch.batch_id,
# must use pydantic_encoder bc field_values is a list of models
json.dumps(field_values, default=to_jsonable_python) if field_values else None, # field_values (json)
priority, # priority
json.dumps(workflow, default=to_jsonable_python) if workflow else None, # workflow (json)
batch.origin, # origin
batch.destination, # destination
field_values=json.dumps(field_values, default=to_jsonable_python) if field_values else None, # as json
priority=priority,
workflow=json.dumps(workflow, default=to_jsonable_python) if workflow else None, # as json
origin=batch.origin,
destination=batch.destination,
)
)
return values_to_insert
Expand Down
77 changes: 75 additions & 2 deletions invokeai/app/services/session_queue/session_queue_sqlite.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
import json
import sqlite3
import threading
from typing import Optional, Union, cast

from pydantic_core import to_jsonable_python

from invokeai.app.services.invoker import Invoker
from invokeai.app.services.session_queue.session_queue_base import SessionQueueBase
from invokeai.app.services.session_queue.session_queue_common import (
Expand All @@ -18,11 +21,13 @@
IsEmptyResult,
IsFullResult,
PruneResult,
RetryItemsResult,
SessionQueueCountsByDestination,
SessionQueueItem,
SessionQueueItemDTO,
SessionQueueItemNotFoundError,
SessionQueueStatus,
SessionQueueValueToInsert,
calc_session_count,
prepare_values_to_insert,
)
Expand Down Expand Up @@ -130,8 +135,8 @@ def enqueue_batch(self, queue_id: str, batch: Batch, prepend: bool) -> EnqueueBa

self.__cursor.executemany(
"""--sql
INSERT INTO session_queue (queue_id, session, session_id, batch_id, field_values, priority, workflow, origin, destination)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)
INSERT INTO session_queue (queue_id, session, session_id, batch_id, field_values, priority, workflow, origin, destination, retried_from_item_id)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
""",
values_to_insert,
)
Expand Down Expand Up @@ -761,3 +766,71 @@ def get_counts_by_destination(self, queue_id: str, destination: str) -> SessionQ
canceled=counts.get("canceled", 0),
total=total,
)

def retry_items_by_id(self, queue_id: str, item_ids: list[int]) -> RetryItemsResult:
"""Retries the given queue items"""
try:
self.__lock.acquire()

values_to_insert: list[SessionQueueValueToInsert] = []
retried_item_ids: list[int] = []

for item_id in item_ids:
queue_item = self.get_queue_item(item_id)

if queue_item.status not in ("failed", "canceled"):
continue

retried_item_ids.append(item_id)

field_values_json = (
json.dumps(queue_item.field_values, default=to_jsonable_python) if queue_item.field_values else None
)
workflow_json = (
json.dumps(queue_item.workflow, default=to_jsonable_python) if queue_item.workflow else None
)
cloned_session = GraphExecutionState(graph=queue_item.session.graph)
cloned_session_json = cloned_session.model_dump_json(warnings=False, exclude_none=True)

retried_from_item_id = (
queue_item.retried_from_item_id
if queue_item.retried_from_item_id is not None
else queue_item.item_id
)

value_to_insert = SessionQueueValueToInsert(
queue_id=queue_item.queue_id,
batch_id=queue_item.batch_id,
destination=queue_item.destination,
field_values=field_values_json,
origin=queue_item.origin,
priority=queue_item.priority,
workflow=workflow_json,
session=cloned_session_json,
session_id=cloned_session.id,
retried_from_item_id=retried_from_item_id,
)
values_to_insert.append(value_to_insert)

# TODO(psyche): Handle max queue size?

self.__cursor.executemany(
"""--sql
INSERT INTO session_queue (queue_id, session, session_id, batch_id, field_values, priority, workflow, origin, destination, retried_from_item_id)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
""",
values_to_insert,
)

self.__conn.commit()
except Exception:
self.__conn.rollback()
raise
finally:
self.__lock.release()
retry_result = RetryItemsResult(
queue_id=queue_id,
retried_item_ids=retried_item_ids,
)
self.__invoker.services.events.emit_queue_items_retried(retry_result)
return retry_result
2 changes: 2 additions & 0 deletions invokeai/app/services/shared/sqlite/sqlite_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_13 import build_migration_13
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_14 import build_migration_14
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_15 import build_migration_15
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_16 import build_migration_16
from invokeai.app.services.shared.sqlite_migrator.sqlite_migrator_impl import SqliteMigrator


Expand Down Expand Up @@ -53,6 +54,7 @@ def init_db(config: InvokeAIAppConfig, logger: Logger, image_files: ImageFileSto
migrator.register_migration(build_migration_13())
migrator.register_migration(build_migration_14())
migrator.register_migration(build_migration_15())
migrator.register_migration(build_migration_16())
migrator.run_migrations()

return db
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
import sqlite3

from invokeai.app.services.shared.sqlite_migrator.sqlite_migrator_common import Migration


class Migration16Callback:
def __call__(self, cursor: sqlite3.Cursor) -> None:
self._add_retried_from_item_id_col(cursor)

def _add_retried_from_item_id_col(self, cursor: sqlite3.Cursor) -> None:
"""
- Adds `retried_from_item_id` column to the session queue table.
"""

cursor.execute("ALTER TABLE session_queue ADD COLUMN retried_from_item_id INTEGER;")


def build_migration_16() -> Migration:
"""
Build the migration from database version 15 to 16.

This migration does the following:
- Adds `retried_from_item_id` column to the session queue table.
"""
migration_16 = Migration(
from_version=15,
to_version=16,
callback=Migration16Callback(),
)

return migration_16
3 changes: 3 additions & 0 deletions invokeai/frontend/web/public/locales/en.json
Original file line number Diff line number Diff line change
Expand Up @@ -228,6 +228,8 @@
"cancelTooltip": "Cancel Current Item",
"cancelSucceeded": "Item Canceled",
"cancelFailed": "Problem Canceling Item",
"retrySucceeded": "Item Retried",
"retryFailed": "Problem Retrying Item",
"confirm": "Confirm",
"prune": "Prune",
"pruneTooltip": "Prune {{item_count}} Completed Items",
Expand All @@ -239,6 +241,7 @@
"clearFailed": "Problem Clearing Queue",
"cancelBatch": "Cancel Batch",
"cancelItem": "Cancel Item",
"retryItem": "Retry Item",
"cancelBatchSucceeded": "Batch Canceled",
"cancelBatchFailed": "Problem Canceling Batch",
"clearQueueAlertDialog": "Clearing the queue immediately cancels any processing items and clears the queue entirely. Pending filters will be canceled.",
Expand Down
3 changes: 2 additions & 1 deletion invokeai/frontend/web/src/app/types/invokeai.ts
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,8 @@ export type AppFeature =
| 'modelCache'
| 'bulkDownload'
| 'starterModels'
| 'hfToken';
| 'hfToken'
| 'retryQueueItem';
/**
* A disable-able Stable Diffusion feature
*/
Expand Down
Loading