Skip to content

Commit 9934376

Browse files
committed
[VCF] Stable version NonStatGEV. Vectorize parametro function and remove _search function.
1 parent 24a1749 commit 9934376

File tree

1 file changed

+54
-67
lines changed

1 file changed

+54
-67
lines changed

bluemath_tk/distributions/nonstat_gev.py

Lines changed: 54 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -13,36 +13,36 @@
1313
from numba import njit, prange
1414

1515

16-
@njit(fastmath=True)
17-
def search(times: np.ndarray, values: np.ndarray, xs) -> np.ndarray:
18-
"""
19-
Function to search the nearest value of certain time to use in self._parametro function
20-
21-
Parameters
22-
----------
23-
times : np.ndarray
24-
Times when covariates are known
25-
values : np.ndarray
26-
Values of the covariates at those times
27-
"""
28-
# n = times.shape[0]
29-
# yin = np.zeros_like(xs)
30-
# pos = 0
31-
# for j in range(xs.size):
32-
# found = 0
33-
# while found == 0 and pos < n:
34-
# if xs[j] < times[pos]:
35-
# yin[j] = values[pos]
36-
# found = 1
37-
# else:
38-
# pos += 1
39-
40-
# return yin
41-
42-
idx = np.searchsorted(times, xs, side='right')
43-
mask = idx < len(times)
44-
45-
return values[idx[mask]]
16+
# @njit(fastmath=True)
17+
# def search(times: np.ndarray, values: np.ndarray, xs) -> np.ndarray:
18+
# """
19+
# Function to search the nearest value of certain time to use in self._parametro function
20+
21+
# Parameters
22+
# ----------
23+
# times : np.ndarray
24+
# Times when covariates are known
25+
# values : np.ndarray
26+
# Values of the covariates at those times
27+
# """
28+
# # n = times.shape[0]
29+
# # yin = np.zeros_like(xs)
30+
# # pos = 0
31+
# # for j in range(xs.size):
32+
# # found = 0
33+
# # while found == 0 and pos < n:
34+
# # if xs[j] < times[pos]:
35+
# # yin[j] = values[pos]
36+
# # found = 1
37+
# # else:
38+
# # pos += 1
39+
40+
# # return yin
41+
42+
# idx = np.searchsorted(times, xs, side='right')
43+
# mask = idx < len(times)
44+
45+
# return values[idx[mask]]
4646

4747

4848
class NonStatGEV(BlueMathModel):
@@ -4617,7 +4617,7 @@ def parametro(
46174617
indicesint : np.ndarray, optional
46184618
Covariate mean values in the integral interval
46194619
times : np.ndarray, optional
4620-
Times when covariates are known, used to find the nearest value using self._search
4620+
Times when covariates are known, used to find the nearest value
46214621
t : np.ndarray, optional
46224622
Specific time point to evaluate the parameters at, if None, uses the times given
46234623
@@ -4658,45 +4658,32 @@ def parametro(
46584658
if nind > 0:
46594659
if indicesint.shape[0] > 0:
46604660
if times.shape[0] == 0:
4661-
for i in prange(nind):
4662-
y = y + beta_cov[i] * indicesint[i]
4661+
# for i in prange(nind):
4662+
# y = y + beta_cov[i] * indicesint[i]
4663+
y = y + indicesint @ beta_cov
46634664
else:
4664-
for i in prange(nind):
4665-
indicesintaux = search(
4666-
times, covariates[:, i], x.flatten()
4667-
)
4668-
y = y + beta_cov[i] * indicesintaux
4665+
# for i in prange(nind):
4666+
# indicesintaux = search(
4667+
# times, covariates[:, i], x.flatten()
4668+
# )
4669+
# y = y + beta_cov[i] * indicesintaux
4670+
idx = np.searchsorted(times, x, side='right')
4671+
valid = idx < times.size
4672+
4673+
y_add = np.zeros_like(x, dtype=np.float64)
4674+
if np.any(valid):
4675+
# pick rows from covariates and do one matvec
4676+
A = covariates[idx[valid], :] # (k, nind)
4677+
y_add[valid] = A @ beta_cov # (k,)
4678+
4679+
y = y + y_add
46694680
else:
4670-
for i in prange(nind):
4671-
y = y + beta_cov[i] * covariates[:, i]
4681+
# for i in prange(nind):
4682+
# y = y + beta_cov[i] * covariates[:, i]
4683+
y = y + covariates @ beta_cov
46724684

4673-
return y
4674-
4675-
def _search(self, times: np.ndarray, values: np.ndarray, xs) -> np.ndarray:
4676-
"""
4677-
Function to search the nearest value of certain time to use in self._parametro function
4678-
4679-
Parameters
4680-
----------
4681-
times : np.ndarray
4682-
Times when covariates are known
4683-
values : np.ndarray
4684-
Values of the covariates at those times
4685-
"""
4686-
n = times.shape[0]
4687-
yin = np.zeros_like(xs)
4688-
pos = 0
4689-
for j in range(xs.size):
4690-
found = 0
4691-
while found == 0 and pos < n:
4692-
if xs[j] < times[pos]:
4693-
yin[j] = values[pos]
4694-
found = 1
4695-
else:
4696-
pos += 1
4697-
4698-
return yin
46994685

4686+
return y
47004687

47014688
def _evaluate_params(
47024689
self,
@@ -5170,7 +5157,7 @@ def plot(self, return_plot: bool = False, save: bool = False, init_year: int = 0
51705157
alpha=0.9,
51715158
)
51725159

5173-
# TODO: Add aggregated return period lines
5160+
# Aggregated return period lines
51745161
# n_years = int(np.ceil(self.t[-1]))
51755162
# rt_10 = np.zeros(n_years)
51765163
# for year in range(n_years):

0 commit comments

Comments
 (0)