diff --git a/docs/guide/relay.md b/docs/guide/relay.md index cd4438f2..f7a431e3 100644 --- a/docs/guide/relay.md +++ b/docs/guide/relay.md @@ -69,3 +69,36 @@ For more customization options, like changing the pagination algorithm, adding e to the `Connection`/`Edge` type, take a look at the [official strawberry relay integration](https://strawberry.rocks/docs/guides/relay) as those are properly explained there. + +## Cursor based connections + +As an alternative to the default `ListConnection`, `DjangoCursorConnection` is also available. +It supports pagination through a Django `QuerySet` via "true" cursors. +`ListConnection` uses slicing to achieve pagination, which can negatively affect performance for huge datasets, +because large page numbers require a large `OFFSET` in SQL. +Instead, `DjangoCursorConnection` uses range queries such as `Q(due_date__gte=...)` for pagination. In combination +with an Index, this makes for more efficient queries. + +`DjangoCursorConnection` requires a _strictly_ ordered `QuerySet`, that is, no two entries in the `QuerySet` +must be considered equal by its ordering. `order_by('due_date')` for example is not strictly ordered, because two +items could have the same due date. `DjangoCursorConnection` will automatically resolve such situations by +also ordering by the primary key. + +When the order for the connection is configurable by the user (for example via +[`@strawberry_django.order`](./ordering.md)) then cursors created by `DjangoCursorConnection` will not be compatible +between different orders. + +The drawback of cursor based pagination is that users cannot jump to a particular page immediately. Therefor +cursor based pagination is better suited for special use-cases like an infinitely scrollable list. + +Otherwise `DjangoCursorConnection` behaves like other connection classes: + +```python +@strawberry.type +class Query: + fruit: DjangoCursorConnection[FruitType] = strawberry_django.connection() + + @strawberry_django.connection(DjangoCursorConnection[FruitType]) + def fruit_with_custom_resolver(self) -> list[Fruit]: + return Fruit.objects.all() +``` diff --git a/strawberry_django/fields/field.py b/strawberry_django/fields/field.py index 91e41fc2..4e3fac74 100644 --- a/strawberry_django/fields/field.py +++ b/strawberry_django/fields/field.py @@ -323,8 +323,10 @@ def get_queryset(self, queryset, info, **kwargs): ) # If optimizer extension is enabled, optimize this queryset - ext = optimizer.optimizer.get() - if ext is not None: + if ( + not self.disable_optimization + and (ext := optimizer.optimizer.get()) is not None + ): queryset = ext.optimize(queryset, info=info) return queryset diff --git a/strawberry_django/optimizer.py b/strawberry_django/optimizer.py index cba1c712..5b86d290 100644 --- a/strawberry_django/optimizer.py +++ b/strawberry_django/optimizer.py @@ -41,7 +41,7 @@ from strawberry.relay.utils import SliceMetadata from strawberry.schema.schema import Schema from strawberry.schema.schema_converter import get_arguments -from strawberry.types import get_object_definition +from strawberry.types import get_object_definition, has_object_definition from strawberry.types.base import StrawberryContainer from strawberry.types.info import Info from strawberry.types.lazy_type import LazyType @@ -94,7 +94,6 @@ "optimize", ] -NESTED_PREFETCH_MARK = "_strawberry_nested_prefetch_optimized" _M = TypeVar("_M", bound=models.Model) _sentinel = object() @@ -496,6 +495,8 @@ def _optimize_prefetch_queryset( StrawberryDjangoField, ) + from .relay_cursor import DjangoCursorConnection, apply_cursor_pagination + if ( not config or not config.enable_nested_relations_prefetch @@ -571,6 +572,17 @@ def _optimize_prefetch_queryset( limit=slice_metadata.end - slice_metadata.start, max_results=connection_extension.max_results, ) + elif connection_type is DjangoCursorConnection: + qs, _ = apply_cursor_pagination( + qs, + related_field_id=related_field_id, + info=Info(_raw_info=info, _field=field), + first=field_kwargs.get("first"), + last=field_kwargs.get("last"), + before=field_kwargs.get("before"), + after=field_kwargs.get("after"), + max_results=connection_extension.max_results, + ) else: mark_optimized = False @@ -1246,13 +1258,20 @@ def _get_model_hints_from_connection( if edge.name.value != "edges": continue - e_definition = get_object_definition(relay.Edge, strict=True) - e_type = e_definition.resolve_generic( - relay.Edge[cast("type[relay.Node]", n_type)], - ) + e_field = object_definition.get_field("edges") + if e_field is None: + break + + e_definition = e_field.type + while isinstance(e_definition, StrawberryContainer): + e_definition = e_definition.of_type + if has_object_definition(e_definition): + e_definition = get_object_definition(e_definition, strict=True) + assert isinstance(e_definition, StrawberryObjectDefinition) + e_gql_definition = _get_gql_definition( schema, - get_object_definition(e_type, strict=True), + e_definition, ) assert isinstance(e_gql_definition, (GraphQLObjectType, GraphQLInterfaceType)) e_info = _generate_selection_resolve_info( @@ -1451,20 +1470,17 @@ def optimize( def is_optimized(qs: QuerySet) -> bool: - return get_queryset_config(qs).optimized or is_optimized_by_prefetching(qs) + config = get_queryset_config(qs) + return config.optimized or config.optimized_by_prefetching def mark_optimized_by_prefetching(qs: QuerySet[_M]) -> QuerySet[_M]: - # This is a bit of a hack, but there is no easy way to mark a related manager - # as optimized at this phase, so we just add a mark to the queryset that - # we can check leater on using is_optimized_by_prefetching - return qs.annotate(**{ - NESTED_PREFETCH_MARK: models.Value(True), - }) + get_queryset_config(qs).optimized_by_prefetching = True + return qs def is_optimized_by_prefetching(qs: QuerySet) -> bool: - return NESTED_PREFETCH_MARK in qs.query.annotations + return get_queryset_config(qs).optimized_by_prefetching optimizer: contextvars.ContextVar[DjangoOptimizerExtension | None] = ( diff --git a/strawberry_django/queryset.py b/strawberry_django/queryset.py index d2457d47..14052ad5 100644 --- a/strawberry_django/queryset.py +++ b/strawberry_django/queryset.py @@ -9,6 +9,8 @@ if TYPE_CHECKING: from strawberry import Info + from strawberry_django.relay_cursor import OrderingDescriptor + _M = TypeVar("_M", bound=Model) CONFIG_KEY = "_strawberry_django_config" @@ -17,11 +19,16 @@ @dataclasses.dataclass class StrawberryDjangoQuerySetConfig: optimized: bool = False + optimized_by_prefetching: bool = False type_get_queryset_did_run: bool = False + ordering_descriptors: list[OrderingDescriptor] | None = None def get_queryset_config(queryset: QuerySet) -> StrawberryDjangoQuerySetConfig: - return getattr(queryset, CONFIG_KEY, None) or StrawberryDjangoQuerySetConfig() + config = getattr(queryset, CONFIG_KEY, None) + if config is None: + setattr(queryset, CONFIG_KEY, (config := StrawberryDjangoQuerySetConfig())) + return config def run_type_get_queryset( diff --git a/strawberry_django/relay_cursor.py b/strawberry_django/relay_cursor.py new file mode 100644 index 00000000..62d3f72b --- /dev/null +++ b/strawberry_django/relay_cursor.py @@ -0,0 +1,462 @@ +import json +from dataclasses import dataclass +from json import JSONDecodeError +from typing import Any, ClassVar, Optional, cast + +import strawberry +from asgiref.sync import sync_to_async +from django.core.exceptions import ValidationError +from django.db import DEFAULT_DB_ALIAS, models +from django.db.models import Expression, F, OrderBy, Q, QuerySet, Value, Window +from django.db.models.constants import LOOKUP_SEP +from django.db.models.expressions import Col +from django.db.models.functions import RowNumber +from django.db.models.sql.datastructures import BaseTable +from strawberry import Info, relay +from strawberry.relay import NodeType, PageInfo, from_base64, to_base64 +from strawberry.relay.types import NodeIterableType +from strawberry.relay.utils import should_resolve_list_connection_edges +from strawberry.types import get_object_definition +from strawberry.types.base import StrawberryContainer +from strawberry.utils.await_maybe import AwaitableOrValue +from strawberry.utils.inspect import in_async_context +from typing_extensions import Self + +from strawberry_django import django_resolver +from strawberry_django.optimizer import is_optimized_by_prefetching +from strawberry_django.pagination import apply_window_pagination, get_total_count +from strawberry_django.queryset import get_queryset_config + + +def _get_order_by(qs: QuerySet) -> list[OrderBy]: + return [ + expr + for expr, _ in qs.query.get_compiler( + using=qs._db or DEFAULT_DB_ALIAS # type: ignore + ).get_order_by() + ] + + +@dataclass +class OrderingDescriptor: + attname: str + order_by: OrderBy + # we have to assume everything is nullable by default + maybe_null: bool = True + + def get_comparator(self, value: Any, before: bool) -> Optional[Q]: + if value is None: + # 1. When nulls are first: + # 1.1 there is nothing before "null" + # 1.2 after "null" comes everything non-null + # 2. When nulls are last: + # 2.1 there is nothing after "null" + # 2.2 before "null" comes everything non-null + # => 1.1 and 2.1 require no checks + # => 1.2 and 2.2 require an "is not null" check + if bool(self.order_by.nulls_first) ^ before: + return Q((f"{self.attname}{LOOKUP_SEP}isnull", False)) + return None + lookup = "lt" if before ^ self.order_by.descending else "gt" + cmp = Q((f"{self.attname}{LOOKUP_SEP}{lookup}", value)) + + if self.maybe_null and bool(self.order_by.nulls_first) == before: + # if nulls are first, "before any value" can also mean "is null" + # if nulls are last, "after any value" can also mean "is null" + cmp |= Q((f"{self.attname}{LOOKUP_SEP}isnull", True)) + return cmp + + def get_eq(self, value) -> Q: + if value is None: + return Q((f"{self.attname}{LOOKUP_SEP}isnull", True)) + return Q((f"{self.attname}{LOOKUP_SEP}exact", value)) + + +def annotate_ordering_fields( + qs: QuerySet, +) -> tuple[QuerySet, list[OrderingDescriptor], list[OrderBy]]: + annotations = {} + descriptors = [] + new_defer = None + new_only = None + order_bys = _get_order_by(qs) + pk_in_order = False + for index, order_by in enumerate(order_bys): + if isinstance(order_by.expression, Col) and isinstance( + # Col.alias is missing from django-types + qs.query.alias_map[order_by.expression.alias], # type: ignore + BaseTable, + ): + field_name = order_by.expression.field.name + # if it's a field in the base table, just make sure it is not deferred (e.g. by the optimizer) + existing, defer = qs.query.deferred_loading + if defer and field_name in existing: + # Query is in "defer fields" mode and our field is being deferred + if new_defer is None: + new_defer = set(existing) + new_defer.discard(field_name) + elif not defer and field_name not in existing: + # Query is in "only these fields" mode and our field is not in the set of fields + if new_only is None: + new_only = set(existing) + new_only.add(field_name) + descriptors.append( + OrderingDescriptor( + order_by.expression.field.attname, + order_by, + maybe_null=order_by.expression.field.null, + ) + ) + if order_by.expression.field.primary_key: + pk_in_order = True + else: + dynamic_field = f"_strawberry_order_field_{index}" + annotations[dynamic_field] = order_by.expression + descriptors.append(OrderingDescriptor(dynamic_field, order_by)) + + if new_defer is not None: + # defer is additive, so clear it first + qs = qs.defer(None).defer(*new_defer) + elif new_only is not None: + # only overwrites + qs = qs.only(*new_only) + + if not pk_in_order: + # Ensure we always have a clearly defined order by ordering by pk if it isn't in the order already + # We cannot use QuerySet.order_by, because it validates the order expressions again, + # but we're operating on the OrderBy expressions which have already been resolved by the compiler + # In case the user has previously ordered by an aggregate like so: + # qs.annotate(_c=Count("foo")).order_by("_c") # noqa: ERA001 + # then the OrderBy we get here would trigger a ValidationError by QuerySet.order_by. + # But we only want to append to the existing order (and the existing order must be valid already) + # So this is safe. + pk_order = F("pk").resolve_expression(qs.query).asc() + order_bys.append(pk_order) + descriptors.append(OrderingDescriptor("pk", pk_order, maybe_null=False)) + qs = qs._chain() # type: ignore + qs.query.order_by += (pk_order,) + return qs.annotate(**annotations), descriptors, order_bys + + +def build_tuple_compare( + descriptors: list[OrderingDescriptor], + cursor_values: list[Optional[str]], + before: bool, +) -> Q: + current = None + for descriptor, field_value in zip(reversed(descriptors), reversed(cursor_values)): + if field_value is None: + value_expr = None + else: + output_field = descriptor.order_by.expression.output_field + value_expr = Value(field_value, output_field=output_field) + cmp = descriptor.get_comparator(value_expr, before) + if current is None: + current = cmp + else: + eq = descriptor.get_eq(value_expr) + current = cmp | (eq & current) if cmp is not None else eq & current + return current if current is not None else Q() + + +class AttrHelper: + pass + + +def _extract_expression_value( + model: models.Model, expr: Expression, attname: str +) -> Optional[str]: + output_field = expr.output_field + # Unfortunately Field.value_to_string operates on the object, not a direct value + # So we have to potentially construct a fake object + # If the output field's attname doesn't match, we have to construct a fake object + # Additionally, the output field may not have an attname at all + # if expressions are used + field_attname = getattr(output_field, "attname", None) + if not field_attname: + # If the field doesn't have an attname, it's a dynamically constructed field, + # for the purposes of output_field in an expression. Just set its attname, it doesn't hurt anything + output_field.attname = field_attname = attname + obj: Any + if field_attname != attname: + obj = AttrHelper() + setattr(obj, output_field.attname, getattr(model, attname)) + else: + obj = model + value = output_field.value_from_object(obj) + if value is None: + return None + # value_to_string is missing from django-types + return output_field.value_to_string(obj) # type: ignore + + +def apply_cursor_pagination( + qs: QuerySet, + *, + related_field_id: Optional[str] = None, + info: Info, + before: Optional[str], + after: Optional[str], + first: Optional[int], + last: Optional[int], + max_results: Optional[int], +) -> tuple[QuerySet, list[OrderingDescriptor]]: + max_results = ( + max_results if max_results is not None else info.schema.config.relay_max_results + ) + + qs, ordering_descriptors, original_order_by = annotate_ordering_fields(qs) + if after: + after_cursor = OrderedCollectionCursor.from_cursor(after, ordering_descriptors) + qs = qs.filter( + build_tuple_compare(ordering_descriptors, after_cursor.field_values, False) + ) + if before: + before_cursor = OrderedCollectionCursor.from_cursor( + before, ordering_descriptors + ) + qs = qs.filter( + build_tuple_compare(ordering_descriptors, before_cursor.field_values, True) + ) + + slice_: Optional[slice] = None + if first is not None and last is not None: + if last > max_results: + raise ValueError(f"Argument 'last' cannot be higher than {max_results}.") + # if first and last are given, we have to + # - reverse the order in the DB so we can use slicing to apply [:last], + # otherwise we would have to know the total count to apply slicing from the end + # - We still need to apply forward-direction [:first] slicing, and according to the Relay spec, + # it shall happen before [:last] slicing. To do this, we use a window function with a RowNumber ordered + # in the original direction, which is opposite the actual query order. + # This query is likely not very efficient, but using last _and_ first together is discouraged by the + # spec anyway + qs = ( + qs.reverse() + .annotate( + _strawberry_row_number_fwd=Window( + RowNumber(), + order_by=original_order_by, + ), + ) + .filter( + _strawberry_row_number_fwd__lte=first + 1, + ) + ) + # we're overfetching by two, in both directions + slice_ = slice(last + 2) + elif first is not None: + if first < 0: + raise ValueError("Argument 'first' must be a non-negative integer.") + if first > max_results: + raise ValueError(f"Argument 'first' cannot be higher than {max_results}.") + slice_ = slice(first + 1) + elif last is not None: + # when using last, optimize by reversing the QuerySet ordering in the DB, + # then slicing from the end (which is now the start in QuerySet ordering) + # and then iterating the results in reverse to restore the original order + if last < 0: + raise ValueError("Argument 'last' must be a non-negative integer.") + if last > max_results: + raise ValueError(f"Argument 'last' cannot be higher than {max_results}.") + slice_ = slice(last + 1) + qs = qs.reverse() + if related_field_id is not None: + # we always apply window pagination for nested connections, + # because we want its total count annotation + offset = slice_.start or 0 if slice_ is not None else 0 + qs = apply_window_pagination( + qs, + related_field_id=related_field_id, + offset=offset, + limit=slice_.stop - offset if slice_ is not None else None, + ) + elif slice_ is not None: + qs = qs[slice_] + + get_queryset_config(qs).ordering_descriptors = ordering_descriptors + + return qs, ordering_descriptors + + +@dataclass +class OrderedCollectionCursor: + field_values: list[Any] + + @classmethod + def from_model( + cls, model: models.Model, descriptors: list[OrderingDescriptor] + ) -> Self: + values = [ + _extract_expression_value( + model, descriptor.order_by.expression, descriptor.attname + ) + for descriptor in descriptors + ] + return cls(field_values=values) + + @classmethod + def from_cursor(cls, cursor: str, descriptors: list[OrderingDescriptor]) -> Self: + type_, values_json = from_base64(cursor) + if type_ != DjangoCursorEdge.PREFIX: + raise ValueError("Invalid cursor") + try: + string_values = json.loads(values_json) + except JSONDecodeError as e: + raise ValueError("Invalid cursor") from e + if ( + not isinstance(string_values, list) + or len(string_values) != len(descriptors) + or any(not (v is None or isinstance(v, str)) for v in string_values) + ): + raise ValueError("Invalid cursor") + + try: + decoded_values = [ + d.order_by.expression.output_field.to_python(v) + for d, v in zip(descriptors, string_values) + ] + except ValidationError as e: + raise ValueError("Invalid cursor") from e + + return cls(decoded_values) + + def to_cursor(self) -> str: + return to_base64( + DjangoCursorEdge.PREFIX, + json.dumps(self.field_values, separators=(",", ":")), + ) + + +@strawberry.type(name="CursorEdge", description="An edge in a connection.") +class DjangoCursorEdge(relay.Edge[relay.NodeType]): + PREFIX: ClassVar[str] = "orderedcursor" + + @classmethod + def resolve_edge( + cls, node: NodeType, *, cursor: Optional[OrderedCollectionCursor] = None + ) -> Self: + assert cursor is not None + return cls(cursor=cursor.to_cursor(), node=node) + + +@strawberry.type( + name="CursorConnection", description="A connection to a list of items." +) +class DjangoCursorConnection(relay.Connection[relay.NodeType]): + total_count_qs: strawberry.Private[Optional[QuerySet]] = None + + @strawberry.field(description="Total quantity of existing nodes.") + @django_resolver + def total_count(self) -> int: + assert self.total_count_qs is not None + + return get_total_count(self.total_count_qs) + + # TODO: Django CursorEdge should not exist, but relay.Edge has a hardcoded prefix currently + edges: list[DjangoCursorEdge[NodeType]] = strawberry.field( # type: ignore + description="Contains the nodes in this connection" + ) + + @classmethod + def resolve_connection( + cls, + nodes: NodeIterableType[NodeType], + *, + info: Info, + before: Optional[str] = None, + after: Optional[str] = None, + first: Optional[int] = None, + last: Optional[int] = None, + max_results: Optional[int] = None, + **kwargs: Any, + ) -> AwaitableOrValue[Self]: + if not isinstance(nodes, QuerySet): + raise TypeError("DjangoCursorConnection requires a QuerySet") + total_count_qs: QuerySet = nodes + qs: QuerySet + if not is_optimized_by_prefetching(nodes): + qs, ordering_descriptors = apply_cursor_pagination( + nodes, + info=info, + before=before, + after=after, + first=first, + last=last, + max_results=max_results, + ) + else: + qs = nodes + ordering_descriptors = get_queryset_config(qs).ordering_descriptors + assert ordering_descriptors is not None + + type_def = get_object_definition(cls) + assert type_def + field_def = type_def.get_field("edges") + assert field_def + + field = field_def.resolve_type(type_definition=type_def) + while isinstance(field, StrawberryContainer): + field = field.of_type + + edge_class = cast("DjangoCursorEdge[NodeType]", field) + + if not should_resolve_list_connection_edges(info): + return cls( + edges=[], + total_count_qs=total_count_qs, + page_info=PageInfo( + start_cursor=None, + end_cursor=None, + has_previous_page=False, + has_next_page=False, + ), + ) + + def finish_resolving(): + nonlocal qs + has_previous_page = has_next_page = False + + results = list(qs) + + if first is not None: + if last is None: + has_next_page = len(results) > first + results = results[:first] + # we're paginating forwards _and_ backwards + # remove the (potentially) overfetched row in the forwards direction first + elif ( + results + and getattr(results[0], "_strawberry_row_number_fwd", 0) > first + ): + has_next_page = True + results = results[1:] + + if last is not None: + has_previous_page = len(results) > last + results = results[:last] + + it = reversed(results) if last is not None else results + + edges = [ + edge_class.resolve_edge( + cls.resolve_node(v, info=info, **kwargs), + cursor=OrderedCollectionCursor.from_model(v, ordering_descriptors), + ) + for v in it + ] + + return cls( + edges=edges, + total_count_qs=total_count_qs, + page_info=PageInfo( + start_cursor=edges[0].cursor if edges else None, + end_cursor=edges[-1].cursor if edges else None, + has_previous_page=has_previous_page, + has_next_page=has_next_page, + ), + ) + + if in_async_context() and qs._result_cache is None: # type: ignore + return sync_to_async(finish_resolving)() + return finish_resolving() diff --git a/tests/relay/test_cursor_pagination.py b/tests/relay/test_cursor_pagination.py new file mode 100644 index 00000000..c4fb548b --- /dev/null +++ b/tests/relay/test_cursor_pagination.py @@ -0,0 +1,1416 @@ +import datetime +from typing import Optional, cast + +import pytest +import strawberry +from django.db.models import F, OrderBy, QuerySet, Value +from django.db.models.aggregates import Count +from pytest_mock import MockFixture +from strawberry.relay import GlobalID, Node, to_base64 + +import strawberry_django +from strawberry_django.optimizer import DjangoOptimizerExtension +from strawberry_django.relay_cursor import ( + DjangoCursorConnection, + DjangoCursorEdge, +) +from tests.projects.models import Milestone, Project +from tests.utils import assert_num_queries + + +@strawberry_django.order(Project) +class ProjectOrder: + id: strawberry.auto + name: strawberry.auto + due_date: strawberry.auto + + @strawberry_django.order_field() + def milestone_count( + self, queryset: QuerySet, value: strawberry_django.Ordering, prefix: str + ) -> "tuple[QuerySet, list[OrderBy]]": + queryset = queryset.annotate(_milestone_count=Count(f"{prefix}milestone")) + return queryset, [value.resolve("_milestone_count")] + + +@strawberry_django.order(Milestone) +class MilestoneOrder: + due_date: strawberry.auto + project: ProjectOrder + + @strawberry_django.order_field() + def days_left( + self, queryset: QuerySet, value: strawberry_django.Ordering, prefix: str + ) -> "tuple[QuerySet, list[OrderBy]]": + queryset = queryset.alias( + _days_left=Value(datetime.date(2025, 12, 31)) - F(f"{prefix}due_date") + ) + return queryset, [value.resolve("_days_left")] + + +@strawberry_django.type(Milestone, order=MilestoneOrder) +class MilestoneType(Node): + due_date: strawberry.auto + project: "ProjectType" + + @classmethod + def get_queryset(cls, qs: QuerySet, info): + if not qs.ordered: + qs = qs.order_by("project__name", "pk") + return qs + + +@strawberry_django.type(Project, order=ProjectOrder) +class ProjectType(Node): + name: str + due_date: datetime.date + milestones: DjangoCursorConnection[MilestoneType] = strawberry_django.connection() + + @classmethod + def get_queryset(cls, qs: QuerySet, info): + if not qs.ordered: + qs = qs.order_by("name", "pk") + return qs + + +@strawberry.type() +class Query: + project: Optional[ProjectType] = strawberry_django.node() + projects: DjangoCursorConnection[ProjectType] = strawberry_django.connection() + milestones: DjangoCursorConnection[MilestoneType] = strawberry_django.connection() + + @strawberry_django.connection( + DjangoCursorConnection[ProjectType], disable_optimization=True + ) + @staticmethod + def deferred_projects() -> list[ProjectType]: + result = Project.objects.all().order_by("name").defer("name") + return cast("list[ProjectType]", result) + + @strawberry_django.connection(DjangoCursorConnection[ProjectType]) + @staticmethod + def projects_with_resolver() -> list[ProjectType]: + return cast("list[ProjectType]", Project.objects.all().order_by("-pk")) + + +schema = strawberry.Schema(query=Query, extensions=[DjangoOptimizerExtension()]) + + +@pytest.fixture +def test_objects(): + pa = Project.objects.create(id=1, name="Project A") + pc1 = Project.objects.create(id=2, name="Project C") + Project.objects.create(id=5, name="Project C") + pb = Project.objects.create(id=3, name="Project B") + Project.objects.create(id=6, name="Project D") + Project.objects.create(id=4, name="Project E") + + Milestone.objects.create(id=1, project=pb, due_date=datetime.date(2025, 6, 1)) + Milestone.objects.create(id=2, project=pb, due_date=datetime.date(2025, 6, 2)) + Milestone.objects.create(id=3, project=pc1, due_date=datetime.date(2025, 6, 1)) + Milestone.objects.create(id=4, project=pa, due_date=datetime.date(2025, 6, 5)) + + +@pytest.mark.django_db(transaction=True) +def test_cursor_pagination(test_objects): + query = """ + query TestQuery { + projects { + edges { + cursor + node { id name } + } + } + } + """ + with assert_num_queries(1): + result = schema.execute_sync(query) + assert result.data == { + "projects": { + "edges": [ + { + "cursor": to_base64( + DjangoCursorEdge.PREFIX, '["Project A","1"]' + ), + "node": { + "id": str(GlobalID("ProjectType", "1")), + "name": "Project A", + }, + }, + { + "cursor": to_base64( + DjangoCursorEdge.PREFIX, '["Project B","3"]' + ), + "node": { + "id": str(GlobalID("ProjectType", "3")), + "name": "Project B", + }, + }, + { + "cursor": to_base64( + DjangoCursorEdge.PREFIX, '["Project C","2"]' + ), + "node": { + "id": str(GlobalID("ProjectType", "2")), + "name": "Project C", + }, + }, + { + "cursor": to_base64( + DjangoCursorEdge.PREFIX, '["Project C","5"]' + ), + "node": { + "id": str(GlobalID("ProjectType", "5")), + "name": "Project C", + }, + }, + { + "cursor": to_base64( + DjangoCursorEdge.PREFIX, '["Project D","6"]' + ), + "node": { + "id": str(GlobalID("ProjectType", "6")), + "name": "Project D", + }, + }, + { + "cursor": to_base64( + DjangoCursorEdge.PREFIX, '["Project E","4"]' + ), + "node": { + "id": str(GlobalID("ProjectType", "4")), + "name": "Project E", + }, + }, + ] + } + } + + +@pytest.mark.django_db(transaction=True) +def test_cursor_pagination_custom_resolver(test_objects): + query = """ + query TestQuery($after: String, $first: Int) { + projectsWithResolver(after: $after, first: $first) { + edges { + cursor + node { id name } + } + } + } + """ + with assert_num_queries(1): + result = schema.execute_sync( + query, + { + "after": to_base64(DjangoCursorEdge.PREFIX, '["6"]'), + "first": 2, + }, + ) + assert result.data == { + "projectsWithResolver": { + "edges": [ + { + "cursor": to_base64(DjangoCursorEdge.PREFIX, '["5"]'), + "node": { + "id": str(GlobalID("ProjectType", "5")), + "name": "Project C", + }, + }, + { + "cursor": to_base64(DjangoCursorEdge.PREFIX, '["4"]'), + "node": { + "id": str(GlobalID("ProjectType", "4")), + "name": "Project E", + }, + }, + ] + } + } + + +@pytest.mark.django_db(transaction=True) +def test_forward_pagination(test_objects): + query = """ + query TestQuery($first: Int, $after: String) { + projects(first: $first, after: $after) { + edges { + cursor + node { id name } + } + pageInfo { + startCursor + endCursor + hasNextPage + } + } + } + """ + with assert_num_queries(1): + result = schema.execute_sync( + query, + { + "first": 3, + "after": to_base64(DjangoCursorEdge.PREFIX, '["Project B","3"]'), + }, + ) + assert result.data == { + "projects": { + "pageInfo": { + "startCursor": to_base64( + DjangoCursorEdge.PREFIX, '["Project C","2"]' + ), + "endCursor": to_base64( + DjangoCursorEdge.PREFIX, '["Project D","6"]' + ), + "hasNextPage": True, + }, + "edges": [ + { + "cursor": to_base64( + DjangoCursorEdge.PREFIX, '["Project C","2"]' + ), + "node": { + "id": str(GlobalID("ProjectType", "2")), + "name": "Project C", + }, + }, + { + "cursor": to_base64( + DjangoCursorEdge.PREFIX, '["Project C","5"]' + ), + "node": { + "id": str(GlobalID("ProjectType", "5")), + "name": "Project C", + }, + }, + { + "cursor": to_base64( + DjangoCursorEdge.PREFIX, '["Project D","6"]' + ), + "node": { + "id": str(GlobalID("ProjectType", "6")), + "name": "Project D", + }, + }, + ], + } + } + + +@pytest.mark.django_db(transaction=True) +def test_forward_pagination_first_page(test_objects): + query = """ + query TestQuery($first: Int, $after: String) { + projects(first: $first, after: $after) { + edges { + cursor + node { id name } + } + pageInfo { + startCursor + endCursor + hasPreviousPage + hasNextPage + } + } + } + """ + with assert_num_queries(1): + result = schema.execute_sync( + query, + { + "first": 1, + "after": None, + }, + ) + assert result.data == { + "projects": { + "pageInfo": { + "startCursor": to_base64( + DjangoCursorEdge.PREFIX, '["Project A","1"]' + ), + "endCursor": to_base64( + DjangoCursorEdge.PREFIX, '["Project A","1"]' + ), + "hasPreviousPage": False, + "hasNextPage": True, + }, + "edges": [ + { + "cursor": to_base64( + DjangoCursorEdge.PREFIX, '["Project A","1"]' + ), + "node": { + "id": str(GlobalID("ProjectType", "1")), + "name": "Project A", + }, + }, + ], + } + } + + +@pytest.mark.django_db(transaction=True) +def test_forward_pagination_last_page(test_objects): + query = """ + query TestQuery($first: Int, $after: String) { + projects(first: $first, after: $after) { + edges { + cursor + node { id name } + } + pageInfo { + startCursor + endCursor + hasNextPage + } + } + } + """ + with assert_num_queries(1): + result = schema.execute_sync( + query, + { + "first": 10, + "after": to_base64(DjangoCursorEdge.PREFIX, '["Project D","6"]'), + }, + ) + assert result.data == { + "projects": { + "pageInfo": { + "startCursor": to_base64( + DjangoCursorEdge.PREFIX, '["Project E","4"]' + ), + "endCursor": to_base64( + DjangoCursorEdge.PREFIX, '["Project E","4"]' + ), + "hasNextPage": False, + }, + "edges": [ + { + "cursor": to_base64( + DjangoCursorEdge.PREFIX, '["Project E","4"]' + ), + "node": { + "id": str(GlobalID("ProjectType", "4")), + "name": "Project E", + }, + }, + ], + } + } + + +@pytest.mark.django_db(transaction=True) +def test_backward_pagination(test_objects): + query = """ + query TestQuery($last: Int, $before: String) { + projects(last: $last, before: $before) { + edges { + cursor + node { id name } + } + pageInfo { + startCursor + endCursor + hasPreviousPage + } + } + } + """ + with assert_num_queries(1): + result = schema.execute_sync( + query, + { + "last": 2, + "before": to_base64(DjangoCursorEdge.PREFIX, '["Project C","5"]'), + }, + ) + assert result.data == { + "projects": { + "pageInfo": { + "startCursor": to_base64( + DjangoCursorEdge.PREFIX, '["Project B","3"]' + ), + "endCursor": to_base64( + DjangoCursorEdge.PREFIX, '["Project C","2"]' + ), + "hasPreviousPage": True, + }, + "edges": [ + { + "cursor": to_base64( + DjangoCursorEdge.PREFIX, '["Project B","3"]' + ), + "node": { + "id": str(GlobalID("ProjectType", "3")), + "name": "Project B", + }, + }, + { + "cursor": to_base64( + DjangoCursorEdge.PREFIX, '["Project C","2"]' + ), + "node": { + "id": str(GlobalID("ProjectType", "2")), + "name": "Project C", + }, + }, + ], + } + } + + +@pytest.mark.django_db(transaction=True) +def test_backward_pagination_first_page(test_objects): + query = """ + query TestQuery($last: Int, $before: String) { + projects(last: $last, before: $before) { + edges { + cursor + node { id name } + } + pageInfo { + startCursor + endCursor + hasNextPage + hasPreviousPage + } + } + } + """ + with assert_num_queries(1): + result = schema.execute_sync( + query, + { + "last": 2, + "before": None, + }, + ) + assert result.data == { + "projects": { + "pageInfo": { + "startCursor": to_base64( + DjangoCursorEdge.PREFIX, '["Project D","6"]' + ), + "endCursor": to_base64( + DjangoCursorEdge.PREFIX, '["Project E","4"]' + ), + "hasPreviousPage": True, + "hasNextPage": False, + }, + "edges": [ + { + "cursor": to_base64( + DjangoCursorEdge.PREFIX, '["Project D","6"]' + ), + "node": { + "id": str(GlobalID("ProjectType", "6")), + "name": "Project D", + }, + }, + { + "cursor": to_base64( + DjangoCursorEdge.PREFIX, '["Project E","4"]' + ), + "node": { + "id": str(GlobalID("ProjectType", "4")), + "name": "Project E", + }, + }, + ], + } + } + + +@pytest.mark.django_db(transaction=True) +def test_backward_pagination_last_page(test_objects): + query = """ + query TestQuery($last: Int, $before: String) { + projects(last: $last, before: $before) { + edges { + cursor + node { id name } + } + pageInfo { + startCursor + endCursor + hasPreviousPage + } + } + } + """ + with assert_num_queries(1): + result = schema.execute_sync( + query, + { + "last": 2, + "before": to_base64(DjangoCursorEdge.PREFIX, '["Project C","2"]'), + }, + ) + assert result.data == { + "projects": { + "pageInfo": { + "startCursor": to_base64( + DjangoCursorEdge.PREFIX, '["Project A","1"]' + ), + "endCursor": to_base64( + DjangoCursorEdge.PREFIX, '["Project B","3"]' + ), + "hasPreviousPage": False, + }, + "edges": [ + { + "cursor": to_base64( + DjangoCursorEdge.PREFIX, '["Project A","1"]' + ), + "node": { + "id": str(GlobalID("ProjectType", "1")), + "name": "Project A", + }, + }, + { + "cursor": to_base64( + DjangoCursorEdge.PREFIX, '["Project B","3"]' + ), + "node": { + "id": str(GlobalID("ProjectType", "3")), + "name": "Project B", + }, + }, + ], + } + } + + +@pytest.mark.parametrize( + ("first", "last", "pks", "has_next", "has_previous"), + [ + (4, 2, [3, 4], True, True), + (6, 2, [5, 6], False, True), + (4, 4, [1, 2, 3, 4], True, False), + (6, 6, [1, 2, 3, 4, 5, 6], False, False), + (8, 4, [3, 4, 5, 6], False, True), + (4, 8, [1, 2, 3, 4], True, False), + ], +) +@pytest.mark.django_db(transaction=True) +def test_first_and_last_pagination( + first, last, pks, has_next, has_previous, test_objects +): + query = """ + query TestQuery($first: Int, $last: Int) { + projects(first: $first, last: $last, order: { id: ASC }) { + edges { + cursor + node { id } + } + pageInfo { + startCursor + endCursor + hasNextPage + hasPreviousPage + } + } + } + """ + with assert_num_queries(1): + result = schema.execute_sync( + query, + { + "first": first, + "last": last, + }, + ) + assert result.data == { + "projects": { + "pageInfo": { + "startCursor": to_base64(DjangoCursorEdge.PREFIX, f'["{pks[0]}"]'), + "endCursor": to_base64(DjangoCursorEdge.PREFIX, f'["{pks[-1]}"]'), + "hasPreviousPage": has_previous, + "hasNextPage": has_next, + }, + "edges": [ + { + "cursor": to_base64(DjangoCursorEdge.PREFIX, f'["{pk}"]'), + "node": { + "id": str(GlobalID("ProjectType", str(pk))), + }, + } + for pk in pks + ], + } + } + + +@pytest.mark.django_db(transaction=True) +def test_empty_connection(): + query = """ + query TestQuery { + projects { + edges { + cursor + node { id name } + } + pageInfo { + startCursor + endCursor + hasNextPage + hasPreviousPage + } + } + } + """ + with assert_num_queries(1): + result = schema.execute_sync( + query, + ) + assert result.data == { + "projects": { + "pageInfo": { + "startCursor": None, + "endCursor": None, + "hasNextPage": False, + "hasPreviousPage": False, + }, + "edges": [], + } + } + + +@pytest.mark.django_db(transaction=True) +def test_cursor_pagination_custom_order(test_objects): + query = """ + query TestQuery($first: Int, $after: String) { + projects(first: $first, after: $after, order: { name: DESC id: ASC }) { + edges { + cursor + node { id name } + } + } + } + """ + with assert_num_queries(1): + result = schema.execute_sync( + query, + { + "first": 2, + "after": to_base64(DjangoCursorEdge.PREFIX, '["Project E","4"]'), + }, + ) + assert result.data == { + "projects": { + "edges": [ + { + "cursor": to_base64( + DjangoCursorEdge.PREFIX, '["Project D","6"]' + ), + "node": { + "id": str(GlobalID("ProjectType", "6")), + "name": "Project D", + }, + }, + { + "cursor": to_base64( + DjangoCursorEdge.PREFIX, '["Project C","2"]' + ), + "node": { + "id": str(GlobalID("ProjectType", "2")), + "name": "Project C", + }, + }, + ] + } + } + + +@pytest.mark.django_db(transaction=True) +def test_cursor_pagination_joined_field_order(test_objects): + query = """ + query TestQuery { + milestones(order: { dueDate: DESC, project: { name: ASC } }) { + edges { + cursor + node { id dueDate project { id name } } + } + } + } + """ + with assert_num_queries(2): + result = schema.execute_sync(query) + assert result.data == { + "milestones": { + "edges": [ + { + "cursor": to_base64( + DjangoCursorEdge.PREFIX, + '["2025-06-05","Project A","4"]', + ), + "node": { + "id": str(GlobalID("MilestoneType", "4")), + "dueDate": "2025-06-05", + "project": { + "id": str(GlobalID("ProjectType", "1")), + "name": "Project A", + }, + }, + }, + { + "cursor": to_base64( + DjangoCursorEdge.PREFIX, + '["2025-06-02","Project B","2"]', + ), + "node": { + "id": str(GlobalID("MilestoneType", "2")), + "dueDate": "2025-06-02", + "project": { + "id": str(GlobalID("ProjectType", "3")), + "name": "Project B", + }, + }, + }, + { + "cursor": to_base64( + DjangoCursorEdge.PREFIX, + '["2025-06-01","Project B","1"]', + ), + "node": { + "id": str(GlobalID("MilestoneType", "1")), + "dueDate": "2025-06-01", + "project": { + "id": str(GlobalID("ProjectType", "3")), + "name": "Project B", + }, + }, + }, + { + "cursor": to_base64( + DjangoCursorEdge.PREFIX, + '["2025-06-01","Project C","3"]', + ), + "node": { + "id": str(GlobalID("MilestoneType", "3")), + "dueDate": "2025-06-01", + "project": { + "id": str(GlobalID("ProjectType", "2")), + "name": "Project C", + }, + }, + }, + ] + } + } + + +@pytest.mark.django_db(transaction=True) +def test_cursor_pagination_expression_order(test_objects): + query = """ + query TestQuery($after: String) { + milestones(after: $after, order: { daysLeft: ASC }) { + edges { + cursor + node { id } + } + } + } + """ + with assert_num_queries(1): + result = schema.execute_sync( + query, + {"after": to_base64(DjangoCursorEdge.PREFIX, '["209 00:00:00","4"]')}, + ) + assert result.data == { + "milestones": { + "edges": [ + { + "cursor": to_base64( + DjangoCursorEdge.PREFIX, '["212 00:00:00","2"]' + ), + "node": { + "id": str(GlobalID("MilestoneType", "2")), + }, + }, + { + "cursor": to_base64( + DjangoCursorEdge.PREFIX, '["213 00:00:00","1"]' + ), + "node": { + "id": str(GlobalID("MilestoneType", "1")), + }, + }, + { + "cursor": to_base64( + DjangoCursorEdge.PREFIX, '["213 00:00:00","3"]' + ), + "node": { + "id": str(GlobalID("MilestoneType", "3")), + }, + }, + ] + } + } + + +@pytest.mark.django_db(transaction=True) +def test_cursor_pagination_agg_expression_order(test_objects): + query = """ + query TestQuery($after: String, $first: Int) { + projects(after: $after, first: $first, order: { milestoneCount: DESC }) { + edges { + cursor + node { id } + } + } + } + """ + with assert_num_queries(1): + result = schema.execute_sync( + query, + { + "after": None, + "first": 4, + }, + ) + assert result.data == { + "projects": { + "edges": [ + { + "cursor": to_base64(DjangoCursorEdge.PREFIX, '["2","3"]'), + "node": { + "id": str(GlobalID("ProjectType", "3")), + }, + }, + { + "cursor": to_base64(DjangoCursorEdge.PREFIX, '["1","1"]'), + "node": { + "id": str(GlobalID("ProjectType", "1")), + }, + }, + { + "cursor": to_base64(DjangoCursorEdge.PREFIX, '["1","2"]'), + "node": { + "id": str(GlobalID("ProjectType", "2")), + }, + }, + { + "cursor": to_base64(DjangoCursorEdge.PREFIX, '["0","4"]'), + "node": { + "id": str(GlobalID("ProjectType", "4")), + }, + }, + ] + } + } + + +@pytest.mark.django_db(transaction=True) +def test_cursor_pagination_order_field_deferred(test_objects): + query = """ + query TestQuery { + deferredProjects(first: 2) { + edges { + cursor + node { id } + } + } + } + """ + with assert_num_queries(1): + result = schema.execute_sync(query) + assert result.data == { + "deferredProjects": { + "edges": [ + { + "cursor": to_base64( + DjangoCursorEdge.PREFIX, '["Project A","1"]' + ), + "node": { + "id": str(GlobalID("ProjectType", "1")), + }, + }, + { + "cursor": to_base64( + DjangoCursorEdge.PREFIX, '["Project B","3"]' + ), + "node": { + "id": str(GlobalID("ProjectType", "3")), + }, + }, + ] + } + } + + +@pytest.mark.django_db(transaction=True) +@pytest.mark.parametrize( + ("order", "pks"), + [ + ("DESC_NULLS_FIRST", [1, 4, 3, 2]), + ("DESC_NULLS_LAST", [3, 2, 1, 4]), + ("ASC_NULLS_FIRST", [1, 4, 2, 3]), + ("ASC_NULLS_LAST", [2, 3, 1, 4]), + ], +) +@pytest.mark.parametrize("offset", [0, 1, 2, 3]) +def test_cursor_pagination_order_with_nulls(order, pks, offset): + pa = Project.objects.create(id=1, name="Project A", due_date=None) + pc = Project.objects.create( + id=2, name="Project C", due_date=datetime.date(2025, 6, 2) + ) + pb = Project.objects.create( + id=3, name="Project B", due_date=datetime.date(2025, 6, 5) + ) + pd = Project.objects.create(id=4, name="Project D", due_date=None) + projects_lookup = {p.pk: p for p in (pa, pb, pc, pd)} + projects = [projects_lookup[pk] for pk in pks] + query = """ + query TestQuery($after: String, $first: Int, $order: Ordering!) { + projects(after: $after, first: $first, order: { dueDate: $order }) { + edges { + cursor + node { id name } + } + } + } + """ + + def make_cursor(project: Project) -> str: + due_date_part = ( + f'"{project.due_date.isoformat()}"' if project.due_date else "null" + ) + return to_base64(DjangoCursorEdge.PREFIX, f'[{due_date_part},"{project.pk}"]') + + with assert_num_queries(1): + result = schema.execute_sync( + query, + { + "order": order, + "after": make_cursor(projects[offset]), + "first": 2, + }, + ) + assert result.data == { + "projects": { + "edges": [ + { + "cursor": make_cursor(project), + "node": { + "id": str(GlobalID("ProjectType", str(project.pk))), + "name": project.name, + }, + } + for project in projects[offset + 1 : offset + 3] + ] + } + } + + +@pytest.mark.django_db(transaction=True) +async def test_cursor_pagination_async(test_objects): + query = """ + query TestQuery { + projects { + edges { + cursor + node { id name } + } + } + } + """ + result = await schema.execute(query) + assert result.data == { + "projects": { + "edges": [ + { + "cursor": to_base64(DjangoCursorEdge.PREFIX, '["Project A","1"]'), + "node": { + "id": str(GlobalID("ProjectType", "1")), + "name": "Project A", + }, + }, + { + "cursor": to_base64(DjangoCursorEdge.PREFIX, '["Project B","3"]'), + "node": { + "id": str(GlobalID("ProjectType", "3")), + "name": "Project B", + }, + }, + { + "cursor": to_base64(DjangoCursorEdge.PREFIX, '["Project C","2"]'), + "node": { + "id": str(GlobalID("ProjectType", "2")), + "name": "Project C", + }, + }, + { + "cursor": to_base64(DjangoCursorEdge.PREFIX, '["Project C","5"]'), + "node": { + "id": str(GlobalID("ProjectType", "5")), + "name": "Project C", + }, + }, + { + "cursor": to_base64(DjangoCursorEdge.PREFIX, '["Project D","6"]'), + "node": { + "id": str(GlobalID("ProjectType", "6")), + "name": "Project D", + }, + }, + { + "cursor": to_base64(DjangoCursorEdge.PREFIX, '["Project E","4"]'), + "node": { + "id": str(GlobalID("ProjectType", "4")), + "name": "Project E", + }, + }, + ] + } + } + + +@pytest.mark.django_db(transaction=True) +def test_nested_cursor_pagination_in_single(): + pa = Project.objects.create(id=1, name="Project A") + pb = Project.objects.create(id=2, name="Project B") + + Milestone.objects.create(id=1, project=pb, due_date=datetime.date(2025, 6, 1)) + Milestone.objects.create(id=2, project=pb, due_date=datetime.date(2025, 6, 2)) + Milestone.objects.create(id=3, project=pb, due_date=datetime.date(2025, 6, 1)) + Milestone.objects.create(id=4, project=pa, due_date=datetime.date(2025, 6, 5)) + Milestone.objects.create(id=5, project=pa, due_date=datetime.date(2025, 6, 1)) + + query = """ + query TestQuery($id: GlobalID!) { + project(id: $id) { + id + milestones(first: 2, order: { dueDate: ASC }) { + edges { + cursor + node { id dueDate } + } + } + } + } + """ + with assert_num_queries(2): + result = schema.execute_sync(query, {"id": str(GlobalID("ProjectType", "2"))}) + assert result.data == { + "project": { + "id": str(GlobalID("ProjectType", "2")), + "milestones": { + "edges": [ + { + "cursor": to_base64( + DjangoCursorEdge.PREFIX, + '["2025-06-01","1"]', + ), + "node": { + "id": str(GlobalID("MilestoneType", "1")), + "dueDate": "2025-06-01", + }, + }, + { + "cursor": to_base64( + DjangoCursorEdge.PREFIX, + '["2025-06-01","3"]', + ), + "node": { + "id": str(GlobalID("MilestoneType", "3")), + "dueDate": "2025-06-01", + }, + }, + ] + }, + }, + } + + +@pytest.mark.django_db(transaction=True) +def test_nested_cursor_pagination(): + pa = Project.objects.create(id=1, name="Project A") + pb = Project.objects.create(id=2, name="Project B") + + Milestone.objects.create(id=1, project=pb, due_date=datetime.date(2025, 6, 1)) + Milestone.objects.create(id=2, project=pb, due_date=datetime.date(2025, 6, 2)) + Milestone.objects.create(id=3, project=pb, due_date=datetime.date(2025, 6, 1)) + Milestone.objects.create(id=4, project=pa, due_date=datetime.date(2025, 6, 5)) + Milestone.objects.create(id=5, project=pa, due_date=datetime.date(2025, 6, 1)) + + query = """ + query TestQuery { + projects { + edges { + cursor + node { + id + milestones(first: 2, order: { dueDate: ASC }) { + pageInfo { hasNextPage } + edges { + cursor + node { id dueDate } + } + } + } + } + } + } + """ + with assert_num_queries(2): + result = schema.execute_sync(query) + assert result.data == { + "projects": { + "edges": [ + { + "cursor": to_base64( + DjangoCursorEdge.PREFIX, '["Project A","1"]' + ), + "node": { + "id": str(GlobalID("ProjectType", "1")), + "milestones": { + "pageInfo": {"hasNextPage": False}, + "edges": [ + { + "cursor": to_base64( + DjangoCursorEdge.PREFIX, + '["2025-06-01","5"]', + ), + "node": { + "id": str(GlobalID("MilestoneType", "5")), + "dueDate": "2025-06-01", + }, + }, + { + "cursor": to_base64( + DjangoCursorEdge.PREFIX, + '["2025-06-05","4"]', + ), + "node": { + "id": str(GlobalID("MilestoneType", "4")), + "dueDate": "2025-06-05", + }, + }, + ], + }, + }, + }, + { + "cursor": to_base64( + DjangoCursorEdge.PREFIX, '["Project B","2"]' + ), + "node": { + "id": str(GlobalID("ProjectType", "2")), + "milestones": { + "pageInfo": {"hasNextPage": True}, + "edges": [ + { + "cursor": to_base64( + DjangoCursorEdge.PREFIX, + '["2025-06-01","1"]', + ), + "node": { + "id": str(GlobalID("MilestoneType", "1")), + "dueDate": "2025-06-01", + }, + }, + { + "cursor": to_base64( + DjangoCursorEdge.PREFIX, + '["2025-06-01","3"]', + ), + "node": { + "id": str(GlobalID("MilestoneType", "3")), + "dueDate": "2025-06-01", + }, + }, + ], + }, + }, + }, + ] + } + } + + +@pytest.mark.django_db(transaction=True) +@pytest.mark.parametrize("first", [None, 3]) +@pytest.mark.parametrize("after", [None, to_base64(DjangoCursorEdge.PREFIX, '["2"]')]) +@pytest.mark.parametrize("last", [None, 3]) +@pytest.mark.parametrize("before", [None, to_base64(DjangoCursorEdge.PREFIX, '["2"]')]) +def test_total_count_ignores_pagination(test_objects, first, after, before, last): + query = """ + query TestQuery($first: Int, $after: String, $last: Int, $before: String) { + projects(first: $first, after: $after, last: $last, before: $before, order: { id: ASC }) { + totalCount + } + } + """ + with assert_num_queries(1): + result = schema.execute_sync( + query, {"first": first, "after": after, "last": last, "before": before} + ) + assert result.data == {"projects": {"totalCount": 6}} + + +@pytest.mark.django_db(transaction=True) +def test_total_count_works_with_edges(test_objects): + query = """ + query TestQuery($first: Int, $after: String, $last: Int, $before: String) { + projects(first: $first, after: $after, last: $last, before: $before, order: { id: ASC }) { + totalCount + edges { + node { + id + } + } + } + } + """ + with assert_num_queries(2): + result = schema.execute_sync( + query, {"first": 3, "after": to_base64(DjangoCursorEdge.PREFIX, '["2"]')} + ) + assert result.data == { + "projects": { + "totalCount": 6, + "edges": [ + {"node": {"id": str(GlobalID("ProjectType", "3"))}}, + {"node": {"id": str(GlobalID("ProjectType", "4"))}}, + {"node": {"id": str(GlobalID("ProjectType", "5"))}}, + ], + } + } + + +@pytest.mark.django_db(transaction=True) +def test_nested_total_count(): + p1 = Project.objects.create() + p2 = Project.objects.create() + + p1m = [Milestone.objects.create(project=p1) for _ in range(3)] + p2m = [Milestone.objects.create(project=p2) for _ in range(2)] + + query = """ + query TestQuery { + projects(first: 2, order: { id: ASC }) { + edges { + node { + id + milestones { totalCount edges { node { id } } } + } + } + } + } + """ + with assert_num_queries(2): + result = schema.execute_sync(query) + assert result.data == { + "projects": { + "edges": [ + { + "node": { + "id": str(GlobalID("ProjectType", str(p1.pk))), + "milestones": { + "totalCount": 3, + "edges": [ + { + "node": { + "id": str( + GlobalID("MilestoneType", str(m.pk)) + ) + } + } + for m in p1m + ], + }, + } + }, + { + "node": { + "id": str(GlobalID("ProjectType", str(p2.pk))), + "milestones": { + "totalCount": 2, + "edges": [ + { + "node": { + "id": str( + GlobalID("MilestoneType", str(m.pk)) + ) + } + } + for m in p2m + ], + }, + } + }, + ], + } + } + + +@pytest.mark.django_db(transaction=True) +@pytest.mark.parametrize( + "cursor", + [ + *( + to_base64(DjangoCursorEdge.PREFIX, c) + for c in ("", "[]", "[1]", "{}", "foo", '["foo"]') + ), + to_base64("foo", "bar"), + to_base64("foo", '["1"]'), + ], +) +def test_invalid_cursor(cursor, test_objects): + query = """ + query TestQuery($after: String) { + projects(after: $after, order: { id: ASC }) { + edges { + cursor + node { + id + } + } + } + } + """ + result = schema.execute_sync(query, {"after": cursor}) + assert result.data is None + assert result.errors + assert result.errors[0].message == "Invalid cursor" + + +@pytest.mark.django_db(transaction=True) +@pytest.mark.parametrize( + ("first", "last", "error_message"), + [ + (-1, None, "Argument 'first' must be a non-negative integer."), + (None, -1, "Argument 'last' must be a non-negative integer."), + (150, None, "Argument 'first' cannot be higher than 100."), + (None, 150, "Argument 'last' cannot be higher than 100."), + (30, 150, "Argument 'last' cannot be higher than 100."), + ], +) +def test_invalid_offsets(first, last, error_message, test_objects): + query = """ + query TestQuery($first: Int, $last: Int) { + projects(first: $first, last: $last, order: { id: ASC }) { + edges { + cursor + node { + id + } + } + } + } + """ + result = schema.execute_sync(query, {"first": first, "last": last}) + assert result.data is None + assert result.errors + assert result.errors[0].message == error_message + + +@pytest.mark.django_db(transaction=True) +def test_cursor_connection_rejects_non_querysets(mocker: MockFixture): + with pytest.raises(TypeError): + DjangoCursorConnection.resolve_connection( + list(Project.objects.all()), info=mocker.Mock() + ) diff --git a/tests/test_queryset_config.py b/tests/test_queryset_config.py new file mode 100644 index 00000000..02657ac3 --- /dev/null +++ b/tests/test_queryset_config.py @@ -0,0 +1,45 @@ +import pytest +from django.db.models import Prefetch + +from strawberry_django.queryset import get_queryset_config +from tests.projects.models import Milestone, Project + + +def test_queryset_config_survives_filter(): + qs = Project.objects.all() + config = get_queryset_config(qs) + config.optimized = True + new_qs = qs.filter(pk=1) + assert get_queryset_config(new_qs).optimized is True + + +def test_queryset_config_survives_prefetch_related(): + qs = Project.objects.all() + config = get_queryset_config(qs) + config.optimized = True + new_qs = qs.prefetch_related("milestones") + assert get_queryset_config(new_qs).optimized is True + + +def test_queryset_config_survives_select_related(): + qs = Milestone.objects.all() + config = get_queryset_config(qs) + config.optimized = True + new_qs = qs.select_related("project") + assert get_queryset_config(new_qs).optimized is True + + +@pytest.mark.django_db(transaction=True) +def test_queryset_config_survives_in_prefetch_queryset(): + Project.objects.create() + qs = Milestone.objects.all() + config = get_queryset_config(qs) + config.optimized = True + + project = ( + Project.objects.all() + .prefetch_related(Prefetch("milestones", queryset=qs)) + .get() + ) + + assert get_queryset_config(project.milestones.all()).optimized is True