java识别手写文字_神经网络入门 第6章 识别手写字体
前言
神經(jīng)網(wǎng)絡(luò)是一種很特別的解決問題的方法。本書將用最簡(jiǎn)單易懂的方式與讀者一起從最簡(jiǎn)單開始,一步一步深入了解神經(jīng)網(wǎng)絡(luò)的基礎(chǔ)算法。本書將盡量避開讓人望而生畏的名詞和數(shù)學(xué)概念,通過構(gòu)造可以運(yùn)行的Java程序來實(shí)踐相關(guān)算法。
關(guān)注微信號(hào)“邏輯編程"來獲取本書的更多信息。
這一章節(jié)我們將會(huì)解決一個(gè)真正的問題:手寫字體識(shí)別。我們將識(shí)別像下面圖中這樣的手寫數(shù)字。
在開始之前,我們先要準(zhǔn)備好相應(yīng)的測(cè)試數(shù)據(jù)。我們不能像前邊那樣簡(jiǎn)單的產(chǎn)生手寫字體,畢竟我們自己還不知道如何寫出一個(gè)產(chǎn)生手寫字體的算法。訓(xùn)練要達(dá)到一定的精度需要較多的訓(xùn)練數(shù)據(jù)。還好,前人栽樹后人乘涼,先驅(qū)們已經(jīng)收集了寶貴的訓(xùn)練材料。MNIST就是一個(gè)廣泛使用的數(shù)據(jù)集。不但可以拿來用,我們還可以從網(wǎng)站上看到別人的識(shí)別準(zhǔn)確率。這樣我們就有了很好的參照。MNIST包含一套訓(xùn)練數(shù)據(jù)和一套測(cè)試數(shù)據(jù),分別來自不同的人群的手寫。
MNIST網(wǎng)站:?http://yann.lecun.com/exdb/mnist/
這個(gè)數(shù)據(jù)集是寫在特定的二進(jìn)制文件中的,并非普通圖片格式。每個(gè)圖片數(shù)據(jù)由28*28個(gè)像素組成。每個(gè)像素1個(gè)字節(jié)表示顏色灰度級(jí)。MNIST網(wǎng)站上有具體的介紹。
我們寫一個(gè)類來完成數(shù)據(jù)集的讀取工作,并提供接口返回指定的訓(xùn)練或者測(cè)試數(shù)據(jù)。具體代碼不做分析,僅將代碼附在下面,供讀者使用。代碼執(zhí)行前要先下載數(shù)據(jù)文件并保留GZIP格式。代碼執(zhí)行后將隨機(jī)抽取20個(gè)生成PNG圖片供讀者自己查看和驗(yàn)證數(shù)據(jù)內(nèi)容。
下面我們寫個(gè)測(cè)試類來識(shí)別手寫字體。我們使用MNIST庫(kù)的60000訓(xùn)練數(shù)據(jù)來反復(fù)訓(xùn)練我們的神經(jīng)網(wǎng)絡(luò)。每輪訓(xùn)練后使用MNIST庫(kù)的10000個(gè)測(cè)試數(shù)據(jù)來測(cè)試識(shí)別率。
下面是代碼:
package com.luoxq.ann;
import java.util.Arrays;
import java.util.Random;
public class MnistTest {
public static void main(String... args) {
int[] shape = {28 * 28, 10};
NeuralNetwork nn = new NeuralNetwork(shape);
Mnist mnist = new Mnist();
mnist.load();
mnist.shuffle();
System.out.println("Shape: " + Arrays.toString(shape));
System.out.println("Initial correct rate: " + test(nn, mnist));
int epochs = 1000;
double rate = 0.5;
System.out.println("Learning rate: " + rate);
System.out.println("Epoch,Time,Correctness\n----------------------");
long time = System.currentTimeMillis();
Mnist.Data[] data = mnist.getTrainingSlice(0, 60000);
for (int epoch = 1; epoch <= epochs; epoch++) {
for (int sample = 0; sample < data.length; sample++) {
nn.train(data[sample].input, data[sample].output, rate);
}
long seconds = (System.currentTimeMillis() - time) / 1000;
System.out.println(epoch + ", " + seconds + ", " +
test(nn, mnist));
}
}
private static int test(NeuralNetwork nn, Mnist mnist) {
int correct = 0;
Mnist.Data[] data = mnist.getTestSlice(0, 10000);
for (int sample = 0; sample < data.length; sample++) {
if (max(nn.f(data[sample].input)) == data[sample].label) {
correct++;
}
}
return correct;
}
private static int max(double[] d) {
double max = d[0];
int idx = 0;
for (int i = 1; i < d.length; i++) {
if (max < d[i]) {
max = d[i];
idx = i;
}
}
return idx;
}
}
我們先用一個(gè)10個(gè)神經(jīng)元的單層神經(jīng)網(wǎng)絡(luò)試試看。結(jié)果出乎意外的好。我們很快就獲得了超過90%的正確率。單層網(wǎng)絡(luò)幾乎就是對(duì)每個(gè)數(shù)字的像素分布做簡(jiǎn)單統(tǒng)計(jì)。能獲得如此高的識(shí)別率,還是很神奇的。 在達(dá)到90%之后再訓(xùn)練已經(jīng)效果不大,達(dá)到飽和了。我們必須換一種方法來做了。
Shape: [784, 10]
Initial correct rate: 1373
Learning rate: 0.5
Epoch,Time,Correctness
----------------------
1, 4, 6429
2, 8, 7663
3, 13, 8963
4, 17, 9029
5, 22, 9016
6, 27, 9062
7, 31, 9063
8, 36, 9066
9, 41, 9072
10, 45, 9057
11, 50, 9084
12, 55, 9072
13, 61, 9062
14, 66, 9050
15, 70, 9077
16, 75, 9052
17, 79, 9068
18, 84, 9055
19, 88, 9060
20, 93, 9064
那么我們來使用三層神經(jīng)網(wǎng)絡(luò)試一試。在試了幾個(gè)不同的中間層大小和學(xué)習(xí)率參數(shù)之后,我找到了下面這個(gè)較好的參數(shù)組合:
Shape: [784, 50, 10]
Initial correct rate: 944
Learning rate: 1.0
Epoch,Time,Correctness
----------------------
1, 24, 7459
2, 59, 9232
3, 99, 9313
4, 131, 9379
5, 153, 9412
6, 176, 9443
7, 200, 9412
8, 226, 9447
9, 248, 9462
10, 269, 9461
11, 290, 9465
12, 314, 9493
13, 343, 9477
14, 368, 9499
15, 392, 9502
16, 420, 9509
17, 447, 9482
18, 472, 9508
19, 496, 9491
20, 518, 9536
21, 545, 9523
22, 569, 9549
23, 593, 9527
24, 618, 9527
25, 643, 9520
26, 667, 9513
27, 689, 9507
28, 712, 9527
29, 734, 9501
30, 758, 9521
31, 781, 9508
32, 804, 9534
33, 827, 9534
34, 850, 9550
35, 875, 9569
我們很快達(dá)到了95%以上的正確率。可見多層網(wǎng)絡(luò)相對(duì)單層神經(jīng)網(wǎng)絡(luò)還是有優(yōu)勢(shì)的。雖然這個(gè)正確率還達(dá)不到產(chǎn)品水平,但是作為初次嘗試結(jié)果還是很不錯(cuò)的。
下面是MNIST文件讀取源代碼:
package com.luoxq.ann;
import javax.imageio.ImageIO;
import java.awt.image.BufferedImage;
import java.io.DataInputStream;
import java.io.File;
import java.io.FileInputStream;
import java.util.Random;
import java.util.zip.GZIPInputStream;
/**
* Created by luoxq on 17/4/15.
*/
public class Mnist {
static class Data {
public byte[] data;
public int label;
public double[] input;
public double[] output;
}
public static void main(String... args) throws Exception {
Mnist mnist = new Mnist();
mnist.load();
System.out.println("Data loaded.");
Random rand = new Random(System.nanoTime());
for (int i = 0; i < 20; i++) {
int idx = rand.nextInt(60000);
Data d = mnist.getTrainingData(idx);
BufferedImage img = new BufferedImage(28, 28, BufferedImage.TYPE_INT_RGB);
for (int x = 0; x < 28; x++) {
for (int y = 0; y < 28; y++) {
img.setRGB(x, y, toRgb(d.data[y * 28 + x]));
}
}
File output = new File(i + "_" + d.label + ".png");
if (!output.exists()) {
output.createNewFile();
}
ImageIO.write(img, "png", output);
}
}
static int toRgb(byte bb) {
int b = (255 - (0xff & bb));
return (b << 16 | b << 8 | b) & 0xffffff;
}
Data[] trainingSet;
Data[] testSet;
public void shuffle() {
Random rand = new Random();
for (int i = 0; i < trainingSet.length; i++) {
int x = rand.nextInt(trainingSet.length);
Data d = trainingSet[i];
trainingSet[i] = trainingSet[x];
trainingSet[x] = trainingSet[i];
}
}
public Data getTrainingData(int idx) {
return trainingSet[idx];
}
public Data[] getTrainingSlice(int start, int count) {
Data[] ret = new Data[count];
System.arraycopy(trainingSet, start, ret, 0, count);
return ret;
}
public Data getTestData(int idx) {
return testSet[idx];
}
public Data[] getTestSlice(int start, int count) {
Data[] ret = new Data[count];
System.arraycopy(testSet, start, ret, 0, count);
return ret;
}
public void load() {
trainingSet = load("train-images-idx3-ubyte.gz", "train-labels-idx1-ubyte.gz");
testSet = load("t10k-images-idx3-ubyte.gz", "t10k-labels-idx1-ubyte.gz");
if (trainingSet.length != 60000 || testSet.length != 10000) {
throw new RuntimeException("Unexpected training/test data size: " + trainingSet.length + "/" + testSet.length);
}
}
private Data[] load(String imgFile, String labelFile) {
byte[][] images = loadImages(imgFile);
byte[] labels = loadLabels(labelFile);
if (images.length != labels.length) {
throw new RuntimeException("Images and label doesn't match: " + imgFile + " " + labelFile);
}
int len = images.length;
Data[] data = new Data[len];
for (int i = 0; i < len; i++) {
data[i] = new Data();
data[i].data = images[i];
data[i].label = 0xff & labels[i];
data[i].input = dataToInput(images[i]);
data[i].output = labelToOutput(labels[i]);
}
return data;
}
private double[] labelToOutput(byte label) {
double[] o = new double[10];
o[label] = 1;
return o;
}
private double[] dataToInput(byte[] b) {
double[] d = new double[b.length];
for (int i = 0; i < b.length; i++) {
d[i] = (b[i] & 0xff) / 255.0;
}
return d;
}
private byte[][] loadImages(String imgFile) {
try (DataInputStream in = new DataInputStream(new GZIPInputStream(new FileInputStream(imgFile)));) {
int magic = in.readInt();
if (magic != 0x00000803) {
throw new RuntimeException("wrong magic: 0x" + Integer.toHexString(magic));
}
int count = in.readInt();
int rows = in.readInt();
int cols = in.readInt();
if (rows != 28 || cols != 28) {
throw new RuntimeException("Unexpected row and col count: " + rows + "x" + cols);
}
byte[][] data = new byte[count][rows * cols];
for (int i = 0; i < count; i++) {
in.readFully(data[i]);
}
return data;
} catch (Exception ex) {
throw new RuntimeException("Failed to read file: " + imgFile, ex);
}
}
private byte[] loadLabels(String labelFile) {
try (DataInputStream in = new DataInputStream(new GZIPInputStream(new FileInputStream(labelFile)));) {
int magic = in.readInt();
if (magic != 0x00000801) {
throw new RuntimeException("wrong magic: 0x" + Integer.toHexString(magic));
}
int count = in.readInt();
byte[] data = new byte[count];
in.readFully(data);
return data;
} catch (Exception ex) {
throw new RuntimeException("Failed to read file: " + labelFile, ex);
}
}
}
歡迎關(guān)注訂閱號(hào)邏輯編程內(nèi)容。
總結(jié)
以上是生活随笔為你收集整理的java识别手写文字_神经网络入门 第6章 识别手写字体的全部?jī)?nèi)容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: 北京租房小贴士
- 下一篇: 七段显示器 + 74HC595 显示 /