Skip to content

Commit b52fda9

Browse files
authored
fix: Fix missing fields from auto generated connections (#137)
We changed auto generated edges/connections to inherit from `relay.Edge`/`relay.Connection`, which requires some extra fields to be instantiated. Fix #97
1 parent ba0dceb commit b52fda9

File tree

3 files changed

+147
-7
lines changed

3 files changed

+147
-7
lines changed

RELEASE.md

+4
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
Release type: patch
2+
3+
Fix an issue where auto generated connections were missing some expected
4+
attributes to be properly instantiated.

src/strawberry_sqlalchemy_mapper/mapper.py

+18-7
Original file line numberDiff line numberDiff line change
@@ -463,13 +463,22 @@ def make_connection_wrapper_resolver(
463463
edge_type = self._edge_type_for(type_name)
464464

465465
async def wrapper(self, info: Info):
466+
# TODO: Add pagination support to dataloader resolvers
467+
edges = [
468+
edge_type.resolve_edge(
469+
related_object,
470+
cursor=i,
471+
)
472+
for i, related_object in enumerate(await resolver(self, info))
473+
]
466474
return connection_type(
467-
edges=[
468-
edge_type(
469-
node=related_object,
470-
)
471-
for related_object in await resolver(self, info)
472-
]
475+
edges=edges,
476+
page_info=relay.PageInfo(
477+
has_next_page=False,
478+
has_previous_page=False,
479+
start_cursor=edges[0].cursor if edges else None,
480+
end_cursor=edges[-1].cursor if edges else None,
481+
),
473482
)
474483

475484
setattr(wrapper, _IS_GENERATED_RESOLVER_KEY, True)
@@ -732,7 +741,9 @@ def convert(type_: Any) -> Any:
732741
StrawberryField,
733742
field(
734743
resolver=self.association_proxy_resolver_for(
735-
mapper, descriptor, strawberry_type # type: ignore[arg-type]
744+
mapper,
745+
descriptor,
746+
strawberry_type, # type: ignore[arg-type]
736747
)
737748
),
738749
)

tests/relay/test_auto_connections.py

+125
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,125 @@
1+
from typing import Any
2+
3+
import pytest
4+
import strawberry
5+
from sqlalchemy import Column, ForeignKey, Integer, String
6+
from sqlalchemy.ext.asyncio.engine import AsyncEngine
7+
from sqlalchemy.orm import relationship
8+
from strawberry.relay.utils import to_base64
9+
from strawberry_sqlalchemy_mapper import StrawberrySQLAlchemyMapper
10+
from strawberry_sqlalchemy_mapper.loader import StrawberrySQLAlchemyLoader
11+
12+
13+
@pytest.fixture
14+
def user_and_group_tables(base: Any):
15+
class User(base):
16+
__tablename__ = "user"
17+
id = Column(Integer, autoincrement=True, primary_key=True)
18+
name = Column(String(50), nullable=False)
19+
group_id = Column(Integer, ForeignKey("group.id"))
20+
group = relationship("Group", back_populates="users")
21+
22+
class Group(base):
23+
__tablename__ = "group"
24+
id = Column(Integer, autoincrement=True, primary_key=True)
25+
name = Column(String, nullable=False)
26+
users = relationship("User", back_populates="group")
27+
28+
return User, Group
29+
30+
31+
@pytest.mark.asyncio
32+
async def test_query_auto_generated_connection(
33+
base: Any,
34+
async_engine: AsyncEngine,
35+
async_sessionmaker,
36+
user_and_group_tables,
37+
):
38+
user_table, group_table = user_and_group_tables
39+
40+
async with async_engine.begin() as conn:
41+
await conn.run_sync(base.metadata.create_all)
42+
mapper = StrawberrySQLAlchemyMapper()
43+
44+
global User, Group
45+
try:
46+
47+
@mapper.type(user_table)
48+
class User:
49+
...
50+
51+
@mapper.type(group_table)
52+
class Group:
53+
...
54+
55+
@strawberry.type
56+
class Query:
57+
@strawberry.field
58+
async def group(self, id: strawberry.ID) -> Group:
59+
session = async_sessionmaker()
60+
return await session.get(group_table, int(id))
61+
62+
schema = strawberry.Schema(query=Query)
63+
64+
query = """\
65+
query GetGroup ($id: ID!) {
66+
group(id: $id) {
67+
id
68+
name
69+
users {
70+
pageInfo {
71+
hasNextPage
72+
hasPreviousPage
73+
startCursor
74+
endCursor
75+
}
76+
edges {
77+
node {
78+
id
79+
name
80+
}
81+
}
82+
}
83+
}
84+
}
85+
"""
86+
87+
async with async_sessionmaker(expire_on_commit=False) as session:
88+
group = group_table(name="Foo Bar")
89+
user1 = user_table(name="User 1", group=group)
90+
user2 = user_table(name="User 2", group=group)
91+
user3 = user_table(name="User 3", group=group)
92+
session.add_all([group, user1, user2, user3])
93+
await session.commit()
94+
95+
result = await schema.execute(
96+
query,
97+
variable_values={"id": group.id},
98+
context_value={
99+
"sqlalchemy_loader": StrawberrySQLAlchemyLoader(
100+
async_bind_factory=async_sessionmaker
101+
)
102+
},
103+
)
104+
assert result.errors is None
105+
assert result.data == {
106+
"group": {
107+
"id": group.id,
108+
"name": "Foo Bar",
109+
"users": {
110+
"pageInfo": {
111+
"hasNextPage": False,
112+
"hasPreviousPage": False,
113+
"startCursor": to_base64("arrayconnection", "0"),
114+
"endCursor": to_base64("arrayconnection", "2"),
115+
},
116+
"edges": [
117+
{"node": {"id": user1.id, "name": "User 1"}},
118+
{"node": {"id": user2.id, "name": "User 2"}},
119+
{"node": {"id": user3.id, "name": "User 3"}},
120+
],
121+
},
122+
},
123+
}
124+
finally:
125+
del User, Group

0 commit comments

Comments
 (0)