Skip to content

Commit 5f2fa49

Browse files
authored
Merge pull request #109 from GeoOcean/107-add-calvalwaves-tool
107 add calvalwaves tool
2 parents d309f4a + 4611f0c commit 5f2fa49

File tree

4 files changed

+1141
-17
lines changed

4 files changed

+1141
-17
lines changed

.github/workflows/python-tests.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,6 @@ jobs:
4848
source /usr/share/miniconda/etc/profile.d/conda.sh
4949
conda activate bluemath-tk
5050
python -m unittest discover tests/datamining/
51-
python -m unittest discover tests/downloaders/
51+
python -m unittest discover tests/distributions/
5252
python -m unittest discover tests/interpolation/
5353
python -m unittest discover tests/wrappers/

bluemath_tk/core/decorators.py

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -437,3 +437,59 @@ def wrapper(
437437
)
438438

439439
return wrapper
440+
441+
442+
def validate_data_calval(func):
443+
"""
444+
Decorator to validate data in CalVal class fit method.
445+
446+
Parameters
447+
----------
448+
func : callable
449+
The function to be decorated
450+
451+
Returns
452+
-------
453+
callable
454+
The decorated function
455+
"""
456+
457+
@functools.wraps(func)
458+
def wrapper(
459+
self,
460+
data: pd.DataFrame,
461+
data_longitude: float,
462+
data_latitude: float,
463+
data_to_calibrate: pd.DataFrame,
464+
max_time_diff: int = 2,
465+
):
466+
if not isinstance(data, pd.DataFrame):
467+
raise TypeError("Data must be a pandas DataFrame")
468+
if not isinstance(data_longitude, float):
469+
raise TypeError("Longitude must be a float")
470+
data_longitude = data_longitude % 360
471+
if not isinstance(data_latitude, float):
472+
raise TypeError("Latitude must be a float")
473+
if not isinstance(data_to_calibrate, pd.DataFrame):
474+
raise TypeError("Data to calibrate must be a pandas DataFrame")
475+
if "LONGITUDE" not in data_to_calibrate.columns:
476+
raise ValueError(
477+
"Data to calibrate must contain a column named 'LONGITUDE'"
478+
)
479+
if "LATITUDE" not in data_to_calibrate.columns:
480+
raise ValueError("Data to calibrate must contain a column named 'LATITUDE'")
481+
if "Hs_CAL" not in data_to_calibrate.columns:
482+
raise ValueError("Data to calibrate must contain a column named 'Hs_CAL'")
483+
if not isinstance(max_time_diff, int) or max_time_diff <= 0:
484+
raise ValueError("Maximum time difference must be an integer and > 0")
485+
486+
return func(
487+
self,
488+
data,
489+
data_longitude,
490+
data_latitude,
491+
data_to_calibrate,
492+
max_time_diff,
493+
)
494+
495+
return wrapper

bluemath_tk/core/plotting/scatter.py

Lines changed: 209 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,43 +1,236 @@
1-
from typing import List, Tuple
1+
from typing import List, Optional, Tuple
22

33
import numpy as np
44
import pandas as pd
55
from matplotlib.axes import Axes
66
from matplotlib.figure import Figure
7+
from scipy.stats import gaussian_kde, probplot
8+
from sklearn.metrics import mean_squared_error
79

810
from .base_plotting import DefaultStaticPlotting
911
from .colors import default_colors
1012

1113

