Comparing models for sequential data

How to train mixtures and HMMs with various observation models on the same dataset.

# sphinx_gallery_thumbnail_number = 1

import bnpy
import numpy as np
import os

from matplotlib import pylab
import seaborn as sns

SMALL_FIG_SIZE = (2.5, 2.5)
FIG_SIZE = (5, 5)
pylab.rcParams['figure.figsize'] = FIG_SIZE

Load dataset from file

dataset_path = os.path.join(bnpy.DATASET_PATH, 'mocap6')
dataset = bnpy.data.GroupXData.read_npz(
    os.path.join(dataset_path, 'dataset.npz'))

Setup: Function to make a simple plot of the raw data

def show_single_sequence(seq_id):
    start = dataset.doc_range[seq_id]
    stop = dataset.doc_range[seq_id + 1]
    for dim in xrange(12):
        X_seq = dataset.X[start:stop]
        pylab.plot(X_seq[:, dim], '.-')
    pylab.xlabel('time')
    pylab.ylabel('angle')
    pylab.tight_layout()

Visualization of the first sequence

show_single_sequence(0)
../../_images/sphx_glr_plot-01-demo=many_models_same_data_001.png

Visualization of the second sequence

show_single_sequence(1)
../../_images/sphx_glr_plot-01-demo=many_models_same_data_002.png

Setup: hyperparameters

K = 20            # Number of clusters/states

gamma = 5.0       # top-level Dirichlet concentration parameter
transAlpha = 0.5  # trans-level Dirichlet concentration parameter

sF = 1.0          # Set observation model prior so E[covariance] = identity
ECovMat = 'eye'

DP mixture with DiagGauss observation model

mixdiag_trained_model, mixdiag_info_dict = bnpy.run(
    dataset, 'DPMixtureModel', 'DiagGauss', 'memoVB',
    output_path='/tmp/mocap6/showcase-K=20-model=DP+DiagGauss-ECovMat=1*eye/',
    nLap=50, nTask=1, nBatch=1, convergeThr=0.0001,
    gamma=gamma, sF=sF, ECovMat=ECovMat,
    K=K, initname='randexamples',
    )

Out:

Dataset Summary:
GroupXData
  total size: 6 units
  batch size: 6 units
  num. batches: 1
Allocation Model:  DP mixture with K=0. Concentration gamma0= 5.00
Obs. Data  Model:  Gaussian with diagonal covariance.
Obs. Data  Prior:  independent Gauss-Wishart prior on each dimension
  Wishart params
    nu = 14  ...
  beta = [ 12  12]  ...
  Expectations
  E[  mean[k]] =
  [ 0  0] ...
  E[ covar[k]] =
  [[1. 0.]
   [0. 1.]] ...
Initialization:
  initname = randexamples
  K = 20 (number of clusters)
  seed = 1607680
  elapsed_time: 0.0 sec
