LIama 3+Mamba強(qiáng)強(qiáng)聯(lián)手!蒸餾到線(xiàn)性RNN,推理速度提升1.6倍
把Llama 3蒸餾到Mamba,推理速度最高可提升1.6倍!
而且性能不減,甚至表現(xiàn)比原始模型還要優(yōu)異。
這是來(lái)自Together AI的新作,通過(guò)蒸餾將Transformer和Mamba模型結(jié)合到了一起,同時(shí)還為混合模型涉及了推理加速算法。
提出Mamba架構(gòu)的大神、FlashAttention作者Tri Dao,也參與了這一項(xiàng)目。
Together AI創(chuàng)始人兼CEO表示,Transformer和Mamba的混合,是未來(lái)大模型的一大發(fā)展方向。
將Transformer蒸餾進(jìn)Mamba
在蒸餾正式開(kāi)始之前,需要先進(jìn)行從Transformer到線(xiàn)性RNN的初始化。
作者觀察到,Transformer的注意力機(jī)制與RNN的計(jì)算之間存在一定的相似性。
因此可以將Transformer的注意力線(xiàn)性化,從而建立二者的聯(lián)系。
利用這種對(duì)應(yīng)關(guān)系,可以將預(yù)訓(xùn)練的Transformer模型的參數(shù)復(fù)制到Mamba模型中。
在完成參數(shù)初始化后,作者采用了一個(gè)三階段的蒸餾流程進(jìn)一步提升Mamba模型的性能,使其更好地學(xué)習(xí)Transformer的知識(shí)。
第一階段是基于偽標(biāo)簽的蒸餾——使用預(yù)訓(xùn)練的Transformer教師模型在無(wú)標(biāo)簽數(shù)據(jù)上生成偽標(biāo)簽,然后讓Mamba學(xué)生模型在這些偽標(biāo)簽上訓(xùn)練。
這一過(guò)程的損失函數(shù)結(jié)合了KL散度損失和交叉熵?fù)p失,分別用于模仿教師模型輸出分布以及偽標(biāo)簽的擬合。
第二階段是在指令數(shù)據(jù)集上進(jìn)行的監(jiān)督微調(diào),使用帶標(biāo)簽的指令數(shù)據(jù)集(如OpenHermes 2.5)進(jìn)行訓(xùn)練。
最后一個(gè)階段,是用人類(lèi)反饋數(shù)據(jù),通過(guò)基于獎(jiǎng)勵(lì)模型進(jìn)行優(yōu)化。
作者收集了人類(lèi)對(duì)模型輸出的反饋數(shù)據(jù),然后據(jù)此構(gòu)建一個(gè)獎(jiǎng)勵(lì)模型并使用 RL 算法(如 PPO)來(lái)優(yōu)化模型在該獎(jiǎng)勵(lì)模型下的表現(xiàn)。
在8塊80G A100 GPU上,每個(gè)混合模型的整個(gè)蒸餾過(guò)程,只需不到五天的時(shí)間。
通過(guò)以上的蒸餾過(guò)程,作者得到了Transformer-Mamba混合模型,之后又提出了Speculative Decoding(推測(cè)解碼)算法來(lái)加速推理過(guò)程。
混合模型推理加速算法
推測(cè)解碼算法的基本思想是使用一個(gè)輕量級(jí)的Draft模型來(lái)預(yù)測(cè)多個(gè)token,然后再用驗(yàn)證模型(Verifier)來(lái)驗(yàn)證這些預(yù)測(cè)。
這樣可以顯著提高解碼的并行性,加速生成過(guò)程。
Draft模型通常是一個(gè)小的Transformer,根據(jù)當(dāng)前的上下文預(yù)測(cè)出接下來(lái)的K個(gè)token。
對(duì)于預(yù)測(cè)出的K個(gè)token,Transformer層可以直接并行地處理這K個(gè)token,計(jì)算它們的隱狀態(tài);
Mamba層則需要按照順序依次處理每個(gè)token,首先計(jì)算當(dāng)前token的隱狀態(tài),并將其與之前的隱狀態(tài)進(jìn)行比較。
- 如果當(dāng)前token是正確的,則將其添加到已接受的序列中,并更新最新的隱狀態(tài)(但不保存中間狀態(tài))。
- 如果當(dāng)前token是錯(cuò)誤的,則停止處理后續(xù)token,并將最新的隱狀態(tài)回退到上一個(gè)已接受的token處。
如果序列中的所有K個(gè)token都被接受,則將它們添加到輸出序列中,并繼續(xù)預(yù)測(cè)下一組token。
如果有token被拒絕,則從第一個(gè)被拒絕的token處截?cái)囝A(yù)測(cè)序列,并返回初始步驟從該位置開(kāi)始重新預(yù)測(cè)。
Llama 3推理速度提升1.6倍
測(cè)試結(jié)果表明,混合模型在單論(AlpacaEval)和多輪(MT-Bench)聊天對(duì)話(huà)任務(wù)上與Llama-3相當(dāng)甚至更優(yōu)。
并且還對(duì)不同混合比例的模型表現(xiàn)進(jìn)行了測(cè)試,發(fā)現(xiàn)其中按照1:1比例混合的模型表現(xiàn)最佳。
在零樣本的通用 NLP 任務(wù)評(píng)測(cè)中,混合模型的平均成績(jī)優(yōu)于同等規(guī)模的RNN模型。
在少樣本的OpenLLM Leaderboard榜單上,混合模型的表現(xiàn)與最好的開(kāi)源RNN模型相當(dāng),并在GSM8K和CRUX任務(wù)上超過(guò)了對(duì)應(yīng)的Instruct模型。
除了模型性能,作者也對(duì)推測(cè)解碼算法帶來(lái)的加速效果進(jìn)行了測(cè)試。
首先測(cè)試的是純Mamba模型,結(jié)果在2.8B和7B的模型上,相比原來(lái)的解碼方式,推理速度提升了1.7-2.6倍。
進(jìn)一步地,作者在蒸餾的Zephyr和Llama混合模型上進(jìn)行了測(cè)試,結(jié)果Zephyr混合模型的推理速度提升了1.8倍以上,Llama混合模型也有1.6倍左右的加速。
論文地址:https://www.together.ai/blog/the-mamba-in-the-llama-distilling-and-accelerating-hybrid-models