02: Training DP mixture model with birth and merge proposals

How to train a DP mixture of multinomials.

import bnpy
import numpy as np
import os

from matplotlib import pylab
import seaborn as sns

FIG_SIZE = (3, 3)
SMALL_FIG_SIZE = (1.5, 1.5)
pylab.rcParams['figure.figsize'] = FIG_SIZE

Read dataset from file.

dataset_path = os.path.join(bnpy.DATASET_PATH, 'bars_one_per_doc')
dataset = bnpy.data.BagOfWordsData.read_npz(
    os.path.join(dataset_path, 'dataset.npz'))

Make a simple plot of the raw data

X_csr_DV = dataset.getSparseDocTypeCountMatrix()
bnpy.viz.BarsViz.show_square_images(
    X_csr_DV[:10].toarray(), vmin=0, vmax=5)
pylab.tight_layout()
../../_images/sphx_glr_plot-02-demo=vb+proposals-model=dp_mix+mult_001.png

Setup: Function to show bars from start to end of training run

def show_bars_over_time(
        task_output_path=None,
        query_laps=[0, 1, 2, 5, None],
        ncols=10):
    '''
    '''
    nrows = len(query_laps)
    fig_handle, ax_handles_RC = pylab.subplots(
        figsize=(SMALL_FIG_SIZE[0] * ncols, SMALL_FIG_SIZE[1] * nrows),
        nrows=nrows, ncols=ncols, sharex=True, sharey=True)
    for row_id, lap_val in enumerate(query_laps):
        cur_model, lap_val = bnpy.load_model_at_lap(task_output_path, lap_val)
        cur_topics_KV = cur_model.obsModel.getTopics()
        # Plot the current model
        cur_ax_list = ax_handles_RC[row_id].flatten().tolist()
        bnpy.viz.BarsViz.show_square_images(
            cur_topics_KV,
            vmin=0.0, vmax=0.1,
            ax_list=cur_ax_list)
        cur_ax_list[0].set_ylabel("lap: %d" % lap_val)
    pylab.tight_layout()

From K=2 initial clusters

Using random initialization

initname = 'randomlikewang'
K = 2

K2_trained_model, K2_info_dict = bnpy.run(
    dataset, 'DPMixtureModel', 'Mult', 'memoVB',
    output_path='/tmp/bars_one_per_doc/trymoves-K=%d-initname=%s/' % (
        K, initname),
    nTask=1, nLap=50, convergeThr=0.0001, nBatch=1,
    K=K, initname=initname,
    gamma0=50.0, lam=0.1,
    moves='birth,merge,shuffle,delete',
    b_startLap=2,
    m_startLap=5,
    d_startLap=10)

show_bars_over_time(K2_info_dict['task_output_path'])
../../_images/sphx_glr_plot-02-demo=vb+proposals-model=dp_mix+mult_002.png

Out:

Dataset Summary:
BagOfWordsData
  total size: 2000 units
  batch size: 2000 units
  num. batches: 1
Allocation Model:  DP mixture with K=0. Concentration gamma0= 50.00
Obs. Data  Model:  Multinomial over finite vocabulary.
Obs. Data  Prior:  Dirichlet over finite vocabulary
  lam = [ 0.1  0.1] ...
Initialization:
  initname = randomlikewang
  K = 2 (number of clusters)
  seed = 1607680
  elapsed_time: 0.0 sec
Learn Alg: memoVB | task  1/1 | alg. seed: 1607680 | data order seed: 8541952
task_output_path: /tmp/bars_one_per_doc/trymoves-K=2-initname=randomlikewang/1
BIRTH @ lap 1.00: Disabled. Waiting for lap >= 2 (--b_startLap).
MERGE @ lap 1.00: Disabled. Cannot plan merge on first lap. Need valid SS that represent whole dataset.
DELETE @ lap 1.00: Disabled. Cannot delete before first complete lap, because SS that represents whole dataset is required.
    1.000/50 after      0 sec. |    166.4 MiB | K    2 | loss  4.852914927e+00 |
MERGE @ lap 2.00: Disabled. Waiting for lap >= 5 (--m_startLap).
DELETE @ lap 2.00: Disabled. Waiting for lap >= 10 (--d_startLap).
BIRTH @ lap 2.00 : Added 6 states. 2/2 succeeded. 0/2 failed eval phase. 0/2 failed build phase.
    2.000/50 after      0 sec. |    166.4 MiB | K    8 | loss  4.125791748e+00 |
