@@ -234,7 +234,8 @@ def analyze_gaze(
234
234
device_spec : DeviceSpec ,
235
235
data_folder : str ,
236
236
model_type : str = "GaussianProcess" ,
237
- symbol_set : List [str ] = alphabet ()) -> SignalModel :
237
+ symbol_set : List [str ] = alphabet (),
238
+ testing_acc : float = 0.0 ) -> SignalModel :
238
239
"""Analyze gaze data and return/save the gaze model.
239
240
Extract relevant information from gaze data object.
240
241
Extract timing information from trigger file.
@@ -252,6 +253,9 @@ def analyze_gaze(
252
253
data_folder (str): Path to the folder containing the data to be analyzed.
253
254
model_type (str): Type of gaze model to be used. Options are: "GMIndividual", "GMCentralized",
254
255
or "GaussianProcess".
256
+ symbol_set (List[str]): List of symbols to be used in the analysis.
257
+ testing_acc (float): Testing accuracy of the model. This is calculated during fusion analysis.
258
+ Imported to add to the metadata of the model.
255
259
"""
256
260
channels = gaze_data .channels
257
261
type_amp = gaze_data .daq_type
@@ -403,11 +407,11 @@ def analyze_gaze(
403
407
404
408
model .metadata = SignalModelMetadata (device_spec = device_spec ,
405
409
transform = None ,
406
- acc = model . acc )
410
+ acc = testing_acc )
407
411
log .info ("Training complete for Eyetracker model. Saving data..." )
408
412
save_model (
409
413
model ,
410
- Path (data_folder , f"model_{ device_spec .content_type .lower ()} _{ model .acc } .pkl" ))
414
+ Path (data_folder , f"model_{ device_spec .content_type .lower ()} _{ model .metadata . acc } .pkl" ))
411
415
return model
412
416
413
417
@@ -478,6 +482,7 @@ def offline_analysis(
478
482
479
483
symbol_set = alphabet ()
480
484
fusion = False
485
+ avg_testing_acc_gaze = 0.0
481
486
if num_devices == 2 :
482
487
# Ensure there is an EEG and Eyetracker device
483
488
fusion = True
@@ -506,6 +511,8 @@ def offline_analysis(
506
511
)
507
512
508
513
log .info (f"EEG Accuracy: { eeg_acc } , Gaze Accuracy: { gaze_acc } , Fusion Accuracy: { fusion_acc } " )
514
+ # The average gaze model accuracy:
515
+ avg_testing_acc_gaze = round (np .mean (gaze_acc ), 3 )
509
516
510
517
# Ask the user if they want to proceed with full dataset model training
511
518
models = []
@@ -532,7 +539,8 @@ def offline_analysis(
532
539
parameters ,
533
540
device_spec ,
534
541
data_folder ,
535
- symbol_set = symbol_set )
542
+ symbol_set = symbol_set ,
543
+ testing_acc = avg_testing_acc_gaze )
536
544
models .append (et_model )
537
545
538
546
if alert :
0 commit comments