-
Notifications
You must be signed in to change notification settings - Fork 1
/
demo.py
37 lines (29 loc) · 1.02 KB
/
demo.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
import numpy as np
from pyMM import (GMM, SphericalGMM, DiagonalGMM, MPPCA, MFA,
MFA_Miss)
from util import _generate_mixture_data, plot_density
def main():
n_examples = 700
data_dim = 2
n_components = 6
X = _generate_mixture_data(data_dim, n_components, n_examples)
# Obscure data
r = np.random.rand(n_examples, data_dim)
X_miss = X.copy()
X_miss[r > 0.7] = np.nan
# Initialize model
# gmm = GMM(n_components=8)
# gmm = SphericalGMM(n_components=n_components)
# gmm = DiagonalGMM(n_components=n_components)
# gmm = MPPCA(n_components=n_components, latent_dim=1)
gmm = MFA(n_components=n_components, latent_dim=2)
# gmm = MPPCA_Miss(n_components=n_components, latent_dim=2)
# gmm = MFA_Miss(n_components=n_components, latent_dim=1)
# Fit GMM
# gmm.fit(X_miss, init_method='kmeans')
gmm.fit(X, init_method='kmeans')
# print(gmm.score_samples(X))
# Plot results
plot_density(gmm, X=X, n_grid=50)
# plt.savefig('test.png', dpi=600)
main()