keras 的 example 文件 mnist_siamese.py 解析
該程序是介紹了一個(gè)孿生神經(jīng)網(wǎng)絡(luò),大致就是給出兩張圖片,比較兩張圖片的相似性,比如人臉對(duì)比等
這里的數(shù)據(jù)集是mnist,代碼中首先會(huì)建立一些圖片對(duì),就是pairs,如果是同類的圖片,則把y值設(shè)置為 1,如果是不同類的圖片,則把 y 值設(shè)置為 0;y值就是相似度
?
基礎(chǔ)神經(jīng)網(wǎng)絡(luò)結(jié)構(gòu)為:
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
input_1 (InputLayer) (None, 28, 28) 0
_________________________________________________________________
flatten_1 (Flatten) (None, 784) 0
_________________________________________________________________
dense_1 (Dense) (None, 128) 100480
_________________________________________________________________
dropout_1 (Dropout) (None, 128) 0
_________________________________________________________________
dense_2 (Dense) (None, 128) 16512
_________________________________________________________________
dropout_2 (Dropout) (None, 128) 0
_________________________________________________________________
dense_3 (Dense) (None, 128) 16512
=================================================================
Total params: 133,504
Trainable params: 133,504
Non-trainable params: 0
_________________________________________________________________
None
就是給一個(gè)數(shù)字圖片,經(jīng)過(guò)一堆計(jì)算之后,會(huì)生成一個(gè)128維的向量,
然后計(jì)算兩張圖片的運(yùn)算結(jié)果向量之間的歐氏距離,也就是函數(shù)?euclidean_distance,向量平方和再開(kāi)方,其值肯定是一個(gè)整數(shù)
我們猜想一下,如果兩張圖片,送的數(shù)據(jù)一模一樣,那么其距離結(jié)果肯定為 0,如果完全不同的兩張圖片,那么運(yùn)算結(jié)果也應(yīng)該是一個(gè)很大的數(shù)
?
?
下面是一個(gè)自定義的損失函數(shù)?contrastive_loss,損失函數(shù)的參數(shù)為 預(yù)期值 和 實(shí)際運(yùn)算結(jié)果;
這里再?gòu)?qiáng)調(diào)一下,y_pred 的值肯定不會(huì)小于 0,因?yàn)槭瞧椒胶?#xff0c;再開(kāi)方,肯定大于等于 0
損失函數(shù)的返回值為:
return y_true * square_pred + (1 - y_true) * margin_square
因?yàn)?y_true 是我們?cè)跀?shù)據(jù)集預(yù)處理中設(shè)置的,其值可能為 0 ,或 1,
如果 y_true 為 1,也就是兩張圖片預(yù)期一樣時(shí),則返回結(jié)果可以簡(jiǎn)化為?square_pred
square_pred 就是歐氏距離;
如果 y_true 為 0, 也就是兩張圖片預(yù)期不一樣時(shí),則返回結(jié)果可以簡(jiǎn)化為?margin_square;
margin_square = K.square(K.maximum(margin - y_pred, 0))
損失函數(shù)就是,在結(jié)果符合預(yù)期時(shí),損失值很小,在不符合預(yù)期時(shí),損失值很大;
margin_square 就是 (1 - 歐氏距離)的平方,比如傳了一個(gè)1,一個(gè)2,那么歐氏距離應(yīng)該是很大,比如 y_pred 值為0.9,那就是符合預(yù)期,(1 - y_pred)**2,結(jié)果為 0.01,損失就很小;
而如果傳了不同的兩個(gè)值,一個(gè)1,一個(gè)2,結(jié)果歐式距離很小,比如 y_pred 值為 0.01,那么損失函數(shù)應(yīng)該很大,這里確實(shí)很大,0.99 ** 2,差不多為 1 了,需要神經(jīng)網(wǎng)絡(luò)進(jìn)行優(yōu)化;
但如果傳入的是兩個(gè)不同的圖片,計(jì)算結(jié)果?y_pred 為 10,那其實(shí)也是符合預(yù)期,反正兩張不同的圖片 y_pred 越大越好,那么這時(shí)候計(jì)算?K.maximum(margin - y_pred, 0) 就是?K.maximum(-9, 0),也就是 0,損失為0
?
其他就沒(méi)有什么難點(diǎn)了
總結(jié)
以上是生活随笔為你收集整理的keras 的 example 文件 mnist_siamese.py 解析的全部?jī)?nèi)容,希望文章能夠幫你解決所遇到的問(wèn)題。
- 上一篇: keras 的 example 文件 m
- 下一篇: keras 的 example 文件 m