Skip to content

Commit a36918d

Browse files
authored
Merge pull request open5e#647 from calumbell/refactor/eager-loading-mixin
Refactor eager loading code into `EagerLoadingMixin`
2 parents 556293b + 9e5a1b8 commit a36918d

File tree

1 file changed

+59
-0
lines changed

1 file changed

+59
-0
lines changed

api_v2/views/mixins.py

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
class EagerLoadingMixin:
2+
"""
3+
Mixin to apply eager loading optimisations to a ViewSet.
4+
5+
Dynamically applies `selected_related()` for ForeignKey fields and
6+
`prefetch_related()` from ManyToMany/reverse relationships. This improves
7+
query efficiency and prevents N+1 problems
8+
9+
## Usage
10+
1. Make sure your ViewSet inherits from `EagerLoadingMixin` before its base
11+
class (ie. ReadOnlyModelViewSet).
12+
2. Re-define `select_related_fields` and `prefetch_related_fields` lists on
13+
the child ViewSet to specify relationships to optimise.
14+
15+
## Example
16+
```
17+
# EagerLoadingMixin inhertired before base-case
18+
class CreatureViewSet(EagerLoadingMixin, viewsets.ReadOnlyModelViewSet):
19+
queryset = models.Creature.objects.all().order_by('pk')
20+
serializer_class = serializers.CreatureSerializer
21+
filterset_class = CreatureFilterSet
22+
23+
# ForeignKey relations to optimise with select_related()
24+
select_related_fields = []
25+
# ManyToMany / reverse relations to optimise with prefetch_related()
26+
prefetch_related_fields = []
27+
```
28+
"""
29+
30+
# Override these lists in child views
31+
select_related_fields = [] # ForeignKeys to optimise
32+
prefetch_related_fields = [] # ManyToMany & reverse relationships to prefetch
33+
34+
def get_queryset(self):
35+
"""Override DRF's default get_queryset() method to apply eager loading"""
36+
queryset = super().get_queryset()
37+
request = self.request
38+
39+
# Get query parameters
40+
requested_fields = request.query_params.get('fields', '')
41+
depth = int(request.query_params.get('depth', 0))
42+
43+
if requested_fields:
44+
requested_fields = set(requested_fields.split(','))
45+
else:
46+
# If no fields requested, apply all opitmisations
47+
requested_fields = set(self.select_related_fields + self.prefetch_related_fields)
48+
49+
# Filter fields based on on which have been requested
50+
select_fields = [field for field in self.select_related_fields if field in requested_fields]
51+
prefetch_fields = [field for field in self.prefetch_related_fields if field in requested_fields]
52+
53+
# Apply optimisations
54+
if select_fields:
55+
queryset = queryset.select_related(*select_fields)
56+
if prefetch_fields:
57+
queryset = queryset.prefetch_related(*prefetch_fields)
58+
59+
return queryset

0 commit comments

Comments
 (0)