Skip to content

Commit caee8b7

Browse files
committed
Add tests
1 parent 447e777 commit caee8b7

File tree

2 files changed

+22
-1
lines changed

2 files changed

+22
-1
lines changed

gmmx/gmm.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -745,7 +745,10 @@ def predict_proba(self, x: jax.Array) -> jax.Array:
745745
Predicted probabilities
746746
"""
747747
log_prob = self.log_prob(x)
748-
return jnp.exp(log_prob)
748+
log_prob_norm = jax.scipy.special.logsumexp(
749+
log_prob, axis=Axis.components, keepdims=True
750+
)
751+
return jnp.exp(log_prob - log_prob_norm)
749752

750753
@jax.jit
751754
def score_samples(self, x: jax.Array) -> jax.Array:

tests/test_gmm.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,24 @@ def test_against_sklearn(gmm_jax):
8080
assert gmm_jax.n_parameters == gmm._n_parameters()
8181

8282

83+
@pytest.mark.parametrize(
84+
"method", ["aic", "bic", "predict", "predict_proba", "score", "score_samples"]
85+
)
86+
def test_against_sklearn_all(gmm_jax, method):
87+
gmm = gmm_jax.to_sklearn()
88+
x = np.array([
89+
[1, 2, 3],
90+
[1, 4, 2],
91+
[1, 0, 6],
92+
[4, 2, 4],
93+
[4, 4, 4],
94+
[4, 0, 2],
95+
])
96+
result_sklearn = getattr(gmm, method)(x)
97+
result_jax = getattr(gmm_jax, method)(jnp.asarray(x))
98+
assert_allclose(np.squeeze(result_jax), result_sklearn, rtol=1e-5)
99+
100+
83101
def test_sample(gmm_jax):
84102
key = jax.random.PRNGKey(0)
85103
samples = gmm_jax.sample(key, 2)

0 commit comments

Comments
 (0)