Skip to content

Commit f98ac2b

Browse files
committed
added unit tests for events_file_to_annotation_kwargs
1 parent 2942010 commit f98ac2b

File tree

2 files changed

+122
-15
lines changed

2 files changed

+122
-15
lines changed

mne_bids/read.py

Lines changed: 44 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -523,8 +523,49 @@ def _handle_info_reading(sidecar_fname, raw):
523523
return raw
524524

525525

526-
def _events_file_to_annotation_kwargs(events_fname: str) -> dict:
527-
"""Read the `events.tsv` file and extract onset, duration, and description."""
526+
def events_file_to_annotation_kwargs(events_fname: str | Path) -> dict:
527+
"""
528+
Read the `events.tsv` file and extract onset, duration, and description.
529+
530+
This function reads an events file in TSV format and extracts the onset,
531+
duration, and description of events.
532+
533+
Parameters
534+
----------
535+
events_fname : str
536+
The file path to the `events.tsv` file.
537+
538+
Returns
539+
-------
540+
dict
541+
A dictionary containing the following keys:
542+
- 'onset' : np.ndarray
543+
The onset times of the events in seconds.
544+
- 'duration' : np.ndarray
545+
The durations of the events in seconds.
546+
- 'description' : np.ndarray
547+
The descriptions of the events.
548+
- 'event_id' : dict
549+
A dictionary mapping event descriptions to integer event IDs.
550+
551+
Notes
552+
-----
553+
The function handles the following cases:
554+
- If the `trial_type` column is available, it uses it for event descriptions.
555+
- If the `stim_type` column is available, it uses it for backward compatibility.
556+
- If the `value` column is available, it uses it to create the `event_id`.
557+
- If none of the above columns are available, it defaults to using 'n/a' for
558+
descriptions and 1 for event IDs.
559+
560+
Examples (TBD REWORK THIS)
561+
--------
562+
>>> events_dict = events_file_to_annotation_kwargs('path/to/events.tsv')
563+
>>> print(events_dict['onset'])
564+
[0.1, 0.2, 0.3]
565+
>>> print(events_dict['event_id'])
566+
{'event1': 1, 'event2': 2}
567+
568+
"""
528569
logger.info(f"Reading events from {events_fname}.")
529570
events_dict = _from_tsv(events_fname)
530571

@@ -606,7 +647,7 @@ def _events_file_to_annotation_kwargs(events_fname: str) -> dict:
606647

607648
def _handle_events_reading(events_fname, raw):
608649
"""Read associated events.tsv and convert valid events to annotations on Raw."""
609-
annotations_info = _events_file_to_annotation_kwargs(events_fname)
650+
annotations_info = events_file_to_annotation_kwargs(events_fname)
610651
event_id = annotations_info["event_id"]
611652

612653
# Add events as Annotations, but keep essential Annotations present in raw file

mne_bids/tests/test_read.py

Lines changed: 78 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
import mne
1616
import numpy as np
17+
import pandas as pd
1718
import pytest
1819
from mne.datasets import testing
1920
from mne.io.constants import FIFF
@@ -32,6 +33,7 @@
3233
_handle_events_reading,
3334
_handle_scans_reading,
3435
_read_raw,
36+
events_file_to_annotation_kwargs,
3537
get_head_mri_trans,
3638
read_raw_bids,
3739
)
@@ -855,9 +857,7 @@ def test_handle_chpi_reading(tmp_path):
855857
meg_json_data_freq_mismatch["HeadCoilFrequency"][0] = 123
856858
_write_json(meg_json_path, meg_json_data_freq_mismatch, overwrite=True)
857859

858-
with (
859-
pytest.warns(RuntimeWarning, match="Defaulting to .* mne.Raw object"),
860-
):
860+
with (pytest.warns(RuntimeWarning, match="Defaulting to .* mne.Raw object"),):
861861
raw_read = read_raw_bids(bids_path, extra_params=dict(allow_maxshield="yes"))
862862

863863
# cHPI "off" according to sidecar, but present in the data
@@ -1078,9 +1078,7 @@ def test_handle_ieeg_coords_reading(bids_path, tmp_path):
10781078
_to_tsv(electrodes_dict, electrodes_fname)
10791079
# popping off channels should not result in an error
10801080
# however, a warning will be raised through mne-python
1081-
with (
1082-
pytest.warns(RuntimeWarning, match="DigMontage is only a subset of info"),
1083-
):
1081+
with (pytest.warns(RuntimeWarning, match="DigMontage is only a subset of info"),):
10841082
read_raw_bids(bids_path=bids_fname, verbose=False)
10851083

