11from pathlib import Path
2+ from typing import Literal
23
34import pandas as pd
45
@@ -14,23 +15,31 @@ def _get_params(subject, run):
1415 row = df .loc [(df .subject == subject .upper ()) & (df .run == int (run ))]
1516 assert len (row ) == 1
1617 row = row .T .squeeze ()
18+ task = row ["task" ]
1719 return dict (
1820 url = row ["url" ],
19- archive_name = f"{ subject } _DOTS { run } _EEG.mat" ,
20- folder_name = f"EEGEYENET-Data/dots /{ subject } " ,
21+ archive_name = f"{ subject } _ { task } { run } _EEG.mat" ,
22+ folder_name = f"EEGEYENET-Data/{ task } /{ subject } " ,
2123 hash = row ["hash" ],
2224 dataset_name = "EEGEYENET" )
2325
2426
25- def get_subjects_runs ():
27+ def get_subjects_runs (task : Literal [ "DOTS" , "AS" ] = "DOTS" ):
2628 """Get dictionary of {subject: [lists of runs]}.
2729
30+ Parameters
31+ ----------
32+ task :
33+ Which EEGEYENET task task to extract the subject ID's and runs for. Can be
34+ ``"DOTS"``, or ``"AS"`` (antisaccade). Defaults to ``'DOTS'``.
35+
2836 Returns
2937 -------
3038 dict
3139 Dictionary of subjects with the runs as values.
3240 """
3341 df = _get_urls_df ()
42+ df = df .loc [df ["task" ] == task ].copy ()
3443 return {subject : df .run .values [df .subject == subject ]
3544 for subject in df .subject .unique ()}
3645
@@ -54,13 +63,14 @@ def fetch_eegeyenet(subject="EP10", run=1, fetch_dataset_kwargs=None):
5463 pathlib.Path
5564 Path to the downloaded file.
5665 """
66+ task = _get_task_from_subject_id (subject )
5767 if not fetch_dataset_kwargs :
5868 fetch_dataset_kwargs = dict ()
5969 run = int (run )
60- runs = get_subjects_runs ()
70+ runs = get_subjects_runs (task = task )
6171 if subject not in runs or run not in runs [subject ]:
62- raise ValueError ("subject and run not available. See "
63- "get_subjects_runs() for information on "
72+ raise ValueError (f "subject { subject } and run { run } not available. "
73+ "See get_subjects_runs() for information on "
6474 "available subjects and runs." )
6575
6676 fetch_dataset_kwargs ["dataset_params" ] = _get_params (subject , run )
@@ -72,5 +82,14 @@ def fetch_eegeyenet(subject="EP10", run=1, fetch_dataset_kwargs=None):
7282 if not fpath .exists ():
7383 fetch_dataset_kwargs ["force_update" ] = True
7484 _fetch_dataset (fetch_dataset_kwargs = fetch_dataset_kwargs )
75-
7685 return fpath
86+
87+
88+ def _get_task_from_subject_id (subject ):
89+ if subject .startswith ("EP" ):
90+ return "DOTS"
91+ if subject .startswith (("A" , "B" )):
92+ return "AS"
93+ raise ValueError (
94+ f"Can't determine task for { subject } . Is this subject in eegeyenet_urls.csv?"
95+ )
0 commit comments