変分混合ガウス分布(PRML 10.2)

入力にはOldFaithful間欠泉データ集合を用いた.
データは標準入力から読み込むようにした.
図はK=6の場合. 混合係数の期待値は2つ以外ほぼ0になったので満足.

初期値m=0だとおかしくなるのは, パラメータの初期値がそれぞれ等しい要素を持つようにすると, 負担率が均等になってそこで収束するからかな.


なおソース内ではαがaになっていたりΛがAになっていたりνがvになっていたりρがpになっていたりするので, 読む場合は*1心の目で適宜sedして下さい.

#!/usr/bin/python
# -*- coding: utf-8 -*-

import scipy as sp
from scipy.linalg import det, inv
from scipy.special.basic import digamma

def vbgmm(X, K, a_0=1e-2, b_0=1e-2, v_0=10.0, mtol=1e-3, iter=1e4):
    N, D = X.shape
    a = a_0 * sp.ones(K) + N/K # α
    b = b_0 * sp.ones(K) + N/K
    m = sp.arange(1, K*D+1).reshape(K, D) * 00.1
    W = [sp.eye(D) for i in xrange(K)]
    v = v_0 * sp.ones(K) + N/K # ν
    Epi = a / (K*a_0 + N)

    i = 0
    xk = sp.zeros((K, D))
    Sk = sp.zeros((K, D, D))
    E_uA = sp.zeros((K, N))

    while i < iter:
        # e-step
        E_uA.resize((K, N))
        for k in xrange(K):
            E_uA[k] = D/b[k] + v[k]*sp.diag(sp.dot(X-m[k], sp.dot(W[k], (X-m[k]).T)))
        E_uA = E_uA.T

        lnAk = D * sp.log(2.0) + [digamma((v[k]+1-sp.arange(1, D+1)) / 2.0).sum() + sp.log(det(W[k])) for k in xrange(K)]
        lnpi = digamma(a) - digamma(a.sum())

        p = sp.exp(lnpi + lnAk/2.0 - D*sp.log(2*sp.pi)/2.0 - E_uA/2.0)
        r = (p.T/p.sum(axis=1)).T

        # m-step
        Nk = r.sum(axis=0)
        for k in xrange(K):
            xk[k] = (X.T * r[:, k]).sum(axis=1)/Nk[k]
        Sk *= 0.0
        for k in xrange(K):
            for n in xrange(N):
                Sk[k] += r[n][k] * sp.outer(X[n]-xk[k], X[n]-xk[k])
            Sk[k] /= Nk[k]

        W = [inv(sp.eye(D) + Nk[k]*Sk[k] + (b_0*Nk[k]/(b_0+Nk[k]))*sp.outer(xk[k], xk[k])) for k in xrange(K)]
        a = a_0 + Nk
        b = b_0 + Nk
        v = v_0 + Nk
        Epi = a / (K*a_0 + N)

        if (sp.absolute(m - (xk.T*Nk/b).T)).sum() < mtol: break
        m = (xk.T*Nk/b).T

    return dict(a=a, b=b, m=m, W=W, v=v, Epi=Epi)

if __name__ == '__main__':
    import sys
    import matplotlib.pyplot as plt
    d = sp.loadtxt(sys.stdin)
    d = d - [d[:, 0].min(), d[:, 1].min()]
    d = (2 * d / [d[:, 0].max(), d[:, 1].max()] - 1)

    Z = vbgmm(d, 6)
    print "α=%s\nβ=%s\nm=%s\nW=%s\nν=%s\nE[π]=%s\n" % (Z['a'], Z['b'], Z['m'], Z['W'], Z['v'], Z['Epi'])

    indices = Z['Epi']>1e-2
    plt.scatter(d[:, 0], d[:, 1], color='w', edgecolors='g')
    plt.scatter(Z['m'][indices][:, 0], Z['m'][indices][:, 1], color='r', s=40)
    plt.show()

ちなみに今回はこの記事にお世話になりました(mの初期値を散らすようにしたのはここで気付いた). 色々と真面目に考察されているのがタメになってありがたかったです.*2

コードが動いて満足したので11章のサンプリング法へ.

*1:無いと思うけど

*2:というかここまで色々書かれてると自分で実装する必要を感じなくなってしまゲフッ