Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cross Dataset Evaluation #703

Open
wants to merge 20 commits into
base: develop
Choose a base branch
from
Open

Conversation

ali-sehar
Copy link

This adds a new type of evaluation to be able to validate models across several datasets. This particularly relevant for deep learning models as it allows MOABB to be used for benchmarking transfer learning.

Some examples are also added, one of which uses braindecode.

Copy link
Collaborator

@gcattan gcattan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for your contribution!

We have some way to do transfer learning accross dataset in MOABB, using the compound dataset.

I see a different use case however, as:

  1. The compound_dataset considers all subjects identically (so one of the split can contain a mix of subjects from different datasets). So you cannot run one evaluation with only subjects from one dataset for training and another dataset for testing.
  2. And at the inverse, the new cross dataset evaluation is agnostic of the number of subjects/sessions/runs.

I guess the main point is rather how to align with the new splitter API.

logging.getLogger("mne").setLevel(logging.ERROR)


def get_common_channels(datasets: List[Any]) -> List[str]:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is a match_all method in base paradigm:

def match_all(

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please use this method @ali-sehar

logging.basicConfig(level=logging.WARNING)


def get_common_channels(train_dataset, test_dataset):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here (match_all method)

return event_id


def interpolate_missing_channels(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

match_all method.

@@ -780,3 +780,132 @@ def evaluate(

def is_valid(self, dataset):
return len(dataset.subject_list) > 1


class CrossDatasetEvaluation(BaseEvaluation):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hm. I think there is a plan to refactor the existing evaluation.
The recommended way to go will be to use the new Splitter API (see: #612 (comment)).

@bruAristimunha can probably advise you better than me what refactoring is necessary in this case.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey @gcattan, thanks for all your feedback!

@bruAristimunha - if you could comment on the best way to move forward :)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can implement this one and we migrate later

Comment on lines +790 to +793
train_dataset : Dataset or list of Dataset
Dataset(s) to use for training
test_dataset : Dataset or list of Dataset
Dataset(s) to use for testing
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Probably you want to have a cross-evaluation.
So provide a list of datasets, and then, keep one for training and the other for testing. and then rotate.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ali-sehar @EazyAl Please implement this suggestion too.

  1. Pass a list of datasets
  2. And implement cross-validation

Comment on lines +890 to +891
model = clone(pipeline).fit(train_X[0], train_y)
score = model.score(test_X, test_y)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, so you train on the whole subjects/sessions/runs, and then test on the whole subjects/sessions/run of the second dataset?

Comment on lines +94 to +112
# Get the list of channels from each dataset before matching
print("\nChannels before matching:")
for ds_name, ds in datasets_dict.items():
try:
# Load data for first subject to get channel information
data = ds.get_data([ds.subject_list[0]]) # Get data for first subject
first_subject = list(data.keys())[0]
first_session = list(data[first_subject].keys())[0]
first_run = list(data[first_subject][first_session].keys())[0]
run_data = data[first_subject][first_session][first_run]

if isinstance(run_data, (RawArray, RawCNT)):
channels = run_data.info["ch_names"]
else:
# Assuming the channels are stored in the dataset class after loading
channels = ds.channels
print(f"{ds_name}: {channels}")
except Exception as e:
print(f"Error getting channels for {ds_name}: {str(e)}")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove this

Comment on lines +122 to +156
# Get channels from all datasets after matching to ensure we have the correct intersection
all_channels_after_matching = []
print("\nChannels after matching:")
for i, (ds_name, _) in enumerate(datasets_dict.items()):
ds = all_datasets[i] # Get the matched dataset
try:
data = ds.get_data([ds.subject_list[0]])
subject = list(data.keys())[0]
session = list(data[subject].keys())[0]
run = list(data[subject][session].keys())[0]
run_data = data[subject][session][run]

if isinstance(run_data, (RawArray, RawCNT)):
channels = run_data.info["ch_names"]
else:
channels = ds.channels
all_channels_after_matching.append(set(channels))
print(f"{ds_name}: {channels}")
except Exception as e:
print(f"Error getting channels for {ds_name} after matching: {str(e)}")

# Get the intersection of all channel sets
common_channels = sorted(list(set.intersection(*all_channels_after_matching)))
print(f"\nCommon channels after matching: {common_channels}")
print(f"Number of common channels: {len(common_channels)}")

# Update the datasets_dict with the matched datasets
for i, (name, _) in enumerate(datasets_dict.items()):
datasets_dict[name] = all_datasets[i]

train_dataset = datasets_dict["train_dataset"]
test_dataset = datasets_dict["test_dataset"]

# Initialize the paradigm with common channels
paradigm = MotorImagery(channels=common_channels, n_classes=2, fmin=8, fmax=32)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove this.
match_all don't change the number of channels in the dataset,
it just automatically set the filter in the paradigm.

@@ -0,0 +1,691 @@
"""
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same comments about match_all here. Please apply.

Comment on lines +790 to +793
train_dataset : Dataset or list of Dataset
Dataset(s) to use for training
test_dataset : Dataset or list of Dataset
Dataset(s) to use for testing
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ali-sehar @EazyAl Please implement this suggestion too.

  1. Pass a list of datasets
  2. And implement cross-validation

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants