Skip to content

Commit ad7e097

Browse files
committed
[JTH] add pca and rbf optimized wit threads and working
1 parent 49a7bfa commit ad7e097

File tree

5 files changed

+157
-70
lines changed

5 files changed

+157
-70
lines changed

bluemath_tk/core/decorators.py

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -226,18 +226,11 @@ def wrapper(
226226
"PCA dimension for rows must be a string and found in the data dimensions"
227227
)
228228
for variable, windows in windows_in_pca_dim_for_rows.items():
229-
if variable not in vars_to_stack:
230-
raise ValueError(f"Variable {variable} not found in vars_to_stack")
231229
if not isinstance(windows, list):
232230
raise TypeError("Windows must be a list")
233231
if not all([isinstance(window, int) and window > 0 for window in windows]):
234232
raise ValueError("Windows must be a list of integers > 0")
235-
for variable, _ in value_to_replace_nans.items():
236-
if variable not in vars_to_stack:
237-
raise ValueError(f"Variable {variable} not found in vars_to_stack")
238233
for variable, threshold in nan_threshold_to_drop.items():
239-
if variable not in vars_to_stack:
240-
raise ValueError(f"Variable {variable} not found in vars_to_stack")
241234
if not isinstance(threshold, float) or threshold < 0 or threshold > 1:
242235
raise ValueError("Threshold must be a float between 0 and 1")
243236
return func(

bluemath_tk/datamining/pca.py

Lines changed: 55 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -34,18 +34,20 @@ class PCA(BaseReduction):
3434
The PCA or Incremental PCA model.
3535
is_fitted : bool
3636
Indicates whether the PCA model has been fitted.
37-
_data : xr.Dataset
37+
data : xr.Dataset
3838
The original dataset.
39-
_postprocessed_data : xr.Dataset
40-
The postprocessed dataset.
41-
_stacked_data_matrix : np.ndarray
39+
window_processed_data : xr.Dataset
40+
The windows processed dataset.
41+
stacked_data_matrix : np.ndarray
4242
The stacked data matrix.
43-
_standarized_stacked_data_matrix : np.ndarray
43+
standarized_stacked_data_matrix : np.ndarray
4444
The standardized stacked data matrix.
4545
scaler : StandardScaler
4646
The scaler used for standardizing the data.
4747
vars_to_stack : List[str]
4848
The list of variables to stack.
49+
window_stacked_vars : List[str]
50+
The list of variables with windows.
4951
coords_to_stack : List[str]
5052
The list of coordinates to stack.
5153
pca_dim_for_rows : str
@@ -93,12 +95,19 @@ class PCA(BaseReduction):
9395
>>> from bluemath_tk.core.data.sample_data import get_2d_dataset
9496
>>> from bluemath_tk.datamining.pca import PCA
9597
>>> ds = get_2d_dataset()
96-
>>> pca = PCA(n_components=5)
98+
>>> pca = PCA(
99+
... n_components=5,
100+
... is_incremental=False,
101+
... debug=True,
102+
... )
97103
>>> pca.fit(
98104
... data=ds,
99105
... vars_to_stack=["X", "Y"],
100106
... coords_to_stack=["coord1", "coord2"],
101107
... pca_dim_for_rows="coord3",
108+
... windows_in_pca_dim_for_rows={"X": [1, 2, 3]},
109+
... value_to_replace_nans={"X": 0.0},
110+
... nan_threshold_to_drop={"X": 0.95},
102111
... )
103112
>>> pcs = pca.transform(
104113
... data=ds,
@@ -108,6 +117,8 @@ class PCA(BaseReduction):
108117
>>> explained_variance = pca.explained_variance
109118
>>> explained_variance_ratio = pca.explained_variance_ratio
110119
>>> cumulative_explained_variance_ratio = pca.cumulative_explained_variance_ratio
120+
>>> # Save the full class in a pickle file
121+
>>> pca.save_model("pca_model.pkl")
111122
112123
References
113124
----------
@@ -260,18 +271,31 @@ def _generate_stacked_data(self, data: xr.Dataset) -> np.ndarray:
260271
cleaned_vars_to_stack = []
261272
for var_to_clean in self.window_stacked_vars:
262273
var_to_clean_values = tmp_stacked_data[var_to_clean].values
274+
# Drop variables with more than 90% of NaNs if not specified
275+
var_to_clean_threshold = self.nan_threshold_to_drop.get(
276+
var_to_clean,
277+
self.nan_threshold_to_drop.get(
278+
var_to_clean[:-2],
279+
self.nan_threshold_to_drop.get(var_to_clean[:-3], 0.90),
280+
),
281+
)
263282
not_nan_positions = np.where(
264-
np.mean(np.isnan(var_to_clean_values), axis=0)
265-
< self.nan_threshold_to_drop.get(
266-
var_to_clean, 0.05
267-
) # TODO: Add to docstring
283+
np.mean(~np.isnan(var_to_clean_values), axis=0) > var_to_clean_threshold
268284
)[0]
285+
# Replace NaNs with the value specified in value_to_replace_nans
286+
# If not specified, try to get the value from the variable name, deleting window suffixes
287+
var_value_to_replace_nans = self.value_to_replace_nans.get(
288+
var_to_clean,
289+
self.value_to_replace_nans.get(
290+
var_to_clean[:-2], self.value_to_replace_nans.get(var_to_clean[:-3])
291+
),
292+
)
269293
self.logger.debug(
270-
f"Replacing NaNs for variable: {var_to_clean} with value: {self.value_to_replace_nans.get(var_to_clean)}"
294+
f"Replacing NaNs for variable: {var_to_clean} with value: {var_value_to_replace_nans}"
271295
)
272296
cleaned_var = self.check_nans(
273297
data=var_to_clean_values[:, not_nan_positions],
274-
replace_value=self.value_to_replace_nans.get(var_to_clean),
298+
replace_value=var_value_to_replace_nans,
275299
)
276300
cleaned_vars_to_stack.append(cleaned_var)
277301
self.not_nan_positions[var_to_clean] = not_nan_positions
@@ -496,8 +520,15 @@ def fit(
496520
The value to replace NaNs for each variable. Default is {}.
497521
nan_threshold_to_drop : dict, optional
498522
The threshold percentage to drop NaNs for each variable.
499-
By default, variables with more than 95% of NaNs are dropped.
523+
By default, variables with more than 90% of NaNs are dropped.
500524
Default is {}.
525+
526+
Notes
527+
-----
528+
For both value_to_replace_nans and nan_threshold_to_drop, the keys are the variables,
529+
and the suffixes for the windows are considered.
530+
Example: if you have variable "X", and apply windows [1, 2, 3], you can use "X_1", "X_2", "X_3".
531+
Nevertheless, you can also use the original variable name "X" to apply the same value for all windows.
501532
"""
502533

503534
self.vars_to_stack = vars_to_stack.copy()
@@ -585,13 +616,20 @@ def fit_transform(
585616
The value to replace NaNs for each variable. Default is {}.
586617
nan_threshold_to_drop : dict, optional
587618
The threshold percentage to drop NaNs for each variable.
588-
By default, variables with more than 95% of NaNs are dropped.
619+
By default, variables with more than 90% of NaNs are dropped.
589620
Default is {}.
590621
591622
Returns
592623
-------
593624
xr.Dataset
594-
The transformed data.
625+
The transformed data representing the Principal Components (PCs).
626+
627+
Notes
628+
-----
629+
For both value_to_replace_nans and nan_threshold_to_drop, the keys are the variables,
630+
and the suffixes for the windows are considered.
631+
Example: if you have variable "X", and apply windows [1, 2, 3], you can use "X_1", "X_2", "X_3".
632+
Nevertheless, you can also use the original variable name "X" to apply the same value for all windows.
595633
"""
596634

597635
self.fit(
@@ -626,6 +664,8 @@ def inverse_transform(self, PCs: Union[np.ndarray, xr.Dataset]) -> xr.Dataset:
626664

627665
if isinstance(PCs, xr.Dataset):
628666
X = PCs["PCs"].values
667+
elif isinstance(PCs, xr.DataArray):
668+
X = PCs.values
629669
elif isinstance(PCs, np.ndarray):
630670
X = PCs
631671

bluemath_tk/interpolation/rbf.py

Lines changed: 94 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -669,14 +669,63 @@ def _calc_opt_sigma(
669669

670670
return rbf_coeff, opt_sigma
671671

672-
def _rbf_interpolate(self, dataset: pd.DataFrame) -> pd.DataFrame:
672+
def _rbf_variable_interpolation(
673+
self,
674+
normalized_dataset: pd.DataFrame,
675+
opt_sigma: float,
676+
rbf_coeff: np.ndarray,
677+
num_points_subset: int,
678+
num_vars_subset: int,
679+
) -> np.ndarray:
680+
"""
681+
Interpolates the surface for a variable.
682+
683+
normalized_dataset : pd.DataFrame
684+
The normalized dataset.
685+
opt_sigma : float
686+
The optimal sigma calculated for variable.
687+
rbf_coeff : np.ndarray
688+
The fitted coefficients for variable.
689+
num_points_subset : int
690+
The number of points used in the fitting.
691+
num_vars_subset : int
692+
The number of variables used in the fitting.
693+
694+
np.ndarray
695+
The interpolated variable.
696+
"""
697+
698+
r = np.linalg.norm(
699+
normalized_dataset.values[:, np.newaxis, :]
700+
- self.normalized_subset_data.values[np.newaxis, :, :],
701+
axis=2,
702+
)
703+
kernel_values = self.kernel_func(r, opt_sigma)
704+
linear_part = np.dot(
705+
normalized_dataset.values,
706+
rbf_coeff[
707+
num_points_subset + 1 : num_points_subset + 1 + num_vars_subset
708+
].T,
709+
)
710+
711+
return (
712+
rbf_coeff[num_points_subset]
713+
+ np.dot(kernel_values, rbf_coeff[:num_points_subset])
714+
+ linear_part
715+
)
716+
717+
def _rbf_interpolate(
718+
self, dataset: pd.DataFrame, num_threads: int = None
719+
) -> pd.DataFrame:
673720
"""
674721
This function interpolates the dataset.
675722
676723
Parameters
677724
----------
678725
dataset : pd.DataFrame
679726
The dataset to interpolate (must have same variables as subset).
727+
num_threads : int, optional
728+
The number of threads to use for the interpolation. Default is None.
680729
681730
Returns
682731
-------
@@ -698,28 +747,42 @@ def _rbf_interpolate(self, dataset: pd.DataFrame) -> pd.DataFrame:
698747
)
699748

700749
# Loop through the target variables
701-
for i_var, target_var in enumerate(self.target_processed_variables):
702-
self.logger.info(f"Interpolating target variable {target_var}")
703-
rbf_coeff = self._rbf_coeffs[target_var].values
704-
opt_sigma = self._opt_sigmas[target_var]
705-
r = np.linalg.norm(
706-
normalized_dataset.values[:, np.newaxis, :]
707-
- self.normalized_subset_data.values[np.newaxis, :, :],
708-
axis=2,
709-
)
710-
kernel_values = self.kernel_func(r, opt_sigma)
711-
linear_part = np.dot(
712-
normalized_dataset.values,
713-
rbf_coeff[
714-
num_points_subset + 1 : num_points_subset + 1 + num_vars_subset
715-
].T,
716-
)
717-
s = (
718-
rbf_coeff[num_points_subset]
719-
+ np.dot(kernel_values, rbf_coeff[:num_points_subset])
720-
+ linear_part
721-
)
722-
interpolated_array[:, i_var] = s
750+
if num_threads is not None:
751+
# self.set_num_processors_to_use(num_processors=num_threads)
752+
num_threads = min(num_threads, self.get_num_processors_available())
753+
self.logger.info(f"Using {num_threads} threads for interpolation.")
754+
with ThreadPoolExecutor(max_workers=num_threads) as executor:
755+
rbf_variable_calculation = {
756+
executor.submit(
757+
self._rbf_variable_interpolation,
758+
normalized_dataset,
759+
self._opt_sigmas[target_var],
760+
self._rbf_coeffs[target_var].values,
761+
num_points_subset,
762+
num_vars_subset,
763+
): (i_var, target_var)
764+
for i_var, target_var in enumerate(self.target_processed_variables)
765+
}
766+
for future in as_completed(rbf_variable_calculation):
767+
i_rbf_var, rbf_variable = rbf_variable_calculation[future]
768+
try:
769+
interpolated_var = future.result()
770+
interpolated_array[:, i_rbf_var] = interpolated_array
771+
except Exception as exc:
772+
self.logger.error(
773+
f"Job for {rbf_variable} generated an exception: {exc}."
774+
)
775+
else:
776+
for i_var, target_var in enumerate(self.target_processed_variables):
777+
self.logger.info(f"Interpolating target variable {target_var}")
778+
interpolated_var = self._rbf_variable_interpolation(
779+
normalized_dataset=normalized_dataset,
780+
opt_sigma=self._opt_sigmas[target_var],
781+
rbf_coeff=self._rbf_coeffs[target_var].values,
782+
num_points_subset=num_points_subset,
783+
num_vars_subset=num_vars_subset,
784+
)
785+
interpolated_array[:, i_var] = interpolated_var
723786

724787
return pd.DataFrame(interpolated_array, columns=self.target_processed_variables)
725788

@@ -826,14 +889,16 @@ def fit(
826889
# Set the is_fitted attribute to True
827890
self.is_fitted = True
828891

829-
def predict(self, dataset: pd.DataFrame) -> pd.DataFrame:
892+
def predict(self, dataset: pd.DataFrame, num_threads: int = None) -> pd.DataFrame:
830893
"""
831894
Predicts the data for the provided dataset.
832895
833896
Parameters
834897
----------
835898
dataset : pd.DataFrame
836899
The dataset to predict (must have same variables than subset).
900+
num_threads : int, optional
901+
The number of threads to use for the interpolation. Default is None.
837902
838903
Returns
839904
-------
@@ -857,7 +922,9 @@ def predict(self, dataset: pd.DataFrame) -> pd.DataFrame:
857922
raise RBFError("RBF model must be fitted before predicting.")
858923

859924
self.logger.info("Reconstructing data using fitted coefficients.")
860-
interpolated_target = self._rbf_interpolate(dataset=dataset)
925+
interpolated_target = self._rbf_interpolate(
926+
dataset=dataset, num_threads=num_threads
927+
)
861928
if self.is_target_normalized:
862929
self.logger.info("Denormalizing target data")
863930
interpolated_target = self.denormalize(
@@ -870,6 +937,7 @@ def predict(self, dataset: pd.DataFrame) -> pd.DataFrame:
870937
xu=interpolated_target[f"{directional_variable}_u"].values,
871938
xv=interpolated_target[f"{directional_variable}_v"].values,
872939
)
940+
873941
return interpolated_target
874942

875943
def fit_predict(
@@ -933,4 +1001,4 @@ def fit_predict(
9331001
iteratively_update_sigma=iteratively_update_sigma,
9341002
)
9351003

936-
return self.predict(dataset=dataset)
1004+
return self.predict(dataset=dataset, num_threads=num_threads)

tests/datamining/test_pca.py

Lines changed: 7 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -6,38 +6,23 @@
66
class TestPCA(unittest.TestCase):
77
def setUp(self):
88
self.ds = get_2d_dataset()
9-
self.pca = PCA(n_components=5)
9+
self.pca = PCA(n_components=5, debug=True)
1010
self.ipca = PCA(n_components=5, is_incremental=True)
1111

12-
def test_fit(self):
13-
self.pca.fit(
14-
data=self.ds,
15-
vars_to_stack=["X", "Y"],
16-
coords_to_stack=["coord1", "coord2"],
17-
pca_dim_for_rows="coord3",
18-
)
19-
self.assertEqual(self.pca.is_fitted, True)
20-
21-
def test_transform(self):
22-
self.pca.fit(
23-
data=self.ds,
24-
vars_to_stack=["X", "Y"],
25-
coords_to_stack=["coord1", "coord2"],
26-
pca_dim_for_rows="coord3",
27-
)
28-
pcs = self.pca.transform(
29-
data=self.ds,
30-
)
31-
self.assertEqual(pcs.PCs.shape[1], 5)
32-
3312
def test_fit_transform(self):
3413
pcs = self.pca.fit_transform(
3514
data=self.ds,
3615
vars_to_stack=["X", "Y"],
3716
coords_to_stack=["coord1", "coord2"],
3817
pca_dim_for_rows="coord3",
18+
windows_in_pca_dim_for_rows={"X": [3], "Y": [1]},
19+
value_to_replace_nans={"X": 0.0, "X_3": 1.0, "Y": 0.0},
20+
nan_threshold_to_drop={"X": 0.5, "Y": 0.5},
3921
)
22+
self.assertEqual(self.pca.is_fitted, True)
4023
self.assertEqual(pcs.PCs.shape[1], 5)
24+
self.assertEqual(pcs.PCs.shape[0], self.ds.sizes["coord3"])
25+
self.assertCountEqual(self.pca.eofs.data_vars, ["X", "X_3", "Y", "Y_1"])
4126

4227
def test_inverse_transform(self):
4328
pcs = self.pca.fit_transform(

tests/interpolation/test_rbf.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ def test_fit(self):
2929
target_data=self.target,
3030
target_directional_variables=["DirPred"],
3131
normalize_target_data=True,
32+
num_threads=4,
3233
)
3334
self.assertTrue(self.rbf.is_fitted)
3435
self.assertTrue(self.rbf.is_target_normalized)

0 commit comments

Comments
 (0)