新智元報(bào)道
最近谷歌又發(fā)布了全新的文本-圖像生成Muse模型,沒(méi)有采用當(dāng)下大火的擴(kuò)散(diffusion)模型,而是采用了經(jīng)典的Transformer模型就實(shí)現(xiàn)了最先進(jìn)的圖像生成性能,相比擴(kuò)散或自回歸(autoregressive)模型,Muse模型的效率也提升非常多。
論文鏈接:https://arxiv.org/pdf/2301.00704.pdf
項(xiàng)目鏈接:https://muse-model.github.io/
Muse以masked modeling任務(wù)在離散token空間上進(jìn)行訓(xùn)練:給定從預(yù)訓(xùn)練的大型語(yǔ)言模型(LLM)中提取的文本嵌入,Muse的訓(xùn)練過(guò)程就是預(yù)測(cè)隨機(jī)masked掉的圖像token。
與像素空間的擴(kuò)散模型(如Imagen和DALL-E 2)相比,由于Muse使用了離散的token,只需要較少的采樣迭代,所以效率得到了明顯提高;
與自回歸模型(如Parti)相比,由于Muse使用了并行解碼,所以效率更高。
使用預(yù)訓(xùn)練好的LLM可以實(shí)現(xiàn)細(xì)粒度的語(yǔ)言理解,從而轉(zhuǎn)化為高保真的圖像生成和對(duì)視覺(jué)概念的理解,如物體、空間關(guān)系、姿態(tài)、cardinality等。
在實(shí)驗(yàn)結(jié)果中,只有900M參數(shù)的Muse模型在CC3M上實(shí)現(xiàn)了新的SOTA性能,F(xiàn)ID分?jǐn)?shù)為6.06。
Muse 3B參數(shù)模型在zero-shot COCO評(píng)估中實(shí)現(xiàn)了7.88的FID,同時(shí)還有0.32的CLIP得分。
Muse還可以在不對(duì)模型進(jìn)行微調(diào)或反轉(zhuǎn)(invert)直接實(shí)現(xiàn)一些圖像編輯應(yīng)用:修復(fù)(inpainting)、擴(kuò)展(outpainting)和無(wú)遮罩編輯(mask-free editing)。
Muse模型
Muse模型的框架包含多個(gè)組件,訓(xùn)練pipeline由T5-XXL預(yù)訓(xùn)練文本編碼器,基礎(chǔ)模型(base model)和超分辨率模型組成。
1. 預(yù)訓(xùn)練文本編碼器
與之前研究中得出的結(jié)論類似,研究人員發(fā)現(xiàn)利用預(yù)訓(xùn)練的大型語(yǔ)言模型(LLM)有利于提升高質(zhì)量圖像的生成結(jié)果。
比如從語(yǔ)言模型T5-XXL中提取的嵌入(embedding)帶有關(guān)于物體(名詞)、行動(dòng)(動(dòng)詞)、視覺(jué)屬性(形容詞)、空間關(guān)系(介詞)以及其他屬性(如卡片性和組成)的豐富信息。
所以研究人員提出假設(shè)(hypothesis):Muse模型學(xué)會(huì)將LLM嵌入中的這些豐富的視覺(jué)和語(yǔ)義概念映射到生成的圖像上。
最近也有一些工作已經(jīng)證明了,由LLM學(xué)習(xí)到的概念表征與由視覺(jué)任務(wù)訓(xùn)練的模型學(xué)習(xí)的概念表征大致上是可以「線性映射」的。
給定一個(gè)輸入的文本標(biāo)題,將其傳遞給凍結(jié)參數(shù)的T5-XXL編碼器,可以得到一個(gè)4096維的語(yǔ)言嵌入向量,然后將這些向量線性地投射到Transformer模型(base和超分辨率)的hidden size維度上。
2. 使用VQGAN進(jìn)行Semantic Tokenization
VQGAN模型由一個(gè)編碼器和一個(gè)解碼器組成,其中的量化層(quantization layer)將輸入圖像映射成來(lái)自一個(gè)學(xué)習(xí)過(guò)的codebook的token序列。
然后完全用卷積層建立編碼器和解碼器,以支持對(duì)不同分辨率的圖像進(jìn)行編碼。
編碼器中包括幾個(gè)下采樣塊來(lái)減少輸入的空間維度,而解碼器中則是有相應(yīng)數(shù)量的上采樣塊來(lái)將latents映射回原始圖像大小。
研究人員訓(xùn)練了兩個(gè)VQGAN模型:一個(gè)是下采樣率f=16,模型在256×256像素的圖像上獲得基本模型的標(biāo)記,從而得到空間尺寸為16×16的標(biāo)記;另一個(gè)是下采樣率f=8,在512×512的圖像上獲得超分辨率模型的token,相應(yīng)的的空間尺寸為64×64。
編碼后得到的離散token可以捕捉圖像的高層次語(yǔ)義,同時(shí)也可以消除低層次的噪聲,并且根據(jù)token的離散性可以在輸出端使用交叉熵?fù)p失來(lái)預(yù)測(cè)下一階段的masked token
3. Base Model
Muse的基礎(chǔ)模型是一個(gè)masked Transformer,其中輸入是映射的T5嵌入和圖像token.
研究人員將所有的文本嵌入設(shè)置為unmasked,隨機(jī)mask掉一部分不同的圖像token后,用一個(gè)特殊的[MASK]標(biāo)記來(lái)代替原token.
然后將圖像token線性地映射到所需的Transformer輸入或hidden size維度的圖像輸入embedding中,并同時(shí)學(xué)習(xí)2D position embedding
和原始的Transformer架構(gòu)一樣,包括幾個(gè)transformer層,使用自注意塊、交叉注意力塊和MLP塊來(lái)提取特征。
在輸出層,使用一個(gè)MLP將每個(gè)masked圖像嵌入轉(zhuǎn)換為一組logits(對(duì)應(yīng)于VQGAN codebook的大?。?,并以ground truth的token為目標(biāo)使用交叉熵?fù)p失。
在訓(xùn)練階段,基礎(chǔ)模型的訓(xùn)練目標(biāo)為預(yù)測(cè)每一步的所有msked tokens;但在推理階段,mask預(yù)測(cè)是以迭代的方式進(jìn)行的,這種方式可以極大提高質(zhì)量。
4. 超分辨率模型
研究人員發(fā)現(xiàn),直接預(yù)測(cè)512×512分辨率的圖像會(huì)導(dǎo)致模型專注于低層次的細(xì)節(jié)而非高層次的語(yǔ)義。
使用級(jí)聯(lián)模型(cascade of models)則可以改善這種情況:
首先使用一個(gè)生成16×16 latent map(對(duì)應(yīng)256×256的圖像)的基礎(chǔ)模型;然后是一個(gè)超分辨率模型,將基礎(chǔ)latent map上采樣為64×64(對(duì)應(yīng)512×512的圖像)。其中超分辨率模型是在基礎(chǔ)模型訓(xùn)練完成后再進(jìn)行訓(xùn)練的。
如前所述,研究人員總共訓(xùn)練了兩個(gè)VQGAN模型,一個(gè)是16×16潛分辨率和256×256空間分辨率,另一個(gè)是64×64潛伏分辨率和512×512空間分辨率。
由于基礎(chǔ)模型輸出對(duì)應(yīng)于16×16 latent map的token,所以超分辨率模塊學(xué)會(huì)了將低分辨率的latent map 「翻譯」成高分辨率的latent map,然后通過(guò)高分辨率的VQGAN解碼,得到最終的高分辨率圖像;該翻譯模型也是以類似于基礎(chǔ)模型的方式進(jìn)行text conditioning和交叉注意力的訓(xùn)練。
5. 解碼器微調(diào)
為了進(jìn)一步提高模型生成細(xì)節(jié)的能力,研究人員選擇通過(guò)增加VQGAN解碼器的容量,添加更多的殘差層(residual layer)和通道的同時(shí)保持編碼器的容量不變。
然后對(duì)新的解碼器進(jìn)行微調(diào),同時(shí)保持VQGAN編碼器的權(quán)重、codebook和Transformers(即基礎(chǔ)模型和超分辨率模型)不變。這種方式能夠提高生成圖像的視覺(jué)質(zhì)量,而不需要重新訓(xùn)練任何其他的模型組件(因?yàn)橐曈X(jué)token保持固定)。
可以看到,經(jīng)過(guò)微調(diào)的解碼器以重建更多更清晰的細(xì)節(jié)。
6. 可變掩碼率(Masking Rate)
研究人員使用基于Csoine scheduling的可變掩碼率來(lái)訓(xùn)練模型:對(duì)于每個(gè)訓(xùn)練例子,從截?cái)嗟腶rccos分布中抽出一個(gè)掩碼率r∈[0,1],其密度函數(shù)如下.
掩碼率的期望值為0.64,也就是說(shuō)更偏向于選擇更高的掩碼率,使得預(yù)測(cè)問(wèn)題更加困難。
隨機(jī)的掩碼率不僅對(duì)并行采樣方案至關(guān)重要,而且還能實(shí)現(xiàn)一些零散的、開(kāi)箱即用的編輯功能。
7. Classifier Free Guidance(CFG)
研究人員采用無(wú)分類指導(dǎo)(CFG)來(lái)提高圖像的生成質(zhì)量和文本-圖像對(duì)齊。
在訓(xùn)練時(shí),在隨機(jī)選擇的10%的樣本上去除文本條件,注意力機(jī)制降為圖像token本身的自注意力。
在推理階段,為每個(gè)被mask的token計(jì)算一個(gè)條件logit lc和一個(gè)無(wú)條件logit lu,然后通過(guò)從無(wú)條件logit中移出一個(gè)量t作為指導(dǎo)尺度,形成最終的logit lg:
直觀來(lái)看,CFG是以多樣性換取保真度,但與以前方法不同的是,Muse通過(guò)采樣過(guò)程線性地增加指導(dǎo)尺度t來(lái)減少多樣性的損失,使得early token可以在低引導(dǎo)或無(wú)引導(dǎo)的情況下更自由地被取樣,不過(guò)也增加了對(duì)later tokens條件提示的影響。
研究人員還利用這一機(jī)制,通過(guò)將無(wú)條件的logit lu替換為以negative prompt為條件的logit,促進(jìn)了生成圖像具有與postive prompt相關(guān)的特征。
8. 推理時(shí)迭代并行解碼
在提升模型推理時(shí)間效率的一個(gè)關(guān)鍵部分是使用并行解碼來(lái)預(yù)測(cè)單個(gè)前向通道中的多個(gè)輸出token,其中一個(gè)關(guān)鍵假設(shè)是馬爾科夫?qū)傩?,即許多token是有條件地獨(dú)立于給定的其他token的。
其中解碼是根據(jù)cosine schedule進(jìn)行的,選擇固定比例中最高置信度的掩碼進(jìn)行預(yù)測(cè),其中token在剩余的步中被設(shè)定為unmasked,并且適當(dāng)減少masked tokens。
根據(jù)上述過(guò)程,就可以在基本模型中只用24個(gè)解碼步(step)實(shí)現(xiàn)對(duì)256個(gè)token的推理,在超分辨率模型中用8個(gè)解碼步對(duì)4096個(gè)token進(jìn)行推理,相比之下,自回歸模型需要256或4096步,擴(kuò)散模型需要數(shù)百步。
雖然最近的一些研究包括progressive distillation、better ODE solver大大減少了擴(kuò)散模型的采樣步驟,但這些方法還沒(méi)有在大規(guī)模的文本到圖像生成中得到廣泛驗(yàn)證。
實(shí)驗(yàn)結(jié)果
研究人員以不同的參數(shù)量(從600M到3B),基于T5-XXL訓(xùn)練了一系列基礎(chǔ)Transformer模型。
生成圖像的質(zhì)量
實(shí)驗(yàn)中測(cè)試了Muse模型對(duì)于不同屬性的文本提示的能力,包括對(duì)cardinality的基本理解,對(duì)于非單數(shù)的物體,Muse并沒(méi)有多次生成相同的物體像素,而是增加了上下文的變化,使整個(gè)圖像更加真實(shí)。
例如,大象的大小和方向、酒瓶包裝紙的顏色以及網(wǎng)球的旋轉(zhuǎn)等等。
定量比較
研究人員在CC3M和COCO數(shù)據(jù)集上與其他研究方法進(jìn)行了實(shí)驗(yàn)對(duì)比,指標(biāo)包括衡量樣本質(zhì)量和多樣性的Frechet Inception Distance(FID),以及衡量圖像/文本對(duì)齊的CLIP得分。
實(shí)驗(yàn)結(jié)果證明了632M的Muse模型在CC3M上取得了SOTA結(jié)果,在FID得分方面得到了改善,同時(shí)也取得了最先進(jìn)的CLIP得分。
在MS-COCO數(shù)據(jù)集上,3B模型取得了7.88分的FID得分,略好于相似參數(shù)量的Parti-3B模型取得的8.1分。
聯(lián)系客服