最尤推定・交差エントロピー誤差(実装編)
今回は3回に分けて学んだ分類問題の実装編です。
以前の記事をまだ読んでいない方はあらかじめ目を通しておくことをおすすめします。
まずは手順の再確認です。
実装に躓いたら上記の図や記事を参考にして、基本に立ち返ります。
1.モジュールのimport
以下のモジュールを使って実装します。
import matplotlib.pyplot as plt import numpy as np import math import copy import random %matplotlib inline
2.データの作成
2次元配列でデータを作成。[体重データ、性別データ]のセットで作成します。
#データの作成 x=np.array([]) y=np.array([]) for i in range(20): if i<=9: sex=0#random.randint(0,1) y=np.insert(y,len(y),sex,axis=0) weight=np.random.normal(60.0 , 5.0) x=np.insert(x,len(x),weight,axis=0) else: sex=1#random.randint(0,1) y=np.insert(y,len(y),sex,axis=0) weight=np.random.normal(70.0 , 5.0) x=np.insert(x,len(x),weight,axis=0) data=np.c_[x,y] print(data)
3.データの可視化
今回は女性と男性の体重データがそれぞれ10個ずつの想定で作ってます。
#データの可視化 plt.scatter(data[:9,0], data[:9,1],color="red") plt.scatter(data[10:19,0], data[10:19,1],color="blue") plt.show()
4.パラメータと関数の定義
4.1.学習パラメータ
param = np.array([1.0 for i in range(2)])
4.2.シグモイド関数
def sigmoid(x): e = math.e s = 1.0 / (1.0 + e**-x) return s
4.3.尤度関数
戻り値は尤度の符号を変えているので交差エントロピーに変換しています。
def likelihood(param,data): mean = sum(data[:,0])/len(data) data_copy=copy.deepcopy(data) data_copy[:,0]-=mean y=np.array([0.0 for i in range(len(data_copy))]) #直線の方程式 y = param[0]*data_copy[:,0]+param[1] #シグモイド関数へ代入 y = sigmoid(y[:]) prob = 1.0 for i in range(len(data_copy)): prob *= y[i]**data_copy[i,1]*(1- y[i])**(1-data_copy[i,1]) return prob*(-1)
4.4.微分関数
下記の関数は線基底関数モデル(実装編)でもご紹介しました。
def DER(param,func,data,index): epsilon=0.0001 p1=[0.0 for i in range(len(param))] p1+=param p1[index]+=epsilon p2=[0.0 for i in range(len(param))] p2+=param p2[index]-=epsilon mse=func(p1,data)-func(p2,data) der=mse/(2*epsilon) return der
4.5.学習部分
微分関数と同様以前ご紹介しました。
def fitting(param,data): alpha=0.5 for i in range(len(param)): der=DER(param,likelihood,data,i) param[i]-=der*alpha
5.学習結果の可視化
mean = sum(data[:,0])/len(data) x = np.arange(mean-20,mean+20,0.01) #もっと見やすく改良、正規化されたデータを元に戻して表示↓ plt.plot(x[:], sigmoid(param[0]*(x[:]-mean)+param[1])) plt.scatter(data[:9,0], data[:9,1],color="red") plt.scatter(data[10:19,0], data[10:19,1],color="blue") plt.show()
実際の学習結果の例↓
縦に赤と青の点が並んでいる場所では1と0の間に曲線があり、そうでない場所では1か0に曲線が振り切れているのがわかります。
このように確率のアイデアを取り入れることで複数の種類のデータを曖昧さも加味しながら分類することができました。