|
2 | 2 |
|
3 | 3 |
|
4 | 4 | try: |
5 | | - from sqlalchemy import Table, delete, select |
| 5 | + from sqlalchemy import Table, delete, func, select |
6 | 6 | from sqlalchemy.ext.asyncio import ( |
7 | 7 | AsyncEngine, |
8 | 8 | AsyncSession, |
|
21 | 21 |
|
22 | 22 | from a2a.server.context import ServerCallContext |
23 | 23 | from a2a.server.models import Base, TaskModel, create_task_model |
24 | | -from a2a.server.tasks.task_store import TaskStore |
25 | | -from a2a.types import Task # Task is the Pydantic model |
| 24 | +from a2a.server.tasks.task_store import TaskStore, TasksPage |
| 25 | +from a2a.types import ListTasksParams, Task |
| 26 | +from a2a.utils.constants import DEFAULT_LIST_TASKS_PAGE_SIZE |
26 | 27 |
|
27 | 28 |
|
28 | 29 | logger = logging.getLogger(__name__) |
@@ -147,6 +148,54 @@ async def get( |
147 | 148 | logger.debug('Task %s not found in store.', task_id) |
148 | 149 | return None |
149 | 150 |
|
| 151 | + async def list( |
| 152 | + self, params: ListTasksParams, context: ServerCallContext | None = None |
| 153 | + ) -> TasksPage: |
| 154 | + """Retrieves all tasks from the database.""" |
| 155 | + await self._ensure_initialized() |
| 156 | + async with self.async_session_maker() as session: |
| 157 | + page_number = int(params.page_token) if params.page_token else 0 |
| 158 | + page_size = params.page_size or DEFAULT_LIST_TASKS_PAGE_SIZE |
| 159 | + offset = page_number * page_size |
| 160 | + |
| 161 | + # Base query for filtering |
| 162 | + base_stmt = select(self.task_model) |
| 163 | + if params.context_id: |
| 164 | + base_stmt = base_stmt.where( |
| 165 | + self.task_model.context_id == params.context_id |
| 166 | + ) |
| 167 | + if params.status is not None: |
| 168 | + base_stmt = base_stmt.where( |
| 169 | + self.task_model.status['state'].as_string() |
| 170 | + == params.status.value |
| 171 | + ) |
| 172 | + |
| 173 | + # Get total count |
| 174 | + count_stmt = select(func.count()).select_from(base_stmt.alias()) |
| 175 | + total_count = (await session.execute(count_stmt)).scalar_one() |
| 176 | + |
| 177 | + # Get paginated results |
| 178 | + stmt = ( |
| 179 | + base_stmt.order_by(self.task_model.id.desc()) |
| 180 | + .limit(page_size) |
| 181 | + .offset(offset) |
| 182 | + ) |
| 183 | + result = await session.execute(stmt) |
| 184 | + tasks_models = result.scalars().all() |
| 185 | + tasks = [self._from_orm(task_model) for task_model in tasks_models] |
| 186 | + |
| 187 | + next_page_token = ( |
| 188 | + str(page_number + 1) |
| 189 | + if total_count > (page_number + 1) * page_size |
| 190 | + else '' |
| 191 | + ) |
| 192 | + |
| 193 | + return TasksPage( |
| 194 | + tasks=tasks, |
| 195 | + total_size=total_count, |
| 196 | + next_page_token=next_page_token, |
| 197 | + ) |
| 198 | + |
150 | 199 | async def delete( |
151 | 200 | self, task_id: str, context: ServerCallContext | None = None |
152 | 201 | ) -> None: |
|
0 commit comments