从基础到进阶:PyTorch 中必不可少的 CNN 模块大全
2024-02-16 16:59:33
卷积神经网络(CNN)模块:PyTorch 中的基石
卷积神经网络 (CNN) 是计算机视觉领域的支柱,在图像分类、目标检测和语义分割等任务中表现出色。PyTorch 深度学习框架为开发人员提供了全面的 CNN 模块,赋予他们构建强大神经网络的工具。本文将深入探讨 PyTorch 中五种最常用的 CNN 模块,揭示它们的结构、优点和局限性。
SEBlock:注入注意力以增强特征
SEBlock(Squeeze-and-Excitation Block)是一种轻量级的注意力机制,旨在提升 CNN 的表征能力。它通过全局平均池化对特征图进行“压缩”,然后使用两个全连接层进行“激励”,生成通道注意力图。此注意力图与原始特征图相乘,加强信息丰富的特征并抑制无关特征。SEBlock 因其轻量性和有效性而广受青睐。
Inception:捕捉不同尺度的特征
Inception 模块是一种复杂但强大的 CNN 模块,由 Google 开发。它并行使用不同尺寸的卷积核,从图像中提取不同尺度的特征。通过连接这些不同尺度的特征,Inception 模块能够捕捉图像中丰富的空间信息。尽管其强大的性能,Inception 模块的计算成本较高,使其不适用于资源受限的应用。
ResNet:跳过连接以克服梯度消失
ResNet(残差网络)是一种深度 CNN 架构,通过引入残差连接解决了梯度消失问题。残差连接允许信息从网络的较早层直接流向较晚层,使网络能够学习更深的特征表示。ResNet 以其训练稳定性、高准确性和高效性而著称,使其成为图像分类和目标检测的热门选择。
DenseNet:密集连接以促进特征重用
DenseNet(密集连接网络)是一种另一种深度 CNN 架构,旨在通过密集连接不同层之间的特征图来最大化特征重用。与 ResNet 中的残差连接不同,DenseNet 中的每一层都直接连接到所有先前的层。这种密集连接允许特征图在网络中进行更广泛的传播,从而提高了模型的泛化能力。
VGGNet:简单而有效的堆叠卷积层
VGGNet 是一种经典的 CNN 架构,以其简单和有效的堆叠卷积层而闻名。VGGNet 主要用于图像分类,其深度和宽度的变体已广泛用于各种计算机视觉任务。与其他更复杂的 CNN 架构相比,VGGNet 的计算成本较低,使其成为资源受限环境中的一个有吸引力的选择。
选择合适的 CNN 模块
选择合适的 CNN 模块取决于特定任务和资源约束。对于需要轻量级注意力机制的任务,SEBlock 是一个不错的选择。对于需要捕捉不同尺度特征的任务,Inception 模块是一个强大的选择。对于深度学习任务,ResNet 和 DenseNet 提供了出色的训练稳定性和准确性。对于需要简单性和计算效率的任务,VGGNet 是一个可靠的选择。
代码示例
import torch.nn as nn
# SEBlock
class SEBlock(nn.Module):
def __init__(self, channel):
super(SEBlock, self).__init__()
self.avg_pool = nn.AdaptiveAvgPool2d(1)
self.fc = nn.Sequential(
nn.Linear(channel, channel // 4),
nn.ReLU(),
nn.Linear(channel // 4, channel),
nn.Sigmoid()
)
def forward(self, x):
b, c, _, _ = x.size()
y = self.avg_pool(x).view(b, c)
y = self.fc(y).view(b, c, 1, 1)
return x * y
# Inception
class Inception(nn.Module):
def __init__(self, in_channels):
super(Inception, self).__init__()
self.conv1x1 = nn.Conv2d(in_channels, 64, kernel_size=1)
self.conv3x3 = nn.Conv2d(in_channels, 128, kernel_size=3, padding=1)
self.conv5x5 = nn.Conv2d(in_channels, 32, kernel_size=5, padding=2)
self.pool = nn.MaxPool2d(kernel_size=3, stride=1, padding=1)
def forward(self, x):
return torch.cat([self.conv1x1(x), self.conv3x3(x), self.conv5x5(x), self.pool(x)], dim=1)
# ResNet
class BasicBlock(nn.Module):
def __init__(self, in_channels, out_channels, stride=1):
super(BasicBlock, self).__init__()
self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1)
self.bn1 = nn.BatchNorm2d(out_channels)
self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
self.bn2 = nn.BatchNorm2d(out_channels)
if stride != 1 or in_channels != out_channels:
self.shortcut = nn.Sequential(
nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride),
nn.BatchNorm2d(out_channels)
)
else:
self.shortcut = nn.Identity()
def forward(self, x):
out = self.conv1(x)
out = self.bn1(out)
out = F.relu(out)
out = self.conv2(out)
out = self.bn2(out)
out = F.relu(out)
out += self.shortcut(x)
return out
# DenseNet
class Bottleneck(nn.Module):
def __init__(self, in_channels, growth_rate):
super(Bottleneck, self).__init__()
self.conv1 = nn.Conv2d(in_channels, 4 * growth_rate, kernel_size=1)
self.bn1 = nn.BatchNorm2d(4 * growth_rate)
self.conv2 = nn.Conv2d(4 * growth_rate, growth_rate, kernel_size=3, padding=1)
self.bn2 = nn.BatchNorm2d(growth_rate)
def forward(self, x):
out = self.conv1(x)
out = self.bn1(out)
out = F.relu(out)
out = self.conv2(out)
out = self.bn2(out)
out = F.relu(out)
out = torch.cat([out, x], dim=1)
return out
# VGGNet
class VGG(nn.Module):
def __init__(self, num_classes):
super(VGG, self).__init__()
self.features = nn.Sequential(
nn.Conv2d(3, 64, kernel_size=3, padding=1),
nn.BatchNorm2d(64),
nn.ReLU(),
nn.Conv2d(64, 64, kernel_size=3, padding=1),
nn.BatchNorm2d(64),
nn.ReLU(),
nn.MaxPool2d(kernel_size=2, stride=2),
nn.Conv2d(64, 128, kernel_size=3, padding=1),
nn.BatchNorm2d(128),
nn.ReLU(),
nn.Conv2d(128, 128, kernel_size=3, padding=1),
nn.BatchNorm2d(128),
nn.ReLU(),
nn.MaxPool2d(kernel_size=2, stride=2),
nn.Conv2d(128, 256, kernel_size=3, padding=1),
nn.BatchNorm2d(256),
nn.ReLU(),
nn.Conv2d(256, 256, kernel_size=3, padding=1),
nn.BatchNorm2d(256),
nn.ReLU(),
nn.Conv2d(256, 256, kernel_size=3, padding=1),
nn.BatchNorm2d(256),
nn.ReLU(),
nn.MaxPool2d(kernel_size=2, stride=2),
nn.Conv2d(256, 512, kernel_size=3, padding=1),