ResNet详述
簡介
CV常見的卷積神經網絡中,和曾今的VGG一樣,ResNet的提出及其思想對卷積神經網絡的發展產生了巨大的影響。由于最近的課題需要回顧常見的里程碑式的CNN結構,以及分析這些結構的劃時代意義,所以這里簡單介紹一下ResNet。本項目著重實現使用Keras搭建ResNet的網絡結構,同時,利用其在數據集上進行效果評測。
-
論文標題
Deep residual learning for image recognition
-
論文地址
https://arxiv.org/abs/1512.03385
-
論文源碼
https://github.com/pytorch/vision/blob/master/torchvision/models/resnet.py(PyTorch實現)
網絡說明
設計背景
ResNet(Residual Neural Network,殘差神經網絡)由Kaiming He(何愷明大神)等人在2015年的提出。他們通過特殊的網絡結構設計,訓練出了一個152層的深度神經網絡,并在ImageNet比賽分類任務上獲得了冠軍(top-5錯誤率3.57%)。
ResNet的提出第一次有事實地打破了深度網絡無法訓練的難題,其模型的不僅在多個數據集上準確率得到提高,而且參數量還比VGG少(其實縱觀卷積神經網絡這幾年的發展都是網絡越來越深、越來越輕量)。
自此之后,很多神經網絡的設計都借鑒了ResNet的思想,如Google InceptionV4,DenseNet等,其思想使得卷積模型的性能進一步提高。
設計思路
以往問題
- 實驗結果表明,層數的增加會提高網絡的學習效果(理想情況下),但是,單純增加網絡深度,神經網絡的學習會變得十分困難,因此反而不能得到預期效果。(論文中提出,這個原理很容易理解,隨著深度增加,以鏈式法則為基礎的反向傳播會出現難以傳播的問題,這引出很多著名的訓練問題,如梯度消失。)
- 通常認為神經網絡的深度對其性能影響較大,越深的網絡往往能得到更好的性能(網絡越深效果越好是不合理的觀念),但是隨著網絡的加深,容易出現Degradation問題,即準確率上升然后飽和繼而下降的現象,導致訓練困難,這個問題通常叫做梯度退化(gradient degradation)。注意,這并非常見的過擬合(overfit)問題,因為無論訓練集還是驗證集loss都會加大。
解決方法
- 為了解決上述的問題,ResNet提出了一種殘差模塊(residual block)用于解決這個問題。既然網絡加深到一定的程度會出現準確率飽和和準確率下降問題,那么,不妨在特征傳遞的過程中,讓后續網絡層的傳遞媒介影響降低,使用全等映射將輸入直接傳遞給輸出,保證網絡的性能至少不會下降。
- 在上面的殘差模塊中,提供了兩條道路,一條是經過residual映射得到F(x)F(x)F(x),其計算式可以理解為F(x)=relu(w2?(relu(w1?x)))F(x)=r e l u\left(w_{2} *\left(r e l u\left(w_{1} * x\right)\right)\right)F(x)=relu(w2??(relu(w1??x))),另一條則是直接傳遞xxx本身。將這兩個結果合并之后激活傳入下一個模塊,即輸出H(x)=F(x)+xH(x)=F(x)+xH(x)=F(x)+x(這個操作沒有加大網絡運算的復雜度,并且常用的深度學習框架很容易執行)。當網絡訓練已經達到最優,網絡的后續訓練將會限制residual網絡的映射,當residual被限制為0時,就只剩下全等映射的x,網絡也不會因為深度的加深造成準確率下降了。(這樣,通過限制,下一個模塊得到的就是本模塊的輸入xxx,緩解深度難以傳遞的狀況。)
ResNet的核心,擬合殘差函數F=H(x)?g(x)=H(x)?xF=H(x)-g(x)=H(x)-xF=H(x)?g(x)=H(x)?x(選擇g(x)=xg(x)=xg(x)=x是因為此時的g(x)g(x)g(x)效果最好)。其中xxx是由全等映射傳遞過去的,而F(x)F(x)F(x)是由residual映射傳遞的,網絡的訓練將對F(x)F(x)F(x)部分residual網絡的權重進行優化更新,學習的目標不再是完整的輸出H(x)H(x)H(x),而變成了輸出與輸入的差別H(x)?xH(x)-xH(x)?x,F(x)F(x)F(x)也就是殘差。
在整個ResNet中廣泛使用這種殘差模塊,該模塊使用了兩個分支的方法,其中一個分支直接將輸入傳遞到下一層,使得原始信息的更多的保留,解決了深度網絡信息丟失和信息損耗的問題,這種結構被稱之為shortcut或者skip connections。
由于使用了shortcut,原來需要學習逼近的恒等映射H(x)H(x)H(x)變成逼近F(x)=H(x)?xF(x) = H(x) - xF(x)=H(x)?x這個函數。論文作者認為這兩種表達的效果是一致的,但是優化的難度卻不同,F(x)F(x)F(x)的優化比H(x)H(x)H(x)的優化簡單很多(該想法源于圖像處理中的殘差向量編碼)。
殘差網絡
shortcuts
在之前的圖片上,恒等映射xxx使用的為identity shortcuts(同維度元素級相加);事實上,論文還研究了一種project shortcuts(y=F(x,Wi)+Wsxy=F(x,{W_i})+W_sxy=F(x,Wi?)+Ws?x),主要包含以下三種情況,且通過實驗效果是逐漸變好的。
- 維度無變化則直接相連,維度增加的連接通過補零填充后再連接。shortcuts是恒等的,這個連接并不會帶來新的參數。
- 維度無變化則直接相連,維度增加的連接通過投影連接,投影連接會增加參數。
- 所有連接均采用投影連接。
bottleneck
當研究50層以上的深層網絡時,使用了上圖右邊所示的Bottleneck網絡結構,該結構第一層使用1*1的卷積層來降維,最后一層使用1*1的卷積層來進行升維,從而保持與原來輸入同維以便于恒等映射。
網絡結構
網絡結構圖
常見的ResNet為50層的ResNet50以及ResNet101和當年比賽使用的ResNet152。 結構上,先通過一個普通的(7,7)卷積層對輸入圖片進行特征提取同時因為步長為2尺寸減半;隨即,通過(3,3)最大池化層進一步縮小feature map的尺寸;隨后,送入各個殘差模塊(論文中命名為conv2_x的網絡為一個block,同一個block有多個殘差模塊連接)。最后,將特征圖送入全局池化層進行規整,再使用softmax激活進行分類,得到概率分布向量。(這里fc層輸出1000類是因為ImageNet有1000個類別)
雖然,ResNet比起VGG19這樣的網絡深很多,但是運算量是遠少于VGG19等VGGNet的。
網絡對比
左側為經典VGG19結構圖,中間為類VGG19的34層普通網絡圖,右側為帶恒等映射的34層ResNet網絡圖。其中,黑色實線代表同一維度下(卷積核數目相同)的恒等映射,虛線代表不同維度下(卷積核數目不同)的恒等映射。
訓練效果
左側為普通網絡,右側為殘差網絡,粗線代表訓練損失,細線代表驗證損失。顯然,普通網絡34層訓練損失和驗證損失均大于18層,殘差網絡不存在這個現象,這說明殘差網絡確實有效解決了梯度退化問題(這也是ResNet的初衷)。
代碼實現
實際使用各個深度學習框架已經封裝了ResNet的幾種主要網絡結構,使用很方便,不建議自己搭建(尤其對于ResNet152這樣很深的網絡)。
下面使用Keras構建ResNet34和ResNet50,前者使用identity block作為殘差模塊,后者使用bottleneck block作為殘差模塊,同時為了防止過擬合且輸出高斯分布,自定義了緊跟BN層的卷積層Conv2D_BN。
網絡構建對照結構表及結構說明即可,這是復現論文網絡結構的主要依據。
def Conv2D_BN(x, filters, kernel_size, strides=(1, 1), padding='same', name=None):if name:bn_name = name + '_bn'conv_name = name + '_conv'else:bn_name = Noneconv_name = Nonex = Conv2D(filters, kernel_size, strides=strides, padding=padding, activation='relu', name=conv_name)(x)x = BatchNormalization(name=bn_name)(x)return xdef identity_block(input_tensor, filters, kernel_size, strides=(1, 1), is_conv_shortcuts=False):""":param input_tensor::param filters::param kernel_size::param strides::param is_conv_shortcuts: 直接連接或者投影連接:return:"""x = Conv2D_BN(input_tensor, filters, kernel_size, strides=strides, padding='same')x = Conv2D_BN(x, filters, kernel_size, padding='same')if is_conv_shortcuts:shortcut = Conv2D_BN(input_tensor, filters, kernel_size, strides=strides, padding='same')x = add([x, shortcut])else:x = add([x, input_tensor])return xdef bottleneck_block(input_tensor, filters=(64, 64, 256), strides=(1, 1), is_conv_shortcuts=False):""":param input_tensor::param filters::param strides::param is_conv_shortcuts: 直接連接或者投影連接:return:"""filters_1, filters_2, filters_3 = filtersx = Conv2D_BN(input_tensor, filters=filters_1, kernel_size=(1, 1), strides=strides, padding='same')x = Conv2D_BN(x, filters=filters_2, kernel_size=(3, 3))x = Conv2D_BN(x, filters=filters_3, kernel_size=(1, 1))if is_conv_shortcuts:short_cut = Conv2D_BN(input_tensor, filters=filters_3, kernel_size=(1, 1), strides=strides)x = add([x, short_cut])else:x = add([x, input_tensor])return xdef ResNet34(input_shape=(224, 224, 3), n_classes=1000):""":param input_shape::param n_classes::return:"""input_layer = Input(shape=input_shape)x = ZeroPadding2D((3, 3))(input_layer)# block1x = Conv2D_BN(x, filters=64, kernel_size=(7, 7), strides=(2, 2), padding='valid')x = MaxPooling2D(pool_size=(3, 3), strides=2, padding='same')(x)# block2x = identity_block(x, filters=64, kernel_size=(3, 3))x = identity_block(x, filters=64, kernel_size=(3, 3))x = identity_block(x, filters=64, kernel_size=(3, 3))# block3x = identity_block(x, filters=128, kernel_size=(3, 3), strides=(2, 2), is_conv_shortcuts=True)x = identity_block(x, filters=128, kernel_size=(3, 3))x = identity_block(x, filters=128, kernel_size=(3, 3))x = identity_block(x, filters=128, kernel_size=(3, 3))# block4x = identity_block(x, filters=256, kernel_size=(3, 3), strides=(2, 2), is_conv_shortcuts=True)x = identity_block(x, filters=256, kernel_size=(3, 3))x = identity_block(x, filters=256, kernel_size=(3, 3))x = identity_block(x, filters=256, kernel_size=(3, 3))x = identity_block(x, filters=256, kernel_size=(3, 3))x = identity_block(x, filters=256, kernel_size=(3, 3))# block5x = identity_block(x, filters=512, kernel_size=(3, 3), strides=(2, 2), is_conv_shortcuts=True)x = identity_block(x, filters=512, kernel_size=(3, 3))x = identity_block(x, filters=512, kernel_size=(3, 3))x = AveragePooling2D(pool_size=(7, 7))(x)x = Flatten()(x)x = Dense(n_classes, activation='softmax')(x)model = Model(inputs=input_layer, outputs=x)return modeldef ResNet50(input_shape=(224, 224, 3), n_classes=1000):""":param input_shape::param n_classes::return:"""input_layer = Input(shape=input_shape)x = ZeroPadding2D((3, 3))(input_layer)# block1x = Conv2D_BN(x, filters=64, kernel_size=(7, 7), strides=(2, 2), padding='valid')x = MaxPooling2D(pool_size=(3, 3), strides=(2, 2), padding='same')(x)# block2x = bottleneck_block(x, filters=(64, 64, 256), strides=(1, 1), is_conv_shortcuts=True)x = bottleneck_block(x, filters=(64, 64, 256))x = bottleneck_block(x, filters=(64, 64, 256))# block3x = bottleneck_block(x, filters=(128, 128, 512), strides=(2, 2), is_conv_shortcuts=True)x = bottleneck_block(x, filters=(128, 128, 512))x = bottleneck_block(x, filters=(128, 128, 512))x = bottleneck_block(x, filters=(128, 128, 512))# block4x = bottleneck_block(x, filters=(256, 256, 1024), strides=(2, 2), is_conv_shortcuts=True)x = bottleneck_block(x, filters=(256, 256, 1024))x = bottleneck_block(x, filters=(256, 256, 1024))x = bottleneck_block(x, filters=(256, 256, 1024))x = bottleneck_block(x, filters=(256, 256, 1024))x = bottleneck_block(x, filters=(256, 256, 1024))# block5x = bottleneck_block(x, filters=(512, 512, 2048), strides=(2, 2), is_conv_shortcuts=True)x = bottleneck_block(x, filters=(512, 512, 2048))x = bottleneck_block(x, filters=(512, 512, 2048))x = AveragePooling2D(pool_size=(7, 7))(x)x = Flatten()(x)x = Dense(n_classes, activation='softmax')(x)model = Model(inputs=input_layer, outputs=x)return model數據集使用Caltech101數據集,比較性能,不進行數據增廣(注意刪除干擾項)。Batch大小指定為32,使用BN訓練技巧,二次封裝Conv2D。 損失函數使用經典分類的交叉熵損失函數,優化函數使用Adam,激活函數使用Relu。(這都是比較流行的選擇)
具體結果見文末Github倉庫根目錄notebook文件。
比較于我之前介紹的VGGNet,顯然,同一個數據集上ResNet訓練速度快了很多,在同樣輪次的訓練下,驗證集到達的準確率比VGGNet高很多,這正是驗證了比起VGGNet,ResNet計算量更少(表現為訓練速度快),同樣ResNet網絡模型效果好(表現為驗證集準確率高)。
補充說明
ResNet最核心的就是利用residual block通過轉換映射目標從而解決梯度退化問題,這才是ResNet的核心,至于具體的網絡結構,不同的場景可能需求的結構不同,應當明白的是其精髓。本項目源碼開源于我的Github,歡迎Star或者Fork。
總結