Skip to content

Commit 5152232

Browse files
committed
Merge branch 'fix/project_metadata' into staging
2 parents dcd6979 + 55ca1cc commit 5152232

File tree

11 files changed

+256
-33
lines changed

11 files changed

+256
-33
lines changed

compose/backend/manage.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,9 @@
1414
from neurosynth_compose.database import db
1515
from neurosynth_compose import models
1616
from neurosynth_compose.ingest import neurostore as ingest_nstore
17+
from neurosynth_compose.backfill_extraction_metadata import (
18+
add_missing_extraction_ids,
19+
)
1720

1821

1922
app.config.from_object(os.environ["APP_SETTINGS"])
@@ -44,3 +47,12 @@ def create_meta_analyses(n_studysets, neurostore_url):
4447
if n_studysets is not None:
4548
n_studysets = int(n_studysets)
4649
ingest_nstore.create_meta_analyses(url=neurostore_url, n_studysets=n_studysets)
50+
51+
52+
@app.cli.command("backfill-extraction-metadata")
53+
def backfill_extraction_metadata():
54+
"""Add missing extractionMetadata ids to project provenance."""
55+
updated, skipped = add_missing_extraction_ids()
56+
click.echo(
57+
f"Updated {updated} project(s); skipped {skipped} project(s) with no changes."
58+
)

compose/backend/neurosynth_compose/scripts/__init__.py

Whitespace-only changes.
Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
import logging
2+
from typing import Tuple
3+
4+
from sqlalchemy import select
5+
6+
from neurosynth_compose.database import db
7+
from neurosynth_compose.models.analysis import Project
8+
9+
logger = logging.getLogger(__name__)
10+
11+
12+
def add_missing_extraction_ids(session=None) -> Tuple[int, int]:
13+
"""Add null studysetId/annotationId keys to extractionMetadata when absent."""
14+
sess = session or db.session
15+
updated = 0
16+
skipped = 0
17+
18+
projects = sess.scalars(select(Project)).all()
19+
20+
for project in projects:
21+
provenance = project.provenance or {}
22+
extraction_metadata = provenance.get("extractionMetadata")
23+
24+
if not isinstance(extraction_metadata, dict):
25+
skipped += 1
26+
continue
27+
28+
changed = False
29+
30+
if "studysetId" not in extraction_metadata:
31+
extraction_metadata["studysetId"] = None
32+
changed = True
33+
34+
if "annotationId" not in extraction_metadata:
35+
extraction_metadata["annotationId"] = None
36+
changed = True
37+
38+
if changed:
39+
provenance["extractionMetadata"] = extraction_metadata
40+
project.provenance = provenance
41+
updated += 1
42+
else:
43+
skipped += 1
44+
45+
if updated:
46+
try:
47+
sess.commit()
48+
except Exception:
49+
sess.rollback()
50+
logger.exception(
51+
"Failed to commit extractionMetadata backfill for projects."
52+
)
53+
raise
54+
else:
55+
sess.rollback()
56+
57+
return updated, skipped

store/backend/neurostore/ingest/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -364,7 +364,8 @@ def ingest_neurosynth(max_rows=None):
364364

365365
# add notes to annotation
366366
annot.note_keys = {
367-
k: _check_type(v) for k, v in annotation_row._asdict().items()
367+
k: {"type": _check_type(v) or "string", "order": idx}
368+
for idx, (k, v) in enumerate(annotation_row._asdict().items())
368369
}
369370
annot.annotation_analyses = notes
370371
for note in notes:

store/backend/neurostore/resources/data.py

Lines changed: 84 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -393,6 +393,84 @@ def get_affected_ids(self, ids):
393393
}
394394
return unique_ids
395395

396+
@staticmethod
397+
def _ordered_note_keys(note_keys):
398+
if not note_keys:
399+
return []
400+
keys = list(note_keys.keys())
401+
alphabetical = sorted(keys)
402+
if keys == alphabetical:
403+
return alphabetical
404+
return keys
405+
406+
@classmethod
407+
def _normalize_note_keys(cls, note_keys):
408+
if note_keys is None:
409+
return None
410+
if not isinstance(note_keys, dict):
411+
abort_validation("`note_keys` must be an object.")
412+
413+
ordered_keys = cls._ordered_note_keys(note_keys)
414+
normalized = OrderedDict()
415+
used_orders = set()
416+
next_order = 0
417+
418+
for key in ordered_keys:
419+
descriptor = note_keys.get(key) or {}
420+
note_type = descriptor.get("type")
421+
if note_type not in {"string", "number", "boolean"}:
422+
abort_validation(
423+
"Invalid `type` for note_keys entry "
424+
f"'{key}', choose from: ['boolean', 'number', 'string']."
425+
)
426+
427+
order = descriptor.get("order")
428+
if isinstance(order, bool) or (
429+
order is not None and not isinstance(order, int)
430+
):
431+
order = None
432+
433+
if isinstance(order, int) and order not in used_orders:
434+
used_orders.add(order)
435+
if order >= next_order:
436+
next_order = order + 1
437+
else:
438+
while next_order in used_orders:
439+
next_order += 1
440+
order = next_order
441+
used_orders.add(order)
442+
next_order += 1
443+
444+
normalized[key] = {"type": note_type, "order": order}
445+
446+
return normalized
447+
448+
@classmethod
449+
def _merge_note_keys(cls, existing, additions):
450+
"""
451+
additions is a mapping of key -> type
452+
"""
453+
base = cls._normalize_note_keys(existing or {}) or OrderedDict()
454+
used_orders = {v.get("order") for v in base.values() if isinstance(v, dict)}
455+
used_orders = {o for o in used_orders if isinstance(o, int)}
456+
next_order = max(used_orders, default=-1) + 1
457+
458+
for key, value_type in additions.items():
459+
if key in base:
460+
descriptor = base[key] or {}
461+
descriptor["type"] = value_type or descriptor.get("type") or "string"
462+
base[key] = descriptor
463+
continue
464+
465+
descriptor = {
466+
"type": value_type or "string",
467+
"order": next_order,
468+
}
469+
base[key] = descriptor
470+
next_order += 1
471+
472+
return base
473+
396474
@classmethod
397475
def load_nested_records(cls, data, record=None):
398476
if not data:
@@ -554,6 +632,9 @@ def put(self, id):
554632
schema = self._schema()
555633
data = schema.load(request_data)
556634

