使用keras-bert进行中文文本分类+Google colab运行源码
前文介紹了BERT的原理。在實(shí)際應(yīng)用中,BERT要比其理論本身要簡(jiǎn)單的多。這里我們利用Github的中文BERT預(yù)訓(xùn)練的結(jié)果(地址),進(jìn)行實(shí)際的文檔分類。
數(shù)據(jù)集
為了便于進(jìn)行比較,文檔分類的數(shù)據(jù)集來自Github的這個(gè)地址。該數(shù)據(jù)集使用了THUCNews的一個(gè)子集,使用了其中的10個(gè)分類,每個(gè)分類6500條數(shù)據(jù)。
類別如下:
體育, 財(cái)經(jīng), 房產(chǎn), 家居, 教育, 科技, 時(shí)尚, 時(shí)政, 游戲, 娛樂
數(shù)據(jù)集劃分如下:
- 訓(xùn)練集: 5000*10
- 驗(yàn)證集: 500*10
Github上該數(shù)據(jù)集的地址在百度網(wǎng)盤上,可以用multicloud直接將百度網(wǎng)盤的文件傳到google drive上。
基本結(jié)構(gòu)
利用BERT的預(yù)訓(xùn)練模型進(jìn)行文檔分類,其原理如下:
其原理是在BERT的Transform Layer的最高一層的第一個(gè)輸出,添加Dense+Softmax對(duì)文檔進(jìn)行分類,loss函數(shù)取cross entropy函數(shù)。因此,真正的核心代碼也就如下幾行:
預(yù)訓(xùn)練的模型采用了小參量的模型RBT3,RBT3實(shí)際上只有3個(gè)transform layer,而不是完整的12個(gè)transform layer。因此,只能算是“低配窮人版”的BERT,但從后面的結(jié)果來看表現(xiàn)已經(jīng)相當(dāng)不錯(cuò)了。
框架
為了實(shí)現(xiàn)上述功能,本文采用了Keras框架。Keras框架的好處是糙快猛,便于深度學(xué)習(xí)的入門者快速上手,而針對(duì)BERT模型,也已經(jīng)有開源的keras-bert可以直接拿來使用。
平臺(tái)
源代碼在Google colab平臺(tái)上,使用Google免費(fèi)提供的GPU跑通。本來本人想試TPU的,但是發(fā)現(xiàn)依據(jù)keras-bert建議的TPU訓(xùn)練代碼,模型無法收斂,因此還有待牛人進(jìn)一步解決了…
代碼
接下來直接上代碼。
下載數(shù)據(jù)到Colab目錄
首先將RBT3對(duì)應(yīng)的google drive的地址,直接add到自己的google drive賬號(hào)里,這樣就可以不用下載,直接在colab上將文件解壓就可以。
!cp drive/My\ Drive/chinese_rbt3_L-3_H-768_A-12.zip . !mkdir cnews !mkdir model !cp drive/My\ Drive/cnews/* cnews/ !rm -rf model !unzip chinese_rbt3_L-3_H-768_A-12.zip -d model/輸出
Archive: chinese_rbt3_L-3_H-768_A-12.zipinflating: model/bert_config_rbt3.json inflating: model/bert_model.ckpt.data-00000-of-00001 inflating: model/__MACOSX/._bert_model.ckpt.data-00000-of-00001 inflating: model/bert_model.ckpt.index inflating: model/__MACOSX/._bert_model.ckpt.index inflating: model/bert_model.ckpt.meta inflating: model/__MACOSX/._bert_model.ckpt.meta inflating: model/vocab.txt inflating: model/__MACOSX/._vocab.txt安裝庫(kù)
!pip install --upgrade --force-reinstall keras-bert keras-rectified-adam導(dǎo)入模塊
import numpy as np import os import keras # from tensorflow.python import keras from keras_bert import load_trained_model_from_checkpoint #用于加載預(yù)訓(xùn)練的bert from keras_bert import get_pretrained, PretrainedList, get_checkpoint_paths from keras_bert import Tokenizer #用于對(duì)輸入的文章“分詞” import codecs from keras.preprocessing.sequence import pad_sequences注意到keras-bert引用了兩種keras庫(kù):
- tensorflow.python.keras:
- keras:
keras-bert默認(rèn)使用的是keras庫(kù),因此模型也要對(duì)應(yīng)import keras。如果在python里執(zhí)行
那么接下來模型要import tensorflow.python.keras。如果弄混了,在build model的時(shí)候就會(huì)有千奇百怪的錯(cuò)誤。
如果采用colab的TPU,就需要用tensorflow.python.keras包了,因?yàn)樵趉eras本身對(duì)TPU支持目前還有限,對(duì)TPU的操作主要依賴于tensorflow,而且還必須是tensorflow 1.x的版本,tensorflow 2.x版本目前對(duì)TPU的支持不好。目前colab默認(rèn)的還是tensorflow 1.x的版本。
配置數(shù)據(jù)的路徑
# 與訓(xùn)練的bert model的路徑 model_path="model/" paths = get_checkpoint_paths(model_path) print(paths.config, paths.checkpoint, paths.vocab) # 下載的數(shù)據(jù)集的路徑 base_dir = 'cnews' train_dir = os.path.join(base_dir, 'cnews.train.txt') test_dir = os.path.join(base_dir, 'cnews.test.txt') val_dir = os.path.join(base_dir, 'cnews.val.txt') vocab_dir = os.path.join(base_dir, 'cnews.vocab.txt')數(shù)據(jù)預(yù)處理
def read_file(filename):"""讀取文件數(shù)據(jù)"""contents, labels = [], []with open(filename, encoding='utf-8') as f:for line in f:try:label, content = line.strip().split('\t')if content:contents.append(content)labels.append(label)except:passreturn contents, labels # 從文件中讀取文本內(nèi)容和對(duì)應(yīng)的分類標(biāo)簽 train_contents, train_labels=read_file(train_dir) validate_contents, validate_labels=read_file(val_dir)# 讀取預(yù)訓(xùn)練模型的“分詞器” token_dict = {} with codecs.open(paths.vocab, 'r', 'utf8') as reader:for line in reader:token = line.strip()token_dict[token] = len(token_dict) tokenizer = Tokenizer(token_dict) tokenizer.tokenize(u'今天天氣不錯(cuò)')注意到“分詞器”針對(duì)“今天天氣不錯(cuò)”會(huì)輸出:
['[CLS]', '今', '天', '天', '氣', '不', '錯(cuò)', '[SEP]']這里也可以看到BERT的“分詞”實(shí)際上是將句子中的每一個(gè)漢字都拆成了單獨(dú)的一個(gè)字符。這實(shí)際上也是前面提到的“基于字符的”神經(jīng)網(wǎng)絡(luò)模型。2019年已經(jīng)有文獻(xiàn)證明,針對(duì)漢字,基于字符的NLP要比基于分詞的NLP訓(xùn)練效果要更好,因此BERT的中文NLP也采用了這種技術(shù)。
def get_id_segments(contents):ids, segments = [],[]for sent in contents:id, segment = tokenizer.encode(sent)ids.append(id)segments.append(segment)return ids, segments train_ids,train_segments = get_id_segments(train_contents) validate_ids, validate_segments = get_id_segments(validate_contents) print(train_contents[0], train_ids[0])輸出
馬曉旭意外受傷讓國(guó)奧警惕 無奈大雨格外青睞殷家軍記者傅亞雨沈陽(yáng)報(bào)道 來到沈陽(yáng),國(guó)奧隊(duì)依然沒有擺脫雨水的困擾。7月31日下午6點(diǎn),國(guó)奧隊(duì)的日常訓(xùn)練再度受到大雨的干擾,無奈之下隊(duì)員們只慢跑了25分鐘就草草收?qǐng)觥?1日上午10點(diǎn),國(guó)奧隊(duì)在奧體中心外場(chǎng)訓(xùn)練的時(shí)候,天就是陰沉沉的,氣象預(yù)報(bào)顯示當(dāng)天下午沈陽(yáng)就有大雨,但幸好隊(duì)伍上午的訓(xùn)練并沒有受到任何干擾。下午6點(diǎn),當(dāng)球隊(duì)抵達(dá)訓(xùn)練場(chǎng)時(shí),大雨已經(jīng)下了幾個(gè)小時(shí),而且絲毫沒有停下來的意思。抱著試一試的態(tài)度,球隊(duì)開始了當(dāng)天下午的例行訓(xùn)練,25分鐘過去了,天氣沒有任何轉(zhuǎn)好的跡象,為了保護(hù)球員們,國(guó)奧隊(duì)決定中止當(dāng)天的訓(xùn)練,全隊(duì)立即返回酒店。在雨中訓(xùn)練對(duì)足球隊(duì)來說并不是什么稀罕事,但在奧運(yùn)會(huì)即將開始之前,全隊(duì)變得“嬌貴”了。在沈陽(yáng)最后一周的訓(xùn)練,國(guó)奧隊(duì)首先要保證現(xiàn)有的球員不再出現(xiàn)意外的傷病情況以免影響正式比賽,因此這一階段控制訓(xùn)練受傷、控制感冒等疾病的出現(xiàn)被隊(duì)伍放在了相當(dāng)重要的位置。而抵達(dá)沈陽(yáng)之后,中后衛(wèi)馮蕭霆就一直沒有訓(xùn)練,馮蕭霆是7月27日在長(zhǎng)春患上了感冒,因此也沒有參加29日跟塞爾維亞的熱身賽。隊(duì)伍介紹說,馮蕭霆并沒有出現(xiàn)發(fā)燒癥狀,但為了安全起見,這兩天還是讓他靜養(yǎng)休息,等感冒徹底好了之后再恢復(fù)訓(xùn)練。由于有了馮蕭霆這個(gè)例子,因此國(guó)奧隊(duì)對(duì)雨中訓(xùn)練就顯得特別謹(jǐn)慎,主要是擔(dān)心球員們受涼而引發(fā)感冒,造成非戰(zhàn)斗減員。而女足隊(duì)員馬曉旭在熱身賽中受傷導(dǎo)致無緣奧運(yùn)的前科,也讓在沈陽(yáng)的國(guó)奧隊(duì)現(xiàn)在格外警惕,“訓(xùn)練中不斷囑咐隊(duì)員們要注意動(dòng)作,我們可不能再出這樣的事情了。”一位工作人員表示。從長(zhǎng)春到沈陽(yáng),雨水一路伴隨著國(guó)奧隊(duì),“也邪了,我們走到哪兒雨就下到哪兒,在長(zhǎng)春幾次訓(xùn)練都被大雨給攪和了,沒想到來沈陽(yáng)又碰到這種事情。”一位國(guó)奧球員也對(duì)雨水的“青睞”有些不解。 [ 101 7716 3236 3195 2692 1912 1358 839 6375 1744 1952 6356 2664 31871937 1920 7433 3419 1912 7471 4712 3668 2157 1092 6381 5442 987 7627433 3755 7345 2845 6887 3341 1168 3755 7345 8024 1744 1952 7339 8984197 3766 3300 3030 5564 7433 3717 4638 1737 2817 511 128 3299 81763189 678 1286 127 4157 8024 1744 1952 7339 4638 3189 2382 6378 52981086 2428 1358 1168 1920 7433 4638 2397 2817 8024 3187 1937 722 6787339 1447 812 1372 2714 6651 749 8132 1146 7164 2218 5770 5770 31191767 511 8176 3189 677 1286 8108 4157 8024 1744 1952 7339 1762 1952860 704 2552 1912 1767 6378 5298 4638 3198 952 8024 1921 2218 32217346 3756]這里主要是將讀取到的文章的文本轉(zhuǎn)換成id,這個(gè)id才是BERT模型需要的真正的輸入。
同時(shí)每一個(gè)文章長(zhǎng)度都是不一樣的,這里簡(jiǎn)單的畫了下各文章經(jīng)過tokenizer后長(zhǎng)度的分布:
其中橫軸是文章數(shù),縱軸是樣本的個(gè)數(shù)。這里可以看到有相當(dāng)?shù)奈恼碌?#34;字“數(shù)都超過了1000字。但是BERT模型最長(zhǎng)也就只能輸入512,同時(shí)考慮到Colab上GPU的內(nèi)存的限制,真正輸入的文章長(zhǎng)度要更短。
2019年已經(jīng)有論文對(duì)文本分類要截取的“字?jǐn)?shù)”進(jìn)行了討論(原文),論文針對(duì)IMDB上的影評(píng),考慮了三種情況:
- 截取文章頭部510個(gè)token
- 截取文章尾部510個(gè)token
- 截取文章頭部128個(gè)token和尾部382個(gè)token。
發(fā)現(xiàn)第三種情況效果是最好的。
由于我們處理的新聞文章,新聞文章的特點(diǎn)一般都是開宗明義,所以我這里只取了文章頭部128個(gè)token,因此代碼:
這里用了keras自帶的pad_sequences函數(shù)。這個(gè)函數(shù)對(duì)超過指定長(zhǎng)度的會(huì)截?cái)?#xff0c;沒到指定長(zhǎng)度的會(huì)補(bǔ)0,返回numpy數(shù)組。這樣就不用自己hard code了。
此外,還要將讀取到的文章分類也要轉(zhuǎn)為數(shù)字:
加載BERT預(yù)訓(xùn)練模型
加載BERT預(yù)訓(xùn)練模型,實(shí)際上調(diào)用的就是keras-bert的函數(shù)就可以。
bert_model = load_trained_model_from_checkpoint(paths.config.replace('.json','_rbt3.json'), paths.checkpoint, seq_len=None) for i,layer in enumerate(bert_model.layers):layer.trainable = True調(diào)用完后,BERT各層默認(rèn)是constant,即不參與接下來的訓(xùn)練。從實(shí)際使用的情況看,如果只訓(xùn)練自己添加的那幾層,幾乎達(dá)不到分類的效果,因此這里我們采用了所有層都參與訓(xùn)練。訓(xùn)練的層越多,理論上效果越好,但也要考慮到過擬合、以及對(duì)硬件資源占用的問題。(之前嘗試采用“高配旗艦版”的24個(gè)transformer layer的BERT,把gpu搞崩了好幾次)由于我們的這個(gè)BERT的模型比較精簡(jiǎn),因此所有層參與訓(xùn)練的問題不大。
修改BERT模型
修改BERT模型,即前面提到的代碼
inputs = bert_model.inputs[:2] x=bert_model.layers[-1].output # if returns tuple, then we are using keras lib # if returns KerasHistory, then we are using tensorflow.python.keras lib print(type(x._keras_history)) x=keras.layers.Lambda(lambda x: x[:, 0], name='slice')(x) x=keras.layers.Dense(units=3072, activation=keras.backend.tanh)(x) x=keras.layers.Dropout(rate=0.1,seed=2019)(x) x=keras.layers.Dense(units=num_labels, activation=keras.backend.softmax)(x) model=keras.Model(inputs, x) model.compile(optimizer=keras.optimizers.Adam(1e-4),loss=keras.losses.sparse_categorical_crossentropy,metrics=[keras.metrics.sparse_categorical_accuracy]) model.summary()會(huì)輸出
__________________________________________________________________________________________________ Layer (type) Output Shape Param # Connected to ================================================================================================== Input-Token (InputLayer) (None, None) 0 __________________________________________________________________________________________________ Input-Segment (InputLayer) (None, None) 0 __________________________________________________________________________________________________ Embedding-Token (TokenEmbedding [(None, None, 768), 16226304 Input-Token[0][0] __________________________________________________________________________________________________ Embedding-Segment (Embedding) (None, None, 768) 1536 Input-Segment[0][0] __________________________________________________________________________________________________ Embedding-Token-Segment (Add) (None, None, 768) 0 Embedding-Token[0][0] Embedding-Segment[0][0] __________________________________________________________________________________________________ Embedding-Position (PositionEmb (None, None, 768) 393216 Embedding-Token-Segment[0][0] __________________________________________________________________________________________________ Embedding-Dropout (Dropout) (None, None, 768) 0 Embedding-Position[0][0] __________________________________________________________________________________________________ Embedding-Norm (LayerNormalizat (None, None, 768) 1536 Embedding-Dropout[0][0] __________________________________________________________________________________________________ Encoder-1-MultiHeadSelfAttentio (None, None, 768) 2362368 Embedding-Norm[0][0] __________________________________________________________________________________________________ Encoder-1-MultiHeadSelfAttentio (None, None, 768) 0 Encoder-1-MultiHeadSelfAttention[ __________________________________________________________________________________________________ Encoder-1-MultiHeadSelfAttentio (None, None, 768) 0 Embedding-Norm[0][0] Encoder-1-MultiHeadSelfAttention- __________________________________________________________________________________________________ Encoder-1-MultiHeadSelfAttentio (None, None, 768) 1536 Encoder-1-MultiHeadSelfAttention- __________________________________________________________________________________________________ Encoder-1-FeedForward (FeedForw (None, None, 768) 4722432 Encoder-1-MultiHeadSelfAttention- __________________________________________________________________________________________________ Encoder-1-FeedForward-Dropout ( (None, None, 768) 0 Encoder-1-FeedForward[0][0] __________________________________________________________________________________________________ Encoder-1-FeedForward-Add (Add) (None, None, 768) 0 Encoder-1-MultiHeadSelfAttention-Encoder-1-FeedForward-Dropout[0][ __________________________________________________________________________________________________ Encoder-1-FeedForward-Norm (Lay (None, None, 768) 1536 Encoder-1-FeedForward-Add[0][0] __________________________________________________________________________________________________ Encoder-2-MultiHeadSelfAttentio (None, None, 768) 2362368 Encoder-1-FeedForward-Norm[0][0] __________________________________________________________________________________________________ Encoder-2-MultiHeadSelfAttentio (None, None, 768) 0 Encoder-2-MultiHeadSelfAttention[ __________________________________________________________________________________________________ Encoder-2-MultiHeadSelfAttentio (None, None, 768) 0 Encoder-1-FeedForward-Norm[0][0] Encoder-2-MultiHeadSelfAttention- __________________________________________________________________________________________________ Encoder-2-MultiHeadSelfAttentio (None, None, 768) 1536 Encoder-2-MultiHeadSelfAttention- __________________________________________________________________________________________________ Encoder-2-FeedForward (FeedForw (None, None, 768) 4722432 Encoder-2-MultiHeadSelfAttention- __________________________________________________________________________________________________ Encoder-2-FeedForward-Dropout ( (None, None, 768) 0 Encoder-2-FeedForward[0][0] __________________________________________________________________________________________________ Encoder-2-FeedForward-Add (Add) (None, None, 768) 0 Encoder-2-MultiHeadSelfAttention-Encoder-2-FeedForward-Dropout[0][ __________________________________________________________________________________________________ Encoder-2-FeedForward-Norm (Lay (None, None, 768) 1536 Encoder-2-FeedForward-Add[0][0] __________________________________________________________________________________________________ Encoder-3-MultiHeadSelfAttentio (None, None, 768) 2362368 Encoder-2-FeedForward-Norm[0][0] __________________________________________________________________________________________________ Encoder-3-MultiHeadSelfAttentio (None, None, 768) 0 Encoder-3-MultiHeadSelfAttention[ __________________________________________________________________________________________________ Encoder-3-MultiHeadSelfAttentio (None, None, 768) 0 Encoder-2-FeedForward-Norm[0][0] Encoder-3-MultiHeadSelfAttention- __________________________________________________________________________________________________ Encoder-3-MultiHeadSelfAttentio (None, None, 768) 1536 Encoder-3-MultiHeadSelfAttention- __________________________________________________________________________________________________ Encoder-3-FeedForward (FeedForw (None, None, 768) 4722432 Encoder-3-MultiHeadSelfAttention- __________________________________________________________________________________________________ Encoder-3-FeedForward-Dropout ( (None, None, 768) 0 Encoder-3-FeedForward[0][0] __________________________________________________________________________________________________ Encoder-3-FeedForward-Add (Add) (None, None, 768) 0 Encoder-3-MultiHeadSelfAttention-Encoder-3-FeedForward-Dropout[0][ __________________________________________________________________________________________________ Encoder-3-FeedForward-Norm (Lay (None, None, 768) 1536 Encoder-3-FeedForward-Add[0][0] __________________________________________________________________________________________________ slice (Lambda) (None, 768) 0 Encoder-3-FeedForward-Norm[0][0] __________________________________________________________________________________________________ dense_11 (Dense) (None, 3072) 2362368 slice[0][0] __________________________________________________________________________________________________ dropout_6 (Dropout) (None, 3072) 0 dense_11[0][0] __________________________________________________________________________________________________ dense_12 (Dense) (None, 10) 30730 dropout_6[0][0] ================================================================================================== Total params: 40,279,306 Trainable params: 40,279,306 Non-trainable params: 0注意最后三行的提示,即所有的參數(shù)都參與了訓(xùn)練。
這里還有個(gè)tip,構(gòu)建模型時(shí),如果直接將代碼
替換成
x=bert_model(inputs)這樣不會(huì)對(duì)model實(shí)際訓(xùn)練造成影響,但是在打印model.summary()時(shí),bert內(nèi)部的各層只會(huì)顯示為一層bert_model。這樣也就不方便查看了。
開始訓(xùn)練
batch_size = 64 train_size = (train_ids.shape[0] // batch_size) * batch_size validate_size = (validate_ids.shape[0] // batch_size) * batch_sizeclass TrainHistory(keras.callbacks.Callback):def on_train_begin(self, logs={}):self.train_loss = []self.train_acc = []def on_batch_end(self, batch, logs={}):self.train_loss.append(logs.get('loss'))self.train_acc.append(logs.get('sparse_categorical_accuracy'))history = TrainHistory()model.fit(x=[train_ids[:train_size], train_segments[:train_size]],y=np.array(train_labelids[:train_size]),validation_data=[[validate_ids[:validate_size], validate_segments[:validate_size]],np.array(validate_labelids[:validate_size])],batch_size = batch_size,callbacks=[history])Keras最讓人感到激動(dòng)的一點(diǎn)是,訓(xùn)練只需要一行代碼fit就解決了。這里實(shí)際上可以直接將訓(xùn)練集和測(cè)試集扔進(jìn)去也沒有問題。之所以我把輸入截?cái)喑蒪atch_size的整數(shù)倍,是由于TPU的輸入是這樣要求的。但由于TPU試驗(yàn)失敗,但是代碼還是保留了。
另外,為了記錄訓(xùn)練時(shí)的loss和accuracy,自定義了TransHistory類。在完成訓(xùn)練后,就可以畫圖表示loss和accuracy的收斂過程了。
訓(xùn)練結(jié)果輸出
Train on 49984 samples, validate on 4992 samples Epoch 1/1 49984/49984 [==============================] - 199s 4ms/step - loss: 0.1682 - sparse_categorical_accuracy: 0.9495 - val_loss: 0.1362 - val_sparse_categorical_accuracy: 0.9619 <keras.callbacks.callbacks.History at 0x7f6e82390080>這里為了節(jié)省時(shí)間,只訓(xùn)練了一個(gè)epoch,使用google colab的GPU 3分多鐘就跑完了,測(cè)試集上準(zhǔn)確率達(dá)到96.19%,而github上采用cnn訓(xùn)練了3個(gè)epoch,才最多達(dá)到94.12%,效果還是很好的。
同時(shí)圖像上也做了比較:
從圖像上看,模型在100個(gè)batch左右,即輸入6400個(gè)樣本的時(shí)候,就可以訓(xùn)練到準(zhǔn)確率90%左右,這樣也說明了BERT在下游訓(xùn)練的優(yōu)勢(shì)還是很明顯的。這也是最近兩年bert如此火的原因吧。
總結(jié)
以上是生活随笔為你收集整理的使用keras-bert进行中文文本分类+Google colab运行源码的全部?jī)?nèi)容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: 【C++复习总结回顾】—— 【五】数组与
- 下一篇: 新旧电脑文件转移