Skip to content

Commit ece81ce

Browse files
authored
fix spec updates (#617)
* fix spec updates * fix flaky test
1 parent 7795865 commit ece81ce

File tree

5 files changed

+84
-25
lines changed

5 files changed

+84
-25
lines changed

compose/neurosynth_compose/models/analysis.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,11 @@ class SpecificationCondition(BaseMixin, db.Model):
4848
db.Text, db.ForeignKey("conditions.id"), index=True, primary_key=True
4949
)
5050
condition = relationship("Condition", backref=backref("specification_conditions"))
51-
specification = relationship("Specification", backref=backref("specification_conditions"))
51+
specification = relationship(
52+
"Specification", backref=backref("specification_conditions")
53+
)
54+
user_id = db.Column(db.Text, db.ForeignKey("users.external_id"))
55+
user = relationship("User", backref=backref("specification_conditions"))
5256

5357

5458
class Specification(BaseMixin, db.Model):
@@ -59,6 +63,7 @@ class Specification(BaseMixin, db.Model):
5963
filter = db.Column(db.Text)
6064
weights = association_proxy("specification_conditions", "weight")
6165
conditions = association_proxy("specification_conditions", "condition")
66+
database_studyset = db.Column(db.Text)
6267
corrector = db.Column(db.JSON)
6368
user_id = db.Column(db.Text, db.ForeignKey("users.external_id"))
6469
user = relationship("User", backref=backref("specifications"))

compose/neurosynth_compose/resources/analysis.py

Lines changed: 26 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
21
from collections import ChainMap
32
import pathlib
43
from operator import itemgetter
@@ -119,7 +118,10 @@ def update_or_create(cls, data, id=None, commit=True):
119118
only_ids = set(data.keys()) - set(["id"]) == set()
120119

121120
if cls._model is Condition:
122-
record = cls._model.query.filter_by(name=data.get('name')).first() or cls._model()
121+
record = (
122+
cls._model.query.filter_by(name=data.get("name")).first()
123+
or cls._model()
124+
)
123125
if id is None:
124126
record = cls._model()
125127
record.user = current_user
@@ -149,7 +151,8 @@ def update_or_create(cls, data, id=None, commit=True):
149151

150152
# get nested attributes
151153
nested_keys = [
152-
item for key in cls._nested.keys()
154+
item
155+
for key in cls._nested.keys()
153156
for item in (key if isinstance(key, tuple) else (key,))
154157
]
155158

@@ -164,25 +167,42 @@ def update_or_create(cls, data, id=None, commit=True):
164167
# Update nested attributes recursively
165168
for field, res_name in cls._nested.items():
166169
field = (field,) if not isinstance(field, tuple) else field
170+
if set(data.keys()).issubset(field):
171+
field = (list(data.keys())[0],)
172+
167173
try:
168174
rec_data = itemgetter(*field)(data)
169175
except KeyError:
170176
rec_data = None
171177

172178
ResCls = globals()[res_name]
179+
173180
if rec_data is not None:
174181
if isinstance(rec_data, tuple):
175182
rec_data = [dict(ChainMap(*rc)) for rc in zip(*rec_data)]
183+
# get ids of existing nested attributes
184+
existing_nested = None
185+
if cls._attribute_name:
186+
existing_nested = getattr(record, cls._attribute_name, None)
187+
188+
if existing_nested and len(existing_nested) == len(rec_data):
189+
_ = [
190+
rd.update({"id": ns.id})
191+
for rd, ns in zip(
192+
rec_data, getattr(record, cls._attribute_name)
193+
)
194+
]
176195
if isinstance(rec_data, list):
177196
nested = [
178-
ResCls.update_or_create(rec, commit=False)
179-
for rec in rec_data
197+
ResCls.update_or_create(rec, commit=False) for rec in rec_data
180198
]
181199
to_commit.extend(nested)
182200
else:
183201
nested = ResCls.update_or_create(rec_data, commit=False)
184202
to_commit.append(nested)
185-
update_field = field if len(field) == 1 else (cls._attribute_name,)
203+
update_field = (
204+
field if not cls._attribute_name else (cls._attribute_name,)
205+
)
186206
for f in update_field:
187207
setattr(record, f, nested)
188208

compose/neurosynth_compose/schemas/analysis.py

Lines changed: 8 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -130,10 +130,7 @@ class ConditionSchema(Schema):
130130
description = PGSQLString()
131131

132132

133-
class SpecificationConditionSchema(Schema):
134-
id = PGSQLString()
135-
created_at = fields.DateTime()
136-
updated_at = fields.DateTime(allow_none=True)
133+
class SpecificationConditionSchema(BaseSchema):
137134
condition = fields.Pluck(ConditionSchema, "name")
138135
weight = fields.Number()
139136

