Note
Click here to download the full example code
How to train a DP mixture of multinomials.
import bnpy
import numpy as np
import os
from matplotlib import pylab
import seaborn as sns
FIG_SIZE = (3, 3)
SMALL_FIG_SIZE = (1.5, 1.5)
pylab.rcParams['figure.figsize'] = FIG_SIZE
Read dataset from file.
dataset_path = os.path.join(bnpy.DATASET_PATH, 'bars_one_per_doc')
dataset = bnpy.data.BagOfWordsData.read_npz(
os.path.join(dataset_path, 'dataset.npz'))
Make a simple plot of the raw data
X_csr_DV = dataset.getSparseDocTypeCountMatrix()
bnpy.viz.BarsViz.show_square_images(
X_csr_DV[:10].toarray(), vmin=0, vmax=5)
pylab.tight_layout()
Setup: Function to show bars from start to end of training run
def show_bars_over_time(
task_output_path=None,
query_laps=[0, 1, 2, 5, None],
ncols=10):
'''
'''
nrows = len(query_laps)
fig_handle, ax_handles_RC = pylab.subplots(
figsize=(SMALL_FIG_SIZE[0] * ncols, SMALL_FIG_SIZE[1] * nrows),
nrows=nrows, ncols=ncols, sharex=True, sharey=True)
for row_id, lap_val in enumerate(query_laps):
cur_model, lap_val = bnpy.load_model_at_lap(task_output_path, lap_val)
cur_topics_KV = cur_model.obsModel.getTopics()
# Plot the current model
cur_ax_list = ax_handles_RC[row_id].flatten().tolist()
bnpy.viz.BarsViz.show_square_images(
cur_topics_KV,
vmin=0.0, vmax=0.1,
ax_list=cur_ax_list)
cur_ax_list[0].set_ylabel("lap: %d" % lap_val)
pylab.tight_layout()
Using random initialization
initname = 'randomlikewang'
K = 2
K2_trained_model, K2_info_dict = bnpy.run(
dataset, 'DPMixtureModel', 'Mult', 'memoVB',
output_path='/tmp/bars_one_per_doc/trymoves-K=%d-initname=%s/' % (
K, initname),
nTask=1, nLap=50, convergeThr=0.0001, nBatch=1,
K=K, initname=initname,
gamma0=50.0, lam=0.1,
moves='birth,merge,shuffle,delete',
b_startLap=2,
m_startLap=5,
d_startLap=10)
show_bars_over_time(K2_info_dict['task_output_path'])
Dataset Summary:
BagOfWordsData
total size: 2000 units
batch size: 2000 units
num. batches: 1
Allocation Model: DP mixture with K=0. Concentration gamma0= 50.00
Obs. Data Model: Multinomial over finite vocabulary.
Obs. Data Prior: Dirichlet over finite vocabulary
lam = [0.1 0.1] ...
Initialization:
initname = randomlikewang
K = 2 (number of clusters)
seed = 1607680
elapsed_time: 0.0 sec
Learn Alg: memoVB | task 1/1 | alg. seed: 1607680 | data order seed: 8541952
task_output_path: /tmp/bars_one_per_doc/trymoves-K=2-initname=randomlikewang/1
BIRTH @ lap 1.00: Disabled. Waiting for lap >= 2 (--b_startLap).
MERGE @ lap 1.00: Disabled. Cannot plan merge on first lap. Need valid SS that represent whole dataset.
DELETE @ lap 1.00: Disabled. Cannot delete before first complete lap, because SS that represents whole dataset is required.
1.000/50 after 0 sec. | 218.8 MiB | K 2 | loss 4.852914927e+00 |
MERGE @ lap 2.00: Disabled. Waiting for lap >= 5 (--m_startLap).
DELETE @ lap 2.00: Disabled. Waiting for lap >= 10 (--d_startLap).
BIRTH @ lap 2.00 : Added 6 states. 2/2 succeeded. 0/2 failed eval phase. 0/2 failed build phase.
2.000/50 after 0 sec. | 219.7 MiB | K 8 | loss 4.125791748e+00 |
MERGE @ lap 3.00: Disabled. Waiting for lap >= 5 (--m_startLap).
DELETE @ lap 3.00: Disabled. Waiting for lap >= 10 (--d_startLap).
BIRTH @ lap 3.00 : Added 0 states. 0/6 succeeded. 0/6 failed eval phase. 6/6 failed build phase.
3.000/50 after 1 sec. | 219.7 MiB | K 8 | loss 4.125791567e+00 | Ndiff 0.000
MERGE @ lap 4.00: Disabled. Waiting for lap >= 5 (--m_startLap).
DELETE @ lap 4.00: Disabled. Waiting for lap >= 10 (--d_startLap).
BIRTH @ lap 4.000 : None attempted. 6 past failures. 2 too small. 0 too busy.
4.000/50 after 1 sec. | 219.7 MiB | K 8 | loss 4.125791567e+00 | Ndiff 0.000
DELETE @ lap 5.00: Disabled. Waiting for lap >= 10 (--d_startLap).
BIRTH @ lap 5.000 : None attempted. 0 past failures. 0 too small. 8 too busy.
MERGE @ lap 5.00 : 1/16 accepted. Ndiff 0.00. 0 skipped.
5.000/50 after 1 sec. | 219.7 MiB | K 7 | loss 4.125791567e+00 | Ndiff 0.000
DELETE @ lap 6.00: Disabled. Waiting for lap >= 10 (--d_startLap).
BIRTH @ lap 6.000 : None attempted. 1 past failures. 0 too small. 6 too busy.
MERGE @ lap 6.00 : 1/5 accepted. Ndiff 0.00. 0 skipped.
6.000/50 after 1 sec. | 219.7 MiB | K 6 | loss 4.125791567e+00 | Ndiff 0.000
DELETE @ lap 7.00: Disabled. Waiting for lap >= 10 (--d_startLap).
BIRTH @ lap 7.000 : None attempted. 0 past failures. 0 too small. 6 too busy.
MERGE @ lap 7.00 : 0/5 accepted. Ndiff 0.00. 0 skipped.
7.000/50 after 1 sec. | 219.7 MiB | K 6 | loss 4.125791567e+00 | Ndiff 0.000
MERGE @ lap 8.00: No promising candidates, so no attempts.
DELETE @ lap 8.00: Disabled. Waiting for lap >= 10 (--d_startLap).
BIRTH @ lap 8.00 : Added 0 states. 0/1 succeeded. 0/1 failed eval phase. 1/1 failed build phase.
8.000/50 after 1 sec. | 219.7 MiB | K 6 | loss 4.125791567e+00 | Ndiff 0.000
MERGE @ lap 9.00: No promising candidates, so no attempts.
DELETE @ lap 9.00: Disabled. Waiting for lap >= 10 (--d_startLap).
BIRTH @ lap 9.000 : None attempted. 6 past failures. 0 too small. 0 too busy.
9.000/50 after 2 sec. | 219.7 MiB | K 6 | loss 4.125791567e+00 | Ndiff 0.000
BIRTH @ lap 10.000 : None attempted. 0 past failures. 0 too small. 6 too busy.
DELETE @ lap 10.00: 0/1 accepted. Ndiff 0.00.
MERGE @ lap 10.00 : 0/10 accepted. Ndiff 0.00. 0 skipped.
10.000/50 after 2 sec. | 219.7 MiB | K 6 | loss 4.125791567e+00 | Ndiff 0.000
MERGE @ lap 11.00: No promising candidates, so no attempts.
BIRTH @ lap 11.000 : None attempted. 0 past failures. 0 too small. 6 too busy.
DELETE @ lap 11.00: 0/1 accepted. Ndiff 0.00.
11.000/50 after 2 sec. | 219.7 MiB | K 6 | loss 4.125791567e+00 | Ndiff 0.000
BIRTH @ lap 12.000 : None attempted. 0 past failures. 0 too small. 6 too busy.
DELETE @ lap 12.00: 0/1 accepted. Ndiff 0.00.
MERGE @ lap 12.00 : 0/5 accepted. Ndiff 0.00. 0 skipped.
12.000/50 after 2 sec. | 219.7 MiB | K 6 | loss 4.125791567e+00 | Ndiff 0.000
MERGE @ lap 13.00: No promising candidates, so no attempts.
BIRTH @ lap 13.000 : None attempted. 0 past failures. 0 too small. 6 too busy.
DELETE @ lap 13.00: 0/1 accepted. Ndiff 0.00.
13.000/50 after 2 sec. | 219.7 MiB | K 6 | loss 4.125791567e+00 | Ndiff 0.000
MERGE @ lap 14.00: No promising candidates, so no attempts.
BIRTH @ lap 14.000 : None attempted. 0 past failures. 0 too small. 6 too busy.
DELETE @ lap 14.00: 0/1 accepted. Ndiff 0.00.
14.000/50 after 2 sec. | 219.7 MiB | K 6 | loss 4.125791567e+00 | Ndiff 0.000
BIRTH @ lap 15.000 : None attempted. 0 past failures. 0 too small. 6 too busy.
DELETE @ lap 15.00: 0/1 accepted. Ndiff 0.00.
MERGE @ lap 15.00 : 0/10 accepted. Ndiff 0.00. 0 skipped.
15.000/50 after 2 sec. | 219.7 MiB | K 6 | loss 4.125791567e+00 | Ndiff 0.000
MERGE @ lap 16.00: No promising candidates, so no attempts.
BIRTH @ lap 16.000 : None attempted. 0 past failures. 0 too small. 6 too busy.
DELETE @ lap 16.00: 0/1 accepted. Ndiff 0.00.
16.000/50 after 2 sec. | 219.7 MiB | K 6 | loss 4.125791567e+00 | Ndiff 0.000
BIRTH @ lap 17.000 : None attempted. 0 past failures. 0 too small. 6 too busy.
DELETE @ lap 17.00: 0/1 accepted. Ndiff 0.00.
MERGE @ lap 17.00 : 0/5 accepted. Ndiff 0.00. 0 skipped.
17.000/50 after 2 sec. | 219.7 MiB | K 6 | loss 4.125791567e+00 | Ndiff 0.000
MERGE @ lap 18.00: No promising candidates, so no attempts.
BIRTH @ lap 18.000 : None attempted. 0 past failures. 0 too small. 6 too busy.
DELETE @ lap 18.00: 0/1 accepted. Ndiff 0.00.
18.000/50 after 2 sec. | 219.7 MiB | K 6 | loss 4.125791567e+00 | Ndiff 0.000
MERGE @ lap 19.00: No promising candidates, so no attempts.
BIRTH @ lap 19.000 : None attempted. 0 past failures. 0 too small. 6 too busy.
DELETE @ lap 19.00: 0/1 accepted. Ndiff 0.00.
19.000/50 after 2 sec. | 219.7 MiB | K 6 | loss 4.125791567e+00 | Ndiff 0.000
BIRTH @ lap 20.000 : None attempted. 0 past failures. 0 too small. 6 too busy.
DELETE @ lap 20.00: 0/1 accepted. Ndiff 0.00.
MERGE @ lap 20.00 : 0/10 accepted. Ndiff 0.00. 0 skipped.
20.000/50 after 3 sec. | 219.7 MiB | K 6 | loss 4.125791567e+00 | Ndiff 0.000
MERGE @ lap 21.00: No promising candidates, so no attempts.
BIRTH @ lap 21.000 : None attempted. 0 past failures. 0 too small. 6 too busy.
DELETE @ lap 21.00: 0/1 accepted. Ndiff 0.00.
21.000/50 after 3 sec. | 219.7 MiB | K 6 | loss 4.125791567e+00 | Ndiff 0.000
... done. converged.
Using random initialization
initname = 'randomlikewang'
K = 10
K10_trained_model, K10_info_dict = bnpy.run(
dataset, 'DPMixtureModel', 'Mult', 'memoVB',
output_path='/tmp/bars_one_per_doc/trymoves-K=%d-initname=%s/' % (
K, initname),
nTask=1, nLap=50, convergeThr=0.0001, nBatch=1,
K=K, initname=initname,
gamma0=50.0, lam=0.1,
moves='birth,merge,shuffle,delete',
b_startLap=2,
m_startLap=5,
d_startLap=10)
show_bars_over_time(K10_info_dict['task_output_path'])
Dataset Summary:
BagOfWordsData
total size: 2000 units
batch size: 2000 units
num. batches: 1
Allocation Model: DP mixture with K=0. Concentration gamma0= 50.00
Obs. Data Model: Multinomial over finite vocabulary.
Obs. Data Prior: Dirichlet over finite vocabulary
lam = [0.1 0.1] ...
Initialization:
initname = randomlikewang
K = 10 (number of clusters)
seed = 1607680
elapsed_time: 0.0 sec
Learn Alg: memoVB | task 1/1 | alg. seed: 1607680 | data order seed: 8541952
task_output_path: /tmp/bars_one_per_doc/trymoves-K=10-initname=randomlikewang/1
BIRTH @ lap 1.00: Disabled. Waiting for lap >= 2 (--b_startLap).
MERGE @ lap 1.00: Disabled. Cannot plan merge on first lap. Need valid SS that represent whole dataset.
DELETE @ lap 1.00: Disabled. Cannot delete before first complete lap, because SS that represents whole dataset is required.
1.000/50 after 0 sec. | 224.2 MiB | K 10 | loss 4.712365207e+00 |
MERGE @ lap 2.00: Disabled. Waiting for lap >= 5 (--m_startLap).
DELETE @ lap 2.00: Disabled. Waiting for lap >= 10 (--d_startLap).
BIRTH @ lap 2.00 : Added 2 states. 1/6 succeeded. 0/6 failed eval phase. 5/6 failed build phase.
2.000/50 after 1 sec. | 224.2 MiB | K 12 | loss 4.128419662e+00 |
MERGE @ lap 3.00: Disabled. Waiting for lap >= 5 (--m_startLap).
DELETE @ lap 3.00: Disabled. Waiting for lap >= 10 (--d_startLap).
BIRTH @ lap 3.00 : Added 0 states. 0/4 succeeded. 0/4 failed eval phase. 4/4 failed build phase.
3.000/50 after 1 sec. | 224.2 MiB | K 12 | loss 4.127318928e+00 | Ndiff 17.626
MERGE @ lap 4.00: Disabled. Waiting for lap >= 5 (--m_startLap).
DELETE @ lap 4.00: Disabled. Waiting for lap >= 10 (--d_startLap).
BIRTH @ lap 4.00 : Added 0 states. 0/2 succeeded. 0/2 failed eval phase. 2/2 failed build phase.
4.000/50 after 1 sec. | 224.2 MiB | K 12 | loss 4.127075221e+00 | Ndiff 11.299
DELETE @ lap 5.00: Disabled. Waiting for lap >= 10 (--d_startLap).
BIRTH @ lap 5.000 : None attempted. 0 past failures. 0 too small. 12 too busy.
MERGE @ lap 5.00 : 3/18 accepted. Ndiff 0.00. 12 skipped.
5.000/50 after 2 sec. | 224.2 MiB | K 9 | loss 4.126920998e+00 | Ndiff 11.299
DELETE @ lap 6.00: Disabled. Waiting for lap >= 10 (--d_startLap).
BIRTH @ lap 6.000 : None attempted. 1 past failures. 0 too small. 8 too busy.
MERGE @ lap 6.00 : 3/5 accepted. Ndiff 21.29. 10 skipped.
6.000/50 after 2 sec. | 224.2 MiB | K 6 | loss 4.125791567e+00 | Ndiff 11.299
DELETE @ lap 7.00: Disabled. Waiting for lap >= 10 (--d_startLap).
BIRTH @ lap 7.000 : None attempted. 0 past failures. 0 too small. 6 too busy.
MERGE @ lap 7.00 : 0/9 accepted. Ndiff 0.00. 0 skipped.
7.000/50 after 2 sec. | 224.2 MiB | K 6 | loss 4.125791567e+00 | Ndiff 0.000
MERGE @ lap 8.00: No promising candidates, so no attempts.
DELETE @ lap 8.00: Disabled. Waiting for lap >= 10 (--d_startLap).
BIRTH @ lap 8.00 : Added 0 states. 0/2 succeeded. 0/2 failed eval phase. 2/2 failed build phase.
8.000/50 after 2 sec. | 224.2 MiB | K 6 | loss 4.125791567e+00 | Ndiff 0.000
MERGE @ lap 9.00: No promising candidates, so no attempts.
DELETE @ lap 9.00: Disabled. Waiting for lap >= 10 (--d_startLap).
BIRTH @ lap 9.000 : None attempted. 6 past failures. 0 too small. 0 too busy.
9.000/50 after 2 sec. | 224.2 MiB | K 6 | loss 4.125791567e+00 | Ndiff 0.000
BIRTH @ lap 10.000 : None attempted. 0 past failures. 0 too small. 6 too busy.
DELETE @ lap 10.00: 0/1 accepted. Ndiff 0.00.
MERGE @ lap 10.00 : 0/6 accepted. Ndiff 0.00. 0 skipped.
10.000/50 after 2 sec. | 224.2 MiB | K 6 | loss 4.125791567e+00 | Ndiff 0.000
MERGE @ lap 11.00: No promising candidates, so no attempts.
BIRTH @ lap 11.000 : None attempted. 0 past failures. 0 too small. 6 too busy.
DELETE @ lap 11.00: 0/1 accepted. Ndiff 0.00.
11.000/50 after 2 sec. | 224.2 MiB | K 6 | loss 4.125791567e+00 | Ndiff 0.000
BIRTH @ lap 12.000 : None attempted. 0 past failures. 0 too small. 6 too busy.
DELETE @ lap 12.00: 0/1 accepted. Ndiff 0.00.
MERGE @ lap 12.00 : 0/9 accepted. Ndiff 0.00. 0 skipped.
12.000/50 after 2 sec. | 224.2 MiB | K 6 | loss 4.125791567e+00 | Ndiff 0.000
MERGE @ lap 13.00: No promising candidates, so no attempts.
BIRTH @ lap 13.000 : None attempted. 0 past failures. 0 too small. 6 too busy.
DELETE @ lap 13.00: 0/1 accepted. Ndiff 0.00.
13.000/50 after 3 sec. | 224.2 MiB | K 6 | loss 4.125791567e+00 | Ndiff 0.000
MERGE @ lap 14.00: No promising candidates, so no attempts.
BIRTH @ lap 14.000 : None attempted. 0 past failures. 0 too small. 6 too busy.
DELETE @ lap 14.00: 0/1 accepted. Ndiff 0.00.
14.000/50 after 3 sec. | 224.2 MiB | K 6 | loss 4.125791567e+00 | Ndiff 0.000
BIRTH @ lap 15.000 : None attempted. 0 past failures. 0 too small. 6 too busy.
DELETE @ lap 15.00: 0/1 accepted. Ndiff 0.00.
MERGE @ lap 15.00 : 0/6 accepted. Ndiff 0.00. 0 skipped.
15.000/50 after 3 sec. | 224.2 MiB | K 6 | loss 4.125791567e+00 | Ndiff 0.000
MERGE @ lap 16.00: No promising candidates, so no attempts.
BIRTH @ lap 16.000 : None attempted. 0 past failures. 0 too small. 6 too busy.
DELETE @ lap 16.00: 0/1 accepted. Ndiff 0.00.
16.000/50 after 3 sec. | 224.2 MiB | K 6 | loss 4.125791567e+00 | Ndiff 0.000
BIRTH @ lap 17.000 : None attempted. 0 past failures. 0 too small. 6 too busy.
DELETE @ lap 17.00: 0/1 accepted. Ndiff 0.00.
MERGE @ lap 17.00 : 0/9 accepted. Ndiff 0.00. 0 skipped.
17.000/50 after 3 sec. | 224.2 MiB | K 6 | loss 4.125791567e+00 | Ndiff 0.000
MERGE @ lap 18.00: No promising candidates, so no attempts.
BIRTH @ lap 18.000 : None attempted. 0 past failures. 0 too small. 6 too busy.
DELETE @ lap 18.00: 0/1 accepted. Ndiff 0.00.
18.000/50 after 3 sec. | 224.2 MiB | K 6 | loss 4.125791567e+00 | Ndiff 0.000
MERGE @ lap 19.00: No promising candidates, so no attempts.
BIRTH @ lap 19.000 : None attempted. 0 past failures. 0 too small. 6 too busy.
DELETE @ lap 19.00: 0/1 accepted. Ndiff 0.00.
19.000/50 after 3 sec. | 224.2 MiB | K 6 | loss 4.125791567e+00 | Ndiff 0.000
BIRTH @ lap 20.000 : None attempted. 0 past failures. 0 too small. 6 too busy.
DELETE @ lap 20.00: 0/1 accepted. Ndiff 0.00.
MERGE @ lap 20.00 : 0/6 accepted. Ndiff 0.00. 0 skipped.
20.000/50 after 3 sec. | 224.2 MiB | K 6 | loss 4.125791567e+00 | Ndiff 0.000
MERGE @ lap 21.00: No promising candidates, so no attempts.
BIRTH @ lap 21.000 : None attempted. 0 past failures. 0 too small. 6 too busy.
DELETE @ lap 21.00: 0/1 accepted. Ndiff 0.00.
21.000/50 after 3 sec. | 224.2 MiB | K 6 | loss 4.125791567e+00 | Ndiff 0.000
... done. converged.
# Create random number generator with fixed seed
prng = np.random.RandomState(54321)
# Preallocate space for 5 generated docs
n_docs_to_sample = 5
V = X_csr_DV.shape[1]
test_x_DV = np.zeros((n_docs_to_sample, V))
for doc_id in range(n_docs_to_sample):
# Step 1: Pick cluster index *k* that current example is assigned to
proba_K = K10_trained_model.allocModel.get_active_comp_probs()
k = prng.choice(proba_K.size, p=proba_K / np.sum(proba_K))
# Step 2: Draw probability-over-vocab from cluster *k*'s Dirichlet posterior
lam_k_V = K10_trained_model.obsModel.Post.lam[k]
phi_k_V = prng.dirichlet(lam_k_V)
# Step 3: Draw a document with 50 words using phi_k_V
x_d_V = prng.multinomial(50, phi_k_V)
test_x_DV[doc_id] = x_d_V
bnpy.viz.BarsViz.show_square_images(
test_x_DV, vmin=0, vmax=5)
pylab.tight_layout()
pylab.show(block=False)
Total running time of the script: ( 0 minutes 9.967 seconds)