Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 32 additions & 0 deletions netbox/dcim/migrations/0226_modulebay_rebuild_tree.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
from django.db import migrations
import mptt.managers
import mptt.models


def rebuild_mptt(apps, schema_editor):
"""
Rebuild the MPTT tree for ModuleBay to apply new ordering.
"""
ModuleBay = apps.get_model('dcim', 'ModuleBay')

# Set MPTTMeta with the correct order_insertion_by
class MPTTMeta:
order_insertion_by = ('module', 'name',)

ModuleBay.MPTTMeta = MPTTMeta
ModuleBay._mptt_meta = mptt.models.MPTTOptions(MPTTMeta)

manager = mptt.managers.TreeManager()
manager.model = ModuleBay
manager.contribute_to_class(ModuleBay, 'objects')
manager.rebuild()


class Migration(migrations.Migration):
dependencies = [
('dcim', '0225_gfk_indexes'),
]

operations = [
migrations.RunPython(code=rebuild_mptt, reverse_code=migrations.RunPython.noop),
]
2 changes: 1 addition & 1 deletion netbox/dcim/models/device_components.py
Original file line number Diff line number Diff line change
Expand Up @@ -1273,7 +1273,7 @@ class Meta(ModularComponentModel.Meta):
verbose_name_plural = _('module bays')

class MPTTMeta:
order_insertion_by = ('module',)
order_insertion_by = ('module', 'name',)

def clean(self):
super().clean()
Expand Down
12 changes: 10 additions & 2 deletions netbox/dcim/models/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from django.db.models.signals import post_save
from django.utils.translation import gettext_lazy as _
from jsonschema.exceptions import ValidationError as JSONValidationError
from mptt.models import MPTTModel

from dcim.choices import *
from dcim.utils import update_interface_bridges
Expand Down Expand Up @@ -329,7 +330,8 @@ def save(self, *args, **kwargs):
component._location = self.device.location
component._rack = self.device.rack

if component_model is not ModuleBay:
# we handle create and update separately - this is for create
if not issubclass(component_model, MPTTModel):
component_model.objects.bulk_create(create_instances)
# Emit the post_save signal for each newly created object
for component in create_instances:
Expand All @@ -342,11 +344,13 @@ def save(self, *args, **kwargs):
update_fields=None
)
else:
# ModuleBays must be saved individually for MPTT
# MPTT models must be saved individually to maintain tree structure
for instance in create_instances:
instance.save()

update_fields = ['module']

# we handle create and update separately - this is for update
component_model.objects.bulk_update(update_instances, update_fields)
# Emit the post_save signal for each updated object
for component in update_instances:
Expand All @@ -359,5 +363,9 @@ def save(self, *args, **kwargs):
update_fields=update_fields
)

# Rebuild MPTT tree if needed (bulk_update bypasses model save)
if issubclass(component_model, MPTTModel) and update_instances:
component_model.objects.rebuild()

# Interface bridges have to be set after interface instantiation
update_interface_bridges(self.device, self.module_type.interfacetemplates, self)
70 changes: 45 additions & 25 deletions netbox/netbox/views/generic/bulk_views.py
Original file line number Diff line number Diff line change
Expand Up @@ -438,30 +438,12 @@ def save_object(self, object_form, request):
"""
return object_form.save()

def create_and_update_objects(self, form, request):
def _process_import_records(self, form, request, records, prefetched_objects):
"""
Process CSV import records and save objects.
"""
saved_objects = []

records = list(form.cleaned_data['data'])

# Prefetch objects to be updated, if any
prefetch_ids = [int(record['id']) for record in records if record.get('id')]

# check for duplicate IDs
duplicate_pks = [pk for pk, count in Counter(prefetch_ids).items() if count > 1]
if duplicate_pks:
error_msg = _(
"Duplicate objects found: {model} with ID(s) {ids} appears multiple times"
).format(
model=title(self.queryset.model._meta.verbose_name),
ids=', '.join(str(pk) for pk in sorted(duplicate_pks))
)
raise ValidationError(error_msg)

prefetched_objects = {
obj.pk: obj
for obj in self.queryset.model.objects.filter(id__in=prefetch_ids)
} if prefetch_ids else {}

for i, record in enumerate(records, start=1):
object_id = int(record.pop('id')) if record.get('id') else None

Expand Down Expand Up @@ -526,6 +508,37 @@ def create_and_update_objects(self, form, request):

return saved_objects

def create_and_update_objects(self, form, request):
records = list(form.cleaned_data['data'])

# Prefetch objects to be updated, if any
prefetch_ids = [int(record['id']) for record in records if record.get('id')]

# check for duplicate IDs
duplicate_pks = [pk for pk, count in Counter(prefetch_ids).items() if count > 1]
if duplicate_pks:
error_msg = _(
"Duplicate objects found: {model} with ID(s) {ids} appears multiple times"
).format(
model=title(self.queryset.model._meta.verbose_name),
ids=', '.join(str(pk) for pk in sorted(duplicate_pks))
)
raise ValidationError(error_msg)

prefetched_objects = {
obj.pk: obj
for obj in self.queryset.model.objects.filter(id__in=prefetch_ids)
} if prefetch_ids else {}

# For MPTT models, delay tree updates until all saves are complete
if issubclass(self.queryset.model, MPTTModel):
with self.queryset.model.objects.delay_mptt_updates():
saved_objects = self._process_import_records(form, request, records, prefetched_objects)
else:
saved_objects = self._process_import_records(form, request, records, prefetched_objects)

return saved_objects

#
# Request handlers
#
Expand Down Expand Up @@ -895,9 +908,16 @@ def post(self, request):
renamed_pks = self._rename_objects(form, selected_objects)

if '_apply' in request.POST:
for obj in selected_objects:
setattr(obj, self.field_name, obj.new_name)
obj.save()
# For MPTT models, delay tree updates until all saves are complete
if issubclass(self.queryset.model, MPTTModel):
with self.queryset.model.objects.delay_mptt_updates():
for obj in selected_objects:
setattr(obj, self.field_name, obj.new_name)
obj.save()
else:
for obj in selected_objects:
setattr(obj, self.field_name, obj.new_name)
obj.save()

# Enforce constrained permissions
if self.queryset.filter(pk__in=renamed_pks).count() != len(selected_objects):
Expand Down