Skip to content

Commit 27678e8

Browse files
authored
[ENH] make loading large studysets faster (#1007)
* make loading large studysets faster * respond to review
1 parent f03b762 commit 27678e8

3 files changed

Lines changed: 152 additions & 78 deletions

File tree

nimare/_studyset_store.py

Lines changed: 99 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -136,6 +136,14 @@ def _cached_default_masker(target):
136136

137137
def _apply_annotation_payloads(source_dict, annotation_payloads):
138138
"""Apply top-level annotation notes into analysis-level annotation dictionaries."""
139+
if annotation_payloads is None:
140+
return source_dict
141+
142+
annotation_payloads = _coerce_annotation_payloads(annotation_payloads)
143+
if not annotation_payloads:
144+
source_dict["annotations"] = []
145+
return source_dict
146+
139147
analysis_map = {}
140148
for study in source_dict.get("studies", []):
141149
for analysis in study.get("analyses", []):
@@ -247,12 +255,27 @@ def _build_tables_from_source(source_dict):
247255
studies_rows = []
248256
analyses_rows = []
249257
ids = []
250-
coordinate_rows = []
251258
image_rows = []
252259
metadata_rows = []
253260
annotation_rows = []
254261
text_rows = []
255262

263+
# Coordinate rows: collected as parallel column-arrays for fast DataFrame construction.
264+
# POINT_RELATIONSHIP_COLUMNS are collected as lists and only added when non-all-None.
265+
# The resulting DataFrame is canonicalized by a stable sort on 'id', making
266+
# coordinate ordering explicit while preserving original order for rows with identical ids.
267+
coord_ids_acc: list = []
268+
coord_study_ids_acc: list = []
269+
coord_contrast_ids_acc: list = []
270+
coord_xs: list = []
271+
coord_ys: list = []
272+
coord_zs: list = []
273+
coord_spaces: list = []
274+
prc_lists: dict = {col: [] for col in POINT_RELATIONSHIP_COLUMNS}
275+
prc_seen: dict = {col: False for col in POINT_RELATIONSHIP_COLUMNS}
276+
# Truly sparse extras (point values, coordinate_metadata): (row_index, {col: val})
277+
coord_sparse_extras: list = []
278+
256279
for study in source_dict.get("studies", []):
257280
study_id = str(study["id"])
258281
studies_rows.append(
@@ -286,29 +309,30 @@ def _build_tables_from_source(source_dict):
286309
"journal": study.get("publication", ""),
287310
"name": f"{study_name}-{analysis_name}",
288311
}
289-
study_metadata = study.get("metadata", {}) or {}
290-
analysis_metadata = analysis.get("metadata", {}) or {}
312+
study_metadata = study.get("metadata") or {}
313+
analysis_metadata = analysis.get("metadata") or {}
291314
coordinate_metadata, coordinate_metadata_keys = _extract_coordinate_row_metadata(
292315
analysis_metadata,
293316
len(analysis.get("points", []) or []),
294317
)
295-
combined_metadata = copy.deepcopy(study_metadata)
296-
combined_metadata.update(copy.deepcopy(analysis_metadata))
297-
combined_metadata.pop("sample_sizes", None)
298-
combined_metadata.pop("sample_size", None)
299-
for key in coordinate_metadata_keys:
300-
combined_metadata.pop(key, None)
301-
sample_sizes = _extract_coerced_sample_sizes(
302-
[
303-
("sample_sizes", analysis_metadata.get("sample_sizes")),
304-
("sample_size", analysis_metadata.get("sample_size")),
305-
("sample_sizes", study_metadata.get("sample_sizes")),
306-
("sample_size", study_metadata.get("sample_size")),
307-
]
308-
)
309-
if sample_sizes:
310-
combined_metadata["sample_sizes"] = sample_sizes
311-
metadata_row.update(combined_metadata)
318+
if study_metadata or analysis_metadata:
319+
combined_metadata = copy.deepcopy(study_metadata)
320+
combined_metadata.update(copy.deepcopy(analysis_metadata))
321+
combined_metadata.pop("sample_sizes", None)
322+
combined_metadata.pop("sample_size", None)
323+
for key in coordinate_metadata_keys:
324+
combined_metadata.pop(key, None)
325+
sample_sizes = _extract_coerced_sample_sizes(
326+
[
327+
("sample_sizes", analysis_metadata.get("sample_sizes")),
328+
("sample_size", analysis_metadata.get("sample_size")),
329+
("sample_sizes", study_metadata.get("sample_sizes")),
330+
("sample_size", study_metadata.get("sample_size")),
331+
]
332+
)
333+
if sample_sizes:
334+
combined_metadata["sample_sizes"] = sample_sizes
335+
metadata_row.update(combined_metadata)
312336
metadata_rows.append(metadata_row)
313337

314338
annotation_row = dict(base_row)
@@ -320,7 +344,9 @@ def _build_tables_from_source(source_dict):
320344
annotation_rows.append(annotation_row)
321345

322346
text_row = dict(base_row)
323-
text_row.update(copy.deepcopy(analysis.get("texts", {}) or {}))
347+
texts = analysis.get("texts") or {}
348+
if texts:
349+
text_row.update(copy.deepcopy(texts))
324350
text_rows.append(text_row)
325351

326352
image_row = dict(base_row)
@@ -340,37 +366,68 @@ def _build_tables_from_source(source_dict):
340366

341367
for i_point, point in enumerate(analysis.get("points", []) or []):
342368
coords = point.get("coordinates", [None, None, None])
343-
coordinate_row = {
344-
**base_row,
345-
"x": float(coords[0]),
346-
"y": float(coords[1]),
347-
"z": float(coords[2]),
348-
"space": point.get("space"),
349-
}
350-
for column in POINT_RELATIONSHIP_COLUMNS:
351-
value = point.get(column)
352-
if value is not None:
353-
coordinate_row[column] = value
354-
369+
coord_ids_acc.append(full_id)
370+
coord_study_ids_acc.append(study_id)
371+
coord_contrast_ids_acc.append(contrast_id)
372+
coord_xs.append(coords[0])
373+
coord_ys.append(coords[1])
374+
coord_zs.append(coords[2])
375+
coord_spaces.append(point.get("space"))
376+
377+
for col in POINT_RELATIONSHIP_COLUMNS:
378+
val = point.get(col)
379+
prc_lists[col].append(val)
380+
if val is not None:
381+
prc_seen[col] = True
382+
383+
extra: dict = {}
355384
for point_value in point.get("values", []) or []:
356385
if not isinstance(point_value, dict):
357386
continue
358387
column = _point_value_kind_to_coordinate_column(point_value.get("kind"))
359388
value = point_value.get("value")
360389
if column is not None and value is not None:
361-
coordinate_row[column] = value
362-
390+
extra[column] = value
363391
for column, values in coordinate_metadata.items():
364-
coordinate_row[column] = values[i_point]
365-
366-
coordinate_rows.append(coordinate_row)
392+
extra[column] = values[i_point]
393+
if extra:
394+
coord_sparse_extras.append((len(coord_ids_acc) - 1, extra))
395+
396+
n_coord = len(coord_ids_acc)
397+
if n_coord:
398+
coord_frame: dict = {
399+
"id": coord_ids_acc,
400+
"study_id": coord_study_ids_acc,
401+
"contrast_id": coord_contrast_ids_acc,
402+
"x": np.asarray(coord_xs, dtype=float),
403+
"y": np.asarray(coord_ys, dtype=float),
404+
"z": np.asarray(coord_zs, dtype=float),
405+
"space": coord_spaces,
406+
}
407+
for col in POINT_RELATIONSHIP_COLUMNS:
408+
if prc_seen[col]:
409+
coord_frame[col] = prc_lists[col]
410+
if coord_sparse_extras:
411+
extra_cols: dict = {}
412+
for row_idx, extra in coord_sparse_extras:
413+
for col, val in extra.items():
414+
if col not in extra_cols:
415+
extra_cols[col] = [None] * n_coord
416+
extra_cols[col][row_idx] = val
417+
coord_frame.update(extra_cols)
418+
id_arr = np.asarray(coord_ids_acc, dtype=str)
419+
coord_frame["id"] = id_arr
420+
sort_order = np.argsort(id_arr, kind="stable")
421+
coord_df = pd.DataFrame(coord_frame).iloc[sort_order].reset_index(drop=True)
422+
else:
423+
coord_df = pd.DataFrame(columns=_ID_COLS + ["x", "y", "z", "space"])
367424

368425
ids = np.sort(np.asarray(ids, dtype=str))
369426
return {
370427
"studies": _rows_to_df(studies_rows, ["study_id", "name", "authors", "publication"]),
371428
"analyses": _rows_to_df(analyses_rows, _ID_COLS + ["name"]),
372429
"ids": ids,
373-
"coordinates": _rows_to_df(coordinate_rows, _ID_COLS + ["x", "y", "z", "space"]),
430+
"coordinates": coord_df,
374431
"images": _rows_to_df(image_rows, _ID_COLS, normalize_none_strings=True),
375432
"metadata": _rows_to_df(metadata_rows, _ID_COLS, normalize_none_strings=True),
376433
"annotations": _rows_to_df(annotation_rows, _ID_COLS, normalize_none_strings=True),
@@ -497,11 +554,8 @@ def from_source_dict(
497554
target,
498555
harmonize_coordinates=harmonize_coordinates,
499556
)
500-
annotation_payloads = (
501-
_coerce_annotation_payloads(annotation_payloads)
502-
if annotation_payloads is not None
503-
else _coerce_annotation_payloads(source_dict.get("annotations", []))
504-
)
557+
if annotation_payloads is None:
558+
annotation_payloads = source_dict.get("annotations", [])
505559
source_dict = _apply_annotation_payloads(source_dict, annotation_payloads)
506560
return cls(
507561
source_dict["id"],
@@ -683,7 +737,7 @@ def selected_source_dict(self, selected_full_ids=None):
683737
}
684738

685739
if selected_full_ids is None:
686-
return copy.deepcopy(source_dict)
740+
return _structural_copy_source_dict(source_dict)
687741

688742
selected_ids = set(self.selected_ids(selected_full_ids).tolist())
689743
if not selected_ids:

nimare/nimads.py

Lines changed: 36 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -166,9 +166,11 @@ def __init__(
166166
if target is _UNSET:
167167
target = "mni152_2mm"
168168

169-
# load source as json
169+
# load source as json; track ownership so the structural copy can be skipped
170+
_owned = False
170171
if isinstance(source, str):
171172
source = load_json(source)
173+
_owned = True # freshly parsed — no other reference exists
172174

173175
_validate_studyset_source(source)
174176

@@ -184,6 +186,7 @@ def __init__(
184186
annotation_payloads=annotation_payloads,
185187
target=target,
186188
harmonize_coordinates=harmonize_coordinates,
189+
_owned=_owned,
187190
)
188191
execution_profile = StudysetExecutionProfile(
189192
target=target,
@@ -929,21 +932,23 @@ def from_sleuth(cls, sleuth_file):
929932

930933
def combine_analyses(self):
931934
"""Combine analyses in Studyset."""
932-
studyset = self.copy()
933-
for study in studyset.studies:
934-
if len(study.analyses) > 1:
935-
source_lst = [analysis.to_dict() for analysis in study.analyses]
936-
ids = [source["id"] for source in source_lst]
937-
names = [source["name"] for source in source_lst]
938-
conditions = [source.get("conditions", []) for source in source_lst]
939-
images = [source.get("images", []) for source in source_lst]
940-
points = [source.get("points", []) for source in source_lst]
941-
weights = [source.get("weights", []) for source in source_lst]
942-
metadata = [source.get("metadata", {}) for source in source_lst]
943-
annotations = [source.get("annotations", {}) for source in source_lst]
944-
texts = [source.get("texts", {}) for source in source_lst]
945-
946-
new_source = {
935+
from nimare._studyset_store import StudysetStore
936+
937+
source_dict = self.to_dict()
938+
for study in source_dict.get("studies", []):
939+
analyses = study.get("analyses", [])
940+
if len(analyses) > 1:
941+
ids = [a["id"] for a in analyses]
942+
names = [a.get("name", "") for a in analyses]
943+
conditions = [a.get("conditions", []) for a in analyses]
944+
images = [a.get("images", []) for a in analyses]
945+
points = [a.get("points", []) for a in analyses]
946+
weights = [a.get("weights", []) for a in analyses]
947+
metadata = [a.get("metadata", {}) for a in analyses]
948+
annotations = [a.get("annotations", {}) for a in analyses]
949+
texts = [a.get("texts", {}) for a in analyses]
950+
951+
new_analysis = {
947952
"id": "_".join(ids),
948953
"name": "; ".join(names),
949954
"conditions": [cond for c_list in conditions for cond in c_list],
@@ -957,16 +962,23 @@ def combine_analyses(self):
957962
}
958963
combined_texts = {k: v for text_dict in texts for k, v in text_dict.items()}
959964
if combined_annotations:
960-
new_source["annotations"] = combined_annotations
965+
new_analysis["annotations"] = combined_annotations
961966
if combined_texts:
962-
new_source["texts"] = combined_texts
963-
study.analyses = [Analysis(new_source)]
967+
new_analysis["texts"] = combined_texts
968+
study["analyses"] = [new_analysis]
964969

965-
# Old Analysis objects are gone; Annotation notes hold dead weak references.
966-
# Clear top-level annotations so touch() can rebuild cleanly.
967-
studyset._annotations = []
968-
studyset.touch()
969-
return studyset
970+
# Drop annotation payloads: they reference pre-merge analysis IDs and
971+
# can't be mapped to the new merged IDs, so clear them (matching the
972+
# original copy+touch behaviour that set _annotations = []).
973+
source_dict.pop("annotations", None)
974+
store = StudysetStore.from_source_dict(
975+
source_dict,
976+
annotation_payloads=[],
977+
target=None,
978+
harmonize_coordinates=False,
979+
_owned=True,
980+
)
981+
return self.__class__._from_store(store, self._copy_execution_profile())
970982

971983
def to_nimads(self, filename):
972984
"""Write the Studyset to a NIMADS JSON file."""

nimare/results.py

Lines changed: 17 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -233,15 +233,23 @@ def save_tables(self, output_dir=".", prefix="", prefix_sep="_", names=None):
233233
else:
234234
LGR.warning(f"Table {tabletype} is None. Not saving.")
235235

236+
def _set_description(self, desc):
237+
self.__description = desc
238+
self.bibtex_ = "" if not desc else get_description_references(desc)
239+
236240
def copy(self):
237241
"""Return copy of result object."""
238-
new = MetaResult(
239-
estimator=self.estimator,
240-
corrector=self.corrector,
241-
diagnostics=self.diagnostics,
242-
mask=self.masker,
243-
maps=copy.deepcopy(self.maps),
244-
tables=copy.deepcopy(self.tables),
245-
description=self.description_,
246-
)
242+
new = object.__new__(MetaResult)
243+
# Deep copy the estimator so that corrected results can update estimator state
244+
# without mutating the original MetaResult or estimator.
245+
new.estimator = copy.deepcopy(self.estimator)
246+
new.corrector = self.corrector
247+
new.diagnostics = self.diagnostics
248+
new.masker = self.masker
249+
new.maps = copy.deepcopy(self.maps)
250+
new.tables = copy.deepcopy(self.tables)
251+
new.metadata = {}
252+
# Bypass the description_ setter (which re-parses bibtex on every call).
253+
# Both attributes are already computed and neither changes after fit.
254+
new._set_description(self.description_)
247255
return new

0 commit comments

Comments
 (0)