注意力机制 SE-Net 原理与 TensorFlow2.0 实现
文章目錄
- 🍵 介紹
- 🍛 SE 模塊
- 🥡 SE 模塊應用分析
- 🍘 SE 模型效果對比
- 🥙 SE 模塊代碼實現
- 🍜 SE 模塊插入到 DenseNet 代碼實現
🍵 介紹
SENet 是 ImageNet 2017(ImageNet 收官賽)的冠軍模型,是由WMW團隊發(fā)布。具有復雜度低,參數少和計算量小的優(yōu)點。且SENet 思路很簡單,很容易擴展到已有網絡結構如 Inception 和 ResNet 中。
🍛 SE 模塊
已經有很多工作在空間維度上來提升網絡的性能,如 Inception 等,而 SENet 將關注點放在了特征通道之間的關系上。其具體策略為:通過學習的方式來自動獲取到每個特征通道的重要程度,然后依照這個重要程度去提升有用的特征并抑制對當前任務用處不大的特征,這又叫做“特征重標定”策略。具體的 SE 模塊如下圖所示:
給定一個輸入 xxx ,其特征通道數為 c1c_1c1?,通過一系列卷積等一般變換 FtrF_{tr}Ftr? 后得到一個特征通道數為 c2c_2c2? 的特征。與傳統的卷積神經網絡不同,我們需要通過下面三個操作來重標定前面得到的特征。
首先是 Squeeze 操作,我們順著空間維度來進行特征壓縮,將一個通道中整個空間特征編碼為一個全局特征,這個實數某種程度上具有全局的感受野,并且輸出的通道數和輸入的特征通道數相等,例如將形狀為 (1, 32, 32, 10) 的 feature map 壓縮成 (1, 1, 1, 10)。此操作通常采用采用 global average pooling 來實現。
得到了全局描述特征后,我們進行 Excitation 操作來抓取特征通道之間的關系,它是一個類似于循環(huán)神經網絡中門的機制:
這里采用包含兩個全連接層的 bottleneck 結構,即中間小兩頭大的結構:其中第一個全連接層起到降維的作用,并通過 ReLU 激活,第二個全連接層用來將其恢復至原始的維度。進行 Excitation 操作的最終目的是為每個特征通道生成權重,即學習到的各個通道的激活值(sigmoid 激活,值在 0~1 之間)。
最后是一個 Scale 的操作,我們將 Excitation 的輸出的權重看做是經過特征選擇后的每個特征通道的重要性,然后通過乘法逐通道加權到先前的特征上,完成在通道維度上的對原始特征的重標定,從而使得模型對各個通道的特征更有辨別能力,這類似于attention機制。
🥡 SE 模塊應用分析
SE模塊的靈活性在于它可以直接應用現有的網絡結構中。以 Inception 和 ResNet 為例,我們只需要在 Inception 模塊或 Residual 模塊后添加一個 SE 模塊即可。具體如下圖所示:
上圖分別是將 SE 模塊嵌入到 Inception 結構與 ResNet 中的示例,方框旁邊的維度信息代表該層的輸出,rrr 表示 Excitation 操作中的降維系數。
🍘 SE 模型效果對比
SE 模塊很容易嵌入到其它網絡中,為了驗證 SE 模塊的作用,在其它流行網絡如 ResNet 和 Inception 中引入 SE 模塊,測試其在 ImageNet 上的效果,如下表所示:
首先看一下網絡的深度對 SE 的影響。上表分別展示了 ResNet-50、ResNet-101、ResNet-152 和嵌入 SE 模型的結果。第一欄 Original 是原作者實現的結果,為了進行公平的比較,重新進行了實驗得到 Our re-implementation 的結果。最后一欄 SE-module 是指嵌入了 SE 模塊的結果,它的訓練參數和第二欄 Our re-implementation 一致。括號中的紅色數值是指相對于 Our re-implementation 的精度提升的幅值。
從上表可以看出,SE-ResNets 在各種深度上都遠遠超過了其對應的沒有SE的結構版本的精度,這說明無論網絡的深度如何,SE模塊都能夠給網絡帶來性能上的增益。值得一提的是,SE-ResNet-50 可以達到和ResNet-101 一樣的精度;更甚,SE-ResNet-101 遠遠地超過了更深的ResNet-152。
上圖展示了ResNet-50 和 ResNet-152 以及它們對應的嵌入SE模塊的網絡在ImageNet上的訓練過程,可以明顯地看出加入了SE模塊的網絡收斂到更低的錯誤率上。
🥙 SE 模塊代碼實現
import tensorflow as tfclass Squeeze_excitation_layer(tf.keras.Model):def __init__(self, filter_sq):# filter_sq 是 Excitation 中第一個卷積過程中卷積核的個數super().__init__()self.filter_sq = filter_sqself.avepool = tf.keras.layers.GlobalAveragePooling2D()self.dense = tf.keras.layers.Dense(filter_sq)self.relu = tf.keras.layers.Activation('relu')self.sigmoid = tf.keras.layers.Activation('sigmoid')def call(self, inputs):squeeze = self.avepool(inputs)excitation = self.dense(squeeze)excitation = self.relu(excitation)excitation = tf.keras.layers.Dense(inputs.shape[-1])(excitation)excitation = self.sigmoid(excitation)excitation = tf.keras.layers.Reshape((1, 1, inputs.shape[-1]))(excitation)scale = inputs * excitationreturn scaleSE = Squeeze_excitation_layer(16) inputs = np.zeros((1, 32, 32, 32), dtype=np.float32) SE(inputs).shape TensorShape([1, 32, 32, 32])🍜 SE 模塊插入到 DenseNet 代碼實現
from tensorflow.keras.models import Model from tensorflow.keras import layers from tensorflow.keras import backend def dense_block(x, blocks, name):for i in range(blocks):x = conv_block(x, 32, name=name + '_block' + str(i + 1))return x def conv_block(x, growth_rate, name):bn_axis = 3 x1 = layers.BatchNormalization(axis=bn_axis,epsilon=1.001e-5,name=name + '_0_bn')(x)x1 = layers.Activation('relu', name=name + '_0_relu')(x1)x1 = layers.Conv2D(4 * growth_rate, 1,use_bias=False,name=name + '_1_conv')(x1)x1 = layers.BatchNormalization(axis=bn_axis, epsilon=1.001e-5,name=name + '_1_bn')(x1)x1 = layers.Activation('relu', name=name + '_1_relu')(x1)x1 = layers.Conv2D(growth_rate, 3,padding='same',use_bias=False,name=name + '_2_conv')(x1)x = layers.Concatenate(axis=bn_axis, name=name + '_concat')([x, x1])return xdef transition_block(x, reduction, name):bn_axis = 3x = layers.BatchNormalization(axis=bn_axis, epsilon=1.001e-5,name=name + '_bn')(x)x = layers.Activation('relu', name=name + '_relu')(x)x = layers.Conv2D(int(backend.int_shape(x)[bn_axis] * reduction), 1,use_bias=False,name=name + '_conv')(x)x = layers.AveragePooling2D(2, strides=2, name=name + '_pool')(x)return xdef DenseNet(blocks, input_shape=None, classes=1000, **kwargs):img_input = layers.Input(shape=input_shape)bn_axis = 3# 224,224,3 -> 112,112,64x = layers.ZeroPadding2D(padding=((3, 3), (3, 3)))(img_input)x = layers.Conv2D(64, 7, strides=2, use_bias=False, name='conv1/conv')(x)x = layers.BatchNormalization(axis=bn_axis, epsilon=1.001e-5, name='conv1/bn')(x)x = layers.Activation('relu', name='conv1/relu')(x)# 112,112,64 -> 56,56,64x = layers.ZeroPadding2D(padding=((1, 1), (1, 1)))(x)x = layers.MaxPooling2D(3, strides=2, name='pool1')(x)# 56,56,64 -> 56,56,64+32*block[0]# Densenet121 56,56,64 -> 56,56,64+32*6 == 56,56,256x = dense_block(x, blocks[0], name='conv2')# 56,56,64+32*block[0] -> 28,28,32+16*block[0]# Densenet121 56,56,256 -> 28,28,32+16*6 == 28,28,128x = transition_block(x, 0.5, name='pool2')# 28,28,32+16*block[0] -> 28,28,32+16*block[0]+32*block[1]# Densenet121 28,28,128 -> 28,28,128+32*12 == 28,28,512x = dense_block(x, blocks[1], name='conv3')# Densenet121 28,28,512 -> 14,14,256x = transition_block(x, 0.5, name='pool3')# Densenet121 14,14,256 -> 14,14,256+32*block[2] == 14,14,1024x = dense_block(x, blocks[2], name='conv4')# Densenet121 14,14,1024 -> 7,7,512x = transition_block(x, 0.5, name='pool4')# Densenet121 7,7,512 -> 7,7,256+32*block[3] == 7,7,1024x = dense_block(x, blocks[3], name='conv5')# 加SE注意力機制x = Squeeze_excitation_layer(16)(x)x = layers.BatchNormalization(axis=bn_axis, epsilon=1.001e-5, name='bn')(x)x = layers.Activation('relu', name='relu')(x)x = layers.GlobalAveragePooling2D(name='avg_pool')(x)x = layers.Dense(classes, activation='softmax', name='fc1000')(x)inputs = img_inputif blocks == [6, 12, 24, 16]:model = Model(inputs, x, name='densenet121')elif blocks == [6, 12, 32, 32]:model = Model(inputs, x, name='densenet169')elif blocks == [6, 12, 48, 32]:model = Model(inputs, x, name='densenet201')else:model = Model(inputs, x, name='densenet')return modeldef DenseNet121(input_shape=[224,224,3], classes=3, **kwargs):return DenseNet([6, 12, 24, 16], input_shape, classes, **kwargs)def DenseNet169(input_shape=[224,224,3], classes=3, **kwargs):return DenseNet([6, 12, 32, 32], input_shape, classes, **kwargs)def DenseNet201(input_shape=[224,224,3], classes=3, **kwargs):return DenseNet([6, 12, 48, 32], input_shape, classes, **kwargs)參考文章:
- ImageNet 2017冠軍模型SE-Net詳解
總結
以上是生活随笔為你收集整理的注意力机制 SE-Net 原理与 TensorFlow2.0 实现的全部內容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: chapter6
- 下一篇: 发票OCR识别技术太屌了,哈哈哈哈