10861084
# make sure montage is set if there are coordinates w/ 'n/a'
@@ -1096,9 +1094,7 @@ def test_handle_ieeg_coords_reading(bids_path, tmp_path):
10961094
# electrode coordinates should be nan
10971095
# when coordinate is 'n/a'
10981096
nan_chs = [electrodes_dict["name"][i] for i in [0, 3]]
1099-
with (
1100-
pytest.warns(RuntimeWarning, match="There are channels without locations"),
1101-
):
1097+
with (pytest.warns(RuntimeWarning, match="There are channels without locations"),):
11021098
raw = read_raw_bids(bids_path=bids_fname, verbose=False)
11031099
for idx, ch in enumerate(raw.info["chs"]):
11041100
if ch["ch_name"] in nan_chs:
@@ -1226,9 +1222,7 @@ def test_handle_non_mne_channel_type(tmp_path):
12261222
channels_data["type"][ch_idx] = "FOOBAR"
12271223
_to_tsv(data=channels_data, fname=channels_tsv_path)
12281224

1229-
with (
1230-
pytest.warns(RuntimeWarning, match='will be set to "misc"'),
1231-
):
1225+
with (pytest.warns(RuntimeWarning, match='will be set to "misc"'),):
12321226
raw = read_raw_bids(bids_path)
12331227

12341228
# Should be a 'misc' channel.
@@ -1466,3 +1460,75 @@ def test_gsr_and_temp_reading():
14661460
raw = read_raw_bids(bids_path)
14671461
assert raw.get_channel_types(["GSR"]) == ["gsr"]
14681462
assert raw.get_channel_types(["Temperature"]) == ["temperature"]
1463+
1464+
1465+
def test_events_file_to_annotation_kwargs(tmp_path):
1466+
bids_path = BIDSPath(
1467+
subject="01", session="eeg", task="rest", datatype="eeg", root=tiny_bids_root
1468+
)
1469+
events_fname = _find_matching_sidecar(bids_path, suffix="events", extension=".tsv")
1470+
1471+
# ---------------- plain read --------------------------------------------
1472+
df = pd.read_csv(events_fname, sep="\t")
1473+
ev_kwargs = events_file_to_annotation_kwargs(events_fname=events_fname)
1474+
assert (ev_kwargs["onset"] == df["onset"].values).all()
1475+
assert (ev_kwargs["duration"] == df["duration"].values).all()
1476+
assert (ev_kwargs["description"] == df["trial_type"].values).all()
1477+
1478+
# ---------------- filtering out n/a values ------------------------------
1479+
tmp_tsv_file = tmp_path / "events.tsv"
1480+
dext = pd.concat(
1481+
[df.copy().assign(onset=df.onset + i) for i in range(5)]
1482+
).reset_index(drop=True)
1483+
1484+
dext = dext.assign(
1485+
ix=range(len(dext)),
1486+
value=dext.trial_type.map({"start_experiment": 1, "show_stimulus": 2}),
1487+
duration=1.0,
1488+
)
1489+
1490+
# nan values for `_drop` must be string values, `_drop` is called on
1491+
# `onset`, `value` and `trial_type`. `duration` n/a should end up as float 0
1492+
for c in ["onset", "value", "trial_type", "duration"]:
1493+
dext[c] = dext[c].astype(str)
1494+
1495+
dext.loc[0, "onset"] = "n/a"
1496+
dext.loc[1, "duration"] = "n/a"
1497+
dext.loc[4, "trial_type"] = "n/a"
1498+
dext.loc[4, "value"] = (
1499+
"n/a" # to check that filtering is also applied when we drop the `trial_type`
1500+
)
1501+
dext.to_csv(tmp_tsv_file, sep="\t", index=False)
1502+
1503+
ev_kwargs_filtered = events_file_to_annotation_kwargs(events_fname=tmp_tsv_file)
1504+
1505+
dext_f = dext[
1506+
(dext["onset"] != "n/a")
1507+
& (dext["trial_type"] != "n/a")
1508+
& (dext["value"] != "n/a")
1509+
]
1510+
1511+
assert (ev_kwargs_filtered["onset"] == dext_f["onset"].astype(float).values).all()
1512+
assert (
1513+
ev_kwargs_filtered["duration"]
1514+
== dext_f["duration"].replace("n/a", "0.0").astype(float).values
1515+
).all()
1516+
assert (ev_kwargs_filtered["description"] == dext_f["trial_type"].values).all()
1517+
assert (
1518+
ev_kwargs_filtered["duration"][0] == 0.0
1519+
) # now idx=0, as first row is filtered out
1520+
1521+
# ---------------- default if missing trial_type ------------------------
1522+
tmp_tsv_file = tmp_path / "events.tsv"
1523+
dext.drop(columns="trial_type").to_csv(tmp_tsv_file, sep="\t", index=False)
1524+
1525+
ev_kwargs_default = events_file_to_annotation_kwargs(events_fname=tmp_tsv_file)
1526+
assert (ev_kwargs_default["onset"] == dext_f["onset"].astype(float).values).all()
1527+
assert (
1528+
ev_kwargs_default["duration"]
1529+
== dext_f["duration"].replace("n/a", "0.0").astype(float).values
1530+
).all()
1531+
assert (
1532+
np.sort(np.unique(ev_kwargs_default["description"]))
1533+
== np.sort(dext_f["value"].unique())
1534+
).all()

0 commit comments

Comments
 (0)