Skip to content

Commit 877ac19

Browse files
authored
Merge pull request #265 from Mosquito-Alert/fix_openapi_schema
Fix openapi schema + prepare for openapi generator
2 parents 488b33f + d3ccd94 commit 877ac19

File tree

10 files changed

+142
-76
lines changed

10 files changed

+142
-76
lines changed

api/auth/schema.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,20 @@
1-
from drf_spectacular.contrib.rest_framework_simplejwt import SimpleJWTScheme
1+
from rest_framework import serializers
2+
3+
from drf_spectacular.contrib.rest_framework_simplejwt import SimpleJWTScheme, TokenObtainPairSerializerExtension
4+
from drf_spectacular.utils import inline_serializer
25

36

47
class AppUserJWTAuthentication(SimpleJWTScheme):
58
target_class = "api.auth.authentication.AppUserJWTAuthentication"
9+
10+
class AppUserTokenObtainPairSerializer(TokenObtainPairSerializerExtension):
11+
target_class = "api.auth.serializers.AppUserTokenObtainPairSerializer"
12+
13+
def map_serializer(self, auto_schema, direction):
14+
Fixed = inline_serializer('Fixed', fields={
15+
'uuid': serializers.UUIDField(write_only=True),
16+
'password': serializers.CharField(write_only=True),
17+
'access': serializers.CharField(read_only=True),
18+
'refresh': serializers.CharField(read_only=True),
19+
})
20+
return auto_schema._map_serializer(Fixed, direction)

api/auth/serializers.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99

1010

1111
class AppUserTokenObtainSerializer(TokenObtainSerializer):
12-
uuid = serializers.UUIDField(required=True)
12+
uuid = serializers.UUIDField(write_only=True, required=True)
1313

1414
def __init__(self, *args, **kwargs):
1515
super().__init__(*args, **kwargs)

api/base_serializers.py

Lines changed: 20 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -29,29 +29,36 @@ def __new__(cls, *args, **kwargs):
2929
raise ImproperlyConfigured(
3030
"`{cls}.model` must be a django Model".format(cls=cls.__name__)
3131
)
32-
return super().__new__(cls, *args, **kwargs)
3332

34-
def __init__(self, *args, **kwargs):
35-
super().__init__(*args, **kwargs)
36-
field_value_serializer_mapping = self.field_value_serializer_mapping
37-
self.field_value_serializer_mapping = {}
33+
instance = super().__new__(cls, *args, **kwargs)
34+
35+
field_value_serializer_mapping = cls.field_value_serializer_mapping
36+
instance.field_value_serializer_mapping = {}
3837
for field_value, serializer in field_value_serializer_mapping.items():
39-
if callable(serializer):
38+
if isinstance(serializer, serializers.Serializer):
39+
serializer = serializer
40+
else:
4041
serializer = serializer(*args, **kwargs)
41-
serializer.parent = self
42+
serializer.parent = instance
43+
44+
instance.field_value_serializer_mapping[field_value] = serializer
4245

43-
self.field_value_serializer_mapping[field_value] = serializer
46+
return instance
4447

4548
def to_representation(self, instance):
4649
serializer = self._get_serializer_for_instance(instance=instance)
47-
return serializer.to_representation(instance)
50+
ret = serializer.to_representation(instance)
51+
ret[self.resource_type_field_name] = getattr(instance, self.resource_type_field_name)
52+
return ret
4853

4954
def to_internal_value(self, data):
5055
if self.instance:
5156
serializer = self._get_serializer_for_instance(instance=self.instance)
5257
else:
5358
serializer = self._get_serializer_for_data(data=data)
54-
return serializer.to_internal_value(data=data)
59+
ret = serializer.to_internal_value(data=data)
60+
ret[self.resource_type_field_name] = data[self.resource_type_field_name]
61+
return ret
5562

5663
def create(self, validated_data):
5764
serializer = self._get_serializer_for_data(data=validated_data)
@@ -82,7 +89,9 @@ def run_validation(self, data=empty):
8289
serializer = self._get_serializer_for_instance(instance=self.instance)
8390
else:
8491
serializer = self._get_serializer_for_data(data=data)
85-
return serializer.run_validation(data)
92+
validated_data = serializer.run_validation(data)
93+
validated_data[self.resource_type_field_name] = data[self.resource_type_field_name]
94+
return validated_data
8695

8796
def _get_serializer_for_type(self, type_value):
8897
try:

api/schema.py

Lines changed: 42 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
1+
from drf_spectacular.contrib.rest_polymorphic import PolymorphicSerializerExtension
2+
from drf_spectacular.drainage import warn
13
from drf_spectacular.extensions import (
24
OpenApiSerializerFieldExtension,
3-
OpenApiSerializerExtension,
45
)
56
from drf_spectacular.plumbing import build_object_type, build_basic_type
67
from drf_spectacular.types import OpenApiTypes
7-
from drf_spectacular.plumbing import force_instance
88

99

