OpenOpt使ってSVM書いた
追記(5/19):ガウスカーネル2乗してなかった。コード書き忘れ訂正--); ついでに画像も変更
SMO法使った前のエントリは、殆どpureにpythonでコード書いてたせいか、結構時間がかかっててイライラ。ということでOpenOptの二次計画のソルバー使って手抜きに疎な解を求めてみたの巻。
結果はテストデータ200個の↓の図だと200倍の差が…。scipy+OpenOptぱない
コーディングもあっと言う間だし…その…何というか…一昨日の努力は…一体…。まぁデータ200個と少なきゃメモリにのるしね…。
以下適当に書いたpythonのコード。相変わらずグラフの描画とかのコードの筋が悪い気がしてもにょいぜ。
#!/usr/bin/python # -*- coding: utf-8 -*- from scipy import * from scipy.linalg import norm from openopt import QP class svm(object): """ SVM """ def __init__(self, c=10000, kernel=lambda x,y:dot(x,y)): """ Constructor Arguments: - `c`: param - `kernel`: kernel func """ self._c = c self._kernel = kernel def _get_S(self): return self._S S = property(_get_S) def learn(self, x, t): """ Learning SVM (using openopt.QP) Arguments: - `x`: inputs - `t`: targets """ # making Gram matrix N = len(t) K = array([[t[i]*t[j]*self._kernel(x[i], x[j]) for j in range(N)] for i in range(N)]) p = QP(H=K, f=-ones(N), lb=zeros(N), ub=ones(N)*self._c, Aeq=t, beq=0) r = p.solve('nlp:ralg') self._a = r.xf self._S = [i for i in range(N) if 0 < self._a[i]] self._M = [i for i in range(N) if 0 < self._a[i] < self._c] b = 0.0 for i in self._M: b += t[i] for j in self._S: b -= self._a[j]*t[j]*self._kernel(x[i], x[j]) self._b = b/len(self._M) self._x = x self._t = t def calc(self, x): """ Arguments: - `x`: input """ return self._b + sum((self._a[i] * self._t[i] * self._kernel(x, self._x[i]) for i in self._S)) if __name__ == '__main__': from scipy.io import read_array import matplotlib.pyplot as plt import psyco psyco.full() s = svm(c=0.5, kernel=lambda x,y:exp(-square(norm(x-y))/0.45)) data = read_array(open("classification.txt")) p = data[:,0:2] t = data[:,2]*2-1.0 s.learn(p, t) for i in range(len(p)): c = 'r' if t[i] > 0 else 'b' plt.scatter([p[i][0]], [p[i][1]], color=c) X, Y = meshgrid(arange(-2.5, 2.5, 00.1), arange(-2.5, 2.5, 00.1)) w, h = X.shape X.resize(X.size) Y.resize(Y.size) Z = array([s.calc([x, y]) for (x, y) in zip(X, Y)]) X.resize((w, h)) Y.resize((w, h)) Z.resize((w, h)) CS = plt.contour(X, Y, Z, [0.0], colors = ('k'), linewidths = (3,), origin = 'lower') plt.xlim(-2.5, 2.5) plt.ylim(-2.5, 2.5) plt.show()