@@ -12,18 +12,16 @@ class (ie. ReadOnlyModelViewSet).
1212  2. Re-define `select_related_fields` and `prefetch_related_fields` lists on  
1313  the child ViewSet to specify relationships to optimise. 
1414   
15-   ## Example 
15+   ## Usage  Example 
1616  ``` 
1717    # EagerLoadingMixin inhertired before base-case 
1818    class CreatureViewSet(EagerLoadingMixin, viewsets.ReadOnlyModelViewSet): 
1919      queryset = models.Creature.objects.all().order_by('pk') 
2020      serializer_class = serializers.CreatureSerializer 
2121      filterset_class = CreatureFilterSet 
2222
23-       # ForeignKey relations to optimise with select_related() 
24-       select_related_fields = [] 
25-       # ManyToMany / reverse relations to optimise with prefetch_related() 
26-       prefetch_related_fields = [] 
23+       select_related_fields = []   # ForeignKey relations to optimise with select_related()       
24+       prefetch_related_fields = [] # ManyToMany/reverse relations to optimise with prefetch_related() 
2725  ``` 
2826  """ 
2927
@@ -34,26 +32,32 @@ class CreatureViewSet(EagerLoadingMixin, viewsets.ReadOnlyModelViewSet):
3432  def  get_queryset (self ):
3533    """Override DRF's default get_queryset() method to apply eager loading""" 
3634    queryset  =  super ().get_queryset ()
37-     request  =  self .request 
3835
39-     # Get query parameters 
40-     requested_fields  =  request .query_params .get ('fields' , '' )
41-     depth  =  int (request .query_params .get ('depth' , 0 ))
36+     # Get query parameters from request  
37+     requested_fields  =  self . request .query_params .get ('fields' , ''  ). split ( ', ' )
38+     depth  =  int (self . request .query_params .get ('depth' , 0 ))
4239
43-     if  requested_fields :
44-       requested_fields  =  set (requested_fields .split (',' ))
45-     else :
46-       # If no fields requested, apply all opitmisations 
47-       requested_fields  =  set (self .select_related_fields  +  self .prefetch_related_fields )
48- 
49-     # Filter fields based on on which have been requested 
50-     select_fields  =  [field  for  field  in  self .select_related_fields  if  field  in  requested_fields ]
51-     prefetch_fields  =  [field  for  field  in  self .prefetch_related_fields  if  field  in  requested_fields ]
40+     # if no fields are passed via query param, select/prefetch all fields defined on the view 
41+     if  not  requested_fields :
42+       queryset  =  queryset .select_related (* self .select_related_fields )
43+       queryset  =  queryset .prefetch_related (* self .prefetch_related_fields )
44+       return  queryset 
45+     
46+     # filter selected fields against fields requested by user via query params  
47+     # this stops Django prefetching data that isn't even returned by this view 
48+     select_fields  =  []
49+     for  field_to_select  in  self .select_related_fields :
50+       if  any (field_in_request  in  field_to_select  for  field_in_request  in  requested_fields ):
51+         select_fields .append (field_to_select )
5252
53-     # Apply optimisations 
54-     if  select_fields :
55-       queryset  =  queryset .select_related (* select_fields )
56-     if  prefetch_fields :
57-       queryset  =  queryset .prefetch_related (* prefetch_fields )
53+     # filter prefetch fields against fields requested by user via query params 
54+     # this stops Django prefetching data that isn't even returned by this view 
55+     prefetch_fields  =  []
56+     for  field_to_prefetch  in  self .prefetch_related_fields :
57+       if  any (field_in_request  in  field_to_prefetch  for  field_in_request  in  requested_fields ):
58+         prefetch_fields .append (field_to_prefetch )
5859
60+     # Apply filtered optimisations 
61+     queryset  =  queryset .select_related (* select_fields )
62+     queryset  =  queryset .prefetch_related (* prefetch_fields )
5963    return  queryset 
0 commit comments