1010
class PointFieldExtension(OpenApiSerializerFieldExtension):
@@ -20,31 +20,57 @@ def map_serializer_field(self, auto_schema, direction):
2020
required=["latitude", "longitude"] if self.target.required else None
2121
)
2222

23-
class FieldPolymorphicSerializerExtension(OpenApiSerializerExtension):
23+
class FieldPolymorphicSerializerExtension(PolymorphicSerializerExtension):
2424
target_class = "api.base_serializers.FieldPolymorphicSerializer"
2525

26-
def get_name(self):
27-
return self.target.model._meta.model_name
28-
26+
# NOTE: even if the implementation in PolymorphicSerializerExtension works pretty well,
27+
# it causes errors on openapi-generator.
28+
# That is because it's creating a Typed component (see self.build_typed_component())
29+
# And it's a bug in the generator: https://github.com/OpenAPITools/openapi-generator/issues/19261
30+
# See: https://redocly.com/docs/resources/discriminator
2931
def map_serializer(self, auto_schema, direction):
3032
sub_components = []
33+
serializer = self.target
34+
3135
for (
32-
resource_type_field_name,
36+
resource_type_field_name_value,
3337
sub_serializer,
3438
) in self.target.field_value_serializer_mapping.items():
35-
sub_serializer = force_instance(sub_serializer)
36-
resolved = auto_schema.resolve_serializer(sub_serializer, direction)
37-
sub_components.append((resource_type_field_name, resolved.ref))
39+
sub_serializer.partial = serializer.partial
40+
component = auto_schema.resolve_serializer(sub_serializer, direction)
41+
42+
# Define the discriminator field schema
43+
field_schema = build_basic_type(OpenApiTypes.STR)
44+
field_schema['enum'] = [resource_type_field_name_value,] # NOTE: in openapi 3.1 is 'const'
45+
# field_schema['default'] = resource_type_field_name_value
46+
47+
component.schema['properties'] = {
48+
serializer.resource_type_field_name: field_schema,
49+
**component.schema['properties']
50+
}
51+
component.schema['required'].append(serializer.resource_type_field_name)
52+
53+
sub_components.append((resource_type_field_name_value, component.ref))
54+
55+
if not resource_type_field_name_value:
56+
warn(
57+
f'discriminator mapping key is empty for {sub_serializer.__class__}. '
58+
f'this might lead to code generation issues.'
59+
)
60+
61+
one_of_list = []
62+
for _, ref in sub_components:
63+
if ref not in one_of_list:
64+
one_of_list.append(ref)
3865

3966
return {
40-
"oneOf": [schema for _, schema in sub_components],
67+
"oneOf": one_of_list,
4168
"discriminator": {
42-
"propertyName": self.target.resource_type_field_name,
69+
"propertyName": serializer.resource_type_field_name,
4370
"mapping": {
44-
resource_type: schema["$ref"]
45-
for resource_type, schema in sub_components
46-
},
47-
},
71+
resource_type_field_name_value: ref["$ref"]
72+
for resource_type_field_name_value, ref in sub_components},
73+
}
4874
}
4975

5076

api/serializers.py

Lines changed: 29 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,6 @@
66

77
from rest_framework import serializers
88

9-
from drf_spectacular.utils import extend_schema_field
10-
from drf_spectacular.types import OpenApiTypes
11-
129
from drf_extra_fields.geo_fields import PointField
1310

1411
from tigaserver_app.models import (
@@ -103,8 +100,7 @@ class DetailNotificationSerializer(serializers.ModelSerializer):
103100
title = serializers.SerializerMethodField(read_only=True)
104101
body = serializers.SerializerMethodField(read_only=True)
105102

106-
@extend_schema_field(OpenApiTypes.BOOL)
107-
def get_seen(self, obj):
103+
def get_seen(self, obj) -> bool:
108104
user = self.context.get("request").user
109105
if not isinstance(user, TigaUser):
110106
return False
@@ -116,17 +112,15 @@ def get_seen(self, obj):
116112
user=self.context.get("request").user
117113
).exists() or (obj.user == user and obj.acknowledged)
118114

119-
@extend_schema_field(OpenApiTypes.STR)
120-
def get_title(self, obj):
115+
def get_title(self, obj) -> str:
121116
if obj.notification_content is None:
122117
return ""
123118

124119
return obj.notification_content.get_title(
125120
language_code=self.context.get("request").LANGUAGE_CODE
126121
)
127122

128-
@extend_schema_field(OpenApiTypes.STR)
129-
def get_body(self, obj):
123+
def get_body(self, obj) -> str:
130124
if obj.notification_content is None:
131125
return ""
132126

@@ -173,21 +167,29 @@ class BaseNotificationCreateSerializer(serializers.ModelSerializer):
173167
def RECEIVER_TYPE(self):
174168
raise NotImplementedError
175169

176-
receiver_type = WritableSerializerMethodField(
177-
field_class=serializers.ChoiceField,
178-
choices=["user", "topic"],
179-
required=True
180-
)
181-
182170
title_en = serializers.CharField(write_only=True)
183171
body_en = serializers.CharField(write_only=True)
184172

