Skip to content

Commit d0d19d7

Browse files
committed
fix getting probabilities from usleep
1 parent 768e6fb commit d0d19d7

File tree

2 files changed

+74
-50
lines changed

2 files changed

+74
-50
lines changed

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
long_description = fh.read()
55

66
setup(name='sleep_utils',
7-
version='1.23',
7+
version='1.24',
88
description='A collection of tools for sleep research',
99
long_description=long_description,
1010
long_description_content_type="text/markdown",

sleep_utils/usleep_utils.py

Lines changed: 73 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
import mne
1212
from functools import wraps
1313
import numpy as np
14+
import warnings
1415
import pandas as pd
1516

1617
def tempfile_wrapper(func):
@@ -25,40 +26,61 @@ def wrapped(*args, **kwargs):
2526
return res
2627
return wrapped
2728

29+
def disable_ssl_verify():
30+
"""will monkey-patch requests made by usleep-api to veryify=False"""
31+
# Save original method
32+
from usleep_api import USleepAPI
33+
original_request = USleepAPI._request
34+
35+
# Define patched method
36+
def patched_request(self, endpoint, method, as_json=False,
37+
log_response=True, headers=None, **kwargs):
38+
kwargs.setdefault('verify', False)
39+
return original_request(self, endpoint, method, as_json=as_json,
40+
log_response=log_response, headers=headers, **kwargs)
41+
42+
# Apply monkey patch
43+
USleepAPI._request = patched_request
44+
warnings.warn('patched SSL to accept insecure connections')
45+
2846
@tempfile_wrapper
2947
def predict_usleep_raw(raw, api_token, eeg_chs=None, eog_chs=None,
3048
ch_groups=None, model='U-Sleep v2.0', saveto=None,
3149
seconds_per_label=30, tmp_edf=None, return_proba=False):
32-
"""convenience function to upload any mne.io.Raw to usleep
33-
will prepare the file by downsampling to 128 Hz and discarding
34-
any channels that are not used
50+
"""
51+
Run U-Sleep prediction on an mne.io.Raw object.
52+
53+
Prepares the raw data by selecting specified EEG/EOG channels,
54+
downsampling to 128 Hz, and exporting to EDF before submitting
55+
to the U-Sleep API.
56+
3557
Parameters
3658
----------
37-
edf_file : str
38-
link to an edf file
39-
eeg_chs : list
40-
list of channels that are of type EEG and should be used for prediction
41-
channel, e.g. [Fz, Cz]. ch_groups will be created based on that.
42-
eog_chs : list
43-
list of channels that are of type EOG and should be used for prediction
44-
channel, e.g. [lEOG, rEOG]. ch_groups will be created based on that.
45-
ch_groups : list
46-
list of channel tuple, where each tuple contains one EEG and one EOG
47-
channel, e.g. [[Fz, lEOG], [Cz, lEOG], [Fz, rEOG], [Cz, lEOG]].
59+
raw : mne.io.Raw
60+
The raw EEG recording.
4861
api_token : str
49-
U-Sleep API token, apply for it at https://sleep.ai.ku.dk.
62+
U-Sleep API token (https://sleep.ai.ku.dk).
63+
eeg_chs : list of str, optional
64+
EEG channel names for prediction.
65+
eog_chs : list of str, optional
66+
EOG channel names for prediction.
67+
ch_groups : list of tuple, optional
68+
Pairs of (EEG, EOG) channels.
5069
model : str
51-
U-Sleep model to use, e.g. U-Sleep v1.0 or v2.0
70+
U-Sleep model version (default: 'U-Sleep v2.0').
5271
saveto : str, optional
53-
save hypnogram to this file, with one entry per second of
54-
the hypnogram. The default is None.
72+
Path to save the predicted hypnogram.
5573
seconds_per_label : int
56-
number of seconds that each hypnogram label should span. default: 30
74+
Label duration in seconds (default: 30).
75+
tmp_edf : str, optional
76+
Optional path to use for temporary EDF export.
77+
return_proba : bool
78+
If True, return class probabilities.
5779
5880
Returns
5981
-------
60-
hypno : np.array
61-
list of hypnogram labels.
82+
np.ndarray or dict
83+
Hypnogram labels or probability predictions.
6284
"""
6385
assert (eeg_chs is None == eog_chs is None) ^ (ch_groups is None), \
6486
'must either supply eeg_chs and eog_chs OR ch_groups'
@@ -96,38 +118,41 @@ def delete_all_sessions(api_token):
96118
def predict_usleep(edf_file, api_token, eeg_chs=None, eog_chs=None,
97119
ch_groups=None, model='U-Sleep v2.0', saveto=None,
98120
seconds_per_label=30, return_proba=False):
99-
"""helper function to retrieve a hypnogram prediction from usleep
100-
a valid API token is necessary to run the function.
121+
"""
122+
Run U-Sleep prediction on an EDF file via the U-Sleep API.
123+
124+
Requires a valid API token. Optionally saves output and returns
125+
class probabilities.
101126
102127
Parameters
103128
----------
104129
edf_file : str
105-
link to an edf file
106-
eeg_chs : list
107-
list of channels that are of type EEG and should be used for prediction
108-
channel, e.g. [Fz, Cz]. ch_groups will be created based on that.
109-
eog_chs : list
110-
list of channels that are of type EOG and should be used for prediction
111-
channel, e.g. [lEOG, rEOG]. ch_groups will be created based on that.
112-
ch_groups : list
113-
list of channel tuple, where each tuple contains one EEG and one EOG
114-
channel, e.g. [[Fz, lEOG], [Cz, lEOG], [Fz, rEOG], [Cz, lEOG]].
130+
Path to a local EDF file.
115131
api_token : str
116-
U-Sleep API token, apply for it at https://sleep.ai.ku.dk.
132+
U-Sleep API token (https://sleep.ai.ku.dk).
133+
eeg_chs : list of str, optional
134+
EEG channels used for prediction.
135+
eog_chs : list of str, optional
136+
EOG channels used for prediction.
137+
ch_groups : list of tuple, optional
138+
Explicit channel pairs (EEG, EOG). Overrides eeg_chs/eog_chs.
117139
model : str
118-
U-Sleep model to use, e.g. U-Sleep v1.0 or v2.0
140+
U-Sleep model version (default: 'U-Sleep v2.0').
119141
saveto : str, optional
120-
save hypnogram to this file, with one entry per second of
121-
the hypnogram. The default is None.
142+
Path prefix for saving hypnogram (.csv) and confidences (.confidences.csv).
122143
seconds_per_label : int
123-
number of seconds that each hypnogram label should span. default: 30
144+
Seconds per hypnogram label (default: 30).
145+
return_proba : bool
146+
If True, also return class probability array.
124147
125148
Returns
126149
-------
127-
hypno : np.array
128-
list of hypnogram labels.
150+
hypno : np.ndarray
151+
Predicted hypnogram labels.
152+
proba : np.ndarray, optional
153+
Label probabilities (if return_proba is True).
129154
"""
130-
from .. import write_hypno
155+
from sleep_utils import write_hypno
131156
try:
132157
from usleep_api import USleepAPI
133158
except ModuleNotFoundError as e:
@@ -138,7 +163,7 @@ def predict_usleep(edf_file, api_token, eeg_chs=None, eog_chs=None,
138163
assert (eeg_chs is None == eog_chs is None) ^ (ch_groups is None), \
139164
'must either supply eeg_chs and eog_chs OR ch_groups'
140165

141-
# Create an API object and (optionally) a new session.
166+
# Create an API object and a new session.
142167
try:
143168
api = USleepAPI(api_token=api_token)
144169
assert api, f'could init API: {api}, {api.content}'
@@ -158,9 +183,7 @@ def predict_usleep(edf_file, api_token, eeg_chs=None, eog_chs=None,
158183
print(f'uploading {edf_file}')
159184
assert (res:=session.upload_file(edf_file)), f'upload failed: {res}, {res.content}'
160185

161-
# Start the prediction on two channel groups:
162-
# 1: EEG Fpz-Cz + EOG horizontal
163-
# 2: EEG Pz-Oz + EOG horizontal
186+
# Start the prediction on channel groups
164187
# Using 30 second windows (note: U-Slep v1.0 uses 128 Hz re-sampled signals)
165188

166189
assert session.predict(data_per_prediction=128*seconds_per_label,
@@ -173,10 +196,11 @@ def predict_usleep(edf_file, api_token, eeg_chs=None, eog_chs=None,
173196

174197
if success:
175198
# Fetch hypnogram
176-
with tempfile.NamedTemporaryFile(suffix='.npy') as tmp:
177-
session.download_hypnogram(out_path=tmp.name, file_type='npy',
199+
with tempfile.TemporaryDirectory() as tmp:
200+
tmp = os.path.join(tmp, 'probas.npy')
201+
session.download_hypnogram(out_path=tmp, file_type='npy',
178202
with_confidence_scores=True)
179-
proba = np.load(tmp.name)
203+
proba = np.load(tmp)
180204

181205
res = session.get_hypnogram()
182206
hypno = res['hypnogram']
@@ -198,4 +222,4 @@ def predict_usleep(edf_file, api_token, eeg_chs=None, eog_chs=None,
198222
raise Exception(f"Prediction failed.\n\n{hypno}")
199223

200224
# Delete session (i.e., uploaded file, prediction and logs)
201-
return hypno, proba if return_proba else hypno
225+
return (hypno, proba) if return_proba else hypno

0 commit comments

Comments
 (0)