diff --git a/compose/backend/neurosynth_compose/resources/analysis.py b/compose/backend/neurosynth_compose/resources/analysis.py index 1bce1445..115b6aae 100644 --- a/compose/backend/neurosynth_compose/resources/analysis.py +++ b/compose/backend/neurosynth_compose/resources/analysis.py @@ -317,21 +317,21 @@ def insert_data(self, id, data): LIST_USER_ARGS = { - "search": fields.String(missing=None), - "sort": fields.String(missing="created_at"), - "page": fields.Int(missing=1), - "desc": fields.Boolean(missing=True), - "page_size": fields.Int(missing=20, validate=lambda val: val < 100), - "source_id": fields.String(missing=None), - "source": fields.String(missing=None), - "unique": fields.Boolean(missing=False), - "nested": fields.Boolean(missing=False), - "user_id": fields.String(missing=None), - "dataset_id": fields.String(missing=None), - "export": fields.Boolean(missing=False), - "data_type": fields.String(missing=None), - "info": fields.Boolean(missing=False), - "ids": fields.List(fields.String(), missing=None), + "search": fields.String(load_default=None), + "sort": fields.String(load_default="created_at"), + "page": fields.Int(load_default=1), + "desc": fields.Boolean(load_default=True), + "page_size": fields.Int(load_default=20, validate=lambda val: val < 100), + "source_id": fields.String(load_default=None), + "source": fields.String(load_default=None), + "unique": fields.Boolean(load_default=False), + "nested": fields.Boolean(load_default=False), + "user_id": fields.String(load_default=None), + "dataset_id": fields.String(load_default=None), + "export": fields.Boolean(load_default=False), + "data_type": fields.String(load_default=None), + "info": fields.Boolean(load_default=False), + "ids": fields.List(fields.String(), load_default=None), } @@ -702,8 +702,8 @@ def db_validation(self, data): def post(self): clone_args = parser.parse( { - "source_id": fields.String(missing=None), - "copy_annotations": fields.Boolean(missing=True), + "source_id": fields.String(load_default=None), + "copy_annotations": fields.Boolean(load_default=True), }, request, location="query", diff --git a/compose/backend/neurosynth_compose/schemas/analysis.py b/compose/backend/neurosynth_compose/schemas/analysis.py index 79b688e7..7e9f7867 100644 --- a/compose/backend/neurosynth_compose/schemas/analysis.py +++ b/compose/backend/neurosynth_compose/schemas/analysis.py @@ -8,6 +8,13 @@ NS_BASE = "https://neurostore.org/api" +class ContextSchema(Schema): + def __init__(self, *args, **kwargs): + context = kwargs.pop("context", {}) + super().__init__(*args, **kwargs) + self.context = context or {} + + class BytesField(fields.Field): def _deserialize(self, value, attr, data, **kwargs): if isinstance(value, str): @@ -29,7 +36,7 @@ def _deserialize(self, value, attr, data, **kwargs): return result.replace("\x00", "\uFFFD") -class ResultInitSchema(Schema): +class ResultInitSchema(ContextSchema): meta_analysis_id = fields.String(load_only=True) meta_analysis = fields.Pluck( "MetaAnalysisSchema", "id", attribute="meta_analysis", dump_only=True @@ -39,15 +46,15 @@ class ResultInitSchema(Schema): specification_snapshot = fields.Dict() -class ResultUploadSchema(Schema): - statistical_maps = fields.Raw( - metadata={"type": "string", "format": "binary"}, many=True +class ResultUploadSchema(ContextSchema): + statistical_maps = fields.List( + fields.Raw(metadata={"type": "string", "format": "binary"}) ) - cluster_tables = fields.Raw( - metadata={"type": "string", "format": "binary"}, many=True + cluster_tables = fields.List( + fields.Raw(metadata={"type": "string", "format": "binary"}) ) - diagnostic_tables = fields.Raw( - metadata={"type": "string", "format": "binary"}, many=True + diagnostic_tables = fields.List( + fields.Raw(metadata={"type": "string", "format": "binary"}) ) method_description = fields.String() @@ -59,6 +66,23 @@ class StringOrNested(fields.Nested): "invalid_utf8": "Not a valid utf-8 string.", } + def __init__(self, nested, *args, **kwargs): + super().__init__(nested, **kwargs) + self._explicit_context = {} + + @property + def context(self): + if self._explicit_context: + return self._explicit_context + parent = getattr(self, "parent", None) + if parent is not None and hasattr(parent, "context"): + return parent.context + return {} + + @context.setter + def context(self, value): + self._explicit_context = value or {} + def _serialize(self, value, attr, obj, **kwargs): if value is None: return None @@ -127,7 +151,7 @@ def _deserialize(self, value, attr, data, **kwargs): raise self.make_error("invalid_utf8") from error -class BaseSchema(Schema): +class BaseSchema(ContextSchema): id = PGSQLString(metadata={"info_field": True}) created_at = fields.DateTime() updated_at = fields.DateTime(allow_none=True) @@ -137,7 +161,7 @@ class BaseSchema(Schema): ) -class ConditionSchema(Schema): +class ConditionSchema(ContextSchema): id = PGSQLString() created_at = fields.DateTime() updated_at = fields.DateTime(allow_none=True) @@ -147,15 +171,15 @@ class ConditionSchema(Schema): class SpecificationConditionSchema(BaseSchema): condition = fields.Pluck(ConditionSchema, "name") - weight = fields.Number() + weight = fields.Float() -class EstimatorSchema(Schema): +class EstimatorSchema(ContextSchema): type = fields.String() args = fields.Dict() -class StudysetReferenceSchema(Schema): +class StudysetReferenceSchema(ContextSchema): id = PGSQLString() created_at = fields.DateTime() updated_at = fields.DateTime(allow_none=True) @@ -168,7 +192,7 @@ class StudysetReferenceSchema(Schema): ) -class AnnotationReferenceSchema(Schema): +class AnnotationReferenceSchema(ContextSchema): id = PGSQLString() @@ -177,6 +201,8 @@ class SpecificationSchema(BaseSchema): mask = PGSQLString(allow_none=True) transformer = PGSQLString(allow_none=True) estimator = fields.Nested("EstimatorSchema") + name = PGSQLString(allow_none=True) + description = PGSQLString(allow_none=True) database_studyset = PGSQLString(allow_none=True) contrast = PGSQLString(allow_none=True) filter = PGSQLString(allow_none=True) @@ -208,10 +234,6 @@ class SpecificationSchema(BaseSchema): attribute="weights", ) - class Meta: - additional = ("name", "description") - allow_none = ("name", "description") - @post_dump def to_bool(self, data, **kwargs): conditions = data.get("conditions", None) diff --git a/compose/backend/neurosynth_compose/schemas/users.py b/compose/backend/neurosynth_compose/schemas/users.py index 3472c29d..293cf5b4 100644 --- a/compose/backend/neurosynth_compose/schemas/users.py +++ b/compose/backend/neurosynth_compose/schemas/users.py @@ -7,8 +7,10 @@ class UserSchema(BaseSchema): - name = fields.Str(description="User full name") - external_id = fields.Str(description="External authentication service user ID") + name = fields.Str(metadata={"description": "User full name"}) + external_id = fields.Str( + metadata={"description": "External authentication service user ID"} + ) class Meta: unknown = EXCLUDE diff --git a/compose/backend/pyproject.toml b/compose/backend/pyproject.toml index dad0c19e..a6aedde8 100644 --- a/compose/backend/pyproject.toml +++ b/compose/backend/pyproject.toml @@ -24,7 +24,7 @@ dependencies = [ "flask-sqlalchemy~=3.1", "gunicorn~=23.0", "ipython~=9.6", - "marshmallow~=3.0", + "marshmallow~=4.0", "pandas~=2.0", "psycopg2-binary~=2.8", "pyld~=2.0", diff --git a/store/backend/neurostore/resources/data.py b/store/backend/neurostore/resources/data.py index 7a3195a5..9e92c132 100644 --- a/store/backend/neurostore/resources/data.py +++ b/store/backend/neurostore/resources/data.py @@ -506,6 +506,11 @@ def load_from_neurostore(cls, source_id, data=None): schema = cls._schema(context=context) tmp_data = schema.dump(annotation) data = schema.load(tmp_data) + # Ensure cloned payload does not reference original primary keys + data.pop("id", None) + for note in data.get("annotation_analyses") or []: + if isinstance(note, dict): + note.pop("id", None) data["source"] = "neurostore" data["source_id"] = source_id data["source_updated_at"] = annotation.updated_at or annotation.created_at diff --git a/store/backend/neurostore/resources/pipeline.py b/store/backend/neurostore/resources/pipeline.py index fc2a1c7f..af73e0aa 100644 --- a/store/backend/neurostore/resources/pipeline.py +++ b/store/backend/neurostore/resources/pipeline.py @@ -96,8 +96,10 @@ class PipelineStudyResultsView(ObjectView, ListView): "feature_flatten": fields.Bool(load_default=False), "feature_display": fields.List( fields.String(), - description="List of pipeline results. format: pipeline_name[:version]", load_default=[], + metadata={ + "description": "List of pipeline results. format: pipeline_name[:version]" + }, ), } diff --git a/store/backend/neurostore/schemas/data.py b/store/backend/neurostore/schemas/data.py index 4611b82c..0265629c 100644 --- a/store/backend/neurostore/schemas/data.py +++ b/store/backend/neurostore/schemas/data.py @@ -44,18 +44,41 @@ def _deserialize(self, value, attr, data, **kwargs): class ObjToString(fields.Field): def __init__(self, *args, **kwargs): + self.many = kwargs.pop("many", False) super().__init__(*args, **kwargs) - self.many = kwargs.get("many", False) + + def _serialize_single(self, value): + if value is None: + return None + if isinstance(value, dict): + # Already serialized payload + return value.get("id") if "id" in value else value + if hasattr(value, "id"): + return str(value.id) + return str(value) def _serialize(self, value, attr, obj, **kwargs): if self.many: - return [v.id for v in value] - return str(value.id) + return [self._serialize_single(v) for v in value] + return self._serialize_single(value) def _deserialize(self, value, attr, data, **kwargs): if self.many: - return [{"id": v} if isinstance(v, str) else v for v in value] - return {"id": value} if isinstance(value, str) else value + return [ + ( + {"id": v} + if isinstance(v, (str, int)) + else (v if isinstance(v, dict) else {"id": getattr(v, "id", v)}) + ) + for v in value + ] + if isinstance(value, (str, int)): + return {"id": value} + if isinstance(value, dict): + return value + if hasattr(value, "id"): + return {"id": value.id} + return {"id": value} class StringOrNested(fields.Nested): @@ -64,6 +87,20 @@ class StringOrNested(fields.Nested): def __init__(self, nested, *args, **kwargs): super().__init__(nested, **kwargs) self.string_field = ObjToString(*args, **kwargs) + self._explicit_context = {} + + @property + def context(self): + if self._explicit_context: + return self._explicit_context + parent = getattr(self, "parent", None) + if parent is not None and hasattr(parent, "context"): + return parent.context + return {} + + @context.setter + def context(self, value): + self._explicit_context = value or {} def _modify_schema(self): """Only relevant when nested=True""" @@ -157,7 +194,7 @@ def __init__(self, meta, *args, **kwargs): class BaseSchema(Schema): def __init__(self, *args, **kwargs): exclude = kwargs.get("exclude") or self.opts.exclude - context = kwargs.get("context", {}) + context = kwargs.pop("context", {}) only = kwargs.get("only") # if cloning and not only id, exclude id fields (unless preserve_on_clone is True) @@ -200,6 +237,7 @@ def __init__(self, *args, **kwargs): exclude += (f,) kwargs["exclude"] = exclude super().__init__(*args, **kwargs) + self.context = context or {} OPTIONS_CLASS = BaseSchemaOpts # normal return key @@ -220,16 +258,15 @@ class BaseDataSchema(BaseSchema): attribute="user.name", dump_only=True, metadata={"info_field": True}, - default=None, + dump_default=None, ) created_at = fields.DateTime(dump_only=True, metadata={"info_field": True}) updated_at = fields.DateTime(dump_only=True, metadata={"info_field": True}) class ConditionSchema(BaseDataSchema): - class Meta: - additional = ("name", "description") - allow_none = ("name", "description") + name = fields.String(allow_none=True) + description = fields.String(allow_none=True) # Override the id field to preserve it during cloning id = fields.String( @@ -240,9 +277,8 @@ class Meta: class EntitySchema(BaseDataSchema): analysis_id = fields.String(data_key="analysis", metadata={"id_field": True}) - class Meta: - additional = ("level", "label") - allow_none = ("level", "label") + level = fields.String(allow_none=True) + label = fields.String(allow_none=True) class ImageSchema(BaseDataSchema): @@ -251,15 +287,15 @@ class ImageSchema(BaseDataSchema): # analysis = fields.Pluck("AnalysisSchema", "id", metadata={"id_field": True}) analysis_name = fields.String(allow_none=True, dump_only=True) add_date = fields.DateTime(dump_only=True) - - class Meta: - additional = ("url", "filename", "space", "value_type") - allow_none = ("url", "filename", "space", "value_type") + url = fields.String(allow_none=True) + filename = fields.String(allow_none=True) + space = fields.String(allow_none=True) + value_type = fields.String(allow_none=True) class PointValueSchema(BaseSchema): - class Meta: - additional = allow_none = ("kind", "value") + kind = fields.String(allow_none=True) + value = fields.Float(allow_none=True) class PointSchema(BaseDataSchema): @@ -270,19 +306,19 @@ class PointSchema(BaseDataSchema): entities = fields.Nested(EntitySchema, many=True, load_only=True) cluster_size = fields.Float(allow_none=True) subpeak = fields.Boolean(allow_none=True) - deactivation = fields.Boolean(missing=False, allow_none=True) + deactivation = fields.Boolean(load_default=False, allow_none=True) order = fields.Integer() coordinates = fields.List(fields.Float(), dump_only=True) + kind = fields.String(allow_none=True) + space = fields.String(allow_none=True) + image = fields.String(allow_none=True) + label_id = fields.Float(allow_none=True) # deserialization x = fields.Float(load_only=True, allow_none=True) y = fields.Float(load_only=True, allow_none=True) z = fields.Float(load_only=True, allow_none=True) - class Meta: - additional = ("kind", "space", "image", "label_id") - allow_none = ("kind", "space", "image", "label_id", "x", "y", "z") - @pre_load def process_values(self, data, **kwargs): # Handle case where data might be a string ID instead of dict @@ -366,10 +402,8 @@ class AnalysisSchema(BaseDataSchema): points = StringOrNested(PointSchema, many=True) weights = fields.List(fields.Float()) entities = fields.Nested(EntitySchema, many=True, load_only=True) - - class Meta: - additional = ("name", "description") - allow_none = ("name", "description") + name = fields.String(allow_none=True) + description = fields.String(allow_none=True) @pre_load def load_values(self, data, **kwargs): @@ -419,6 +453,15 @@ class StudySetStudyInfoSchema(Schema): class BaseStudySchema(BaseDataSchema): metadata = fields.Dict(attribute="metadata_", dump_only=True) metadata_ = fields.Dict(data_key="metadata", load_only=True, allow_none=True) + name = fields.String(allow_none=True) + description = fields.String(allow_none=True) + publication = fields.String(allow_none=True) + doi = fields.String(allow_none=True) + pmid = fields.String(allow_none=True) + pmcid = fields.String(allow_none=True) + authors = fields.String(allow_none=True) + year = fields.Integer(allow_none=True) + level = fields.String(allow_none=True) versions = StringOrNested("StudySchema", many=True) features = fields.Method("get_features") ace_fulltext = fields.String(load_only=True, allow_none=True) @@ -448,30 +491,6 @@ def get_features(self, obj): return features - class Meta: - additional = ( - "name", - "description", - "publication", - "doi", - "pmid", - "pmcid", - "authors", - "year", - "level", - ) - allow_none = ( - "name", - "description", - "publication", - "doi", - "pmid", - "pmcid", - "authors", - "year", - "level", - ) - @pre_load def check_nulls(self, data, **kwargs): """ @@ -512,6 +531,15 @@ def check_nulls(self, data, **kwargs): class StudySchema(BaseDataSchema): metadata = fields.Dict(attribute="metadata_", dump_only=True) metadata_ = fields.Dict(data_key="metadata", load_only=True, allow_none=True) + name = fields.String(allow_none=True) + description = fields.String(allow_none=True) + publication = fields.String(allow_none=True) + doi = fields.String(allow_none=True) + pmid = fields.String(allow_none=True) + pmcid = fields.String(allow_none=True) + authors = fields.String(allow_none=True) + year = fields.Integer(allow_none=True) + level = fields.String(allow_none=True) analyses = StringOrNested(AnalysisSchema, many=True) source = fields.String( dump_only=True, metadata={"info_field": True}, allow_none=True @@ -532,28 +560,6 @@ class StudySchema(BaseDataSchema): class Meta: # by default exclude this exclude = ("has_coordinates", "has_images", "studysets") - additional = ( - "name", - "description", - "publication", - "doi", - "pmid", - "pmcid", - "authors", - "year", - "level", - ) - allow_none = ( - "name", - "description", - "publication", - "doi", - "pmid", - "pmcid", - "authors", - "year", - "level", - ) @pre_load def check_nulls(self, data, **kwargs): @@ -573,10 +579,13 @@ class StudysetSchema(BaseDataSchema): source = fields.String(dump_only=True, allow_none=True) source_id = fields.String(dump_only=True, allow_none=True) source_updated_at = fields.DateTime(dump_only=True, allow_none=True) + name = fields.String(allow_none=True) + description = fields.String(allow_none=True) + publication = fields.String(allow_none=True) + doi = fields.String(allow_none=True) + pmid = fields.String(allow_none=True) class Meta: - additional = ("name", "description", "publication", "doi", "pmid") - allow_none = ("name", "description", "publication", "doi", "pmid") render_module = orjson @@ -644,10 +653,8 @@ class AnnotationSchema(BaseDataSchema): metadata = fields.Dict(attribute="metadata_", dump_only=True) # deserialization metadata_ = fields.Dict(data_key="metadata", load_only=True, allow_none=True) - - class Meta: - additional = ("name", "description") - allow_none = ("name", "description") + name = fields.String(allow_none=True) + description = fields.String(allow_none=True) @pre_load def add_studyset_id(self, data, **kwargs): diff --git a/store/backend/neurostore/schemas/pipeline.py b/store/backend/neurostore/schemas/pipeline.py index 07406b1e..9df6a9a9 100644 --- a/store/backend/neurostore/schemas/pipeline.py +++ b/store/backend/neurostore/schemas/pipeline.py @@ -51,20 +51,25 @@ class PipelineStudyResultSchema(BaseSchema): # Execution metadata date_executed = fields.DateTime( - dump_only=True, description="Timestamp of pipeline execution", allow_none=True + dump_only=True, + allow_none=True, + metadata={"description": "Timestamp of pipeline execution"}, ) # Result and input data - result_data = fields.Dict(description="Pipeline execution results", allow_none=True) + result_data = fields.Dict( + allow_none=True, metadata={"description": "Pipeline execution results"} + ) file_inputs = fields.Dict( - description="Files used as input for the pipeline", allow_none=True + allow_none=True, + metadata={"description": "Files used as input for the pipeline"}, ) # Pipeline execution status status = fields.Str( validate=lambda x: x in ["SUCCESS", "FAILURE", "ERROR", "UNKNOWN"], required=True, - description="Current status of the pipeline execution", + metadata={"description": "Current status of the pipeline execution"}, ) class Meta: @@ -126,27 +131,30 @@ class PipelineEmbeddingSchema(BaseSchema): # Execution metadata date_executed = fields.DateTime( - dump_only=True, description="Timestamp of pipeline execution", allow_none=True + dump_only=True, + allow_none=True, + metadata={"description": "Timestamp of pipeline execution"}, ) # Result and input data file_inputs = fields.Dict( - description="Files used as input for the pipeline", allow_none=True + allow_none=True, + metadata={"description": "Files used as input for the pipeline"}, ) # Pipeline execution status status = fields.Str( validate=lambda x: x in ["SUCCESS", "FAILURE", "ERROR", "UNKNOWN"], required=True, - description="Current status of the pipeline execution", + metadata={"description": "Current status of the pipeline execution"}, ) # Vector embedding data embedding = fields.List( fields.Float(), required=True, - description="Vector embedding data", allow_none=False, + metadata={"description": "Vector embedding data"}, ) class Meta: diff --git a/store/backend/neurostore/tests/api/test_annotations.py b/store/backend/neurostore/tests/api/test_annotations.py index 45274c8e..f43f216e 100644 --- a/store/backend/neurostore/tests/api/test_annotations.py +++ b/store/backend/neurostore/tests/api/test_annotations.py @@ -18,7 +18,9 @@ def test_post_blank_annotation(auth_client, ingest_neurosynth, session): assert annot.annotation_analyses[0].user_id == annot.user_id -def test_blank_annotation_populates_note_fields(auth_client, ingest_neurosynth, session): +def test_blank_annotation_populates_note_fields( + auth_client, ingest_neurosynth, session +): dset = Studyset.query.first() note_keys = {"included": "boolean", "quality": "string"} payload = { diff --git a/store/backend/pyproject.toml b/store/backend/pyproject.toml index d12d0d63..df7eb5e5 100644 --- a/store/backend/pyproject.toml +++ b/store/backend/pyproject.toml @@ -22,7 +22,7 @@ dependencies = [ "flask-sqlalchemy~=3.1", # fix multiple instance error: https://flask-sqlalchemy.palletsprojects.com/en/3.0.x/changes/#version-3-0-3 "gunicorn~=23.0", "ipython~=8.18", - "marshmallow~=3.0", + "marshmallow~=4.0", "numpy<2.0.0", "pandas~=2.3", "psycopg2-binary~=2.8",