Comparing models for sequential data

How to train mixtures and HMMs with various observation models on the same dataset.

import bnpy
import numpy as np
import os

from matplotlib import pylab
import seaborn as sns

SMALL_FIG_SIZE = (2.5, 2.5)
FIG_SIZE = (5, 5)
pylab.rcParams['figure.figsize'] = FIG_SIZE

Load dataset from file

dataset_path = os.path.join(bnpy.DATASET_PATH, 'mocap6')
dataset = bnpy.data.GroupXData.read_npz(
    os.path.join(dataset_path, 'dataset.npz'))

Setup: Function to make a simple plot of the raw data

def show_single_sequence(seq_id):
    start = dataset.doc_range[seq_id]
    stop = dataset.doc_range[seq_id + 1]
    for dim in xrange(12):
        X_seq = dataset.X[start:stop]
        pylab.plot(X_seq[:, dim], '.-')
    pylab.xlabel('time')
    pylab.ylabel('angle')
    pylab.tight_layout()

Visualization of the first sequence

show_single_sequence(0)
../../_images/sphx_glr_plot-01-demo=many_models_same_data_001.png

Visualization of the second sequence

show_single_sequence(1)
../../_images/sphx_glr_plot-01-demo=many_models_same_data_002.png

Setup: hyperparameters

alpha = 0.5
gamma = 5.0
sF = 1.0
K = 20

DP mixture with DiagGauss observation model

mixdiag_trained_model, mixdiag_info_dict = bnpy.run(
    dataset, 'DPMixtureModel', 'DiagGauss', 'memoVB',
    output_path='/tmp/mocap6/showcase-K=20-model=DP+DiagGauss-ECovMat=1*eye/',
    nLap=50, nTask=1, nBatch=1, convergeThr=0.0001,
    alpha=alpha, gamma=gamma, sF=sF, ECovMat='eye',
    K=K, initname='randexamples',
    )

Out:

WARNING: Found unrecognized keyword args. These are ignored.
  --alpha
Dataset Summary:
GroupXData
  total size: 6 units
  batch size: 6 units
  num. batches: 1
Allocation Model:  DP mixture with K=0. Concentration gamma0= 5.00
Obs. Data  Model:  Gaussian with diagonal covariance.
Obs. Data  Prior:  independent Gauss-Wishart prior on each dimension
  Wishart params
    nu = 14  ...
  beta = [ 12  12]  ...
  Expectations
  E[  mean[k]] =
  [ 0  0] ...
  E[ covar[k]] =
  [[ 1.  0.]
   [ 0.  1.]] ...
Initialization:
  initname = randexamples
  K = 20 (number of clusters)
  seed = 1607680
  elapsed_time: 0.0 sec
