Variational training for Mixtures of Gaussians

Showcase of different models and algorithms applied to same dataset.

In this example, we show how bnpy makes it easy to apply different models and algorithms to the same dataset.

import bnpy
import numpy as np
import os

from matplotlib import pylab
import seaborn as sns

SMALL_FIG_SIZE = (2.5, 2.5)
FIG_SIZE = (5, 5)
pylab.rcParams['figure.figsize'] = FIG_SIZE

Load dataset from file

dataset_path = os.path.join(bnpy.DATASET_PATH, 'faithful')
dataset = bnpy.data.XData.read_csv(
    os.path.join(dataset_path, 'faithful.csv'))

Make a simple plot of the raw data

pylab.plot(dataset.X[:, 0], dataset.X[:, 1], 'k.')
pylab.xlabel(dataset.column_names[0])
pylab.ylabel(dataset.column_names[1])
pylab.tight_layout()
data_ax_h = pylab.gca()
../../_images/sphx_glr_plot-01-demo=vb_algs-model=mix_gauss_001.png

Setup: Helper function to display the learned clusters

def show_clusters_over_time(
        task_output_path=None,
        query_laps=[0, 1, 2, 10, 20, None],
        nrows=2):
    ''' Show 2D elliptical contours overlaid on raw data.
    '''
    ncols = int(np.ceil(len(query_laps) // float(nrows)))
    fig_handle, ax_handle_list = pylab.subplots(
        figsize=(SMALL_FIG_SIZE[0] * ncols, SMALL_FIG_SIZE[1] * nrows),
        nrows=nrows, ncols=ncols, sharex=True, sharey=True)
    for plot_id, lap_val in enumerate(query_laps):
        cur_model, lap_val = bnpy.load_model_at_lap(task_output_path, lap_val)
        cur_ax_handle = ax_handle_list.flatten()[plot_id]
        bnpy.viz.PlotComps.plotCompsFromHModel(
            cur_model, dataset=dataset, ax_handle=cur_ax_handle)
        cur_ax_handle.set_title("lap: %d" % lap_val)
        cur_ax_handle.set_xlabel(dataset.column_names[0])
        cur_ax_handle.set_ylabel(dataset.column_names[1])
        cur_ax_handle.set_xlim(data_ax_h.get_xlim())
        cur_ax_handle.set_ylim(data_ax_h.get_ylim())
    pylab.tight_layout()

DiagGauss observation model

Assume diagonal covariances.

Start with too many clusters (K=20)

gamma = 5.0
sF = 5.0
K = 20

diag_trained_model, diag_info_dict = bnpy.run(
    dataset, 'DPMixtureModel', 'DiagGauss', 'memoVB',
    output_path='/tmp/faithful/showcase-K=20-lik=DiagGauss-ECovMat=5*eye/',
    nLap=1000, nTask=1, nBatch=1, convergeThr=0.0001,
    gamma0=gamma, sF=sF, ECovMat='eye',
    K=K, initname='randexamples',
    )
show_clusters_over_time(diag_info_dict['task_output_path'])
../../_images/sphx_glr_plot-01-demo=vb_algs-model=mix_gauss_002.png

Out:

Dataset Summary:
X Data
  total size: 272 units
  batch size: 272 units
  num. batches: 1
Allocation Model:  DP mixture with K=0. Concentration gamma0= 5.00
Obs. Data  Model:  Gaussian with diagonal covariance.
Obs. Data  Prior:  independent Gauss-Wishart prior on each dimension
  Wishart params
    nu = 4
  beta = [ 10  10]
  Expectations
  E[  mean[k]] =
  [ 0  0]
  E[ covar[k]] =
  [[ 5.  0.]
   [ 0.  5.]]
Initialization:
  initname = randexamples
  K = 20 (number of clusters)
  seed = 1607680
  elapsed_time: 0.0 sec
Learn Alg: memoVB | task  1/1 | alg. seed: 1607680 | data order seed: 8541952
task_output_path: /tmp/faithful/showcase-K=20-lik=DiagGauss-ECovMat=5*eye/1
    1.000/1000 after      0 sec. |    158.6 MiB | K   20 | loss  3.002088161e+00 |
    2.000/1000 after      0 sec. |    158.6 MiB | K   20 | loss  2.966309748e+00 | Ndiff    4.584
    3.000/1000 after      0 sec. |    158.6 MiB | K   20 | loss  2.949498401e+00 | Ndiff    4.501
    4.000/1000 after      0 sec. |    158.6 MiB | K   20 | loss  2.931792100e+00 | Ndiff    5.452
    5.000/1000 after      0 sec. |    158.6 MiB | K   20 | loss  2.910174417e+00 | Ndiff    6.264
    6.000/1000 after      0 sec. |    158.6 MiB | K   20 | loss  2.881973764e+00 | Ndiff    6.557
    7.000/1000 after      0 sec. |    158.6 MiB | K   20 | loss  2.833491488e+00 | Ndiff    6.194
    8.000/1000 after      0 sec. |    158.6 MiB | K   20 | loss  2.793876249e+00 | Ndiff    5.582
    9.000/1000 after      0 sec. |    158.6 MiB | K   20 | loss  2.752249036e+00 | Ndiff    4.580
   10.000/1000 after      0 sec. |    158.6 MiB | K   20 | loss  2.698456106e+00 | Ndiff    4.740
   11.000/1000 after      0 sec. |    158.6 MiB | K   20 | loss  2.651622886e+00 | Ndiff    5.753
   12.000/1000 after      0 sec. |    158.6 MiB | K   20 | loss  2.624108159e+00 | Ndiff    6.393
   13.000/1000 after      0 sec. |    158.6 MiB | K   20 | loss  2.587572334e+00 | Ndiff    6.074
   14.000/1000 after      0 sec. |    158.6 MiB | K   20 | loss  2.560219008e+00 | Ndiff    4.596
   15.000/1000 after      0 sec. |    158.6 MiB | K   20 | loss  2.544887510e+00 | Ndiff    2.425
   16.000/1000 after      0 sec. |    158.6 MiB | K   20 | loss  2.527562170e+00 | Ndiff    1.537
   17.000/1000 after      0 sec. |    158.6 MiB | K   20 | loss  2.513204290e+00 | Ndiff    1.470
   18.000/1000 after      0 sec. |    158.6 MiB | K   20 | loss  2.507574530e+00 | Ndiff    1.454
   19.000/1000 after      0 sec. |    158.6 MiB | K   20 | loss  2.497639081e+00 | Ndiff    1.477
   20.000/1000 after      0 sec. |    158.6 MiB | K   20 | loss  2.497002974e+00 | Ndiff    1.526
   21.000/1000 after      0 sec. |    158.6 MiB | K   20 | loss  2.496368366e+00 | Ndiff    1.584
   22.000/1000 after      0 sec. |    158.6 MiB | K   20 | loss  2.495732569e+00 | Ndiff    1.638
   23.000/1000 after      0 sec. |    158.6 MiB | K   20 | loss  2.495103986e+00 | Ndiff    1.678
   24.000/1000 after      0 sec. |    158.6 MiB | K   20 | loss  2.494497370e+00 | Ndiff    1.697
   25.000/1000 after      0 sec. |    158.6 MiB | K   20 | loss  2.493928089e+00 | Ndiff    1.698
   26.000/1000 after      0 sec. |    158.6 MiB | K   20 | loss  2.493406054e+00 | Ndiff    1.687
   27.000/1000 after      0 sec. |    158.6 MiB | K   20 | loss  2.492931650e+00 | Ndiff    1.674
   28.000/1000 after      0 sec. |    158.6 MiB | K   20 | loss  2.492495409e+00 | Ndiff    1.670
   29.000/1000 after      0 sec. |    158.6 MiB | K   20 | loss  2.492081031e+00 | Ndiff    1.682
   30.000/1000 after      0 sec. |    158.6 MiB | K   20 | loss  2.491669505e+00 | Ndiff    1.715
   31.000/1000 after      0 sec. |    158.6 MiB | K   20 | loss  2.491242078e+00 | Ndiff    1.771
   32.000/1000 after      0 sec. |    158.6 MiB | K   20 | loss  2.490781261e+00 | Ndiff    1.851
   33.000/1000 after      0 sec. |    158.6 MiB | K   20 | loss  2.490270354e+00 | Ndiff    1.954
   34.000/1000 after      0 sec. |    158.6 MiB | K   20 | loss  2.489692235e+00 | Ndiff    2.082
   35.000/1000 after      0 sec. |    158.6 MiB | K   20 | loss  2.489027945e+00 | Ndiff    2.235
   36.000/1000 after      0 sec. |    158.6 MiB | K   20 | loss  2.488255286e+00 | Ndiff    2.414
   37.000/1000 after      0 sec. |    158.6 MiB | K   20 | loss  2.487347489e+00 | Ndiff    2.621
   38.000/1000 after      0 sec. |    158.6 MiB | K   20 | loss  2.486272015e+00 | Ndiff    2.855
   39.000/1000 after      0 sec. |    158.6 MiB | K   20 | loss  2.484989607e+00 | Ndiff    3.113
   40.000/1000 after      0 sec. |    158.6 MiB | K   20 | loss  2.483453885e+00 | Ndiff    3.391
   41.000/1000 after      0 sec. |    158.6 MiB | K   20 | loss  2.481612017e+00 | Ndiff    3.675
   42.000/1000 after      0 sec. |    158.6 MiB | K   20 | loss  2.479407351e+00 | Ndiff    3.945
   43.000/1000 after      0 sec. |    158.6 MiB | K   20 | loss  2.476785615e+00 | Ndiff    4.169
   44.000/1000 after      0 sec. |    158.6 MiB | K   20 | loss  2.473707547e+00 | Ndiff    4.301
   45.000/1000 after      0 sec. |    158.6 MiB | K   20 | loss  2.470173027e+00 | Ndiff    4.287
   46.000/1000 after      0 sec. |    158.6 MiB | K   20 | loss  2.466261998e+00 | Ndiff    4.076
   47.000/1000 after      0 sec. |    158.6 MiB | K   20 | loss  2.462175093e+00 | Ndiff    3.648
   48.000/1000 after      0 sec. |    158.6 MiB | K   20 | loss  2.458145085e+00 | Ndiff    3.068
   49.000/1000 after      0 sec. |    158.6 MiB | K   20 | loss  2.453753531e+00 | Ndiff    2.496
   50.000/1000 after      1 sec. |    158.6 MiB | K   20 | loss  2.441832238e+00 | Ndiff    2.043
   51.000/1000 after      1 sec. |    158.6 MiB | K   20 | loss  2.437414963e+00 | Ndiff    1.868
   52.000/1000 after      1 sec. |    158.6 MiB | K   20 | loss  2.435586138e+00 | Ndiff    1.786
   53.000/1000 after      1 sec. |    158.6 MiB | K   20 | loss  2.433555668e+00 | Ndiff    1.617
   54.000/1000 after      1 sec. |    158.6 MiB | K   20 | loss  2.431430213e+00 | Ndiff    1.335
   55.000/1000 after      1 sec. |    158.6 MiB | K   20 | loss  2.429483776e+00 | Ndiff    0.970
   56.000/1000 after      1 sec. |    158.6 MiB | K   20 | loss  2.428084819e+00 | Ndiff    0.626
   57.000/1000 after      1 sec. |    158.6 MiB | K   20 | loss  2.427098172e+00 | Ndiff    0.419
   58.000/1000 after      1 sec. |    158.6 MiB | K   20 | loss  2.425535651e+00 | Ndiff    0.330
   59.000/1000 after      1 sec. |    158.6 MiB | K   20 | loss  2.415976963e+00 | Ndiff    0.153
   60.000/1000 after      1 sec. |    158.6 MiB | K   20 | loss  2.412692266e+00 | Ndiff    0.026
   61.000/1000 after      1 sec. |    158.6 MiB | K   20 | loss  2.412692093e+00 | Ndiff    0.016
   62.000/1000 after      1 sec. |    158.6 MiB | K   20 | loss  2.412692005e+00 | Ndiff    0.011
   63.000/1000 after      1 sec. |    158.6 MiB | K   20 | loss  2.412691958e+00 | Ndiff    0.008
   64.000/1000 after      1 sec. |    158.6 MiB | K   20 | loss  2.412691932e+00 | Ndiff    0.006
   65.000/1000 after      1 sec. |    158.6 MiB | K   20 | loss  2.412691918e+00 | Ndiff    0.005
   66.000/1000 after      1 sec. |    158.6 MiB | K   20 | loss  2.412691910e+00 | Ndiff    0.003
   67.000/1000 after      1 sec. |    158.6 MiB | K   20 | loss  2.412691906e+00 | Ndiff    0.003
   68.000/1000 after      1 sec. |    158.6 MiB | K   20 | loss  2.412691904e+00 | Ndiff    0.002
   69.000/1000 after      1 sec. |    158.6 MiB | K   20 | loss  2.412691903e+00 | Ndiff    0.001
   70.000/1000 after      1 sec. |    158.6 MiB | K   20 | loss  2.412691902e+00 | Ndiff    0.001
   71.000/1000 after      1 sec. |    158.6 MiB | K   20 | loss  2.412691901e+00 | Ndiff    0.001
   72.000/1000 after      1 sec. |    158.6 MiB | K   20 | loss  2.412691901e+00 | Ndiff    0.001
   73.000/1000 after      1 sec. |    158.6 MiB | K   20 | loss  2.412691901e+00 | Ndiff    0.000
   74.000/1000 after      1 sec. |    158.6 MiB | K   20 | loss  2.412691901e+00 | Ndiff    0.000
   75.000/1000 after      1 sec. |    158.6 MiB | K   20 | loss  2.412691901e+00 | Ndiff    0.000
   76.000/1000 after      1 sec. |    158.6 MiB | K   20 | loss  2.412691901e+00 | Ndiff    0.000
   77.000/1000 after      1 sec. |    158.6 MiB | K   20 | loss  2.412691901e+00 | Ndiff    0.000
   78.000/1000 after      1 sec. |    158.6 MiB | K   20 | loss  2.412691901e+00 | Ndiff    0.000
   79.000/1000 after      1 sec. |    158.6 MiB | K   20 | loss  2.412691901e+00 | Ndiff    0.000
... done. converged.

Gauss observations + VB

Assume full covariances.

Start with too many clusters (K=20)

full_trained_model, full_info_dict = bnpy.run(
    dataset, 'DPMixtureModel', 'Gauss', 'VB',
    output_path='/tmp/faithful/showcase-K=20-lik=Gauss-ECovMat=5*eye/',
    nLap=1000, nTask=1, nBatch=1, convergeThr=0.0001,
    gamma0=gamma, sF=sF, ECovMat='eye',
    K=K, initname='randexamples',
    )
show_clusters_over_time(full_info_dict['task_output_path'])
../../_images/sphx_glr_plot-01-demo=vb_algs-model=mix_gauss_003.png

Out:

WARNING: Found unrecognized keyword args. These are ignored.
  --nBatch
Dataset Summary:
X Data
  num examples: 272
  num dims: 2
Allocation Model:  DP mixture with K=0. Concentration gamma0= 5.00
Obs. Data  Model:  Gaussian with full covariance.
Obs. Data  Prior:  Gauss-Wishart on mean and covar of each cluster
  E[  mean[k] ] =
   [ 0.  0.]
  E[ covar[k] ] =
  [[ 5.  0.]
   [ 0.  5.]]
Initialization:
  initname = randexamples
  K = 20 (number of clusters)
  seed = 1607680
  elapsed_time: 0.0 sec
Learn Alg: VB | task  1/1 | alg. seed: 1607680 | data order seed: 8541952
task_output_path: /tmp/faithful/showcase-K=20-lik=Gauss-ECovMat=5*eye/1
        1/1000 after      0 sec. |    154.6 MiB | K   20 | loss  2.904406498e+00 |
        2/1000 after      0 sec. |    154.6 MiB | K   20 | loss  2.876066302e+00 | Ndiff    2.071
        3/1000 after      0 sec. |    154.6 MiB | K   20 | loss  2.864944306e+00 | Ndiff    2.664
        4/1000 after      0 sec. |    154.6 MiB | K   20 | loss  2.852026479e+00 | Ndiff    3.160
        5/1000 after      0 sec. |    154.6 MiB | K   20 | loss  2.836506635e+00 | Ndiff    3.522
        6/1000 after      0 sec. |    154.6 MiB | K   20 | loss  2.817804673e+00 | Ndiff    3.636
        7/1000 after      0 sec. |    154.6 MiB | K   20 | loss  2.787080350e+00 | Ndiff    3.495
        8/1000 after      0 sec. |    154.6 MiB | K   20 | loss  2.743017730e+00 | Ndiff    3.349
        9/1000 after      0 sec. |    154.6 MiB | K   20 | loss  2.714906517e+00 | Ndiff    3.494
       10/1000 after      0 sec. |    154.6 MiB | K   20 | loss  2.670842807e+00 | Ndiff    4.089
       11/1000 after      0 sec. |    154.6 MiB | K   20 | loss  2.640944068e+00 | Ndiff    4.696
       12/1000 after      0 sec. |    154.6 MiB | K   20 | loss  2.600612644e+00 | Ndiff    5.201
       13/1000 after      0 sec. |    154.6 MiB | K   20 | loss  2.582832638e+00 | Ndiff    5.436
       14/1000 after      0 sec. |    154.6 MiB | K   20 | loss  2.569005251e+00 | Ndiff    5.209
       15/1000 after      0 sec. |    154.6 MiB | K   20 | loss  2.544279312e+00 | Ndiff    4.376
       16/1000 after      0 sec. |    154.6 MiB | K   20 | loss  2.523247072e+00 | Ndiff    2.925
       17/1000 after      0 sec. |    154.6 MiB | K   20 | loss  2.511328327e+00 | Ndiff    1.441
       18/1000 after      0 sec. |    154.6 MiB | K   20 | loss  2.503251361e+00 | Ndiff    1.441
       19/1000 after      0 sec. |    154.6 MiB | K   20 | loss  2.491852524e+00 | Ndiff    1.448
       20/1000 after      0 sec. |    154.6 MiB | K   20 | loss  2.490247584e+00 | Ndiff    1.447
       21/1000 after      0 sec. |    154.6 MiB | K   20 | loss  2.488483904e+00 | Ndiff    1.524
       22/1000 after      0 sec. |    154.6 MiB | K   20 | loss  2.486427687e+00 | Ndiff    1.578
       23/1000 after      0 sec. |    154.6 MiB | K   20 | loss  2.484005616e+00 | Ndiff    1.586
       24/1000 after      0 sec. |    154.6 MiB | K   20 | loss  2.481221463e+00 | Ndiff    1.527
       25/1000 after      0 sec. |    154.6 MiB | K   20 | loss  2.478177971e+00 | Ndiff    1.375
       26/1000 after      0 sec. |    154.6 MiB | K   20 | loss  2.475094563e+00 | Ndiff    1.194
       27/1000 after      0 sec. |    154.6 MiB | K   20 | loss  2.471974138e+00 | Ndiff    1.209
       28/1000 after      0 sec. |    154.6 MiB | K   20 | loss  2.467024647e+00 | Ndiff    1.213
       29/1000 after      0 sec. |    154.6 MiB | K   20 | loss  2.453713427e+00 | Ndiff    1.207
       30/1000 after      0 sec. |    154.6 MiB | K   20 | loss  2.445147652e+00 | Ndiff    1.191
       31/1000 after      0 sec. |    154.6 MiB | K   20 | loss  2.439161802e+00 | Ndiff    1.170
       32/1000 after      1 sec. |    154.6 MiB | K   20 | loss  2.438783471e+00 | Ndiff    1.143
       33/1000 after      1 sec. |    154.6 MiB | K   20 | loss  2.438429230e+00 | Ndiff    1.114
       34/1000 after      1 sec. |    154.6 MiB | K   20 | loss  2.438098475e+00 | Ndiff    1.082
       35/1000 after      1 sec. |    154.6 MiB | K   20 | loss  2.437790556e+00 | Ndiff    1.051
       36/1000 after      1 sec. |    154.6 MiB | K   20 | loss  2.437503774e+00 | Ndiff    1.023
       37/1000 after      1 sec. |    154.6 MiB | K   20 | loss  2.437235314e+00 | Ndiff    1.000
       38/1000 after      1 sec. |    154.6 MiB | K   20 | loss  2.436981439e+00 | Ndiff    0.984
       39/1000 after      1 sec. |    154.6 MiB | K   20 | loss  2.436737744e+00 | Ndiff    0.977
       40/1000 after      1 sec. |    154.6 MiB | K   20 | loss  2.436499373e+00 | Ndiff    0.979
       41/1000 after      1 sec. |    154.6 MiB | K   20 | loss  2.436261156e+00 | Ndiff    0.992
       42/1000 after      1 sec. |    154.6 MiB | K   20 | loss  2.436017666e+00 | Ndiff    1.014
       43/1000 after      1 sec. |    154.6 MiB | K   20 | loss  2.435763198e+00 | Ndiff    1.048
       44/1000 after      1 sec. |    154.6 MiB | K   20 | loss  2.435491737e+00 | Ndiff    1.091
       45/1000 after      1 sec. |    154.6 MiB | K   20 | loss  2.435196931e+00 | Ndiff    1.145
       46/1000 after      1 sec. |    154.6 MiB | K   20 | loss  2.434872198e+00 | Ndiff    1.209
       47/1000 after      1 sec. |    154.6 MiB | K   20 | loss  2.434511031e+00 | Ndiff    1.279
       48/1000 after      1 sec. |    154.6 MiB | K   20 | loss  2.434107639e+00 | Ndiff    1.354
       49/1000 after      1 sec. |    154.6 MiB | K   20 | loss  2.433657974e+00 | Ndiff    1.428
       50/1000 after      1 sec. |    154.6 MiB | K   20 | loss  2.433161085e+00 | Ndiff    1.498
       51/1000 after      1 sec. |    154.6 MiB | K   20 | loss  2.432620547e+00 | Ndiff    1.556
       52/1000 after      1 sec. |    154.6 MiB | K   20 | loss  2.432045498e+00 | Ndiff    1.597
       53/1000 after      1 sec. |    154.6 MiB | K   20 | loss  2.431450777e+00 | Ndiff    1.617
       54/1000 after      1 sec. |    154.6 MiB | K   20 | loss  2.430855677e+00 | Ndiff    1.617
       55/1000 after      1 sec. |    154.6 MiB | K   20 | loss  2.430280760e+00 | Ndiff    1.600
       56/1000 after      1 sec. |    154.6 MiB | K   20 | loss  2.429742357e+00 | Ndiff    1.574
       57/1000 after      1 sec. |    154.6 MiB | K   20 | loss  2.429246005e+00 | Ndiff    1.547
       58/1000 after      1 sec. |    154.6 MiB | K   20 | loss  2.428783131e+00 | Ndiff    1.530
       59/1000 after      1 sec. |    154.6 MiB | K   20 | loss  2.428335044e+00 | Ndiff    1.531
       60/1000 after      1 sec. |    154.6 MiB | K   20 | loss  2.427882144e+00 | Ndiff    1.554
       61/1000 after      1 sec. |    154.6 MiB | K   20 | loss  2.427412198e+00 | Ndiff    1.598
       62/1000 after      1 sec. |    154.6 MiB | K   20 | loss  2.426923365e+00 | Ndiff    1.660
       63/1000 after      1 sec. |    154.6 MiB | K   20 | loss  2.426419184e+00 | Ndiff    1.735
       64/1000 after      1 sec. |    154.6 MiB | K   20 | loss  2.425898789e+00 | Ndiff    1.819
       65/1000 after      1 sec. |    154.6 MiB | K   20 | loss  2.425353548e+00 | Ndiff    1.909
       66/1000 after      1 sec. |    154.6 MiB | K   20 | loss  2.424771068e+00 | Ndiff    2.006
       67/1000 after      1 sec. |    154.6 MiB | K   20 | loss  2.424138280e+00 | Ndiff    2.111
       68/1000 after      1 sec. |    154.6 MiB | K   20 | loss  2.423441574e+00 | Ndiff    2.224
       69/1000 after      1 sec. |    154.6 MiB | K   20 | loss  2.422666132e+00 | Ndiff    2.343
       70/1000 after      1 sec. |    154.6 MiB | K   20 | loss  2.421796837e+00 | Ndiff    2.464
       71/1000 after      1 sec. |    154.6 MiB | K   20 | loss  2.420824917e+00 | Ndiff    2.578
       72/1000 after      1 sec. |    154.6 MiB | K   20 | loss  2.419763730e+00 | Ndiff    2.674
       73/1000 after      1 sec. |    154.6 MiB | K   20 | loss  2.418655856e+00 | Ndiff    2.744
       74/1000 after      1 sec. |    154.6 MiB | K   20 | loss  2.417536411e+00 | Ndiff    2.794
       75/1000 after      1 sec. |    154.6 MiB | K   20 | loss  2.416396660e+00 | Ndiff    2.833
       76/1000 after      1 sec. |    154.6 MiB | K   20 | loss  2.415209641e+00 | Ndiff    2.867
       77/1000 after      1 sec. |    154.6 MiB | K   20 | loss  2.413957301e+00 | Ndiff    2.892
       78/1000 after      1 sec. |    154.6 MiB | K   20 | loss  2.412627524e+00 | Ndiff    2.908
       79/1000 after      1 sec. |    154.6 MiB | K   20 | loss  2.411204929e+00 | Ndiff    2.915
       80/1000 after      1 sec. |    154.6 MiB | K   20 | loss  2.409665366e+00 | Ndiff    2.917
       81/1000 after      1 sec. |    154.6 MiB | K   20 | loss  2.407972796e+00 | Ndiff    2.914
       82/1000 after      1 sec. |    154.6 MiB | K   20 | loss  2.406075452e+00 | Ndiff    2.906
       83/1000 after      1 sec. |    154.6 MiB | K   20 | loss  2.403897086e+00 | Ndiff    2.888
       84/1000 after      1 sec. |    154.6 MiB | K   20 | loss  2.401317119e+00 | Ndiff    2.847
       85/1000 after      1 sec. |    154.6 MiB | K   20 | loss  2.398127190e+00 | Ndiff    2.749
       86/1000 after      1 sec. |    154.6 MiB | K   20 | loss  2.393870554e+00 | Ndiff    2.514
       87/1000 after      1 sec. |    154.6 MiB | K   20 | loss  2.385927886e+00 | Ndiff    2.011
       88/1000 after      1 sec. |    154.6 MiB | K   20 | loss  2.373557979e+00 | Ndiff    1.320
       89/1000 after      1 sec. |    154.6 MiB | K   20 | loss  2.370431956e+00 | Ndiff    0.732
       90/1000 after      1 sec. |    154.6 MiB | K   20 | loss  2.361734581e+00 | Ndiff    0.225
       91/1000 after      1 sec. |    154.6 MiB | K   20 | loss  2.356420337e+00 | Ndiff    0.052
       92/1000 after      1 sec. |    154.6 MiB | K   20 | loss  2.356419191e+00 | Ndiff    0.040
       93/1000 after      1 sec. |    154.6 MiB | K   20 | loss  2.356418521e+00 | Ndiff    0.031
       94/1000 after      1 sec. |    154.6 MiB | K   20 | loss  2.356418120e+00 | Ndiff    0.024
       95/1000 after      1 sec. |    154.6 MiB | K   20 | loss  2.356417878e+00 | Ndiff    0.019
       96/1000 after      2 sec. |    154.6 MiB | K   20 | loss  2.356417731e+00 | Ndiff    0.015
       97/1000 after      2 sec. |    154.6 MiB | K   20 | loss  2.356417640e+00 | Ndiff    0.011
       98/1000 after      2 sec. |    154.6 MiB | K   20 | loss  2.356417584e+00 | Ndiff    0.009
       99/1000 after      2 sec. |    154.6 MiB | K   20 | loss  2.356417550e+00 | Ndiff    0.007
      100/1000 after      2 sec. |    154.6 MiB | K   20 | loss  2.356417528e+00 | Ndiff    0.006
      101/1000 after      2 sec. |    154.6 MiB | K   20 | loss  2.356417515e+00 | Ndiff    0.004
      102/1000 after      2 sec. |    154.6 MiB | K   20 | loss  2.356417507e+00 | Ndiff    0.003
      103/1000 after      2 sec. |    154.6 MiB | K   20 | loss  2.356417502e+00 | Ndiff    0.003
      104/1000 after      2 sec. |    154.6 MiB | K   20 | loss  2.356417499e+00 | Ndiff    0.002
      105/1000 after      2 sec. |    154.6 MiB | K   20 | loss  2.356417497e+00 | Ndiff    0.002
      106/1000 after      2 sec. |    154.6 MiB | K   20 | loss  2.356417495e+00 | Ndiff    0.001
      107/1000 after      2 sec. |    154.6 MiB | K   20 | loss  2.356417495e+00 | Ndiff    0.001
      108/1000 after      2 sec. |    154.6 MiB | K   20 | loss  2.356417494e+00 | Ndiff    0.001
      109/1000 after      2 sec. |    154.6 MiB | K   20 | loss  2.356417494e+00 | Ndiff    0.001
      110/1000 after      2 sec. |    154.6 MiB | K   20 | loss  2.356417494e+00 | Ndiff    0.001
      111/1000 after      2 sec. |    154.6 MiB | K   20 | loss  2.356417493e+00 | Ndiff    0.000
      112/1000 after      2 sec. |    154.6 MiB | K   20 | loss  2.356417493e+00 | Ndiff    0.000
      113/1000 after      2 sec. |    154.6 MiB | K   20 | loss  2.356417493e+00 | Ndiff    0.000
      114/1000 after      2 sec. |    154.6 MiB | K   20 | loss  2.356417493e+00 | Ndiff    0.000
      115/1000 after      2 sec. |    154.6 MiB | K   20 | loss  2.356417493e+00 | Ndiff    0.000
      116/1000 after      2 sec. |    154.6 MiB | K   20 | loss  2.356417493e+00 | Ndiff    0.000
      117/1000 after      2 sec. |    154.6 MiB | K   20 | loss  2.356417493e+00 | Ndiff    0.000
... done. converged.

ZeroMeanGauss observations + VB

Assume full covariances and fix all means to zero.

Start with too many clusters (K=20)

zm_trained_model, zm_info_dict = bnpy.run(
    dataset, 'DPMixtureModel', 'ZeroMeanGauss', 'VB',
    output_path='/tmp/faithful/showcase-K=20-lik=ZeroMeanGauss-ECovMat=5*eye/',
    nLap=1000, nTask=1, nBatch=1, convergeThr=0.0001,
    gamma0=gamma, sF=sF, ECovMat='eye',
    K=K, initname='randexamples',
    )
show_clusters_over_time(zm_info_dict['task_output_path'])
../../_images/sphx_glr_plot-01-demo=vb_algs-model=mix_gauss_004.png

Out:

WARNING: Found unrecognized keyword args. These are ignored.
  --nBatch
Dataset Summary:
X Data
  num examples: 272
  num dims: 2
Allocation Model:  DP mixture with K=0. Concentration gamma0= 5.00
Obs. Data  Model:  Gaussian with fixed zero means, full covariance.
Obs. Data  Prior:  Wishart on prec matrix Lam
  E[ CovMat[k] ] =
  [[ 5.  0.]
   [ 0.  5.]]
Initialization:
  initname = randexamples
  K = 20 (number of clusters)
  seed = 1607680
  elapsed_time: 0.0 sec
Learn Alg: VB | task  1/1 | alg. seed: 1607680 | data order seed: 8541952
task_output_path: /tmp/faithful/showcase-K=20-lik=ZeroMeanGauss-ECovMat=5*eye/1
        1/1000 after      0 sec. |    154.6 MiB | K   20 | loss  4.019419551e+00 |
        2/1000 after      0 sec. |    154.6 MiB | K   20 | loss  3.989063967e+00 | Ndiff    3.437
        3/1000 after      0 sec. |    154.6 MiB | K   20 | loss  3.941416771e+00 | Ndiff    3.621
        4/1000 after      0 sec. |    154.6 MiB | K   20 | loss  3.899222162e+00 | Ndiff    3.882
        5/1000 after      0 sec. |    154.6 MiB | K   20 | loss  3.876371765e+00 | Ndiff    4.129
        6/1000 after      0 sec. |    154.6 MiB | K   20 | loss  3.842038547e+00 | Ndiff    4.274
        7/1000 after      0 sec. |    154.6 MiB | K   20 | loss  3.808536791e+00 | Ndiff    4.341
        8/1000 after      0 sec. |    154.6 MiB | K   20 | loss  3.779782054e+00 | Ndiff    4.418
        9/1000 after      0 sec. |    154.6 MiB | K   20 | loss  3.755287922e+00 | Ndiff    4.502
       10/1000 after      0 sec. |    154.6 MiB | K   20 | loss  3.733837701e+00 | Ndiff    4.577
       11/1000 after      0 sec. |    154.6 MiB | K   20 | loss  3.700117498e+00 | Ndiff    4.646
       12/1000 after      0 sec. |    154.6 MiB | K   20 | loss  3.694082693e+00 | Ndiff    4.826
       13/1000 after      0 sec. |    154.6 MiB | K   20 | loss  3.686765019e+00 | Ndiff    4.986
       14/1000 after      0 sec. |    154.6 MiB | K   20 | loss  3.675117147e+00 | Ndiff    5.045
       15/1000 after      0 sec. |    154.6 MiB | K   20 | loss  3.652297952e+00 | Ndiff    4.998
       16/1000 after      0 sec. |    154.6 MiB | K   20 | loss  3.639249110e+00 | Ndiff    4.933
       17/1000 after      0 sec. |    154.6 MiB | K   20 | loss  3.603588941e+00 | Ndiff    4.758
       18/1000 after      0 sec. |    154.6 MiB | K   20 | loss  3.598749399e+00 | Ndiff    4.729
       19/1000 after      0 sec. |    154.6 MiB | K   20 | loss  3.592583511e+00 | Ndiff    4.761
       20/1000 after      0 sec. |    154.6 MiB | K   20 | loss  3.574936974e+00 | Ndiff    4.760
       21/1000 after      0 sec. |    154.6 MiB | K   20 | loss  3.572459120e+00 | Ndiff    4.885
       22/1000 after      0 sec. |    154.6 MiB | K   20 | loss  3.570977707e+00 | Ndiff    5.084
       23/1000 after      0 sec. |    154.6 MiB | K   20 | loss  3.569306514e+00 | Ndiff    5.272
       24/1000 after      0 sec. |    154.6 MiB | K   20 | loss  3.567412054e+00 | Ndiff    5.442
       25/1000 after      0 sec. |    154.6 MiB | K   20 | loss  3.565249330e+00 | Ndiff    5.588
       26/1000 after      0 sec. |    154.6 MiB | K   20 | loss  3.562749369e+00 | Ndiff    5.701
       27/1000 after      0 sec. |    154.6 MiB | K   20 | loss  3.559772804e+00 | Ndiff    5.771
       28/1000 after      0 sec. |    154.6 MiB | K   20 | loss  3.555894232e+00 | Ndiff    5.784
       29/1000 after      0 sec. |    154.6 MiB | K   20 | loss  3.549078207e+00 | Ndiff    5.714
       30/1000 after      0 sec. |    154.6 MiB | K   20 | loss  3.531303622e+00 | Ndiff    5.559
       31/1000 after      0 sec. |    154.6 MiB | K   20 | loss  3.529382170e+00 | Ndiff    5.572
       32/1000 after      0 sec. |    154.6 MiB | K   20 | loss  3.527237440e+00 | Ndiff    5.642
       33/1000 after      0 sec. |    154.6 MiB | K   20 | loss  3.524655304e+00 | Ndiff    5.674
       34/1000 after      0 sec. |    154.6 MiB | K   20 | loss  3.521314832e+00 | Ndiff    5.648
       35/1000 after      0 sec. |    154.6 MiB | K   20 | loss  3.516111933e+00 | Ndiff    5.520
       36/1000 after      0 sec. |    154.6 MiB | K   20 | loss  3.502294322e+00 | Ndiff    5.202
       37/1000 after      0 sec. |    154.6 MiB | K   20 | loss  3.491586191e+00 | Ndiff    4.853
       38/1000 after      0 sec. |    154.6 MiB | K   20 | loss  3.485039030e+00 | Ndiff    4.637
       39/1000 after      0 sec. |    154.6 MiB | K   20 | loss  3.468792370e+00 | Ndiff    4.344
       40/1000 after      0 sec. |    154.6 MiB | K   20 | loss  3.467909258e+00 | Ndiff    4.313
       41/1000 after      1 sec. |    154.6 MiB | K   20 | loss  3.466944463e+00 | Ndiff    4.367
       42/1000 after      1 sec. |    154.6 MiB | K   20 | loss  3.465874350e+00 | Ndiff    4.413
       43/1000 after      1 sec. |    154.6 MiB | K   20 | loss  3.464679105e+00 | Ndiff    4.447
       44/1000 after      1 sec. |    154.6 MiB | K   20 | loss  3.463333221e+00 | Ndiff    4.465
       45/1000 after      1 sec. |    154.6 MiB | K   20 | loss  3.461803293e+00 | Ndiff    4.461
       46/1000 after      1 sec. |    154.6 MiB | K   20 | loss  3.460044727e+00 | Ndiff    4.426
       47/1000 after      1 sec. |    154.6 MiB | K   20 | loss  3.457996608e+00 | Ndiff    4.349
       48/1000 after      1 sec. |    154.6 MiB | K   20 | loss  3.455573178e+00 | Ndiff    4.210
       49/1000 after      1 sec. |    154.6 MiB | K   20 | loss  3.452647608e+00 | Ndiff    3.983
       50/1000 after      1 sec. |    154.6 MiB | K   20 | loss  3.449011332e+00 | Ndiff    3.629
       51/1000 after      1 sec. |    154.6 MiB | K   20 | loss  3.444220039e+00 | Ndiff    3.091
       52/1000 after      1 sec. |    154.6 MiB | K   20 | loss  3.436689553e+00 | Ndiff    2.295
       53/1000 after      1 sec. |    154.6 MiB | K   20 | loss  3.416217624e+00 | Ndiff    1.184
       54/1000 after      1 sec. |    154.6 MiB | K   20 | loss  3.397676918e+00 | Ndiff    0.269
       55/1000 after      1 sec. |    154.6 MiB | K   20 | loss  3.395369454e+00 | Ndiff    0.002
       56/1000 after      1 sec. |    154.6 MiB | K   20 | loss  3.395369454e+00 | Ndiff    0.000
... done. converged.

Gauss observations + stochastic VB

Assume full covariances and fix all means to zero.

Start with too many clusters (K=20)

stoch_trained_model, stoch_info_dict = bnpy.run(
    dataset, 'DPMixtureModel', 'Gauss', 'soVB',
    output_path=\
        '/tmp/faithful/showcase-K=20-lik=Gauss-ECovMat=5*eye-alg=soVB/',
    nLap=50, nTask=1, nBatch=50,
    rhoexp=0.51, rhodelay=1.0,
    gamma0=gamma, sF=sF, ECovMat='eye',
    K=K, initname='randexamples',
    )
show_clusters_over_time(stoch_info_dict['task_output_path'])
../../_images/sphx_glr_plot-01-demo=vb_algs-model=mix_gauss_005.png

Out:

Dataset Summary:
X Data
  total size: 272 units
  batch size: 6 units
  num. batches: 50
Allocation Model:  DP mixture with K=0. Concentration gamma0= 5.00
Obs. Data  Model:  Gaussian with full covariance.
Obs. Data  Prior:  Gauss-Wishart on mean and covar of each cluster
  E[  mean[k] ] =
   [ 0.  0.]
  E[ covar[k] ] =
  [[ 5.  0.]
   [ 0.  5.]]
Initialization:
  initname = randexamples
  K = 20 (number of clusters)
  seed = 1607680
  elapsed_time: 0.0 sec
Learn Alg: soVB | task  1/1 | alg. seed: 1607680 | data order seed: 8541952
task_output_path: /tmp/faithful/showcase-K=20-lik=Gauss-ECovMat=5*eye-alg=soVB/1
    0.020/50 after      0 sec. |    154.8 MiB | K   20 | loss  3.020105590e+01 |  lrate 0.7022
    0.040/50 after      0 sec. |    154.8 MiB | K   20 | loss  1.693550049e+01 |  lrate 0.5710
    0.060/50 after      0 sec. |    154.8 MiB | K   20 | loss  1.182159020e+01 |  lrate 0.4931
    1.000/50 after      1 sec. |    154.8 MiB | K   20 | loss  1.314590658e+01 |  lrate 0.1346
    2.000/50 after      1 sec. |    154.8 MiB | K   20 | loss  2.522804913e+00 |  lrate 0.0950
    3.000/50 after      2 sec. |    154.8 MiB | K   20 | loss  2.507513303e+00 |  lrate 0.0774
    4.000/50 after      3 sec. |    154.8 MiB | K   20 | loss  2.498366509e+00 |  lrate 0.0669
    5.000/50 after      4 sec. |    154.8 MiB | K   20 | loss  2.475120548e+00 |  lrate 0.0597
    6.000/50 after      4 sec. |    154.8 MiB | K   20 | loss  2.469641662e+00 |  lrate 0.0544
    7.000/50 after      5 sec. |    154.8 MiB | K   20 | loss  2.470932762e+00 |  lrate 0.0503
    8.000/50 after      6 sec. |    154.8 MiB | K   20 | loss  2.466158362e+00 |  lrate 0.0470
    9.000/50 after      6 sec. |    154.8 MiB | K   20 | loss  2.465517619e+00 |  lrate 0.0443
   10.000/50 after      7 sec. |    154.8 MiB | K   20 | loss  2.465435382e+00 |  lrate 0.0420
   11.000/50 after      8 sec. |    154.8 MiB | K   20 | loss  2.467696036e+00 |  lrate 0.0400
   12.000/50 after      8 sec. |    154.8 MiB | K   20 | loss  2.461949856e+00 |  lrate 0.0383
   13.000/50 after      9 sec. |    154.8 MiB | K   20 | loss  2.461043110e+00 |  lrate 0.0367
   14.000/50 after     10 sec. |    154.8 MiB | K   20 | loss  2.461264382e+00 |  lrate 0.0354
   15.000/50 after     11 sec. |    154.8 MiB | K   20 | loss  2.458851451e+00 |  lrate 0.0342
   16.000/50 after     11 sec. |    154.8 MiB | K   20 | loss  2.454655643e+00 |  lrate 0.0330
   17.000/50 after     12 sec. |    154.8 MiB | K   20 | loss  2.451234188e+00 |  lrate 0.0320
   18.000/50 after     13 sec. |    154.8 MiB | K   20 | loss  2.446281746e+00 |  lrate 0.0311
   19.000/50 after     13 sec. |    154.8 MiB | K   20 | loss  2.440793273e+00 |  lrate 0.0303
   20.000/50 after     14 sec. |    154.8 MiB | K   20 | loss  2.434143198e+00 |  lrate 0.0295
   21.000/50 after     15 sec. |    154.8 MiB | K   20 | loss  2.423545849e+00 |  lrate 0.0288
   22.000/50 after     15 sec. |    154.8 MiB | K   20 | loss  2.414924985e+00 |  lrate 0.0281
   23.000/50 after     16 sec. |    154.8 MiB | K   20 | loss  2.405741175e+00 |  lrate 0.0275
   24.000/50 after     17 sec. |    154.8 MiB | K   20 | loss  2.396273286e+00 |  lrate 0.0269
   25.000/50 after     17 sec. |    154.8 MiB | K   20 | loss  2.388169784e+00 |  lrate 0.0263
   26.000/50 after     18 sec. |    154.8 MiB | K   20 | loss  2.383086462e+00 |  lrate 0.0258
   27.000/50 after     19 sec. |    154.8 MiB | K   20 | loss  2.378566216e+00 |  lrate 0.0253
   28.000/50 after     20 sec. |    154.8 MiB | K   20 | loss  2.374468977e+00 |  lrate 0.0248
   29.000/50 after     20 sec. |    154.8 MiB | K   20 | loss  2.369573743e+00 |  lrate 0.0244
   30.000/50 after     21 sec. |    154.8 MiB | K   20 | loss  2.365030091e+00 |  lrate 0.0240
   31.000/50 after     22 sec. |    154.8 MiB | K   20 | loss  2.361840492e+00 |  lrate 0.0236
   32.000/50 after     22 sec. |    154.8 MiB | K   20 | loss  2.358180297e+00 |  lrate 0.0232
   33.000/50 after     23 sec. |    154.8 MiB | K   20 | loss  2.355370718e+00 |  lrate 0.0229
   34.000/50 after     24 sec. |    154.8 MiB | K   20 | loss  2.354001069e+00 |  lrate 0.0225
   35.000/50 after     24 sec. |    154.8 MiB | K   20 | loss  2.352655966e+00 |  lrate 0.0222
   36.000/50 after     25 sec. |    154.8 MiB | K   20 | loss  2.351706756e+00 |  lrate 0.0219
   37.000/50 after     26 sec. |    154.8 MiB | K   20 | loss  2.351346242e+00 |  lrate 0.0216
   38.000/50 after     27 sec. |    154.8 MiB | K   20 | loss  2.351470554e+00 |  lrate 0.0213
   39.000/50 after     27 sec. |    154.8 MiB | K   20 | loss  2.351530696e+00 |  lrate 0.0210
   40.000/50 after     28 sec. |    154.8 MiB | K   20 | loss  2.351708300e+00 |  lrate 0.0207
   41.000/50 after     29 sec. |    154.8 MiB | K   20 | loss  2.351489520e+00 |  lrate 0.0205
   42.000/50 after     29 sec. |    154.8 MiB | K   20 | loss  2.351625569e+00 |  lrate 0.0202
   43.000/50 after     30 sec. |    154.8 MiB | K   20 | loss  2.351357552e+00 |  lrate 0.0200
   44.000/50 after     31 sec. |    154.8 MiB | K   20 | loss  2.351536177e+00 |  lrate 0.0197
   45.000/50 after     32 sec. |    154.8 MiB | K   20 | loss  2.351543540e+00 |  lrate 0.0195
   46.000/50 after     32 sec. |    154.8 MiB | K   20 | loss  2.351346627e+00 |  lrate 0.0193
   47.000/50 after     33 sec. |    154.8 MiB | K   20 | loss  2.351293738e+00 |  lrate 0.0191
   48.000/50 after     34 sec. |    154.8 MiB | K   20 | loss  2.351507196e+00 |  lrate 0.0189
   49.000/50 after     34 sec. |    154.8 MiB | K   20 | loss  2.351496294e+00 |  lrate 0.0187
   50.000/50 after     35 sec. |    154.8 MiB | K   20 | loss  2.351523756e+00 |  lrate 0.0185
... active. not converged.

Compare loss function traces for all methods

pylab.figure()

pylab.plot(
    zm_info_dict['lap_history'],
    zm_info_dict['loss_history'], 'b.-',
    label='full_covar zero_mean')
pylab.plot(
    full_info_dict['lap_history'],
    full_info_dict['loss_history'], 'k.-',
    label='full_covar')
pylab.plot(
    diag_info_dict['lap_history'],
    diag_info_dict['loss_history'], 'r.-',
    label='diag_covar')
pylab.plot(
    stoch_info_dict['lap_history'],
    stoch_info_dict['loss_history'], 'c.:',
    label='full_covar stochastic')
pylab.legend(loc='upper right')
pylab.xlabel('num. laps')
pylab.ylabel('loss')
pylab.xlim([4, 100]) # avoid early iterations
pylab.ylim([2.34, 4.0]) # handpicked
pylab.draw()
pylab.tight_layout()
../../_images/sphx_glr_plot-01-demo=vb_algs-model=mix_gauss_006.png

Total running time of the script: ( 0 minutes 42.505 seconds)

Generated by Sphinx-Gallery