返回

使用深度学习对 Fashion-MNIST 服装图像进行分类:Python 代码实例

人工智能

深度学习在计算机视觉领域取得了显著的成功,特别是用于图像分类任务。本文将演示如何使用 PyTorch 框架构建一个深度学习模型,对 Fashion-MNIST 服装图像数据集进行分类。该数据集包含 70,000 张灰度图像,每张图像代表 10 种不同的服装类别。

简介

Fashion-MNIST 数据集是 MNIST 手写数字数据集的扩展,包含更大、更复杂的图像。这使其成为测试图像分类算法的理想数据集。在本文中,我们将构建一个卷积神经网络 (CNN) 模型,该模型利用其强大的特征提取能力来对 Fashion-MNIST 图像进行分类。

模型架构

我们的 CNN 模型将基于 LeNet-5 架构,该架构是一个经典的 CNN 模型,最初用于手写数字识别。LeNet-5 包含以下层:

  • 卷积层:提取图像特征
  • 池化层:缩小特征图尺寸
  • 全连接层:将特征映射到最终类别

数据准备

下载 Fashion-MNIST 数据集并将其加载到 PyTorch 数据加载器中。数据加载器将以小批量的方式提供图像,从而提高模型训练效率。

import torchvision.datasets as datasets
import torchvision.transforms as transforms

# 加载 Fashion-MNIST 数据集
train_dataset = datasets.FashionMNIST(
    root='./data',
    train=True,
    download=True,
    transform=transforms.ToTensor()
)
test_dataset = datasets.FashionMNIST(
    root='./data',
    train=False,
    download=True,
    transform=transforms.ToTensor()
)

# 创建数据加载器
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=64, shuffle=False)

模型定义

定义 CNN 模型架构,包括卷积层、池化层和全连接层。

import torch.nn as nn
import torch.nn.functional as F

# 定义 LeNet-5 模型架构
class LeNet5(nn.Module):
    def __init__(self):
        super(LeNet5, self).__init__()
        self.conv1 = nn.Conv2d(1, 6, 5)
        self.pool1 = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.pool2 = nn.MaxPool2d(2, 2)
        self.fc1 = nn.Linear(16 * 4 * 4, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        x = self.pool1(F.relu(self.conv1(x)))
        x = self.pool2(F.relu(self.conv2(x)))
        x = x.view(-1, 16 * 4 * 4)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

模型训练

使用交叉熵损失函数和 Adam 优化器训练模型。

# 实例化模型
model = LeNet5()

# 定义损失函数和优化器
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

# 训练模型
num_epochs = 10
for epoch in range(num_epochs):
    for images, labels in train_loader:
        # 前向传播
        outputs = model(images)
        loss = criterion(outputs, labels)

        # 反向传播
        optimizer.zero_grad()
        loss.backward()

        # 更新权重
        optimizer.step()

模型评估

使用测试集评估训练后的模型的性能。

# 评估模型
total = 0
correct = 0
with torch.no_grad():
    for images, labels in test_loader:
        outputs = model(images)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

print(f'Accuracy of the network on the 10000 test images: {100 * correct / total} %')

总结

我们成功地构建了一个 CNN 模型来对 Fashion-MNIST 服装图像进行分类。该模型基于 LeNet-5 架构,利用卷积层和池化层的强大功能来提取图像特征。通过训练该模型,我们实现了很高的分类准确度,展示了深度学习在图像分类任务中的潜力。