Skip to content

Commit 03ec4ac

Browse files
Merge pull request #21 from scott-huberty/antisaccade
ENH: Extend dataset fetcher for the Antisaccade task
2 parents 7a1df0e + 156698e commit 03ec4ac

File tree

3 files changed

+547
-188
lines changed

3 files changed

+547
-188
lines changed

eoglearn/datasets/eegeyenet.py

Lines changed: 26 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from pathlib import Path
2+
from typing import Literal
23

34
import pandas as pd
45

@@ -14,23 +15,31 @@ def _get_params(subject, run):
1415
row = df.loc[(df.subject == subject.upper()) & (df.run == int(run))]
1516
assert len(row) == 1
1617
row = row.T.squeeze()
18+
task = row["task"]
1719
return dict(
1820
url=row["url"],
19-
archive_name=f"{subject}_DOTS{run}_EEG.mat",
20-
folder_name=f"EEGEYENET-Data/dots/{subject}",
21+
archive_name=f"{subject}_{task}{run}_EEG.mat",
22+
folder_name=f"EEGEYENET-Data/{task}/{subject}",
2123
hash=row["hash"],
2224
dataset_name="EEGEYENET")
2325

2426

25-
def get_subjects_runs():
27+
def get_subjects_runs(task: Literal["DOTS", "AS"] = "DOTS"):
2628
"""Get dictionary of {subject: [lists of runs]}.
2729
30+
Parameters
31+
----------
32+
task :
33+
Which EEGEYENET task task to extract the subject ID's and runs for. Can be
34+
``"DOTS"``, or ``"AS"`` (antisaccade). Defaults to ``'DOTS'``.
35+
2836
Returns
2937
-------
3038
dict
3139
Dictionary of subjects with the runs as values.
3240
"""
3341
df = _get_urls_df()
42+
df = df.loc[df["task"] == task].copy()
3443
return {subject: df.run.values[df.subject == subject]
3544
for subject in df.subject.unique()}
3645

@@ -54,13 +63,14 @@ def fetch_eegeyenet(subject="EP10", run=1, fetch_dataset_kwargs=None):
5463
pathlib.Path
5564
Path to the downloaded file.
5665
"""
66+
task = _get_task_from_subject_id(subject)
5767
if not fetch_dataset_kwargs:
5868
fetch_dataset_kwargs = dict()
5969
run = int(run)
60-
runs = get_subjects_runs()
70+
runs = get_subjects_runs(task=task)
6171
if subject not in runs or run not in runs[subject]:
62-
raise ValueError("subject and run not available. See "
63-
"get_subjects_runs() for information on "
72+
raise ValueError(f"subject {subject} and run {run} not available. "
73+
"See get_subjects_runs() for information on "
6474
"available subjects and runs.")
6575

6676
fetch_dataset_kwargs["dataset_params"] = _get_params(subject, run)
@@ -72,5 +82,14 @@ def fetch_eegeyenet(subject="EP10", run=1, fetch_dataset_kwargs=None):
7282
if not fpath.exists():
7383
fetch_dataset_kwargs["force_update"] = True
7484
_fetch_dataset(fetch_dataset_kwargs=fetch_dataset_kwargs)
75-
7685
return fpath
86+
87+
88+
def _get_task_from_subject_id(subject):
89+
if subject.startswith("EP"):
90+
return "DOTS"
91+
if subject.startswith(("A", "B")):
92+
return "AS"
93+
raise ValueError(
94+
f"Can't determine task for {subject}. Is this subject in eegeyenet_urls.csv?"
95+
)

0 commit comments

Comments
 (0)