Learn Alg: memoVB | task  1/1 | alg. seed: 1607680 | data order seed: 8541952
task_output_path: /tmp/mocap6/showcase-K=20-model=DP+DiagGauss-ECovMat=1*eye/1
    1.000/50 after      0 sec. |    215.6 MiB | K   20 | loss  3.845424399e+00 |
    2.000/50 after      0 sec. |    215.6 MiB | K   20 | loss  3.756537185e+00 | Ndiff   26.569
    3.000/50 after      0 sec. |    215.6 MiB | K   20 | loss  3.723105766e+00 | Ndiff   29.336
    4.000/50 after      0 sec. |    215.6 MiB | K   20 | loss  3.703275957e+00 | Ndiff   19.798
    5.000/50 after      0 sec. |    215.6 MiB | K   20 | loss  3.689932307e+00 | Ndiff   22.929
    6.000/50 after      0 sec. |    215.6 MiB | K   20 | loss  3.681432839e+00 | Ndiff   24.626
    7.000/50 after      0 sec. |    215.6 MiB | K   20 | loss  3.676243541e+00 | Ndiff   20.961
    8.000/50 after      0 sec. |    215.6 MiB | K   20 | loss  3.672159855e+00 | Ndiff   16.818
    9.000/50 after      0 sec. |    215.6 MiB | K   20 | loss  3.669045283e+00 | Ndiff   13.011
   10.000/50 after      0 sec. |    215.6 MiB | K   20 | loss  3.665208654e+00 | Ndiff   13.962
   11.000/50 after      0 sec. |    215.6 MiB | K   20 | loss  3.661952478e+00 | Ndiff   16.751
   12.000/50 after      0 sec. |    215.6 MiB | K   20 | loss  3.658115620e+00 | Ndiff   17.355
   13.000/50 after      0 sec. |    215.6 MiB | K   20 | loss  3.655058982e+00 | Ndiff   12.150
   14.000/50 after      0 sec. |    215.6 MiB | K   20 | loss  3.653853641e+00 | Ndiff    5.448
   15.000/50 after      0 sec. |    215.6 MiB | K   20 | loss  3.652570589e+00 | Ndiff    5.226
   16.000/50 after      0 sec. |    215.6 MiB | K   20 | loss  3.651988134e+00 | Ndiff    5.605
   17.000/50 after      0 sec. |    215.6 MiB | K   20 | loss  3.651566377e+00 | Ndiff    5.980
   18.000/50 after      0 sec. |    215.6 MiB | K   20 | loss  3.651133556e+00 | Ndiff    6.283
   19.000/50 after      0 sec. |    215.6 MiB | K   20 | loss  3.650800669e+00 | Ndiff    6.427
   20.000/50 after      0 sec. |    215.6 MiB | K   20 | loss  3.650523019e+00 | Ndiff    6.354
   21.000/50 after      0 sec. |    215.6 MiB | K   20 | loss  3.650284547e+00 | Ndiff    6.121
   22.000/50 after      0 sec. |    215.6 MiB | K   20 | loss  3.650075236e+00 | Ndiff    5.571
   23.000/50 after      0 sec. |    215.6 MiB | K   20 | loss  3.649874747e+00 | Ndiff    4.678
   24.000/50 after      0 sec. |    215.6 MiB | K   20 | loss  3.649684370e+00 | Ndiff    3.694
   25.000/50 after      0 sec. |    215.6 MiB | K   20 | loss  3.649503048e+00 | Ndiff    2.678
   26.000/50 after      0 sec. |    215.6 MiB | K   20 | loss  3.649287038e+00 | Ndiff    2.238
   27.000/50 after      0 sec. |    215.6 MiB | K   20 | loss  3.648992642e+00 | Ndiff    1.703
   28.000/50 after      0 sec. |    215.6 MiB | K   20 | loss  3.648274600e+00 | Ndiff    1.704
   29.000/50 after      0 sec. |    215.6 MiB | K   20 | loss  3.647520307e+00 | Ndiff    1.649
   30.000/50 after      0 sec. |    215.6 MiB | K   20 | loss  3.647350990e+00 | Ndiff    1.528
   31.000/50 after      0 sec. |    215.6 MiB | K   20 | loss  3.647242457e+00 | Ndiff    1.494
   32.000/50 after      1 sec. |    215.6 MiB | K   20 | loss  3.647171899e+00 | Ndiff    1.870
   33.000/50 after      1 sec. |    215.6 MiB | K   20 | loss  3.647125709e+00 | Ndiff    2.031
   34.000/50 after      1 sec. |    215.6 MiB | K   20 | loss  3.647094615e+00 | Ndiff    2.051
   35.000/50 after      1 sec. |    215.6 MiB | K   20 | loss  3.647072676e+00 | Ndiff    1.974
   36.000/50 after      1 sec. |    215.6 MiB | K   20 | loss  3.647056795e+00 | Ndiff    1.830
   37.000/50 after      1 sec. |    215.6 MiB | K   20 | loss  3.647045269e+00 | Ndiff    1.651
   38.000/50 after      1 sec. |    215.6 MiB | K   20 | loss  3.647036962e+00 | Ndiff    1.460
   39.000/50 after      1 sec. |    215.6 MiB | K   20 | loss  3.647031023e+00 | Ndiff    1.273
   40.000/50 after      1 sec. |    215.6 MiB | K   20 | loss  3.647026801e+00 | Ndiff    1.099
   41.000/50 after      1 sec. |    215.6 MiB | K   20 | loss  3.647023808e+00 | Ndiff    0.943
   42.000/50 after      1 sec. |    215.6 MiB | K   20 | loss  3.647021689e+00 | Ndiff    0.806
   43.000/50 after      1 sec. |    215.6 MiB | K   20 | loss  3.647020187e+00 | Ndiff    0.687
   44.000/50 after      1 sec. |    215.6 MiB | K   20 | loss  3.647019120e+00 | Ndiff    0.584
   45.000/50 after      1 sec. |    215.6 MiB | K   20 | loss  3.647018361e+00 | Ndiff    0.496
   46.000/50 after      1 sec. |    215.6 MiB | K   20 | loss  3.647017819e+00 | Ndiff    0.421
   47.000/50 after      1 sec. |    215.6 MiB | K   20 | loss  3.647017432e+00 | Ndiff    0.357
   48.000/50 after      1 sec. |    215.6 MiB | K   20 | loss  3.647017154e+00 | Ndiff    0.303
   49.000/50 after      1 sec. |    215.6 MiB | K   20 | loss  3.647016955e+00 | Ndiff    0.258
   50.000/50 after      1 sec. |    215.6 MiB | K   20 | loss  3.647016811e+00 | Ndiff    0.219
