diff --git a/lego/apps/events/tests/test_events_api.py b/lego/apps/events/tests/test_events_api.py index 5e2d3c675..3d2655fa9 100644 --- a/lego/apps/events/tests/test_events_api.py +++ b/lego/apps/events/tests/test_events_api.py @@ -1,3 +1,4 @@ +import typing from copy import deepcopy from datetime import timedelta from unittest import mock, skipIf @@ -226,6 +227,14 @@ }, ] +_test_payload_error_messages = { + "required_field": "This field is required.", + "not_blank_field": "This field may not be blank.", + "not_null_field": "This field may not be null.", + "valid_choice": "is not valid choice.", + "wrong_format": "Datetime has wrong format.", +} + _test_registration_data = {"user": 1} webhook_secret = getattr(settings, "STRIPE_WEBHOOK_SECRET", None) @@ -509,6 +518,15 @@ class CreateEventsTestCase(BaseAPITestCase): "test_users.yaml", "test_events.yaml", ] + PAYLOAD_FIELDS: typing.ClassVar[list[str]] = [ + "title", + "description", + "text", + "event_type", + "location", + "start_time", + "end_time", + ] def setUp(self): self.abakus_user = User.objects.all().first() @@ -518,6 +536,65 @@ def setUp(self): self.assertEqual(self.event_response.status_code, status.HTTP_201_CREATED) self.event_id = self.event_response.json().pop("id", None) + def test_fields_required_on_create(self) -> None: + response = self.client.post(_get_list_url(), {}) + for field in self.PAYLOAD_FIELDS: + self.assertEqual( + response.data.get(field)[0], + _test_payload_error_messages["required_field"], + ) + + def test_fields_required_on_update(self) -> None: + response = self.client.put(_get_detail_url(self.event_id), {}) + for field in self.PAYLOAD_FIELDS: + self.assertEqual( + response.data.get(field)[0], + _test_payload_error_messages["required_field"], + ) + + def test_empty_fields_in_payload(self) -> None: + test_data = [ + {"method": "post", "is_detail": False}, + {"method": "put", "is_detail": True}, + ] + + _nul_fields = ["title", "text", "location"] + _date_fields = ["start_time", "end_time"] + + for data in test_data: + with self.subTest(): + url = ( + _get_detail_url(self.event_id) + if data["is_detail"] + else _get_list_url() + ) + response = getattr(self.client, data["method"])( + url, + { + "title": "", + "desciption": "", + "text": "", + "event_type": "", + "location": "", + "start_time": "", + "end_time": "", + }, + ) + for field in _nul_fields: + self.assertEqual( + response.data.get(field)[0], + _test_payload_error_messages["not_blank_field"], + ) + for field in _date_fields: + self.assertIn( + _test_payload_error_messages["wrong_format"], + response.data.get(field)[0], + ) + self.assertEqual( + response.data.get("description")[0], + _test_payload_error_messages["required_field"], + ) + def test_event_creation(self): """Test event creation with pools""" self.assertIsNotNone(self.event_id)