谷歌大腦重磅研究:首個具有O(nlogn)時間、O(n)空間復(fù)雜度可微分排序算法,速度快出一個數(shù)量級
魚羊 十三 發(fā)自 凹非寺
量子位 報(bào)道 | 公眾號 QbitAI
快排堆排冒泡排。排序,在計(jì)算機(jī)中是再常見不過的算法。
在機(jī)器學(xué)習(xí)中,排序也經(jīng)常用于統(tǒng)計(jì)數(shù)據(jù)、信息檢索等領(lǐng)域。
那么問題來了,排序算法在函數(shù)角度上是分段線性的,也就是說,在幾個分段的“節(jié)點(diǎn)”處是不可微的。這樣,就給反向傳播造成了困難。
現(xiàn)在,谷歌大腦針對這一問題,提出了一種快速可微分排序算法,并且,時間復(fù)雜度達(dá)到了O(nlogn),空間復(fù)雜度達(dá)為O(n)。
速度比現(xiàn)有方法快出一個數(shù)量級!
代碼的PyTorch、TensorFlow和JAX版本即將開源。
快速可微分排序算法
現(xiàn)代深度學(xué)習(xí)架構(gòu)通常是通過組合參數(shù)化功能塊來構(gòu)建,并使用梯度反向傳播進(jìn)行端到端的訓(xùn)練。
這也就激發(fā)了像LeCun提出的可微分編程?(differentiable programming)的概念。
雖然在經(jīng)驗(yàn)上取得了較大的成功,但是許多操作仍舊存在不可微分的問題,這就限制了可以計(jì)算梯度的體系結(jié)構(gòu)集。
諸如此類的操作就包括排序?(sorting)和排名?(ranking)。
從函數(shù)角度來看都是分段線性函數(shù),排序的問題在于,它的向量包含許多不可微分的“節(jié)點(diǎn)”,而排名的秩要比排序還要麻煩。
首先將排序和排名操作轉(zhuǎn)換為在排列多面體(permutahedron)上的線性過程,如下圖所示。
△排列多面體說明
在這一過程后,可以發(fā)現(xiàn)對于r(θ),若是θ出現(xiàn)微小“擾動”,就會導(dǎo)致線性程序跳轉(zhuǎn)到另外一個排序,使得r(θ)不連續(xù)。
也就意味著導(dǎo)數(shù)要么為null,要么就是“未定義”,這就阻礙了梯度反向傳播。
為了解決上述的問題,就需要對排序和排名運(yùn)算符,進(jìn)行有效可計(jì)算的近似設(shè)計(jì)。
谷歌大腦團(tuán)隊(duì)提出的方法,就是通過在線性規(guī)劃公式中引入強(qiáng)凸正則化來實(shí)現(xiàn)這一目標(biāo)。
這就讓它們轉(zhuǎn)換成高效可計(jì)算的投影算子(projection operator),可微分,且服從于形式分析(formal analysis)。
在投影到排列多面體之后,可以根據(jù)這些投影來定義軟排序(soft sorting)和軟排名(soft ranking)操作符。
△軟排序和軟排名操作符
在此基礎(chǔ)上,要想完成快速計(jì)算和微分,一個關(guān)鍵步驟就是將投影簡化為保序優(yōu)化?(isotonic optimization)。
接下來是將保序優(yōu)化進(jìn)行微分,此處采用的是雅可比矩陣(Jacobian),因?yàn)樗唵蔚膲K級結(jié)構(gòu),使得導(dǎo)數(shù)很容易分析。
而后,結(jié)合命題3和引理2,可以描述投影到排列多面體上的雅可比矩陣。
需要強(qiáng)調(diào)的是,與保序優(yōu)化的雅可比矩陣不同,投影的雅可比矩陣不是塊對角的,因?yàn)槲覀冃枰獙λ男泻土羞M(jìn)行轉(zhuǎn)置。
最終,可以用O(n)時間和空間中的軟算子雅可比矩陣相乘。
實(shí)驗(yàn)結(jié)果
研究人員在CIFAR-10和CIFAR-100數(shù)據(jù)集上進(jìn)行了實(shí)驗(yàn)。
實(shí)驗(yàn)使用的CNN,包含4個具有2個最大池化層的Conv2D,RelU激活,2個完全連接層;ADAM優(yōu)化器的步長恒定為10-4,k=1。
與之比較的是O(Tn2)的OT方法,以及O(n2)的All-pairs方法。
△rQ及rE為新算法
結(jié)果表明,在CIFAR-10和CIFAR-100上,新算法都達(dá)到了與OT方法相當(dāng)?shù)木?,并且速度明顯更快。
在CIFAR-100上訓(xùn)練600個epoch,OT耗費(fèi)的時間為29小時,rQ為21小時,rE為23小時,All-pairs為16小時。在CIFAR-10上結(jié)果差不多。
在驗(yàn)證輸入尺寸對運(yùn)行時間的影響時,研究人員使用的是64GB RAM的6核Intel Xeon W-2135,以及GeForce GTX 1080Ti。
禁用反向傳播的情況下,進(jìn)行1個batch的計(jì)算,OT和All-pairs分別在n=2000和n=3000的時候出現(xiàn)內(nèi)存不足。
啟用反向傳播時,OT和All-pairs分別在n=1000和n=2500的時候出現(xiàn)內(nèi)存不足。
開啟新的可能性
曾就職于谷歌、NASA的機(jī)器學(xué)習(xí)工程師Brad Neuberg認(rèn)為,從機(jī)器學(xué)習(xí)的角度來說,快速可微分排序、排名算法看上去十分重要。
而谷歌的這一新排序算法,也在reddit和hacker news等平臺上引起了熱烈的討論。
有網(wǎng)友對其帶來的“新可能性”做出了更為詳細(xì)的討論:
我想,可微分排序生成的梯度信息量更大,使得梯度下降的速度更快,從而能夠進(jìn)一步提升訓(xùn)練速度。
我認(rèn)為,這意味著某些基于排名的指標(biāo),以后可以用可微分的形式來表示。也就是說,神經(jīng)網(wǎng)絡(luò)可以輕松地針對這些結(jié)果直接進(jìn)行優(yōu)化。
對于谷歌而言,這很顯然會應(yīng)用于網(wǎng)絡(luò)搜索,以及諸如標(biāo)簽分配之類的東西問題。
也有網(wǎng)友指出,雖然該算法并不是第一個解決了排序不可微問題的方法,但它的效率無疑更高。
傳送門
論文:https://arxiv.org/pdf/2002.08871.pdf
討論:https://news.ycombinator.com/item?id=22393790https://www.reddit.com/r/MachineLearning/comments/f85yp4/r_fast_differentiable_sorting_and_ranking/
— 完 —