14+
def rmse(pred: np.ndarray, tar: np.ndarray) -> float:
15+
"""
16+
Calculate the Root Mean Square Error between predicted and target values.
17+
18+
Parameters
19+
----------
20+
pred : np.ndarray
21+
Array of predicted values.
22+
tar : np.ndarray
23+
Array of target/actual values.
24+
25+
Returns
26+
-------
27+
float
28+
The Root Mean Square Error value.
29+
"""
30+
31+
if len(pred) != len(tar):
32+
raise ValueError("pred and tar must have the same length")
33+
34+
return np.sqrt(((pred - tar) ** 2).mean())
35+
36+
37+
def bias(pred: np.ndarray, tar: np.ndarray) -> float:
38+
"""
39+
Calculate the bias between predicted and target values.
40+
41+
Parameters
42+
----------
43+
pred : np.ndarray
44+
Array of predicted values.
45+
tar : np.ndarray
46+
Array of target/actual values.
47+
48+
Returns
49+
-------
50+
float
51+
The bias value (mean difference between predictions and targets).
52+
"""
53+
54+
if len(pred) != len(tar):
55+
raise ValueError("pred and tar must have the same length")
56+
57+
return sum(pred - tar) / len(pred)
58+
59+
60+
def si(pred: np.ndarray, tar: np.ndarray) -> float:
61+
"""
62+
Calculate the Scatter Index between predicted and target values.
63+
64+
Parameters
65+
----------
66+
pred : np.ndarray
67+
Array of predicted values.
68+
tar : np.ndarray
69+
Array of target/actual values.
70+
71+
Returns
72+
-------
73+
float
74+
The Scatter Index value.
75+
"""
76+
77+
if len(pred) != len(tar):
78+
raise ValueError("pred and tar must have the same length")
79+
80+
pred_mean = pred.mean()
81+
tar_mean = tar.mean()
82+
83+
return np.sqrt(sum(((pred - pred_mean) - (tar - tar_mean)) ** 2) / (sum(tar**2)))
84+
85+
86+
def density_scatter(
87+
x: np.ndarray, y: np.ndarray
88+
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
89+
"""
90+
Compute a density scatter for two arrays using gaussian KDE.
91+
92+
Parameters
93+
----------
94+
x : np.ndarray
95+
X values for the scatter plot.
96+
y : np.ndarray
97+
Y values for the scatter plot.
98+
99+
Returns
100+
-------
101+
Tuple[np.ndarray, np.ndarray, np.ndarray]
102+
A tuple containing:
103+
- Sorted x values
104+
- Sorted y values
105+
- Density values corresponding to each point
106+
"""
107+
108+
if len(x) != len(y):
109+
raise ValueError("x and y must have the same length")
110+
111+
xy = np.vstack([x, y])
112+
z = gaussian_kde(xy)(xy)
113+
idx = z.argsort()
114+
x1, y1, z = x[idx], y[idx], z[idx]
115+
116+
return x1, y1, z
117+
118+
119+
def validation_scatter(
120+
axs: Axes,
121+
x: np.ndarray,
122+
y: np.ndarray,
123+
xlabel: str,
124+
ylabel: str,
125+
title: str,
126+
) -> None:
127+
"""
128+
Plot a density scatter and Q-Q plot for validation.
129+
130+
Parameters
131+
----------
132+
axs : Axes
133+
Matplotlib axes to plot on.
134+
x : np.ndarray
135+
X values for the scatter plot.
136+
y : np.ndarray
137+
Y values for the scatter plot.
138+
xlabel : str
139+
Label for the X-axis.
140+
ylabel : str
141+
Label for the Y-axis.
142+
title : str
143+
Title for the plot.
144+
"""
145+
146+
x2, y2, z = density_scatter(x, y)
147+
148+
# plot
149+
axs.scatter(x2, y2, c=z, s=5, cmap="rainbow")
150+
151+
# labels
152+
axs.set_xlabel(xlabel)
153+
axs.set_ylabel(ylabel)
154+
axs.set_title(title)
155+
156+
# axis limits
157+
maxt = np.ceil(max(max(x) + 0.1, max(y) + 0.1))
158+
axs.set_xlim(0, maxt)
159+
axs.set_ylim(0, maxt)
160+
axs.plot([0, maxt], [0, maxt], "-r")
161+
axs.set_xticks(np.linspace(0, maxt, 5))
162+
axs.set_yticks(np.linspace(0, maxt, 5))
163+
axs.set_aspect("equal")
164+
165+
# qq-plot
166+
xq = probplot(x, dist="norm")
167+
yq = probplot(y, dist="norm")
168+
axs.plot(xq[0][1], yq[0][1], "o", markersize=0.5, color="k", label="Q-Q plot")
169+
170+
# diagnostic errors
171+
props = dict(
172+
boxstyle="round", facecolor="w", edgecolor="grey", linewidth=0.8, alpha=0.5
173+
)
174+
mse = mean_squared_error(x2, y2)
175+
rmse_e = rmse(x2, y2)
176+
BIAS = bias(x2, y2)
177+
SI = si(x2, y2)
178+
label = "\n".join(
179+
(
180+
r"RMSE = %.2f" % (rmse_e,),
181+
r"mse = %.2f" % (mse,),
182+
r"BIAS = %.2f" % (BIAS,),
183+
R"SI = %.2f" % (SI,),
184+
)
185+
)
186+
axs.text(
187+
0.05,
188+
0.95,
189+
label,
190+
transform=axs.transAxes,
191+
fontsize=9,
192+
verticalalignment="top",
193+
bbox=props,
194+
)
195+
196+
12197
def plot_scatters_in_triangle(
13198
dataframes: List[pd.DataFrame],
14-
data_colors: List[str] = default_colors,
199+
data_colors: Optional[List[str]] = None,
15200
**kwargs,
16-
) -> Tuple[Figure, Axes]:
201+
) -> Tuple[Figure, np.ndarray]:
17202
"""
18-
Plot a scatter plot of the dataframes with axes in a triangle.
203+
Plot scatter plots of the dataframes with axes in a triangle arrangement.
19204
20205
Parameters
21206
----------
22207
dataframes : List[pd.DataFrame]
23-
List of dataframes to plot.
24-
data_colors : List[str], optional
25-
List of colors for the dataframes.
208+
List of dataframes to plot. Each dataframe should contain the same columns.
209+
data_colors : Optional[List[str]], optional
210+
List of colors for the dataframes. If None, uses default_colors.
26211
**kwargs : dict, optional
27-
Keyword arguments for the scatter plot. Will be passed to the
28-
DefaultStaticPlotting.plot_scatter method, which is the same
29-
as the one in matplotlib.pyplot.scatter.
30-
For example, to change the marker size, you can use:
31-
``plot_scatters_in_triangle(dataframes, s=10)``
212+
Additional keyword arguments for the scatter plot. These will be passed to
213+
matplotlib.pyplot.scatter. Common parameters include:
214+
- s : float, marker size
215+
- alpha : float, transparency
216+
- marker : str, marker style
32217
33218
Returns
34219
-------
35-
fig : Figure
36-
Figure object.
37-
axes : Axes
38-
Axes object.
220+
Tuple[Figure, np.ndarray]
221+
A tuple containing:
222+
- Figure object
223+
- 2D array of Axes objects
224+
225+
Raises
226+
------
227+
ValueError
228+
If the variables in the first dataframe are not present in all other dataframes.
39229
"""
40230

231+
if data_colors is None:
232+
data_colors = default_colors
233+
41234
# Get the number and names of variables from the first dataframe
42235
variables_names = list(dataframes[0].columns)
43236
num_variables = len(variables_names)

0 commit comments

Comments
 (0)