返回
用PyTorch实现逻辑回归对FashionMNIST数据集进行分类:GPU加速
人工智能
2024-02-09 19:43:08
- 导入必要的库
首先,我们需要导入必要的库。
import torch
import torchvision
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
#如果使用GPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
2. 加载数据
接下来,我们需要加载FashionMNIST数据集。FashionMNIST是一个由70,000张28x28像素的灰度图像组成的图像分类数据集,其中包含10种不同的服饰类别。
# 下载并加载FashionMNIST数据集
train_data = torchvision.datasets.FashionMNIST(
root="./data",
train=True,
download=True,
transform=torchvision.transforms.ToTensor(),
)
test_data = torchvision.datasets.FashionMNIST(
root="./data",
train=False,
download=True,
transform=torchvision.transforms.ToTensor(),
)
# 将数据加载到DataLoader中
train_loader = DataLoader(train_data, batch_size=64)
test_loader = DataLoader(test_data, batch_size=64)
3. 定义模型
现在,我们需要定义我们的逻辑回归模型。逻辑回归是一种简单的线性分类器,它将输入数据映射到一个概率值,表示该数据属于某个类别的概率。
class LogisticRegression(nn.Module):
def __init__(self, input_dim, output_dim):
super(LogisticRegression, self).__init__()
self.linear = nn.Linear(input_dim, output_dim)
def forward(self, x):
out = self.linear(x)
return out
4. 定义损失函数和优化器
接下来,我们需要定义损失函数和优化器。我们将使用交叉熵损失函数和Adam优化器。
# 定义损失函数
loss_fn = nn.CrossEntropyLoss()
# 定义优化器
optimizer = optim.Adam(model.parameters(), lr=0.001)
5. 训练模型
现在,我们可以开始训练模型了。
# 训练模型
for epoch in range(10):
for i, (inputs, labels) in enumerate(train_loader):
# 将数据移动到GPU
inputs = inputs.to(device)
labels = labels.to(device)
# 前向传播
outputs = model(inputs)
# 计算损失
loss = loss_fn(outputs, labels)
# 反向传播
optimizer.zero_grad()
loss.backward()
# 更新参数
optimizer.step()
# 打印训练信息
if i % 100 == 0:
print(f"Epoch [{epoch+1}/{10}], Step [{i}/{len(train_loader)}], Loss: {loss.item()}")
6. 评估模型
训练结束后,我们可以评估模型在测试集上的性能。
# 评估模型
correct = 0
total = 0
with torch.no_grad():
for inputs, labels in test_loader:
# 将数据移动到GPU
inputs = inputs.to(device)
labels = labels.to(device)
# 前向传播
outputs = model(inputs)
# 计算预测值
_, predicted = torch.max(outputs.data, 1)
# 累加正确预测的样本数
total += labels.size(0)
correct += (predicted == labels).sum().item()
# 计算准确率
accuracy = 100 * correct / total
print(f"Accuracy: {accuracy}%")
7. 保存模型参数
如果我们对模型的性能满意,我们可以保存模型参数以供以后使用。
# 保存模型参数
torch.save(model.state_dict(), "logistic_regression_model.pt")
8. 加载模型参数
如果我们想在以后使用模型,我们可以加载模型参数。
# 加载模型参数
model = LogisticRegression(input_dim, output_dim)
model.load_state_dict(torch.load("logistic_regression_model.pt"))
9. 总结
在这篇博文中,我们使用PyTorch实现了逻辑回归模型,并将其应用于FashionMNIST数据集的分类任务。我们还探讨了如何保存和加载模型参数。希望这篇博文对您有所帮助!