1111import mne
1212from functools import wraps
1313import numpy as np
14+ import warnings
1415import pandas as pd
1516
1617def tempfile_wrapper (func ):
@@ -25,40 +26,61 @@ def wrapped(*args, **kwargs):
2526 return res
2627 return wrapped
2728
29+ def disable_ssl_verify ():
30+ """will monkey-patch requests made by usleep-api to veryify=False"""
31+ # Save original method
32+ from usleep_api import USleepAPI
33+ original_request = USleepAPI ._request
34+
35+ # Define patched method
36+ def patched_request (self , endpoint , method , as_json = False ,
37+ log_response = True , headers = None , ** kwargs ):
38+ kwargs .setdefault ('verify' , False )
39+ return original_request (self , endpoint , method , as_json = as_json ,
40+ log_response = log_response , headers = headers , ** kwargs )
41+
42+ # Apply monkey patch
43+ USleepAPI ._request = patched_request
44+ warnings .warn ('patched SSL to accept insecure connections' )
45+
2846@tempfile_wrapper
2947def predict_usleep_raw (raw , api_token , eeg_chs = None , eog_chs = None ,
3048 ch_groups = None , model = 'U-Sleep v2.0' , saveto = None ,
3149 seconds_per_label = 30 , tmp_edf = None , return_proba = False ):
32- """convenience function to upload any mne.io.Raw to usleep
33- will prepare the file by downsampling to 128 Hz and discarding
34- any channels that are not used
50+ """
51+ Run U-Sleep prediction on an mne.io.Raw object.
52+
53+ Prepares the raw data by selecting specified EEG/EOG channels,
54+ downsampling to 128 Hz, and exporting to EDF before submitting
55+ to the U-Sleep API.
56+
3557 Parameters
3658 ----------
37- edf_file : str
38- link to an edf file
39- eeg_chs : list
40- list of channels that are of type EEG and should be used for prediction
41- channel, e.g. [Fz, Cz]. ch_groups will be created based on that.
42- eog_chs : list
43- list of channels that are of type EOG and should be used for prediction
44- channel, e.g. [lEOG, rEOG]. ch_groups will be created based on that.
45- ch_groups : list
46- list of channel tuple, where each tuple contains one EEG and one EOG
47- channel, e.g. [[Fz, lEOG], [Cz, lEOG], [Fz, rEOG], [Cz, lEOG]].
59+ raw : mne.io.Raw
60+ The raw EEG recording.
4861 api_token : str
49- U-Sleep API token, apply for it at https://sleep.ai.ku.dk.
62+ U-Sleep API token (https://sleep.ai.ku.dk).
63+ eeg_chs : list of str, optional
64+ EEG channel names for prediction.
65+ eog_chs : list of str, optional
66+ EOG channel names for prediction.
67+ ch_groups : list of tuple, optional
68+ Pairs of (EEG, EOG) channels.
5069 model : str
51- U-Sleep model to use, e.g. U-Sleep v1.0 or v2.0
70+ U-Sleep model version (default: ' U-Sleep v2.0').
5271 saveto : str, optional
53- save hypnogram to this file, with one entry per second of
54- the hypnogram. The default is None.
72+ Path to save the predicted hypnogram.
5573 seconds_per_label : int
56- number of seconds that each hypnogram label should span. default: 30
74+ Label duration in seconds (default: 30).
75+ tmp_edf : str, optional
76+ Optional path to use for temporary EDF export.
77+ return_proba : bool
78+ If True, return class probabilities.
5779
5880 Returns
5981 -------
60- hypno : np.array
61- list of hypnogram labels .
82+ np.ndarray or dict
83+ Hypnogram labels or probability predictions .
6284 """
6385 assert (eeg_chs is None == eog_chs is None ) ^ (ch_groups is None ), \
6486 'must either supply eeg_chs and eog_chs OR ch_groups'
@@ -96,38 +118,41 @@ def delete_all_sessions(api_token):
96118def predict_usleep (edf_file , api_token , eeg_chs = None , eog_chs = None ,
97119 ch_groups = None , model = 'U-Sleep v2.0' , saveto = None ,
98120 seconds_per_label = 30 , return_proba = False ):
99- """helper function to retrieve a hypnogram prediction from usleep
100- a valid API token is necessary to run the function.
121+ """
122+ Run U-Sleep prediction on an EDF file via the U-Sleep API.
123+
124+ Requires a valid API token. Optionally saves output and returns
125+ class probabilities.
101126
102127 Parameters
103128 ----------
104129 edf_file : str
105- link to an edf file
106- eeg_chs : list
107- list of channels that are of type EEG and should be used for prediction
108- channel, e.g. [Fz, Cz]. ch_groups will be created based on that.
109- eog_chs : list
110- list of channels that are of type EOG and should be used for prediction
111- channel, e.g. [lEOG, rEOG]. ch_groups will be created based on that.
112- ch_groups : list
113- list of channel tuple, where each tuple contains one EEG and one EOG
114- channel, e.g. [[Fz, lEOG], [Cz, lEOG], [Fz, rEOG], [Cz, lEOG]].
130+ Path to a local EDF file.
115131 api_token : str
116- U-Sleep API token, apply for it at https://sleep.ai.ku.dk.
132+ U-Sleep API token (https://sleep.ai.ku.dk).
133+ eeg_chs : list of str, optional
134+ EEG channels used for prediction.
135+ eog_chs : list of str, optional
136+ EOG channels used for prediction.
137+ ch_groups : list of tuple, optional
138+ Explicit channel pairs (EEG, EOG). Overrides eeg_chs/eog_chs.
117139 model : str
118- U-Sleep model to use, e.g. U-Sleep v1.0 or v2.0
140+ U-Sleep model version (default: ' U-Sleep v2.0').
119141 saveto : str, optional
120- save hypnogram to this file, with one entry per second of
121- the hypnogram. The default is None.
142+ Path prefix for saving hypnogram (.csv) and confidences (.confidences.csv).
122143 seconds_per_label : int
123- number of seconds that each hypnogram label should span. default: 30
144+ Seconds per hypnogram label (default: 30).
145+ return_proba : bool
146+ If True, also return class probability array.
124147
125148 Returns
126149 -------
127- hypno : np.array
128- list of hypnogram labels.
150+ hypno : np.ndarray
151+ Predicted hypnogram labels.
152+ proba : np.ndarray, optional
153+ Label probabilities (if return_proba is True).
129154 """
130- from .. import write_hypno
155+ from sleep_utils import write_hypno
131156 try :
132157 from usleep_api import USleepAPI
133158 except ModuleNotFoundError as e :
@@ -138,7 +163,7 @@ def predict_usleep(edf_file, api_token, eeg_chs=None, eog_chs=None,
138163 assert (eeg_chs is None == eog_chs is None ) ^ (ch_groups is None ), \
139164 'must either supply eeg_chs and eog_chs OR ch_groups'
140165
141- # Create an API object and (optionally) a new session.
166+ # Create an API object and a new session.
142167 try :
143168 api = USleepAPI (api_token = api_token )
144169 assert api , f'could init API: { api } , { api .content } '
@@ -158,9 +183,7 @@ def predict_usleep(edf_file, api_token, eeg_chs=None, eog_chs=None,
158183 print (f'uploading { edf_file } ' )
159184 assert (res := session .upload_file (edf_file )), f'upload failed: { res } , { res .content } '
160185
161- # Start the prediction on two channel groups:
162- # 1: EEG Fpz-Cz + EOG horizontal
163- # 2: EEG Pz-Oz + EOG horizontal
186+ # Start the prediction on channel groups
164187 # Using 30 second windows (note: U-Slep v1.0 uses 128 Hz re-sampled signals)
165188
166189 assert session .predict (data_per_prediction = 128 * seconds_per_label ,
@@ -173,10 +196,11 @@ def predict_usleep(edf_file, api_token, eeg_chs=None, eog_chs=None,
173196
174197 if success :
175198 # Fetch hypnogram
176- with tempfile .NamedTemporaryFile (suffix = '.npy' ) as tmp :
177- session .download_hypnogram (out_path = tmp .name , file_type = 'npy' ,
199+ with tempfile .TemporaryDirectory () as tmp :
200+ tmp = os .path .join (tmp , 'probas.npy' )
201+ session .download_hypnogram (out_path = tmp , file_type = 'npy' ,
178202 with_confidence_scores = True )
179- proba = np .load (tmp . name )
203+ proba = np .load (tmp )
180204
181205 res = session .get_hypnogram ()
182206 hypno = res ['hypnogram' ]
@@ -198,4 +222,4 @@ def predict_usleep(edf_file, api_token, eeg_chs=None, eog_chs=None,
198222 raise Exception (f"Prediction failed.\n \n { hypno } " )
199223
200224 # Delete session (i.e., uploaded file, prediction and logs)
201- return hypno , proba if return_proba else hypno
225+ return ( hypno , proba ) if return_proba else hypno
0 commit comments