Skip to content

Commit a45c051

Browse files
authored
[ENH] add ability to update study upon cloning (#682)
* add ability to update study upon cloning * update function signatures
1 parent 58520ac commit a45c051

File tree

5 files changed

+72
-29
lines changed

5 files changed

+72
-29
lines changed

store/neurostore/resources/base.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -529,13 +529,15 @@ def post(self):
529529
args = parser.parse(self._user_args, request, location="query")
530530
source_id = args.get("source_id")
531531
source = args.get("source") or "neurostore"
532+
533+
unknown = self.__class__._schema.opts.unknown
534+
data = parser.parse(
535+
self.__class__._schema(exclude=("id",)), request, unknown=unknown
536+
)
537+
532538
if source_id:
533-
data = self._load_from_source(source, source_id)
534-
else:
535-
unknown = self.__class__._schema.opts.unknown
536-
data = parser.parse(
537-
self.__class__._schema(exclude=("id",)), request, unknown=unknown
538-
)
539+
data = self._load_from_source(source, source_id, data)
540+
539541
args["nested"] = bool(args.get("nested") or request.args.get("source_id"))
540542

541543
with db.session.no_autoflush:

store/neurostore/resources/data.py

Lines changed: 23 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,6 @@
5353
"nested": fields.Boolean(load_default=False),
5454
}
5555

56-
5756
# Individual resource views
5857

5958

@@ -191,12 +190,12 @@ def insert_data(self, id, data):
191190
return data
192191

193192
@classmethod
194-
def _load_from_source(cls, source, source_id):
193+
def _load_from_source(cls, source, source_id, data=None):
195194
if source == "neurostore":
196-
return cls.load_from_neurostore(source_id)
195+
return cls.load_from_neurostore(source_id, data)
197196

198197
@classmethod
199-
def load_from_neurostore(cls, source_id):
198+
def load_from_neurostore(cls, source_id, data=None):
200199
annotation = cls._model.query.filter_by(id=source_id).first_or_404()
201200
parent_source_id = annotation.source_id
202201
parent_source = annotation.source
@@ -418,16 +417,16 @@ def serialize_records(self, records, args, exclude=tuple()):
418417
return super().serialize_records(records, args, exclude)
419418

420419
@classmethod
421-
def _load_from_source(cls, source, source_id):
420+
def _load_from_source(cls, source, source_id, data=None):
422421
if source == "neurostore":
423-
return cls.load_from_neurostore(source_id)
422+
return cls.load_from_neurostore(source_id, data)
424423
elif source == "neurovault":
425-
return cls.load_from_neurovault(source_id)
424+
return cls.load_from_neurovault(source_id, data)
426425
elif source == "pubmed":
427-
return cls.load_from_pubmed(source_id)
426+
return cls.load_from_pubmed(source_id, data)
428427

429428
@classmethod
430-
def load_from_neurostore(cls, source_id):
429+
def load_from_neurostore(cls, source_id, data=None):
431430
study = cls._model.query.filter_by(id=source_id).first_or_404()
432431
parent_source_id = study.source_id
433432
parent_source = study.source
@@ -437,22 +436,26 @@ def load_from_neurostore(cls, source_id):
437436
parent_source = parent.source
438437
parent_source_id = parent.source_id
439438

440-
context = {"clone": True, "nested": True}
441-
schema = cls._schema(context=context)
442-
dump_study = schema.dump(study)
443-
data = schema.load(dump_study)
444-
data["source"] = "neurostore"
445-
data["source_id"] = source_id
446-
data["source_updated_at"] = study.updated_at or study.created_at
447-
data["base_study"] = {"id": study.base_study_id}
448-
return data
439+
update_schema = cls._schema(context={"nested": True})
440+
clone_data = update_schema.load(update_schema.dump(study))
441+
# update data with new source
442+
clone_data.update(data)
443+
444+
context = {"nested": True, "clone": True}
445+
return_schema = cls._schema(context=context)
446+
clone_data = return_schema.load(return_schema.dump(clone_data))
447+
clone_data["source"] = "neurostore"
448+
clone_data["source_id"] = source_id
449+
clone_data["source_updated_at"] = study.updated_at or study.created_at
450+
clone_data["base_study"] = {"id": study.base_study_id}
451+
return clone_data
449452

