返回

PyTorch实操:利用逻辑回归对鸢尾花分类的实战指南

人工智能

在机器学习的浩瀚世界中,探索实践至关重要。而PyTorch是一个功能强大的深度学习框架,为实践提供了广阔的天地。在这篇文章中,我们将踏上PyTorch的实战之旅,探索如何利用逻辑回归对鸢尾花进行分类。

了解逻辑回归

逻辑回归是一种广为人知的分类算法,特别适用于二分类问题。它的核心思想是通过一个线性函数将输入特征映射到概率分布中,其中0表示某个类的概率,而1表示另一类的概率。

获取数据集

我们的目标是训练一个逻辑回归模型来区分三种鸢尾花品种:山鸢尾、变色鸢尾和弗吉尼亚鸢尾。为此,我们将使用著名的鸢尾花数据集,该数据集包含150个样本,每个样本有4个特征和一个目标类别。

使用PyTorch实现逻辑回归

导入必要的库

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt

加载和预处理数据

# 加载鸢尾花数据集
df_iris = pd.read_csv('iris.csv')

# 将数据转换为PyTorch张量
np_iris = df_iris.values
X_train, X_test, y_train, y_test = train_test_split(np_iris[:, :-1], np_iris[:, -1], test_size=0.2)
X_train = torch.FloatTensor(X_train)
X_test = torch.FloatTensor(X_test)
y_train = torch.LongTensor(y_train)
y_test = torch.LongTensor(y_test)

# 创建数据加载器
train_loader = DataLoader(dataset=TensorDataset(X_train, y_train), batch_size=64, shuffle=True)
test_loader = DataLoader(dataset=TensorDataset(X_test, y_test), batch_size=64, shuffle=False)

定义逻辑回归模型

class LogisticRegression(nn.Module):
    def __init__(self, input_size, output_size):
        super(LogisticRegression, self).__init__()
        self.linear = nn.Linear(input_size, output_size)

    def forward(self, x):
        out = self.linear(x)
        return F.log_softmax(out, dim=1)

训练模型

# 创建模型实例
model = LogisticRegression(4, 3)

# 定义损失函数和优化器
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)

# 训练模型
for epoch in range(1000):
    for i, data in enumerate(train_loader):
        # 获取输入和标签
        inputs, labels = data
        
        # 前向传播
        outputs = model(inputs)
        
        # 计算损失
        loss = criterion(outputs, labels)
        
        # 反向传播
        optimizer.zero_grad()
        loss.backward()
        
        # 更新权重
        optimizer.step()

评估模型

# 评估模型
correct = 0
total = 0
with torch.no_grad():
    for data in test_loader:
        inputs, labels = data
        outputs = model(inputs)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

print(f'Accuracy of the network on the test data: {100 * correct / total}%')

总结

恭喜您!您已成功使用PyTorch对鸢尾花进行了逻辑回归分类。通过本次实践,您掌握了以下知识:

  • 如何加载和预处理数据
  • 如何定义和训练逻辑回归模型
  • 如何评估模型的性能

继续探索机器学习和深度学习的精彩世界,并使用PyTorch释放您的想象力!