Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix SynchronousOnlyOperation error if @paginate with async view returns Django queryset #1293

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
129 changes: 79 additions & 50 deletions ninja/pagination.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from collections.abc import AsyncIterable, Iterable
import inspect
from abc import ABC, abstractmethod
from functools import partial, wraps
Expand Down Expand Up @@ -170,69 +171,45 @@ def my_view(request):
)

if isfunction:
return _inject_pagination(func_or_pgn_class, pagination_class)
if is_async_callable(func_or_pgn_class):
return _inject_async_pagination(func_or_pgn_class, pagination_class)
else:
return _inject_pagination(func_or_pgn_class, pagination_class)

if not isnotset:
pagination_class = func_or_pgn_class

def wrapper(func: Callable) -> Any:
return _inject_pagination(func, pagination_class, **paginator_params)
if is_async_callable(func):
return _inject_async_pagination(func, pagination_class, **paginator_params)
else:
return _inject_pagination(func, pagination_class, **paginator_params)

return wrapper


def _inject_pagination(
func: Callable,
paginator_class: Type[Union[PaginationBase, AsyncPaginationBase]],
paginator_class: Type[PaginationBase],
**paginator_params: Any,
) -> Callable:
paginator = paginator_class(**paginator_params)
if is_async_callable(func):
if not hasattr(paginator, "apaginate_queryset"):
raise ConfigError("Pagination class not configured for async requests")

@wraps(func)
async def view_with_pagination(request: HttpRequest, **kwargs: Any) -> Any:
pagination_params = kwargs.pop("ninja_pagination")
if paginator.pass_parameter:
kwargs[paginator.pass_parameter] = pagination_params

items = await func(request, **kwargs)

result = await paginator.apaginate_queryset(
items, pagination=pagination_params, request=request, **kwargs
)

async def evaluate(results: Union[List, QuerySet]) -> AsyncGenerator:
for result in results:
yield result

if paginator.Output: # type: ignore
result[paginator.items_attribute] = [
result
async for result in evaluate(result[paginator.items_attribute])
]
return result

else:

@wraps(func)
def view_with_pagination(request: HttpRequest, **kwargs: Any) -> Any:
pagination_params = kwargs.pop("ninja_pagination")
if paginator.pass_parameter:
kwargs[paginator.pass_parameter] = pagination_params

items = func(request, **kwargs)

result = paginator.paginate_queryset(
items, pagination=pagination_params, request=request, **kwargs
)
if paginator.Output: # type: ignore
result[paginator.items_attribute] = list(
result[paginator.items_attribute]
)
# ^ forcing queryset evaluation #TODO: check why pydantic did not do it here
return result

@wraps(func)
def view_with_pagination(request: HttpRequest, **kwargs: Any) -> Any:
pagination_params = kwargs.pop("ninja_pagination")
if paginator.pass_parameter:
kwargs[paginator.pass_parameter] = pagination_params

items = func(request, **kwargs)

result = paginator.paginate_queryset(
items, pagination=pagination_params, request=request, **kwargs
)
if paginator.Output: # type: ignore
result[paginator.items_attribute] = list(result[paginator.items_attribute])
# ^ forcing queryset evaluation #TODO: check why pydantic did not do it here
return result

contribute_operation_args(
view_with_pagination,
Expand All @@ -250,6 +227,55 @@ def view_with_pagination(request: HttpRequest, **kwargs: Any) -> Any:
return view_with_pagination


def _inject_async_pagination(
func: Callable,
paginator_class: Type[AsyncPaginationBase],
**paginator_params: Any,
) -> Callable:
paginator = paginator_class(**paginator_params)
if not hasattr(paginator, "apaginate_queryset"):
raise ConfigError("Pagination class not configured for async requests")

@wraps(func)
async def paginated_view(request: HttpRequest, **kwargs: Any) -> Any:
pagination_params = kwargs.pop("ninja_pagination")
if paginator.pass_parameter:
kwargs[paginator.pass_parameter] = pagination_params

result = await func(request, **kwargs)
paginated_result = await paginator.apaginate_queryset(
result, pagination=pagination_params, request=request, **kwargs
)

if paginator.Output: # type: ignore
items = paginated_result[paginator.items_attribute]
if isinstance(items, QuerySet) or isinstance(items, AsyncIterable):
new_items = [r async for r in items]
elif isinstance(items, Iterable):
new_items = list(items)
else:
raise TypeError("Unexpected type")

paginated_result[paginator.items_attribute] = new_items

return paginated_result

contribute_operation_args(
paginated_view,
"ninja_pagination",
paginator.Input,
paginator.InputSource,
)

if paginator.Output: # type: ignore
contribute_operation_callback(
paginated_view,
partial(make_response_paginated, paginator),
)

return paginated_view


class RouterPaginated(Router):
def __init__(self, *args: Any, **kwargs: Any) -> None:
super().__init__(*args, **kwargs)
Expand All @@ -260,7 +286,10 @@ def add_api_operation(
) -> None:
response = kwargs["response"]
if is_collection_type(response):
view_func = _inject_pagination(view_func, self.pagination_class)
if is_async_callable(view_func):
view_func = _inject_async_pagination(view_func, self.pagination_class)
else:
view_func = _inject_pagination(view_func, self.pagination_class)
return super().add_api_operation(path, methods, view_func, **kwargs)


Expand Down
42 changes: 42 additions & 0 deletions tests/test_async_paginate.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
from typing import List

import pytest

from ninja import NinjaAPI, Schema
from ninja.pagination import paginate
from ninja.testing import TestAsyncClient

from someapp.models import Event

api = NinjaAPI()


class DummySchema(Schema):
id: int
name: str


@api.get("/async_view_return_queryset/", response=List[DummySchema])
@paginate
async def async_view_return_queryset(request, **kwargs) -> None:
return Event.objects.all()


@api.get("/async_view_return_list/", response=List[DummySchema])
@paginate
async def async_view_return_list(request, **kwargs) -> None:
return []


@pytest.mark.asyncio
@pytest.mark.django_db
async def test_success__async_paginated_async_view_return_queryset() -> None:
client = TestAsyncClient(api)
await client.get("/async_view_return_queryset/") # not raising any exception


@pytest.mark.asyncio
@pytest.mark.django_db
async def test_success__async_paginated_async_view_return_list() -> None:
client = TestAsyncClient(api)
await client.get("/async_view_return_list/") # not raising any exception
Loading