MERGE @ lap 3.00: Disabled. Waiting for lap >= 5 (--m_startLap).
DELETE @ lap 3.00: Disabled. Waiting for lap >= 10 (--d_startLap).
BIRTH @ lap 3.00 : Added 0 states. 0/6 succeeded. 0/6 failed eval phase. 6/6 failed build phase.
    3.000/50 after      1 sec. |    166.4 MiB | K    8 | loss  4.125791567e+00 | Ndiff    0.000
MERGE @ lap 4.00: Disabled. Waiting for lap >= 5 (--m_startLap).
DELETE @ lap 4.00: Disabled. Waiting for lap >= 10 (--d_startLap).
BIRTH @ lap 4.000 : None attempted. 6 past failures. 2 too small. 0 too busy.
    4.000/50 after      1 sec. |    166.4 MiB | K    8 | loss  4.125791567e+00 | Ndiff    0.000
DELETE @ lap 5.00: Disabled. Waiting for lap >= 10 (--d_startLap).
BIRTH @ lap 5.000 : None attempted. 0 past failures. 0 too small. 8 too busy.
MERGE @ lap 5.00 : 1/16 accepted. Ndiff 0.00. 0 skipped.
    5.000/50 after      1 sec. |    166.4 MiB | K    7 | loss  4.125791567e+00 | Ndiff    0.000
DELETE @ lap 6.00: Disabled. Waiting for lap >= 10 (--d_startLap).
BIRTH @ lap 6.000 : None attempted. 1 past failures. 0 too small. 6 too busy.
MERGE @ lap 6.00 : 1/5 accepted. Ndiff 0.00. 0 skipped.
    6.000/50 after      1 sec. |    166.4 MiB | K    6 | loss  4.125791567e+00 | Ndiff    0.000
DELETE @ lap 7.00: Disabled. Waiting for lap >= 10 (--d_startLap).
BIRTH @ lap 7.000 : None attempted. 0 past failures. 0 too small. 6 too busy.
MERGE @ lap 7.00 : 0/5 accepted. Ndiff 0.00. 0 skipped.
    7.000/50 after      1 sec. |    166.4 MiB | K    6 | loss  4.125791567e+00 | Ndiff    0.000
MERGE @ lap 8.00: No promising candidates, so no attempts.
DELETE @ lap 8.00: Disabled. Waiting for lap >= 10 (--d_startLap).
BIRTH @ lap 8.00 : Added 0 states. 0/1 succeeded. 0/1 failed eval phase. 1/1 failed build phase.
    8.000/50 after      2 sec. |    166.4 MiB | K    6 | loss  4.125791567e+00 | Ndiff    0.000
MERGE @ lap 9.00: No promising candidates, so no attempts.
DELETE @ lap 9.00: Disabled. Waiting for lap >= 10 (--d_startLap).
BIRTH @ lap 9.000 : None attempted. 6 past failures. 0 too small. 0 too busy.
    9.000/50 after      2 sec. |    166.4 MiB | K    6 | loss  4.125791567e+00 | Ndiff    0.000
BIRTH @ lap 10.000 : None attempted. 0 past failures. 0 too small. 6 too busy.
DELETE @ lap 10.00: 0/1 accepted. Ndiff 0.00.
MERGE @ lap 10.00 : 0/10 accepted. Ndiff 0.00. 0 skipped.
   10.000/50 after      2 sec. |    166.4 MiB | K    6 | loss  4.125791567e+00 | Ndiff    0.000
MERGE @ lap 11.00: No promising candidates, so no attempts.
BIRTH @ lap 11.000 : None attempted. 0 past failures. 0 too small. 6 too busy.
DELETE @ lap 11.00: 0/1 accepted. Ndiff 0.00.
   11.000/50 after      2 sec. |    166.4 MiB | K    6 | loss  4.125791567e+00 | Ndiff    0.000
BIRTH @ lap 12.000 : None attempted. 0 past failures. 0 too small. 6 too busy.
DELETE @ lap 12.00: 0/1 accepted. Ndiff 0.00.
MERGE @ lap 12.00 : 0/5 accepted. Ndiff 0.00. 0 skipped.
   12.000/50 after      2 sec. |    166.4 MiB | K    6 | loss  4.125791567e+00 | Ndiff    0.000
MERGE @ lap 13.00: No promising candidates, so no attempts.
BIRTH @ lap 13.000 : None attempted. 0 past failures. 0 too small. 6 too busy.
DELETE @ lap 13.00: 0/1 accepted. Ndiff 0.00.
   13.000/50 after      2 sec. |    166.4 MiB | K    6 | loss  4.125791567e+00 | Ndiff    0.000
