技術は使ってなんぼ

自分が得たものを誰かの役に立てたい

PyTorchで作るLSTM なぜRNNではなくLSTMなのか

こんにちは。

GWいかがお過ごしでしょうか?

私はここぞとばかりにブログを更新するぞ!と躍起になっております。

何か深く学ぶのにじっくり時間をかけられるので、連休を有意義に過ごしたいものですね。

さてRNNの発展形とされているLSTMについて実装してみました。

今回は実装の前段階として、「なぜRNNではなくLSTMが良いされるのか」について解説してみたいと思います。

RNNの課題

そもそもRNNでは何が課題なのでしょうか?

大きく2点考えられます。

1.長期記憶に向いていない
 →データが多くなるほど、直近のデータしか反映できず、古いデータの情報が残らない
2.勾配消失・勾配爆発の可能性がある
 →途中で上手く学習できなくなる


LSTMは上記2点の改善を狙って作られたものと考えられます。

具体的にどんな対策をLSTMでやっているかというと、

1.記憶セルcの追加
 →重要な情報を残し、重要でない情報を忘れる(sigmoid)
2.逆伝搬時に活性化関数(特にtanh)を経由しない
 →次章の図で解説します


勾配爆発に対しては、一般的には入力データの正規化や勾配クリッピング等でケアするのが基本のため、LSTM内部にはケアされてないように見えます。

RNNとLSTMのモデル比較

では先ほどの1,2の課題や対策内容を具体的に絵を書いてみていきましょう。

実際に絵で書いてみるのが一番わかりやすいかと思います。

理解を深めたいと思う方は、自分でアルゴリズムソースコードを見ながら絵を書いてみると理解が深まります。

まずはRNNを書いてみます。

RNNモデル


逆伝搬時にtanhを経由しているのがわかるかと思います。

tanh微分値は 0~1未満の値を取るため、行列積を続けた結果、重みは限りなく0に近づいていってしまいます。これが勾配消失です。

またイメージしてもらうとわかるかと思いますが、もしこの逆伝搬がより多くのデータに繋がるとどうなるか。

おそらく最初に読み込ませたデータまで逆伝搬するころは、最後に読み込ませたデータに比べてほとんど情報が反映されないと想像できます。

従って、長期のデータを記憶させることにもあまり向いている仕組みではないことが予想できます。

では、LSTMはどうでしょうか?

LSTMモデル


逆伝搬の経路を見ると、活性化関数を経由していないことがわかります。これが勾配消失が起きない仕組みです。

また各RNNCellも4つ配置されており、sigmoidがかけられたものが存在することが見受けられます。

これは重要な情報をできるだけ残し、あまり重要でない情報を捨てるという狙いがあることが予想できます。

つまりただ順番に情報を覚えさせていくRNNと違って、全体から重要な情報だけを残そうとするLSTMは長期のデータ記憶により強い仕組みであると考えられます。


今回はここまで!

次回、いよいよLSTMの実装とその推論精度について紹介します。