返回
Keras: 模型保存与加载持续训练指南
人工智能
2023-12-26 09:37:51
引言
Keras 是一个用于构建和训练神经网络的高级神经网络 API,由 TensorFlow 后端提供支持。它提供了广泛的功能,包括模型保存和加载,允许您在需要时停止和恢复训练过程。在本指南中,我们将详细介绍如何使用 Keras 保存、加载和持续训练模型,以便在中断或完成训练后恢复并继续训练模型。
保存 Keras 模型
要保存 Keras 模型,您可以使用 model.save()
方法。此方法将模型及其权重保存到指定的 HDF5 文件中。HDF5 是一种用于存储大型数据集和复杂数据结构的二进制文件格式。
import tensorflow as tf
from keras.models import Sequential
from keras.layers import Dense, Activation
from keras.optimizers import SGD
# 创建一个序贯模型
model = Sequential()
model.add(Dense(32, activation='relu', input_shape=(784,)))
model.add(Dense(10, activation='softmax'))
# 编译模型
model.compile(optimizer=SGD(), loss='categorical_crossentropy', metrics=['accuracy'])
# 训练模型
model.fit(x_train, y_train, epochs=10)
# 保存模型到 HDF5 文件
model.save('my_model.h5')
加载 Keras 模型
要加载保存的 Keras 模型,可以使用 keras.models.load_model()
方法。此方法从指定的 HDF5 文件中加载模型及其权重。
# 加载保存的模型
new_model = keras.models.load_model('my_model.h5')
持续训练 Keras 模型
要继续训练加载的模型,可以使用 model.fit()
方法。该方法从上一次训练停止的地方继续训练模型。
# 从上一次训练停止的地方继续训练模型
new_model.fit(x_train, y_train, epochs=10)
示例:手写数字识别
为了演示模型保存、加载和持续训练,我们使用 Keras 构建了一个用于手写数字识别的模型。
from tensorflow.keras.datasets import mnist
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense, Dropout, Flatten
from tensorflow.keras.utils import to_categorical
from keras.callbacks import ModelCheckpoint
# 加载 MNIST 数据集
(x_train, y_train), (x_test, y_test) = mnist.load_data()
# 归一化输入数据
x_train = x_train.astype('float32') / 255
x_test = x_test.astype('float32') / 255
# 转换为独热编码
y_train = to_categorical(y_train, 10)
y_test = to_categorical(y_test, 10)
# 创建模型
model = Sequential([
Flatten(input_shape=(28, 28)),
Dense(128, activation='relu'),
Dropout(0.2),
Dense(10, activation='softmax')
])
# 编译模型
model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])
# 创建模型检查点回调函数
checkpoint = ModelCheckpoint('mnist_model.h5', save_best_only=True)
# 训练模型
model.fit(x_train, y_train, epochs=10, validation_data=(x_test, y_test), callbacks=[checkpoint])
# 加载最佳模型
best_model = keras.models.load_model('mnist_model.h5')
# 评估模型
best_model.evaluate(x_test, y_test)
在这个示例中,我们创建了一个卷积神经网络来识别手写数字。我们使用 ModelCheckpoint
回调函数在训练期间保存最佳模型。然后,我们可以加载最佳模型并继续训练它,这在我们的示例中是不必要的,但可以演示持续训练过程。
结论
本指南详细介绍了如何使用 Keras 保存、加载和持续训练模型。通过利用这些功能,您可以停止和恢复训练过程,在需要时中断训练,并在完成训练后继续训练。这对于大型和复杂模型非常有用,需要分阶段或并行训练。