Skip to content

Commit d6dd67e

Browse files
richin13DoctorJohnzsiciarzcolelin26
authored
Graphene v3 (tests) (#317)
Co-authored-by: Jonathan Ehwald <[email protected]> Co-authored-by: Zbigniew Siciarz <[email protected]> Co-authored-by: Cole Lin <[email protected]>
1 parent cba727c commit d6dd67e

21 files changed

+196
-220
lines changed

.github/workflows/tests.yml

+25-25
Original file line numberDiff line numberDiff line change
@@ -8,31 +8,31 @@ jobs:
88
strategy:
99
max-parallel: 10
1010
matrix:
11-
sql-alchemy: ["1.2", "1.3"]
11+
sql-alchemy: ["1.2", "1.3", "1.4"]
1212
python-version: ["3.6", "3.7", "3.8", "3.9"]
1313

1414
steps:
15-
- uses: actions/checkout@v2
16-
- name: Set up Python ${{ matrix.python-version }}
17-
uses: actions/setup-python@v2
18-
with:
19-
python-version: ${{ matrix.python-version }}
20-
- name: Install dependencies
21-
run: |
22-
python -m pip install --upgrade pip
23-
pip install tox tox-gh-actions
24-
- name: Test with tox
25-
run: tox
26-
env:
27-
SQLALCHEMY: ${{ matrix.sql-alchemy }}
28-
TOXENV: ${{ matrix.toxenv }}
29-
- name: Upload coverage.xml
30-
if: ${{ matrix.sql-alchemy == '1.3' && matrix.python-version == '3.9' }}
31-
uses: actions/upload-artifact@v2
32-
with:
33-
name: graphene-sqlalchemy-coverage
34-
path: coverage.xml
35-
if-no-files-found: error
36-
- name: Upload coverage.xml to codecov
37-
if: ${{ matrix.sql-alchemy == '1.3' && matrix.python-version == '3.9' }}
38-
uses: codecov/codecov-action@v1
15+
- uses: actions/checkout@v2
16+
- name: Set up Python ${{ matrix.python-version }}
17+
uses: actions/setup-python@v2
18+
with:
19+
python-version: ${{ matrix.python-version }}
20+
- name: Install dependencies
21+
run: |
22+
python -m pip install --upgrade pip
23+
pip install tox tox-gh-actions
24+
- name: Test with tox
25+
run: tox
26+
env:
27+
SQLALCHEMY: ${{ matrix.sql-alchemy }}
28+
TOXENV: ${{ matrix.toxenv }}
29+
- name: Upload coverage.xml
30+
if: ${{ matrix.sql-alchemy == '1.4' && matrix.python-version == '3.9' }}
31+
uses: actions/upload-artifact@v2
32+
with:
33+
name: graphene-sqlalchemy-coverage
34+
path: coverage.xml
35+
if-no-files-found: error
36+
- name: Upload coverage.xml to codecov
37+
if: ${{ matrix.sql-alchemy == '1.4' && matrix.python-version == '3.9' }}
38+
uses: codecov/codecov-action@v1

.gitignore

+3
Original file line numberDiff line numberDiff line change
@@ -69,3 +69,6 @@ target/
6969
# Databases
7070
*.sqlite3
7171
.vscode
72+
73+
# mypy cache
74+
.mypy_cache/

graphene_sqlalchemy/__init__.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
from .fields import SQLAlchemyConnectionField
33
from .utils import get_query, get_session
44

5-
__version__ = "2.3.0"
5+
__version__ = "3.0.0b1"
66

77
__all__ = [
88
"__version__",

graphene_sqlalchemy/batching.py

+33-16
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,21 @@
1+
import aiodataloader
12
import sqlalchemy
2-
from promise import dataloader, promise
33
from sqlalchemy.orm import Session, strategies
44
from sqlalchemy.orm.query import QueryContext
55

6+
from .utils import is_sqlalchemy_version_less_than
7+
68

79
def get_batch_resolver(relationship_prop):
810

911
# Cache this across `batch_load_fn` calls
1012
# This is so SQL string generation is cached under-the-hood via `bakery`
1113
selectin_loader = strategies.SelectInLoader(relationship_prop, (('lazy', 'selectin'),))
1214

13-
class RelationshipLoader(dataloader.DataLoader):
15+
class RelationshipLoader(aiodataloader.DataLoader):
1416
cache = False
1517

16-
def batch_load_fn(self, parents): # pylint: disable=method-hidden
18+
async def batch_load_fn(self, parents):
1719
"""
1820
Batch loads the relationships of all the parents as one SQL statement.
1921
@@ -52,21 +54,36 @@ def batch_load_fn(self, parents): # pylint: disable=method-hidden
5254
states = [(sqlalchemy.inspect(parent), True) for parent in parents]
5355

5456
# For our purposes, the query_context will only used to get the session
55-
query_context = QueryContext(session.query(parent_mapper.entity))
56-
57-
selectin_loader._load_for_path(
58-
query_context,
59-
parent_mapper._path_registry,
60-
states,
61-
None,
62-
child_mapper,
63-
)
64-
65-
return promise.Promise.resolve([getattr(parent, relationship_prop.key) for parent in parents])
57+
query_context = None
58+
if is_sqlalchemy_version_less_than('1.4'):
59+
query_context = QueryContext(session.query(parent_mapper.entity))
60+
else:
61+
parent_mapper_query = session.query(parent_mapper.entity)
62+
query_context = parent_mapper_query._compile_context()
63+
64+
if is_sqlalchemy_version_less_than('1.4'):
65+
selectin_loader._load_for_path(
66+
query_context,
67+
parent_mapper._path_registry,
68+
states,
69+
None,
70+
child_mapper
71+
)
72+
else:
73+
selectin_loader._load_for_path(
74+
query_context,
75+
parent_mapper._path_registry,
76+
states,
77+
None,
78+
child_mapper,
79+
None
80+
)
81+
82+
return [getattr(parent, relationship_prop.key) for parent in parents]
6683

6784
loader = RelationshipLoader()
6885

69-
def resolve(root, info, **args):
70-
return loader.load(root)
86+
async def resolve(root, info, **args):
87+
return await loader.load(root)
7188

7289
return resolve

graphene_sqlalchemy/converter.py

+11-6
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
1-
from enum import EnumMeta
1+
from functools import singledispatch
22

3-
from singledispatch import singledispatch
43
from sqlalchemy import types
54
from sqlalchemy.dialects import postgresql
65
from sqlalchemy.orm import interfaces, strategies
@@ -21,6 +20,11 @@
2120
except ImportError:
2221
ChoiceType = JSONType = ScalarListType = TSVectorType = object
2322

23+
try:
24+
from sqlalchemy_utils.types.choice import EnumTypeImpl
25+
except ImportError:
26+
EnumTypeImpl = object
27+
2428

2529
is_selectin_available = getattr(strategies, 'SelectInLoader', None)
2630

@@ -110,9 +114,9 @@ def _convert_o2m_or_m2m_relationship(relationship_prop, obj_type, batching, conn
110114

111115

112116
def convert_sqlalchemy_hybrid_method(hybrid_prop, resolver, **field_kwargs):
113-
if 'type' not in field_kwargs:
117+
if 'type_' not in field_kwargs:
114118
# TODO The default type should be dependent on the type of the property propety.
115-
field_kwargs['type'] = String
119+
field_kwargs['type_'] = String
116120

117121
return Field(
118122
resolver=resolver,
@@ -156,7 +160,8 @@ def inner(fn):
156160

157161
def convert_sqlalchemy_column(column_prop, registry, resolver, **field_kwargs):
158162
column = column_prop.columns[0]
159-
field_kwargs.setdefault('type', convert_sqlalchemy_type(getattr(column, "type", None), column, registry))
163+
164+
field_kwargs.setdefault('type_', convert_sqlalchemy_type(getattr(column, "type", None), column, registry))
160165
field_kwargs.setdefault('required', not is_column_nullable(column))
161166
field_kwargs.setdefault('description', get_column_doc(column))
162167

@@ -221,7 +226,7 @@ def convert_enum_to_enum(type, column, registry=None):
221226
@convert_sqlalchemy_type.register(ChoiceType)
222227
def convert_choice_to_enum(type, column, registry=None):
223228
name = "{}_{}".format(column.table.name, column.name).upper()
224-
if isinstance(type.choices, EnumMeta):
229+
if isinstance(type.type_impl, EnumTypeImpl):
225230
# type.choices may be Enum/IntEnum, in ChoiceType both presented as EnumMeta
226231
# do not use from_enum here because we can have more than one enum column in table
227232
return Enum(name, list((v.name, v.value) for v in type.choices))

graphene_sqlalchemy/enums.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
import six
21
from sqlalchemy.orm import ColumnProperty
32
from sqlalchemy.types import Enum as SQLAlchemyEnumType
43

@@ -63,7 +62,7 @@ def enum_for_field(obj_type, field_name):
6362
if not isinstance(obj_type, type) or not issubclass(obj_type, SQLAlchemyObjectType):
6463
raise TypeError(
6564
"Expected SQLAlchemyObjectType, but got: {!r}".format(obj_type))
66-
if not field_name or not isinstance(field_name, six.string_types):
65+
if not field_name or not isinstance(field_name, str):
6766
raise TypeError(
6867
"Expected a field name, but got: {!r}".format(field_name))
6968
registry = obj_type._meta.registry

graphene_sqlalchemy/fields.py

+40-26
Original file line numberDiff line numberDiff line change
@@ -1,28 +1,29 @@
1+
import enum
12
import warnings
23
from functools import partial
34

4-
import six
55
from promise import Promise, is_thenable
66
from sqlalchemy.orm.query import Query
77

88
from graphene import NonNull
99
from graphene.relay import Connection, ConnectionField
10-
from graphene.relay.connection import PageInfo
11-
from graphql_relay.connection.arrayconnection import connection_from_list_slice
10+
from graphene.relay.connection import connection_adapter, page_info_adapter
11+
from graphql_relay.connection.arrayconnection import \
12+
connection_from_array_slice
1213

1314
from .batching import get_batch_resolver
14-
from .utils import get_query
15+
from .utils import EnumValue, get_query
1516

1617

1718
class UnsortedSQLAlchemyConnectionField(ConnectionField):
1819
@property
1920
def type(self):
2021
from .types import SQLAlchemyObjectType
2122

22-
_type = super(ConnectionField, self).type
23-
nullable_type = get_nullable_type(_type)
23+
type_ = super(ConnectionField, self).type
24+
nullable_type = get_nullable_type(type_)
2425
if issubclass(nullable_type, Connection):
25-
return _type
26+
return type_
2627
assert issubclass(nullable_type, SQLAlchemyObjectType), (
2728
"SQLALchemyConnectionField only accepts SQLAlchemyObjectType types, not {}"
2829
).format(nullable_type.__name__)
@@ -31,7 +32,7 @@ def type(self):
3132
), "The type {} doesn't have a connection".format(
3233
nullable_type.__name__
3334
)
34-
assert _type == nullable_type, (
35+
assert type_ == nullable_type, (
3536
"Passing a SQLAlchemyObjectType instance is deprecated. "
3637
"Pass the connection type instead accessible via SQLAlchemyObjectType.connection"
3738
)
@@ -53,15 +54,19 @@ def resolve_connection(cls, connection_type, model, info, args, resolved):
5354
_len = resolved.count()
5455
else:
5556
_len = len(resolved)
56-
connection = connection_from_list_slice(
57-
resolved,
58-
args,
57+
58+
def adjusted_connection_adapter(edges, pageInfo):
59+
return connection_adapter(connection_type, edges, pageInfo)
60+
61+
connection = connection_from_array_slice(
62+
array_slice=resolved,
63+
args=args,
5964
slice_start=0,
60-
list_length=_len,
61-
list_slice_length=_len,
62-
connection_type=connection_type,
63-
pageinfo_type=PageInfo,
65+
array_length=_len,
66+
array_slice_length=_len,
67+
connection_type=adjusted_connection_adapter,
6468
edge_type=connection_type.Edge,
69+
page_info_type=page_info_adapter,
6570
)
6671
connection.iterable = resolved
6772
connection.length = _len
@@ -77,7 +82,7 @@ def connection_resolver(cls, resolver, connection_type, model, root, info, **arg
7782

7883
return on_resolve(resolved)
7984

80-
def get_resolver(self, parent_resolver):
85+
def wrap_resolve(self, parent_resolver):
8186
return partial(
8287
self.connection_resolver,
8388
parent_resolver,
@@ -88,8 +93,8 @@ def get_resolver(self, parent_resolver):
8893

8994
# TODO Rename this to SortableSQLAlchemyConnectionField
9095
class SQLAlchemyConnectionField(UnsortedSQLAlchemyConnectionField):
91-
def __init__(self, type, *args, **kwargs):
92-
nullable_type = get_nullable_type(type)
96+
def __init__(self, type_, *args, **kwargs):
97+
nullable_type = get_nullable_type(type_)
9398
if "sort" not in kwargs and issubclass(nullable_type, Connection):
9499
# Let super class raise if type is not a Connection
95100
try:
@@ -103,16 +108,25 @@ def __init__(self, type, *args, **kwargs):
103108
)
104109
elif "sort" in kwargs and kwargs["sort"] is None:
105110
del kwargs["sort"]
106-
super(SQLAlchemyConnectionField, self).__init__(type, *args, **kwargs)
111+
super(SQLAlchemyConnectionField, self).__init__(type_, *args, **kwargs)
107112

108113
@classmethod
109114
def get_query(cls, model, info, sort=None, **args):
110115
query = get_query(model, info.context)
111116
if sort is not None:
112-
if isinstance(sort, six.string_types):
113-
query = query.order_by(sort.value)
114-
else:
115-
query = query.order_by(*(col.value for col in sort))
117+
if not isinstance(sort, list):
118+
sort = [sort]
119+
sort_args = []
120+
# ensure consistent handling of graphene Enums, enum values and
121+
# plain strings
122+
for item in sort:
123+
if isinstance(item, enum.Enum):
124+
sort_args.append(item.value.value)
125+
elif isinstance(item, EnumValue):
126+
sort_args.append(item.value)
127+
else:
128+
sort_args.append(item)
129+
query = query.order_by(*sort_args)
116130
return query
117131

118132

@@ -123,7 +137,7 @@ class BatchSQLAlchemyConnectionField(UnsortedSQLAlchemyConnectionField):
123137
Use at your own risk.
124138
"""
125139

126-
def get_resolver(self, parent_resolver):
140+
def wrap_resolve(self, parent_resolver):
127141
return partial(
128142
self.connection_resolver,
129143
self.resolver,
@@ -148,13 +162,13 @@ def default_connection_field_factory(relationship, registry, **field_kwargs):
148162
__connectionFactory = UnsortedSQLAlchemyConnectionField
149163

150164

151-
def createConnectionField(_type, **field_kwargs):
165+
def createConnectionField(type_, **field_kwargs):
152166
warnings.warn(
153167
'createConnectionField is deprecated and will be removed in the next '
154168
'major version. Use SQLAlchemyObjectType.Meta.connection_field_factory instead.',
155169
DeprecationWarning,
156170
)
157-
return __connectionFactory(_type, **field_kwargs)
171+
return __connectionFactory(type_, **field_kwargs)
158172

159173

160174
def registerConnectionFieldFactory(factoryMethod):

graphene_sqlalchemy/registry.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
from collections import defaultdict
22

3-
import six
43
from sqlalchemy.types import Enum as SQLAlchemyEnumType
54

65
from graphene import Enum
@@ -43,7 +42,7 @@ def register_orm_field(self, obj_type, field_name, orm_field):
4342
raise TypeError(
4443
"Expected SQLAlchemyObjectType, but got: {!r}".format(obj_type)
4544
)
46-
if not field_name or not isinstance(field_name, six.string_types):
45+
if not field_name or not isinstance(field_name, str):
4746
raise TypeError("Expected a field name, but got: {!r}".format(field_name))
4847
self._registry_orm_fields[obj_type][field_name] = orm_field
4948

graphene_sqlalchemy/tests/conftest.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ def convert_composite_class(composite, registry):
2222
return graphene.Field(graphene.Int)
2323

2424

25-
@pytest.yield_fixture(scope="function")
25+
@pytest.fixture(scope="function")
2626
def session_factory():
2727
engine = create_engine(test_db_url)
2828
Base.metadata.create_all(engine)

0 commit comments

Comments
 (0)