Skip to content

Commit 454ea54

Browse files
authored
[ENH] add conditions and weights to specification (#597)
* add studyset-references endpoint * allow version to be none * update how specification is updated * handle boolean inputs * update openapi * update openapi
1 parent b0d35b2 commit 454ea54

File tree

10 files changed

+190
-14
lines changed

10 files changed

+190
-14
lines changed

compose/neurosynth_compose/models/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@
1616
from .auth import User
1717

1818
__all__ = [
19+
"Condition",
20+
"SpecificationCondition",
1921
"Specification",
2022
"Studyset",
2123
"StudysetReference",

compose/neurosynth_compose/models/analysis.py

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
"""TODO: PLACE INTO THE NEUROSYNTH APP"""
22
from sqlalchemy.orm import relationship, backref
33
from sqlalchemy.sql import func
4+
from sqlalchemy.ext.associationproxy import association_proxy
45
import shortuuid
56
import secrets
67

@@ -31,16 +32,35 @@ class BaseMixin(object):
3132
# relationship("User", backref=cls.__tablename__, uselist=False)
3233

3334

35+
class Condition(BaseMixin, db.Model):
36+
__tablename__ = "conditions"
37+
name = db.Column(db.Text)
38+
description = db.Column(db.Text)
39+
40+
41+
class SpecificationCondition(BaseMixin, db.Model):
42+
__tablename__ = "specification_conditions"
43+
weight = db.Column(db.Float)
44+
specification_id = db.Column(
45+
db.Text, db.ForeignKey("specifications.id"), index=True, primary_key=True
46+
)
47+
condition_id = db.Column(
48+
db.Text, db.ForeignKey("conditions.id"), index=True, primary_key=True
49+
)
50+
condition = relationship("Condition", backref=backref("specification_conditions"))
51+
specification = relationship("Specification", backref=backref("specification_conditions"))
52+
53+
3454
class Specification(BaseMixin, db.Model):
3555
__tablename__ = "specifications"
3656

3757
type = db.Column(db.Text)
3858
estimator = db.Column(db.JSON)
3959
filter = db.Column(db.Text)
40-
contrast = db.Column(db.JSON)
60+
weights = association_proxy("specification_conditions", "weight")
61+
conditions = association_proxy("specification_conditions", "condition")
4162
corrector = db.Column(db.JSON)
4263
user_id = db.Column(db.Text, db.ForeignKey("users.external_id"))
43-
4464
user = relationship("User", backref=backref("specifications"))
4565

4666

compose/neurosynth_compose/openapi

compose/neurosynth_compose/resources/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
11
from .analysis import (
2+
ConditionsResource,
3+
SpecificationConditionsResource,
24
MetaAnalysesView,
35
MetaAnalysisResultsView,
46
NeurovaultCollectionsView,
@@ -15,6 +17,8 @@
1517
from .users import UsersView
1618

1719
__all__ = [
20+
"ConditionsResource",
21+
"SpecificationConditionsResource",
1822
"MetaAnalysesView",
1923
"MetaAnalysisResultsView",
2024
"NeurovaultCollectionsView",

compose/neurosynth_compose/resources/analysis.py

Lines changed: 46 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,7 @@
1+
2+
from collections import ChainMap
13
import pathlib
4+
from operator import itemgetter
25

36
import connexion
47
from flask import abort, request, jsonify, current_app
@@ -13,6 +16,8 @@
1316

1417
from ..database import db
1518
from ..models.analysis import ( # noqa E401
19+
Condition,
20+
SpecificationCondition,
1621
Studyset,
1722
Annotation,
1823
MetaAnalysis,
@@ -29,6 +34,8 @@
2934
from ..models.auth import User
3035

3136
from ..schemas import ( # noqa E401
37+
ConditionSchema,
38+
SpecificationConditionSchema,
3239
MetaAnalysisSchema,
3340
AnnotationSchema,
3441
StudysetSchema,
@@ -87,6 +94,7 @@ class ClassView(cls):
8794
class BaseView(MethodView):
8895
_model = None
8996
_nested = {}
97+
_attribute_name = None
9098

9199
@classmethod
92100
def _external_request(cls, data, record, id):
@@ -110,6 +118,8 @@ def update_or_create(cls, data, id=None, commit=True):
110118

111119
only_ids = set(data.keys()) - set(["id"]) == set()
112120

121+
if cls._model is Condition:
122+
record = cls._model.query.filter_by(name=data.get('name')).first() or cls._model()
113123
if id is None:
114124
record = cls._model()
115125
record.user = current_user
@@ -137,29 +147,44 @@ def update_or_create(cls, data, id=None, commit=True):
137147
# check if external request updated the data already
138148
committed = cls._external_request(data, record, id)
139149

150+
# get nested attributes
151+
nested_keys = [
152+
item for key in cls._nested.keys()
153+
for item in (key if isinstance(key, tuple) else (key,))
154+
]
155+
140156
# Update all non-nested attributes
141157
if not committed:
142158
for k, v in data.items():
143-
if k not in cls._nested and k not in ["id", "user"]:
159+
if k not in nested_keys and k not in ["id", "user"]:
144160
setattr(record, k, v)
145161

146162
to_commit.append(record)
147163

148164
# Update nested attributes recursively
149165
for field, res_name in cls._nested.items():
166+
field = (field,) if not isinstance(field, tuple) else field
167+
try:
168+
rec_data = itemgetter(*field)(data)
169+
except KeyError:
170+
rec_data = None
171+
150172
ResCls = globals()[res_name]
151-
if data.get(field) is not None:
152-
if isinstance(data.get(field), list):
173+
if rec_data is not None:
174+
if isinstance(rec_data, tuple):
175+
rec_data = [dict(ChainMap(*rc)) for rc in zip(*rec_data)]
176+
if isinstance(rec_data, list):
153177
nested = [
154178
ResCls.update_or_create(rec, commit=False)
155-
for rec in data.get(field)
179+
for rec in rec_data
156180
]
157181
to_commit.extend(nested)
158182
else:
159-
nested = ResCls.update_or_create(data.get(field), commit=False)
183+
nested = ResCls.update_or_create(rec_data, commit=False)
160184
to_commit.append(nested)
161-
162-
setattr(record, field, nested)
185+
update_field = field if len(field) == 1 else (cls._attribute_name,)
186+
for f in update_field:
187+
setattr(record, f, nested)
163188

164189
if commit:
165190
db.session.add_all(to_commit)
@@ -374,7 +399,10 @@ class StudysetsView(ObjectView, ListView):
374399

375400
@view_maker
376401
class SpecificationsView(ObjectView, ListView):
377-
pass
402+
_nested = {
403+
("conditions", "weights"): "SpecificationConditionsResource",
404+
}
405+
_attribute_name = "specification_conditions"
378406

379407

380408
@view_maker
@@ -387,6 +415,16 @@ class AnnotationReferencesResource(ObjectView):
387415
pass
388416

389417

418+
@view_maker
419+
class ConditionsResource(ObjectView):
420+
pass
421+
422+
423+
@view_maker
424+
class SpecificationConditionsResource(ObjectView):
425+
_nested = {"condition": "ConditionsResource"}
426+
427+
390428
@view_maker
391429
class MetaAnalysisResultsView(ObjectView, ListView):
392430
_nested = {

compose/neurosynth_compose/schemas/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
11
from .analysis import (
2+
ConditionSchema,
3+
SpecificationConditionSchema,
24
MetaAnalysisSchema,
35
MetaAnalysisResultSchema,
46
NeurovaultCollectionSchema,
@@ -17,6 +19,8 @@
1719

1820

1921
__all__ = [
22+
"ConditionSchema",
23+
"SpecificationConditionSchema",
2024
"MetaAnalysisSchema",
2125
"MetaAnalysisResultSchema",
2226
"ResultInitSchema",

compose/neurosynth_compose/schemas/analysis.py

Lines changed: 75 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from marshmallow import fields, Schema, utils, post_load, post_dump
1+
from marshmallow import fields, Schema, utils, post_load, post_dump, pre_load
22

33

44
# neurovault api base URL
@@ -122,6 +122,22 @@ class BaseSchema(Schema):
122122
username = fields.String(attribute="user.name", dump_only=True)
123123

124124

125+
class ConditionSchema(Schema):
126+
id = PGSQLString()
127+
created_at = fields.DateTime()
128+
updated_at = fields.DateTime(allow_none=True)
129+
name = PGSQLString()
130+
description = PGSQLString()
131+
132+
133+
class SpecificationConditionSchema(Schema):
134+
id = PGSQLString()
135+
created_at = fields.DateTime()
136+
updated_at = fields.DateTime(allow_none=True)
137+
condition = fields.Pluck(ConditionSchema, "name")
138+
weight = fields.Number()
139+
140+
125141
class EstimatorSchema(Schema):
126142
type = fields.String()
127143
args = fields.Dict()
@@ -152,11 +168,69 @@ class SpecificationSchema(BaseSchema):
152168
contrast = PGSQLString(allow_none=True)
153169
filter = PGSQLString(allow_none=True)
154170
corrector = fields.Dict(allow_none=True)
171+
_conditions = fields.Pluck(
172+
SpecificationConditionSchema,
173+
"condition",
174+
many=True,
175+
allow_none=True,
176+
load_only=True,
177+
attribute="conditions",
178+
data_key="conditions",
179+
)
180+
conditions = fields.Pluck(
181+
ConditionSchema,
182+
"name",
183+
many=True,
184+
allow_none=True,
185+
dump_only=True
186+
)
187+
weights = fields.List(
188+
fields.Float(),
189+
allow_none=True,
190+
dump_only=True,
191+
)
192+
_weights = fields.Pluck(
193+
SpecificationConditionSchema,
194+
"weight",
195+
many=True,
196+
allow_none=True,
197+
load_only=True,
198+
data_key="weights",
199+
attribute="weights",
200+
)
155201

156202
class Meta:
157203
additional = ("name", "description")
158204
allow_none = ("name", "description")
159205

206+
@post_dump
207+
def to_bool(self, data, **kwargs):
208+
conditions = data.get("conditions", None)
209+
if conditions:
210+
output_conditions = conditions[:]
211+
for i, cond in enumerate(conditions):
212+
if cond.lower() == "true":
213+
output_conditions[i] = True
214+
elif cond.lower() == "false":
215+
output_conditions[i] = False
216+
data['conditions'] = conditions
217+
218+
return data
219+
220+
@pre_load
221+
def to_string(self, data, **kwargs):
222+
conditions = data.get("conditions", None)
223+
if conditions:
224+
output_conditions = conditions[:]
225+
for i, cond in enumerate(conditions):
226+
if cond is True:
227+
output_conditions[i] = 'true'
228+
elif cond is False:
229+
output_conditions[i] = 'false'
230+
data['conditions'] = output_conditions
231+
232+
return data
233+
160234

161235
class StudysetSchema(BaseSchema):
162236
snapshot = fields.Dict()
Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,37 @@
1+
import pytest
2+
3+
14
def test_get_specification(session, app, auth_client, user_data):
25
get = auth_client.get("/api/specifications")
36
assert get.status_code == 200
7+
8+
9+
@pytest.mark.parametrize(
10+
"specification_data",
11+
[
12+
{
13+
"estimator": {"type": "ALE"},
14+
"type": "cbma",
15+
"conditions": ["open", "closed"],
16+
"weights": [1, -1],
17+
"corrector": {"type": "FDRCorrector"},
18+
"filter": "eyes",
19+
},
20+
{
21+
"estimator": {"type": "ALE"},
22+
"type": "cbma",
23+
"conditions": [True],
24+
"weights": [1],
25+
"corrector": {"type": "FDRCorrector"},
26+
"filter": "eyes",
27+
},
28+
]
29+
)
30+
def test_create_and_get_spec(session, app, auth_client, user_data, specification_data):
31+
create_spec = auth_client.post("/api/specifications", data=specification_data)
32+
33+
assert create_spec.status_code == 200
34+
35+
view_spec = auth_client.get(f"/api/specifications/{create_spec.json['id']}")
36+
37+
assert create_spec.json == view_spec.json

store/neurostore/openapi

0 commit comments

Comments
 (0)