【论文笔记】Multi-Content Complementation Network for Salient Object Detection in Optical RSI
論文?
?論文:Multi-Content Complementation Network for Salient Object Detection in Optical Remote Sensing Images
發(fā)表:?IEEE TGRS, vol. 60, pp. 1-13, 2022
地址:Multi-Content Complementation Network for Salient Object Detection in Optical Remote Sensing Images | IEEE Journals & Magazine | IEEE Xplore
https://arxiv.org/abs/2112.01932
代碼: https://github.com/mathlee/mccnetGitHub - MathLee/MCCNet: [TGRS2022] [MCCNet] Multi-Content Complementation Network for Salient Object Detection in Optical Remote Sensing Images
正文
動機
光學遙感圖像顯著性目標檢測(RSI-SOD),很具有挑戰(zhàn)性。現(xiàn)有的SOD方法多是自然場景(NSI),但兩者間存在較大差異。(獲取方式差異很大,使得兩種圖像差異很大,NSI使用手機、相機等設配拍攝,RSI使用衛(wèi)星或航空器拍攝)。直接將NSI-SOD的方式用于RSI-SOD可能不合適,以前的工作借鑒NSI-SOD和結合RSI的特點提出解決方案證明是可行的,本文結合前人的工作(前景特征、邊緣特征、背景特征單獨使用都是有效的,BCE損失、IoU損失、度量感知F-m損失也能work),提出自己的方法。
做法
- 提出多內(nèi)容互補網(wǎng)絡( Multi-Content Complementation Network,MCCNet)來探索RSI-SOD多內(nèi)容的互補性。在多尺度特征上使用MCCM模塊,利用前景特征、邊緣特征、背景特征和全局圖像級特征間的內(nèi)容互補性,通過注意力機制來突出RSI特征在不同尺度上的顯著區(qū)域。
- 結合三種損失構成綜合損失,并加入邊緣損失,共同監(jiān)督模型的訓練。
網(wǎng)絡架構
?MCCNet由三個部分組成:編碼器網(wǎng)絡、5個MCCM組件、解碼器網(wǎng)絡。
- 編碼器網(wǎng)絡,用vgg16提取基本特征;
- 5個MCCM組件,對前景、邊緣、背景和全局圖像特征間的互補信息進行建模;
- 解碼器網(wǎng)絡,逐級上采樣推斷出顯著目標。
訓練時對5層進行監(jiān)督,采用三種損失。?同時利用邊緣損失監(jiān)督MCCM中的產(chǎn)生的邊緣。
Multi-Content Complementation Module,MCCM
?設計動機: 前景特征、背景特征、邊緣特征都有助于顯著性檢測,于是提出多內(nèi)容互補模塊(MCCM)結合它們,并添加全局信息。
輸入:編碼器提取的特征;輸出:多內(nèi)容互補特征。 中間過程:產(chǎn)生4種不同類型特征,并進行聚合。(看圖或代碼即可,后面附有代碼)
前景和邊緣特征,都與顯著區(qū)域相關,相輔相成,求和聚集。 背景特征,由前者取反得到,關注到非顯著區(qū)域。 前面三者包含了局部細節(jié)。 全局信息,丟失細節(jié)信息,捕捉特征整體基調(diào)。
4種特征聚合方式:拼接后卷積,再相加。?
?MCCM 特征可視化
a^3_fe表示前景+邊緣特征;a^3_b表示背景特征;a^3_g表示整體基調(diào)。?損失函數(shù)
實驗
?23個對比方法在兩個數(shù)據(jù)集上的實驗
23個對比方法在兩個數(shù)據(jù)集上的實驗。?
?不同場景不同方法可視化效果對比
消融實驗
驗證MCCM中不同特征都能work,相互間存在互補性?
消融的MCCM具體結構?
?MCCM中殘差路徑的效果提升
?使用不同損失組合的性能比較
關鍵代碼 MCCM
# https://github.com/MathLee/MCCNet/blob/main/model/MCCNet_models.pyimport torch import torch.nn as nn import torch.nn.functional as F import numpy as np import os# 定義一個卷積操作:卷積+BN+ReLU class BasicConv2d(nn.Module):def __init__(self, in_planes, out_planes, kernel_size, stride=1, padding=0, dilation=1):super(BasicConv2d, self).__init__()self.conv = nn.Conv2d(in_planes, out_planes,kernel_size=kernel_size, stride=stride,padding=padding, dilation=dilation, bias=False)self.bn = nn.BatchNorm2d(out_planes)self.relu = nn.ReLU(inplace=True)def forward(self, x):x = self.conv(x)x = self.bn(x)x = self.relu(x)return x# 通道注意力(SE) class ChannelAttention(nn.Module):def __init__(self, in_planes, ratio=16):super(ChannelAttention, self).__init__()self.max_pool = nn.AdaptiveMaxPool2d(1)self.fc1 = nn.Conv2d(in_planes, in_planes // 16, 1, bias=False)self.relu1 = nn.ReLU()self.fc2 = nn.Conv2d(in_planes // 16, in_planes, 1, bias=False)self.sigmoid = nn.Sigmoid()def forward(self, x):max_out = self.fc2(self.relu1(self.fc1(self.max_pool(x))))out = max_outreturn self.sigmoid(out)# 空間注意力 SA class SpatialAttention(nn.Module):def __init__(self, kernel_size=7):super(SpatialAttention, self).__init__()assert kernel_size in (3, 7), 'kernel size must be 3 or 7'padding = 3 if kernel_size == 7 else 1self.conv1 = nn.Conv2d(1, 1, kernel_size, padding=padding, bias=False)self.sigmoid = nn.Sigmoid()def forward(self, x):max_out, _ = torch.max(x, dim=1, keepdim=True)x = max_outx = self.conv1(x)return self.sigmoid(x)# 空間注意力,不帶sigmoid class SpatialAttention_no_s(nn.Module):def __init__(self, kernel_size=7):super(SpatialAttention_no_s, self).__init__()assert kernel_size in (3, 7), 'kernel size must be 3 or 7'padding = 3 if kernel_size == 7 else 1self.conv1 = nn.Conv2d(1, 1, kernel_size, padding=padding, bias=False)# self.sigmoid = nn.Sigmoid()def forward(self, x):max_out, _ = torch.max(x, dim=1, keepdim=True)x = max_outx = self.conv1(x)return x# Multi-Content Complementation Module,MCCM class MCCM(nn.Module):def __init__(self, cur_channel):super(MCCM, self).__init__()self.relu = nn.ReLU(True)self.ca = ChannelAttention(cur_channel)self.sa_fg = SpatialAttention_no_s()self.sa_edge = SpatialAttention_no_s()self.sigmoid = nn.Sigmoid()self.FE_conv = BasicConv2d(cur_channel, cur_channel, 3, padding=1)self.BG_conv = BasicConv2d(cur_channel, cur_channel, 3, padding=1)self.global_avg_pool = nn.AdaptiveAvgPool2d(1)self.conv1 = BasicConv2d(cur_channel, cur_channel, 1)self.sa_ic = SpatialAttention()self.IC_conv = BasicConv2d(cur_channel, cur_channel, 3, padding=1)self.FE_B_I_conv = BasicConv2d(3 * cur_channel, cur_channel, 3, padding=1)def forward(self, x):x_ca = x.mul(self.ca(x))# Foreground attentionx_sa_fg = self.sa_fg(x_ca)# Edge attentionx_edge = self.sa_edge(x_ca)# Foreground and Edge (FE) featurex_fg_edge = self.FE_conv(x_ca.mul(self.sigmoid(x_sa_fg) + self.sigmoid(x_edge)))# Background featurex_bg = self.BG_conv(x_ca.mul(1 - self.sigmoid(x_sa_fg) - self.sigmoid(x_edge)))# Image-level contentin_size = x.shape[2:]x_gap = self.conv1(self.global_avg_pool(x))x_up = F.interpolate(x_gap, size=in_size, mode="bilinear", align_corners=True)x_ic = self.IC_conv(x.mul(self.sa_ic(x_up)))x_RE_B_I = self.FE_B_I_conv(torch.cat((x_fg_edge, x_bg, x_ic), 1))return (x + x_RE_B_I), x_edge總結
以上是生活随笔為你收集整理的【论文笔记】Multi-Content Complementation Network for Salient Object Detection in Optical RSI的全部內(nèi)容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: 完美解决织梦CMS加入lian666自动
- 下一篇: 小程序picker标题_微信小程序实现自