返回
使用深度学习对 Fashion-MNIST 服装图像进行分类:Python 代码实例
人工智能
2024-02-13 13:45:14
深度学习在计算机视觉领域取得了显著的成功,特别是用于图像分类任务。本文将演示如何使用 PyTorch 框架构建一个深度学习模型,对 Fashion-MNIST 服装图像数据集进行分类。该数据集包含 70,000 张灰度图像,每张图像代表 10 种不同的服装类别。
简介
Fashion-MNIST 数据集是 MNIST 手写数字数据集的扩展,包含更大、更复杂的图像。这使其成为测试图像分类算法的理想数据集。在本文中,我们将构建一个卷积神经网络 (CNN) 模型,该模型利用其强大的特征提取能力来对 Fashion-MNIST 图像进行分类。
模型架构
我们的 CNN 模型将基于 LeNet-5 架构,该架构是一个经典的 CNN 模型,最初用于手写数字识别。LeNet-5 包含以下层:
- 卷积层:提取图像特征
- 池化层:缩小特征图尺寸
- 全连接层:将特征映射到最终类别
数据准备
下载 Fashion-MNIST 数据集并将其加载到 PyTorch 数据加载器中。数据加载器将以小批量的方式提供图像,从而提高模型训练效率。
import torchvision.datasets as datasets
import torchvision.transforms as transforms
# 加载 Fashion-MNIST 数据集
train_dataset = datasets.FashionMNIST(
root='./data',
train=True,
download=True,
transform=transforms.ToTensor()
)
test_dataset = datasets.FashionMNIST(
root='./data',
train=False,
download=True,
transform=transforms.ToTensor()
)
# 创建数据加载器
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=64, shuffle=False)
模型定义
定义 CNN 模型架构,包括卷积层、池化层和全连接层。
import torch.nn as nn
import torch.nn.functional as F
# 定义 LeNet-5 模型架构
class LeNet5(nn.Module):
def __init__(self):
super(LeNet5, self).__init__()
self.conv1 = nn.Conv2d(1, 6, 5)
self.pool1 = nn.MaxPool2d(2, 2)
self.conv2 = nn.Conv2d(6, 16, 5)
self.pool2 = nn.MaxPool2d(2, 2)
self.fc1 = nn.Linear(16 * 4 * 4, 120)
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(84, 10)
def forward(self, x):
x = self.pool1(F.relu(self.conv1(x)))
x = self.pool2(F.relu(self.conv2(x)))
x = x.view(-1, 16 * 4 * 4)
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.fc3(x)
return x
模型训练
使用交叉熵损失函数和 Adam 优化器训练模型。
# 实例化模型
model = LeNet5()
# 定义损失函数和优化器
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
# 训练模型
num_epochs = 10
for epoch in range(num_epochs):
for images, labels in train_loader:
# 前向传播
outputs = model(images)
loss = criterion(outputs, labels)
# 反向传播
optimizer.zero_grad()
loss.backward()
# 更新权重
optimizer.step()
模型评估
使用测试集评估训练后的模型的性能。
# 评估模型
total = 0
correct = 0
with torch.no_grad():
for images, labels in test_loader:
outputs = model(images)
_, predicted = torch.max(outputs.data, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
print(f'Accuracy of the network on the 10000 test images: {100 * correct / total} %')
总结
我们成功地构建了一个 CNN 模型来对 Fashion-MNIST 服装图像进行分类。该模型基于 LeNet-5 架构,利用卷积层和池化层的强大功能来提取图像特征。通过训练该模型,我们实现了很高的分类准确度,展示了深度学习在图像分类任务中的潜力。