diff --git a/openwisp_firmware_upgrader/api/serializers.py b/openwisp_firmware_upgrader/api/serializers.py index 112648d3..b498bb94 100644 --- a/openwisp_firmware_upgrader/api/serializers.py +++ b/openwisp_firmware_upgrader/api/serializers.py @@ -84,13 +84,30 @@ class Meta(BaseMeta): class UpgradeOperationSerializer(serializers.ModelSerializer): class Meta: model = UpgradeOperation - fields = ("id", "device", "image", "status", "log", "modified", "created") + fields = ( + "id", + "device", + "image", + "status", + "log", + "modified", + "created", + "upgrade_options", + ) class DeviceUpgradeOperationSerializer(serializers.ModelSerializer): class Meta: model = UpgradeOperation - fields = ("id", "device", "image", "status", "log", "modified") + fields = ( + "id", + "device", + "image", + "status", + "log", + "modified", + "upgrade_options", + ) class BatchUpgradeOperationListSerializer(BaseSerializer): @@ -116,9 +133,11 @@ class Meta: class DeviceFirmwareSerializer(ValidatedModelSerializer): + upgrade_options = serializers.JSONField(required=False, allow_null=True) + class Meta: model = DeviceFirmware - fields = ("id", "image", "installed", "modified") + fields = ("id", "image", "installed", "modified", "upgrade_options") read_only_fields = ("installed", "modified") def validate(self, data): @@ -142,8 +161,58 @@ def validate(self, data): ) } ) + # Validate upgrade_options if provided + upgrade_options = data.get("upgrade_options") + if upgrade_options is not None: + # Create a temporary UpgradeOperation to validate upgrade_options + # This will trigger the model's validation logic + temp_operation = UpgradeOperation( + device=device, + image=image, + upgrade_options=upgrade_options or {}, + ) + try: + temp_operation.validate_upgrade_options() + except ValidationError as e: + raise serializers.ValidationError({"upgrade_options": e.messages}) return super().validate(data) + def to_representation(self, instance): + """ + Include upgrade_options from the latest upgrade operation if available. + """ + ret = super().to_representation(instance) + # Get upgrade_options from the latest upgrade operation for this device + try: + latest_operation = instance.device.upgradeoperation_set.latest("created") + ret["upgrade_options"] = latest_operation.upgrade_options + except UpgradeOperation.DoesNotExist: + ret["upgrade_options"] = {} + return ret + + def create(self, validated_data): + """ + Extract upgrade_options from validated_data and pass it to model.save() + """ + upgrade_options = validated_data.pop("upgrade_options", None) + if upgrade_options is None: + upgrade_options = {} + instance = DeviceFirmware(**validated_data) + instance.save(upgrade_options=upgrade_options) + return instance + + def update(self, instance, validated_data): + """ + Extract upgrade_options from validated_data and pass it to model.save() + """ + upgrade_options = validated_data.pop("upgrade_options", None) + if upgrade_options is None: + upgrade_options = {} + for attr, value in validated_data.items(): + setattr(instance, attr, value) + instance.save(upgrade_options=upgrade_options) + return instance + def _get_device_object(self, device_id): try: device = Device.objects.get(id=device_id) diff --git a/openwisp_firmware_upgrader/api/views.py b/openwisp_firmware_upgrader/api/views.py index 2818fe6c..ec0cf188 100644 --- a/openwisp_firmware_upgrader/api/views.py +++ b/openwisp_firmware_upgrader/api/views.py @@ -1,3 +1,5 @@ +import json + import swapper from django.core.exceptions import ValidationError from django.http import Http404 @@ -86,9 +88,36 @@ def post(self, request, pk): """ Upgrades all the devices related to the specified build ID. """ - upgrade_all = request.POST.get("upgrade_all") is not None + upgrade_all = ( + request.data.get("upgrade_all") is not None + or request.POST.get("upgrade_all") is not None + ) + upgrade_options = request.data.get("upgrade_options") + # If not in request.data, try request.POST (for form data) + if upgrade_options is None: + upgrade_options = request.POST.get("upgrade_options") + # Parse upgrade_options if it's a string (from form data) + if isinstance(upgrade_options, str): + try: + upgrade_options = json.loads(upgrade_options) + except (json.JSONDecodeError, ValueError): + upgrade_options = {} + if upgrade_options is None: + upgrade_options = {} instance = self.get_object() - batch = instance.batch_upgrade(firmwareless=upgrade_all) + # Validate upgrade_options by creating a temporary BatchUpgradeOperation + temp_batch = BatchUpgradeOperation( + build=instance, upgrade_options=upgrade_options + ) + try: + temp_batch.full_clean() + except ValidationError as e: + return Response( + {"upgrade_options": e.messages}, status=status.HTTP_400_BAD_REQUEST + ) + batch = instance.batch_upgrade( + firmwareless=upgrade_all, upgrade_options=upgrade_options + ) return Response({"batch": str(batch.pk)}, status=201) def get(self, request, pk): diff --git a/openwisp_firmware_upgrader/base/models.py b/openwisp_firmware_upgrader/base/models.py index defd3226..36193b08 100644 --- a/openwisp_firmware_upgrader/base/models.py +++ b/openwisp_firmware_upgrader/base/models.py @@ -55,12 +55,15 @@ class Meta: def validate_upgrade_options(self): if not self.upgrade_options: return - if not getattr(self.upgrader_class, "SCHEMA"): + upgrader_class = self.upgrader_class + if not upgrader_class: + return + if not getattr(upgrader_class, "SCHEMA", None): raise ValidationError( _("Using upgrade options is not allowed with this upgrader.") ) try: - self.upgrader_class.validate_upgrade_options(self.upgrade_options) + upgrader_class.validate_upgrade_options(self.upgrade_options) except jsonschema.ValidationError: raise ValidationError("The upgrade options are invalid") except FirmwareUpgradeOptionsException as error: diff --git a/openwisp_firmware_upgrader/tests/test_api.py b/openwisp_firmware_upgrader/tests/test_api.py index fb40ed59..08ddc810 100644 --- a/openwisp_firmware_upgrader/tests/test_api.py +++ b/openwisp_firmware_upgrader/tests/test_api.py @@ -321,7 +321,7 @@ def test_api_batch_upgrade(self): self.assertEqual(DeviceFirmware.objects.count(), 0) with self.subTest("Existing build"): url = reverse("upgrader:api_build_batch_upgrade", args=[build.pk]) - with self.assertNumQueries(8): + with self.assertNumQueries(10): r = self.client.post(url) self.assertEqual(BatchUpgradeOperation.objects.count(), 1) self.assertEqual(DeviceFirmware.objects.count(), 0) @@ -403,7 +403,7 @@ def test_api_shared_build_batch_upgrade(self): ) with self.subTest("Test superuser can mass upgrade shared build"): - with self.assertNumQueries(5): + with self.assertNumQueries(7): response = self.client.post(path) self.assertEqual(response.status_code, 201) batch = BatchUpgradeOperation.objects.first() @@ -413,7 +413,35 @@ def test_build_upgradeable_404(self): url = reverse("upgrader:api_build_batch_upgrade", args=[uuid.uuid4()]) with self.assertNumQueries(4): r = self.client.get(url) - self.assertEqual(r.status_code, 404) + self.assertEqual(r.status_code, 404) + + def test_api_batch_upgrade_with_upgrade_options(self): + """Test batch upgrade accepts upgrade_options""" + build = self._create_build() + url = reverse("upgrader:api_build_batch_upgrade", args=[build.pk]) + upgrade_options = {"c": True, "F": True} + r = self.client.post( + url, {"upgrade_options": upgrade_options}, content_type="application/json" + ) + self.assertEqual(r.status_code, 201) + batch = BatchUpgradeOperation.objects.first() + self.assertEqual(batch.upgrade_options, upgrade_options) + + def test_api_batch_upgrade_validates_upgrade_options(self): + """Test batch upgrade validates upgrade_options""" + # Create a build with a device so upgrader_class can be determined + env = self._create_upgrade_env() + build = env["build2"] + url = reverse("upgrader:api_build_batch_upgrade", args=[build.pk]) + # -n and -c are mutually exclusive + invalid_upgrade_options = {"c": True, "n": True} + r = self.client.post( + url, + {"upgrade_options": invalid_upgrade_options}, + content_type="application/json", + ) + self.assertEqual(r.status_code, 400) + self.assertIn("upgrade_options", r.data) self.assertEqual(BatchUpgradeOperation.objects.count(), 0) @@ -794,6 +822,18 @@ def test_batchupgradeoperation_view(self): r = self.client.get(url) self.assertEqual(r.data, serialized) + def test_batchupgradeoperation_includes_upgrade_options(self): + """Test that batch upgrade operation includes upgrade_options in response""" + build = self._create_build() + batch = BatchUpgradeOperation.objects.create( + build=build, upgrade_options={"c": True, "F": True} + ) + url = reverse("upgrader:api_batchupgradeoperation_detail", args=[batch.pk]) + r = self.client.get(url) + self.assertEqual(r.status_code, 200) + self.assertIn("upgrade_options", r.data) + self.assertEqual(r.data["upgrade_options"], {"c": True, "F": True}) + class TestFirmwareImageViews(TestAPIUpgraderMixin, TestCase): def _serialize_image(self, firmware): @@ -1159,7 +1199,7 @@ def test_device_firmware_detail_get(self): url = reverse( "upgrader:api_devicefirmware_detail", args=[device_fw1.device.pk] ) - with self.assertNumQueries(9): + with self.assertNumQueries(13): r = self.client.get(url, {"format": "api"}) self.assertEqual(r.status_code, 200) serializer_detail = self._serialize_device_firmware(device_fw1) @@ -1207,7 +1247,7 @@ def test_device_firmware_detail_create(self): self.assertEqual(DeviceFirmware.objects.count(), 0) self.assertEqual(UpgradeOperation.objects.count(), 0) - with self.assertNumQueries(26): + with self.assertNumQueries(29): data = {"image": image1a.pk} # This API view allows the creation # of new devicefirmware objects with @@ -1245,7 +1285,7 @@ def test_device_firmware_detail_create_shared_image(self): self.assertEqual(UpgradeOperation.objects.count(), 0) self.client.force_login(self.administrator) - with self.assertNumQueries(25): + with self.assertNumQueries(28): data = {"image": shared_image.pk} r = self.client.put( f"{path}?format=api", data, content_type="application/json" @@ -1275,7 +1315,7 @@ def test_device_firmware_detail_update(self): self.assertEqual(DeviceFirmware.objects.count(), 2) self.assertEqual(UpgradeOperation.objects.count(), 0) - with self.assertNumQueries(27): + with self.assertNumQueries(30): data = {"image": image2a.pk} r = self.client.put( f"{url}?format=api", data, content_type="application/json" @@ -1314,7 +1354,7 @@ def test_device_firmware_detail_partial_update(self): self.assertEqual(DeviceFirmware.objects.count(), 2) self.assertEqual(UpgradeOperation.objects.count(), 0) - with self.assertNumQueries(27): + with self.assertNumQueries(31): data = {"image": image2a.pk} r = self.client.patch( f"{url}?format=api", data, content_type="application/json" @@ -1336,6 +1376,41 @@ def test_device_firmware_detail_partial_update(self): self.assertNotIn(f"{image2b}", repsonse) self.assertNotIn(f"{image2}", repsonse) + def test_device_firmware_detail_with_upgrade_options(self): + """Test device firmware update accepts upgrade_options""" + env = self._create_upgrade_env() + device_fw = env["device_fw1"] + new_image = env["image2a"] + url = reverse("upgrader:api_devicefirmware_detail", args=[device_fw.device.pk]) + upgrade_options = {"c": True, "F": True} + r = self.client.put( + url, + {"image": new_image.pk, "upgrade_options": upgrade_options}, + content_type="application/json", + ) + self.assertEqual(r.status_code, 200) + upgrade_operation = UpgradeOperation.objects.first() + self.assertEqual(upgrade_operation.upgrade_options, upgrade_options) + self.assertIn("upgrade_options", r.data) + self.assertEqual(r.data["upgrade_options"], upgrade_options) + + def test_device_firmware_detail_validates_upgrade_options(self): + """Test device firmware update validates upgrade_options""" + env = self._create_upgrade_env() + device_fw = env["device_fw1"] + new_image = env["image2a"] + url = reverse("upgrader:api_devicefirmware_detail", args=[device_fw.device.pk]) + # -n and -c are mutually exclusive + invalid_upgrade_options = {"c": True, "n": True} + r = self.client.put( + url, + {"image": new_image.pk, "upgrade_options": invalid_upgrade_options}, + content_type="application/json", + ) + self.assertEqual(r.status_code, 400) + self.assertIn("upgrade_options", r.data) + self.assertEqual(UpgradeOperation.objects.count(), 0) + def test_device_firmware_detail_multitenancy(self): ( d1, @@ -1349,7 +1424,7 @@ def test_device_firmware_detail_multitenancy(self): with self.subTest("Test device firmware detail org manager"): self._login("org1_manager", "tester") url = reverse("upgrader:api_devicefirmware_detail", args=[d1.pk]) - with self.assertNumQueries(7): + with self.assertNumQueries(8): r = self.client.get(url, {"format": "api"}) self.assertEqual(r.status_code, 200) serializer_detail = self._serialize_device_firmware(device_fw1) @@ -1382,7 +1457,7 @@ def test_device_firmware_detail_multitenancy(self): with self.subTest("Test device firmware detail org admin"): self._login("org_admin", "tester") url = reverse("upgrader:api_devicefirmware_detail", args=[d1.pk]) - with self.assertNumQueries(6): + with self.assertNumQueries(10): r = self.client.get(url, {"format": "api"}) self.assertEqual(r.status_code, 200) serializer_detail = self._serialize_device_firmware(device_fw1) @@ -1390,7 +1465,7 @@ def test_device_firmware_detail_multitenancy(self): self.assertContains(r, f"{image1}") self.assertNotContains(r, f"{image2}") url = reverse("upgrader:api_devicefirmware_detail", args=[d2.pk]) - with self.assertNumQueries(6): + with self.assertNumQueries(10): r = self.client.get(url, {"format": "api"}) self.assertEqual(r.status_code, 200) serializer_detail = self._serialize_device_firmware(device_fw2) @@ -1557,6 +1632,22 @@ def test_device_uo_list_multitenancy(self): serializer_list = self._serialize_device_upgrade_operation(device_uo2) self.assertEqual(r.data["results"], [serializer_list]) + def test_device_upgrade_operation_includes_upgrade_options(self): + """Test that device upgrade operation includes upgrade_options in response""" + device_fw = self._create_device_firmware(upgrade=True) + upgrade_operation = UpgradeOperation.objects.first() + upgrade_operation.upgrade_options = {"c": True, "F": True} + upgrade_operation.save() + url = reverse( + "upgrader:api_deviceupgradeoperation_list", args=[device_fw.device.pk] + ) + r = self.client.get(url) + self.assertEqual(r.status_code, 200) + self.assertGreater(len(r.data["results"]), 0) + result = r.data["results"][0] + self.assertIn("upgrade_options", result) + self.assertEqual(result["upgrade_options"], {"c": True, "F": True}) + class TestUpgradeOperationViews(TestAPIUpgraderMixin, TestCase): def _serialize_upgrade_operation(self, uo, many=False): @@ -1787,6 +1878,20 @@ def test_uo_list_detail_multitenancy(self): serializer_list = self._serialize_upgrade_operation(uo_qs, many=True) self.assertEqual(r.data["results"], serializer_list) + def test_upgrade_operation_includes_upgrade_options(self): + """Test that upgrade operation includes upgrade_options in response""" + self._create_device_firmware(upgrade=True) + upgrade_operation = UpgradeOperation.objects.first() + upgrade_operation.upgrade_options = {"c": True, "F": True} + upgrade_operation.save() + url = reverse( + "upgrader:api_upgradeoperation_detail", args=[upgrade_operation.pk] + ) + r = self.client.get(url) + self.assertEqual(r.status_code, 200) + self.assertIn("upgrade_options", r.data) + self.assertEqual(r.data["upgrade_options"], {"c": True, "F": True}) + class TestOrgAPIMixin(TestAPIUpgraderMixin, TestCase): def _serialize_build(self, build):