1美元訓練BERT,教你如何薅谷歌TPU羊毛|附Colab代碼
曉查 發(fā)自 凹非寺
量子位 出品 | 公眾號 QbitAI
BERT是谷歌去年推出的NLP模型,一經推出就在各項測試中碾壓競爭對手,而且BERT是開源的。只可惜訓練BERT的價格實在太高,讓人望而卻步。
之前需要用64個TPU訓練4天才能完成,后來谷歌用并行計算優(yōu)化了到只需一個多小時,但是需要的TPU數量陡增,達到了驚人的1024個。
那么總共要多少錢呢?谷歌云TPU的使用價格是每個每小時6.5美元,訓練完成訓練完整個模型需要近4萬美元,簡直就是天價。
現在,有個羊毛告訴你,在培養(yǎng)基上有人找到了薅谷歌羊毛的辦法,只需1美元就能訓練BERT,模型還能留存在你的谷歌云盤中,留作以后使用。
準備工作
為了薅谷歌的羊毛,您需要一個Google云存儲(Google Cloud Storage)空間。按照Google云TPU快速入門指南,創(chuàng)建Google云平臺(Google Cloud Platform)帳戶和Google云存儲賬戶。新的谷歌云平臺用戶可獲得300美元的免費贈送金額。
在TPUv2上預訓練BERT-Base模型大約需要54小時.Google Colab并非設計用于執(zhí)行長時間運行的作業(yè),它會每8小時左右中斷一次訓練過程。對于不間斷的訓練,請考慮使用付費的不間斷使用TPUv2的方法。
也就是說,使用Colab TPU,你可以在以1美元的價格在谷云盤上存儲模型和數據,以幾乎可忽略成本從頭開始預訓練BERT模型。
以下是整個過程的代碼下面的代碼,可以在Colab Jupyter環(huán)境中運行。
設置訓練環(huán)境
首先,安裝訓練模型所需的包.Jupyter允許使用直接從筆記本執(zhí)行的bash命令 ‘!’:
!pip install sentencepiece
!git clone https://github.com/google-research/bert
導入包并在谷歌云中授權:
import os
import sys
import json
import nltk
import random
import logging
import tensorflow as tf
import sentencepiece as spm
from glob import glob
from google.colab import auth, drive
from tensorflow.keras.utils import Progbar
sys.path.append("bert")
from bert import modeling, optimization, tokenization
from bert.run_pretraining import input_fn_builder, model_fn_builder
auth.authenticate_user()
# configure logging
log = logging.getLogger('tensorflow')
log.setLevel(logging.INFO)
# create formatter and add it to the handlers
formatter = logging.Formatter('%(asctime)s : %(message)s')
sh = logging.StreamHandler()
sh.setLevel(logging.INFO)
sh.setFormatter(formatter)
log.handlers = [sh]
if 'COLAB_TPU_ADDR' in os.environ:
log.info("Using TPU runtime")
USE_TPU = True
TPU_ADDRESS = 'grpc://' + os.environ['COLAB_TPU_ADDR']
with tf.Session(TPU_ADDRESS) as session:
log.info('TPU address is ' + TPU_ADDRESS)
# Upload credentials to TPU.
with open('/content/adc.json', 'r') as f:
auth_info = json.load(f)
tf.contrib.cloud.configure_gcs(session, credentials=auth_info)
else:
log.warning('Not connected to TPU runtime')
USE_TPU = False
下載原始文本數據
接下來從網絡上獲取文本數據語料庫。在本次實驗中,我們使用OpenSubtitles數據集,該數據集包括65種語言。
與更常用的文本數據集(如維基百科)不同,它不需要任何復雜的預處理,提供預格式化,一行一個句子。
AVAILABLE = {'af','ar','bg','bn','br','bs','ca','cs',
'da','de','el','en','eo','es','et','eu',
'fa','fi','fr','gl','he','hi','hr','hu',
'hy','id','is','it','ja','ka','kk','ko',
'lt','lv','mk','ml','ms','nl','no','pl',
'pt','pt_br','ro','ru','si','sk','sl','sq',
'sr','sv','ta','te','th','tl','tr','uk',
'ur','vi','ze_en','ze_zh','zh','zh_cn',
'zh_en','zh_tw','zh_zh'}
LANG_CODE = "en" #@param {type:"string"}
assert LANG_CODE in AVAILABLE, "Invalid language code selected"
!wget http://opus.nlpl.eu/download.php?f=OpenSubtitles/v2016/mono/OpenSubtitles.raw.'$LANG_CODE'.gz -O dataset.txt.gz
!gzip -d dataset.txt.gz
!tail dataset.txt
你可以通過設置代碼隨意選擇你需要的語言。出于演示目的,代碼只默認使用整個語料庫的一小部分。在實際訓練模型時,請務必取消選中DEMO_MODE復選框,使用大100倍的數據集。
當然,100M數據足以訓練出相當不錯的BERT基礎模型。
DEMO_MODE = True #@param {type:"boolean"}
if DEMO_MODE:
CORPUS_SIZE = 1000000
else:
CORPUS_SIZE = 100000000 #@param {type: "integer"}
!(head -n $CORPUS_SIZE dataset.txt) > subdataset.txt
!mv subdataset.txt dataset.txt
預處理文本數據
我們下載的原始文本數據包含標點符號,大寫字母和非UTF符號,我們將在繼續(xù)下一步之前將其刪除。在推理期間,我們將對新數據應用相同的過程。
如果你需要不同的預處理方式(例如在推理期間預期會出現大寫字母或標點符號),請修改以下代碼以滿足你的需求。
regex_tokenizer = nltk.RegexpTokenizer("\w+")
def normalize_text(text):
# lowercase text
text = str(text).lower()
# remove non-UTF
text = text.encode("utf-8", "ignore").decode()
# remove punktuation symbols
text = " ".join(regex_tokenizer.tokenize(text))
return text
def count_lines(filename):
count = 0
with open(filename) as fi:
for line in fi:
count += 1
return count
現在讓我們預處理整個數據集:
RAW_DATA_FPATH = "dataset.txt" #@param {type: "string"}
PRC_DATA_FPATH = "proc_dataset.txt" #@param {type: "string"}
# apply normalization to the dataset
# this will take a minute or two
total_lines = count_lines(RAW_DATA_FPATH)
bar = Progbar(total_lines)
with open(RAW_DATA_FPATH,encoding="utf-8") as fi:
with open(PRC_DATA_FPATH, "w",encoding="utf-8") as fo:
for l in fi:
fo.write(normalize_text(l)+"\n")
bar.add(1)
構建詞匯表
下一步,我們將訓練模型學習一個新的詞匯表,用于表示我們的數據集。
BERT文件使用WordPiece分詞器,在開源中不可用。我們將在單字模式下使用SentencePiece分詞器。雖然它與BERT不直接兼容,但是通過一個小的處理方法,可以使它工作。
SentencePiece需要相當多的運行內存,因此在Colab中的運行完整數據集會導致內核崩潰。
為避免這種情況,我們將隨機對數據集的一小部分進行子采樣,構建詞匯表。另一個選擇是使用更大內存的機器來執(zhí)行此步驟。
此外,SentencePiece默認情況下將BOS和EOS控制符號添加到詞匯表中。我們通過將其索引設置為-1來禁用它們。
VOC_SIZE的典型值介于32000和128000之間。如果想要更新詞匯表,并在預訓練階段結束后對模型進行微調,我們會保留NUM_PLACEHOLDERS個令牌。
MODEL_PREFIX = "tokenizer" #@param {type: "string"}
VOC_SIZE = 32000 #@param {type:"integer"}
SUBSAMPLE_SIZE = 12800000 #@param {type:"integer"}
NUM_PLACEHOLDERS = 256 #@param {type:"integer"}
SPM_COMMAND = ('--input={} --model_prefix={} '
'--vocab_size={} --input_sentence_size={} '
'--shuffle_input_sentence=true '
'--bos_id=-1 --eos_id=-1').format(
PRC_DATA_FPATH, MODEL_PREFIX,
VOC_SIZE - NUM_PLACEHOLDERS, SUBSAMPLE_SIZE)
spm.SentencePieceTrainer.Train(SPM_COMMAND)
現在,讓我們看看如何讓SentencePiece在BERT模型上工作。
下面是使用來自官方的預訓練英語BERT基礎模型的WordPiece詞匯表標記的語句。
>>> wordpiece.tokenize("Colorless geothermal substations are generating furiously")
['color',
'##less',
'geo',
'##thermal',
'sub',
'##station',
'##s',
'are',
'generating',
'furiously']
WordPiece標記器在“##”的單詞中間預置了出現的子字。在單詞開頭出現的子詞不變。如果子詞出現在單詞的開頭和中間,則兩個版本(帶和不帶” ##’)都會添加到詞匯表中。
SentencePiece創(chuàng)建了兩個文件:tokenizer.model和tokenizer.vocab讓我們來看看它學到的詞匯:
def read_sentencepiece_vocab(filepath):
voc = []
with open(filepath, encoding='utf-8') as fi:
for line in fi:
voc.append(line.split("\t")[0])
# skip the first <unk> token
voc = voc[1:]
return voc
snt_vocab = read_sentencepiece_vocab("{}.vocab".format(MODEL_PREFIX))
print("Learnt vocab size: {}".format(len(snt_vocab)))
print("Sample tokens: {}".format(random.sample(snt_vocab, 10)))
運行結果:
Learnt vocab size: 31743
Sample tokens: ['▁cafe', '▁slippery', 'xious', '▁resonate', '▁terrier', '▁feat', '▁frequencies', 'ainty', '▁punning', 'modern']
SentencePiece與WordPiece的運行結果完全相反從文檔中可以看出:SentencePiece首先使用元符號“_”將空格轉義為空格,如下所示:
Hello_World。
然后文本被分段為小塊:
[Hello] [_Wor] [ld] [.]
在空格之后出現的子詞(也是大多數詞開頭的子詞)前面加上“_”,而其他子詞不變。這排除了僅出現在句子開頭而不是其他地方的子詞。然而,這些案件應該非常罕見。
因此,為了獲得類似于WordPiece的詞匯表,我們需要執(zhí)行一個簡單的轉換,從包含它的標記中刪除“_”,并將“##”添加到不包含它的標記中。
我們還添加了一些BERT架構所需的特殊控制符號。按照慣例,我們把它們放在詞匯的開頭。
另外,我們在詞匯表中添加了一些占位符標記。
如果你希望使用新的用于特定任務的令牌來更新預先訓練的模型,那么這些方法是很有用的。
在這種情況下,占位符標記被替換為新的令牌,重新生成預訓練數據,并且對新數據進行微調。
def parse_sentencepiece_token(token):
if token.startswith("▁"):
return token[1:]
else:
return "##" + token
bert_vocab = list(map(parse_sentencepiece_token, snt_vocab))
ctrl_symbols = ["[PAD]","[UNK]","[CLS]","[SEP]","[MASK]"]
bert_vocab = ctrl_symbols + bert_vocab
bert_vocab += ["[UNUSED_{}]".format(i) for i in range(VOC_SIZE - len(bert_vocab))]
print(len(bert_vocab))
最后,我們將獲得的詞匯表寫入文件。
VOC_FNAME = "vocab.txt" #@param {type:"string"}
with open(VOC_FNAME, "w") as fo:
for token in bert_vocab:
fo.write(token+"\n")
現在,讓我們看看新詞匯在實踐中是如何運作的:
>>> testcase = "Colorless geothermal substations are generating furiously"
>>> bert_tokenizer = tokenization.FullTokenizer(VOC_FNAME)
>>> bert_tokenizer.tokenize(testcase)
['color',
'##less',
'geo',
'##ther',
'##mal',
'sub',
'##station',
'##s',
'are',
'generat',
'##ing',
'furious',
'##ly']
創(chuàng)建分片預訓練數據(生成預訓練數據)
通過手頭的詞匯表,我們可以為BERT模型生成預訓練數據。
由于我們的數據集可能非常大,我們將其拆分為碎片:
mkdir ./shards
split -a 4 -l 256000 -d $PRC_DATA_FPATH ./shards/shard_
現在,對于每個部分,我們需要從BERT倉庫調用create_pretraining_data.py腳本,需要使用xargs的命令。
在開始生成之前,我們需要設置一些參數傳遞給腳本。你可以從自述文件中找到有關它們含義的更多信息。
MAX_SEQ_LENGTH = 128 #@param {type:"integer"}
MASKED_LM_PROB = 0.15 #@param
MAX_PREDICTIONS = 20 #@param {type:"integer"}
DO_LOWER_CASE = True #@param {type:"boolean"}
PRETRAINING_DIR = "pretraining_data" #@param {type:"string"}
# controls how many parallel processes xargs can create
PROCESSES = 2 #@param {type:"integer"}
運行此操作可能需要相當長的時間,具體取決于數據集的大小。
XARGS_CMD = ("ls ./shards/ | "
"xargs -n 1 -P {} -I{} "
"python3 bert/create_pretraining_data.py "
"--input_file=./shards/{} "
"--output_file={}/{}.tfrecord "
"--vocab_file={} "
"--do_lower_case={} "
"--max_predictions_per_seq={} "
"--max_seq_length={} "
"--masked_lm_prob={} "
"--random_seed=34 "
"--dupe_factor=5")
XARGS_CMD = XARGS_CMD.format(PROCESSES, '{}', '{}', PRETRAINING_DIR, '{}',
VOC_FNAME, DO_LOWER_CASE,
MAX_PREDICTIONS, MAX_SEQ_LENGTH, MASKED_LM_PROB)
tf.gfile.MkDir(PRETRAINING_DIR)
!$XARGS_CMD
為數據和模型設置GCS存儲,將數據和模型存儲到云端
為了保留來之不易的訓練模型,我們會將其保留在谷歌云存儲中。
在谷歌云存儲中創(chuàng)建兩個目錄,一個用于數據,一個用于模型。在模型目錄中,我們將放置模型詞匯表和配置文件。
在繼續(xù)操作之前,請配置BUCKET_NAME變量,否則將無法訓練模型。
BUCKET_NAME = "bert_resourses" #@param {type:"string"}
MODEL_DIR = "bert_model" #@param {type:"string"}
tf.gfile.MkDir(MODEL_DIR)
if not BUCKET_NAME:
log.warning("WARNING: BUCKET_NAME is not set. "
"You will not be able to train the model.")
下面是BERT基的超參數配置示例:
# use this for BERT-base
bert_base_config = {
"attention_probs_dropout_prob": 0.1,
"directionality": "bidi",
"hidden_act": "gelu",
"hidden_dropout_prob": 0.1,
"hidden_size": 768,
"initializer_range": 0.02,
"intermediate_size": 3072,
"max_position_embeddings": 512,
"num_attention_heads": 12,
"num_hidden_layers": 12,
"pooler_fc_size": 768,
"pooler_num_attention_heads": 12,
"pooler_num_fc_layers": 3,
"pooler_size_per_head": 128,
"pooler_type": "first_token_transform",
"type_vocab_size": 2,
"vocab_size": VOC_SIZE
}
with open("{}/bert_config.json".format(MODEL_DIR), "w") as fo:
json.dump(bert_base_config, fo, indent=2)
with open("{}/{}".format(MODEL_DIR, VOC_FNAME), "w") as fo:
for token in bert_vocab:
fo.write(token+"\n")
現在,我們已準備好將模型和數據存儲到谷歌云當中:
if BUCKET_NAME:
!gsutil -m cp -r $MODEL_DIR $PRETRAINING_DIR gs://$BUCKET_NAME
在云TPU上訓練模型
注意,之前步驟中的某些參數在此處不用改變。請確保在整個實驗中設置的參數完全相同。
BUCKET_NAME = "bert_resourses" #@param {type:"string"}
MODEL_DIR = "bert_model" #@param {type:"string"}
PRETRAINING_DIR = "pretraining_data" #@param {type:"string"}
VOC_FNAME = "vocab.txt" #@param {type:"string"}
# Input data pipeline config
TRAIN_BATCH_SIZE = 128 #@param {type:"integer"}
MAX_PREDICTIONS = 20 #@param {type:"integer"}
MAX_SEQ_LENGTH = 128 #@param {type:"integer"}
MASKED_LM_PROB = 0.15 #@param
# Training procedure config
EVAL_BATCH_SIZE = 64
LEARNING_RATE = 2e-5
TRAIN_STEPS = 1000000 #@param {type:"integer"}
SAVE_CHECKPOINTS_STEPS = 2500 #@param {type:"integer"}
NUM_TPU_CORES = 8
if BUCKET_NAME:
BUCKET_PATH = "gs://{}".format(BUCKET_NAME)
else:
BUCKET_PATH = "."
BERT_GCS_DIR = "{}/{}".format(BUCKET_PATH, MODEL_DIR)
DATA_GCS_DIR = "{}/{}".format(BUCKET_PATH, PRETRAINING_DIR)
VOCAB_FILE = os.path.join(BERT_GCS_DIR, VOC_FNAME)
CONFIG_FILE = os.path.join(BERT_GCS_DIR, "bert_config.json")
INIT_CHECKPOINT = tf.train.latest_checkpoint(BERT_GCS_DIR)
bert_config = modeling.BertConfig.from_json_file(CONFIG_FILE)
input_files = tf.gfile.Glob(os.path.join(DATA_GCS_DIR,'*tfrecord'))
log.info("Using checkpoint: {}".format(INIT_CHECKPOINT))
log.info("Using {} data shards".format(len(input_files)))
準備訓練運行配置,建立評估器和輸入函數,啟動BERT!
model_fn = model_fn_builder(
bert_config=bert_config,
init_checkpoint=INIT_CHECKPOINT,
learning_rate=LEARNING_RATE,
num_train_steps=TRAIN_STEPS,
num_warmup_steps=10,
use_tpu=USE_TPU,
use_one_hot_embeddings=True)
tpu_cluster_resolver = tf.contrib.cluster_resolver.TPUClusterResolver(TPU_ADDRESS)
run_config = tf.contrib.tpu.RunConfig(
cluster=tpu_cluster_resolver,
model_dir=BERT_GCS_DIR,
save_checkpoints_steps=SAVE_CHECKPOINTS_STEPS,
tpu_config=tf.contrib.tpu.TPUConfig(
iterations_per_loop=SAVE_CHECKPOINTS_STEPS,
num_shards=NUM_TPU_CORES,
per_host_input_for_training=tf.contrib.tpu.InputPipelineConfig.PER_HOST_V2))
estimator = tf.contrib.tpu.TPUEstimator(
use_tpu=USE_TPU,
model_fn=model_fn,
config=run_config,
train_batch_size=TRAIN_BATCH_SIZE,
eval_batch_size=EVAL_BATCH_SIZE)
train_input_fn = input_fn_builder(
input_files=input_files,
max_seq_length=MAX_SEQ_LENGTH,
max_predictions_per_seq=MAX_PREDICTIONS,
is_training=True)
執(zhí)行!
estimator.train(input_fn=train_input_fn, max_steps=TRAIN_STEPS)
最后,使用默認參數訓練模型需要100萬步,約54小時的運行時間。如果內核由于某種原因重新啟動,可以從斷點處繼續(xù)訓練。
以上就是是在云TPU上從頭開始預訓練BERT的指南。
下一步
好的,我們已經訓練好了模型,接下來可以做什么?
如圖1所示,使用預訓練的模型作為通用的自然語言理解模塊;
2,針對某些特定的分類任務微調模型;
3,使用BERT作為構建塊,去創(chuàng)建另一個深度學習模型。
傳送門
原文地址:
https ://towardsdatascience.com/pre-training-bert-from-scratch-with-cloud-tpu-6e2f71028379
Colab代碼:
https ://colab.research.google.com/drive/1nVn6AFpQSzXBt8_ywfx6XR8ZfQXlKGAz
- 腦機接口走向現實,11張PPT看懂中國腦機接口產業(yè)現狀|量子位智庫2021-08-10
- 張朝陽開課手推E=mc2,李永樂現場狂做筆記2022-03-11
- 阿里數學競賽可以報名了!獎金增加到400萬元,題目面向大眾公開征集2022-03-14
- 英偉達遭黑客最后通牒:今天必須開源GPU驅動,否則公布1TB機密數據2022-03-05