-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathmain.py
More file actions
68 lines (55 loc) · 1.9 KB
/
main.py
File metadata and controls
68 lines (55 loc) · 1.9 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
import numpy as np
import kmeans
import common
import naive_em
import em
from matplotlib import pyplot as plt
X = np.loadtxt("toy_data.txt")
K = [1,2,3,4]
k = len(K)
seeds = [0,1,2,3,4]
mixtures_k, posts_k, costs_k, seeds_k, mixtures_em, posts_em, costs_em, seeds_em = k*[[]], k*[[]], k*[10000], k*[0], k*[[]], k*[[]], k*[-10000000], k*[0]
#===============================
# (2) K-Means
#===============================
print('=================/n K-MEANS\n================= ')
for k in K:
for seed in seeds:
mix, post = common.init(X, k, seed) # random initialization
mix_k, post_k, cost_k = kmeans.run(X, mix, post) # run kmeans
# Only keep max values
print(f'k={k}, seed={seed}, cost_k={cost_k}')
if cost_k < costs_k[k-1]:
mixtures_k[k-1] = mix_k
posts_k[k-1] = post_k
costs_k[k-1] = cost_k
seeds_k[k-1] = seed
# common.plot(X, mixtures_k[k-1], posts_k[k-1],f'K-means K={k}, seed={seeds_k[k-1]}')
#===============================
# (4) K-means vs. EM
#===============================
print('=================/n EM\n================= ')
mixtures_k, posts_k, costs_k, seeds_k = k*[[]], k*[[]], k*[10000], k*[0]
for k in K:
for seed in seeds:
mix, post = common.init(X, k, seed) # random initialization
mix_em, post_em, cost_em = naive_em.run(X, mix, post)
# Only keep max values
print(f'k={k}, seed={seed}, cost_k={cost_k}')
if cost_em > costs_em[k-1]:
mixtures_em[k-1] = mix_em
posts_em[k-1] = post_em
costs_em[k-1] = cost_em
seeds_em[k-1] = seed
# common.plot(X, mixtures_k[k-1], posts_k[k-1],f'Naive EM K={k}, seed={seeds_k[k-1]}')
# ============================
# (5) BIC
# ============================
print('=================/n BAYESIAN CRITERION\n================= ')
k = len(K)
bic = k*[0]
for k in K:
bic[k-1] = common.bic(X, mixtures_em[k-1], costs_em[k-1])
plt.plot(K, bic)
plt.show()
print(bic)