ニューラルネットワークを理解する #15 「勾配降下法以外のパラメータ更新方法について③」 Adam:計算編

 前回で理論的なところをまとめたので実際に計算してみよう!

Adamのアルゴリズムは以下のものでしたね。


Adamのアルゴリズム

変数

α:学習率
β1, β2 : モーメントを調節するための値(0~1)
ε: 限りなく小さな値(0除算を防ぐための変数)
θt : 目的関数のパラメータ
gt : 勾配
mt : 一次モーメント
vt : 二次モーメント
(tは試行回数、なのでmt-1は一つ前のmt) 

定義

mt ← β1 * mt-1 + (1 - β1) * gt
vt ←  β2 * vt-1 + (1 - β2) * gt ^ 2
^mt ← mt / (1 - β1 ^ t)
^vt ← vt / (1 - β2 ^ t)
θt ← θt-1 - α * ^mt / (√^vt + ε)

計算してみよう!

  • f(x,y) = (1/20) * x^2 + y^2 の最小値をAdamで求める
  • 初期値=(-3.0, 4.0), 学習率=0.001, step=3, β1=0.9, β2=0.999,ε=10 ^ -8
  • 初期値は論文の推奨値を使用

Step1

偏微分の計算
xについて微分すると ∂f / ∂x = 10 / x
yについて微分すると ∂f / ∂x1 = 2 * y

gt_x = -3/10 = -0.3
gt_y = 8

mt_x ← 0.9 * 0 + (1 - 0.9) * -0.3 = -0.03
vt_x ←  0.999 * 0 + (1 - 0.999) * (-0.3) ^ 2 = 0.00009
^mt_x ← -0.03 /  (1 - 0.9 ^ 1) = -0.30
^vt_x ← 0.00009 / (1 - 0.999 ^ 1) = 0.09

mt_y ← 0.9 * 0 + (1 - 0.9) * 8 = 0.8
vt_y ←  0.999 * 0 + (1 - 0.999) * 8 ^ 2 = 0.064
^mt_y ← 0.8 / (1 - 0.9 ^ 1) = 8
^vt_y ← 0.064 / (1 - 0.999 ^ 1) = 64

x= -3.0 - 0.001 * -0.03 / (√0.09 + 0.00000001) = -2.9999
y= 4 - 0.001 * 8 / (√64 + 0.00000001) = 3.999

Step2

gt_x = -2.9999/10 = -0.29999
gt_y = 3.999*2=7.998

mt_x ← 0.9 * -0.03 + (1 - 0.9) * -0.29999 = -0.056999
vt_x ←  0.999 * 0.00009 + (1 - 0.999) * (-0.29999) ^ 2 = 0.0001799
^mt_x ← -0.056999 /  (1 - 0.9 ^ 2) = -0.29999
^vt_x ← 0.0001799 / (1 - 0.999 ^ 2) = 0.08999

mt_y ← 0.9 * 0.8 + (1 - 0.9) * 7.998 = 1.5198
vt_y ←  0.999 * 0.064 + (1 - 0.999) * 7.998 ^ 2 = 0.1279
^mt_y ← 1.5198 / (1 - 0.9 ^ 2) = 7.9989
^vt_y ← 0.1279 / (1 - 0.999 ^ 2) = 63.981

x= -2.9999 - 0.001 * -0.29999 / (√0.08999 + 0.00000001) = -2.9988
y= 3.999 - 0.001 * 7.9989 / (√63.981 + 0.00000001) = 3.997


Step3

gt_x = -2.9988/10 = -0.29988
gt_y = 3.997*2=7.994

mt_x ← 0.9 * -0.056999 + (1 - 0.9) * -0.29988 = -0.081287
vt_x ←  0.999 * 0.0001799+ (1 - 0.999) * (-0.29988) ^ 2 = 0.0002696
^mt_x ← -0.081287 /  (1 - 0.9 ^ 3) = -0.29995
^vt_x ← 0.0002696/ (1 - 0.999 ^ 3) = 0.08995

mt_y ← 0.9 * 1.5198 + (1 - 0.9) * 7.994 = 2.16722
vt_y ←  0.999 * 0.1279 + (1 - 0.999) * 7.994 ^ 2 = 0.191676
^mt_y ← 2.16722 / (1 - 0.9 ^ 3) = 7.99712
^vt_y ← 0.191676 / (1 - 0.999 ^ 3) = 63.9559

x= -2.9988 - 0.001 * -0.29995 / (√0.08995 + 0.00000001) = -2.9977
y=  3.997 - 0.001 * 7.99712 / (√63.9559 + 0.00000001) = 3.996


全然進みませんね。。
もっと回数重ねると他の手法よりも早いのかもしれないです。

コメント