diff --git a/bcipy/core/stimuli.py b/bcipy/core/stimuli.py index 43ce5feb..e81a72c2 100644 --- a/bcipy/core/stimuli.py +++ b/bcipy/core/stimuli.py @@ -307,9 +307,11 @@ def __call__(self, for symbol in symbol_set: data_by_targets_dict[symbol] = [] - buffer = stimulus_duration / 5 # seconds, buffer for each inquiry + buffer = 0.5 # seconds, buffer for each inquiry # NOTE: This buffer is used to account for the screen downtime between each stimulus. - # There is a "duty cycle" of 80% for the stimuli, so we add a buffer of 20% of the stimulus length + # A better way of handling this buffer would be subtracting the flash time of the + # second symbol from the first symbol, which gives a more accurate representation of + # "stimulus duration". window_length = (stimulus_duration + buffer) * num_stimuli_per_inquiry # in seconds reshaped_data = [] diff --git a/bcipy/signal/evaluate/fusion.py b/bcipy/signal/evaluate/fusion.py index c23ba928..824b4a09 100644 --- a/bcipy/signal/evaluate/fusion.py +++ b/bcipy/signal/evaluate/fusion.py @@ -260,7 +260,6 @@ def calculate_eeg_gaze_fusion_acc( # generate a tuple that matches the index of the symbol with the symbol itself: symbol_to_index = {symbol: i for i, symbol in enumerate(symbol_set)} - # train and save the gaze model as a pkl file: reshaped_data = centralized_gaze_data_train.reshape( (len(centralized_gaze_data_train), inquiry_length * predefined_dimensions)) units = 1e4 diff --git a/bcipy/signal/model/README.md b/bcipy/signal/model/README.md index 9f0abf2d..a4189030 100644 --- a/bcipy/signal/model/README.md +++ b/bcipy/signal/model/README.md @@ -62,7 +62,7 @@ These models may be trained and evalulated, but are still being integrated into *Note*: The gaze model is currently under development and is not yet fully implemented. -These models are used to update the posterior probability of stimuli viewed by a user based on gaze data. The gaze model uses a generative model to estimate the likelihood of the gaze data given the stimuli. There are several models implemented in this module, including a Gaussian Mixture Model (GMIndividual and GMCentralized) and Gaussian Process Model (GaussianProcess). When training data via offline analysis, if the data folder contains gaze data, the gaze model will be trained and saved to the output directory. +These models are used to update the posterior probability of stimuli viewed by a user based on gaze data. The gaze model uses a generative model to estimate the likelihood of the gaze data given the stimuli. There are several models implemented in this module, including a Gaussian Mixture Model (GMIndividual) and a Gaussian Process Model (GaussianProcess). When training data via offline analysis, if the data folder contains gaze data, the gaze model will be trained and saved to the output directory. ## Fusion Analyis diff --git a/bcipy/signal/model/__init__.py b/bcipy/signal/model/__init__.py index e4559f1b..d48a32f4 100644 --- a/bcipy/signal/model/__init__.py +++ b/bcipy/signal/model/__init__.py @@ -2,7 +2,7 @@ from bcipy.signal.model.pca_rda_kde.pca_rda_kde import PcaRdaKdeModel from bcipy.signal.model.rda_kde.rda_kde import RdaKdeModel from bcipy.signal.model.gaussian_mixture.gaussian_mixture import ( - GMIndividual, GMCentralized, GaussianProcess) + GMIndividual, GaussianProcess) __all__ = [ @@ -10,7 +10,6 @@ "PcaRdaKdeModel", "RdaKdeModel", 'GMIndividual', - 'GMCentralized', 'GaussianProcess', "ModelEvaluationReport", ] diff --git a/bcipy/signal/model/gaussian_mixture/__init__.py b/bcipy/signal/model/gaussian_mixture/__init__.py index 9be2725f..f9b92f3f 100644 --- a/bcipy/signal/model/gaussian_mixture/__init__.py +++ b/bcipy/signal/model/gaussian_mixture/__init__.py @@ -1,8 +1,7 @@ -from .gaussian_mixture import GMIndividual, GMCentralized, GaussianProcess, GazeModelResolver +from .gaussian_mixture import GMIndividual, GaussianProcess, GazeModelResolver __all__ = [ 'GMIndividual', - 'GMCentralized', 'GaussianProcess', 'GazeModelResolver' ] diff --git a/bcipy/signal/model/gaussian_mixture/gaussian_mixture.py b/bcipy/signal/model/gaussian_mixture/gaussian_mixture.py index 73e64b75..d04762bc 100644 --- a/bcipy/signal/model/gaussian_mixture/gaussian_mixture.py +++ b/bcipy/signal/model/gaussian_mixture/gaussian_mixture.py @@ -1,7 +1,9 @@ +import pickle from pathlib import Path from typing import List from enum import Enum +from bcipy.exceptions import SignalException from bcipy.core.stimuli import GazeReshaper from bcipy.signal.model import SignalModel @@ -17,7 +19,6 @@ class GazeModelType(Enum): """Enum for gaze model types""" GAUSSIAN_PROCESS = "GaussianProcess" GM_INDIVIDUAL = "GMIndividual" - GM_CENTRALIZED = "GMCentralized" def __str__(self): return self.value @@ -31,8 +32,6 @@ def from_str(label: str): return GazeModelType.GAUSSIAN_PROCESS elif label == "GMIndividual": return GazeModelType.GM_INDIVIDUAL - elif label == "GMCentralized": - return GazeModelType.GM_CENTRALIZED else: raise ValueError(f"Model type {label} not recognized.") @@ -51,8 +50,6 @@ def resolve(model_type: str, *args, **kwargs) -> SignalModel: return GaussianProcess(*args, **kwargs) elif model_type == GazeModelType.GM_INDIVIDUAL: return GMIndividual(*args, **kwargs) - elif model_type == GazeModelType.GM_CENTRALIZED: - return GMCentralized(*args, **kwargs) else: raise ValueError( f"Model type {model_type} not able to resolve. Not registered in GazeModelResolver.") @@ -66,24 +63,75 @@ class GaussianProcess(SignalModel): def __init__(self, *args, **kwargs): self.ready_to_predict = False self.acc = None + self.time_average = None + self.centralized_data = None + self.model = None - def fit(self, training_data: np.ndarray): - ... + + def fit(self, time_avg: np.ndarray, cent_data: np.ndarray): + """Fit the Gaussian Process model to the training data. + Args: + time_avg Dict[(np.ndarray)]: Time average for the symbols. + mean_data (np.ndarray): Sample average for the training data. + cov_data (np.ndarray): Covariance matrix for the training data. + """ + self.time_average = time_avg + self.centralized_data = cent_data + self.ready_to_predict = True + return self def evaluate(self, test_data: np.ndarray, test_labels: np.ndarray): ... + def evaluate_likelihood(self, data: np.ndarray, symbols: List[str], + symbol_set: List[str]) -> np.ndarray: + if not self.ready_to_predict: + raise SignalException("must use model.fit() before model.evaluate_likelihood()") + + gaze_log_likelihoods = np.zeros((len(symbol_set))) + # Clip the pre-saved centralized data to the length of our test data + cent_data = self.centralized_data[:, :, :data.shape[1]] + reshaped_data = cent_data.reshape((len(cent_data), data.shape[0] * data.shape[1])) + cov_matrix = np.cov(reshaped_data, rowvar=False) + reshaped_mean = np.mean(reshaped_data, axis=0) + eps = 10e-1 # add a small value to the diagonal to make the cov matrix invertible + inv_cov_matrix = np.linalg.inv(cov_matrix + np.eye(len(cov_matrix)) * eps) + + for idx, sym in enumerate(symbol_set): + if self.time_average[sym] == []: + gaze_log_likelihoods[idx] = -100000 # set a very small value + else: + # Compute the likelihood of the data for each symbol + central_data = self.subtract_mean(data, self.time_average[sym]) + # flatten this data + flattened_data = np.reshape(central_data, (-1, )) + diff = flattened_data - reshaped_mean + numerator = -np.dot(diff.T, np.dot(inv_cov_matrix, diff)) / 2 + denominator = 0 + unnormalized_log_likelihood_gaze = numerator - denominator + gaze_log_likelihoods[idx] = unnormalized_log_likelihood_gaze + # Find the gaze_likelihoods for the symbols in the inquiry + gaze_likelihood = np.exp(gaze_log_likelihoods) + + return gaze_likelihood # used in multimodal update + def predict(self, test_data: np.ndarray, inquiry, symbol_set) -> np.ndarray: ... def predict_proba(self, test_data: np.ndarray) -> np.ndarray: ... - def save(self, path: Path): - ... + def save(self, path: Path) -> None: + """Save model weights (e.g. after training) to `path`""" + with open(path, "wb") as f: + pickle.dump(self.model, f) - def load(self, path: Path): - ... + def load(self, path: Path) -> SignalModel: + """Load pretrained model from `path`""" + with open(path, "rb") as f: + model = pickle.load(f) + + return model def centralize(self, data: np.ndarray, symbol_pos: np.ndarray) -> np.ndarray: """ Using the symbol locations in matrix, centralize all data (in Tobii units). @@ -207,9 +255,12 @@ def predict_proba(self, test_data: np.ndarray) -> np.ndarray: return likelihoods - def evaluate_likelihood(self, data: np.ndarray) -> np.ndarray: + def evaluate_likelihood(self, data: np.ndarray, symbols: List[str], + symbol_set: List[str]) -> np.ndarray: + if not self.ready_to_predict: + raise SignalException("must use model.fit() before model.evaluate_likelihood()") + data_length, _ = data.shape - likelihoods = np.zeros((data_length, self.num_components), dtype=object) # Find the likelihoods by insterting the test data into the pdf of each component @@ -222,124 +273,14 @@ def evaluate_likelihood(self, data: np.ndarray) -> np.ndarray: return likelihoods - def save(self, path: Path): - """Save model state to the provided checkpoint""" - ... - - def load(self, path: Path): - """Load model state from the provided checkpoint""" - ... - - -class GMCentralized(SignalModel): - '''Gaze model that uses all symbols to fit a single Gaussian ''' - reshaper = GazeReshaper() - name = "gaze_model_combined" - - def __init__(self, num_components=4, random_state=0, *args, **kwargs): - self.num_components = num_components # number of gaussians to fit - self.random_state = random_state - self.acc = None - self.means = None - self.covs = None - - self.ready_to_predict = False - - def fit(self, train_data: np.ndarray): - model = GaussianMixture(n_components=self.num_components, random_state=self.random_state, init_params='kmeans') - model.fit(train_data) - self.model = model - - self.means = model.means_ - self.covs = model.covariances_ - - self.ready_to_predict = True - return self - - def evaluate(self, predictions, true_labels) -> np.ndarray: - ''' - Compute performance characteristics on the provided test data and labels. - - Parameters: - ----------- - predictions: predicted labels for each test point per symbol - true_labels: true labels for each test point per symbol - Returns: - -------- - accuracy_per_symbol: accuracy per symbol - ''' - accuracy_per_symbol = np.sum(predictions == true_labels) / len(predictions) * 100 - self.acc = accuracy_per_symbol - return accuracy_per_symbol - - def predict(self, test_data: np.ndarray) -> np.ndarray: - ''' - Compute log-likelihood of each sample. - Predict the labels for the test data. - ''' - data_length, _ = test_data.shape - predictions = np.zeros(data_length, dtype=object) - likelihoods = self.model.predict_proba(test_data) - - for i in range(data_length): - # Find the argmax of the likelihoods to get the predictions - predictions[i] = np.argmax(likelihoods[i]) - - return predictions - - def predict_proba(self, test_data: np.ndarray) -> np.ndarray: - ''' - Compute log-likelihood of each sample. - Predict the labels for the test data. - - test_data: - ''' - data_length, _ = test_data.shape - - likelihoods = np.zeros((data_length, self.num_components), dtype=object) - - # Find the likelihoods by insterting the test data into the pdf of each component - for i in range(data_length): - for k in range(self.num_components): - mu = self.means[k] - sigma = self.covs[k] - - likelihoods[i, k] = stats.multivariate_normal.pdf(test_data[i], mu, sigma) - - return likelihoods - - def calculate_acc(self, predictions: int, counter: int): - ''' - Compute model performance characteristics on the provided test data and labels. - - predictions: predicted labels for each test point per symbol - counter: true labels for each test point per symbol - ''' - accuracy_per_symbol = np.sum(predictions == counter) / len(predictions) * 100 + def save(self, path: Path) -> None: + """Save model weights (e.g. after training) to `path`""" + with open(path, "wb") as f: + pickle.dump(self.model, f) - return accuracy_per_symbol + def load(self, path: Path) -> SignalModel: + """Load pretrained model from `path`""" + with open(path, "rb") as f: + model = pickle.load(f) - def save(self, path: Path): - """Save model state to the provided checkpoint""" - ... - - def load(self, path: Path): - """Load model state from the provided checkpoint""" - ... - - def centralize(self, data: np.ndarray, symbol_pos: np.ndarray) -> np.ndarray: - """ Using the symbol locations in matrix, centralize all data (in Tobii units). - This data will only be used in certain model types. - Args: - data (np.ndarray): Data in shape of num_samples x num_dimensions - symbol_pos (np.ndarray(float)): Array of the current symbol posiiton in Tobii units - Returns: - new_data (np.ndarray): Centralized data in shape of num_samples x num_dimensions - """ - new_data = np.copy(data) - for i in range(len(data)): - # new_data[i] = data[i] - symbol_pos - new_data[:2, i] = data[:2, i] - symbol_pos - new_data[2:, i] = data[2:, i] - symbol_pos - - return new_data + return model diff --git a/bcipy/signal/model/offline_analysis.py b/bcipy/signal/model/offline_analysis.py index aab4c49f..98381e81 100644 --- a/bcipy/signal/model/offline_analysis.py +++ b/bcipy/signal/model/offline_analysis.py @@ -234,7 +234,8 @@ def analyze_gaze( device_spec: DeviceSpec, data_folder: str, model_type: str = "GaussianProcess", - symbol_set: List[str] = alphabet()) -> SignalModel: + symbol_set: List[str] = alphabet(), + testing_acc: float = 0.0) -> SignalModel: """Analyze gaze data and return/save the gaze model. Extract relevant information from gaze data object. Extract timing information from trigger file. @@ -250,15 +251,18 @@ def analyze_gaze( parameters (Parameters): Parameters object retireved from parameters.json. device_spec (DeviceSpec): DeviceSpec object containing information about the device used. data_folder (str): Path to the folder containing the data to be analyzed. - model_type (str): Type of gaze model to be used. Options are: "GMIndividual", "GMCentralized", - or "GaussianProcess". + model_type (str): Type of gaze model to be used. Options are: "GMIndividual" or + "GaussianProcess". + symbol_set (List[str]): List of symbols to be used in the analysis. + testing_acc (float): Testing accuracy of the model. This is calculated during fusion analysis. + Imported to add to the metadata of the model. """ channels = gaze_data.channels type_amp = gaze_data.daq_type sample_rate = gaze_data.sample_rate flash_time = parameters.get("time_flash") # duration of each stimulus - stim_size = parameters.get("stim_length") # number of stimuli per inquiry + stim_length = parameters.get("stim_length") # number of stimuli per inquiry log.info(f"Channels read from csv: {channels}") log.info(f"Device type: {type_amp}, fs={sample_rate}") @@ -286,10 +290,10 @@ def analyze_gaze( ) ''' Trigger_timing includes PROMPT and excludes FIXATION ''' - target_symbols = trigger_symbols[0::stim_size + 1] # target symbols are the PROMPT triggers + target_symbols = trigger_symbols[0::stim_length + 1] # target symbols are the PROMPT triggers # Use trigger_timing to generate time windows for each letter flashing # Take every 10th trigger as the start point of timing. - inq_start = trigger_timing[1::stim_size + 1] # start of each inquiry (here we jump over prompts) + inq_start = trigger_timing[1::stim_length + 1] # start of each inquiry (here we jump over prompts) # Extract the inquiries dictionary with keys as target symbols and values as inquiry windows: inquiries_dict, inquiries_list, _ = model.reshaper( @@ -298,7 +302,7 @@ def analyze_gaze( gaze_data=data, sample_rate=sample_rate, stimulus_duration=flash_time, - num_stimuli_per_inquiry=stim_size, + num_stimuli_per_inquiry=stim_length, symbol_set=symbol_set, ) @@ -345,16 +349,6 @@ def analyze_gaze( preprocessed_data[sym].shape[1])) model.fit(reshaped_data) - if model_type == "GMCentralized": - # Centralize the data using symbol positions and fit a single Gaussian. - # Load json file. - with open(f"{data_folder}/{STIMULI_POSITIONS_FILENAME}", 'r') as params_file: - symbol_positions = json.load(params_file) - - # Subtract the symbol positions from the data: - for j in range(len(preprocessed_data[sym])): - centralized_data[sym].append(model.centralize(preprocessed_data[sym][j], symbol_positions[sym])) - if model_type == "GaussianProcess": # Instead of centralizing, take the time average: for j in range(len(preprocessed_data[sym])): @@ -388,26 +382,17 @@ def analyze_gaze( cov_matrix[l_ind, m_ind] = 0 reshaped_mean = np.mean(reshaped_data, axis=0) - # Save model parameters which are mean and covariance matrix - model.fit(reshaped_mean) - - if model_type == "GMCentralized": - # Fit the model parameters using the centralized data: - # flatten the dict to a np array: - cent_data = np.concatenate([centralized_data[sym] for sym in symbol_set], axis=0) - # Merge the first and third dimensions: - cent_data = cent_data.reshape((cent_data.shape[0] * cent_data.shape[2], cent_data.shape[1])) - - # cent_data = np.concatenate(centralized_data, axis=0) - model.fit(cent_data) + # Save model parameters which are time averages per symbol (time_average): Dict(np.array), + # and centralized data points (centralized_gaze_data): (np.array) of shape N_inquiries x N_dims x N_timesamples + model.fit(time_average, centralized_gaze_data) model.metadata = SignalModelMetadata(device_spec=device_spec, transform=None, - acc=model.acc) + acc=testing_acc) log.info("Training complete for Eyetracker model. Saving data...") save_model( model, - Path(data_folder, f"model_{device_spec.content_type.lower()}_{model.acc}.pkl")) + Path(data_folder, f"model_{device_spec.content_type.lower()}_{model.metadata.acc}.pkl")) return model @@ -478,6 +463,7 @@ def offline_analysis( symbol_set = alphabet() fusion = False + avg_testing_acc_gaze = 0.0 if num_devices == 2: # Ensure there is an EEG and Eyetracker device fusion = True @@ -506,6 +492,8 @@ def offline_analysis( ) log.info(f"EEG Accuracy: {eeg_acc}, Gaze Accuracy: {gaze_acc}, Fusion Accuracy: {fusion_acc}") + # The average gaze model accuracy: + avg_testing_acc_gaze = round(np.mean(gaze_acc), 3) # Ask the user if they want to proceed with full dataset model training models = [] @@ -532,7 +520,8 @@ def offline_analysis( parameters, device_spec, data_folder, - symbol_set=symbol_set) + symbol_set=symbol_set, + testing_acc=avg_testing_acc_gaze) models.append(et_model) if alert: diff --git a/bcipy/signal/tests/model/gaussian_mixture/test_gaussian_mixture.py b/bcipy/signal/tests/model/gaussian_mixture/test_gaussian_mixture.py index df373f28..ff5fe344 100644 --- a/bcipy/signal/tests/model/gaussian_mixture/test_gaussian_mixture.py +++ b/bcipy/signal/tests/model/gaussian_mixture/test_gaussian_mixture.py @@ -2,7 +2,6 @@ from bcipy.signal.model.gaussian_mixture import ( GaussianProcess, - GMCentralized, GMIndividual, GazeModelResolver ) @@ -14,10 +13,6 @@ def test_resolve(self): response = GazeModelResolver.resolve('GaussianProcess') self.assertIsInstance(response, GaussianProcess) - def test_resolve_centralized(self): - response = GazeModelResolver.resolve('GMCentralized') - self.assertIsInstance(response, GMCentralized) - def test_resolve_individual(self): response = GazeModelResolver.resolve('GMIndividual') self.assertIsInstance(response, GMIndividual) @@ -33,10 +28,6 @@ def test_gaussian_process(self): model = GaussianProcess() self.assertIsInstance(model, GaussianProcess) - def test_centrailized(self): - model = GMCentralized() - self.assertIsInstance(model, GMCentralized) - def test_individual(self): model = GMIndividual() self.assertIsInstance(model, GMIndividual) diff --git a/bcipy/task/control/evidence.py b/bcipy/task/control/evidence.py index 97b2986d..0b68f0f5 100644 --- a/bcipy/task/control/evidence.py +++ b/bcipy/task/control/evidence.py @@ -8,8 +8,9 @@ from bcipy.acquisition.multimodal import ContentType from bcipy.config import SESSION_LOG_FILENAME from bcipy.helpers.acquisition import analysis_channels -from bcipy.core.stimuli import TrialReshaper +from bcipy.core.stimuli import TrialReshaper, GazeReshaper from bcipy.signal.model import SignalModel +from bcipy.signal.process import extract_eye_info from bcipy.task.data import EvidenceType from bcipy.task.exceptions import MissingEvidenceEvaluator @@ -97,7 +98,8 @@ def preprocess(self, raw_data: np.ndarray, times: List[float], # pylint: disable=arguments-differ def evaluate(self, raw_data: np.ndarray, symbols: List[str], times: List[float], target_info: List[str], - window_length: float) -> np.ndarray: + window_length: float, flash_time: float, + stim_length: float) -> np.ndarray: """Evaluate the evidence. Parameters @@ -131,38 +133,50 @@ def __init__(self, symbol_set: List[str], signal_model: SignalModel): self.channel_map = analysis_channels(self.device_spec.channels, self.device_spec) self.transform = signal_model.metadata.transform - self.reshape = TrialReshaper() + self.reshape = GazeReshaper() def preprocess(self, raw_data: np.ndarray, times: List[float], - target_info: List[str], window_length: float) -> np.ndarray: - """Preprocess the inquiry data. + flash_time: float) -> np.ndarray: + """Preprocess the inquiry data. Parameters ---------- raw_data - C x L eeg data where C is number of channels and L is the - signal length + signal length. Includes all channels in devices.json symbols - symbols displayed in the inquiry times - timestamps associated with each symbol - target_info - target information about the stimuli; - ex. ['nontarget', 'nontarget', ...] - window_length - The length of the time between stimuli presentation + flash_time - duration (in seconds) of each stimulus + + Function + -------- + The preprocessing is functionally different than Gaze Reshaper, since + the raw data contains only one inquiry. start_idx is determined as the + start time of first symbol flashing multiplied by the sampling rate + of eye tracker. stop_idx is the index indicating the end of last + symbol flashing. """ - transformed_data, transform_sample_rate = self.transform( - raw_data, self.device_spec.sample_rate) - - # The data from DAQ is assumed to have offsets applied - reshaped_data, _lbls = self.reshape(trial_targetness_label=target_info, - timing_info=times, - eeg_data=transformed_data, - sample_rate=transform_sample_rate, - channel_map=self.channel_map, - poststimulus_length=window_length) - return reshaped_data + if self.transform: + transformed_data, transform_sample_rate = self.transform( + raw_data, self.device_spec.sample_rate) + else: + transformed_data = raw_data + transform_sample_rate = self.device_spec.sample_rate + + start_idx = int(self.device_spec.sample_rate*times[0]) + stop_idx = start_idx + int((times[-1]-times[0]+flash_time) * self.device_spec.sample_rate) + data_all_channels = transformed_data[:, start_idx:stop_idx] + + # Extract left and right eye from all channels. Remove/replace nan values + left_eye, right_eye, _, _, _, _ = extract_eye_info(data_all_channels) + reshaped_data = np.vstack((np.array(left_eye).T, np.array(right_eye).T)) + + return reshaped_data # (4, N_samples) # pylint: disable=arguments-differ def evaluate(self, raw_data: np.ndarray, symbols: List[str], times: List[float], target_info: List[str], - window_length: float) -> np.ndarray: + window_length: float, flash_time: float, + stim_length: float) -> np.ndarray: """Evaluate the evidence. Parameters @@ -175,10 +189,11 @@ def evaluate(self, raw_data: np.ndarray, symbols: List[str], ex. ['nontarget', 'nontarget', ...] window_length - The length of the time between stimuli presentation """ - data = self.preprocess(raw_data, times, target_info, window_length) + data = self.preprocess(raw_data, times, flash_time) # We need the likelihoods in the form of p(label | gaze). predict returns the argmax of the likelihoods. # Therefore we need predict_proba method to get the likelihoods. - return self.signal_model.evaluate_likelihood(data) # multiplication over the inquiry + likelihood = self.signal_model.evaluate_likelihood(data, symbols, self.symbol_set) + return likelihood def get_evaluator( diff --git a/bcipy/task/control/handler.py b/bcipy/task/control/handler.py index 3e67cdfd..492c3d6b 100644 --- a/bcipy/task/control/handler.py +++ b/bcipy/task/control/handler.py @@ -33,8 +33,8 @@ def update_and_fuse(self, dict_evidence): dict_evidence(dict{name: ndarray[float]}): dictionary of evidences (EEG (likelihood ratios) and other likelihoods) """ - # {EEG: [], GAZE: ()} - + # {ERP: [], EYE: ()} + for key in dict_evidence.keys(): tmp = dict_evidence[key][:][:] self.evidence_history[key].append(tmp) diff --git a/bcipy/task/paradigm/rsvp/copy_phrase.py b/bcipy/task/paradigm/rsvp/copy_phrase.py index b69df658..87808d4f 100644 --- a/bcipy/task/paradigm/rsvp/copy_phrase.py +++ b/bcipy/task/paradigm/rsvp/copy_phrase.py @@ -724,6 +724,8 @@ def compute_device_evidence( post_stim_buffer = int(self.parameters.get("task_buffer_length") / 2) prestim_buffer: float = self.parameters["prestim_length"] trial_window: Tuple[float, float] = self.parameters["trial_window"] + flash_time = self.parameters["time_flash"] + stim_length = self.parameters["stim_length"] window_length = trial_window[1] - trial_window[0] inquiry_timing = self.stims_for_decision(stim_times) @@ -756,6 +758,8 @@ def compute_device_evidence( times=times, target_info=filtered_labels, window_length=window_length, + flash_time=flash_time, + stim_length=stim_length, ) evidences.append((evidence_evaluator.produces, probs))