TensorFlow使用CNN实现中文文本分类
TensorFlow使用CNN實現(xiàn)中文文本分類
? ? 讀研期間使用過TensorFlow實現(xiàn)過簡單的CNN情感分析(分類),當(dāng)然這是比較low的二分類情況,后來進(jìn)行多分類情況。但之前的學(xué)習(xí)基本上都是在英文詞庫上訓(xùn)練的。斷斷續(xù)續(xù),想整理一下手頭的項目資料,于是就拾起讀研期間的文本分類的小項目,花了一點時間,把原來英文文本分類的項目,應(yīng)用在中文文本分類,效果還不錯,在THUCNews中文數(shù)據(jù)集上,準(zhǔn)確率93.9%左右,老規(guī)矩,先上源碼地址
? ??Github項目源碼:nlp-learning-tutorials/THUCNews at master · PanJinquan/nlp-learning-tutorials · GitHub, 記得給個“Star”哈
目錄
TensorFlow使用CNN實現(xiàn)中文文本分類
一、項目介紹
1.1 目錄結(jié)構(gòu)
1.2?THUCNews數(shù)據(jù)集
二、CNN模型結(jié)構(gòu)
三、文本預(yù)處理
1、jieba中文分詞
2、gensim訓(xùn)練word2vec模型
四、訓(xùn)練過程
五、測試過程
一、項目介紹
1.1 目錄結(jié)構(gòu)
? Github項目源碼:nlp-learning-tutorials/THUCNews at master · PanJinquan/nlp-learning-tutorials · GitHub, 記得給個“Star”哈
其他資源地址:
- 1.THUCTC官方數(shù)據(jù)集,鏈接:?THUCTC: 一個高效的中文文本分類工具
- 2.THUCTC百度網(wǎng)盤,鏈接:?百度網(wǎng)盤 請輸入提取碼?提取碼: bbpe
- 3.已經(jīng)訓(xùn)練好的word2vec模型:鏈接:?百度網(wǎng)盤 請輸入提取碼?提取碼: mtrj
- 4.使用詞向量處理的THUCNews數(shù)據(jù):鏈接:?百度網(wǎng)盤 請輸入提取碼?提取碼: m9dx
1.2?THUCNews數(shù)據(jù)集
? ? THUCNews是根據(jù)新浪新聞RSS訂閱頻道2005~2011年間的歷史數(shù)據(jù)篩選過濾生成,包含74萬篇新聞文檔(2.19 GB),均為UTF-8純文本格式。我們在原始新浪新聞分類體系的基礎(chǔ)上,重新整合劃分出14個候選分類類別:財經(jīng)、彩票、房產(chǎn)、股票、家居、教育、科技、社會、時尚、時政、體育、星座、游戲、娛樂。相關(guān)介紹,可以看這里THUCTC: 一個高效的中文文本分類工具
下載地址:
1.官方數(shù)據(jù)集下載鏈接: http://thuctc.thunlp.org/message
2.百度網(wǎng)盤下載鏈接: https://pan.baidu.com/s/1DT5xY9m2yfu1YGaGxpWiBQ 提取碼: bbpe
二、CNN模型結(jié)構(gòu)
CNN文本分類的網(wǎng)絡(luò)結(jié),如下:
簡單分析一下:
(1)我們假定輸入CNN的數(shù)據(jù)是二維的,其中每一行表示一個樣本(即一個字詞),如圖中“I”、“l(fā)ike”等。每一個樣本(字詞)有d個維度,可以看成是詞向量長度,即每個字詞的維度,程序中用embedding_dim表示。
(2)使用CNN的卷積對這個二維數(shù)據(jù)進(jìn)行卷積:在圖像的CNN卷積中,卷積核的大小一般是3*3,5*5等,但在NLP中就不就不能這么搞了,因為這里的輸入數(shù)據(jù)每行是一個樣本了!假設(shè)卷積核的大小為[filter_height,filter_width],那么卷積核的高度filter_height可以為1,2,3等任意值,而寬度filter_width只能是embedding_dim的大小,這樣才能把完整的樣本框進(jìn)去!
? ? 下面是使用TensorFlow實現(xiàn)的CNN文本分類網(wǎng)絡(luò):TextCNN,
?max_sentence_length = 300 # 最大句子長度,也就是說文本樣本中字詞的最大長度,不足補零,多余的截斷
?embedding_dim = 128 #詞向量長度,即每個字詞的維度
?filter_sizes = [3, 4, 5, 6] #卷積核大小
?num_filters = 200 ?# Number of filters per filter size 卷價個數(shù)
?base_lr=0.001? ? ? # 學(xué)習(xí)率
?dropout_keep_prob = 0.5
?l2_reg_lambda = 0.0 ?# "L2 regularization lambda (default: 0.0)
三、文本預(yù)處理
? ? 本博客使用jieba工具進(jìn)行中文分詞,使用詞進(jìn)行訓(xùn)練會比使用字進(jìn)行訓(xùn)練,效果更好。
? ? 這部分:已經(jīng)在《使用gensim訓(xùn)練中文語料word2vec》使用gensim訓(xùn)練中文語料word2vec_pan_jinquan的博客-CSDN博客_gensim 中文,詳解講解,自己看吧!
1、jieba中文分詞
? ? 這個需要自己安裝:pip install jieba?或者pip3 install jieba
2、gensim訓(xùn)練word2vec模型
? ?這里使用jieba工具對THUCNews數(shù)據(jù)集進(jìn)行分詞,并利用gensim訓(xùn)練基于THUCNews的word2vec模型,這里提供已經(jīng)訓(xùn)練好的word2vec模型:鏈接: https://pan.baidu.com/s/1n4ZgiF0gbY0zsK0706wZiw 提取碼: mtrj?
2、THUCNews數(shù)據(jù)處理
? ?有了word2vec模型,我就可以用word2vec詞向量處理THUCNews數(shù)據(jù):先使用jieba工具將中文句子轉(zhuǎn)為字詞,再將字詞根據(jù)word2vec模型轉(zhuǎn)為embadding 的索引,有了索引就可以獲得詞向量embadding? 。這里并把這些索引數(shù)據(jù)保存為npy文件。后續(xù)訓(xùn)練時,CNN網(wǎng)絡(luò)只需要讀取這些npy文件,并將索引轉(zhuǎn)為embadding,就可以進(jìn)行訓(xùn)練了。
處理好的THUCNews數(shù)據(jù)下載地址:鏈接: https://pan.baidu.com/s/12Hdf36QafQ3y6KgV_vLTsw 提取碼: m9dx?
? ? 下面的代碼實現(xiàn)的功能:使用jieba工具將中文句子轉(zhuǎn)為字詞,再將字詞根據(jù)word2vec模型轉(zhuǎn)為embadding 的索引矩陣,然后把這些索引矩陣保存下來(*.npy文件),源代碼中batchSize=20000表示:將20000中文TXT文件處理成字詞,轉(zhuǎn)為索引矩陣并保存為一個*.npy文件,相當(dāng)于將20000中文TXT文件保存為一個*.npy文件,主要是為了壓縮數(shù)據(jù),避免單個文件過大的情況。
# -*-coding: utf-8 -*- """@Project: nlp-learning-tutorials@File : create_word2vec.py@Author : panjq@E-mail : pan_jinquan@163.com@Date : 2018-11-08 17:37:21 """ from gensim.models import Word2Vec import random import numpy as np import os import math from utils import files_processing,segmentdef info_npy(file_list):sizes=0for file in file_list:data = np.load(file)print("data.shape:{}".format(data.shape))size = data.shape[0]sizes+=sizeprint("files nums:{}, data nums:{}".format(len(file_list), sizes))return sizesdef save_multi_file(files_list,labels_list,word2vec_path,out_dir,prefix,batchSize,max_sentence_length,labels_set=None,shuffle=False):'''將文件內(nèi)容映射為索引矩陣,并且將數(shù)據(jù)保存為多個文件:param files_list::param labels_list::param word2vec_path: word2vec模型的位置:param out_dir: 文件保存的目錄:param prefix: 保存文件的前綴名:param batchSize: 將多個文件內(nèi)容保存為一個文件:param labels_set: labels集合:return:'''if not os.path.exists(out_dir):os.mkdir(out_dir)# 把該目錄下的所有文件都刪除files_processing.delete_dir_file(out_dir)if shuffle:random.seed(100)random.shuffle(files_list)random.seed(100)random.shuffle(labels_list)sample_num = len(files_list)w2vModel=load_wordVectors(word2vec_path)if labels_set is None:labels_set= files_processing.get_labels_set(label_list)labels_list, labels_set = files_processing.labels_encoding(labels_list, labels_set)labels_list=labels_list.tolist()batchNum = int(math.ceil(1.0 * sample_num / batchSize))for i in range(batchNum):start = i * batchSizeend = min((i + 1) * batchSize, sample_num)batch_files = files_list[start:end]batch_labels = labels_list[start:end]# 讀取文件內(nèi)容,字詞分割batch_content = files_processing.read_files_list_to_segment(batch_files,max_sentence_length,padding_token='<PAD>',segment_type='word')# 將字詞轉(zhuǎn)為索引矩陣batch_indexMat = word2indexMat(w2vModel, batch_content, max_sentence_length)batch_labels=np.asarray(batch_labels)batch_labels = batch_labels.reshape([len(batch_labels), 1])# 保存*.npy文件filename = os.path.join(out_dir,prefix + '{0}.npy'.format(i))labels_indexMat = cat_labels_indexMat(batch_labels, batch_indexMat)np.save(filename, labels_indexMat)print('step:{}/{}, save:{}, data.shape{}'.format(i,batchNum,filename,labels_indexMat.shape))def cat_labels_indexMat(labels,indexMat):indexMat_labels = np.concatenate([labels,indexMat], axis=1)return indexMat_labelsdef split_labels_indexMat(indexMat_labels,label_index=0):labels = indexMat_labels[:, 0:label_index+1] # 第一列是labelsindexMat = indexMat_labels[:, label_index+1:] # 其余是indexMatreturn labels, indexMatdef load_wordVectors(word2vec_path):w2vModel = Word2Vec.load(word2vec_path)return w2vModeldef word2vector_lookup(w2vModel, sentences):'''將字詞轉(zhuǎn)換為詞向量:param w2vModel: word2vector模型:param sentences: type->list[list[str]]:return: sentences對應(yīng)的詞向量,type->list[list[ndarray[list]]'''all_vectors = []embeddingDim = w2vModel.vector_sizeembeddingUnknown = [0 for i in range(embeddingDim)]for sentence in sentences:this_vector = []for word in sentence:if word in w2vModel.wv.vocab:v=w2vModel[word]this_vector.append(v)else:this_vector.append(embeddingUnknown)all_vectors.append(this_vector)all_vectors=np.array(all_vectors)return all_vectorsdef word2indexMat(w2vModel, sentences, max_sentence_length):'''將字詞word轉(zhuǎn)為索引矩陣:param w2vModel::param sentences::param max_sentence_length::return:'''nums_sample=len(sentences)indexMat = np.zeros((nums_sample, max_sentence_length), dtype='int32')rows = 0for sentence in sentences:indexCounter = 0for word in sentence:try:index = w2vModel.wv.vocab[word].index # 獲得單詞word的下標(biāo)indexMat[rows][indexCounter] = indexexcept :indexMat[rows][indexCounter] = 0 # Vector for unkown wordsindexCounter = indexCounter + 1if indexCounter >= max_sentence_length:breakrows+=1return indexMatdef indexMat2word(w2vModel, indexMat, max_sentence_length=None):'''將索引矩陣轉(zhuǎn)為字詞word:param w2vModel::param indexMat::param max_sentence_length::return:'''if max_sentence_length is None:row,col =indexMat.shapemax_sentence_length=colsentences=[]for Mat in indexMat:indexCounter = 0sentence=[]for index in Mat:try:word = w2vModel.wv.index2word[index] # 獲得單詞word的下標(biāo)sentence+=[word]except :sentence+=['<PAD>']indexCounter = indexCounter + 1if indexCounter >= max_sentence_length:breaksentences.append(sentence)return sentencesdef save_indexMat(indexMat,path):np.save(path, indexMat)def load_indexMat(path):indexMat = np.load(path)return indexMatdef indexMat2vector_lookup(w2vModel,indexMat):'''將索引矩陣轉(zhuǎn)為詞向量:param w2vModel::param indexMat::return: 詞向量'''all_vectors = w2vModel.wv.vectors[indexMat]return all_vectorsdef pos_neg_test():positive_data_file = "./data/ham_5000.utf8"negative_data_file = './data/spam_5000.utf8'word2vec_path = 'out/trained_word2vec.model'sentences, labels = files_processing.load_pos_neg_files(positive_data_file, negative_data_file)# embedding_test(positive_data_file,negative_data_file)sentences, max_document_length = segment.padding_sentences(sentences, '<PADDING>', padding_sentence_length=190)# train_wordVectors(sentences,embedding_size=128,word2vec_path=word2vec_path) # 訓(xùn)練word2vec,并保存word2vec_pathw2vModel=load_wordVectors(word2vec_path) #加載訓(xùn)練好的word2vec模型'''轉(zhuǎn)換詞向量提供有兩種方法:[1]直接轉(zhuǎn)換:根據(jù)字詞直接映射到詞向量:word2vector_lookup[2]間接轉(zhuǎn)換:先將字詞轉(zhuǎn)為索引矩陣,再由索引矩陣映射到詞向量:word2indexMat->indexMat2vector_lookup'''# [1]根據(jù)字詞直接映射到詞向量x1=word2vector_lookup(w2vModel, sentences)# [2]先將字詞轉(zhuǎn)為索引矩陣,再由索引矩陣映射到詞向量indexMat_path = 'out/indexMat.npy'indexMat=word2indexMat(w2vModel, sentences, max_sentence_length=190) # 將字詞轉(zhuǎn)為索引矩陣save_indexMat(indexMat, indexMat_path)x2=indexMat2vector_lookup(w2vModel, indexMat) # 索引矩陣映射到詞向量print("x.shape = {}".format(x2.shape))# shape=(10000, 190, 128)->(樣本個數(shù)10000,每個樣本的字詞個數(shù)190,每個字詞的向量長度128)if __name__=='__main__':# THUCNews_path='/home/ubuntu/project/tfTest/THUCNews/test'# THUCNews_path='/home/ubuntu/project/tfTest/THUCNews/spam'THUCNews_path='/home/ubuntu/project/tfTest/THUCNews/THUCNews'# 讀取所有文件列表files_list, label_list = files_processing.gen_files_labels(THUCNews_path)max_sentence_length=300word2vec_path="../../word2vec/models/THUCNews_word2Vec/THUCNews_word2Vec_128.model"# 獲得標(biāo)簽集合,并保存在本地# labels_set=['星座','財經(jīng)','教育']# labels_set = files_processing.get_labels_set(label_list)labels_file='../data/THUCNews_labels.txt'# files_processing.write_txt(labels_file, labels_set)# 將數(shù)據(jù)劃分為train val數(shù)據(jù)集train_files, train_label, val_files, val_label= files_processing.split_train_val_list(files_list, label_list, facror=0.9, shuffle=True)# contents, labels=files_processing.read_files_labels(files_list,label_list)# word2vec_path = 'out/trained_word2vec.model'train_out_dir='../data/train_data'prefix='train_data'batchSize=20000labels_set=files_processing.read_txt(labels_file)# labels_set2 = files_processing.read_txt(labels_file)save_multi_file(files_list=train_files,labels_list=train_label,word2vec_path=word2vec_path,out_dir=train_out_dir,prefix=prefix,batchSize=batchSize,max_sentence_length=max_sentence_length,labels_set=labels_set,shuffle=True)print("*******************************************************")val_out_dir='../data/val_data'prefix='val_data'save_multi_file(files_list=val_files,labels_list=val_label,word2vec_path=word2vec_path,out_dir=val_out_dir,prefix=prefix,batchSize=batchSize,max_sentence_length=max_sentence_length,labels_set=labels_set,shuffle=True)四、訓(xùn)練過程
? ? 訓(xùn)練代碼如下,注意,Github上不能上傳大文件,所以你需要把上面提供的文件都下載下來,并放在對應(yīng)的文件目錄,就可以訓(xùn)練了。
? ? 訓(xùn)練中需要讀取訓(xùn)練數(shù)據(jù),即*.npy文件,*.npy文件保存的是索引數(shù)據(jù),因此需要轉(zhuǎn)為CNN的embadding數(shù)據(jù):這個過程由函數(shù):indexMat2vector_lookup完成:train_batch_data = create_word2vec.indexMat2vector_lookup(w2vModel, train_batch_data)
#! /usr/bin/env python # encoding: utf-8import tensorflow as tf import numpy as np import os from text_cnn import TextCNN from utils import create_batch_data, create_word2vec, files_processingdef train(train_dir,val_dir,labels_file,word2vec_path,batch_size,max_steps,log_step,val_step,snapshot,out_dir):'''訓(xùn)練...:param train_dir: 訓(xùn)練數(shù)據(jù)目錄:param val_dir: val數(shù)據(jù)目錄:param labels_file: labels文件目錄:param word2vec_path: 詞向量模型文件:param batch_size: batch size:param max_steps: 最大迭代次數(shù):param log_step: log顯示間隔:param val_step: 測試間隔:param snapshot: 保存模型間隔:param out_dir: 模型ckpt和summaries輸出的目錄:return:'''max_sentence_length = 300embedding_dim = 128filter_sizes = [3, 4, 5, 6]num_filters = 200 # Number of filters per filter sizebase_lr=0.001# 學(xué)習(xí)率dropout_keep_prob = 0.5l2_reg_lambda = 0.0 # "L2 regularization lambda (default: 0.0)allow_soft_placement = True # 如果你指定的設(shè)備不存在,允許TF自動分配設(shè)備log_device_placement = False # 是否打印設(shè)備分配日志print("Loading data...")w2vModel = create_word2vec.load_wordVectors(word2vec_path)labels_set = files_processing.read_txt(labels_file)labels_nums = len(labels_set)train_file_list = create_batch_data.get_file_list(file_dir=train_dir, postfix='*.npy')train_batch = create_batch_data.get_data_batch(train_file_list, labels_nums=labels_nums, batch_size=batch_size,shuffle=False, one_hot=True)val_file_list = create_batch_data.get_file_list(file_dir=val_dir, postfix='*.npy')val_batch = create_batch_data.get_data_batch(val_file_list, labels_nums=labels_nums, batch_size=batch_size,shuffle=False, one_hot=True)print("train data info *****************************")train_nums=create_word2vec.info_npy(train_file_list)print("val data info *****************************")val_nums = create_word2vec.info_npy(val_file_list)print("labels_set info *****************************")files_processing.info_labels_set(labels_set)# Trainingwith tf.Graph().as_default():session_conf = tf.ConfigProto(allow_soft_placement = allow_soft_placement,log_device_placement = log_device_placement)sess = tf.Session(config = session_conf)with sess.as_default():cnn = TextCNN(sequence_length = max_sentence_length,num_classes = labels_nums,embedding_size = embedding_dim,filter_sizes = filter_sizes,num_filters = num_filters,l2_reg_lambda = l2_reg_lambda)# Define Training procedureglobal_step = tf.Variable(0, name="global_step", trainable=False)optimizer = tf.train.AdamOptimizer(learning_rate=base_lr)# optimizer = tf.train.MomentumOptimizer(learning_rate=0.01, momentum=0.9)grads_and_vars = optimizer.compute_gradients(cnn.loss)train_op = optimizer.apply_gradients(grads_and_vars, global_step=global_step)# Keep track of gradient values and sparsity (optional)grad_summaries = []for g, v in grads_and_vars:if g is not None:grad_hist_summary = tf.summary.histogram("{}/grad/hist".format(v.name), g)sparsity_summary = tf.summary.scalar("{}/grad/sparsity".format(v.name), tf.nn.zero_fraction(g))grad_summaries.append(grad_hist_summary)grad_summaries.append(sparsity_summary)grad_summaries_merged = tf.summary.merge(grad_summaries)# Output directory for models and summariesprint("Writing to {}\n".format(out_dir))# Summaries for loss and accuracyloss_summary = tf.summary.scalar("loss", cnn.loss)acc_summary = tf.summary.scalar("accuracy", cnn.accuracy)# Train Summariestrain_summary_op = tf.summary.merge([loss_summary, acc_summary, grad_summaries_merged])train_summary_dir = os.path.join(out_dir, "summaries", "train")train_summary_writer = tf.summary.FileWriter(train_summary_dir, sess.graph)# Dev summariesdev_summary_op = tf.summary.merge([loss_summary, acc_summary])dev_summary_dir = os.path.join(out_dir, "summaries", "dev")dev_summary_writer = tf.summary.FileWriter(dev_summary_dir, sess.graph)# Checkpoint directory. Tensorflow assumes this directory already exists so we need to create itcheckpoint_dir = os.path.abspath(os.path.join(out_dir, "checkpoints"))checkpoint_prefix = os.path.join(checkpoint_dir, "model")if not os.path.exists(checkpoint_dir):os.makedirs(checkpoint_dir)saver = tf.train.Saver(tf.global_variables(), max_to_keep=5)# Initialize all variablessess.run(tf.global_variables_initializer())def train_step(x_batch, y_batch):"""A single training step"""feed_dict = {cnn.input_x: x_batch,cnn.input_y: y_batch,cnn.dropout_keep_prob: dropout_keep_prob}_, step, summaries, loss, accuracy = sess.run([train_op, global_step, train_summary_op, cnn.loss, cnn.accuracy],feed_dict)if step % log_step==0:print("training: step {}, loss {:g}, acc {:g}".format(step, loss, accuracy))train_summary_writer.add_summary(summaries, step)def dev_step(x_batch, y_batch, writer=None):"""Evaluates model on a dev set"""feed_dict = {cnn.input_x: x_batch,cnn.input_y: y_batch,cnn.dropout_keep_prob: 1.0}step, summaries, loss, accuracy = sess.run([global_step, dev_summary_op, cnn.loss, cnn.accuracy],feed_dict)if writer:writer.add_summary(summaries, step)return loss, accuracyfor i in range(max_steps):train_batch_data, train_batch_label = create_batch_data.get_next_batch(train_batch)train_batch_data = create_word2vec.indexMat2vector_lookup(w2vModel, train_batch_data)train_step(train_batch_data, train_batch_label)current_step = tf.train.global_step(sess, global_step)if current_step % val_step == 0:val_losses = []val_accs = []# for k in range(int(val_nums/batch_size)):for k in range(100):val_batch_data, val_batch_label = create_batch_data.get_next_batch(val_batch)val_batch_data = create_word2vec.indexMat2vector_lookup(w2vModel, val_batch_data)val_loss, val_acc=dev_step(val_batch_data, val_batch_label, writer=dev_summary_writer)val_losses.append(val_loss)val_accs.append(val_acc)mean_loss = np.array(val_losses, dtype=np.float32).mean()mean_acc = np.array(val_accs, dtype=np.float32).mean()print("--------Evaluation:step {}, loss {:g}, acc {:g}".format(current_step, mean_loss, mean_acc))if current_step % snapshot == 0:path = saver.save(sess, checkpoint_prefix, global_step=current_step)print("Saved model checkpoint to {}\n".format(path))def main():# Data preprocesslabels_file = 'data/THUCNews_labels.txt'word2vec_path = "../word2vec/models/THUCNews_word2Vec/THUCNews_word2Vec_128.model"max_steps = 100000 # 迭代次數(shù)batch_size = 128out_dir = "./models" # 模型ckpt和summaries輸出的目錄train_dir = './data/train_data'val_dir = './data/val_data'train(train_dir=train_dir,val_dir=val_dir,labels_file=labels_file,word2vec_path=word2vec_path,batch_size=batch_size,max_steps=max_steps,log_step=50,val_step=500,snapshot=1000,out_dir=out_dir)if __name__=="__main__":main()五、測試過程
? ? 這里提供兩種測試方法:
(1):text_predict(files_list, labels_file, models_path, word2vec_path, batch_size)
? ? 該方法,可以直接測試待分類的中文文本
(2):batch_predict(val_dir,labels_file,models_path,word2vec_path,batch_size)
? ? 該方法,用于批量測試,val_dir目錄保存的是測試數(shù)據(jù)的npy文件,這些文件都是上面用word2vec詞向量處理THUCNews數(shù)據(jù)文件。
#! /usr/bin/env python # encoding: utf-8import tensorflow as tf import numpy as np import os from text_cnn import TextCNN from utils import create_batch_data, create_word2vec, files_processing import mathdef text_predict(files_list, labels_file, models_path, word2vec_path, batch_size):'''預(yù)測...:param val_dir: val數(shù)據(jù)目錄:param labels_file: labels文件目錄:param models_path: 模型文件:param word2vec_path: 詞向量模型文件:param batch_size: batch size:return:'''max_sentence_length = 300embedding_dim = 128filter_sizes = [3, 4, 5, 6]num_filters = 200 # Number of filters per filter sizel2_reg_lambda = 0.0 # "L2 regularization lambda (default: 0.0)print("Loading data...")w2vModel = create_word2vec.load_wordVectors(word2vec_path)labels_set = files_processing.read_txt(labels_file)labels_nums = len(labels_set)sample_num=len(files_list)labels_list=[-1]labels_list=labels_list*sample_numwith tf.Graph().as_default():sess = tf.Session()with sess.as_default():cnn = TextCNN(sequence_length = max_sentence_length,num_classes = labels_nums,embedding_size = embedding_dim,filter_sizes = filter_sizes,num_filters = num_filters,l2_reg_lambda = l2_reg_lambda)# Initialize all variablessess.run(tf.global_variables_initializer())saver = tf.train.Saver()saver.restore(sess, models_path)def pred_step(x_batch):"""predictions model on a dev set"""feed_dict = {cnn.input_x: x_batch,cnn.dropout_keep_prob: 1.0}pred = sess.run([cnn.predictions],feed_dict)return predbatchNum = int(math.ceil(1.0 * sample_num / batch_size))for i in range(batchNum):start = i * batch_sizeend = min((i + 1) * batch_size, sample_num)batch_files = files_list[start:end]# 讀取文件內(nèi)容,字詞分割batch_content= files_processing.read_files_list_to_segment(batch_files,max_sentence_length,padding_token='<PAD>')# [1]將字詞轉(zhuǎn)為索引矩陣,再映射為詞向量batch_indexMat = create_word2vec.word2indexMat(w2vModel, batch_content, max_sentence_length)val_batch_data = create_word2vec.indexMat2vector_lookup(w2vModel, batch_indexMat)# [2]直接將字詞映射為詞向量# val_batch_data = create_word2vec.word2vector_lookup(w2vModel,batch_content)pred=pred_step(val_batch_data)pred=pred[0].tolist()pred=files_processing.labels_decoding(pred,labels_set)for k,file in enumerate(batch_files):print("{}, pred:{}".format(file,pred[k]))def batch_predict(val_dir,labels_file,models_path,word2vec_path,batch_size):'''預(yù)測...:param val_dir: val數(shù)據(jù)目錄:param labels_file: labels文件目錄:param models_path: 模型文件:param word2vec_path: 詞向量模型文件:param batch_size: batch size:return:'''max_sentence_length = 300embedding_dim = 128filter_sizes = [3, 4, 5, 6]num_filters = 200 # Number of filters per filter sizel2_reg_lambda = 0.0 # "L2 regularization lambda (default: 0.0)print("Loading data...")w2vModel = create_word2vec.load_wordVectors(word2vec_path)labels_set = files_processing.read_txt(labels_file)labels_nums = len(labels_set)val_file_list = create_batch_data.get_file_list(file_dir=val_dir, postfix='*.npy')val_batch = create_batch_data.get_data_batch(val_file_list, labels_nums=labels_nums, batch_size=batch_size,shuffle=False, one_hot=True)print("val data info *****************************")val_nums = create_word2vec.info_npy(val_file_list)print("labels_set info *****************************")files_processing.info_labels_set(labels_set)# Trainingwith tf.Graph().as_default():sess = tf.Session()with sess.as_default():cnn = TextCNN(sequence_length = max_sentence_length,num_classes = labels_nums,embedding_size = embedding_dim,filter_sizes = filter_sizes,num_filters = num_filters,l2_reg_lambda = l2_reg_lambda)# Initialize all variablessess.run(tf.global_variables_initializer())saver = tf.train.Saver()saver.restore(sess, models_path)def dev_step(x_batch, y_batch):"""Evaluates model on a dev set"""feed_dict = {cnn.input_x: x_batch,cnn.input_y: y_batch,cnn.dropout_keep_prob: 1.0}loss, accuracy = sess.run([cnn.loss, cnn.accuracy],feed_dict)return loss, accuracyval_losses = []val_accs = []for k in range(int(val_nums/batch_size)):# for k in range(int(10)):val_batch_data, val_batch_label = create_batch_data.get_next_batch(val_batch)val_batch_data = create_word2vec.indexMat2vector_lookup(w2vModel, val_batch_data)val_loss, val_acc=dev_step(val_batch_data, val_batch_label)val_losses.append(val_loss)val_accs.append(val_acc)print("--------Evaluation:step {}, loss {:g}, acc {:g}".format(k, val_loss, val_acc))mean_loss = np.array(val_losses, dtype=np.float32).mean()mean_acc = np.array(val_accs, dtype=np.float32).mean()print("--------Evaluation:step {}, mean loss {:g}, mean acc {:g}".format(k, mean_loss, mean_acc))def main():# Data preprocesslabels_file = 'data/THUCNews_labels.txt'# word2vec_path = 'word2vec/THUCNews_word2vec300.model'word2vec_path = "../word2vec/models/THUCNews_word2Vec/THUCNews_word2Vec_128.model"models_path='models/checkpoints/model-30000'batch_size = 128val_dir = './data/val_data'batch_predict(val_dir=val_dir,labels_file=labels_file,models_path=models_path,word2vec_path=word2vec_path,batch_size=batch_size)test_path='/home/ubuntu/project/tfTest/THUCNews/my_test'files_list = files_processing.get_files_list(test_path,postfix='*.txt')text_predict(files_list, labels_file, models_path, word2vec_path, batch_size)if __name__=="__main__":main()總結(jié)
以上是生活随笔為你收集整理的TensorFlow使用CNN实现中文文本分类的全部內(nèi)容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: OpenCV常见的优化方法和技巧总结
- 下一篇: pytorch实现L2和L1正则化reg