RVMで分類
逐次的疎ベイジアン学習アルゴリズムの前に、普通にαの値を(7.116)で更新していくコードでも書くか→動かない... というのでハマっていた。いくらやってもαが0に収束してしまう。。。
元の論文を斜め読みするに、こちらはwの更新式がちょっと違う感じ(?) ということでwの更新を(7.82)に、(7.117)のt^を利用するコードを書いたらそれっぽく動いた。うーん、、、、ちゃんと理解できてないなコリャ...。7.23後半の式の導出をいくらかサボったツケか。
ということで以下のコードは怪しげ。あまり参考にしないように...*1
#!/usr/bin/python # -*- coding: utf-8 -*- # 分類 import scipy as sp import scipy.linalg as spla import itertools as it import functools as fn class RVM(object): u""" relevance vector machine PRML 7.2.3を参考にした """ def __init__(self, kernel=lambda x, y: sp.exp(-sp.square(spla.norm(x-y))/0.25)): self._kernel = kernel self._sig = lambda x: 1.0/(1.0+sp.exp(-x)) def learn(self, X, t, tol=0.01, amax=1e5): u"""学習""" N = X.shape[0] phi = sp.ones((N, N+1)) # design matrix(*transposed*) phi[:,1:] = [[self._kernel(xi, xj) for xj in X] for xi in X] a = sp.ones(N+1) # hyperparameter w = sp.zeros(N+1) t_est = sp.dot(phi, w) y = self._sig(t_est) b = y * (1 - y) diff = float('inf') while diff >= tol: # bが0に収束するとB**-1が無限大に飛んで困る b_ = 1.0/b b_[b_ >= amax] = amax invB = sp.diag(b_) sigma = spla.inv(sp.diag(a) + sp.dot(phi.T, sp.dot(sp.diag(b), phi))) t_hat = t_est + sp.dot(invB, t - y) w = sp.dot(sigma, sp.dot(phi.T, sp.dot(sp.diag(b), t_hat))) t_est = sp.dot(phi, w) y = self._sig(t_est) b = y * (1 - y) gamma = 1.0 - a * sigma.diagonal() anew = gamma / sp.square(w) anew[anew >= amax] = amax adiff = anew - a diff = sp.square(adiff).sum() a = anew self._w = w[a < amax] self._rv_index = a[1:] < amax self._base_index = sp.arange(N+1)[a < amax] self._X = X def p1(self, x): ret = 0 phi = sp.append([1.0], [self._kernel(x, xi) for xi in self._X]) for i in range(len(self._base_index)): ret += self._w[i] * phi[self._base_index[i]] return self._sig(ret) def _get_w(self): return self._w_ w = property(_get_w) def _get_rv_index(self): return self._rv_index # 関連ベクトルのインデックス rv_index = property(_get_rv_index) if __name__ == '__main__': from sys import stdin import matplotlib.pyplot as plt data = sp.loadtxt(stdin) X = data[:, 0:2] t = data[:, 2] rvm = RVM() rvm.learn(X, t) # 描画 import matplotlib.pyplot as plt x = sp.linspace(0, 1, 50) # 入力 plt.scatter(X[:,0][t > 0], X[:,1][t > 0], color='b', marker='x') plt.scatter(X[:,0][t == 0], X[:, 1][t == 0], color='r', marker='x') # 関連ベクトルの描画 plt.scatter(X[:,0][rvm.rv_index & (t > 0)], X[:,1][rvm.rv_index & (t > 0)], color='b', edgecolor='k', s=50, label="relevance vector") plt.scatter(X[:,0][rvm.rv_index & (t == 0)], X[:,1][rvm.rv_index & (t == 0)], color='r', edgecolor='k', s=50, label="relevance vector") meshx, meshy = sp.meshgrid(sp.linspace(-3, 3, 100), sp.linspace(-3, 3, 100)) meshz = [[rvm.p1([meshx[j][i], meshy[j][i]]) >= 0.5 for i in range(len(meshx[0]))] for j in range(len(meshx))] plt.contour(meshx, meshy, meshz, 1) plt.spring() # label表示 plt.show()
データはPRMLのテストデータから適当に取ってきた
実行は
./rvm.py < classification.txt
時間がかかるのは殆ど描画の部分。結果は↓
何だかんだで図7.12と同じようなものが生成された。逐次的(ryはもう放置。QとSを(7.119)のCの式からWoodBuryの公式を使って適当にやれば多分何とかなるだろう(おい
ちなみに書き終えてから気付いたが、予測分布はwで周辺化するべきなのかな、4章の時のように。
まぁ上の図は境界線しか引いてないので関係無いか;
*1:このblog全般に言えゲフッゲフッ