MNIST 手写数字识别实践

问题介绍

手写数字识别为模式识别的一个重要分支,其主要研究如何让计算机通过手写体图片识别出图片中的数字。手写数字识别可以被应用于需要大量数字识别处理的场景,如学校试卷成绩统计、银行票据统计等,应用场景广泛。同时手写数字识别由计算机自动化操作,相比于人工处理,效率更高。

机器学习及深度学习方法常用于手写数字识别任务,本文将采用多种模型完成手写数字识别10分类任务,并对不同模型进行对比分析。

相关问题内容调研

处理手写数字识别任务的模式识别算法主要包括机器学习方法如K近邻算法(K Nearest Neighbor,KNN)、支持向量机(Support Vector Machine,SVM)、贝叶斯网络和深度学习方法如深度神经网络(Deep Neural Networks,DNN)、卷积神经网络(Convolutional Neural Network,CNN)、循环神经网络(Recurrent Neural Network,RNN)等。

用于手写数字识别的数据集主要由MNIST,DBRHD和USPS。MNIST数据集常用于深度学习领域,共包括70000个带标签的手写数字样本。其中60000张为训练集,10000张为测试集。每张图像尺寸为28*28像素。DBRHD是UCI的机器学习中心提供的数字手写体数据库。训练集共有7494个,来自40位手写者;测试集共有3498个,来自14位手写者。USPS数据集是美国邮政服务手写数字识别库,共包含9282个手写数字图像,每张图像均为灰度图,尺寸为16*16像素。

实现方法

本文将分别使用DNN、CNN、KNN、SVM四种方法对MNIST数据集进行手写数字识别实验。

  • DNN。该模型网络结构包含一个64维的隐含层,损失函数采用交叉熵,优化器选用随机梯度下降法,激活函数为ReLU。

  • CNN。CNN方法包含两个模型。模型1:该模型网络结构仅包含一个卷积层,卷积核大小为3*3,包含32个通道。损失函数采用交叉熵,优化器选用随机梯度下降法,激活函数为ReLU;模型2:采用Pytorch官方示例模型,该模型增加了卷积层及全连接层,并新添池化层及Dropout策略,详见https://github.com/pytorch/examples/blob/master/mnist/main.py。

  • KNN。KNN算法原理为通过寻找待分类样本在样本空间中K个最相似的样本,K个样本中大多数样本同属于某个类别,则该样本也属于这个类别。本次实验中K选择10。

  • SVM。SVM算法是一类按照监督学习方式对数据进行二元分类的广义线性分类器,目标在于寻找划分学习样本的最大边距超平面,并可通过核方法实现非线性分类。

本次实验基本流程包括数据集的获取、预处理,模型设计及实现、模型训练及调优。

具体实验代码详见https://github.com/LanceZhu/deep-learning。

实验数据及设置

本次实验采用MNIST数据集,60000个样本作为训练集,10000个样本作为测试集。

软硬件环境:

​ 操作系统:Windows 10,WSL(Windows Subsystem Linux)

​ CPU:Intel(R) Core(TM) i7-6500U

​ RAM:12GB

​ 编程语言及框架:Python,Pytorch、Scikit-learn

实验结果和分析

在划分的训练集上进行模型训练,测试集上验证模型效果。模型训练过程见图,各模型准确率见下表:

模型 模型准确率
DNN 95%
CNN(1) 98%
CNN(2) 99%
KNN 90%
SVM 93%

表1:各模型准确率

由表1可看出,

  • 几种模型准确率均大于等于90%,最高可达99%,结果较为乐观。

  • 传统机器学习方法均劣于深度学习方法。

  • 深度学习方法中,CNN方法优于普通DNN方法。原因在于CNN中引入了感受野、权值共享的方法减少了网络参数,缓解了网络过拟合问题,CNN的卷积操作使得其相比于DNN获得了某种程度的尺度、位移、形变不变性,具有更好的泛化和鲁棒性。

结论

本文使用DNN、CNN、KNN、SVM四种方法对MNIST数据集进行了手写数字识别实验。实验结果表明,各方法准确率高于90%,其中CNN模型准确率最高(99%),可认为达到了正确识别手写数字的效果。其中深度学习方法在准确率上均高于传统机器学习方法,可认为深度学习方法在手写数字识别问题上优于传统机器学习方法。

参考文献

[1] 陈庭轩. 基于集成卷积神经网络的手写体数字识别研究[D].华中师范大学,2020.

[2] 何帅.卷积神经网络在手写数字识别中的应用[J].电脑知识与技术,2020,16(21):13-15.

[3] 唐子清,姚俭.基于深度学习的数字识别方法研究[J].软件导刊,2020,19(09):228-232.

[4] 汤晓武.一种基于KNN算法的手写数字分类器的设计与实现[J].信息通信,2020(10):53-55.

[5] Justin Johnson.LEARNING PYTORCH WITH EXAMPLES[EB/OL].https://pytorch.org/tutorials/beginner/pytorch_with_examples.html,2017.