使用fastai完成图像分类
by Wenqi Sun
1 min read
Categories
Deep Learning
Tags
FastaiCNNApplication
1. 使用現(xiàn)有數(shù)據(jù)集進行分類
圖像數(shù)據(jù)為Oxford-IIIT Pet Dataset(12類貓和25類狗,共37類),這里僅使用原始圖片集images.tar.gz
數(shù)據(jù)準(zhǔn)備
import numpy as np
from fastai.vision import *
from fastai.metrics import error_rate
path_img = 'data/pets/images'
bs = 64 #batch size
fnames = get_image_files(path_img) #get filenames(absolute path) from path_img
pat = re.compile(r'/([^/]+)_d+.jpg$') #get labels from filenames(e.g., 'american_bulldog' from 'data/pets/images/american_bulldog_20.jpg')
### ImageDataBunch
### 使用正則表達式pat從圖像文件名fnames中提取標(biāo)簽,并和圖像對應(yīng)起來
### ds_tfms: 圖像轉(zhuǎn)換(翻轉(zhuǎn)、旋轉(zhuǎn)、裁剪、放大等),用于圖像數(shù)據(jù)增強(data augmentation)
### size: 最終圖像尺寸, bs: batch size, valid_pct: train/valid split
### normalize: 使用提供的均值和標(biāo)準(zhǔn)差(每個通道對應(yīng)一個均值和標(biāo)準(zhǔn)差)對圖像數(shù)據(jù)進行歸一化
np.random.seed(2)
data = ImageDataBunch.from_name_re(path_img, fnames, pat, ds_tfms=get_transforms(), size=224, bs=bs, valid_pct=0.2).normalize(imagenet_stats)
data.show_batch(rows=3, figsize=(7,6)) #grab a batch and display 3x3 images
模型搭建和訓(xùn)練
使用Resnet34進行遷移學(xué)習(xí),首先通過lr_find確定最大學(xué)習(xí)率,再通過fit_one_cycle(1-Cycle style)進行訓(xùn)練
lr_find: 在前面幾次的迭代中將學(xué)習(xí)率從一個很小的值逐漸增加,選擇損失函數(shù)(train loss)處于下降趨勢之中并且距離損失停止下降的拐點有一定距離的點做為模型的最大學(xué)習(xí)率max_lr
fit_one_cycle: 共分為兩個階段,在第一階段學(xué)習(xí)率從max_lr/div_factor線性增長到max_lr,momentum線性地從moms[0]降到moms[1];第二階段學(xué)習(xí)率以余弦形式從max_lr降為0,momentum也同樣按余弦形式從moms[1]增長到moms[0]。第一階段的迭代次數(shù)占總迭代次數(shù)的比例為pct_start
學(xué)習(xí)率和momentum: , , , 其中是要更新的參數(shù),G為梯度, 為學(xué)習(xí)率, 為momentum
### Use Resnet34 to classify images
learn = create_cnn(data, models.resnet34, metrics=error_rate)
print(learn.model) #model summary
learn.lr_find()
learn.recorder.plot() #由左上圖可以看出max_lr可選擇函數(shù)fit_one_cycle的默認(rèn)值0.003
learn.fit_one_cycle(4, max_lr=slice(0.003), div_factor=25.0, moms=(0.95, 0.85), pct_start=0.3) #4 epochs
learn.recorder.plot_lr(show_moms=True) #中上圖(學(xué)習(xí)率)和右上圖(momentum), x軸表示迭代次數(shù)
learn.save('stage-1') #save model
### Unfreeze all the model layers and keep training
learn.unfreeze()
learn.lr_find()
learn.recorder.plot() #左下圖
### 由左下圖可以看出max_lr可選擇1e-6, 但是模型的不同層可以設(shè)置不同的學(xué)習(xí)率加速訓(xùn)練
### 模型的前面幾層的學(xué)習(xí)率設(shè)置為max_lr, 后面幾層的學(xué)習(xí)率可以適當(dāng)增加(例如可以設(shè)置成比上一個fit_one_cycle的學(xué)習(xí)率小一個量級)
### slice(1e-6,1e-4)表示模型每層的學(xué)習(xí)率由1e-6逐漸增加過渡到1e-4
learn.fit_one_cycle(2, max_lr=slice(1e-6,1e-4), div_factor=25.0, moms=(0.95, 0.85), pct_start=0.3) #2 epochs
learn.recorder.plot_lr(show_moms=True) #中下圖(模型最后一層的學(xué)習(xí)率)和右下圖(momentum)
可視化
interp = ClassificationInterpretation.from_learner(learn)
interp.plot_confusion_matrix(figsize=(12,12), dpi=60) #confusion matrix
print(interp.most_confused(min_val=2)) #從大到小列出混淆矩陣中非對角線的最大的幾個元素
2. 從谷歌圖片下載數(shù)據(jù)并進行分類
獲得圖片鏈接
打開谷歌圖片,輸入想要下載的圖像類別,頁面上出現(xiàn)的圖片即為可下載的圖片
打開JavaScript Console(Windows/Linux:Ctrl+Shift+J, Mac:Cmd+Opt+J),運行下面的命令獲取圖片鏈接
<span 大專欄 使用fastai完成圖像分類 class="nx">urls = Array.from(document.querySelectorAll('.rg_di .rg_meta')).map(el=>JSON.parse(el.textContent).ou);
window.open('data:text/csv;charset=utf-8,' + escape(urls.join('n')));
分別搜索teddy bears、 black bears、 grizzly bears, 將下載的保存鏈接的文件分別命名為urls_teddys.txt、 urls_black.txt、 urls_grizzly.txt
下載圖片
import numpy as np
from fastai.vision import *
from fastai.metrics import error_rate
### 建立目錄并下載圖片
path = Path('data/bears')
folders = ['teddys', 'black', 'grizzly']
files = 'urls_teddys.txt', 'urls_black.txt', 'urls_grizzly.txt'
for i,folder in enumerate(folders):
dest = path/folder
dest.mkdir(parents=True, exist_ok=True)
download_images(files[i], dest, max_pics=200)
print(path.ls())
### 刪除不能被打開的圖片
for folder in folders:
verify_images(path/folder, delete=True, max_size=500)
訓(xùn)練模型
np.random.seed(42)
data = ImageDataBunch.from_folder(path, train=".", valid_pct=0.2, ds_tfms=get_transforms(), size=224, bs=64, num_workers=4).normalize(imagenet_stats)
print(data.classes)
learn = create_cnn(data, models.resnet34, metrics=error_rate)
learn.lr_find()
learn.recorder.plot() #左圖
learn.fit_one_cycle(4)
learn.save('stage-1')
learn.unfreeze()
learn.lr_find()
learn.recorder.plot() #右圖
learn.fit_one_cycle(2, max_lr=slice(3e-5,3e-4)) #若數(shù)據(jù)量較小,該步不一定有正效果
learn.save('stage-2')
learn.load('stage-1') #選擇stage-1
interp = ClassificationInterpretation.from_learner(learn)
interp.plot_confusion_matrix()
根據(jù)訓(xùn)練好的模型去除錯誤圖片
模型預(yù)測效果不好不一定是因為模型本身的問題,還可能是由于圖片自身的問題(例如下載了錯誤的圖片,圖片標(biāo)簽有誤),需要進行檢查和處理
from fastai.widgets import *
### ds: 訓(xùn)練圖片集, idxs: 具有最大損失的訓(xùn)練圖片索引
ds, idxs = DatasetFormatter().from_toplosses(learn, n_imgs=200) #選出前200個具有最大損失的訓(xùn)練圖片
ImageCleaner(ds, idxs, path) #手動處理,處理好的文件被存入path/cleaned.csv(該文件僅包含經(jīng)過處理后的訓(xùn)練圖片集,不包含驗證圖片)
可根據(jù)具體情況對處理之后的數(shù)據(jù)重新進行訓(xùn)練
保存模型并預(yù)測
learn.export() #將模型存入learn.path/export.pkl
learn = load_learner(path) #從path中讀取模型
img = open_image(path/'black'/'00000021.jpg') #以訓(xùn)練集中的一個圖片為例
pred_class,pred_idx,outputs = learn.predict(img) #預(yù)測圖片
print(pred_class) #輸出類別
print(outputs) #輸出每個類的概率
總結(jié)
以上是生活随笔為你收集整理的使用fastai完成图像分类的全部內(nèi)容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: MYSQL | 最左匹配原则
- 下一篇: JPA 系列教程11-复合主键-2个@I