1+ """A Gaussian mixture model class"""
12import numpy as np
2- from numpy .testing import assert_allclose
3+
4+ from numpy_ml .utils .misc import logsumexp , log_gaussian_pdf
35
46
57class GMM (object ):
@@ -31,48 +33,52 @@ def __init__(self, C=3, seed=None):
3133 sigma : :py:class:`ndarray <numpy.ndarray>` of shape `(C, d, d)`
3234 The cluster covariance matrices.
3335 """
34- self .C = C # number of clusters
35- self .N = None # number of objects
36- self .d = None # dimension of each object
37- self .seed = seed
36+ self .elbo = None
37+ self .parameters = {}
38+ self .hyperparameters = {
39+ "C" : C ,
40+ "seed" : seed ,
41+ }
3842
39- if self .seed :
40- np .random .seed (self .seed )
43+ self .is_fit = False
44+
45+ if seed :
46+ np .random .seed (seed )
47+
48+ def _initialize_params (self , X ):
49+ """Randomly initialize the starting GMM parameters."""
50+ N , d = X .shape
51+ C = self .hyperparameters ["C" ]
4152
42- def _initialize_params (self ):
43- """
44- Randomly initialize the starting GMM parameters.
45- """
46- C , d = self .C , self .d
4753 rr = np .random .rand (C )
4854
49- self .pi = rr / rr .sum () # cluster priors
50- self .Q = np .zeros ((self .N , C )) # variational distribution q(T)
51- self .mu = np .random .uniform (- 5 , 10 , C * d ).reshape (C , d ) # cluster means
52- self .sigma = np .array ([np .identity (d ) for _ in range (C )]) # cluster covariances
55+ self .parameters = {
56+ "pi" : rr / rr .sum (), # cluster priors
57+ "Q" : np .zeros ((N , C )), # variational distribution q(T)
58+ "mu" : np .random .uniform (- 5 , 10 , C * d ).reshape (C , d ), # cluster means
59+ "sigma" : np .array ([np .eye (d ) for _ in range (C )]), # cluster covariances
60+ }
5361
54- self .best_pi = None
55- self .best_mu = None
56- self .best_sigma = None
57- self .best_elbo = - np .inf
62+ self .elbo = None
63+ self .is_fit = False
5864
59- def likelihood_lower_bound (self ):
60- """
61- Compute the LLB under the current GMM parameters.
62- """
63- N = self .N
64- C = self . C
65+ def likelihood_lower_bound (self , X ):
66+ """Compute the LLB under the current GMM parameters."""
67+ N = X . shape [ 0 ]
68+ P = self . parameters
69+ C = self .hyperparameters [ "C" ]
70+ pi , Q , mu , sigma = P [ "pi" ], P [ "Q" ], P [ "mu" ], P [ "sigma" ]
6571
6672 eps = np .finfo (float ).eps
6773 expec1 , expec2 = 0.0 , 0.0
6874 for i in range (N ):
69- x_i = self . X [i ]
75+ x_i = X [i ]
7076
7177 for c in range (C ):
72- pi_k = self . pi [c ]
73- z_nk = self . Q [i , c ]
74- mu_k = self . mu [c , :]
75- sigma_k = self . sigma [c , :, :]
78+ pi_k = pi [c ]
79+ z_nk = Q [i , c ]
80+ mu_k = mu [c , :]
81+ sigma_k = sigma [c , :, :]
7682
7783 log_pi_k = np .log (pi_k + eps )
7884 log_p_x_i = log_gaussian_pdf (x_i , mu_k , sigma_k )
@@ -110,49 +116,83 @@ def fit(self, X, max_iter=100, tol=1e-3, verbose=False):
110116 mixture components collapsed and training was halted prematurely
111117 (-1).
112118 """
113- self .X = X
114- self .N = X .shape [0 ] # number of objects
115- self .d = X .shape [1 ] # dimension of each object
116-
117- self ._initialize_params ()
118119 prev_vlb = - np .inf
120+ self ._initialize_params (X )
119121
120122 for _iter in range (max_iter ):
121123 try :
122- self ._E_step ()
123- self ._M_step ()
124- vlb = self .likelihood_lower_bound ()
124+ self ._E_step (X )
125+ self ._M_step (X )
126+ vlb = self .likelihood_lower_bound (X )
125127
126128 if verbose :
127- print ("{ }. Lower bound: {}" . format ( _iter + 1 , vlb ) )
129+ print (f" { _iter + 1 } . Lower bound: { vlb } " )
128130
129131 converged = _iter > 0 and np .abs (vlb - prev_vlb ) <= tol
130132 if np .isnan (vlb ) or converged :
131133 break
132134
133135 prev_vlb = vlb
134136
135- # retain best parameters across fits
136- if vlb > self .best_elbo :
137- self .best_elbo = vlb
138- self .best_mu = self .mu
139- self .best_pi = self .pi
140- self .best_sigma = self .sigma
141-
142137 except np .linalg .LinAlgError :
143138 print ("Singular matrix: components collapsed" )
144139 return - 1
140+
141+ self .elbo = vlb
142+ self .is_fit = True
145143 return 0
146144
147- def _E_step (self ):
148- for i in range (self .N ):
149- x_i = self .X [i , :]
145+ def predict (self , X , soft_labels = True ):
146+ """
147+ Return the log probability of each data point in `X` under each
148+ mixture components.
149+
150+ Parameters
151+ ----------
152+ X : :py:class:`ndarray <numpy.ndarray>` of shape `(M, d)`
153+ A collection of `M` data points, each with dimension `d`.
154+ soft_labels : bool
155+ If True, return the log probabilities of the M data points in X
156+ under each mixture component. If False, return only the ID of the
157+ most probable mixture. Default is True.
158+
159+ Returns
160+ -------
161+ y : :py:class:`ndarray <numpy.ndarray>` of shape `(M, C)` or `(M,)`
162+ If `soft_labels` is True, `y` is a 2D array where index (i,j) gives
163+ the log probability of the `i` th data point under the `j` th
164+ mixture component. If `soft_labels` is False, `y` is a 1D array
165+ where the `i` th index contains the ID of the most probable mixture
166+ component.
167+ """
168+ assert self .is_fit , "Must call the `.fit` method before making predictions"
169+
170+ P = self .parameters
171+ C = self .hyperparameters ["C" ]
172+ mu , sigma = P ["mu" ], P ["sigma" ]
173+
174+ y = []
175+ for x_i in X :
176+ cprobs = [log_gaussian_pdf (x_i , mu [c , :], sigma [c , :, :]) for c in range (C )]
150177
178+ if not soft_labels :
179+ y .append (np .argmax (cprobs ))
180+ else :
181+ y .append (cprobs )
182+
183+ return np .array (y )
184+
185+ def _E_step (self , X ):
186+ P = self .parameters
187+ C = self .hyperparameters ["C" ]
188+ pi , Q , mu , sigma = P ["pi" ], P ["Q" ], P ["mu" ], P ["sigma" ]
189+
190+ for i , x_i in enumerate (X ):
151191 denom_vals = []
152- for c in range (self . C ):
153- pi_c = self . pi [c ]
154- mu_c = self . mu [c , :]
155- sigma_c = self . sigma [c , :, :]
192+ for c in range (C ):
193+ pi_c = pi [c ]
194+ mu_c = mu [c , :]
195+ sigma_c = sigma [c , :, :]
156196
157197 log_pi_c = np .log (pi_c )
158198 log_p_x_i = log_gaussian_pdf (x_i , mu_c , sigma_c )
@@ -163,63 +203,38 @@ def _E_step(self):
163203 # log \sum_c exp{ log N(X_i | mu_c, Sigma_c) + log pi_c } ]
164204 log_denom = logsumexp (denom_vals )
165205 q_i = np .exp ([num - log_denom for num in denom_vals ])
166- assert_allclose (np .sum (q_i ), 1 , err_msg = "{}" .format (np .sum (q_i )))
206+ np . testing . assert_allclose (np .sum (q_i ), 1 , err_msg = "{}" .format (np .sum (q_i )))
167207
168- self . Q [i , :] = q_i
208+ Q [i , :] = q_i
169209
170- def _M_step (self ):
171- C , N , X = self .C , self .N , self .X
172- denoms = np .sum (self .Q , axis = 0 )
210+ def _M_step (self , X ):
211+ N , d = X .shape
212+ P = self .parameters
213+ C = self .hyperparameters ["C" ]
214+ pi , Q , mu , sigma = P ["pi" ], P ["Q" ], P ["mu" ], P ["sigma" ]
215+
216+ denoms = np .sum (Q , axis = 0 )
173217
174218 # update cluster priors
175- self . pi = denoms / N
219+ pi = denoms / N
176220
177221 # update cluster means
178- nums_mu = [np .dot (self . Q [:, c ], X ) for c in range (C )]
222+ nums_mu = [np .dot (Q [:, c ], X ) for c in range (C )]
179223 for ix , (num , den ) in enumerate (zip (nums_mu , denoms )):
180- self . mu [ix , :] = num / den if den > 0 else np .zeros_like (num )
224+ mu [ix , :] = num / den if den > 0 else np .zeros_like (num )
181225
182226 # update cluster covariances
183227 for c in range (C ):
184- mu_c = self . mu [c , :]
228+ mu_c = mu [c , :]
185229 n_c = denoms [c ]
186230
187- outer = np .zeros ((self . d , self . d ))
231+ outer = np .zeros ((d , d ))
188232 for i in range (N ):
189- wic = self . Q [i , c ]
190- xi = self . X [i , :]
233+ wic = Q [i , c ]
234+ xi = X [i , :]
191235 outer += wic * np .outer (xi - mu_c , xi - mu_c )
192236
193237 outer = outer / n_c if n_c > 0 else outer
194- self .sigma [c , :, :] = outer
195-
196- assert_allclose (np .sum (self .pi ), 1 , err_msg = "{}" .format (np .sum (self .pi )))
197-
198-
199- #######################################################################
200- # Utils #
201- #######################################################################
202-
203-
204- def log_gaussian_pdf (x_i , mu , sigma ):
205- """
206- Compute log N(x_i | mu, sigma)
207- """
208- n = len (mu )
209- a = n * np .log (2 * np .pi )
210- _ , b = np .linalg .slogdet (sigma )
211-
212- y = np .linalg .solve (sigma , x_i - mu )
213- c = np .dot (x_i - mu , y )
214- return - 0.5 * (a + b + c )
215-
238+ sigma [c , :, :] = outer
216239
217- def logsumexp (log_probs , axis = None ):
218- """
219- Redefine scipy.special.logsumexp
220- see: http://bayesjumping.net/log-sum-exp-trick/
221- """
222- _max = np .max (log_probs )
223- ds = log_probs - _max
224- exp_sum = np .exp (ds ).sum (axis = axis )
225- return _max + np .log (exp_sum )
240+ np .testing .assert_allclose (np .sum (pi ), 1 , err_msg = "{}" .format (np .sum (pi )))
0 commit comments