返回

BiFormer:插即用系列解锁视觉任务新范式

后端

双向路由注意力网络:BiFormer

简介

在计算机视觉领域,模型需要灵活高效地处理各种任务,从图像分类到实例分割。近日发表在 CVPR'2023 上的论文《BiFormer:构建具有双向路由注意力的金字塔网络》提出了一种创新模型,可满足这一需求。

双向路由注意力(BRA)模块

BiFormer 的核心是双向路由注意力(BRA)模块。BRA 模块是一种动态注意力机制,可分配计算资源,以关注输入特征图中最相关的区域。它通过构建和修剪区域级有向图来实现此目标,从而过滤掉无关区域。这种机制使模型能够灵活地适应不同任务和输入的要求。

BiFormer 模型

基于 BRA 模块,BiFormer 模型采用金字塔结构,包含多个阶段。每个阶段由 BRA 模块和一个下采样层组成。BRA 模块负责计算特征图之间的注意力,而下采样层负责减小特征图的尺寸。这种设计允许 BiFormer 模型在保持高分辨率的同时提取强大特征。

卓越的性能

BiFormer 模型在图像分类、目标检测、实例分割和语义分割等四种主要视觉任务上表现出色。在 ImageNet 图像分类数据集上,BiFormer 取得了 84.5% 的顶级准确率,在 COCO 目标检测数据集上取得了 45.3% 的框 AP,在 COCO 实例分割数据集上取得了 36.3% 的掩码 AP,在 ADE20K 语义分割数据集上取得了 46.6% 的 mIoU。

BiFormer 的优势

BiFormer 模型具有以下优势:

  • 即插即用: BiFormer 可以轻松应用于各种视觉任务,无需大量修改,使其成为通用工具。
  • 高效: BRA 模块的动态计算资源分配机制提高了 BiFormer 的计算效率,使其在相同资源下优于其他模型。
  • 准确: BiFormer 在四项视觉任务中均达到或超越了最先进的性能,证明了其强大的特征表示能力和泛化能力。

应用与前景

BiFormer 的插即用、高效和准确特性使其在广泛的视觉任务中具有应用前景,包括图像分类、目标检测、实例分割、语义分割、人脸识别和医疗图像分析。它有望在这些领域发挥变革性作用,为我们带来新的见解和可能性。

常见问题解答

  1. BiFormer 如何处理不同大小的输入?
    BiFormer 的金字塔结构使它能够处理不同大小的输入,通过下采样和上采样层调整特征图的分辨率。

  2. BRA 模块如何平衡效率和准确性?
    BRA 模块通过动态计算资源分配在效率和准确性之间取得平衡,在无关区域修剪计算并将其分配给重要区域。

  3. BiFormer 与其他金字塔模型有何不同?
    BiFormer 利用 BRA 模块,该模块具有查询感知能力,可以更精确地分配计算资源,从而提高了效率和性能。

  4. BiFormer 是否需要预训练?
    BiFormer 可以从头开始训练,也可以使用在 ImageNet 等大型数据集上预训练的权重进行微调。

  5. BiFormer 是否适用于实时应用?
    BiFormer 的高效性使其适用于需要低延迟的实时应用,例如目标检测和自动驾驶。

代码示例

import torch
import torch.nn as nn

class BRA(nn.Module):
    def __init__(self, dim, num_heads, qkv_bias=True):
        super().__init__()
        self.dim = dim
        self.num_heads = num_heads
        self.qkv_bias = qkv_bias

        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
        self.proj = nn.Linear(dim, dim)

    def forward(self, x):
        # Query, Key, Value
        qkv = self.qkv(x).reshape(x.shape[0], x.shape[1], 3, self.num_heads, -1).permute(2, 0, 1, 3, 4)
        q, k, v = qkv[0], qkv[1], qkv[2]

        # Routing
        sim = torch.einsum('bnhd,bnhq->bnqhd', k, q) / self.dim ** 0.5
        attn = sim.softmax(dim=-2)

        # Value Update
        out = torch.einsum('bnqhd,bnhe->bnhd', attn, v)
        out = out.reshape(x.shape[0], x.shape[1], -1)

        # Projection
        out = self.proj(out)

        return out

class BiFormer(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        self.stages = nn.ModuleList()

        for stage_cfg in cfg:
            stage = nn.Sequential()
            for block_cfg in stage_cfg:
                if block_cfg['type'] == 'BRA':
                    stage.add_module(f'BRA-{block_cfg["num_blocks"]}',
                                      nn.Sequential(*[BRA(block_cfg['dim'], block_cfg['num_heads']) for _ in range(block_cfg['num_blocks'])])
                                     )
                else:
                    stage.add_module(f'Down-{block_cfg["name"]}', nn.MaxPool2d(block_cfg['kernel_size'], block_cfg['stride']))
            self.stages.append(stage)

    def forward(self, x):
        for stage in self.stages:
            x = stage(x)
        return x