Skip to content

Commit bfc7e51

Browse files
first iteration of writing in parallel
1 parent d029e06 commit bfc7e51

File tree

1 file changed

+100
-72
lines changed

1 file changed

+100
-72
lines changed

mne_bids/write.py

Lines changed: 100 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@
4343
get_bids_path_from_fname,
4444
read_raw_bids,
4545
)
46-
from mne_bids._fileio import _open_lock
46+
from mne_bids._fileio import _file_lock, _open_lock
4747
from mne_bids.config import (
4848
ALLOWED_DATATYPE_EXTENSIONS,
4949
ALLOWED_INPUT_EXTENSIONS,
@@ -97,6 +97,23 @@
9797
_BTI_SUFFIX_CACHE: dict[str | None, bool] = {}
9898

9999

100+
def _write_tsv_locked(fname: Path | str, data: OrderedDict) -> None:
101+
"""Write TSV data while the caller holds the file lock."""
102+
fname = Path(fname)
103+
columns = list(data.keys())
104+
n_rows = len(data[columns[0]]) if columns else 0
105+
lines = ["\t".join(columns)]
106+
for row_idx in range(n_rows):
107+
lines.append("\t".join(str(data[col][row_idx]) for col in columns))
108+
109+
fname.parent.mkdir(parents=True, exist_ok=True)
110+
with open(fname, "w", encoding="utf-8-sig") as fid:
111+
fid.write("\n".join(lines))
112+
fid.write("\n")
113+
114+
logger.info(f"Writing '{fname}'...")
115+
116+
100117
def _is_numeric(n):
101118
return isinstance(n, np.integer | np.floating | int | float)
102119

@@ -574,59 +591,59 @@ def _participants_tsv(raw, subject_id, fname, overwrite=False):
574591

575592
data[key] = new_value
576593

577-
if os.path.exists(fname):
578-
orig_data = _from_tsv(fname)
579-
# whether the new data exists identically in the previous data
580-
exact_included = _contains_row(
581-
data=orig_data,
582-
row_data={
583-
"participant_id": subject_id,
584-
"age": subject_age,
585-
"sex": sex,
586-
"hand": hand,
587-
"weight": weight,
588-
"height": height,
589-
},
590-
)
591-
# whether the subject id is in the previous data
592-
sid_included = subject_id in orig_data["participant_id"]
593-
# if the subject data provided is different to the currently existing
594-
# data and overwrite is not True raise an error
595-
if (sid_included and not exact_included) and not overwrite:
596-
raise FileExistsError(
597-
f'"{subject_id}" already exists in '
598-
f"the participant list. Please set "
599-
f"overwrite to True."
594+
fname = Path(fname)
595+
with _file_lock(fname):
596+
if fname.exists():
597+
orig_data = _from_tsv(fname)
598+
# whether the new data exists identically in the previous data
599+
exact_included = _contains_row(
600+
data=orig_data,
601+
row_data={
602+
"participant_id": subject_id,
603+
"age": subject_age,
604+
"sex": sex,
605+
"hand": hand,
606+
"weight": weight,
607+
"height": height,
608+
},
600609
)
610+
# whether the subject id is in the previous data
611+
sid_included = subject_id in orig_data["participant_id"]
612+
# if the subject data provided is different to the currently
613+
# existing data and overwrite is not True raise an error
614+
if (sid_included and not exact_included) and not overwrite:
615+
raise FileExistsError(
616+
f'"{subject_id}" already exists in '
617+
f"the participant list. Please set "
618+
f"overwrite to True."
619+
)
601620

602-
# Append any columns the original data did not have, and fill them with
603-
# n/a's.
604-
for key in data.keys():
605-
if key in orig_data:
606-
continue
621+
# Append any columns the original data did not have, and fill them
622+
# with n/a's.
623+
for key in data.keys():
624+
if key in orig_data:
625+
continue
607626

608-
orig_data[key] = ["n/a"] * len(orig_data["participant_id"])
627+
orig_data[key] = ["n/a"] * len(orig_data["participant_id"])
609628

610-
# Append any additional columns that original data had.
611-
# Keep the original order of the data by looping over
612-
# the original OrderedDict keys
613-
for key in orig_data.keys():
614-
if key in data:
615-
continue
629+
# Append any additional columns that original data had.
630+
# Keep the original order of the data by looping over
631+
# the original OrderedDict keys
632+
for key in orig_data.keys():
633+
if key in data:
634+
continue
616635

617-
# add original value for any user-appended columns
618-
# that were not handled by mne-bids
619-
p_id = data["participant_id"][0]
620-
if p_id in orig_data["participant_id"]:
621-
row_idx = orig_data["participant_id"].index(p_id)
622-
data[key] = [orig_data[key][row_idx]]
636+
# add original value for any user-appended columns
637+
# that were not handled by mne-bids
638+
p_id = data["participant_id"][0]
639+
if p_id in orig_data["participant_id"]:
640+
row_idx = orig_data["participant_id"].index(p_id)
641+
data[key] = [orig_data[key][row_idx]]
623642

624-
# otherwise add the new data as new row
625-
data = _combine_rows(orig_data, data, "participant_id")
643+
# otherwise add the new data as new row
644+
data = _combine_rows(orig_data, data, "participant_id")
626645

627-
# overwrite is forced to True as all issues with overwrite == False have
628-
# been handled by this point
629-
_write_tsv(fname, data, True)
646+
_write_tsv_locked(fname, data)
630647

631648

632649
def _participants_json(fname, overwrite=False):
@@ -665,13 +682,24 @@ def _participants_json(fname, overwrite=False):
665682
# Note: mne-bids will overwrite age, sex and hand fields
666683
# if `overwrite` is True
667684
fname = Path(fname)
668-
if fname.exists():
669-
orig_data = json.loads(
670-
fname.read_text(encoding="utf-8"), object_pairs_hook=OrderedDict
671-
)
672-
new_data = {**orig_data, **new_data}
685+
with _file_lock(fname):
686+
if fname.exists():
687+
if not overwrite:
688+
raise FileExistsError(
689+
f'"{fname}" already exists. Please set overwrite to True.'
690+
)
691+
orig_data = json.loads(
692+
fname.read_text(encoding="utf-8"), object_pairs_hook=OrderedDict
693+
)
694+
new_data = {**orig_data, **new_data}
673695

674-
_write_json(fname, new_data, overwrite)
696+
fname.parent.mkdir(parents=True, exist_ok=True)
697+
json_output = json.dumps(new_data, indent=4, ensure_ascii=False)
698+
with open(fname, "w", encoding="utf-8") as fid:
699+
fid.write(json_output)
700+
fid.write("\n")
701+
702+
logger.info(f"Writing '{fname}'...")
675703

676704

677705
def _scans_tsv(raw, raw_fname, fname, keep_source, overwrite=False):
@@ -743,29 +771,29 @@ def _scans_tsv(raw, raw_fname, fname, keep_source, overwrite=False):
743771
else:
744772
_write_json(sidecar_json_path, sidecar_json)
745773

746-
if os.path.exists(fname):
747-
orig_data = _from_tsv(fname)
748-
# if the file name is already in the file raise an error
749-
if raw_fname in orig_data["filename"] and not overwrite:
750-
raise FileExistsError(
751-
f'"{raw_fname}" already exists in '
752-
f"the scans list. Please set "
753-
f"overwrite to True."
754-
)
774+
fname = Path(fname)
775+
with _file_lock(fname):
776+
if fname.exists():
777+
orig_data = _from_tsv(fname)
778+
# if the file name is already in the file raise an error
779+
if raw_fname in orig_data["filename"] and not overwrite:
780+
raise FileExistsError(
781+
f'"{raw_fname}" already exists in '
782+
f"the scans list. Please set "
783+
f"overwrite to True."
784+
)
755785

756-
for key in data.keys():
757-
if key in orig_data:
758-
continue
786+
for key in data.keys():
787+
if key in orig_data:
788+
continue
759789

760-
# add 'n/a' if any missing columns
761-
orig_data[key] = ["n/a"] * len(next(iter(data.values())))
790+
# add 'n/a' if any missing columns
791+
orig_data[key] = ["n/a"] * len(next(iter(data.values())))
762792

763-
# otherwise add the new data
764-
data = _combine_rows(orig_data, data, "filename")
793+
# otherwise add the new data
794+
data = _combine_rows(orig_data, data, "filename")
765795

766-
# overwrite is forced to True as all issues with overwrite == False have
767-
# been handled by this point
768-
_write_tsv(fname, data, True)
796+
_write_tsv_locked(fname, data)
769797

770798

771799
def _load_image(image, name="image"):

0 commit comments

Comments
 (0)