Learn Alg: memoVB | task  1/1 | alg. seed: 1607680 | data order seed: 8541952
task_output_path: /tmp/mocap6/showcase-K=20-model=DP+DiagGauss-ECovMat=1*eye/1
    1.000/50 after      0 sec. |    178.8 MiB | K   20 | loss  3.845424399e+00 |
    2.000/50 after      0 sec. |    178.8 MiB | K   20 | loss  3.756537185e+00 | Ndiff   26.569
    3.000/50 after      0 sec. |    178.8 MiB | K   20 | loss  3.723105766e+00 | Ndiff   29.336
    4.000/50 after      0 sec. |    178.8 MiB | K   20 | loss  3.703275957e+00 | Ndiff   19.798
    5.000/50 after      0 sec. |    178.8 MiB | K   20 | loss  3.689932307e+00 | Ndiff   22.929
    6.000/50 after      0 sec. |    178.8 MiB | K   20 | loss  3.681432839e+00 | Ndiff   24.626
    7.000/50 after      0 sec. |    178.8 MiB | K   20 | loss  3.676243541e+00 | Ndiff   20.961
    8.000/50 after      0 sec. |    178.8 MiB | K   20 | loss  3.672159855e+00 | Ndiff   16.818
    9.000/50 after      0 sec. |    178.8 MiB | K   20 | loss  3.669045283e+00 | Ndiff   13.011
   10.000/50 after      0 sec. |    178.8 MiB | K   20 | loss  3.665208654e+00 | Ndiff   13.962
   11.000/50 after      0 sec. |    178.8 MiB | K   20 | loss  3.661952478e+00 | Ndiff   16.751
   12.000/50 after      0 sec. |    178.8 MiB | K   20 | loss  3.658115620e+00 | Ndiff   17.355
   13.000/50 after      0 sec. |    178.8 MiB | K   20 | loss  3.655058982e+00 | Ndiff   12.150
   14.000/50 after      0 sec. |    178.8 MiB | K   20 | loss  3.653853641e+00 | Ndiff    5.448
   15.000/50 after      0 sec. |    178.8 MiB | K   20 | loss  3.652570589e+00 | Ndiff    5.226
   16.000/50 after      0 sec. |    178.8 MiB | K   20 | loss  3.651988134e+00 | Ndiff    5.605
   17.000/50 after      0 sec. |    178.8 MiB | K   20 | loss  3.651566377e+00 | Ndiff    5.980
   18.000/50 after      0 sec. |    178.8 MiB | K   20 | loss  3.651133556e+00 | Ndiff    6.283
   19.000/50 after      0 sec. |    178.8 MiB | K   20 | loss  3.650800669e+00 | Ndiff    6.427
   20.000/50 after      0 sec. |    178.8 MiB | K   20 | loss  3.650523019e+00 | Ndiff    6.354
   21.000/50 after      0 sec. |    178.8 MiB | K   20 | loss  3.650284547e+00 | Ndiff    6.121
   22.000/50 after      0 sec. |    178.8 MiB | K   20 | loss  3.650075236e+00 | Ndiff    5.571
   23.000/50 after      0 sec. |    178.8 MiB | K   20 | loss  3.649874747e+00 | Ndiff    4.678
   24.000/50 after      0 sec. |    178.8 MiB | K   20 | loss  3.649684370e+00 | Ndiff    3.694
   25.000/50 after      0 sec. |    178.8 MiB | K   20 | loss  3.649503048e+00 | Ndiff    2.678
   26.000/50 after      0 sec. |    178.8 MiB | K   20 | loss  3.649287038e+00 | Ndiff    2.238
   27.000/50 after      0 sec. |    178.8 MiB | K   20 | loss  3.648992642e+00 | Ndiff    1.703
   28.000/50 after      1 sec. |    178.8 MiB | K   20 | loss  3.648274600e+00 | Ndiff    1.704
   29.000/50 after      1 sec. |    178.8 MiB | K   20 | loss  3.647520307e+00 | Ndiff    1.649
   30.000/50 after      1 sec. |    178.8 MiB | K   20 | loss  3.647350990e+00 | Ndiff    1.528
   31.000/50 after      1 sec. |    178.8 MiB | K   20 | loss  3.647242457e+00 | Ndiff    1.494
   32.000/50 after      1 sec. |    178.8 MiB | K   20 | loss  3.647171899e+00 | Ndiff    1.870
   33.000/50 after      1 sec. |    178.8 MiB | K   20 | loss  3.647125709e+00 | Ndiff    2.031
   34.000/50 after      1 sec. |    178.8 MiB | K   20 | loss  3.647094615e+00 | Ndiff    2.051
   35.000/50 after      1 sec. |    178.8 MiB | K   20 | loss  3.647072676e+00 | Ndiff    1.974
   36.000/50 after      1 sec. |    178.8 MiB | K   20 | loss  3.647056795e+00 | Ndiff    1.830
   37.000/50 after      1 sec. |    178.8 MiB | K   20 | loss  3.647045269e+00 | Ndiff    1.651
   38.000/50 after      1 sec. |    178.8 MiB | K   20 | loss  3.647036962e+00 | Ndiff    1.460
   39.000/50 after      1 sec. |    178.8 MiB | K   20 | loss  3.647031023e+00 | Ndiff    1.273
   40.000/50 after      1 sec. |    178.8 MiB | K   20 | loss  3.647026801e+00 | Ndiff    1.099
   41.000/50 after      1 sec. |    178.8 MiB | K   20 | loss  3.647023808e+00 | Ndiff    0.943
   42.000/50 after      1 sec. |    178.8 MiB | K   20 | loss  3.647021689e+00 | Ndiff    0.806
   43.000/50 after      1 sec. |    178.8 MiB | K   20 | loss  3.647020187e+00 | Ndiff    0.687
   44.000/50 after      1 sec. |    178.8 MiB | K   20 | loss  3.647019120e+00 | Ndiff    0.584
   45.000/50 after      1 sec. |    178.8 MiB | K   20 | loss  3.647018361e+00 | Ndiff    0.496
   46.000/50 after      1 sec. |    178.8 MiB | K   20 | loss  3.647017819e+00 | Ndiff    0.421
   47.000/50 after      1 sec. |    178.8 MiB | K   20 | loss  3.647017432e+00 | Ndiff    0.357
   48.000/50 after      1 sec. |    178.8 MiB | K   20 | loss  3.647017154e+00 | Ndiff    0.303
   49.000/50 after      1 sec. |    178.8 MiB | K   20 | loss  3.647016955e+00 | Ndiff    0.258
   50.000/50 after      1 sec. |    178.8 MiB | K   20 | loss  3.647016811e+00 | Ndiff    0.219
