Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 17 additions & 17 deletions compose/backend/neurosynth_compose/resources/analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -317,21 +317,21 @@ def insert_data(self, id, data):


LIST_USER_ARGS = {
"search": fields.String(missing=None),
"sort": fields.String(missing="created_at"),
"page": fields.Int(missing=1),
"desc": fields.Boolean(missing=True),
"page_size": fields.Int(missing=20, validate=lambda val: val < 100),
"source_id": fields.String(missing=None),
"source": fields.String(missing=None),
"unique": fields.Boolean(missing=False),
"nested": fields.Boolean(missing=False),
"user_id": fields.String(missing=None),
"dataset_id": fields.String(missing=None),
"export": fields.Boolean(missing=False),
"data_type": fields.String(missing=None),
"info": fields.Boolean(missing=False),
"ids": fields.List(fields.String(), missing=None),
"search": fields.String(load_default=None),
"sort": fields.String(load_default="created_at"),
"page": fields.Int(load_default=1),
"desc": fields.Boolean(load_default=True),
"page_size": fields.Int(load_default=20, validate=lambda val: val < 100),
"source_id": fields.String(load_default=None),
"source": fields.String(load_default=None),
"unique": fields.Boolean(load_default=False),
"nested": fields.Boolean(load_default=False),
"user_id": fields.String(load_default=None),
"dataset_id": fields.String(load_default=None),
"export": fields.Boolean(load_default=False),
"data_type": fields.String(load_default=None),
"info": fields.Boolean(load_default=False),
"ids": fields.List(fields.String(), load_default=None),
}


