Skip to content

Commit ff3e419

Browse files
committed
updated mapper
1 parent 3f7f13d commit ff3e419

File tree

1 file changed

+37
-47
lines changed

1 file changed

+37
-47
lines changed

src/strawberry_sqlalchemy_mapper/mapper.py

+37-47
Original file line numberDiff line numberDiff line change
@@ -155,8 +155,7 @@ def from_type(cls, type_: type, *, strict: Literal[True]) -> Self: ...
155155

156156
@overload
157157
@classmethod
158-
def from_type(cls, type_: type, *,
159-
strict: bool = False) -> Optional[Self]: ...
158+
def from_type(cls, type_: type, *, strict: bool = False) -> Optional[Self]: ...
160159

161160
@classmethod
162161
def from_type(
@@ -167,8 +166,7 @@ def from_type(
167166
) -> Optional[Self]:
168167
definition = getattr(type_, cls.TYPE_KEY_NAME, None)
169168
if strict and definition is None:
170-
raise TypeError(
171-
f"{type_!r} does not have a StrawberrySQLAlchemyType in it")
169+
raise TypeError(f"{type_!r} does not have a StrawberrySQLAlchemyType in it")
172170
return definition
173171

174172

@@ -231,12 +229,11 @@ class StrawberrySQLAlchemyMapper(Generic[BaseModelType]):
231229

232230
def __init__(
233231
self,
234-
model_to_type_name: Optional[Callable[[
235-
Type[BaseModelType]], str]] = None,
236-
model_to_interface_name: Optional[Callable[[
237-
Type[BaseModelType]], str]] = None,
238-
extra_sqlalchemy_type_to_strawberry_type_map: Optional[Mapping[Type[TypeEngine], Type[Any]]
239-
] = None,
232+
model_to_type_name: Optional[Callable[[Type[BaseModelType]], str]] = None,
233+
model_to_interface_name: Optional[Callable[[Type[BaseModelType]], str]] = None,
234+
extra_sqlalchemy_type_to_strawberry_type_map: Optional[
235+
Mapping[Type[TypeEngine], Type[Any]]
236+
] = None,
240237
) -> None:
241238
if model_to_type_name is None:
242239
model_to_type_name = self._default_model_to_type_name
@@ -299,8 +296,7 @@ def _edge_type_for(self, type_name: str) -> Type[Any]:
299296
"""
300297
edge_name = f"{type_name}Edge"
301298
if edge_name not in self.edge_types:
302-
lazy_type = StrawberrySQLAlchemyLazy(
303-
type_name=type_name, mapper=self)
299+
lazy_type = StrawberrySQLAlchemyLazy(type_name=type_name, mapper=self)
304300
self.edge_types[edge_name] = edge_type = strawberry.type(
305301
dataclasses.make_dataclass(
306302
edge_name,
@@ -319,15 +315,15 @@ def _connection_type_for(self, type_name: str) -> Type[Any]:
319315
connection_name = f"{type_name}Connection"
320316
if connection_name not in self.connection_types:
321317
edge_type = self._edge_type_for(type_name)
322-
lazy_type = StrawberrySQLAlchemyLazy(
323-
type_name=type_name, mapper=self)
318+
lazy_type = StrawberrySQLAlchemyLazy(type_name=type_name, mapper=self)
324319
self.connection_types[connection_name] = connection_type = strawberry.type(
325320
dataclasses.make_dataclass(
326321
connection_name,
327322
[
328323
("edges", List[edge_type]), # type: ignore[valid-type]
329324
],
330-
bases=(relay.ListConnection[lazy_type],), # type: ignore[valid-type]
325+
# type: ignore[valid-type]
326+
bases=(relay.ListConnection[lazy_type],),
331327
)
332328
)
333329
setattr(connection_type, _GENERATED_FIELD_KEYS_KEY, ["edges"])
@@ -457,8 +453,7 @@ def _get_association_proxy_annotation(
457453
strawberry_type.__forward_arg__
458454
)
459455
else:
460-
strawberry_type = self._connection_type_for(
461-
strawberry_type.__name__)
456+
strawberry_type = self._connection_type_for(strawberry_type.__name__)
462457
return strawberry_type
463458

464459
def make_connection_wrapper_resolver(
@@ -509,24 +504,25 @@ async def resolve(self, info: Info):
509504
else:
510505
if relationship.secondary is None:
511506
relationship_key = tuple(
512-
[
513-
getattr(self, local.key)
514-
for local, _ in relationship.local_remote_pairs or []
515-
if local.key
516-
]
507+
getattr(self, local.key)
508+
for local, _ in relationship.local_remote_pairs or []
509+
if local.key
517510
)
518511
else:
519512
# If has a secondary table, gets only the first ID as additional IDs require a separate query
520513
if not relationship.local_remote_pairs:
521514
raise InvalidLocalRemotePairs(
522-
f"{relationship.entity.entity.__name__} -- {relationship.parent.entity.__name__}")
515+
f"{relationship.entity.entity.__name__} -- {relationship.parent.entity.__name__}"
516+
)
523517

524-
local_remote_pairs_secondary_table_local = relationship.local_remote_pairs[
525-
0][0]
518+
local_remote_pairs_secondary_table_local = (
519+
relationship.local_remote_pairs[0][0]
520+
)
526521
relationship_key = tuple(
527522
[
528523
getattr(
529-
self, str(local_remote_pairs_secondary_table_local.key)),
524+
self, str(local_remote_pairs_secondary_table_local.key)
525+
),
530526
]
531527
)
532528

@@ -560,7 +556,8 @@ def connection_resolver_for(
560556
return self.make_connection_wrapper_resolver(
561557
relationship_resolver,
562558
self.model_to_type_or_interface_name(
563-
relationship.entity.entity), # type: ignore[arg-type]
559+
relationship.entity.entity # type: ignore[arg-type]
560+
),
564561
)
565562
else:
566563
return relationship_resolver
@@ -578,15 +575,13 @@ def association_proxy_resolver_for(
578575
Return an async field resolver for the given association proxy.
579576
"""
580577
in_between_relationship = mapper.relationships[descriptor.target_collection]
581-
in_between_resolver = self.relationship_resolver_for(
582-
in_between_relationship)
578+
in_between_resolver = self.relationship_resolver_for(in_between_relationship)
583579
in_between_mapper: Mapper = mapper.relationships[ # type: ignore[assignment]
584580
descriptor.target_collection
585581
].entity
586582
assert descriptor.value_attr in in_between_mapper.relationships
587583
end_relationship = in_between_mapper.relationships[descriptor.value_attr]
588-
end_relationship_resolver = self.relationship_resolver_for(
589-
end_relationship)
584+
end_relationship_resolver = self.relationship_resolver_for(end_relationship)
590585
end_type_name = self.model_to_type_or_interface_name(
591586
end_relationship.entity.entity # type: ignore[arg-type]
592587
)
@@ -613,8 +608,7 @@ async def resolve(self, info: Info):
613608
if outputs and isinstance(outputs[0], list):
614609
outputs = list(chain.from_iterable(outputs))
615610
else:
616-
outputs = [
617-
output for output in outputs if output is not None]
611+
outputs = [output for output in outputs if output is not None]
618612
else:
619613
outputs = await end_relationship_resolver(in_between_objects, info)
620614
if not isinstance(outputs, collections.abc.Iterable):
@@ -710,8 +704,7 @@ def convert(type_: Any) -> Any:
710704
setattr(type_, key, field(resolver=val))
711705
generated_field_keys.append(key)
712706

713-
self._handle_columns(
714-
mapper, type_, excluded_keys, generated_field_keys)
707+
self._handle_columns(mapper, type_, excluded_keys, generated_field_keys)
715708
relationship: RelationshipProperty
716709
for key, relationship in mapper.relationships.items():
717710
if (
@@ -813,7 +806,8 @@ def convert(type_: Any) -> Any:
813806
# ignore inherited `is_type_of`
814807
if "is_type_of" not in type_.__dict__:
815808
type_.is_type_of = (
816-
lambda obj, info: type(obj) == model or type(obj) == type_
809+
lambda obj, info: type(obj) == model # noqa: E721
810+
or type(obj) == type_ # noqa: E721
817811
)
818812

819813
# Default querying methods for relay
@@ -833,7 +827,8 @@ def convert(type_: Any) -> Any:
833827
setattr(
834828
type_,
835829
attr,
836-
types.MethodType(func, type_), # type: ignore[arg-type]
830+
# type: ignore[arg-type]
831+
types.MethodType(func, type_),
837832
)
838833

839834
# Adjust types that inherit from other types/interfaces that implement Node
@@ -846,8 +841,7 @@ def convert(type_: Any) -> Any:
846841
setattr(
847842
type_,
848843
attr,
849-
types.MethodType(
850-
cast(classmethod, meth).__func__, type_),
844+
types.MethodType(cast(classmethod, meth).__func__, type_),
851845
)
852846

853847
# need to make fields that are already in the type
@@ -875,8 +869,7 @@ def convert(type_: Any) -> Any:
875869
model=model,
876870
),
877871
)
878-
setattr(mapped_type, _GENERATED_FIELD_KEYS_KEY,
879-
generated_field_keys)
872+
setattr(mapped_type, _GENERATED_FIELD_KEYS_KEY, generated_field_keys)
880873
setattr(mapped_type, _ORIGINAL_TYPE_KEY, type_)
881874
return mapped_type
882875