635+
if "note_keys" in data:
636+
data["note_keys"] = self._normalize_note_keys(data["note_keys"])
637+
557638
pipeline_payload = data.pop("pipelines", [])
558639

559640
args = {}
@@ -942,12 +1023,10 @@ def _apply_pipeline_columns(self, annotation, data, specs, column_counter):
9421023

9431024
if column_types:
9441025
if data.get("note_keys") is None:
945-
note_keys = dict(annotation.note_keys or {})
1026+
note_keys = self._normalize_note_keys(annotation.note_keys or {})
9461027
else:
947-
note_keys = dict(data["note_keys"])
948-
for key, value_type in column_types.items():
949-
note_keys[key] = value_type or "string"
950-
data["note_keys"] = note_keys
1028+
note_keys = self._normalize_note_keys(data["note_keys"])
1029+
data["note_keys"] = self._merge_note_keys(note_keys, column_types)
9511030

9521031
data["annotation_analyses"] = list(note_map.values())
9531032

store/backend/neurostore/schemas/data.py

Lines changed: 65 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -663,6 +663,70 @@ class AnnotationPipelineSchema(BaseSchema):
663663
columns = fields.List(fields.String(), required=True)
664664

665665

666+
class NoteKeysField(fields.Field):
667+
allowed_types = {"string", "number", "boolean"}
668+
669+
def _serialize(self, value, attr, obj, **kwargs):
670+
if not value:
671+
return {}
672+
serialized = {}
673+
for key, descriptor in value.items():
674+
if not isinstance(descriptor, dict):
675+
continue
676+
serialized[key] = {
677+
"type": descriptor.get("type"),
678+
"order": descriptor.get("order"),
679+
}
680+
return serialized
681+
682+
def _deserialize(self, value, attr, data, **kwargs):
683+
if value is None:
684+
return {}
685+
if not isinstance(value, dict):
686+
raise ValidationError("`note_keys` must be an object.")
687+
688+
normalized = {}
689+
used_orders = set()
690+
explicit_orders = []
691+
for descriptor in value.values():
692+
if isinstance(descriptor, dict) and isinstance(
693+
descriptor.get("order"), int
694+
):
695+
explicit_orders.append(descriptor["order"])
696+
next_order = max(explicit_orders, default=-1) + 1
697+
698+
for key, descriptor in value.items():
699+
if not isinstance(descriptor, dict):
700+
raise ValidationError("Each note key must map to an object.")
701+
702+
note_type = descriptor.get("type")
703+
if note_type not in self.allowed_types:
704+
raise ValidationError(
705+
f"Invalid note type for '{key}', choose from: {sorted(self.allowed_types)}"
706+
)
707+
708+
order = descriptor.get("order")
709+
if isinstance(order, bool) or (
710+
order is not None and not isinstance(order, int)
711+
):
712+
order = None
713+
714+
if isinstance(order, int) and order not in used_orders:
715+
used_orders.add(order)
716+
if order >= next_order:
717+
next_order = order + 1
718+
else:
719+
while next_order in used_orders:
720+
next_order += 1
721+
order = next_order
722+
used_orders.add(order)
723+
next_order += 1
724+
725+
normalized[key] = {"type": note_type, "order": order}
726+
727+
return normalized
728+
729+
666730
class AnnotationSchema(BaseDataSchema):
667731
# serialization
668732
studyset_id = fields.String(data_key="studyset")
@@ -675,7 +739,7 @@ class AnnotationSchema(BaseDataSchema):
675739
source_id = fields.String(dump_only=True, allow_none=True)
676740
source_updated_at = fields.DateTime(dump_only=True, allow_none=True)
677741

678-
note_keys = fields.Dict()
742+
note_keys = NoteKeysField()
679743
metadata = fields.Dict(attribute="metadata_", dump_only=True)
680744
# deserialization
681745
metadata_ = fields.Dict(data_key="metadata", load_only=True, allow_none=True)

0 commit comments

Comments
 (0)