diff --git a/llama_stack/apis/common/responses.py b/llama_stack/apis/common/responses.py index b3bb5cb6b..47796d597 100644 --- a/llama_stack/apis/common/responses.py +++ b/llama_stack/apis/common/responses.py @@ -4,7 +4,7 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from typing import Any +from typing import Any, Dict, List, Optional from pydantic import BaseModel @@ -17,7 +17,9 @@ class PaginatedResponse(BaseModel): :param data: The list of items for the current page :param has_more: Whether there are more items available after this set + :param url: Optional URL to fetch the next page of results. Only present if has_more is true. """ - data: list[dict[str, Any]] + data: List[Dict[str, Any]] has_more: bool + url: Optional[str] = None diff --git a/llama_stack/distribution/server/server.py b/llama_stack/distribution/server/server.py index 0c8c70306..b0e587402 100644 --- a/llama_stack/distribution/server/server.py +++ b/llama_stack/distribution/server/server.py @@ -1,3 +1,4 @@ + # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # @@ -12,6 +13,7 @@ import sys import traceback import warnings +import urllib.parse from contextlib import asynccontextmanager from importlib.metadata import version as parse_version from pathlib import Path @@ -25,6 +27,7 @@ from openai import BadRequestError from pydantic import BaseModel, ValidationError +from llama_stack.apis.common.responses import PaginatedResponse from llama_stack.distribution.datatypes import LoggingConfig, StackRunConfig from llama_stack.distribution.distribution import builtin_automatically_routed_apis from llama_stack.distribution.request_headers import ( @@ -202,8 +205,46 @@ async def endpoint(request: Request, **kwargs): ) return StreamingResponse(gen, media_type="text/event-stream") else: - value = func(**kwargs) - return await maybe_await(value) + # Execute the actual implementation function + result_value = func(**kwargs) + value = await maybe_await(result_value) + + # Check if the result is a PaginatedResponse and needs a next URL + if isinstance(value, PaginatedResponse) and value.has_more: + try: + # Retrieve pagination params from original call kwargs + limit = kwargs.get("limit") + start_index = kwargs.get("start_index", 0) # Default to 0 if not provided + + # Ensure params are integers + limit = int(limit) if limit is not None else None + start_index = int(start_index) if start_index is not None else 0 + + if limit is not None and limit > 0: + next_start_index = start_index + limit + + # Build query params for the next page URL + next_params = dict(request.query_params) + next_params['start_index'] = str(next_start_index) + # Ensure limit is also included/updated if necessary + next_params['limit'] = str(limit) + + # Construct the full URL for the next page + next_url = str(request.url.replace(query=urllib.parse.urlencode(next_params))) + # Assign the URL to the response object (assuming 'url' field exists) + value.url = next_url + else: + # Log a warning if limit is missing or invalid for pagination that has_more + logger.warning(f"PaginatedResponse has_more=True but limit is missing or invalid for request: {request.url}") + + except (ValueError, TypeError) as e: + logger.error(f"Error processing pagination parameters for URL generation: {e}", exc_info=True) + except AttributeError: + # This might happen if PaginatedResponse doesn't have the 'url' field yet. + # Should not happen if Task 1 was completed correctly. + logger.error("PaginatedResponse object does not have a 'url' attribute. Ensure the model definition is updated.", exc_info=True) + + return value # Return the (potentially modified) value except Exception as e: logger.exception(f"Error executing endpoint {route=} {method=}") raise translate_exception(e) from e diff --git a/tests/integration/datasets/test_datasets.py b/tests/integration/datasets/test_datasets.py index 18b31d39c..2b8cd1906 100644 --- a/tests/integration/datasets/test_datasets.py +++ b/tests/integration/datasets/test_datasets.py @@ -1,3 +1,4 @@ + # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # @@ -8,6 +9,7 @@ import base64 import mimetypes import os +from urllib.parse import urlparse, parse_qs import pytest @@ -33,7 +35,7 @@ def data_url_from_file(file_path: str) -> str: @pytest.mark.skip(reason="flaky. Couldn't find 'llamastack/simpleqa' on the Hugging Face Hub") @pytest.mark.parametrize( - "purpose, source, provider_id, limit", + "purpose, source, provider_id, limit, total_expected", [ ( "eval/messages-answer", @@ -42,7 +44,8 @@ def data_url_from_file(file_path: str) -> str: "uri": "huggingface://datasets/llamastack/simpleqa?split=train", }, "huggingface", - 10, + 5, # Request 5, expect more + 10, # Assume total > 5 ), ( "eval/messages-answer", @@ -62,10 +65,20 @@ def data_url_from_file(file_path: str) -> str: ], "answer": "Paris", }, + { + "messages": [ + { + "role": "user", + "content": "Third message", + } + ], + "answer": "Third answer", + }, ], }, "localfs", - 2, + 2, # Request 2, expect more + 3, # Total is 3 ), ( "eval/messages-answer", @@ -74,23 +87,62 @@ def data_url_from_file(file_path: str) -> str: "uri": data_url_from_file(os.path.join(os.path.dirname(__file__), "test_dataset.csv")), }, "localfs", - 5, + 3, # Request 3, expect more + 5, # Total is 5 + ), + ( + "eval/messages-answer", + { + "type": "uri", + "uri": data_url_from_file(os.path.join(os.path.dirname(__file__), "test_dataset.csv")), + }, + "localfs", + 5, # Request all 5, expect no more + 5, # Total is 5 ), ], ) -def test_register_and_iterrows(llama_stack_client, purpose, source, provider_id, limit): +def test_register_and_iterrows(llama_stack_client, purpose, source, provider_id, limit, total_expected): dataset = llama_stack_client.datasets.register( purpose=purpose, source=source, ) assert dataset.identifier is not None assert dataset.provider_id == provider_id - iterrow_response = llama_stack_client.datasets.iterrows(dataset.identifier, limit=limit) - assert len(iterrow_response.data) == limit + # Initial request + start_index = 0 + iterrow_response = llama_stack_client.datasets.iterrows(dataset.identifier, limit=limit, start_index=start_index) + assert len(iterrow_response.data) == min(limit, total_expected) + + # Check pagination fields + expected_has_more = (start_index + limit) < total_expected + assert iterrow_response.has_more == expected_has_more + + if expected_has_more: + assert hasattr(iterrow_response, "url"), "PaginatedResponse should have a 'url' field when has_more is True" + assert iterrow_response.url is not None, "PaginatedResponse url should not be None when has_more is True" + # Parse the URL to check parameters + parsed_url = urlparse(iterrow_response.url) + query_params = parse_qs(parsed_url.query) + assert "start_index" in query_params, "Next page URL must contain start_index" + assert int(query_params["start_index"][0]) == start_index + limit, "Next page URL start_index is incorrect" + assert "limit" in query_params, "Next page URL must contain limit" + assert int(query_params["limit"][0]) == limit, "Next page URL limit is incorrect" + assert parsed_url.path == f"/datasets/{dataset.identifier}/iterrows", "Next page URL path is incorrect" + + # Optionally, make a request to the next page URL (requires client base_url to be set) + # This is more complex as it bypasses the client method + + else: + assert not hasattr(iterrow_response, "url") or iterrow_response.url is None, "PaginatedResponse url should be None or missing when has_more is False" + + + # List and check presence dataset_list = llama_stack_client.datasets.list() assert dataset.identifier in [d.identifier for d in dataset_list] + # Unregister and check absence llama_stack_client.datasets.unregister(dataset.identifier) dataset_list = llama_stack_client.datasets.list() assert dataset.identifier not in [d.identifier for d in dataset_list]