... done. not converged. max laps thru data exceeded.

HDP-HMM with DiagGauss observation model

Assume diagonal covariances.

Start with too many clusters (K=20)

hmmdiag_trained_model, hmmdiag_info_dict = bnpy.run(
    dataset, 'HDPHMM', 'DiagGauss', 'memoVB',
    output_path='/tmp/mocap6/showcase-K=20-model=HDPHMM+DiagGauss-ECovMat=1*eye/',
    nLap=50, nTask=1, nBatch=1, convergeThr=0.0001,
    transAlpha=transAlpha, gamma=gamma, sF=sF, ECovMat=ECovMat,
    K=K, initname='randexamples',
    )

Out:

Dataset Summary:
GroupXData
  total size: 6 units
  batch size: 6 units
  num. batches: 1
Allocation Model:  None
Obs. Data  Model:  Gaussian with diagonal covariance.
Obs. Data  Prior:  independent Gauss-Wishart prior on each dimension
  Wishart params
    nu = 14  ...
  beta = [ 12  12]  ...
  Expectations
  E[  mean[k]] =
  [ 0  0] ...
  E[ covar[k]] =
  [[1. 0.]
   [0. 1.]] ...
Initialization:
  initname = randexamples
  K = 20 (number of clusters)
  seed = 1607680
  elapsed_time: 0.0 sec
