Skip to content

Commit 0fd78b9

Browse files
committed
Closes #21263: Prefetch related objects after creating/updating objects via REST API
1 parent 359179f commit 0fd78b9

File tree

2 files changed

+43
-8
lines changed

2 files changed

+43
-8
lines changed

netbox/netbox/api/viewsets/__init__.py

Lines changed: 34 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -170,6 +170,26 @@ def dispatch(self, request, *args, **kwargs):
170170

171171
# Creates
172172

173+
def create(self, request, *args, **kwargs):
174+
serializer = self.get_serializer(data=request.data)
175+
serializer.is_valid(raise_exception=True)
176+
bulk_create = getattr(serializer, 'many', False)
177+
self.perform_create(serializer)
178+
179+
# After creating the instance(s), re-initialize the serializer with a queryset
180+
# to ensure related objects are prefetched.
181+
if bulk_create:
182+
instance_pks = [obj.pk for obj in serializer.instance]
183+
qs = self.get_queryset().filter(pk__in=instance_pks).order_by('pk')
184+
else:
185+
qs = self.get_queryset().get(pk=serializer.instance.pk)
186+
187+
# Re-serialize the instance(s) with prefetched data
188+
serializer = self.get_serializer(qs, many=bulk_create)
189+
190+
headers = self.get_success_headers(serializer.data)
191+
return Response(serializer.data, status=status.HTTP_201_CREATED, headers=headers)
192+
173193
def perform_create(self, serializer):
174194
model = self.queryset.model
175195
logger = logging.getLogger(f'netbox.api.views.{self.__class__.__name__}')
@@ -186,9 +206,20 @@ def perform_create(self, serializer):
186206
# Updates
187207

188208
def update(self, request, *args, **kwargs):
189-
# Hotwire get_object() to ensure we save a pre-change snapshot
190-
self.get_object = self.get_object_with_snapshot
191-
return super().update(request, *args, **kwargs)
209+
partial = kwargs.pop('partial', False)
210+
instance = self.get_object_with_snapshot()
211+
serializer = self.get_serializer(instance, data=request.data, partial=partial)
212+
serializer.is_valid(raise_exception=True)
213+
self.perform_update(serializer)
214+
215+
# After updating the instance, re-initialize the serializer with a queryset
216+
# to ensure related objects are prefetched.
217+
qs = self.get_queryset().get(pk=serializer.instance.pk)
218+
219+
# Re-serialize the instance(s) with prefetched data
220+
serializer = self.get_serializer(qs)
221+
222+
return Response(serializer.data)
192223

193224
def perform_update(self, serializer):
194225
model = self.queryset.model

netbox/netbox/api/viewsets/mixins.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -108,23 +108,27 @@ def bulk_update(self, request, *args, **kwargs):
108108
obj.pop('id'): obj for obj in request.data
109109
}
110110

111-
data = self.perform_bulk_update(qs, update_data, partial=partial)
111+
object_pks = self.perform_bulk_update(qs, update_data, partial=partial)
112112

113-
return Response(data, status=status.HTTP_200_OK)
113+
# Prefetch related objects for all updated instances
114+
qs = self.get_queryset().filter(pk__in=object_pks)
115+
serializer = self.get_serializer(qs, many=True)
116+
117+
return Response(serializer.data, status=status.HTTP_200_OK)
114118

115119
def perform_bulk_update(self, objects, update_data, partial):
120+
updated_pks = []
116121
with transaction.atomic(using=router.db_for_write(self.queryset.model)):
117-
data_list = []
118122
for obj in objects:
119123
data = update_data.get(obj.id)
120124
if hasattr(obj, 'snapshot'):
121125
obj.snapshot()
122126
serializer = self.get_serializer(obj, data=data, partial=partial)
123127
serializer.is_valid(raise_exception=True)
124128
self.perform_update(serializer)
125-
data_list.append(serializer.data)
129+
updated_pks.append(obj.pk)
126130

127-
return data_list
131+
return updated_pks
128132

129133
def bulk_partial_update(self, request, *args, **kwargs):
130134
kwargs['partial'] = True

0 commit comments

Comments
 (0)