Expand Down Expand Up @@ -702,8 +702,8 @@ def db_validation(self, data):
def post(self):
clone_args = parser.parse(
{
"source_id": fields.String(missing=None),
"copy_annotations": fields.Boolean(missing=True),
"source_id": fields.String(load_default=None),
"copy_annotations": fields.Boolean(load_default=True),
},
request,
location="query",
Expand Down
58 changes: 40 additions & 18 deletions compose/backend/neurosynth_compose/schemas/analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,13 @@
NS_BASE = "https://neurostore.org/api"


class ContextSchema(Schema):
def __init__(self, *args, **kwargs):
context = kwargs.pop("context", {})
super().__init__(*args, **kwargs)
self.context = context or {}


class BytesField(fields.Field):
def _deserialize(self, value, attr, data, **kwargs):
if isinstance(value, str):
Expand All @@ -29,7 +36,7 @@ def _deserialize(self, value, attr, data, **kwargs):
return result.replace("\x00", "\uFFFD")


class ResultInitSchema(Schema):
class ResultInitSchema(ContextSchema):
meta_analysis_id = fields.String(load_only=True)
meta_analysis = fields.Pluck(
"MetaAnalysisSchema", "id", attribute="meta_analysis", dump_only=True
Expand All @@ -39,15 +46,15 @@ class ResultInitSchema(Schema):
specification_snapshot = fields.Dict()


class ResultUploadSchema(Schema):
statistical_maps = fields.Raw(
metadata={"type": "string", "format": "binary"}, many=True
class ResultUploadSchema(ContextSchema):
statistical_maps = fields.List(
fields.Raw(metadata={"type": "string", "format": "binary"})
)
cluster_tables = fields.Raw(
metadata={"type": "string", "format": "binary"}, many=True
cluster_tables = fields.List(
fields.Raw(metadata={"type": "string", "format": "binary"})
)
diagnostic_tables = fields.Raw(
metadata={"type": "string", "format": "binary"}, many=True
diagnostic_tables = fields.List(
fields.Raw(metadata={"type": "string", "format": "binary"})
)
method_description = fields.String()

Expand All @@ -59,6 +66,23 @@ class StringOrNested(fields.Nested):
"invalid_utf8": "Not a valid utf-8 string.",
}

def __init__(self, nested, *args, **kwargs):
super().__init__(nested, **kwargs)
self._explicit_context = {}

@property
def context(self):
if self._explicit_context:
return self._explicit_context
parent = getattr(self, "parent", None)
if parent is not None and hasattr(parent, "context"):
return parent.context
return {}

@context.setter
def context(self, value):
self._explicit_context = value or {}

def _serialize(self, value, attr, obj, **kwargs):
if value is None:
return None
Expand Down Expand Up @@ -127,7 +151,7 @@ def _deserialize(self, value, attr, data, **kwargs):
raise self.make_error("invalid_utf8") from error


class BaseSchema(Schema):
class BaseSchema(ContextSchema):
id = PGSQLString(metadata={"info_field": True})
created_at = fields.DateTime()
updated_at = fields.DateTime(allow_none=True)
Expand All @@ -137,7 +161,7 @@ class BaseSchema(Schema):
)


class ConditionSchema(Schema):
class ConditionSchema(ContextSchema):
id = PGSQLString()
created_at = fields.DateTime()
updated_at = fields.DateTime(allow_none=True)
Expand All @@ -147,15 +171,15 @@ class ConditionSchema(Schema):

class SpecificationConditionSchema(BaseSchema):
condition = fields.Pluck(ConditionSchema, "name")
weight = fields.Number()
weight = fields.Float()


class EstimatorSchema(Schema):
class EstimatorSchema(ContextSchema):
type = fields.String()
args = fields.Dict()


class StudysetReferenceSchema(Schema):
class StudysetReferenceSchema(ContextSchema):
id = PGSQLString()
created_at = fields.DateTime()
updated_at = fields.DateTime(allow_none=True)
Expand All @@ -168,7 +192,7 @@ class StudysetReferenceSchema(Schema):
)


class AnnotationReferenceSchema(Schema):
class AnnotationReferenceSchema(ContextSchema):
id = PGSQLString()


Expand All @@ -177,6 +201,8 @@ class SpecificationSchema(BaseSchema):
mask = PGSQLString(allow_none=True)
transformer = PGSQLString(allow_none=True)
estimator = fields.Nested("EstimatorSchema")
name = PGSQLString(allow_none=True)
description = PGSQLString(allow_none=True)
database_studyset = PGSQLString(allow_none=True)
contrast = PGSQLString(allow_none=True)
filter = PGSQLString(allow_none=True)
Expand Down Expand Up @@ -208,10 +234,6 @@ class SpecificationSchema(BaseSchema):
attribute="weights",
)

class Meta:
additional = ("name", "description")
allow_none = ("name", "description")

@post_dump
def to_bool(self, data, **kwargs):
conditions = data.get("conditions", None)
Expand Down
6 changes: 4 additions & 2 deletions compose/backend/neurosynth_compose/schemas/users.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,10 @@


class UserSchema(BaseSchema):
name = fields.Str(description="User full name")
external_id = fields.Str(description="External authentication service user ID")
name = fields.Str(metadata={"description": "User full name"})
external_id = fields.Str(
metadata={"description": "External authentication service user ID"}
)

class Meta:
unknown = EXCLUDE
2 changes: 1 addition & 1 deletion compose/backend/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ dependencies = [
"flask-sqlalchemy~=3.1",
"gunicorn~=23.0",
"ipython~=9.6",
"marshmallow~=3.0",
"marshmallow~=4.0",
"pandas~=2.0",
"psycopg2-binary~=2.8",
"pyld~=2.0",
Expand Down
5 changes: 5 additions & 0 deletions store/backend/neurostore/resources/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -506,6 +506,11 @@ def load_from_neurostore(cls, source_id, data=None):
schema = cls._schema(context=context)
tmp_data = schema.dump(annotation)
data = schema.load(tmp_data)
# Ensure cloned payload does not reference original primary keys
data.pop("id", None)
for note in data.get("annotation_analyses") or []:
if isinstance(note, dict):
note.pop("id", None)
data["source"] = "neurostore"
data["source_id"] = source_id
data["source_updated_at"] = annotation.updated_at or annotation.created_at
Expand Down
4 changes: 3 additions & 1 deletion store/backend/neurostore/resources/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,8 +96,10 @@ class PipelineStudyResultsView(ObjectView, ListView):
"feature_flatten": fields.Bool(load_default=False),
"feature_display": fields.List(
fields.String(),
description="List of pipeline results. format: pipeline_name[:version]",
load_default=[],
metadata={
"description": "List of pipeline results. format: pipeline_name[:version]"
},
),
}

Expand Down
Loading
Loading