Skip to content

Commit 7e18fc0

Browse files
authored
Merge pull request #55 from GeoOcean/feature/binwaves-kps
[JTH] pcs df option and binwaves kp extraction working
2 parents deb4ba9 + fc32e63 commit 7e18fc0

File tree

8 files changed

+614
-88
lines changed

8 files changed

+614
-88
lines changed

bluemath_tk/datamining/pca.py

Lines changed: 59 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import cartopy.crs as ccrs
44
import numpy as np
5+
import pandas as pd
56
import xarray as xr
67
from sklearn.decomposition import PCA as PCA_
78
from sklearn.decomposition import IncrementalPCA as IncrementalPCA_
@@ -75,6 +76,8 @@ class PCA(BaseReduction):
7576
The explained variance ratio.
7677
cumulative_explained_variance_ratio : np.ndarray
7778
The cumulative explained variance ratio.
79+
pcs_df : pd.DataFrame
80+
The fitting PCs (self.pcs) as a pandas DataFrame.
7881
7982
Methods
8083
-------
@@ -97,33 +100,41 @@ class PCA(BaseReduction):
97100
98101
Examples
99102
--------
100-
>>> from bluemath_tk.core.data.sample_data import get_2d_dataset
101-
>>> from bluemath_tk.datamining.pca import PCA
102-
>>> ds = get_2d_dataset()
103-
>>> pca = PCA(
104-
... n_components=5,
105-
... is_incremental=False,
106-
... debug=True,
107-
... )
108-
>>> pca.fit(
109-
... data=ds,
110-
... vars_to_stack=["X", "Y"],
111-
... coords_to_stack=["coord1", "coord2"],
112-
... pca_dim_for_rows="coord3",
113-
... windows_in_pca_dim_for_rows={"X": [1, 2, 3]},
114-
... value_to_replace_nans={"X": 0.0},
115-
... nan_threshold_to_drop={"X": 0.95},
116-
... )
117-
>>> pcs = pca.transform(
118-
... data=ds,
119-
... )
120-
>>> reconstructed_ds = pca.inverse_transform(PCs=pcs)
121-
>>> eofs = pca.eofs
122-
>>> explained_variance = pca.explained_variance
123-
>>> explained_variance_ratio = pca.explained_variance_ratio
124-
>>> cumulative_explained_variance_ratio = pca.cumulative_explained_variance_ratio
125-
>>> # Save the full class in a pickle file
126-
>>> pca.save_model("pca_model.pkl")
103+
.. jupyter-execute::
104+
105+
from bluemath_tk.core.data.sample_data import get_2d_dataset
106+
from bluemath_tk.datamining.pca import PCA
107+
108+
ds = get_2d_dataset()
109+
110+
pca = PCA(
111+
n_components=5,
112+
is_incremental=False,
113+
debug=True,
114+
)
115+
pca.fit(
116+
data=ds,
117+
vars_to_stack=["X", "Y"],
118+
coords_to_stack=["coord1", "coord2"],
119+
pca_dim_for_rows="coord3",
120+
windows_in_pca_dim_for_rows={"X": [1, 2, 3]},
121+
value_to_replace_nans={"X": 0.0},
122+
nan_threshold_to_drop={"X": 0.95},
123+
)
124+
pcs = pca.transform(
125+
data=ds,
126+
)
127+
reconstructed_ds = pca.inverse_transform(PCs=pcs)
128+
eofs = pca.eofs
129+
explained_variance = pca.explained_variance
130+
explained_variance_ratio = pca.explained_variance_ratio
131+
cumulative_explained_variance_ratio = pca.cumulative_explained_variance_ratio
132+
133+
# Save the full class in a pickle file
134+
pca.save_model("pca_model.pkl")
135+
136+
# Plot the calculated EOFs
137+
pca.plot_eofs(vars_to_plot=["X", "Y"], num_eofs=3)
127138
128139
References
129140
----------
@@ -212,10 +223,10 @@ def __init__(
212223

213224
# Exclude attributes from beign saved with pca.save_model()
214225
self._exclude_attributes = [
215-
# "_data",
226+
"_data",
216227
"_window_processed_data",
217-
# "_stacked_data_matrix",
218-
# "_standarized_stacked_data_matrix",
228+
"_stacked_data_matrix",
229+
"_standarized_stacked_data_matrix",
219230
]
220231

221232
@property
@@ -254,6 +265,19 @@ def explained_variance_ratio(self) -> np.ndarray:
254265
def cumulative_explained_variance_ratio(self) -> np.ndarray:
255266
return np.cumsum(self.explained_variance_ratio)
256267

268+
@property
269+
def pcs_df(self) -> pd.DataFrame:
270+
if self.pcs is not None:
271+
return pd.DataFrame(
272+
data=self.pcs["PCs"].values,
273+
columns=[f"PC{i + 1}" for i in range(self.pca.n_components_)],
274+
index=self.pcs[self.pca_dim_for_rows].values,
275+
)
276+
else:
277+
raise PCAError(
278+
"PCA model must be fitted and transformed before calling pcs_df"
279+
)
280+
257281
def _generate_stacked_data(self, data: xr.Dataset) -> np.ndarray:
258282
"""
259283
Generate stacked data matrix.
@@ -606,7 +630,7 @@ def transform(self, data: xr.Dataset, after_fitting: bool = False) -> xr.Dataset
606630
transformed_data = self.pca.transform(X=processed_data)
607631

608632
# Save the Principal Components (PCs) in an xr.Dataset
609-
self.pcs = xr.Dataset(
633+
pcs = xr.Dataset(
610634
{
611635
"PCs": ((self.pca_dim_for_rows, "n_component"), transformed_data),
612636
"stds": (("n_component",), np.std(transformed_data, axis=0)),
@@ -616,8 +640,10 @@ def transform(self, data: xr.Dataset, after_fitting: bool = False) -> xr.Dataset
616640
"n_component": np.arange(self.pca.n_components_),
617641
},
618642
)
643+
if after_fitting:
644+
self.pcs = pcs.copy()
619645

620-
return self.pcs.copy()
646+
return pcs
621647

622648
def fit_transform(
623649
self,
@@ -736,7 +762,7 @@ def plot_pcs(self, num_pcs: int, pcs: xr.Dataset = None) -> None:
736762
"No Principal Components (PCs) found. Please transform some data first."
737763
)
738764
self.logger.info("Using the Principal Components (PCs) from the class")
739-
pcs = self.pcs
765+
pcs = self.pcs.copy()
740766

741767
_ = (
742768
pcs["PCs"]

bluemath_tk/wrappers/_base_wrappers.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -422,7 +422,7 @@ def run_cases(
422422
cases_dir_to_run = copy.deepcopy(self.cases_dirs)
423423

424424
if parallel:
425-
num_threads = self.get_num_processors_available()
425+
num_threads = 5 # self.get_num_processors_available()
426426
self.logger.debug(
427427
f"Running cases in parallel with launcher={launcher}. Number of threads: {num_threads}."
428428
)

bluemath_tk/wrappers/swan/swan_example.py

Lines changed: 140 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -12,69 +12,173 @@
1212
from bluemath_tk.wrappers.swan.swan_wrapper import SwanModelWrapper
1313

1414
example_directions = [
15+
1.5,
16+
4.5,
1517
7.5,
18+
10.5,
19+
13.5,
20+
16.5,
21+
19.5,
1622
22.5,
23+
25.5,
24+
28.5,
25+
31.5,
26+
34.5,
1727
37.5,
28+
40.5,
29+
43.5,
30+
46.5,
31+
49.5,
1832
52.5,
33+
55.5,
34+
58.5,
35+
61.5,
36+
64.5,
1937
67.5,
38+
70.5,
39+
73.5,
40+
76.5,
41+
79.5,
2042
82.5,
43+
85.5,
44+
88.5,
45+
91.5,
46+
94.5,
2147
97.5,
48+
100.5,
49+
103.5,
50+
106.5,
51+
109.5,
2252
112.5,
53+
115.5,
54+
118.5,
55+
121.5,
56+
124.5,
2357
127.5,
58+
130.5,
59+
133.5,
60+
136.5,
61+
139.5,
2462
142.5,
63+
145.5,
64+
148.5,
65+
151.5,
66+
154.5,
2567
157.5,
68+
160.5,
69+
163.5,
70+
166.5,
71+
169.5,
2672
172.5,
73+
175.5,
74+
178.5,
75+
181.5,
76+
184.5,
2777
187.5,
78+
190.5,
79+
193.5,
80+
196.5,
81+
199.5,
2882
202.5,
83+
205.5,
84+
208.5,
85+
211.5,
86+
214.5,
2987
217.5,
88+
220.5,
89+
223.5,
90+
226.5,
91+
229.5,
3092
232.5,
93+
235.5,
94+
238.5,
95+
241.5,
96+
244.5,
3197
247.5,
98+
250.5,
99+
253.5,
100+
256.5,
101+
259.5,
32102
262.5,
103+
265.5,
104+
268.5,
105+
271.5,
106+
274.5,
33107
277.5,
108+
280.5,
109+
283.5,
110+
286.5,
111+
289.5,
34112
292.5,
113+
295.5,
114+
298.5,
115+
301.5,
116+
304.5,
35117
307.5,
118+
310.5,
119+
313.5,
120+
316.5,
121+
319.5,
36122
322.5,
123+
325.5,
124+
328.5,
125+
331.5,
126+
334.5,
37127
337.5,
128+
340.5,
129+
343.5,
130+
346.5,
131+
349.5,
38132
352.5,
133+
355.5,
134+
358.5,
39135
]
40136
example_frequencies = [
41-
0.035,
42-
0.0385,
43-
0.042349998,
44-
0.046585,
45-
0.051243503,
46-
0.05636785,
47-
0.062004633,
48-
0.068205096,
49-
0.07502561,
50-
0.082528174,
51-
0.090780996,
52-
0.099859096,
53-
0.10984501,
54-
0.120829515,
55-
0.13291247,
56-
0.14620373,
57-
0.1608241,
58-
0.17690653,
59-
0.19459718,
60-
0.21405691,
61-
0.2354626,
62-
0.25900885,
63-
0.28490975,
64-
0.31340075,
65-
0.3447408,
66-
0.37921488,
67-
0.4171364,
68-
0.45885003,
69-
0.50473505,
137+
0.03,
138+
0.033,
139+
0.0363,
140+
0.0399,
141+
0.0438,
142+
0.0482,
143+
0.053,
144+
0.0582,
145+
0.064,
146+
0.0704,
147+
0.0774,
148+
0.0851,
149+
0.0935,
150+
0.1028,
151+
0.1131,
152+
0.1243,
153+
0.1367,
154+
0.1503,
155+
0.1652,
156+
0.1816,
157+
0.1997,
158+
0.2195,
159+
0.2413,
160+
0.2653,
161+
0.2917,
162+
0.3207,
163+
0.3526,
164+
0.3876,
165+
0.4262,
166+
0.4685,
167+
0.5151,
168+
0.5663,
169+
0.6226,
170+
0.6845,
171+
0.7525,
172+
0.8273,
173+
0.9096,
174+
1.0,
70175
]
71176

72177

73178
class BinWavesWrapper(SwanModelWrapper):
74179
""" """
75180

76181
def build_case(self, case_dir: str, case_context: dict):
77-
self.logger.info(f"Saving spectrum for {case_dir}")
78182
input_spectrum = construct_partition(
79183
freq_name="jonswap",
80184
freq_kwargs={
@@ -105,13 +209,12 @@ def build_cases(self, mode="one_by_one"):
105209
templates_dir = (
106210
"/home/tausiaj/GitHub-GeoOcean/BlueMath/bluemath_tk/wrappers/swan/templates/"
107211
)
108-
templates_name = ["input.swn", "depth_main.dat", "buoys.loc"]
109-
output_dir = "/home/tausiaj/GitHub-GeoOcean/BlueMath/test_cases/swan/javi/"
212+
templates_name = ["input.swn", "depth_main_cantabria.dat", "buoys.loc"]
213+
output_dir = "/home/tausiaj/GitHub-GeoOcean/BlueMath/test_cases/swan/CAN/"
110214
# Load swan model parameters
111215
model_parameters = (
112216
xr.open_dataset("/home/tausiaj/GitHub-GeoOcean/BlueMath/test_data/subset.nc")
113217
.to_dataframe()
114-
.iloc[::60]
115218
.to_dict(orient="list")
116219
)
117220
# Create an instance of the SWAN model wrapper
@@ -126,11 +229,11 @@ def build_cases(self, mode="one_by_one"):
126229
# List available launchers
127230
print(swan_wrapper.list_available_launchers())
128231
# Run the model
129-
swan_wrapper.run_cases(launcher="docker", parallel=True)
232+
# swan_wrapper.run_cases(launcher="docker", parallel=True)
130233
# Post-process the output files
131-
postprocessed_ds = swan_wrapper.postprocess_cases()
132-
postprocessed_ds.to_netcdf(op.join(swan_wrapper.output_dir, "waves_part.nc"))
133-
print(postprocessed_ds)
234+
# postprocessed_ds = swan_wrapper.postprocess_cases()
235+
# postprocessed_ds.to_netcdf(op.join(swan_wrapper.output_dir, "waves_part.nc"))
236+
# print(postprocessed_ds)
134237
# # Load spectra example
135238
# spectra = xr.open_dataset(
136239
# "/home/tausiaj/GitHub-GeoOcean/BlueMath/test_data/Waves_Cantabria_356.08_43.82.nc"

0 commit comments

Comments
 (0)