基于bert的阅读理解脚本(run_squad)原理梳理(从举例的角度说明)
文章目錄
- 1. 例子
- 2. 對文章進行分詞
- 3. 確定文章相關屬性
- 4. 構造example
- 5. 得到query_tokens
- 6. 將doc_tokens進行更細粒度地劃分all_doc_tokens
- 7. 獲取答案在all_doc_tokens中的起始位置
- 8. 構造doc_spans
- 9. 構造tokens,并轉化為input_ids
- 10. 更新start_position和end_position
- 11. 構造features
- 12. 保存features并作為model的輸入
- 13. 構造model
- 14. 訓練
- 15. 預測
1. 例子
假如有一篇文章paragraph_text= "Lucy, (1964-2000), from Xian."
問題:question_text = 'When was Lucy born?'
答案:answer = '1964'
本例子一篇文章只有一個問題,有的文章,存在多個問題的情況。下述所操作對象均指一篇文章的一個問題及其答案
2. 對文章進行分詞
對上例中的paragraph_text按照\s,\t,\r,\n進行分詞,最終得到結果:
doc_tokens = ['Lucy,', '(1964-2000),', 'from', 'xian.']
注意一個詞后的標點符號如逗號會和該詞分到一起
與此同時得到char_to_word_offset = [0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3, 3]。
在解釋char_to_word_offset含義之前先明確兩個概念:token表示一個單詞如Lucy,字符表示token的一個字母,如Lucy中的L是一個字符。下文均采用此概念。
doc_tokens中"Lucy,"是第0個字符,而他由5個字符構成,因此這5個字符均用0表示,對應char_to_word_offset前5個元素,以此類推。
3. 確定文章相關屬性
對于一篇文章的一個問題,可能存在沒有答案的情況,因此用變量is_is_impossible表示,如果為True表示該文章的該問題缺少答案;反之則表示該文章的該問題有答案。
- 對于訓練集且答案沒有丟失:
- qas_id:表示該篇文章問題的編號。(對于一篇文章可能有多個問題,因此會對一篇的文章的多個問題進行編號),一般原始文本會標注,直接取即可。
- 用orig_answer_text表示答案的原始文本,該例即’1964’
- start_position:答案開頭的token(本例為1964)在doc_tokens中的index(本例為1,1964包含在’(1964-2000),'中,因此為1)
- end_position:答案結束的token(本例仍為1964)在doc_tokens中的index(本例為1)
- 對于訓練集答案有所丟失:
- qas_id:同上
- orig_answer_text:""空字符串
- start_position:-1
- end_position:-1
- 對于預測集:
- qas_id:同上
- orig_answer_text:None
- start_position:None
- end_position:None
4. 構造example
example是一個列表,該列表每個元素是一個SquadExample類,該類主要記錄了上述一篇文章一個問題的相關屬性:qas_id, question_text, doc_tokens, orig_answer_text, start_positions, end_position, is_impossible。
對于該例:
- qas_id:0
- question_text:‘When was Lucy born?’
- doc_tokens: [‘Lucy,’, ‘(1964-2000),’, ‘from’, ‘xian.’]
- orig_answer_text: ‘1964’
- start_positions: 1
- end_positions: 1
- is_impossible: False
5. 得到query_tokens
從現在開始,所做操縱均指一個example中的一個元素(SquadExample)
將問題(question_text)進行tokenize后按照max_query_length(問題的最大長度)進行截斷。
tokenize指tokenizer.tokenize(example.question_text)操作,
此外還需要限制問題最大的長度,即問題最多有幾個token,因此對于token超過max_query_length的,需要進行截斷。
假設max_query_length = 3
對于本例query_tokens = ['When', 'was', 'Lucy']
6. 將doc_tokens進行更細粒度地劃分all_doc_tokens
上文doc_tokens只是對文章按照\s,\t,\r,\n進行分詞,現在需要將doc_tokens中每個token再進行更細粒度劃分,如’normans’可能會劃分成’norman’和’##s’, ‘Lucy,‘劃分成’Lucy’和’,’。
對于本例,用all_doc_tokens表示細粒度劃分后的結果,all_doc_tokens = ['lucy', ',', '(', '1964', '-', '2000', ')', ',', 'from', 'xi', '##an', '.']
在構造all_doc_tokens的同時,還會得到另兩個變量:
- tok_to_orig_index:list,對于每個元素,表示all_doc_tokens中的元素在doc_tokens中的index。比如all_doc_tokens中的lucy對應doc_tokens第0個元素,','對應doc_tokens第0個元素。因此對于本例,tok_to_orig_index = [0, 0, 1, 1, 1, 1, 1, 1, 2, 3, 3, 3]
- orig_to_tok_index:list與tok_to_orig_index含義相反,表示doc_tokens中的元素在all_doc_tokens中的index(取最近的那個index)。對于本例orig_to_tok_index = [0, 2, 8, 9]
7. 獲取答案在all_doc_tokens中的起始位置
- tok_start_position:答案的第一個詞是all_doc_tokens第幾個元素
- tok_end_position:答案最后一個詞是all_doc_tokens第幾個元素
- 對于訓練集且答案沒有丟失:
- tok_start_position:如上述解釋,本例為3,得到該結果可以參考bert中tok_start_position = orig_to_tok_index[example.start_position]和_improve_answer_span函數
- tok_end_position:如上解釋,本例為3
- 對于訓練集且答案丟失:
- tok_start_position:-1
- tok_end_position:-1
- 對于預測集:
- tok_start_position:None
- tok_end_position:None
8. 構造doc_spans
由于每篇文章的token個數是不一樣的,如果簡單地按照截斷的思想去做,會存在一個問題:截斷后的文章沒有完全包含答案,而閱讀理解最終就想得到答案在文章的開始和結束位置,假如預測的時候把答案截斷了,就不能夠得到結果。所以不能簡單粗暴得進行截斷。較為合理的一個方法是設定一個文章的最大長度,假如用x表示,先取文章的前x個元素作為batch_size中的一個樣本,然后向后走doc_stride步,再取x個元素作為batch_size的另一個樣本,以此類推。因此一篇文章可以切分為好幾個樣本,有的樣本包含答案,有的樣本只包含部分答案,有的樣本完全不包含答案。
在得到本例的結果之前,先劇透下輸入到bert中的tokens構成’[CLS]’ + query + ‘[SEP]’ + 文章 + ‘[SEP]’
假如用max_seq_length表示tokens的最大長度,那么文章的最大長度max_tokens_for_doc = max_seq_length - len(query_tokens) -3。3表示’[CLS]’ 、‘[SEP]’ 、‘[SEP]’
假設max_seq_length = 9, doc_stride = 2,前述假設max_query_length=3,因此len(qeury_tokens)=3,故max_tokens_for_doc=3
對于本例:
doc_spans是一個列表,每個元素是_DocSpan = collections.namedtuple("DocSpan", ["start", "length"]), start表示開頭那個token取的是all_doc_tokens第幾個,length表示3個token占all_doc_tokens幾個元素,一般等于max_tokens_for_doc,只有構造最后一個時可能小于max_tokens_for_doc。
all_doc_tokens = [‘lucy’, ‘,’, ‘(’, ‘1964’, ‘-’, ‘2000’, ‘)’, ‘,’, ‘from’, ‘xi’, ‘##an’, ‘.’]
doc_spans第0個元素start:0, length:3,對應[‘lucy’, ‘,’, ‘(’]
doc_spans第1個元素start:2, length:3,對應[‘(’, ‘1964’, ‘-’]
doc_spans第2個元素start:4, length:3,對應[‘-’, ‘2000’, ‘)’]
doc_spans第3個元素start:6, length:3,對應[‘)’, ‘,’, ‘from’]
doc_spans第4個元素start:8, length:3,對應[ ‘from’, ‘xi’, ‘##an’]
doc_spans第5個元素start:10, length:2,對應[ ‘##an’, ‘.’]
9. 構造tokens,并轉化為input_ids
第8章中每個doc_span都會構成一個tokens,且上文也提前劇透了tokens的構成:‘[CLS]’ + query + ‘[SEP]’ + 文章 + ‘[SEP]’。因此對于doc_spans第0個元素tokens = ['[CLS]', 'When', 'was', 'Lucy', '[SEP]', 'lucy', ',', '(', '[SEP]'],其對應的segment_ids = [0, 0, 0, 0, 0, 1, 1, 1, 1],0表示tokens的該元素位于上半句(query),1表示tokens的元素位于下半句(部分文章片段);其對應的input_mask = [1, 1, 1, 1, 1, 1, 1, 1, 1],1表示tokens的該元素為非填充值,0表示填充值(填充值接下來會講)。
得到了tokens,需要將每個元素轉化為id(詞匯字典,每個字都對應唯一的id號),input_ids = tokenizer.convert_tokens_to_ids(tokens),對于本例input_ids=[101, 2043, 2001, 7004, 102, 7004, 1010, 1006, 102]。
對于doc_spans第5個元素,其tokens = [‘[CLS]’, ‘When’, ‘was’, ‘Lucy’, ‘[SEP]’, ‘##an’, ‘.’, ‘[SEP]’]長度只有8,小于max_seq_length,因此還需要進行填充操作,即在tokens后面加上max_seq_length-len(tokens)個[‘PAD’],相對應的input_ids需要再加max_seq_length-len(tokens)個0(詞匯表中’PAD’對應的id號為0),segment_ids后面加上max_seq_length-len(tokens)個0,input_mask后面加上max_seq_length-len(tokens)個0(表示填充值)。
10. 更新start_position和end_position
- 對于訓練集且答案未丟失:
- start_position:更新為答案的開頭在tokens中的index,對于tokens = [‘[CLS]’, ‘When’, ‘was’, ‘Lucy’, ‘[SEP]’, ‘(’, ‘1964’, ‘-’, ‘[SEP]’],start_position = 6
- end_position:更新為答案的結尾在tokens中的index,對于tokens = [‘[CLS]’, ‘When’, ‘was’, ‘Lucy’, ‘[SEP]’, ‘(’, ‘1964’, ‘-’, ‘[SEP]’],end_position = 6
- 如果token中并沒有完全包含答案,那么start_position和end_position都為0
- 對于訓練集且答案丟失:
- start_position和end_position都為0
- 對于預測集:
- 任務就是預測start_position和end_position
11. 構造features
features是一個列表,每個列表是一個InputFeatures類的實例,其主要記錄了一些屬性,這些屬性針對一篇文章一個問題的一個doc_span。
- unique_id:該文章該問題的該doc_span的唯一標識號
- example_index:examples中第幾個元素(含義為第幾篇文章的第幾個問題編號)
- doc_span_index:doc_spans的第幾個元素(一篇文章一個問題的第幾個子樣本(tokens),一篇文章一個問題可以構造多個tokens)
- tokens:即上述的tokens(不包括最后填充的’PAD’)
- token_to_orig_map:字典,key為tokens中每個元素的index(只有文章,不包含query),value為當前詞在doc_tokens中的索引,以tokens = [‘[CLS]’, ‘When’, ‘was’, ‘Lucy’, ‘[SEP]’, ‘(’, ‘1964’, ‘-’, ‘[SEP]’]為例,該例子的doc_span.start = 2,因此token_to_orig_map為{5: 1, 6: 1, 7: 1}
- token_is_max_context:字典,key為tokens中每個元素的index(只有文章,不包含query),value為該詞是否是含有該詞的doc_span中得分最高的。比如doc_spans第0個元素start:0, length:3,對應[‘lucy’, ‘,’, ‘(’]和doc_spans第1個元素start:2, length:3,對應[‘(’, ‘1964’, ‘-’],兩者都含有’(‘,因此會計算這兩個doc_span中’(‘各自的得分。得分計算方式:對于doc_spans第0個元素start:0, length:3,對應[‘lucy’, ‘,’, ‘(’],’(‘左邊有2個詞,用num_left_context表示,右邊有0個詞,用num_right_context表示,因此score = min(num_left_context, num_right_context) + 0.01 * doc_span.length = 0.03;同理,對于doc_spans第1個元素start:2, length:3,對應[’(‘, ‘1964’, ‘-’],’(‘左邊有0個元素,右邊有兩個元素,因此score = 0.03。由于doc_spans第1個元素在doc_spans第0個元素之后,所以比較順序為doc_spans第1個元素的’(‘得分不大于doc_spans第0個元素的’(‘得分,故對于doc_spans第0個元素中’(‘的value為Ture,對于doc_spans第1個元素中’('的value為False
- input_ids:同上,包括填充
- segment_ids:同上,包括填充
- start_position:同上,更新后的
- end_position:同上,更新后的
- is_impossible:同上,是否缺失答案(預測集為False)
12. 保存features并作為model的輸入
將上述features保存成TFRecord,讀取后即可作為model的輸入了
13. 構造model
model有了輸入,通過bert模型,得到結果final_hidden = model.get_sequence_output(),shape:[batch_size, seq_length, hidden_size],將其reshape為[batch_size * max_seq_length, hidden_size]再經過一個全連接層,得到final_hidden_matrix,shap:[batch_size*max_seq_length, 2]。再將final_hidden_matrix reshape和transpose,得到logits,shape:[2, batch_size, max_seq_length]。logits第0個元素就是start對應的結果start_logits,logits第1個元素就是end對應的結果end_logits。start,shape:[batch_size, max_seq_length],每個樣本每個token的結果,再softmax就是每個樣本每個token的概率值
14. 訓練
正常訓練即可,無更多說明
15. 預測
預測的時候主要對一些變量進行說明。
- start_logits:列表。其對象為一篇文章一個問題的一個doc_span。列表長度為len(tokens),表示tokens中每個token做為開始位置的概率。通過操作[float(x) for x in result[“start_logits”].flat]得到,其中result就是一篇文章一個問題一個doc_span的結果,結構如下:{“unique_ids”: unique_ids, “start_logits”: start_logits, “end_logits”: end_logits},因此最終得到的start_logits是沒有softmax的。
- end_logits:和start_logits類似
- start_indexes:列表,其對象為一篇文章一個問題的一個doc_span。上述start_logits表示每個token做為開始位置的概率值,因此start_indexes記錄了概率前n_best_size大的index(也即最有可能的前n_best_size大token的位置)
- end_indexes:和start_indexes類似
- min_null_feature_index:int。其對象為一篇文章一個問題。
- 對于version_2_with_negative=True(樣本中存在有些問題沒有答案):一篇文章一個問題有多個doc_span,計算每個doc_span的start_logits[0] + end_logits[0](開始位置選擇0,結束位置選擇0,作為無答案的代表)。start_logits[0] + end_logits[0]最小的那個doc_span對應的index(一篇文章一個問題的每個doc_span的feature構成features列表,這里的index是針對features而言的)即min_null_feature_index
- 對于version_2_with_negative=True,min_null_feature_index = 0
- null_start_logit:int。其對象為一篇文章一個問題。
- 對于version_2_with_negative=True,承上,start_logits[0] + end_logits[0]最小的那個doc_span的start_logits[0]
- 對于version_2_with_negative=True,null_start_logit= 0
- null_end_logit:與null_start_logit類似
- score_null:float。其對象為一篇文章的一個問題。
- 對于version_2_with_negative=True,承上,score_null = min(start_logits[0] + end_logits[0])
- 對于version_2_with_negative=False,score_null = 1000000
- prelim_predictions:列表。其對象為一篇文章的一個問題。其元素都是 _PrelimPrediction = collections.namedtuple( “PrelimPrediction”, [“feature_index”, “start_index”, “end_index”, “start_logit”, “end_logit”])。一篇文章一個問題一個doc_span得到了start_indexes和end_indexes,遍歷每個doc_span的start_indexes與end_indexes組合(首先要遍歷doc_span,然后要遍歷組合,組合指:如start_indexes=[0,1], end_indexes=[3, 4],組合類型有start_index=0,end_index=3;start_index=0,end_index=4;start_index=1,end_index=3;start_index=1,end_index=4;),一旦發現有一組合的start_index對應那個token的token_is_max_context[start_index]為True,且end_index大于等于start_index,則成為一個 _PrelimPrediction元素,該元素的feature_index, start_index, end_index, start_logit, end_logit顧名思義。此外,對于version_2_with_negative=True,prelim_predictions中還有一個特殊的_PrelimPrediction 元素,該元素的feature_index = min_null_feature_index,start_index = 0, end_index = 0, start_logit = null_start_logit, end_logit = null_end_logit。最后還需要將所有的_PrelimPrediction元素按照start_logit + end_logit進行從大到小排序。這里注意的是在version_2_with_negative=False,prelim_predictions有可能為空(沒有一個doc_span的start_index/end_index組合是滿足要求的),當version_2_with_negative=True時,prelim_predictions有可能只有一個特殊元素
- n_best:列表。其對象為一篇文章的一個問題。其元素都是 _NbestPrediction = collections.namedtuple( “NbestPrediction”, [“text”, “start_logit”, “end_logit”])。n_best長度小于等于n_best_size(也即上述prelim_predictions前n_best_size元素經過一些操作得到n_best)。遍歷上述prelim_predictions每個元素,記每個元素用pred表示(直到len(n_best) > =n_best_size停止循環)。
- 如果pred.start_index > 0(pred.start_index=0,從上面可以看到是version_2_with_negative=True時那個特殊的_PrelimPrediction元素。正常得到的_PrelimPrediction.start_index不可能是0,因為0對應的token屬于query):final_text為最終預測的結果字符串
- 如果pred.start_index = 0:final_text = “”
故n_best的一個_NbestPrediction的text,start_logit,end_logit都可以得到了。在得到最多n_best_size個 _NbestPrediction元素后,如果version_2_with_negative=True,且final_text=““不在n_best中,n_best還需要添加一個特殊的_NbestPrediction,其 text=””,start_logit=null_start_logit, end_logit = null_end_logit(前n_best_size個_PrelimPrediction 有可能不包含那個特殊的_PrelimPrediction )。因此最終n_best的長度有可能是n_best_size+1。此外n_best在version_2_with_negative=False時,有可能為空,此時需要添加另一個特殊的_NbestPrediction,final_text = ‘empty’, start_logit = 0.0, end_logits = 0.0。n_best在version_2_with_negative=True時有可能只有那個特殊的_NbestPrediction
- total_scores:列表。其對象為一篇文章的一個問題。得到上述n_best后,每個元素都可以得到start_logit+end_logit,該值即total_socores的每個元素值
- best_non_null_entry:一個_NbestPrediction 元素。其對象為一篇文章的一個問題。得到上述n_best后(按照start_logit+end_logit從大到小排過序了),best_non_null_entry即n_best首個text有值的_NbestPrediction 元素。如果n_best只有一個text為""的元素,則best_non_null_entry=None
- probs:列表。其對象為一篇文章的一個問題。相對total_scores而言的。probs就是將total_scores的每個元素按照一定公式轉化得到的。公式為:probi=exp(scorei?max(scores))∑iexp(scorei?max(scores))prob_i = \frac{exp(score_i-max(scores))}{\sum_{i}exp(score_i-max(scores))}probi?=∑i?exp(scorei??max(scores))exp(scorei??max(scores))?
- nbest_json:列表。其對象為一篇文章的一個問題。得到上述n_best后,將每個元素轉換為一個字典,keys為text、probability、start_logit、end_logit。key的含義顧名思義。該字典就是nbest_json的一個元素。
- all_predictions:字典。其對象為所有文章的所有問題。單個元素為一篇文章一個問題的答案。對于單個元素其值如下:
- FLAGS.version_2_with_negative = False:key為example.qas_id(一篇文章的一個問題編號),value為nbest_json[0][“text”]。答案有可能是empty
- FLAGS.version_2_with_negative = True:score_diff = score_null - best_non_null_entry.start_logit - best_non_null_entry.end_logit
- 如果score_diff 大于null_score_diff_threshold(指定的超參數),key為example.qas_id, value為""
- 如果score_diff 小于等于null_score_diff_threshold,key為example.qas_id,value為best_non_null_entry.text
- all_nbest_json:字典。其對象為所有文章的所有問題。單個元素為一篇文章一個問題的n個可能答案。對于單元素,其值key為example.qas_id,value為nbest_json。
- scores_diff_json:其對象為所有文章的所有問題。
- FLAGS.version_2_with_negative = False,score_diff_json就是一個空字典
- FLAGS.version_2_with_negative = True,對于單個元素,承上all_predictions,key為example.qas_id,value為score_diff
總結
以上是生活随笔為你收集整理的基于bert的阅读理解脚本(run_squad)原理梳理(从举例的角度说明)的全部內容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: 网狐荣耀代码通读一----登录服务器
- 下一篇: DH参数法 例题 机器人学