Skip to content

Commit d3110da

Browse files
Copilotjdkent
andauthored
Fix null coordinates validation and cloning for Point objects (#1138)
* Initial plan * Implement null coordinates validation and partial cloning fix Co-authored-by: jdkent <[email protected]> * Complete null coordinates fix - validation and cloning support Co-authored-by: jdkent <[email protected]> * Refactor PointSchema: simplify analysis_id logic and fix style issues - Move ValidationError import to top of file - Simplify analysis_id extraction logic to avoid duplication - Fix flake8 style issues (whitespace, line length) - Apply black formatting for consistent code style Co-authored-by: jdkent <[email protected]> * fix style issues --------- Co-authored-by: copilot-swe-agent[bot] <[email protected]> Co-authored-by: jdkent <[email protected]> Co-authored-by: James Kent <[email protected]>
1 parent fa59a47 commit d3110da

File tree

2 files changed

+143
-10
lines changed

2 files changed

+143
-10
lines changed

store/backend/neurostore/schemas/data.py

Lines changed: 37 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
pre_load,
88
post_load,
99
EXCLUDE,
10+
ValidationError,
1011
)
1112

1213
from sqlalchemy import func
@@ -273,26 +274,52 @@ class PointSchema(BaseDataSchema):
273274
coordinates = fields.List(fields.Float(), dump_only=True)
274275

275276
# deserialization
276-
x = fields.Float(load_only=True)
277-
y = fields.Float(load_only=True)
278-
z = fields.Float(load_only=True)
277+
x = fields.Float(load_only=True, allow_none=True)
278+
y = fields.Float(load_only=True, allow_none=True)
279+
z = fields.Float(load_only=True, allow_none=True)
279280

280281
class Meta:
281282
additional = ("kind", "space", "image", "label_id")
282-
allow_none = ("kind", "space", "image", "label_id")
283+
allow_none = ("kind", "space", "image", "label_id", "x", "y", "z")
283284

284285
@pre_load
285286
def process_values(self, data, **kwargs):
286-
# PointValues need special handling
287-
if data.get("coordinates"):
288-
coords = [float(c) for c in data.pop("coordinates")]
289-
data["x"], data["y"], data["z"] = coords
287+
# Handle case where data might be a string ID instead of dict
288+
if not isinstance(data, dict):
289+
return data
290+
291+
# Only process coordinates if they exist in the data
292+
if "coordinates" in data and data["coordinates"] is not None:
293+
coords = data.pop("coordinates")
294+
295+
# Check if all coordinates are null
296+
if all(c is None for c in coords):
297+
# During cloning, allow null coordinates but store them as None
298+
if self.context.get("clone"):
299+
data["x"], data["y"], data["z"] = None, None, None
300+
else:
301+
# Don't save points with all null coordinates to database
302+
raise ValidationError("Points cannot have all null coordinates")
303+
else:
304+
# Convert coordinates to float, handling potential null values
305+
try:
306+
converted_coords = [
307+
float(c) if c is not None else None for c in coords
308+
]
309+
data["x"], data["y"], data["z"] = converted_coords
310+
except (TypeError, ValueError) as e:
311+
raise ValidationError(f"Invalid coordinate values: {e}")
290312

291313
if data.get("order") is None:
292-
if data.get("analysis_id") is not None:
314+
# Extract analysis_id first, then check if it exists
315+
analysis_id = data.get("analysis_id") or (
316+
data.get("analysis") if isinstance(data.get("analysis"), str) else None
317+
)
318+
319+
if analysis_id:
293320
max_order = (
294321
db.session.query(func.max(Point.order))
295-
.filter_by(analysis_id=data["analysis_id"])
322+
.filter_by(analysis_id=analysis_id)
296323
.scalar()
297324
)
298325
data["order"] = 1 if max_order is None else max_order + 1
Lines changed: 106 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,106 @@
1+
#!/usr/bin/env python3
2+
3+
import pytest
4+
from marshmallow import ValidationError
5+
6+
7+
def test_point_schema_null_coordinates_validation():
8+
"""Test that PointSchema rejects null coordinates for new points"""
9+
from neurostore.schemas.data import PointSchema
10+
11+
point_data = {
12+
"analysis": "test_analysis_id",
13+
"coordinates": [None, None, None],
14+
"space": "MNI",
15+
"order": 0,
16+
}
17+
18+
schema = PointSchema()
19+
20+
with pytest.raises(ValidationError) as exc_info:
21+
schema.load(point_data)
22+
23+
assert "Points cannot have all null coordinates" in str(exc_info.value)
24+
25+
26+
def test_point_schema_null_coordinates_cloning():
27+
"""Test that PointSchema handles null coordinates during cloning"""
28+
from neurostore.schemas.data import PointSchema
29+
30+
point_data = {
31+
"analysis": "test_analysis_id",
32+
"coordinates": [None, None, None],
33+
"space": "MNI",
34+
"order": 0,
35+
}
36+
37+
schema = PointSchema(context={"clone": True})
38+
39+
# During cloning, null coordinates should be allowed and stored as None
40+
result = schema.load(point_data)
41+
42+
assert result["x"] is None
43+
assert result["y"] is None
44+
assert result["z"] is None
45+
46+
47+
def test_point_schema_valid_coordinates():
48+
"""Test that valid coordinates work normally"""
49+
from neurostore.schemas.data import PointSchema
50+
51+
point_data = {
52+
"analysis": "test_analysis_id",
53+
"coordinates": [1.0, 2.0, 3.0],
54+
"space": "MNI",
55+
"order": 0,
56+
}
57+
58+
schema = PointSchema()
59+
result = schema.load(point_data)
60+
61+
assert result["x"] == 1.0
62+
assert result["y"] == 2.0
63+
assert result["z"] == 3.0
64+
65+
66+
def test_analysis_schema_allows_null_coordinates_during_cloning():
67+
"""Test that AnalysisSchema allows null coordinate points during cloning"""
68+
from neurostore.schemas.data import AnalysisSchema
69+
70+
analysis_data = {
71+
"study": "test_study_id",
72+
"name": "test_analysis",
73+
"points": [
74+
{
75+
"analysis": "test_analysis_id",
76+
"coordinates": [1.0, 2.0, 3.0],
77+
"space": "MNI",
78+
"order": 0,
79+
},
80+
{
81+
"analysis": "test_analysis_id",
82+
"coordinates": [None, None, None],
83+
"space": "MNI",
84+
"order": 1,
85+
},
86+
],
87+
}
88+
89+
schema = AnalysisSchema(context={"clone": True, "nested": True})
90+
result = schema.load(analysis_data)
91+
92+
# Should have both points, including the null coordinate one
93+
points = result.get("points", [])
94+
assert len(points) == 2 # Both points should be present
95+
96+
# Verify the valid point has valid coordinates
97+
valid_point = next(p for p in points if p["x"] == 1.0)
98+
assert valid_point["x"] == 1.0
99+
assert valid_point["y"] == 2.0
100+
assert valid_point["z"] == 3.0
101+
102+
# Verify the null coordinate point has null coordinates
103+
null_point = next(p for p in points if p["x"] is None)
104+
assert null_point["x"] is None
105+
assert null_point["y"] is None
106+
assert null_point["z"] is None

0 commit comments

Comments
 (0)