LLM 分布式訓(xùn)練六大關(guān)鍵技術(shù)介紹 原創(chuàng) 精華
編者按: 本文聚焦于分布式去中心化神經(jīng)網(wǎng)絡(luò)訓(xùn)練技術(shù),作者系統(tǒng)闡述了在大規(guī)模模型訓(xùn)練中提高硬件使用效率的創(chuàng)新方法。
文章重點(diǎn)闡述了六種關(guān)鍵的分布式訓(xùn)練技術(shù):
- 數(shù)據(jù)并行訓(xùn)練:通過將數(shù)據(jù) mini-batches 分散到多個(gè) workers,實(shí)現(xiàn)并行梯度計(jì)算和高效訓(xùn)練。
- Butterfly All-Reduce:通過創(chuàng)新的數(shù)據(jù)分割和匯總方法,有效降低通信成本。
- Gossip-Based Averaging:去中心化的通信策略,提高系統(tǒng)的容錯(cuò)性和可擴(kuò)展性。
- Moshpit Gradient Descent:允許 workers 在小型獨(dú)立組內(nèi)進(jìn)行梯度平均,增強(qiáng)訓(xùn)練的容錯(cuò)能力。
- DiLoCo:創(chuàng)新的內(nèi)外優(yōu)化算法,結(jié)合局部和全局參數(shù)更新,平衡收斂速度和系統(tǒng)性能。
- SWARM:引入動(dòng)態(tài)任務(wù)分配和容錯(cuò)機(jī)制,優(yōu)化異構(gòu)硬件環(huán)境下的資源配置。
作者 | Robert Lange
編譯 | 岳揚(yáng)
隨著人工智能技術(shù)的發(fā)展進(jìn)步,訓(xùn)練大規(guī)模神經(jīng)網(wǎng)絡(luò)(包括大語言模型)變得越來越重要。這些模型的規(guī)模和復(fù)雜度不斷提升,不僅增加了訓(xùn)練的成本和能耗,也迫切要求我們提高硬件使用效率。為了應(yīng)對(duì)這些挑戰(zhàn),研究人員和工程師們正在探索分布式去中心化訓(xùn)練方法。本文將探討多種分布式訓(xùn)練技術(shù),例如數(shù)據(jù)并行訓(xùn)練方法和 Gossip-Based Averaging 方法,展示這些技術(shù)如何在滿足該領(lǐng)域不斷增長的需求的同時(shí)優(yōu)化模型訓(xùn)練效率。
一幅以簡約日式風(fēng)格繪制的GPU集群圖,圖中加入了很多小型 GPU(由 OpenAI 的 Dallé-3 API 生成)
01 數(shù)據(jù)并行訓(xùn)練技術(shù)、全歸約操作與節(jié)點(diǎn)同步
數(shù)據(jù)并行訓(xùn)練技術(shù)通過將數(shù)據(jù)的 mini-batches 分散到多個(gè)工作節(jié)點(diǎn)(workers)上,實(shí)現(xiàn)了高效的訓(xùn)練。這種方法不僅加快了訓(xùn)練進(jìn)程,因?yàn)槎鄠€(gè) workers 可以并行計(jì)算梯度,而且還使得我們可以處理比單個(gè)設(shè)備更大的 batch sizes。為了保持所有 workers 之間的模型更新同步,我們采用了全歸約操作。該操作會(huì)將所有 workers 的梯度匯總并求平均值,然后統(tǒng)一更新模型,確保整個(gè)分布式系統(tǒng)中的模型保持一致。
以下是用 PyTorch 在 Python 中展示這一過程的一個(gè)簡單示例:
全歸約操作之外,還有一種方法是使用參數(shù)服務(wù)器(parameter server)。在這種架構(gòu)中,中央服務(wù)器負(fù)責(zé)收集梯度信息并監(jiān)控優(yōu)化器的狀態(tài)。雖然這樣做可以簡化同步過程,但同時(shí)也存在單點(diǎn)故障的風(fēng)險(xiǎn),并有可能成為系統(tǒng)性能的瓶頸。
分布式訓(xùn)練中,Hogwild(Recht et al., 2011)[1]是另一項(xiàng)著名的技術(shù)。它采用異步更新模型參數(shù)的方法,無需所有計(jì)算節(jié)點(diǎn)同步即可進(jìn)行。這種方法不僅適用于監(jiān)督學(xué)習(xí),也適用于強(qiáng)化學(xué)習(xí)(RL)場景,如異步演員-評(píng)論家算法(A3C, Mnih et al., 2016)[2]。在 A3C 中,多個(gè)智能體可以同時(shí)與環(huán)境互動(dòng),并基于各自的經(jīng)驗(yàn)異步更新同一個(gè)模型。這樣做不僅提高了資源的使用效率,還能通過多個(gè)智能體的不同經(jīng)驗(yàn)加快收斂速度,從而提高在復(fù)雜環(huán)境中的性能。
除了數(shù)據(jù)并行訓(xùn)練方法,還有模型并行和管道并行等其他并行訓(xùn)練方法(詳見 Llian Weng 的博客[3])。模型并行是將模型分割到多個(gè)計(jì)算設(shè)備上,使得模型的不同部分可以同時(shí)處理,這對(duì)于那些單個(gè)設(shè)備無法承載的超大型模型尤其有用。而管道并行則是將模型分為幾個(gè)階段,各個(gè) mini-batches 數(shù)據(jù)依次通過這些階段進(jìn)行處理,這樣做可以實(shí)現(xiàn)計(jì)算與通信的并行,從而提高訓(xùn)練的整體效率和吞吐量。這些技術(shù)互為補(bǔ)充,共同優(yōu)化了大規(guī)模訓(xùn)練場景下的資源利用。
02 Butterfly All-Reduce
Butterfly All-Reduce(Zhao和Canny,2013)技術(shù)有效地解決了傳統(tǒng)全歸約方法所面臨的挑戰(zhàn)。在這種技術(shù)中,每個(gè)參與的節(jié)點(diǎn)(共 N 個(gè))都會(huì)將其本地?cái)?shù)據(jù)分割成 N 份。然后,第 i 個(gè)節(jié)點(diǎn)會(huì)收集所有其他節(jié)點(diǎn)發(fā)來的第 i 份數(shù)據(jù),進(jìn)行匯總后,再平均分配回各個(gè)節(jié)點(diǎn)。
這種方法大幅降低了通信的負(fù)擔(dān),并提升了系統(tǒng)的可擴(kuò)展性。 在分布式訓(xùn)練中,所謂的“world size”是指參與訓(xùn)練的總進(jìn)程或設(shè)備數(shù)。這個(gè)參數(shù)對(duì)于決定如何在各個(gè)節(jié)點(diǎn)間聚合和同步數(shù)據(jù)起到了關(guān)鍵作用。
以下是對(duì) Butterfly All-Reduce 技術(shù)的一個(gè)概念性實(shí)現(xiàn)示例:
這段代碼展示了 butterfly all-reduce 技術(shù)如何在保持分布式系統(tǒng)同步的同時(shí),有效利用并行處理的優(yōu)勢。
butterfly all-reduce 方法的優(yōu)勢在于,與傳統(tǒng)的全歸約技術(shù)相比,它能夠顯著降低通信成本,并且具有更好的可擴(kuò)展性,因此非常適合用于大規(guī)模分布式系統(tǒng)。然而,這種方法也存在一些不足之處。例如,其實(shí)現(xiàn)過程較為復(fù)雜,性能可能會(huì)受到通信網(wǎng)絡(luò)拓?fù)浣Y(jié)構(gòu)和網(wǎng)絡(luò)狀況的影響。另外,如果參與節(jié)點(diǎn)中的任何一個(gè)發(fā)生故障,可能會(huì)對(duì)整個(gè)系統(tǒng)的同步過程造成影響。
在某些特定的應(yīng)用場景,尤其是聯(lián)邦學(xué)習(xí)中,訓(xùn)練過程需要能夠適應(yīng)不穩(wěn)定的網(wǎng)絡(luò)帶寬和不可靠的工作節(jié)點(diǎn)。聯(lián)邦學(xué)習(xí)尤為復(fù)雜,因?yàn)樗婕岸鄠€(gè)持有敏感隱私數(shù)據(jù)的獨(dú)立參與節(jié)點(diǎn)。這些情況要求我們必須采用穩(wěn)健的策略,以確保模型訓(xùn)練的可靠性。接下來,我們將探討一些方法,這些方法的目的是平衡收斂速度和系統(tǒng)的容錯(cuò)能力。
03 Gossip-Based Averaging
gossip-based averaging(Boyd等人,2005年)是一種去中心化的通信策略,其中的參與節(jié)點(diǎn)構(gòu)建了一個(gè)稀疏的通信網(wǎng)絡(luò)。每個(gè)節(jié)點(diǎn)定期從鄰近節(jié)點(diǎn)獲取參數(shù),并將其與自己的本地參數(shù)進(jìn)行結(jié)合。這種方式減輕了參數(shù)服務(wù)器(parameter servers)帶來的通信壓力,但也意味著每個(gè)節(jié)點(diǎn)可能會(huì)使用不同的本地參數(shù)進(jìn)行計(jì)算。
gossip-based averaging 的收斂特性深受通信網(wǎng)絡(luò)結(jié)構(gòu)的影響。以下是一個(gè)簡單的 gossip-based averaging 實(shí)現(xiàn)示例:
gossip-based averaging 具有以下優(yōu)勢:
- 減少通信瓶頸:由于不需要集中的參數(shù)服務(wù)器,gossip averaging 大幅降低了通信擁堵,使得參數(shù)更新更加高效。
- 可擴(kuò)展性:這種方法的去中心化特點(diǎn)使得它在擴(kuò)展性方面表現(xiàn)出色,能夠輕松應(yīng)對(duì)參與節(jié)點(diǎn)數(shù)量的增加,而不會(huì)產(chǎn)生過多的額外開銷。
- 容錯(cuò)性:分布式的設(shè)計(jì)提升了系統(tǒng)的容錯(cuò)能力,即使有 worker 出現(xiàn)故障,也不會(huì)中斷整個(gè)訓(xùn)練過程;其他 workers 仍可以繼續(xù)通信和更新參數(shù)。
然而,我們也需要注意到這種方法可能帶來的幾個(gè)不足之處:
- 收斂速度降低:與集中式更新方法相比,gossip averaging 的收斂速度可能會(huì)較慢,因?yàn)閰?shù)的聚合并不頻繁,每個(gè) worker 可能需要基于不太新的數(shù)據(jù)進(jìn)行計(jì)算。
- 參數(shù)更新存在分歧:由于每個(gè)節(jié)點(diǎn)使用的是不同的本地參數(shù),這可能會(huì)導(dǎo)致參數(shù)更新存在分歧,進(jìn)而影響收斂的穩(wěn)定性和速度。
- 依賴通信圖:gossip averaging 的效果在很大程度上受制于通信圖的結(jié)構(gòu)。如果圖的連通性不佳或者結(jié)構(gòu)不平衡,可能會(huì)影響到算法的整體性能。
綜合來看,盡管 gossip-based averaging 這種去中心化的參數(shù)更新方法具有很大的潛力,但在實(shí)際應(yīng)用中,我們需要根據(jù)具體的訓(xùn)練場景,權(quán)衡其利弊。
04 Moshpit Gradient Descent
Moshpit Gradient Descent(Ryabinin et al., 2021)[4]方法進(jìn)一步發(fā)展了去中心化訓(xùn)練的理念,它允許 workers 在小型且獨(dú)立的組內(nèi)進(jìn)行梯度平均。這種設(shè)計(jì)意味著,即使某個(gè)參與節(jié)點(diǎn)出現(xiàn)問題,影響的也僅限于其所在的小組,從而提高了整個(gè)訓(xùn)練過程的容錯(cuò)性,避免了全局訓(xùn)練的中斷。
這些小組的動(dòng)態(tài)構(gòu)建對(duì)于保證訓(xùn)練的有效性至關(guān)重要。通過優(yōu)化小組結(jié)構(gòu),該方法大幅減少了達(dá)到收斂所需的步驟數(shù),因?yàn)?workers 可以在較小的團(tuán)隊(duì)內(nèi)更高效地交換和更新梯度信息。這種自適應(yīng)的分組策略有助于更好地利用現(xiàn)有資源,并在不同的網(wǎng)絡(luò)環(huán)境下實(shí)現(xiàn)更優(yōu)的性能表現(xiàn)。
以下是一個(gè)實(shí)施 moshpit gradient descent 的概念性框架:
moshpit gradient descent 的優(yōu)勢包括:
- 容錯(cuò)性:個(gè)別 worker 的故障只會(huì)影響其所在的小組,不會(huì)波及整個(gè)訓(xùn)練過程,其他小組可以繼續(xù)正常訓(xùn)練。
- 資源利用效率:在較小的小組內(nèi)進(jìn)行更新,該方法能夠靈活應(yīng)對(duì)網(wǎng)絡(luò)狀況和 worker 可用性的變化,從而提升訓(xùn)練效率。
- 降低通信負(fù)擔(dān):由于通信僅限于小組內(nèi)部,整體的通信量得以減少,這在帶寬受限的情況下尤為有利。
然而,這一方法也存在一些不足之處:
- 收斂難題:小組結(jié)構(gòu)的不斷變化可能導(dǎo)致參數(shù)更新出現(xiàn)不一致,可能會(huì)使得訓(xùn)練的收斂和穩(wěn)定性面臨挑戰(zhàn)。
- 管理復(fù)雜性增加:對(duì)小組進(jìn)行動(dòng)態(tài)管理和調(diào)整,無疑增加了訓(xùn)練流程的復(fù)雜性。為了找到最佳的小組配置,我們需要開發(fā)更復(fù)雜的機(jī)制。
- 可擴(kuò)展性問題:較小的小組雖然有助于提高系統(tǒng)的容錯(cuò)性,但如果沒有有效的管理,這種方法在大規(guī)模訓(xùn)練場景中的可擴(kuò)展性可能會(huì)受限。
綜合來看,moshpit gradient descent 作為一種去中心化訓(xùn)練的新方法,其潛力不容小覷。它在容錯(cuò)能力和資源利用效率上的優(yōu)勢,與面臨的收斂難題和實(shí)施復(fù)雜性之間,實(shí)現(xiàn)了微妙的平衡。
05 DiLoCo: Inner-Outer Optimization
DiLoCo(Douillard等人,2023年)[5]帶來了一種創(chuàng)新的 inner-outer 優(yōu)化算法,旨在提高去中心化訓(xùn)練的效率。在這種算法中,每個(gè)計(jì)算節(jié)點(diǎn)在內(nèi)部優(yōu)化階段,會(huì)利用局部的 AdamW 優(yōu)化器進(jìn)行多次參數(shù)更新。這樣的設(shè)計(jì)讓節(jié)點(diǎn)能夠基于局部數(shù)據(jù)獨(dú)立優(yōu)化參數(shù),而不必實(shí)時(shí)與其他節(jié)點(diǎn)同步。當(dāng)完成了一定量(通常是500次左右的)局部更新后,便進(jìn)入外部優(yōu)化階段,此時(shí)會(huì)同步所有節(jié)點(diǎn)的偽梯度(這些梯度是局部更新結(jié)果的匯總)。
這種做法巧妙地結(jié)合了局部和全局更新的優(yōu)勢,有望加快收斂速度并提升訓(xùn)練表現(xiàn)。DiLoCo 通過讓節(jié)點(diǎn)先在局部優(yōu)化參數(shù),再與全局模型同步,充分發(fā)揮了兩種更新策略的長處。
以下是對(duì) DiLoCo 更新過程的概念性描述:
DiLoCo 最初由 Google DeepMind 實(shí)現(xiàn),而現(xiàn)在一家新興的初創(chuàng)公司 PrimeIntellect 也成功復(fù)現(xiàn)了這一方法。OpenDiLoCo(Jaghouar等人,2024年)[6]已在 GitHub[7] 上公開,借助 Hivemind 庫[8]訓(xùn)練了一個(gè) 10 億參數(shù)的模型。最近,PrimeIntellect 推出了自家研發(fā)的定制化基礎(chǔ)設(shè)施[9],其中包含了諸多工程創(chuàng)新,如定制的 all-reduce 算法和通信協(xié)議。該公司目前正在訓(xùn)練一個(gè)名為 Intellect-1[10] 的 100 億參數(shù)模型。我相信這項(xiàng)實(shí)驗(yàn)的結(jié)果將對(duì)我們突破現(xiàn)有模式產(chǎn)生深遠(yuǎn)影響。目前,大模型的訓(xùn)練還依賴于集中的計(jì)算資源。但未來,或許每個(gè)人都能為打造下一代領(lǐng)先的基礎(chǔ)模型貢獻(xiàn)力量。
06 SWARM: Fault Tolerance and Dynamic Task Assignment
SWARM 算法(Ryabinin等,2023年)[11]引入了一種新穎的分布式訓(xùn)練方法,允許每個(gè)工作節(jié)點(diǎn)在訓(xùn)練過程的后續(xù)階段將其輸出發(fā)送給其他工作節(jié)點(diǎn)。這種靈活的任務(wù)分配方式,使得計(jì)算能力較強(qiáng)的設(shè)備能夠承擔(dān)更多任務(wù),從而在多樣化的硬件環(huán)境中實(shí)現(xiàn)資源的最優(yōu)配置。這種策略在計(jì)算資源波動(dòng)較大的場景下尤為有效,可實(shí)現(xiàn)更均衡的工作量,減少閑置時(shí)間。
面對(duì)工作節(jié)點(diǎn)的故障,SWARM 算法展現(xiàn)了其容錯(cuò)能力,能夠迅速將故障節(jié)點(diǎn)的任務(wù)轉(zhuǎn)交給其他正常運(yùn)行的節(jié)點(diǎn)。這一機(jī)制對(duì)于維持訓(xùn)練流程的連貫性至關(guān)重要,它有效減少了意外中斷的影響,并確保了處理能力的及時(shí)補(bǔ)充。工作節(jié)點(diǎn)間的通信路徑是隨機(jī)且動(dòng)態(tài)調(diào)整的,這使得算法能夠根據(jù)網(wǎng)絡(luò)狀況或節(jié)點(diǎn)狀態(tài)的變動(dòng)實(shí)時(shí)調(diào)整。
通過這種自適應(yīng)的通信方式,不僅數(shù)據(jù)流轉(zhuǎn)更加高效,訓(xùn)練過程的穩(wěn)定性也得到了加強(qiáng)。下面是 SWARM 通信實(shí)現(xiàn)方式的簡化示例:
在這個(gè)示例中,每個(gè)活躍的工作節(jié)點(diǎn)隨機(jī)選取一個(gè)相鄰節(jié)點(diǎn)作為信息傳遞的對(duì)象,這樣的去中心化交流模式能夠?qū)崟r(shí)適應(yīng)當(dāng)前系統(tǒng)的狀態(tài)。SWARM 算法以其動(dòng)態(tài)任務(wù)分配和強(qiáng)大的容錯(cuò)能力,在大規(guī)模機(jī)器學(xué)習(xí)場景中顯著提高了分布式訓(xùn)練的效率和可靠性。
07 Conclusion
分布式去中心化訓(xùn)練為高效訓(xùn)練大規(guī)模神經(jīng)網(wǎng)絡(luò)提供了一個(gè)強(qiáng)有力的支撐。借助數(shù)據(jù)并行訓(xùn)練方法、butterfly all-reduce、gossip-based averaging 等手段,從業(yè)人員能夠在各種環(huán)境中應(yīng)對(duì)模型訓(xùn)練的難題。對(duì)于任何想要優(yōu)化大規(guī)模 AI 系統(tǒng)性能的人來說,掌握這些技術(shù)至關(guān)重要。隨著該領(lǐng)域研究的不斷深入,了解這些方法將是發(fā)揮分布式訓(xùn)練全部實(shí)力的關(guān)鍵。本文并非涵蓋所有分布式訓(xùn)練方法和最新研究進(jìn)展,而是提供一個(gè)粗略的概覽——因此,還請讀者自行探索更多技術(shù)細(xì)節(jié)??。
Thanks for reading!
Hope you have enjoyed and learned new things from this blog!
About the authors
Robert Lange
Deep Learning PhD @TU Berlin. Research Scientist @Sakana.AI. ?? 2x Google DeepMind Intern
END
本期互動(dòng)內(nèi)容 ??
?在分布式訓(xùn)練中,您認(rèn)為最大的技術(shù)瓶頸是什么?是通信開銷、收斂速度、還是系統(tǒng)的容錯(cuò)性,或是其他?
??文中鏈接??
[2]??https://arxiv.org/abs/1602.01783??
[3]??https://lilianweng.github.io/posts/2021-09-25-train-large/??
[4]??https://openreview.net/pdf?id=cwWfDHYpb1z??
[5]??https://arxiv.org/abs/2311.08105??
[6]??https://arxiv.org/abs/2407.07852??
[7]??https://github.com/PrimeIntellect-ai/OpenDiLoCo??
[8]??https://github.com/learning-at-home/hivemind??
[9]??https://github.com/PrimeIntellect-ai/prime??
[10]??https://www.primeintellect.ai/blog/intellect-1??
[11]??https://arxiv.org/abs/2301.11913??
原文鏈接:
