Skip to content

Commit 1a24d89

Browse files
committed
Accept nested JSON string fields in Multipart
1 parent f9bf70b commit 1a24d89

File tree

7 files changed

+142
-10
lines changed

7 files changed

+142
-10
lines changed

api/parsers.py

Lines changed: 110 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,110 @@
1+
import json
2+
from django.http import QueryDict
3+
from rest_framework import parsers
4+
5+
# NOTE: This class is needed to work with auto-generated OpenAPI SDKs.
6+
# It's important to mention that MultiParser from DRF needs from nested
7+
# dotted notation, e.g: location.point.latitude, location.point.longitude
8+
# But most OpenAPI SDKs (like openapi-generator) do not support that.
9+
# They only support nested JSON objects (encoded to string!), e.g:
10+
# location: '{"point": {"latitude": .., "longitude": ..} }'
11+
# This class converts those JSON strings into dotted notation keys.
12+
# If ever need to use bracket notation see: https://github.com/remigermain/nested-multipart-parser/
13+
class MultiPartJsonNestedParser(parsers.MultiPartParser):
14+
"""
15+
A custom multipart parser that extends MultiPartParser.
16+
17+
It parses nested JSON strings found in the value of form data fields
18+
and converts them into dotted notation keys in the QueryDict.
19+
"""
20+
def parse(self, stream, media_type=None, parser_context=None):
21+
"""
22+
Parses the multi-part request data and converts nested JSON to dotted notation.
23+
24+
Returns a tuple of (QueryDict, MultiValueDict).
25+
"""
26+
# Call the base parser to get the initial QueryDict (data) and MultiValueDict (files)
27+
result = super().parse(stream, media_type, parser_context)
28+
data = result.data
29+
files = result.files
30+
31+
# Create a mutable copy of the data QueryDict for modification
32+
mutable_data = data.copy()
33+
new_data = {}
34+
35+
# Iterate over all keys in the QueryDict
36+
for key, value_list in mutable_data.lists():
37+
# A value_list from QueryDict is always a list of strings
38+
39+
# 1. Attempt to parse the first value as JSON if it seems like a dictionary
40+
# We assume non-list values (like 'created_at') are single-element lists.
41+
# If the list has multiple elements, we treat the field as a list of non-JSON strings
42+
# and leave it alone (e.g., 'tags': ['tag1', 'tag2']).
43+
if len(value_list) == 1 and isinstance(value_list[0], str) and value_list[0].strip().startswith('{'):
44+
try:
45+
json_data = json.loads(value_list[0])
46+
# 2. Flatten the JSON dictionary into dotted notation
47+
flattened = self._flatten_dict(json_data, parent_key=key)
48+
# 3. Add the flattened data to our new_data dictionary
49+
new_data.update(flattened)
50+
51+
# Remove the original key as it's been expanded
52+
# This is implicitly done by building new_data, but for clarity:
53+
# mutable_data.pop(key)
54+
55+
except json.JSONDecodeError:
56+
# Not valid JSON, treat it as a regular string field
57+
new_data[key] = value_list
58+
59+
else:
60+
# Field is not a single JSON string, e.g., 'note': [''] or 'tags': ['tag1', 'tag2']
61+
# Keep the original data intact
62+
new_data[key] = value_list
63+
64+
# Convert the resulting dictionary back into a QueryDict
65+
# We need to construct it carefully as QueryDict expects lists of values
66+
final_data = QueryDict('', mutable=True)
67+
for k, v in new_data.items():
68+
# v will be either a list (from original data) or a single value (from flattened json)
69+
if isinstance(v, list):
70+
final_data.setlist(k, v)
71+
else:
72+
final_data[k] = v
73+
74+
return parsers.DataAndFiles(final_data, files)
75+
76+
def _flatten_dict(self, d, parent_key='', sep='.'):
77+
"""
78+
Recursively flattens a nested dictionary into a single-level dictionary
79+
with dotted keys.
80+
"""
81+
items = []
82+
for k, v in d.items():
83+
new_key = parent_key + sep + k if parent_key else k
84+
if isinstance(v, dict):
85+
# Recurse into nested dictionaries
86+
items.extend(self._flatten_dict(v, new_key, sep=sep).items())
87+
elif isinstance(v, list):
88+
# Handle lists by keeping the key and setting the value as the list
89+
# This is a simplification; a more complex parser might flatten lists too.
90+
items.append((new_key, v))
91+
else:
92+
# Add simple key-value pair
93+
items.append((new_key, v))
94+
95+
# When converting back to QueryDict, simple values (not lists) should be
96+
# left as single values for the QueryDict to handle correctly.
97+
final_flat_dict = {}
98+
for k, v in items:
99+
# Important: QueryDict expects lists for multi-value fields.
100+
# If the value is a list (from the JSON), keep it as a list.
101+
if isinstance(v, list):
102+
final_flat_dict[k] = v
103+
else:
104+
# For single values (str, int, float, bool, None), QueryDict will
105+
# automatically wrap it in a list upon assignment.
106+
# However, for consistency with how QueryDict works in general, we
107+
# store the single value.
108+
final_flat_dict[k] = str(v) # Convert to string for form data
109+
110+
return final_flat_dict

api/serializers.py

Lines changed: 20 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from uuid import UUID
55

66
from django.contrib.auth import get_user_model
7+
from django.contrib.gis.geos import Point
78
from django.db import transaction
89

