han模型理解
一、han模型有兩個重要特征,第一是分層,word-level層與sentence-level層,符合文檔結構;第二個就是使用注意力機制(在加權時,可以根據內容賦予動態權重);
二、han模型如下:
首先是one-hot的嵌入式表示,即embedding
然后再經過word-level編碼層,這個有很多選擇,論文中選擇了雙向GRU模型,得到每個word的編碼
然后再經過注意力層計算出每個word編碼的權值,用于線性加權;這里有個重點就是注意力層中Q,即圖中的Uw,代表context vector,語義向量,是隨機初始化的,不對應任何輸入;V代表word的編碼(GRU層輸出的隱藏狀態),K是將V經過一個FNN層的輸出;?證明如原文的記錄:
然后上面就完成了一個句子的編碼;
然后多個句子組成輸入,即基于sentence-level,經過編碼層(雙向GRU),(本質上和word-level一模一樣),輸出一個文檔向量
最后,經過一個線性轉換變成得分,再加softmax層輸出分類概率值;如圖所示(v是文檔向量):
三、模型設置與訓練:
? ? ?a、先處理文本,分詞化
? ? b、使用word2vec模型訓練得到word2vec矩陣;用于初始化han模型中嵌入層;嵌入層輸出維度為200,編碼層輸出維度為100(每個方向各占50),語義向量維度也為100;
? ?c、batchsize為64,動量值為0.9,學習率用grid?search搜索得到;
?
四、han定義模型代碼:
?
?
五、訓練代碼:
#coding=utf-8 import tensorflow as tf import time import os from data_helper import load_dataset from HAN_model import HAN# Data loading params tf.flags.DEFINE_string("yelp_json_path", 'data/yelp_academic_dataset_review.json', "data directory") tf.flags.DEFINE_integer("vocab_size", 46960, "vocabulary size") tf.flags.DEFINE_integer("num_classes", 5, "number of classes") tf.flags.DEFINE_integer("embedding_size", 200, "Dimensionality of character embedding (default: 200)") tf.flags.DEFINE_integer("hidden_size", 50, "Dimensionality of GRU hidden layer (default: 50)") tf.flags.DEFINE_integer("batch_size", 32, "Batch Size (default: 64)") tf.flags.DEFINE_integer("num_epochs", 10, "Number of training epochs (default: 50)") tf.flags.DEFINE_integer("checkpoint_every", 100, "Save model after this many steps (default: 100)") tf.flags.DEFINE_integer("num_checkpoints", 5, "Number of checkpoints to store (default: 5)") tf.flags.DEFINE_integer("max_sent_in_doc", 30, "Number of checkpoints to store (default: 5)") tf.flags.DEFINE_integer("max_word_in_sent", 30, "Number of checkpoints to store (default: 5)") tf.flags.DEFINE_integer("evaluate_every", 100, "evaluate every this many batches") tf.flags.DEFINE_float("learning_rate", 0.01, "learning rate") tf.flags.DEFINE_float("grad_clip", 5, "grad clip to prevent gradient explode")FLAGS = tf.flags.FLAGStrain_x, train_y, dev_x, dev_y = load_dataset(FLAGS.yelp_json_path, FLAGS.max_sent_in_doc, FLAGS.max_word_in_sent) print "data load finished"with tf.Session() as sess:han = HAN(vocab_size=FLAGS.vocab_size,num_classes=FLAGS.num_classes,embedding_size=FLAGS.embedding_size,hidden_size=FLAGS.hidden_size)with tf.name_scope('loss'):loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=han.input_y,logits=han.out,name='loss'))with tf.name_scope('accuracy'):predict = tf.argmax(han.out, axis=1, name='predict')label = tf.argmax(han.input_y, axis=1, name='label')acc = tf.reduce_mean(tf.cast(tf.equal(predict, label), tf.float32))timestamp = str(int(time.time()))out_dir = os.path.abspath(os.path.join(os.path.curdir, "runs", timestamp))print("Writing to {}\n".format(out_dir))global_step = tf.Variable(0, trainable=False)optimizer = tf.train.AdamOptimizer(FLAGS.learning_rate)# RNN中常用的梯度截斷,防止出現梯度過大難以求導的現象tvars = tf.trainable_variables()grads, _ = tf.clip_by_global_norm(tf.gradients(loss, tvars), FLAGS.grad_clip)grads_and_vars = tuple(zip(grads, tvars))train_op = optimizer.apply_gradients(grads_and_vars, global_step=global_step)# Keep track of gradient values and sparsity (optional)grad_summaries = []for g, v in grads_and_vars:if g is not None:grad_hist_summary = tf.summary.histogram("{}/grad/hist".format(v.name), g)grad_summaries.append(grad_hist_summary)grad_summaries_merged = tf.summary.merge(grad_summaries)loss_summary = tf.summary.scalar('loss', loss)acc_summary = tf.summary.scalar('accuracy', acc)train_summary_op = tf.summary.merge([loss_summary, acc_summary, grad_summaries_merged])train_summary_dir = os.path.join(out_dir, "summaries", "train")train_summary_writer = tf.summary.FileWriter(train_summary_dir, sess.graph)dev_summary_op = tf.summary.merge([loss_summary, acc_summary])dev_summary_dir = os.path.join(out_dir, "summaries", "dev")dev_summary_writer = tf.summary.FileWriter(dev_summary_dir, sess.graph)checkpoint_dir = os.path.abspath(os.path.join(out_dir, "checkpoints"))checkpoint_prefix = os.path.join(checkpoint_dir, "model")if not os.path.exists(checkpoint_dir):os.makedirs(checkpoint_dir)saver = tf.train.Saver(tf.global_variables(), max_to_keep=FLAGS.num_checkpoints)sess.run(tf.global_variables_initializer())def train_step(x_batch, y_batch):feed_dict = {han.input_x: x_batch,han.input_y: y_batch,han.max_sentence_num: 30,han.max_sentence_length: 30,han.batch_size: 64}_, step, summaries, cost, accuracy = sess.run([train_op, global_step, train_summary_op, loss, acc], feed_dict)time_str = str(int(time.time()))print("{}: step {}, loss {:g}, acc {:g}".format(time_str, step, cost, accuracy))train_summary_writer.add_summary(summaries, step)return stepdef dev_step(x_batch, y_batch, writer=None):feed_dict = {han.input_x: x_batch,han.input_y: y_batch,han.max_sentence_num: 30,han.max_sentence_length: 30,han.batch_size: 64}step, summaries, cost, accuracy = sess.run([global_step, dev_summary_op, loss, acc], feed_dict)time_str = str(int(time.time()))print("++++++++++++++++++dev++++++++++++++{}: step {}, loss {:g}, acc {:g}".format(time_str, step, cost, accuracy))if writer:writer.add_summary(summaries, step)for epoch in range(FLAGS.num_epochs):print('current epoch %s' % (epoch + 1))for i in range(0, 200000, FLAGS.batch_size):x = train_x[i:i + FLAGS.batch_size]y = train_y[i:i + FLAGS.batch_size]step = train_step(x, y)if step % FLAGS.evaluate_every == 0:dev_step(dev_x, dev_y, dev_summary_writer)代碼來源:https://github.com/Irvinglove/HAN-text-classification/blob/master
?
總結
- 上一篇: xgboost与coo_matrix
- 下一篇: coo_maxtrix保存到本地