88
99from sklearn .metrics import roc_curve , auc , average_precision_score
1010
11-
1211from tensorflow .keras .losses import binary_crossentropy , categorical_crossentropy , sparse_categorical_crossentropy
13- from tensorflow .keras .losses import logcosh , cosine_similarity , mean_squared_error , mean_absolute_error , mean_absolute_percentage_error
12+ from tensorflow .keras .losses import LogCosh , CosineSimilarity , MSE , MAE , MAPE , Dice
13+ from keras .saving import register_keras_serializable
1414
15- from neurite .tf .losses import Dice
1615
1716STRING_METRICS = [
1817 'categorical_crossentropy' ,'binary_crossentropy' ,'mean_absolute_error' ,'mae' ,
19- 'mean_squared_error' , 'mse' , 'cosine_similarity' , 'logcosh ' , 'sparse_categorical_crossentropy' ,
18+ 'mean_squared_error' , 'mse' , 'cosine_similarity' , 'log_cosh ' , 'sparse_categorical_crossentropy' ,
2019]
2120
2221
@@ -48,6 +47,7 @@ def weighted_crossentropy(weights, name='anonymous'):
4847 string_fxn += '\t return loss\n '
4948 exec (string_fxn , globals (), locals ())
5049 loss_fxn = eval (name + fxn_postfix , globals (), locals ())
50+ loss_fxn = register_keras_serializable ()(loss_fxn )
5151 return loss_fxn
5252
5353
@@ -109,39 +109,39 @@ def paired_angle_between_batches(tensors):
109109
110110def ignore_zeros_l2 (y_true , y_pred ):
111111 mask = K .cast (K .not_equal (y_true , 0 ), K .floatx ())
112- return mean_squared_error (y_true * mask , y_pred * mask )
112+ return MSE (y_true * mask , y_pred * mask )
113113
114114
115115def ignore_zeros_logcosh (y_true , y_pred ):
116116 mask = K .cast (K .not_equal (y_true , 0 ), K .floatx ())
117- return logcosh (y_true * mask , y_pred * mask )
117+ return LogCosh (y_true * mask , y_pred * mask )
118118
119119
120120def sentinel_logcosh_loss (sentinel : float ):
121121 def ignore_sentinel_logcosh (y_true , y_pred ):
122122 mask = K .cast (K .not_equal (y_true , sentinel ), K .floatx ())
123- return logcosh (y_true * mask , y_pred * mask )
123+ return LogCosh (y_true * mask , y_pred * mask )
124124 return ignore_sentinel_logcosh
125125
126126
127127def y_true_times_mse (y_true , y_pred ):
128- return K .maximum (y_true , 1.0 )* mean_squared_error (y_true , y_pred )
128+ return K .maximum (y_true , 1.0 )* MSE (y_true , y_pred )
129129
130130
131131def mse_10x (y_true , y_pred ):
132- return 10.0 * mean_squared_error (y_true , y_pred )
132+ return 10.0 * MSE (y_true , y_pred )
133133
134134
135135def y_true_squared_times_mse (y_true , y_pred ):
136- return K .maximum (1.0 + y_true , 1.0 )* K .maximum (1.0 + y_true , 1.0 )* mean_squared_error (y_true , y_pred )
136+ return K .maximum (1.0 + y_true , 1.0 )* K .maximum (1.0 + y_true , 1.0 )* MSE (y_true , y_pred )
137137
138138
139139def y_true_cubed_times_mse (y_true , y_pred ):
140- return K .maximum (y_true , 1.0 )* K .maximum (y_true , 1.0 )* K .maximum (y_true , 1.0 )* mean_squared_error (y_true , y_pred )
140+ return K .maximum (y_true , 1.0 )* K .maximum (y_true , 1.0 )* K .maximum (y_true , 1.0 )* MSE (y_true , y_pred )
141141
142142
143143def y_true_squared_times_logcosh (y_true , y_pred ):
144- return K .maximum (1.0 + y_true , 1.0 )* K .maximum (1.0 + y_true , 1.0 )* logcosh (y_true , y_pred )
144+ return K .maximum (1.0 + y_true , 1.0 )* K .maximum (1.0 + y_true , 1.0 )* LogCosh (y_true , y_pred )
145145
146146
147147def two_batch_euclidean (tensors ):
@@ -265,6 +265,7 @@ def loss(y_true, y_pred):
265265 return loss
266266
267267def dice (y_true , y_pred ):
268+ return Dice ()(y_true , y_pred )
268269 return Dice (laplace_smoothing = 1e-05 ).mean_loss (y_true , y_pred )
269270
270271def per_class_dice (labels ):
@@ -273,12 +274,13 @@ def per_class_dice(labels):
273274 label_idx = labels [label_key ]
274275 fxn_name = label_key .replace ('-' , '_' ).replace (' ' , '_' )
275276 string_fxn = 'def ' + fxn_name + '_dice(y_true, y_pred):\n '
276- string_fxn += '\t dice = Dice(laplace_smoothing=1e-05).dice (y_true, y_pred)\n '
277- string_fxn += '\t dice = K.mean(dice, axis=0)[' + str (label_idx )+ ']\n '
277+ string_fxn += '\t dice = tf.keras.losses. Dice() (y_true, y_pred)\n '
278+ # string_fxn += '\tdice = K.mean(dice, axis=0)['+str(label_idx)+']\n'
278279 string_fxn += '\t return dice'
279280
280281 exec (string_fxn )
281282 dice_fxn = eval (fxn_name + '_dice' )
283+ dice_fxn = register_keras_serializable ()(dice_fxn )
282284 dice_fxns .append (dice_fxn )
283285
284286 return dice_fxns
@@ -299,6 +301,7 @@ def per_class_recall(labels):
299301
300302 exec (string_fxn )
301303 recall_fxn = eval (fxn_name + '_recall' )
304+ recall_fxn = register_keras_serializable ()(recall_fxn )
302305 recall_fxns .append (recall_fxn )
303306
304307 return recall_fxns
@@ -317,6 +320,7 @@ def per_class_precision(labels):
317320
318321 exec (string_fxn )
319322 precision_fxn = eval (fxn_name + '_precision' )
323+ precision_fxn = register_keras_serializable ()(precision_fxn )
320324 precision_fxns .append (precision_fxn )
321325
322326 return precision_fxns
@@ -335,6 +339,7 @@ def per_class_recall_3d(labels):
335339
336340 exec (string_fxn )
337341 recall_fxn = eval (fxn_prefix + '_recall' )
342+ recall_fxn = register_keras_serializable ()(recall_fxn )
338343 recall_fxns .append (recall_fxn )
339344
340345 return recall_fxns
@@ -353,6 +358,7 @@ def per_class_precision_3d(labels):
353358
354359 exec (string_fxn )
355360 precision_fxn = eval (fxn_prefix + '_precision' )
361+ precision_fxn = register_keras_serializable ()(precision_fxn )
356362 precision_fxns .append (precision_fxn )
357363
358364 return precision_fxns
@@ -371,6 +377,7 @@ def per_class_recall_4d(labels):
371377
372378 exec (string_fxn )
373379 recall_fxn = eval (fxn_prefix + '_recall' )
380+ recall_fxn = register_keras_serializable ()(recall_fxn )
374381 recall_fxns .append (recall_fxn )
375382
376383 return recall_fxns
@@ -389,6 +396,8 @@ def per_class_precision_4d(labels):
389396
390397 exec (string_fxn )
391398 precision_fxn = eval (fxn_prefix + '_precision' )
399+ precision_fxn = register_keras_serializable ()(precision_fxn )
400+
392401 precision_fxns .append (precision_fxn )
393402
394403 return precision_fxns
@@ -407,6 +416,7 @@ def per_class_recall_5d(labels):
407416
408417 exec (string_fxn )
409418 recall_fxn = eval (fxn_prefix + '_recall' )
419+ recall_fxn = register_keras_serializable ()(recall_fxn )
410420 recall_fxns .append (recall_fxn )
411421
412422 return recall_fxns
@@ -425,6 +435,7 @@ def per_class_precision_5d(labels):
425435
426436 exec (string_fxn )
427437 precision_fxn = eval (fxn_prefix + '_precision' )
438+ precision_fxn = register_keras_serializable ()(precision_fxn )
428439 precision_fxns .append (precision_fxn )
429440
430441 return precision_fxns
@@ -449,15 +460,15 @@ def get_metric_dict(output_tensor_maps):
449460 elif tm .loss == 'binary_crossentropy' :
450461 losses .append (binary_crossentropy )
451462 elif tm .loss == 'mean_absolute_error' or tm .loss == 'mae' :
452- losses .append (mean_absolute_error )
463+ losses .append (MSE )
453464 elif tm .loss == 'mean_squared_error' or tm .loss == 'mse' :
454- losses .append (mean_squared_error )
465+ losses .append (MSE )
455466 elif tm .loss == 'cosine_similarity' :
456- losses .append (cosine_similarity )
457- elif tm .loss == 'logcosh ' :
458- losses .append (logcosh )
467+ losses .append (CosineSimilarity )
468+ elif tm .loss == 'log_cosh ' :
469+ losses .append (LogCosh )
459470 elif tm .loss == 'mape' :
460- losses .append (mean_absolute_percentage_error )
471+ losses .append (MAPE )
461472 elif hasattr (tm .loss , '__name__' ):
462473 metrics [tm .loss .__name__ ] = tm .loss
463474 losses .append (tm .loss )
@@ -857,4 +868,12 @@ def result(self):
857868 def reset_state (self ):
858869 # Reset the metric state variables
859870 self .total_ssim .assign (0.0 )
860- self .count .assign (0.0 )
871+ self .count .assign (0.0 )
872+
873+
874+ def _register_all (module_globals ):
875+ for name , obj in module_globals .items ():
876+ if callable (obj ) and not name .startswith ("_" ):
877+ module_globals [name ] = register_keras_serializable ()(obj )
878+
879+ _register_all (globals ())
0 commit comments