1
+ import enum
1
2
import warnings
2
3
from functools import partial
3
4
4
- import six
5
5
from promise import Promise , is_thenable
6
6
from sqlalchemy .orm .query import Query
7
7
8
8
from graphene import NonNull
9
9
from graphene .relay import Connection , ConnectionField
10
- from graphene .relay .connection import PageInfo
11
- from graphql_relay .connection .arrayconnection import connection_from_list_slice
10
+ from graphene .relay .connection import connection_adapter , page_info_adapter
11
+ from graphql_relay .connection .arrayconnection import \
12
+ connection_from_array_slice
12
13
13
14
from .batching import get_batch_resolver
14
- from .utils import get_query
15
+ from .utils import EnumValue , get_query
15
16
16
17
17
18
class UnsortedSQLAlchemyConnectionField (ConnectionField ):
18
19
@property
19
20
def type (self ):
20
21
from .types import SQLAlchemyObjectType
21
22
22
- _type = super (ConnectionField , self ).type
23
- nullable_type = get_nullable_type (_type )
23
+ type_ = super (ConnectionField , self ).type
24
+ nullable_type = get_nullable_type (type_ )
24
25
if issubclass (nullable_type , Connection ):
25
- return _type
26
+ return type_
26
27
assert issubclass (nullable_type , SQLAlchemyObjectType ), (
27
28
"SQLALchemyConnectionField only accepts SQLAlchemyObjectType types, not {}"
28
29
).format (nullable_type .__name__ )
@@ -31,7 +32,7 @@ def type(self):
31
32
), "The type {} doesn't have a connection" .format (
32
33
nullable_type .__name__
33
34
)
34
- assert _type == nullable_type , (
35
+ assert type_ == nullable_type , (
35
36
"Passing a SQLAlchemyObjectType instance is deprecated. "
36
37
"Pass the connection type instead accessible via SQLAlchemyObjectType.connection"
37
38
)
@@ -53,15 +54,19 @@ def resolve_connection(cls, connection_type, model, info, args, resolved):
53
54
_len = resolved .count ()
54
55
else :
55
56
_len = len (resolved )
56
- connection = connection_from_list_slice (
57
- resolved ,
58
- args ,
57
+
58
+ def adjusted_connection_adapter (edges , pageInfo ):
59
+ return connection_adapter (connection_type , edges , pageInfo )
60
+
61
+ connection = connection_from_array_slice (
62
+ array_slice = resolved ,
63
+ args = args ,
59
64
slice_start = 0 ,
60
- list_length = _len ,
61
- list_slice_length = _len ,
62
- connection_type = connection_type ,
63
- pageinfo_type = PageInfo ,
65
+ array_length = _len ,
66
+ array_slice_length = _len ,
67
+ connection_type = adjusted_connection_adapter ,
64
68
edge_type = connection_type .Edge ,
69
+ page_info_type = page_info_adapter ,
65
70
)
66
71
connection .iterable = resolved
67
72
connection .length = _len
@@ -77,7 +82,7 @@ def connection_resolver(cls, resolver, connection_type, model, root, info, **arg
77
82
78
83
return on_resolve (resolved )
79
84
80
- def get_resolver (self , parent_resolver ):
85
+ def wrap_resolve (self , parent_resolver ):
81
86
return partial (
82
87
self .connection_resolver ,
83
88
parent_resolver ,
@@ -88,8 +93,8 @@ def get_resolver(self, parent_resolver):
88
93
89
94
# TODO Rename this to SortableSQLAlchemyConnectionField
90
95
class SQLAlchemyConnectionField (UnsortedSQLAlchemyConnectionField ):
91
- def __init__ (self , type , * args , ** kwargs ):
92
- nullable_type = get_nullable_type (type )
96
+ def __init__ (self , type_ , * args , ** kwargs ):
97
+ nullable_type = get_nullable_type (type_ )
93
98
if "sort" not in kwargs and issubclass (nullable_type , Connection ):
94
99
# Let super class raise if type is not a Connection
95
100
try :
@@ -103,16 +108,25 @@ def __init__(self, type, *args, **kwargs):
103
108
)
104
109
elif "sort" in kwargs and kwargs ["sort" ] is None :
105
110
del kwargs ["sort" ]
106
- super (SQLAlchemyConnectionField , self ).__init__ (type , * args , ** kwargs )
111
+ super (SQLAlchemyConnectionField , self ).__init__ (type_ , * args , ** kwargs )
107
112
108
113
@classmethod
109
114
def get_query (cls , model , info , sort = None , ** args ):
110
115
query = get_query (model , info .context )
111
116
if sort is not None :
112
- if isinstance (sort , six .string_types ):
113
- query = query .order_by (sort .value )
114
- else :
115
- query = query .order_by (* (col .value for col in sort ))
117
+ if not isinstance (sort , list ):
118
+ sort = [sort ]
119
+ sort_args = []
120
+ # ensure consistent handling of graphene Enums, enum values and
121
+ # plain strings
122
+ for item in sort :
123
+ if isinstance (item , enum .Enum ):
124
+ sort_args .append (item .value .value )
125
+ elif isinstance (item , EnumValue ):
126
+ sort_args .append (item .value )
127
+ else :
128
+ sort_args .append (item )
129
+ query = query .order_by (* sort_args )
116
130
return query
117
131
118
132
@@ -123,7 +137,7 @@ class BatchSQLAlchemyConnectionField(UnsortedSQLAlchemyConnectionField):
123
137
Use at your own risk.
124
138
"""
125
139
126
- def get_resolver (self , parent_resolver ):
140
+ def wrap_resolve (self , parent_resolver ):
127
141
return partial (
128
142
self .connection_resolver ,
129
143
self .resolver ,
@@ -148,13 +162,13 @@ def default_connection_field_factory(relationship, registry, **field_kwargs):
148
162
__connectionFactory = UnsortedSQLAlchemyConnectionField
149
163
150
164
151
- def createConnectionField (_type , ** field_kwargs ):
165
+ def createConnectionField (type_ , ** field_kwargs ):
152
166
warnings .warn (
153
167
'createConnectionField is deprecated and will be removed in the next '
154
168
'major version. Use SQLAlchemyObjectType.Meta.connection_field_factory instead.' ,
155
169
DeprecationWarning ,
156
170
)
157
- return __connectionFactory (_type , ** field_kwargs )
171
+ return __connectionFactory (type_ , ** field_kwargs )
158
172
159
173
160
174
def registerConnectionFieldFactory (factoryMethod ):
0 commit comments