Note
Click here to download the full example code
Showcase of different models and algorithms applied to same dataset.
In this example, we show how bnpy makes it easy to apply different models and algorithms to the same dataset.
import bnpy
import numpy as np
import os
from matplotlib import pylab
import seaborn as sns
SMALL_FIG_SIZE = (2.5, 2.5)
FIG_SIZE = (5, 5)
pylab.rcParams['figure.figsize'] = FIG_SIZE
np.set_printoptions(precision=3, suppress=1, linewidth=200)
Load dataset from file
dataset_path = os.path.join(bnpy.DATASET_PATH, 'faithful')
dataset = bnpy.data.XData.read_csv(
os.path.join(dataset_path, 'faithful.csv'))
Make a simple plot of the raw data
pylab.plot(dataset.X[:, 0], dataset.X[:, 1], 'k.')
pylab.xlabel(dataset.column_names[0])
pylab.ylabel(dataset.column_names[1])
pylab.tight_layout()
data_ax_h = pylab.gca()
def show_clusters_over_time(
task_output_path=None,
query_laps=[0, 1, 2, 10, 20, None],
nrows=2):
''' Show 2D elliptical contours overlaid on raw data.
'''
ncols = int(np.ceil(len(query_laps) // float(nrows)))
fig_handle, ax_handle_list = pylab.subplots(
figsize=(SMALL_FIG_SIZE[0] * ncols, SMALL_FIG_SIZE[1] * nrows),
nrows=nrows, ncols=ncols, sharex=True, sharey=True)
for plot_id, lap_val in enumerate(query_laps):
cur_model, lap_val = bnpy.load_model_at_lap(task_output_path, lap_val)
cur_ax_handle = ax_handle_list.flatten()[plot_id]
bnpy.viz.PlotComps.plotCompsFromHModel(
cur_model, dataset=dataset, ax_handle=cur_ax_handle)
cur_ax_handle.set_title("lap: %d" % lap_val)
cur_ax_handle.set_xlabel(dataset.column_names[0])
cur_ax_handle.set_ylabel(dataset.column_names[1])
cur_ax_handle.set_xlim(data_ax_h.get_xlim())
cur_ax_handle.set_ylim(data_ax_h.get_ylim())
pylab.tight_layout()
Assume diagonal covariances.
Start with too many clusters (K=20)
gamma = 5.0
sF = 5.0
K = 20
diag_trained_model, diag_info_dict = bnpy.run(
dataset, 'DPMixtureModel', 'DiagGauss', 'memoVB',
output_path='/tmp/faithful/showcase-K=20-lik=DiagGauss-ECovMat=5*eye/',
nLap=1000, nTask=1, nBatch=1, convergeThr=0.0001,
gamma0=gamma, sF=sF, ECovMat='eye',
K=K, initname='randexamples',
)
show_clusters_over_time(diag_info_dict['task_output_path'])
Dataset Summary:
X Data
total size: 272 units
batch size: 272 units
num. batches: 1
Allocation Model: DP mixture with K=0. Concentration gamma0= 5.00
Obs. Data Model: Gaussian with diagonal covariance.
Obs. Data Prior: independent Gauss-Wishart prior on each dimension
Wishart params
nu = 4
beta = [ 10 10]
Expectations
E[ mean[k]] =
[ 0 0]
E[ covar[k]] =
[[5. 0.]
[0. 5.]]
Initialization:
initname = randexamples
K = 20 (number of clusters)
seed = 1607680
elapsed_time: 0.0 sec
Learn Alg: memoVB | task 1/1 | alg. seed: 1607680 | data order seed: 8541952
task_output_path: /tmp/faithful/showcase-K=20-lik=DiagGauss-ECovMat=5*eye/1
1.000/1000 after 0 sec. | 213.1 MiB | K 20 | loss 3.002088161e+00 |
2.000/1000 after 0 sec. | 213.1 MiB | K 20 | loss 2.966309748e+00 | Ndiff 4.584
3.000/1000 after 0 sec. | 213.1 MiB | K 20 | loss 2.949498401e+00 | Ndiff 4.501
4.000/1000 after 0 sec. | 213.1 MiB | K 20 | loss 2.931792100e+00 | Ndiff 5.452
5.000/1000 after 0 sec. | 213.1 MiB | K 20 | loss 2.910174417e+00 | Ndiff 6.264
6.000/1000 after 0 sec. | 213.1 MiB | K 20 | loss 2.881973764e+00 | Ndiff 6.557
7.000/1000 after 0 sec. | 213.1 MiB | K 20 | loss 2.833491488e+00 | Ndiff 6.194
8.000/1000 after 0 sec. | 213.1 MiB | K 20 | loss 2.793876249e+00 | Ndiff 5.582
9.000/1000 after 0 sec. | 213.1 MiB | K 20 | loss 2.752249036e+00 | Ndiff 4.580
10.000/1000 after 0 sec. | 213.1 MiB | K 20 | loss 2.698456106e+00 | Ndiff 4.740
11.000/1000 after 0 sec. | 213.1 MiB | K 20 | loss 2.651622886e+00 | Ndiff 5.753
12.000/1000 after 0 sec. | 213.1 MiB | K 20 | loss 2.624108159e+00 | Ndiff 6.393
13.000/1000 after 0 sec. | 213.1 MiB | K 20 | loss 2.587572334e+00 | Ndiff 6.074
14.000/1000 after 0 sec. | 213.1 MiB | K 20 | loss 2.560219008e+00 | Ndiff 4.596
15.000/1000 after 0 sec. | 213.1 MiB | K 20 | loss 2.544887510e+00 | Ndiff 2.425
16.000/1000 after 0 sec. | 213.1 MiB | K 20 | loss 2.527562170e+00 | Ndiff 1.537
17.000/1000 after 0 sec. | 213.1 MiB | K 20 | loss 2.513204290e+00 | Ndiff 1.470
18.000/1000 after 0 sec. | 213.1 MiB | K 20 | loss 2.507574530e+00 | Ndiff 1.454
19.000/1000 after 0 sec. | 213.1 MiB | K 20 | loss 2.497639081e+00 | Ndiff 1.477
20.000/1000 after 0 sec. | 213.1 MiB | K 20 | loss 2.497002974e+00 | Ndiff 1.526
21.000/1000 after 0 sec. | 213.1 MiB | K 20 | loss 2.496368366e+00 | Ndiff 1.584
22.000/1000 after 0 sec. | 213.1 MiB | K 20 | loss 2.495732569e+00 | Ndiff 1.638
23.000/1000 after 0 sec. | 213.1 MiB | K 20 | loss 2.495103986e+00 | Ndiff 1.678
24.000/1000 after 0 sec. | 213.1 MiB | K 20 | loss 2.494497370e+00 | Ndiff 1.697
25.000/1000 after 0 sec. | 213.1 MiB | K 20 | loss 2.493928089e+00 | Ndiff 1.698
26.000/1000 after 0 sec. | 213.1 MiB | K 20 | loss 2.493406054e+00 | Ndiff 1.687
27.000/1000 after 0 sec. | 213.1 MiB | K 20 | loss 2.492931650e+00 | Ndiff 1.674
28.000/1000 after 0 sec. | 213.1 MiB | K 20 | loss 2.492495409e+00 | Ndiff 1.670
29.000/1000 after 0 sec. | 213.1 MiB | K 20 | loss 2.492081031e+00 | Ndiff 1.682
30.000/1000 after 0 sec. | 213.1 MiB | K 20 | loss 2.491669505e+00 | Ndiff 1.715
31.000/1000 after 0 sec. | 213.1 MiB | K 20 | loss 2.491242078e+00 | Ndiff 1.771
32.000/1000 after 0 sec. | 213.1 MiB | K 20 | loss 2.490781261e+00 | Ndiff 1.851
33.000/1000 after 0 sec. | 213.1 MiB | K 20 | loss 2.490270354e+00 | Ndiff 1.954
34.000/1000 after 0 sec. | 213.1 MiB | K 20 | loss 2.489692235e+00 | Ndiff 2.082
35.000/1000 after 0 sec. | 213.1 MiB | K 20 | loss 2.489027945e+00 | Ndiff 2.235
36.000/1000 after 0 sec. | 213.1 MiB | K 20 | loss 2.488255286e+00 | Ndiff 2.414
37.000/1000 after 0 sec. | 213.1 MiB | K 20 | loss 2.487347489e+00 | Ndiff 2.621
38.000/1000 after 0 sec. | 213.1 MiB | K 20 | loss 2.486272015e+00 | Ndiff 2.855
39.000/1000 after 0 sec. | 213.1 MiB | K 20 | loss 2.484989607e+00 | Ndiff 3.113
40.000/1000 after 1 sec. | 213.1 MiB | K 20 | loss 2.483453885e+00 | Ndiff 3.391
41.000/1000 after 1 sec. | 213.1 MiB | K 20 | loss 2.481612017e+00 | Ndiff 3.675
42.000/1000 after 1 sec. | 213.1 MiB | K 20 | loss 2.479407351e+00 | Ndiff 3.945
43.000/1000 after 1 sec. | 213.1 MiB | K 20 | loss 2.476785615e+00 | Ndiff 4.169
44.000/1000 after 1 sec. | 213.1 MiB | K 20 | loss 2.473707547e+00 | Ndiff 4.301
45.000/1000 after 1 sec. | 213.1 MiB | K 20 | loss 2.470173027e+00 | Ndiff 4.287
46.000/1000 after 1 sec. | 213.1 MiB | K 20 | loss 2.466261998e+00 | Ndiff 4.076
47.000/1000 after 1 sec. | 213.1 MiB | K 20 | loss 2.462175093e+00 | Ndiff 3.648
48.000/1000 after 1 sec. | 213.1 MiB | K 20 | loss 2.458145085e+00 | Ndiff 3.068
49.000/1000 after 1 sec. | 213.1 MiB | K 20 | loss 2.453753531e+00 | Ndiff 2.496
50.000/1000 after 1 sec. | 213.1 MiB | K 20 | loss 2.441832238e+00 | Ndiff 2.043
51.000/1000 after 1 sec. | 213.1 MiB | K 20 | loss 2.437414963e+00 | Ndiff 1.868
52.000/1000 after 1 sec. | 213.1 MiB | K 20 | loss 2.435586138e+00 | Ndiff 1.786
53.000/1000 after 1 sec. | 213.1 MiB | K 20 | loss 2.433555668e+00 | Ndiff 1.617
54.000/1000 after 1 sec. | 213.1 MiB | K 20 | loss 2.431430213e+00 | Ndiff 1.335
55.000/1000 after 1 sec. | 213.1 MiB | K 20 | loss 2.429483776e+00 | Ndiff 0.970
56.000/1000 after 1 sec. | 213.1 MiB | K 20 | loss 2.428084819e+00 | Ndiff 0.626
57.000/1000 after 1 sec. | 213.1 MiB | K 20 | loss 2.427098172e+00 | Ndiff 0.419
58.000/1000 after 1 sec. | 213.1 MiB | K 20 | loss 2.425535651e+00 | Ndiff 0.330
59.000/1000 after 1 sec. | 213.1 MiB | K 20 | loss 2.415976963e+00 | Ndiff 0.153
60.000/1000 after 1 sec. | 213.1 MiB | K 20 | loss 2.412692266e+00 | Ndiff 0.026
61.000/1000 after 1 sec. | 213.1 MiB | K 20 | loss 2.412692093e+00 | Ndiff 0.016
62.000/1000 after 1 sec. | 213.1 MiB | K 20 | loss 2.412692005e+00 | Ndiff 0.011
63.000/1000 after 1 sec. | 213.1 MiB | K 20 | loss 2.412691958e+00 | Ndiff 0.008
64.000/1000 after 1 sec. | 213.1 MiB | K 20 | loss 2.412691932e+00 | Ndiff 0.006
65.000/1000 after 1 sec. | 213.1 MiB | K 20 | loss 2.412691918e+00 | Ndiff 0.005
66.000/1000 after 1 sec. | 213.1 MiB | K 20 | loss 2.412691910e+00 | Ndiff 0.003
67.000/1000 after 1 sec. | 213.1 MiB | K 20 | loss 2.412691906e+00 | Ndiff 0.003
68.000/1000 after 1 sec. | 213.1 MiB | K 20 | loss 2.412691904e+00 | Ndiff 0.002
69.000/1000 after 1 sec. | 213.1 MiB | K 20 | loss 2.412691903e+00 | Ndiff 0.001
70.000/1000 after 1 sec. | 213.1 MiB | K 20 | loss 2.412691902e+00 | Ndiff 0.001
71.000/1000 after 1 sec. | 213.1 MiB | K 20 | loss 2.412691901e+00 | Ndiff 0.001
72.000/1000 after 1 sec. | 213.1 MiB | K 20 | loss 2.412691901e+00 | Ndiff 0.001
73.000/1000 after 1 sec. | 213.1 MiB | K 20 | loss 2.412691901e+00 | Ndiff 0.000
74.000/1000 after 1 sec. | 213.1 MiB | K 20 | loss 2.412691901e+00 | Ndiff 0.000
75.000/1000 after 1 sec. | 213.1 MiB | K 20 | loss 2.412691901e+00 | Ndiff 0.000
76.000/1000 after 1 sec. | 213.1 MiB | K 20 | loss 2.412691901e+00 | Ndiff 0.000
77.000/1000 after 1 sec. | 213.1 MiB | K 20 | loss 2.412691901e+00 | Ndiff 0.000
78.000/1000 after 1 sec. | 213.1 MiB | K 20 | loss 2.412691901e+00 | Ndiff 0.000
79.000/1000 after 1 sec. | 213.1 MiB | K 20 | loss 2.412691901e+00 | Ndiff 0.000
... done. converged.
Assume full covariances.
Start with too many clusters (K=20)
full_trained_model, full_info_dict = bnpy.run(
dataset, 'DPMixtureModel', 'Gauss', 'VB',
output_path='/tmp/faithful/showcase-K=20-lik=Gauss-ECovMat=5*eye/',
nLap=1000, nTask=1, nBatch=1, convergeThr=0.0001,
gamma0=gamma, sF=sF, ECovMat='eye',
K=K, initname='randexamples',
)
show_clusters_over_time(full_info_dict['task_output_path'])
WARNING: Found unrecognized keyword args. These are ignored.
--nBatch
Dataset Summary:
X Data
num examples: 272
num dims: 2
Allocation Model: DP mixture with K=0. Concentration gamma0= 5.00
Obs. Data Model: Gaussian with full covariance.
Obs. Data Prior: Gauss-Wishart on mean and covar of each cluster
E[ mean[k] ] =
[0. 0.]
E[ covar[k] ] =
[[5. 0.]
[0. 5.]]
Initialization:
initname = randexamples
K = 20 (number of clusters)
seed = 1607680
elapsed_time: 0.0 sec
Learn Alg: VB | task 1/1 | alg. seed: 1607680 | data order seed: 8541952
task_output_path: /tmp/faithful/showcase-K=20-lik=Gauss-ECovMat=5*eye/1
1/1000 after 0 sec. | 215.6 MiB | K 20 | loss 2.904406498e+00 |
2/1000 after 0 sec. | 215.6 MiB | K 20 | loss 2.876066302e+00 | Ndiff 2.071
3/1000 after 0 sec. | 215.6 MiB | K 20 | loss 2.864944306e+00 | Ndiff 2.664
4/1000 after 0 sec. | 215.6 MiB | K 20 | loss 2.852026479e+00 | Ndiff 3.160
5/1000 after 0 sec. | 215.6 MiB | K 20 | loss 2.836506635e+00 | Ndiff 3.522
6/1000 after 0 sec. | 215.6 MiB | K 20 | loss 2.817804673e+00 | Ndiff 3.636
7/1000 after 0 sec. | 215.6 MiB | K 20 | loss 2.787080350e+00 | Ndiff 3.495
8/1000 after 0 sec. | 215.6 MiB | K 20 | loss 2.743017730e+00 | Ndiff 3.349
9/1000 after 0 sec. | 215.6 MiB | K 20 | loss 2.714906517e+00 | Ndiff 3.494
10/1000 after 0 sec. | 215.6 MiB | K 20 | loss 2.670842807e+00 | Ndiff 4.089
11/1000 after 0 sec. | 215.6 MiB | K 20 | loss 2.640944068e+00 | Ndiff 4.696
12/1000 after 0 sec. | 215.6 MiB | K 20 | loss 2.600612644e+00 | Ndiff 5.201
13/1000 after 0 sec. | 215.6 MiB | K 20 | loss 2.582832638e+00 | Ndiff 5.436
14/1000 after 0 sec. | 215.6 MiB | K 20 | loss 2.569005251e+00 | Ndiff 5.209
15/1000 after 0 sec. | 215.6 MiB | K 20 | loss 2.544279312e+00 | Ndiff 4.376
16/1000 after 0 sec. | 215.6 MiB | K 20 | loss 2.523247072e+00 | Ndiff 2.925
17/1000 after 0 sec. | 215.6 MiB | K 20 | loss 2.511328327e+00 | Ndiff 1.441
18/1000 after 1 sec. | 215.6 MiB | K 20 | loss 2.503251361e+00 | Ndiff 1.441
19/1000 after 1 sec. | 215.6 MiB | K 20 | loss 2.491852524e+00 | Ndiff 1.448
20/1000 after 1 sec. | 215.6 MiB | K 20 | loss 2.490247584e+00 | Ndiff 1.447
21/1000 after 1 sec. | 215.6 MiB | K 20 | loss 2.488483904e+00 | Ndiff 1.524
22/1000 after 1 sec. | 215.6 MiB | K 20 | loss 2.486427687e+00 | Ndiff 1.578
23/1000 after 1 sec. | 215.6 MiB | K 20 | loss 2.484005616e+00 | Ndiff 1.586
24/1000 after 1 sec. | 215.6 MiB | K 20 | loss 2.481221463e+00 | Ndiff 1.527
25/1000 after 1 sec. | 215.6 MiB | K 20 | loss 2.478177971e+00 | Ndiff 1.375
26/1000 after 1 sec. | 215.6 MiB | K 20 | loss 2.475094563e+00 | Ndiff 1.194
27/1000 after 1 sec. | 215.6 MiB | K 20 | loss 2.471974138e+00 | Ndiff 1.209
28/1000 after 1 sec. | 215.6 MiB | K 20 | loss 2.467024647e+00 | Ndiff 1.213
29/1000 after 1 sec. | 215.6 MiB | K 20 | loss 2.453713427e+00 | Ndiff 1.207
30/1000 after 1 sec. | 215.6 MiB | K 20 | loss 2.445147652e+00 | Ndiff 1.191
31/1000 after 1 sec. | 215.6 MiB | K 20 | loss 2.439161802e+00 | Ndiff 1.170
32/1000 after 1 sec. | 215.6 MiB | K 20 | loss 2.438783471e+00 | Ndiff 1.143
33/1000 after 1 sec. | 215.6 MiB | K 20 | loss 2.438429230e+00 | Ndiff 1.114
34/1000 after 1 sec. | 215.6 MiB | K 20 | loss 2.438098475e+00 | Ndiff 1.082
35/1000 after 1 sec. | 215.6 MiB | K 20 | loss 2.437790556e+00 | Ndiff 1.051
36/1000 after 1 sec. | 215.6 MiB | K 20 | loss 2.437503774e+00 | Ndiff 1.023
37/1000 after 1 sec. | 215.6 MiB | K 20 | loss 2.437235314e+00 | Ndiff 1.000
38/1000 after 1 sec. | 215.6 MiB | K 20 | loss 2.436981439e+00 | Ndiff 0.984
39/1000 after 1 sec. | 215.6 MiB | K 20 | loss 2.436737744e+00 | Ndiff 0.977
40/1000 after 1 sec. | 215.6 MiB | K 20 | loss 2.436499373e+00 | Ndiff 0.979
41/1000 after 1 sec. | 215.6 MiB | K 20 | loss 2.436261156e+00 | Ndiff 0.992
42/1000 after 1 sec. | 215.6 MiB | K 20 | loss 2.436017666e+00 | Ndiff 1.014
43/1000 after 1 sec. | 215.6 MiB | K 20 | loss 2.435763198e+00 | Ndiff 1.048
44/1000 after 1 sec. | 215.6 MiB | K 20 | loss 2.435491737e+00 | Ndiff 1.091
45/1000 after 1 sec. | 215.6 MiB | K 20 | loss 2.435196931e+00 | Ndiff 1.145
46/1000 after 1 sec. | 215.6 MiB | K 20 | loss 2.434872198e+00 | Ndiff 1.209
47/1000 after 1 sec. | 215.6 MiB | K 20 | loss 2.434511031e+00 | Ndiff 1.279
48/1000 after 1 sec. | 215.6 MiB | K 20 | loss 2.434107639e+00 | Ndiff 1.354
49/1000 after 1 sec. | 215.6 MiB | K 20 | loss 2.433657974e+00 | Ndiff 1.428
50/1000 after 1 sec. | 215.6 MiB | K 20 | loss 2.433161085e+00 | Ndiff 1.498
51/1000 after 1 sec. | 215.6 MiB | K 20 | loss 2.432620547e+00 | Ndiff 1.556
52/1000 after 1 sec. | 215.6 MiB | K 20 | loss 2.432045498e+00 | Ndiff 1.597
53/1000 after 2 sec. | 215.6 MiB | K 20 | loss 2.431450777e+00 | Ndiff 1.617
54/1000 after 2 sec. | 215.6 MiB | K 20 | loss 2.430855677e+00 | Ndiff 1.617
55/1000 after 2 sec. | 215.6 MiB | K 20 | loss 2.430280760e+00 | Ndiff 1.600
56/1000 after 2 sec. | 215.6 MiB | K 20 | loss 2.429742357e+00 | Ndiff 1.574
57/1000 after 2 sec. | 215.6 MiB | K 20 | loss 2.429246005e+00 | Ndiff 1.547
58/1000 after 2 sec. | 215.6 MiB | K 20 | loss 2.428783131e+00 | Ndiff 1.530
59/1000 after 2 sec. | 215.6 MiB | K 20 | loss 2.428335044e+00 | Ndiff 1.531
60/1000 after 2 sec. | 215.6 MiB | K 20 | loss 2.427882144e+00 | Ndiff 1.554
61/1000 after 2 sec. | 215.6 MiB | K 20 | loss 2.427412198e+00 | Ndiff 1.598
62/1000 after 2 sec. | 215.6 MiB | K 20 | loss 2.426923365e+00 | Ndiff 1.660
63/1000 after 2 sec. | 215.6 MiB | K 20 | loss 2.426419184e+00 | Ndiff 1.735
64/1000 after 2 sec. | 215.6 MiB | K 20 | loss 2.425898789e+00 | Ndiff 1.819
65/1000 after 2 sec. | 215.6 MiB | K 20 | loss 2.425353548e+00 | Ndiff 1.909
66/1000 after 2 sec. | 215.6 MiB | K 20 | loss 2.424771068e+00 | Ndiff 2.006
67/1000 after 2 sec. | 215.6 MiB | K 20 | loss 2.424138280e+00 | Ndiff 2.111
68/1000 after 2 sec. | 215.6 MiB | K 20 | loss 2.423441574e+00 | Ndiff 2.224
69/1000 after 2 sec. | 215.6 MiB | K 20 | loss 2.422666132e+00 | Ndiff 2.343
70/1000 after 2 sec. | 215.6 MiB | K 20 | loss 2.421796837e+00 | Ndiff 2.464
71/1000 after 2 sec. | 215.6 MiB | K 20 | loss 2.420824917e+00 | Ndiff 2.578
72/1000 after 2 sec. | 215.6 MiB | K 20 | loss 2.419763730e+00 | Ndiff 2.674
73/1000 after 2 sec. | 215.6 MiB | K 20 | loss 2.418655856e+00 | Ndiff 2.744
74/1000 after 2 sec. | 215.6 MiB | K 20 | loss 2.417536411e+00 | Ndiff 2.794
75/1000 after 2 sec. | 215.6 MiB | K 20 | loss 2.416396660e+00 | Ndiff 2.833
76/1000 after 2 sec. | 215.6 MiB | K 20 | loss 2.415209641e+00 | Ndiff 2.867
77/1000 after 2 sec. | 215.6 MiB | K 20 | loss 2.413957301e+00 | Ndiff 2.892
78/1000 after 2 sec. | 215.6 MiB | K 20 | loss 2.412627524e+00 | Ndiff 2.908
79/1000 after 2 sec. | 215.6 MiB | K 20 | loss 2.411204929e+00 | Ndiff 2.915
80/1000 after 2 sec. | 215.6 MiB | K 20 | loss 2.409665366e+00 | Ndiff 2.917
81/1000 after 2 sec. | 215.6 MiB | K 20 | loss 2.407972796e+00 | Ndiff 2.914
82/1000 after 2 sec. | 215.6 MiB | K 20 | loss 2.406075452e+00 | Ndiff 2.906
83/1000 after 2 sec. | 215.6 MiB | K 20 | loss 2.403897086e+00 | Ndiff 2.888
84/1000 after 2 sec. | 215.6 MiB | K 20 | loss 2.401317119e+00 | Ndiff 2.847
85/1000 after 2 sec. | 215.6 MiB | K 20 | loss 2.398127190e+00 | Ndiff 2.749
86/1000 after 3 sec. | 215.6 MiB | K 20 | loss 2.393870554e+00 | Ndiff 2.514
87/1000 after 3 sec. | 215.6 MiB | K 20 | loss 2.385927886e+00 | Ndiff 2.011
88/1000 after 3 sec. | 215.6 MiB | K 20 | loss 2.373557979e+00 | Ndiff 1.320
89/1000 after 3 sec. | 215.6 MiB | K 20 | loss 2.370431956e+00 | Ndiff 0.732
90/1000 after 3 sec. | 215.6 MiB | K 20 | loss 2.361734581e+00 | Ndiff 0.225
91/1000 after 3 sec. | 215.6 MiB | K 20 | loss 2.356420337e+00 | Ndiff 0.052
92/1000 after 3 sec. | 215.6 MiB | K 20 | loss 2.356419191e+00 | Ndiff 0.040
93/1000 after 3 sec. | 215.6 MiB | K 20 | loss 2.356418521e+00 | Ndiff 0.031
94/1000 after 3 sec. | 215.6 MiB | K 20 | loss 2.356418120e+00 | Ndiff 0.024
95/1000 after 3 sec. | 215.6 MiB | K 20 | loss 2.356417878e+00 | Ndiff 0.019
96/1000 after 3 sec. | 215.6 MiB | K 20 | loss 2.356417731e+00 | Ndiff 0.015
97/1000 after 3 sec. | 215.6 MiB | K 20 | loss 2.356417640e+00 | Ndiff 0.011
98/1000 after 3 sec. | 215.6 MiB | K 20 | loss 2.356417584e+00 | Ndiff 0.009
99/1000 after 3 sec. | 215.6 MiB | K 20 | loss 2.356417550e+00 | Ndiff 0.007
100/1000 after 3 sec. | 215.6 MiB | K 20 | loss 2.356417528e+00 | Ndiff 0.006
101/1000 after 3 sec. | 215.6 MiB | K 20 | loss 2.356417515e+00 | Ndiff 0.004
102/1000 after 3 sec. | 215.6 MiB | K 20 | loss 2.356417507e+00 | Ndiff 0.003
103/1000 after 3 sec. | 215.6 MiB | K 20 | loss 2.356417502e+00 | Ndiff 0.003
104/1000 after 3 sec. | 215.6 MiB | K 20 | loss 2.356417499e+00 | Ndiff 0.002
105/1000 after 3 sec. | 215.6 MiB | K 20 | loss 2.356417497e+00 | Ndiff 0.002
106/1000 after 3 sec. | 215.6 MiB | K 20 | loss 2.356417495e+00 | Ndiff 0.001
107/1000 after 3 sec. | 215.6 MiB | K 20 | loss 2.356417495e+00 | Ndiff 0.001
108/1000 after 3 sec. | 215.6 MiB | K 20 | loss 2.356417494e+00 | Ndiff 0.001
109/1000 after 3 sec. | 215.6 MiB | K 20 | loss 2.356417494e+00 | Ndiff 0.001
110/1000 after 3 sec. | 215.6 MiB | K 20 | loss 2.356417494e+00 | Ndiff 0.001
111/1000 after 3 sec. | 215.6 MiB | K 20 | loss 2.356417493e+00 | Ndiff 0.000
112/1000 after 3 sec. | 215.6 MiB | K 20 | loss 2.356417493e+00 | Ndiff 0.000
113/1000 after 3 sec. | 215.6 MiB | K 20 | loss 2.356417493e+00 | Ndiff 0.000
114/1000 after 3 sec. | 215.6 MiB | K 20 | loss 2.356417493e+00 | Ndiff 0.000
115/1000 after 3 sec. | 215.6 MiB | K 20 | loss 2.356417493e+00 | Ndiff 0.000
116/1000 after 3 sec. | 215.6 MiB | K 20 | loss 2.356417493e+00 | Ndiff 0.000
117/1000 after 3 sec. | 215.6 MiB | K 20 | loss 2.356417493e+00 | Ndiff 0.000
... done. converged.
Assume full covariances and fix all means to zero.
Start with too many clusters (K=20)
zm_trained_model, zm_info_dict = bnpy.run(
dataset, 'DPMixtureModel', 'ZeroMeanGauss', 'VB',
output_path='/tmp/faithful/showcase-K=20-lik=ZeroMeanGauss-ECovMat=5*eye/',
nLap=1000, nTask=1, nBatch=1, convergeThr=0.0001,
gamma0=gamma, sF=sF, ECovMat='eye',
K=K, initname='randexamples',
)
show_clusters_over_time(zm_info_dict['task_output_path'])
WARNING: Found unrecognized keyword args. These are ignored.
--nBatch
Dataset Summary:
X Data
num examples: 272
num dims: 2
Allocation Model: DP mixture with K=0. Concentration gamma0= 5.00
Obs. Data Model: Gaussian with fixed zero means, full covariance.
Obs. Data Prior: Wishart on prec matrix Lam
E[ CovMat[k] ] =
[[5. 0.]
[0. 5.]]
Initialization:
initname = randexamples
K = 20 (number of clusters)
seed = 1607680
elapsed_time: 0.0 sec
Learn Alg: VB | task 1/1 | alg. seed: 1607680 | data order seed: 8541952
task_output_path: /tmp/faithful/showcase-K=20-lik=ZeroMeanGauss-ECovMat=5*eye/1
1/1000 after 0 sec. | 217.2 MiB | K 20 | loss 4.019419551e+00 |
2/1000 after 0 sec. | 217.2 MiB | K 20 | loss 3.989063967e+00 | Ndiff 3.437
3/1000 after 0 sec. | 217.2 MiB | K 20 | loss 3.941416771e+00 | Ndiff 3.621
4/1000 after 0 sec. | 217.2 MiB | K 20 | loss 3.899222162e+00 | Ndiff 3.882
5/1000 after 0 sec. | 217.2 MiB | K 20 | loss 3.876371765e+00 | Ndiff 4.129
6/1000 after 0 sec. | 217.2 MiB | K 20 | loss 3.842038547e+00 | Ndiff 4.274
7/1000 after 0 sec. | 217.2 MiB | K 20 | loss 3.808536791e+00 | Ndiff 4.341
8/1000 after 0 sec. | 217.2 MiB | K 20 | loss 3.779782054e+00 | Ndiff 4.418
9/1000 after 0 sec. | 217.2 MiB | K 20 | loss 3.755287922e+00 | Ndiff 4.502
10/1000 after 0 sec. | 217.2 MiB | K 20 | loss 3.733837701e+00 | Ndiff 4.577
11/1000 after 0 sec. | 217.2 MiB | K 20 | loss 3.700117498e+00 | Ndiff 4.646
12/1000 after 0 sec. | 217.2 MiB | K 20 | loss 3.694082693e+00 | Ndiff 4.826
13/1000 after 0 sec. | 217.2 MiB | K 20 | loss 3.686765019e+00 | Ndiff 4.986
14/1000 after 0 sec. | 217.2 MiB | K 20 | loss 3.675117147e+00 | Ndiff 5.045
15/1000 after 0 sec. | 217.2 MiB | K 20 | loss 3.652297952e+00 | Ndiff 4.998
16/1000 after 1 sec. | 217.2 MiB | K 20 | loss 3.639249110e+00 | Ndiff 4.933
17/1000 after 1 sec. | 217.2 MiB | K 20 | loss 3.603588941e+00 | Ndiff 4.758
18/1000 after 1 sec. | 217.2 MiB | K 20 | loss 3.598749399e+00 | Ndiff 4.729
19/1000 after 1 sec. | 217.2 MiB | K 20 | loss 3.592583511e+00 | Ndiff 4.761
20/1000 after 1 sec. | 217.2 MiB | K 20 | loss 3.574936974e+00 | Ndiff 4.760
21/1000 after 1 sec. | 217.2 MiB | K 20 | loss 3.572459120e+00 | Ndiff 4.885
22/1000 after 1 sec. | 217.2 MiB | K 20 | loss 3.570977707e+00 | Ndiff 5.084
23/1000 after 1 sec. | 217.2 MiB | K 20 | loss 3.569306514e+00 | Ndiff 5.272
24/1000 after 1 sec. | 217.2 MiB | K 20 | loss 3.567412054e+00 | Ndiff 5.442
25/1000 after 1 sec. | 217.2 MiB | K 20 | loss 3.565249330e+00 | Ndiff 5.588
26/1000 after 1 sec. | 217.2 MiB | K 20 | loss 3.562749369e+00 | Ndiff 5.701
27/1000 after 1 sec. | 217.2 MiB | K 20 | loss 3.559772804e+00 | Ndiff 5.771
28/1000 after 1 sec. | 217.2 MiB | K 20 | loss 3.555894232e+00 | Ndiff 5.784
29/1000 after 1 sec. | 217.2 MiB | K 20 | loss 3.549078207e+00 | Ndiff 5.714
30/1000 after 1 sec. | 217.2 MiB | K 20 | loss 3.531303622e+00 | Ndiff 5.559
31/1000 after 1 sec. | 217.2 MiB | K 20 | loss 3.529382170e+00 | Ndiff 5.572
32/1000 after 1 sec. | 217.2 MiB | K 20 | loss 3.527237440e+00 | Ndiff 5.642
33/1000 after 1 sec. | 217.2 MiB | K 20 | loss 3.524655304e+00 | Ndiff 5.674
34/1000 after 1 sec. | 217.2 MiB | K 20 | loss 3.521314832e+00 | Ndiff 5.648
35/1000 after 1 sec. | 217.2 MiB | K 20 | loss 3.516111933e+00 | Ndiff 5.520
36/1000 after 1 sec. | 217.2 MiB | K 20 | loss 3.502294322e+00 | Ndiff 5.202
37/1000 after 1 sec. | 217.2 MiB | K 20 | loss 3.491586191e+00 | Ndiff 4.853
38/1000 after 1 sec. | 217.2 MiB | K 20 | loss 3.485039030e+00 | Ndiff 4.637
39/1000 after 1 sec. | 217.2 MiB | K 20 | loss 3.468792370e+00 | Ndiff 4.344
40/1000 after 1 sec. | 217.2 MiB | K 20 | loss 3.467909258e+00 | Ndiff 4.313
41/1000 after 1 sec. | 217.2 MiB | K 20 | loss 3.466944463e+00 | Ndiff 4.367
42/1000 after 1 sec. | 217.2 MiB | K 20 | loss 3.465874350e+00 | Ndiff 4.413
43/1000 after 1 sec. | 217.2 MiB | K 20 | loss 3.464679105e+00 | Ndiff 4.447
44/1000 after 1 sec. | 217.2 MiB | K 20 | loss 3.463333221e+00 | Ndiff 4.465
45/1000 after 2 sec. | 217.2 MiB | K 20 | loss 3.461803293e+00 | Ndiff 4.461
46/1000 after 2 sec. | 217.2 MiB | K 20 | loss 3.460044727e+00 | Ndiff 4.426
47/1000 after 2 sec. | 217.2 MiB | K 20 | loss 3.457996608e+00 | Ndiff 4.349
48/1000 after 2 sec. | 217.2 MiB | K 20 | loss 3.455573178e+00 | Ndiff 4.210
49/1000 after 2 sec. | 217.2 MiB | K 20 | loss 3.452647608e+00 | Ndiff 3.983
50/1000 after 2 sec. | 217.2 MiB | K 20 | loss 3.449011332e+00 | Ndiff 3.629
51/1000 after 2 sec. | 217.2 MiB | K 20 | loss 3.444220039e+00 | Ndiff 3.091
52/1000 after 2 sec. | 217.2 MiB | K 20 | loss 3.436689553e+00 | Ndiff 2.295
53/1000 after 2 sec. | 217.2 MiB | K 20 | loss 3.416217624e+00 | Ndiff 1.184
54/1000 after 2 sec. | 217.2 MiB | K 20 | loss 3.397676918e+00 | Ndiff 0.269
55/1000 after 2 sec. | 217.2 MiB | K 20 | loss 3.395369454e+00 | Ndiff 0.002
56/1000 after 2 sec. | 217.2 MiB | K 20 | loss 3.395369454e+00 | Ndiff 0.000
... done. converged.
Assume full covariances and fix all means to zero.
Start with too many clusters (K=20)
stoch_trained_model, stoch_info_dict = bnpy.run(
dataset, 'DPMixtureModel', 'Gauss', 'soVB',
output_path=\
'/tmp/faithful/showcase-K=20-lik=Gauss-ECovMat=5*eye-alg=soVB/',
nLap=50, nTask=1, nBatch=50,
rhoexp=0.51, rhodelay=1.0,
gamma0=gamma, sF=sF, ECovMat='eye',
K=K, initname='randexamples',
)
show_clusters_over_time(stoch_info_dict['task_output_path'])
Dataset Summary:
X Data
total size: 272 units
batch size: 6 units
num. batches: 50
Allocation Model: DP mixture with K=0. Concentration gamma0= 5.00
Obs. Data Model: Gaussian with full covariance.
Obs. Data Prior: Gauss-Wishart on mean and covar of each cluster
E[ mean[k] ] =
[0. 0.]
E[ covar[k] ] =
[[5. 0.]
[0. 5.]]
Initialization:
initname = randexamples
K = 20 (number of clusters)
seed = 1607680
elapsed_time: 0.0 sec
Learn Alg: soVB | task 1/1 | alg. seed: 1607680 | data order seed: 8541952
task_output_path: /tmp/faithful/showcase-K=20-lik=Gauss-ECovMat=5*eye-alg=soVB/1
0.020/50 after 0 sec. | 217.2 MiB | K 20 | loss 3.020105590e+01 | lrate 0.7022
0.040/50 after 0 sec. | 217.2 MiB | K 20 | loss 1.693550049e+01 | lrate 0.5710
0.060/50 after 0 sec. | 217.2 MiB | K 20 | loss 1.182159020e+01 | lrate 0.4931
1.000/50 after 1 sec. | 217.2 MiB | K 20 | loss 1.314590658e+01 | lrate 0.1346
2.000/50 after 1 sec. | 217.2 MiB | K 20 | loss 2.522804913e+00 | lrate 0.0950
3.000/50 after 2 sec. | 217.2 MiB | K 20 | loss 2.507513303e+00 | lrate 0.0774
4.000/50 after 3 sec. | 217.2 MiB | K 20 | loss 2.498366509e+00 | lrate 0.0669
5.000/50 after 3 sec. | 217.2 MiB | K 20 | loss 2.475120548e+00 | lrate 0.0597
6.000/50 after 4 sec. | 217.2 MiB | K 20 | loss 2.469641662e+00 | lrate 0.0544
7.000/50 after 5 sec. | 217.2 MiB | K 20 | loss 2.470932762e+00 | lrate 0.0503
8.000/50 after 5 sec. | 217.2 MiB | K 20 | loss 2.466158362e+00 | lrate 0.0470
9.000/50 after 6 sec. | 217.2 MiB | K 20 | loss 2.465517619e+00 | lrate 0.0443
10.000/50 after 7 sec. | 217.2 MiB | K 20 | loss 2.465435382e+00 | lrate 0.0420
11.000/50 after 7 sec. | 217.2 MiB | K 20 | loss 2.467696036e+00 | lrate 0.0400
12.000/50 after 8 sec. | 217.2 MiB | K 20 | loss 2.461949856e+00 | lrate 0.0383
13.000/50 after 8 sec. | 217.2 MiB | K 20 | loss 2.461043110e+00 | lrate 0.0367
14.000/50 after 9 sec. | 217.2 MiB | K 20 | loss 2.461264382e+00 | lrate 0.0354
15.000/50 after 10 sec. | 217.2 MiB | K 20 | loss 2.458851451e+00 | lrate 0.0342
16.000/50 after 10 sec. | 217.2 MiB | K 20 | loss 2.454655643e+00 | lrate 0.0330
17.000/50 after 11 sec. | 217.2 MiB | K 20 | loss 2.451234188e+00 | lrate 0.0320
18.000/50 after 12 sec. | 217.2 MiB | K 20 | loss 2.446281746e+00 | lrate 0.0311
19.000/50 after 12 sec. | 217.2 MiB | K 20 | loss 2.440793273e+00 | lrate 0.0303
20.000/50 after 13 sec. | 217.2 MiB | K 20 | loss 2.434143198e+00 | lrate 0.0295
21.000/50 after 14 sec. | 217.2 MiB | K 20 | loss 2.423545849e+00 | lrate 0.0288
22.000/50 after 14 sec. | 217.2 MiB | K 20 | loss 2.414924985e+00 | lrate 0.0281
23.000/50 after 15 sec. | 217.2 MiB | K 20 | loss 2.405741175e+00 | lrate 0.0275
24.000/50 after 15 sec. | 217.2 MiB | K 20 | loss 2.396273286e+00 | lrate 0.0269
25.000/50 after 16 sec. | 217.2 MiB | K 20 | loss 2.388169784e+00 | lrate 0.0263
26.000/50 after 17 sec. | 217.2 MiB | K 20 | loss 2.383086462e+00 | lrate 0.0258
27.000/50 after 17 sec. | 217.2 MiB | K 20 | loss 2.378566216e+00 | lrate 0.0253
28.000/50 after 18 sec. | 217.2 MiB | K 20 | loss 2.374468977e+00 | lrate 0.0248
29.000/50 after 19 sec. | 217.2 MiB | K 20 | loss 2.369573743e+00 | lrate 0.0244
30.000/50 after 19 sec. | 217.2 MiB | K 20 | loss 2.365030091e+00 | lrate 0.0240
31.000/50 after 20 sec. | 217.2 MiB | K 20 | loss 2.361840492e+00 | lrate 0.0236
32.000/50 after 21 sec. | 217.2 MiB | K 20 | loss 2.358180297e+00 | lrate 0.0232
33.000/50 after 21 sec. | 217.2 MiB | K 20 | loss 2.355370718e+00 | lrate 0.0229
34.000/50 after 22 sec. | 217.2 MiB | K 20 | loss 2.354001069e+00 | lrate 0.0225
35.000/50 after 23 sec. | 217.2 MiB | K 20 | loss 2.352655966e+00 | lrate 0.0222
36.000/50 after 23 sec. | 217.2 MiB | K 20 | loss 2.351706756e+00 | lrate 0.0219
37.000/50 after 24 sec. | 217.2 MiB | K 20 | loss 2.351346242e+00 | lrate 0.0216
38.000/50 after 24 sec. | 217.2 MiB | K 20 | loss 2.351470554e+00 | lrate 0.0213
39.000/50 after 25 sec. | 217.2 MiB | K 20 | loss 2.351530696e+00 | lrate 0.0210
40.000/50 after 26 sec. | 217.2 MiB | K 20 | loss 2.351708300e+00 | lrate 0.0207
41.000/50 after 26 sec. | 217.2 MiB | K 20 | loss 2.351489520e+00 | lrate 0.0205
42.000/50 after 27 sec. | 217.2 MiB | K 20 | loss 2.351625569e+00 | lrate 0.0202
43.000/50 after 28 sec. | 217.2 MiB | K 20 | loss 2.351357552e+00 | lrate 0.0200
44.000/50 after 28 sec. | 217.2 MiB | K 20 | loss 2.351536177e+00 | lrate 0.0197
45.000/50 after 29 sec. | 217.2 MiB | K 20 | loss 2.351543540e+00 | lrate 0.0195
46.000/50 after 30 sec. | 217.2 MiB | K 20 | loss 2.351346627e+00 | lrate 0.0193
47.000/50 after 30 sec. | 217.2 MiB | K 20 | loss 2.351293738e+00 | lrate 0.0191
48.000/50 after 31 sec. | 217.2 MiB | K 20 | loss 2.351507196e+00 | lrate 0.0189
49.000/50 after 31 sec. | 217.2 MiB | K 20 | loss 2.351496294e+00 | lrate 0.0187
50.000/50 after 32 sec. | 217.2 MiB | K 20 | loss 2.351523756e+00 | lrate 0.0185
... active. not converged.
pylab.figure()
pylab.plot(
zm_info_dict['lap_history'],
zm_info_dict['loss_history'], 'b.-',
label='full_covar zero_mean')
pylab.plot(
full_info_dict['lap_history'],
full_info_dict['loss_history'], 'k.-',
label='full_covar')
pylab.plot(
diag_info_dict['lap_history'],
diag_info_dict['loss_history'], 'r.-',
label='diag_covar')
pylab.plot(
stoch_info_dict['lap_history'],
stoch_info_dict['loss_history'], 'c.:',
label='full_covar stochastic')
pylab.legend(loc='upper right')
pylab.xlabel('num. laps')
pylab.ylabel('loss')
pylab.xlim([4, 100]) # avoid early iterations
pylab.ylim([2.34, 4.0]) # handpicked
pylab.draw()
pylab.tight_layout()
# E_proba_K : 1D array, size n_clusters
# Each entry gives expected probability of that cluster
E_proba_K = stoch_trained_model.allocModel.get_active_comp_probs()
print("probability of each cluster:")
print(E_proba_K)
probability of each cluster:
[0.004 0.004 0.004 0.004 0.004 0.004 0.627 0.003 0.003 0.003 0.003 0.003 0.003 0.003 0.312 0.003 0.002 0.002 0.002 0.001]
Remember that each cluster has the following approximate posterior over its mean vector $mu$ and covariance matrix $Sigma$:
$$ q(mu, Sigma) = Normal(mu | m, kappa Sigma) Wishart(Sigma | nu, S) $$
We show here how to compute the expected mean of $mu$ and $Sigma$ from a trained model.
for k in range(K):
E_mu_k = stoch_trained_model.obsModel.get_mean_for_comp(k)
E_Sigma_k = stoch_trained_model.obsModel.get_covar_mat_for_comp(k)
print("")
print("mean[k=%d]" % k)
print(E_mu_k)
print("covar[k=%d]" % k)
print(E_Sigma_k)
mean[k=0]
[0. 0.]
covar[k=0]
[[5. 0.]
[0. 5.]]
mean[k=1]
[0. 0.]
covar[k=1]
[[5. 0.]
[0. 5.]]
mean[k=2]
[0. 0.]
covar[k=2]
[[5. 0.]
[0. 5.]]
mean[k=3]
[0. 0.]
covar[k=3]
[[5. 0.]
[0. 5.]]
mean[k=4]
[0. 0.]
covar[k=4]
[[5. 0.]
[0. 5.]]
mean[k=5]
[0. 0.]
covar[k=5]
[[5. 0.]
[0. 5.]]
mean[k=6]
[ 4.304 79.93 ]
covar[k=6]
[[ 0.194 0.92 ]
[ 0.92 36.103]]
mean[k=7]
[0. 0.]
covar[k=7]
[[5. 0.]
[0. 5.]]
mean[k=8]
[0. 0.]
covar[k=8]
[[5. 0.]
[0. 5.]]
mean[k=9]
[0. 0.]
covar[k=9]
[[5. 0.]
[0. 5.]]
mean[k=10]
[0. 0.]
covar[k=10]
[[5. 0.]
[0. 5.]]
mean[k=11]
[0. 0.]
covar[k=11]
[[5. 0.]
[0. 5.]]
mean[k=12]
[0. 0.]
covar[k=12]
[[5. 0.]
[0. 5.]]
mean[k=13]
[0. 0.]
covar[k=13]
[[5. 0.]
[0. 5.]]
mean[k=14]
[ 2.034 54.415]
covar[k=14]
[[ 0.12 0.453]
[ 0.453 33.606]]
mean[k=15]
[0. 0.]
covar[k=15]
[[5. 0.]
[0. 5.]]
mean[k=16]
[0. 0.]
covar[k=16]
[[5. 0.]
[0. 5.]]
mean[k=17]
[0. 0.]
covar[k=17]
[[5. 0.]
[0. 5.]]
mean[k=18]
[0. 0.]
covar[k=18]
[[5. 0.]
[0. 5.]]
mean[k=19]
[0. 0.]
covar[k=19]
[[5. 0.]
[0. 5.]]
Total running time of the script: ( 0 minutes 42.005 seconds)