Learn Alg: memoVB | task  1/1 | alg. seed: 1607680 | data order seed: 8541952
task_output_path: /tmp/mocap6/showcase-K=20-model=HDPHMM+DiagGauss-ECovMat=1*eye/1
    1.000/50 after      0 sec. |    220.2 MiB | K   20 | loss  3.717124121e+00 |
    2.000/50 after      0 sec. |    216.4 MiB | K   20 | loss  3.612329446e+00 | Ndiff   35.499
    3.000/50 after      1 sec. |    216.4 MiB | K   20 | loss  3.580489903e+00 | Ndiff   22.259
    4.000/50 after      1 sec. |    216.4 MiB | K   20 | loss  3.565890757e+00 | Ndiff   19.752
    5.000/50 after      1 sec. |    216.4 MiB | K   20 | loss  3.552817867e+00 | Ndiff   18.510
    6.000/50 after      1 sec. |    216.4 MiB | K   20 | loss  3.547626693e+00 | Ndiff    9.442
    7.000/50 after      1 sec. |    216.4 MiB | K   20 | loss  3.545715575e+00 | Ndiff    3.145
    8.000/50 after      1 sec. |    216.4 MiB | K   20 | loss  3.542710111e+00 | Ndiff    8.834
    9.000/50 after      2 sec. |    216.4 MiB | K   20 | loss  3.534107178e+00 | Ndiff   14.540
   10.000/50 after      2 sec. |    216.4 MiB | K   20 | loss  3.530991441e+00 | Ndiff    9.174
   11.000/50 after      2 sec. |    216.4 MiB | K   20 | loss  3.527520848e+00 | Ndiff    9.868
   12.000/50 after      2 sec. |    216.4 MiB | K   20 | loss  3.525056000e+00 | Ndiff   11.883
   13.000/50 after      2 sec. |    216.4 MiB | K   20 | loss  3.523781079e+00 | Ndiff    4.145
   14.000/50 after      2 sec. |    216.4 MiB | K   20 | loss  3.521694790e+00 | Ndiff    8.826
   15.000/50 after      3 sec. |    216.4 MiB | K   20 | loss  3.518502897e+00 | Ndiff    4.970
   16.000/50 after      3 sec. |    216.4 MiB | K   20 | loss  3.517515182e+00 | Ndiff    4.090
   17.000/50 after      3 sec. |    216.4 MiB | K   20 | loss  3.517412339e+00 | Ndiff    0.535
   18.000/50 after      3 sec. |    216.4 MiB | K   20 | loss  3.517396038e+00 | Ndiff    0.203
   19.000/50 after      3 sec. |    216.4 MiB | K   20 | loss  3.517386711e+00 | Ndiff    0.244
   20.000/50 after      4 sec. |    216.4 MiB | K   20 | loss  3.517380000e+00 | Ndiff    0.274
   21.000/50 after      4 sec. |    216.4 MiB | K   20 | loss  3.517375098e+00 | Ndiff    0.264
   22.000/50 after      4 sec. |    216.4 MiB | K   20 | loss  3.517371095e+00 | Ndiff    0.208
   23.000/50 after      4 sec. |    216.4 MiB | K   20 | loss  3.517367640e+00 | Ndiff    0.144
   24.000/50 after      4 sec. |    216.4 MiB | K   20 | loss  3.517364400e+00 | Ndiff    0.139
   25.000/50 after      5 sec. |    216.4 MiB | K   20 | loss  3.517361220e+00 | Ndiff    0.166
   26.000/50 after      5 sec. |    216.4 MiB | K   20 | loss  3.517357862e+00 | Ndiff    0.226
   27.000/50 after      5 sec. |    216.4 MiB | K   20 | loss  3.517353576e+00 | Ndiff    0.359
   28.000/50 after      5 sec. |    216.4 MiB | K   20 | loss  3.517347351e+00 | Ndiff    0.518
   29.000/50 after      6 sec. |    216.4 MiB | K   20 | loss  3.517338850e+00 | Ndiff    0.683
   30.000/50 after      6 sec. |    216.4 MiB | K   20 | loss  3.517329011e+00 | Ndiff    0.799
   31.000/50 after      6 sec. |    216.4 MiB | K   20 | loss  3.517319213e+00 | Ndiff    0.813
   32.000/50 after      6 sec. |    216.4 MiB | K   20 | loss  3.517310197e+00 | Ndiff    0.746
   33.000/50 after      6 sec. |    216.4 MiB | K   20 | loss  3.517301719e+00 | Ndiff    0.673
   34.000/50 after      7 sec. |    216.4 MiB | K   20 | loss  3.517293335e+00 | Ndiff    0.634
   35.000/50 after      7 sec. |    216.4 MiB | K   20 | loss  3.517285024e+00 | Ndiff    0.618
   36.000/50 after      7 sec. |    216.4 MiB | K   20 | loss  3.517277155e+00 | Ndiff    0.602
   37.000/50 after      7 sec. |    216.4 MiB | K   20 | loss  3.517270251e+00 | Ndiff    0.571
   38.000/50 after      8 sec. |    216.4 MiB | K   20 | loss  3.517264770e+00 | Ndiff    0.520
   39.000/50 after      8 sec. |    216.4 MiB | K   20 | loss  3.517260855e+00 | Ndiff    0.457
   40.000/50 after      8 sec. |    216.4 MiB | K   20 | loss  3.517258241e+00 | Ndiff    0.394
   41.000/50 after      8 sec. |    216.4 MiB | K   20 | loss  3.517256507e+00 | Ndiff    0.341
   42.000/50 after      8 sec. |    216.4 MiB | K   20 | loss  3.517255300e+00 | Ndiff    0.300
   43.000/50 after      9 sec. |    216.4 MiB | K   20 | loss  3.517254394e+00 | Ndiff    0.272
   44.000/50 after      9 sec. |    216.4 MiB | K   20 | loss  3.517253647e+00 | Ndiff    0.256
   45.000/50 after      9 sec. |    216.4 MiB | K   20 | loss  3.517252964e+00 | Ndiff    0.251
   46.000/50 after      9 sec. |    216.4 MiB | K   20 | loss  3.517252267e+00 | Ndiff    0.259
   47.000/50 after      9 sec. |    216.4 MiB | K   20 | loss  3.517251461e+00 | Ndiff    0.283
   48.000/50 after     10 sec. |    216.4 MiB | K   20 | loss  3.517250379e+00 | Ndiff    0.332
   49.000/50 after     10 sec. |    216.4 MiB | K   20 | loss  3.517248618e+00 | Ndiff    0.430
   50.000/50 after     10 sec. |    216.4 MiB | K   20 | loss  3.517244801e+00 | Ndiff    0.643
