@@ -155,8 +155,7 @@ def from_type(cls, type_: type, *, strict: Literal[True]) -> Self: ...
155
155
156
156
@overload
157
157
@classmethod
158
- def from_type (cls , type_ : type , * ,
159
- strict : bool = False ) -> Optional [Self ]: ...
158
+ def from_type (cls , type_ : type , * , strict : bool = False ) -> Optional [Self ]: ...
160
159
161
160
@classmethod
162
161
def from_type (
@@ -167,8 +166,7 @@ def from_type(
167
166
) -> Optional [Self ]:
168
167
definition = getattr (type_ , cls .TYPE_KEY_NAME , None )
169
168
if strict and definition is None :
170
- raise TypeError (
171
- f"{ type_ !r} does not have a StrawberrySQLAlchemyType in it" )
169
+ raise TypeError (f"{ type_ !r} does not have a StrawberrySQLAlchemyType in it" )
172
170
return definition
173
171
174
172
@@ -231,12 +229,11 @@ class StrawberrySQLAlchemyMapper(Generic[BaseModelType]):
231
229
232
230
def __init__ (
233
231
self ,
234
- model_to_type_name : Optional [Callable [[
235
- Type [BaseModelType ]], str ]] = None ,
236
- model_to_interface_name : Optional [Callable [[
237
- Type [BaseModelType ]], str ]] = None ,
238
- extra_sqlalchemy_type_to_strawberry_type_map : Optional [Mapping [Type [TypeEngine ], Type [Any ]]
239
- ] = None ,
232
+ model_to_type_name : Optional [Callable [[Type [BaseModelType ]], str ]] = None ,
233
+ model_to_interface_name : Optional [Callable [[Type [BaseModelType ]], str ]] = None ,
234
+ extra_sqlalchemy_type_to_strawberry_type_map : Optional [
235
+ Mapping [Type [TypeEngine ], Type [Any ]]
236
+ ] = None ,
240
237
) -> None :
241
238
if model_to_type_name is None :
242
239
model_to_type_name = self ._default_model_to_type_name
@@ -299,8 +296,7 @@ def _edge_type_for(self, type_name: str) -> Type[Any]:
299
296
"""
300
297
edge_name = f"{ type_name } Edge"
301
298
if edge_name not in self .edge_types :
302
- lazy_type = StrawberrySQLAlchemyLazy (
303
- type_name = type_name , mapper = self )
299
+ lazy_type = StrawberrySQLAlchemyLazy (type_name = type_name , mapper = self )
304
300
self .edge_types [edge_name ] = edge_type = strawberry .type (
305
301
dataclasses .make_dataclass (
306
302
edge_name ,
@@ -319,15 +315,15 @@ def _connection_type_for(self, type_name: str) -> Type[Any]:
319
315
connection_name = f"{ type_name } Connection"
320
316
if connection_name not in self .connection_types :
321
317
edge_type = self ._edge_type_for (type_name )
322
- lazy_type = StrawberrySQLAlchemyLazy (
323
- type_name = type_name , mapper = self )
318
+ lazy_type = StrawberrySQLAlchemyLazy (type_name = type_name , mapper = self )
324
319
self .connection_types [connection_name ] = connection_type = strawberry .type (
325
320
dataclasses .make_dataclass (
326
321
connection_name ,
327
322
[
328
323
("edges" , List [edge_type ]), # type: ignore[valid-type]
329
324
],
330
- bases = (relay .ListConnection [lazy_type ],), # type: ignore[valid-type]
325
+ # type: ignore[valid-type]
326
+ bases = (relay .ListConnection [lazy_type ],),
331
327
)
332
328
)
333
329
setattr (connection_type , _GENERATED_FIELD_KEYS_KEY , ["edges" ])
@@ -457,8 +453,7 @@ def _get_association_proxy_annotation(
457
453
strawberry_type .__forward_arg__
458
454
)
459
455
else :
460
- strawberry_type = self ._connection_type_for (
461
- strawberry_type .__name__ )
456
+ strawberry_type = self ._connection_type_for (strawberry_type .__name__ )
462
457
return strawberry_type
463
458
464
459
def make_connection_wrapper_resolver (
@@ -509,24 +504,25 @@ async def resolve(self, info: Info):
509
504
else :
510
505
if relationship .secondary is None :
511
506
relationship_key = tuple (
512
- [
513
- getattr (self , local .key )
514
- for local , _ in relationship .local_remote_pairs or []
515
- if local .key
516
- ]
507
+ getattr (self , local .key )
508
+ for local , _ in relationship .local_remote_pairs or []
509
+ if local .key
517
510
)
518
511
else :
519
512
# If has a secondary table, gets only the first ID as additional IDs require a separate query
520
513
if not relationship .local_remote_pairs :
521
514
raise InvalidLocalRemotePairs (
522
- f"{ relationship .entity .entity .__name__ } -- { relationship .parent .entity .__name__ } " )
515
+ f"{ relationship .entity .entity .__name__ } -- { relationship .parent .entity .__name__ } "
516
+ )
523
517
524
- local_remote_pairs_secondary_table_local = relationship .local_remote_pairs [
525
- 0 ][0 ]
518
+ local_remote_pairs_secondary_table_local = (
519
+ relationship .local_remote_pairs [0 ][0 ]
520
+ )
526
521
relationship_key = tuple (
527
522
[
528
523
getattr (
529
- self , str (local_remote_pairs_secondary_table_local .key )),
524
+ self , str (local_remote_pairs_secondary_table_local .key )
525
+ ),
530
526
]
531
527
)
532
528
@@ -560,7 +556,8 @@ def connection_resolver_for(
560
556
return self .make_connection_wrapper_resolver (
561
557
relationship_resolver ,
562
558
self .model_to_type_or_interface_name (
563
- relationship .entity .entity ), # type: ignore[arg-type]
559
+ relationship .entity .entity # type: ignore[arg-type]
560
+ ),
564
561
)
565
562
else :
566
563
return relationship_resolver
@@ -578,15 +575,13 @@ def association_proxy_resolver_for(
578
575
Return an async field resolver for the given association proxy.
579
576
"""
580
577
in_between_relationship = mapper .relationships [descriptor .target_collection ]
581
- in_between_resolver = self .relationship_resolver_for (
582
- in_between_relationship )
578
+ in_between_resolver = self .relationship_resolver_for (in_between_relationship )
583
579
in_between_mapper : Mapper = mapper .relationships [ # type: ignore[assignment]
584
580
descriptor .target_collection
585
581
].entity
586
582
assert descriptor .value_attr in in_between_mapper .relationships
587
583
end_relationship = in_between_mapper .relationships [descriptor .value_attr ]
588
- end_relationship_resolver = self .relationship_resolver_for (
589
- end_relationship )
584
+ end_relationship_resolver = self .relationship_resolver_for (end_relationship )
590
585
end_type_name = self .model_to_type_or_interface_name (
591
586
end_relationship .entity .entity # type: ignore[arg-type]
592
587
)
@@ -613,8 +608,7 @@ async def resolve(self, info: Info):
613
608
if outputs and isinstance (outputs [0 ], list ):
614
609
outputs = list (chain .from_iterable (outputs ))
615
610
else :
616
- outputs = [
617
- output for output in outputs if output is not None ]
611
+ outputs = [output for output in outputs if output is not None ]
618
612
else :
619
613
outputs = await end_relationship_resolver (in_between_objects , info )
620
614
if not isinstance (outputs , collections .abc .Iterable ):
@@ -710,8 +704,7 @@ def convert(type_: Any) -> Any:
710
704
setattr (type_ , key , field (resolver = val ))
711
705
generated_field_keys .append (key )
712
706
713
- self ._handle_columns (
714
- mapper , type_ , excluded_keys , generated_field_keys )
707
+ self ._handle_columns (mapper , type_ , excluded_keys , generated_field_keys )
715
708
relationship : RelationshipProperty
716
709
for key , relationship in mapper .relationships .items ():
717
710
if (
@@ -813,7 +806,8 @@ def convert(type_: Any) -> Any:
813
806
# ignore inherited `is_type_of`
814
807
if "is_type_of" not in type_ .__dict__ :
815
808
type_ .is_type_of = (
816
- lambda obj , info : type (obj ) == model or type (obj ) == type_
809
+ lambda obj , info : type (obj ) == model # noqa: E721
810
+ or type (obj ) == type_ # noqa: E721
817
811
)
818
812
819
813
# Default querying methods for relay
@@ -833,7 +827,8 @@ def convert(type_: Any) -> Any:
833
827
setattr (
834
828
type_ ,
835
829
attr ,
836
- types .MethodType (func , type_ ), # type: ignore[arg-type]
830
+ # type: ignore[arg-type]
831
+ types .MethodType (func , type_ ),
837
832
)
838
833
839
834
# Adjust types that inherit from other types/interfaces that implement Node
@@ -846,8 +841,7 @@ def convert(type_: Any) -> Any:
846
841
setattr (
847
842
type_ ,
848
843
attr ,
849
- types .MethodType (
850
- cast (classmethod , meth ).__func__ , type_ ),
844
+ types .MethodType (cast (classmethod , meth ).__func__ , type_ ),
851
845
)
852
846
853
847
# need to make fields that are already in the type
@@ -875,8 +869,7 @@ def convert(type_: Any) -> Any:
875
869
model = model ,
876
870
),
877
871
)
878
- setattr (mapped_type , _GENERATED_FIELD_KEYS_KEY ,
879
- generated_field_keys )
872
+ setattr (mapped_type , _GENERATED_FIELD_KEYS_KEY , generated_field_keys )
880
873
setattr (mapped_type , _ORIGINAL_TYPE_KEY , type_ )
881
874
return mapped_type
882
875
@@ -916,16 +909,14 @@ def _fix_annotation_namespaces(self) -> None:
916
909
self .edge_types .values (),
917
910
self .connection_types .values (),
918
911
):
919
- strawberry_definition = get_object_definition (
920
- mapped_type , strict = True )
912
+ strawberry_definition = get_object_definition (mapped_type , strict = True )
921
913
for f in strawberry_definition .fields :
922
914
if f .name in getattr (mapped_type , _GENERATED_FIELD_KEYS_KEY ):
923
915
namespace = {}
924
916
if hasattr (mapped_type , _ORIGINAL_TYPE_KEY ):
925
917
namespace .update (
926
918
sys .modules [
927
- getattr (mapped_type ,
928
- _ORIGINAL_TYPE_KEY ).__module__
919
+ getattr (mapped_type , _ORIGINAL_TYPE_KEY ).__module__
929
920
].__dict__
930
921
)
931
922
namespace .update (self .mapped_types )
@@ -956,8 +947,7 @@ def _map_unmapped_relationships(self) -> None:
956
947
if type_name not in self .mapped_interfaces :
957
948
unmapped_interface_models .add (model )
958
949
for model in unmapped_models :
959
- self .type (model )(
960
- type (self .model_to_type_name (model ), (object ,), {}))
950
+ self .type (model )(type (self .model_to_type_name (model ), (object ,), {}))
961
951
for model in unmapped_interface_models :
962
952
self .interface (model )(
963
953
type (self .model_to_interface_name (model ), (object ,), {})
0 commit comments