Skip to content

Commit 36bba1f

Browse files
committed
refactor gmm to have more consistent API with other models
1 parent 56fb4a6 commit 36bba1f

File tree

1 file changed

+112
-97
lines changed

1 file changed

+112
-97
lines changed

numpy_ml/gmm/gmm.py

Lines changed: 112 additions & 97 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
1+
"""A Gaussian mixture model class"""
12
import numpy as np
2-
from numpy.testing import assert_allclose
3+
4+
from numpy_ml.utils.misc import logsumexp, log_gaussian_pdf
35

46

57
class GMM(object):
@@ -31,48 +33,52 @@ def __init__(self, C=3, seed=None):
3133
sigma : :py:class:`ndarray <numpy.ndarray>` of shape `(C, d, d)`
3234
The cluster covariance matrices.
3335
"""
34-
self.C = C # number of clusters
35-
self.N = None # number of objects
36-
self.d = None # dimension of each object
37-
self.seed = seed
36+
self.elbo = None
37+
self.parameters = {}
38+
self.hyperparameters = {
39+
"C": C,
40+
"seed": seed,
41+
}
3842

39-
if self.seed:
40-
np.random.seed(self.seed)
43+
self.is_fit = False
44+
45+
if seed:
46+
np.random.seed(seed)
47+
48+
def _initialize_params(self, X):
49+
"""Randomly initialize the starting GMM parameters."""
50+
N, d = X.shape
51+
C = self.hyperparameters["C"]
4152

42-
def _initialize_params(self):
43-
"""
44-
Randomly initialize the starting GMM parameters.
45-
"""
46-
C, d = self.C, self.d
4753
rr = np.random.rand(C)
4854

49-
self.pi = rr / rr.sum() # cluster priors
50-
self.Q = np.zeros((self.N, C)) # variational distribution q(T)
51-
self.mu = np.random.uniform(-5, 10, C * d).reshape(C, d) # cluster means
52-
self.sigma = np.array([np.identity(d) for _ in range(C)]) # cluster covariances
55+
self.parameters = {
56+
"pi": rr / rr.sum(), # cluster priors
57+
"Q": np.zeros((N, C)), # variational distribution q(T)
58+
"mu": np.random.uniform(-5, 10, C * d).reshape(C, d), # cluster means
59+
"sigma": np.array([np.eye(d) for _ in range(C)]), # cluster covariances
60+
}
5361

54-
self.best_pi = None
55-
self.best_mu = None
56-
self.best_sigma = None
57-
self.best_elbo = -np.inf
62+
self.elbo = None
63+
self.is_fit = False
5864

59-
def likelihood_lower_bound(self):
60-
"""
61-
Compute the LLB under the current GMM parameters.
62-
"""
63-
N = self.N
64-
C = self.C
65+
def likelihood_lower_bound(self, X):
66+
"""Compute the LLB under the current GMM parameters."""
67+
N = X.shape[0]
68+
P = self.parameters
69+
C = self.hyperparameters["C"]
70+
pi, Q, mu, sigma = P["pi"], P["Q"], P["mu"], P["sigma"]
6571

6672
eps = np.finfo(float).eps
6773
expec1, expec2 = 0.0, 0.0
6874
for i in range(N):
69-
x_i = self.X[i]
75+
x_i = X[i]
7076

7177
for c in range(C):
72-
pi_k = self.pi[c]
73-
z_nk = self.Q[i, c]
74-
mu_k = self.mu[c, :]
75-
sigma_k = self.sigma[c, :, :]
78+
pi_k = pi[c]
79+
z_nk = Q[i, c]
80+
mu_k = mu[c, :]
81+
sigma_k = sigma[c, :, :]
7682

7783
log_pi_k = np.log(pi_k + eps)
7884
log_p_x_i = log_gaussian_pdf(x_i, mu_k, sigma_k)
@@ -110,49 +116,83 @@ def fit(self, X, max_iter=100, tol=1e-3, verbose=False):
110116
mixture components collapsed and training was halted prematurely
111117
(-1).
112118
"""
113-
self.X = X
114-
self.N = X.shape[0] # number of objects
115-
self.d = X.shape[1] # dimension of each object
116-
117-
self._initialize_params()
118119
prev_vlb = -np.inf
120+
self._initialize_params(X)
119121

120122
for _iter in range(max_iter):
121123
try:
122-
self._E_step()
123-
self._M_step()
124-
vlb = self.likelihood_lower_bound()
124+
self._E_step(X)
125+
self._M_step(X)
126+
vlb = self.likelihood_lower_bound(X)
125127

126128
if verbose:
127-
print("{}. Lower bound: {}".format(_iter + 1, vlb))
129+
print(f"{_iter + 1}. Lower bound: {vlb}")
128130

129131
converged = _iter > 0 and np.abs(vlb - prev_vlb) <= tol
130132
if np.isnan(vlb) or converged:
131133
break
132134

133135
prev_vlb = vlb
134136

135-
# retain best parameters across fits
136-
if vlb > self.best_elbo:
137-
self.best_elbo = vlb
138-
self.best_mu = self.mu
139-
self.best_pi = self.pi
140-
self.best_sigma = self.sigma
141-
142137
except np.linalg.LinAlgError:
143138
print("Singular matrix: components collapsed")
144139
return -1
140+
141+
self.elbo = vlb
142+
self.is_fit = True
145143
return 0
146144

