Skip to content

Commit a01ed57

Browse files
committed
Offline analysis bugfix: models saved
1 parent b49318b commit a01ed57

File tree

1 file changed

+12
-4
lines changed

1 file changed

+12
-4
lines changed

bcipy/signal/model/offline_analysis.py

+12-4
Original file line numberDiff line numberDiff line change
@@ -234,7 +234,8 @@ def analyze_gaze(
234234
device_spec: DeviceSpec,
235235
data_folder: str,
236236
model_type: str = "GaussianProcess",
237-
symbol_set: List[str] = alphabet()) -> SignalModel:
237+
symbol_set: List[str] = alphabet(),
238+
testing_acc: float = 0.0) -> SignalModel:
238239
"""Analyze gaze data and return/save the gaze model.
239240
Extract relevant information from gaze data object.
240241
Extract timing information from trigger file.
@@ -252,6 +253,9 @@ def analyze_gaze(
252253
data_folder (str): Path to the folder containing the data to be analyzed.
253254
model_type (str): Type of gaze model to be used. Options are: "GMIndividual", "GMCentralized",
254255
or "GaussianProcess".
256+
symbol_set (List[str]): List of symbols to be used in the analysis.
257+
testing_acc (float): Testing accuracy of the model. This is calculated during fusion analysis.
258+
Imported to add to the metadata of the model.
255259
"""
256260
channels = gaze_data.channels
257261
type_amp = gaze_data.daq_type
@@ -403,11 +407,11 @@ def analyze_gaze(
403407

404408
model.metadata = SignalModelMetadata(device_spec=device_spec,
405409
transform=None,
406-
acc=model.acc)
410+
acc=testing_acc)
407411
log.info("Training complete for Eyetracker model. Saving data...")
408412
save_model(
409413
model,
410-
Path(data_folder, f"model_{device_spec.content_type.lower()}_{model.acc}.pkl"))
414+
Path(data_folder, f"model_{device_spec.content_type.lower()}_{model.metadata.acc}.pkl"))
411415
return model
412416

413417

@@ -478,6 +482,7 @@ def offline_analysis(
478482

479483
symbol_set = alphabet()
480484
fusion = False
485+
avg_testing_acc_gaze = 0.0
481486
if num_devices == 2:
482487
# Ensure there is an EEG and Eyetracker device
483488
fusion = True
@@ -506,6 +511,8 @@ def offline_analysis(
506511
)
507512

508513
log.info(f"EEG Accuracy: {eeg_acc}, Gaze Accuracy: {gaze_acc}, Fusion Accuracy: {fusion_acc}")
514+
# The average gaze model accuracy:
515+
avg_testing_acc_gaze = round(np.mean(gaze_acc), 3)
509516

510517
# Ask the user if they want to proceed with full dataset model training
511518
models = []
@@ -532,7 +539,8 @@ def offline_analysis(
532539
parameters,
533540
device_spec,
534541
data_folder,
535-
symbol_set=symbol_set)
542+
symbol_set=symbol_set,
543+
testing_acc=avg_testing_acc_gaze)
536544
models.append(et_model)
537545

538546
if alert:

0 commit comments

Comments
 (0)