-
Notifications
You must be signed in to change notification settings - Fork 188
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Cross Dataset Evaluation #703
base: develop
Are you sure you want to change the base?
Conversation
…se improve on spline channel interpolation and handling of events. Please add dataset augmentation and load balancing if necessary
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you for your contribution!
We have some way to do transfer learning accross dataset in MOABB, using the compound dataset.
I see a different use case however, as:
- The
compound_dataset
considers all subjects identically (so one of the split can contain a mix of subjects from different datasets). So you cannot run one evaluation with only subjects from one dataset for training and another dataset for testing. - And at the inverse, the new cross dataset evaluation is agnostic of the number of subjects/sessions/runs.
I guess the main point is rather how to align with the new splitter API.
examples/cross_dataset.py
Outdated
logging.getLogger("mne").setLevel(logging.ERROR) | ||
|
||
|
||
def get_common_channels(datasets: List[Any]) -> List[str]: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There is a match_all
method in base paradigm:
Line 429 in 357cd12
def match_all( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
please use this method @ali-sehar
logging.basicConfig(level=logging.WARNING) | ||
|
||
|
||
def get_common_channels(train_dataset, test_dataset): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same here (match_all
method)
return event_id | ||
|
||
|
||
def interpolate_missing_channels( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
match_all
method.
@@ -780,3 +780,132 @@ def evaluate( | |||
|
|||
def is_valid(self, dataset): | |||
return len(dataset.subject_list) > 1 | |||
|
|||
|
|||
class CrossDatasetEvaluation(BaseEvaluation): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hm. I think there is a plan to refactor the existing evaluation.
The recommended way to go will be to use the new Splitter API (see: #612 (comment)).
@bruAristimunha can probably advise you better than me what refactoring is necessary in this case.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hey @gcattan, thanks for all your feedback!
@bruAristimunha - if you could comment on the best way to move forward :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We can implement this one and we migrate later
train_dataset : Dataset or list of Dataset | ||
Dataset(s) to use for training | ||
test_dataset : Dataset or list of Dataset | ||
Dataset(s) to use for testing |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Probably you want to have a cross-evaluation.
So provide a list of datasets, and then, keep one for training and the other for testing. and then rotate.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@ali-sehar @EazyAl Please implement this suggestion too.
- Pass a list of datasets
- And implement cross-validation
model = clone(pipeline).fit(train_X[0], train_y) | ||
score = model.score(test_X, test_y) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok, so you train on the whole subjects/sessions/runs, and then test on the whole subjects/sessions/run of the second dataset?
# Get the list of channels from each dataset before matching | ||
print("\nChannels before matching:") | ||
for ds_name, ds in datasets_dict.items(): | ||
try: | ||
# Load data for first subject to get channel information | ||
data = ds.get_data([ds.subject_list[0]]) # Get data for first subject | ||
first_subject = list(data.keys())[0] | ||
first_session = list(data[first_subject].keys())[0] | ||
first_run = list(data[first_subject][first_session].keys())[0] | ||
run_data = data[first_subject][first_session][first_run] | ||
|
||
if isinstance(run_data, (RawArray, RawCNT)): | ||
channels = run_data.info["ch_names"] | ||
else: | ||
# Assuming the channels are stored in the dataset class after loading | ||
channels = ds.channels | ||
print(f"{ds_name}: {channels}") | ||
except Exception as e: | ||
print(f"Error getting channels for {ds_name}: {str(e)}") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
remove this
# Get channels from all datasets after matching to ensure we have the correct intersection | ||
all_channels_after_matching = [] | ||
print("\nChannels after matching:") | ||
for i, (ds_name, _) in enumerate(datasets_dict.items()): | ||
ds = all_datasets[i] # Get the matched dataset | ||
try: | ||
data = ds.get_data([ds.subject_list[0]]) | ||
subject = list(data.keys())[0] | ||
session = list(data[subject].keys())[0] | ||
run = list(data[subject][session].keys())[0] | ||
run_data = data[subject][session][run] | ||
|
||
if isinstance(run_data, (RawArray, RawCNT)): | ||
channels = run_data.info["ch_names"] | ||
else: | ||
channels = ds.channels | ||
all_channels_after_matching.append(set(channels)) | ||
print(f"{ds_name}: {channels}") | ||
except Exception as e: | ||
print(f"Error getting channels for {ds_name} after matching: {str(e)}") | ||
|
||
# Get the intersection of all channel sets | ||
common_channels = sorted(list(set.intersection(*all_channels_after_matching))) | ||
print(f"\nCommon channels after matching: {common_channels}") | ||
print(f"Number of common channels: {len(common_channels)}") | ||
|
||
# Update the datasets_dict with the matched datasets | ||
for i, (name, _) in enumerate(datasets_dict.items()): | ||
datasets_dict[name] = all_datasets[i] | ||
|
||
train_dataset = datasets_dict["train_dataset"] | ||
test_dataset = datasets_dict["test_dataset"] | ||
|
||
# Initialize the paradigm with common channels | ||
paradigm = MotorImagery(channels=common_channels, n_classes=2, fmin=8, fmax=32) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Remove this.
match_all
don't change the number of channels in the dataset,
it just automatically set the filter in the paradigm.
@@ -0,0 +1,691 @@ | |||
""" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same comments about match_all here. Please apply.
train_dataset : Dataset or list of Dataset | ||
Dataset(s) to use for training | ||
test_dataset : Dataset or list of Dataset | ||
Dataset(s) to use for testing |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@ali-sehar @EazyAl Please implement this suggestion too.
- Pass a list of datasets
- And implement cross-validation
This adds a new type of evaluation to be able to validate models across several datasets. This particularly relevant for deep learning models as it allows MOABB to be used for benchmarking transfer learning.
Some examples are also added, one of which uses braindecode.