... done. not converged. max laps thru data exceeded.

HDP-HMM with DiagGauss observation model

Assume diagonal covariances.

Start with too many clusters (K=20)

hmmdiag_trained_model, hmmdiag_info_dict = bnpy.run(
    dataset, 'HDPHMM', 'DiagGauss', 'memoVB',
    output_path='/tmp/mocap6/showcase-K=20-model=HDPHMM+DiagGauss-ECovMat=1*eye/',
    nLap=50, nTask=1, nBatch=1, convergeThr=0.0001,
    alpha=alpha, gamma=gamma, sF=sF, ECovMat='eye',
    K=K, initname='randexamples',
    )

Out:

WARNING: Found unrecognized keyword args. These are ignored.
  --alpha
Dataset Summary:
GroupXData
  total size: 6 units
  batch size: 6 units
  num. batches: 1
Allocation Model:  None
Obs. Data  Model:  Gaussian with diagonal covariance.
Obs. Data  Prior:  independent Gauss-Wishart prior on each dimension
  Wishart params
    nu = 14  ...
  beta = [ 12  12]  ...
  Expectations
  E[  mean[k]] =
  [ 0  0] ...
  E[ covar[k]] =
  [[ 1.  0.]
   [ 0.  1.]] ...
Initialization:
  initname = randexamples
  K = 20 (number of clusters)
  seed = 1607680
  elapsed_time: 0.0 sec
