返回
PyTorch实战三:使用DNN实现逻辑回归对FashionMNIST进行分类
人工智能
2023-10-17 22:45:09
PyTorch实战三:使用GPU构建DNN进行FashionMNIST分类
引言
深度神经网络(DNN)在图像分类等计算机视觉任务中发挥着至关重要的作用。在本教程中,我们将深入了解如何使用PyTorch构建DNN,并在FashionMNIST数据集上应用它。值得注意的是,我们将利用GPU的并行处理能力来加速训练过程。
DNN架构
我们的DNN将是一个三层神经网络,包括:
- 输入层:接收784个输入值(FashionMNIST图像大小为28x28)
- 隐藏层:包含512个神经元,使用ReLU激活函数
- 输出层:包含10个神经元(对应于FashionMNIST中的10个类别)
数据集准备
FashionMNIST是一个包含7万张手写服饰图片的数据集。我们使用TorchVision库加载和预处理数据集:
import torch
from torchvision import datasets, transforms
train_data = datasets.FashionMNIST(
root="./data",
train=True,
download=True,
transform=transforms.ToTensor()
)
test_data = datasets.FashionMNIST(
root="./data",
train=False,
download=True,
transform=transforms.ToTensor()
)
模型构建
使用PyTorch构建DNN:
import torch.nn as nn
class DNN(nn.Module):
def __init__(self):
super(DNN, self).__init__()
self.fc1 = nn.Linear(784, 512)
self.fc2 = nn.Linear(512, 10)
def forward(self, x):
x = F.relu(self.fc1(x))
x = self.fc2(x)
return x
GPU加速
为了利用GPU的并行计算能力,我们将模型移动到GPU:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = DNN().to(device)
训练过程
我们使用交叉熵损失和Adam优化器来训练模型:
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
for epoch in range(10):
# 训练
model.train()
for images, labels in train_loader:
images, labels = images.to(device), labels.to(device)
optimizer.zero_grad()
outputs = model(images)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
# 评估
model.eval()
with torch.no_grad():
correct = 0
total = 0
for images, labels in test_loader:
images, labels = images.to(device), labels.to(device)
outputs = model(images)
_, predicted = torch.max(outputs.data, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
print(f"Epoch {epoch + 1}: Accuracy: {100 * correct / total:.2f}%")
模型保存和加载
训练完成后,我们可以保存模型参数以便以后使用:
torch.save(model.state_dict(), "dnn_model.pt")
要加载保存的模型:
model = DNN()
model.load_state_dict(torch.load("dnn_model.pt"))
结论
我们已经成功构建了一个DNN并使用GPU加速将其训练在FashionMNIST数据集上。本教程展示了PyTorch如何使构建和训练深度学习模型变得轻而易举。凭借其强大的生态系统和GPU支持,PyTorch已成为构建高性能计算机视觉应用程序的首选工具。