@@ -13,6 +13,11 @@ def mean_squared_log_error_tf(y_true, y_pred):
1313 """ use tensorflow function to compute mean squared log error
1414 """
1515 return tf .keras .losses .MeanSquaredLogarithmicError ()(y_true , y_pred )
16+
17+ def mean_absolute_percentage_error_tf (y_true , y_pred ):
18+ """Calculates the Mean Absolute Percentage Error (MAPE) in PyTorch.
19+ """
20+ return tf .keras .losses .MAPE ()(y_true , y_pred )
1621#}}}
1722# ---- jax {{{
1823def surface_log_vel_misfit_jax (v_true , v_pred ):
@@ -25,6 +30,13 @@ def mean_squared_log_error_jax(y_true, y_pred):
2530 """ use jax/numpy function to compute mean squared log error
2631 """
2732 return bkd .reduce_mean (bkd .square (jax .numpy .log (y_true + 1.0 ) - jax .numpy .log (y_pred + 1.0 )))
33+ def mean_absolute_percentage_error_jax (y_true , y_pred ):
34+ """Calculates the Mean Absolute Percentage Error (MAPE) in PyTorch.
35+ """
36+ # Ensure y_true is not zero to avoid division by zero
37+ # Add a small epsilon to the denominator to prevent NaN values
38+ epsvel = 2.220446049250313e-16
39+ return bkd .reduce_mean (jax .numpy .abs ((y_true - y_pred ) / (y_true + epsvel ))) * 100
2840#}}}
2941# ---- pytorch {{{
3042def surface_log_vel_misfit_pytorch (v_true , v_pred ):
@@ -37,24 +49,35 @@ def mean_squared_log_error_pytorch(y_true, y_pred):
3749 """ use jax/numpy function to compute mean squared log error
3850 """
3951 return bkd .reduce_mean (bkd .square (torch .log (y_true + 1.0 ) - torch .log (y_pred + 1.0 )))
52+
53+ def mean_absolute_percentage_error_pytorch (y_true , y_pred ):
54+ """Calculates the Mean Absolute Percentage Error (MAPE) in PyTorch.
55+ """
56+ # Ensure y_true is not zero to avoid division by zero
57+ # Add a small epsilon to the denominator to prevent NaN values
58+ epsvel = 2.220446049250313e-16
59+ return bkd .reduce_mean (bkd .abs ((y_true - y_pred ) / (y_true + epsvel ))) * 100
4060#}}}
4161# ---------------
4262def loss_dict_tf ():
4363 return {
4464 "VEL_LOG" : surface_log_vel_misfit_tf ,
45- "MEAN_SQUARE_LOG" : mean_squared_log_error_tf
65+ "MEAN_SQUARE_LOG" : mean_squared_log_error_tf ,
66+ "MAPE" :mean_absolute_percentage_error_tf ,
4667 }
4768
4869def loss_dict_jax ():
4970 return {
5071 "VEL_LOG" : surface_log_vel_misfit_jax ,
51- "MEAN_SQUARE_LOG" : mean_squared_log_error_jax
72+ "MEAN_SQUARE_LOG" : mean_squared_log_error_jax ,
73+ "MAPE" :mean_absolute_percentage_error_jax ,
5274 }
5375
5476def loss_dict_pytorch ():
5577 return {
5678 "VEL_LOG" : surface_log_vel_misfit_pytorch ,
57- "MEAN_SQUARE_LOG" : mean_squared_log_error_pytorch
79+ "MEAN_SQUARE_LOG" : mean_squared_log_error_pytorch ,
80+ "MAPE" :mean_absolute_percentage_error_pytorch ,
5881 }
5982
6083if backend_name == "tensorflow" :
0 commit comments