MERGE @ lap 14.00: No promising candidates, so no attempts.
BIRTH @ lap 14.000 : None attempted. 0 past failures. 0 too small. 6 too busy.
DELETE @ lap 14.00: 0/1 accepted. Ndiff 0.00.
   14.000/50 after      2 sec. |    166.4 MiB | K    6 | loss  4.125791567e+00 | Ndiff    0.000
BIRTH @ lap 15.000 : None attempted. 0 past failures. 0 too small. 6 too busy.
DELETE @ lap 15.00: 0/1 accepted. Ndiff 0.00.
MERGE @ lap 15.00 : 0/10 accepted. Ndiff 0.00. 0 skipped.
   15.000/50 after      2 sec. |    166.4 MiB | K    6 | loss  4.125791567e+00 | Ndiff    0.000
MERGE @ lap 16.00: No promising candidates, so no attempts.
BIRTH @ lap 16.000 : None attempted. 0 past failures. 0 too small. 6 too busy.
DELETE @ lap 16.00: 0/1 accepted. Ndiff 0.00.
   16.000/50 after      2 sec. |    166.4 MiB | K    6 | loss  4.125791567e+00 | Ndiff    0.000
BIRTH @ lap 17.000 : None attempted. 0 past failures. 0 too small. 6 too busy.
DELETE @ lap 17.00: 0/1 accepted. Ndiff 0.00.
MERGE @ lap 17.00 : 0/5 accepted. Ndiff 0.00. 0 skipped.
   17.000/50 after      3 sec. |    166.4 MiB | K    6 | loss  4.125791567e+00 | Ndiff    0.000
MERGE @ lap 18.00: No promising candidates, so no attempts.
BIRTH @ lap 18.000 : None attempted. 0 past failures. 0 too small. 6 too busy.
DELETE @ lap 18.00: 0/1 accepted. Ndiff 0.00.
   18.000/50 after      3 sec. |    166.4 MiB | K    6 | loss  4.125791567e+00 | Ndiff    0.000
MERGE @ lap 19.00: No promising candidates, so no attempts.
BIRTH @ lap 19.000 : None attempted. 0 past failures. 0 too small. 6 too busy.
DELETE @ lap 19.00: 0/1 accepted. Ndiff 0.00.
   19.000/50 after      3 sec. |    166.4 MiB | K    6 | loss  4.125791567e+00 | Ndiff    0.000
BIRTH @ lap 20.000 : None attempted. 0 past failures. 0 too small. 6 too busy.
DELETE @ lap 20.00: 0/1 accepted. Ndiff 0.00.
MERGE @ lap 20.00 : 0/10 accepted. Ndiff 0.00. 0 skipped.
   20.000/50 after      3 sec. |    166.4 MiB | K    6 | loss  4.125791567e+00 | Ndiff    0.000
MERGE @ lap 21.00: No promising candidates, so no attempts.
BIRTH @ lap 21.000 : None attempted. 0 past failures. 0 too small. 6 too busy.
DELETE @ lap 21.00: 0/1 accepted. Ndiff 0.00.
   21.000/50 after      3 sec. |    166.4 MiB | K    6 | loss  4.125791567e+00 | Ndiff    0.000
... done. converged.

K=10 initial clusters

Using random initialization

initname = 'randomlikewang'
K = 10

K10_trained_model, K10_info_dict = bnpy.run(
    dataset, 'DPMixtureModel', 'Mult', 'memoVB',
    output_path='/tmp/bars_one_per_doc/trymoves-K=%d-initname=%s/' % (
        K, initname),
    nTask=1, nLap=50, convergeThr=0.0001, nBatch=1,
    K=K, initname=initname,
    gamma0=50.0, lam=0.1,
    moves='birth,merge,shuffle,delete',
    b_startLap=2,
    m_startLap=5,
    d_startLap=10)

show_bars_over_time(K10_info_dict['task_output_path'])
../../_images/sphx_glr_plot-02-demo=vb+proposals-model=dp_mix+mult_003.png

Out:

Dataset Summary:
BagOfWordsData
  total size: 2000 units
  batch size: 2000 units
  num. batches: 1
Allocation Model:  DP mixture with K=0. Concentration gamma0= 50.00
Obs. Data  Model:  Multinomial over finite vocabulary.
Obs. Data  Prior:  Dirichlet over finite vocabulary
  lam = [ 0.1  0.1] ...
Initialization:
  initname = randomlikewang
  K = 10 (number of clusters)
  seed = 1607680
  elapsed_time: 0.0 sec
