第五章functions.py中的交叉熵代码解释
functions中的一個交叉熵的代碼是這樣的:
def cross_entropy_error(y, t):if y.ndim == 1:t = t.reshape(1, t.size)y = y.reshape(1, y.size)if t.size == y.size:t = t.argmax(axis=1) #① batch_size = y.shape[0]#這里batch_size是100return -np.sum(np.log(y[np.arange(batch_size), t] + 1e-7)) / batch_size我們來仔細對照下這個函數到底怎么回事。
根據[1]:
L(w)=1N∑n=1NH(pn,qn)L(w)=\frac{1}{N}\sum_{n=1}^NH(p_n,q_n)L(w)=N1?n=1∑N?H(pn?,qn?)
=?1N∑n=1N[pnlogqn+(1?pn)log(1?qn)](這個形式僅僅適用于二分類)=-\frac{1}{N}\sum_{n=1}^N[p_nlog q_n+(1-p_n)log(1-q_n)](這個形式僅僅適用于二分類)=?N1?n=1∑N?[pn?logqn?+(1?pn?)log(1?qn?)](這個形式僅僅適用于二分類)
=?1N∑n=1N∑k=12[pnk?logqnk]=-\frac{1}{N}\sum_{n=1}^N\sum_{k=1}^2[p_{nk}·log\ q_{nk}]=?N1?n=1∑N?k=1∑2?[pnk??log?qnk?]
| batch_size | NNN | |
| np.sum | ∑\sum∑ | |
| pnp_npn? | 代碼①處等號右側的ttt | 因為這個ttt中的元素非0即1,所以在代碼中沒有顯式體現 |
| qnq_nqn? | yyy |
注意:
課本P87的(4.2):
E=?∑ktklogykE=-\sum_kt_klog\ y_kE=?k∑?tk?log?yk?
針對的是單條數據的預測,這里的k指的是單條數據的第kkk個元素,對應神經網絡的第kkk個輸出端口。
課本P89的(4.3):
E=?1N∑n∑ktnklogynkE=-\frac{1}{N}\sum_n \sum_k t_{nk}\ log\ y_{nk}E=?N1?n∑?k∑?tnk??log?ynk?
針對的是NNN條數據,這里的n∈[1,N]n∈[1,N]n∈[1,N]
Reference:
[1]https://blog.csdn.net/appleyuchi/article/details/86497288
總結
以上是生活随笔為你收集整理的第五章functions.py中的交叉熵代码解释的全部內容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: softmax函数上溢出和下溢出(转载+
- 下一篇: 第六章插图以及代码文件和插图之间的对应关