|
1 | 1 | import string |
| 2 | +from copy import deepcopy |
2 | 3 | from sqlalchemy import func, text |
| 4 | +from sqlalchemy.exc import SQLAlchemyError |
3 | 5 |
|
4 | 6 | from pgvector.sqlalchemy import Vector |
5 | 7 | from flask import request |
@@ -87,6 +89,7 @@ class StudysetsView(ObjectView, ListView): |
87 | 89 | _view_fields = { |
88 | 90 | **LIST_CLONE_ARGS, |
89 | 91 | **LIST_NESTED_ARGS, |
| 92 | + "copy_annotations": fields.Boolean(load_default=True), |
90 | 93 | } |
91 | 94 | # reorg int o2m and m2o |
92 | 95 | _o2m = {"studies": "StudiesView", "annotations": "AnnotationsView"} |
@@ -213,6 +216,157 @@ def serialize_records(self, records, args): |
213 | 216 | return content |
214 | 217 | return super().serialize_records(records, args) |
215 | 218 |
|
| 219 | + def post(self): |
| 220 | + args = parser.parse(self._user_args, request, location="query") |
| 221 | + copy_annotations = args.pop("copy_annotations", True) |
| 222 | + source_id = args.get("source_id") |
| 223 | + |
| 224 | + if not source_id: |
| 225 | + return super().post() |
| 226 | + |
| 227 | + source = args.get("source") or "neurostore" |
| 228 | + if source != "neurostore": |
| 229 | + field_err = make_field_error("source", source, valid_options=["neurostore"]) |
| 230 | + abort_unprocessable( |
| 231 | + "invalid source, choose from: 'neurostore'", [field_err] |
| 232 | + ) |
| 233 | + |
| 234 | + unknown = self.__class__._schema.opts.unknown |
| 235 | + data = parser.parse( |
| 236 | + self.__class__._schema(exclude=("id",)), request, unknown=unknown |
| 237 | + ) |
| 238 | + |
| 239 | + clone_payload, source_record = self._build_clone_payload(source_id, data) |
| 240 | + |
| 241 | + # ensure nested serialization when cloning |
| 242 | + args["nested"] = bool(args.get("nested") or request.args.get("source_id")) |
| 243 | + |
| 244 | + with db.session.no_autoflush: |
| 245 | + record = self.__class__.update_or_create(clone_payload) |
| 246 | + |
| 247 | + unique_ids = self.get_affected_ids([record.id]) |
| 248 | + clear_cache(unique_ids) |
| 249 | + |
| 250 | + db.session.flush() |
| 251 | + |
| 252 | + self.update_base_studies(unique_ids.get("base-studies")) |
| 253 | + |
| 254 | + try: |
| 255 | + if copy_annotations: |
| 256 | + self._clone_annotations(source_record, record) |
| 257 | + self.update_annotations(unique_ids.get("annotations")) |
| 258 | + except SQLAlchemyError as e: |
| 259 | + db.session.rollback() |
| 260 | + abort_validation(str(e)) |
| 261 | + |
| 262 | + response_context = dict(args) |
| 263 | + response = self.__class__._schema(context=response_context).dump(record) |
| 264 | + |
| 265 | + db.session.commit() |
| 266 | + |
| 267 | + return response |
| 268 | + |
| 269 | + def _build_clone_payload(self, source_id, override_data): |
| 270 | + source_record = ( |
| 271 | + Studyset.query.options( |
| 272 | + selectinload(Studyset.studies), |
| 273 | + selectinload(Studyset.annotations).options( |
| 274 | + selectinload(Annotation.annotation_analyses) |
| 275 | + ), |
| 276 | + ) |
| 277 | + .filter_by(id=source_id) |
| 278 | + .first() |
| 279 | + ) |
| 280 | + |
| 281 | + if source_record is None: |
| 282 | + abort_not_found(Studyset.__name__, source_id) |
| 283 | + |
| 284 | + payload = { |
| 285 | + "name": source_record.name, |
| 286 | + "description": source_record.description, |
| 287 | + "publication": source_record.publication, |
| 288 | + "doi": source_record.doi, |
| 289 | + "pmid": source_record.pmid, |
| 290 | + "authors": source_record.authors, |
| 291 | + "metadata_": ( |
| 292 | + deepcopy(source_record.metadata_) |
| 293 | + if source_record.metadata_ is not None |
| 294 | + else None |
| 295 | + ), |
| 296 | + "public": source_record.public, |
| 297 | + "studies": [{"id": study.id} for study in source_record.studies], |
| 298 | + "source": "neurostore", |
| 299 | + "source_id": self._resolve_neurostore_origin(source_record), |
| 300 | + "source_updated_at": source_record.updated_at or source_record.created_at, |
| 301 | + } |
| 302 | + |
| 303 | + if payload.get("metadata_") is None: |
| 304 | + payload.pop("metadata_", None) |
| 305 | + |
| 306 | + if override_data: |
| 307 | + payload.update(override_data) |
| 308 | + |
| 309 | + return payload, source_record |
| 310 | + |
| 311 | + def _clone_annotations(self, source_record, cloned_record): |
| 312 | + if not source_record.annotations: |
| 313 | + return |
| 314 | + |
| 315 | + owner_id = cloned_record.user_id |
| 316 | + |
| 317 | + for annotation in source_record.annotations: |
| 318 | + clone_annotation = Annotation( |
| 319 | + name=annotation.name, |
| 320 | + description=annotation.description, |
| 321 | + source="neurostore", |
| 322 | + source_id=self._resolve_neurostore_origin(annotation), |
| 323 | + source_updated_at=annotation.updated_at or annotation.created_at, |
| 324 | + user_id=owner_id, |
| 325 | + metadata_=( |
| 326 | + deepcopy(annotation.metadata_) if annotation.metadata_ else None |
| 327 | + ), |
| 328 | + public=annotation.public, |
| 329 | + note_keys=( |
| 330 | + deepcopy(annotation.note_keys) if annotation.note_keys else {} |
| 331 | + ), |
| 332 | + ) |
| 333 | + clone_annotation.studyset = cloned_record |
| 334 | + db.session.add(clone_annotation) |
| 335 | + db.session.flush() |
| 336 | + |
| 337 | + analyses_to_create = [] |
| 338 | + for aa in annotation.annotation_analyses: |
| 339 | + analyses_to_create.append( |
| 340 | + AnnotationAnalysis( |
| 341 | + annotation_id=clone_annotation.id, |
| 342 | + analysis_id=aa.analysis_id, |
| 343 | + note=deepcopy(aa.note) if aa.note else {}, |
| 344 | + user_id=owner_id, |
| 345 | + study_id=aa.study_id, |
| 346 | + studyset_id=cloned_record.id, |
| 347 | + ) |
| 348 | + ) |
| 349 | + |
| 350 | + if analyses_to_create: |
| 351 | + db.session.add_all(analyses_to_create) |
| 352 | + |
| 353 | + @staticmethod |
| 354 | + def _resolve_neurostore_origin(record): |
| 355 | + source_id = record.id |
| 356 | + parent_source_id = record.source_id |
| 357 | + parent_source = getattr(record, "source", None) |
| 358 | + Model = type(record) |
| 359 | + |
| 360 | + while parent_source_id is not None and parent_source == "neurostore": |
| 361 | + source_id = parent_source_id |
| 362 | + parent = Model.query.filter_by(id=parent_source_id).first() |
| 363 | + if parent is None: |
| 364 | + break |
| 365 | + parent_source_id = parent.source_id |
| 366 | + parent_source = getattr(parent, "source", None) |
| 367 | + |
| 368 | + return source_id |
| 369 | + |
216 | 370 |
|
217 | 371 | @view_maker |
218 | 372 | class AnnotationsView(ObjectView, ListView): |
|
0 commit comments