matlab手写神经网络实现识别手写数字
實(shí)驗(yàn)說(shuō)明
一直想自己寫(xiě)一個(gè)神經(jīng)網(wǎng)絡(luò)來(lái)實(shí)現(xiàn)手寫(xiě)數(shù)字的識(shí)別,而不是套用別人的框架。恰巧前幾天,有幸從同學(xué)那拿到5000張已經(jīng)貼好標(biāo)簽的手寫(xiě)數(shù)字圖片,于是我就嘗試用matlab寫(xiě)一個(gè)網(wǎng)絡(luò)。
實(shí)驗(yàn)數(shù)據(jù):5000張手寫(xiě)數(shù)字圖片(.jpg),圖片命名為1.jpg,2.jpg…5000.jpg。還有一個(gè)放著標(biāo)簽的excel文件。
數(shù)據(jù)處理:前4000張作為訓(xùn)練樣本,后1000張作為測(cè)試樣本。
圖片處理:用matlab的imread()函數(shù)讀取圖片的灰度值矩陣(28,28),然后把每張圖片的灰度值矩陣reshape為(28*28,1),然后把前4000張圖片的灰度值矩陣合并為x_train,把后1000張圖片的灰度值矩陣合并為x_test。
神經(jīng)網(wǎng)絡(luò)設(shè)計(jì)
網(wǎng)絡(luò)層設(shè)計(jì):一層隱藏層,一層輸出層
輸入層:一張圖片的灰度值矩陣reshape后的784個(gè)數(shù),也就是x_train中的某一列
輸出層:(10,1)的列向量,其中列向量中最大的數(shù)所在的索引就是預(yù)測(cè)的數(shù)字
激勵(lì)函數(shù):sigmoid函數(shù)(公式)
更新法則:后向傳播算法(參考)
一點(diǎn)說(shuō)明:這里的訓(xùn)練我分別用了普通梯度下降法和mini_batch(batch size 為10)梯度下降法來(lái)實(shí)現(xiàn)
測(cè)試:用了兩種方式表示正確率,一是統(tǒng)計(jì)預(yù)測(cè)正確的個(gè)數(shù),而是利用matlab的plotconfusion函數(shù)
網(wǎng)絡(luò)實(shí)現(xiàn)
全部實(shí)現(xiàn)包括5個(gè)函數(shù)(gedata.m / layerout.m / mytrain.m / mytrain_mini.m / test.m)和一個(gè)main.m文件。
讀取數(shù)據(jù)(getdata.m)
function[x_train,y_train,x_test,y_test]=getdata() %把圖片變成像素矩陣 %path :圖片路徑 % x_train:訓(xùn)練樣本像素矩陣(784,4000) %y_train:訓(xùn)練樣本標(biāo)簽(10,4000) %x_test:測(cè)試樣本像素矩陣(784,1000) %y_test:測(cè)試樣本標(biāo)簽(10,1000)% photopath = './photo/'; % snames=dir([photopath '*' '.jpg'])%get all filenames in photopath % l = length(snames) % % %get x_ data % x_train = []; % x_test = []; % % for i=1:4000 % iname=[photopath snames(i).name] %the path of jpg % x = imread(iname); % the shape of x is (28,28) % x = reshape(x,784,1); %reshape x to (784,1) % x_train = [x_train,x]; % end % % for k=4001:5000 % kname=[photopath snames(k).name]; %the path of jpg % x = imread(kname); %the shape of x is (28,28) % x = reshape(x,784,1); %reshape x to (784,1) % x_test = [x_test,x]; % endx_train=[];for i=1:4000x=im2double(imread(strcat(num2str(i),'.jpg')));x=reshape(x,784,1);x_train=[x_train,x]; end x_test =[];for k=4001:5000x=im2double(imread(strcat(num2str(k),'.jpg')));x=reshape(x,784,1);x_test=[x_test,x]; end data=xlsread('label.xlsx'); y_train=data(:,1:4000); y_test = data(:,4001:5000);x_train; y_train; x_test; y_test;end這里踩了一個(gè)坑。我本來(lái)讀取圖片,是按目錄來(lái)讀取的,然后訓(xùn)練出來(lái)的效果一直不好。一度懷疑自己的更新函數(shù)寫(xiě)錯(cuò)了,改了很久,才發(fā)現(xiàn)按目錄讀取的圖片順序是錯(cuò)誤的!按目錄讀取的圖片并不是按1,2,3…這樣讀的,而是按下面的順序讀取的,這樣就和label對(duì)不上了!!!
layerout函數(shù)
function [y] = layerout(w,b,x) %output function y = w*x + b; n = length(y); for i =1:ny(i)=1.0/(1+exp(-y(i))); end y; end訓(xùn)練一(mytrain.m)
function[w,b,w_h,b_h]=mytrain(x_train,y_train) %train function:設(shè)置一個(gè)隱藏層,784-->隱藏層神經(jīng)元個(gè)數(shù)-->10 %x_train:訓(xùn)練樣本的像素?cái)?shù)據(jù) %y_train:訓(xùn)練樣本的標(biāo)簽 %w:輸出層權(quán)重 %b:輸出層偏置 %w_h:隱藏層權(quán)重 %b_h:隱藏層偏置 %step:循環(huán)步數(shù)step=input('迭代步數(shù):'); a=input('學(xué)習(xí)因子:'); in = 784; %輸入神經(jīng)元個(gè)數(shù) hid = input('隱藏層神經(jīng)元個(gè)數(shù):');%隱藏層神經(jīng)元個(gè)數(shù) out = 10; %輸出層神經(jīng)元個(gè)數(shù) o =1;w = randn(out,hid); b = randn(out,1); w_h =randn(hid,in); b_h = randn(hid,1);for i=0:step%打亂訓(xùn)練樣本r=randperm(4000);x_train = x_train(:,r);y_train = y_train(:,r);for j=1:4000x = x_train(:,j);y = y_train(:,j);hid_put = layerout(w_h,b_h,x);out_put = layerout(w,b,hid_put);%更新公式的實(shí)現(xiàn)o_update = (y-out_put).*out_put.*(1-out_put);h_update = ((w')*o_update).*hid_put.*(1-hid_put);outw_update = a*(o_update*(hid_put'));outb_update = a*o_update;hidw_update = a*(h_update*(x'));hidb_update = a*h_update;w = w + outw_update;b = b+ outb_update;w_h = w_h +hidw_update;b_h =b_h +hidb_update;end end end訓(xùn)練二(mytrain_mini.m)
function[w,b,w_h,b_h]=mytrain_mini(x_train,y_train) %train function:設(shè)置一個(gè)隱藏層,784-->隱藏層神經(jīng)元個(gè)數(shù)-->10 %x_train:訓(xùn)練樣本的像素?cái)?shù)據(jù) %y_train:訓(xùn)練樣本的標(biāo)簽 %w:輸出層權(quán)重 %b:輸出層偏置 %w_h:隱藏層權(quán)重 %b_h:隱藏層偏置 %step:循環(huán)步數(shù)step=ipout('迭代步數(shù):'); a=input('學(xué)習(xí)因子:'); in = 784; %輸入神經(jīng)元個(gè)數(shù) hid = input('隱藏層神經(jīng)元個(gè)數(shù):');%隱藏層神經(jīng)元個(gè)數(shù) out = 10; %輸出層神經(jīng)元個(gè)數(shù) o =1;w = randn(out,hid); b = randn(out,1); w_h =randn(hid,in); b_h = randn(hid,1);for i=0:step%打亂訓(xùn)練樣本r=randperm(4000);x_train = x_train(:,r);y_train = y_train(:,r);%mini_batchfor jj=0:399%取batch為10 更新取10次的平均值for j=jj*10+1:(jj+1)*10x = x_train(:,j);y = y_train(:,j);hid_put = layerout(w_h,b_h,x);out_put = layerout(w,b,hid_put);%更新公式的實(shí)現(xiàn)o_update = (y-out_put).*out_put.*(1-out_put);h_update = ((w')*o_update).*hid_put.*(1-hid_put);if j==1outw_update = (double(a)/10)*(o_update*(hid_put'));outb_update = (double(a)/10)*o_update;hidw_update = (double(a)/10)*(h_update*(x'));hidb_update = (double(a)/10)*h_update;endif j~=1outw_update = outw_update + (double(a)/10)*(o_update*(hid_put'));outb_update = outb_update -(double(a)/10)*o_update;hidw_update = hidw_update + (double(a)/10)*(h_update*(x'));hidb_update = hidb_update -(double(a)/10)*h_update;endendw = w + outw_update;b = b+ outb_update;w_h = w_h +hidw_update;b_h =b_h +hidb_update;end end end測(cè)試(mytest.m)
function[]= mytest(x_test,y_test,w,b,w_h,b_h) %x_test:測(cè)試樣本的像素?cái)?shù)據(jù) %y_test:測(cè)試樣本的標(biāo)簽 %w:輸出層權(quán)重 %b:輸出層偏置 %w_h:隱藏層權(quán)重 %b_h:隱藏層偏置test = zeros(10,1000); for k=1:1000x = x_test(:,k);hid = layerout(w_h,b_h,x);test(:,k)=layerout(w,b,hid);%正確率表示方式一:輸出正確個(gè)數(shù)[t,t_index]=max(test);[y,y_index]=max(y_test);sum = 0;for p=1:length(t_index)if t_index(p)==y_index(p)sum =sum+1;endend endfprintf('正確率: %d/1000\n',sum);%正確率表示方式二:用plotconfusion函數(shù) plotconfusion(y_test,test); endmain.m
[x_train,y_train,x_test,y_test]=getdata();%歸一化 x_train = mapminmax(x_train,0,1); x_test =mapminmax(x_test,0,1);[w1,b1,w_h1,b_h1]=mytrain(x_train,y_train); fprintf('mytrain正確率:\n'); mytest(x_test,y_test,w1,b1,w_h1,b_h1);[w2,b2,w_h2,b_h2]=mytrain(x_train,y_train); fprintf('mytrain_mini正確率:\n'); mytest(x_test,y_test,w2,b2,w_h2,b_h2);實(shí)驗(yàn)結(jié)果
直接運(yùn)行main.m,且兩個(gè)訓(xùn)練方式都輸入相同參數(shù),得到結(jié)果如下:
下面是mini_batch的plotconfusion結(jié)果,mytrain的也差不多。其中綠色的為正確率:
直觀感覺(jué)min_batch方式的訓(xùn)練會(huì)快一丟丟。由于這里數(shù)據(jù)不多,所以兩者的差別看不大出來(lái)!
總結(jié)
以上是生活随笔為你收集整理的matlab手写神经网络实现识别手写数字的全部?jī)?nèi)容,希望文章能夠幫你解決所遇到的問(wèn)題。
- 上一篇: 2020hdu多校6
- 下一篇: matlab人脸追踪,求大神帮助我这个菜