|
1 | 1 | from typing import Dict, List |
2 | 2 |
|
3 | 3 | import numpy as np |
4 | | -from scipy.special import gamma |
| 4 | +from scipy.stats import genextreme |
5 | 5 |
|
6 | 6 | from ._base_distributions import BaseDistribution, FitResult, fit_dist |
7 | 7 |
|
@@ -117,22 +117,7 @@ def pdf( |
117 | 117 | if scale <= 0: |
118 | 118 | raise ValueError("Scale parameter must be > 0") |
119 | 119 |
|
120 | | - y = (x - loc) / scale |
121 | | - |
122 | | - # Gumbel case (shape = 0) |
123 | | - if shape == 0.0: |
124 | | - pdf = (1 / scale) * (np.exp(-y) * np.exp(-np.exp(-y))) |
125 | | - |
126 | | - # General case (Weibull and Frechet, shape != 0) |
127 | | - else: |
128 | | - pdf = np.full_like(x, 0, dtype=float) # 0 |
129 | | - yy = 1 + shape * y |
130 | | - yymask = yy > 0 |
131 | | - pdf[yymask] = (1 / scale) * ( |
132 | | - yy[yymask] ** (-1 - (1 / shape)) * np.exp(-(yy[yymask] ** (-1 / shape))) |
133 | | - ) |
134 | | - |
135 | | - return pdf |
| 120 | + return genextreme.pdf(x, -shape, loc=loc, scale=scale) |
136 | 121 |
|
137 | 122 | @staticmethod |
138 | 123 | def cdf( |
@@ -167,17 +152,7 @@ def cdf( |
167 | 152 | if scale <= 0: |
168 | 153 | raise ValueError("Scale parameter must be > 0") |
169 | 154 |
|
170 | | - y = (x - loc) / scale |
171 | | - |
172 | | - # Gumbel case (shape = 0) |
173 | | - if shape == 0.0: |
174 | | - p = np.exp(-np.exp(-y)) |
175 | | - |
176 | | - # General case (Weibull and Frechet, shape != 0) |
177 | | - else: |
178 | | - p = np.exp(-(np.maximum(1 + shape * y, 0) ** (-1 / shape))) |
179 | | - |
180 | | - return p |
| 155 | + return genextreme.cdf(x, -shape, loc=loc, scale=scale) |
181 | 156 |
|
182 | 157 | @staticmethod |
183 | 158 | def sf( |
@@ -212,9 +187,7 @@ def sf( |
212 | 187 | if scale <= 0: |
213 | 188 | raise ValueError("Scale parameter must be > 0") |
214 | 189 |
|
215 | | - sp = 1 - GEV.cdf(x, loc=loc, scale=scale, shape=shape) |
216 | | - |
217 | | - return sp |
| 190 | + return genextreme.sf(x, -shape, loc=loc, scale=scale) |
218 | 191 |
|
219 | 192 | @staticmethod |
220 | 193 | def qf( |
@@ -255,15 +228,7 @@ def qf( |
255 | 228 | if scale <= 0: |
256 | 229 | raise ValueError("Scale parameter must be > 0") |
257 | 230 |
|
258 | | - # Gumbel case (shape = 0) |
259 | | - if shape == 0.0: |
260 | | - q = loc - scale * np.log(-np.log(p)) |
261 | | - |
262 | | - # General case (Weibull and Frechet, shape != 0) |
263 | | - else: |
264 | | - q = loc + scale * ((-np.log(p)) ** (-shape) - 1) / shape |
265 | | - |
266 | | - return q |
| 231 | + return genextreme.ppf(p, -shape, loc=loc, scale=scale) |
267 | 232 |
|
268 | 233 | @staticmethod |
269 | 234 | def nll( |
@@ -291,35 +256,12 @@ def nll( |
291 | 256 | """ |
292 | 257 |
|
293 | 258 | if scale <= 0: |
294 | | - nll = np.inf # Return a large value for invalid scale |
| 259 | + return np.inf # Return a large value for invalid scale |
295 | 260 |
|
296 | 261 | else: |
297 | | - y = (data - loc) / scale |
298 | | - |
299 | | - # # Gumbel case (shape = 0) |
300 | | - # if shape == 0.0: |
301 | | - # pass |
302 | | - # nll = data.shape[0] * np.log(scale) + np.sum( |
303 | | - # np.exp(-y) + np.sum(-y) |
304 | | - # ) # Gumbel case |
305 | | - |
306 | | - # # General case (Weibull and Frechet, shape != 0) |
307 | | - # else: |
308 | | - |
309 | | - shape = ( |
310 | | - np.maximum(shape, 1e-8) if shape > 0 else np.minimum(shape, -1e-8) |
311 | | - ) # Avoid division by zero |
312 | | - y = 1 + shape * y |
313 | | - if any(y <= 0): |
314 | | - nll = np.inf # Return a large value for invalid y |
315 | | - else: |
316 | | - nll = ( |
317 | | - data.shape[0] * np.log(scale) |
318 | | - + np.sum(y ** (-1 / shape)) |
319 | | - + (1 / shape + 1) * np.sum(np.log(y)) |
320 | | - ) |
321 | | - |
322 | | - return nll |
| 262 | + return -np.sum( |
| 263 | + genextreme.logpdf(data, -shape, loc=loc, scale=scale), axis=0 |
| 264 | + ) |
323 | 265 |
|
324 | 266 | @staticmethod |
325 | 267 | def fit(data: np.ndarray, **kwargs) -> FitResult: |
@@ -385,22 +327,9 @@ def random( |
385 | 327 | if scale <= 0: |
386 | 328 | raise ValueError("Scale parameter must be > 0") |
387 | 329 |
|
388 | | - # Set random state if provided |
389 | | - if random_state is not None: |
390 | | - np.random.seed(random_state) |
391 | | - |
392 | | - # Generate uniform random numbers |
393 | | - u = np.random.uniform(0, 1, size) |
394 | | - |
395 | | - # Gumbel case (shape = 0) |
396 | | - if shape == 0.0: |
397 | | - x = loc - scale * np.log(-np.log(u)) |
398 | | - |
399 | | - # General case (Weibull and Frechet, shape != 0) |
400 | | - else: |
401 | | - x = loc + scale * ((-np.log(u)) ** (-shape) - 1) / shape |
402 | | - |
403 | | - return x |
| 330 | + return genextreme.rvs( |
| 331 | + -shape, loc=loc, scale=scale, size=size, random_state=random_state |
| 332 | + ) |
404 | 333 |
|
405 | 334 | @staticmethod |
406 | 335 | def mean(loc: float = 0.0, scale: float = 1.0, shape: float = 0.0) -> float: |
@@ -431,21 +360,7 @@ def mean(loc: float = 0.0, scale: float = 1.0, shape: float = 0.0) -> float: |
431 | 360 | if scale <= 0: |
432 | 361 | raise ValueError("Scale parameter must be > 0") |
433 | 362 |
|
434 | | - eu_cons = np.euler_gamma # Euler-Mascheroni constant |
435 | | - |
436 | | - # Gumbel case (shape = 0) |
437 | | - if shape == 0.0: |
438 | | - mean = loc + scale * eu_cons |
439 | | - |
440 | | - # General case (Weibull and Frechet, shape != 0 and shape < 1) |
441 | | - elif shape != 0.0 and shape < 1: |
442 | | - mean = loc + scale * (gamma(1 - shape) - 1) / shape |
443 | | - |
444 | | - # Shape >= 1 case |
445 | | - else: |
446 | | - mean = np.inf |
447 | | - |
448 | | - return mean |
| 363 | + return genextreme.mean(-shape, loc=loc, scale=scale) |
449 | 364 |
|
450 | 365 | @staticmethod |
451 | 366 | def median(loc: float = 0.0, scale: float = 1.0, shape: float = 0.0) -> float: |
@@ -476,13 +391,7 @@ def median(loc: float = 0.0, scale: float = 1.0, shape: float = 0.0) -> float: |
476 | 391 | if scale <= 0: |
477 | 392 | raise ValueError("Scale parameter must be > 0") |
478 | 393 |
|
479 | | - if shape == 0.0: |
480 | | - median = loc - scale * np.log(np.log(2)) |
481 | | - |
482 | | - else: |
483 | | - median = loc + scale * ((np.log(2)) ** (-shape) - 1) / shape |
484 | | - |
485 | | - return median |
| 394 | + return genextreme.median(-shape, loc=loc, scale=scale) |
486 | 395 |
|
487 | 396 | @staticmethod |
488 | 397 | def variance(loc: float = 0.0, scale: float = 1.0, shape: float = 0.0) -> float: |
@@ -513,21 +422,7 @@ def variance(loc: float = 0.0, scale: float = 1.0, shape: float = 0.0) -> float: |
513 | 422 | if scale <= 0: |
514 | 423 | raise ValueError("Scale parameter must be > 0") |
515 | 424 |
|
516 | | - # Gumbel case (shape = 0) |
517 | | - if shape == 0.0: |
518 | | - var = (np.pi**2 / 6) * scale**2 |
519 | | - |
520 | | - elif shape != 0.0 and shape < 0.5: |
521 | | - var = ( |
522 | | - (scale**2) |
523 | | - * (gamma(1 - 2 * shape) - (gamma(1 - shape) ** 2)) |
524 | | - / (shape**2) |
525 | | - ) |
526 | | - |
527 | | - else: |
528 | | - var = np.inf |
529 | | - |
530 | | - return var |
| 425 | + return genextreme.var(-shape, loc=loc, scale=scale) |
531 | 426 |
|
532 | 427 | @staticmethod |
533 | 428 | def std(loc: float = 0.0, scale: float = 1.0, shape: float = 0.0) -> float: |
@@ -559,9 +454,7 @@ def std(loc: float = 0.0, scale: float = 1.0, shape: float = 0.0) -> float: |
559 | 454 | if scale <= 0: |
560 | 455 | raise ValueError("Scale parameter must be > 0") |
561 | 456 |
|
562 | | - std = np.sqrt(GEV.variance(loc, scale, shape)) |
563 | | - |
564 | | - return std |
| 457 | + return genextreme.std(-shape, loc=loc, scale=scale) |
565 | 458 |
|
566 | 459 | @staticmethod |
567 | 460 | def stats( |
|
0 commit comments