Skip to content

Commit 7ce2d20

Browse files
committed
[VCF] Add plot to FitResult and small changes in gev tests
1 parent 47ce317 commit 7ce2d20

File tree

2 files changed

+146
-5
lines changed

2 files changed

+146
-5
lines changed

bluemath_tk/distributions/_base_distributions.py

Lines changed: 142 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
import numpy as np
55
from scipy.optimize import minimize
6+
import matplotlib.pyplot as plt
67

78
from ..core.models import BlueMathModel
89

@@ -53,6 +54,10 @@ def __init__(self, dist, data, res):
5354
self.nll = res.fun
5455
self.res = res
5556

57+
# Auxiliar for diagnostics plots
58+
self.n = self.data.shape[0]
59+
self.ecdf = np.arange(1, self.n + 1) / (self.n + 1)
60+
5661
def summary(self):
5762
"""
5863
Print a summary of the fitting results
@@ -66,11 +71,145 @@ def summary(self):
6671
print(f"Negative Log-Likelihood value: {self.nll:.4f}")
6772
print(f"{self.message}")
6873

69-
def plot(self, ax=None, plot_type="hist"):
74+
def plot(self, ax=None, plot_type="all"):
7075
"""
71-
Plots of fitting results
76+
Plots of fitting results: PP-plot, QQ-plot, histogram with fitted distribution, and return period plot.
77+
Parameters
78+
----------
79+
ax : matplotlib.axes.Axes, optional
80+
Axes to plot on. If None, a new figure and axes will be created.
81+
plot_type : str, optional
82+
Type of plot to create. Options are "hist" for histogram, "pp" for P-P plot,
83+
"qq" for Q-Q plot, "return_period" for return period plot, or "all" for all plots.
84+
Default is "all".
85+
86+
Returns
87+
-------
88+
fig : matplotlib.figure.Figure
89+
Figure object containing the plots. If `ax` is provided, returns None.
90+
91+
Raises
92+
-------
93+
ValueError
94+
If `plot_type` is not one of the valid options ("hist", "pp", "qq", "return_period", "all").
7295
"""
73-
pass
96+
if plot_type == "all":
97+
fig, axs = plt.subplots(2, 2, figsize=(12, 10))
98+
self.hist(ax=axs[0, 0])
99+
self.pp(ax=axs[0, 1])
100+
self.qq(ax=axs[1, 0])
101+
self.return_period(ax=axs[1, 1])
102+
plt.tight_layout()
103+
return fig
104+
elif plot_type == "hist":
105+
return self.hist()
106+
elif plot_type == "pp":
107+
return self.pp()
108+
elif plot_type == "qq":
109+
return self.qq()
110+
elif plot_type == "return_period":
111+
return self.return_period()
112+
else:
113+
raise ValueError("Invalid plot type. Use 'hist', 'pp', 'qq', 'return_period', or 'all'.")
114+
115+
def pp(self, ax=None):
116+
"""
117+
Probability plot of the fitted distribution.
118+
Parameters
119+
----------
120+
ax : matplotlib.axes.Axes, optional
121+
Axes to plot on. If None, a new figure and axes will be created.
122+
"""
123+
if ax is None:
124+
fig, ax = plt.subplots(figsize=(8, 6))
125+
else:
126+
fig = None
127+
128+
probabilities = self.dist.cdf(np.sort(self.data), *self.params)
129+
ax.plot([0, 1], [0, 1], color="tab:red", linestyle="--")
130+
ax.plot(probabilities, self.ecdf, color="tab:blue", marker="o", linestyle="", alpha=0.7)
131+
ax.set_xlabel("Fitted Probability")
132+
ax.set_ylabel("Empirical Probability")
133+
ax.set_title(f"PP Plot of {self.dist().name}")
134+
ax.grid()
135+
136+
return fig
137+
138+
def qq(self, ax=None):
139+
"""
140+
Quantile-Quantile plot of the fitted distribution.
141+
Parameters
142+
----------
143+
ax : matplotlib.axes.Axes, optional
144+
Axes to plot on. If None, a new figure and axes will be created.
145+
"""
146+
if ax is None:
147+
fig, ax = plt.subplots(figsize=(8, 6))
148+
else:
149+
fig = None
150+
151+
quantiles = self.dist.qf(self.ecdf, *self.params)
152+
ax.plot([np.min(self.data), np.max(self.data)], [np.min(self.data), np.max(self.data)], color="tab:red", linestyle="--")
153+
ax.plot(quantiles, np.sort(self.data), color="tab:blue", marker="o", linestyle="", alpha=0.7)
154+
ax.set_xlabel("Theoretical Quantiles")
155+
ax.set_ylabel("Sample Quantiles")
156+
ax.set_title(f"QQ Plot of {self.dist().name}")
157+
ax.grid()
158+
159+
return fig
160+
161+
def hist(self, ax=None):
162+
"""
163+
Histogram of the data with the fitted distribution overlayed.
164+
Parameters
165+
----------
166+
ax : matplotlib.axes.Axes, optional
167+
Axes to plot on. If None, a new figure and axes will be created.
168+
"""
169+
if ax is None:
170+
fig, ax = plt.subplots(figsize=(8, 6))
171+
else:
172+
fig = None
173+
174+
ax.hist(self.data, bins=30, density=True, alpha=0.7, color='tab:blue', label='Data Histogram')
175+
x = np.linspace(np.min(self.data), np.max(self.data), 1000)
176+
ax.plot(x, self.dist.pdf(x, *self.params), color='tab:red', label='Fitted PDF')
177+
ax.set_xlabel("Data Values")
178+
ax.set_ylabel("Density")
179+
ax.set_title(f"Histogram and Fitted PDF of {self.dist().name}")
180+
ax.legend()
181+
ax.grid()
182+
183+
return fig
184+
185+
def return_period(self, ax=None):
186+
"""
187+
Return period plot of the fitted distribution.
188+
Parameters
189+
----------
190+
ax : matplotlib.axes.Axes, optional
191+
Axes to plot on. If None, a new figure and axes will be created.
192+
"""
193+
if ax is None:
194+
fig, ax = plt.subplots(figsize=(8, 6))
195+
else:
196+
fig = None
197+
198+
199+
sorted_data = np.sort(self.data)
200+
exceedance_prob = 1 - self.ecdf
201+
return_period = 1 / exceedance_prob
202+
203+
ax.plot(return_period, self.dist.qf(self.ecdf, *self.params), color='tab:red', label='Fitted Distribution')
204+
ax.plot(return_period, sorted_data, marker="o", linestyle="", color="tab:blue", alpha=0.7, label='Empirical Data')
205+
ax.set_xscale("log")
206+
ax.set_xlabel("Return Period")
207+
ax.set_ylabel("Data Values")
208+
ax.set_title(f"Return Period Plot of {self.dist().name}")
209+
ax.legend()
210+
ax.grid()
211+
212+
return fig
74213

75214

76215
def fit_dist(dist, data: np.ndarray, **kwargs) -> FitResult:

tests/distributions/test_gev.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,7 @@ def test_invalid_scale(self):
105105
def test_fit(self):
106106
# Generate data using specific parameters
107107
loc, scale, shape = 0.5, 1.5, 0.2
108-
data = gev.random(1000, loc, scale, shape)
108+
data = gev.random(1000, loc, scale, shape, random_state=42)
109109

110110
# Fit the GEV distribution to the data
111111
fit_result = gev.fit(data)
@@ -118,7 +118,9 @@ def test_fit(self):
118118
self.assertIsInstance(fit_result.nll, float)
119119

120120
# Verify that the fitted parameters are close to the original ones
121-
np.testing.assert_allclose(fit_result.params, [loc, scale, shape], rtol=0.25)
121+
self.assertAlmostEqual(fit_result.params[0], loc, delta=0.1)
122+
self.assertAlmostEqual(fit_result.params[1], scale, delta=0.1)
123+
self.assertAlmostEqual(fit_result.params[2], shape, delta=0.1)
122124

123125

124126
if __name__ == "__main__":

0 commit comments

Comments
 (0)