From d20ab1675ca1e4536c25489cd44e7914ff746cf5 Mon Sep 17 00:00:00 2001 From: qodot Date: Sun, 8 Sep 2024 16:23:12 +0900 Subject: [PATCH 1/2] add `_inject_async_pagination` to run async for loop for django queryset --- ninja/pagination.py | 50 +++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 50 insertions(+) diff --git a/ninja/pagination.py b/ninja/pagination.py index 47ab70de..be25d535 100644 --- a/ninja/pagination.py +++ b/ninja/pagination.py @@ -1,3 +1,4 @@ +from collections.abc import AsyncIterable, Iterable import inspect from abc import ABC, abstractmethod from functools import partial, wraps @@ -250,6 +251,55 @@ def view_with_pagination(request: HttpRequest, **kwargs: Any) -> Any: return view_with_pagination +def _inject_async_pagination( + func: Callable, + paginator_class: Type[AsyncPaginationBase], + **paginator_params: Any, +) -> Callable: + paginator = paginator_class(**paginator_params) + if not hasattr(paginator, "apaginate_queryset"): + raise ConfigError("Pagination class not configured for async requests") + + @wraps(func) + async def paginated_view(request: HttpRequest, **kwargs: Any) -> Any: + pagination_params = kwargs.pop("ninja_pagination") + if paginator.pass_parameter: + kwargs[paginator.pass_parameter] = pagination_params + + result = await func(request, **kwargs) + paginated_result = await paginator.apaginate_queryset( + result, pagination=pagination_params, request=request, **kwargs + ) + + if paginator.Output: # type: ignore + items = paginated_result[paginator.items_attribute] + if isinstance(items, QuerySet) or isinstance(items, AsyncIterable): + new_items = [r async for r in items] + elif isinstance(items, Iterable): + new_items = list(items) + else: + raise TypeError("Unexpected type") + + paginated_result[paginator.items_attribute] = new_items + + return paginated_result + + contribute_operation_args( + paginated_view, + "ninja_pagination", + paginator.Input, + paginator.InputSource, + ) + + if paginator.Output: # type: ignore + contribute_operation_callback( + paginated_view, + partial(make_response_paginated, paginator), + ) + + return paginated_view + + class RouterPaginated(Router): def __init__(self, *args: Any, **kwargs: Any) -> None: super().__init__(*args, **kwargs) From 5a8b63cead1195d282f03541263b6a2381981b0e Mon Sep 17 00:00:00 2001 From: qodot Date: Sun, 8 Sep 2024 16:23:40 +0900 Subject: [PATCH 2/2] apply `_inject_async_pagination` --- ninja/pagination.py | 79 +++++++++++++----------------------- tests/test_async_paginate.py | 42 +++++++++++++++++++ 2 files changed, 71 insertions(+), 50 deletions(-) create mode 100644 tests/test_async_paginate.py diff --git a/ninja/pagination.py b/ninja/pagination.py index be25d535..f4dc7a77 100644 --- a/ninja/pagination.py +++ b/ninja/pagination.py @@ -171,69 +171,45 @@ def my_view(request): ) if isfunction: - return _inject_pagination(func_or_pgn_class, pagination_class) + if is_async_callable(func_or_pgn_class): + return _inject_async_pagination(func_or_pgn_class, pagination_class) + else: + return _inject_pagination(func_or_pgn_class, pagination_class) if not isnotset: pagination_class = func_or_pgn_class def wrapper(func: Callable) -> Any: - return _inject_pagination(func, pagination_class, **paginator_params) + if is_async_callable(func): + return _inject_async_pagination(func, pagination_class, **paginator_params) + else: + return _inject_pagination(func, pagination_class, **paginator_params) return wrapper def _inject_pagination( func: Callable, - paginator_class: Type[Union[PaginationBase, AsyncPaginationBase]], + paginator_class: Type[PaginationBase], **paginator_params: Any, ) -> Callable: paginator = paginator_class(**paginator_params) - if is_async_callable(func): - if not hasattr(paginator, "apaginate_queryset"): - raise ConfigError("Pagination class not configured for async requests") - - @wraps(func) - async def view_with_pagination(request: HttpRequest, **kwargs: Any) -> Any: - pagination_params = kwargs.pop("ninja_pagination") - if paginator.pass_parameter: - kwargs[paginator.pass_parameter] = pagination_params - - items = await func(request, **kwargs) - - result = await paginator.apaginate_queryset( - items, pagination=pagination_params, request=request, **kwargs - ) - - async def evaluate(results: Union[List, QuerySet]) -> AsyncGenerator: - for result in results: - yield result - - if paginator.Output: # type: ignore - result[paginator.items_attribute] = [ - result - async for result in evaluate(result[paginator.items_attribute]) - ] - return result - - else: - - @wraps(func) - def view_with_pagination(request: HttpRequest, **kwargs: Any) -> Any: - pagination_params = kwargs.pop("ninja_pagination") - if paginator.pass_parameter: - kwargs[paginator.pass_parameter] = pagination_params - - items = func(request, **kwargs) - - result = paginator.paginate_queryset( - items, pagination=pagination_params, request=request, **kwargs - ) - if paginator.Output: # type: ignore - result[paginator.items_attribute] = list( - result[paginator.items_attribute] - ) - # ^ forcing queryset evaluation #TODO: check why pydantic did not do it here - return result + + @wraps(func) + def view_with_pagination(request: HttpRequest, **kwargs: Any) -> Any: + pagination_params = kwargs.pop("ninja_pagination") + if paginator.pass_parameter: + kwargs[paginator.pass_parameter] = pagination_params + + items = func(request, **kwargs) + + result = paginator.paginate_queryset( + items, pagination=pagination_params, request=request, **kwargs + ) + if paginator.Output: # type: ignore + result[paginator.items_attribute] = list(result[paginator.items_attribute]) + # ^ forcing queryset evaluation #TODO: check why pydantic did not do it here + return result contribute_operation_args( view_with_pagination, @@ -310,7 +286,10 @@ def add_api_operation( ) -> None: response = kwargs["response"] if is_collection_type(response): - view_func = _inject_pagination(view_func, self.pagination_class) + if is_async_callable(view_func): + view_func = _inject_async_pagination(view_func, self.pagination_class) + else: + view_func = _inject_pagination(view_func, self.pagination_class) return super().add_api_operation(path, methods, view_func, **kwargs) diff --git a/tests/test_async_paginate.py b/tests/test_async_paginate.py new file mode 100644 index 00000000..b61f0648 --- /dev/null +++ b/tests/test_async_paginate.py @@ -0,0 +1,42 @@ +from typing import List + +import pytest + +from ninja import NinjaAPI, Schema +from ninja.pagination import paginate +from ninja.testing import TestAsyncClient + +from someapp.models import Event + +api = NinjaAPI() + + +class DummySchema(Schema): + id: int + name: str + + +@api.get("/async_view_return_queryset/", response=List[DummySchema]) +@paginate +async def async_view_return_queryset(request, **kwargs) -> None: + return Event.objects.all() + + +@api.get("/async_view_return_list/", response=List[DummySchema]) +@paginate +async def async_view_return_list(request, **kwargs) -> None: + return [] + + +@pytest.mark.asyncio +@pytest.mark.django_db +async def test_success__async_paginated_async_view_return_queryset() -> None: + client = TestAsyncClient(api) + await client.get("/async_view_return_queryset/") # not raising any exception + + +@pytest.mark.asyncio +@pytest.mark.django_db +async def test_success__async_paginated_async_view_return_list() -> None: + client = TestAsyncClient(api) + await client.get("/async_view_return_list/") # not raising any exception