EMアルゴリズムで混合ガウス分布
numpyのufuncの説明を読んで,何かちょろっと書きたい気分になったので書いた.相変わらずの生産性0.
実行すると適当に乱数でデータを生成し,混合ガウス分布のパラメータを推定した後,負担率に応じて色分け.
ufuncのルールをある程度理解したので,3次以上のarrayの扱いがマシになり,ループが減ったものの,弊害として一行に170文字以上書くというPEP8(笑)なコードが出来あがった*1.だからどうしたんだという気もするが.
#!/usr/bin/python # -*- coding: utf-8 -*- # PRML chapter 9 # Gaussian Mixture Model import scipy as sp from scipy.linalg import det, inv def multivariate_normal_pdf(x, u, sigma): D = len(x) x, u = sp.asarray(x), sp.asarray(u) y = x-u return sp.exp(-(sp.dot(y, sp.dot(inv(sigma), y)))/2.0) / (((2*sp.pi)**(D/2.0)) * (det(sigma) ** 0.5)) def gmm(X, K, iter=1000, tol=1e-6): """ Gaussian Mixture Model Arguments: - `X`: Input data (2D array, [[x11, x12, ..., x1D], ..., [xN1, ... xND]]). - `K`: Number of clusters. - `iter`: Number of iterations to run. - `tol`: Tolerance. """ X = sp.asarray(X) N, D = X.shape pi = sp.ones(K) * 1.0/K mu = sp.rand(K, D) sigma = sp.array([sp.eye(D) for i in xrange(K)]) L = sp.inf for i in xrange(iter): # E-step gamma = sp.apply_along_axis(lambda x: sp.fromiter((pi[k] * multivariate_normal_pdf(x, mu[k], sigma[k]) for k in xrange(K)), dtype=float), 1, X) gamma /= sp.sum(gamma, 1)[:, sp.newaxis] # M-step Nk = sp.sum(gamma, 0) mu = sp.sum(X*gamma.T[..., sp.newaxis], 1) / Nk[..., sp.newaxis] xmu = X[:, sp.newaxis, :] - mu sigma = sp.sum(gamma[..., sp.newaxis, sp.newaxis] * xmu[:, :, sp.newaxis, :] * xmu[:, :, :, sp.newaxis], 0) / Nk[..., sp.newaxis, sp.newaxis] pi = Nk / N # Likelihood Lnew = sp.sum(sp.log2(sp.sum(sp.apply_along_axis(lambda x: sp.fromiter((pi[k] * multivariate_normal_pdf(x, mu[k], sigma[k]) for k in xrange(K)), dtype=float), 1, X), 1))) if abs(L-Lnew) < tol: break L = Lnew print "L=%s" % L return dict(pi=pi, mu=mu, sigma=sigma, gamma=gamma) if __name__ == '__main__': data = sp.append(sp.random.multivariate_normal([-3.5, 5.0], sp.eye(2)*4, 50), sp.random.multivariate_normal([-8.2, 10.0], sp.eye(2)*2, 70)).reshape(50+70, 2) K = 2 d = gmm(data, K) print "π=%s\nμ=%s\nΣ=%s" % (d['pi'], d['mu'], d['sigma']) gamma = d['gamma'] import matplotlib.pyplot as plt plt.scatter(data[:, 0][gamma[:, 0] >= 0.5], data[:, 1][gamma[:, 0] >= 0.5], color='r') plt.scatter(data[:, 0][gamma[:, 1] > 0.5 ], data[:, 1][gamma[:, 1] > 0.5 ], color='g') plt.show()
"Rで学ぶクラスタ解析"を今読んでおり,それに習ってirisとかのデータを試してエントロピーとか適当に指標とってみようかと思ったけど,やる気が無かったのでまた後日 ;-p
*1:まぁ適当に括弧の中で改行すればいいのだが