@@ -916,16 +909,14 @@ def _fix_annotation_namespaces(self) -> None:
916909
self.edge_types.values(),
917910
self.connection_types.values(),
918911
):
919-
strawberry_definition = get_object_definition(
920-
mapped_type, strict=True)
912+
strawberry_definition = get_object_definition(mapped_type, strict=True)
921913
for f in strawberry_definition.fields:
922914
if f.name in getattr(mapped_type, _GENERATED_FIELD_KEYS_KEY):
923915
namespace = {}
924916
if hasattr(mapped_type, _ORIGINAL_TYPE_KEY):
925917
namespace.update(
926918
sys.modules[
927-
getattr(mapped_type,
928-
_ORIGINAL_TYPE_KEY).__module__
919+
getattr(mapped_type, _ORIGINAL_TYPE_KEY).__module__
929920
].__dict__
930921
)
931922
namespace.update(self.mapped_types)
@@ -956,8 +947,7 @@ def _map_unmapped_relationships(self) -> None:
956947
if type_name not in self.mapped_interfaces:
957948
unmapped_interface_models.add(model)
958949
for model in unmapped_models:
959-
self.type(model)(
960-
type(self.model_to_type_name(model), (object,), {}))
950+
self.type(model)(type(self.model_to_type_name(model), (object,), {}))
961951
for model in unmapped_interface_models:
962952
self.interface(model)(
963953
type(self.model_to_interface_name(model), (object,), {})

0 commit comments

Comments
 (0)