From 21765ab4b74c65cf985017c54847d57093340d9c Mon Sep 17 00:00:00 2001 From: Michael Chase <3686226+reallistic@users.noreply.github.com> Date: Mon, 7 Apr 2025 22:46:19 -0400 Subject: [PATCH 1/2] check for async callable object and function for relay resolver --- ariadne/contrib/relay/objects.py | 4 ++-- ariadne/utils.py | 7 +++++++ tests/relay/test_objects.py | 25 +++++++++++++++++++++++++ 3 files changed, 34 insertions(+), 2 deletions(-) diff --git a/ariadne/contrib/relay/objects.py b/ariadne/contrib/relay/objects.py index 67883050..83d3547d 100644 --- a/ariadne/contrib/relay/objects.py +++ b/ariadne/contrib/relay/objects.py @@ -16,7 +16,7 @@ ) from ariadne.contrib.relay.utils import decode_global_id from ariadne.types import Resolver -from ariadne.utils import type_get_extension, type_set_extension +from ariadne.utils import is_async_callable, type_get_extension, type_set_extension class RelayObjectType(ObjectType): @@ -33,7 +33,7 @@ def __init__( def resolve_wrapper(self, resolver: ConnectionResolver): def wrapper(obj, info, *args, **kwargs): connection_arguments = self.connection_arguments_class(**kwargs) - if iscoroutinefunction(resolver): + if is_async_callable(resolver): async def async_my_extension(): relay_connection = await resolver( diff --git a/ariadne/utils.py b/ariadne/utils.py index 9034b75e..e2cc5d78 100644 --- a/ariadne/utils.py +++ b/ariadne/utils.py @@ -1,4 +1,5 @@ import asyncio +import inspect from collections.abc import Mapping from functools import wraps from typing import Any, Callable, Optional, Union, cast @@ -264,3 +265,9 @@ def type_get_extension( object_type: GraphQLNamedType, extension_name: str, fallback: Any = None ) -> Any: return getattr(object_type, "extensions", {}).get(extension_name, fallback) + + +def is_async_callable(obj: Any) -> bool: + return inspect.iscoroutinefunction(obj) or ( + callable(obj) and inspect.iscoroutinefunction(obj.__call__) + ) diff --git a/tests/relay/test_objects.py b/tests/relay/test_objects.py index 9d0d1e78..215ca8ec 100644 --- a/tests/relay/test_objects.py +++ b/tests/relay/test_objects.py @@ -271,3 +271,28 @@ def test_relay_node_query_faction( assert result.errors is None assert result.data == {"node": {"bid": "RmFjdGlvbjoy", "name": "Galactic Empire"}} + + +@pytest.mark.asyncio +async def test_relay_object_resolve_wrapper_awaitable(friends_connection): + class Resolver: + async def __call__(self, *_, **__): + return friends_connection + + object_type = RelayObjectType("User") + wrapped_resolver = object_type.resolve_wrapper(Resolver()) + + result = await wrapped_resolver(None, None, first=10) + assert result == { + "totalCount": 2, + "edges": [ + {"node": {"id": "VXNlcjox", "name": "Alice"}, "cursor": "VXNlcjox"}, + {"node": {"id": "VXNlcjoy", "name": "Bob"}, "cursor": "VXNlcjoy"}, + ], + "pageInfo": { + "hasNextPage": False, + "hasPreviousPage": False, + "startCursor": "VXNlcjox", + "endCursor": "VXNlcjoy", + }, + } From 3b1d13ce6312c825152328ab8645d492b040fdd3 Mon Sep 17 00:00:00 2001 From: Michael Chase <3686226+reallistic@users.noreply.github.com> Date: Mon, 7 Apr 2025 23:04:47 -0400 Subject: [PATCH 2/2] fix relay typing --- ariadne/contrib/relay/objects.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/ariadne/contrib/relay/objects.py b/ariadne/contrib/relay/objects.py index 83d3547d..cd2795b4 100644 --- a/ariadne/contrib/relay/objects.py +++ b/ariadne/contrib/relay/objects.py @@ -1,3 +1,4 @@ +from collections.abc import Awaitable from inspect import iscoroutinefunction from typing import Optional, cast @@ -6,6 +7,7 @@ from graphql.type import GraphQLSchema from ariadne import InterfaceType, ObjectType +from ariadne.contrib.relay import RelayConnection from ariadne.contrib.relay.arguments import ( ConnectionArguments, ConnectionArgumentsTypeUnion, @@ -36,11 +38,13 @@ def wrapper(obj, info, *args, **kwargs): if is_async_callable(resolver): async def async_my_extension(): - relay_connection = await resolver( + relay_connection = resolver( obj, info, connection_arguments, *args, **kwargs ) if is_awaitable(relay_connection): - relay_connection = await relay_connection + relay_connection = await cast( + Awaitable[RelayConnection], relay_connection + ) return { "totalCount": relay_connection.total, "edges": relay_connection.get_edges(),