RVMで分類

逐次的疎ベイジアン学習アルゴリズムの前に、普通にαの値を(7.116)で更新していくコードでも書くか→動かない... というのでハマっていた。いくらやってもαが0に収束してしまう。。。
元の論文を斜め読みするに、こちらはwの更新式がちょっと違う感じ(?) ということでwの更新を(7.82)に、(7.117)のt^を利用するコードを書いたらそれっぽく動いた。うーん、、、、ちゃんと理解できてないなコリャ...。7.23後半の式の導出をいくらかサボったツケか。

ということで以下のコードは怪しげ。あまり参考にしないように...*1

#!/usr/bin/python
# -*- coding: utf-8 -*-

# 分類

import scipy as sp
import scipy.linalg as spla
import itertools as it
import functools as fn


class RVM(object):
    u"""
    relevance vector machine
    PRML 7.2.3を参考にした
    """
    
    def __init__(self,
                 kernel=lambda x, y: sp.exp(-sp.square(spla.norm(x-y))/0.25)):
        self._kernel = kernel
        self._sig = lambda x: 1.0/(1.0+sp.exp(-x))

    def learn(self, X, t, tol=0.01, amax=1e5):
        u"""学習"""
        N = X.shape[0]
        phi = sp.ones((N, N+1)) # design matrix(*transposed*)
        phi[:,1:] = [[self._kernel(xi, xj) for xj in X] for xi in X]
        a = sp.ones(N+1)  # hyperparameter
        w = sp.zeros(N+1)
        t_est = sp.dot(phi, w)
        y = self._sig(t_est)
        b = y * (1 - y)

        diff = float('inf')
        while diff >= tol:
            # bが0に収束するとB**-1が無限大に飛んで困る
            b_ = 1.0/b
            b_[b_ >= amax] = amax
            invB = sp.diag(b_)

            sigma = spla.inv(sp.diag(a) + sp.dot(phi.T, sp.dot(sp.diag(b), phi)))
            t_hat = t_est + sp.dot(invB, t - y)
            w = sp.dot(sigma, sp.dot(phi.T, sp.dot(sp.diag(b), t_hat)))
            t_est = sp.dot(phi, w)
            y = self._sig(t_est)
            b = y * (1 - y)
            gamma = 1.0 - a * sigma.diagonal()
            anew = gamma / sp.square(w)
            anew[anew >= amax] = amax
            adiff = anew - a
            diff = sp.square(adiff).sum()
            a  = anew

        self._w = w[a < amax]
        self._rv_index = a[1:] < amax
        self._base_index = sp.arange(N+1)[a < amax]
        self._X = X

    def p1(self, x):
        ret = 0
        phi = sp.append([1.0], [self._kernel(x, xi) for xi in self._X])
        for i in range(len(self._base_index)):
            ret += self._w[i] * phi[self._base_index[i]]
        return self._sig(ret)

    def _get_w(self):
        return self._w_
    w = property(_get_w)

    def _get_rv_index(self):
        return self._rv_index

    # 関連ベクトルのインデックス
    rv_index = property(_get_rv_index)


if __name__ == '__main__':
    from sys import stdin
    import matplotlib.pyplot as plt

    data = sp.loadtxt(stdin)
    X = data[:, 0:2]
    t = data[:, 2]

    rvm = RVM()
    rvm.learn(X, t)

    # 描画
    import matplotlib.pyplot as plt
    x = sp.linspace(0, 1, 50)

    # 入力
    plt.scatter(X[:,0][t > 0], X[:,1][t > 0], color='b', marker='x')
    plt.scatter(X[:,0][t == 0], X[:, 1][t == 0], color='r', marker='x')

    # 関連ベクトルの描画
    plt.scatter(X[:,0][rvm.rv_index & (t > 0)], X[:,1][rvm.rv_index & (t > 0)], color='b', edgecolor='k', s=50, label="relevance vector")
    plt.scatter(X[:,0][rvm.rv_index & (t == 0)], X[:,1][rvm.rv_index & (t == 0)], color='r', edgecolor='k', s=50, label="relevance vector")

    meshx, meshy = sp.meshgrid(sp.linspace(-3, 3, 100), sp.linspace(-3, 3, 100))
    meshz = [[rvm.p1([meshx[j][i], meshy[j][i]]) >= 0.5
              for i in range(len(meshx[0]))] for j in range(len(meshx))]
    plt.contour(meshx, meshy, meshz, 1)
    plt.spring()

    # label表示
    plt.show()

データはPRMLのテストデータから適当に取ってきた

実行は

./rvm.py < classification.txt

時間がかかるのは殆ど描画の部分。結果は↓

何だかんだで図7.12と同じようなものが生成された。逐次的(ryはもう放置。QとSを(7.119)のCの式からWoodBuryの公式を使って適当にやれば多分何とかなるだろう(おい
ちなみに書き終えてから気付いたが、予測分布はwで周辺化するべきなのかな、4章の時のように。
まぁ上の図は境界線しか引いてないので関係無いか;

*1:このblog全般に言えゲフッゲフッ