-
Notifications
You must be signed in to change notification settings - Fork 3
Open
Description
@christian-oreilly what do you think about copying (or moving) the updated model class we defined in the paper_2024 sub-directory, to the eoglearn/model sub directory?
basically this:
eog-learn/paper_2024/run_eog_lstm_regression.py
Lines 20 to 104 in 7a1df0e
| class EOGRegressor(nn.Module): | |
| def __init__(self, n_input_features, n_output_features, hidden_size=64, num_layers=1, dropout=0.5): | |
| super(EOGRegressor, self).__init__() | |
| self.input_size = n_input_features | |
| self.hidden_size = hidden_size | |
| self.num_layers = num_layers | |
| self.dropout = nn.Dropout(dropout) | |
| self.rnn = nn.LSTM(n_input_features, hidden_size, num_layers=num_layers, batch_first=True) | |
| self.fc = nn.Linear(hidden_size, n_output_features) | |
| def forward(self, input): | |
| # input shape: (batch_size, seq_len, input_size) | |
| batch_size = input.size(0) # same as input.shape[0] | |
| # Initialize hidden state & cell states | |
| h0 = torch.zeros(self.num_layers, batch_size, self.hidden_size) | |
| c0 = torch.zeros(self.num_layers, batch_size, self.hidden_size) | |
| # Forward propagate RNN | |
| out, (h0, c0) = self.rnn(input, (h0, c0)) | |
| # Decode the hidden state of the last time step | |
| out = self.dropout(out) | |
| out = self.fc(out) | |
| return out | |
| def train_the_model(X, Y, num_epochs=1000, hidden_size=64, num_layers=1, dropout=0.5): | |
| """ Train the Pytorch model.""" | |
| # Instantiate the model | |
| if X.ndim == 3: | |
| assert Y.ndim == 3 | |
| input_features = X.shape[2] # Assuming (batch_size, seq_len, input_size) | |
| output_features = Y.shape[2] | |
| else: | |
| raise ValueError("Input data must have 3 dimensions: (batch_size, seq_len, input_size)") | |
| model = EOGRegressor(input_features, output_features, hidden_size=hidden_size, num_layers=num_layers, dropout=dropout) | |
| # Loss function (Mean Squared Error) | |
| criterion = nn.MSELoss() | |
| # Optimizer | |
| optimizer = torch.optim.Adam(model.parameters(), lr=0.01) | |
| losses = np.zeros(num_epochs) | |
| # Training loop | |
| model.train() | |
| for i, epoch in enumerate(range(num_epochs)): | |
| # Forward pass | |
| outputs = model(X) | |
| # Compute loss | |
| loss = criterion(outputs, Y) | |
| losses[i] = loss.detach().numpy() | |
| # Zero gradients, backward pass, and optimization | |
| optimizer.zero_grad() | |
| loss.backward() | |
| optimizer.step() | |
| # Print the loss every 100 iterations | |
| if i % 100 == 0: | |
| print(f'Epoch: {epoch} Loss: {loss.item():.4f}') | |
| # Set model to eval mode to turn off dropout | |
| model.eval() | |
| with torch.no_grad(): | |
| predicted_noise = model(X) | |
| denoised_output = (Y - predicted_noise).numpy() | |
| return losses, predicted_noise, denoised_output | |
| def prep_data(subject="EP10", run=1): | |
| fpath = eoglearn.datasets.fetch_eegeyenet(subject=subject, run=run) | |
| raw = eoglearn.io.read_raw_eegeyenet(fpath) | |
| raw.set_montage("GSN-HydroCel-129") | |
| raw.filter(1, 30, picks="eeg").resample(100) # DO NOT filter eyetrack channels | |
| raw.set_eeg_reference("average") | |
| return raw |
to replace this:
eog-learn/eoglearn/models/model.py
Line 18 in 7a1df0e
| class EOGDenoiser: |
Metadata
Metadata
Assignees
Labels
No labels