Learn Alg: memoVB | task  1/1 | alg. seed: 1607680 | data order seed: 8541952
task_output_path: /tmp/mocap6/showcase-K=20-model=HDPHMM+DiagGauss-ECovMat=1*eye/1
    1.000/50 after      0 sec. |    180.1 MiB | K   20 | loss  3.717124121e+00 |
    2.000/50 after      0 sec. |    179.0 MiB | K   20 | loss  3.612329446e+00 | Ndiff   35.499
    3.000/50 after      1 sec. |    179.0 MiB | K   20 | loss  3.580489903e+00 | Ndiff   22.259
    4.000/50 after      1 sec. |    179.0 MiB | K   20 | loss  3.565890757e+00 | Ndiff   19.752
    5.000/50 after      1 sec. |    179.0 MiB | K   20 | loss  3.552817867e+00 | Ndiff   18.510
    6.000/50 after      1 sec. |    179.0 MiB | K   20 | loss  3.547626693e+00 | Ndiff    9.442
    7.000/50 after      1 sec. |    179.0 MiB | K   20 | loss  3.545715575e+00 | Ndiff    3.145
    8.000/50 after      2 sec. |    179.0 MiB | K   20 | loss  3.542710111e+00 | Ndiff    8.834
    9.000/50 after      2 sec. |    179.0 MiB | K   20 | loss  3.534107178e+00 | Ndiff   14.540
   10.000/50 after      2 sec. |    179.0 MiB | K   20 | loss  3.530991441e+00 | Ndiff    9.174
   11.000/50 after      2 sec. |    179.0 MiB | K   20 | loss  3.527520848e+00 | Ndiff    9.868
   12.000/50 after      2 sec. |    179.0 MiB | K   20 | loss  3.525056000e+00 | Ndiff   11.883
   13.000/50 after      3 sec. |    179.0 MiB | K   20 | loss  3.523781079e+00 | Ndiff    4.145
   14.000/50 after      3 sec. |    179.0 MiB | K   20 | loss  3.521694790e+00 | Ndiff    8.826
   15.000/50 after      3 sec. |    179.0 MiB | K   20 | loss  3.518502897e+00 | Ndiff    4.970
   16.000/50 after      3 sec. |    179.0 MiB | K   20 | loss  3.517515182e+00 | Ndiff    4.090
   17.000/50 after      3 sec. |    179.0 MiB | K   20 | loss  3.517412339e+00 | Ndiff    0.535
   18.000/50 after      3 sec. |    179.0 MiB | K   20 | loss  3.517396038e+00 | Ndiff    0.203
   19.000/50 after      4 sec. |    179.0 MiB | K   20 | loss  3.517386711e+00 | Ndiff    0.244
   20.000/50 after      4 sec. |    179.0 MiB | K   20 | loss  3.517380000e+00 | Ndiff    0.274
   21.000/50 after      4 sec. |    179.0 MiB | K   20 | loss  3.517375098e+00 | Ndiff    0.264
   22.000/50 after      4 sec. |    179.0 MiB | K   20 | loss  3.517371095e+00 | Ndiff    0.208
   23.000/50 after      4 sec. |    179.0 MiB | K   20 | loss  3.517367640e+00 | Ndiff    0.144
   24.000/50 after      5 sec. |    179.0 MiB | K   20 | loss  3.517364400e+00 | Ndiff    0.139
   25.000/50 after      5 sec. |    179.0 MiB | K   20 | loss  3.517361220e+00 | Ndiff    0.166
   26.000/50 after      5 sec. |    179.0 MiB | K   20 | loss  3.517357862e+00 | Ndiff    0.226
   27.000/50 after      5 sec. |    179.0 MiB | K   20 | loss  3.517353576e+00 | Ndiff    0.359
   28.000/50 after      5 sec. |    179.0 MiB | K   20 | loss  3.517347351e+00 | Ndiff    0.518
   29.000/50 after      5 sec. |    179.0 MiB | K   20 | loss  3.517338850e+00 | Ndiff    0.683
   30.000/50 after      6 sec. |    179.0 MiB | K   20 | loss  3.517329011e+00 | Ndiff    0.799
   31.000/50 after      6 sec. |    179.0 MiB | K   20 | loss  3.517319213e+00 | Ndiff    0.813
   32.000/50 after      6 sec. |    179.0 MiB | K   20 | loss  3.517310197e+00 | Ndiff    0.746
   33.000/50 after      6 sec. |    179.0 MiB | K   20 | loss  3.517301719e+00 | Ndiff    0.673
   34.000/50 after      6 sec. |    179.0 MiB | K   20 | loss  3.517293335e+00 | Ndiff    0.634
   35.000/50 after      6 sec. |    179.0 MiB | K   20 | loss  3.517285024e+00 | Ndiff    0.618
   36.000/50 after      7 sec. |    179.0 MiB | K   20 | loss  3.517277155e+00 | Ndiff    0.602
   37.000/50 after      7 sec. |    179.0 MiB | K   20 | loss  3.517270251e+00 | Ndiff    0.571
   38.000/50 after      7 sec. |    179.0 MiB | K   20 | loss  3.517264770e+00 | Ndiff    0.520
   39.000/50 after      7 sec. |    179.0 MiB | K   20 | loss  3.517260855e+00 | Ndiff    0.457
   40.000/50 after      7 sec. |    179.0 MiB | K   20 | loss  3.517258241e+00 | Ndiff    0.394
   41.000/50 after      7 sec. |    179.0 MiB | K   20 | loss  3.517256507e+00 | Ndiff    0.341
   42.000/50 after      7 sec. |    179.0 MiB | K   20 | loss  3.517255300e+00 | Ndiff    0.300
   43.000/50 after      8 sec. |    179.0 MiB | K   20 | loss  3.517254394e+00 | Ndiff    0.272
   44.000/50 after      8 sec. |    179.0 MiB | K   20 | loss  3.517253647e+00 | Ndiff    0.256
   45.000/50 after      8 sec. |    179.0 MiB | K   20 | loss  3.517252964e+00 | Ndiff    0.251
   46.000/50 after      8 sec. |    179.0 MiB | K   20 | loss  3.517252267e+00 | Ndiff    0.259
   47.000/50 after      8 sec. |    179.0 MiB | K   20 | loss  3.517251461e+00 | Ndiff    0.283
   48.000/50 after      8 sec. |    179.0 MiB | K   20 | loss  3.517250379e+00 | Ndiff    0.332
   49.000/50 after      9 sec. |    179.0 MiB | K   20 | loss  3.517248618e+00 | Ndiff    0.430
   50.000/50 after      9 sec. |    179.0 MiB | K   20 | loss  3.517244801e+00 | Ndiff    0.643
