Skip to content

Commit 888ffb8

Browse files
committed
FIX, STY: backwards compat + type hint
1 parent 90faaf4 commit 888ffb8

File tree

1 file changed

+9
-1
lines changed

1 file changed

+9
-1
lines changed

eoglearn/datasets/eegeyenet.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from pathlib import Path
2+
from typing import Literal
23

34
import pandas as pd
45

@@ -23,15 +24,22 @@ def _get_params(subject, run):
2324
dataset_name="EEGEYENET")
2425

2526

26-
def get_subjects_runs():
27+
def get_subjects_runs(task: Literal["DOTS", "AS"] = "DOTS"):
2728
"""Get dictionary of {subject: [lists of runs]}.
2829
30+
Parameters
31+
----------
32+
task :
33+
Which EEGEYENET task task to extract the subject ID's and runs for. Can be
34+
``"DOTS"``, or ``"AS"`` (antisaccade). Defaults to ``'DOTS'``.
35+
2936
Returns
3037
-------
3138
dict
3239
Dictionary of subjects with the runs as values.
3340
"""
3441
df = _get_urls_df()
42+
df = df.loc[df["task"] == task].copy()
3543
return {subject: df.run.values[df.subject == subject]
3644
for subject in df.subject.unique()}
3745

0 commit comments

Comments
 (0)