Skip to content

Commit 7668ade

Browse files
committed
Add example on moon dataset
1 parent b8c7a77 commit 7668ade

File tree

1 file changed

+90
-0
lines changed

1 file changed

+90
-0
lines changed
Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,90 @@
1+
# Density Estimation of Moon Data. This exampled is adapted from "In Depth: Gaussian Mixture Models" chapter of
2+
# the Python Data Science Handbook by Jake VanderPlas. The original code can be found
3+
# at https://jakevdp.github.io/PythonDataScienceHandbook/05.12-gaussian-mixtures.html
4+
5+
import matplotlib.pyplot as plt
6+
import numpy as np
7+
from matplotlib.patches import Ellipse
8+
from sklearn.datasets import make_moons
9+
10+
from gmmx import EMFitter, GaussianMixtureModelJax
11+
12+
13+
def draw_ellipse(position, covariance, ax=None, **kwargs):
14+
"""Draw an ellipse with a given position and covariance"""
15+
ax = ax or plt.gca()
16+
17+
# Convert covariance to principal axes
18+
if covariance.shape == (2, 2):
19+
U, s, Vt = np.linalg.svd(covariance)
20+
angle = np.degrees(np.arctan2(U[1, 0], U[0, 0]))
21+
width, height = 2 * np.sqrt(s)
22+
else:
23+
angle = 0
24+
width, height = 2 * np.sqrt(covariance)
25+
26+
# Draw the Ellipse
27+
for nsig in range(1, 4):
28+
ax.add_patch(
29+
Ellipse(
30+
xy=position,
31+
width=nsig * width,
32+
height=nsig * height,
33+
angle=angle,
34+
**kwargs,
35+
)
36+
)
37+
38+
39+
def plot_gmm(gmm, X, label=True, ax=None):
40+
"""Plot the GMM"""
41+
ax = ax or plt.gca()
42+
43+
labels = gmm.predict(X)
44+
45+
if label:
46+
ax.scatter(X[:, 0], X[:, 1], c=labels, s=10, cmap="viridis", zorder=2)
47+
else:
48+
ax.scatter(X[:, 0], X[:, 1], s=10, zorder=2)
49+
ax.axis("equal")
50+
51+
w_factor = 0.2 / gmm.weights_numpy.max()
52+
for pos, covar, w in zip(
53+
gmm.means_numpy, gmm.covariances.values_numpy, gmm.weights_numpy
54+
):
55+
draw_ellipse(pos, covar, alpha=w * w_factor, ax=ax)
56+
57+
58+
def fit_and_plot_gmm(n_components, ax=None):
59+
"""Fit and plot a GMM"""
60+
ax = ax or plt.gca()
61+
x, y = make_moons(200, noise=0.05, random_state=0)
62+
ax.scatter(x[:, 0], x[:, 1])
63+
ax.text(
64+
0.95,
65+
0.9,
66+
f"N Components: {n_components}",
67+
ha="right",
68+
va="bottom",
69+
transform=ax.transAxes,
70+
)
71+
ax.set_xticks([])
72+
ax.set_yticks([])
73+
74+
gmm = GaussianMixtureModelJax.from_k_means(x, n_components=n_components)
75+
76+
fitter = EMFitter(tol=1e-4, max_iter=100)
77+
result = fitter.fit(x=x, gmm=gmm)
78+
79+
plot_gmm(result.gmm, x, ax=ax)
80+
return ax
81+
82+
83+
if __name__ == "__main__":
84+
fig, axes = plt.subplots(4, 4, figsize=(9, 9))
85+
86+
for idx, ax in enumerate(axes.flat):
87+
ax = fit_and_plot_gmm(idx + 1, ax=ax)
88+
89+
plt.tight_layout()
90+
plt.show()

0 commit comments

Comments
 (0)