Learn Alg: memoVB | task  1/1 | alg. seed: 1607680 | data order seed: 8541952
task_output_path: /tmp/bars_one_per_doc/trymoves-K=10-initname=randomlikewang/1
BIRTH @ lap 1.00: Disabled. Waiting for lap >= 2 (--b_startLap).
MERGE @ lap 1.00: Disabled. Cannot plan merge on first lap. Need valid SS that represent whole dataset.
DELETE @ lap 1.00: Disabled. Cannot delete before first complete lap, because SS that represents whole dataset is required.
    1.000/50 after      0 sec. |    173.4 MiB | K   10 | loss  4.712365207e+00 |
MERGE @ lap 2.00: Disabled. Waiting for lap >= 5 (--m_startLap).
DELETE @ lap 2.00: Disabled. Waiting for lap >= 10 (--d_startLap).
BIRTH @ lap 2.00 : Added 2 states. 1/6 succeeded. 0/6 failed eval phase. 5/6 failed build phase.
    2.000/50 after      1 sec. |    174.0 MiB | K   12 | loss  4.128419662e+00 |
MERGE @ lap 3.00: Disabled. Waiting for lap >= 5 (--m_startLap).
DELETE @ lap 3.00: Disabled. Waiting for lap >= 10 (--d_startLap).
BIRTH @ lap 3.00 : Added 0 states. 0/4 succeeded. 0/4 failed eval phase. 4/4 failed build phase.
    3.000/50 after      1 sec. |    174.0 MiB | K   12 | loss  4.127318928e+00 | Ndiff   17.626
MERGE @ lap 4.00: Disabled. Waiting for lap >= 5 (--m_startLap).
DELETE @ lap 4.00: Disabled. Waiting for lap >= 10 (--d_startLap).
BIRTH @ lap 4.00 : Added 0 states. 0/2 succeeded. 0/2 failed eval phase. 2/2 failed build phase.
    4.000/50 after      1 sec. |    174.0 MiB | K   12 | loss  4.127075221e+00 | Ndiff   11.299
DELETE @ lap 5.00: Disabled. Waiting for lap >= 10 (--d_startLap).
BIRTH @ lap 5.000 : None attempted. 0 past failures. 0 too small. 12 too busy.
MERGE @ lap 5.00 : 3/18 accepted. Ndiff 0.00. 12 skipped.
    5.000/50 after      2 sec. |    174.0 MiB | K    9 | loss  4.126920998e+00 | Ndiff   11.299
DELETE @ lap 6.00: Disabled. Waiting for lap >= 10 (--d_startLap).
BIRTH @ lap 6.000 : None attempted. 1 past failures. 0 too small. 8 too busy.
MERGE @ lap 6.00 : 3/5 accepted. Ndiff 21.29. 10 skipped.
    6.000/50 after      2 sec. |    174.0 MiB | K    6 | loss  4.125791567e+00 | Ndiff   11.299
DELETE @ lap 7.00: Disabled. Waiting for lap >= 10 (--d_startLap).
BIRTH @ lap 7.000 : None attempted. 0 past failures. 0 too small. 6 too busy.
MERGE @ lap 7.00 : 0/9 accepted. Ndiff 0.00. 0 skipped.
    7.000/50 after      2 sec. |    174.0 MiB | K    6 | loss  4.125791567e+00 | Ndiff    0.000
MERGE @ lap 8.00: No promising candidates, so no attempts.
DELETE @ lap 8.00: Disabled. Waiting for lap >= 10 (--d_startLap).
BIRTH @ lap 8.00 : Added 0 states. 0/2 succeeded. 0/2 failed eval phase. 2/2 failed build phase.
    8.000/50 after      2 sec. |    174.0 MiB | K    6 | loss  4.125791567e+00 | Ndiff    0.000
MERGE @ lap 9.00: No promising candidates, so no attempts.
DELETE @ lap 9.00: Disabled. Waiting for lap >= 10 (--d_startLap).
BIRTH @ lap 9.000 : None attempted. 6 past failures. 0 too small. 0 too busy.
    9.000/50 after      2 sec. |    174.0 MiB | K    6 | loss  4.125791567e+00 | Ndiff    0.000
BIRTH @ lap 10.000 : None attempted. 0 past failures. 0 too small. 6 too busy.
DELETE @ lap 10.00: 0/1 accepted. Ndiff 0.00.
MERGE @ lap 10.00 : 0/6 accepted. Ndiff 0.00. 0 skipped.
   10.000/50 after      2 sec. |    174.0 MiB | K    6 | loss  4.125791567e+00 | Ndiff    0.000