450453
@classmethod
451-
def load_from_neurovault(cls, source_id):
454+
def load_from_neurovault(cls, source_id, data=None):
452455
pass
453456

454457
@classmethod
455-
def load_from_pubmed(cls, source_id):
458+
def load_from_pubmed(cls, source_id, data=None):
456459
pass
457460

458461
def pre_nested_record_update(record):

store/neurostore/schemas/data.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -209,15 +209,16 @@ class PointSchema(BaseDataSchema):
209209
cluster_size = fields.Float(allow_none=True)
210210
subpeak = fields.Boolean(allow_none=True)
211211
order = fields.Integer()
212+
coordinates = fields.List(fields.Float(), dump_only=True)
212213

213214
# deserialization
214215
x = fields.Float(load_only=True)
215216
y = fields.Float(load_only=True)
216217
z = fields.Float(load_only=True)
217218

218219
class Meta:
219-
additional = ("kind", "space", "coordinates", "image", "label_id")
220-
allow_none = ("kind", "space", "coordinates", "image", "label_id")
220+
additional = ("kind", "space", "image", "label_id")
221+
allow_none = ("kind", "space", "image", "label_id")
221222

222223
@pre_load
223224
def process_values(self, data, **kwargs):
@@ -227,6 +228,14 @@ def process_values(self, data, **kwargs):
227228
data["x"], data["y"], data["z"] = coords
228229
return data
229230

231+
@pre_dump
232+
def pre_dump_process(self, data, **kwargs):
233+
if getattr(data, "coordinates", None) or data.get("coordinates"):
234+
return data
235+
if isinstance(data, dict):
236+
data["coordinates"] = [data["x"], data["y"], data["z"]]
237+
return data
238+
230239

231240
class AnalysisConditionSchema(BaseDataSchema):
232241
weight = fields.Float()

store/neurostore/tests/api/test_base_studies.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ def test_info_base_study(auth_client, ingest_neurosynth, session):
5858
assert single_reg_resp.status_code == 200
5959
info_fields = ["has_coordinates", "has_images", "studysets"]
6060
for field in info_fields:
61-
assert field in single_info_resp.json()['versions'][0]
61+
assert field in single_info_resp.json()["versions"][0]
6262
assert "id" in single_info_resp.json()["versions"][0]
6363
assert isinstance(single_reg_resp.json()["versions"][0], str)
6464

store/neurostore/tests/api/test_studies.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import pytest
22

33
from ...models import Studyset, Study, User, Analysis
4+
from ...schemas import StudySchema
45

56

67
def test_create_study_as_user_and_analysis_as_bot(auth_clients, session):
@@ -111,6 +112,34 @@ def test_clone_studies(auth_client, ingest_neurosynth, session):
111112
)
112113

113114

115+
def test_clone_studies_with_data(auth_client, ingest_neurosynth, session):
116+
study_entry = Study.query.first()
117+
schema = StudySchema(context={"nested": True})
118+
study_data = schema.dump(study_entry)
119+
half_points = len(study_data["analyses"][0]["points"]) // 2
120+
121+
first_analysis_points = study_data["analyses"][0]["points"][0:half_points]
122+
second_analysis_points = study_data["analyses"][0]["points"][half_points:]
123+
first_analysis_points[0]["coordinates"] = [0, 0, 0]
124+
second_analysis_points[0]["coordinates"] = [0, 0, 0]
125+
study_data["analyses"][0]["points"] = first_analysis_points
126+
study_data["analyses"].append(
127+
{"name": "new analysis", "points": second_analysis_points}
128+
)
129+
130+
resp = auth_client.post(
131+
f"/api/studies/?source_id={study_entry.id}",
132+
data=study_data,
133+
)
134+
data = resp.json()
135+
assert data["name"] == study_entry.name
136+
assert data["source_id"] == study_entry.id
137+
assert data["source"] == "neurostore"
138+
assert data["analyses"][0]["points"][0]["coordinates"] == [0, 0, 0]
139+
assert data["analyses"][1]["points"][0]["coordinates"] == [0, 0, 0]
140+
assert "new analysis" in [a["name"] for a in data["analyses"]]
141+
142+
114143
def test_private_studies(user_data, auth_clients, session):
115144
from ...resources.users import User
116145

0 commit comments

Comments
 (0)