1個GPU幾分鐘搞定強化學(xué)習(xí)訓(xùn)練,谷歌新引擎讓深度學(xué)習(xí)提速1000倍丨開源
博雯 發(fā)自 凹非寺
量子位 報道 | 公眾號 QbitAI
機器人要如何完成這樣一個動作?
我們一般會基于強化學(xué)習(xí),在仿真環(huán)境中進(jìn)行模擬訓(xùn)練。
這時,如果在一臺機器的CPU環(huán)境下進(jìn)行模擬訓(xùn)練,那么需要幾個小時到幾天。
但現(xiàn)在,只需一個TPU/GPU,就能和數(shù)千個CPU或GPU的計算集群的速度一樣快,直接將所需時間縮短到幾分鐘!
相當(dāng)于將強化學(xué)習(xí)的速度提升了1000倍!
這就是來自谷歌的科學(xué)家們開發(fā)的物理模擬引擎Brax。
三種策略避免邏輯分支
現(xiàn)在大多數(shù)的物理模擬引擎都是怎么設(shè)計的呢?
將重力、電機驅(qū)動、關(guān)節(jié)約束、物體碰撞等任務(wù)都整合在一個模擬器中,并行地進(jìn)行多個模擬,以此來逼近現(xiàn)實中的運動系統(tǒng)。
這種情況下,每個模擬器中的計算都不相同,且數(shù)據(jù)必須在數(shù)據(jù)中心內(nèi)通過網(wǎng)絡(luò)傳輸。
這種并行布局也就導(dǎo)致了較高的延遲時間——即學(xué)習(xí)者可能需要超過10000納秒的等待時間,才能從模擬器中獲得經(jīng)驗。
那么怎樣才能縮短這種延遲時間呢?
Brax選擇通過避免模擬中的分支來保證數(shù)千個并行環(huán)境中的計算完全統(tǒng)一,進(jìn)而降低整個訓(xùn)練架構(gòu)的復(fù)雜度。
直到復(fù)雜度降低到可以在單一的TPU或GPU上執(zhí)行,跨機器通信的計算開銷就隨之降低,延遲也就能被有效消除。
主要分為以下三個方法:
- 連續(xù)函數(shù)替換離散分支邏輯
比如,在計算一個小球與墻壁之間的接觸力時,就產(chǎn)生了一個分支:
如果球接觸墻壁,就執(zhí)行模擬球從墻壁反彈的獨立代碼;
否則,就執(zhí)行其他代碼;
這里就可以通過符號距離函數(shù)來避免這種if/else的離散分支邏輯的產(chǎn)生。
- 使用JAX即時編譯中評估分支
在仿真時間之前評估基于環(huán)境靜態(tài)屬性的分支,例如兩個物體是否有可能發(fā)生碰撞。
- 在模擬中只選擇需要的分支結(jié)果
在使用了這三種策略之后,我們就得到了一個模擬由剛體、關(guān)節(jié)、執(zhí)行器組成環(huán)境的物理引擎。
同時也是一種實現(xiàn)在這種環(huán)境中各類操作(如進(jìn)化策略,直接軌跡優(yōu)化等)的學(xué)習(xí)算法。
那么Brax的性能究竟如何呢?
速度最高提升1000倍
Brax測試所用的基準(zhǔn)是OpenAI Gym中Ant、HalfCheetah、Humanoid、Reacher四種。
同時也增加了三個新環(huán)境:包括對物理的靈巧操作、通用運動(例如前往周圍任何一個放置了物體的地點)、以及工業(yè)機器人手臂的模擬:
研究人員首先測試了Brax在并行模擬越來越多的環(huán)境時,可以產(chǎn)生多少次物理步驟(也即對環(huán)境狀態(tài)的更新)。
測試結(jié)果中的TPUv3 8×8曲線顯示,Brax可以在多個設(shè)備之間進(jìn)行無縫擴展,每秒可達(dá)到數(shù)億個物理步驟:
而不僅是在TPU上,從V100和P100曲線也能看出,Brax在高端GPU上同樣表現(xiàn)出色。
然后就是Brax在單個工作站(workstation)上運行一個強化學(xué)習(xí)實驗所需要的時間。
在這里,研究人員將基于Ant基準(zhǔn)環(huán)境訓(xùn)練的Brax引擎與MuJoCo物理引擎做了對比:
可以看到,相對于MuJoCo(藍(lán)線)所需的將近3小時時間,使用了Brax的加速器硬件最快只需要10秒。
使用Brax,不僅能夠提高單核訓(xùn)練的效率,還可以擴展到大規(guī)模的并行模擬訓(xùn)練。
論文地址:
https://arxiv.org/abs/2106.13281
下載:
https://github.com/google/brax
參考鏈接:
https://ai.googleblog.com/2021/07/speeding-up-reinforcement-learning-with.html
- 有道智能學(xué)習(xí)燈發(fā)布,通過“桌面學(xué)習(xí)分析引擎”實現(xiàn)全球最快指尖查詞2022-04-08
- 科學(xué)證明:狗勾真的懂你有多累,聽到聲音0.25秒后就知道你是誰,對人比對狗更親近2022-04-14
- 在M1芯片上跑原生Linux:編譯速度比macOS還快40%2022-04-05
- 小學(xué)生們在B站講算法,網(wǎng)友:我只會阿巴阿巴2022-03-28