MERGE @ lap 11.00: No promising candidates, so no attempts.
BIRTH @ lap 11.000 : None attempted. 0 past failures. 0 too small. 6 too busy.
DELETE @ lap 11.00: 0/1 accepted. Ndiff 0.00.
   11.000/50 after      2 sec. |    174.0 MiB | K    6 | loss  4.125791567e+00 | Ndiff    0.000
BIRTH @ lap 12.000 : None attempted. 0 past failures. 0 too small. 6 too busy.
DELETE @ lap 12.00: 0/1 accepted. Ndiff 0.00.
MERGE @ lap 12.00 : 0/9 accepted. Ndiff 0.00. 0 skipped.
   12.000/50 after      3 sec. |    174.0 MiB | K    6 | loss  4.125791567e+00 | Ndiff    0.000
MERGE @ lap 13.00: No promising candidates, so no attempts.
BIRTH @ lap 13.000 : None attempted. 0 past failures. 0 too small. 6 too busy.
DELETE @ lap 13.00: 0/1 accepted. Ndiff 0.00.
   13.000/50 after      3 sec. |    174.0 MiB | K    6 | loss  4.125791567e+00 | Ndiff    0.000
MERGE @ lap 14.00: No promising candidates, so no attempts.
BIRTH @ lap 14.000 : None attempted. 0 past failures. 0 too small. 6 too busy.
DELETE @ lap 14.00: 0/1 accepted. Ndiff 0.00.
   14.000/50 after      3 sec. |    174.0 MiB | K    6 | loss  4.125791567e+00 | Ndiff    0.000
BIRTH @ lap 15.000 : None attempted. 0 past failures. 0 too small. 6 too busy.
DELETE @ lap 15.00: 0/1 accepted. Ndiff 0.00.
MERGE @ lap 15.00 : 0/6 accepted. Ndiff 0.00. 0 skipped.
   15.000/50 after      3 sec. |    174.0 MiB | K    6 | loss  4.125791567e+00 | Ndiff    0.000
MERGE @ lap 16.00: No promising candidates, so no attempts.
BIRTH @ lap 16.000 : None attempted. 0 past failures. 0 too small. 6 too busy.
DELETE @ lap 16.00: 0/1 accepted. Ndiff 0.00.
   16.000/50 after      3 sec. |    174.0 MiB | K    6 | loss  4.125791567e+00 | Ndiff    0.000
BIRTH @ lap 17.000 : None attempted. 0 past failures. 0 too small. 6 too busy.
DELETE @ lap 17.00: 0/1 accepted. Ndiff 0.00.
MERGE @ lap 17.00 : 0/9 accepted. Ndiff 0.00. 0 skipped.
   17.000/50 after      3 sec. |    174.0 MiB | K    6 | loss  4.125791567e+00 | Ndiff    0.000
MERGE @ lap 18.00: No promising candidates, so no attempts.
BIRTH @ lap 18.000 : None attempted. 0 past failures. 0 too small. 6 too busy.
DELETE @ lap 18.00: 0/1 accepted. Ndiff 0.00.
   18.000/50 after      3 sec. |    174.0 MiB | K    6 | loss  4.125791567e+00 | Ndiff    0.000
MERGE @ lap 19.00: No promising candidates, so no attempts.
BIRTH @ lap 19.000 : None attempted. 0 past failures. 0 too small. 6 too busy.
DELETE @ lap 19.00: 0/1 accepted. Ndiff 0.00.
   19.000/50 after      3 sec. |    174.0 MiB | K    6 | loss  4.125791567e+00 | Ndiff    0.000
BIRTH @ lap 20.000 : None attempted. 0 past failures. 0 too small. 6 too busy.
DELETE @ lap 20.00: 0/1 accepted. Ndiff 0.00.
MERGE @ lap 20.00 : 0/6 accepted. Ndiff 0.00. 0 skipped.
   20.000/50 after      3 sec. |    174.0 MiB | K    6 | loss  4.125791567e+00 | Ndiff    0.000
MERGE @ lap 21.00: No promising candidates, so no attempts.
BIRTH @ lap 21.000 : None attempted. 0 past failures. 0 too small. 6 too busy.
DELETE @ lap 21.00: 0/1 accepted. Ndiff 0.00.
   21.000/50 after      3 sec. |    174.0 MiB | K    6 | loss  4.125791567e+00 | Ndiff    0.000
... done. converged.

Total running time of the script: ( 0 minutes 15.776 seconds)

Generated by Sphinx-Gallery