... done. not converged. max laps thru data exceeded.

HDP-HMM with Gauss observation model

Assume full covariances.

Start with too many clusters (K=20)

hmmfull_trained_model, hmmfull_info_dict = bnpy.run(
    dataset, 'HDPHMM', 'Gauss', 'memoVB',
    output_path='/tmp/mocap6/showcase-K=20-model=HDPHMM+Gauss-ECovMat=1*eye/',
    nLap=50, nTask=1, nBatch=1, convergeThr=0.0001,
    transAlpha=transAlpha, gamma=gamma, sF=sF, ECovMat=ECovMat,
    K=K, initname='randexamples',
    )

Out:

Dataset Summary:
GroupXData
  total size: 6 units
  batch size: 6 units
  num. batches: 1
Allocation Model:  None
Obs. Data  Model:  Gaussian with full covariance.
Obs. Data  Prior:  Gauss-Wishart on mean and covar of each cluster
  E[  mean[k] ] =
   [0. 0.]  ...
  E[ covar[k] ] =
  [[1. 0.]
   [0. 1.]] ...
Initialization:
  initname = randexamples
  K = 20 (number of clusters)
  seed = 1607680
  elapsed_time: 0.0 sec
Learn Alg: memoVB | task  1/1 | alg. seed: 1607680 | data order seed: 8541952
task_output_path: /tmp/mocap6/showcase-K=20-model=HDPHMM+Gauss-ECovMat=1*eye/1
    1.000/50 after      0 sec. |    216.5 MiB | K   20 | loss  3.453792348e+00 |
    2.000/50 after      1 sec. |    216.5 MiB | K   20 | loss  3.332556042e+00 | Ndiff   27.128
    3.000/50 after      1 sec. |    216.5 MiB | K   20 | loss  3.280863302e+00 | Ndiff   33.040
    4.000/50 after      1 sec. |    216.5 MiB | K   20 | loss  3.247802919e+00 | Ndiff   27.962
    5.000/50 after      1 sec. |    216.5 MiB | K   20 | loss  3.232386544e+00 | Ndiff   16.663
    6.000/50 after      2 sec. |    216.5 MiB | K   20 | loss  3.214482068e+00 | Ndiff   24.147
    7.000/50 after      2 sec. |    216.5 MiB | K   20 | loss  3.202981898e+00 | Ndiff   20.209
    8.000/50 after      2 sec. |    216.5 MiB | K   20 | loss  3.198209231e+00 | Ndiff    9.843
    9.000/50 after      2 sec. |    216.5 MiB | K   20 | loss  3.196602056e+00 | Ndiff    4.055
   10.000/50 after      3 sec. |    216.5 MiB | K   20 | loss  3.194049899e+00 | Ndiff    3.639
   11.000/50 after      3 sec. |    216.5 MiB | K   20 | loss  3.192384953e+00 | Ndiff    3.209
   12.000/50 after      3 sec. |    216.5 MiB | K   20 | loss  3.190378614e+00 | Ndiff    3.166
   13.000/50 after      3 sec. |    216.5 MiB | K   20 | loss  3.188801039e+00 | Ndiff    4.323
   14.000/50 after      4 sec. |    216.5 MiB | K   20 | loss  3.187980225e+00 | Ndiff    2.014
   15.000/50 after      4 sec. |    216.5 MiB | K   20 | loss  3.187267197e+00 | Ndiff    1.176
   16.000/50 after      4 sec. |    216.5 MiB | K   20 | loss  3.187104503e+00 | Ndiff    0.703
   17.000/50 after      4 sec. |    216.5 MiB | K   20 | loss  3.186910658e+00 | Ndiff    0.921
   18.000/50 after      4 sec. |    216.5 MiB | K   20 | loss  3.186639962e+00 | Ndiff    1.257
   19.000/50 after      5 sec. |    216.5 MiB | K   20 | loss  3.186450784e+00 | Ndiff    0.927
   20.000/50 after      5 sec. |    216.5 MiB | K   20 | loss  3.185973787e+00 | Ndiff    1.523
   21.000/50 after      5 sec. |    216.5 MiB | K   20 | loss  3.185779780e+00 | Ndiff    0.970
   22.000/50 after      5 sec. |    216.5 MiB | K   20 | loss  3.185453273e+00 | Ndiff    0.954
   23.000/50 after      6 sec. |    216.5 MiB | K   20 | loss  3.182363213e+00 | Ndiff    2.429
   24.000/50 after      6 sec. |    216.5 MiB | K   20 | loss  3.182200890e+00 | Ndiff    1.243
   25.000/50 after      6 sec. |    216.5 MiB | K   20 | loss  3.182084494e+00 | Ndiff    1.106
   26.000/50 after      6 sec. |    216.5 MiB | K   20 | loss  3.182015005e+00 | Ndiff    0.783
   27.000/50 after      7 sec. |    216.5 MiB | K   20 | loss  3.181989580e+00 | Ndiff    0.396
   28.000/50 after      7 sec. |    216.5 MiB | K   20 | loss  3.181974836e+00 | Ndiff    0.244
   29.000/50 after      7 sec. |    216.5 MiB | K   20 | loss  3.181947074e+00 | Ndiff    0.356
   30.000/50 after      7 sec. |    216.5 MiB | K   20 | loss  3.181838450e+00 | Ndiff    0.658
   31.000/50 after      8 sec. |    216.5 MiB | K   20 | loss  3.181817711e+00 | Ndiff    0.226
   32.000/50 after      8 sec. |    216.5 MiB | K   20 | loss  3.181814701e+00 | Ndiff    0.115
   33.000/50 after      8 sec. |    216.5 MiB | K   20 | loss  3.181813740e+00 | Ndiff    0.058
   34.000/50 after      8 sec. |    216.5 MiB | K   20 | loss  3.181813509e+00 | Ndiff    0.026
   35.000/50 after      8 sec. |    216.5 MiB | K   20 | loss  3.181813438e+00 | Ndiff    0.018
   36.000/50 after      9 sec. |    216.5 MiB | K   20 | loss  3.181813408e+00 | Ndiff    0.012
   37.000/50 after      9 sec. |    216.5 MiB | K   20 | loss  3.181813395e+00 | Ndiff    0.009
   38.000/50 after      9 sec. |    216.5 MiB | K   20 | loss  3.181813388e+00 | Ndiff    0.006
   39.000/50 after      9 sec. |    216.5 MiB | K   20 | loss  3.181813385e+00 | Ndiff    0.004
   40.000/50 after     10 sec. |    216.5 MiB | K   20 | loss  3.181813383e+00 | Ndiff    0.003
   41.000/50 after     10 sec. |    216.5 MiB | K   20 | loss  3.181813382e+00 | Ndiff    0.002
   42.000/50 after     10 sec. |    216.5 MiB | K   20 | loss  3.181813382e+00 | Ndiff    0.002
   43.000/50 after     10 sec. |    216.5 MiB | K   20 | loss  3.181813381e+00 | Ndiff    0.001
   44.000/50 after     11 sec. |    216.5 MiB | K   20 | loss  3.181813381e+00 | Ndiff    0.001
   45.000/50 after     11 sec. |    216.5 MiB | K   20 | loss  3.181813381e+00 | Ndiff    0.001
   46.000/50 after     11 sec. |    216.5 MiB | K   20 | loss  3.181813381e+00 | Ndiff    0.001
   47.000/50 after     11 sec. |    216.5 MiB | K   20 | loss  3.181813381e+00 | Ndiff    0.000
   48.000/50 after     11 sec. |    216.5 MiB | K   20 | loss  3.181813381e+00 | Ndiff    0.000
   49.000/50 after     12 sec. |    216.5 MiB | K   20 | loss  3.181813381e+00 | Ndiff    0.000
   50.000/50 after     12 sec. |    216.5 MiB | K   20 | loss  3.181813381e+00 | Ndiff    0.000
