Skip to content

Commit a6f4c70

Browse files
Fix SetRawAnnotations when no stim channel (#838)
* First try to fix * Add test * [pre-commit.ci] auto fixes from pre-commit.com hooks * Whats new * enforce desired event ids when creating events from annotations --------- Signed-off-by: Pierre Guetschel <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 77acb23 commit a6f4c70

File tree

3 files changed

+32
-4
lines changed

3 files changed

+32
-4
lines changed

docs/source/whats_new.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ Enhancements
4545
Bugs
4646
- Fixes the management of include/exclude datasets in :func:`moabb.benchmark`, adds additional verifications (:gh:`834` by `Anton Andreev`_)
4747
- Fixing pagination issue with figshare (:gh:`839` by `Bruno Aristimunha`_)
48+
- Fixes :class:`moabb.datasets.preprocessing.SetRawAnnotations` in case no STIM channel is present (:gh:`838` by `Pierre Guetschel`_ and `Simon Kojima`_)
4849

4950
~~~~
5051
API changes

moabb/datasets/preprocessing.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -116,11 +116,20 @@ def transform(self, raw, y=None):
116116
offset = int(self.interval[0] * raw.info["sfreq"])
117117
stim_channels = mne.utils._get_stim_channel(None, raw.info, raise_error=False)
118118
if len(stim_channels) == 0:
119-
log.warning(
120-
"No stim channel nor annotations found, skipping setting annotations."
119+
if raw.annotations is None:
120+
log.warning(
121+
"No stim channel nor annotations found, skipping setting annotations."
122+
)
123+
return raw
124+
if not all(isinstance(mrk, int) for mrk in self.event_id.values()):
125+
raise ValueError(
126+
"When no stim channel is present, event_id values must be integers (not lists)."
127+
)
128+
events, _ = mne.events_from_annotations(
129+
raw, event_id=self.event_id, verbose=False
121130
)
122-
return raw
123-
events = mne.find_events(raw, shortest_event=0, verbose=False)
131+
else:
132+
events = mne.find_events(raw, shortest_event=0, verbose=False)
124133
events = _unsafe_pick_events(events, include=_get_event_id_values(self.event_id))
125134
events[:, 0] += offset
126135
if len(events) != 0:

moabb/tests/test_paradigms.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,24 @@ def used_events(self, dataset):
4343

4444

4545
class TestMotorImagery(unittest.TestCase):
46+
47+
def test_paradigm_interval(self):
48+
paradigm = SimpleMotorImagery()
49+
for use_annotations in [True, False]:
50+
dataset = FakeDataset(
51+
paradigm="imagery", annotations=use_annotations, seed=12
52+
)
53+
dataset.interval = (0, 4)
54+
# checking that the random data generation is consistent for the same interval
55+
X1, _, _ = paradigm.get_data(dataset, subjects=[1])
56+
X2, _, _ = paradigm.get_data(dataset, subjects=[1])
57+
np.testing.assert_array_equal(X1, X2)
58+
59+
dataset.interval = (1, 5)
60+
# checking that changing the interval changes the data selected
61+
X3, _, _ = paradigm.get_data(dataset, subjects=[1])
62+
assert not np.array_equal(X1, X3)
63+
4664
def test_BaseImagery_paradigm(self):
4765
paradigm = SimpleMotorImagery()
4866
dataset = FakeDataset(paradigm="imagery")

0 commit comments

Comments
 (0)