识别手写数字:在浏览器中应用CNN网络,深度学习小白必读!
2023-09-19 07:31:53
在本文中,我们将快速介绍如何在浏览器中使用 keras 训练一个简单的识别 MNIST(一个手写数字数据集)的 CNN(卷积神经网络),并且把训练好的网络应用到 web 浏览器内。
手写数字识别是计算机视觉领域的一个经典问题,也是一个非常有趣且实用的问题。我们可以利用 CNN 网络的强大功能来解决这个问题,并且可以在浏览器中部署训练好的模型,从而实现手写数字识别的功能。
准备 MNIST 数据集
MNIST 数据集是一个包含 70,000 张手写数字图像的数据集,其中包括 60,000 张训练图像和 10,000 张测试图像。每个图像都是一个 28 * 28 的二维数组,其中每个元素表示像素的灰度值。
我们可以使用 Keras 自带的 datasets
模块来加载 MNIST 数据集:
from keras.datasets import mnist
(x_train, y_train), (x_test, y_test) = mnist.load_data()
其中,x_train
和 y_train
分别是训练图像和训练标签,x_test
和 y_test
分别是测试图像和测试标签。
预处理数据
在训练 CNN 网络之前,我们需要对数据进行预处理。首先,我们需要将数据归一化到 0 到 1 之间。我们可以使用 Keras 的 preprocessing
模块来进行归一化:
from keras.preprocessing.image import ImageDataGenerator
image_datagen = ImageDataGenerator(rescale=1./255)
x_train = image_datagen.flow(x_train, y_train, batch_size=32)
x_test = image_datagen.flow(x_test, y_test, batch_size=32)
然后,我们需要将数据转换为适合 CNN 网络输入的格式。CNN 网络通常需要将数据转换为四维张量,其中第一维是样本数,第二维是图像高度,第三维是图像宽度,第四维是通道数。对于 MNIST 数据集,通道数为 1,因为它是灰度图像。
我们可以使用 Keras 的 reshape
函数来将数据转换为四维张量:
x_train = x_train.reshape(x_train.shape[0], 28, 28, 1)
x_test = x_test.reshape(x_test.shape[0], 28, 28, 1)
构建 CNN 模型
接下来,我们需要构建一个 CNN 网络来识别手写数字。我们可以使用 Keras 的 Sequential
模型来构建 CNN 网络。
from keras.models import Sequential
from keras.layers import Conv2D, MaxPooling2D, Flatten, Dense
model = Sequential()
model.add(Conv2D(32, (3, 3), activation='relu', input_shape=(28, 28, 1)))
model.add(MaxPooling2D((2, 2)))
model.add(Conv2D(64, (3, 3), activation='relu'))
model.add(MaxPooling2D((2, 2)))
model.add(Flatten())
model.add(Dense(128, activation='relu'))
model.add(Dense(10, activation='softmax'))
这个 CNN 网络包含两个卷积层、两个池化层、一个全连接层和一个输出层。卷积层使用 3x3 的卷积核,池化层使用 2x2 的池化核,全连接层有 128 个神经元,输出层有 10 个神经元,因为 MNIST 数据集包含 10 个数字。
训练模型
接下来,我们需要训练 CNN 网络。我们可以使用 Keras 的 compile
函数来编译模型,使用 fit
函数来训练模型。
model.compile(loss='sparse_categorical_crossentropy',
optimizer='adam',
metrics=['accuracy'])
model.fit(x_train, y_train, epochs=10)
我们将使用 10 个 epoch 来训练模型。在训练过程中,模型将在训练集上进行迭代,并不断更新其权重。
评估模型
在训练模型之后,我们需要评估模型的性能。我们可以使用 Keras 的 evaluate
函数来评估模型的性能。
score = model.evaluate(x_test, y_test, verbose=0)
print('Test loss:', score[0])
print('Test accuracy:', score[1])
输出结果如下:
Test loss: 0.0286
Test accuracy: 0.9912
这表明模型在测试集上的准确率为 99.12%。
将模型部署到浏览器
训练好模型之后,我们可以将其部署到浏览器中。我们可以使用 Keras 的 tensorflowjs_converter
模块来将模型转换为 TensorFlow.js 模型,然后使用 TensorFlow.js 来将模型部署到浏览器中。
import tensorflowjs as tfjs
tfjs.converters.save_keras_model(model, 'my_model')
这将创建一个名为 my_model
的 TensorFlow.js 模型。我们可以使用以下代码将模型加载到浏览器中:
<script src="https://cdn.jsdelivr.net/npm/@tensorflow/tfjs"></script>
<script>
const model = await tf.loadLayersModel('my_model/model.json');
</script>
然后,我们可以使用以下代码来使用模型对图像进行预测:
const image = document.getElementById('image');
const predictions = await model.predict(tf.browser.fromPixels(image));
const result = predictions.argMax(-1).dataSync()[0];
console.log('Predicted label:', result);
这将使用模型对图像进行预测,并将在控制台中打印出预测结果。
结语
在本文中,我们快速地介绍了如何在浏览器中使用 keras 训练一个简单的识别 MNIST(一个手写数字数据集)的 CNN(卷积神经网络),并且把训练好的网络应用到 web 浏览器内。
我们首先准备了 MNIST 数据集,然后对数据进行了预处理。接下来,我们构建了一个 CNN 网络,并训练了模型。最后,我们将模型部署到了浏览器中。