tensorflow实现宝可梦数据集迁移学习
目錄
一、遷移學(xué)習(xí)簡(jiǎn)介
二、構(gòu)建預(yù)訓(xùn)練模型
1、調(diào)用內(nèi)置模型
2、修改模型
3、構(gòu)建模型
三、導(dǎo)入數(shù)據(jù)和預(yù)處理
1、設(shè)置batch size
2、讀取訓(xùn)練數(shù)據(jù)
3、讀取驗(yàn)證數(shù)據(jù)
4、讀取測(cè)試數(shù)據(jù)
5、預(yù)處理
四、模型訓(xùn)練
1、設(shè)置early_stopping
2、模型編譯
3、模型設(shè)置
4、模型評(píng)估
5、保存訓(xùn)練權(quán)重
五、模型預(yù)測(cè)
1、構(gòu)建預(yù)測(cè)模型
2、導(dǎo)入權(quán)重
3、預(yù)測(cè)
4、對(duì)比分析
一、遷移學(xué)習(xí)簡(jiǎn)介
遷移學(xué)習(xí)就是把預(yù)先定義好的模型,以及該模型在對(duì)應(yīng)數(shù)據(jù)集上訓(xùn)練得到的參數(shù)遷移到新的模型,用來(lái)幫助新模型訓(xùn)練。通過(guò)遷移學(xué)習(xí)我們可以將模型已經(jīng)學(xué)到的參數(shù),分享給新模型從而加快并優(yōu)化模型的學(xué)習(xí)效率,從而不用像大多數(shù)網(wǎng)絡(luò)那樣從零開(kāi)始學(xué)習(xí)。對(duì)于小樣本學(xué)習(xí)的也可以減少過(guò)擬合或者欠擬合問(wèn)題。
遷移學(xué)習(xí)的幾種實(shí)現(xiàn)方式:
Transfer Learning:凍結(jié)預(yù)訓(xùn)練模型的全部卷積層,只訓(xùn)練自己定制的全連接層。
Extract Feature Vector:先計(jì)算出預(yù)訓(xùn)練模型的卷積層對(duì)所有訓(xùn)練和測(cè)試數(shù)據(jù)的特征向量,然后拋開(kāi)預(yù)訓(xùn)練模型,只訓(xùn)練自己定制的簡(jiǎn)配版全連接網(wǎng)絡(luò)。
Fine-tuning:凍結(jié)預(yù)訓(xùn)練模型的部分卷積層(通常是靠近輸入的多數(shù)卷積層,因?yàn)檫@些層保留了大量底層信息)甚至不凍結(jié)任何網(wǎng)絡(luò)層,訓(xùn)練剩下的卷積層(通常是靠近輸出的部分卷積層)和全連接層。
二、構(gòu)建預(yù)訓(xùn)練模型
1、調(diào)用內(nèi)置模型
調(diào)用tensorflow內(nèi)置VGG19模型,下載該模型在"imagenet"數(shù)據(jù)集上預(yù)訓(xùn)練權(quán)重
net = keras.applications.VGG19(weights='imagenet', include_top=False,
???????????????????????????????pooling='max')
2、修改模型
凍結(jié)卷積層,將全連接層修改為自定義數(shù)據(jù)集對(duì)應(yīng)分類(lèi)數(shù)。
net.trainable = False
newnet = keras.Sequential([
???? net,
???? layers.Dense(5)
])
3、構(gòu)建模型
newnet.build(input_shape=(4,224,224,3))
newnet.summary()
三、導(dǎo)入數(shù)據(jù)和預(yù)處理
1、設(shè)置batch size
根據(jù)模型參數(shù)量和硬件環(huán)境設(shè)定batch size大小
batchsz = 128
2、讀取訓(xùn)練數(shù)據(jù)
images, labels, table = load_pokemon('pokemon',mode='train')
db_train = tf.data.Dataset.from_tensor_slices((images, labels))
db_train = db_train.shuffle(1000).map(preprocess).batch(batchsz)
3、讀取驗(yàn)證數(shù)據(jù)
images2, labels2, table = load_pokemon('pokemon',mode='val')
db_val = tf.data.Dataset.from_tensor_slices((images2, labels2))
db_val = db_val.map(preprocess).batch(batchsz)
4、讀取測(cè)試數(shù)據(jù)
images3, labels3, table = load_pokemon('pokemon',mode='test')
db_test = tf.data.Dataset.from_tensor_slices((images3, labels3))
db_test = db_test.map(preprocess).batch(batchsz)
5、預(yù)處理
def preprocess(x,y):
???? # x: 圖片的路徑,y:圖片的數(shù)字編碼
???? x = tf.io.read_file(x)
???? x = tf.image.decode_jpeg(x, channels=3)
???? x = tf.image.resize(x, [244, 244])
???? x = tf.image.random_flip_up_down(x)
???? x = tf.image.random_crop(x, [224,224,3])
????x = tf.cast(x, dtype=tf.float32) / 255.
??? x = normalize(x)
???? y = tf.convert_to_tensor(y)
???? y = tf.one_hot(y, depth=5)
???? return x, y
四、模型訓(xùn)練
1、設(shè)置early_stopping
為防止過(guò)擬合,這里使用early_stopping,當(dāng)模型在驗(yàn)證集上精度變化在min_delta以?xún)?nèi),并且持續(xù)次數(shù)達(dá)到patience以后,模型訓(xùn)練即停止。
early_stopping = EarlyStopping(
???? monitor='val_accuracy',
???? min_delta=0.001,
???? patience=5
)
2、模型編譯
設(shè)置優(yōu)化器,損失函數(shù)和精度衡量標(biāo)準(zhǔn)
newnet.compile(optimizer=optimizers.Adam(lr=1e-3),
???????????????loss=losses.CategoricalCrossentropy(from_logits=True),
???????????????metrics=['accuracy'])
3、模型設(shè)置
設(shè)置訓(xùn)練集,驗(yàn)證集,驗(yàn)證頻率,迭代次數(shù)以及回調(diào)函數(shù)
newnet.fit(db_train, validation_data=db_val, validation_freq=1, epochs=20,
???????????callbacks=[early_stopping])
4、模型評(píng)估
訓(xùn)練結(jié)束后,使用evaluate函數(shù)進(jìn)行模型評(píng)估,了解模型最終精度情況。
newnet.evaluate(db_test)
5、保存訓(xùn)練權(quán)重
newnet.save_weights('weights.ckpt')
五、模型預(yù)測(cè)
1、構(gòu)建預(yù)測(cè)模型
net = keras.applications.VGG19(weights='imagenet', include_top=False,
???????????????????????????????pooling='max')
net.trainable = False
model= keras.Sequential([
???? net,
???? layers.Dense(5)
])
model.build(input_shape=(4,224,224,3))
2、導(dǎo)入權(quán)重
model.load_weights('weights.ckpt')
3、預(yù)測(cè)
logits = newnet.predict(x)
prob = tf.nn.softmax(logits, axis=1)
print(prob)
max_prob_index = np.argmax(prob, axis=-1)[0]
prob = prob.numpy()
max_prob = prob[0][max_prob_index]
print(max_prob)
max_index = np.argmax(logits, axis=-1)[0]
name = ['妙蛙種子', '小火龍', '超夢(mèng)', '皮卡丘', '杰尼龜']
print(name[max_index])
測(cè)試圖像:
?
預(yù)測(cè)結(jié)果:
tf.Tensor([[0.78470963 0.09179451 0.03650109 0.01834733 0.06864741]], shape=(1, 5), dtype=float32)
0.78470963
妙蛙種子
4、對(duì)比分析
使用同樣測(cè)試圖像在沒(méi)有進(jìn)行遷移學(xué)習(xí)訓(xùn)練的模型上進(jìn)行測(cè)試,輸出結(jié)果:
tf.Tensor([[0.46965462 0.0470721 ?0.20003504 0.11915307 0.16408516]], shape=(1, 5), dtype=float32)
0.46965462
妙蛙種子
從結(jié)果上看,兩個(gè)模型都能準(zhǔn)確預(yù)測(cè),但輸出的分類(lèi)概率(遷移學(xué)習(xí)0.7847,非遷移學(xué)習(xí)0.4696),兩者存在明顯差別,可以看出使用遷移學(xué)習(xí)能夠達(dá)到更好的擬合效果。
總結(jié)
以上是生活随笔為你收集整理的tensorflow实现宝可梦数据集迁移学习的全部?jī)?nèi)容,希望文章能夠幫你解決所遇到的問(wèn)題。
- 上一篇: 深度学习tensorflow实现宝可梦图
- 下一篇: 使用OpenCV进行多边形绘制和填充