.. DO NOT EDIT. .. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY. .. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE: .. "examples/01_asterisk_K8/plot-02-demo=vb_single_run-model=dp_mix+gauss.py" .. LINE NUMBERS ARE GIVEN BELOW. .. only:: html .. note:: :class: sphx-glr-download-link-note Click :ref:`here ` to download the full example code .. rst-class:: sphx-glr-example-title .. _sphx_glr_examples_01_asterisk_K8_plot-02-demo=vb_single_run-model=dp_mix+gauss.py: ============================================================ Variational coordinate descent for Mixture of Gaussians ============================================================ How to do Variational Bayes (VB) coordinate descent for GMM. Here, we train a finite mixture of Gaussians with full covariances. We'll consider a mixture model with a symmetric Dirichlet prior: .. math:: \pi \sim \mbox{Dir}(1/K, 1/K, \ldots 1/K) as well as a standard conjugate prior on the mean and covariances, such that .. math:: \E[\mu_k] = 0 \E[\Sigma_k] = 0.1 I_D We will initialize the approximate variational posterior using K=10 randomly chosen examples ('randexamples' procedure), and then perform coordinate descent updates (alternating local step and global step) until convergence. .. GENERATED FROM PYTHON SOURCE LINES 29-42 .. code-block:: default # SPECIFY WHICH PLOT CREATED BY THIS SCRIPT IS THE THUMBNAIL IMAGE # sphinx_gallery_thumbnail_number = 3 import bnpy import numpy as np import os from matplotlib import pylab import seaborn as sns FIG_SIZE = (3, 3) pylab.rcParams['figure.figsize'] = FIG_SIZE .. GENERATED FROM PYTHON SOURCE LINES 43-44 Read bnpy's built-in "AsteriskK8" dataset from file. .. GENERATED FROM PYTHON SOURCE LINES 44-49 .. code-block:: default dataset_path = os.path.join(bnpy.DATASET_PATH, 'AsteriskK8') dataset = bnpy.data.XData.read_npz( os.path.join(dataset_path, 'x_dataset.npz')) .. GENERATED FROM PYTHON SOURCE LINES 50-51 Make a simple plot of the raw data .. GENERATED FROM PYTHON SOURCE LINES 52-59 .. code-block:: default pylab.plot(dataset.X[:, 0], dataset.X[:, 1], 'k.') pylab.gca().set_xlim([-2, 2]) pylab.gca().set_ylim([-2, 2]) pylab.tight_layout() .. image-sg:: /examples/01_asterisk_K8/images/sphx_glr_plot-02-demo=vb_single_run-model=dp_mix+gauss_001.png :alt: plot 02 demo=vb single run model=dp mix+gauss :srcset: /examples/01_asterisk_K8/images/sphx_glr_plot-02-demo=vb_single_run-model=dp_mix+gauss_001.png :class: sphx-glr-single-img .. GENERATED FROM PYTHON SOURCE LINES 60-65 Training the model ------------------ Let's do one single run of the VB algorithm. Using 10 clusters and the 'randexamples' initializatio procedure. .. GENERATED FROM PYTHON SOURCE LINES 66-75 .. code-block:: default trained_model, info_dict = bnpy.run( dataset, 'FiniteMixtureModel', 'Gauss', 'VB', output_path='/tmp/AsteriskK8/helloworld-K=10/', nLap=100, sF=0.1, ECovMat='eye', K=10, initname='randexamples') .. rst-class:: sphx-glr-script-out .. code-block:: none Dataset Summary: X Data num examples: 5000 num dims: 2 Allocation Model: Finite mixture model. Dir prior param 1.00 Obs. Data Model: Gaussian with full covariance. Obs. Data Prior: Gauss-Wishart on mean and covar of each cluster E[ mean[k] ] = [0. 0.] E[ covar[k] ] = [[0.1 0. ] [0. 0.1]] Initialization: initname = randexamples K = 10 (number of clusters) seed = 1607680 elapsed_time: 0.0 sec Learn Alg: VB | task 1/1 | alg. seed: 1607680 | data order seed: 8541952 task_output_path: /tmp/AsteriskK8/helloworld-K=10/1 1/100 after 0 sec. | 200.1 MiB | K 10 | loss 6.582634775e-01 | 2/100 after 0 sec. | 201.3 MiB | K 10 | loss 4.350235353e-01 | Ndiff 68.926 3/100 after 0 sec. | 201.3 MiB | K 10 | loss 3.454096950e-01 | Ndiff 193.565 4/100 after 0 sec. | 201.3 MiB | K 10 | loss 3.049230819e-01 | Ndiff 175.237 5/100 after 0 sec. | 201.3 MiB | K 10 | loss 2.732439109e-01 | Ndiff 108.630 6/100 after 0 sec. | 201.3 MiB | K 10 | loss 2.326372999e-01 | Ndiff 42.961 7/100 after 0 sec. | 201.3 MiB | K 10 | loss 2.100254570e-01 | Ndiff 9.814 8/100 after 0 sec. | 201.3 MiB | K 10 | loss 2.097453779e-01 | Ndiff 10.216 9/100 after 0 sec. | 201.3 MiB | K 10 | loss 2.094741199e-01 | Ndiff 10.608 10/100 after 0 sec. | 201.3 MiB | K 10 | loss 2.091535256e-01 | Ndiff 11.009 11/100 after 0 sec. | 201.3 MiB | K 10 | loss 2.087740779e-01 | Ndiff 11.403 12/100 after 0 sec. | 201.3 MiB | K 10 | loss 2.083262051e-01 | Ndiff 11.781 13/100 after 0 sec. | 201.3 MiB | K 10 | loss 2.078022512e-01 | Ndiff 12.135 14/100 after 0 sec. | 201.3 MiB | K 10 | loss 2.072001819e-01 | Ndiff 12.449 15/100 after 0 sec. | 201.3 MiB | K 10 | loss 2.065288318e-01 | Ndiff 12.706 16/100 after 0 sec. | 201.3 MiB | K 10 | loss 2.058118708e-01 | Ndiff 12.883 17/100 after 0 sec. | 201.3 MiB | K 10 | loss 2.050844132e-01 | Ndiff 12.951 18/100 after 0 sec. | 201.3 MiB | K 10 | loss 2.043795824e-01 | Ndiff 12.873 19/100 after 1 sec. | 201.3 MiB | K 10 | loss 2.037146182e-01 | Ndiff 12.598 20/100 after 1 sec. | 201.3 MiB | K 10 | loss 2.030678434e-01 | Ndiff 12.069 21/100 after 1 sec. | 201.3 MiB | K 10 | loss 2.022621299e-01 | Ndiff 11.232 22/100 after 1 sec. | 201.3 MiB | K 10 | loss 2.008648784e-01 | Ndiff 10.057 23/100 after 1 sec. | 201.3 MiB | K 10 | loss 2.002168334e-01 | Ndiff 8.536 24/100 after 1 sec. | 201.3 MiB | K 10 | loss 1.995420809e-01 | Ndiff 8.587 25/100 after 1 sec. | 201.3 MiB | K 10 | loss 1.988594991e-01 | Ndiff 9.312 26/100 after 1 sec. | 201.3 MiB | K 10 | loss 1.982035545e-01 | Ndiff 9.984 27/100 after 1 sec. | 201.3 MiB | K 10 | loss 1.974254707e-01 | Ndiff 10.598 28/100 after 1 sec. | 201.3 MiB | K 10 | loss 1.966834121e-01 | Ndiff 11.131 29/100 after 1 sec. | 201.3 MiB | K 10 | loss 1.964267444e-01 | Ndiff 11.549 30/100 after 1 sec. | 201.3 MiB | K 10 | loss 1.961250819e-01 | Ndiff 11.802 31/100 after 1 sec. | 201.3 MiB | K 10 | loss 1.957730227e-01 | Ndiff 11.830 32/100 after 1 sec. | 201.3 MiB | K 10 | loss 1.953666781e-01 | Ndiff 11.565 33/100 after 1 sec. | 201.3 MiB | K 10 | loss 1.949058031e-01 | Ndiff 10.930 34/100 after 1 sec. | 201.3 MiB | K 10 | loss 1.943977009e-01 | Ndiff 9.860 35/100 after 1 sec. | 201.3 MiB | K 10 | loss 1.938623040e-01 | Ndiff 8.334 36/100 after 1 sec. | 201.3 MiB | K 10 | loss 1.933345511e-01 | Ndiff 6.449 37/100 after 1 sec. | 201.3 MiB | K 10 | loss 1.928570776e-01 | Ndiff 4.458 38/100 after 1 sec. | 201.3 MiB | K 10 | loss 1.924667495e-01 | Ndiff 2.727 39/100 after 1 sec. | 201.3 MiB | K 10 | loss 1.921899371e-01 | Ndiff 1.530 40/100 after 1 sec. | 201.3 MiB | K 10 | loss 1.919922352e-01 | Ndiff 0.873 41/100 after 1 sec. | 201.3 MiB | K 10 | loss 1.916267958e-01 | Ndiff 0.492 42/100 after 1 sec. | 201.3 MiB | K 10 | loss 1.909787395e-01 | Ndiff 0.045 ... done. converged. .. GENERATED FROM PYTHON SOURCE LINES 76-87 Loss function trace plot ------------------------ We can plot the value of the loss function over iterations, starting after the first full pass over the dataset (first lap). As expected, we see monotonic decrease in the loss function's score after every subsequent iteration. Remember that the VB algorithm for GMMs is *guaranteed* to decrease this loss function after every step. .. GENERATED FROM PYTHON SOURCE LINES 88-94 .. code-block:: default pylab.plot(info_dict['lap_history'][1:], info_dict['loss_history'][1:], 'k.-') pylab.xlabel('num. laps') pylab.ylabel('loss') pylab.tight_layout() .. image-sg:: /examples/01_asterisk_K8/images/sphx_glr_plot-02-demo=vb_single_run-model=dp_mix+gauss_002.png :alt: plot 02 demo=vb single run model=dp mix+gauss :srcset: /examples/01_asterisk_K8/images/sphx_glr_plot-02-demo=vb_single_run-model=dp_mix+gauss_002.png :class: sphx-glr-single-img .. GENERATED FROM PYTHON SOURCE LINES 95-98 Visualization of learned clusters --------------------------------- Here's a short function to show the learned clusters over time. .. GENERATED FROM PYTHON SOURCE LINES 99-124 .. code-block:: default def show_clusters_over_time( task_output_path=None, query_laps=[0, 1, 2, 5, 10, None], nrows=2): ''' Read model snapshots from provided folder and make visualizations Post Condition -------------- New matplotlib plot with some nice pictures. ''' ncols = int(np.ceil(len(query_laps) // float(nrows))) fig_handle, ax_handle_list = pylab.subplots( figsize=(FIG_SIZE[0] * ncols, FIG_SIZE[1] * nrows), nrows=nrows, ncols=ncols, sharex=True, sharey=True) for plot_id, lap_val in enumerate(query_laps): cur_model, lap_val = bnpy.load_model_at_lap(task_output_path, lap_val) # Plot the current model cur_ax_handle = ax_handle_list.flatten()[plot_id] bnpy.viz.PlotComps.plotCompsFromHModel( cur_model, Data=dataset, ax_handle=cur_ax_handle) cur_ax_handle.set_xticks([-2, -1, 0, 1, 2]) cur_ax_handle.set_yticks([-2, -1, 0, 1, 2]) cur_ax_handle.set_xlabel("lap: %d" % lap_val) pylab.tight_layout() .. GENERATED FROM PYTHON SOURCE LINES 125-126 Show the estimated clusters over time .. GENERATED FROM PYTHON SOURCE LINES 127-128 .. code-block:: default show_clusters_over_time(info_dict['task_output_path']) .. image-sg:: /examples/01_asterisk_K8/images/sphx_glr_plot-02-demo=vb_single_run-model=dp_mix+gauss_003.png :alt: plot 02 demo=vb single run model=dp mix+gauss :srcset: /examples/01_asterisk_K8/images/sphx_glr_plot-02-demo=vb_single_run-model=dp_mix+gauss_003.png :class: sphx-glr-single-img .. rst-class:: sphx-glr-script-out .. code-block:: none SKIPPED 3 comps with size below 0.00 .. rst-class:: sphx-glr-timing **Total running time of the script:** ( 0 minutes 1.888 seconds) .. _sphx_glr_download_examples_01_asterisk_K8_plot-02-demo=vb_single_run-model=dp_mix+gauss.py: .. only:: html .. container:: sphx-glr-footer sphx-glr-footer-example .. container:: sphx-glr-download sphx-glr-download-python :download:`Download Python source code: plot-02-demo=vb_single_run-model=dp_mix+gauss.py ` .. container:: sphx-glr-download sphx-glr-download-jupyter :download:`Download Jupyter notebook: plot-02-demo=vb_single_run-model=dp_mix+gauss.ipynb ` .. only:: html .. rst-class:: sphx-glr-signature `Gallery generated by Sphinx-Gallery `_