矩陣乘法可以算得更快了!港中文10頁論文證明:能源、時間均可節(jié)省
天下苦大模型矩陣乘法久矣。
畢竟不論是訓(xùn)練還是推理過程,矩陣乘法作為最主要的計算操作之一,往往都需要消耗大量的算力。
那么就沒有一種更“快、好、省”的方法來搞這事兒嗎?
有的,香港中文大學(xué)最新一篇僅10頁的論文,便提出了一種新算法:
- 能源可節(jié)?。?%-10%
- 時間可節(jié)?。?%
論文作者之一的Dmitry Rybin表示:
這項研究對數(shù)據(jù)分析、芯片設(shè)計、無線通信和LLM訓(xùn)練都有著深遠的影響!
這么算矩陣乘法,更快!
矩陣乘法是計算機科學(xué)和數(shù)值線性代數(shù)中的核心問題之一。
自從Strassen和Winograd的開創(chuàng)性工作以來,研究者們一直在探索如何減少矩陣乘法所需的計算量。
盡管這類運算在統(tǒng)計、數(shù)據(jù)分析、深度學(xué)習(xí)和無線通信等領(lǐng)域有著廣泛應(yīng)用,例如協(xié)方差矩陣的計算和線性回歸中的關(guān)鍵步驟,但對于具有特殊結(jié)構(gòu)的矩陣乘法(如計算矩陣與其轉(zhuǎn)置的乘積XXt)的研究相對較少。
從理論角度看,計算XXt與一般矩陣乘法具有相同的漸近復(fù)雜度,因此只能通過常數(shù)因子優(yōu)化來提升速度。
因此,這篇論文《XXt Can Be Faster》提出了一種名為RXTX的新算法,通過結(jié)合機器學(xué)習(xí)搜索方法和組合優(yōu)化技術(shù),顯著提升了XXt的計算效率。
我們先來了解一下RXTX。
整體來看,這個基于4×4分塊矩陣的遞歸乘法,通過機器學(xué)習(xí)搜索與組合優(yōu)化相結(jié)合的方法發(fā)現(xiàn)。
算法主要包含以下關(guān)鍵步驟:
- 分塊與遞歸調(diào)用:將矩陣X劃分為16個4×4子塊,通過8次遞歸調(diào)用處理子問題,并計算26個一般矩陣乘積m1至m26。
2.對稱乘積計算:直接計算8個子塊的對稱乘積s1至m8。
3.結(jié)果組合:通過線性組合上述乘積結(jié)果,得到最終的XXt矩陣各分塊元素C11至C44。
與此前最先進的算法(基 Strassen的遞歸分治)相比,RXTX的遞歸關(guān)系式為 R(n)=8R(n/4) + 26M(n/4),而原算法為 S(n) = 4S(n/2) + 2M(n/2)。
這一設(shè)計使得RXTX的漸近乘法常數(shù)為 26/41≈0.6341,比原算法的2/3≈0.6667降低了約5%。
接下來,我們來看下乘法次數(shù)與運算總量分析。
通過論文中的定理1的推導(dǎo),RXTX的乘法次數(shù)表達式為:
實驗數(shù)據(jù)表明,當n為4的冪次時,RXTX的乘法次數(shù)比原算法低5%,且隨著n增大,這一優(yōu)勢持續(xù)保持:
通過優(yōu)化加法步驟(利用公共子表達式減少加法次數(shù)),RXTX的總運算量表達式為:
而原算法的總運算量包含對數(shù)項,導(dǎo)致其增長更快。
實驗顯示,當n≥256時,RXTX的總運算量優(yōu)于原算法;當n≥1024時,顯著優(yōu)于樸素算法:
在6144×6144矩陣的測試中,RXTX的平均運行時間為2.524秒,比BLAS的默認實現(xiàn)快9%,且在99%的測試中表現(xiàn)更優(yōu):
盡管運行時間受硬件和內(nèi)存管理影響,但理論分析表明,當n≥256時,RXTX即可展現(xiàn)速度優(yōu)勢。
值得一提的是,RXTX的發(fā)現(xiàn)得益于機器學(xué)習(xí)與組合優(yōu)化的結(jié)合,具體流程如下:
- RL代理生成候選乘積:通過強化學(xué)習(xí)策略生成大量可能的秩-1雙線性乘積。
- MILP枚舉與篩選:
a.MILP-A:枚舉候選乘積與目標表達式(XXt的各分塊)之間的線性關(guān)系。
b.MILP-B:選擇最小的乘積子集,確保所有目標表達式可通過線性組合表示。
- 大鄰域搜索迭代:通過迭代優(yōu)化,逐步減少冗余乘積,提升算法效率。
這一方法借鑒了AlphaTensor的思路,但通過限制候選空間為二維張量,顯著降低了計算復(fù)雜度,使得MILP求解器(如 Gurobi)能夠高效處理。