pytorch几种损失函数CrossEntropyLoss、NLLLoss、BCELoss、BCEWithLogitsLoss、focal_loss、heatmap_loss
分類問題常用的幾種損失,記錄下來備忘,后續不斷完善。
nn.CrossEntropyLoss()交叉熵損失
常用于多分類問題
CE = nn.CrossEntropyLoss() loss = CE(input,target)Input: (N, C) , dtype: float, N是樣本數量,在批次計算時通常就是batch_size
target: (N), dtype: long,是類別號,0 ≤ targets[i] ≤ C?1
pytorch中的交叉熵損失就是softmax和NLL損失的組合,即
nn.NLLLoss()
NLL = nn.NLLLoss() loss = NLL(input,target)Input: (N, C) , dtype: float, N是樣本數量,在批次計算時通常就是batch_size
target: (N), dtype: long,是類別號,0 ≤ targets[i] ≤ C?1
nn.BCELoss() 二元交叉熵損失
常用于二分類或多標簽分類
BCE = nn.BCELoss() loss = BCE(input,target)Input: (N, x) , dtype: float, N是樣本數量,在批次計算時通常就是batch_size,x是標簽數
target: (N, x), dtype: float,通常是標簽的獨熱碼形式,注意需改成float格式
nn.BCEWithLogitsLoss()
相當于BCE加上sigmoid
nn.BCEWithLogitsLoss()(input,target) == nn.BCELoss()(torch.sigmoid(input),target)focal_loss
focal loss在pytorch中沒有,它常用在目標檢測問題中,公式和曲線見論文中的圖:
帶平衡參數的focal loss公式如下:
代碼:(待后補)
heatmap_loss
heatmap_loss出現在anchor-free的目標檢測網絡centernet和conernet中,它在focal loss的基礎上進一步改進,加入了對熱點區域的損失減小的措施,以使模型輸出可以較容易的收斂到檢測點附件區域。(否則,必須收斂到檢測點的話,難度太大,收斂速度慢)
注意,它只是在otherwise情況下多加了一個 (1?Yxyc)β(1-Y_{xyc})^\beta(1?Yxyc?)β 除此之外,就是focal loss
總結
以上是生活随笔為你收集整理的pytorch几种损失函数CrossEntropyLoss、NLLLoss、BCELoss、BCEWithLogitsLoss、focal_loss、heatmap_loss的全部內容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: 好玩的deep dream(清晰版,py
- 下一篇: 用pytorch及numpy计算成对余弦