返回
使用PyTorch:学习编程就像打游戏一样
人工智能
2023-10-26 18:44:07
有时,学习可能很困难,就像打一场艰难的比赛一样。但如果我们能找到一种方法将学习变成一场游戏,那么我们就会更愿意花时间学习,也更有可能取得成功。
PyTorch是一个强大的深度学习框架,可以轻松构建和训练神经网络。在本文中,我们将向您展示如何使用PyTorch构建一个简单的神经网络,并对其进行训练。我们将从头开始,所以即使您是深度学习的新手,您也可以轻松跟上。
PyTorch基础
在开始构建神经网络之前,我们需要先了解一下PyTorch的基础知识。
- 张量 :张量是PyTorch中的基本数据结构。张量类似于NumPy中的数组,但它们具有更多的功能。张量可以是1D、2D或更高维的。
- 神经网络 :神经网络是一种受人类大脑启发的机器学习模型。神经网络可以学习从数据中提取特征,并将其用于预测或分类。
- 优化器 :优化器是一种算法,用于更新神经网络中的权重和偏差。优化器的目标是找到一组权重和偏差,使得神经网络在训练集上的性能最好。
- 损失函数 :损失函数衡量神经网络的预测与真实值之间的差异。优化器的目标是找到一组权重和偏差,使得损失函数的值最小。
构建一个简单的神经网络
现在我们已经了解了PyTorch的基础知识,我们可以开始构建一个简单的神经网络了。我们将构建一个神经网络来对MNIST数据集中的手写数字进行分类。
MNIST数据集包含70,000张手写数字图像,其中60,000张用于训练,10,000张用于测试。每张图像都是28x28像素的灰度图像。
加载和预处理数据
首先,我们需要加载MNIST数据集并对其进行预处理。我们可以使用以下代码来加载MNIST数据集:
import torchvision
from torchvision import transforms
# 加载MNIST数据集
train_data = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=transforms.ToTensor())
test_data = torchvision.datasets.MNIST(root='./data', train=False, download=True, transform=transforms.ToTensor())
# 将数据转换为张量
train_data = torch.utils.data.DataLoader(train_data, batch_size=64, shuffle=True)
test_data = torch.utils.data.DataLoader(test_data, batch_size=64, shuffle=True)
构建神经网络
现在我们已经加载并预处理了数据,我们可以开始构建神经网络了。我们将构建一个简单的神经网络,由两个全连接层组成。
import torch
import torch.nn as nn
import torch.nn.functional as F
# 定义神经网络
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.fc1 = nn.Linear(784, 128)
self.fc2 = nn.Linear(128, 10)
def forward(self, x):
x = x.view(-1, 784)
x = F.relu(self.fc1(x))
x = self.fc2(x)
return x
# 实例化神经网络
net = Net()
# 定义损失函数和优化器
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(net.parameters(), lr=0.001)
训练神经网络
现在我们已经构建了神经网络,我们可以开始训练它了。我们将使用Adam优化器和交叉熵损失函数来训练神经网络。
# 训练神经网络
for epoch in range(10):
for batch_idx, (data, target) in enumerate(train_data):
# 前向传播
output = net(data)
# 计算损失
loss = criterion(output, target)
# 反向传播
optimizer.zero_grad()
loss.backward()
# 更新权重
optimizer.step()
# 测试神经网络
test_loss = 0
correct = 0
with torch.no_grad():
for batch_idx, (data, target) in enumerate(test_data):
# 前向传播
output = net(data)
# 计算损失
test_loss += criterion(output, target).item()
# 计算准确率
_, predicted = torch.max(output, 1)
correct += (predicted == target).sum().item()
test_loss /= len(test_data)
accuracy = 100. * correct / len(test_data.dataset)
print(f'Epoch: {epoch}, Test Loss: {test_loss:.4f}, Accuracy: {accuracy:.2f}%')
经过10个epoch的训练,我们的神经网络在MNIST数据集上的准确率达到了98%。这意味着我们的神经网络能够正确地将98%的手写数字分类。