Skip to content

Commit 2bc8eb7

Browse files
committed
fixed bug on EagerLoadingMixin
1 parent 8a3f15a commit 2bc8eb7

File tree

1 file changed

+27
-23
lines changed

1 file changed

+27
-23
lines changed

api_v2/views/mixins.py

Lines changed: 27 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -12,18 +12,16 @@ class (ie. ReadOnlyModelViewSet).
1212
2. Re-define `select_related_fields` and `prefetch_related_fields` lists on
1313
the child ViewSet to specify relationships to optimise.
1414
15-
## Example
15+
## Usage Example
1616
```
1717
# EagerLoadingMixin inhertired before base-case
1818
class CreatureViewSet(EagerLoadingMixin, viewsets.ReadOnlyModelViewSet):
1919
queryset = models.Creature.objects.all().order_by('pk')
2020
serializer_class = serializers.CreatureSerializer
2121
filterset_class = CreatureFilterSet
2222
23-
# ForeignKey relations to optimise with select_related()
24-
select_related_fields = []
25-
# ManyToMany / reverse relations to optimise with prefetch_related()
26-
prefetch_related_fields = []
23+
select_related_fields = [] # ForeignKey relations to optimise with select_related()
24+
prefetch_related_fields = [] # ManyToMany/reverse relations to optimise with prefetch_related()
2725
```
2826
"""
2927

@@ -34,26 +32,32 @@ class CreatureViewSet(EagerLoadingMixin, viewsets.ReadOnlyModelViewSet):
3432
def get_queryset(self):
3533
"""Override DRF's default get_queryset() method to apply eager loading"""
3634
queryset = super().get_queryset()
37-
request = self.request
3835

39-
# Get query parameters
40-
requested_fields = request.query_params.get('fields', '')
41-
depth = int(request.query_params.get('depth', 0))
36+
# Get query parameters from request
37+
requested_fields = self.request.query_params.get('fields', '').split(',')
38+
depth = int(self.request.query_params.get('depth', 0))
4239

43-
if requested_fields:
44-
requested_fields = set(requested_fields.split(','))
45-
else:
46-
# If no fields requested, apply all opitmisations
47-
requested_fields = set(self.select_related_fields + self.prefetch_related_fields)
48-
49-
# Filter fields based on on which have been requested
50-
select_fields = [field for field in self.select_related_fields if field in requested_fields]
51-
prefetch_fields = [field for field in self.prefetch_related_fields if field in requested_fields]
40+
# if no fields are passed via query param, select/prefetch all fields defined on the view
41+
if not requested_fields:
42+
queryset = queryset.select_related(*self.select_related_fields)
43+
queryset = queryset.prefetch_related(*self.prefetch_related_fields)
44+
return queryset
45+
46+
# filter selected fields against fields requested by user via query params
47+
# this stops Django prefetching data that isn't even returned by this view
48+
select_fields = []
49+
for field_to_select in self.select_related_fields:
50+
if any(field_in_request in field_to_select for field_in_request in requested_fields):
51+
select_fields.append(field_to_select)
5252

53-
# Apply optimisations
54-
if select_fields:
55-
queryset = queryset.select_related(*select_fields)
56-
if prefetch_fields:
57-
queryset = queryset.prefetch_related(*prefetch_fields)
53+
# filter prefetch fields against fields requested by user via query params
54+
# this stops Django prefetching data that isn't even returned by this view
55+
prefetch_fields = []
56+
for field_to_prefetch in self.prefetch_related_fields:
57+
if any(field_in_request in field_to_prefetch for field_in_request in requested_fields):
58+
prefetch_fields.append(field_to_prefetch)
5859

60+
# Apply filtered optimisations
61+
queryset = queryset.select_related(*select_fields)
62+
queryset = queryset.prefetch_related(*prefetch_fields)
5963
return queryset

0 commit comments

Comments
 (0)