Skip to content

Commit a15779c

Browse files
committed
run style
1 parent 53f6a8c commit a15779c

File tree

3 files changed

+569
-2
lines changed

3 files changed

+569
-2
lines changed

store/backend/neurostore/resources/data.py

Lines changed: 363 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
import string
2+
from collections import Counter, OrderedDict
23
from copy import deepcopy
4+
from datetime import datetime
35
from sqlalchemy import func, text
46
from sqlalchemy.exc import SQLAlchemyError
57

@@ -47,7 +49,7 @@
4749
PipelineConfig,
4850
Pipeline,
4951
)
50-
from ..models.data import StudysetStudy, BaseStudy
52+
from ..models.data import StudysetStudy, BaseStudy, _check_type
5153
from ..utils import parse_json_filter, build_jsonpath
5254

5355

@@ -56,6 +58,7 @@
5658
AnalysisConditionSchema,
5759
StudysetStudySchema,
5860
EntitySchema,
61+
PipelineStudyResultSchema,
5962
)
6063
from ..schemas.data import StudysetSnapshot
6164

@@ -544,6 +547,365 @@ def db_validation(self, record, data):
544547
"annotation request must contain all analyses from the studyset."
545548
)
546549

550+
def put(self, id):
551+
request_data = self.insert_data(id, request.json)
552+
schema = self._schema()
553+
data = schema.load(request_data)
554+
555+
pipeline_payload = data.pop("pipelines", [])
556+
557+
args = {}
558+
if set(self._o2m.keys()).intersection(set(data.keys())):
559+
args["nested"] = True
560+
561+
q = self._model.query.filter_by(id=id)
562+
q = self.eager_load(q, args)
563+
input_record = q.one()
564+
565+
if pipeline_payload:
566+
specs, column_counter = self._normalize_pipeline_specs(pipeline_payload)
567+
self._apply_pipeline_columns(input_record, data, specs, column_counter)
568+
569+
self.db_validation(input_record, data)
570+
571+
with db.session.no_autoflush:
572+
record = self.__class__.update_or_create(data, id, record=input_record)
573+
574+
with db.session.no_autoflush:
575+
unique_ids = self.get_affected_ids([id])
576+
clear_cache(unique_ids)
577+
578+
try:
579+
self.update_base_studies(unique_ids.get("base-studies"))
580+
self.update_annotations(unique_ids.get("annotations"))
581+
except SQLAlchemyError as e:
582+
db.session.rollback()
583+
abort_validation(str(e))
584+
585+
db.session.flush()
586+
587+
response = schema.dump(record)
588+
589+
db.session.commit()
590+
591+
return response
592+
593+
@staticmethod
594+
def _extract_analysis_id(note):
595+
if not isinstance(note, dict):
596+
return None
597+
analysis = note.get("analysis")
598+
if isinstance(analysis, dict):
599+
return analysis.get("id")
600+
return analysis
601+
602+
def _normalize_pipeline_specs(self, pipelines):
603+
specs = []
604+
column_counter = Counter()
605+
606+
if not isinstance(pipelines, list):
607+
field_err = make_field_error("pipelines", pipelines, code="INVALID_FORMAT")
608+
abort_validation(
609+
"`pipelines` must be provided as a list of pipeline descriptors.",
610+
[field_err],
611+
)
612+
613+
for idx, payload in enumerate(pipelines):
614+
if not isinstance(payload, dict):
615+
field_err = make_field_error(
616+
f"pipelines[{idx}]", payload, code="INVALID_VALUE"
617+
)
618+
abort_validation(
619+
"Each pipeline descriptor must be an object.", [field_err]
620+
)
621+
622+
name = payload.get("name")
623+
if not isinstance(name, str) or not name.strip():
624+
field_err = make_field_error(
625+
f"pipelines[{idx}].name", name, code="MISSING_VALUE"
626+
)
627+
abort_validation(
628+
"Pipeline entries must include a non-empty `name`.", [field_err]
629+
)
630+
name = name.strip()
631+
632+
columns = payload.get("columns")
633+
if not isinstance(columns, (list, tuple)) or not columns:
634+
field_err = make_field_error(
635+
f"pipelines[{idx}].columns", columns, code="MISSING_VALUE"
636+
)
637+
abort_validation(
638+
"Pipeline entries must include a non-empty `columns` list.",
639+
[field_err],
640+
)
641+
normalized_columns = []
642+
for column in columns:
643+
if isinstance(column, str) and column.strip():
644+
col = column.strip()
645+
if col not in normalized_columns:
646+
normalized_columns.append(col)
647+
if not normalized_columns:
648+
field_err = make_field_error(
649+
f"pipelines[{idx}].columns", columns, code="INVALID_VALUE"
650+
)
651+
abort_validation(
652+
"Columns must contain at least one non-empty string.", [field_err]
653+
)
654+
655+
version = payload.get("version")
656+
if version is not None and not isinstance(version, str):
657+
field_err = make_field_error(
658+
f"pipelines[{idx}].version", version, code="INVALID_VALUE"
659+
)
660+
abort_validation(
661+
"`version` must be a string when provided.", [field_err]
662+
)
663+
config_id = payload.get("config_id")
664+
if config_id is not None and not isinstance(config_id, str):
665+
field_err = make_field_error(
666+
f"pipelines[{idx}].config_id", config_id, code="INVALID_VALUE"
667+
)
668+
abort_validation(
669+
"`config_id` must be a string when provided.", [field_err]
670+
)
671+
672+
spec = {
673+
"name": name,
674+
"columns": normalized_columns,
675+
"version": version.strip() if isinstance(version, str) else None,
676+
"config_id": config_id.strip() if isinstance(config_id, str) else None,
677+
}
678+
specs.append(spec)
679+
680+
for column in normalized_columns:
681+
column_counter[column] += 1
682+
683+
return specs, column_counter
684+
685+
def _build_note_map(self, annotation, incoming_notes):
686+
note_map = OrderedDict()
687+
analysis_context = {}
688+
689+
for aa in annotation.annotation_analyses:
690+
studyset_study = getattr(aa, "studyset_study", None)
691+
study = getattr(studyset_study, "study", None) if studyset_study else None
692+
if study is None and aa.analysis:
693+
study = aa.analysis.study
694+
base_study_id = getattr(study, "base_study_id", None)
695+
analysis_context[aa.analysis_id] = {
696+
"base_study_id": base_study_id,
697+
"study_id": aa.study_id,
698+
"studyset_id": aa.studyset_id,
699+
}
700+
note_map[aa.analysis_id] = {
701+
"id": aa.id,
702+
"analysis": {"id": aa.analysis_id},
703+
"studyset_study": {
704+
"study": {"id": aa.study_id},
705+
"studyset": {"id": aa.studyset_id},
706+
},
707+
"note": dict(aa.note or {}),
708+
}
709+
710+
for note in incoming_notes or []:
711+
if not isinstance(note, dict):
712+
continue
713+
analysis_id = self._extract_analysis_id(note)
714+
if not analysis_id:
715+
continue
716+
if analysis_id in note_map:
717+
merged = note_map[analysis_id]
718+
if "note" in note and isinstance(note["note"], dict):
719+
merged["note"] = dict(note["note"])
720+
if "studyset_study" in note and note["studyset_study"]:
721+
merged["studyset_study"] = note["studyset_study"]
722+
if "analysis" in note and note["analysis"]:
723+
merged["analysis"] = note["analysis"]
724+
if "id" in note and note["id"]:
725+
merged["id"] = note["id"]
726+
else:
727+
note_map[analysis_id] = note
728+
analysis_context.setdefault(
729+
analysis_id,
730+
{"base_study_id": None, "study_id": None, "studyset_id": None},
731+
)
732+
733+
return note_map, analysis_context
734+
735+
def _fetch_pipeline_data(self, spec, base_study_ids):
736+
pipeline = Pipeline.query.filter_by(name=spec["name"]).first()
737+
if not pipeline:
738+
field_err = make_field_error("pipeline", spec["name"], code="NOT_FOUND")
739+
abort_validation(f"Pipeline '{spec['name']}' does not exist.", [field_err])
740+
741+
config_query = PipelineConfig.query.filter_by(pipeline_id=pipeline.id)
742+
743+
resolved_config = None
744+
if spec["config_id"]:
745+
resolved_config = config_query.filter_by(id=spec["config_id"]).first()
746+
if not resolved_config:
747+
field_err = make_field_error(
748+
"config_id", spec["config_id"], code="NOT_FOUND"
749+
)
750+
abort_validation(
751+
f"Pipeline '{spec['name']}' does not have config '{spec['config_id']}'.",
752+
[field_err],
753+
)
754+
755+
configs_for_version = []
756+
if spec["version"]:
757+
configs_for_version = config_query.filter_by(version=spec["version"]).all()
758+
if not configs_for_version:
759+
field_err = make_field_error(
760+
"version", spec["version"], code="NOT_FOUND"
761+
)
762+
abort_validation(
763+
f"Pipeline '{spec['name']}' does not have version '{spec['version']}'.",
764+
[field_err],
765+
)
766+
if not resolved_config and len(configs_for_version) == 1:
767+
resolved_config = configs_for_version[0]
768+
769+
query = (
770+
db.session.query(
771+
PipelineStudyResult.base_study_id,
772+
PipelineStudyResult.result_data,
773+
PipelineStudyResult.date_executed,
774+
PipelineStudyResult.created_at,
775+
PipelineStudyResult.config_id,
776+
PipelineConfig.version,
777+
)
778+
.join(PipelineConfig, PipelineStudyResult.config_id == PipelineConfig.id)
779+
.filter(PipelineConfig.pipeline_id == pipeline.id)
780+
)
781+
782+
if base_study_ids:
783+
query = query.filter(PipelineStudyResult.base_study_id.in_(base_study_ids))
784+
if spec["config_id"]:
785+
query = query.filter(PipelineConfig.id == spec["config_id"])
786+
if spec["version"]:
787+
query = query.filter(PipelineConfig.version == spec["version"])
788+
789+
rows = query.all()
790+
791+
per_base = {}
792+
for row in rows:
793+
timestamp = row.date_executed or row.created_at or datetime.min
794+
existing = per_base.get(row.base_study_id)
795+
if existing is None or timestamp > existing["timestamp"]:
796+
result_data = (
797+
row.result_data if isinstance(row.result_data, dict) else {}
798+
)
799+
flattened = (
800+
PipelineStudyResultSchema.flatten_dict(result_data)
801+
if isinstance(result_data, dict)
802+
else {}
803+
)
804+
per_base[row.base_study_id] = {
805+
"flat": flattened,
806+
"config_id": row.config_id,
807+
"version": row.version,
808+
"timestamp": timestamp,
809+
}
810+
811+
if not resolved_config and per_base:
812+
config_ids = {
813+
entry["config_id"] for entry in per_base.values() if entry["config_id"]
814+
}
815+
if len(config_ids) == 1:
816+
resolved_config = PipelineConfig.query.filter_by(
817+
id=config_ids.pop()
818+
).first()
819+
820+
resolved_version = spec["version"]
821+
if not resolved_version:
822+
versions = {
823+
entry["version"] for entry in per_base.values() if entry["version"]
824+
}
825+
if len(versions) == 1:
826+
resolved_version = versions.pop()
827+
elif len(versions) > 1:
828+
resolved_version = sorted(versions)[-1]
829+
elif resolved_config:
830+
resolved_version = resolved_config.version
831+
832+
resolved_config_id = spec["config_id"]
833+
if not resolved_config_id:
834+
if resolved_config:
835+
resolved_config_id = resolved_config.id
836+
else:
837+
config_ids = {
838+
entry["config_id"]
839+
for entry in per_base.values()
840+
if entry["config_id"]
841+
}
842+
if len(config_ids) == 1:
843+
resolved_config_id = config_ids.pop()
844+
845+
return (
846+
per_base,
847+
resolved_version or "latest",
848+
resolved_config_id or "latest",
849+
)
850+
851+
def _apply_pipeline_columns(self, annotation, data, specs, column_counter):
852+
incoming_notes = data.get("annotation_analyses") or []
853+
note_map, analysis_context = self._build_note_map(annotation, incoming_notes)
854+
855+
base_study_ids = {
856+
ctx["base_study_id"]
857+
for ctx in analysis_context.values()
858+
if ctx.get("base_study_id")
859+
}
860+
861+
column_types = {}
862+
863+
for spec in specs:
864+
if not base_study_ids:
865+
continue
866+
pipeline_data, resolved_version, resolved_config_id = (
867+
self._fetch_pipeline_data(spec, base_study_ids)
868+
)
869+
version_label = str(resolved_version or "latest").replace(" ", "_")
870+
config_label = str(resolved_config_id or "latest").replace(" ", "_")
871+
suffix_label = f"{spec['name']}_{version_label}_{config_label}"
872+
873+
for column in spec["columns"]:
874+
key_name = column
875+
if column_counter[column] > 1:
876+
key_name = f"{column}_{suffix_label}"
877+
878+
for analysis_id, payload in note_map.items():
879+
context = analysis_context.get(analysis_id) or {}
880+
base_study_id = context.get("base_study_id")
881+
entry = pipeline_data.get(base_study_id)
882+
flat_values = entry["flat"] if entry else {}
883+
value = flat_values.get(column)
884+
885+
payload.setdefault("note", {})
886+
payload["note"][key_name] = value
887+
888+
detected_type = _check_type(value)
889+
existing_type = column_types.get(key_name)
890+
if detected_type:
891+
if existing_type and existing_type != detected_type:
892+
column_types[key_name] = "string"
893+
else:
894+
column_types[key_name] = detected_type
895+
elif existing_type is None:
896+
column_types[key_name] = "string"
897+
898+
if column_types:
899+
if data.get("note_keys") is None:
900+
note_keys = dict(annotation.note_keys or {})
901+
else:
902+
note_keys = dict(data["note_keys"])
903+
for key, value_type in column_types.items():
904+
note_keys[key] = value_type or "string"
905+
data["note_keys"] = note_keys
906+
907+
data["annotation_analyses"] = list(note_map.values())
908+
547909

548910
@view_maker
549911
class BaseStudiesView(ObjectView, ListView):

0 commit comments

Comments
 (0)