B站文生視頻模型工程實(shí)踐
一、前言
近年來,AI 內(nèi)容生成(AIGC)領(lǐng)域的快速發(fā)展令人雀躍,OpenAI 在 2023 年初推出大型語言模型(LLM)GPT-4 受到了學(xué)術(shù)界和工業(yè)界的極大關(guān)注。OpenAI 隨后在 2024 年初推出文生視頻(T2V)模型Sora,能夠根據(jù)文本指令制作出具有現(xiàn)實(shí)風(fēng)格和富有想象力的場景視頻,更是展示了令人驚喜的“世界模擬器”能力。
B站作為UGC內(nèi)容豐富的視頻網(wǎng)站,在視頻生成模型領(lǐng)域有著天然數(shù)據(jù)優(yōu)勢和廣泛應(yīng)用場景。在此之前我們已經(jīng)有了一段時(shí)間的LLM模型訓(xùn)練經(jīng)驗(yàn),文生視頻模型結(jié)構(gòu)、語料以及訓(xùn)練過程有一定的差異性,本文重點(diǎn)介紹B站TTV團(tuán)隊(duì)在文生視頻模型上積極探索后的經(jīng)驗(yàn)及感悟。
二、TTV model
在OpenAI提供的公開信息中,Sora模型實(shí)際上是一個(gè)Diffusion Model+Transformer架構(gòu)。本文基于B站在生成式TTV在自研道路上的探索、結(jié)合行業(yè)進(jìn)展和工程實(shí)踐,先后嘗試了幾種TTV(Text to video,后面將簡稱為TTV)模型。文中重點(diǎn)介紹由colossal-ai發(fā)布的類Sora模型Open-Sora,以及由智譜AI發(fā)布的CogVideoX模型。
2.1 OpenSora
OpenSora的核心是Stdit(Spatial Temporal Diffusion Transformer)結(jié)構(gòu)。
DiT(Diffusion Transformer)模型是一種結(jié)合了去噪擴(kuò)散概率模型(DDPM)和Transformer架構(gòu)的擴(kuò)散模型,通過模擬數(shù)據(jù)的逐步去噪過程來生成新的文本,在此基礎(chǔ)上發(fā)展出了STDiT,STDiT模型就是一種使用Cross-Attention的DiT變體。
STDiT模型 的特點(diǎn)包括:
- 融合空間 - 時(shí)間注意力機(jī)制:結(jié)構(gòu)巧妙地串聯(lián)起二維空間注意力模塊和一維時(shí)間注意力模塊,能夠在捕捉視頻幀內(nèi)的空間特征的同時(shí),精準(zhǔn)模擬視頻幀與幀間的時(shí)序關(guān)聯(lián)。
- 交叉注意力集中模塊:緊隨在時(shí)間注意力模塊后,確保文本語義與生成視頻的深度對齊。
- 計(jì)算資源需求降低:相較于全注意力機(jī)制顯著減少了計(jì)算資源需求。
- 利用預(yù)訓(xùn)練權(quán)重:能夠更好地利用預(yù)訓(xùn)練好的圖像 DiT 權(quán)重遷移學(xué)習(xí)至視頻場景。
圖2-1 opensora模型架構(gòu)
如圖2-1所示,OpenSora的Transformer Block的部分包括了Spatial Attention(空間維度上的注意力計(jì)算)、Temporal Attention(時(shí)間維度上的注意力計(jì)算),以及通過Cross Attention將兩個(gè)維度信息進(jìn)行交叉計(jì)算,此外整個(gè)架構(gòu)還包括一個(gè)預(yù)訓(xùn)練好的視頻編碼器(VAE),一個(gè)文本編碼器(T5)。整個(gè)模型的訓(xùn)練階段如圖2-2所示,首先采用預(yù)訓(xùn)練好的 Variational Autoencoder (VAE) 的編碼器將視頻、圖片數(shù)據(jù)進(jìn)行壓縮,將對應(yīng)的文本描述通過Text Encoder進(jìn)行embedding化,然后在壓縮之后的隱層空間中與文本嵌入(text embedding)一起訓(xùn)練 STDiT擴(kuò)散模型。
圖2-2 opensora訓(xùn)練流程示意圖
Open-Sora 模型的訓(xùn)練采用多階段訓(xùn)練方法,包括大規(guī)模圖像預(yù)訓(xùn)練、大規(guī)模視頻預(yù)訓(xùn)練和高質(zhì)量視頻數(shù)據(jù)微調(diào)等。在數(shù)據(jù)收集和預(yù)處理方面,涉及到對視頻美學(xué)、動(dòng)態(tài)性、運(yùn)鏡、字幕等信息的處理,構(gòu)建高質(zhì)量的訓(xùn)練數(shù)據(jù)集。
其功能特點(diǎn)有:支持視頻數(shù)據(jù)預(yù)處理、加速訓(xùn)練、推理等全套流程;提供視頻切割和字幕工具,支持剪輯和 T5 文本調(diào)節(jié);實(shí)現(xiàn)可變長寬比、可變分辨率和可變時(shí)長等功能;能夠生成各種風(fēng)格的視頻內(nèi)容,支持多種生成任務(wù)等。
2.2 CogVideoX
CogVideoX【8】使用了由智譜提出的3D Causal VAE以及專家Transformer模塊,訓(xùn)練流程與STDIT基本一致,都是通過VAE與Text Encoder將視頻數(shù)據(jù)與文本描述進(jìn)行embedding化后,輸入Transformer Block進(jìn)行訓(xùn)練。與STDiT不同的是,視頻embedding與文本embedding會直接拼接起來做 Full Attention 計(jì)算,以便更好地對齊視覺和語義信息。但是這兩種模態(tài)的特征空間差異顯著,它們的嵌入甚至可能具有不同的數(shù)值范圍。為了在同一序列中更好地處理它們,Transformer模塊采用了專家自適應(yīng)層歸一化(Expert Adaptive Layernorm)來獨(dú)立處理每種模態(tài),如下圖所示。處理后的模態(tài)信息會繼續(xù)拼接到一起進(jìn)行3D Full Attention進(jìn)行計(jì)算,在Attention計(jì)算前使用了 3D-RoPE將相對位置信息添加到視覺模態(tài)上。
圖2-3 cogvideo模型結(jié)構(gòu)
三、工程實(shí)踐
3.1 數(shù)據(jù)存儲與加載
TTV模型訓(xùn)練相較于LLM模型,訓(xùn)練數(shù)據(jù)集有較大差異性。處理后的訓(xùn)練數(shù)據(jù)包括視頻切片、圖片和對應(yīng)的文字描述。相比于LLM語料庫,TTV單個(gè)數(shù)據(jù)文件較小、但整體數(shù)據(jù)量大,給訓(xùn)練造成了一些不便。
此前的LLM模型訓(xùn)練數(shù)據(jù)均以5G一個(gè)分片整塊存儲在boss上,每個(gè)訓(xùn)練epoch開始前會拉取一個(gè)5G的分片到訓(xùn)練機(jī)器上,IO效率較高。但視頻切片的數(shù)據(jù)平均大小在1M左右,如果在一個(gè)epoch的開始拉取所需的訓(xùn)練數(shù)據(jù),boss拉取小文件的效率無法滿足我們的訓(xùn)練需求,會造成GPU資源閑置。
結(jié)合上述特點(diǎn),我們和B站基礎(chǔ)架構(gòu)存儲團(tuán)隊(duì)討論過多個(gè)方案:
方案一:HDFS+文件打包
相較于boss,HDFS能夠提供更大的數(shù)據(jù)訪問的帶寬上限。HDFS 的負(fù)載通常較為平穩(wěn),但也存在抖動(dòng)場景,且DN和NN受任務(wù)影響明顯,當(dāng)前HDFS數(shù)據(jù)99%為2副本存儲,當(dāng)有2臺DN受到任務(wù)影響時(shí),數(shù)據(jù)讀寫性能會有明顯下降。此外因?yàn)樾∥募奶攸c(diǎn),拉取效率較難滿足訓(xùn)練要求,基于這個(gè)特點(diǎn),我們考慮將多個(gè)小文件打包在一起,形成一個(gè)大chunk file,加速HDFS拉取效率,此外在訓(xùn)練的腳本中,加入一個(gè)dataset reader,解析相應(yīng)的chunk file,重新還原成小視頻、圖片文件。顯然這個(gè)辦法可以解決文件拉取效率問題,并且還節(jié)約了一定的存儲空間和文件句柄,對存儲媒介友好。但是存在對代碼的侵入,以及chunk內(nèi)容調(diào)整不靈活的問題,我們短期沒有采用。
方案二:HDFS->backend & Alluxio->frontend
Alluxio 是一個(gè)開源的分布式內(nèi)存文件系統(tǒng),能大幅提升數(shù)據(jù)訪問性能,加速大數(shù)據(jù)分析和處理任務(wù)。在機(jī)器學(xué)習(xí)訓(xùn)練中經(jīng)常使用alluxio,快速提供數(shù)據(jù)給模型。并且他能靈活適配,可與多種存儲系統(tǒng)(如 HDFS、S3、Azure Blob Storage 等)和計(jì)算框架(如 Hadoop、Spark、Flink 等)集成,適應(yīng)不同的大數(shù)據(jù)架構(gòu)。經(jīng)基架同學(xué)一同努力,最終我們采用了Alluxion Fused+HDFS backend方案,加載數(shù)據(jù)的方式近似于訪問本地文件系統(tǒng),訓(xùn)練框架無感。alluxio定期同步HDFS上的數(shù)據(jù)到本地,能較好的滿足我們海量小文件的讀寫以及快速上線的訴求。
3.2 數(shù)據(jù)預(yù)處理優(yōu)化
TTV數(shù)據(jù)的預(yù)處理過程更為繁瑣,除了對視頻的文字描述進(jìn)行tokenizer外,還需要對視頻數(shù)據(jù)進(jìn)行抽幀、歸一化和vae編碼,尤其是vae編碼這一步在單epoch訓(xùn)練中會占到總時(shí)長的18%,并且極為消耗顯存。為了優(yōu)化這部分的速度,我們做了兩部優(yōu)化。
- 數(shù)據(jù)并行
參考訓(xùn)練中的數(shù)據(jù)并行策略(Data Parallel),在訓(xùn)練任務(wù)啟動(dòng)的初始階段,初始化一個(gè)包含全部顯卡的全局通信組,將一個(gè)epoch訓(xùn)練中最耗算力的編碼部分拆分到所有顯卡上,待每張卡處理完成后再將所有的結(jié)果gather到0號卡上進(jìn)行聚合和保存,后續(xù)的訓(xùn)練階段只需讀取處理好的embedding數(shù)據(jù)即可。
- VAE離線化
考慮到一份訓(xùn)練數(shù)據(jù)可能會在后續(xù)下訓(xùn)練任務(wù)中被反復(fù)使用,并且數(shù)據(jù)處理非常消耗算力和顯存,因此在第一步優(yōu)化的基礎(chǔ)上,我們會把聚合好的訓(xùn)練結(jié)果持久化,后續(xù)直接從HDFS\Alluxio中讀取,整體顯存和訓(xùn)練時(shí)間都有較大優(yōu)化。
上述過程如圖3-1所示:
圖3-1 數(shù)據(jù)讀取、預(yù)處理示意
3.3 模型并行優(yōu)化
雖然TTV模型通常參數(shù)量較小,但視頻訓(xùn)練數(shù)據(jù)更為復(fù)雜,序列長度較長。因此TTV模型訓(xùn)練過程中的激活內(nèi)存成本極高,導(dǎo)致訓(xùn)練速度明顯更慢,且限制了訓(xùn)練數(shù)據(jù)的清晰度與抽幀率的提高。
在Transformer Block中,激活顯存占用的大頭為Attention的計(jì)算,在不使用Flash-attention的情況下,attention激活占用與序列長度的平方成正比,而使用了時(shí)空Cross-Attention設(shè)計(jì)的STDIT模型,訓(xùn)練過程中更是包含了3次attention的計(jì)算,使得激活占用更是大大增加。針對Transformer訓(xùn)練過程中,序列長度過長導(dǎo)致的激活占用高的問題,業(yè)界已有一些序列并行方案提被出,序列并行是一種專門為跨多個(gè)設(shè)備分發(fā)長序列和激活而設(shè)計(jì)的技術(shù),以下四種是主要的序列并行技術(shù), Ring-Attention、Megatron-SP、DeepSpeed-Ulysses和Dyncamic SP,首先將介紹下各方案的設(shè)計(jì)。
業(yè)界方案
- Ring Attention
Ring Attention 的核心思想是將輸入序列分割成多個(gè)塊,并將這些塊分布在多個(gè)計(jì)算設(shè)備上進(jìn)行并行處理,通過使用 Online Softmax 機(jī)制,在不保留完整序列長度的情況下計(jì)算注意力分?jǐn)?shù)。每個(gè)設(shè)備首先對自己的數(shù)據(jù)塊進(jìn)行局部的自注意力計(jì)算,然后將關(guān)鍵信息(key-value pair)傳遞給下一個(gè)設(shè)備,通過一個(gè)環(huán)形的數(shù)據(jù)傳輸策略,實(shí)現(xiàn)在不增加單個(gè)設(shè)備內(nèi)存負(fù)擔(dān)的情況下處理超長序列。然而,環(huán)形注意力對 P2P 通信的依賴在高延遲環(huán)境中可能效率較低。
- Megatron-SP
Megatron框架中的的序列并行,是在張量并行(Tensor Parallelism)的基礎(chǔ)上,將Transformer Block中的LayerNorm以及Dropout層的輸入按Sequence Length維度進(jìn)行了切分,使得各個(gè)設(shè)備上面只需要做一部分的Dropout和LayerNorm。雖然減少了每張卡上的激活占用,但在通信過程中引入了額外的all-gather和reduce-scater操作。因在設(shè)計(jì)上依托切分注意力頭來實(shí)現(xiàn)并行,使用上也會受到注意力頭數(shù)量的限制。
圖3-2 megatron-sp
- DeepSpeed-Ulysses
DeepSpeed-Ulysses 【7】引入了一種創(chuàng)新的方法,通過利用全對全(all-to-all)集體通信來訓(xùn)練長序列。在處理長序列時(shí),它將查詢(query)、鍵(key)和值(value)矩陣在注意力頭之間進(jìn)行切分,但保留原始注意力計(jì)算結(jié)構(gòu)。這個(gè)過程通過兩組全對全通信來實(shí)現(xiàn),這兩組通信在序列分割和注意力頭分割之間交替進(jìn)行。這樣的設(shè)計(jì)使得在處理長序列時(shí),能夠在保持計(jì)算結(jié)構(gòu)的同時(shí),有效地在多個(gè) GPU 之間分配數(shù)據(jù),減少通信開銷。
圖3-3 deepspeed-ulysses
- Dyncamic SP
與以上三種針對單序列維度內(nèi)的并行性方案不同,DSP(Dynamic Sequence Parallel)【6】方案是針對多維序列的并行問題所設(shè)計(jì)。在多維Transformer模塊中,每個(gè)序列維度的計(jì)算其實(shí)是獨(dú)立進(jìn)行的,因此可以分別在不同緯度切分和計(jì)算,僅在維度計(jì)算交換的時(shí)間點(diǎn)使用高效的全對全操作(all-to-all)來為中間序列切換并行維度并重新進(jìn)行動(dòng)態(tài)切分。這種方法使 DSP 能夠獨(dú)立于模塊內(nèi)的計(jì)算邏輯外,消除了模塊內(nèi)許多不必要的通信。
圖3-4 dsp
相關(guān)實(shí)踐
- 基于OpenSora SP實(shí)現(xiàn)
OpenSora模型使用了時(shí)空交叉注意力機(jī)制,會分別計(jì)算視頻空間維度的Attention、時(shí)間維度的Attention,并最后和文本embedding進(jìn)行交叉Attention計(jì)算,其設(shè)計(jì)更適合使用DSP方案。具體實(shí)現(xiàn)上,我們會在空間注意力計(jì)算前,先在空間維度進(jìn)行切分,SP并行組內(nèi)的各rank分別計(jì)算不同的seq片段。隨后在時(shí)間維度注意力計(jì)算前,進(jìn)行一次all-to-all通信進(jìn)行同步,并交換切分緯度到時(shí)間維度上。因?yàn)榕c文本embedding的交叉注意力計(jì)算,只與空間信息有關(guān),因此可在交叉Attention計(jì)算后再進(jìn)行一次all-to-all通信來同步計(jì)算結(jié)果。
- 基于CogvideoX的SP實(shí)現(xiàn)
考慮到CogVideoX模型雖然使用了視頻信息與文字信息,但兩種embedding是在單一維度進(jìn)行拼接,并進(jìn)行全局Attention計(jì)算的,因此本身屬于單維Transformer。對于單維Transformer,我們選擇實(shí)現(xiàn)了DS-UIysses方案,在實(shí)現(xiàn)上:
a. 先在Transformer Block之前,沿sequence維度進(jìn)行切分,但只對視頻的hidden state的seq維進(jìn)行切分。文字embedding的部分不切分,并與切分后的視頻embedding拼在一起。
b. 相對位置編碼是需要對全序列長度進(jìn)行計(jì)算的,因此在每個(gè)3D-Rope計(jì)算前,進(jìn)行一次all-to-all通信,回收sequence維度的信息。
c. 因?yàn)槊總€(gè)seq段都拼接了文字embedding,因此首先需要通過remove_extra_encoder方法,移除每段seq冗余的text embedding。在計(jì)算完Attention后,進(jìn)行一次all-to-all通信,在attention head維度回收信息,并回到sequence切分的狀態(tài)。在MLP Block計(jì)算前,通過 add_text_encoder補(bǔ)上每段seq都有的文字embedding部分。
d. 在所有的Transformer Block后,進(jìn)行g(shù)ather sequence操作,合并SP組各組上的計(jì)算結(jié)果。
通過使用SP并行策略,CogVideoX由單機(jī)16卡只能訓(xùn)練45幀1080p的數(shù)據(jù),提升至可訓(xùn)練221幀1080p的訓(xùn)練數(shù)據(jù)。
圖3-5 基于CogVideoX的SP實(shí)現(xiàn)示意圖
四、文生視頻模型在NPU架構(gòu)上的工程實(shí)踐
目前我們的訓(xùn)練算力構(gòu)成為GPU + NPU,但GPU與NPU在芯片設(shè)計(jì)上差異較大,軟件棧和生態(tài)也存在較大區(qū)別,因此需要做相應(yīng)的適配和優(yōu)化才能充分利用NPU資源進(jìn)行訓(xùn)練。
4.1 基礎(chǔ)適配
- 模型適配
目的是讓模型能夠在npu環(huán)境下開箱啟用(訓(xùn)練),保證pipeline可運(yùn)行。其主要工作為檢查依賴cuda硬件及精度在npu下受限的算子,查詢?nèi)雲(yún)⑵ヅ涞膎pu-wrapped算子,進(jìn)行替換與輕量的代碼適配,例如替換T5模塊中的LayerNorm為NPURmsNorm、在nn.Conv3d算子中顯式指定使用torch.float16精度等。
- 框架移植
我們的訓(xùn)練框架中部分模塊開采用megatron-core作為加速手段,需要移植到對應(yīng)NPU的版本,主要參考華為Megatron-NPU倉庫的范例和實(shí)現(xiàn),進(jìn)行移植。
- 精度驗(yàn)證
可訓(xùn)練后,需要進(jìn)行精度的驗(yàn)證。精度驗(yàn)證需注意保證GPU版本與NPU版本的第三方依賴保持一致,并固定代碼中隨機(jī)的部分,例如隨機(jī)數(shù)設(shè)置、數(shù)據(jù)的抽樣、Vision encoder中的加躁信息等??梢越柚A為的精度驗(yàn)證工具,dump主要算子的輸入輸出,并逐步進(jìn)行API級別、模塊級別與整網(wǎng)級別的精度驗(yàn)證。
不過值得注意的是,因?yàn)榈讓訉?shí)現(xiàn)有所區(qū)別,配置不同,以及fusion的使用與否等等,loss是無法完全對齊的。下面列出幾個(gè)常見的,會引起loss差異的參數(shù)以供參考。
- transformer_impl: local -> transformer_engine
- attention_softmax_in_fp32: False -> True
- apply_rope_fusion: False -> True
- rotary_fusion: False -> True
- swiglu_fusion: False -> True
案例
基于內(nèi)部基座版本:global batch size 160,Run 1曲線代表NPU,Run 2曲線代表GPU。
2940-steps:
Loss diff max = 0.46540455520153046, diff mean = 0.006955942064507858
圖片
1500-steps:
Loss diff max = 0.2761184275150299, diff mean = 0.008173752022286256
圖片
對比趨勢和diff,基本一致,可以確認(rèn)模型參數(shù)基本可用,更細(xì)粒度則需要做到分層精度對比以及生成視頻benchmark對比。
基于CogVideo-5B基座版本:8卡,25幀,480P
圖片
4.2 優(yōu)化
- profiler性能分析調(diào)優(yōu)
A.后向耗時(shí)問題定位
下圖profiler中可以發(fā)現(xiàn),在后向過程中執(zhí)行了一次前向,占用接近20%,但其實(shí)我們并未顯式地配置重計(jì)算策略。
后續(xù)分析代碼發(fā)現(xiàn),訓(xùn)練中會默認(rèn)使用torch.utils.checkpoint對全層數(shù)的transformer進(jìn)行重計(jì)算,但基于如下幾點(diǎn)原因可以對其優(yōu)化:
a.顯存仍未用盡,有tradeoff的余地
b.需要重計(jì)算的層數(shù)應(yīng)該可控
c.重計(jì)算作用在一個(gè)融合算子上,其中涉及MLP,GN,QKV計(jì)算,attention計(jì)算等,但其實(shí)只有attention計(jì)算使用該策略的價(jià)值最高
經(jīng)過一系列的代碼調(diào)整后,最終訓(xùn)練速度提升約12%。
B.融合算子替換
根據(jù)profiling分析,紅框中g(shù)elu算?實(shí)際執(zhí)?時(shí)是以多個(gè)?算?拼接的形式下發(fā)和執(zhí)?的。
圖片
可以使?融合算?F.gelu 進(jìn)?替換,優(yōu)化下發(fā)和執(zhí)?。跟據(jù)profiling中的API調(diào)???梢远ㄎ坏皆撍?主要出現(xiàn)在T5模塊和megatron中。
C.使用連續(xù)張量減少切分操作
圖片
圖片
如上圖profiling所示,耗時(shí)占比最大的3個(gè)op中,第二第三都是計(jì)算密集型的,可以從策略或算法的角度優(yōu)化。第一位StridedSlice是訪存密集型,改變使用StridedSlice 的上層算子的輸入tensor的存儲方式(使用torch.contiguous)即可連續(xù)分配,優(yōu)化原理說明如下:
torch.contiguous操作將非連續(xù)張量轉(zhuǎn)換為連續(xù)張量,后續(xù)的切片和訪問操作大幅簡化,甚至避免StridedSlice的調(diào)用。
主要體現(xiàn)在三種場景:
其一,連續(xù)張量的切片不需要依賴復(fù)雜的stride信息。硬件可以更高效地預(yù)取數(shù)據(jù),提高計(jì)算速度。
其二,非連續(xù)張量的 StridedSlice 需要?jiǎng)討B(tài)計(jì)算目標(biāo)地址,頻繁調(diào)用可能帶來性能瓶頸。連續(xù)張量的切片是直接基于線性偏移量完成的,減少了計(jì)算需求。
其三,某些操作(如 view)要求張量是連續(xù)的。如果張量已經(jīng)是連續(xù)的,相關(guān)操作無需隱式調(diào)用 StridedSlice 或創(chuàng)建新的張量拷貝,直接提升性能。
在我們的訓(xùn)練中得到加速的主要是第三種場景。上右profiling是優(yōu)化后的結(jié)果。
- FlashAttention(NPU)
在torch2.0之前的時(shí)代,業(yè)界都會使用樸素的standard_attention進(jìn)行注意力機(jī)制的計(jì)算,但當(dāng)其attention_mask為精度f32時(shí)會引起巨大的顯存占用及 NaN-bug。后續(xù)torch推出了一個(gè)改良算子scaled_dot_product_attention(sdpa),進(jìn)行算子融合,優(yōu)化內(nèi)存使用,適合長序列,內(nèi)置支持 Dropout 和 Causal Mask?,F(xiàn)今,主流的大模型結(jié)構(gòu)中都會采用FlashAttention,是一種sdpa的實(shí)現(xiàn),相較于原版scaled_dot_product_attention,其對QKT進(jìn)行了分塊處理,避 免存儲完整的注意力矩陣,可處理更長的序列。
圖片
在遷移至NPU架構(gòu)后,我們逐步將原有的注意力機(jī)制代碼,改進(jìn)為基于 NPU底層實(shí)現(xiàn)的FlashAttention,其中涉及的抽象、封裝,引入必要的設(shè)計(jì)模式等。以下僅對flashAttention作簡單介紹。
FlashAttention本質(zhì)是算子和數(shù)據(jù)的融合:
- 將多個(gè)算子合并為一個(gè),簡化計(jì)算過程,減少計(jì)算量,提高計(jì)算效率。
- 將多個(gè)中間結(jié)果合并為一個(gè),減少內(nèi)存占用,提高內(nèi)存利用率;減少不同算子之間的傳輸,提高數(shù)據(jù)處理效率。
- 簡化代碼實(shí)現(xiàn),減少代碼量,提高代碼可讀性和可維護(hù)性。
FlashAttention實(shí)現(xiàn)原理【1】:三基石
- Tiling切片:利用高速SRAM代替內(nèi)存,但SRAM內(nèi)存小,無法一次性完成所有數(shù)據(jù)的注意力計(jì)算,需要進(jìn)行分塊計(jì)算,對應(yīng)上文中的QKT分塊。
- 重計(jì)算:放棄中間結(jié)果寫回,需要使用時(shí)重新計(jì)算,用計(jì)算換訪存。
- Kernel Fusion:將多個(gè)操作融合為一個(gè)操作,基于Tiling利用一個(gè)kernel完成整個(gè)計(jì)算,對應(yīng)上文中的算子融合。
以下是前向、反向過程的公式描述,具體細(xì)節(jié)不在此討論。
前向Forward: FlashAttentionScore[2]
后向Backward: FlashAttentionScoreGrad[3]
- 虛擬內(nèi)存特性:expandable_segments【4】
一般情況下,由PyTorch自己管理虛擬地址與物理地址映射,降低內(nèi)存碎片。其原理大致如下:對于大于 2MB 的分配,分配器會調(diào)用 aclrtMalloc,以獲取與用戶請求大小完全相同的內(nèi)存分配。后續(xù)計(jì)算中,如果這些分配中的某些部分空閑,它們可以被重新用于其他請求。這種方式在程序多次請求相同大小或者是該大小整數(shù)倍的內(nèi)存時(shí)效果很好。許多深度學(xué)習(xí)模型的行為符合這一特點(diǎn)。
然而,有一種常見的例外情況是,批次大小在每次迭代中會略有變化,例如在批量推理中。當(dāng)程序最初以批次大小N運(yùn)行時(shí),會為該大小進(jìn)行合適的內(nèi)存分配。如果后續(xù)運(yùn)行的批次大小變?yōu)镹?1,現(xiàn)有的內(nèi)存分配仍然足夠使用。然而,如果批次大小變?yōu)镹+1,則需要進(jìn)行新的內(nèi)存分配,這些后續(xù)分配并非所有張量的大小都相同。一些張量的大小可能是(N+1)×A,而另一些可能是(N+1)×A×B,其中A和B是模型中與批次無關(guān)的維度。由于分配器會在現(xiàn)有分配足夠大時(shí)重用它們,因此某些(N+1)×A 的分配可能會勉強(qiáng)適應(yīng)已經(jīng)存在的N×B×A 段,盡管不完全匹配。當(dāng)模型運(yùn)行時(shí),這些段會被部分填充,導(dǎo)致在段末尾留下不可用的空閑內(nèi)存切片。最終,分配器可能需要調(diào)用 aclrtMalloc 為新的(N+1)×A×B 段分配內(nèi)存。如果沒有足夠的內(nèi)存,就會拋出異常,結(jié)束程序。對于有 50 層以上的模型,這種模式可能會重復(fù) 50 多次,從而產(chǎn)生許多小的內(nèi)存碎片。
通過分析訓(xùn)練時(shí)采集的profiling文件,發(fā)現(xiàn)我們研發(fā)中常見的顯存瓶頸符合上述場景描述:
a.動(dòng)態(tài)shape場景,比如VAE中的升降采樣,shape隨step增加而增大,從而導(dǎo)致顯存塊不能復(fù)用,碎片上升
b.transformer一般由多層組成,其帶來的大量激活數(shù)據(jù)與激活數(shù)據(jù)size不統(tǒng)一,會給顯存復(fù)用帶來困難,不僅會影響性能,嚴(yán)重時(shí)導(dǎo)致OOM
采用NPU自研的“內(nèi)存池?cái)U(kuò)展段”功能,有助于優(yōu)化上述場景。
“內(nèi)存池?cái)U(kuò)展段”在最初創(chuàng)建一個(gè)內(nèi)存段后,可按需擴(kuò)展其大小。與每次分配都創(chuàng)建一個(gè)新的內(nèi)存段不同,它嘗試為每個(gè)流(stream)創(chuàng)建一個(gè)內(nèi)存段,并根據(jù)需求動(dòng)態(tài)增長。當(dāng)程序運(yùn)行到批次大小N+1的情況時(shí),這些分配會很好地排列到單個(gè)大型內(nèi)存段中,直到段被填滿。然后,分配器會請求更多內(nèi)存,并將其追加到該段的末尾。這種方式不會產(chǎn)生太多無法使用的內(nèi)存碎片,因此更有可能成功找到所需的內(nèi)存。
五、后續(xù)工作方向
1.引入流水線并行(Pipeline Parallelism)
如上文所述,現(xiàn)在已采用的優(yōu)化方案多集中在tensor計(jì)算的某些維度,比如sp/cp針對的是tensor的序列維度,flash_attention針對的tensor的特征維度。在業(yè)界主流優(yōu)化方案中,以layer或模塊作為優(yōu)化對象的Pipeline Parallelism,PP也是值得探索的方向。
優(yōu)點(diǎn):
- 將模型分成多個(gè)階段,每個(gè)階段只在一個(gè)NPU上運(yùn)行,每個(gè)NPU只需要存儲它負(fù)責(zé)的部分模型,而不是整個(gè)模型。這大大減少了顯存消耗。類似FSDP與ZERO機(jī)制。
- 模型并行MP以及張量并行TP會導(dǎo)致頻繁的跨設(shè)備通信,而PP通過流水線操作僅傳遞每個(gè)階段最終結(jié)果,減少了通信負(fù)擔(dān)。
- 擴(kuò)展性強(qiáng),比如transformer這類的堆疊結(jié)構(gòu)
挑戰(zhàn):
- 流水線填充與同步延遲:前向傳播和反向傳播之間存在同步延遲(Pipeline Bubble)。增加batch size是常見的優(yōu)化策略,但可能導(dǎo)致內(nèi)存壓力增大。
- 負(fù)載均衡:如果模型切分不均勻,會導(dǎo)致某些NPU過載,而其他NPU閑置。
- 實(shí)現(xiàn)復(fù)雜性:需要在代碼層面將模型和數(shù)據(jù)流切分為多個(gè)階段,并設(shè)計(jì)高效的通信方案。
2.突破torch.nn.GroupNorm的限制
在處理VAE encoding時(shí),雖然我們在序列維度進(jìn)行了切分,但在處理groupnorm時(shí),常規(guī)做法需要統(tǒng)計(jì)全序列的數(shù)值才可計(jì)算。但全序列g(shù)roupnorm,計(jì)算時(shí)的激活tensor過大,單份顯存~26.9G = 128 * 9 * 720 * 1088 * 16 * 2 / 1024 / 1024 / 1024,如果后續(xù)增加視頻幀數(shù)或是視頻分辨率,則將對NPU的64G顯存上限是一個(gè)不小的挑戰(zhàn)。
一種方案是對group維度進(jìn)行切分,有2個(gè)關(guān)注點(diǎn):
- 序列維度合并后才能進(jìn)行g(shù)roup維度的切分,在此過程中是否會發(fā)生用于存儲中間結(jié)果的tensor無法申請到顯存的情況
- 對預(yù)訓(xùn)練的groupnorm權(quán)重進(jìn)行提取,并新建對應(yīng)的batchnorm算子,計(jì)算后,還需還原序列維度的切分狀態(tài)
另一種方案,全序列g(shù)roupnorm計(jì)算需傳輸高維張量。但從原理上,只需要獲得序列切分后的統(tǒng)計(jì)值(低維)即可,單卡獲取全量統(tǒng)計(jì)值后,可以處理自身的序列切分?jǐn)?shù)據(jù),其中有3個(gè)難點(diǎn):
- 本質(zhì)上是使用理論公式對groupnorm進(jìn)行函數(shù)級別的分拆實(shí)現(xiàn),涉及eps,有偏無偏參數(shù),統(tǒng)計(jì)值精度等細(xì)節(jié)對齊,以保證與torch.nn.GroupNorm的誤差可接受
- 因?yàn)槭且粋€(gè)FusedOp-reverse的過程,會造成速度的損失,需考量可接受度
- 與第一種方案類似,需對預(yù)訓(xùn)練的groupnorm權(quán)重進(jìn)行提?。涣硗庠谀承﹫鼍跋?,還需要實(shí)現(xiàn)backward代碼,實(shí)現(xiàn)難度較高
3.分層ZERO3
Deepspeed 的 zero-3:是一種用于深度學(xué)習(xí)優(yōu)化的技術(shù)。在分布式訓(xùn)練框架下,zero-3 將訓(xùn)練狀態(tài)(包括權(quán)重、梯度和優(yōu)化器狀態(tài))分布到不同的顯卡上,以優(yōu)化顯存利用。與 zero-2 相比,它還對模型參數(shù)進(jìn)行了分區(qū),顯存減少幅度與使用的 GPU 數(shù)量成正比。在TTV這種顯存為瓶頸的場景,可以支持更大參數(shù)量的模型,提高batch的大小,優(yōu)化訓(xùn)練效率。
參考文獻(xiàn)
- https://arxiv.org/pdf/2205.14135
- https://gitee.com/ascend/cann-ops-adv/blob/master/docs/FlashAttentionScore.md
- https://gitee.com/ascend/cann-ops-adv/blob/master/docs/FlashAttentionScoreGrad.md#https://gitee.com/link?target=https%3A%2F%2Fcreativecommons.org%2Flicenses%2Fby%2F4.0%2Flegalcode
- https://gitee.com/ascend/pytorch/blob/master/torch_npu/csrc/core/npu/NPUCachingAllocator.cpp#L240
- https://arxiv.org/abs/2205.05198
- https://arxiv.org/abs/2309.14509
- https://arxiv.org/abs/2403.10266
- https://www.alphaxiv.org/abs/2408.06072v1