... done. not converged. max laps thru data exceeded.

HDP-HMM with Gauss observation model

Assume full covariances.

Start with too many clusters (K=20)

hmmfull_trained_model, hmmfull_info_dict = bnpy.run(
    dataset, 'HDPHMM', 'Gauss', 'memoVB',
    output_path='/tmp/mocap6/showcase-K=20-model=HDPHMM+Gauss-ECovMat=1*eye/',
    nLap=50, nTask=1, nBatch=1, convergeThr=0.0001,
    alpha=alpha, gamma=gamma, sF=sF, ECovMat='eye',
    K=K, initname='randexamples',
    )

Out:

WARNING: Found unrecognized keyword args. These are ignored.
  --alpha
Dataset Summary:
GroupXData
  total size: 6 units
  batch size: 6 units
  num. batches: 1
Allocation Model:  None
Obs. Data  Model:  Gaussian with full covariance.
Obs. Data  Prior:  Gauss-Wishart on mean and covar of each cluster
  E[  mean[k] ] =
   [ 0.  0.]  ...
  E[ covar[k] ] =
  [[ 1.  0.]
   [ 0.  1.]] ...
Initialization:
  initname = randexamples
  K = 20 (number of clusters)
  seed = 1607680
  elapsed_time: 0.0 sec
