PyTorch之torch.nn.CrossEntropyLoss()
簡(jiǎn)介
信息熵: 按照真實(shí)分布p來(lái)衡量識(shí)別一個(gè)樣本所需的編碼長(zhǎng)度的期望,即平均編碼長(zhǎng)度
交叉熵: 使用擬合分布q來(lái)表示來(lái)自真實(shí)分布p的編碼長(zhǎng)度的期望,即平均編碼長(zhǎng)度
多分類任務(wù)中的交叉熵?fù)p失函數(shù)
代碼
1)導(dǎo)入包
import torch import torch.nn as nn2)準(zhǔn)備數(shù)據(jù)
在圖片單標(biāo)簽分類時(shí),輸入m張圖片,輸出一個(gè)m x N的Tensor,其中N是分類個(gè)數(shù)。比如輸入3張圖片,分三類,最后的輸出是一個(gè)3 x 3的Tensor,舉個(gè)例子:
3)計(jì)算概率分布
第123行分別是第123張圖片的結(jié)果,假設(shè)第123列分別是貓、狗和豬的分類得分。
然后對(duì)每一行使用Softmax,這樣可以得到每張圖片的概率分布。
這里dim的意思是計(jì)算Softmax的維度,這里設(shè)置dim=1,可以看到每一行的加和為1。比如第一行0.1022+0.3831+0.5147=1。
4)對(duì)Softmax的結(jié)果取自然對(duì)數(shù)
log_output=torch.log(soft_output) print('log_output:\n',log_output)
對(duì)比softmax與log的結(jié)合與nn.LogSoftmaxloss(負(fù)對(duì)數(shù)似然損失)的輸出結(jié)果,兩者是一致的。
5)NLLLoss
NLLLoss的結(jié)果就是把上面的輸出與y_label對(duì)應(yīng)的那個(gè)值拿出來(lái),再去掉負(fù)號(hào),再求均值。
y_target中[1, 2, 0]對(duì)應(yīng)上述第一行的第二個(gè),第二行的第三個(gè),第三行的第1個(gè):
(0.9594+0.4241+0.5265)/3=0.6367
6) CrossEntropyLoss()
參考鏈接:
https://blog.csdn.net/qq_22210253/article/details/85229988
https://zhuanlan.zhihu.com/p/98785902
https://zhuanlan.zhihu.com/p/56638625
總結(jié)
以上是生活随笔為你收集整理的PyTorch之torch.nn.CrossEntropyLoss()的全部?jī)?nèi)容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: Chrome浏览器密码框自动填充的bug
- 下一篇: Servlet的认识