@@ -2,15 +2,15 @@ class EagerLoadingMixin:
22 """
33 Mixin to apply eager loading optimisations to a ViewSet.
44
5- Dynamically applies `selected_related ()` for ForeignKey fields and
6- `prefetch_related()` from ManyToMany/reverse relationships. This improves
7- query efficiency and prevents N+1 problems
5+ Handles the running of `select_related ()` ( for ForeignKey fields) and
6+ `prefetch_related()` ( from ManyToMany/reverse relationships) queryset methods
7+ to allow developers to solve N+1 problems on Open5e endpoints.
88
99 ## Usage
1010 1. Make sure your ViewSet inherits from `EagerLoadingMixin` before its base
1111 class (ie. ReadOnlyModelViewSet).
1212 2. Re-define `select_related_fields` and `prefetch_related_fields` lists on
13- the child ViewSet to specify relationships to optimise .
13+ the child ViewSet to specify relationships to select related / pre-fetch .
1414
1515 ## Usage Example
1616 ```
@@ -28,24 +28,69 @@ class CreatureViewSet(EagerLoadingMixin, viewsets.ReadOnlyModelViewSet):
2828 prefetch_related_fields = []
2929
3030 def get_queryset (self ):
31+ """
32+ Builds the queryset with optimised eager loading based on the requested and excluded fields.
33+ """
3134 queryset = super ().get_queryset ()
32- requested_fields = self .request .query_params .get ('fields' , '' ).split (',' )
33- filtered_select_fields = self .filter_fields (self .select_related_fields , requested_fields )
34- filtered_prefetch_fields = self .filter_fields (self .prefetch_related_fields , requested_fields )
35+
36+ # Check fields included or excluded via query parameter. We use this data
37+ # so that we only eagerly load fields actually returned by the API.
38+ requested_fields = self .parse_requested_fields ()
39+ excluded_fields = self .parse_excluded_fields ()
40+
41+ filtered_select_fields = self .filter_fields (self .select_related_fields , requested_fields , excluded_fields )
42+ filtered_prefetch_fields = self .filter_fields (self .prefetch_related_fields , requested_fields , excluded_fields )
3543
3644 return queryset \
3745 .select_related (* filtered_select_fields ) \
3846 .prefetch_related (* filtered_prefetch_fields )
47+
48+ def parse_requested_fields (self ):
49+ """
50+ Parses the 'fields' query param into a list of requested field paths.
51+ """
52+ requested_fields = self .request .query_params .get ('fields' , '' )
53+ requested_fields = requested_fields .split (',' )
54+ requested_fields = [field for field in requested_fields if field ]
55+ return requested_fields
56+
57+ def parse_excluded_fields (self ):
58+ """
59+ Parses 'exclude' query params into a flat list of field paths for use in eager loading
60+ """
61+ excluded_fields = []
62+ for key , value in self .request .query_params .items ():
63+ if key == 'exclude' :
64+ excluded_fields += value .split (',' )
65+ elif key .endswith ('__exclude' ):
66+ prefix = key .removesuffix ('__exclude' )
67+ excluded_fields += [f'{ prefix } __{ field } ' for field in value .split (',' )]
68+ return excluded_fields
3969
40- def filter_fields (self , related_fields , requested_fields ):
70+ def filter_fields (self , related_fields , requested_fields = None , excluded_fields = None ):
4171 """
42- Filters'related_fields' according to whether they are included in
43- 'requested_fields'. Used to remove fields from eager loading if they are
44- not requested (and thus not returned by API), avoiding unnecessary DB calls
72+ Filters 'related_fields' according to whether they are included in
73+ 'requested_fields' or 'excluded_fields' . Used to remove fields from eager
74+ loading if they are not returned by API call to avoid unnecessary DB calls
4575 """
46- if not any (requested_fields ):
47- return related_fields
48- return [
49- related_field for related_field in related_fields
50- if any (related_field == req or related_field .startswith (req + '__' ) for req in requested_fields )
51- ]
76+ # avoids mutable default argument issues: set to empty list if param missing
77+ requested_fields = requested_fields or []
78+ excluded_fields = excluded_fields or []
79+
80+ def field_matches (field , targets ):
81+ # Returns True if 'field' equals any 'target', or is a child path of one
82+ return any (field == target or field .startswith (target + '__' ) for target in targets )
83+
84+ if requested_fields :
85+ related_fields = [
86+ related_field for related_field in related_fields
87+ if field_matches (related_field , requested_fields )
88+ ]
89+
90+ if excluded_fields :
91+ related_fields = [
92+ related_field for related_field in related_fields
93+ if not field_matches (related_field , excluded_fields )
94+ ]
95+
96+ return related_fields
0 commit comments