Learn Alg: memoVB | task  1/1 | alg. seed: 1607680 | data order seed: 8541952
task_output_path: /tmp/mocap6/showcase-K=20-model=HDPHMM+Gauss-ECovMat=1*eye/1
    1.000/50 after      0 sec. |    179.0 MiB | K   20 | loss  3.453792348e+00 |
    2.000/50 after      0 sec. |    179.0 MiB | K   20 | loss  3.332556042e+00 | Ndiff   27.128
    3.000/50 after      1 sec. |    179.0 MiB | K   20 | loss  3.280863302e+00 | Ndiff   33.040
    4.000/50 after      1 sec. |    179.0 MiB | K   20 | loss  3.247802919e+00 | Ndiff   27.962
    5.000/50 after      1 sec. |    179.0 MiB | K   20 | loss  3.232386544e+00 | Ndiff   16.663
    6.000/50 after      1 sec. |    179.0 MiB | K   20 | loss  3.214482068e+00 | Ndiff   24.147
    7.000/50 after      2 sec. |    179.0 MiB | K   20 | loss  3.202981898e+00 | Ndiff   20.209
    8.000/50 after      2 sec. |    179.0 MiB | K   20 | loss  3.198209231e+00 | Ndiff    9.843
    9.000/50 after      2 sec. |    179.0 MiB | K   20 | loss  3.196602056e+00 | Ndiff    4.055
   10.000/50 after      2 sec. |    179.0 MiB | K   20 | loss  3.194049899e+00 | Ndiff    3.639
   11.000/50 after      2 sec. |    179.0 MiB | K   20 | loss  3.192384953e+00 | Ndiff    3.209
   12.000/50 after      3 sec. |    179.0 MiB | K   20 | loss  3.190378614e+00 | Ndiff    3.166
   13.000/50 after      3 sec. |    179.0 MiB | K   20 | loss  3.188801039e+00 | Ndiff    4.323
   14.000/50 after      3 sec. |    179.0 MiB | K   20 | loss  3.187980225e+00 | Ndiff    2.014
   15.000/50 after      3 sec. |    179.0 MiB | K   20 | loss  3.187267197e+00 | Ndiff    1.176
   16.000/50 after      3 sec. |    179.0 MiB | K   20 | loss  3.187104503e+00 | Ndiff    0.703
   17.000/50 after      4 sec. |    179.0 MiB | K   20 | loss  3.186910658e+00 | Ndiff    0.921
   18.000/50 after      4 sec. |    179.0 MiB | K   20 | loss  3.186639962e+00 | Ndiff    1.257
   19.000/50 after      4 sec. |    179.0 MiB | K   20 | loss  3.186450784e+00 | Ndiff    0.927
   20.000/50 after      4 sec. |    179.0 MiB | K   20 | loss  3.185973787e+00 | Ndiff    1.523
   21.000/50 after      4 sec. |    179.0 MiB | K   20 | loss  3.185779780e+00 | Ndiff    0.970
   22.000/50 after      5 sec. |    179.0 MiB | K   20 | loss  3.185453273e+00 | Ndiff    0.954
   23.000/50 after      5 sec. |    179.0 MiB | K   20 | loss  3.182363213e+00 | Ndiff    2.429
   24.000/50 after      5 sec. |    179.0 MiB | K   20 | loss  3.182200890e+00 | Ndiff    1.243
   25.000/50 after      5 sec. |    179.0 MiB | K   20 | loss  3.182084494e+00 | Ndiff    1.106
   26.000/50 after      5 sec. |    179.0 MiB | K   20 | loss  3.182015005e+00 | Ndiff    0.783
   27.000/50 after      6 sec. |    179.0 MiB | K   20 | loss  3.181989580e+00 | Ndiff    0.396
   28.000/50 after      6 sec. |    179.0 MiB | K   20 | loss  3.181974836e+00 | Ndiff    0.244
   29.000/50 after      6 sec. |    179.0 MiB | K   20 | loss  3.181947074e+00 | Ndiff    0.356
   30.000/50 after      6 sec. |    179.0 MiB | K   20 | loss  3.181838450e+00 | Ndiff    0.658
   31.000/50 after      6 sec. |    179.0 MiB | K   20 | loss  3.181817711e+00 | Ndiff    0.226
   32.000/50 after      7 sec. |    179.0 MiB | K   20 | loss  3.181814701e+00 | Ndiff    0.115
   33.000/50 after      7 sec. |    179.0 MiB | K   20 | loss  3.181813740e+00 | Ndiff    0.058
   34.000/50 after      7 sec. |    179.0 MiB | K   20 | loss  3.181813509e+00 | Ndiff    0.026
   35.000/50 after      7 sec. |    179.0 MiB | K   20 | loss  3.181813438e+00 | Ndiff    0.018
   36.000/50 after      7 sec. |    179.0 MiB | K   20 | loss  3.181813408e+00 | Ndiff    0.012
   37.000/50 after      8 sec. |    179.0 MiB | K   20 | loss  3.181813395e+00 | Ndiff    0.009
   38.000/50 after      8 sec. |    179.0 MiB | K   20 | loss  3.181813388e+00 | Ndiff    0.006
   39.000/50 after      8 sec. |    179.0 MiB | K   20 | loss  3.181813385e+00 | Ndiff    0.004
   40.000/50 after      8 sec. |    179.0 MiB | K   20 | loss  3.181813383e+00 | Ndiff    0.003
   41.000/50 after      8 sec. |    179.0 MiB | K   20 | loss  3.181813382e+00 | Ndiff    0.002
   42.000/50 after      9 sec. |    179.0 MiB | K   20 | loss  3.181813382e+00 | Ndiff    0.002
   43.000/50 after      9 sec. |    179.0 MiB | K   20 | loss  3.181813381e+00 | Ndiff    0.001
   44.000/50 after      9 sec. |    179.0 MiB | K   20 | loss  3.181813381e+00 | Ndiff    0.001
   45.000/50 after      9 sec. |    179.0 MiB | K   20 | loss  3.181813381e+00 | Ndiff    0.001
   46.000/50 after      9 sec. |    179.0 MiB | K   20 | loss  3.181813381e+00 | Ndiff    0.001
   47.000/50 after     10 sec. |    179.0 MiB | K   20 | loss  3.181813381e+00 | Ndiff    0.000
   48.000/50 after     10 sec. |    179.0 MiB | K   20 | loss  3.181813381e+00 | Ndiff    0.000
   49.000/50 after     10 sec. |    179.0 MiB | K   20 | loss  3.181813381e+00 | Ndiff    0.000
   50.000/50 after     10 sec. |    179.0 MiB | K   20 | loss  3.181813381e+00 | Ndiff    0.000
... done. not converged. max laps thru data exceeded.

HDP-HMM with AutoRegGauss observation model

Assume full covariances.

Start with too many clusters (K=20)

hmmar_trained_model, hmmar_info_dict = bnpy.run(
    dataset, 'HDPHMM', 'AutoRegGauss', 'memoVB',
    output_path='/tmp/mocap6/showcase-K=20-model=HDPHMM+AutoRegGauss-ECovMat=1*eye/',
    nLap=50, nTask=1, nBatch=1, convergeThr=0.0001,
    alpha=alpha, gamma=gamma, sF=sF, ECovMat='eye',
    K=K, initname='randexamples',
    )