147-
def _E_step(self):
148-
for i in range(self.N):
149-
x_i = self.X[i, :]
145+
def predict(self, X, soft_labels=True):
146+
"""
147+
Return the log probability of each data point in `X` under each
148+
mixture components.
149+
150+
Parameters
151+
----------
152+
X : :py:class:`ndarray <numpy.ndarray>` of shape `(M, d)`
153+
A collection of `M` data points, each with dimension `d`.
154+
soft_labels : bool
155+
If True, return the log probabilities of the M data points in X
156+
under each mixture component. If False, return only the ID of the
157+
most probable mixture. Default is True.
158+
159+
Returns
160+
-------
161+
y : :py:class:`ndarray <numpy.ndarray>` of shape `(M, C)` or `(M,)`
162+
If `soft_labels` is True, `y` is a 2D array where index (i,j) gives
163+
the log probability of the `i` th data point under the `j` th
164+
mixture component. If `soft_labels` is False, `y` is a 1D array
165+
where the `i` th index contains the ID of the most probable mixture
166+
component.
167+
"""
168+
assert self.is_fit, "Must call the `.fit` method before making predictions"
169+
170+
P = self.parameters
171+
C = self.hyperparameters["C"]
172+
mu, sigma = P["mu"], P["sigma"]
173+
174+
y = []
175+
for x_i in X:
176+
cprobs = [log_gaussian_pdf(x_i, mu[c, :], sigma[c, :, :]) for c in range(C)]
150177

178+
if not soft_labels:
179+
y.append(np.argmax(cprobs))
180+
else:
181+
y.append(cprobs)
182+
183+
return np.array(y)
184+
185+
def _E_step(self, X):
186+
P = self.parameters
187+
C = self.hyperparameters["C"]
188+
pi, Q, mu, sigma = P["pi"], P["Q"], P["mu"], P["sigma"]
189+
190+
for i, x_i in enumerate(X):
151191
denom_vals = []
152-
for c in range(self.C):
153-
pi_c = self.pi[c]
154-
mu_c = self.mu[c, :]
155-
sigma_c = self.sigma[c, :, :]
192+
for c in range(C):
193+
pi_c = pi[c]
194+
mu_c = mu[c, :]
195+
sigma_c = sigma[c, :, :]
156196

157197
log_pi_c = np.log(pi_c)
158198
log_p_x_i = log_gaussian_pdf(x_i, mu_c, sigma_c)
@@ -163,63 +203,38 @@ def _E_step(self):
163203
# log \sum_c exp{ log N(X_i | mu_c, Sigma_c) + log pi_c } ]
164204
log_denom = logsumexp(denom_vals)
165205
q_i = np.exp([num - log_denom for num in denom_vals])
166-
assert_allclose(np.sum(q_i), 1, err_msg="{}".format(np.sum(q_i)))
206+
np.testing.assert_allclose(np.sum(q_i), 1, err_msg="{}".format(np.sum(q_i)))
167207

168-
self.Q[i, :] = q_i
208+
Q[i, :] = q_i
169209

170-
def _M_step(self):
171-
C, N, X = self.C, self.N, self.X
172-
denoms = np.sum(self.Q, axis=0)
210+
def _M_step(self, X):
211+
N, d = X.shape
212+
P = self.parameters
213+
C = self.hyperparameters["C"]
214+
pi, Q, mu, sigma = P["pi"], P["Q"], P["mu"], P["sigma"]
215+
216+
denoms = np.sum(Q, axis=0)
173217

174218
# update cluster priors
175-
self.pi = denoms / N
219+
pi = denoms / N
176220

177221
# update cluster means
178-
nums_mu = [np.dot(self.Q[:, c], X) for c in range(C)]
222+
nums_mu = [np.dot(Q[:, c], X) for c in range(C)]
179223
for ix, (num, den) in enumerate(zip(nums_mu, denoms)):
180-
self.mu[ix, :] = num / den if den > 0 else np.zeros_like(num)
224+
mu[ix, :] = num / den if den > 0 else np.zeros_like(num)
181225

182226
# update cluster covariances
183227
for c in range(C):
184-
mu_c = self.mu[c, :]
228+
mu_c = mu[c, :]
185229
n_c = denoms[c]
186230

187-
outer = np.zeros((self.d, self.d))
231+
outer = np.zeros((d, d))
188232
for i in range(N):
189-
wic = self.Q[i, c]
190-
xi = self.X[i, :]
233+
wic = Q[i, c]
234+
xi = X[i, :]
191235
outer += wic * np.outer(xi - mu_c, xi - mu_c)
192236

193237
outer = outer / n_c if n_c > 0 else outer
194-
self.sigma[c, :, :] = outer
195-
196-
assert_allclose(np.sum(self.pi), 1, err_msg="{}".format(np.sum(self.pi)))
197-
198-
199-
#######################################################################
200-
# Utils #
201-
#######################################################################
202-
203-
204-
def log_gaussian_pdf(x_i, mu, sigma):
205-
"""
206-
Compute log N(x_i | mu, sigma)
207-
"""
208-
n = len(mu)
209-
a = n * np.log(2 * np.pi)
210-
_, b = np.linalg.slogdet(sigma)
211-
212-
y = np.linalg.solve(sigma, x_i - mu)
213-
c = np.dot(x_i - mu, y)
214-
return -0.5 * (a + b + c)
215-
238+
sigma[c, :, :] = outer
216239

217-
def logsumexp(log_probs, axis=None):
218-
"""
219-
Redefine scipy.special.logsumexp
220-
see: http://bayesjumping.net/log-sum-exp-trick/
221-
"""
222-
_max = np.max(log_probs)
223-
ds = log_probs - _max
224-
exp_sum = np.exp(ds).sum(axis=axis)
225-
return _max + np.log(exp_sum)
240+
np.testing.assert_allclose(np.sum(pi), 1, err_msg="{}".format(np.sum(pi)))

0 commit comments

Comments
 (0)