混合精度下位置編碼竟有大坑,LLaMA等主流開源模型紛紛中招,百川智能給出修復(fù)方案
還得從位置編碼算法原理說(shuō)起
位置編碼技術(shù)是一種能夠讓神經(jīng)網(wǎng)絡(luò)建模句子中Token位置信息的技術(shù)。
在Transformer大行其道的時(shí)代,由于Attention結(jié)構(gòu)無(wú)法建模每個(gè)token的位置信息,位置編碼(Position Embedding)成為Transformer非常重要的一個(gè)組件。
研究人員也提出了各種各樣的位置編碼方案來(lái)讓網(wǎng)絡(luò)建模位置信息,RoPE和 Alibi 是目前最被廣泛采納的兩種位置編碼方案。
然而最近來(lái)自百川智能的研究發(fā)現(xiàn),RoPE和Alibi位置編碼的主流實(shí)現(xiàn)在低精度(尤其是bfloat16)下存在位置編碼碰撞的bug, 這可能會(huì)影響模型的訓(xùn)練和推理。
而且目前大部分主流開源模型的實(shí)現(xiàn)都存在該問(wèn)題,連llama官方代碼也中招了。

還得從位置編碼算法說(shuō)起
為了弄清楚這個(gè)問(wèn)題,得先從位置編碼的算法原理說(shuō)起。
在Transformer結(jié)構(gòu)中,所有Attention Block的輸入都會(huì)先經(jīng)過(guò)位置編碼, 再輸入網(wǎng)絡(luò)進(jìn)行后續(xù)處理。
純粹的Attention結(jié)構(gòu)是無(wú)法精確感知到每個(gè)token的位置信息的,而對(duì)于語(yǔ)言的很多任務(wù)來(lái)說(shuō),語(yǔ)句的順序?qū)φZ(yǔ)義信息的影響是非常大的,為了建模token之間的位置關(guān)系,Transfomer原始論文中引入位置編碼來(lái)建模位置信息。
圖1-施加 Positon Embedding 示意圖
為了讓模型更好地建模句子的位置信息,研究人員提出了多種位置編碼方案,Meta開源的llama模型采用了RoPE方案,使得RoPE成為在開源社區(qū)被廣泛采納的一種位置編碼方案。Alibi編碼也因?yàn)槠淞己玫耐馔菩砸脖粡V泛應(yīng)用。
了解低精度下的位置編碼碰撞之前,先來(lái)回顧一下相關(guān)算法原理
- Sinusoidal位置編碼