Out:

WARNING: Found unrecognized keyword args. These are ignored.
  --alpha
Dataset Summary:
GroupXData
  total size: 6 units
  batch size: 6 units
  num. batches: 1
Allocation Model:  None
Obs. Data  Model:  Auto-Regressive Gaussian with full covariance.
Obs. Data  Prior:  MatrixNormal-Wishart on each mean/prec matrix pair: A, Lam
  E[ A ] =
  [[ 1.  0.]
   [ 0.  1.]] ...
  E[ Sigma ] =
  [[ 1.  0.]
   [ 0.  1.]] ...
Initialization:
  initname = randexamples
  K = 20 (number of clusters)
  seed = 1607680
  elapsed_time: 0.0 sec
Learn Alg: memoVB | task  1/1 | alg. seed: 1607680 | data order seed: 8541952
task_output_path: /tmp/mocap6/showcase-K=20-model=HDPHMM+AutoRegGauss-ECovMat=1*eye/1
    1.000/50 after      0 sec. |    179.0 MiB | K   20 | loss  2.908127656e+00 |
    2.000/50 after      1 sec. |    179.0 MiB | K   20 | loss  2.712797722e+00 | Ndiff   53.407
    3.000/50 after      1 sec. |    179.0 MiB | K   20 | loss  2.623844667e+00 | Ndiff   51.736
    4.000/50 after      1 sec. |    179.0 MiB | K   20 | loss  2.579060900e+00 | Ndiff   47.749
    5.000/50 after      1 sec. |    179.0 MiB | K   20 | loss  2.548616789e+00 | Ndiff   37.879
    6.000/50 after      2 sec. |    179.0 MiB | K   20 | loss  2.518588127e+00 | Ndiff   28.737
    7.000/50 after      2 sec. |    179.0 MiB | K   20 | loss  2.504804864e+00 | Ndiff   10.223
    8.000/50 after      2 sec. |    179.0 MiB | K   20 | loss  2.496420689e+00 | Ndiff    3.905
    9.000/50 after      2 sec. |    179.0 MiB | K   20 | loss  2.492445130e+00 | Ndiff    1.966
   10.000/50 after      3 sec. |    179.0 MiB | K   20 | loss  2.490165194e+00 | Ndiff    1.965
   11.000/50 after      3 sec. |    179.0 MiB | K   20 | loss  2.485730707e+00 | Ndiff    4.873
   12.000/50 after      3 sec. |    179.0 MiB | K   20 | loss  2.484287434e+00 | Ndiff    2.107
   13.000/50 after      3 sec. |    179.0 MiB | K   20 | loss  2.482746506e+00 | Ndiff    2.765
   14.000/50 after      4 sec. |    179.0 MiB | K   20 | loss  2.480880306e+00 | Ndiff    2.160
   15.000/50 after      4 sec. |    179.0 MiB | K   20 | loss  2.480467067e+00 | Ndiff    1.079
   16.000/50 after      4 sec. |    179.0 MiB | K   20 | loss  2.479991235e+00 | Ndiff    1.090
   17.000/50 after      5 sec. |    179.0 MiB | K   20 | loss  2.479792583e+00 | Ndiff    1.188
   18.000/50 after      5 sec. |    179.0 MiB | K   20 | loss  2.478520054e+00 | Ndiff    2.398
   19.000/50 after      5 sec. |    179.0 MiB | K   20 | loss  2.477365552e+00 | Ndiff    1.471
   20.000/50 after      5 sec. |    179.0 MiB | K   20 | loss  2.476525301e+00 | Ndiff    2.496
   21.000/50 after      6 sec. |    179.0 MiB | K   20 | loss  2.476119048e+00 | Ndiff    2.484
   22.000/50 after      6 sec. |    179.0 MiB | K   20 | loss  2.474839041e+00 | Ndiff    2.706
   23.000/50 after      6 sec. |    179.0 MiB | K   20 | loss  2.473186959e+00 | Ndiff    1.637
   24.000/50 after      6 sec. |    179.0 MiB | K   20 | loss  2.471452324e+00 | Ndiff    2.002
   25.000/50 after      7 sec. |    179.0 MiB | K   20 | loss  2.471162626e+00 | Ndiff    0.496
   26.000/50 after      7 sec. |    179.0 MiB | K   20 | loss  2.470983405e+00 | Ndiff    0.705
   27.000/50 after      7 sec. |    179.0 MiB | K   20 | loss  2.470774342e+00 | Ndiff    0.502
   28.000/50 after      7 sec. |    179.0 MiB | K   20 | loss  2.469880859e+00 | Ndiff    2.126
   29.000/50 after      8 sec. |    179.0 MiB | K   20 | loss  2.468970839e+00 | Ndiff    0.799
   30.000/50 after      8 sec. |    179.0 MiB | K   20 | loss  2.467972189e+00 | Ndiff    0.604
   31.000/50 after      8 sec. |    179.0 MiB | K   20 | loss  2.467661612e+00 | Ndiff    1.243
   32.000/50 after      9 sec. |    179.0 MiB | K   20 | loss  2.467483397e+00 | Ndiff    1.372
   33.000/50 after      9 sec. |    179.0 MiB | K   20 | loss  2.467450274e+00 | Ndiff    0.804
   34.000/50 after      9 sec. |    179.0 MiB | K   20 | loss  2.467407757e+00 | Ndiff    0.662
   35.000/50 after      9 sec. |    179.0 MiB | K   20 | loss  2.467292085e+00 | Ndiff    0.699
   36.000/50 after     10 sec. |    179.0 MiB | K   20 | loss  2.467236037e+00 | Ndiff    0.388
   37.000/50 after     10 sec. |    179.0 MiB | K   20 | loss  2.467197165e+00 | Ndiff    0.399
   38.000/50 after     10 sec. |    179.0 MiB | K   20 | loss  2.467112796e+00 | Ndiff    0.632
   39.000/50 after     10 sec. |    179.0 MiB | K   20 | loss  2.466989563e+00 | Ndiff    1.476
   40.000/50 after     11 sec. |    179.0 MiB | K   20 | loss  2.466727224e+00 | Ndiff    1.857
   41.000/50 after     11 sec. |    179.0 MiB | K   20 | loss  2.466587267e+00 | Ndiff    1.298
   42.000/50 after     11 sec. |    179.0 MiB | K   20 | loss  2.466255511e+00 | Ndiff    1.694
   43.000/50 after     11 sec. |    179.0 MiB | K   20 | loss  2.466170740e+00 | Ndiff    1.741
   44.000/50 after     12 sec. |    179.0 MiB | K   20 | loss  2.466133408e+00 | Ndiff    1.208
   45.000/50 after     12 sec. |    179.0 MiB | K   20 | loss  2.466086219e+00 | Ndiff    1.084
   46.000/50 after     12 sec. |    179.0 MiB | K   20 | loss  2.466016271e+00 | Ndiff    1.161
   47.000/50 after     12 sec. |    179.0 MiB | K   20 | loss  2.465964032e+00 | Ndiff    1.311
   48.000/50 after     13 sec. |    179.0 MiB | K   20 | loss  2.465914037e+00 | Ndiff    1.176
   49.000/50 after     13 sec. |    179.0 MiB | K   20 | loss  2.465798065e+00 | Ndiff    0.767
   50.000/50 after     13 sec. |    179.0 MiB | K   20 | loss  2.465630615e+00 | Ndiff    0.558
... done. not converged. max laps thru data exceeded.

Compare loss function traces for all methods

pylab.figure()

pylab.plot(
    mixdiag_info_dict['lap_history'],
    mixdiag_info_dict['loss_history'], 'b.-',
    label='mix + diag gauss')
pylab.plot(
    hmmdiag_info_dict['lap_history'],
    hmmdiag_info_dict['loss_history'], 'k.-',
    label='hmm + diag gauss')
pylab.plot(
    hmmfull_info_dict['lap_history'],
    hmmfull_info_dict['loss_history'], 'r.-',
    label='hmm + full gauss')
pylab.plot(
    hmmar_info_dict['lap_history'],
    hmmar_info_dict['loss_history'], 'c.-',
    label='hmm + ar gauss')
pylab.legend(loc='upper right')
pylab.xlabel('num. laps')
pylab.ylabel('loss')
pylab.xlim([4, 100]) # avoid early iterations
pylab.ylim([2.4, 3.7]) # handpicked
pylab.draw()
pylab.tight_layout()
../../_images/sphx_glr_plot-01-demo=many_models_same_data_003.png

Total running time of the script: ( 0 minutes 33.511 seconds)

Generated by Sphinx-Gallery