Skip to content

feat: Refactor Pagination: Remove Next URL Field and Update Tests #2045

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions llama_stack/apis/common/responses.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

from typing import Any
from typing import Any, Dict, List, Optional

from pydantic import BaseModel

Expand All @@ -17,7 +17,9 @@ class PaginatedResponse(BaseModel):

:param data: The list of items for the current page
:param has_more: Whether there are more items available after this set
:param url: Optional URL to fetch the next page of results. Only present if has_more is true.
"""

data: list[dict[str, Any]]
data: List[Dict[str, Any]]
has_more: bool
url: Optional[str] = None
45 changes: 43 additions & 2 deletions llama_stack/distribution/server/server.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@

# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
Expand All @@ -12,6 +13,7 @@
import sys
import traceback
import warnings
import urllib.parse
from contextlib import asynccontextmanager
from importlib.metadata import version as parse_version
from pathlib import Path
Expand All @@ -25,6 +27,7 @@
from openai import BadRequestError
from pydantic import BaseModel, ValidationError

from llama_stack.apis.common.responses import PaginatedResponse
from llama_stack.distribution.datatypes import LoggingConfig, StackRunConfig
from llama_stack.distribution.distribution import builtin_automatically_routed_apis
from llama_stack.distribution.request_headers import (
Expand Down Expand Up @@ -202,8 +205,46 @@ async def endpoint(request: Request, **kwargs):
)
return StreamingResponse(gen, media_type="text/event-stream")
else:
value = func(**kwargs)
return await maybe_await(value)
# Execute the actual implementation function
result_value = func(**kwargs)
value = await maybe_await(result_value)

# Check if the result is a PaginatedResponse and needs a next URL
if isinstance(value, PaginatedResponse) and value.has_more:
try:
# Retrieve pagination params from original call kwargs
limit = kwargs.get("limit")
start_index = kwargs.get("start_index", 0) # Default to 0 if not provided

# Ensure params are integers
limit = int(limit) if limit is not None else None
start_index = int(start_index) if start_index is not None else 0

if limit is not None and limit > 0:
next_start_index = start_index + limit

# Build query params for the next page URL
next_params = dict(request.query_params)
next_params['start_index'] = str(next_start_index)
# Ensure limit is also included/updated if necessary
next_params['limit'] = str(limit)

# Construct the full URL for the next page
next_url = str(request.url.replace(query=urllib.parse.urlencode(next_params)))
# Assign the URL to the response object (assuming 'url' field exists)
value.url = next_url
else:
# Log a warning if limit is missing or invalid for pagination that has_more
logger.warning(f"PaginatedResponse has_more=True but limit is missing or invalid for request: {request.url}")

except (ValueError, TypeError) as e:
logger.error(f"Error processing pagination parameters for URL generation: {e}", exc_info=True)
except AttributeError:
# This might happen if PaginatedResponse doesn't have the 'url' field yet.
# Should not happen if Task 1 was completed correctly.
logger.error("PaginatedResponse object does not have a 'url' attribute. Ensure the model definition is updated.", exc_info=True)

return value # Return the (potentially modified) value
except Exception as e:
logger.exception(f"Error executing endpoint {route=} {method=}")
raise translate_exception(e) from e
Expand Down
66 changes: 59 additions & 7 deletions tests/integration/datasets/test_datasets.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@

# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
Expand All @@ -8,6 +9,7 @@
import base64
import mimetypes
import os
from urllib.parse import urlparse, parse_qs

import pytest

Expand All @@ -33,7 +35,7 @@ def data_url_from_file(file_path: str) -> str:

@pytest.mark.skip(reason="flaky. Couldn't find 'llamastack/simpleqa' on the Hugging Face Hub")
@pytest.mark.parametrize(
"purpose, source, provider_id, limit",
"purpose, source, provider_id, limit, total_expected",
[
(
"eval/messages-answer",
Expand All @@ -42,7 +44,8 @@ def data_url_from_file(file_path: str) -> str:
"uri": "huggingface://datasets/llamastack/simpleqa?split=train",
},
"huggingface",
10,
5, # Request 5, expect more
10, # Assume total > 5
),
(
"eval/messages-answer",
Expand All @@ -62,10 +65,20 @@ def data_url_from_file(file_path: str) -> str:
],
"answer": "Paris",
},
{
"messages": [
{
"role": "user",
"content": "Third message",
}
],
"answer": "Third answer",
},
],
},
"localfs",
2,
2, # Request 2, expect more
3, # Total is 3
),
(
"eval/messages-answer",
Expand All @@ -74,23 +87,62 @@ def data_url_from_file(file_path: str) -> str:
"uri": data_url_from_file(os.path.join(os.path.dirname(__file__), "test_dataset.csv")),
},
"localfs",
5,
3, # Request 3, expect more
5, # Total is 5
),
(
"eval/messages-answer",
{
"type": "uri",
"uri": data_url_from_file(os.path.join(os.path.dirname(__file__), "test_dataset.csv")),
},
"localfs",
5, # Request all 5, expect no more
5, # Total is 5
),
],
)
def test_register_and_iterrows(llama_stack_client, purpose, source, provider_id, limit):
def test_register_and_iterrows(llama_stack_client, purpose, source, provider_id, limit, total_expected):
dataset = llama_stack_client.datasets.register(
purpose=purpose,
source=source,
)
assert dataset.identifier is not None
assert dataset.provider_id == provider_id
iterrow_response = llama_stack_client.datasets.iterrows(dataset.identifier, limit=limit)
assert len(iterrow_response.data) == limit

# Initial request
start_index = 0
iterrow_response = llama_stack_client.datasets.iterrows(dataset.identifier, limit=limit, start_index=start_index)
assert len(iterrow_response.data) == min(limit, total_expected)

# Check pagination fields
expected_has_more = (start_index + limit) < total_expected
assert iterrow_response.has_more == expected_has_more

if expected_has_more:
assert hasattr(iterrow_response, "url"), "PaginatedResponse should have a 'url' field when has_more is True"
assert iterrow_response.url is not None, "PaginatedResponse url should not be None when has_more is True"
# Parse the URL to check parameters
parsed_url = urlparse(iterrow_response.url)
query_params = parse_qs(parsed_url.query)
assert "start_index" in query_params, "Next page URL must contain start_index"
assert int(query_params["start_index"][0]) == start_index + limit, "Next page URL start_index is incorrect"
assert "limit" in query_params, "Next page URL must contain limit"
assert int(query_params["limit"][0]) == limit, "Next page URL limit is incorrect"
assert parsed_url.path == f"/datasets/{dataset.identifier}/iterrows", "Next page URL path is incorrect"

# Optionally, make a request to the next page URL (requires client base_url to be set)
# This is more complex as it bypasses the client method

else:
assert not hasattr(iterrow_response, "url") or iterrow_response.url is None, "PaginatedResponse url should be None or missing when has_more is False"


# List and check presence
dataset_list = llama_stack_client.datasets.list()
assert dataset.identifier in [d.identifier for d in dataset_list]

# Unregister and check absence
llama_stack_client.datasets.unregister(dataset.identifier)
dataset_list = llama_stack_client.datasets.list()
assert dataset.identifier not in [d.identifier for d in dataset_list]
Loading