GMM_gen.ipynb

by tirthajyoti

notebooks/GMM_gen.ipynb

Synthesizing a Gaussian Mixture Model (GMM) dataset

Dr. Tirthajyoti Sarkar, Fremont, CA 94536

Gaussian mixture models are a probabilistic model for representing normally distributed subpopulations within an overall population. Mixture models in general don't require knowing which subpopulation a data point belongs to, allowing the model to learn the subpopulations automatically. Since subpopulation assignment is not known, this constitutes a form of unsupervised learning.

import numpy as np
import matplotlib.pyplot as plt
import random
import seaborn as sns

The generating function

def gen_GMM(N=1000,n_comp=3, mu=[-1,0,1],sigma=[1,1,1],mult=[1,1,1]):
    """
    Generates a Gaussian mixture model data, from a given list of Gaussian components
    N: Number of total samples (data points)
    n_comp: Number of Gaussian components
    mu: List of mean values of the Gaussian components
    sigma: List of sigma (std. dev) values of the Gaussian components
    mult: (Optional) list of multiplier for the Gaussian components
    """
    assert n_comp == len(mu), "The length of the list of mean values does not match number of Gaussian components"
    assert n_comp == len(sigma), "The length of the list of sigma values does not match number of Gaussian components"
    assert n_comp == len(mult), "The length of the list of multiplier values does not match number of Gaussian components"
    rand_samples = []
    for i in range(N):
        pivot = random.uniform(0,n_comp)
        j = int(pivot)
        rand_samples.append(mult[j]*random.gauss(mu[j],sigma[j]))
        
    return np.array(rand_samples)

Testing the function for the AssertionErrors

gen_GMM(N=10,n_comp=4,mu=[1,2,0],sigma=[1,1,1,2])
---------------------------------------------------------------------------
AssertionError                            Traceback (most recent call last)
<ipython-input-7-812a110f17f3> in <module>
----> 1 gen_GMM(N=10,n_comp=4,mu=[1,2,0],sigma=[1,1,1,2])

<ipython-input-6-13828bf507a2> in gen_GMM(N, n_comp, mu, sigma, mult)
      8     mult: (Optional) list of multiplier for the Gaussian components
      9     """
---> 10     assert n_comp == len(mu), "The length of the list of mean values does not match number of Gaussian components"
     11     assert n_comp == len(sigma), "The length of the list of sigma values does not match number of Gaussian components"
     12     assert n_comp == len(mult), "The length of the list of multiplier values does not match number of Gaussian components"

AssertionError: The length of the list of mean values does not match number of Gaussian components
gen_GMM(N=10,n_comp=4,mu=[1,2,0,-1],sigma=[1,1,2])
---------------------------------------------------------------------------
AssertionError                            Traceback (most recent call last)
<ipython-input-8-f59c57e7c6fc> in <module>
----> 1 gen_GMM(N=10,n_comp=4,mu=[1,2,0,-1],sigma=[1,1,2])

<ipython-input-6-13828bf507a2> in gen_GMM(N, n_comp, mu, sigma, mult)
      9     """
     10     assert n_comp == len(mu), "The length of the list of mean values does not match number of Gaussian components"
---> 11     assert n_comp == len(sigma), "The length of the list of sigma values does not match number of Gaussian components"
     12     assert n_comp == len(mult), "The length of the list of multiplier values does not match number of Gaussian components"
     13     rand_samples = []

AssertionError: The length of the list of sigma values does not match number of Gaussian components

Data and plot examples

data = gen_GMM(N=10000,mu=[-6,0,6])
sns.distplot(data,bins=50,hist_kws={'color':'blue','edgecolor':'k'},kde_kws={'lw':3,'color':'k'})
plt.show()
data = gen_GMM(N=10000,mu=[-3,0,3])
sns.distplot(data,bins=50,hist_kws={'color':'blue','edgecolor':'k'},kde_kws={'lw':3,'color':'k'})
plt.show()
data = gen_GMM(N=10000,mu=[-5,0,5],sigma=[1,2,1.5])
sns.distplot(data,bins=50,hist_kws={'color':'blue','edgecolor':'k'},kde_kws={'lw':3,'color':'k'})
plt.show()
data = gen_GMM(N=10000,mu=[-5,0,5],sigma=[1.8,0.3,1.1],mult=[0.7,1.8,1.1])
sns.distplot(data,bins=50,hist_kws={'color':'blue','edgecolor':'k'},kde_kws={'lw':3,'color':'k'})
plt.show()