DeepMind發(fā)布神經(jīng)網(wǎng)絡(luò)、強(qiáng)化學(xué)習(xí)庫,網(wǎng)友:推動JAX發(fā)展
十三 發(fā)自 凹非寺
量子位 報道 | 公眾號 QbitAI
DeepMind今日發(fā)布了Haiku和RLax兩個庫,都是基于JAX。
JAX由谷歌提出,是TensorFlow的簡化庫。結(jié)合了針對線性代數(shù)的編譯器XLA,和自動區(qū)分本地 Python 和 Numpy 代碼的庫Autograd,在高性能的機(jī)器學(xué)習(xí)研究中使用。
而此次發(fā)布的兩個庫,分別針對神經(jīng)網(wǎng)絡(luò)和強(qiáng)化學(xué)習(xí),大幅簡化了JAX的使用。
Haiku是基于JAX的神經(jīng)網(wǎng)絡(luò)庫,允許用戶使用熟悉的面向?qū)ο蟪绦蛟O(shè)計模型,可完全訪問 JAX 的純函數(shù)變換。
RLax是JAX頂層的庫,它提供了用于實現(xiàn)增強(qiáng)學(xué)習(xí)代理的有用構(gòu)件。
有意思的是,Reddit網(wǎng)友驚奇的發(fā)現(xiàn)Haiku這個庫的名字,竟然不以“ax”結(jié)尾。
當(dāng)然,也有網(wǎng)友對這兩個庫表示了肯定:
毫無疑問,對JAX起到了推動作用。
那么,我們就來看下Haiku和RLex的廬山真面目吧。
Haiku
Haiku是JAX的神經(jīng)網(wǎng)絡(luò)庫,它允許用戶使用熟悉的面向?qū)ο缶幊棠P?,同時允許完全訪問JAX的純函數(shù)轉(zhuǎn)換。
它提供了兩個核心工具:模塊抽象hk.Module,和一個簡單的函數(shù)轉(zhuǎn)換hk.transform。
hk.Module是Python對象,包含對其自身參數(shù)、其他模塊和對用戶輸入應(yīng)用函數(shù)方法的引用。
hk.transform允許完全訪問JAX的純函數(shù)轉(zhuǎn)換。
其實,在JAX中有許多神經(jīng)網(wǎng)絡(luò)庫,那么Haiku有什么特別之處呢?有5點(diǎn)。
1、Haiku已經(jīng)由DeepMind的研究人員進(jìn)行了大規(guī)模測試
DeepMind相對容易地在Haiku和JAX中復(fù)制了許多實驗。其中包括圖像和語言處理的大規(guī)模結(jié)果、生成模型和強(qiáng)化學(xué)習(xí)。
2、Haiku是一個庫,而不是一個框架
它的設(shè)計是為了簡化一些具體的事情,包括管理模型參數(shù)和其他模型狀態(tài)??梢耘c其他庫一起編寫,并與JAX的其他部分一起工作。
3、Haiku并不是另起爐灶
它建立在Sonnet的編程模型和API之上,Sonnet是DeepMind幾乎普遍采用的神經(jīng)網(wǎng)絡(luò)庫。它保留了Sonnet用于狀態(tài)管理的基于模塊的編程模型,同時保留了對JAX函數(shù)轉(zhuǎn)換的訪問。
4、過渡到Haiku是比較容易的
通過精心的設(shè)計,從TensorFlow和Sonnet,過渡到JAX和Haiku是比較容易的。除了新的函數(shù)(如hk.transform),Haiku的目的是Sonnet 2的API。
5、Haiku簡化了JAX
它提供了一個處理隨機(jī)數(shù)的簡單模型。在轉(zhuǎn)換后的函數(shù)中,hk.next_rng_key()返回一個唯一的rng鍵。
那么,該如何安裝Haiku呢?
Haiku是用純Python編寫的,但是通過JAX依賴于c++代碼。
首先,按照下方鏈接中的說明,安裝帶有相關(guān)加速器支持的JAX。https://github.com/google/jax#installation
然后,只需要一句簡單的pip命令就可以完成安裝。
$?pip?install?git+https://github.com/deepmind/haiku
接下來,是一個神經(jīng)網(wǎng)絡(luò)和損失函數(shù)的例子。
import?haiku?as?hk
import?jax.numpy?as?jnp
def?softmax_cross_entropy(logits,?labels):
??one_hot?=?hk.one_hot(labels,?logits.shape[-1])
??return?-jnp.sum(jax.nn.log_softmax(logits)?*?one_hot,?axis=-1)
def?loss_fn(images,?labels):
??model?=?hk.Sequential([
??????hk.Linear(1000),
??????jax.nn.relu,
??????hk.Linear(100),
??????jax.nn.relu,
??????hk.Linear(10),
??])
??logits?=?model(images)
??return?jnp.mean(softmax_cross_entropy(logits,?labels))
loss_obj?=?hk.transform(loss_fn)
RLax
RLax是JAX頂層的庫,它提供了用于實現(xiàn)增強(qiáng)學(xué)習(xí)代理的有用構(gòu)件。
它所提供的操作和函數(shù)不是完整的算法,而是強(qiáng)化學(xué)習(xí)特定數(shù)學(xué)操作的實現(xiàn)。
RLax的安裝也非常簡單,一個pip命令就可以搞定。
pip?install?git+git://github.com/deepmind/rlax.git
使用JAX的jax.jit函數(shù),所有的RLax代碼可以不同的硬件上編譯。
RLax需要注意的是它的命名規(guī)則。
許多函數(shù)在連續(xù)的時間步長中考慮策略、操作、獎勵和值,以便計算它們的輸出。在這種情況下,后綴_t和tm1通常是為了說明每個輸入是在哪個步驟上生成的,例如:
q_tm1:轉(zhuǎn)換的源狀態(tài)中的操作值。a_tm1:在源狀態(tài)下選擇的操作。r_t:在目標(biāo)狀態(tài)下收集的結(jié)果獎勵。q_t:目標(biāo)狀態(tài)下的操作值。
Haiku和RLax都已在GitHub上開源,有興趣的讀者可從“傳送門”的鏈接訪問。
傳送門
Haiku:https://github.com/deepmind/haiku
RLax:https://github.com/deepmind/rlax
- 商湯林達(dá)華萬字長文回答AGI:4層破壁,3大挑戰(zhàn)2025-08-12
- 商湯多模態(tài)大模型賦能鐵路勘察設(shè)計,讓70年經(jīng)驗“活”起來2025-08-13
- 以“具身智能基座”為核,睿爾曼攜全產(chǎn)品矩陣及新品亮相2025 WRC2025-08-11
- 哇塞,今天北京被機(jī)器人人人人人塞滿了!2025-08-08