... done. not converged. max laps thru data exceeded.

HDP-HMM with AutoRegGauss observation model

Assume full covariances.

Start with too many clusters (K=20)

hmmar_trained_model, hmmar_info_dict = bnpy.run(
    dataset, 'HDPHMM', 'AutoRegGauss', 'memoVB',
    output_path='/tmp/mocap6/showcase-K=20-model=HDPHMM+AutoRegGauss-ECovMat=1*eye/',
    nLap=50, nTask=1, nBatch=1, convergeThr=0.0001,
    transAlpha=transAlpha, gamma=gamma, sF=sF, ECovMat=ECovMat,
    K=K, initname='randexamples',
    )

Out:

Dataset Summary:
GroupXData
  total size: 6 units
  batch size: 6 units
  num. batches: 1
Allocation Model:  None
Obs. Data  Model:  Auto-Regressive Gaussian with full covariance.
Obs. Data  Prior:  MatrixNormal-Wishart on each mean/prec matrix pair: A, Lam
  E[ A ] =
  [[1. 0.]
   [0. 1.]] ...
  E[ Sigma ] =
  [[1. 0.]
   [0. 1.]] ...
Initialization:
  initname = randexamples
  K = 20 (number of clusters)
  seed = 1607680
  elapsed_time: 0.0 sec
