2.使用RNN做诗歌生成
生活随笔
收集整理的這篇文章主要介紹了
2.使用RNN做诗歌生成
小編覺得挺不錯的,現在分享給大家,幫大家做個參考.
詩歌生成比分類問題要稍微麻煩一些,而且第一次使用RNN做文本方面的問題,還是有很多概念性的東西~~~
數據下載:
鏈接:https://pan.baidu.com/s/1uCDup7U5rGuIlIb-lnZgjQ
提取碼:f436
data.py——數據處理
1 import numpy as np 2 import os 3 4 def get_data(conf): 5 ''' 6 生成數據 7 :param conf: 配置選項,Config對象 8 :return: word2ix: 每個字符對應的索引id,如u'月'->100 9 :return: ix2word: 每個字符對應的索引id,如100->u'月' 10 :return: data: 每一行是一首詩對應的字的索引id 11 ''' 12 if os.path.exists(conf.pickle_path): 13 14 datas = np.load(conf.pickle_path) #np數據文件 15 data = datas['data'] 16 ix2word = datas['ix2word'].item() 17 word2ix = datas['word2ix'].item() 18 return data, word2ix, ix2word View Codeconfig.py——配置文件
1 class Config(object): 2 """Base configuration class.For custom configurations, create a 3 sub-class that inherits from this one and override properties that 4 need to changed 5 """ 6 7 # 模型保存路徑前綴(幾個epoch后保存) 8 model_prefix = 'checkpoints/tang' 9 10 # 模型保存路徑 11 model_path = 'checkpoints/tang.pth' 12 13 # start words 14 start_words = '春江花月夜' 15 16 # 生成詩歌的類型,默認為藏頭詩 17 p_type = 'acrostic' 18 19 # 訓練次數 20 max_epech = 200 21 22 # 數據存放的路徑 23 pickle_path = 'data/tang.npz' 24 25 # mini批大小 26 batch_size =128###128 27 28 # dataloader加載數據使用多少進程 29 num_workers = 4 30 31 # LSTM的層數 32 num_layers = 2 33 34 # 詞向量維數 35 embedding_dim = 128 36 37 # LSTM隱藏層維度 38 hidden_dim = 256 39 40 # 多少個epoch保存一次模型權重和詩 41 save_every = 10 42 43 # 訓練是生成詩的保存路徑 44 out_path = 'out' 45 46 # 測試生成詩的保存路徑 47 out_poetry_path = 'out/poetry.txt' 48 49 # 生成詩的最大長度 50 max_gen_len = 200 51 use_gpu=True View Codemodel.py——模型
1 import torch.nn as nn 2 import torch 3 class PoetryModel(nn.Module): 4 def __init__(self, vocab_size, conf, device): 5 super(PoetryModel, self).__init__() 6 self.num_layers = conf.num_layers 7 self.hidden_dim = conf.hidden_dim 8 self.device = device 9 # 定義詞向量層 10 self.embeddings = nn.Embedding(vocab_size, conf.embedding_dim)#(詞庫個數,詞向量維度) 11 # 定義2層的LSTM,并且batch位于函數參數的第一位 12 self.lstm = nn.LSTM(conf.embedding_dim, conf.hidden_dim, num_layers=self.num_layers) 13 # 定義全連接層,后接一個softmax進行分類 14 self.linear_out = nn.Linear(conf.hidden_dim, vocab_size) 15 16 def forward(self, input, hidden=None): 17 ''' 18 :param input: (seq,batch) 19 :return: 模型的結果 20 ''' 21 seq_len, batch_size = input.size() 22 embeds = self.embeddings(input) # embeds_size:(seq_len,batch_size,embedding_dim) 23 if hidden is None: 24 h0 = torch.zeros(self.num_layers, batch_size, self.hidden_dim).to(self.device) 25 c0 = torch.zeros(self.num_layers, batch_size, self.hidden_dim).to(self.device) 26 else: 27 h0, c0 = hidden 28 output, hidden = self.lstm(embeds, (h0, c0))#(seq_len,batch_size,隱藏層維度) 29 30 31 output = self.linear_out(output.view(seq_len * batch_size, -1)) # output_size:(seq_len*batch_size,vocab_size) 32 return output, hidden View Codetrain.py——訓練模型
1 import torch 2 from torch import nn 3 from torch.autograd import Variable 4 from torch.optim import Adam 5 from torch.utils.data import DataLoader 6 7 from data import get_data 8 from model import PoetryModel 9 from config import Config 10 device=torch.device('cuda:0') 11 conf = Config() 12 13 def generate(model, start_words, ix2word, word2ix, prefix_words=None): 14 ''' 15 給定幾個詞,根據這幾個詞接著生成一首完整的詩歌 16 ''' 17 print(start_words) 18 results = list(start_words) 19 start_word_len = len(start_words) 20 # 手動設置第一個詞為<START> 21 # 這個地方有問題,最后需要再看一下 22 input = Variable(torch.Tensor([word2ix['<START>']]).view(1, 1).long()) 23 if conf.use_gpu: input = input.cuda() 24 hidden = None 25 26 if prefix_words: 27 for word in prefix_words: 28 output, hidden = model(input, hidden) 29 # 下邊這句話是為了把input變成1*1? 30 input = Variable(input.data.new([word2ix[word]])).view(1, 1) 31 for i in range(conf.max_gen_len): 32 output, hidden = model(input, hidden)#input只有一個詞,對應的是'<START>'的序號 33 34 35 if i < start_word_len: 36 w = results[i] 37 input = Variable(input.data.new([word2ix[w]])).view(1, 1) 38 else: 39 top_index = output.cpu().data.topk(1)[1][0].numpy()[0] 40 41 w = ix2word[top_index] 42 results.append(w) 43 input = Variable(input.data.new([top_index])).view(1, 1) 44 if w == '<EOP>': 45 del results[-1] # -1的意思是倒數第一個 46 break 47 return results 48 49 def gen_acrostic(model, start_words, ix2word, word2ix, prefix_words=None): 50 ''' 51 生成藏頭詩 52 start_words : u'深度學習' 53 生成: 54 深木通中岳,青苔半日脂。 55 度山分地險,逆浪到南巴。 56 學道兵猶毒,當時燕不移。 57 習根通古岸,開鏡出清羸。 58 ''' 59 results = [] 60 start_word_len = len(start_words) 61 input = Variable(torch.Tensor([word2ix['<START>']]).view(1, 1).long()) 62 if conf.use_gpu: input = input.cuda() 63 hidden = None 64 65 index = 0 # 用來指示已經生成了多少句藏頭詩 66 # 上一個詞 67 pre_word = '<START>' 68 69 if prefix_words: 70 for word in prefix_words: 71 output, hidden = model(input, hidden) 72 input = Variable(input.data.new([word2ix[word]])).view(1, 1) 73 74 for i in range(conf.max_gen_len): 75 output, hidden = model(input, hidden) 76 top_index = output.data[0].topk(1)[1][0] 77 w = ix2word[top_index] 78 79 if (pre_word in {u'。', u'!', '<START>'}): 80 # 如果遇到句號,藏頭的詞送進去生成 81 82 if index == start_word_len: 83 # 如果生成的詩歌已經包含全部藏頭的詞,則結束 84 break 85 else: 86 # 把藏頭的詞作為輸入送入模型 87 w = start_words[index] 88 index += 1 89 input = Variable(input.data.new([word2ix[w]])).view(1, 1) 90 else: 91 # 否則的話,把上一次預測是詞作為下一個詞輸入 92 input = Variable(input.data.new([word2ix[w]])).view(1, 1) 93 results.append(w) 94 pre_word = w 95 return results 96 97 def train(**kwargs): 98 99 for k, v in kwargs.items(): 100 setattr(conf, k, v) 101 # 獲取數據 102 data, word2ix, ix2word = get_data(conf) 103 # 生成dataload 104 dataloader = DataLoader(dataset=data, batch_size=conf.batch_size, 105 shuffle=True, 106 drop_last=True, 107 num_workers=conf.num_workers) 108 # 定義模型 109 model = PoetryModel(len(word2ix), conf, device).to(device) 110 # model.load_state_dict(torch.load(r'C:\Users\ocean\PycharmProjects\guesswhat_pytorch\checkpoints\tang_0.pth')) 111 # fout = open('%s/p%d' % (conf.out_path, 1), 'w',encoding='utf-8') 112 # # for word in list('春江花月夜'): 113 # # gen_poetry = generate(model, word, ix2word, word2ix) 114 # # fout.write("".join(gen_poetry) + "\n\n") 115 # gen_poetry = generate(model, list("北郵真的號"), ix2word, word2ix) 116 # 117 # fout.write("".join(gen_poetry) + "\n\n") 118 # fout.close() 119 # torch.save(model.state_dict(), "%s_%d.pth" % (conf.model_prefix, 1)) 120 121 122 123 # 定義優化器 124 optimizer = Adam(model.parameters()) 125 # 定義損失函數 126 criterion = nn.CrossEntropyLoss() 127 # 開始訓練模型 128 for epoch in range(conf.max_epech): 129 epoch_loss = 0 130 for i, data in enumerate(dataloader): 131 132 data = data.long().transpose(1, 0).contiguous()#(sequence長度,batch_size) 133 134 input, target = data[:-1, :], data[1:, :] 135 input, target = input.to(device), target.to(device) 136 optimizer.zero_grad() 137 output, _ = model(input) 138 139 loss = criterion(output, target.view(-1)) 140 141 loss.backward() 142 optimizer.step() 143 epoch_loss += loss.item() 144 print("epoch_%d_loss:%0.4f" % (epoch, epoch_loss / conf.batch_size)) 145 if epoch % conf.save_every == 0: 146 fout = open('%s/p%d' % (conf.out_path, epoch), 'w',encoding='utf-8') 147 for word in list('春江花月夜'): 148 gen_poetry = generate(model, word, ix2word, word2ix) 149 fout.write("".join(gen_poetry) + "\n\n") 150 fout.close() 151 torch.save(model.state_dict(), "%s_%d.pth" % (conf.model_prefix, epoch)) 152 153 154 if __name__ == '__main__': 155 156 train() View Code最終效果:
春雨,君王背日暮花枝。桂花飄雨裛芙蓉,花蕚垂紅綰芙蓉。上天高峨落不歸,中有一枝春未老。一枝香蘂紅妝結,春風吹花飄落萼。今朝今日凌風沙,今日還家花落早。東風吹落柳條生,柳色參差春水東。昨日風煙花滿樹,今日東風正如萍。杏園春色不自勝,青春忽倒春風來。春風不哢花枝落,況復春風花滿枝。江上春未央,春光照四面。一日一千里,一朝一瞬息。不如塌然云,不見巢下樹。一身一何讬,萬事皆有敵。君子不敢橫,君心若為役。嗚呼兩鬢苦,又如寒玉翦。不知何代費,所以心不殞。一朝得之愚,所以心所施。我亦我未領,我來亦未歸。始知與君子,不覺身不饑。彭澤有余事,吾君何所為。何以為我人,於今有耆誰。
花間一人家,十五日中見。一朝出門門,不見君子諾。車騎徒自媒,朱紱不能競。拜軍拜車騎,倏忽嫖姚羌。既無征鎮憤,慷慨望鄉國。一朝辭虜府,暮宿在薊壘。君子儻封侯,今人在咸朔。英英在其間,日昃不敢作。云山互相見,魏闕空懷戚。何必在沛人,裴回眇眇。所念無窮,斯人不怠。。
月白風來吹,君心不可攀。從來一字內,不覺一朝閑。未達身難棄,衰容事不閑。不憂譏孺子,不覺老農閑。寢食能供給,閑橙媿漉肱。酒闌湘口臥,窗拔峽添燈。靜譚畬茶駭,遙聞夜笛閑。蘆洲多雨霽,石火帶霜蒸。釀酒眠新熟,扁舟醉自閑。夜漁疎竹塢,春水釣漁關。石筍穿云燒,江花帶筍斑。此時多好客,不敢問山僧。
夜夜拍人笑,春風弄酒絲。花開桃李嶺,花落洞庭春。酒思同君醉,詩成是襪塵。自憐心已矣,何事夢何如。擯世才難易,傷心鏡不如。臉如銀燭薄,色映玉樓嚬。繡戶雕筵軟,鴛鴦拂枕春。相逢期洛下,夢想憶揚秦。玉匣調金鼎,金盤染髻巾。鷰人曾有什,山寺不相親。鶴毳應相毒,蠅蚊爽有真。空余襟袖下,不覺世間人。
參考博客:https://blog.csdn.net/jiangpeng59/article/details/81003058
轉載于:https://www.cnblogs.com/tangweijqxx/p/10608997.html
總結
以上是生活随笔為你收集整理的2.使用RNN做诗歌生成的全部內容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: 【已解决】点击Import Packag
- 下一篇: 这些行业高薪职位最多