185173
created_at = serializers.DateTimeField(source="date_comment", read_only=True)
186174
expert = serializers.HiddenField(default=serializers.CurrentUserDefault())
187175

188-
@extend_schema_field(OpenApiTypes.STR)
189-
def get_receiver_type(self, obj):
190-
return self.RECEIVER_TYPE
176+
def __init__(self, *args, **kwargs):
177+
# Call the parent constructor first
178+
super().__init__(*args, **kwargs)
179+
180+
# Add a dynamic field
181+
self.fields['receiver_type'] = serializers.ChoiceField(
182+
choices=[self.RECEIVER_TYPE,],
183+
write_only=True,
184+
required=True
185+
)
186+
187+
# Re-order the fields with 'receiver_type' at the start
188+
from collections import OrderedDict
189+
self.fields = OrderedDict(
190+
[('receiver_type', self.fields['receiver_type'])] +
191+
[(key, field) for key, field in self.fields.items() if key != 'receiver_type']
192+
)
191193

192194
@transaction.atomic
193195
def create(self, validated_data):
@@ -203,7 +205,12 @@ def create(self, validated_data):
203205

204206
class Meta:
205207
model = Notification
206-
fields = ("id", "receiver_type", "created_at", "title_en", "body_en")
208+
fields = (
209+
"id",
210+
"created_at",
211+
"title_en",
212+
"body_en"
213+
)
207214

208215

209216
class UserNotificationCreateSerializer(BaseNotificationCreateSerializer):
@@ -235,11 +242,10 @@ def validate(self, data):
235242
return data
236243

237244
class Meta(BaseNotificationCreateSerializer.Meta):
238-
fields = (
239-
"receiver_type",
245+
fields = BaseNotificationCreateSerializer.Meta.fields + (
240246
"report_uuid",
241247
"user_uuid",
242-
) + BaseNotificationCreateSerializer.Meta.fields
248+
)
243249

244250

245251
class TopicNotificationCreateSerializer(BaseNotificationCreateSerializer):
@@ -268,10 +274,7 @@ def create(self, validated_data):
268274
return instance
269275

270276
class Meta(BaseNotificationCreateSerializer.Meta):
271-
fields = (
272-
"receiver_type",
273-
"topic_code",
274-
) + BaseNotificationCreateSerializer.Meta.fields
277+
fields = BaseNotificationCreateSerializer.Meta.fields + ("topic_code", )
275278

276279
#### END NOTIFICATION SERIALIZERS ####
277280

@@ -390,7 +393,6 @@ class Meta:
390393
"short_id",
391394
"user_uuid",
392395
"user",
393-
"type",
394396
"session_id",
395397
"created_at",
396398
"sent_at",

api/tests/integration/notifications/create.tavern.yml

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -75,10 +75,7 @@ stages:
7575
method: "POST"
7676
response:
7777
status_code: 201
78-
json: &response_validation
79-
id: !anyint
80-
receiver_type: "user"
81-
created_at: !re_fullmatch \d{4}-\d{2}-\d{2}T\d{2}:\d{2}:\d{2}.\d{6}Z
78+
json: !force_format_include "{response_data_validation}"
8279
- name: Create topic notification using auth user with permissions
8380
request:
8481
url: "{api_live_url}/{endpoint}/"
@@ -92,6 +89,4 @@ stages:
9289
method: "POST"
9390
response:
9491
status_code: 201
95-
json:
96-
<<: *response_validation
97-
receiver_type: "topic"
92+
json: !force_format_include "{response_data_validation}"

api/tests/integration/notifications/schema.yml

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,3 +6,11 @@ description: Login information for test server
66

77
variables:
88
endpoint: "notifications"
9+
response_data_validation:
10+
id: !anyint
11+
report_uuid: !anything
12+
expert_id: !anything
13+
created_at: !re_fullmatch \d{4}-\d{2}-\d{2}T\d{2}:\d{2}:\d{2}.\d{6}Z
14+
title: !anystr
15+
body: !anystr
16+
seen: !anybool

api/urls.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
from drf_spectacular.settings import spectacular_settings
77
from drf_spectacular.utils import extend_schema
8-
from drf_spectacular.views import SpectacularRedocView, SpectacularAPIView
8+
from drf_spectacular.views import SpectacularRedocView, SpectacularAPIView, SpectacularJSONAPIView
99

1010
from .views import (
1111
UserViewSet,
@@ -51,6 +51,7 @@ def get(self, request, *args, **kwargs):
5151
api_urlpatterns += router.urls
5252

5353
urlpatterns = [
54-
path("schema/", SpectacularAPIView.as_view(), name="schema"),
54+
path("openapi.yml", SpectacularAPIView.as_view(), name="schema"),
55+
path("openapi.json", SpectacularJSONAPIView.as_view(), name="schema-json"),
5556
re_path("$^", CustomRedocView.as_view(url_name="schema"), name="redoc"),
5657
] + api_urlpatterns

0 commit comments

Comments
 (0)