Skip to content

Commit ad0be44

Browse files
drpedapatiscott-hubertychristian-oreilly
authored
adding format parameter to save function (lina-usc#159)
* Update pipeline.py added format parameter to pipeline save * Update pipeline.py updated documentation * Update pipeline.py added event-id * add method to get event_ids * Update pylossless/pipeline.py Co-authored-by: Scott Huberty <[email protected]> * Update pylossless/pipeline.py Co-authored-by: Scott Huberty <[email protected]> * Update pylossless/pipeline.py Co-authored-by: Scott Huberty <[email protected]> * Update pylossless/pipeline.py Co-authored-by: Scott Huberty <[email protected]> * Update pylossless/pipeline.py Co-authored-by: Scott Huberty <[email protected]> * Update pylossless/pipeline.py Co-authored-by: Scott Huberty <[email protected]> * Update pylossless/pipeline.py Co-authored-by: Scott Huberty <[email protected]> * Update pylossless/pipeline.py Co-authored-by: Scott Huberty <[email protected]> * Update pylossless/pipeline.py Co-authored-by: Scott Huberty <[email protected]> * FIX: flake * Update pylossless/pipeline.py - stylistic change --------- Co-authored-by: Scott Huberty <[email protected]> Co-authored-by: Scott Huberty <[email protected]> Co-authored-by: Christian O'Reilly <[email protected]>
1 parent fe61bc3 commit ad0be44

File tree

1 file changed

+50
-2
lines changed

1 file changed

+50
-2
lines changed

pylossless/pipeline.py

Lines changed: 50 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1028,7 +1028,7 @@ def flag_noisy_ics(self):
10281028

10291029
# icsd_epoch_flags=padflags(raw, icsd_epoch_flags,1,'value',.5);
10301030

1031-
def save(self, derivatives_path, overwrite=False):
1031+
def save(self, derivatives_path, overwrite=False, format="EDF", event_id=None):
10321032
"""Save the file at the end of the pipeline.
10331033
10341034
Parameters
@@ -1037,13 +1037,19 @@ def save(self, derivatives_path, overwrite=False):
10371037
path of the derivatives folder to save the file to.
10381038
overwrite : bool (default False)
10391039
whether to overwrite existing files with the same name.
1040+
format : str (default "EDF")
1041+
The format to use for saving the raw data. Can be ``"auto"``,
1042+
``"FIF"``, ``"EDF"``, ``"BrainVision"``, ``"EEGLAB"``.
1043+
event_id : dict | None (default None)
1044+
Dictionary mapping annotation descriptions to event codes.
10401045
"""
10411046
mne_bids.write_raw_bids(
10421047
self.raw,
10431048
derivatives_path,
10441049
overwrite=overwrite,
1045-
format="EDF",
1050+
format=format,
10461051
allow_preload=True,
1052+
event_id=event_id,
10471053
)
10481054
# TODO: address derivatives support in MNE bids.
10491055
# use shutils ( or pathlib?) to rename file with ll suffix
@@ -1234,3 +1240,45 @@ def get_derivative_path(self, bids_path, derivative_name="pylossless"):
12341240
return bids_path.copy().update(
12351241
suffix=lossless_suffix, root=lossless_root, check=False
12361242
)
1243+
1244+
def get_all_event_ids(self):
1245+
"""
1246+
Get a combined event ID dictionary from existing markers and raw annotations.
1247+
1248+
Returns
1249+
-------
1250+
dict or None
1251+
A combined dictionary of event IDs, including both existing markers
1252+
and new ones from annotations.
1253+
Returns ``None`` if no events or annotations are found.
1254+
"""
1255+
try:
1256+
# Get existing events and their IDs
1257+
event_id = mne.events_from_annotations(self.raw)[1]
1258+
except ValueError as e:
1259+
warn(f"Warning: No events found in raw data. Error: {e}")
1260+
event_id = {}
1261+
1262+
# Check if there are any annotations
1263+
if len(self.raw.annotations) == 0 and not event_id:
1264+
warn("Warning: No events or annotations found in the raw data.")
1265+
return
1266+
1267+
# Initialize the combined event ID dictionary with existing events
1268+
combined_event_id = event_id.copy()
1269+
1270+
# Determine the starting ID for new annotations
1271+
start_id = max(combined_event_id.values()) + 1 if combined_event_id else 1
1272+
1273+
# Get unique annotations and add new event IDs
1274+
for desc in set(self.raw.annotations.description):
1275+
if desc not in combined_event_id:
1276+
combined_event_id[desc] = start_id
1277+
start_id += 1
1278+
1279+
# Final check to ensure we have at least one event
1280+
if not combined_event_id:
1281+
warn("Warning: No valid events or annotations could be processed.")
1282+
return
1283+
1284+
return combined_event_id

0 commit comments

Comments
 (0)