返回
PyTorch实操:利用逻辑回归对鸢尾花分类的实战指南
人工智能
2023-09-19 01:34:47
在机器学习的浩瀚世界中,探索实践至关重要。而PyTorch是一个功能强大的深度学习框架,为实践提供了广阔的天地。在这篇文章中,我们将踏上PyTorch的实战之旅,探索如何利用逻辑回归对鸢尾花进行分类。
了解逻辑回归
逻辑回归是一种广为人知的分类算法,特别适用于二分类问题。它的核心思想是通过一个线性函数将输入特征映射到概率分布中,其中0表示某个类的概率,而1表示另一类的概率。
获取数据集
我们的目标是训练一个逻辑回归模型来区分三种鸢尾花品种:山鸢尾、变色鸢尾和弗吉尼亚鸢尾。为此,我们将使用著名的鸢尾花数据集,该数据集包含150个样本,每个样本有4个特征和一个目标类别。
使用PyTorch实现逻辑回归
导入必要的库
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
加载和预处理数据
# 加载鸢尾花数据集
df_iris = pd.read_csv('iris.csv')
# 将数据转换为PyTorch张量
np_iris = df_iris.values
X_train, X_test, y_train, y_test = train_test_split(np_iris[:, :-1], np_iris[:, -1], test_size=0.2)
X_train = torch.FloatTensor(X_train)
X_test = torch.FloatTensor(X_test)
y_train = torch.LongTensor(y_train)
y_test = torch.LongTensor(y_test)
# 创建数据加载器
train_loader = DataLoader(dataset=TensorDataset(X_train, y_train), batch_size=64, shuffle=True)
test_loader = DataLoader(dataset=TensorDataset(X_test, y_test), batch_size=64, shuffle=False)
定义逻辑回归模型
class LogisticRegression(nn.Module):
def __init__(self, input_size, output_size):
super(LogisticRegression, self).__init__()
self.linear = nn.Linear(input_size, output_size)
def forward(self, x):
out = self.linear(x)
return F.log_softmax(out, dim=1)
训练模型
# 创建模型实例
model = LogisticRegression(4, 3)
# 定义损失函数和优化器
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
# 训练模型
for epoch in range(1000):
for i, data in enumerate(train_loader):
# 获取输入和标签
inputs, labels = data
# 前向传播
outputs = model(inputs)
# 计算损失
loss = criterion(outputs, labels)
# 反向传播
optimizer.zero_grad()
loss.backward()
# 更新权重
optimizer.step()
评估模型
# 评估模型
correct = 0
total = 0
with torch.no_grad():
for data in test_loader:
inputs, labels = data
outputs = model(inputs)
_, predicted = torch.max(outputs.data, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
print(f'Accuracy of the network on the test data: {100 * correct / total}%')
总结
恭喜您!您已成功使用PyTorch对鸢尾花进行了逻辑回归分类。通过本次实践,您掌握了以下知识:
- 如何加载和预处理数据
- 如何定义和训练逻辑回归模型
- 如何评估模型的性能
继续探索机器学习和深度学习的精彩世界,并使用PyTorch释放您的想象力!