07-CBAM_block注意力机制
生活随笔
收集整理的這篇文章主要介紹了
07-CBAM_block注意力机制
小編覺得挺不錯的,現在分享給大家,幫大家做個參考.
import tensorflow as tf"""
實現用于CNN的注意力機制的模塊
"""def cbam(inputs, reduction=8):"""變量重用,我們使用的是 tf.AUTO_REUSE:param inputs: 輸入的tensor格式為: [N, H, W, C]:param reduction::return:"""with tf.variable_scope('cbam', reuse=tf.AUTO_REUSE):_, height, width, channels = inputs.get_shape()# 1、實現通道注意力x_mean = tf.reduce_mean(inputs, axis=[1, 2], keep_dims=True) # [N, 1, 1, C]x_mean = tf.layers.conv2d(x_mean, channels // reduction, kernel_size=1, activation=tf.nn.relu, name='cbam1') # [N, 1, 1, C/r]x_mean = tf.layers.conv2d(x_mean, channels, kernel_size=1, activation=None, name='cbam2') # [N, 1, 1, C]x_max = tf.reduce_max(inputs, axis=[1, 2], keep_dims=True) # [N, 1, 1, C]x_max = tf.layers.conv2d(x_max, channels // reduction, kernel_size=1, activation=tf.nn.relu, name='cbam1') # [N, 1, 1, C/r]x_max = tf.layers.conv2d(x_max, channels, kernel_size=1, activation=None, name='cbam2') # [N, 1, 1, C]x = tf.add(x_mean, x_max)x = tf.nn.sigmoid(x) # [N, 1, 1, C]# 獲取通道注意力結果x = tf.multiply(inputs, x) # [N, H, W, C]# 2、空間注意力y_mean = tf.reduce_mean(x, axis=[3], keepdims=True) # [N, H, W, 1]y_max = tf.reduce_max(x, axis=[3], keep_dims=True) # [N, H, W, 1]y = tf.concat([y_mean, y_max], axis=-1) # [N, H, W, 2]y = tf.layers.conv2d(y, filters=1, kernel_size=7, padding='same', activation=tf.nn.sigmoid) # [N, H, W, 1]y = tf.multiply(x, y) # [N, H, W, C]return ydef test():with tf.Graph().as_default():data = tf.ones(shape=[64, 32, 32, 128], dtype=tf.float32)cbam_out = cbam(data, reduction=8)vars_list = tf.trainable_variables()print(cbam_out)print(len(vars_list), '\n', vars_list)if __name__ == '__main__':test()
D:\Anaconda\python.exe D:/AI20/HJZ/04-深度學習/3-CNN/20191215__AI20_CNN/09_CBAM_block.py
WARNING:tensorflow:From D:/AI20/HJZ/04-深度學習/3-CNN/20191215__AI20_CNN/09_CBAM_block.py:19: calling reduce_mean (from tensorflow.python.ops.math_ops) with keep_dims is deprecated and will be removed in a future version.
Instructions for updating:
keep_dims is deprecated, use keepdims instead
WARNING:tensorflow:From D:/AI20/HJZ/04-深度學習/3-CNN/20191215__AI20_CNN/09_CBAM_block.py:27: calling reduce_max (from tensorflow.python.ops.math_ops) with keep_dims is deprecated and will be removed in a future version.
Instructions for updating:
keep_dims is deprecated, use keepdims instead
Tensor("cbam/Mul_1:0", shape=(64, 32, 32, 128), dtype=float32)
6 [<tf.Variable 'cbam/cbam1/kernel:0' shape=(1, 1, 128, 16) dtype=float32_ref>, <tf.Variable 'cbam/cbam1/bias:0' shape=(16,) dtype=float32_ref>, <tf.Variable 'cbam/cbam2/kernel:0' shape=(1, 1, 16, 128) dtype=float32_ref>, <tf.Variable 'cbam/cbam2/bias:0' shape=(128,) dtype=float32_ref>, <tf.Variable 'cbam/conv2d/kernel:0' shape=(7, 7, 2, 1) dtype=float32_ref>, <tf.Variable 'cbam/conv2d/bias:0' shape=(1,) dtype=float32_ref>]Process finished with exit code 0
總結
以上是生活随笔為你收集整理的07-CBAM_block注意力机制的全部內容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: JVM系统优化实践(8):订单系统的垃圾
- 下一篇: 聊一聊工业和自动化之间的5种接近传感器