Skip to content

Commit ee87db7

Browse files
Merge pull request #1018 from peterhcharlton/add_msptdfastv1.1
Add MSPTDfast algorithm
2 parents 8195d2f + 5d3e044 commit ee87db7

File tree

1 file changed

+245
-1
lines changed

1 file changed

+245
-1
lines changed

neurokit2/ppg/ppg_findpeaks.py

+245-1
Original file line numberDiff line numberDiff line change
@@ -68,20 +68,26 @@ def ppg_findpeaks(
6868
* Bishop, S. M., & Ercole, A. (2018). Multi-scale peak and trough detection optimised for
6969
periodic and quasi-periodic neuroscience data. In Intracranial Pressure & Neuromonitoring XVI
7070
(pp. 189-195). Springer International Publishing.
71+
* Charlton, P. H. et al. (2024). MSPTDfast: An Efficient Photoplethysmography Beat Detection
72+
Algorithm. Proc CinC.
7173
7274
"""
7375
method = method.lower()
7476
if method in ["elgendi"]:
7577
peaks = _ppg_findpeaks_elgendi(ppg_cleaned, sampling_rate, show=show, **kwargs)
7678
elif method in ["msptd", "bishop2018", "bishop"]:
7779
peaks, _ = _ppg_findpeaks_bishop(ppg_cleaned, show=show, **kwargs)
80+
elif method in ["msptdfast", "msptdfastv1", "charlton2024", "charlton"]:
81+
peaks, onsets = _ppg_findpeaks_charlton(ppg_cleaned, sampling_rate, show=show, **kwargs)
7882
else:
7983
raise ValueError(
80-
"`method` not found. Must be one of the following: 'elgendi', 'bishop'."
84+
"`method` not found. Must be one of the following: 'elgendi', 'bishop', 'charlton'."
8185
)
8286

8387
# Prepare output.
8488
info = {"PPG_Peaks": peaks}
89+
if 'onsets' in locals():
90+
info["PPG_Onsets"] = onsets
8591

8692
return info
8793

@@ -239,3 +245,241 @@ def _ppg_findpeaks_bishop(
239245
ax0.set_title("PPG Peaks (Method by Bishop et al., 2018)")
240246

241247
return peaks, onsets
248+
249+
250+
def _ppg_findpeaks_charlton(
251+
signal,
252+
sampling_rate=1000,
253+
show=False,
254+
):
255+
"""Implementation of Charlton et al (2024) MSPTDfast: An Efficient Photoplethysmography
256+
Beat Detection Algorithm. 2024 Computing in Cardiology (CinC), Karlsruhe, Germany,
257+
doi:10.1101/2024.07.18.24310627.
258+
"""
259+
260+
# Inner functions
261+
262+
def find_m_max(x, N, max_scale, m_max):
263+
"""Find local maxima scalogram for peaks
264+
"""
265+
266+
for k in range(1, max_scale + 1): # scalogram scales
267+
for i in range(k + 2, N - k + 2):
268+
if x[i - 2] > x[i - k - 2] and x[i - 2] > x[i + k - 2]:
269+
m_max[k - 1, i - 2] = True
270+
271+
return m_max
272+
273+
def find_m_min(x, N, max_scale, m_min):
274+
"""Find local minima scalogram for onsets
275+
"""
276+
277+
for k in range(1, max_scale + 1): # scalogram scales
278+
for i in range(k + 2, N - k + 2):
279+
if x[i - 2] < x[i - k - 2] and x[i - 2] < x[i + k - 2]:
280+
m_min[k - 1, i - 2] = True
281+
282+
return m_min
283+
284+
def find_lms_using_msptd_approach(max_scale, x, options):
285+
"""Find local maxima (or minima) scalogram(s) using the
286+
MSPTD approach
287+
"""
288+
289+
# Setup
290+
N = len(x)
291+
292+
# Find local maxima scalogram (if required)
293+
if options["find_pks"]:
294+
m_max = np.full((max_scale, N), False) # matrix for maxima
295+
m_max = find_m_max(x, N, max_scale, m_max)
296+
else:
297+
m_max = None
298+
299+
# Find local minima scalogram (if required)
300+
if options["find_trs"]:
301+
m_min = np.full((max_scale, N), False) # matrix for minima
302+
m_min = find_m_min(x, N, max_scale, m_min)
303+
else:
304+
m_min = None
305+
306+
return m_max, m_min
307+
308+
def downsample(win_sig, ds_factor):
309+
"""Downsamples signal by picking out every nth sample, where n is
310+
specified by ds_factor
311+
"""
312+
313+
return win_sig[::ds_factor]
314+
315+
def detect_peaks_and_onsets_using_msptd(signal, fs, options):
316+
"""Detect peaks and onsets in a PPG signal using a modified MSPTD approach
317+
(where the modifications are those specified in Charlton et al. 2024)
318+
"""
319+
320+
# Setup
321+
N = len(signal)
322+
L = int(np.ceil(N / 2) - 1)
323+
324+
# Step 0: Don't calculate scales outside the range of plausible HRs
325+
326+
plaus_hr_hz = np.array(options['plaus_hr_bpm']) / 60 # in Hz
327+
init_scales = np.arange(1, L + 1)
328+
durn_signal = len(signal) / fs
329+
init_scales_fs = (L / init_scales) / durn_signal
330+
if options['use_reduced_lms_scales']:
331+
init_scales_inc_log = init_scales_fs >= plaus_hr_hz[0]
332+
else:
333+
init_scales_inc_log = np.ones_like(init_scales_fs, dtype=bool) # DIDN"T FULLY UNDERSTAND
334+
335+
max_scale_index = np.where(init_scales_inc_log)[0] # DIDN"T FULLY UNDERSTAND THIS AND NEXT FEW LINES
336+
if max_scale_index.size > 0:
337+
max_scale = max_scale_index[-1] + 1 # Add 1 to convert from 0-based to 1-based index
338+
else:
339+
max_scale = None # Or handle the case where no scales are valid
340+
341+
# Step 1: calculate local maxima and local minima scalograms
342+
343+
# - detrend
344+
x = scipy.signal.detrend(signal, type="linear")
345+
346+
# - populate LMS matrices
347+
[m_max, m_min] = find_lms_using_msptd_approach(max_scale, x, options)
348+
349+
# Step 2: find the scale with the most local maxima (or local minima)
350+
351+
# - row-wise summation (i.e. sum each row)
352+
if options["find_pks"]:
353+
gamma_max = np.sum(m_max, axis=1) # the "axis=1" option makes it row-wise
354+
if options["find_trs"]:
355+
gamma_min = np.sum(m_min, axis=1)
356+
# - find scale with the most local maxima (or local minima)
357+
if options["find_pks"]:
358+
lambda_max = np.argmax(gamma_max)
359+
if options["find_trs"]:
360+
lambda_min = np.argmax(gamma_min)
361+
362+
# Step 3: Use lambda to remove all elements of m for which k>lambda
363+
first_scale_to_include = np.argmax(init_scales_inc_log)
364+
if options["find_pks"]:
365+
m_max = m_max[first_scale_to_include:lambda_max + 1, :]
366+
if options["find_trs"]:
367+
m_min = m_min[first_scale_to_include:lambda_min + 1, :]
368+
369+
# Step 4: Find peaks (and onsets)
370+
# - column-wise summation
371+
if options["find_pks"]:
372+
m_max_sum = np.sum(m_max == False, axis=0)
373+
peaks = np.where(m_max_sum == 0)[0].astype(int)
374+
else:
375+
peaks = []
376+
377+
if options["find_trs"]:
378+
m_min_sum = np.sum(m_min == False, axis=0)
379+
onsets = np.where(m_min_sum == 0)[0].astype(int)
380+
else:
381+
onsets = []
382+
383+
return peaks, onsets
384+
385+
# ~~~ Main function ~~~
386+
387+
# Specify settings
388+
# - version: optimal selection (CinC 2024)
389+
options = {
390+
'find_trs': True, # whether or not to find onsets
391+
'find_pks': True, # whether or not to find peaks
392+
'do_ds': True, # whether or not to do downsampling
393+
'ds_freq': 20, # the target downsampling frequency
394+
'use_reduced_lms_scales': True, # whether or not to reduce the number of scales (default 30 bpm)
395+
'win_len': 8, # duration of individual windows for analysis
396+
'win_overlap': 0.2, # proportion of window overlap
397+
'plaus_hr_bpm': [30, 200] # range of plausible HRs (only the lower bound is used)
398+
}
399+
400+
# Split into overlapping windows
401+
no_samps_in_win = options["win_len"] * sampling_rate
402+
if len(signal) <= no_samps_in_win:
403+
win_starts = 0
404+
win_ends = len(signal) - 1
405+
else:
406+
win_offset = round(no_samps_in_win * (1 - options["win_overlap"]))
407+
win_starts = list(range(0, len(signal) - no_samps_in_win + 1, win_offset))
408+
win_ends = [start + 1 + no_samps_in_win for start in win_starts]
409+
if win_ends[-1] < len(signal):
410+
win_starts.append(len(signal) - 1 - no_samps_in_win)
411+
win_ends.append(len(signal))
412+
# this ensures that the windows include the entire signal duration
413+
414+
# Set up downsampling if the sampling frequency is particularly high
415+
if options["do_ds"]:
416+
min_fs = options["ds_freq"]
417+
if sampling_rate > min_fs:
418+
ds_factor = int(np.floor(sampling_rate / min_fs))
419+
ds_fs = sampling_rate / np.floor(sampling_rate / min_fs)
420+
else:
421+
options["do_ds"] = False
422+
423+
# detect peaks and onsets in each window
424+
peaks = []
425+
onsets = []
426+
427+
# cycle through each window
428+
for win_no in range(len(win_starts)):
429+
# Extract this window's data
430+
win_sig = signal[win_starts[win_no]:win_ends[win_no]]
431+
432+
# Downsample signal
433+
if options['do_ds']:
434+
rel_sig = downsample(win_sig, ds_factor)
435+
rel_fs = ds_fs
436+
else:
437+
rel_sig = win_sig
438+
rel_fs = sampling_rate
439+
440+
# Detect peaks and onsets
441+
p, t = detect_peaks_and_onsets_using_msptd(rel_sig, rel_fs, options)
442+
443+
# Resample peaks
444+
if options['do_ds']:
445+
p = [peak * ds_factor for peak in p]
446+
t = [onset * ds_factor for onset in t]
447+
448+
# Correct peak indices by finding highest point within tolerance either side of detected peaks
449+
tol_durn = 0.05
450+
if rel_fs < 10:
451+
tol_durn = 0.2
452+
elif rel_fs < 20:
453+
tol_durn = 0.1
454+
tol = int(np.ceil(rel_fs * tol_durn))
455+
456+
for pk_no in range(len(p)):
457+
segment = win_sig[(p[pk_no] - tol):(p[pk_no] + tol + 1)]
458+
temp = np.argmax(segment)
459+
p[pk_no] = p[pk_no] - tol + temp
460+
461+
# Correct onset indices by finding highest point within tolerance either side of detected onsets
462+
for onset_no in range(len(t)):
463+
segment = win_sig[(t[onset_no] - tol):(t[onset_no] + tol + 1)]
464+
temp = np.argmin(segment)
465+
t[onset_no] = t[onset_no] - tol + temp
466+
467+
# Store peaks and onsets
468+
win_peaks = [peak + win_starts[win_no] for peak in p]
469+
peaks.extend(win_peaks)
470+
win_onsets = [onset + win_starts[win_no] for onset in t]
471+
onsets.extend(win_onsets)
472+
473+
# Tidy up detected peaks and onsets (by ordering them and only retaining unique ones)
474+
peaks = sorted(set(peaks))
475+
onsets = sorted(set(onsets))
476+
477+
# Plot results (optional)
478+
if show:
479+
_, ax0 = plt.subplots(nrows=1, ncols=1, sharex=True)
480+
ax0.plot(signal, label="signal")
481+
ax0.scatter(peaks, signal[peaks], c="r")
482+
ax0.scatter(onsets, signal[onsets], c="b")
483+
ax0.set_title("PPG Onsets (Method by Charlton et al., 2024)")
484+
485+
return peaks, onsets

0 commit comments

Comments
 (0)