変分混合ガウス分布(PRML 10.2)
入力にはOldFaithful間欠泉データ集合を用いた.
データは標準入力から読み込むようにした.
図はK=6の場合. 混合係数の期待値は2つ以外ほぼ0になったので満足.
初期値m=0だとおかしくなるのは, パラメータの初期値がそれぞれ等しい要素を持つようにすると, 負担率が均等になってそこで収束するからかな.
なおソース内ではαがaになっていたりΛがAになっていたりνがvになっていたりρがpになっていたりするので, 読む場合は*1心の目で適宜sedして下さい.
#!/usr/bin/python # -*- coding: utf-8 -*- import scipy as sp from scipy.linalg import det, inv from scipy.special.basic import digamma def vbgmm(X, K, a_0=1e-2, b_0=1e-2, v_0=10.0, mtol=1e-3, iter=1e4): N, D = X.shape a = a_0 * sp.ones(K) + N/K # α b = b_0 * sp.ones(K) + N/K m = sp.arange(1, K*D+1).reshape(K, D) * 00.1 W = [sp.eye(D) for i in xrange(K)] v = v_0 * sp.ones(K) + N/K # ν Epi = a / (K*a_0 + N) i = 0 xk = sp.zeros((K, D)) Sk = sp.zeros((K, D, D)) E_uA = sp.zeros((K, N)) while i < iter: # e-step E_uA.resize((K, N)) for k in xrange(K): E_uA[k] = D/b[k] + v[k]*sp.diag(sp.dot(X-m[k], sp.dot(W[k], (X-m[k]).T))) E_uA = E_uA.T lnAk = D * sp.log(2.0) + [digamma((v[k]+1-sp.arange(1, D+1)) / 2.0).sum() + sp.log(det(W[k])) for k in xrange(K)] lnpi = digamma(a) - digamma(a.sum()) p = sp.exp(lnpi + lnAk/2.0 - D*sp.log(2*sp.pi)/2.0 - E_uA/2.0) r = (p.T/p.sum(axis=1)).T # m-step Nk = r.sum(axis=0) for k in xrange(K): xk[k] = (X.T * r[:, k]).sum(axis=1)/Nk[k] Sk *= 0.0 for k in xrange(K): for n in xrange(N): Sk[k] += r[n][k] * sp.outer(X[n]-xk[k], X[n]-xk[k]) Sk[k] /= Nk[k] W = [inv(sp.eye(D) + Nk[k]*Sk[k] + (b_0*Nk[k]/(b_0+Nk[k]))*sp.outer(xk[k], xk[k])) for k in xrange(K)] a = a_0 + Nk b = b_0 + Nk v = v_0 + Nk Epi = a / (K*a_0 + N) if (sp.absolute(m - (xk.T*Nk/b).T)).sum() < mtol: break m = (xk.T*Nk/b).T return dict(a=a, b=b, m=m, W=W, v=v, Epi=Epi) if __name__ == '__main__': import sys import matplotlib.pyplot as plt d = sp.loadtxt(sys.stdin) d = d - [d[:, 0].min(), d[:, 1].min()] d = (2 * d / [d[:, 0].max(), d[:, 1].max()] - 1) Z = vbgmm(d, 6) print "α=%s\nβ=%s\nm=%s\nW=%s\nν=%s\nE[π]=%s\n" % (Z['a'], Z['b'], Z['m'], Z['W'], Z['v'], Z['Epi']) indices = Z['Epi']>1e-2 plt.scatter(d[:, 0], d[:, 1], color='w', edgecolors='g') plt.scatter(Z['m'][indices][:, 0], Z['m'][indices][:, 1], color='r', s=40) plt.show()
ちなみに今回はこの記事にお世話になりました(mの初期値を散らすようにしたのはここで気付いた). 色々と真面目に考察されているのがタメになってありがたかったです.*2
コードが動いて満足したので11章のサンプリング法へ.