Skip to content

Commit 46f284b

Browse files
authored
optionally return event_id dict from read_raw_bids, and use value column from events file if present (mne-tools#1349)
* optionally return event_id from read_raw_bids * drop n/a values when creating event dict * drop NA-onset events early * commments & cleanup * simplify * bug * clean up / strengthen test * revert introduced bug * changelog * docstring
1 parent 3710a1c commit 46f284b

File tree

3 files changed

+90
-73
lines changed

3 files changed

+90
-73
lines changed

doc/whats_new.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ Detailed list of changes
4646
🪲 Bug fixes
4747
^^^^^^^^^^^^
4848

49-
- Nothing yet
49+
- :func:`mne_bids.read_raw_bids` can optionally return an ``event_id`` dictionary suitable for use with :func:`mne.events_from_annotations`, and if a ``values`` column is present in ``events.tsv`` it will be used as the source of the integer event ID codes, by `Daniel McCloy`_ (:gh:`1349`)
5050

5151
⚕️ Code health
5252
^^^^^^^^^^^^^^

mne_bids/read.py

Lines changed: 51 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -527,89 +527,76 @@ def _handle_info_reading(sidecar_fname, raw):
527527

528528

529529
def _handle_events_reading(events_fname, raw):
530-
"""Read associated events.tsv and populate raw.
531-
532-
Handle onset, duration, and description of each event.
533-
"""
530+
"""Read associated events.tsv and convert valid events to annotations on Raw."""
534531
logger.info(f"Reading events from {events_fname}.")
535532
events_dict = _from_tsv(events_fname)
536533

537-
# Get the descriptions of the events
534+
# drop events where onset is n/a
535+
events_dict = _drop(events_dict, "n/a", "onset")
536+
537+
# Get event descriptions. Use `trial_type` column if available.
538538
if "trial_type" in events_dict:
539539
trial_type_col_name = "trial_type"
540-
elif "stim_type" in events_dict: # Backward-compat with old datasets.
540+
# allow `stim_type` for backward-compat with old datasets.
541+
elif "stim_type" in events_dict:
541542
trial_type_col_name = "stim_type"
542543
warn(
543-
f'The events file, {events_fname}, contains a "stim_type" '
544-
f'column. This column should be renamed to "trial_type" for '
545-
f"BIDS compatibility."
544+
f'The events file, {events_fname}, contains a "stim_type" column. This '
545+
'column should be renamed to "trial_type" for BIDS compatibility.'
546546
)
547+
# If we lack proper event descriptions, perhaps we have at least an event value?
548+
elif "value" in events_dict:
549+
trial_type_col_name = "value"
550+
# Worst case: all events will become `n/a` and all values will be `1`
547551
else:
548552
trial_type_col_name = None
549553

550554
if trial_type_col_name is not None:
551555
# Drop events unrelated to a trial type
552556
events_dict = _drop(events_dict, "n/a", trial_type_col_name)
553-
557+
trial_types = events_dict[trial_type_col_name]
558+
# handle event values (if provided); ensure pairings are 1 value per description
554559
if "value" in events_dict:
555-
# Check whether the `trial_type` <> `value` mapping is unique.
556-
trial_types = events_dict[trial_type_col_name]
557560
values = np.asarray(events_dict["value"], dtype=str)
558561
for trial_type in np.unique(trial_types):
559562
idx = np.where(trial_type == np.atleast_1d(trial_types))[0]
560563
matching_values = values[idx]
561-
562564
if len(np.unique(matching_values)) > 1:
563-
# Event type descriptors are ambiguous; create hierarchical
564-
# event descriptors.
565+
# Event type descriptors are ambiguous; create hierarchical event
566+
# descriptors (to ensure trial_type -> integerID is 1:1)
565567
logger.info(
566-
f'The event "{trial_type}" refers to multiple event '
567-
f"values. Creating hierarchical event names."
568+
f'The event "{trial_type}" refers to multiple event values.'
569+
"Creating hierarchical event names."
568570
)
569571
for ii in idx:
570572
value = values[ii]
571573
value = "na" if value == "n/a" else value
572574
new_name = f"{trial_type}/{value}"
573-
logger.info(
574-
f" Renaming event: {trial_type} -> " f"{new_name}"
575-
)
575+
logger.info(f" Renaming event: {trial_type} -> {new_name}")
576576
trial_types[ii] = new_name
577-
descriptions = np.asarray(trial_types, dtype=str)
577+
# drop rows where `value` is `n/a` & convert remaining `value` to int (only
578+
# when making our `event_id` dict; `value = n/a` doesn't prevent annotation)
579+
culled = _drop(events_dict, "n/a", "value")
580+
event_id = dict(
581+
zip(culled[trial_type_col_name], np.asarray(culled["value"], dtype=int))
582+
)
578583
else:
579-
descriptions = np.asarray(events_dict[trial_type_col_name], dtype=str)
580-
elif "value" in events_dict:
581-
# If we don't have a proper description of the events, perhaps we have
582-
# at least an event value?
583-
# Drop events unrelated to value
584-
events_dict = _drop(events_dict, "n/a", "value")
585-
descriptions = np.asarray(events_dict["value"], dtype=str)
584+
event_id = dict(zip(trial_types, np.arange(len(trial_types))))
585+
descrs = np.asarray(trial_types, dtype=str)
586586

587-
# Worst case, we go with 'n/a' for all events
587+
# Worst case: all events become `n/a` and all values become `1`
588588
else:
589-
descriptions = np.array(["n/a"] * len(events_dict["onset"]), dtype=str)
590-
589+
descrs = np.full(len(events_dict["onset"]), "n/a")
590+
event_id = {descrs[0]: 1}
591591
# Deal with "n/a" strings before converting to float
592-
onsets = np.array(
593-
[np.nan if on == "n/a" else on for on in events_dict["onset"]], dtype=float
594-
)
595-
durations = np.array(
592+
ons = np.asarray(events_dict["onset"], dtype=float)
593+
durs = np.array(
596594
[0 if du == "n/a" else du for du in events_dict["duration"]], dtype=float
597595
)
598596

599-
# Keep only events where onset is known
600-
good_events_idx = ~np.isnan(onsets)
601-
onsets = onsets[good_events_idx]
602-
durations = durations[good_events_idx]
603-
descriptions = descriptions[good_events_idx]
604-
del good_events_idx
605-
606-
# Add events as Annotations, but keep essential Annotations present in
607-
# raw file
597+
# Add events as Annotations, but keep essential Annotations present in raw file
608598
annot_from_raw = raw.annotations.copy()
609-
610-
annot_from_events = mne.Annotations(
611-
onset=onsets, duration=durations, description=descriptions
612-
)
599+
annot_from_events = mne.Annotations(onset=ons, duration=durs, description=descrs)
613600
raw.set_annotations(annot_from_events)
614601

615602
annot_idx_to_keep = [
@@ -622,7 +609,7 @@ def _handle_events_reading(events_fname, raw):
622609
if len(annot_to_keep):
623610
raw.set_annotations(raw.annotations + annot_to_keep)
624611

625-
return raw
612+
return raw, event_id
626613

627614

628615
def _get_bads_from_tsv_data(tsv_data):
@@ -756,7 +743,9 @@ def _handle_channels_reading(channels_fname, raw):
756743

757744

758745
@verbose
759-
def read_raw_bids(bids_path, extra_params=None, verbose=None):
746+
def read_raw_bids(
747+
bids_path, extra_params=None, *, return_event_dict=False, verbose=None
748+
):
760749
"""Read BIDS compatible data.
761750
762751
Will attempt to read associated events.tsv and channels.tsv files to
@@ -781,12 +770,21 @@ def read_raw_bids(bids_path, extra_params=None, verbose=None):
781770
Note that the ``exclude`` parameter, which is supported by some
782771
MNE-Python readers, is not supported; instead, you need to subset
783772
your channels **after** reading.
773+
return_event_dict : bool
774+
Whether to return a dictionary that maps annotation descriptions to integer
775+
event IDs, in addition to the :class:`~mne.io.Raw` object. If a ``value`` column
776+
is present in the ``*_events.tsv`` file, it will be used as the source of the
777+
integer event ID values (events with ``value="n/a"`` will be omitted).
784778
%(verbose)s
785779
786780
Returns
787781
-------
788782
raw : mne.io.Raw
789783
The data as MNE-Python Raw object.
784+
event_id : dict
785+
A mapping from event descriptions to integer event IDs, suitable for,
786+
e.g., passing to :func:`mne.events_from_annotations`. Only returned if
787+
``return_event_dict=True``.
790788
791789
Raises
792790
------
@@ -923,9 +921,8 @@ def read_raw_bids(bids_path, extra_params=None, verbose=None):
923921
events_fname = _find_matching_sidecar(
924922
bids_path, suffix="events", extension=".tsv", on_error=on_error
925923
)
926-
927924
if events_fname is not None:
928-
raw = _handle_events_reading(events_fname, raw)
925+
raw, event_id = _handle_events_reading(events_fname, raw)
929926

930927
# Try to find an associated channels.tsv to get information about the
931928
# status and type of present channels
@@ -989,6 +986,8 @@ def read_raw_bids(bids_path, extra_params=None, verbose=None):
989986
raw.info["subject_info"] = dict()
990987

991988
assert raw.annotations.orig_time == raw.info["meas_date"]
989+
if return_event_dict:
990+
return raw, event_id
992991
return raw
993992

994993

mne_bids/tests/test_read.py

Lines changed: 38 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -509,8 +509,11 @@ def test_handle_events_reading(tmp_path):
509509
events_fname.parent.mkdir()
510510
_to_tsv(events, events_fname)
511511

512-
raw = _handle_events_reading(events_fname, raw)
513-
events, event_id = mne.events_from_annotations(raw)
512+
raw, event_id = _handle_events_reading(events_fname, raw)
513+
ev_arr, ev_dict = mne.events_from_annotations(raw)
514+
assert list(ev_dict.values()) == [1, 2] # auto-assigned
515+
want = len(events["onset"]) - 1 # one onset was n/a
516+
assert want == len(raw.annotations) == len(ev_arr) == len(ev_dict)
514517

515518
# Test with a `stim_type` column instead of `trial_type`.
516519
events = {
@@ -523,9 +526,24 @@ def test_handle_events_reading(tmp_path):
523526
_to_tsv(events, events_fname)
524527

525528
with pytest.warns(RuntimeWarning, match="This column should be renamed"):
526-
raw = _handle_events_reading(events_fname, raw)
529+
raw, _ = _handle_events_reading(events_fname, raw)
527530
events, event_id = mne.events_from_annotations(raw)
528531

532+
# Test with only a `value` column.
533+
events = {
534+
"onset": [11, 12, 13, 14, 15],
535+
"duration": ["n/a", "n/a", 0.1, 0.1, "n/a"],
536+
"value": [3, 1, 1, 3, "n/a"],
537+
}
538+
events_fname = tmp_path / "bids3" / "sub-01_task-test_events.json"
539+
events_fname.parent.mkdir()
540+
_to_tsv(events, events_fname)
541+
542+
raw, event_id = _handle_events_reading(events_fname, raw)
543+
ev_arr, ev_dict = mne.events_from_annotations(raw, event_id=event_id)
544+
assert len(ev_arr) == len(events["value"]) - 1 # one value was n/a
545+
assert {"1": 1, "3": 3} == event_id == ev_dict
546+
529547
# Test with same `trial_type` referring to different `value`:
530548
# The events should be renamed automatically
531549
events = {
@@ -534,32 +552,32 @@ def test_handle_events_reading(tmp_path):
534552
"trial_type": ["event1", "event1", "event2", "event3", "event3"],
535553
"value": [1, 2, 3, 4, "n/a"],
536554
}
537-
events_fname = tmp_path / "bids3" / "sub-01_task-test_events.json"
555+
events_fname = tmp_path / "bids4" / "sub-01_task-test_events.json"
538556
events_fname.parent.mkdir()
539557
_to_tsv(events, events_fname)
540558

541-
raw = _handle_events_reading(events_fname, raw)
542-
events, event_id = mne.events_from_annotations(raw)
543-
544-
assert len(events) == 5
545-
assert "event1/1" in event_id
546-
assert "event1/2" in event_id
547-
assert "event3/4" in event_id
548-
assert "event3/na" in event_id # 'n/a' value should become 'na'
549-
# The event with unique value mapping should not be renamed
550-
assert "event2" in event_id
559+
raw, event_id = _handle_events_reading(events_fname, raw)
560+
ev_arr, ev_dict = mne.events_from_annotations(raw)
561+
# `event_id` will exclude the last event, as its value is `n/a`, but `ev_dict` won't
562+
# exclude it (it's made from annotations, which don't know about missing `value`s)
563+
assert len(event_id) == len(ev_dict) - 1
564+
# check the renaming
565+
assert len(ev_arr) == 5
566+
assert "event1/1" in ev_dict
567+
assert "event1/2" in ev_dict
568+
assert "event3/4" in ev_dict
569+
assert "event3/na" in ev_dict # 'n/a' value should become 'na'
570+
assert "event2" in ev_dict # has unique value mapping; should not be renamed
551571

552572
# Test without any kind of event description.
553573
events = {"onset": [11, 12, "n/a"], "duration": ["n/a", "n/a", "n/a"]}
554-
events_fname = tmp_path / "bids4" / "sub-01_task-test_events.json"
574+
events_fname = tmp_path / "bids5" / "sub-01_task-test_events.json"
555575
events_fname.parent.mkdir()
556576
_to_tsv(events, events_fname)
557577

558-
raw = _handle_events_reading(events_fname, raw)
559-
events, event_id = mne.events_from_annotations(raw)
560-
ids = list(event_id.keys())
561-
assert len(ids) == 1
562-
assert ids == ["n/a"]
578+
raw, event_id = _handle_events_reading(events_fname, raw)
579+
ev_arr, ev_dict = mne.events_from_annotations(raw)
580+
assert event_id == ev_dict == {"n/a": 1} # fallback behavior
563581

564582

565583
@pytest.mark.filterwarnings(warning_str["channel_unit_changed"])

0 commit comments

Comments
 (0)