當前位置:編程學習大全網 - 源碼下載 - 循環神經網絡(RNN)淺析

循環神經網絡(RNN)淺析

RNN是兩種神經網絡模型的縮寫,壹種是遞歸神經網絡(Recursive Neural Network),壹種是循環神經網絡(Recurrent Neural Network)。雖然這兩種神經網絡有著千絲萬縷的聯系,但是本文主要討論的是第二種神經網絡模型——循環神經網絡(Recurrent Neural Network)。

循環神經網絡是指壹個隨著時間的推移,重復發生的結構。在自然語言處理(NLP),語音圖像等多個領域均有非常廣泛的應用。RNN網絡和其他網絡最大的不同就在於RNN能夠實現某種“記憶功能”,是進行時間序列分析時最好的選擇。如同人類能夠憑借自己過往的記憶更好地認識這個世界壹樣。RNN也實現了類似於人腦的這壹機制,對所處理過的信息留存有壹定的記憶,而不像其他類型的神經網絡並不能對處理過的信息留存記憶。

循環神經網絡的原理並不十分復雜,本節主要從原理上分析RNN的結構和功能,不涉及RNN的數學推導和證明,整個網絡只有簡單的輸入輸出和網絡狀態參數。壹個典型的RNN神經網絡如圖所示:

由上圖可以看出:壹個典型的RNN網絡包含壹個輸入x,壹個輸出h和壹個神經網絡單元A。和普通的神經網絡不同的是,RNN網絡的神經網絡單元A不僅僅與輸入和輸出存在聯系,其與自身也存在壹個回路。這種網絡結構就揭示了RNN的實質:上壹個時刻的網絡狀態信息將會作用於下壹個時刻的網絡狀態。如果上圖的網絡結構仍不夠清晰,RNN網絡還能夠以時間序列展開成如下形式:

等號右邊是RNN的展開形式。由於RNN壹般用來處理序列信息,因此下文說明時都以時間序列來舉例,解釋。等號右邊的等價RNN網絡中最初始的輸入是x0,輸出是h0,這代表著0時刻RNN網絡的輸入為x0,輸出為h0,網絡神經元在0時刻的狀態保存在A中。當下壹個時刻1到來時,此時網絡神經元的狀態不僅僅由1時刻的輸入x1決定,也由0時刻的神經元狀態決定。以後的情況都以此類推,直到時間序列的末尾t時刻。

上面的過程可以用壹個簡單的例子來論證:假設現在有壹句話“I want to play basketball”,由於自然語言本身就是壹個時間序列,較早的語言會與較後的語言存在某種聯系,例如剛才的句子中“play”這個動詞意味著後面壹定會有壹個名詞,而這個名詞具體是什麽可能需要更遙遠的語境來決定,因此壹句話也可以作為RNN的輸入。回到剛才的那句話,這句話中的5個單詞是以時序出現的,我們現在將這五個單詞編碼後依次輸入到RNN中。首先是單詞“I”,它作為時序上第壹個出現的單詞被用作x0輸入,擁有壹個h0輸出,並且改變了初始神經元A的狀態。單詞“want”作為時序上第二個出現的單詞作為x1輸入,這時RNN的輸出和神經元狀態將不僅僅由x1決定,也將由上壹時刻的神經元狀態或者說上壹時刻的輸入x0決定。之後的情況以此類推,直到上述句子輸入到最後壹個單詞“basketball”。

接下來我們需要關註RNN的神經元結構:

上圖依然是壹個RNN神經網絡的時序展開模型,中間t時刻的網絡模型揭示了RNN的結構。可以看到,原始的RNN網絡的內部結構非常簡單。神經元A在t時刻的狀態僅僅是t-1時刻神經元狀態與t時刻網絡輸入的雙曲正切函數的值,這個值不僅僅作為該時刻網絡的輸出,也作為該時刻網絡的狀態被傳入到下壹個時刻的網絡狀態中,這個過程叫做RNN的正向傳播(forward propagation)。註:雙曲正切函數的解析式如下:

雙曲正切函數的求導如下:

雙曲正切函數的圖像如下所示:

這裏就帶來壹個問題:為什麽RNN網絡的激活函數要選用雙曲正切而不是sigmod呢?(RNN的激活函數除了雙曲正切,RELU函數也用的非常多)原因在於RNN網絡在求解時涉及時間序列上的大量求導運算,使用sigmod函數容易出現梯度消失,且sigmod的導數形式較為復雜。事實上,即使使用雙曲正切函數,傳統的RNN網絡依然存在梯度消失問題,無法“記憶”長時間序列上的信息,這個bug直到LSTM上引入了單元狀態後才算較好地解決。

這壹節主要介紹與RNN相關的數學推導,由於RNN是壹個時序模型,因此其求解過程可能和壹般的神經網絡不太相同。首先需要介紹壹下RNN完整的結構圖,上壹節給出的RNN結構圖省去了很多內部參數,僅僅作為壹個概念模型給出。

上圖表明了RNN網絡的完整拓撲結構,從圖中我們可以看到RNN網絡中的參數情況。在這裏我們只分析t時刻網絡的行為與數學推導。t時刻網絡迎來壹個輸入xt,網絡此時刻的神經元狀態st用如下式子表達:

