Skip to content

Commit 8a39178

Browse files
committed
implement MAPE for pytorch
1 parent 7d99801 commit 8a39178

File tree

2 files changed

+29
-4
lines changed

2 files changed

+29
-4
lines changed

pinnicle/utils/data_misfit.py

Lines changed: 26 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,11 @@ def mean_squared_log_error_tf(y_true, y_pred):
1313
""" use tensorflow function to compute mean squared log error
1414
"""
1515
return tf.keras.losses.MeanSquaredLogarithmicError()(y_true, y_pred)
16+
17+
def mean_absolute_percentage_error_tf(y_true, y_pred):
18+
"""Calculates the Mean Absolute Percentage Error (MAPE) in PyTorch.
19+
"""
20+
return tf.keras.losses.MAPE()(y_true, y_pred)
1621
#}}}
1722
# ---- jax {{{
1823
def surface_log_vel_misfit_jax(v_true, v_pred):
@@ -25,6 +30,13 @@ def mean_squared_log_error_jax(y_true, y_pred):
2530
""" use jax/numpy function to compute mean squared log error
2631
"""
2732
return bkd.reduce_mean(bkd.square(jax.numpy.log(y_true+1.0) - jax.numpy.log(y_pred+1.0)))
33+
def mean_absolute_percentage_error_jax(y_true, y_pred):
34+
"""Calculates the Mean Absolute Percentage Error (MAPE) in PyTorch.
35+
"""
36+
# Ensure y_true is not zero to avoid division by zero
37+
# Add a small epsilon to the denominator to prevent NaN values
38+
epsvel=2.220446049250313e-16
39+
return bkd.reduce_mean(jax.numpy.abs((y_true - y_pred) / (y_true + epsvel))) * 100
2840
#}}}
2941
# ---- pytorch {{{
3042
def surface_log_vel_misfit_pytorch(v_true, v_pred):
@@ -37,24 +49,35 @@ def mean_squared_log_error_pytorch(y_true, y_pred):
3749
""" use jax/numpy function to compute mean squared log error
3850
"""
3951
return bkd.reduce_mean(bkd.square(torch.log(y_true+1.0) - torch.log(y_pred+1.0)))
52+
53+
def mean_absolute_percentage_error_pytorch(y_true, y_pred):
54+
"""Calculates the Mean Absolute Percentage Error (MAPE) in PyTorch.
55+
"""
56+
# Ensure y_true is not zero to avoid division by zero
57+
# Add a small epsilon to the denominator to prevent NaN values
58+
epsvel=2.220446049250313e-16
59+
return bkd.reduce_mean(bkd.abs((y_true - y_pred) / (y_true + epsvel))) * 100
4060
#}}}
4161
# ---------------
4262
def loss_dict_tf():
4363
return {
4464
"VEL_LOG": surface_log_vel_misfit_tf,
45-
"MEAN_SQUARE_LOG": mean_squared_log_error_tf
65+
"MEAN_SQUARE_LOG": mean_squared_log_error_tf,
66+
"MAPE":mean_absolute_percentage_error_tf,
4667
}
4768

4869
def loss_dict_jax():
4970
return {
5071
"VEL_LOG": surface_log_vel_misfit_jax,
51-
"MEAN_SQUARE_LOG": mean_squared_log_error_jax
72+
"MEAN_SQUARE_LOG": mean_squared_log_error_jax,
73+
"MAPE":mean_absolute_percentage_error_jax,
5274
}
5375

5476
def loss_dict_pytorch():
5577
return {
5678
"VEL_LOG": surface_log_vel_misfit_pytorch,
57-
"MEAN_SQUARE_LOG": mean_squared_log_error_pytorch
79+
"MEAN_SQUARE_LOG": mean_squared_log_error_pytorch,
80+
"MAPE":mean_absolute_percentage_error_pytorch,
5881
}
5982

6083
if backend_name == "tensorflow":

tests/test_utils.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,15 +18,17 @@ def test_save_and_load_dict(tmp_path):
1818
def test_data_misfit():
1919
with pytest.raises(Exception):
2020
data_misfit.get("not defined")
21-
dde_loss = ["mean absolute error", "MAE", "mae", "mean squared error", "mse", "mean absolute percentage error", "MAPE", "mape", "mean l2 relative error", "softmax cross entropy", "zero"]
21+
dde_loss = ["mean absolute error", "MAE", "mae", "mean squared error", "mse", "mean l2 relative error", "softmax cross entropy", "zero"]
2222
for l in dde_loss:
2323
assert data_misfit.get(l) == l
2424

2525
def test_data_misfit_functions():
2626
assert data_misfit.get("VEL_LOG") != None
2727
assert data_misfit.get("MEAN_SQUARE_LOG") != None
28+
assert data_misfit.get("MAPE") != None
2829
assert data_misfit.get("VEL_LOG")(backend.as_tensor([1.0]),backend.as_tensor([1.0])) == 0.0
2930
assert data_misfit.get("MEAN_SQUARE_LOG")(backend.as_tensor([1.0]),backend.as_tensor([1.0])) == 0.0
31+
assert data_misfit.get("MAPE")(backend.as_tensor([1.0]),backend.as_tensor([1.0])) == 0.0
3032

3133
def test_loadmat():
3234
filename = "flightTracks.mat"

0 commit comments

Comments
 (0)