清華AIR ModelMerging:無需訓(xùn)練數(shù)據(jù)!合并多個模型實現(xiàn)任意場景的感知(ECCV'24)
近日,來自清華大學(xué)智能產(chǎn)業(yè)研究院(AIR)助理教授趙昊老師的團(tuán)隊,聯(lián)合戴姆勒公司,提出了一種無需訓(xùn)練的多域感知模型融合新方法。研究重點關(guān)注場景理解模型的多目標(biāo)域自適應(yīng),并提出了一個挑戰(zhàn)性的問題:如何在無需訓(xùn)練數(shù)據(jù)的條件下,合并在不同域上獨立訓(xùn)練的模型實現(xiàn)跨領(lǐng)域的感知能力?團(tuán)隊給出了“Merging Parameters + Merging Buffers”的解決方案,這一方法簡單有效,在無須訪問訓(xùn)練數(shù)據(jù)的條件下,能夠?qū)崿F(xiàn)與多目標(biāo)域數(shù)據(jù)混合訓(xùn)練相當(dāng)?shù)慕Y(jié)果。
論文題目:
Training-Free Model Merging for Multi-target Domain Adaptation
作者:Wenyi Li, Huan-ang Gao, Mingju Gao, Beiwen Tian, Rong Zhi, Hao Zhao
1 背景介紹
一個適用于世界各地自動駕駛場景的感知模型,需要能夠在各個領(lǐng)域(比如不同時間、天氣和城市)中都輸出可靠的結(jié)果。然而,典型的監(jiān)督學(xué)習(xí)方法嚴(yán)重依賴于需要大量人力標(biāo)注的像素級注釋,這嚴(yán)重阻礙了這些場景的可擴(kuò)展性。因此,多目標(biāo)域自適應(yīng)(Multi-target Domain Adaptation, MTDA)的研究變得越來越重要。多目標(biāo)域自適應(yīng)通過設(shè)計某種策略,在訓(xùn)練期間同時利用來自多個目標(biāo)域的無標(biāo)簽數(shù)據(jù)以及源域的有標(biāo)簽合成數(shù)據(jù),來增強(qiáng)這些模型在不同目標(biāo)域上的魯棒性。
與傳統(tǒng)的單目標(biāo)域自適應(yīng) (Single-target Domain Adaptation, STDA)相比,MTDA 面臨更大的挑戰(zhàn)——一個模型需要在多個目標(biāo)域中都能很好工作。為了解決這個問題,以前的方法采用了各種專家模型之間的一致性學(xué)習(xí)和在線知識蒸餾來構(gòu)建各目標(biāo)域通用的學(xué)生模型。盡管如此,這些方法的一個重大限制是它們需要同時使用所有目標(biāo)數(shù)據(jù),如圖1(b) 所示。
但是,同時訪問到所有目標(biāo)數(shù)據(jù)是不切實際的。一方面原因是數(shù)據(jù)傳輸成本限制,因為包含數(shù)千張圖像的數(shù)據(jù)集可能會達(dá)到數(shù)百 GB。另一方面,從數(shù)據(jù)隱私保護(hù)的角度出發(fā),不同地域間自動駕駛街景數(shù)據(jù)的共享或傳輸可能會受到限制。面對這些挑戰(zhàn),在本文中,我們聚焦于一個全新的問題,如圖1(c) 所示。我們的研究任務(wù)仍然是MTDA,但我們并沒有來自多個目標(biāo)域的數(shù)據(jù),而是只能獲得各自獨立訓(xùn)練的模型。我們的目標(biāo)是,通過某種融合方式,將這些模型集成為一個能夠適用于各個目標(biāo)域的模型。
圖1:不同實驗設(shè)置的對比
2 方法
如何將多個模型合并為一個,同時保留它們在各自領(lǐng)域的能力?我們提出的解決方案主要包括兩部分:Merging Parameters(即可學(xué)習(xí)層的weight和bias)和 Merging Buffers(即normalization layers的參數(shù))。在第一階段,我們從針對不同單目標(biāo)域的無監(jiān)督域自適應(yīng)模型中,得到訓(xùn)練后的感知模型。然后,在第二階段,利用我們提出的方法,在無須獲取任何訓(xùn)練數(shù)據(jù)的條件下,只對模型做合并,得到一個在多目標(biāo)域都能工作的感知模型。
圖2:整體實驗流程
下面,我們將詳細(xì)介紹這兩種合并的技術(shù)細(xì)節(jié)和研究動機(jī)。
2.1 Merging Parameters
2.1.1 Permutation-based的方法出現(xiàn)退化
事實上,如何將模型之間可學(xué)習(xí)層的 weight 和 bias 合并一直是一個前沿研究領(lǐng)域。在之前的工作中,有一種稱為基于置換 (Permutation-based) 的方法。這些方法基于這樣的假設(shè):當(dāng)考慮神經(jīng)網(wǎng)絡(luò)隱藏層的所有潛在排列對稱性時,loss landscape 通常形成單個盆地(single basin)。因此,在合并模型參數(shù) 和時,這類方法的主要目標(biāo)是找到一組置換變換 ,確保 在功能上等同于 ,同時也位于參考模型 附近的近似凸盆地(convex basin)內(nèi)。之后,通過簡單的中點合并 以獲得一個合并后的模型 ,該模型能夠表現(xiàn)出比單個模型更好的泛化能力,
在我們的實驗中,模型 和 在第一階段都使用相同的網(wǎng)絡(luò)架構(gòu)進(jìn)行訓(xùn)練,并且,源數(shù)據(jù)都使用相同的合成圖像和標(biāo)簽。我們最初嘗試采用了一種 Permutation-based 的代表性方法——Git Re-Basin,該方法將尋找置換對稱變換的問題轉(zhuǎn)化為線性分配問題 (LAP),是目前最高效實用的算法。
圖3:Git Re-basin和mid-point的實驗結(jié)果對比
但是,如圖3所示,我們的實驗結(jié)果出乎意料地表明,不同網(wǎng)絡(luò)架構(gòu)(ResNet50、ResNet101 和 MiT-B5)下 Git Re-Basin 的性能與簡單中點合并相同。進(jìn)一步的研究表明,Git Re-Basin 發(fā)現(xiàn)的排列變換在解決 LAP 的迭代中保持相同的排列,這表明在我們的領(lǐng)域適應(yīng)場景下,Git Re-Basin 退化為一種簡單的中點合并方法。
2.1.2 線性模式連通性的分析
我們從線性模式連通性(linear mode connectivity)的視角進(jìn)一步研究上述退化問題。具體來說,我們使用連續(xù)曲線 在參數(shù)空間中連接模型 和模型 。在這種特定情況下,我們考慮如下線性路徑,
接下來,我們通過對 做插值遍歷評估模型的性能。為了衡量這些模型在兩個指定目標(biāo)域(分別表示為 和 )上的有效性,我們使用調(diào)和平均值 (Harmonic Mean) 作為主要評估指標(biāo),
我們之所以選擇調(diào)和平均值作為指標(biāo),是因為它能夠賦予較小的值更大的權(quán)重,這能夠更好應(yīng)對世界各地各個城市中最差的情況。它有效地懲罰了模型在一個目標(biāo)域(例如,在發(fā)達(dá)的大城市)的表現(xiàn)異常高,而其他目標(biāo)域(例如,在第三世界鄉(xiāng)村)表現(xiàn)低的情況。不同插值的實驗結(jié)果如圖4(a)所示?!癈S”和“IDD”分別表示目標(biāo)數(shù)據(jù)集 Cityscapes 和 Indian Driving Dataset。
圖4:線性模式連通性的分析實驗
2.1.3 理解線性模式連通性的原因
在上述實驗結(jié)果的基礎(chǔ)上,我們進(jìn)一步探究:在先前域自適應(yīng)方法中觀察到的線性模式連通性,背后的根本原因是什么?為此,我們進(jìn)行了消融實驗,來研究第一階段訓(xùn)練 和 期間的幾個影響因素。
- 合成數(shù)據(jù)。使用相同的合成數(shù)據(jù)可以作為兩個域之間的橋梁。為了評估這一點,我們將合成數(shù)據(jù)集 GTA 中的訓(xùn)練數(shù)據(jù)劃分為兩個不同的非重疊子集,每個子子集包含原始訓(xùn)練樣本的 30%。在劃分過程中,我們將合成數(shù)據(jù)集提供的具有相同場景標(biāo)識的圖像分組到同一個子集中,而具有顯著差異的場景則放在單獨的子集中。我們使用這兩個不同子集分別作為源域,訓(xùn)練兩個單目標(biāo)域自適應(yīng)模型(目標(biāo)域為 CityScapes 數(shù)據(jù)集)。隨后,我們研究這兩個 STDA 模型的線性模式連通性。結(jié)果如圖 4(b) 所示,可以觀察到,在參數(shù)空間內(nèi)連接兩個模型的線性曲線上,性能沒有明顯下降。這一觀察結(jié)果表明,使用相同的合成數(shù)據(jù)并不是影響線性模式連通性的主要因素。
- 自訓(xùn)練架構(gòu)。使用教師-學(xué)生模型可能會將最后的模型限制在 loss landscape 的同一 basin 中。為了評估這種可能性,我們禁用了教師模型的指數(shù)移動平均 (EMA) 更新。相應(yīng)地,我們在每次迭代中將學(xué)生權(quán)重直接復(fù)制到教師模型中。隨后,我們繼續(xù)訓(xùn)練兩個單目標(biāo)域自適應(yīng)模型,分別利用 GTA 作為源域,Cityscapes 和 IDD 作為目標(biāo)域。然后,我們研究在參數(shù)空間內(nèi)連接兩個模型的線性曲線,結(jié)果如圖 4(c) 所示。我們可以看到線性模式連接屬性保持不變。
- 初始化和預(yù)訓(xùn)練。 使用相同的預(yù)訓(xùn)練權(quán)重初始化 backbone 的做法,可能會使模型在訓(xùn)練過程中難以擺脫的某一 basin。為了驗證這種潛在情況,我們初始化兩個具有不同權(quán)重的獨立 backbone,然后繼續(xù)針對 Cityscapes 和 IDD 進(jìn)行域自適應(yīng)。在評估兩個收斂模型之間的線性插值模型時,我們觀察到性能明顯下降,如圖 4(d) 所示。為了更深入地了解潛在因素,我們繼續(xù)探究,是相同的初始權(quán)重,還是預(yù)訓(xùn)練過程導(dǎo)致了這種影響? 我們初始化兩個具有相同權(quán)重但沒有預(yù)訓(xùn)練的主干,然后再次進(jìn)行實驗。有趣的是,我們發(fā)現(xiàn),在參數(shù)空間的線性連接曲線仍然遇到了巨大的性能障礙,如圖 4(e) 所示。這意味著預(yù)訓(xùn)練過程在模型中的線性模式連接方面起著關(guān)鍵作用。
2.1.4 關(guān)于合并參數(shù)的小結(jié)
我們通過大量實驗證明,當(dāng)領(lǐng)域自適應(yīng)模型從相同的預(yù)訓(xùn)練權(quán)重開始時,模型可以有效地過渡到不同的目標(biāo)領(lǐng)域,同時仍然保持參數(shù)空間中的線性模式連通性。因此,這些訓(xùn)練模型可以通過簡單的中點合并,得到在兩個領(lǐng)域都有效的合并模型。
2.2 Merging Buffers
Buffers,即批量歸一化 (BN) 層的均值和方差,與數(shù)據(jù)域密切相關(guān)。因為數(shù)據(jù)不同的方差和均值代表了域的某些特定特征。在合并模型時如何有效地合并 Buffers 的問題通常被忽視,因為現(xiàn)有方法主要探究如何合并在同一域內(nèi)的不同子集上訓(xùn)練的兩個模型。在這樣的前提下,之前的合并方法不考慮 Buffers 是合理的,因為來自任何給定模型的 Buffers 都可以被視為對整個總體的無偏估計,盡管它完全來自隨機(jī)數(shù)據(jù)子樣本。
但是,在我們的實驗環(huán)境中,我們正在研究如何合并在完全不同的目標(biāo)域中訓(xùn)練的兩個模型,這使得 Buffers 合并的問題不再簡單。由于我們假設(shè)在模型 A 和模型 B 的合并階段無法訪問任何形式的訓(xùn)練數(shù)據(jù),因此我們可用的信息僅限于 Buffers 集 。其中, 表示 BN 層的數(shù)量,而 、 和 分別表示第 層的平均值、標(biāo)準(zhǔn)差和 tracked 的批次數(shù)。生成 BN 層的統(tǒng)計數(shù)據(jù)如下:
以上方程背后的原理可以解釋如下:引入 BN 層是為了緩解內(nèi)部協(xié)變量偏移(internal covariate shift)問題,其中輸入的均值和方差在通過內(nèi)部可學(xué)習(xí)層時會發(fā)生變化。在這種情況下,我們的基本假設(shè)是,后續(xù)可學(xué)習(xí)層合并的 BN 層的輸出遵循正態(tài)分布。由于生成的 BN 層保持符合高斯先驗的輸入歸納偏差,我們根據(jù)從 和 得到的結(jié)果估計 和 。如圖5所示,我們獲得了從該高斯先驗中采樣的兩組數(shù)據(jù)點的均值和方差,以及這些集合的大小。我們利用這些值來估計該分布的參數(shù)。
圖5:合并BN層的示意圖
當(dāng)將 Merging Buffers 方法擴(kuò)展到 個高斯分布時,tracked 的批次數(shù) 、均值的加權(quán)平均值 和方差的加權(quán)平均值可以按如下方式計算。
3 實驗與結(jié)果
3.1 數(shù)據(jù)集
在多目標(biāo)域適應(yīng)實驗中,我們使用 GTA 和 SYNTHIA 作為合成數(shù)據(jù)集,并使用 Cityscapes 、Indian Driving Dataset 、ACDC 和 DarkZurich 的作為目標(biāo)域真實數(shù)據(jù)集。在訓(xùn)練單個領(lǐng)域自適應(yīng)模型時,使用帶有標(biāo)記的源域數(shù)據(jù)和無標(biāo)記的目標(biāo)域數(shù)據(jù)。接下來,我們采用所提出的模型融合技術(shù),直接從訓(xùn)練好的模型出發(fā)構(gòu)建混合模型,這個過程中無需使用訓(xùn)練數(shù)據(jù)。
3.2 與Baseline模型的比較
在實驗中,我們將我們的模型融合方法在 MTDA 任務(wù)上的結(jié)果與幾種 baseline 模型進(jìn)行對比。baseline 模型包括數(shù)據(jù)組合(Data Comb.)方法,其中單個域自適應(yīng)模型在來自兩個目標(biāo)域的混合數(shù)據(jù)上進(jìn)行訓(xùn)練(這個baseline僅供參考,因為它們與我們關(guān)于數(shù)據(jù)傳輸帶寬和數(shù)據(jù)隱私問題的設(shè)定相矛盾)。baseline 模型還包括單目標(biāo)域自適應(yīng)(STDA),即為單一目標(biāo)域訓(xùn)練的自適應(yīng)模型,評估其在兩個域上的泛化能力。
表1:與Baseline模型的比較
表 1 展示了基于 CNN 架構(gòu)的 ResNet101和基于 Transformer 架構(gòu)的 MiT-B5 的結(jié)果。與最好的單目標(biāo)域自適應(yīng)模型相比,當(dāng)將我們的方法分別應(yīng)用于 ResNet101 和 MiT-B5 兩種不同 Backbone 時,在兩個目標(biāo)域上性能的調(diào)和平均值分別提高 +4.2% 和 +1.2%。值得注意的是,這種性能水平(ResNet101架構(gòu)下的調(diào)和平均值為 56.3%)已經(jīng)與數(shù)據(jù)組合(Data Comb.)方法(56.2%)相當(dāng),而且我們無需訪問任何訓(xùn)練數(shù)據(jù)即可實現(xiàn)這一目標(biāo)。
此外,我們探索了一種更為寬松的條件,其中僅合并 Encoder backbone,而 decoder head 則針對各個下游域進(jìn)行分離。值得注意的是,這種條件下,分別使兩種 backbone 下的調(diào)和平均性能顯著提高 +5.6% 和 +2.5%。我們還發(fā)現(xiàn),我們提出的方法在大多數(shù)類別中能夠始終實現(xiàn)最佳調(diào)和平均,這表明它能夠增強(qiáng)全局適應(yīng)性,而不是偏向某些類別。
3.3 與SoTA模型的比較
我們首先將我們的方法與 GTACityscapes 任務(wù)上的單目標(biāo)域自適應(yīng) (STDA) 進(jìn)行比較,如表 2 所示。值得注意的是,我們的方法可以應(yīng)用于任何這些方法,只要它們使用相同的預(yù)訓(xùn)練權(quán)重適應(yīng)不同的域。這使我們能夠使用單個模型推廣到所有目標(biāo)域,同時保持 STDA 方法相對優(yōu)越的性能。
表2:與SoTA模型的比較
我們還將我們的方法與表 2 中的域泛化(DG)方法進(jìn)行了比較,域泛化旨在將在源域上訓(xùn)練的模型推廣到多個看不見的目標(biāo)域。我們的方法無需額外的技巧,只需利用參數(shù)空間的線性模式連接即可實現(xiàn)卓越的性能。在多目標(biāo)域自適應(yīng)領(lǐng)域,我們的方法也取得了領(lǐng)先。我們不需要對多個學(xué)生模型做顯式的域間一致性正則化或知識提煉,但能使 STDA 方法中的技術(shù)(如多分辨率訓(xùn)練)能夠輕松轉(zhuǎn)移到 MTDA 任務(wù)??梢杂^察到,我們對 MTDA 任務(wù)的最佳結(jié)果做出了的顯著改進(jìn),同時消除了對訓(xùn)練數(shù)據(jù)的依賴。
3.4 多目標(biāo)域拓展
我們還擴(kuò)展了我們的模型融合技術(shù),以涵蓋四個不同的目標(biāo)領(lǐng)域:Cityscapes 、IDD 、ACDC 和 DarkZurich 。每個領(lǐng)域都面臨著獨特的挑戰(zhàn)和特點:Cityscapes 主要關(guān)注歐洲城市環(huán)境,IDD 主要體現(xiàn)印度道路場景,ACDC 主要針對霧、雨或雪等惡劣天氣條件,DarkZurich 則主要處理夜間道路場景。我們對針對每個領(lǐng)域單獨訓(xùn)練后的模型,以及用我們的方法融合后的模型進(jìn)行了全面評估。
表3:在4個目標(biāo)域上的實驗結(jié)果
如表 3 所示,我們提出的模型融合技術(shù)表現(xiàn)出顯著的性能提升。雖然我們將來自單獨訓(xùn)練模型的調(diào)和平均值最高的方法作為比較的基線,但所有基于模型融合的方法都優(yōu)于它,性能增長高達(dá) +5.8%。此外,盡管合并來自多個不同領(lǐng)域模型的復(fù)雜性不斷增加,但我們觀察到所有領(lǐng)域的整體性能并沒有明顯下降。通過進(jìn)一步分析,我們發(fā)現(xiàn)我們的方法能夠簡化領(lǐng)域一致性的復(fù)雜性。現(xiàn)有的域間一致性正則化和在線知識提煉方法的復(fù)雜度為 ,而我們的方法可以將其減少到更高效的 ,其中 表示考慮的目標(biāo)域數(shù)量。
3.5 消融實驗
我們使用 ResNet101 和 MiT-B5 作為分割網(wǎng)絡(luò)中的圖像編碼器,對我們提出的 Merging Parameters 和 Merging Buffers 方法進(jìn)行了消融研究,結(jié)果如表 4 所示。我們觀察到單目標(biāo)域自適應(yīng) (STDA) 模型在不同域中的泛化能力存在差異,這主要源于所用目標(biāo)數(shù)據(jù)集的多樣性和質(zhì)量差異。盡管如此,我們還是選擇 STDA 模型中的最高的調(diào)和平均值作為比較基線。
表4:消融實驗
表 4(a) 和 4(b) 中的數(shù)據(jù)顯示,采用簡單的中點合并方法對參數(shù)進(jìn)行處理,可使模型的泛化能力提高 +2.7% 和 +0.6%。此外,當(dāng)結(jié)合 Merging Buffers 時,這種性能的增強(qiáng)會進(jìn)一步放大到 +4.2% 和+1.2%。我們還觀察到 MiT-B5 作為 backbone 時的一個有趣現(xiàn)象:在 IDD 域中進(jìn)行評估時,融合模型的表現(xiàn)優(yōu)于單目標(biāo)自適應(yīng)模型。這一發(fā)現(xiàn)意味著模型可以從其他域獲取域不變的知識。這些結(jié)果表明,我們提出的模型融合技術(shù)的每個部分都是有效的。
3.6 模型融合在分類任務(wù)上的應(yīng)用
我們還通過實驗驗證了我們所提出的模型融合方法在圖像分類任務(wù)上的有效性。通過將 CIFAR-100 分類數(shù)據(jù)集劃分為兩個不同的、不重疊的子集,我們在這些子集上獨立訓(xùn)練兩個 ResNet50 模型,標(biāo)記為 A 和 B。這種訓(xùn)練要么從一組共同的預(yù)訓(xùn)練權(quán)重中進(jìn)行,要么從兩組隨機(jī)初始化的權(quán)重中進(jìn)行。模型 A 和 B 的性能結(jié)果如圖 6 所示。結(jié)果表明,從相同的預(yù)訓(xùn)練權(quán)重進(jìn)行融合的模型優(yōu)于在任何單個子集上訓(xùn)練的模型。相反,當(dāng)從隨機(jī)初始化的權(quán)重開始時,單個模型表現(xiàn)出學(xué)習(xí)能力,而合并模型的性能類似于隨機(jī)猜測。
圖6:CIFAR-100 分類任務(wù)上的模型融合結(jié)果
隨機(jī)初始化會破壞模型線性平均性,而相同的預(yù)訓(xùn)練主干會導(dǎo)致線性模式連接。我們在另一個預(yù)訓(xùn)練權(quán)重上再次驗證了這個結(jié)論。圖 7 中的結(jié)果表明,DINO 預(yù)訓(xùn)練和 ImageNet 預(yù)訓(xùn)練在模型參數(shù)空間中具有不同的loss landscape,模型的融合必須在相同的loss landscape內(nèi)進(jìn)行。
圖7:ImageNet和DINO預(yù)訓(xùn)練權(quán)重對線性模式連接的影響
4 結(jié)論
本文介紹了一種新穎的模型融合策略,旨在解決多目標(biāo)域自適應(yīng) (MTDA)問題,同時無需依賴訓(xùn)練數(shù)據(jù)。研究結(jié)果表明,在大量數(shù)據(jù)集上進(jìn)行預(yù)訓(xùn)練時,基于 CNN 的神經(jīng)網(wǎng)絡(luò)和基于 Transformer 的視覺模型都可以將微調(diào)后模型限制在 loss landscape 的相同 basin 中。我們還強(qiáng)調(diào)了 Buffers 的合并在 MTDA 中的重要性,因為 Buffers 是捕獲各個域獨特特征的關(guān)鍵。我們所提出的模型融合方法簡單而高效,在 MTDA 基準(zhǔn)上取得了最好的評測性能。我們期待本文所提出的模型融合方法能夠激發(fā)未來更多關(guān)于這個領(lǐng)域的探索。