大模型微調(diào)非得依賴(lài)人類(lèi)數(shù)據(jù)嗎?DeepMind:用帶反饋的自訓(xùn)練更好
如你我所見(jiàn),大語(yǔ)言模型(LLM)正在改變深度學(xué)習(xí)的格局,在生成人類(lèi)質(zhì)量的文本和解決各種語(yǔ)言任務(wù)方面展現(xiàn)出了卓越的能力。雖然業(yè)界通過(guò)對(duì)人類(lèi)收集的數(shù)據(jù)進(jìn)行監(jiān)督微調(diào)進(jìn)一步提升了在具體任務(wù)上的性能,但獲取高質(zhì)量人類(lèi)數(shù)據(jù)卻面臨著重大瓶頸。這對(duì)于要解決復(fù)雜問(wèn)題的任務(wù)來(lái)說(shuō)尤為明顯,需要大量資源和專(zhuān)業(yè)知識(shí)。
怎么解決呢?模型生成得合成數(shù)據(jù)是一種有潛力的替代方案,只要能保證數(shù)據(jù)的質(zhì)量,就能實(shí)現(xiàn)可擴(kuò)展性和成本效益。
雖然 LLM 能夠自我評(píng)估生成的數(shù)據(jù),但在本文中,谷歌 DeepMind 探索了一種更簡(jiǎn)單的設(shè)置,將外部標(biāo)量反饋信號(hào)用作每個(gè)生成樣本的質(zhì)量指標(biāo)。
論文地址:https://arxiv.org/pdf/2312.06585.pdf
為了研究在模型生成數(shù)據(jù)上的訓(xùn)練,研究者考慮了一種簡(jiǎn)單但強(qiáng)大的語(yǔ)言模型自訓(xùn)練方法,僅需要兩項(xiàng)功能,一是基于模型生成樣本,二是利用評(píng)分機(jī)制對(duì)這些樣本進(jìn)行評(píng)估。
為了確保清晰度和一致性,研究者采用了一種強(qiáng)化自訓(xùn)練方法 ReST^????,并證明該方法可以將期望最大化(expectation-maximization,EM)用于強(qiáng)化學(xué)習(xí)。具體來(lái)講,ReST^????在期望和最大化步驟之間交替進(jìn)行。
- 生成(E-step):語(yǔ)言模型為每個(gè)輸入上下文生成多個(gè)輸出樣本,然后使用二元獎(jiǎng)勵(lì)過(guò)濾這些樣本以收集訓(xùn)練數(shù)據(jù)集。
- 改進(jìn)(M-step):原始語(yǔ)言模型在來(lái)自前一個(gè) E-step 的訓(xùn)練數(shù)據(jù)集上進(jìn)行監(jiān)督微調(diào),然后在下一個(gè) E-step 中使用。
研究者證實(shí),ReST^????及變體在增強(qiáng)各個(gè)領(lǐng)域的語(yǔ)言模型方面取得了成功,包括機(jī)器翻譯、語(yǔ)義分析、偏好對(duì)齊和基礎(chǔ)推理。
此外,以往工作主要將 ReST^????用于相對(duì)較小的模型(最高 70 億參數(shù)),對(duì)于較大模型的可擴(kuò)展性受限。因此,本文旨在探究模型生成的合成數(shù)據(jù)與人類(lèi)生成的數(shù)據(jù)在以下兩個(gè)具有挑戰(zhàn)性但研究較少領(lǐng)域的有效性和可擴(kuò)展性,這兩個(gè)領(lǐng)域分別是競(jìng)爭(zhēng)水平數(shù)學(xué)解題(MATH)和代碼生成(APPS)。
實(shí)證結(jié)果表明,當(dāng)將 ReST^????用于不同規(guī)模的 PaLM 2 模型時(shí),在數(shù)學(xué)推理和代碼生成任務(wù)中實(shí)現(xiàn)了顯著的能力改進(jìn)。與在人類(lèi)編寫(xiě)數(shù)據(jù)上訓(xùn)練的模型相比,在模型生成的合成數(shù)據(jù)上微調(diào)的模型取得了更大的性能增益。有趣的是,超過(guò)了一定數(shù)量的 ReST^???? 迭代后,性能會(huì)降低,這表明了在少量訓(xùn)練問(wèn)題上可能會(huì)出現(xiàn)過(guò)擬合。
此外,使用 ReST^????微調(diào)的模型提升了 pass@k 指標(biāo)和多數(shù)投票性能。這些微調(diào)后的模型在相關(guān)但 held-out 的基準(zhǔn)上也表現(xiàn)出了性能增強(qiáng),包括數(shù)學(xué)題(GSM8K 和 Hungarian HS finals)、編碼(HumanEval)和 Big-Bench Hard 任務(wù)。
總之,本文研究結(jié)果表明,具有反饋的自訓(xùn)練是減少對(duì)人類(lèi)數(shù)據(jù)依賴(lài)的一種有潛力的方法。
用于強(qiáng)化自訓(xùn)練的期望最大值(EM)
首先,該研究基于 Dayan 和 Hinton 之前的研究,用語(yǔ)言模型描述了基于 EM 的強(qiáng)化學(xué)習(xí)框架。具體而言,他們先是定義了一個(gè)二進(jìn)制最優(yōu)變量 O,使得??(??= 1|??,??)∝??(??(??,??));然后對(duì)非遞減函數(shù) ?? : ? → ?+ ,實(shí)現(xiàn)最大化觀察??= 1(獲得高獎(jiǎng)勵(lì)),得到如下公式:
然而,求解上式中的序列 ?? 的和很棘手。因而本文考慮相對(duì)于參數(shù) ?? 和變分分布 ??( ??|??) 最大化其 ELBO ??( ????, ??),而不是最大化 log ??(?? = 1; ??)。具體來(lái)說(shuō):
公式(2)中的 EM 算法在 E-step(Expectation) 和 M-step(Maximization)之間交替進(jìn)行。
ReST^????:受 EM 框架的啟發(fā),接下來(lái)論文討論了 Gulcehre 等人提出的 ReST 方法的簡(jiǎn)化版本。為了清楚起見(jiàn),本文將這種方法稱(chēng)為 ReST^????,它將 RL pipeline 中的數(shù)據(jù)收集 (E-step) 和策略?xún)?yōu)化 (M-step) 進(jìn)行解耦。如算法 1 所示:
生成(E-step):在此步驟中,該研究通過(guò)從當(dāng)前策略 ???? 中采樣輸出序列來(lái)生成數(shù)據(jù)集
。在這里,輸入是從原始數(shù)據(jù)集
中重新采樣的。然后使用二元獎(jiǎng)勵(lì)函數(shù) ??(??, ??) 對(duì)
中的輸出序列進(jìn)行評(píng)分。
改進(jìn)(M-step):在第 ??步迭代中,該研究使用 E-step 中的新數(shù)據(jù)集來(lái)微調(diào)策略 ????。不同于 Gulcehre 的研究,他們微調(diào)基本預(yù)訓(xùn)練語(yǔ)言模型,以最大限度地減少特定于任務(wù)的過(guò)度擬合并最大限度地減少與基本模型的偏差。為了進(jìn)行微調(diào),該研究最小化獎(jiǎng)勵(lì)加權(quán)負(fù)對(duì)數(shù)似然損失
。一旦策略得到改進(jìn),就可以再次創(chuàng)建質(zhì)量更好樣本的新數(shù)據(jù)集。
實(shí)驗(yàn)和分析
本文進(jìn)行實(shí)驗(yàn)的主要目標(biāo)是回答以下問(wèn)題:
- 與人類(lèi)生成的數(shù)據(jù)進(jìn)行微調(diào)相比,ReST^????的效果如何?
- 需要多少次迭代才能獲得最佳性能?ReST^????多長(zhǎng)時(shí)間會(huì)導(dǎo)致訓(xùn)練集過(guò)度擬合?
- ReST^????如何影響 pass@k 和多數(shù)投票表現(xiàn)?
- 如果用戶(hù)在特定任務(wù)上使用模型生成的數(shù)據(jù)進(jìn)行微調(diào),是否會(huì)遷移到其他任務(wù)上?在廣泛的任務(wù)中評(píng)估本文的微調(diào)模型時(shí),與基本模型相比,性能是否會(huì)下降?
- 大約需要多少輸入數(shù)據(jù)才能從 ReST^???? 獲得大部分性能提升?ReST^????的一次迭代是否足夠?
該研究使用 PaLM 2 模型和 Google Cloud 上的公共 API 進(jìn)行實(shí)驗(yàn),包括 PaLM 2-S (Bison)、PaLM 2-S* (Codey) 和 PaLM 2-L (Unicorn)。訓(xùn)練數(shù)據(jù)集采用 MATH 數(shù)據(jù)集和 APPS 數(shù)據(jù)集。
圖 2 和圖 3 分別顯示了 ReST^????在 MATH 和 APPS 數(shù)據(jù)集上訓(xùn)練的性能??梢缘贸?MATH 受益于 ReST^???? 的多次迭代,無(wú)論是在 MATH 測(cè)試集上的性能還是遷移到 GSM8K 方面。另一方面可以看到 APPS 的大部分收益來(lái)自第一次迭代,而執(zhí)行更多次迭代會(huì)導(dǎo)致 APPS 和 HumanEval 的性能下降。
訓(xùn)練和測(cè)試性能的差距。圖 4 顯示,雖然訓(xùn)練集性能隨著 ReST^????迭代次數(shù)線(xiàn)性增加,但測(cè)試集性能卻沒(méi)有。對(duì)于 MATH,第一次迭代后測(cè)試性能改進(jìn)很小,而對(duì)于 APPS,在第二次迭代中觀察到性能回歸。該研究猜測(cè)性能的回歸可能是由于過(guò)度擬合造成的。由于 APPS 數(shù)據(jù)集的大小約為 MATH 數(shù)據(jù)集的三分之一,因此它更容易受到此問(wèn)題的影響。
圖 5 顯示了 Palm-2-L 模型在 pass@K 指標(biāo)上的性能。結(jié)果顯示,微調(diào)后獲得的 ReST^???? 模型對(duì)于所有 K 值都更強(qiáng),其中性能差距通常在 K=1 時(shí)最大。