910
from drf_spectacular.utils import extend_schema_field
@@ -551,7 +552,23 @@ class AdmBoundarySerializer(serializers.Serializer):
551552
source = serializers.CharField(required=True, allow_null=False)
552553
level = serializers.IntegerField(required=True, min_value=0)
553554

554-
point = PointField(required=True)
555+
class PointSerializer(serializers.Serializer):
556+
latitude = WritableSerializerMethodField(
557+
field_class=serializers.FloatField,
558+
required=True,
559+
)
560+
longitude = WritableSerializerMethodField(
561+
field_class=serializers.FloatField,
562+
required=True,
563+
)
564+
565+
def get_latitude(self, obj: Point) -> float:
566+
return obj.y
567+
568+
def get_longitude(self, obj: Point) -> float:
569+
return obj.x
570+
571+
point = PointSerializer(required=True)
555572
timezone = TimeZoneSerializerChoiceField(read_only=True, allow_null=True)
556573
country = CountrySerializer(read_only=True, allow_null=True)
557574
adm_boundaries = AdmBoundarySerializer(many=True, read_only=True)
@@ -579,8 +596,8 @@ def to_internal_value(self, data):
579596
preffix = "selected"
580597

581598
point = ret.pop("point")
582-
ret[f"{preffix}_location_lat"] = point.y
583-
ret[f"{preffix}_location_lon"] = point.x
599+
ret[f"{preffix}_location_lat"] = point['latitude']
600+
ret[f"{preffix}_location_lon"] = point['longitude']
584601

585602
return ret
586603

api/tests/integration/bites/create.tavern.yml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,8 @@ stages:
4141
data: &request_bite_data
4242
created_at: '2024-01-01T00:00:00Z'
4343
sent_at: '2024-01-01T00:30:00Z'
44-
location.point: !raw '{"latitude": 41.67419, "longitude": 2.79036}'
44+
location.point.latitude: 41.67419
45+
location.point.longitude: 2.79036
4546
location.source: 'auto'
4647
note: "Test"
4748
tags:

api/tests/integration/breeding_sites/create.tavern.yml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,8 @@ stages:
4747
data: &request_site_data
4848
created_at: '2024-01-01T00:00:00Z'
4949
sent_at: '2024-01-01T00:30:00Z'
50-
location.point: !raw '{"latitude": 41.67419, "longitude": 2.79036}'
50+
location.point.latitude: 41.67419
51+
location.point.longitude: 2.79036
5152
location.source: 'auto'
5253
note: "Test"
5354
tags:

api/tests/integration/observations/create.tavern.yml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,8 @@ stages:
4747
data: &request_site_data
4848
created_at: '2024-01-01T00:00:00Z'
4949
sent_at: '2024-01-01T00:30:00Z'
50-
location.point: !raw '{"latitude": 41.67419, "longitude": 2.79036}'
50+
location.point.latitude: 41.67419
51+
location.point.longitude: 2.79036
5152
location.source: 'auto'
5253
note: "Test"
5354
tags:

api/views.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
UpdateModelMixin,
2626
DestroyModelMixin,
2727
)
28-
from rest_framework.parsers import MultiPartParser, FormParser
28+
from rest_framework.parsers import FormParser
2929
from rest_framework.permissions import AllowAny, IsAuthenticated, SAFE_METHODS
3030
from rest_framework.response import Response
3131
from rest_framework.settings import api_settings
@@ -57,6 +57,7 @@
5757
TaxonFilter
5858
)
5959
from .mixins import IdentificationTaskNestedAttribute
60+
from .parsers import MultiPartJsonNestedParser
6061
from .serializers import (
6162
PartnerSerializer,
6263
CampaignSerializer,
@@ -333,7 +334,7 @@ def get_parsers(self):
333334
# Since photos are required on POST, only allow
334335
# parasers that allow files.
335336
if self.request and self.request.method == 'POST':
336-
return [MultiPartParser(), FormParser()]
337+
return [MultiPartJsonNestedParser(), FormParser()]
337338
return super().get_parsers()
338339

339340
class BiteViewSet(BaseReportViewSet):

api/viewsets.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,12 @@
11
from rest_framework.authentication import TokenAuthentication
22
from rest_framework.viewsets import GenericViewSet as DRFGenericViewSet
33
from rest_framework.pagination import PageNumberPagination
4-
from rest_framework.parsers import JSONParser, FormParser, MultiPartParser
4+
from rest_framework.parsers import JSONParser, FormParser
55
from rest_framework.renderers import JSONRenderer
66
from rest_framework_nested.viewsets import NestedViewSetMixin as OriginalNestedViewSetMixin, _force_mutable
77

88
from .auth.authentication import AppUserJWTAuthentication, NonAppUserSessionAuthentication
9+
from .parsers import MultiPartJsonNestedParser
910
from .permissions import UserObjectPermissions, IsMobileUser, DjangoRegularUserModelPermissions
1011

1112

@@ -31,7 +32,7 @@ def pagination_class(self):
3132
return self._pagination_class
3233

3334
permission_classes = (UserObjectPermissions,)
34-
parser_classes = (JSONParser, FormParser, MultiPartParser)
35+
parser_classes = (JSONParser, FormParser, MultiPartJsonNestedParser)
3536
renderer_classes = (JSONRenderer,)
3637

3738

0 commit comments

Comments
 (0)