DBNet详解
文章目錄
- 創新點
- 算法的整體架構
- 自適應閾值(Adaptive threshhold)
- 二值化
- 標準二值化
- 可微二值(differentiable Binarization)
- 直觀展示
- 可形變卷積(Deformable convolution)
- 標簽的生成
- PSENet標簽生成
- DBNet標簽生成
- 損失函數
- 后處理
- 代碼閱讀
- 數據預處理
- 入口
- AugmentDetectionData(數據增強類)
- RandomCropData(數據裁剪類)
- MakeICDARData(數據重新組織類)
- MakeSegDetectionData(生成概率圖和對應mask類)
- MakeBorderMap(生成閾值圖和對應Mask類)
- NormalizeImage
- FilterKeys
- 模型結構
- 骨干網絡和FPN
- head部分(decoder)
- binary
- thresh
- step_function
- 損失函數
- binary loss
- thresh loss
- thresh_binary loss
- 邏輯推理
- 補充
- 語義分割中的loss function
- cross entropy loss
- weighted loss
- focal loss
- dice soft loss
- Dice系數計算
- Dice loss
- 梯度分析
- 總結
- soft IOU loss
- 總結
- 總結
- soft IOU loss
- 總結
創新點
? 本文的最大創新點。在基于分割的文本檢測網絡中,最終的二值化map都是使用的固定閾值來獲取,并且閾值不同對性能影響較大。本文中,對每一個像素點進行自適應二值化,二值化閾值由網絡學習得到,徹底將二值化這一步驟加入到網絡里一起訓練,這樣最終的輸出圖對于閾值就會非常魯棒。
和常規基于語義分割算法的區別是多了一條threshold map分支,該分支的主要目的是和分割圖聯合得到更接近二值化的二值圖,屬于輔助分支。其余操作就沒啥了。整個核心知識就這些了。
算法的整體架構
- 首先,圖像輸入特征提取主干,提取特征;
- 其次,特征金字塔上采樣到相同的尺寸,并進行特征級聯得到特征F;
- 然后,特征F用于預測概率圖(probability map P)和閾值圖(threshold map T)
- 最后,通過P和F計算近似二值圖(approximate binary map B)
在訓練期間對P,T,B進行監督訓練,P和B是用的相同的監督信號(label)。在推理時,只需要P或B就可以得到文本框。
網絡輸出:
1.probability map, w*h*1 , 代表像素點是文本的概率
2.threshhold map, w*h*1, 每個像素點的閾值
3.binary map, w*h*1, 由1,2計算得到,計算公式為DB公式
自適應閾值(Adaptive threshhold)
[外鏈圖片轉存失敗,源站可能有防盜鏈機制,建議將圖片保存下來直接上傳(img-V5RaeceH-1610966579215)(C:\F\notebook\DB\20200922201346491.png)]
文中指出傳統的文本檢測算法主要是圖中藍色線,處理流程如下:
- 首先,通過設置一個固定閾值將分割網絡訓練得到的概率圖(segmentation map)轉化為二值圖(binarization map);
- 然后,使用一些啟發式技術(例如像素聚類)將像素分組為文本實例。
而DBNet使用紅色線,思路:
通過網絡去預測圖片每個位置處的閾值,而不是采用一個固定的值,這樣就可以很好將背景與前景分離出來,但是這樣的操作會給訓練帶來梯度不可微的情況,對此對于二值化提出了一個叫做Differentiable Binarization來解決不可微的問題。
? 閾值圖(threshhold map)使用流程如圖2所示,使用閾值map和不使用閾值map的效果對比如圖6所示,從圖6?中可以看到,即使沒用帶監督的閾值map,閾值map也會突出顯示文本邊界區域,這說明邊界型閾值map對最終結果是有利的。所以,本文在閾值map上選擇監督訓練,已達到更好的表現
[外鏈圖片轉存失敗,源站可能有防盜鏈機制,建議將圖片保存下來直接上傳(img-4gDERWPU-1610966579221)(C:\F\notebook\DB\20200922201612829.png)]
二值化
標準二值化
? 一般使用分割網絡(segmentation network)產生的概率圖(probability map P),將P轉化為一個二值圖P,當像素為1的時候,認定其為有效的文本區域,同時二值處理過程:
i和j代表了坐標點的坐標,t是預定義的閾值;
可微二值(differentiable Binarization)
公式1是不可微的,所以沒法直接用于訓練,本文提出可微的二值化函數,如下(其實就是一個帶系數的sigmoid):
就是近似二值圖;T代表從網絡中學習到的自適應閾值圖;k是膨脹因子(經驗性設置k=50).
? 這個近似的二值化函數的表現類似于標準的二值化函數,如圖4所表示,但是因為可微,所以可以直接用于網絡訓練,基于自適應閾值可微二值化不僅可以幫助區分文本區域和背景,而且可以將連接緊密的文本實例分離出來。
? 為了說明DB模塊的引入對于聯合訓練的優勢,作者對該函數進行梯度分析,也就是對approximate
binary map進行求導分析,由于是sigmod輸出,故假設Loss是bce,對于label為0或者1的位置,其Loss函數可以重寫為:
x表示probability map-threshold map,最后一層關于x的梯度很容易計算:
? 看上圖右邊,(b)圖是當label=1,x預測值從-1到1的梯度,可以發現,當k=50時候梯度遠遠大于k=1,錯誤的區域梯度更大,對于label=0的情況分析也是一樣的。故:
(1) 通過增加參數K,就可以達到增大梯度的目的,加快收斂
(2) 在預測錯誤位置,梯度也是顯著增加
總之通過引入DB模塊,通過參數K可以達到增加梯度幅值,更加有利優化,可以使得三個輸出圖優化更好,最終分割結果會優異。而DB模塊本身就是帶參數的sigmod函數,實現如下:
直觀展示
p可以理解,就是有文字的區域有值0.9以上,沒有文字區域黑的為0 .
T是一個只有文字邊界才有值的,其他地方為0 .
? 分別是原圖,gt圖,threshold map圖。 這里再說下threshold map圖,非文字邊界處都是灰色的,這是因為統一加了0.3,所有最小值是0.3.
這里其實還看不清,我們把src+gt+threshold map看看。
可以看到:
- p的ground truth是標注縮水之后
- T的ground truth是文字塊邊緣分別向內向外收縮和擴張
- p與T是公式里面的那兩個變量。
再看這個公式與曲線圖:
P和T我們就用ground truth帶入來理解:
? P網絡學的文字塊內部, T網絡學的文字邊緣,兩者計算得到B。 B的ground truth也是標注縮水之后,和p用的同一個。 在實際操作中,作者把除了文字塊邊緣的區域置為0.3.應該就是為了當在非文字區域, P=0,T=0.3,x=p-T<0這樣拉到負半軸更有利于區分。
可形變卷積(Deformable convolution)
? 可變形卷積可以提供模型一個靈活的感受野,這對于不同縱橫比的文本很有利,本文應用可變形卷積,使用3×3卷積核在ResNet-18或者ResNet-50的conv3,conv4,conv5層。
標簽的生成
概率圖的標簽產成法類似PSENet
PSENet標簽生成
[外鏈圖片轉存失敗,源站可能有防盜鏈機制,建議將圖片保存下來直接上傳(img-8eRCYRHv-1610966579234)(C:\F\notebook\DB\20200923193744225.png)]
? 網絡輸出多個分割結果(S1,Sn),因此訓練時需要有多個GY與其匹配,在本文中,通過收縮原始標簽就可以簡單高效的生成不同尺度的GT,如圖5所示,(b)代表原始的標注結果,也表示最大的分割標簽mask,即Sn,利用Vatti裁剪算法獲取其他尺度的Mask,如圖5(a),將原始多邊形pn 縮小di 像素到 pi ,收縮后的pi 轉換成0/1的二值mask作為GT,用G1,G2,,,,Gn分別代表不同尺度的GT,用數學方式表示的話,尺度比例為ri 。
di 的計算方式為:
di=Area(Pn)?(1?ri2)/Perimeter(pn)d_i=Area(P_n)*(1-r_i^2)/Perimeter(p_n) di?=Area(Pn?)?(1?ri2?)/Perimeter(pn?)
Area(·) 是計算多邊形面積的函數, Perimeter(·)是計算多邊形周長的函數,生成Gi時的尺度比例ri計算公式為:
ri=1?(1?m)?(n?i)/(n?1)r_i=1-(1-m)*(n-i)/(n-1) ri?=1?(1?m)?(n?i)/(n?1)
m代表最小的尺度比例,取值范圍是(0,1],使用上式,通過m和n兩個超參數可以計算出r1,r2,…rn,他們隨著m變現線性增加到最大值1.
DBNet標簽生成
給定一張圖片,文本區域標注的多邊形可以描述為:
G={Sk}k=1nG=\{S_k\}_{k=1}^{n} G={Sk?}k=1n?
n是每隔文本框的標注點總數,在不同數據中可能不同,然后使用vatti裁剪算法,將正樣例區域產生通過收縮polygon從G到Gs,補償公式計算
D:offset;L:周長;A:面積;r:收縮比例,設置為0.4;
損失函數
損失函數為概率map的loss、二值map的loss和閾值map的loss之和。
Ls 是概率map的loss,Lb 是二值map的loss,均使用二值交叉熵loss(BCE),為了解決正負樣本不均衡問題,使用hard negative mining, α和β分別設置為1.0和10 .
[外鏈圖片轉存失敗,源站可能有防盜鏈機制,建議將圖片保存下來直接上傳(img-udDAMvqX-1610966579237)(C:\F\notebook\DB\2020092220283134.png)]
Sl 設計樣本集,其中正陽樣本和負樣本比例是1:3
Lt計算方式為擴展文本多邊形Gd內預測結果和標簽之間的L1距離之和:
[外鏈圖片轉存失敗,源站可能有防盜鏈機制,建議將圖片保存下來直接上傳(img-pGt5zzIG-1610966579239)(C:\F\notebook\DB\20200922203558285.png)]
Rd是在膨脹Gd內像素的索引,y*是閾值map的標簽
后處理
(由于threshold map的存在,probability map的邊界可以學習的很好,因此可以直接按照收縮的方式(Vatti clipping algorithm)擴張回去 )
在推理時可以采用概率圖或近似二值圖來生成文本框,為了方便作者選擇了概率圖,具體步驟如下:
1、使用固定閾值0.2將概率圖做二值化得到二值化圖;
2、由二值化圖得到收縮文字區域;
3、將收縮文字區域按Vatti clipping算法的偏移系數D’通過膨脹再擴展回來。
D‘就是擴展補償,A’是收縮多邊形的面積,L‘就是收縮多邊形的周長,r’作者設置的是1.5;
(注意r‘的值在DBNet工程中不是1.5,而在我自己的數據集上,參數設置為1.3較合適,大家訓練的時候可以根據自己模型效果進行調整)
文中說明DB算法的主要優勢有以下4點:
- 在五個基準數據集上有良好的表現,其中包括水平、多個方向、彎曲的文本。
- 比之前的方法要快很多,因為DB可以提供健壯的二值化圖,從而大大簡化了后處理過程。
- 使用輕量級的backbone(ResNet18)也有很好的表現。
- DB模塊在推理過程中可以去除,因此不占用額外的內存和時間的消耗。
參考:
論文鏈接:https://arxiv.org/pdf/1911.08947.pdf
工程鏈接:https://github.com/MhLiao/DB
? https://github.com/WenmuZhou/DBNet.pytorch
- https://blog.csdn.net/qq_22764813/article/details/107785388
- https://blog.csdn.net/qq_39707285/article/details/108739010
- https://zhuanlan.zhihu.com/p/94677957
- https://mp.weixin.qq.com/s/ehbROyE-grp_F3T3YBX9CA
代碼閱讀
數據預處理
入口
在data/image_dataset.py,數據預處理邏輯非常簡單,就是讀取圖片和gt標注,解析出每張圖片poly標注,包括多邊形標注、字符內容以及是否是忽略文本,忽略文本一般是比較模糊和小的文本。
具體可以在getitem方法里面插入:
ImageDataset.__getitem__():data_process(data)預處理配置:
processes:- class: AugmentDetectionDataaugmenter_args:- ['Fliplr', 0.5]- {'cls': 'Affine', 'rotate': [-10, 10]}- ['Resize', [0.5, 3.0]]only_resize: Falsekeep_ratio: False- class: RandomCropDatasize: [640, 640]max_tries: 10- class: MakeICDARData- class: MakeSegDetectionData- class: MakeBorderMap- class: NormalizeImage- class: FilterKeyssuperfluous: ['polygons', 'filename', 'shape', 'ignore_tags', 'is_training']預處理流程:
AugmentDetectionData(數據增強類)
DB/data/processes/augment_data.py
? 其目的就是對圖片和poly標注進行數據增強,包括翻轉、旋轉和縮放三個,參數如配置所示。本文采用的增強庫是imgaug。可以看出本文訓練階段對數據是不保存比例的resize,然后再進行三種增強。
由于icdar數據,文本區域占比都是非常小的,故不能用直接resize到指定輸入大小的數據增強操作,而是使用后續的randcrop操作比較科學。但是如果自己項目的數據文本區域比較大,則可能沒必要采用RandomCropData這么復雜的數據增強操作,直接resize算了。
RandomCropData(數據裁剪類)
DB/data/processes/random_crop_data.py
因為數據裁剪涉及到比較復雜的多變形標注后處理,所以單獨列出來 。
? 其目的是對圖片進行裁剪到指定的[640, 640]。由于斜框的特點,裁剪增強沒那么容易做,本文采用的裁剪策略非常簡單: 遍歷每一個多邊形標注,只要裁剪后有至少有一個poly還在裁剪框內,則認為該次裁剪有效。這個策略主要可以保證一張圖片中至少有一個gt,且實現比較簡單。
其具體流程是:
代碼如下:
def crop_area(self, im, text_polys):h, w = im.shape[:2]h_array = np.zeros(h, dtype=np.int32)w_array = np.zeros(w, dtype=np.int32)#將poly數據進行水平和垂直方向投影,有標注的地方是1,其余地方是0for points in text_polys:points = np.round(points, decimals=0).astype(np.int32)minx = np.min(points[:, 0])maxx = np.max(points[:, 0])w_array[minx:maxx] = 1miny = np.min(points[:, 1])maxy = np.max(points[:, 1])h_array[miny:maxy] = 1# ensure the cropped area not across a text#找出沒有標注的水平和垂直坐標h_axis = np.where(h_array == 0)[0]w_axis = np.where(w_array == 0)[0]#如果所有位置都有標注,則無法裁剪,直接原圖返回if len(h_axis) == 0 or len(w_axis) == 0:return 0, 0, w, h#對水平和垂直坐標進行連續區域分離,其實就是把所有連續0坐標區域切割處理#后面進行隨機裁剪都是在每個連續區域進行,可以最大程度保證不會裁斷標注h_regions = self.split_regions(h_axis)w_regions = self.split_regions(w_axis)for i in range(self.max_tries):if len(w_regions) > 1:#先從n個區域隨機選擇2個區域,然后在兩個區域內部隨機選擇兩個點,構成x方向最大最小坐標xmin, xmax = self.region_wise_random_select(w_regions, w)else:xmin, xmax = self.random_select(w_axis, w)if len(h_regions) > 1:#h方向也是一樣處理ymin, ymax = self.region_wise_random_select(h_regions, h)else:ymin, ymax = self.random_select(h_axis, h)#不能裁剪的過小if xmax - xmin < self.min_crop_side_ratio * w or ymax - ymin < self.min_crop_side_ratio * h:# area too smallcontinuenum_poly_in_rect = 0for poly in text_polys:#如果有一個poly標注沒有出界,則直接返回,表示裁剪成功if not self.is_poly_outside_rect(poly, xmin, ymin, xmax - xmin, ymax - ymin):num_poly_in_rect += 1breakif num_poly_in_rect > 0:return xmin, ymin, xmax - xmin, ymax - yminreturn 0, 0, w, h? 在得到裁剪區域后,就比較簡單了。先對裁剪區域圖片進行保存長寬比的resize,最長邊為網絡輸入,例如640x640, 然后從上到下pad,得到640x640的圖片
# 計算crop區域 crop_x, crop_y, crop_w, crop_h = self.crop_area(im, all_care_polys) # crop 圖片 保持比例填充 scale_w = self.size[0] / crop_w scale_h = self.size[1] / crop_h scale = min(scale_w, scale_h) h = int(crop_h * scale) w = int(crop_w * scale)padimg = np.zeros((self.size[1], self.size[0], im.shape[2]), im.dtype) padimg[:h, :w] = cv2.resize(im[crop_y:crop_y + crop_h, crop_x:crop_x + crop_w], (w, h)) img = padimg如果進行可視化,會顯示如下所示:
可以看出,這種裁剪策略雖然簡單暴力,但是為了拼接成640x640的輸出,會帶來大量無關全黑像素區域。
MakeICDARData(數據重新組織類)
DB/data/processes/make_icdar_data.py
就是簡單的組織數據而已
#Making ICDAE format #返回值: OrderedDict(image=data['image'],polygons=polygons,ignore_tags=ignore_tags,shape=shape,filename=filename,is_training=data['is_training'])MakeSegDetectionData(生成概率圖和對應mask類)
DB/data/processes/make_seg_detection_data.py
功能:將多邊形數據轉化為mask格式即概率圖gt,并且標記哪些多邊形是忽略區域
#Making binary mask from detection data with ICDAR format 輸入:image,polygons,ignore_tags,filename 輸出:gt(shape:[1,h,w]),mask (shape:[h,w])(用于后面計算binary loss)
? 為了防止標注間相互粘連,不好后處理,區分實例,目前做法都是會進行shrink即沿著多邊形標注的每條邊進行向內縮減一定像素,得到縮減的gt,然后才進行訓練;在測試時候再采用相反的手動還原回來。
? 縮減做法采用的也是常規的Vatti clipping algorithm,是通過pyclipper庫實現的,縮減比例是默認0.4,公式是:
r=0.4,A是多邊形面積,L是多邊形周長,通過該公式就可以對每個不同大小的多邊形計算得到一個唯一的D,代表每條邊的向內縮放像素個數。
gt = np.zeros((1, h, w), dtype=np.float32)#shrink后得到概率圖,包括所有區域mask = np.ones((h, w), dtype=np.float32)#指示哪些區域是忽略區域,0就是忽略區域for i in range(len(polygons)):polygon = polygons[i]height = max(polygon[:, 1]) - min(polygon[:, 1])width = max(polygon[:, 0]) - min(polygon[:, 0])#如果是忽略樣本,或者高寬過小,則mask對應位置設置為0即可if ignore_tags[i] or min(height, width) < self.min_text_size:cv2.fillPoly(mask, polygon.astype(np.int32)[np.newaxis, :, :], 0)ignore_tags[i] = Trueelse:#沿著每條邊進行shrinkpolygon_shape = Polygon(polygon)#多邊形分析庫#每條邊收縮距離:polygon, D=A(1-r^2)/Ldistance = polygon_shape.area * \(1 - np.power(self.shrink_ratio, 2)) / polygon_shape.lengthsubject = [tuple(l) for l in polygons[i]]#實現坐標的偏移padding = pyclipper.PyclipperOffset()padding.AddPath(subject, pyclipper.JT_ROUND,pyclipper.ET_CLOSEDPOLYGON)shrinked = padding.Execute(-distance)#得到縮放后的多邊形if shrinked == []:cv2.fillPoly(mask, polygon.astype(np.int32)[np.newaxis, :, :], 0)ignore_tags[i] = Truecontinueshrinked = np.array(shrinked[0]).reshape(-1, 2)cv2.fillPoly(gt[0], [shrinked.astype(np.int32)], 1)如果進行可視化,如下所示:
? 概率圖內部全白區域就是概率圖的label,右圖是忽略區域mask,0為忽略區域,到時候該區域是不計算概率圖loss的。
MakeBorderMap(生成閾值圖和對應Mask類)
DB/data/make_border_map.py
功能:計算閾值圖和對應mask。
輸入:預處理后的image info: image,polygons,ignore_tags 輸出:thresh_map,thresh_mask (用于后面計算thresh loss)? 仔細看閾值圖的標注,首先紅線點是poly標注;然后對該多邊形先進行shrink操作,得到藍線; 然后向外反向shrink同樣的距離,得到綠色;閾值圖就是綠線和藍色區域,以紅線為起點,計算在綠線和藍線區域內的點距離紅線的距離,故為距離圖。
其代碼的處理邏輯是:
流程:
canvas = np.zeros(image.shape[:2], dtype=np.float32) mask = np.zeros(image.shape[:2], dtype=np.float32)draw_border_map(polygons[i], canvas, mask=mask) canvas = canvas * (0.7 - 0.3) + 0.3 data['thresh_map'] = canvas data['thresh_mask'] = maskdraw_border_map
#處理每條polydef draw_border_map(self, polygon, canvas, mask):polygon = np.array(polygon)assert polygon.ndim == 2assert polygon.shape[1] == 2#向外擴展polygon_shape = Polygon(polygon)distance = polygon_shape.area * \(1 - np.power(self.shrink_ratio, 2)) / polygon_shape.lengthsubject = [tuple(l) for l in polygon]padding = pyclipper.PyclipperOffset()padding.AddPath(subject, pyclipper.JT_ROUND,pyclipper.ET_CLOSEDPOLYGON)padded_polygon = np.array(padding.Execute(distance)[0])#shape:[12,2]擴大和縮減一樣的像素cv2.fillPoly(mask, [padded_polygon.astype(np.int32)], 1.0)#內部全部填充1#計算最小包圍poly矩形xmin = padded_polygon[:, 0].min()xmax = padded_polygon[:, 0].max()ymin = padded_polygon[:, 1].min()ymax = padded_polygon[:, 1].max()width = xmax - xmin + 1height = ymax - ymin + 1#裁剪掉無關區域,加快計算速度polygon[:, 0] = polygon[:, 0] - xminpolygon[:, 1] = polygon[:, 1] - ymin#最小包圍矩形的所有位置坐標xs = np.broadcast_to(np.linspace(0, width - 1, num=width).reshape(1, width), (height, width))ys = np.broadcast_to(np.linspace(0, height - 1, num=height).reshape(height, 1), (height, width))distance_map = np.zeros((polygon.shape[0], height, width), dtype=np.float32)for i in range(polygon.shape[0]):#對每條邊進行遍歷j = (i + 1) % polygon.shape[0]#計算圖片上所有點到線上面的距離absolute_distance = self.distance(xs, ys, polygon[i], polygon[j])#僅僅保留0-1之間的位置,得到距離圖distance_map[i] = np.clip(absolute_distance / distance, 0, 1)distance_map = distance_map.min(axis=0)#繪制到原圖上xmin_valid = min(max(0, xmin), canvas.shape[1] - 1)xmax_valid = min(max(0, xmax), canvas.shape[1] - 1)ymin_valid = min(max(0, ymin), canvas.shape[0] - 1)ymax_valid = min(max(0, ymax), canvas.shape[0] - 1)#如果有多個ploy實例重合,則該區域取最大值canvas[ymin_valid:ymax_valid + 1, xmin_valid:xmax_valid + 1] = np.fmax(1 - distance_map[ymin_valid-ymin:ymax_valid-ymax+height,xmin_valid-xmin:xmax_valid-xmax+width],canvas[ymin_valid:ymax_valid + 1, xmin_valid:xmax_valid + 1])可視化如下所示:
采用matpoltlib繪制距離圖會更好看
NormalizeImage
DB/data/processes/normalize_image.py
圖片歸一化類
FilterKeys
DB/data/processes/filter_keys.py
字典數據過濾類,具體是把superfluous里面的key和value刪掉,不輸入網絡中
#刪除無用的圖片信息,只保留信息: dict("image","gt","mask","thresh_map","thresh_mask")模型結構
DB/structure/model.py
模型結構配置部分:
builder: class: Buildermodel: SegDetectorModelmodel_args:backbone: deformable_resnet18decoder: SegDetectordecoder_args: adaptive: Truein_channels: [64, 128, 256, 512]k: 50骨干網絡和FPN
? 骨架網絡采用的是resnet18或者resnet50,為了增加網絡特征提取能力,在layer2、layer3和layer4模塊內部引入了變形卷積dcnv2模塊。在resnet輸出的4個特征圖后面采用標準的FPN網絡結構,得到4個增強后輸出,然后cat進來,得到1/4的特征圖輸出fuse。
? 其中,resnet骨架特征提取代碼在backbones/resnet.py里,具體是輸出x2, x3, x4, x5,分別是1/4~1/32尺寸。FPN部分代碼在decoders/seg_detector.py里面.
head部分(decoder)
DB/decoders/seg_detector.py
? 輸出head在訓練時候包括三個分支,分別是probability map、threshold map和經過DB模塊計算得到的approximate binary map。三個圖通道都是1,輸出和輸入是一樣大的。要想分割精度高,高分辨率輸出是必要的。
**輸出:**binary、thresh、thresh_binary
fuse = torch.cat((p5, p4, p3, p2), 1) #推理時,只需返回binary binary = self.binarize(fuse) thresh = self.thresh(fuse) thresh_binary = self.step_function(binary, thresh)binary
? 對fuse特征圖經過一系列卷積和反卷積,擴大到和原圖一樣大的輸出,然后經過sigmod層得到0-1輸出概率圖probability map
self.binarize = nn.Sequential(nn.Conv2d(inner_channels, inner_channels //4, 3, padding=1, bias=bias),BatchNorm2d(inner_channels//4),nn.ReLU(inplace=True),nn.ConvTranspose2d(inner_channels//4, inner_channels//4, 2, 2),BatchNorm2d(inner_channels//4),nn.ReLU(inplace=True),nn.ConvTranspose2d(inner_channels//4, 1, 2, 2),nn.Sigmoid())self.binarize.apply(self.weights_init)thresh
? 同時對fuse特征圖采用類似上采樣操作,經過sigmod層的0-1輸出閾值圖threshold map
if adaptive:self.thresh = self._init_thresh(inner_channels, serial=serial, smooth=smooth, bias=bias)self.thresh.apply(self.weights_init)def _init_thresh(self, inner_channels,serial=False, smooth=False, bias=False):in_channels = inner_channelsif serial:in_channels += 1self.thresh = nn.Sequential(nn.Conv2d(in_channels, inner_channels //4, 3, padding=1, bias=bias),BatchNorm2d(inner_channels//4),nn.ReLU(inplace=True),self._init_upsample(inner_channels // 4, inner_channels//4, smooth=smooth, bias=bias),BatchNorm2d(inner_channels//4),nn.ReLU(inplace=True),self._init_upsample(inner_channels // 4, 1, smooth=smooth, bias=bias),nn.Sigmoid())return self.threshstep_function
? 將這兩個輸出圖經過DB模塊得到approximate binary map
torch.reciprocal(1 + torch.exp(-self.k * (binary - thresh)))損失函數
DB/decoders/seg_detector_loss.py
? 輸出是單個單通道圖,probability map和approximate binary map是典型的分割輸出,故其loss就是普通的bce,但是為了平衡正負樣本,還額外采用了難負樣本采樣策略,對背景區域和前景區域采用3:1的設置。對于threshold map,其輸出不一定是0-1之間,后面會介紹其值的范圍,當前采用的是L1 loss,且僅僅計算擴展后的多邊形內部區域,其余區域忽略。
Ls是概率圖,Lt是閾值圖,Lb是近似二值化圖,
? 本文整個論文Loss的實現在decoders/seg_detector_loss.py的L1BalanceCELoss類,可以發現其實approximate binary map采用的并不是論文中的bce,而是可以克服正負樣本平衡的dice loss。一般在高度不平衡的二值分割任務中,dice loss效果會比純bce好,但是更好的策略是dice loss +bce loss。
loss = dice_loss + 10 * l1_loss + 5*bce_lossbinary loss
bce_loss = self.bce_loss(pred['binary'], batch['gt'], batch['mask'])bce_loss:
DB/decoders/balance_cross_entropy_loss.py
def forward(self,pred: torch.Tensor,gt: torch.Tensor,mask: torch.Tensor,return_origin=False):'''Args:pred: shape :math:`(N, 1, H, W)`, the prediction of networkgt: shape :math:`(N, 1, H, W)`, the targetmask: shape :math:`(N, H, W)`, the mask indicates positive regions'''positive = (gt * mask).byte()negative = ((1 - gt) * mask).byte()positive_count = int(positive.float().sum())#負樣本個數為positive_count的self.negative_ratio倍數negative_count = min(int(negative.float().sum()),int(positive_count * self.negative_ratio))loss = nn.functional.binary_cross_entropy(pred, gt, reduction='none')[:, 0, :, :]positive_loss = loss * positive.float()negative_loss = loss * negative.float()#按照loss選擇topK個negative_loss, _ = torch.topk(negative_loss.view(-1), negative_count)balance_loss = (positive_loss.sum() + negative_loss.sum()) /\(positive_count + negative_count + self.eps)if return_origin:return balance_loss, lossreturn balance_lossthresh loss
l1_loss, l1_metric = self.l1_loss(pred['thresh'], batch['thresh_map'], batch['thresh_mask'])l1_loss:
DB/decoders/l1_loss.py
class MaskL1Loss(nn.Module):def __init__(self):super(MaskL1Loss, self).__init__()def forward(self, pred: torch.Tensor, gt, mask):mask_sum = mask.sum()if mask_sum.item() == 0:return mask_sum, dict(l1_loss=mask_sum)else:loss = (torch.abs(pred[:, 0] - gt) * mask).sum() / mask_sumreturn loss, dict(l1_loss=loss)thresh_binary loss
dice_loss = self.dice_loss(pred['thresh_binary'], batch['gt'], batch['mask'])dice_loss:
DB/decoders/dice_loss.py
class DiceLoss(nn.Module):'''Loss function from https://arxiv.org/abs/1707.03237,where iou computation is introduced heatmap manner to measure thediversity bwtween tow heatmaps.'''def __init__(self, eps=1e-6):super(DiceLoss, self).__init__()self.eps = epsdef forward(self, pred: torch.Tensor, gt, mask, weights=None):'''pred: one or two heatmaps of shape (N, 1, H, W),the losses of tow heatmaps are added together.gt: (N, 1, H, W)mask: (N, H, W)'''assert pred.dim() == 4, pred.dim()return self._compute(pred, gt, mask, weights)def _compute(self, pred, gt, mask, weights):if pred.dim() == 4:pred = pred[:, 0, :, :]gt = gt[:, 0, :, :]assert pred.shape == gt.shapeassert pred.shape == mask.shapeif weights is not None:assert weights.shape == mask.shapemask = weights * maskintersection = (pred * gt * mask).sum()union = (pred * mask).sum() + (gt * mask).sum() + self.epsloss = 1 - 2.0 * intersection / unionassert loss <= 1return lossbinary與thresh_binary的標簽都是用的gt
thresh的標簽用的thresh_map
邏輯推理
配置如下:
- name: validate_dataclass: ImageDatasetdata_dir:- '/remote_workspace/ocr/public_dataset/icdar2015/'data_list:- '/remote_workspace/ocr/public_dataset/icdar2015/test_list.txt'processes:- class: AugmentDetectionDataaugmenter_args:- ['Resize', {'width': 1280, 'height': 736}]# - ['Resize', {'width': 2048, 'height': 1152}]only_resize: Truekeep_ratio: False- class: MakeICDARData- class: MakeSegDetectionData- class: NormalizeImage? 如果不考慮label,則其處理邏輯和訓練邏輯有一點不一樣,其把圖片統一resize到指定的長度進行預測。
前面說過閾值圖分支其實可以相當于輔助分支,可以聯合優化各個分支性能。故在測試時候發現概率圖預測值已經蠻好了,故在測試階段實際上把閾值圖分支移除了,只需要概率圖輸出即可。
后處理邏輯在structure/representers/seg_detector_representer.py,本文特色就是后處理比較簡單,故流程為:
采用作者提供的訓練好的權重進行預測,可視化預測結果如下所示:
論文中指標結果:
可以看出變形卷積和閾值圖對整個性能都有比較大的促進作用。
測試icdar2015數據結果:
補充
語義分割中的loss function
cross entropy loss
用于圖像語義分割任務的最常用損失函數是像素級別的交叉熵損失,這種損失會逐個檢查每個像素,將對每個像素類別的預測結果(概率分布向量)與我們的獨熱編碼標簽向量進行比較。
假設我們需要對每個像素的預測類別有5個,則預測的概率分布向量長度為5:
每個像素對應的損失函數為:
[外鏈圖片轉存失敗,源站可能有防盜鏈機制,建議將圖片保存下來直接上傳(img-OeRSbkGm-1610966579269)(https://www.zhihu.com/equation?tex=%5Ctext+%7Bpixel+loss%7D+%3D±%5Csum_%7Bclasses%7D+y_%7Btrue%7D+log+%28y_%7Bpred%7D%29+%5C%5C)]
整個圖像的損失就是對每個像素的損失求平均值。
特別注意的是,binary entropy loss 是針對類別只有兩個的情況,簡稱 bce loss,損失函數公式為:
[外鏈圖片轉存失敗,源站可能有防盜鏈機制,建議將圖片保存下來直接上傳(img-DOMD9pF8-1610966579270)(https://www.zhihu.com/equation?tex=%5Ctext+%7Bbce+loss%7D+%3D±+y_%7Btrue%7D+log+%28y_%7Bpred%7D%29±+%281-y_%7Btrue%7D%29+log+%281-y_%7Bpred%7D%29%5C%5C)]
weighted loss
由于交叉熵損失會分別評估每個像素的類別預測,然后對所有像素的損失進行平均,因此我們實質上是在對圖像中的每個像素進行平等地學習。如果多個類在圖像中的分布不均衡,那么這可能導致訓練過程由像素數量多的類所主導,即模型會主要學習數量多的類別樣本的特征,并且學習出來的模型會更偏向將像素預測為該類別。
FCN論文和U-Net論文中針對這個問題,對輸出概率分布向量中的每個值進行加權,即希望模型更加關注數量較少的樣本,以緩解圖像中存在的類別不均衡問題。
比如對于二分類,正負樣本比例為1: 99,此時模型將所有樣本都預測為負樣本,那么準確率仍有99%這么高,但其實該模型沒有任何使用價值。
為了平衡這個差距,就對正樣本和負樣本的損失賦予不同的權重,帶權重的二分類損失函數公式如下:
[外鏈圖片轉存失敗,源站可能有防盜鏈機制,建議將圖片保存下來直接上傳(img-PtRPVOrh-1610966579271)(https://www.zhihu.com/equation?tex=%5Ctext+%7Bpos_weight%7D+%3D+%5Cfrac%7B%5Ctext+%7Bnum_neg%7D%7D%7B%5Ctext+%7Bnum_pos%7D%7D+%5C%5C+%5Ctext+%7Bloss%7D+%3D±+%5Ctext+%7Bpos_weight%7D+%5Ctimes+y_%7Btrue%7D+log+%28y_%7Bpred%7D%29±+%281-y_%7Btrue%7D%29+log+%281-y_%7Bpred%7D%29%5C%5C)]
要減少假陰性樣本的數量,可以增大 pos_weight;要減少假陽性樣本的數量,可以減小 pos_weight。
focal loss
上面針對不同類別的像素數量不均衡提出了改進方法,但有時還需要將像素分為難學習和容易學習這兩種樣本。
容易學習的樣本模型可以很輕松地將其預測正確,模型只要將大量容易學習的樣本分類正確,loss就可以減小很多,從而導致模型不怎么顧及難學習的樣本,所以我們要想辦法讓模型更加關注難學習的樣本。
對于較難學習的樣本,將 bce loss 修改為:
[外鏈圖片轉存失敗,源站可能有防盜鏈機制,建議將圖片保存下來直接上傳(img-6eZmpbcM-1610966579272)(https://www.zhihu.com/equation?tex=-+%281-y_%7Bpred%7D%29%5E%5Cgamma+%5Ctimes+y_%7Btrue%7D+log+%28y_%7Bpred%7D%29±+y_%7Bpred%7D%5E%5Cgamma+%5Ctimes+%281-y_%7Btrue%7D%29+log+%281-y_%7Bpred%7D%29+%5C%5C)]
其中的 [外鏈圖片轉存失敗,源站可能有防盜鏈機制,建議將圖片保存下來直接上傳(img-pOuXJT0r-1610966579273)(https://www.zhihu.com/equation?tex=%5Cgamma)] 通常設置為2。
舉個例子,預測一個正樣本,如果預測結果為0.95,這是一個容易學習的樣本,有 [外鏈圖片轉存失敗,源站可能有防盜鏈機制,建議將圖片保存下來直接上傳(img-w4wQuNnV-1610966579274)(https://www.zhihu.com/equation?tex=%281-0.95%29%5E2%3D0.0025)] ,損失直接減少為原來的1/400。
而如果預測結果為0.4,這是一個難學習的樣本,有 [外鏈圖片轉存失敗,源站可能有防盜鏈機制,建議將圖片保存下來直接上傳(img-QQu9pL2M-1610966579276)(https://www.zhihu.com/equation?tex=%281-0.5%29%5E2%3D0.25)] ,損失減小為原來的1/4,雖然也在減小,但是相對來說,減小的程度小得多。
所以通過這種修改,就可以使模型更加專注于學習難學習的樣本。
而將這個修改和對正負樣本不均衡的修改合并在一起,就是大名鼎鼎的 focal loss:
[外鏈圖片轉存失敗,源站可能有防盜鏈機制,建議將圖片保存下來直接上傳(img-23Dr9Qbl-1610966579278)(https://www.zhihu.com/equation?tex=-+%5Calpha+%281-y_%7Bpred%7D%29%5E%5Cgamma+%5Ctimes+y_%7Btrue%7D+log+%28y_%7Bpred%7D%29±+%281-%5Calpha%29+y_%7Bpred%7D%5E%5Cgamma+%5Ctimes+%281-y_%7Btrue%7D%29+log+%281-y_%7Bpred%7D%29+%5C%5C)]
dice soft loss
Dice系數計算
語義分割任務中常用的還有一個基于 Dice 系數的損失函數,該系數實質上是兩個樣本之間重疊的度量。此度量范圍為 0~1,其中 Dice 系數為1表示完全重疊。Dice 系數最初是用于二進制數據的,可以計算為:
[外鏈圖片轉存失敗,源站可能有防盜鏈機制,建議將圖片保存下來直接上傳(img-SGCRjBn0-1610966579279)(https://www.zhihu.com/equation?tex=Dice+%3D+%5Cfrac+%7B2+%7CA+%5Ccap+B%7C%7D%7B%7CA%7C+%2B+%7CB%7C%7D+%5C%5C)]
[外鏈圖片轉存失敗,源站可能有防盜鏈機制,建議將圖片保存下來直接上傳(img-TDJx3q9u-1610966579281)(https://www.zhihu.com/equation?tex=%7CA+%5Ccap+B%7C)] 代表集合A和B之間的公共元素,并且 [外鏈圖片轉存失敗,源站可能有防盜鏈機制,建議將圖片保存下來直接上傳(img-Y7BLfLdh-1610966579284)(https://www.zhihu.com/equation?tex=%7C+A+%7C)] 代表集合A中的元素數量(對于集合B同理)。
對于在預測的分割掩碼上評估 Dice 系數,我們可以將 [外鏈圖片轉存失敗,源站可能有防盜鏈機制,建議將圖片保存下來直接上傳(img-wvhJvFjz-1610966579286)(https://www.zhihu.com/equation?tex=%7CA+%5Ccap+B%7C)] 近似為預測掩碼和標簽掩碼之間的逐元素乘法,然后對結果矩陣求和。
計算 Dice 系數的分子中有一個2,那是因為分母中對兩個集合的元素個數求和,兩個集合的共同元素被加了兩次。
Dice loss
為了設計一個可以最小化的損失函數,可以簡單地使用 [外鏈圖片轉存失敗,源站可能有防盜鏈機制,建議將圖片保存下來直接上傳(img-zTwxiWv4-1610966579288)(https://www.zhihu.com/equation?tex=1-Dice+)]。 這種損失函數被稱為 soft Dice loss,這是因為我們直接使用預測出的概率,而不是使用閾值將其轉換成一個二進制掩碼。
Dice loss是針對前景比例太小的問題提出的,dice系數源于二分類,本質上是衡量兩個樣本的重疊部分。
對于二分類問題,一般預測值分為以下幾種:
- TP: true positive,真陽性,預測是陽性,預測對了,實際也是正例。
- TN: true negative,真陰性,預測是陰性,預測對了,實際也是負例。
- FP: false positive,假陽性,預測是陽性,預測錯了,實際是負例。
- FN: false negative,假陰性,預測是陰性,預測錯了,實際是正例。
[外鏈圖片轉存失敗,源站可能有防盜鏈機制,建議將圖片保存下來直接上傳(img-I6TNYlF0-1610966579290)(C:\F\notebook\DB\aHR0cHM6Ly9tbWJpei5xcGljLmNuL21tYml6X3BuZy9pYVRhOHV0NkhpYXdBZWpDcGhDVGtpY3EyVlRaaWJJTTBDR0JEcEJoM2ZGMkd2cjhxbHM2eG04Z2JBUURyUHIyT1VIN2ljWGVSWGdDckVjUVJteDBMTXI4bURBLzY0MA.png)]
這里dice coefficient可以寫成如下形式:
dice=2TP2TP+FP+FNdice=\frac{2TP}{2TP+FP+FN} dice=2TP+FP+FN2TP?
而我們知道:
可見dice coefficient是等同**「F1 score」,直觀上dice coefficient是計算 與 的相似性,本質上則同時隱含precision和recall兩個指標。可見dice loss是直接優化「F1 score」**。
對于神經網絡的輸出,分子與我們的預測和標簽之間的共同激活有關,而分母分別與每個掩碼中的激活數量有關,這具有根據標簽掩碼的尺寸對損失進行歸一化的效果。
對于每個類別的mask,都計算一個 Dice 損失:
[外鏈圖片轉存失敗,源站可能有防盜鏈機制,建議將圖片保存下來直接上傳(img-v2aiaP5D-1610966579294)(https://www.zhihu.com/equation?tex=1-+%5Cfrac+%7B2+%5Csum%5Climits_%7Bpixels%7D+y_%7Btrue%7D+y_%7Bpred%7D%7D%7B%5Csum%5Climits_%7Bpixels%7D+%28y_%7Btrue%7D%5E2+%2B+y_%7Bpred%7D%5E2%29%7D+%5C%5C)]
將每個類的 Dice 損失求和取平均,得到最后的 Dice soft loss。
梯度分析
從dice loss的定義可以看出,dice loss 是一種**「區域相關」**的loss。意味著某像素點的loss以及梯度值不僅和該點的label以及預測值相關,和其他點的label以及預測值也相關,這點和ce (交叉熵cross entropy) ?loss 不同。
dice loss 是應用于語義分割而不是分類任務,并且是一個區域相關的loss,因此更適合針對多點的情況進行分析。由于多點輸出的情況比較難用曲線呈現,這里使用模擬預測值的形式觀察梯度的變化。
下圖為原始圖片和對應的label:[外鏈圖片轉存失敗,源站可能有防盜鏈機制,建議將圖片保存下來直接上傳(img-xPLvYa7F-1610966579296)(C:\F\notebook\DB\aHR0cHM6Ly9tbWJpei5xcGljLmNuL21tYml6X3BuZy9pYVRhOHV0NkhpYXdBZWpDcGhDVGtpY3EyVlRaaWJJTTBDR0JGMWR3blNGU1R5VEY4VFllNHN3SHBrR1FOM3JrWnRQamtYZGhoWjBydWo3RFFyamlibmowZ3lBLzY0MA.png)]
為了便于梯度可視化,這里對梯度求絕對值操作,因為我們關注的是梯度的大小而非方向。另外梯度值都乘以 保證在容易辨認的范圍。
首先定義如下熱圖,值越大,顏色越亮,反之亦然:
[外鏈圖片轉存失敗,源站可能有防盜鏈機制,建議將圖片保存下來直接上傳(img-BeCmQcqD-1610966579298)(C:\F\notebook\DB\aHR0cHM6Ly9tbWJpei5xcGljLmNuL21tYml6X3BuZy9pYVRhOHV0NkhpYXdBZWpDcGhDVGtpY3EyVlRaaWJJTTBDR0I1YXBtUDZNWGJaNDhocklkWmE3dHpGdEZKQmJwSFV6Q0tqTUhWRW5mQ3MyTmh1b2o4TTJTNVEvNjQw.png)]
預測值變化( 值,圖上的數字為預測值區間):
[外鏈圖片轉存失敗,源站可能有防盜鏈機制,建議將圖片保存下來直接上傳(img-jY3HOS0H-1610966579299)(C:\F\notebook\DB\aHR0cHM6Ly9tbWJpei5xcGljLmNuL21tYml6X3BuZy9pYVRhOHV0NkhpYXdBZWpDcGhDVGtpY3EyVlRaaWJJTTBDR0JiT3FrQzRMYXZJMThMbVdxQVNXTmE3STdjR2EwMm95cnB6cVhuZTRMNWhwajJDOWRySXUyS2cvNjQw.png)]
dice loss 對應 值的梯度:
[外鏈圖片轉存失敗,源站可能有防盜鏈機制,建議將圖片保存下來直接上傳(img-pXdKApFF-1610966579301)(C:\F\notebook\DB\aHR0cHM6Ly9tbWJpei5xcGljLmNuL21tYml6X3BuZy9pYVRhOHV0NkhpYXdBZWpDcGhDVGtpY3EyVlRaaWJJTTBDR0JrMXV2Sjhyem1qanMyMXdraWJzYkRBbktiSlVqTXFjaWFYSUt1VkJSaWFDd213TGZpYTMyanFUaWFuQS82NDA.png)]
ce loss 對應 值的梯度:
[外鏈圖片轉存失敗,源站可能有防盜鏈機制,建議將圖片保存下來直接上傳(img-i9VOLACe-1610966579302)(C:\F\notebook\DB\aHR0cHM6Ly9tbWJpei5xcGljLmNuL21tYml6X3BuZy9pYVRhOHV0NkhpYXdBZWpDcGhDVGtpY3EyVlRaaWJJTTBDR0I1aFdVUWliVk1CaWFtRzFxbXdocnp6ZUVqWTY1dmFhZWtlV05iMGVJcGJBUkNpYkoyMFdHUmliZmJRLzY0MA.png)]
可以看出:
- 一般情況下,dice loss 正樣本的梯度大于背景樣本的,尤其是剛開始網絡預測接進0.5的時候。說明dice loss 更具有指向性,更加偏向于正樣本,保證有較低的FN。
- 負樣本(背景區域)也會產生梯度
- 極端情況下,網絡預測接進0或1時,對應點梯度值極小,dice loss 存在梯度飽和現象。此時預測失敗(FN,FP)的情況很難扭轉回來。不過該情況出現的概率較低,因為網絡初始化輸出接近0.5,此時具有較大的梯度值。而網絡通過梯度下降的方式更新參數,只會逐漸削弱預測失敗的像素點。
- 對于ce loss,當前的點的梯度僅和當前預測值與label的距離相關,預測越接近label,梯度越小。當網絡預測接近0或1時,梯度依然保持該特性。
- 對比發現,訓練前中期,dice loss 下正樣本的梯度值相對于ce loss ,顏色更亮,值更大。說明dice loss對挖掘正樣本更加有優勢。
【dice loss為何能夠解決正負樣本不平衡問題?】
因為dice loss 是一個區域相關的loss。區域相關的意思就是,當前像素的loss不光和當前像素的預測值相關,和其他點的值也相關。dice loss的求交的形式可以理解為mask掩碼操作,因此不管圖片有多大,固定大小的正樣本的區域計算的loss是一樣的,對網絡起到的監督貢獻不會隨著圖片的大小而變化。從上圖可視化也發現,訓練更傾向于挖掘前景區域,正負樣本不平衡的情況就是前景占比較小。而ce loss 會公平處理正負樣本,當出現正樣本占比較小時,就會被更多的負樣本淹沒。
【dice loss背景區域能否起到監督作用?】
可以的,但是會小于前景區域。和直觀理解不同的是,隨著訓練的進行,背景區域也能產生較為可觀的梯度。這點和單點的情況分析不同。這里求偏導,當t_i=0 時:
可以看出, 背景區域的梯度是存在的,只有預測值命中的區域極小時, 背景梯度才會很小.
【dice loss 為何訓練會很不穩定?】
在使用dice loss時,一般正樣本為小目標時會產生嚴重的震蕩。因為在只有前景和背景的情況下,小目標一旦有部分像素預測錯誤,那么就會導致loss值大幅度的變動,從而導致梯度變化劇烈。可以假設極端情況,只有一個像素為正樣本,如果該像素預測正確了,不管其他像素預測如何,loss 就接近0,預測錯誤了,loss 接近1。而對于ce loss,loss的值是總體求平均的,更多會依賴負樣本的地方。
總結
dice loss 對正負樣本嚴重不平衡的場景有著不錯的性能,訓練過程中更側重對前景區域的挖掘。但訓練loss容易不穩定,尤其是小目標的情況下。另外極端情況會導致梯度飽和現象。因此有一些改進操作,主要是結合ce loss等改進,比如: ?dice+ce loss,dice + focal loss等,
soft IOU loss
前面我們知道計算 Dice 系數的公式,其實也可以表示為:
[外鏈圖片轉存失敗,源站可能有防盜鏈機制,建議將圖片保存下來直接上傳(img-2X1iB7aM-1610966579304)(https://www.zhihu.com/equation?tex=Dice+%3D+%5Cfrac+%7B2+%7CA+%5Ccap+B%7C%7D%7B%7CA%7C+%2B+%7CB%7C%7D+%3D+%5Cfrac+%7B2+TP%7D%7B2+TP+%2B+FP+%2B+FN%7D+%5C%5C)]
其中 TP 為真陽性樣本,FP 為假陽性樣本,FN 為假陰性樣本。分子和分母中的 TP 樣本都加了兩次。
IoU 的計算公式和這個很像,區別就是 TP 只計算一次:
[外鏈圖片轉存失敗,源站可能有防盜鏈機制,建議將圖片保存下來直接上傳(img-sN52A4f0-1610966579306)(https://www.zhihu.com/equation?tex=IoU+%3D+%5Cfrac+%7B%7CA+%5Ccap+B%7C%7D%7B%7CA%7C+%2B+%7CB%7C±+%7CA+%5Ccap+B%7C%7D+%3D+%5Cfrac+%7BTP%7D%7BTP+%2B+FP+%2B+FN%7D+%5C%5C)]
和 Dice soft loss 一樣,通過 IoU 計算損失也是使用預測的概率值:
[外鏈圖片轉存失敗,源站可能有防盜鏈機制,建議將圖片保存下來直接上傳(img-efgoYMzv-1610966579307)(https://www.zhihu.com/equation?tex=loss+%3D±+%5Cfrac+%7B1%7D%7B%7CC%7C%7D+%5Csum%5Climits_c+%5Cfrac+%7B%5Csum%5Climits_%7Bpixels%7D+y_%7Btrue%7D+y_%7Bpred%7D%7D%7B%5Csum%5Climits_%7Bpixels%7D+%28y_%7Btrue%7D+%2B+y_%7Bpred%7D±+y_%7Btrue%7D+y_%7Bpred%7D%29%7D+%5C%5C)]
其中 C 表示總的類別數。
總結
交叉熵損失把每個像素都當作一個獨立樣本進行預測,而 dice loss 和 iou loss 則以一種更“整體”的方式來看待最終的預測輸出。
預測值相關,和其他點的值也相關。dice loss的求交的形式可以理解為mask掩碼操作,因此不管圖片有多大,固定大小的正樣本的區域計算的loss是一樣的,對網絡起到的監督貢獻不會隨著圖片的大小而變化。從上圖可視化也發現,訓練更傾向于挖掘前景區域,正負樣本不平衡的情況就是前景占比較小。而ce loss 會公平處理正負樣本,當出現正樣本占比較小時,就會被更多的負樣本淹沒。
【dice loss背景區域能否起到監督作用?】
可以的,但是會小于前景區域。和直觀理解不同的是,隨著訓練的進行,背景區域也能產生較為可觀的梯度。這點和單點的情況分析不同。這里求偏導,當t_i=0 時:
可以看出, 背景區域的梯度是存在的,只有預測值命中的區域極小時, 背景梯度才會很小.
【dice loss 為何訓練會很不穩定?】
在使用dice loss時,一般正樣本為小目標時會產生嚴重的震蕩。因為在只有前景和背景的情況下,小目標一旦有部分像素預測錯誤,那么就會導致loss值大幅度的變動,從而導致梯度變化劇烈。可以假設極端情況,只有一個像素為正樣本,如果該像素預測正確了,不管其他像素預測如何,loss 就接近0,預測錯誤了,loss 接近1。而對于ce loss,loss的值是總體求平均的,更多會依賴負樣本的地方。
總結
dice loss 對正負樣本嚴重不平衡的場景有著不錯的性能,訓練過程中更側重對前景區域的挖掘。但訓練loss容易不穩定,尤其是小目標的情況下。另外極端情況會導致梯度飽和現象。因此有一些改進操作,主要是結合ce loss等改進,比如: ?dice+ce loss,dice + focal loss等,
soft IOU loss
前面我們知道計算 Dice 系數的公式,其實也可以表示為:
[外鏈圖片轉存中…(img-2X1iB7aM-1610966579304)]
其中 TP 為真陽性樣本,FP 為假陽性樣本,FN 為假陰性樣本。分子和分母中的 TP 樣本都加了兩次。
IoU 的計算公式和這個很像,區別就是 TP 只計算一次:
[外鏈圖片轉存中…(img-sN52A4f0-1610966579306)]
和 Dice soft loss 一樣,通過 IoU 計算損失也是使用預測的概率值:
[外鏈圖片轉存中…(img-efgoYMzv-1610966579307)]
其中 C 表示總的類別數。
總結
交叉熵損失把每個像素都當作一個獨立樣本進行預測,而 dice loss 和 iou loss 則以一種更“整體”的方式來看待最終的預測輸出。
這兩類損失是針對不同情況,各有優點和缺點,在實際應用中,可以同時使用這兩類損失來進行互補。
總結
- 上一篇: 基于matlab的杨氏双缝干涉模拟仿真+
- 下一篇: Logit Adjust