.. DO NOT EDIT. .. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY. .. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE: .. "examples/03_faithful/plot-01-demo=vb_algs-model=mix_gauss.py" .. LINE NUMBERS ARE GIVEN BELOW. .. only:: html .. note:: :class: sphx-glr-download-link-note Click :ref:`here ` to download the full example code .. rst-class:: sphx-glr-example-title .. _sphx_glr_examples_03_faithful_plot-01-demo=vb_algs-model=mix_gauss.py: ============================================== Variational training for Mixtures of Gaussians ============================================== Showcase of different models and algorithms applied to same dataset. In this example, we show how bnpy makes it easy to apply different models and algorithms to the same dataset. .. GENERATED FROM PYTHON SOURCE LINES 12-25 .. code-block:: default import bnpy import numpy as np import os from matplotlib import pylab import seaborn as sns SMALL_FIG_SIZE = (2.5, 2.5) FIG_SIZE = (5, 5) pylab.rcParams['figure.figsize'] = FIG_SIZE np.set_printoptions(precision=3, suppress=1, linewidth=200) .. GENERATED FROM PYTHON SOURCE LINES 26-27 Load dataset from file .. GENERATED FROM PYTHON SOURCE LINES 28-33 .. code-block:: default dataset_path = os.path.join(bnpy.DATASET_PATH, 'faithful') dataset = bnpy.data.XData.read_csv( os.path.join(dataset_path, 'faithful.csv')) .. GENERATED FROM PYTHON SOURCE LINES 34-35 Make a simple plot of the raw data .. GENERATED FROM PYTHON SOURCE LINES 36-44 .. code-block:: default pylab.plot(dataset.X[:, 0], dataset.X[:, 1], 'k.') pylab.xlabel(dataset.column_names[0]) pylab.ylabel(dataset.column_names[1]) pylab.tight_layout() data_ax_h = pylab.gca() .. image-sg:: /examples/03_faithful/images/sphx_glr_plot-01-demo=vb_algs-model=mix_gauss_001.png :alt: plot 01 demo=vb algs model=mix gauss :srcset: /examples/03_faithful/images/sphx_glr_plot-01-demo=vb_algs-model=mix_gauss_001.png :class: sphx-glr-single-img .. GENERATED FROM PYTHON SOURCE LINES 45-47 Setup: Helper function to display the learned clusters ------------------------------------------------------ .. GENERATED FROM PYTHON SOURCE LINES 48-72 .. code-block:: default def show_clusters_over_time( task_output_path=None, query_laps=[0, 1, 2, 10, 20, None], nrows=2): ''' Show 2D elliptical contours overlaid on raw data. ''' ncols = int(np.ceil(len(query_laps) // float(nrows))) fig_handle, ax_handle_list = pylab.subplots( figsize=(SMALL_FIG_SIZE[0] * ncols, SMALL_FIG_SIZE[1] * nrows), nrows=nrows, ncols=ncols, sharex=True, sharey=True) for plot_id, lap_val in enumerate(query_laps): cur_model, lap_val = bnpy.load_model_at_lap(task_output_path, lap_val) cur_ax_handle = ax_handle_list.flatten()[plot_id] bnpy.viz.PlotComps.plotCompsFromHModel( cur_model, dataset=dataset, ax_handle=cur_ax_handle) cur_ax_handle.set_title("lap: %d" % lap_val) cur_ax_handle.set_xlabel(dataset.column_names[0]) cur_ax_handle.set_ylabel(dataset.column_names[1]) cur_ax_handle.set_xlim(data_ax_h.get_xlim()) cur_ax_handle.set_ylim(data_ax_h.get_ylim()) pylab.tight_layout() .. GENERATED FROM PYTHON SOURCE LINES 73-79 *DiagGauss* observation model ----------------------------- Assume diagonal covariances. Start with too many clusters (K=20) .. GENERATED FROM PYTHON SOURCE LINES 80-94 .. code-block:: default gamma = 5.0 sF = 5.0 K = 20 diag_trained_model, diag_info_dict = bnpy.run( dataset, 'DPMixtureModel', 'DiagGauss', 'memoVB', output_path='/tmp/faithful/showcase-K=20-lik=DiagGauss-ECovMat=5*eye/', nLap=1000, nTask=1, nBatch=1, convergeThr=0.0001, gamma0=gamma, sF=sF, ECovMat='eye', K=K, initname='randexamples', ) show_clusters_over_time(diag_info_dict['task_output_path']) .. image-sg:: /examples/03_faithful/images/sphx_glr_plot-01-demo=vb_algs-model=mix_gauss_002.png :alt: lap: 0, lap: 1, lap: 2, lap: 10, lap: 20, lap: 79 :srcset: /examples/03_faithful/images/sphx_glr_plot-01-demo=vb_algs-model=mix_gauss_002.png :class: sphx-glr-single-img .. rst-class:: sphx-glr-script-out .. code-block:: none Dataset Summary: X Data total size: 272 units batch size: 272 units num. batches: 1 Allocation Model: DP mixture with K=0. Concentration gamma0= 5.00 Obs. Data Model: Gaussian with diagonal covariance. Obs. Data Prior: independent Gauss-Wishart prior on each dimension Wishart params nu = 4 beta = [ 10 10] Expectations E[ mean[k]] = [ 0 0] E[ covar[k]] = [[5. 0.] [0. 5.]] Initialization: initname = randexamples K = 20 (number of clusters) seed = 1607680 elapsed_time: 0.0 sec Learn Alg: memoVB | task 1/1 | alg. seed: 1607680 | data order seed: 8541952 task_output_path: /tmp/faithful/showcase-K=20-lik=DiagGauss-ECovMat=5*eye/1 1.000/1000 after 0 sec. | 213.1 MiB | K 20 | loss 3.002088161e+00 | 2.000/1000 after 0 sec. | 213.1 MiB | K 20 | loss 2.966309748e+00 | Ndiff 4.584 3.000/1000 after 0 sec. | 213.1 MiB | K 20 | loss 2.949498401e+00 | Ndiff 4.501 4.000/1000 after 0 sec. | 213.1 MiB | K 20 | loss 2.931792100e+00 | Ndiff 5.452 5.000/1000 after 0 sec. | 213.1 MiB | K 20 | loss 2.910174417e+00 | Ndiff 6.264 6.000/1000 after 0 sec. | 213.1 MiB | K 20 | loss 2.881973764e+00 | Ndiff 6.557 7.000/1000 after 0 sec. | 213.1 MiB | K 20 | loss 2.833491488e+00 | Ndiff 6.194 8.000/1000 after 0 sec. | 213.1 MiB | K 20 | loss 2.793876249e+00 | Ndiff 5.582 9.000/1000 after 0 sec. | 213.1 MiB | K 20 | loss 2.752249036e+00 | Ndiff 4.580 10.000/1000 after 0 sec. | 213.1 MiB | K 20 | loss 2.698456106e+00 | Ndiff 4.740 11.000/1000 after 0 sec. | 213.1 MiB | K 20 | loss 2.651622886e+00 | Ndiff 5.753 12.000/1000 after 0 sec. | 213.1 MiB | K 20 | loss 2.624108159e+00 | Ndiff 6.393 13.000/1000 after 0 sec. | 213.1 MiB | K 20 | loss 2.587572334e+00 | Ndiff 6.074 14.000/1000 after 0 sec. | 213.1 MiB | K 20 | loss 2.560219008e+00 | Ndiff 4.596 15.000/1000 after 0 sec. | 213.1 MiB | K 20 | loss 2.544887510e+00 | Ndiff 2.425 16.000/1000 after 0 sec. | 213.1 MiB | K 20 | loss 2.527562170e+00 | Ndiff 1.537 17.000/1000 after 0 sec. | 213.1 MiB | K 20 | loss 2.513204290e+00 | Ndiff 1.470 18.000/1000 after 0 sec. | 213.1 MiB | K 20 | loss 2.507574530e+00 | Ndiff 1.454 19.000/1000 after 0 sec. | 213.1 MiB | K 20 | loss 2.497639081e+00 | Ndiff 1.477 20.000/1000 after 0 sec. | 213.1 MiB | K 20 | loss 2.497002974e+00 | Ndiff 1.526 21.000/1000 after 0 sec. | 213.1 MiB | K 20 | loss 2.496368366e+00 | Ndiff 1.584 22.000/1000 after 0 sec. | 213.1 MiB | K 20 | loss 2.495732569e+00 | Ndiff 1.638 23.000/1000 after 0 sec. | 213.1 MiB | K 20 | loss 2.495103986e+00 | Ndiff 1.678 24.000/1000 after 0 sec. | 213.1 MiB | K 20 | loss 2.494497370e+00 | Ndiff 1.697 25.000/1000 after 0 sec. | 213.1 MiB | K 20 | loss 2.493928089e+00 | Ndiff 1.698 26.000/1000 after 0 sec. | 213.1 MiB | K 20 | loss 2.493406054e+00 | Ndiff 1.687 27.000/1000 after 0 sec. | 213.1 MiB | K 20 | loss 2.492931650e+00 | Ndiff 1.674 28.000/1000 after 0 sec. | 213.1 MiB | K 20 | loss 2.492495409e+00 | Ndiff 1.670 29.000/1000 after 0 sec. | 213.1 MiB | K 20 | loss 2.492081031e+00 | Ndiff 1.682 30.000/1000 after 0 sec. | 213.1 MiB | K 20 | loss 2.491669505e+00 | Ndiff 1.715 31.000/1000 after 0 sec. | 213.1 MiB | K 20 | loss 2.491242078e+00 | Ndiff 1.771 32.000/1000 after 0 sec. | 213.1 MiB | K 20 | loss 2.490781261e+00 | Ndiff 1.851 33.000/1000 after 0 sec. | 213.1 MiB | K 20 | loss 2.490270354e+00 | Ndiff 1.954 34.000/1000 after 0 sec. | 213.1 MiB | K 20 | loss 2.489692235e+00 | Ndiff 2.082 35.000/1000 after 0 sec. | 213.1 MiB | K 20 | loss 2.489027945e+00 | Ndiff 2.235 36.000/1000 after 0 sec. | 213.1 MiB | K 20 | loss 2.488255286e+00 | Ndiff 2.414 37.000/1000 after 0 sec. | 213.1 MiB | K 20 | loss 2.487347489e+00 | Ndiff 2.621 38.000/1000 after 0 sec. | 213.1 MiB | K 20 | loss 2.486272015e+00 | Ndiff 2.855 39.000/1000 after 0 sec. | 213.1 MiB | K 20 | loss 2.484989607e+00 | Ndiff 3.113 40.000/1000 after 1 sec. | 213.1 MiB | K 20 | loss 2.483453885e+00 | Ndiff 3.391 41.000/1000 after 1 sec. | 213.1 MiB | K 20 | loss 2.481612017e+00 | Ndiff 3.675 42.000/1000 after 1 sec. | 213.1 MiB | K 20 | loss 2.479407351e+00 | Ndiff 3.945 43.000/1000 after 1 sec. | 213.1 MiB | K 20 | loss 2.476785615e+00 | Ndiff 4.169 44.000/1000 after 1 sec. | 213.1 MiB | K 20 | loss 2.473707547e+00 | Ndiff 4.301 45.000/1000 after 1 sec. | 213.1 MiB | K 20 | loss 2.470173027e+00 | Ndiff 4.287 46.000/1000 after 1 sec. | 213.1 MiB | K 20 | loss 2.466261998e+00 | Ndiff 4.076 47.000/1000 after 1 sec. | 213.1 MiB | K 20 | loss 2.462175093e+00 | Ndiff 3.648 48.000/1000 after 1 sec. | 213.1 MiB | K 20 | loss 2.458145085e+00 | Ndiff 3.068 49.000/1000 after 1 sec. | 213.1 MiB | K 20 | loss 2.453753531e+00 | Ndiff 2.496 50.000/1000 after 1 sec. | 213.1 MiB | K 20 | loss 2.441832238e+00 | Ndiff 2.043 51.000/1000 after 1 sec. | 213.1 MiB | K 20 | loss 2.437414963e+00 | Ndiff 1.868 52.000/1000 after 1 sec. | 213.1 MiB | K 20 | loss 2.435586138e+00 | Ndiff 1.786 53.000/1000 after 1 sec. | 213.1 MiB | K 20 | loss 2.433555668e+00 | Ndiff 1.617 54.000/1000 after 1 sec. | 213.1 MiB | K 20 | loss 2.431430213e+00 | Ndiff 1.335 55.000/1000 after 1 sec. | 213.1 MiB | K 20 | loss 2.429483776e+00 | Ndiff 0.970 56.000/1000 after 1 sec. | 213.1 MiB | K 20 | loss 2.428084819e+00 | Ndiff 0.626 57.000/1000 after 1 sec. | 213.1 MiB | K 20 | loss 2.427098172e+00 | Ndiff 0.419 58.000/1000 after 1 sec. | 213.1 MiB | K 20 | loss 2.425535651e+00 | Ndiff 0.330 59.000/1000 after 1 sec. | 213.1 MiB | K 20 | loss 2.415976963e+00 | Ndiff 0.153 60.000/1000 after 1 sec. | 213.1 MiB | K 20 | loss 2.412692266e+00 | Ndiff 0.026 61.000/1000 after 1 sec. | 213.1 MiB | K 20 | loss 2.412692093e+00 | Ndiff 0.016 62.000/1000 after 1 sec. | 213.1 MiB | K 20 | loss 2.412692005e+00 | Ndiff 0.011 63.000/1000 after 1 sec. | 213.1 MiB | K 20 | loss 2.412691958e+00 | Ndiff 0.008 64.000/1000 after 1 sec. | 213.1 MiB | K 20 | loss 2.412691932e+00 | Ndiff 0.006 65.000/1000 after 1 sec. | 213.1 MiB | K 20 | loss 2.412691918e+00 | Ndiff 0.005 66.000/1000 after 1 sec. | 213.1 MiB | K 20 | loss 2.412691910e+00 | Ndiff 0.003 67.000/1000 after 1 sec. | 213.1 MiB | K 20 | loss 2.412691906e+00 | Ndiff 0.003 68.000/1000 after 1 sec. | 213.1 MiB | K 20 | loss 2.412691904e+00 | Ndiff 0.002 69.000/1000 after 1 sec. | 213.1 MiB | K 20 | loss 2.412691903e+00 | Ndiff 0.001 70.000/1000 after 1 sec. | 213.1 MiB | K 20 | loss 2.412691902e+00 | Ndiff 0.001 71.000/1000 after 1 sec. | 213.1 MiB | K 20 | loss 2.412691901e+00 | Ndiff 0.001 72.000/1000 after 1 sec. | 213.1 MiB | K 20 | loss 2.412691901e+00 | Ndiff 0.001 73.000/1000 after 1 sec. | 213.1 MiB | K 20 | loss 2.412691901e+00 | Ndiff 0.000 74.000/1000 after 1 sec. | 213.1 MiB | K 20 | loss 2.412691901e+00 | Ndiff 0.000 75.000/1000 after 1 sec. | 213.1 MiB | K 20 | loss 2.412691901e+00 | Ndiff 0.000 76.000/1000 after 1 sec. | 213.1 MiB | K 20 | loss 2.412691901e+00 | Ndiff 0.000 77.000/1000 after 1 sec. | 213.1 MiB | K 20 | loss 2.412691901e+00 | Ndiff 0.000 78.000/1000 after 1 sec. | 213.1 MiB | K 20 | loss 2.412691901e+00 | Ndiff 0.000 79.000/1000 after 1 sec. | 213.1 MiB | K 20 | loss 2.412691901e+00 | Ndiff 0.000 ... done. converged. .. GENERATED FROM PYTHON SOURCE LINES 95-101 *Gauss* observations + VB ------------------------- Assume full covariances. Start with too many clusters (K=20) .. GENERATED FROM PYTHON SOURCE LINES 102-112 .. code-block:: default full_trained_model, full_info_dict = bnpy.run( dataset, 'DPMixtureModel', 'Gauss', 'VB', output_path='/tmp/faithful/showcase-K=20-lik=Gauss-ECovMat=5*eye/', nLap=1000, nTask=1, nBatch=1, convergeThr=0.0001, gamma0=gamma, sF=sF, ECovMat='eye', K=K, initname='randexamples', ) show_clusters_over_time(full_info_dict['task_output_path']) .. image-sg:: /examples/03_faithful/images/sphx_glr_plot-01-demo=vb_algs-model=mix_gauss_003.png :alt: lap: 0, lap: 1, lap: 2, lap: 10, lap: 20, lap: 117 :srcset: /examples/03_faithful/images/sphx_glr_plot-01-demo=vb_algs-model=mix_gauss_003.png :class: sphx-glr-single-img .. rst-class:: sphx-glr-script-out .. code-block:: none WARNING: Found unrecognized keyword args. These are ignored. --nBatch Dataset Summary: X Data num examples: 272 num dims: 2 Allocation Model: DP mixture with K=0. Concentration gamma0= 5.00 Obs. Data Model: Gaussian with full covariance. Obs. Data Prior: Gauss-Wishart on mean and covar of each cluster E[ mean[k] ] = [0. 0.] E[ covar[k] ] = [[5. 0.] [0. 5.]] Initialization: initname = randexamples K = 20 (number of clusters) seed = 1607680 elapsed_time: 0.0 sec Learn Alg: VB | task 1/1 | alg. seed: 1607680 | data order seed: 8541952 task_output_path: /tmp/faithful/showcase-K=20-lik=Gauss-ECovMat=5*eye/1 1/1000 after 0 sec. | 215.6 MiB | K 20 | loss 2.904406498e+00 | 2/1000 after 0 sec. | 215.6 MiB | K 20 | loss 2.876066302e+00 | Ndiff 2.071 3/1000 after 0 sec. | 215.6 MiB | K 20 | loss 2.864944306e+00 | Ndiff 2.664 4/1000 after 0 sec. | 215.6 MiB | K 20 | loss 2.852026479e+00 | Ndiff 3.160 5/1000 after 0 sec. | 215.6 MiB | K 20 | loss 2.836506635e+00 | Ndiff 3.522 6/1000 after 0 sec. | 215.6 MiB | K 20 | loss 2.817804673e+00 | Ndiff 3.636 7/1000 after 0 sec. | 215.6 MiB | K 20 | loss 2.787080350e+00 | Ndiff 3.495 8/1000 after 0 sec. | 215.6 MiB | K 20 | loss 2.743017730e+00 | Ndiff 3.349 9/1000 after 0 sec. | 215.6 MiB | K 20 | loss 2.714906517e+00 | Ndiff 3.494 10/1000 after 0 sec. | 215.6 MiB | K 20 | loss 2.670842807e+00 | Ndiff 4.089 11/1000 after 0 sec. | 215.6 MiB | K 20 | loss 2.640944068e+00 | Ndiff 4.696 12/1000 after 0 sec. | 215.6 MiB | K 20 | loss 2.600612644e+00 | Ndiff 5.201 13/1000 after 0 sec. | 215.6 MiB | K 20 | loss 2.582832638e+00 | Ndiff 5.436 14/1000 after 0 sec. | 215.6 MiB | K 20 | loss 2.569005251e+00 | Ndiff 5.209 15/1000 after 0 sec. | 215.6 MiB | K 20 | loss 2.544279312e+00 | Ndiff 4.376 16/1000 after 0 sec. | 215.6 MiB | K 20 | loss 2.523247072e+00 | Ndiff 2.925 17/1000 after 0 sec. | 215.6 MiB | K 20 | loss 2.511328327e+00 | Ndiff 1.441 18/1000 after 1 sec. | 215.6 MiB | K 20 | loss 2.503251361e+00 | Ndiff 1.441 19/1000 after 1 sec. | 215.6 MiB | K 20 | loss 2.491852524e+00 | Ndiff 1.448 20/1000 after 1 sec. | 215.6 MiB | K 20 | loss 2.490247584e+00 | Ndiff 1.447 21/1000 after 1 sec. | 215.6 MiB | K 20 | loss 2.488483904e+00 | Ndiff 1.524 22/1000 after 1 sec. | 215.6 MiB | K 20 | loss 2.486427687e+00 | Ndiff 1.578 23/1000 after 1 sec. | 215.6 MiB | K 20 | loss 2.484005616e+00 | Ndiff 1.586 24/1000 after 1 sec. | 215.6 MiB | K 20 | loss 2.481221463e+00 | Ndiff 1.527 25/1000 after 1 sec. | 215.6 MiB | K 20 | loss 2.478177971e+00 | Ndiff 1.375 26/1000 after 1 sec. | 215.6 MiB | K 20 | loss 2.475094563e+00 | Ndiff 1.194 27/1000 after 1 sec. | 215.6 MiB | K 20 | loss 2.471974138e+00 | Ndiff 1.209 28/1000 after 1 sec. | 215.6 MiB | K 20 | loss 2.467024647e+00 | Ndiff 1.213 29/1000 after 1 sec. | 215.6 MiB | K 20 | loss 2.453713427e+00 | Ndiff 1.207 30/1000 after 1 sec. | 215.6 MiB | K 20 | loss 2.445147652e+00 | Ndiff 1.191 31/1000 after 1 sec. | 215.6 MiB | K 20 | loss 2.439161802e+00 | Ndiff 1.170 32/1000 after 1 sec. | 215.6 MiB | K 20 | loss 2.438783471e+00 | Ndiff 1.143 33/1000 after 1 sec. | 215.6 MiB | K 20 | loss 2.438429230e+00 | Ndiff 1.114 34/1000 after 1 sec. | 215.6 MiB | K 20 | loss 2.438098475e+00 | Ndiff 1.082 35/1000 after 1 sec. | 215.6 MiB | K 20 | loss 2.437790556e+00 | Ndiff 1.051 36/1000 after 1 sec. | 215.6 MiB | K 20 | loss 2.437503774e+00 | Ndiff 1.023 37/1000 after 1 sec. | 215.6 MiB | K 20 | loss 2.437235314e+00 | Ndiff 1.000 38/1000 after 1 sec. | 215.6 MiB | K 20 | loss 2.436981439e+00 | Ndiff 0.984 39/1000 after 1 sec. | 215.6 MiB | K 20 | loss 2.436737744e+00 | Ndiff 0.977 40/1000 after 1 sec. | 215.6 MiB | K 20 | loss 2.436499373e+00 | Ndiff 0.979 41/1000 after 1 sec. | 215.6 MiB | K 20 | loss 2.436261156e+00 | Ndiff 0.992 42/1000 after 1 sec. | 215.6 MiB | K 20 | loss 2.436017666e+00 | Ndiff 1.014 43/1000 after 1 sec. | 215.6 MiB | K 20 | loss 2.435763198e+00 | Ndiff 1.048 44/1000 after 1 sec. | 215.6 MiB | K 20 | loss 2.435491737e+00 | Ndiff 1.091 45/1000 after 1 sec. | 215.6 MiB | K 20 | loss 2.435196931e+00 | Ndiff 1.145 46/1000 after 1 sec. | 215.6 MiB | K 20 | loss 2.434872198e+00 | Ndiff 1.209 47/1000 after 1 sec. | 215.6 MiB | K 20 | loss 2.434511031e+00 | Ndiff 1.279 48/1000 after 1 sec. | 215.6 MiB | K 20 | loss 2.434107639e+00 | Ndiff 1.354 49/1000 after 1 sec. | 215.6 MiB | K 20 | loss 2.433657974e+00 | Ndiff 1.428 50/1000 after 1 sec. | 215.6 MiB | K 20 | loss 2.433161085e+00 | Ndiff 1.498 51/1000 after 1 sec. | 215.6 MiB | K 20 | loss 2.432620547e+00 | Ndiff 1.556 52/1000 after 1 sec. | 215.6 MiB | K 20 | loss 2.432045498e+00 | Ndiff 1.597 53/1000 after 2 sec. | 215.6 MiB | K 20 | loss 2.431450777e+00 | Ndiff 1.617 54/1000 after 2 sec. | 215.6 MiB | K 20 | loss 2.430855677e+00 | Ndiff 1.617 55/1000 after 2 sec. | 215.6 MiB | K 20 | loss 2.430280760e+00 | Ndiff 1.600 56/1000 after 2 sec. | 215.6 MiB | K 20 | loss 2.429742357e+00 | Ndiff 1.574 57/1000 after 2 sec. | 215.6 MiB | K 20 | loss 2.429246005e+00 | Ndiff 1.547 58/1000 after 2 sec. | 215.6 MiB | K 20 | loss 2.428783131e+00 | Ndiff 1.530 59/1000 after 2 sec. | 215.6 MiB | K 20 | loss 2.428335044e+00 | Ndiff 1.531 60/1000 after 2 sec. | 215.6 MiB | K 20 | loss 2.427882144e+00 | Ndiff 1.554 61/1000 after 2 sec. | 215.6 MiB | K 20 | loss 2.427412198e+00 | Ndiff 1.598 62/1000 after 2 sec. | 215.6 MiB | K 20 | loss 2.426923365e+00 | Ndiff 1.660 63/1000 after 2 sec. | 215.6 MiB | K 20 | loss 2.426419184e+00 | Ndiff 1.735 64/1000 after 2 sec. | 215.6 MiB | K 20 | loss 2.425898789e+00 | Ndiff 1.819 65/1000 after 2 sec. | 215.6 MiB | K 20 | loss 2.425353548e+00 | Ndiff 1.909 66/1000 after 2 sec. | 215.6 MiB | K 20 | loss 2.424771068e+00 | Ndiff 2.006 67/1000 after 2 sec. | 215.6 MiB | K 20 | loss 2.424138280e+00 | Ndiff 2.111 68/1000 after 2 sec. | 215.6 MiB | K 20 | loss 2.423441574e+00 | Ndiff 2.224 69/1000 after 2 sec. | 215.6 MiB | K 20 | loss 2.422666132e+00 | Ndiff 2.343 70/1000 after 2 sec. | 215.6 MiB | K 20 | loss 2.421796837e+00 | Ndiff 2.464 71/1000 after 2 sec. | 215.6 MiB | K 20 | loss 2.420824917e+00 | Ndiff 2.578 72/1000 after 2 sec. | 215.6 MiB | K 20 | loss 2.419763730e+00 | Ndiff 2.674 73/1000 after 2 sec. | 215.6 MiB | K 20 | loss 2.418655856e+00 | Ndiff 2.744 74/1000 after 2 sec. | 215.6 MiB | K 20 | loss 2.417536411e+00 | Ndiff 2.794 75/1000 after 2 sec. | 215.6 MiB | K 20 | loss 2.416396660e+00 | Ndiff 2.833 76/1000 after 2 sec. | 215.6 MiB | K 20 | loss 2.415209641e+00 | Ndiff 2.867 77/1000 after 2 sec. | 215.6 MiB | K 20 | loss 2.413957301e+00 | Ndiff 2.892 78/1000 after 2 sec. | 215.6 MiB | K 20 | loss 2.412627524e+00 | Ndiff 2.908 79/1000 after 2 sec. | 215.6 MiB | K 20 | loss 2.411204929e+00 | Ndiff 2.915 80/1000 after 2 sec. | 215.6 MiB | K 20 | loss 2.409665366e+00 | Ndiff 2.917 81/1000 after 2 sec. | 215.6 MiB | K 20 | loss 2.407972796e+00 | Ndiff 2.914 82/1000 after 2 sec. | 215.6 MiB | K 20 | loss 2.406075452e+00 | Ndiff 2.906 83/1000 after 2 sec. | 215.6 MiB | K 20 | loss 2.403897086e+00 | Ndiff 2.888 84/1000 after 2 sec. | 215.6 MiB | K 20 | loss 2.401317119e+00 | Ndiff 2.847 85/1000 after 2 sec. | 215.6 MiB | K 20 | loss 2.398127190e+00 | Ndiff 2.749 86/1000 after 3 sec. | 215.6 MiB | K 20 | loss 2.393870554e+00 | Ndiff 2.514 87/1000 after 3 sec. | 215.6 MiB | K 20 | loss 2.385927886e+00 | Ndiff 2.011 88/1000 after 3 sec. | 215.6 MiB | K 20 | loss 2.373557979e+00 | Ndiff 1.320 89/1000 after 3 sec. | 215.6 MiB | K 20 | loss 2.370431956e+00 | Ndiff 0.732 90/1000 after 3 sec. | 215.6 MiB | K 20 | loss 2.361734581e+00 | Ndiff 0.225 91/1000 after 3 sec. | 215.6 MiB | K 20 | loss 2.356420337e+00 | Ndiff 0.052 92/1000 after 3 sec. | 215.6 MiB | K 20 | loss 2.356419191e+00 | Ndiff 0.040 93/1000 after 3 sec. | 215.6 MiB | K 20 | loss 2.356418521e+00 | Ndiff 0.031 94/1000 after 3 sec. | 215.6 MiB | K 20 | loss 2.356418120e+00 | Ndiff 0.024 95/1000 after 3 sec. | 215.6 MiB | K 20 | loss 2.356417878e+00 | Ndiff 0.019 96/1000 after 3 sec. | 215.6 MiB | K 20 | loss 2.356417731e+00 | Ndiff 0.015 97/1000 after 3 sec. | 215.6 MiB | K 20 | loss 2.356417640e+00 | Ndiff 0.011 98/1000 after 3 sec. | 215.6 MiB | K 20 | loss 2.356417584e+00 | Ndiff 0.009 99/1000 after 3 sec. | 215.6 MiB | K 20 | loss 2.356417550e+00 | Ndiff 0.007 100/1000 after 3 sec. | 215.6 MiB | K 20 | loss 2.356417528e+00 | Ndiff 0.006 101/1000 after 3 sec. | 215.6 MiB | K 20 | loss 2.356417515e+00 | Ndiff 0.004 102/1000 after 3 sec. | 215.6 MiB | K 20 | loss 2.356417507e+00 | Ndiff 0.003 103/1000 after 3 sec. | 215.6 MiB | K 20 | loss 2.356417502e+00 | Ndiff 0.003 104/1000 after 3 sec. | 215.6 MiB | K 20 | loss 2.356417499e+00 | Ndiff 0.002 105/1000 after 3 sec. | 215.6 MiB | K 20 | loss 2.356417497e+00 | Ndiff 0.002 106/1000 after 3 sec. | 215.6 MiB | K 20 | loss 2.356417495e+00 | Ndiff 0.001 107/1000 after 3 sec. | 215.6 MiB | K 20 | loss 2.356417495e+00 | Ndiff 0.001 108/1000 after 3 sec. | 215.6 MiB | K 20 | loss 2.356417494e+00 | Ndiff 0.001 109/1000 after 3 sec. | 215.6 MiB | K 20 | loss 2.356417494e+00 | Ndiff 0.001 110/1000 after 3 sec. | 215.6 MiB | K 20 | loss 2.356417494e+00 | Ndiff 0.001 111/1000 after 3 sec. | 215.6 MiB | K 20 | loss 2.356417493e+00 | Ndiff 0.000 112/1000 after 3 sec. | 215.6 MiB | K 20 | loss 2.356417493e+00 | Ndiff 0.000 113/1000 after 3 sec. | 215.6 MiB | K 20 | loss 2.356417493e+00 | Ndiff 0.000 114/1000 after 3 sec. | 215.6 MiB | K 20 | loss 2.356417493e+00 | Ndiff 0.000 115/1000 after 3 sec. | 215.6 MiB | K 20 | loss 2.356417493e+00 | Ndiff 0.000 116/1000 after 3 sec. | 215.6 MiB | K 20 | loss 2.356417493e+00 | Ndiff 0.000 117/1000 after 3 sec. | 215.6 MiB | K 20 | loss 2.356417493e+00 | Ndiff 0.000 ... done. converged. .. GENERATED FROM PYTHON SOURCE LINES 113-119 *ZeroMeanGauss* observations + VB --------------------------------- Assume full covariances and fix all means to zero. Start with too many clusters (K=20) .. GENERATED FROM PYTHON SOURCE LINES 120-131 .. code-block:: default zm_trained_model, zm_info_dict = bnpy.run( dataset, 'DPMixtureModel', 'ZeroMeanGauss', 'VB', output_path='/tmp/faithful/showcase-K=20-lik=ZeroMeanGauss-ECovMat=5*eye/', nLap=1000, nTask=1, nBatch=1, convergeThr=0.0001, gamma0=gamma, sF=sF, ECovMat='eye', K=K, initname='randexamples', ) show_clusters_over_time(zm_info_dict['task_output_path']) .. image-sg:: /examples/03_faithful/images/sphx_glr_plot-01-demo=vb_algs-model=mix_gauss_004.png :alt: lap: 0, lap: 1, lap: 2, lap: 10, lap: 20, lap: 56 :srcset: /examples/03_faithful/images/sphx_glr_plot-01-demo=vb_algs-model=mix_gauss_004.png :class: sphx-glr-single-img .. rst-class:: sphx-glr-script-out .. code-block:: none WARNING: Found unrecognized keyword args. These are ignored. --nBatch Dataset Summary: X Data num examples: 272 num dims: 2 Allocation Model: DP mixture with K=0. Concentration gamma0= 5.00 Obs. Data Model: Gaussian with fixed zero means, full covariance. Obs. Data Prior: Wishart on prec matrix Lam E[ CovMat[k] ] = [[5. 0.] [0. 5.]] Initialization: initname = randexamples K = 20 (number of clusters) seed = 1607680 elapsed_time: 0.0 sec Learn Alg: VB | task 1/1 | alg. seed: 1607680 | data order seed: 8541952 task_output_path: /tmp/faithful/showcase-K=20-lik=ZeroMeanGauss-ECovMat=5*eye/1 1/1000 after 0 sec. | 217.2 MiB | K 20 | loss 4.019419551e+00 | 2/1000 after 0 sec. | 217.2 MiB | K 20 | loss 3.989063967e+00 | Ndiff 3.437 3/1000 after 0 sec. | 217.2 MiB | K 20 | loss 3.941416771e+00 | Ndiff 3.621 4/1000 after 0 sec. | 217.2 MiB | K 20 | loss 3.899222162e+00 | Ndiff 3.882 5/1000 after 0 sec. | 217.2 MiB | K 20 | loss 3.876371765e+00 | Ndiff 4.129 6/1000 after 0 sec. | 217.2 MiB | K 20 | loss 3.842038547e+00 | Ndiff 4.274 7/1000 after 0 sec. | 217.2 MiB | K 20 | loss 3.808536791e+00 | Ndiff 4.341 8/1000 after 0 sec. | 217.2 MiB | K 20 | loss 3.779782054e+00 | Ndiff 4.418 9/1000 after 0 sec. | 217.2 MiB | K 20 | loss 3.755287922e+00 | Ndiff 4.502 10/1000 after 0 sec. | 217.2 MiB | K 20 | loss 3.733837701e+00 | Ndiff 4.577 11/1000 after 0 sec. | 217.2 MiB | K 20 | loss 3.700117498e+00 | Ndiff 4.646 12/1000 after 0 sec. | 217.2 MiB | K 20 | loss 3.694082693e+00 | Ndiff 4.826 13/1000 after 0 sec. | 217.2 MiB | K 20 | loss 3.686765019e+00 | Ndiff 4.986 14/1000 after 0 sec. | 217.2 MiB | K 20 | loss 3.675117147e+00 | Ndiff 5.045 15/1000 after 0 sec. | 217.2 MiB | K 20 | loss 3.652297952e+00 | Ndiff 4.998 16/1000 after 1 sec. | 217.2 MiB | K 20 | loss 3.639249110e+00 | Ndiff 4.933 17/1000 after 1 sec. | 217.2 MiB | K 20 | loss 3.603588941e+00 | Ndiff 4.758 18/1000 after 1 sec. | 217.2 MiB | K 20 | loss 3.598749399e+00 | Ndiff 4.729 19/1000 after 1 sec. | 217.2 MiB | K 20 | loss 3.592583511e+00 | Ndiff 4.761 20/1000 after 1 sec. | 217.2 MiB | K 20 | loss 3.574936974e+00 | Ndiff 4.760 21/1000 after 1 sec. | 217.2 MiB | K 20 | loss 3.572459120e+00 | Ndiff 4.885 22/1000 after 1 sec. | 217.2 MiB | K 20 | loss 3.570977707e+00 | Ndiff 5.084 23/1000 after 1 sec. | 217.2 MiB | K 20 | loss 3.569306514e+00 | Ndiff 5.272 24/1000 after 1 sec. | 217.2 MiB | K 20 | loss 3.567412054e+00 | Ndiff 5.442 25/1000 after 1 sec. | 217.2 MiB | K 20 | loss 3.565249330e+00 | Ndiff 5.588 26/1000 after 1 sec. | 217.2 MiB | K 20 | loss 3.562749369e+00 | Ndiff 5.701 27/1000 after 1 sec. | 217.2 MiB | K 20 | loss 3.559772804e+00 | Ndiff 5.771 28/1000 after 1 sec. | 217.2 MiB | K 20 | loss 3.555894232e+00 | Ndiff 5.784 29/1000 after 1 sec. | 217.2 MiB | K 20 | loss 3.549078207e+00 | Ndiff 5.714 30/1000 after 1 sec. | 217.2 MiB | K 20 | loss 3.531303622e+00 | Ndiff 5.559 31/1000 after 1 sec. | 217.2 MiB | K 20 | loss 3.529382170e+00 | Ndiff 5.572 32/1000 after 1 sec. | 217.2 MiB | K 20 | loss 3.527237440e+00 | Ndiff 5.642 33/1000 after 1 sec. | 217.2 MiB | K 20 | loss 3.524655304e+00 | Ndiff 5.674 34/1000 after 1 sec. | 217.2 MiB | K 20 | loss 3.521314832e+00 | Ndiff 5.648 35/1000 after 1 sec. | 217.2 MiB | K 20 | loss 3.516111933e+00 | Ndiff 5.520 36/1000 after 1 sec. | 217.2 MiB | K 20 | loss 3.502294322e+00 | Ndiff 5.202 37/1000 after 1 sec. | 217.2 MiB | K 20 | loss 3.491586191e+00 | Ndiff 4.853 38/1000 after 1 sec. | 217.2 MiB | K 20 | loss 3.485039030e+00 | Ndiff 4.637 39/1000 after 1 sec. | 217.2 MiB | K 20 | loss 3.468792370e+00 | Ndiff 4.344 40/1000 after 1 sec. | 217.2 MiB | K 20 | loss 3.467909258e+00 | Ndiff 4.313 41/1000 after 1 sec. | 217.2 MiB | K 20 | loss 3.466944463e+00 | Ndiff 4.367 42/1000 after 1 sec. | 217.2 MiB | K 20 | loss 3.465874350e+00 | Ndiff 4.413 43/1000 after 1 sec. | 217.2 MiB | K 20 | loss 3.464679105e+00 | Ndiff 4.447 44/1000 after 1 sec. | 217.2 MiB | K 20 | loss 3.463333221e+00 | Ndiff 4.465 45/1000 after 2 sec. | 217.2 MiB | K 20 | loss 3.461803293e+00 | Ndiff 4.461 46/1000 after 2 sec. | 217.2 MiB | K 20 | loss 3.460044727e+00 | Ndiff 4.426 47/1000 after 2 sec. | 217.2 MiB | K 20 | loss 3.457996608e+00 | Ndiff 4.349 48/1000 after 2 sec. | 217.2 MiB | K 20 | loss 3.455573178e+00 | Ndiff 4.210 49/1000 after 2 sec. | 217.2 MiB | K 20 | loss 3.452647608e+00 | Ndiff 3.983 50/1000 after 2 sec. | 217.2 MiB | K 20 | loss 3.449011332e+00 | Ndiff 3.629 51/1000 after 2 sec. | 217.2 MiB | K 20 | loss 3.444220039e+00 | Ndiff 3.091 52/1000 after 2 sec. | 217.2 MiB | K 20 | loss 3.436689553e+00 | Ndiff 2.295 53/1000 after 2 sec. | 217.2 MiB | K 20 | loss 3.416217624e+00 | Ndiff 1.184 54/1000 after 2 sec. | 217.2 MiB | K 20 | loss 3.397676918e+00 | Ndiff 0.269 55/1000 after 2 sec. | 217.2 MiB | K 20 | loss 3.395369454e+00 | Ndiff 0.002 56/1000 after 2 sec. | 217.2 MiB | K 20 | loss 3.395369454e+00 | Ndiff 0.000 ... done. converged. .. GENERATED FROM PYTHON SOURCE LINES 132-138 *Gauss* observations + stochastic VB ------------------------------------ Assume full covariances and fix all means to zero. Start with too many clusters (K=20) .. GENERATED FROM PYTHON SOURCE LINES 139-152 .. code-block:: default stoch_trained_model, stoch_info_dict = bnpy.run( dataset, 'DPMixtureModel', 'Gauss', 'soVB', output_path=\ '/tmp/faithful/showcase-K=20-lik=Gauss-ECovMat=5*eye-alg=soVB/', nLap=50, nTask=1, nBatch=50, rhoexp=0.51, rhodelay=1.0, gamma0=gamma, sF=sF, ECovMat='eye', K=K, initname='randexamples', ) show_clusters_over_time(stoch_info_dict['task_output_path']) .. image-sg:: /examples/03_faithful/images/sphx_glr_plot-01-demo=vb_algs-model=mix_gauss_005.png :alt: lap: 0, lap: 1, lap: 2, lap: 10, lap: 20, lap: 50 :srcset: /examples/03_faithful/images/sphx_glr_plot-01-demo=vb_algs-model=mix_gauss_005.png :class: sphx-glr-single-img .. rst-class:: sphx-glr-script-out .. code-block:: none Dataset Summary: X Data total size: 272 units batch size: 6 units num. batches: 50 Allocation Model: DP mixture with K=0. Concentration gamma0= 5.00 Obs. Data Model: Gaussian with full covariance. Obs. Data Prior: Gauss-Wishart on mean and covar of each cluster E[ mean[k] ] = [0. 0.] E[ covar[k] ] = [[5. 0.] [0. 5.]] Initialization: initname = randexamples K = 20 (number of clusters) seed = 1607680 elapsed_time: 0.0 sec Learn Alg: soVB | task 1/1 | alg. seed: 1607680 | data order seed: 8541952 task_output_path: /tmp/faithful/showcase-K=20-lik=Gauss-ECovMat=5*eye-alg=soVB/1 0.020/50 after 0 sec. | 217.2 MiB | K 20 | loss 3.020105590e+01 | lrate 0.7022 0.040/50 after 0 sec. | 217.2 MiB | K 20 | loss 1.693550049e+01 | lrate 0.5710 0.060/50 after 0 sec. | 217.2 MiB | K 20 | loss 1.182159020e+01 | lrate 0.4931 1.000/50 after 1 sec. | 217.2 MiB | K 20 | loss 1.314590658e+01 | lrate 0.1346 2.000/50 after 1 sec. | 217.2 MiB | K 20 | loss 2.522804913e+00 | lrate 0.0950 3.000/50 after 2 sec. | 217.2 MiB | K 20 | loss 2.507513303e+00 | lrate 0.0774 4.000/50 after 3 sec. | 217.2 MiB | K 20 | loss 2.498366509e+00 | lrate 0.0669 5.000/50 after 3 sec. | 217.2 MiB | K 20 | loss 2.475120548e+00 | lrate 0.0597 6.000/50 after 4 sec. | 217.2 MiB | K 20 | loss 2.469641662e+00 | lrate 0.0544 7.000/50 after 5 sec. | 217.2 MiB | K 20 | loss 2.470932762e+00 | lrate 0.0503 8.000/50 after 5 sec. | 217.2 MiB | K 20 | loss 2.466158362e+00 | lrate 0.0470 9.000/50 after 6 sec. | 217.2 MiB | K 20 | loss 2.465517619e+00 | lrate 0.0443 10.000/50 after 7 sec. | 217.2 MiB | K 20 | loss 2.465435382e+00 | lrate 0.0420 11.000/50 after 7 sec. | 217.2 MiB | K 20 | loss 2.467696036e+00 | lrate 0.0400 12.000/50 after 8 sec. | 217.2 MiB | K 20 | loss 2.461949856e+00 | lrate 0.0383 13.000/50 after 8 sec. | 217.2 MiB | K 20 | loss 2.461043110e+00 | lrate 0.0367 14.000/50 after 9 sec. | 217.2 MiB | K 20 | loss 2.461264382e+00 | lrate 0.0354 15.000/50 after 10 sec. | 217.2 MiB | K 20 | loss 2.458851451e+00 | lrate 0.0342 16.000/50 after 10 sec. | 217.2 MiB | K 20 | loss 2.454655643e+00 | lrate 0.0330 17.000/50 after 11 sec. | 217.2 MiB | K 20 | loss 2.451234188e+00 | lrate 0.0320 18.000/50 after 12 sec. | 217.2 MiB | K 20 | loss 2.446281746e+00 | lrate 0.0311 19.000/50 after 12 sec. | 217.2 MiB | K 20 | loss 2.440793273e+00 | lrate 0.0303 20.000/50 after 13 sec. | 217.2 MiB | K 20 | loss 2.434143198e+00 | lrate 0.0295 21.000/50 after 14 sec. | 217.2 MiB | K 20 | loss 2.423545849e+00 | lrate 0.0288 22.000/50 after 14 sec. | 217.2 MiB | K 20 | loss 2.414924985e+00 | lrate 0.0281 23.000/50 after 15 sec. | 217.2 MiB | K 20 | loss 2.405741175e+00 | lrate 0.0275 24.000/50 after 15 sec. | 217.2 MiB | K 20 | loss 2.396273286e+00 | lrate 0.0269 25.000/50 after 16 sec. | 217.2 MiB | K 20 | loss 2.388169784e+00 | lrate 0.0263 26.000/50 after 17 sec. | 217.2 MiB | K 20 | loss 2.383086462e+00 | lrate 0.0258 27.000/50 after 17 sec. | 217.2 MiB | K 20 | loss 2.378566216e+00 | lrate 0.0253 28.000/50 after 18 sec. | 217.2 MiB | K 20 | loss 2.374468977e+00 | lrate 0.0248 29.000/50 after 19 sec. | 217.2 MiB | K 20 | loss 2.369573743e+00 | lrate 0.0244 30.000/50 after 19 sec. | 217.2 MiB | K 20 | loss 2.365030091e+00 | lrate 0.0240 31.000/50 after 20 sec. | 217.2 MiB | K 20 | loss 2.361840492e+00 | lrate 0.0236 32.000/50 after 21 sec. | 217.2 MiB | K 20 | loss 2.358180297e+00 | lrate 0.0232 33.000/50 after 21 sec. | 217.2 MiB | K 20 | loss 2.355370718e+00 | lrate 0.0229 34.000/50 after 22 sec. | 217.2 MiB | K 20 | loss 2.354001069e+00 | lrate 0.0225 35.000/50 after 23 sec. | 217.2 MiB | K 20 | loss 2.352655966e+00 | lrate 0.0222 36.000/50 after 23 sec. | 217.2 MiB | K 20 | loss 2.351706756e+00 | lrate 0.0219 37.000/50 after 24 sec. | 217.2 MiB | K 20 | loss 2.351346242e+00 | lrate 0.0216 38.000/50 after 24 sec. | 217.2 MiB | K 20 | loss 2.351470554e+00 | lrate 0.0213 39.000/50 after 25 sec. | 217.2 MiB | K 20 | loss 2.351530696e+00 | lrate 0.0210 40.000/50 after 26 sec. | 217.2 MiB | K 20 | loss 2.351708300e+00 | lrate 0.0207 41.000/50 after 26 sec. | 217.2 MiB | K 20 | loss 2.351489520e+00 | lrate 0.0205 42.000/50 after 27 sec. | 217.2 MiB | K 20 | loss 2.351625569e+00 | lrate 0.0202 43.000/50 after 28 sec. | 217.2 MiB | K 20 | loss 2.351357552e+00 | lrate 0.0200 44.000/50 after 28 sec. | 217.2 MiB | K 20 | loss 2.351536177e+00 | lrate 0.0197 45.000/50 after 29 sec. | 217.2 MiB | K 20 | loss 2.351543540e+00 | lrate 0.0195 46.000/50 after 30 sec. | 217.2 MiB | K 20 | loss 2.351346627e+00 | lrate 0.0193 47.000/50 after 30 sec. | 217.2 MiB | K 20 | loss 2.351293738e+00 | lrate 0.0191 48.000/50 after 31 sec. | 217.2 MiB | K 20 | loss 2.351507196e+00 | lrate 0.0189 49.000/50 after 31 sec. | 217.2 MiB | K 20 | loss 2.351496294e+00 | lrate 0.0187 50.000/50 after 32 sec. | 217.2 MiB | K 20 | loss 2.351523756e+00 | lrate 0.0185 ... active. not converged. .. GENERATED FROM PYTHON SOURCE LINES 153-156 Compare loss function traces for all methods -------------------------------------------- .. GENERATED FROM PYTHON SOURCE LINES 157-183 .. code-block:: default pylab.figure() pylab.plot( zm_info_dict['lap_history'], zm_info_dict['loss_history'], 'b.-', label='full_covar zero_mean') pylab.plot( full_info_dict['lap_history'], full_info_dict['loss_history'], 'k.-', label='full_covar') pylab.plot( diag_info_dict['lap_history'], diag_info_dict['loss_history'], 'r.-', label='diag_covar') pylab.plot( stoch_info_dict['lap_history'], stoch_info_dict['loss_history'], 'c.:', label='full_covar stochastic') pylab.legend(loc='upper right') pylab.xlabel('num. laps') pylab.ylabel('loss') pylab.xlim([4, 100]) # avoid early iterations pylab.ylim([2.34, 4.0]) # handpicked pylab.draw() pylab.tight_layout() .. image-sg:: /examples/03_faithful/images/sphx_glr_plot-01-demo=vb_algs-model=mix_gauss_006.png :alt: plot 01 demo=vb algs model=mix gauss :srcset: /examples/03_faithful/images/sphx_glr_plot-01-demo=vb_algs-model=mix_gauss_006.png :class: sphx-glr-single-img .. GENERATED FROM PYTHON SOURCE LINES 184-186 Inspect the learned distribution over appearance probabilities -------------------------------------------------------------- .. GENERATED FROM PYTHON SOURCE LINES 187-196 .. code-block:: default # E_proba_K : 1D array, size n_clusters # Each entry gives expected probability of that cluster E_proba_K = stoch_trained_model.allocModel.get_active_comp_probs() print("probability of each cluster:") print(E_proba_K) .. rst-class:: sphx-glr-script-out .. code-block:: none probability of each cluster: [0.004 0.004 0.004 0.004 0.004 0.004 0.627 0.003 0.003 0.003 0.003 0.003 0.003 0.003 0.312 0.003 0.002 0.002 0.002 0.001] .. GENERATED FROM PYTHON SOURCE LINES 197-209 Inspect the learned means and covariance distributions ------------------------------------------------------ Remember that each cluster has the following approximate posterior over its mean vector $\mu$ and covariance matrix $\Sigma$: $$ q(\mu, \Sigma) = \Normal(\mu | m, \kappa \Sigma) \Wishart(\Sigma | \nu, S) $$ We show here how to compute the expected mean of $\mu$ and $\Sigma$ from a trained model. .. GENERATED FROM PYTHON SOURCE LINES 210-220 .. code-block:: default for k in range(K): E_mu_k = stoch_trained_model.obsModel.get_mean_for_comp(k) E_Sigma_k = stoch_trained_model.obsModel.get_covar_mat_for_comp(k) print("") print("mean[k=%d]" % k) print(E_mu_k) print("covar[k=%d]" % k) print(E_Sigma_k) .. rst-class:: sphx-glr-script-out .. code-block:: none mean[k=0] [0. 0.] covar[k=0] [[5. 0.] [0. 5.]] mean[k=1] [0. 0.] covar[k=1] [[5. 0.] [0. 5.]] mean[k=2] [0. 0.] covar[k=2] [[5. 0.] [0. 5.]] mean[k=3] [0. 0.] covar[k=3] [[5. 0.] [0. 5.]] mean[k=4] [0. 0.] covar[k=4] [[5. 0.] [0. 5.]] mean[k=5] [0. 0.] covar[k=5] [[5. 0.] [0. 5.]] mean[k=6] [ 4.304 79.93 ] covar[k=6] [[ 0.194 0.92 ] [ 0.92 36.103]] mean[k=7] [0. 0.] covar[k=7] [[5. 0.] [0. 5.]] mean[k=8] [0. 0.] covar[k=8] [[5. 0.] [0. 5.]] mean[k=9] [0. 0.] covar[k=9] [[5. 0.] [0. 5.]] mean[k=10] [0. 0.] covar[k=10] [[5. 0.] [0. 5.]] mean[k=11] [0. 0.] covar[k=11] [[5. 0.] [0. 5.]] mean[k=12] [0. 0.] covar[k=12] [[5. 0.] [0. 5.]] mean[k=13] [0. 0.] covar[k=13] [[5. 0.] [0. 5.]] mean[k=14] [ 2.034 54.415] covar[k=14] [[ 0.12 0.453] [ 0.453 33.606]] mean[k=15] [0. 0.] covar[k=15] [[5. 0.] [0. 5.]] mean[k=16] [0. 0.] covar[k=16] [[5. 0.] [0. 5.]] mean[k=17] [0. 0.] covar[k=17] [[5. 0.] [0. 5.]] mean[k=18] [0. 0.] covar[k=18] [[5. 0.] [0. 5.]] mean[k=19] [0. 0.] covar[k=19] [[5. 0.] [0. 5.]] .. rst-class:: sphx-glr-timing **Total running time of the script:** ( 0 minutes 42.005 seconds) .. _sphx_glr_download_examples_03_faithful_plot-01-demo=vb_algs-model=mix_gauss.py: .. only:: html .. container:: sphx-glr-footer sphx-glr-footer-example .. container:: sphx-glr-download sphx-glr-download-python :download:`Download Python source code: plot-01-demo=vb_algs-model=mix_gauss.py ` .. container:: sphx-glr-download sphx-glr-download-jupyter :download:`Download Jupyter notebook: plot-01-demo=vb_algs-model=mix_gauss.ipynb ` .. only:: html .. rst-class:: sphx-glr-signature `Gallery generated by Sphinx-Gallery `_