Skip to content

Commit a72a9c7

Browse files
committed
[JTH] add different row dimesnsion recon in pca
1 parent ef3bfd6 commit a72a9c7

File tree

3 files changed

+14
-12
lines changed

3 files changed

+14
-12
lines changed

bluemath_tk/core/models.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ class BlueMathModel(ABC):
3434
@abstractmethod
3535
def __init__(self) -> None:
3636
self._logger: logging.Logger = None
37-
self._exclude_attributes: List[str] = ["_logger"]
37+
self._exclude_attributes: List[str] = []
3838

3939
def __getstate__(self):
4040
"""Exclude certain attributes from being pickled."""

bluemath_tk/datamining/pca.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -482,13 +482,7 @@ def _reshape_data(self, X: np.ndarray, destandarize: bool = True) -> xr.Dataset:
482482
)
483483
}
484484

485-
return xr.Dataset(
486-
X_reshaped_vars_dict,
487-
coords={
488-
self.pca_dim_for_rows: self.data[self.pca_dim_for_rows],
489-
**{coord: self.data[coord] for coord in self.coords_to_stack},
490-
},
491-
)
485+
return X_reshaped_vars_dict
492486

493487
@validate_data_pca
494488
def fit(
@@ -671,7 +665,15 @@ def inverse_transform(self, PCs: Union[np.ndarray, xr.Dataset]) -> xr.Dataset:
671665

672666
self.logger.info("Inverse transforming data using PCA model")
673667
X_transformed = self.pca.inverse_transform(X=X)
674-
data_transformed = self._reshape_data(X=X_transformed, destandarize=True)
668+
data_reshaped_vars_dict = self._reshape_data(X=X_transformed, destandarize=True)
669+
# Create xarray Dataset with the transformed data
670+
data_transformed = xr.Dataset(
671+
data_reshaped_vars_dict,
672+
coords={
673+
self.pca_dim_for_rows: PCs[self.pca_dim_for_rows].values,
674+
**{coord: self.data[coord] for coord in self.coords_to_stack},
675+
},
676+
)
675677
# Transpose dimensions based on the original data
676678
data_transformed = data_transformed.transpose(*self.data.dims)
677679

tests/datamining/test_pca.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -67,10 +67,10 @@ def test_inverse_transform(self):
6767
coords_to_stack=["coord1", "coord2"],
6868
pca_dim_for_rows="coord3",
6969
)
70-
reconstructed_ds = self.pca.inverse_transform(PCs=pcs)
70+
reconstructed_ds = self.pca.inverse_transform(PCs=pcs.isel(coord3=slice(0, 5)))
7171
self.assertAlmostEqual(
72-
self.ds.isel(coord1=5, coord2=5, coord3=5),
73-
reconstructed_ds.isel(coord1=5, coord2=5, coord3=5),
72+
self.ds.isel(coord1=5, coord2=5, coord3=1),
73+
reconstructed_ds.isel(coord1=5, coord2=5, coord3=1),
7474
)
7575

7676
def test_incremental_fit(self):

0 commit comments

Comments
 (0)