現実逃避に多クラスロジスティック回帰のコードを書いた
春合宿の…スライドに…まだ手を付けていないッ…。
・・・流石に何話すかとか話の流れとかは大体イメージできてる状態だが、いかんせん書く気が起きない。死ね俺。*1
PRML4章の多クラスロジスティック回帰を殴り書きした。多クラスのIRLSの形がよー分からんというかもうめんどうなので直接ニュートンラフソン法で。
・・・以下のコードだとパラメータの更新にヘッセ行列の対角ブロックの部分しか使ってない…んだけどこれでいいのか…?まぁでも結果上記の画像のようにそれっぽくなったのでもういいや状態。
標準入力からデータを読み込んで動作。
(x,y)がクラス2に属するなら"x y 0 0 1"みたいな感じの行が延々入力に続いてるとします。
信頼のコメント率!やる気の無い変数名!どうみても一週間で読めなくなります。
#!/usr/bin/python # -*- coding: utf-8 -*- from sys import stdin import numpy as np import numpy.linalg as npla import matplotlib.pyplot as plt from scipy.io import read_array def calcY(W, X, n, k): return np.array([[yi(i, W, X[j], k) for i in range(k)] for j in range(n)]) def yi(i, W, x, k): return (np.exp(np.dot(x, W[i])) / sum((np.exp(np.dot(x, W[j])) for j in range(k)))) def calcE(W, X, T, Y, n, m, k): E = np.zeros(m*k) for j in range(k): tmp = np.dot(X.T, Y[:,j]-T[:,j]) E[j*m:j*m+m] = tmp return E def calcH(W, X, Y, n, m, k): H = np.zeros((m*k, m*k)) for i in range(k): for j in range(k): Iij = 1 if i == j else 0 R = np.diag(Y[:,j]*(Iij-Y[:,i])) H[i*m:i*m+m:,j*m:j*m+m] = np.dot(X.T, np.dot(R, X)) return H if __name__ == '__main__': m = 3 data = read_array(stdin) k = len(data[0])-m+1 T = data[:,m-1:] X = np.ones((len(T), m)) X[:,0:m-1] = data[:,0:m-1] n = len(T) W = np.zeros((k, m)) while True: Y = calcY(W, X, n, k) E = calcE(W, X, T, Y, n, m, k) H = calcH(W, X, Y, n, m, k) Wnew = np.zeros((k, m)) # 更新これでいいのかな... for i in range(k): Wnew[i] = W[i]-np.dot(npla.inv(H[i*m:i*m+m, i*m:i*m+m]), E[i*m:i*m+m]) diff = npla.norm(Wnew-W)/npla.norm(W) print diff W = Wnew if diff < 0.1: break x, y = np.meshgrid(np.linspace(-5, 5, 200), np.linspace(-5, 5, 200)) w, h = x.shape x.resize((w*h,)) y.resize((w*h,)) z = [] for i in range(k): z.append(zip(map(lambda xi, yj: yi(i, W, [xi, yj, 1], k), x, y), [i]*len(x))) z = np.array(map(lambda x: max(x)[1], zip(*z))) x.resize((w,h)) y.resize((w,h)) z.resize((w,h)) CS = plt.contourf(x, y, z, [-0.5, 0.5, 1.5, 2.5, 3.5], cmap=plt.cm.bone, origin = 'lower') plt.colorbar(CS) colors=['r', 'g', 'b', 'k'] for i in range(k): plt.scatter(X[:,0][i*100:(i+1)*100], X[:,1][i*100:(i+1)*100], color=colors[i]) plt.show()
どうでもいいけどデータの生成は以下のコード適当に変えつつ生成。
#!/usr/bin/python # -*- coding: utf-8 -*- import scipy import matplotlib.pyplot as plt K = 4 means = [[2, 2], [2, -2], [-2, -2], [-2, 2]] covs = [[[0.1, 0], [0, 0.1]]] * K colors = ['r', 'g', 'b', 'k'] for i in range(len(means)): p = scipy.random.multivariate_normal(means[i], covs[i], 100) plt.scatter(p[:,0:1], p[:,1:2], color=colors[i]) t = [0]*K t[i] = 1 for j in p: s = "%lf %lf" % (j[0], j[1]) for k in range(K): s += ' '+str(t[k]) print s plt.show()