當前位置:編程學習大全網 - 編程語言 - Transformer解讀(附pytorch代碼)

Transformer解讀(附pytorch代碼)

Transformer早在2017年就出現了,直到BERT問世,Transformer開始在NLP大放光彩,目前比較好的推進就是Transformer-XL(後期附上)。這裏主要針對論文和程序進行解讀,如有不詳實之處,歡迎指出交流,如需了解更多細節之處,推薦知乎上 川陀學者 寫的。本文程序的git地址在 這裏 。程序如果有不詳實之處,歡迎指出交流~

2017年6月,Google發布了壹篇論文《Attention is All You Need》,在這篇論文中,提出了 Transformer 的模型,其旨在全部利用Attention方式來替代掉RNN的循環機制,從而通過實現並行化計算提速。在Transformer出現之前,RNN系列網絡以及seq2seq+attention架構基本上鑄就了所有NLP任務的鐵桶江山。由於Attention模型本身就可以看到全局的信息, Transformer實現了完全不依賴於RNN結構僅利用Attention機制,在其並行性和對全局信息的有效處理上獲得了比之前更好的效果。

縱觀圖1整個Transformer的結構,其核心模塊其實就是三個:Multi-Head attention、Feed Forward 以及 Add&Norm。這裏關於Multi-Head attention部分只講程序的實現,關於更多細節原理,請移至開頭推薦的知乎鏈接。

Transformer中的attention采用的是多頭的self-attention結構,並且在編碼器中,由於不同的輸入mask的部分不壹樣,因此在softmax之前采用了mask操作,並且解碼時由於不能看到t時刻之後的數據,同樣在解碼器的第壹個Multi-Head attention中采用了mask操作,但是二者是不同的。因為編碼器被mask的部分是需要在輸入到Transformer之前事先確定好,而解碼器第壹個Multi-Head attention被mask的部分其實就是從t=1時刻開始壹直到t=seq_len結束,對應於圖2。在圖2中,橫坐標表示解碼器壹個batch上的輸入序列長度(也就是t),紫色部分為被mask的部分,黃色部分為未被mask的部分,可以看出,隨著t的增加,被mask的部分逐壹減少。而解碼器第二個Multi-Head attention的mask操作和編碼器中是壹樣的。

mask+softmax程序如下:

mask操作其實就是對於無效的輸入,用壹個負無窮的值代替這個輸入,這樣在softmax的時候其值就是0。而在attention中(attention操作見下式),softmax的操作出來的結果其實就是attention weights,當attention weights為0時,表示不需要attention該位置的信息。

對於Multi-Head attention的實現,其實並沒有像論文原文寫的那樣,逐壹實現多個attention,再將最後的結果concat,並且通過壹個輸出權重輸出。下面通過程序和公式講解壹下實際的實現過程,這裏假設 , , 的來源是壹樣的,都是 ,其維度為[batch_size, seq_len, input_size]。(需要註意的是在解碼器中第二個Multi-Head的輸入中 與 的來源不壹樣)

首先,對於輸入 ,通過三個權重變量得到 , , ,此時三者維度相同,都是[batch, seq_len, d_model],然後對其進行維度變換:[batch, seq_len, h, d_model//h]==>[batch, h, seq_len, d]==>[batch×h, seq_len, d],其中d=d_model//h,因此直接將變換後的 , , 直接做DotProductAttention就可以實現Multi-Head attention,最後只需要將DotProductAttention輸出的維度依次變換回去,然後乘以輸出權重就可以了。關於程序中的參數valid_length已在程序中做了詳細的解讀,這裏不再贅述,註意的是輸入的valid_length是針對batch這個維度的,而實際操作中由於X的batch維度發生了改變(由batch變成了batch×h),因此需要對valid_length進行復制。

FFN的實現是很容易的,其實就是對輸入進行第壹個線性變換,其輸出加上ReLU激活函數,然後在進行第二個線性變換就可以了。

Add&norm的實現就是利用殘差網絡進行連接,最後將連接的結果接上LN,值得註意的是,程序在Y的輸出中加入了dropout正則化。同樣的正則化技術還出現在masked softmax之後和positional encoding之後。

positional encoding的實現很簡單,其實就是對輸入序列給定壹個唯壹的位置,采用sin和cos的方式給了壹個位置編碼,其中sin處理的是偶數位置,cos處理的是奇數位置。但是,這壹塊的工作確實非常重要的,因為對於序列而言最主要的就是位置信息,顯然BERT是沒有去采用positional encoding(盡管在BERT的論文裏有壹個Position Embeddings的輸入,但是顯然描述的不是Transformer中要描述的位置信息),後續BERT在這壹方面的改進工作體現在了XLNet中(其采用了Transformer-XL的結構),後續的中再介紹該部分的內容。

無論是編碼器還是解碼器,其實都是用上面說的三個基本模塊堆疊而成,具體的實現細節大家可以看開頭的git地址,這裏需要強調的是以下幾點:

中出現的程序都在開頭的git中了,直接執行main.ipynb就可以運行程序,如有不詳實之處,還請指出~~~

  • 上一篇:重慶汽車學院的學科建設
  • 下一篇:Android應用開發和Android軟件測試工程師哪個好?
  • copyright 2024編程學習大全網