這是Transformer原始論文中提出的位置編碼方法。它通過(guò)使用不同頻率的正弦和余弦函數(shù)來(lái)為每個(gè)位置產(chǎn)生一個(gè)獨(dú)特的編碼。選擇三角函數(shù)來(lái)生成位置編碼有兩個(gè)良好的性質(zhì):
1)編碼相對(duì)位置信息,數(shù)學(xué)上可以證明 PE(pos+k) 可以被 PE(pos) 線性表示, 這意味著位置編碼中蘊(yùn)含了相對(duì)位置信息。
圖2-句子長(zhǎng)度為50的位置編碼,編碼維度128,每行代表一個(gè)Position Embedding
2)遠(yuǎn)程衰減:不同位置的position embedding點(diǎn)乘結(jié)果會(huì)隨著相對(duì)位置的增加而遞減。
圖3-不同位置的位置編碼點(diǎn)積可視化
RoPE
RoPE是目前開源社區(qū)應(yīng)用最廣泛的一種位置編碼方案, 通過(guò)絕對(duì)位置編碼的方式實(shí)現(xiàn)相對(duì)位置編碼,在引入相對(duì)位置信息的同時(shí)保持了絕對(duì)位置編碼的優(yōu)勢(shì)(不需要像相對(duì)位置編碼一樣去操作Attention matrix)。令f_q, f_k 為 位置編碼的函數(shù),m表示位置, x_m 表示該位置token對(duì)應(yīng)的Embedding,希望經(jīng)過(guò)位置編碼后的Embedding 點(diǎn)積僅和相對(duì)位置有關(guān),則可以有公式
上面公式中g(shù)是某個(gè)函數(shù),表示內(nèi)積的結(jié)果只和x_m 和 x_n的值,以及二者位置的相對(duì)關(guān)系(m-n)有關(guān)在2維的情況下可以推導(dǎo)出(詳細(xì)推導(dǎo)過(guò)程可參考原論文):
因?yàn)榫仃嚦朔ň€性累加的性質(zhì),可以拓展到多維的情況可得:
為了引入遠(yuǎn)程衰減的特性,Rope中theta的選取選擇了Transformer 原始論文中 sinusoidal 公式。
Alibi
- Alibi是谷歌發(fā)表在ICLR2022的一篇工作,Alibi主要解決了位置編碼外推效果差的痛點(diǎn),算法思想非常的簡(jiǎn)單,而且非常直觀。與直接加在Embedding 上的絕對(duì)位置編碼不同,Alibi的思想是在 Attention matrix上施加一個(gè)與距離成正比的懲罰偏置,懲罰偏置隨著相對(duì)距離的增加而增加。在具體實(shí)現(xiàn)時(shí),對(duì)于每個(gè)head會(huì)有一個(gè)超參m 來(lái)控制懲罰偏置隨著相對(duì)距離增加的幅度(斜率)。
圖4-Alibi attention bias示意圖
- 論文結(jié)果顯示Alibi 極大的提升了模型的外推性能,16k token 的輸入依然可以很好的支持
圖5-Alibi 外推效果對(duì)比
混合精度下位置編碼的bug
- 從上面的算法原理中,不管是RoPE 的 cos(m theta) 還是alibi 的 i-1(m, i 代表postion id), 都需要為每個(gè)位置生成一個(gè)整型的position_id, 在上下文窗口比較大的時(shí)候,百川智能發(fā)現(xiàn)目前主流的位置編碼實(shí)現(xiàn)在混合精度下都存在因?yàn)榈途龋╢loat16/bfloat16)浮點(diǎn)數(shù)表示精度不足導(dǎo)致位置編碼碰撞的問(wèn)題。尤其當(dāng)模型訓(xùn)練(推理)時(shí)上下文長(zhǎng)度越來(lái)越長(zhǎng),低精度表示帶來(lái)的位置編碼碰撞問(wèn)題越來(lái)越嚴(yán)重,進(jìn)而影響模型的效果,下面以bfloat16為例來(lái)說(shuō)明這個(gè) bug
浮點(diǎn)數(shù)表示精度
- 浮點(diǎn)數(shù)在計(jì)算機(jī)中表示由符號(hào)位(sign),指數(shù)位(exponent),尾數(shù)位(fraction) 三部分組成, 對(duì)于一個(gè)常規(guī)的數(shù)值表示,可以由如下公式來(lái)計(jì)算其代表的數(shù)值(其中offset是指數(shù)位的偏置):
- 由公式可知,尾數(shù)位的長(zhǎng)度決定了浮點(diǎn)數(shù)的表示精度。深度學(xué)習(xí)中常用的 float32/float16/bfloat16 內(nèi)存中的表示分別如下圖所示:
圖6-bfloat16 的表示格式
圖7-float16 的表示格式
圖8-float32 的表示格式
可以看到可以看到float16和bfloat16相比于float32都犧牲了表示的精度,后續(xù)以bfloat16為例說(shuō)明位置編碼中存在的問(wèn)題(float16同理)。 下表展示了bfloat16在不同數(shù)值范圍(只截取整數(shù)部分)內(nèi)的表示精度
可以看到當(dāng)整數(shù)范圍超過(guò)256,bfloat16就無(wú)法精確表示每一個(gè)整數(shù),我們可以用代碼驗(yàn)證一下表示精度帶來(lái)的問(wèn)題
RoPE& Alibi 編碼的問(wèn)題
- Meta開源的llama模型采用了RoPE的位置編碼方式,官方的實(shí)現(xiàn)(以及大部分的第三方llama系列模型)在bfloat16下存在精度問(wèn)題帶來(lái)的位置編碼碰撞(不同位置的token在bfloat16下變成同一個(gè)數(shù))。llama官方代碼如下:
- 上面第18行核心一句根據(jù)輸入序列長(zhǎng)度生成每個(gè)位置的 positon idx在bfloat16 下產(chǎn)生位置碰撞
- 在實(shí)際訓(xùn)練時(shí)如果開了bfloat16, self.inv_freq的 dtype會(huì)被轉(zhuǎn)為bfloat16, 我們可以通過(guò)簡(jiǎn)單的代碼來(lái)看一下位置碰撞的問(wèn)題
圖9-bfloat16位置碰撞示意圖
- 根據(jù)bfloat16的表示精度可知,訓(xùn)練(推理)時(shí)上下文長(zhǎng)度越長(zhǎng),位置編碼碰撞的情況越嚴(yán)重,長(zhǎng)度為8192的上下文推理中,僅有大約10%的token位置編碼是精確的,好在位置編碼碰撞有局域性的特質(zhì),只有若干個(gè)相鄰的token才會(huì)共享同一個(gè)position Embedding, 在更大的尺度上,不同位置的token 還是有一定的區(qū)分性。
圖10-不同上下文窗口下位置編碼精確token所占比例
除了RoPE位置編碼方案,百川智能發(fā)現(xiàn) Alibi 位置編碼也存在上述問(wèn)題,原因依然在于生成整數(shù)的位置索引時(shí)會(huì)在低精度下產(chǎn)生碰撞問(wèn)題。
修復(fù)方案
RoPE修復(fù)
- RoPE 的修復(fù)相對(duì)簡(jiǎn)單,只需要保證在生成 position_id的時(shí)候一定在float32的精度上即可。注意:
- float32的tensor register_buffer后在訓(xùn)練時(shí)如果開啟了bfloat16, 也會(huì)被轉(zhuǎn)為bfloat16
Alibi修復(fù)
- Alibi位置編碼修復(fù)思路和RoPE的修復(fù)思路一致,但因?yàn)锳libi的 attention bias直接加在 attention matrix上面,如果按照上面的修復(fù)思路,attention matrix的類型必須和attention bias 一致,導(dǎo)致整個(gè)attention的計(jì)算都在float32類型上計(jì)算,這會(huì)極大的拖慢訓(xùn)練速度
- 目前主流的attention加速方法flashattention不支持 attention bias參數(shù), 而 xformers要求attention bias類型必須與query.dtype相同,因此像RoPE那樣簡(jiǎn)單的將attention bias類型提升到float32將會(huì)極大的拖慢訓(xùn)練速度
- 針對(duì)該問(wèn)題百川智能提出了一種新的Alibi attention方案, 整個(gè)attention bias依然在bfloat16類型上,類似于sinusoidal的遠(yuǎn)程衰減特質(zhì),我們盡量保證臨近token位置編碼的精確性,對(duì)于相對(duì)距離過(guò)遠(yuǎn)的的token我們則可以容忍其產(chǎn)生一定的位置碰撞。原本的Alibi實(shí)現(xiàn)則相反,相對(duì)距離越遠(yuǎn)的token表示越精確,相對(duì)距離越近的token 則會(huì)碰撞
圖11- 修復(fù)前后alibi attention_bias對(duì)照
修復(fù)效果
- 此處僅在推理階段對(duì)位置編碼的精度問(wèn)題進(jìn)行修復(fù)【注:訓(xùn)練階段可能也存在問(wèn)題,取決于訓(xùn)練的具體配置和方法】,可以看到:
- 在長(zhǎng)上下文的推理中,模型的ppl 要顯著優(yōu)于修復(fù)前的ppl
- Benchmark上測(cè)試結(jié)果顯示修復(fù)前后區(qū)別不大,可能是因?yàn)閎enchmark上測(cè)試文本長(zhǎng)度有限,很少觸發(fā)Position embedding的碰撞
Benchmark對(duì)比
Perplexity對(duì)比
在通用的文本數(shù)據(jù)上對(duì)修改前后模型在中英文文本上的困惑度進(jìn)行測(cè)試,效果如下:
參考資料:
Dongxu Zhang, & Dong Wang. (2015). Relation Classification via Recurrent Neural Network.
Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N. Gomez, Lukasz Kaiser, & Illia Polosukhin. (2023). Attention Is All You Need.
Zihang Dai, Zhilin Yang, Yiming Yang, Jaime Carbonell, Quoc V. Le, & Ruslan Salakhutdinov. (2019). Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context.
Colin Raffel, Noam Shazeer, Adam Roberts, Katherine Lee, Sharan Narang, Michael Matena, Yanqi Zhou, Wei Li, & Peter J. Liu. (2020). Exploring the Limits of Transfer Learning with a Unified Text-to-Text Transformer.
Hugo Touvron, Thibaut Lavril, Gautier Izacard, Xavier Martinet, Marie-Anne Lachaux, Timothée Lacroix, Baptiste Rozière, Naman Goyal, Eric Hambro, Faisal Azhar, Aurelien Rodriguez, Armand Joulin, Edouard Grave, & Guillaume Lample. (2023). LLaMA: Open and Efficient Foundation Language Models.
Jianlin Su, Yu Lu, Shengfeng Pan, Ahmed Murtadha, Bo Wen, & Yunfeng Liu. (2022). RoFormer: Enhanced Transformer with Rotary Position Embedding.
Ofir Press, Noah A. Smith, & Mike Lewis. (2022). Train Short, Test Long: Attention with Linear Biases Enables Input Length Extrapolation.
Yutao Sun, Li Dong, Barun Patra, Shuming Ma, Shaohan Huang, Alon Benhaim, Vishrav Chaudhary, Xia Song, & Furu Wei. (2022). A Length-Extrapolatable Transformer.
https://kazemnejad.com/blog/transformer_architecture_positional_encoding/
Shouyuan Chen, Sherman Wong, Liangjian Chen, & Yuandong Tian. (2023). Extending Context Window of Large Language Models via Positional Interpolation.
https://www.reddit.com/r/LocalLLaMA/comments/14lz7j5/ntkaware_scaled_rope_allows_llama_models_to_have/