返回

初探 PyTorch:入门手写数字识别之旅

人工智能

PyTorch 实战:用神经网络征服手写数字识别世界

踏入人工智能领域,选择合适的框架至关重要。PyTorch 以其灵活性、强大的性能和直观的模型构建方式,成为深度学习探索的理想伴侣。本文将带你踏上 PyTorch 之旅,亲身体验如何用它解决实际问题——手写数字识别!

认识 PyTorch

PyTorch 是一个开源的 Python 深度学习库,由 Facebook AI Research 开发。它以动态计算图和灵活的模型构建而著称。PyTorch 通过直接定义计算图的方式,让深度学习研究和开发变得更轻松。

搭建开发环境

为了开始我们的旅程,我们需要确保开发环境准备就绪。你需要 Python 3.6.5 及以上版本、Jupyter Lab 作为集成开发环境以及 TensorFlow 2.4.1 作为深度学习框架。另外,我们还需要下载 MNIST 手写数字数据集,该数据集可在参加训练营时获取。

PyTorch 实战:手写数字识别

1. 数据准备

MNIST 数据集包含 70,000 张手写数字图像,其中 60,000 张用于训练,10,000 张用于测试。PyTorch 内置了对 MNIST 数据集的支持,我们可以轻松地将它加载到代码中:

import torch
from torchvision import datasets, transforms

# 下载 MNIST 数据集
train_data = datasets.MNIST(root='./data', train=True, download=True, transform=transforms.ToTensor())
test_data = datasets.MNIST(root='./data', train=False, download=True, transform=transforms.ToTensor())

# 准备数据加载器
train_loader = torch.utils.data.DataLoader(train_data, batch_size=64, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_data, batch_size=64, shuffle=False)

2. 构建神经网络模型

手写数字识别本质上是一个分类问题。我们将构建一个简单的卷积神经网络模型,包括卷积层、池化层和全连接层:

import torch.nn as nn

# 定义模型架构
class LeNet5(nn.Module):
    def __init__(self):
        super(LeNet5, self).__init__()
        self.conv1 = nn.Conv2d(1, 6, 5)
        self.pool1 = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.pool2 = nn.MaxPool2d(2, 2)
        self.fc1 = nn.Linear(16 * 4 * 4, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        x = self.pool1(F.relu(self.conv1(x)))
        x = self.pool2(F.relu(self.conv2(x)))
        x = x.view(x.size(0), -1)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

3. 训练模型

训练模型涉及迭代地更新模型的权重,以最小化训练数据集上的损失函数。我们将使用 Adam 优化器和交叉熵损失函数:

import torch.optim as optim
import torch.nn.functional as F

# 实例化模型
model = LeNet5()

# 定义损失函数和优化器
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

# 训练模型
for epoch in range(20):
    for images, labels in train_loader:
        # 前向传播
        outputs = model(images)

        # 计算损失
        loss = criterion(outputs, labels)

        # 反向传播
        optimizer.zero_grad()
        loss.backward()

        # 更新权重
        optimizer.step()

4. 评估模型

训练完成后,我们评估模型在测试数据集上的性能:

# 评估模型
correct = 0
total = 0
with torch.no_grad():
    for images, labels in test_loader:
        outputs = model(images)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

print(f'Accuracy of the network on the 10000 test images: {100 * correct / total} %')

结语

通过 PyTorch 的强大功能,我们成功地实现了手写数字识别。本周的旅程为我们提供了坚实的基础,了解 PyTorch 的工作原理和如何将其应用于实际问题。随着我们深入探索 PyTorch,我们将解锁更复杂的模型和应用程序。让我们一起继续这段激动人心的学习之旅!

常见问题解答

1. 为什么选择 PyTorch?

PyTorch 因其灵活性、动态计算图和直观的模型构建而受到深度学习研究人员的青睐。

2. PyTorch 和 TensorFlow 有什么区别?

PyTorch 使用动态计算图,而 TensorFlow 使用静态计算图。PyTorch 允许在训练过程中更灵活地更改模型,而 TensorFlow 则在训练前需要定义整个计算图。

3. 如何在我的计算机上安装 PyTorch?

可以在 PyTorch 网站上找到有关安装说明:https://pytorch.org/get-started/locally/

4. MNIST 数据集是什么?

MNIST 数据集包含 70,000 张手写数字图像,其中 60,000 张用于训练,10,000 张用于测试。

5. 卷积神经网络 (CNN) 是什么?

CNN 是一种专门用于处理栅格数据(如图像)的神经网络类型。它使用卷积层从数据中提取特征。