使用 Theano 框架进行 MNIST 数字识别(纯干货,一点水没有)

news/2024/7/4 12:58:49

概要:使用 Theano 机器学习框架,设计一个深度学习神经网络,对 MNIST 数据集进行手写数字识别。

(想自学习编程的小伙伴请搜索圈T社区,更多行业相关资讯更有行业相关免费视频教程。完全免费哦!)

Theano 是蒙特利尔理工学院开发的机器学习框架,派生出了 Keras 等深度学习 Python 软件包。Theano 是为处理深度学习的大型神经网络算法而专门设计的,它的核心是一个数学表达式的编译器,它知道如何获 取网络结构,能够使相关代码高效地运行。Keras 是一个高级、快速和模块化的 Python 神经网络库,能够在 Theano 或 TensorFlow 平台上运行。

Keras 是一个简洁、高度模块化的神经网络库,可以通过如下命令进行安装。

$> pip install keras

也可以使用如下命令在线安装 Keras 的最新版本。

$> pip install git+git://github.com/fchollet/keras.git

安装过程如下图所示。

MNIST 手写数字识别是一个分类问题,下面介绍怎样使用 Theano 和 Keras 简单地训练和测试一个神经 网络。

首先需要启动 Jupyter notebook 导入一些必要的模块。如果 Keras 是首次导入,输出信息中会显示它选 择的后端模块,毫无疑问,这里选择 Theano 作为后端引擎。

from keras.models import Sequential
from keras.layers import Dense, Dropout, Flatten
from keras.layers.convolutional import Convolution2D, MaxPooling2D
from keras.utils import np_utils
import numpy as np

然后加载 MNIST 数据集,这里将使用 Keras 命令自动下载。

from keras.datasets import mnist
(x_train, y_train), (x_test, y_test) = mnist.load_data()

使用 x_train.shape 可以查看数据集的大小及图像尺寸。现在需要对数据集进行预处理,使它转换成Keras 使用的格式。当处理图像数据时,Keras 需要知道图像通道数,为了加快处理速度,将数据转换为 0~1 的 32 位浮点数。

num_pixels = x_train.shape[1] * x_train.shape[2] n_channels = 1
def preprocess(matrix):
return matrix.reshape(matrix.shape[0], \ n_channels, \
matrix.shape[1], \
matrix.shape[2]
).astype('float32') / 255.
x_train, x_test = preprocess(x_train), preprocess(x_test)

现在,还需要对输出结果进行处理。由于它是一个具有 10 个类别的分类问题,因此每一类都应该有自己 的输出列,每一列都对应输出层的一个神经元。

y_train = np_utils.to_categorical(y_train)
y_test = np_utils.to_categorical(y_test)
num_classes = y_train.shape[1]

现在来创建一个简单的基准模型。我们创建一个 Sequential 模型,模型的每一个层都是按顺序堆叠的。首先将输入图像拉伸为向量,即将 28×28 的单通道图像变换成 784 维的向量,然后输入层采用 784 个神经元, 紧接着的输出层有 10 个神经元。对于第一层,可以选择 ReLU 函数作为激活函数,由于是分类任务,第二层 采用 softmax 函数。最后,用 Keras 编译模型,它使用类的交叉熵作为优化损失函数,采用分类精度作为主要 性能指标。

def baseline_model():
model = Sequential()
model.add(Flatten(input_shape=(1, 28, 28)))
model.add(Dense(num_pixels, init='normal', activation='relu'))
model.add(Dense(num_classes, init='normal', activation='softmax')) model.compile(loss='categorical_crossentropy', optimizer='adam',
metrics=['accuracy'])
return model

为了展示模型提升的效果,可以构建一个稍微复杂的卷积神经网络,命名为 convolution_small,模型包含步骤如下:

  1. 操作在二维矩阵上的卷积滤波器 (Convolutional Filter):使用窗口为 5×5
    的滤波器,对二维图像进行卷积滤波操作,产生 32 维的输出向量。
  2. 最大池化层 (max-pooler):对 2×2 窗口进行最大化选择,以非线性的方式对图像进行采样。
  3. dropout 层:随机将神经元的 20% 重置为 0,这样能够防止模型过拟合。
  4. 其他步骤和基准模型一样。
def convolution_small():
model = Sequential()
model.add(Convolution2D(32, 5, 5, border_mode='valid',input_shape=(1, 28, 28), activation='relu'))
model.add(MaxPooling2D(pool_size=(2, 2)))
model.add(Dropout(0.2))
model.add(Flatten())
model.add(Dense(128, activation='relu'))
model.add(Dense(num_classes, activation='softmax')) model.compile(loss='categorical_crossentropy',optimizer='adam', metrics=['accuracy'])
return model 

为了展现神经网络的威力,还可以创建一个更复杂的神经网络,它和前面的模型类似,但是Convolution2DMaxPooling2D 的层数是原来的 2 倍。

def convolution_large():
model = Sequential()
model.add(Convolution2D(30, 5, 5, border_mode='valid', input_shape=(1, 28, 28), activation='relu')) model.add(MaxPooling2D(pool_size=(2, 2))) model.add(Convolution2D(15, 3, 3, activation='relu')) model.add(MaxPooling2D(pool_size=(2, 2)))

     model.add(Dropout(0.2))
model.add(Flatten())
model.add(Dense(128, activation='relu'))
model.add(Dense(50, activation='relu'))
model.add(Dense(num_classes, activation='softmax')) model.compile(loss='categorical_crossentropy',optimizer='adam', metrics=['accuracy'])
return model 

