-
Notifications
You must be signed in to change notification settings - Fork 244
Expand file tree
/
Copy pathunsupervised_predictor.py
More file actions
193 lines (186 loc) · 11.3 KB
/
unsupervised_predictor.py
File metadata and controls
193 lines (186 loc) · 11.3 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
"""Unsupervised learning methods including POS, GREEN, CHROME, ICA, LGI and PBV."""
import numpy as np
from evaluation.post_process import *
from unsupervised_methods.methods.CHROME_DEHAAN import *
from unsupervised_methods.methods.GREEN import *
from unsupervised_methods.methods.ICA_POH import *
from unsupervised_methods.methods.LGI import *
from unsupervised_methods.methods.PBV import *
from unsupervised_methods.methods.POS_WANG import *
from unsupervised_methods.methods.OMIT import *
from tqdm import tqdm
from evaluation.BlandAltmanPy import BlandAltman
def unsupervised_predict(config, data_loader, method_name):
""" Model evaluation on the testing dataset."""
if data_loader["unsupervised"] is None:
raise ValueError("No data for unsupervised method predicting")
print("===Unsupervised Method ( " + method_name + " ) Predicting ===")
predict_hr_peak_all = []
gt_hr_peak_all = []
predict_hr_fft_all = []
gt_hr_fft_all = []
SNR_all = []
MACC_all = []
sbar = tqdm(data_loader["unsupervised"], ncols=80)
for _, test_batch in enumerate(sbar):
batch_size = test_batch[0].shape[0]
for idx in range(batch_size):
data_input, labels_input = test_batch[0][idx].cpu().numpy(), test_batch[1][idx].cpu().numpy()
data_input = data_input[..., :3]
if method_name == "POS":
BVP = POS_WANG(data_input, config.UNSUPERVISED.DATA.FS)
elif method_name == "CHROM":
BVP = CHROME_DEHAAN(data_input, config.UNSUPERVISED.DATA.FS)
elif method_name == "ICA":
BVP = ICA_POH(data_input, config.UNSUPERVISED.DATA.FS)
elif method_name == "GREEN":
BVP = GREEN(data_input)
elif method_name == "LGI":
BVP = LGI(data_input)
elif method_name == "PBV":
BVP = PBV(data_input)
elif method_name == "OMIT":
BVP = OMIT(data_input)
else:
raise ValueError("unsupervised method name wrong!")
video_frame_size = test_batch[0].shape[1]
if config.INFERENCE.EVALUATION_WINDOW.USE_SMALLER_WINDOW:
window_frame_size = config.INFERENCE.EVALUATION_WINDOW.WINDOW_SIZE * config.UNSUPERVISED.DATA.FS
if window_frame_size > video_frame_size:
window_frame_size = video_frame_size
else:
window_frame_size = video_frame_size
for i in range(0, len(BVP), window_frame_size):
BVP_window = BVP[i:i+window_frame_size]
label_window = labels_input[i:i+window_frame_size]
if len(BVP_window) < 9:
print(f"Window frame size of {len(BVP_window)} is smaller than minimum pad length of 9. Window ignored!")
continue
if config.INFERENCE.EVALUATION_METHOD == "peak detection":
gt_hr, pre_hr, SNR, macc = calculate_metric_per_video(BVP_window, label_window, diff_flag=False,
fs=config.UNSUPERVISED.DATA.FS, hr_method='Peak')
gt_hr_peak_all.append(gt_hr)
predict_hr_peak_all.append(pre_hr)
SNR_all.append(SNR)
MACC_all.append(macc)
elif config.INFERENCE.EVALUATION_METHOD == "FFT":
gt_fft_hr, pre_fft_hr, SNR, macc = calculate_metric_per_video(BVP_window, label_window, diff_flag=False,
fs=config.UNSUPERVISED.DATA.FS, hr_method='FFT')
gt_hr_fft_all.append(gt_fft_hr)
predict_hr_fft_all.append(pre_fft_hr)
SNR_all.append(SNR)
MACC_all.append(macc)
else:
raise ValueError("Inference evaluation method name wrong!")
print("Used Unsupervised Method: " + method_name)
# Filename ID to be used in any results files (e.g., Bland-Altman plots) that get saved
if config.TOOLBOX_MODE == 'unsupervised_method':
filename_id = method_name + "_" + config.UNSUPERVISED.DATA.DATASET
else:
raise ValueError('unsupervised_predictor.py evaluation only supports unsupervised_method!')
if config.INFERENCE.EVALUATION_METHOD == "peak detection":
predict_hr_peak_all = np.array(predict_hr_peak_all)
gt_hr_peak_all = np.array(gt_hr_peak_all)
SNR_all = np.array(SNR_all)
MACC_all = np.array(MACC_all)
num_test_samples = len(predict_hr_peak_all)
for metric in config.UNSUPERVISED.METRICS:
if metric == "MAE":
MAE_PEAK = np.mean(np.abs(predict_hr_peak_all - gt_hr_peak_all))
standard_error = np.std(np.abs(predict_hr_peak_all - gt_hr_peak_all)) / np.sqrt(num_test_samples)
print("Peak MAE (Peak Label): {0} +/- {1}".format(MAE_PEAK, standard_error))
elif metric == "RMSE":
# Calculate the squared errors, then RMSE, in order to allow
# for a more robust and intuitive standard error that won't
# be influenced by abnormal distributions of errors.
squared_errors = np.square(predict_hr_peak_all - gt_hr_peak_all)
RMSE_PEAK = np.sqrt(np.mean(squared_errors))
standard_error = np.sqrt(np.std(squared_errors) / np.sqrt(num_test_samples))
print("PEAK RMSE (Peak Label): {0} +/- {1}".format(RMSE_PEAK, standard_error))
elif metric == "MAPE":
MAPE_PEAK = np.mean(np.abs((predict_hr_peak_all - gt_hr_peak_all) / gt_hr_peak_all)) * 100
standard_error = np.std(np.abs((predict_hr_peak_all - gt_hr_peak_all) / gt_hr_peak_all)) / np.sqrt(num_test_samples) * 100
print("PEAK MAPE (Peak Label): {0} +/- {1}".format(MAPE_PEAK, standard_error))
elif metric == "Pearson":
Pearson_PEAK = np.corrcoef(predict_hr_peak_all, gt_hr_peak_all)
correlation_coefficient = Pearson_PEAK[0][1]
standard_error = np.sqrt((1 - correlation_coefficient**2) / (num_test_samples - 2))
print("PEAK Pearson (Peak Label): {0} +/- {1}".format(correlation_coefficient, standard_error))
elif metric == "SNR":
SNR_FFT = np.mean(SNR_all)
standard_error = np.std(SNR_all) / np.sqrt(num_test_samples)
print("FFT SNR (FFT Label): {0} +/- {1} (dB)".format(SNR_FFT, standard_error))
elif metric == "MACC":
MACC_avg = np.mean(MACC_all)
standard_error = np.std(MACC_all) / np.sqrt(num_test_samples)
print("MACC (avg): {0} +/- {1}".format(MACC_avg, standard_error))
elif "BA" in metric:
compare = BlandAltman(gt_hr_peak_all, predict_hr_peak_all, config, averaged=True)
compare.scatter_plot(
x_label='GT PPG HR [bpm]',
y_label='rPPG HR [bpm]',
show_legend=True, figure_size=(5, 5),
the_title=f'{filename_id}_Peak_BlandAltman_ScatterPlot',
file_name=f'{filename_id}_Peak_BlandAltman_ScatterPlot.pdf')
compare.difference_plot(
x_label='Difference between rPPG HR and GT PPG HR [bpm]',
y_label='Average of rPPG HR and GT PPG HR [bpm]',
show_legend=True, figure_size=(5, 5),
the_title=f'{filename_id}_Peak_BlandAltman_DifferencePlot',
file_name=f'{filename_id}_Peak_BlandAltman_DifferencePlot.pdf')
else:
raise ValueError("Wrong Test Metric Type")
elif config.INFERENCE.EVALUATION_METHOD == "FFT":
predict_hr_fft_all = np.array(predict_hr_fft_all)
gt_hr_fft_all = np.array(gt_hr_fft_all)
SNR_all = np.array(SNR_all)
MACC_all = np.array(MACC_all)
num_test_samples = len(predict_hr_fft_all)
for metric in config.UNSUPERVISED.METRICS:
if metric == "MAE":
MAE_FFT = np.mean(np.abs(predict_hr_fft_all - gt_hr_fft_all))
standard_error = np.std(np.abs(predict_hr_fft_all - gt_hr_fft_all)) / np.sqrt(num_test_samples)
print("FFT MAE (FFT Label): {0} +/- {1}".format(MAE_FFT, standard_error))
elif metric == "RMSE":
# Calculate the squared errors, then RMSE, in order to allow
# for a more robust and intuitive standard error that won't
# be influenced by abnormal distributions of errors.
squared_errors = np.square(predict_hr_fft_all - gt_hr_fft_all)
RMSE_FFT = np.sqrt(np.mean(squared_errors))
standard_error = np.sqrt(np.std(squared_errors) / np.sqrt(num_test_samples))
print("FFT RMSE (FFT Label): {0} +/- {1}".format(RMSE_FFT, standard_error))
elif metric == "MAPE":
MAPE_FFT = np.mean(np.abs((predict_hr_fft_all - gt_hr_fft_all) / gt_hr_fft_all)) * 100
standard_error = np.std(np.abs((predict_hr_fft_all - gt_hr_fft_all) / gt_hr_fft_all)) / np.sqrt(num_test_samples) * 100
print("FFT MAPE (FFT Label): {0} +/- {1}".format(MAPE_FFT, standard_error))
elif metric == "Pearson":
Pearson_FFT = np.corrcoef(predict_hr_fft_all, gt_hr_fft_all)
correlation_coefficient = Pearson_FFT[0][1]
standard_error = np.sqrt((1 - correlation_coefficient**2) / (num_test_samples - 2))
print("FFT Pearson (FFT Label): {0} +/- {1}".format(correlation_coefficient, standard_error))
elif metric == "SNR":
SNR_PEAK = np.mean(SNR_all)
standard_error = np.std(SNR_all) / np.sqrt(num_test_samples)
print("FFT SNR (FFT Label): {0} +/- {1} (dB)".format(SNR_PEAK, standard_error))
elif metric == "MACC":
MACC_avg = np.mean(MACC_all)
standard_error = np.std(MACC_all) / np.sqrt(num_test_samples)
print("MACC (avg): {0} +/- {1}".format(MACC_avg, standard_error))
elif "BA" in metric:
compare = BlandAltman(gt_hr_fft_all, predict_hr_fft_all, config, averaged=True)
compare.scatter_plot(
x_label='GT PPG HR [bpm]',
y_label='rPPG HR [bpm]',
show_legend=True, figure_size=(5, 5),
the_title=f'{filename_id}_FFT_BlandAltman_ScatterPlot',
file_name=f'{filename_id}_FFT_BlandAltman_ScatterPlot.pdf')
compare.difference_plot(
x_label='Difference between rPPG HR and GT PPG HR [bpm]',
y_label='Average of rPPG HR and GT PPG HR [bpm]',
show_legend=True, figure_size=(5, 5),
the_title=f'{filename_id}_FFT_BlandAltman_DifferencePlot',
file_name=f'{filename_id}_FFT_BlandAltman_DifferencePlot.pdf')
else:
raise ValueError("Wrong Test Metric Type")
else:
raise ValueError("Inference evaluation method name wrong!")