Learn Alg: memoVB | task  1/1 | alg. seed: 1607680 | data order seed: 8541952
task_output_path: /tmp/mocap6/showcase-K=20-model=HDPHMM+AutoRegGauss-ECovMat=1*eye/1
    1.000/50 after      0 sec. |    217.3 MiB | K   20 | loss  2.908127656e+00 |
    2.000/50 after      1 sec. |    217.3 MiB | K   20 | loss  2.712797722e+00 | Ndiff   53.407
    3.000/50 after      1 sec. |    217.3 MiB | K   20 | loss  2.623844667e+00 | Ndiff   51.736
    4.000/50 after      1 sec. |    217.3 MiB | K   20 | loss  2.579060900e+00 | Ndiff   47.749
    5.000/50 after      1 sec. |    217.3 MiB | K   20 | loss  2.548616789e+00 | Ndiff   37.879
    6.000/50 after      2 sec. |    217.3 MiB | K   20 | loss  2.518588127e+00 | Ndiff   28.737
    7.000/50 after      2 sec. |    217.3 MiB | K   20 | loss  2.504804864e+00 | Ndiff   10.223
    8.000/50 after      2 sec. |    217.3 MiB | K   20 | loss  2.496420689e+00 | Ndiff    3.905
    9.000/50 after      2 sec. |    217.3 MiB | K   20 | loss  2.492445130e+00 | Ndiff    1.966
   10.000/50 after      3 sec. |    217.3 MiB | K   20 | loss  2.490165194e+00 | Ndiff    1.965
   11.000/50 after      3 sec. |    217.3 MiB | K   20 | loss  2.485730707e+00 | Ndiff    4.873
   12.000/50 after      3 sec. |    217.3 MiB | K   20 | loss  2.484287434e+00 | Ndiff    2.107
   13.000/50 after      3 sec. |    217.3 MiB | K   20 | loss  2.482746506e+00 | Ndiff    2.765
   14.000/50 after      4 sec. |    217.3 MiB | K   20 | loss  2.480880306e+00 | Ndiff    2.160
   15.000/50 after      4 sec. |    217.3 MiB | K   20 | loss  2.480467067e+00 | Ndiff    1.079
   16.000/50 after      4 sec. |    217.3 MiB | K   20 | loss  2.479991235e+00 | Ndiff    1.090
   17.000/50 after      4 sec. |    217.3 MiB | K   20 | loss  2.479792583e+00 | Ndiff    1.188
   18.000/50 after      5 sec. |    217.3 MiB | K   20 | loss  2.478520054e+00 | Ndiff    2.398
   19.000/50 after      5 sec. |    217.3 MiB | K   20 | loss  2.477365552e+00 | Ndiff    1.471
   20.000/50 after      5 sec. |    217.3 MiB | K   20 | loss  2.476525301e+00 | Ndiff    2.496
   21.000/50 after      5 sec. |    217.3 MiB | K   20 | loss  2.476119048e+00 | Ndiff    2.484
   22.000/50 after      6 sec. |    217.3 MiB | K   20 | loss  2.474839041e+00 | Ndiff    2.706
   23.000/50 after      6 sec. |    217.3 MiB | K   20 | loss  2.473186959e+00 | Ndiff    1.637
   24.000/50 after      6 sec. |    217.3 MiB | K   20 | loss  2.471452324e+00 | Ndiff    2.002
   25.000/50 after      6 sec. |    217.3 MiB | K   20 | loss  2.471162626e+00 | Ndiff    0.496
   26.000/50 after      6 sec. |    217.3 MiB | K   20 | loss  2.470983405e+00 | Ndiff    0.705
   27.000/50 after      7 sec. |    217.3 MiB | K   20 | loss  2.470774342e+00 | Ndiff    0.502
   28.000/50 after      7 sec. |    217.3 MiB | K   20 | loss  2.469880859e+00 | Ndiff    2.126
   29.000/50 after      7 sec. |    217.3 MiB | K   20 | loss  2.468970839e+00 | Ndiff    0.799
   30.000/50 after      7 sec. |    217.3 MiB | K   20 | loss  2.467972189e+00 | Ndiff    0.604
   31.000/50 after      7 sec. |    217.3 MiB | K   20 | loss  2.467661612e+00 | Ndiff    1.243
   32.000/50 after      8 sec. |    217.3 MiB | K   20 | loss  2.467483397e+00 | Ndiff    1.372
   33.000/50 after      8 sec. |    217.3 MiB | K   20 | loss  2.467450274e+00 | Ndiff    0.804
   34.000/50 after      8 sec. |    217.3 MiB | K   20 | loss  2.467407757e+00 | Ndiff    0.662
   35.000/50 after      8 sec. |    217.3 MiB | K   20 | loss  2.467292085e+00 | Ndiff    0.699
   36.000/50 after      8 sec. |    217.3 MiB | K   20 | loss  2.467236037e+00 | Ndiff    0.388
   37.000/50 after      9 sec. |    217.3 MiB | K   20 | loss  2.467197165e+00 | Ndiff    0.399
   38.000/50 after      9 sec. |    217.3 MiB | K   20 | loss  2.467112796e+00 | Ndiff    0.632
   39.000/50 after      9 sec. |    217.3 MiB | K   20 | loss  2.466989563e+00 | Ndiff    1.476
   40.000/50 after      9 sec. |    217.3 MiB | K   20 | loss  2.466727224e+00 | Ndiff    1.857
   41.000/50 after      9 sec. |    217.3 MiB | K   20 | loss  2.466587267e+00 | Ndiff    1.298
   42.000/50 after     10 sec. |    217.3 MiB | K   20 | loss  2.466255511e+00 | Ndiff    1.694
   43.000/50 after     10 sec. |    217.3 MiB | K   20 | loss  2.466170740e+00 | Ndiff    1.741
   44.000/50 after     10 sec. |    217.3 MiB | K   20 | loss  2.466133408e+00 | Ndiff    1.208
   45.000/50 after     10 sec. |    217.3 MiB | K   20 | loss  2.466086219e+00 | Ndiff    1.084
   46.000/50 after     10 sec. |    217.3 MiB | K   20 | loss  2.466016271e+00 | Ndiff    1.161
   47.000/50 after     11 sec. |    217.3 MiB | K   20 | loss  2.465964032e+00 | Ndiff    1.311
   48.000/50 after     11 sec. |    217.3 MiB | K   20 | loss  2.465914037e+00 | Ndiff    1.176
   49.000/50 after     11 sec. |    217.3 MiB | K   20 | loss  2.465798065e+00 | Ndiff    0.767
   50.000/50 after     11 sec. |    217.3 MiB | K   20 | loss  2.465630615e+00 | Ndiff    0.558
... done. not converged. max laps thru data exceeded.

Compare loss function traces for all methods

pylab.figure()

pylab.plot(
    mixdiag_info_dict['lap_history'],
    mixdiag_info_dict['loss_history'], 'b.-',
    label='mix + diag gauss')
pylab.plot(
    hmmdiag_info_dict['lap_history'],
    hmmdiag_info_dict['loss_history'], 'k.-',
    label='hmm + diag gauss')
pylab.plot(
    hmmfull_info_dict['lap_history'],
    hmmfull_info_dict['loss_history'], 'r.-',
    label='hmm + full gauss')
pylab.plot(
    hmmar_info_dict['lap_history'],
    hmmar_info_dict['loss_history'], 'c.-',
    label='hmm + ar gauss')
pylab.legend(loc='upper right')
pylab.xlabel('num. laps')
pylab.ylabel('loss')
pylab.xlim([4, 100]) # avoid early iterations
pylab.ylim([2.4, 3.7]) # handpicked
pylab.draw()
pylab.tight_layout()
../../_images/sphx_glr_plot-01-demo=many_models_same_data_003.png

Total running time of the script: ( 0 minutes 34.599 seconds)

Gallery generated by Sphinx-Gallery