最后,来测试一下这些模型,注意观察模型的性能和产生结果的时间。在相同的验证集上测试这些算法。训练阶段的轮数设置为 10。

np.random.seed(101)
models = [('baseline', baseline_model()),
('small', convolution_small()),
('large', convolution_large())]
for name, model in models:
print("With model:", name)
# Fit the model
model.fit(X_train, y_train, validation_data=(X_test, y_test), nb_epoch=10, batch_size=100, verbose=2)
# Final evaluation of the model
scores = model.evaluate(X_test, y_test, verbose=0) print("Baseline Error: %.2f%%" % (100-scores[1]*100))
print()
Out: With model: baseline
Train on 60000 samples, validate on 10000 samples
Epoch 1/10
3s - loss: 0.2332 - acc: 0.9313 - val_loss: 0.1113 - val_acc: 0.9670 Epoch 2/10
3s - loss: 0.0897 - acc: 0.9735 - val_loss: 0.0864 - val_acc: 0.9737 [...]
Epoch 10/10
2s - loss: 0.0102 - acc: 0.9970 - val_loss: 0.0724 - val_acc: 0.9796 Baseline Error: 2.04%
With model: small
Train on 60000 samples, validate on 10000 samples
Epoch 1/10
17s - loss: 0.1878 - acc: 0.9449 - val_loss: 0.0600 - val_acc: 0.9806 Epoch 2/10
16s - loss: 0.0631 - acc: 0.9808 - val_loss: 0.0424 - val_acc: 0.9850 [...]
Epoch 10/10
16s - loss: 0.0110 - acc: 0.9965 - val_loss: 0.0410 - val_acc: 0.9894 Baseline Error: 1.06%
With model: large
Train on 60000 samples, validate on 10000 samples
Epoch 1/10

   26s - loss: 0.2920 - acc: 0.9087 - val_loss: 0.0738 - val_acc: 0.9749 Epoch 2/10
25s - loss: 0.0816 - acc: 0.9747 - val_loss: 0.0454 - val_acc: 0.9857 [...]
Epoch 10/10
27s - loss: 0.0253 - acc: 0.9921 - val_loss: 0.0253 - val_acc: 0.9919
Baseline Error: 0.81%

从上面结果可以看出,深度越深(convolution_large 越大)神经网络模型误差越小,同时花费的训练时间也越长。相反,基准模型速度很快,但最终的精确性却最低。
(文章来源公众号Python中文社区)


http://www.niftyadmin.cn/n/4788549.html

相关文章

Java中两变量三方法基础判断与理解

在这里首先先进行一个举例 package Text;//包名一般文件被其包裹在一个包里 public class Car { String color;//实例变量1 String paizi;//实例变量2 static int lunzi4;//静态变量 储存在指定地方。为一个类的基本信息如:车轮子4个 //构造方法不会产生返回值 如何…

一条SQL语句执行得很慢的原因到底有哪些?你不知道很有可能当误你的大事

说实话,这个问题可以涉及到 MySQL 的很多核心知识,可以扯出一大堆,就像要考你计算机网络的知识时,问你“输入URL回车之后,究竟发生了什么”一样,看看你能说出多少了。 之前腾讯面试的实话,也问…

数据结构中动态分配空间

1&#xff1a;就链式而言&#xff08;我个人比较喜欢链式因为链式缺少空间可以以加结点的形式来进行增加比较方便&#xff09; #include <iostream> #include <stdlib.h> using namespace std; void dongtaifenpei(){int *p;pnew int;&#xff08;new一个新的类型…

5个重要的人工智能预测(2019年)每个人都应该阅读

人工智能 - 特别是机器学习和深度学习 - 在2018年无处不在&#xff0c;并且预计未来12个月的炒作将不会消失。 当然&#xff0c;炒作最终会消亡&#xff0c;人工智能将成为我们生活中的另一个连贯的线索&#xff0c;就像互联网&#xff0c;电力和燃烧在过去几天一样。 但至少…

VMware 菜鸟教程

VMware下载与安装 一、 虚拟机的下载 进入VMware官网&#xff08;https://www.vmware.com/cn.html&#xff09;&#xff0c;可能会有一点慢&#xff0c;耐心等待。 点击下载&#xff0c;进入到如下页面 点击下载产品 可以看到这里有两个版本&#xff0c;windows和Linux版本&…

VMware tools详细教程 解决安装失败等问题

1、打开虚拟机VMware Workstation&#xff0c;启动Ubuntu系统&#xff0c;菜单栏 - 虚拟机 - 安装VMware Tools&#xff0c;不启动Ubuntu系统是无法点击“安装VMware Tools”选项的&#xff0c;如下图&#xff1a; 必须在虚拟机内部进行安装&#xff01;&#xff01;&#xff0…

人工智能创造了一个假世界 - 这对人类意味着什么?

“眼见为实”或是吗&#xff1f;曾经有一段时间我们可以确信我们在照片和视频中看到的内容是真实的。即使Photoshopping图像变得流行&#xff0c;我们仍然知道图像是作为原始图像开始的。现在&#xff0c;随着人工智能的进步&#xff0c;世界变得越来越虚化&#xff0c;你不能确…

使用人工智能自定义无偏见的内容

&#xff08;想自学习编程的小伙伴请搜索圈T社区&#xff0c;更多行业相关资讯更有行业相关免费视频教程。完全免费哦!&#xff09; 使用人工智能&#xff08;AI&#xff09;定制内容正在改变今天内容消费和货币化的方式。从媒体到制造&#xff0c;人工智能在公司如何创建更加…