-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathprobability_propagating.py
More file actions
135 lines (105 loc) · 4.83 KB
/
probability_propagating.py
File metadata and controls
135 lines (105 loc) · 4.83 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
import numpy as np
from helper import choose_model, class_name_from_ind, get_metrics
from helper import get_prob, legend_from_ind
from plotting import plot_probability, plot_confusion_matrix
from multiprocessing import Pool
from scipy.special import expit
def propogate(params):
x, y, t, models, num_classes, A, B, class_weight, img_name, name, \
cur_series, num_series, plot_graphs, mode_label, guess_bool = params
step = 1.0
num_points = x.shape[0]
class_weight[class_weight < 0.05] = 0.05
class_weight = class_weight/np.sum(class_weight)
prev_prob = np.ones_like(class_weight)/class_weight.shape[0]
new_t = np.arange(0, np.ceil(t[-1]), step)
probs = np.zeros((new_t.shape[0], num_classes))
probs[0,:] = prev_prob
for ii, tt in enumerate(new_t[1:]):
cur_inds = np.where((t <= tt) & (t >= tt-step))[0]
if cur_inds.shape[0] == 0:
probs[ii+1,:] = prev_prob
# if cur_series == 9:
# print(np.zeros((1,17)))
continue
model_ind = choose_model(t[cur_inds[-1]], models[0])
avg_x = np.mean(x[cur_inds,:], axis=0).reshape(1, x[0,:].shape[0])
model_prob = get_prob(models[1][model_ind], avg_x, num_classes)
model_prob[model_prob < 1e-6] = 1e-6
prev_prob[prev_prob < 1e-6] = 1e-6
probs[ii+1,:] = model_prob*class_weight*prev_prob
probs[ii+1,:] = probs[ii+1,:]/(np.sum(probs[ii+1,:]))
prev_prob = probs[ii,:]
# if cur_series == 9:
# print(avg_x)
# print('Done')
label = y[-1].astype('int')[0]
pred_labels = np.empty_like(A).astype('int')
earliness = np.empty_like(A)
thresh_met = np.zeros_like(A)
for thresh_ind, (aa, bb) in enumerate(zip(A, B)):
if guess_bool:
pred_label = mode_label
ind_of_classification = 0
else:
thresh = aa*np.exp(bb*new_t)
thresh[thresh>1.0] = 1.0
thresh[thresh<0.3] = 0.3
inds = [np.where(probs[:,class_num] > thresh)[0]
for class_num in range(num_classes)]
num_inds = [inds[class_num].shape[0] for class_num in range(num_classes)]
max_inds = [probs.shape[0]+1 for class_num in range(num_classes)]
for class_num in range(num_classes):
if num_inds[class_num] > 0:
max_inds[class_num] = np.min(inds[class_num])
pred_label = np.argmin(max_inds)
if num_inds[pred_label] == 0:
pred_label = num_classes - 2
ind_of_classification = -1
else:
ind_of_classification = max_inds[pred_label]
#Threshold not met
if ind_of_classification == -1:
pred_label = mode_label
else:
thresh_met[thresh_ind] = 1
if plot_graphs:
if ind_of_classification == -1:
count_text = "{:.0%}, {:.1%} Threshold\nClassified at: end".format(
aa, bb)
else:
count_text = "{:.0%}, {:.1%} Threshold\nClassified at: {} sec".format(
aa, bb, new_t[ind_of_classification])
legend = legend_from_ind(num_classes)
title = "{}: {}/{}\n True Label: {}, Predicted Label: {}".format(name,
cur_series+1, num_series, class_name_from_ind(label, num_classes),
class_name_from_ind(pred_label, num_classes))
if label == pred_label:
result = 'Correct'
else:
result = 'Incorrect'
full_img_name = '{}_{:.4}_{:.4}_Series_{:03d}.png'.format(img_name, aa, bb, cur_series+1)
plot_probability(new_t, probs, legend, title, result, count_text,
full_img_name, thresh)
pred_labels[thresh_ind] = pred_label
if new_t[-1] < 1e-6:
earliness[thresh_ind] = (new_t[-1] - new_t[ind_of_classification])/(new_t[-1] + 1e-6)
else:
earliness[thresh_ind] = (new_t[-1] - new_t[ind_of_classification])/(new_t[-1])
return label, pred_labels, earliness, thresh_met
def run_models(X, Y, T, models, A, B, class_weight, num_classes = 3,
plot_graphs=False, plot_confusions=False, name='test',
img_name = '', num_workers = 3, incr=0.05, guess_bool=False,
guess_acc_bool = False):
res = []
params = []
if guess_acc_bool:
mode_label = np.argmax(class_weight)
else:
mode_label = 10
for ii in range(len(X)):
res.append(propogate((X[ii], Y[ii], T[ii], models, num_classes,
A, B, class_weight, img_name, name, ii, len(X),
plot_graphs, mode_label, guess_bool)))
acc, early_mats, thresh_met_mats = get_metrics(res, A.shape[0], num_classes)
return acc, early_mats, thresh_met_mats