11"""Test the functions in yasa/staging.py."""
22
33import unittest
4+ from unittest .mock import MagicMock
45
56import matplotlib .pyplot as plt
67import mne
78import numpy as np
9+ import pandas as pd
810
911from yasa .fetchers import fetch_sample
1012from yasa .hypno import Hypnogram
@@ -31,6 +33,7 @@ def test_sleep_staging(self):
3133 )
3234 print (sls )
3335 print (str (sls ))
36+ assert repr (sls )
3437 sls .get_features ()
3538 y_pred = sls .predict ()
3639 assert isinstance (y_pred , Hypnogram )
@@ -56,3 +59,40 @@ def test_sleep_staging(self):
5659 SleepStaging (raw , eeg_name = "C4" , eog_name = "EOG1" ).fit ()
5760 # .. just the EEG
5861 SleepStaging (raw , eeg_name = "C4" ).fit ()
62+
63+ def test_short_data_warning (self ):
64+ """Test that a warning is raised for recordings shorter than 5 minutes."""
65+ raw_short = raw .copy ().crop (tmax = 200 )
66+ with self .assertLogs ("yasa" , level = "WARNING" ):
67+ SleepStaging (raw_short , eeg_name = "C4" )
68+
69+ def test_validate_predict_errors (self ):
70+ """Test _validate_predict raises ValueError for mismatched features."""
71+ sls = SleepStaging (raw , eeg_name = "C4" )
72+ sls .fit ()
73+
74+ # Features in clf not present in current feature set
75+ clf_mock = MagicMock ()
76+ clf_mock .feature_name_ = ["nonexistent_feature" ]
77+ with self .assertRaises (ValueError ):
78+ sls ._validate_predict (clf_mock )
79+
80+ # Features in current set not present in clf
81+ clf_mock .feature_name_ = sls .feature_name_ [:- 1 ]
82+ with self .assertRaises (ValueError ):
83+ sls ._validate_predict (clf_mock )
84+
85+ def test_plot_predict_proba_no_predict (self ):
86+ """Test that plot_predict_proba raises ValueError before predict is called."""
87+ sls = SleepStaging (raw , eeg_name = "C4" )
88+ with self .assertRaises (ValueError ):
89+ sls .plot_predict_proba ()
90+
91+ def test_predict_proba_without_prior_predict (self ):
92+ """Test that predict_proba internally calls predict when _proba is not set."""
93+ sls = SleepStaging (raw , eeg_name = "C4" )
94+ sls .fit ()
95+ with self .assertWarns (FutureWarning ):
96+ proba = sls .predict_proba ()
97+ assert isinstance (proba , pd .DataFrame )
98+ plt .close ("all" )
0 commit comments