t時刻的網絡狀態st不僅僅要輸入到下壹個時刻t+1的網絡狀態中去,還要作為該時刻的網絡輸出。當然,st不能直接輸出,在輸出之前還要再乘上壹個系數V,而且為了誤差逆傳播時的方便通常還要對輸出進行歸壹化處理,也就是對輸出進行softmax化。因此,t時刻網絡的輸出ot表達為如下形式:

為了表達方便,筆者將上述兩個公式做如下變換:

以上,就是RNN網絡的數學表達了,接下來我們需要求解這個模型。在論述具體解法之前首先需要明確兩個問題:優化目標函數是什麽?待優化的量是什麽?

只有在明確了這兩個問題之後才能對模型進行具體的推導和求解。關於第壹個問題,筆者選取模型的損失函數作為優化目標;關於第二個問題,我們從RNN的結構圖中不難發現:只要我們得到了模型的U,V,W這三個參數就能完全確定模型的狀態。因此該優化問題的優化變量就是RNN的這三個參數。順便說壹句,RNN模型的U,V,W三個參數是全局***享的,也就是說不同時刻的模型參數是完全壹致的,這個特性使RNN得參數變得稍微少了壹些。

不做過多的討論,RNN的損失函數選用交叉熵(Cross Entropy),這是機器學習中使用最廣泛的損失函數之壹了,其通常的表達式如下所示:

上面式子是交叉熵的標量形式,y_i是真實的標簽值,y_i*是模型給出的預測值,最外面之所以有壹個累加符號是因為模型輸出的壹般都是壹個多維的向量,只有把n維損失都加和才能得到真實的損失值。交叉熵在應用於RNN時需要做壹些改變:首先,RNN的輸出是向量形式,沒有必要將所有維度都加在壹起,直接把損失值用向量表達就可以了;其次,由於RNN模型處理的是序列問題,因此其模型損失不能只是壹個時刻的損失,應該包含全部N個時刻的損失。

故RNN模型在t時刻的損失函數寫成如下形式:

全部N個時刻的損失函數(全局損失)表達為如下形式:

需要說明的是:yt是t時刻輸入的真實標簽值,ot為模型的預測值,N代表全部N個時刻。下文中為了書寫方便,將Loss簡記為L。在結束本小節之前,最後補充壹個softmax函數的求導公式:

由於RNN模型與時間序列有關,因此不能直接使用BP(back propagation)算法。針對RNN問題的特殊情況,提出了BPTT算法。BPTT的全稱是“隨時間變化的反向傳播算法”(back propagation through time)。這個方法的基礎仍然是常規的鏈式求導法則,接下來開始具體推導。雖然RNN的全局損失是與全部N個時刻有關的,但為了簡單筆者在推導時只關註t時刻的損失函數。

首先求出t時刻下損失函數關於o_t*的微分:

求出損失函數關於參數V的微分:

因此,全局損失關於參數V的微分為:

求出t時刻的損失函數關於關於st*的微分:

求出t時刻的損失函數關於s_t-1*的微分:

求出t時刻損失函數關於參數U的偏微分。註意:由於是時間序列模型,因此t時刻關於U的微分與前t-1個時刻都有關,在具體計算時可以限定最遠回溯到前n個時刻,但在推導時需要將前t-1個時刻全部帶入:

因此,全局損失關於U的偏微分為:

求t時刻損失函數關於參數W的偏微分,和上面相同的道理,在這裏仍然要計算全部前t-1時刻的情況:

因此,全局損失關於參數W的微分結果為:

至此,全局損失函數關於三個主要參數的微分都已經得到了。整理如下:

接下來進壹步化簡上述微分表達式,化簡的主要方向為t時刻的損失函數關於ot的微分以及關於st*的微分。已知t時刻損失函數的表達式,求關於ot的微分:

softmax函數求導:

因此:

又因為:

且:

有了上面的數學推導,我們可以得到全局損失關於U,V,W三個參數的梯度公式:

由於參數U和W的微分公式不僅僅與t時刻有關,還與前面的t-1個時刻都有關,因此無法寫出直接的計算公式。不過上面已經給出了t時刻的損失函數關於s_t-1的微分遞推公式,想來求解這個式子也是十分簡單的,在這裏就不贅述了。

以上就是關於BPTT算法的全部數學推導。從最終結果可以看出三個公式的偏微分結果非常簡單,在具體的優化過程中可以直接帶入進行計算。對於這種優化問題來說,最常用的方法就是梯度下降法。針對本文涉及的RNN問題,可以構造出三個參數的梯度更新公式:

依靠上述梯度更新公式就能夠叠代求解三個參數,直到三個參數的值發生收斂。

這是筆者第壹次嘗試推導RNN的數學模型,在推導過程中遇到了非常多的bug。非常感謝互聯網上的壹些公開資料和博客,給了我非常大的幫助和指引。接下來筆者將嘗試實現壹個單隱層的RNN模型用於實現壹個語義預測模型。

  • 上一篇:如何有效殺滅綠化帶內的黃蜂
  • 下一篇:html語言中的標記有哪幾類?他們的格式是如何規定的?
  • copyright 2024編程學習大全網