-
Notifications
You must be signed in to change notification settings - Fork 6
Expand file tree
/
Copy pathgaussian.py
More file actions
105 lines (83 loc) · 3.42 KB
/
gaussian.py
File metadata and controls
105 lines (83 loc) · 3.42 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
import numpy as np
from sklearn.covariance import OAS
from diffeo import DiffeomorphicMixin
class ClassConditionalGaussianPrior:
def __init__(self, random_state=None):
if isinstance(random_state, np.random.RandomState):
self.rng = random_state
else:
self.rng = np.random.RandomState(random_state)
self.means_ = None
self.covariances_cholesky_ = None
def fit(self, X: np.ndarray, y: np.ndarray):
"""
Estimate one Gaussian per class.
Args:
X: (n_samples, dim) features
y: (n_samples,) integer labels
"""
self.classes_ = np.unique(y)
self.n_classes_ = len(self.classes_)
self.dim_ = X.shape[1]
self.means_ = []
self.covariances_cholesky_ = []
for cls in self.classes_:
X_cls = X[y == cls]
mean = X_cls.mean(axis=0)
self.means_.append(mean)
cov = OAS(store_precision=False).fit(X_cls).covariance_
L = np.linalg.cholesky(cov)
self.covariances_cholesky_.append(L)
self.means_ = np.stack(self.means_)
self.covariances_cholesky_ = np.stack(self.covariances_cholesky_)
def sample(self, y_cond: np.ndarray) -> np.ndarray:
"""
Sample from the corresponding Gaussian for each class.
Args:
y_cond: (n_samples,) integer labels
Returns:
samples: (n_samples, dim) samples from the corresponding Gaussian
"""
if self.means_ is None or self.covariances_cholesky_ is None:
raise RuntimeError("Call `fit` before sampling.")
samples = np.empty((len(y_cond), self.means_.shape[-1]))
for i, k in enumerate(np.unique(y_cond)):
mean = self.means_[k]
cov_cholesky = self.covariances_cholesky_[k]
idx = y_cond == k
samples[idx] = (
self.rng.randn(np.sum(idx), self.dim_) @ cov_cholesky.T + mean
)
# Return np.newaxis to ensure the output shape is (n_steps, ...)
return samples[np.newaxis, ...]
class DiffeoGauss(DiffeomorphicMixin):
"""Baseline Gaussian model operating in a diffeomorphic latent space."""
def __init__(self, config: dict):
self.config = dict(config)
rng = self.config.get("RNG")
diffeo = self.config.get("DIFFEO")
super().__init__(diffeomorphism=diffeo)
self._prior = ClassConditionalGaussianPrior(random_state=rng)
def fit(self, X: np.ndarray, y: np.ndarray):
X = np.asarray(X)
y = np.asarray(y)
X_proj = self._fit_transform_features(X)
self._prior.fit(X_proj, y)
return None
def sample(self, y_cond: np.ndarray) -> np.ndarray:
samples = self._prior.sample(y_cond).squeeze(0)
samples = samples[np.newaxis, ...]
return self._inverse_transform_features(samples)
def set_diffeomorphism(self, diffeomorphism: str | None) -> None:
self.config["DIFFEO"] = diffeomorphism
super().set_diffeomorphism(diffeomorphism)
self._prior = ClassConditionalGaussianPrior(random_state=self.config.get("RNG"))
if __name__ == "__main__":
# Example usage
X = np.random.rand(100, 10)
y = np.random.randint(0, 2, size=100)
y_cond = np.random.randint(0, 2, size=20)
prior = ClassConditionalGaussianPrior()
prior.fit(X, y)
samples = prior.sample(y_cond)
assert samples.shape == (20, 10)