線形基底関数モデル(実装編)
前回:線形基底関数モデル(理論編)で紹介した内容を実際にプログラミングしていきます。
以前学んだ内容もバージョンアップ(より一般化)してもう一度組み直してみます。
では始めて行きましょう↓
線形基底関数モデルの実装
1.データの準備
今回採集された(と仮定する)データを適当に作ります。
データは[x座標,y座標]となるような行列で表します。
def data_log(x): y=np.log(x*10) return y #データの作成 x=np.array([]) y=np.array([]) for i in range(20): tall=random.uniform(0,15) x=np.insert(x,len(x),tall,axis=0) norm=np.random.normal(0 , 0.05) weight=data_log(x[i])+norm y=np.insert(y,len(y),weight,axis=0) data=np.c_[x,y]#データの結合 print(data)
可視化もします↓
#データの可視化 x_axis=np.arange(0,15.0,0.1) plt.plot(x_axis[:],data_log(x_axis[:])) plt.scatter(data[:,0], data[:,1],color="red") plt.show()
2.正規関数の用意
上で準備したデータに正規関数の重ね合わせで学習して行きます。
各正規関数には重みづけをしてから重ね合わせます。この時に使う重みの値が推定するパラメータになります。
#初期パラメータ param=np.array([1.0 for i in range(6)]) #正規関数 def gauss(x,mean,sigma): y=np.exp(-(x-mean)**2/(2*sigma**2)) return y #正規関数の重ね合わせ def mixgauss(x,param): y=0 for i in range(len(param)): y+=param[i]*gauss(x,i*3,5) return y
好きな範囲でプロットして確認↓
x_axis=np.arange(-15.0,15.0,0.1) for i in range(len(param)): plt.plot(x_axis[:],gauss(x_axis[:],i*3,5)) plt.plot(x_axis[:],mixgauss(x_axis[:],param)) plt.show()
3.学習に必要なメソッドの定義
根本の考え方は最小二乗法なので誤差を求める↓
(下記のプログラムはyとzの内積計算をしているが、分かりにくければfor文を使って和を取れば良い。結果は同じ。)
def min2error(param,data): y=0.0 z=[1.0 for i in range(len(data))] y+=(mixgauss(data[:,0],param)-data[:,1])**2 error2=y.dot(z)/len(data) return error2
与えられたインデックス(index)の値を収束させた極限を求める↓
(つまりは偏微分。以下のコードはパラメータの個数が可変長でもうまく動作するはず。)
def DER(param,func,data,index): epsilon=0.0001 p1=[0.0 for i in range(len(param))] p1+=param p1[index]+=epsilon p2=[0.0 for i in range(len(param))] p2+=param p2[index]-=epsilon mse=func(p1,data)-func(p2,data) der=mse/(2*epsilon) return der
全てのパラメータを1回ずつ更新する↓
def fitting(param,data): alpha=0.5 for i in range(len(param)): der=DER(param,min2error,data,i) param[i]-=der*alpha
4.実際に学習してみる
今回は1000回学習する。
for i in range(10000): fitting(param,data) print('学習結果') print(param)
最終確認↓
#データの可視化 x_axis=np.arange(0,15.0,0.1) plt.plot(x_axis[:],mixgauss(x_axis[:],param)) plt.scatter(data[:,0], data[:,1],color="red") plt.show()
なかなか良い形になったのではないでしょうか?
学習回数を増やすとさらに値が変化し、データに沿うような関数が得られます。
ただし、正規関数の平均値と分散を固定、または任意な値で固定していたので学習には限度があるかもしれません。