Skip to content

Commit 716d62e

Browse files
author
Alexander Ororbia
committed
cleaned up density structure, use parent mixture class to organize model variations
1 parent fa9822f commit 716d62e

File tree

5 files changed

+56
-6
lines changed

5 files changed

+56
-6
lines changed

docs/source/ngclearn.utils.density.rst

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,14 @@ ngclearn.utils.density.gmm module
2020
:undoc-members:
2121
:show-inheritance:
2222

23+
ngclearn.utils.density.mixture module
24+
-------------------------------------
25+
26+
.. automodule:: ngclearn.utils.density.mixture
27+
:members:
28+
:undoc-members:
29+
:show-inheritance:
30+
2331
Module contents
2432
---------------
2533

ngclearn/utils/density/__init__.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
1+
from .mixture import Mixture ## general mixture template parent class
12
## point to supported density estimator models
2-
from .gmm import GMM ## Gaussian mixture
3-
from .bmm import BMM ## Bernoulli mixture
3+
from .gmm import GMM ## Gaussian mixture model
4+
from .bmm import BMM ## Bernoulli mixture model
45

ngclearn/utils/density/bmm.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33
import time, sys
44
import numpy as np
55

6+
from ngclearn.utils.density.mixture import Mixture
7+
68
########################################################################################################################
79
## internal routines for mixture model
810
########################################################################################################################
@@ -58,7 +60,7 @@ def _sample_component(dkey, n_samples, mu): ## samples a component (of mixture)
5860

5961
########################################################################################################################
6062

61-
class BMM: ## Bernoulli mixture model (mixture-of-Bernoullis)
63+
class BMM(Mixture): ## Bernoulli mixture model (mixture-of-Bernoullis)
6264
"""
6365
Implements a Bernoulli mixture model (BMM) -- or mixture of Bernoullis (MoB).
6466
Adaptation of parameters is conducted via the Expectation-Maximization (EM)
@@ -74,7 +76,8 @@ class BMM: ## Bernoulli mixture model (mixture-of-Bernoullis)
7476
init_kmeans: <Unsupported>
7577
"""
7678

77-
def __init__(self, K, max_iter=50, init_kmeans=False, key=None):
79+
def __init__(self, K, max_iter=50, init_kmeans=False, key=None, **kwargs):
80+
super().__init__(K, max_iter, **kwargs)
7881
self.K = K
7982
self.max_iter = int(max_iter)
8083
self.init_kmeans = init_kmeans ## Unsupported currently
@@ -204,3 +207,4 @@ def sample(self, n_samples, mode_j=-1):
204207
Xs.append(x_s)
205208
Xs = jnp.concat(Xs, axis=0)
206209
return Xs
210+

ngclearn/utils/density/gmm.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33
import time, sys
44
import numpy as np
55

6+
from ngclearn.utils.density.mixture import Mixture
7+
68
########################################################################################################################
79
## internal routines for mixture model
810
########################################################################################################################
@@ -99,7 +101,7 @@ def _sample_component(dkey, n_samples, mu, Sigma, assume_diag_cov=False): ## sam
99101

100102
########################################################################################################################
101103

102-
class GMM: ## Gaussian mixture model (mixture-of-Gaussians)
104+
class GMM(Mixture): ## Gaussian mixture model (mixture-of-Gaussians)
103105
"""
104106
Implements a Gaussian mixture model (GMM) -- or mixture of Gaussians (MoG).
105107
Adaptation of parameters is conducted via the Expectation-Maximization (EM)
@@ -119,7 +121,8 @@ class GMM: ## Gaussian mixture model (mixture-of-Gaussians)
119121
# init_kmeans: if True, first learn use the K-Means algorithm to initialize
120122
# the component Gaussians of this GMM (Default = False)
121123

122-
def __init__(self, K, max_iter=50, assume_diag_cov=False, init_kmeans=False, key=None):
124+
def __init__(self, K, max_iter=50, assume_diag_cov=False, init_kmeans=False, key=None, **kwargs):
125+
super().__init__(K, max_iter, **kwargs)
123126
self.K = K
124127
self.max_iter = int(max_iter)
125128
self.assume_diag_cov = assume_diag_cov
@@ -265,3 +268,4 @@ def sample(self, n_samples, mode_j=-1):
265268
Xs.append(x_s)
266269
Xs = jnp.concat(Xs, axis=0)
267270
return Xs
271+

ngclearn/utils/density/mixture.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
2+
3+
class Mixture: ## General mixture structure
4+
"""
5+
Implements a general mixture model template/structure. Effectively, this is the parent
6+
class/template for mixtures of distributions.
7+
8+
Args:
9+
K: the number of components/latent variables within this mixture model
10+
11+
max_iter: the maximum number of iterations to fit parameters to data (Default = 50)
12+
13+
"""
14+
15+
def __init__(self, K, max_iter=50, **kwargs):
16+
self.K = K
17+
self.max_iter = max_iter
18+
19+
def init(self, X): ## model data-dependent initialization function
20+
pass
21+
22+
def calc_log_likelihood(self, X): ## log-likelihood calculation routine
23+
pass
24+
25+
def fit(self, X, tol=1e-3, verbose=False): ## outer fitting process
26+
pass
27+
28+
def update(self, X): ## inner/iterative adjustment/update step
29+
pass
30+
31+
def sample(self, n_samples, mode_j=-1): ## model sampling routine
32+
pass
33+

0 commit comments

Comments
 (0)