Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 37 additions & 7 deletions compose_runner/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,10 @@
import json
import io
import pickle
from datetime import date, datetime
from importlib import import_module
from pathlib import Path
from uuid import UUID

import requests
import neurosynth_compose_sdk
Expand Down Expand Up @@ -227,7 +229,7 @@ def _get_entity_snapshot_record(self, entity_name, documents):
# Old API format: list of {id, md5} snapshot summaries
ref_document = document.get(ref_key)
if isinstance(ref_document, dict):
for summary_document in (ref_document.get(summary_key) or []):
for summary_document in ref_document.get(summary_key) or []:
snapshot_id = self._extract_document_id(summary_document)
if snapshot_id is not None:
break
Expand All @@ -239,9 +241,11 @@ def _get_entity_snapshot_record(self, entity_name, documents):
id=snapshot_id
).to_dict()
else:
snapshot_document = self.compose_api.snapshot_annotations_id_get(
id=snapshot_id
).to_dict()
snapshot_document = (
self.compose_api.snapshot_annotations_id_get(
id=snapshot_id
).to_dict()
)
except ComposeApiException:
continue
payload = self._unwrap_snapshot(snapshot_document)
Expand Down Expand Up @@ -357,12 +361,36 @@ def _apply_entity_records(self, records):
self.existing_annotation_snapshot_id = records["annotation"]["snapshot_id"]

@staticmethod
def _snapshot_md5(payload):
serialized_payload = json.dumps(
def _json_payload_default(value):
if isinstance(value, (date, datetime)):
return value.isoformat()
if isinstance(value, UUID):
return str(value)
if isinstance(value, set):
return sorted(value, key=str)
to_dict = getattr(value, "to_dict", None)
if callable(to_dict):
return to_dict()
raise TypeError(
f"Object of type {value.__class__.__name__} is not JSON serializable"
)

@classmethod
def _snapshot_json(cls, payload):
return json.dumps(
payload,
default=cls._json_payload_default,
sort_keys=True,
separators=(",", ":"),
)

@classmethod
def _json_safe_payload(cls, payload):
return json.loads(cls._snapshot_json(payload))

@classmethod
def _snapshot_md5(cls, payload):
serialized_payload = cls._snapshot_json(payload)
return hashlib.md5(serialized_payload.encode("utf-8")).hexdigest()

def _should_link_existing_snapshot(
Expand Down Expand Up @@ -587,7 +615,9 @@ def create_result_object(self):
):
kwargs[f"snapshot_{entity_name}_id"] = existing_id
else:
kwargs[f"snapshot_{entity_name}"] = live_payload
kwargs[f"snapshot_{entity_name}"] = self._json_safe_payload(
live_payload
)

self._compose_config.api_key["upload_key"] = self.nsc_key
result = self.compose_api.meta_analysis_results_post(
Expand Down
73 changes: 73 additions & 0 deletions compose_runner/tests/test_run.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
from datetime import date, datetime, timezone
from uuid import UUID

import pytest
from neurosynth_compose_sdk.exceptions import ApiException as ComposeApiException

Expand Down Expand Up @@ -27,6 +30,76 @@ def test_download_bundle():
assert runner.cached_specification is not None


def test_snapshot_md5_serializes_sdk_scalars_like_api_strings():
created_at = datetime(2023, 6, 19, 15, 29, 59, 132810, tzinfo=timezone.utc)
live_payload = {
"created_at": created_at,
"id": UUID("00000000-0000-0000-0000-000000000001"),
"studies": [{"created_at": created_at, "publication_date": date(2023, 6, 19)}],
}
api_payload = {
"created_at": "2023-06-19T15:29:59.132810+00:00",
"id": "00000000-0000-0000-0000-000000000001",
"studies": [
{
"created_at": "2023-06-19T15:29:59.132810+00:00",
"publication_date": "2023-06-19",
}
],
}

assert Runner._snapshot_md5(live_payload) == Runner._snapshot_md5(api_payload)


def test_json_safe_payload_normalizes_datetimes_without_mutating_source():
created_at = datetime(2023, 6, 19, 15, 29, 59, 132810, tzinfo=timezone.utc)
payload = {"created_at": created_at, "tags": {"b", "a"}}

normalized = Runner._json_safe_payload(payload)

assert normalized == {
"created_at": "2023-06-19T15:29:59.132810+00:00",
"tags": ["a", "b"],
}
assert payload["created_at"] is created_at


def test_create_result_object_normalizes_uploaded_snapshots():
created_at = datetime(2023, 6, 19, 15, 29, 59, 132810, tzinfo=timezone.utc)
captured = {}

class FakeComposeApi:
def meta_analysis_results_post(self, result_init):
captured["result_init"] = result_init
return type("Result", (), {"id": "result-id"})()

runner = Runner(meta_analysis_id="meta-id", environment="production")
runner.compose_api = FakeComposeApi()
runner.cached_studyset = {
"id": UUID("00000000-0000-0000-0000-000000000001"),
"created_at": created_at,
"studies": [],
}
runner.cached_annotation = {
"created_at": created_at,
"notes": [],
}

runner.create_result_object()

result_init = captured["result_init"]
assert result_init.snapshot_studyset == {
"created_at": "2023-06-19T15:29:59.132810+00:00",
"id": "00000000-0000-0000-0000-000000000001",
"studies": [],
}
assert result_init.snapshot_annotation == {
"created_at": "2023-06-19T15:29:59.132810+00:00",
"notes": [],
}
assert runner.result_id == "result-id"


@pytest.mark.vcr
def test_run_workflow():
runner = Runner(
Expand Down
Loading