Skip to content

Commit a248077

Browse files
Merge pull request #919 from calumbell/918/nested-exclude-query-param
[BUGFIX] `?exclude` query parameter now works on nested fields
2 parents 5d18bf0 + 30200fc commit a248077

2 files changed

Lines changed: 110 additions & 30 deletions

File tree

api_v2/views/mixins/eager_loading_mixin.py

Lines changed: 62 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -2,15 +2,15 @@ class EagerLoadingMixin:
22
"""
33
Mixin to apply eager loading optimisations to a ViewSet.
44
5-
Dynamically applies `selected_related()` for ForeignKey fields and
6-
`prefetch_related()` from ManyToMany/reverse relationships. This improves
7-
query efficiency and prevents N+1 problems
5+
Handles the running of `select_related()` (for ForeignKey fields) and
6+
`prefetch_related()` (from ManyToMany/reverse relationships) queryset methods
7+
to allow developers to solve N+1 problems on Open5e endpoints.
88
99
## Usage
1010
1. Make sure your ViewSet inherits from `EagerLoadingMixin` before its base
1111
class (ie. ReadOnlyModelViewSet).
1212
2. Re-define `select_related_fields` and `prefetch_related_fields` lists on
13-
the child ViewSet to specify relationships to optimise.
13+
the child ViewSet to specify relationships to select related / pre-fetch.
1414
1515
## Usage Example
1616
```
@@ -28,24 +28,69 @@ class CreatureViewSet(EagerLoadingMixin, viewsets.ReadOnlyModelViewSet):
2828
prefetch_related_fields = []
2929

3030
def get_queryset(self):
31+
"""
32+
Builds the queryset with optimised eager loading based on the requested and excluded fields.
33+
"""
3134
queryset = super().get_queryset()
32-
requested_fields = self.request.query_params.get('fields', '').split(',')
33-
filtered_select_fields = self.filter_fields(self.select_related_fields, requested_fields)
34-
filtered_prefetch_fields = self.filter_fields(self.prefetch_related_fields, requested_fields)
35+
36+
# Check fields included or excluded via query parameter. We use this data
37+
# so that we only eagerly load fields actually returned by the API.
38+
requested_fields = self.parse_requested_fields()
39+
excluded_fields = self.parse_excluded_fields()
40+
41+
filtered_select_fields = self.filter_fields(self.select_related_fields, requested_fields, excluded_fields)
42+
filtered_prefetch_fields = self.filter_fields(self.prefetch_related_fields, requested_fields, excluded_fields)
3543

3644
return queryset \
3745
.select_related(*filtered_select_fields) \
3846
.prefetch_related(*filtered_prefetch_fields)
47+
48+
def parse_requested_fields(self):
49+
"""
50+
Parses the 'fields' query param into a list of requested field paths.
51+
"""
52+
requested_fields = self.request.query_params.get('fields', '')
53+
requested_fields = requested_fields.split(',')
54+
requested_fields = [field for field in requested_fields if field]
55+
return requested_fields
56+
57+
def parse_excluded_fields(self):
58+
"""
59+
Parses 'exclude' query params into a flat list of field paths for use in eager loading
60+
"""
61+
excluded_fields = []
62+
for key, value in self.request.query_params.items():
63+
if key == 'exclude':
64+
excluded_fields += value.split(',')
65+
elif key.endswith('__exclude'):
66+
prefix = key.removesuffix('__exclude')
67+
excluded_fields += [f'{prefix}__{field}' for field in value.split(',')]
68+
return excluded_fields
3969

40-
def filter_fields(self, related_fields, requested_fields):
70+
def filter_fields(self, related_fields, requested_fields=None, excluded_fields=None):
4171
"""
42-
Filters'related_fields' according to whether they are included in
43-
'requested_fields'. Used to remove fields from eager loading if they are
44-
not requested (and thus not returned by API), avoiding unnecessary DB calls
72+
Filters 'related_fields' according to whether they are included in
73+
'requested_fields' or 'excluded_fields'. Used to remove fields from eager
74+
loading if they are not returned by API call to avoid unnecessary DB calls
4575
"""
46-
if not any(requested_fields):
47-
return related_fields
48-
return [
49-
related_field for related_field in related_fields
50-
if any(related_field == req or related_field.startswith(req + '__') for req in requested_fields)
51-
]
76+
# avoids mutable default argument issues: set to empty list if param missing
77+
requested_fields = requested_fields or []
78+
excluded_fields = excluded_fields or []
79+
80+
def field_matches(field, targets):
81+
# Returns True if 'field' equals any 'target', or is a child path of one
82+
return any(field == target or field.startswith(target + '__') for target in targets)
83+
84+
if requested_fields:
85+
related_fields = [
86+
related_field for related_field in related_fields
87+
if field_matches(related_field, requested_fields)
88+
]
89+
90+
if excluded_fields:
91+
related_fields = [
92+
related_field for related_field in related_fields
93+
if not field_matches(related_field, excluded_fields)
94+
]
95+
96+
return related_fields
Lines changed: 48 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,23 @@
11
class ExcludeFieldsMixin:
2+
"""
3+
This Mixin supports dynamically excluding returned fields of serializers that
4+
inherit from it via the `?exclude` query parameter.
5+
6+
Syntactically similar to the default `?field` DRF query parameter. Nested
7+
fields are similarly excluded via the '__' operator (see Examples).
8+
9+
## Usage
10+
1. Make sure your ViewSet inherits from `ExcludeFieldsMixin` before its base
11+
class (ie. ReadOnlyModelViewSet).
12+
2. Pass exclude params in the request query string to remove fields from the response.
13+
14+
# Exclude top-level fields
15+
GET /v2/creatures/?exclude=traits,actions
16+
17+
# Exclude nested fields
18+
GET /v2/creatures/?actions__exclude=attacks
19+
"""
20+
221
def get_serializer_class(self):
322

423
# Handle other mixins that might also override get_serializer_class
@@ -7,25 +26,41 @@ def get_serializer_class(self):
726
else:
827
serializer_class = getattr(self, 'serializer_class')
928

10-
# just return the regular serializer if there is no request
29+
# Return base serializer if there is no request. This stops calculation of
30+
# excluded fields for nested serializers, avoiding unnecessary computing.
1131
if not hasattr(self, 'request') or not hasattr(self.request, 'query_params'):
1232
return serializer_class
1333

14-
exclude_fields = self.request.query_params.get('exclude', '').split(',')
34+
# Iterates over params, scans for any 'exclude' or '<field>_exclude' keys
35+
# and builds a dict mapping API field paths to lists of field to remove from each
36+
# e.g. '?exclude=id&document__exclude=permalink' becomes:
37+
# { '': ['id'], 'document': ['permalink'] }
38+
fields_to_exclude = {}
39+
for key, value in self.request.query_params.items():
40+
if key == 'exclude':
41+
fields_to_exclude[''] = value.split(',')
42+
elif key.endswith('__exclude'):
43+
fields_to_exclude[key.removesuffix('__exclude')] = value.split(',')
1544

16-
if not exclude_fields:
45+
if not fields_to_exclude:
1746
return serializer_class
18-
19-
# create a new serializer with 'exclude_fields' removed and return it
47+
48+
# Walks the serializer tree removing fields at each level flagged for removal.
49+
# 'path' tracks where we are in the tree, which we use as a key to index into
50+
# 'fields_to_exclude' to check which fields to remove. We then recurse into
51+
# nested serializers and apply the same logic.
52+
def strip_excluded_fields(fields, path=''):
53+
for excluded_field in fields_to_exclude.get(path, []):
54+
fields.pop(excluded_field, None)
55+
for field_name, field in fields.items():
56+
nested_serializer = getattr(field, 'child', field)
57+
if hasattr(nested_serializer, 'fields'):
58+
nested_path = f'{path}__{field_name}' if path else field_name
59+
strip_excluded_fields(nested_serializer.fields, nested_path)
60+
2061
class DynamicSerializer(serializer_class):
2162
def __init__(self, *args, **kwargs):
2263
super().__init__(*args, **kwargs)
64+
strip_excluded_fields(self.fields)
2365

24-
excluded_fields = []
25-
for field in exclude_fields:
26-
if field in self.fields:
27-
self.fields.pop(field)
28-
excluded_fields.append(field)
29-
30-
return DynamicSerializer
31-
66+
return DynamicSerializer

0 commit comments

Comments
 (0)