大模型分布式并行技術(shù)--數(shù)據(jù)并行優(yōu)化
通信融合
從上文知道數(shù)據(jù)并行中需要同步每一個(gè)模型梯度, 這是通過進(jìn)程間的 Allreduce 通信實(shí)現(xiàn)的。如果一個(gè)模型 有非常多的參數(shù),則數(shù)據(jù)并行訓(xùn)練的每一個(gè) step 中會(huì)有非常多次的 Allreduce 通信,下圖為融合梯度同步示例。
融合梯度同步示例
通信的耗時(shí)可以從通信延遲(lantency) 和數(shù)據(jù)傳輸時(shí)間消耗兩方面考慮。單次通信延遲時(shí)間相對(duì)固定, 而 傳輸時(shí)間由通信的數(shù)據(jù)量和帶寬決定。減少總的通信消耗, 可以通過減少通信頻率來實(shí)現(xiàn), 通信融合是一個(gè)可 行的手段,通過將 N 個(gè)梯度的 Allreduce 通信合并成一次 Allreduce 通信,可以減少 N- 1 次通信延遲時(shí)間。
常用的 Allreduce 融合實(shí)現(xiàn)方式是在通信前將多個(gè)梯度 tensors 拼接成一個(gè)內(nèi)存地址連續(xù)的大 tensor,梯度同 步時(shí)僅對(duì)拼接后的大 tensor 做一次 Allreduce 操作。參數(shù)更新時(shí)將大 tensor 切分還原回之前的多個(gè)小 tensors,完 成每個(gè)梯度對(duì)應(yīng)參數(shù)的更新。
通信計(jì)算重疊
除了降低絕對(duì)的通信耗時(shí),還可以從降低整體訓(xùn)練耗時(shí)角度來優(yōu)化,可以考慮通信和計(jì)算的異步流水實(shí)現(xiàn)。 數(shù)據(jù)并行中的梯度同步 Allreduce 通信是在訓(xùn)練的反向過程中進(jìn)行的, 而 Allreduce 后得到的同步梯度是在訓(xùn)練 的更新過程中才被使用, 在反向中并沒有被使用。也就是說上一個(gè)梯度的通信和下一個(gè)梯度的計(jì)算間并沒有依 賴,通信和計(jì)算可以并行,讓兩者的耗時(shí)相互重疊掩蓋,減少反向的耗時(shí),下圖為通信計(jì)算并行相互重疊示例。
通信計(jì)算并行相互重疊示例。
通信和計(jì)算的重疊通常是將通信和計(jì)算算子調(diào)度到不同的流 (stream) 上實(shí)現(xiàn)的。通信算子調(diào)度到通信流, 計(jì) 算算子調(diào)度到計(jì)算流, 同一個(gè)流上的算子間是順序執(zhí)行的, 不同流上的算子可以并行執(zhí)行, 從而實(shí)現(xiàn)反向中梯 度通信和計(jì)算的并行重疊。需要注意的是, 當(dāng)通信和計(jì)算被調(diào)度在不同的流上執(zhí)行時(shí), 需要考慮兩個(gè)流之間依 賴和同步關(guān)系。
- 某個(gè)梯度 Allreduce 通信進(jìn)行前,該梯度的反向計(jì)算已經(jīng)完成。
- 某個(gè)梯度對(duì)應(yīng)參數(shù)的更新計(jì)算開始前,該梯度的 Allreduce 通信已經(jīng)完成。
在梯度同步的數(shù)據(jù)并行場(chǎng)景中,開發(fā)者需要需要通過 stream 間的同步功能保證:
以上兩個(gè)方法是數(shù)據(jù)并行中常用的減少通信時(shí)間消耗, 提高并行加速比的優(yōu)化策略。如果能做到通信和計(jì) 算的重疊程度越高,那么數(shù)據(jù)并行的加速比越接近 100% ,多卡并行對(duì)訓(xùn)練吞吐提升的效率也就越高。