自动编码器
學習目標
- 目標
- 了解自動編碼器作用
- 說明自動編碼器的結構
- 應用
- 使用自動編碼器對Mnist手寫數字進行數據降噪處理
5.2.1 自動編碼器什么用
自編碼器的應用主要有兩個方面
- 數據去噪
?
- 進行可視化而降維
- 自編碼器可以學習到比PCA等技術更好的數據投影
?
5.2.1 什么是自動編碼器(Autoencoder)
5.2.1.1 定義
?
自動編碼器是一種數據的壓縮算法,一種使用神經網絡學習數據值編碼的無監督方式。
5.2.1.2 原理作用案例
搭建一個自動編碼器需要完成下面三樣工作:
- 搭建編碼器
- 搭建解碼器
- 設定一個損失函數,用以衡量由于壓縮而損失掉的信息。
- 編碼器和解碼器一般都是參數化的方程,并關于損失函數可導,通常情況是使用神經網絡。
?
5.2.1.3 類別
- 普通自編碼器
- 編解碼網絡使用全連接層
- 多層自編碼器
- 卷積自編碼器
- 編解碼器使用卷積結構
- 正則化自編碼器
- 降噪自編碼器
5.2.2 Keras快速搭建普通自編碼器-基于Mnist手寫數字
5.2.2.1 自編碼器效果
- 迭代50次效果
?
Train on 60000 samples, validate on 10000 samples
Epoch 1/50256/60000 [..............................] - ETA: 44s - loss: 0.69571280/60000 [..............................] - ETA: 11s - loss: 0.68672560/60000 [>.............................] - ETA: 6s - loss: 0.6699 3584/60000 [>.............................] - ETA: 5s - loss: 0.6493
...
...
...
55808/60000 [==========================>...] - ETA: 0s - loss: 0.0925
57088/60000 [===========================>..] - ETA: 0s - loss: 0.0925
58112/60000 [============================>.] - ETA: 0s - loss: 0.0925
59392/60000 [============================>.] - ETA: 0s - loss: 0.0925
60000/60000 [==============================] - 3s 47us/step - loss: 0.0925 - val_loss: 0.0914
5.2.2.2 流程
- 初始化自編碼器結構
- 訓練自編碼器
- 獲取數據
- 模型輸入輸出訓練
- 顯示自編碼前后效果對比
5.2.2.3 代碼編寫
導入所需包
from keras.layers import Input, Dense
from keras.models import Model
from keras.datasets import mnist
import numpy as np
import matplotlib.pyplot as plt
- 1、初始化自編碼器結構
定義編碼器:輸出32個神經元,使用relu激活函數,(32這個值可以自己制定)
定義解碼器:輸出784個神經元,使用sigmoid函數,(784這個值是輸出與原圖片大小一致)
損失:
- 每個像素值的交叉熵損失(輸出為sigmoid值(0,1),輸入圖片要進行歸一化(0,1))
class AutoEncoder(object):"""自動編碼器"""def __init__(self):self.encoding_dim = 32self.decoding_dim = 784self.model = self.auto_encoder_model()def auto_encoder_model(self):"""初始化自動編碼器模型將編碼器和解碼器放在一起作為一個模型:return: auto_encoder"""input_img = Input(shape=(784,))encoder = Dense(self.encoding_dim, activation='relu')(input_img)decoder = Dense(self.decoding_dim, activation='sigmoid')(encoder)auto_encoder = Model(inputs=input_img, outputs=decoder)auto_encoder.compile(optimizer='adam', loss='binary_crossentropy')return auto_encoder
- 2、訓練流程
- 讀取Mnist數據,并進行歸一化處理以及形狀修改
- 模型進行fit訓練
- 指定迭代次數
- 指定每批次數據大小
- 是否打亂數據
- 驗證集合
def train(self):"""訓練自編碼器:param model: 編碼器結構:return:"""(x_train, _), (x_test, _) = mnist.load_data()# 進行歸一化x_train = x_train.astype('float32') / 255.x_test = x_test.astype('float32') / 255.# 進行形狀改變x_train = np.reshape(x_train, (len(x_train), np.prod(x_train.shape[1:])))x_test = np.reshape(x_test, (len(x_test), np.prod(x_test.shape[1:])))print(x_train.shape)print(x_test.shape)# 訓練self.model.fit(x_train, x_train,epochs=5,batch_size=256,shuffle=True,validation_data=(x_test, x_test))
- 3、顯示模型生成的圖片與原始圖片對比
- 導入matplotlib包
def display(self):"""顯示前后效果對比:return:"""(x_train, _), (x_test, _) = mnist.load_data()x_test = np.reshape(x_test, (len(x_test), np.prod(x_test.shape[1:])))decoded_imgs = self.model.predict(x_test)plt.figure(figsize=(20, 4))# 顯示5張結果n = 5for i in range(n):# 顯示編碼前結果ax = plt.subplot(2, n, i + 1)plt.imshow(x_test[i].reshape(28, 28))plt.gray()ax.get_xaxis().set_visible(False)ax.get_yaxis().set_visible(False)# 顯示編解碼后結果ax = plt.subplot(2, n, i + n + 1)plt.imshow(decoded_imgs[i].reshape(28, 28))plt.gray()ax.get_xaxis().set_visible(False)ax.get_yaxis().set_visible(False)plt.show()
5.2.3 基于Mnist手寫數字-深度自編碼器
- 將多個自編碼進行重疊
input_img = Input(shape=(784,))
encoded = Dense(128, activation='relu')(input_img)
encoded = Dense(64, activation='relu')(encoded)
encoded = Dense(32, activation='relu')(encoded)decoded = Dense(64, activation='relu')(encoded)
decoded = Dense(128, activation='relu')(decoded)
decoded = Dense(784, activation='sigmoid')(decoded)auto_encoder = Model(input=input_img, output=decoded)
auto_encoder.compile(optimizer='adam', loss='binary_crossentropy')
我們可以替換原來的編碼器進行測試
59392/60000 [============================>.] - ETA: 0s - loss: 0.0860
最后的損失會較之前同樣的epoch迭代好一些
5.2.4 基于Mnist手寫數字-卷積自編碼器
- 卷積編解碼結構設計
- 編碼器
- Conv2D(32, (3, 3), activation='relu', padding='same')
- MaxPooling2D((2, 2), padding='same')
- Conv2D(32, (3, 3), activation='relu', padding='same')
- MaxPooling2D((2, 2), padding='same')
- 輸出大小為:Tensor("max_pooling2d_2/MaxPool:0", shape=(?, 7, 7, 32), dtype=float32)
- 解碼器:反卷積過程
- Conv2D(32, (3, 3), activation='relu', padding='same')
- UpSampling2D((2, 2))
- Conv2D(32, (3, 3), activation='relu', padding='same')
- UpSampling2D((2, 2))
- Conv2D(1, (3, 3), activation='sigmoid', padding='same')
- 輸出大小:Tensor("conv2d_5/Sigmoid:0", shape=(?, 28, 28, 1), dtype=float32)
- 編碼器
input_img = Input(shape=(28, 28, 1))x = Conv2D(32, (3, 3), activation='relu', padding='same')(input_img)
x = MaxPooling2D((2, 2), padding='same')(x)
x = Conv2D(32, (3, 3), activation='relu', padding='same')(x)
encoded = MaxPooling2D((2, 2), padding='same')(x)
print(encoded)x = Conv2D(32, (3, 3), activation='relu', padding='same')(encoded)
x = UpSampling2D((2, 2))(x)
x = Conv2D(32, (3, 3), activation='relu', padding='same')(x)
x = UpSampling2D((2, 2))(x)
decoded = Conv2D(1, (3, 3), activation='sigmoid', padding='same')(x)
print(decoded)auto_encoder = Model(input_img, decoded)
auto_encoder.compile(optimizer='adam', loss='binary_crossentropy')
由于修改了模型的輸入輸出數據形狀,所以在訓練的地方同樣也需要修改(顯示的時候數據輸入也要修改)
x_train = np.reshape(x_train, (len(x_train), 28, 28, 1))
x_test = np.reshape(x_test, (len(x_test), 28, 28, 1))
5.2.4 基于Mnist手寫數字-降噪自編碼器
- 降噪自編碼器效果
?
- 過程
- 對原始數據添加噪音
- 隨機加上正態分布的噪音
- x_train + np.random.normal(loc=0.0, scale=1.0, size=x_train.shape)
# 添加噪音
x_train_noisy = x_train + np.random.normal(loc=0.0, scale=3.0, size=x_train.shape)
x_test_noisy = x_test + np.random.normal(loc=0.0, scale=3.0, size=x_test.shape)# 重新進行限制每個像素值的大小在0~1之間
x_train_noisy = np.clip(x_train_noisy, 0., 1.)
x_test_noisy = np.clip(x_test_noisy, 0., 1.)
在進行顯示的時候也要進行修改
# 獲取數據改變形狀,增加噪點數據
(x_train, _), (x_test, _) = mnist.load_data()
x_test = np.reshape(x_test, (len(x_test), 28, 28, 1))
x_test_noisy = x_test + np.random.normal(loc=3.0, scale=10.0, size=x_test.shape)# 預測結果
decoded_imgs = self.model.predict(x_test_noisy)# 修改需要顯示的圖片變量
plt.imshow(x_test_noisy[i].reshape(28, 28))
5.2.5 總結
- 掌握自動編碼器的結構
- 掌握正則化自動編碼器結構作用
總結
- 上一篇: 生成对抗网络(GAN)
- 下一篇: CapsuleNet(了解)