作者丨蘇劍林
單位丨追一科技
研究方向丨NLP,神經(jīng)網(wǎng)絡(luò)
個(gè)人主頁(yè)丨kexue.fm
相信近一年來(尤其是近半年來),大家都能很頻繁地看到各種 Transformer 相關(guān)工作(比如 BERT、GPT、XLNet 等等)的報(bào)導(dǎo),連同各種基礎(chǔ)評(píng)測(cè)任務(wù)的評(píng)測(cè)指標(biāo)不斷被刷新。同時(shí),也有很多相關(guān)的博客、專欄等對(duì)這些模型做科普和解讀。
俗話說,“外行看熱鬧,內(nèi)行看門道”,我們不僅要在“是什么”這個(gè)層面去理解這些工作,我們還需要思考“為什么”。這個(gè)“為什么”不僅僅是“為什么要這樣做”,還包括“為什么可以這樣做”。比如,在談到 XLNet 的亂序語言模型時(shí),我們或許已經(jīng)從諸多介紹中明白了亂序語言模型的好處,那不妨更進(jìn)一步思考一下:
為什么 Transformer 可以實(shí)現(xiàn)亂序語言模型?是怎么實(shí)現(xiàn)的?RNN 可以實(shí)現(xiàn)嗎?
本文從對(duì) Attention 矩陣進(jìn)行 Mask 的角度,來分析為什么眾多 Transformer 模型可以玩得如此“出彩”的基本原因,正如標(biāo)題所述“Transformer 如戲,全靠 Mask”,這是各種花式 Transformer 模型的重要“門道”之一。
讀完本文,你或許可以了解到:
1. Attention 矩陣的 Mask 方式與各種預(yù)訓(xùn)練方案的關(guān)系;
2. 直接利用預(yù)訓(xùn)練的 BERT 模型來做 Seq2Seq 任務(wù)。
自 Attention is All You Need 以后,基于純 Attention 的 Transformer 類模型逐漸變得流行起來,而 BERT 的出現(xiàn)則將這股潮流推向了一個(gè)新的高度。而后,各種基于大規(guī)模預(yù)訓(xùn)練的 Transformer 模型的工作不斷出現(xiàn),有基于現(xiàn)成的模型做應(yīng)用的,有試圖更好地去解釋和可視化這些模型的,還有改進(jìn)架構(gòu)、改進(jìn)預(yù)訓(xùn)練方式等以得到更好結(jié)果的。
總的來說,這些以預(yù)訓(xùn)練為基礎(chǔ)的工作層出不窮,有種琳瑯滿目的感覺。甚至一定程度上來說,如果你還沒有微調(diào)過 BERT ,那已經(jīng)算是落后于主流的 NLP 技術(shù)了。
花式預(yù)訓(xùn)練
眾所周知,傳統(tǒng)的模型預(yù)訓(xùn)練手段就是語言模型,比如 ELMo [1] 模型就是以 BiLSTM 為基礎(chǔ)架構(gòu)、用兩個(gè)方向的語言模型分別預(yù)訓(xùn)練兩個(gè)方向的 LSTM 的,后面的 OpenAI 的 GPT、GPT-2 [2] 也是堅(jiān)定不移地堅(jiān)持著用祖?zhèn)鞯模?biāo)準(zhǔn)的、單向的)語言模型來預(yù)訓(xùn)練。
然而,還有更多花樣的預(yù)訓(xùn)練玩法。比如 BERT [3] 就用了稱之為“掩碼語言模型(Masked Language Model)”的方式來預(yù)訓(xùn)練,不過這只是普通語言模型的一種變體;還有 XLNet [4]則提出了更徹底的“Permutation Language Modeling”,我們可以稱之為“亂序語言模型”;還有 UNILM [5] 模型,直接用單個(gè) BERT 的架構(gòu)做 Seq2Seq,你可以將它作為一種預(yù)訓(xùn)練手段,又或者干脆就用它來做 Seq2Seq 任務(wù)。
如此花樣百出,讓我們不禁疑問:為什么剛好在 Transformer 流行的時(shí)代,才出現(xiàn)這種各種大型預(yù)訓(xùn)練模型“百花齊放,百家爭(zhēng)鳴”的現(xiàn)象?
Transformer專屬
事實(shí)上,除了單向語言模型及其簡(jiǎn)單變體掩碼語言模型之外,UNILM 的 Seq2Seq 預(yù)訓(xùn)練、XLNet 的亂序語言模型預(yù)訓(xùn)練,基本可以說是專為 Transformer 架構(gòu)定制的。說白了,如果是 RNN 架構(gòu),根本就不能用亂序語言模型的方式來預(yù)訓(xùn)練,至于 Seq2Seq 的預(yù)訓(xùn)練方式,則必須同時(shí)引入兩個(gè)模型(encoder 和 decoder),而無法像 Transformer 架構(gòu)一樣,可以一個(gè)模型搞定。
這其中的奧妙主要在 Attention 矩陣之上。Attention 實(shí)際上相當(dāng)于將輸入兩兩地算相似度,這構(gòu)成了一個(gè)大小的相似度矩陣(即 Attention 矩陣,n 是句子長(zhǎng)度,本文的 Attention 均指 Self Attention),這意味著它的空間占用量是量級(jí),相比之下,RNN 模型、CNN 模型只不過是 ??(n),所以實(shí)際上 Attention 通常更耗顯存。
然而,有弊也有利,更大的空間占用也意味著擁有了更多的可能性,我們可以通過往這個(gè)級(jí)別的 Attention 矩陣加入各種先驗(yàn)約束,使得它可以做更靈活的任務(wù)。說白了,也就只有純 Attention 的模型,才有那么大的“容量”去承載那么多的“花樣”。
而加入先驗(yàn)約束的方式,就是對(duì) Attention 矩陣進(jìn)行不同形式的 Mask,這便是本文要關(guān)注的焦點(diǎn)。
在一文讀懂「Attention is All You Need」| 附代碼實(shí)現(xiàn)一文中筆者已經(jīng)對(duì) Attention 做了基本介紹,這里僅做簡(jiǎn)單回顧。Attention 的數(shù)學(xué)形式為:
這里的,分別代表 query、key、value 的向量序列,其中我們可以認(rèn)為 key 和 value 是一一對(duì)應(yīng)的,而
則是將 query、key 的向量?jī)蓛勺鰞?nèi)積,然后用 softmax 歸一化,就得到一個(gè)的 Attention 矩陣,它描述的就是 query 和 key 之間任意兩個(gè)元素的關(guān)聯(lián)強(qiáng)度,后面我們要講的故事,都是在這個(gè) Attention 矩陣上下功夫。最后再與 V 相乘,相當(dāng)于按照這個(gè)關(guān)聯(lián)強(qiáng)度將 V 的各個(gè)向量加權(quán)求和,最終輸出一個(gè)的向量序列。目前最常用的 Attention 方式當(dāng)數(shù) Self Attention,即 Q, K, V 都是同一個(gè)向量序列經(jīng)過線性變換而來的,而 Transformer 則是 Self Attention 跟 Position-Wise 全連接層(相當(dāng)于 kernel size 為 1 的一維卷積)的組合。所以,Transformer 就是基于 Attention 的向量序列到向量序列的變換。
在本節(jié)中,我們將會(huì)比較詳細(xì)地分析 Attention 矩陣的 Mask 方式,這分別對(duì)應(yīng)單向語言模型、亂序語言模型、Seq2Seq 的實(shí)現(xiàn)原理。
單向語言模型
語言模型可以說是一個(gè)無條件的文本生成模型,如果讀者還不了解文本生成模型,可以自行查閱相關(guān)資料并配合玩轉(zhuǎn)Keras之Seq2Seq自動(dòng)生成標(biāo)題 | 附開源代碼一文來理解。單向語言模型相當(dāng)于把訓(xùn)練語料通過下述條件概率分布的方式“記住”了:
我們一般說的“語言模型”,就是指單向的(更狹義的只是指正向的)語言模型。語言模型的關(guān)鍵點(diǎn)是要防止看到“未來信息”。如上式,預(yù)測(cè) x1 的時(shí)候,是沒有任何外部輸入的;而預(yù)測(cè) x2 的時(shí)候,只能輸入 x1,預(yù)測(cè) x3 的時(shí)候,只能輸入 x1,x2;依此類推。
RNN 模型是天然適合做語言模型的,因?yàn)樗旧砭褪沁f歸的運(yùn)算;如果用 CNN 來做的話,則需要對(duì)卷積核進(jìn)行 Mask,即需要將卷積核對(duì)應(yīng)右邊的部分置零。如果是 Transformer 呢?那需要一個(gè)下三角矩陣形式的 Attention 矩陣:
▲ 單向(正向)語言模型的Mask方式
如圖所示,Attention 矩陣的每一行事實(shí)上代表著輸出,而每一列代表著輸入,而 Attention 矩陣就表示輸出和輸入的關(guān)聯(lián)。假定白色方格都代表 0,那么第 1 行表示“北”只能跟起始標(biāo)記 <s> 相關(guān)了,而第 2 行就表示“京”只能跟起始標(biāo)記 <s> 和“北”相關(guān)了,依此類推。
所以,只需要在 Transformer 的 Attention 矩陣中引入下三角形形式的 Mask,并將輸入輸出錯(cuò)開一位訓(xùn)練,就可以實(shí)現(xiàn)單向語言模型了。至于 Mask 的實(shí)現(xiàn)方式,可以參考“讓Keras更酷一些!”:層中層與mask的 Mask 一節(jié)。
亂序語言模型
亂序語言模型是 XLNet 提出來的概念,它主要用于 XLNet 的預(yù)訓(xùn)練上。說到 XLNet,我覺得它的亂序語言模型這種預(yù)訓(xùn)練方式是很有意思的,但是我并不喜歡它將基本架構(gòu)換成了 Transformer-XL。我覺得誰有資源可以試試“BERT+亂序語言語言模型預(yù)訓(xùn)練”的組合,或許會(huì)有意外的發(fā)現(xiàn)。
亂序語言模型跟語言模型一樣,都是做條件概率分解,但是亂序語言模型的分解順序是隨機(jī)的:
總之, x1, x2, … , xn 任意一種“出場(chǎng)順序”都有可能。原則上來說,每一種順序都對(duì)應(yīng)著一個(gè)模型,所以原則上就有 n! 個(gè)語言模型。而基于 Transformer 的模型,則可以將這所有順序都做到一個(gè)模型中去!
那怎么做到這一點(diǎn)呢?還是以“北京歡迎你”的生成為例,假設(shè)隨機(jī)的一種生成順序?yàn)椤?lt;s> → 迎 → 京 → 你 → 歡 → 北 → <e>”,那么我們只需要用下圖中第二個(gè)子圖的方式去 Mask 掉 Attention 矩陣,就可以達(dá)到目的了:
跟前面的單向語言模型類似,第 4 行只有一個(gè)藍(lán)色格,表示“迎”只能跟起始標(biāo)記 <s> 相關(guān),而第 2 行有兩個(gè)藍(lán)色格,表示“京”只能跟起始標(biāo)記 <s> 和“迎”相關(guān),依此類推。直觀來看,這就像是把單向語言模型的下三角形式的 Mask“打亂”了。
也就是說,實(shí)現(xiàn)一種順序的語言模型,就相當(dāng)于將原來的下三角形式的 Mask 以某種方式打亂。正因?yàn)?Attention 提供了這樣的一個(gè) n × n 的 Attention 矩陣,我們才有足夠多的自由度去以不同的方式去 Mask 這個(gè)矩陣,從而實(shí)現(xiàn)多樣化的效果。
說到這里,讀者可能會(huì)有一個(gè)實(shí)現(xiàn)上的疑問:打亂后的 Mask 似乎沒看出什么規(guī)律呀,難道每次都要隨機(jī)生成一個(gè)這樣的似乎沒有什么明顯概率的 Mask 矩陣?
事實(shí)上有一種更簡(jiǎn)單的、數(shù)學(xué)上等效的訓(xùn)練方案。這個(gè)訓(xùn)練方案源于純 Attention 的模型本質(zhì)上是一個(gè)無序的模型,它里邊的詞序?qū)嶋H上是通過 Position Embedding 加上去的。也就是說,我們輸入的不僅只有 token 本身,還包括 token 所在的位置 id;再換言之,你覺得你是輸入了序列“[北, 京, 歡, 迎, 你]”,實(shí)際上你輸入的是集合“{(北, 1), (京, 2), (歡, 3), (迎, 4), (你, 5)}”。
▲ 重新排序,使得正向語言模型就可以實(shí)現(xiàn)亂序語言模型
既然只是一個(gè)集合,跟順序無關(guān),那么我們完全可以換一種順序輸入,比如剛才的“<s> → 迎 → 京 → 你 → 歡 → 北 → <e>”,我們可以按“(迎, 4), (京, 2), (你, 5), (歡, 3), (北, 1)”的順序輸入,也就是說將 token 打亂為“迎,京,你,歡,北”輸入到 Transformer 中,但是第 1 個(gè) token 的 position 就不是 1 了,而是 4;依此類推。這樣換過來之后,Mask 矩陣可以恢復(fù)為下三角矩陣,所以只需要在輸入層面打亂即可,這樣操作起來就更簡(jiǎn)單了。
Seq2Seq
現(xiàn)在到我們的“重頭戲”了:將 BERT 等 Transformer 架構(gòu)跟 Seq2Seq 結(jié)合起來。為什么說重頭戲呢?因?yàn)樵瓌t上來說,任何 NLP 問題都可以轉(zhuǎn)化為 Seq2Seq 來做,它是一個(gè)真正意義上的萬能模型。所以如果能夠做到 Seq2Seq,理論上就可以實(shí)現(xiàn)任意任務(wù)了。
將 BERT 與 Seq2Seq 結(jié)合的比較知名的工作有兩個(gè):MASS [6] 和 UNILM [5],兩者都是微軟的工作,兩者還都在同一個(gè)月發(fā)的。其中 MASS 還是普通的 Seq2Seq 架構(gòu),分別用 BERT 類似的 Transformer 模型來做 encoder 和 decoder,它的主要貢獻(xiàn)就是提供了一種 Seq2Seq 思想的預(yù)訓(xùn)練方案。
真正有意思的是 UNILM,它提供了一種很優(yōu)雅的方式,能夠讓我們直接用單個(gè) BERT 模型就可以做 Seq2Seq 任務(wù),而不用區(qū)分 encoder 和 decoder。而實(shí)現(xiàn)這一點(diǎn)幾乎不費(fèi)吹灰之力——只需要一個(gè)特別的 Mask。
插曲:事實(shí)的順序是筆者前兩周自己獨(dú)立地想到了用單個(gè) BERT 模型做 Seq2Seq 的思路,然后去找資料發(fā)現(xiàn)這個(gè)思路已經(jīng)被做了,正是 UNILM。
UNILM 直接將 Seq2Seq 當(dāng)成句子補(bǔ)全來做。假如輸入是“你想吃啥”,目標(biāo)句子是“白切雞”,那 UNILM 將這兩個(gè)句子拼成一個(gè):[CLS] 你 想 吃 啥 [SEP] 白 切 雞 [SEP]。經(jīng)過這樣轉(zhuǎn)化之后,最簡(jiǎn)單的方案就是訓(xùn)練一個(gè)語言模型,然后輸入“[CLS] 你 想 吃 啥 [SEP]”來逐字預(yù)測(cè)“白 切 雞”,直到出現(xiàn)“[SEP]”為止,即如下面的左圖:
▲ UNILM做Seq2Seq模型圖示。輸入部分內(nèi)部可做雙向Attention,輸出部分只做單向Attention。
事實(shí)上,上述的這些 Mask 方案,基本上都已經(jīng)被集成在筆者寫的 bert4keras [7],讀者可以直接用 bert4keras 加載 BERT 的預(yù)訓(xùn)練權(quán)重,并且調(diào)用上述 Mask 方案來做相應(yīng)的任務(wù)。下面,我們給出一個(gè)利用 UNILM 的思路做一個(gè)快速收斂的 Seq2Seq 模型的例子。
代碼開源
這次代碼的測(cè)試任務(wù)依然是之前的標(biāo)題生成,代碼調(diào)整自玩轉(zhuǎn)Keras之Seq2Seq自動(dòng)生成標(biāo)題里邊的代碼,并且得益于 bert4keras 的封裝,模型部分的代碼實(shí)現(xiàn)非常簡(jiǎn)單清爽。這一次直接使用了 THUCNews [8] 的原始數(shù)據(jù)集,讀者可以自行下載數(shù)據(jù)集和源碼測(cè)試復(fù)現(xiàn)。
詳細(xì)請(qǐng)看:
這個(gè)效果能有多好呢?經(jīng)過實(shí)驗(yàn),在標(biāo)題生成的任務(wù)上,只要 7000 個(gè) iteration,就已經(jīng)能生成基本可讀的標(biāo)題了。相應(yīng)地,以前用 LSTM 做的時(shí)候,大概需要多 10 倍的 iteration 才有同樣的效果。
▲ 只需要7000步的訓(xùn)練,就可以得到基本可讀的生成結(jié)果
簡(jiǎn)單說明
下面對(duì)代碼的關(guān)鍵部分做簡(jiǎn)要說明。
首先,輸入格式還是以 token_id 和 segment_id 輸入,比如:
tokens = ['[ClS]', u'你', u'想', u'吃', u'啥', '[SEP]', u'白', u'切', u'雞', '[SEP]']
token_ids = [token_dict[t] for t in tokens]
segment_ids = [0, 0, 0, 0, 0, 0, 1, 1, 1, 1]
segment_ids 用來區(qū)分輸入句子和目標(biāo)句子,0 對(duì)應(yīng)的為輸入句子,1 對(duì)應(yīng)的為目標(biāo)句子,只需要自帶的 tokenizer.encode 就可以生成這種 token_id 和 segment_id 了。
至于搭建模型,就只有寥寥幾行:
model = load_pretrained_model(
config_path,
checkpoint_path,
seq2seq=True,
keep_words=keep_words
)
model.summary()
y_in = model.input[0][:, 1:] # 目標(biāo)tokens
y_mask = model.input[1][:, 1:]
y = model.output[:, :-1] # 預(yù)測(cè)tokens,預(yù)測(cè)與目標(biāo)錯(cuò)開一位
# 交叉熵作為loss,并mask掉輸入部分的預(yù)測(cè)
y = model.output[:, :-1] # 預(yù)測(cè)tokens,預(yù)測(cè)與目標(biāo)錯(cuò)開一位
cross_entropy = K.sparse_categorical_crossentropy(y_in, y)
cross_entropy = K.sum(cross_entropy * y_mask) / K.sum(y_mask)
注意 load_pretrained_model 中只要設(shè)置 seq2seq=True,就會(huì)自動(dòng)加載 BERT 的 MLM 部分,并且傳入對(duì)應(yīng)的 Mask,剩下就只需要把 loss 寫好就行了。另外還有一個(gè) keep_words,這個(gè)是用來精簡(jiǎn) Embedding 層用的,對(duì)于中文 BERT 來說,總的 tokens 大概有 2 萬個(gè),這意味著最后預(yù)測(cè)生成的 token 時(shí)是一個(gè) 2 萬分類問題。
但事實(shí)上這大多數(shù) tokens 都不會(huì)被使用到,因此這 2 萬分類浪費(fèi)了不少計(jì)算量。于是這里提供了一個(gè)選項(xiàng),我們可以自行統(tǒng)計(jì)一個(gè)字表,然后傳入對(duì)應(yīng)的 id,只保留這部分 token,這樣就可以降低計(jì)算量了(精簡(jiǎn)后一般只有 5000 個(gè)左右)。
剩下的就是通過 beam search 來解碼等步驟了,這與一般的 Seq2Seq 無異,不再贅述,大家看玩轉(zhuǎn)Keras之Seq2Seq自動(dòng)生成標(biāo)題和代碼即可。
本文相對(duì)系統(tǒng)地總結(jié)了 Transformer 中 Attention 矩陣的 Mask 技巧,并且給出了用 UNILM 方案來做 Seq2Seq 的實(shí)現(xiàn)。對(duì)于同語言的 Seq2Seq 的文本生成任務(wù)來說,采用 UNILM 的思路加載 BERT 的 MLM 預(yù)訓(xùn)練權(quán)重,能夠有效、快速地實(shí)現(xiàn)并提升生成效果,值得一試。
聯(lián)系客服