|
1 | 1 | from inspect import iscoroutinefunction
|
| 2 | +from typing import Awaitable |
2 | 3 | from typing import Optional, cast
|
3 | 4 |
|
4 | 5 | from graphql import GraphQLNamedType
|
5 | 6 | from graphql.pyutils import is_awaitable
|
6 | 7 | from graphql.type import GraphQLSchema
|
7 | 8 |
|
8 | 9 | from ariadne import InterfaceType, ObjectType
|
| 10 | +from ariadne.contrib.relay import RelayConnection |
9 | 11 | from ariadne.contrib.relay.arguments import (
|
10 | 12 | ConnectionArguments,
|
11 | 13 | ConnectionArgumentsTypeUnion,
|
@@ -36,11 +38,11 @@ def wrapper(obj, info, *args, **kwargs):
|
36 | 38 | if is_async_callable(resolver):
|
37 | 39 |
|
38 | 40 | async def async_my_extension():
|
39 |
| - relay_connection = await resolver( |
| 41 | + relay_connection = resolver( |
40 | 42 | obj, info, connection_arguments, *args, **kwargs
|
41 | 43 | )
|
42 | 44 | if is_awaitable(relay_connection):
|
43 |
| - relay_connection = await relay_connection |
| 45 | + relay_connection = await cast(Awaitable[RelayConnection], relay_connection) |
44 | 46 | return {
|
45 | 47 | "totalCount": relay_connection.total,
|
46 | 48 | "edges": relay_connection.get_edges(),
|
|
0 commit comments