Skip to content

Commit 7cd14e8

Browse files
fruitymedleysourcery-ai[bot]jojoCkk3
authored
Optional Pagination (#168)
* Pre-commit-update * Add option to exclude relay pagination * Use standard resolver for exclude relay * Clarify docstring Co-authored-by: sourcery-ai[bot] <58596630+sourcery-ai[bot]@users.noreply.github.com> * Documentation * Add exclude relay test * Add release.md * Use property of class * Restore relay.py * Upd test to use list * fix tests * adding type ignore on return * fix test * use isinstance on test --------- Co-authored-by: sourcery-ai[bot] <58596630+sourcery-ai[bot]@users.noreply.github.com> Co-authored-by: jojo <[email protected]> Co-authored-by: Ckk3 <[email protected]>
1 parent d1fc069 commit 7cd14e8

File tree

5 files changed

+64
-30
lines changed

5 files changed

+64
-30
lines changed

.pre-commit-config.yaml

+4-4
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,12 @@
11
repos:
22
- repo: https://github.com/psf/black
3-
rev: 23.9.1
3+
rev: 24.4.2
44
hooks:
55
- id: black
66
exclude: ^tests/\w+/snapshots/
77

88
- repo: https://github.com/astral-sh/ruff-pre-commit
9-
rev: v0.0.289
9+
rev: v0.4.5
1010
hooks:
1111
- id: ruff
1212
exclude: ^tests/\w+/snapshots/
@@ -18,13 +18,13 @@ repos:
1818
exclude: (CHANGELOG|TWEET).md
1919

2020
- repo: https://github.com/pre-commit/mirrors-prettier
21-
rev: v3.0.3
21+
rev: v4.0.0-alpha.8
2222
hooks:
2323
- id: prettier
2424
files: '^docs/.*\.mdx?$'
2525

2626
- repo: https://github.com/pre-commit/pre-commit-hooks
27-
rev: v4.4.0
27+
rev: v4.6.0
2828
hooks:
2929
- id: trailing-whitespace
3030
- id: check-merge-conflict

RELEASE.md

+4
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
Release type: minor
2+
3+
Add an optional function to exclude relationships from relay pagination and use traditional strawberry lists.
4+
Default behavior preserves original behavior for backwords compatibilty.

src/strawberry_sqlalchemy_mapper/mapper.py

+18-9
Original file line numberDiff line numberDiff line change
@@ -150,13 +150,11 @@ class StrawberrySQLAlchemyType(Generic[BaseModelType]):
150150

151151
@overload
152152
@classmethod
153-
def from_type(cls, type_: type, *, strict: Literal[True]) -> Self:
154-
...
153+
def from_type(cls, type_: type, *, strict: Literal[True]) -> Self: ...
155154

156155
@overload
157156
@classmethod
158-
def from_type(cls, type_: type, *, strict: bool = False) -> Optional[Self]:
159-
...
157+
def from_type(cls, type_: type, *, strict: bool = False) -> Optional[Self]: ...
160158

161159
@classmethod
162160
def from_type(
@@ -374,7 +372,7 @@ def _convert_column_to_strawberry_type(
374372
return type_annotation
375373

376374
def _convert_relationship_to_strawberry_type(
377-
self, relationship: RelationshipProperty
375+
self, relationship: RelationshipProperty, use_list: bool = False
378376
) -> Union[Type[Any], ForwardRef]:
379377
"""
380378
Given a SQLAlchemy relationship, return the type annotation for the field in the
@@ -387,6 +385,10 @@ def _convert_relationship_to_strawberry_type(
387385
else:
388386
self._related_type_models.add(relationship_model)
389387
if relationship.uselist:
388+
# Use list if excluding relay pagination
389+
if use_list:
390+
return List[ForwardRef(type_name)] # type: ignore
391+
390392
return self._connection_type_for(type_name)
391393
else:
392394
if self._get_relationship_is_optional(relationship):
@@ -524,14 +526,14 @@ async def resolve(self, info: Info):
524526
return resolve
525527

526528
def connection_resolver_for(
527-
self, relationship: RelationshipProperty
529+
self, relationship: RelationshipProperty, use_list=False
528530
) -> Callable[..., Awaitable[Any]]:
529531
"""
530532
Return an async field resolver for the given relationship that
531533
returns a Connection instead of an array of objects.
532534
"""
533535
relationship_resolver = self.relationship_resolver_for(relationship)
534-
if relationship.uselist:
536+
if relationship.uselist and not use_list:
535537
return self.make_connection_wrapper_resolver(
536538
relationship_resolver,
537539
self.model_to_type_or_interface_name(relationship.entity.entity), # type: ignore[arg-type]
@@ -666,6 +668,7 @@ def convert(type_: Any) -> Any:
666668
generated_field_keys = []
667669

668670
excluded_keys = getattr(type_, "__exclude__", [])
671+
list_keys = getattr(type_, "__use_list__", [])
669672

670673
# if the type inherits from another mapped type, then it may have
671674
# generated resolvers. These will be treated by dataclasses as having
@@ -690,7 +693,8 @@ def convert(type_: Any) -> Any:
690693
):
691694
continue
692695
strawberry_type = self._convert_relationship_to_strawberry_type(
693-
relationship
696+
relationship,
697+
key in list_keys,
694698
)
695699
self._add_annotation(
696700
type_,
@@ -700,7 +704,12 @@ def convert(type_: Any) -> Any:
700704
)
701705
sqlalchemy_field = cast(
702706
StrawberryField,
703-
field(resolver=self.connection_resolver_for(relationship)),
707+
field(
708+
resolver=self.connection_resolver_for(
709+
relationship,
710+
key in list_keys,
711+
)
712+
),
704713
)
705714
assert not sqlalchemy_field.init
706715
setattr(

src/strawberry_sqlalchemy_mapper/relay.py

+7-14
Original file line numberDiff line numberDiff line change
@@ -158,8 +158,7 @@ def resolve_model_nodes(
158158
info: Optional[Info] = None,
159159
node_ids: Iterable[Union[str, relay.GlobalID]],
160160
required: Literal[True],
161-
) -> AwaitableOrValue[Iterable[_T]]:
162-
...
161+
) -> AwaitableOrValue[Iterable[_T]]: ...
163162

164163

165164
@overload
@@ -174,8 +173,7 @@ def resolve_model_nodes(
174173
info: Optional[Info] = None,
175174
node_ids: None = None,
176175
required: Literal[True],
177-
) -> AwaitableOrValue[Iterable[_T]]:
178-
...
176+
) -> AwaitableOrValue[Iterable[_T]]: ...
179177

180178

181179
@overload
@@ -190,8 +188,7 @@ def resolve_model_nodes(
190188
info: Optional[Info] = None,
191189
node_ids: Iterable[Union[str, relay.GlobalID]],
192190
required: Literal[False],
193-
) -> AwaitableOrValue[Iterable[Optional[_T]]]:
194-
...
191+
) -> AwaitableOrValue[Iterable[Optional[_T]]]: ...
195192

196193

197194
@overload
@@ -206,8 +203,7 @@ def resolve_model_nodes(
206203
info: Optional[Info] = None,
207204
node_ids: None = None,
208205
required: Literal[False],
209-
) -> AwaitableOrValue[Optional[Iterable[_T]]]:
210-
...
206+
) -> AwaitableOrValue[Optional[Iterable[_T]]]: ...
211207

212208

213209
@overload
@@ -229,8 +225,7 @@ def resolve_model_nodes(
229225
Iterable[Optional[_T]],
230226
Optional[Query[_T]],
231227
]
232-
]:
233-
...
228+
]: ...
234229

235230

236231
def resolve_model_nodes(
@@ -307,8 +302,7 @@ def resolve_model_node(
307302
session: Session,
308303
info: Optional[Info] = ...,
309304
required: Literal[False] = ...,
310-
) -> AwaitableOrValue[Optional[_T]]:
311-
...
305+
) -> AwaitableOrValue[Optional[_T]]: ...
312306

313307

314308
@overload
@@ -323,8 +317,7 @@ def resolve_model_node(
323317
session: Session,
324318
info: Optional[Info] = ...,
325319
required: Literal[True],
326-
) -> AwaitableOrValue[_T]:
327-
...
320+
) -> AwaitableOrValue[_T]: ...
328321

329322

330323
def resolve_model_node(

tests/test_mapper.py

+31-3
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from sqlalchemy.dialects.postgresql.array import ARRAY
99
from sqlalchemy.orm import relationship
1010
from strawberry.scalars import JSON as StrawberryJSON
11-
from strawberry.types.base import StrawberryOptional
11+
from strawberry.types.base import StrawberryList, StrawberryOptional
1212
from strawberry_sqlalchemy_mapper import StrawberrySQLAlchemyMapper
1313

1414

@@ -263,6 +263,35 @@ class Lawyer:
263263
assert {"Employee", "Lawyer"} == {t.__name__ for t in additional_types}
264264

265265

266+
def test_use_list(employee_and_department_tables, mapper):
267+
Employee, Department = employee_and_department_tables
268+
269+
@mapper.type(Employee)
270+
class Employee:
271+
pass
272+
273+
@mapper.type(Department)
274+
class Department:
275+
__use_list__ = ["employees"]
276+
277+
mapper.finalize()
278+
additional_types = list(mapper.mapped_types.values())
279+
assert len(additional_types) == 2
280+
mapped_employee_type = additional_types[0]
281+
assert mapped_employee_type.__name__ == "Employee"
282+
mapped_department_type = additional_types[1]
283+
assert mapped_department_type.__name__ == "Department"
284+
assert len(mapped_department_type.__strawberry_definition__.fields) == 3
285+
department_type_fields = mapped_department_type.__strawberry_definition__.fields
286+
287+
name = next(
288+
(field for field in department_type_fields if field.name == "employees"), None
289+
)
290+
assert name is not None
291+
assert isinstance(name.type, StrawberryOptional) is False
292+
assert isinstance(name.type, StrawberryList) is True
293+
294+
266295
def test_type_relationships(employee_and_department_tables, mapper):
267296
Employee, _ = employee_and_department_tables
268297

@@ -297,8 +326,7 @@ class Department:
297326
@strawberry.type
298327
class Query:
299328
@strawberry.field
300-
def departments(self) -> Department:
301-
...
329+
def departments(self) -> Department: ...
302330

303331
mapper.finalize()
304332
schema = strawberry.Schema(query=Query)

0 commit comments

Comments
 (0)