对抗思想与强化学习的碰撞-SeqGAN模型原理和代码解析
GAN作為生成模型的一種新型訓練方法,通過discriminative model來指導generative model的訓練,并在真實數據中取得了很好的效果。盡管如此,當目標是一個待生成的非連續性序列時,該方法就會表現出其局限性。非連續性序列生成,比如說文本生成,為什么單純的使用GAN沒有取得很好的效果呢?主要的屏障有兩點:
1)在GAN中,Generator是通過隨機抽樣作為開始,然后根據模型的參數進行確定性的轉化。通過generative model G的輸出,discriminative model D計算的損失值,根據得到的損失梯度去指導generative model G做輕微改變,從而使G產生更加真實的數據。而在文本生成任務中,G通常使用的是LSTM,那么G傳遞給D的是一堆離散值序列,即每一個LSTM單元的輸出經過softmax之后再取argmax或者基于概率采樣得到一個具體的單詞,那么這使得梯度下架很難處理。
2)GAN只能評估出整個生成序列的score/loss,不能夠細化到去評估當前生成token的好壞和對后面生成的影響。
強化學習可以很好的解決上述的兩點。再回想一下Policy Gradient的基本思想,即通過reward作為反饋,增加得到reward大的動作出現的概率,減小reward小的動作出現的概率,如果我們有了reward,就可以進行梯度訓練,更新參數。如果使用Policy Gradient的算法,當G產生一個單詞時,如果我們能夠得到一個反饋的Reward,就能通過這個reward來更新G的參數,而不再需要依賴于D的反向傳播來更新參數,因此較好的解決了上面所說的第一個屏障。對于第二個屏障,當產生一個單詞時,我們可以使用蒙塔卡羅樹搜索(Alpho Go也運用了此方法)立即評估當前單詞的好壞,而不需要等到整個序列結束再來評價這個單詞的好壞。
因此,強化學習和對抗思想的結合,理論上可以解決非連續序列生成的問題,而SeqGAN模型,正是這兩種思想碰撞而產生的可用于文本序列生成的模型。
SeqGAN模型的原文地址為:https://arxiv.org/abs/1609.05473,當然在我的github鏈接中已經把下載好的原文貼進去啦。
結合代碼可以更好的理解模型的細節喲:https://github.com/princewen/tensorflow_practice/tree/master/seqgan
2、SeqGAN的原理
SeqGAN的全稱是Sequence Generative Adversarial Nets。這里打公式太麻煩了,所以我們用word打好再粘過來,沖這波手打也要給小編一個贊呀,哈哈!
整體流程
模型的示意圖如下:
Generator模型和訓練
接下來,我們分別來說一下Generator模型和Discriminator模型結構。
Generator一般選擇的是循環神經網絡結構,RNN,LSTM或者是GRU都可以。對于輸入的序列,我們首先得到序列中單詞的embedding,然后輸入每個cell中,并結合一層全鏈接隱藏層得到輸出每個單詞的概率,即:
有了這個概率,Generator可以根據它采樣一批產生的序列,比如我們生成一個只有,兩個單詞的序列,總共的單詞序列有3個,第一個cell的輸出為(0.5,0.5,0.0),第二個cell的輸出為(0.1,0.8,0.1),那么Generator產生的序列以0.4的概率是1->2,以0.05的概率是1->1。注意這里Generator產生的序列是概率采樣得到的,而不是對每個輸出進行argmax得到的固定的值。這和policy gradient的思想是一致的。
在每一個cell我們都能得到一個概率分布,我們基于它選擇了一個動作或者說一個單詞,如何判定基于這個概率分布得到的單詞的還是壞的呢?即我們需要一個reward來左右這個單詞被選擇的概率。這個reward怎么得到呢,就需要我們的Discriminator以及蒙塔卡羅樹搜索方法了。前面提到過Reward的計算依據是最大可能的Discriminator,即盡可能的讓Discriminator認為Generator產生的數據為real-world的數據。這里我們設定real-world的數據的label為1,而Generator產生的數據label為0.
如果當前的cell是最后的一個cell,即我們已經得到了一個完整的序列,那么此時很好辦,直接把這個序列扔給Discriminator,得到輸出為1的概率就可以得到reward值。如果當前的cell不是最后一個cell,即當前的單詞不是最后的單詞,我們還沒有得到一個完整的序列,如何估計當前這個單詞的reward呢?我們用到了蒙特卡羅樹搜索的方法。即使用前面已經產生的序列,從當前位置的下一個位置開始采樣,得到一堆完整的序列。在原文中,采樣策略被稱為roll-out policy,這個策略也是通過一個神經網絡實現,這個神經網絡我們可以認為就是我們的Generator。得到采樣的序列后,我們把這一堆序列扔給Discriminator,得到一批輸出為1的概率,這堆概率的平均值即我們的reward。這部分正如過程示意圖中的下面一部分:
用原文中的公式表示如下:
得到了reward,我們訓練Generator的方式就很簡單了,即通過Policy Gradient的方式進行訓練。最簡單的思想就是增加reward大的動作的選擇概率,減小reward小的動作的選擇概率。
Discriminator模型和訓練
Discriminator模型即一個分類器,對文本分類的分類器很多,原文采用的是卷積神經網絡。同時為了使模型的分類效果更好,在CNN的基礎上增加了一個highway network。有關highway network的介紹參考博客:https://blog.csdn.net/l494926429/article/details/51737883,這里就不再細講啦。
對于Discriminator來說,既然是一個分類器,輸出的又是兩個類別的概率值,我們很自然的想到使用類似邏輯回歸的對數損失函數,沒錯,論文中也是使用對數損失來訓練Discriminator的。
結合oracle模型
可以說,模型我們已經介紹完了,但是在實驗部分,論文中引入了一個新的模型中,被稱為oracle model。這里的oracle如何翻譯,我還真的是不知道,總不能翻譯為甲骨文吧。這個oracle model被用來生成真實的序列,可以認為這個model就是一個被訓練完美的lstm模型,輸出的序列都是real-world數據。論文中使用這個模型的原因有兩點:首先是可以用來產生訓練數據,另一點是可以用來評價我們Generator的真實表現。原文如下:
我們會在訓練過程中不斷通過上面的式子來評估我們的Generator與oracle model的相似性。
預訓練過程
上面我們講的其實是在對抗過程中Generator和Discriminator的訓練過程,其實在進行對抗之前,我們的Generator和Discriminator都有一個預訓練的過程,這能使我們的模型更快的收斂。
對于Generator來說,預訓練和對抗過程中使用的損失函數是不一樣的,在預訓練過程中,Generator使用的是交叉熵損失函數,而在對抗過程中,我們使用的則是Policy Gradient中的損失函數,即對數損失*獎勵值。
而對Discriminator來說,兩個過程中的損失函數都是一樣的,即我們前面介紹的對數損失函數。
SeqGAN模型流程
介紹了這么多,我們再來看一看SeqGAN的流程:
3、SeqGAN代碼解析
這里我們用到的代碼高度還原了原文中的實驗過程,本文參考的github代碼地址為:https://github.com/ChenChengKuan/SeqGAN_tensorflow
參考的代碼為python2版本的,本文將其稍作修改,改成了python3版本的。其實主要就是print和pickle兩個地方。本文代碼的github地址為:https://github.com/princewen/tensorflow_practice/tree/master/seqgan
代碼實在是太多了,我們這里只介紹一下代碼結構,具體的代碼細節大家可以參考github進行學習。
3.1 代碼結構
本文的代碼結構如下:
save:save文件夾下保存了我們的實驗日志,eval_file是由Generator產生,用來評價Generator和oracle model相似性所產生的數據。real_data是由oracle model產生的real-world數據,generator_sample是由Generator產生的數據,target_params是oracle model的參數,我們直接用里面的參數還原oracle model。
configuration : 一些配置參數
dataloader.py: 產生訓練數據,對于Generator來說,我們只在預訓練中使用dataloader來得到訓練數據,對Discriminator來說,在預訓練和對抗過程中都要使用dataloader來得到訓練數據。而在eval過程即進行Generator和oracle model相似性判定時,會用刀dataloader來產生數據。
discriminator.py:定義了我們的discriminator
generator.py :定義了我們的generator
rollout.py:計算reward時的采樣過程
target_lstm.py:定義了我們的oracle model,這個文件不用管,復制過去就好,哈哈。
train.py : 定義了我們的訓練過程,這是我們一會重點講解的文件
utils.py : 定義了一些在訓練過程中的通用過程。
下面,我們就來介紹一下每個文件。
3.2 dataloader
dataloader是我們的數據生成器。
它定義了兩個類,一個時Generator的數據生成器,主要用于Generator的預訓練以及計算Generator和Oracle model的相似性。另一個時Discriminator的數據生成器,主要用于Discriminator的訓練。
3.3 generator
generator中定義了我們的Generator,代碼結構如下:
build_input:定義了我們的預訓練模型和對抗過程中需要輸入的數據
build_pretrain_network : 定義了Generator的預訓練過程中的網絡結構,其實這個網絡結構在預訓練,對抗和采樣的過程中是一樣的,參數共享。預訓練過程中定義的損失是交叉熵損失。
build_adversarial_network: 定義了Generator的對抗過程的網絡結構,和預訓練過程共享參數,因此你可以發現代碼基本上是一樣的,只不過在對抗過程中的損失函數是policy gradient的損失函數,即 -log(p(xi) * v(xi):
self.pgen_loss_adv = - tf.reduce_sum(tf.reduce_sum(tf.one_hot(tf.to_int32(tf.reshape(self.input_seqs_adv,[-1])),self.num_emb,on_value=1.0,off_value=0.0)* tf.log(tf.clip_by_value(tf.reshape(self.softmax_list_reshape,[-1,self.num_emb]),1e-20,1.0)),1) * tf.reshape(self.rewards,[-1]))build_sample_network:定義了我們Generator采樣得到生成序列過程的網絡結構,與前兩個網絡參數是共享的。
那么這三個網絡是如何使用的呢?pretrain_network就是用來預訓練我們的Generator的,這個沒有異議。然后在對抗時的每一個epoch,首先用sample_network得到一堆采樣的序列samples,然后對采樣序列的對每一個時點,使用roll-out-policy結合Discriminator得到reward值。最后,把這些samples和reward值喂給adversarial_network進行參數更新。
3.4 discriminator
discriminator的文件結構如下:
前面的linear和highway函數實現了highway network。
在Discriminator類中,我們采用CNN建立了Discriminator的網絡結構,值得注意的是,我們這里采用的損失函數加入了正則項:
with tf.name_scope("output"):W = tf.Variable(tf.truncated_normal([num_filters_total,self.num_classes],stddev = 0.1),name="W")b = tf.Variable(tf.constant(0.1,shape=[self.num_classes]),name='b')self.l2_loss += tf.nn.l2_loss(W)self.l2_loss += tf.nn.l2_loss(b)self.scores = tf.nn.xw_plus_b(self.h_drop,W,b,name='scores') # batch * num_classesself.ypred_for_auc = tf.nn.softmax(self.scores)self.predictions = tf.argmax(self.scores,1,name='predictions')with tf.name_scope(“loss”):
losses = tf.nn.softmax_cross_entropy_with_logits(logits=self.scores,labels=self.input_y)
# 損失函數中加入了正則項
self.loss = tf.reduce_mean(losses) + self.l2_reg_lambda + self.l2_loss
3.5 rollout
這個文件實現的通過rollout-policy得到一堆完整序列的過程,前面我們提到過了,rollout-policy實現需要一個神經網絡,而我們這里用Generator當作這個神經網絡,所以它與前面提到的三個Generator的網絡的參數也是共享的。
另外需要注意的是,我們這里要得到每個序列每個時點的采樣數據,因此需要進行兩層循環:
假設我們傳過來的序列長度是20,最后一個不需要進行采樣,因為已經是完整的序列了。假設當前的step是5,那么0-4是不需要采樣的,但我們需要把0-4位置的序列輸入到網絡中得到state。得到state之后,我們再經過一層循環得到5-19位的采樣序列,然后將0-4位置的序列的和5-19位置的序列的進行拼接。
sample_rollout = tf.concat([sample_rollout_left,sample_rollout_right],axis=1)3.6 utils
utils中定義了兩個函數:
generate_samples函數用于調用Generator中的sample_network產生sample或者用于調用target-lstm中的sample_network產生real-world數據
target_loss函數用于計算Generator和oracle model的相似性。
3.7 train
終于改介紹我們的主要流程控制代碼了,先深呼吸一口,準備開始!
定義dataloader以及網絡
首先,我們獲取了configuration中定義的參數,然后基于這些參數,我們得到了三個dataloader。
隨后,我們定義了Generator和Discriminator,以及通過讀文件來創建了我們的oracle model,在代碼中叫target_lstm。
config_train = training_config() config_gen = generator_config() config_dis = discriminator_config()np.random.seed(config_train.seed)
assert config_train.start_token == 0
gen_data_loader = Gen_Data_loader(config_gen.gen_batch_size)
likelihood_data_loader = Gen_Data_loader(config_gen.gen_batch_size)
dis_data_loader = Dis_dataloader(config_dis.dis_batch_size)
generator = Generator(config=config_gen)
generator.build()
rollout_gen = rollout(config=config_gen)
#Build target LSTM
target_params = pickle.load(open(‘save/target_params.pkl’,‘rb’),encoding=‘iso-8859-1’)
target_lstm = TARGET_LSTM(config=config_gen, params=target_params) # The oracle model
# Build discriminator
discriminator = Discriminator(config=config_dis)
discriminator.build_discriminator()
預訓練Generator
我們首先定義了預訓練過程中Generator的優化器,即通過AdamOptimizer來最小化交叉熵損失,隨后我們通過target-lstm網絡來產生Generator的訓練數據,利用dataloader來讀取每一個batch的數據。
同時,每隔一定的步數,我們會計算Generator與target-lstm的相似性(likelihood)
sess = tf.Session()
sess.run(tf.global_variables_initializer())
generate_samples(sess,target_lstm,config_train.batch_size,config_train.generated_num,config_train.positive_file)
gen_data_loader.create_batches(config_train.positive_file)
log = open(‘save/experiment-log.txt’,‘w’)
print(‘Start pre-training generator…’)
log.write(‘pre-training…\n’)
for epoch in range(config_train.pretrained_epoch_num):
gen_data_loader.reset_pointer()
for it in range(gen_data_loader.num_batch):
batch = gen_data_loader.next_batch()
_,g_loss = sess.run([gen_pre_update,generator.pretrained_loss],feed_dict={generator.input_seqs_pre:batch,
generator.input_seqs_mask:np.ones_like(batch)})
預訓練Discriminator
預訓練好Generator之后,我們就可以通過Generator得到一批負樣本,并結合target-lstm產生的正樣本來預訓練我們的Discriminator。
定義對抗過程中Generator的優化器
這里定義的對抗過程中Generator的優化器即最小化我們前面提到的policy gradient損失,再回顧一遍:
# Initialize global variables of optimizer for adversarial training
uninitialized_var = [e for e in tf.global_variables() if e not in tf.trainable_variables()]
init_vars_uninit_op = tf.variables_initializer(uninitialized_var)
sess.run(init_vars_uninit_op)
對抗過程中訓練Generator
對抗過程中訓練Generator,我們首先需要通過Generator得到一批序列sample,然后使用roll-out結合Dsicriminator得到每個序列中每個時點的reward,再將reward和sample喂給adversarial_network進行參數更新。
對抗過程中訓練Discriminator
對抗過程中Discriminator的訓練和預訓練過程一樣,這里就不再贅述。
for _ in range(config_train.dis_update_time_adv):generate_samples(sess,generator,config_train.batch_size,config_train.generated_num,config_train.negative_file)dis_data_loader.load_train_data(config_train.positive_file,config_train.negative_file) <span class="hljs-keyword">for</span> _ <span class="hljs-keyword">in</span> range(config_train.dis_update_time_adv):dis_data_loader.reset_pointer()<span class="hljs-keyword">for</span> it <span class="hljs-keyword">in</span> range(dis_data_loader.num_batch):x_batch,y_batch = dis_data_loader.next_batch()feed = {discriminator.input_x:x_batch,discriminator.input_y:y_batch,discriminator.dropout_keep_prob:config_dis.dis_dropout_keep_prob}_ = sess.run(discriminator.train_op,feed)3.8 訓練效果
來一發訓練效果截圖:
可以看到,我們的Generator越來越接近oracle model啦,哈哈哈!
參考文獻:
1、https://blog.csdn.net/liuyuemaicha/article/details/70161273
2、https://blog.csdn.net/yinruiyang94/article/details/77675586
3、https://www.jianshu.com/p/32e164883eab
4、https://blog.csdn.net/l494926429/article/details/51737883
總結
以上是生活随笔為你收集整理的对抗思想与强化学习的碰撞-SeqGAN模型原理和代码解析的全部內容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: Android官方开发文档Trainin
- 下一篇: 咨询笔记:麦肯锡7步成诗