@@ -152,7 +149,7 @@ class StudysetReferenceSchema(Schema):
152149
exclude=("snapshot",),
153150
metadata={"pluck": "id"},
154151
many=True,
155-
dump_only=True
152+
dump_only=True,
156153
)
157154

158155

@@ -165,6 +162,7 @@ class SpecificationSchema(BaseSchema):
165162
mask = PGSQLString(allow_none=True)
166163
transformer = PGSQLString(allow_none=True)
167164
estimator = fields.Nested("EstimatorSchema")
165+
database_studyset = PGSQLString(allow_none=True)
168166
contrast = PGSQLString(allow_none=True)
169167
filter = PGSQLString(allow_none=True)
170168
corrector = fields.Dict(allow_none=True)
@@ -178,11 +176,7 @@ class SpecificationSchema(BaseSchema):
178176
data_key="conditions",
179177
)
180178
conditions = fields.Pluck(
181-
ConditionSchema,
182-
"name",
183-
many=True,
184-
allow_none=True,
185-
dump_only=True
179+
ConditionSchema, "name", many=True, allow_none=True, dump_only=True
186180
)
187181
weights = fields.List(
188182
fields.Float(),
@@ -213,7 +207,7 @@ def to_bool(self, data, **kwargs):
213207
output_conditions[i] = True
214208
elif cond.lower() == "false":
215209
output_conditions[i] = False
216-
data['conditions'] = conditions
210+
data["conditions"] = conditions
217211

218212
return data
219213

@@ -224,10 +218,10 @@ def to_string(self, data, **kwargs):
224218
output_conditions = conditions[:]
225219
for i, cond in enumerate(conditions):
226220
if cond is True:
227-
output_conditions[i] = 'true'
221+
output_conditions[i] = "true"
228222
elif cond is False:
229-
output_conditions[i] = 'false'
230-
data['conditions'] = output_conditions
223+
output_conditions[i] = "false"
224+
data["conditions"] = output_conditions
231225

232226
return data
233227

compose/neurosynth_compose/tests/api/test_specification.py

Lines changed: 42 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ def test_get_specification(session, app, auth_client, user_data):
2525
"corrector": {"type": "FDRCorrector"},
2626
"filter": "eyes",
2727
},
28-
]
28+
],
2929
)
3030
def test_create_and_get_spec(session, app, auth_client, user_data, specification_data):
3131
create_spec = auth_client.post("/api/specifications", data=specification_data)
@@ -35,3 +35,44 @@ def test_create_and_get_spec(session, app, auth_client, user_data, specification
3535
view_spec = auth_client.get(f"/api/specifications/{create_spec.json['id']}")
3636

3737
assert create_spec.json == view_spec.json
38+
39+
40+
@pytest.mark.parametrize(
41+
"attribute,value",
42+
[
43+
("estimator", {"type": "MKDA"}),
44+
("type", "ibma"),
45+
("conditions", ["yes", "no"]),
46+
("weights", [1, 1]),
47+
("corrector", {"type": "FWECorrector"}),
48+
("filter", "bunny"),
49+
("database_studyset", "neurostore"),
50+
],
51+
)
52+
def test_update_spec(session, app, auth_client, user_data, attribute, value):
53+
specification_data = {
54+
"estimator": {"type": "ALE"},
55+
"type": "cbma",
56+
"conditions": ["open", "closed"],
57+
"weights": [1, -1],
58+
"corrector": {"type": "FDRCorrector"},
59+
"filter": "eyes",
60+
}
61+
create_spec = auth_client.post("/api/specifications", data=specification_data)
62+
63+
assert create_spec.status_code == 200
64+
65+
spec_id = create_spec.json["id"]
66+
67+
update_spec = auth_client.put(
68+
f"/api/specifications/{spec_id}", data={attribute: value}
69+
)
70+
assert update_spec.status_code == 200
71+
72+
get_spec = auth_client.get(f"/api/specifications/{spec_id}")
73+
assert get_spec.status_code == 200
74+
75+
if isinstance(value, list):
76+
assert set(get_spec.json[attribute]) == set(value)
77+
else:
78+
assert get_spec.json[attribute] == value
Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,7 @@
1-
21
def test_studyset_references(session, app, auth_client, user_data):
32
nonnested = auth_client.get("/api/studyset-references?nested=false")
43
nested = auth_client.get("/api/studyset-references?nested=true")
54

65
assert nonnested.status_code == nested.status_code == 200
7-
assert isinstance(nonnested.json['results'][0]['studysets'][0], str)
8-
assert isinstance(nested.json['results'][0]['studysets'][0], dict)
6+
assert isinstance(nonnested.json["results"][0]["studysets"][0], str)
7+
assert isinstance(nested.json["results"][0]["studysets"][0], dict)

0 commit comments

Comments
 (0)