返回

从零实现对比损失:解决PyTorch梯度爆炸难题

Ai

从零实现对比损失 (Contrastive Loss):搞定梯度爆炸难题

写对比损失(Contrastive Loss)的代码时,碰上梯度爆炸(Gradient Explosion)这事儿挺让人头疼的。模型训练过程中,梯度突然变成 infNaN,训练直接中断,感觉就像代码里埋了个雷。

这不,就有朋友遇到了这个问题,他的 ContrastiveLoss 实现如下:

import torch
import torch.nn as nn
import torch.nn.functional as F

class ContrastiveLoss(nn.Module):
    def __init__(self, temperature=0.5):
        super(ContrastiveLoss, self).__init__()
        self.temperature = temperature

    def forward(self, projections_1, projections_2):
        z_i = projections_1
        z_j = projections_2
        z_i_norm = F.normalize(z_i, dim=1)
        z_j_norm = F.normalize(z_j, dim=1)
        
        # 问题可能出在这里
        cosine_num = torch.matmul(z_i, z_j.T) 
        cosine_denom = torch.matmul(z_i_norm, z_j_norm.T)
        cosine_similarity = cosine_num / cosine_denom # <- 潜在的除零风险或数值不稳定

        numerator = torch.exp(torch.diag(cosine_similarity) / self.temperature)

        # 分母的计算方式也有问题
        denominator = cosine_similarity
        diagonal_indices = torch.arange(denominator.size(0))
        denominator[diagonal_indices, diagonal_indices] = 0 # <- 原地修改张量
        denominator = torch.exp(torch.sum(cosine_similarity, dim=1)) # <- 应该先 exp 再 sum

        loss = -torch.log(numerator / denominator).sum() # <- log(0) 或除零风险
        return loss

这段代码尝试实现类似 SimCLR 中的 NT-Xent Loss,目标是拉近正样本对(同一输入的两种增强视图的投影)的表示,推开负样本对(不同输入的投影)的表示。

看着代码,梯度直接飙到无穷大,那大概率是哪里写错了。咱们来仔细扒一扒。

问题分析:雷区在哪?

这段代码里有几个地方看着就“不太对劲”,很可能是梯度爆炸的元凶:

  1. 余弦相似度的计算:画蛇添足还引入了不稳定性

    • 代码先分别计算了 未归一化 向量的点积 (cosine_num = torch.matmul(z_i, z_j.T)) 和 已归一化 向量的点积 (cosine_denom = torch.matmul(z_i_norm, z_j_norm.T))。
    • 然后用 cosine_similarity = cosine_num / cosine_denom 来得到最终的相似度。
    • 核心问题cosine_denom 本身 就是 余弦相似度矩阵!因为根据定义,cosine_similarity(A, B) = dot(A, B) / (norm(A) * norm(B))。当 A 和 B 已经是单位向量(经过 F.normalize 处理)时,它们的 L2 范数 norm(A)norm(B) 都等于 1,所以 dot(A_norm, B_norm) 直接就是余弦相似度。
    • 那个 cosine_num / cosine_denom 的除法操作不仅多余,而且极其危险。cosine_denom 的值域是 [-1, 1],如果其中某个元素非常接近 0,做除法时结果就会趋近无穷大,梯度自然也就“原地爆炸”了。
  2. 分母的计算:顺序搞反,结果全错

    • NT-Xent Loss 的分母应该是 所有负样本对 的相似度(经过温度系数缩放后)的 指数和 ,即 sum(exp(sim(i, k) / T)),其中 k 遍历所有与 i 不同的样本(包括 j 的其他实例,但不包括 i 本身)。
    • 代码里先把 cosine_similarity 对角线元素(正样本对相似度)置零:denominator[diagonal_indices, diagonal_indices] = 0。虽然意图是排除正样本,但这种原地修改张量的操作(in-place modification)有时会带来麻烦,尤其是在计算图中,可能干扰梯度计算。
    • 接着,代码计算了 denominator = torch.exp(torch.sum(cosine_similarity, dim=1))。这完全搞错了顺序!它先对一行中的(修改后的)相似度求和,然后 再对这个和取指数。正确的做法应该是先对每个负样本对的相似度 sim(i, k) / T 取指数,然后 再把这些指数值加起来。也就是 torch.sum(torch.exp(cosine_similarity / self.temperature), dim=1) (当然,还要排除掉正样本对)。
  3. 潜在的数值问题:log(0) 风险

    • 最后计算损失时用了 -torch.log(numerator / denominator)。如果 numerator 非常小或者 denominator 非常小(接近零),numerator / denominator 的结果也可能接近零,取对数 torch.log 会得到负无穷,导致 loss 变成 NaNinf

搞清楚了问题所在,修复起来就容易多了。

动手修复:两种方案任你选

我们可以提供两种修改方案:一种是尽量在原有代码结构上修正错误,另一种是采用更标准、更简洁的方式来实现 NT-Xent Loss。

方案一:修正版代码

这个方案主要修正计算余弦相似度和分母的逻辑错误,并添加一些数值稳定性处理。

import torch
import torch.nn as nn
import torch.nn.functional as F

class ContrastiveLossFixed(nn.Module):
    def __init__(self, temperature=0.5, epsilon=1e-6):
        """
        修复版的对比损失函数
        Args:
            temperature (float): 温度系数,控制分布的锐利程度
            epsilon (float): 一个很小的数,用于防止 log(0) 和除零错误
        """
        super(ContrastiveLossFixed, self).__init__()
        self.temperature = temperature
        self.epsilon = epsilon

    def forward(self, projections_1, projections_2):
        z_i = projections_1
        z_j = projections_2

        # 1. 归一化特征向量
        z_i_norm = F.normalize(z_i, p=2, dim=1)
        z_j_norm = F.normalize(z_j, p=2, dim=1)

        # 2. 计算正确的余弦相似度矩阵
        # 直接用归一化后的向量做点积,结果就是余弦相似度
        # 注意这里 z_i_norm 和 z_j_norm 来自同一批数据,所以得到的是 N x N 的矩阵
        # 其中对角线元素是 (z_i[k], z_j[k]) 对的相似度,它们是正样本对
        # 非对角线元素是 (z_i[k], z_j[l]) (k!=l) 对的相似度,是负样本对的一部分
        # 理论上还应该包含 (z_i[k], z_i[l]) 和 (z_j[k], z_j[l]) 等负样本对,标准实现会更复杂
        # 但我们先基于原代码的简化假设:只考虑 z_i 和 z_j 之间的 N^2 对相似度
        cosine_similarity_matrix = torch.matmul(z_i_norm, z_j_norm.T)

        # 提取正样本对的相似度 (对角线元素)
        # sim(i, i) in the paper, calculated between z_i[k] and z_j[k]
        positive_pairs_similarity = torch.diag(cosine_similarity_matrix)

        # 3. 计算分子: exp(sim(i,j) / T)
        # 确保不会除以零温度
        safe_temperature = max(self.temperature, self.epsilon)
        numerator = torch.exp(positive_pairs_similarity / safe_temperature)

        # 4. 计算分母: sum_{k!=i} exp(sim(i, k) / T)
        # 这里的分母是针对每个 z_i[k] 而言的,需要计算它与所有 z_j[l] (l!=k) 的相似度的指数和
        
        # 创建一个 N x N 的 mask,对角线为 False,其余为 True
        mask = ~torch.eye(cosine_similarity_matrix.size(0), device=cosine_similarity_matrix.device, dtype=torch.bool)
        
        # 选取所有负样本对的相似度
        # 对 cosine_similarity_matrix 应用 mask,得到一个稀疏表示或只是标记
        # 我们需要计算每一行的 sum(exp(sim(i, k) / T)) for k!=i
        
        # 计算所有项的 exp(sim / T)
        exp_similarity = torch.exp(cosine_similarity_matrix / safe_temperature)
        
        # 将对角线(正样本)的 exp 值置零,或者使用 mask 来求和
        # 这里我们用 mask 来选取负样本项,然后求和
        # 注意:原代码的 sum 在 exp 之前是错误的!
        denominator = torch.sum(exp_similarity * mask, dim=1) # 对每一行(每个 z_i)求和

        # 5. 计算 Loss
        # 防止分母为零
        loss_per_sample = -torch.log(numerator / (denominator + self.epsilon))
        
        # 这里原代码是 .sum(),但通常 NT-Xent Loss 是对 Batch 取平均 .mean()
        # 如果遵循原代码的 .sum(),这里保留
        loss = loss_per_sample.sum() 
        # 或者更常见的: loss = loss_per_sample.mean() 

        return loss

这个修正版做了什么?

  • 直接计算余弦相似度: 使用 torch.matmul(z_i_norm, z_j_norm.T) 直接得到正确的 cosine_similarity_matrix。简单、直接、数值稳定。
  • 修正分母计算: 先对所有相似度(除以温度 T)取指数 torch.exp(),然后利用布尔掩码 mask 过滤掉正样本对(对角线元素),最后对每行的负样本对指数值求和 torch.sum(..., dim=1)。这才是正确的计算顺序。
  • 增加数值稳定性: 加入 epsilon 防止除以零和 log(0) 的情况。同时检查 temperature 防止其为零。
  • Loss 聚合方式: 指出原代码使用 .sum(),但 .mean() 更常见,让使用者可以根据需要选择。

注意: 上述代码仍然是基于原代码的一个简化假设,即只考虑了 z_iz_j 构成的 N*N 相似度矩阵。一个更完整的 NT-Xent 实现(如下面的方案二所示)通常会将 z_iz_j 合并,计算一个 2N * 2N 的相似度矩阵,从而包含更多类型的负样本对(例如 z_i[k] vs z_i[l])。

方案二:更标准的 NT-Xent Loss 实现

这个方案直接实现更标准的 NT-Xent Loss,它将两个视图的投影 z_iz_j 拼接到一起,计算一个更大的相似度矩阵,处理起来更规范,也更容易理解。

import torch
import torch.nn as nn
import torch.nn.functional as F

class NTXentLoss(nn.Module):
    """
    更标准的 NT-Xent Loss 实现 (来自 SimCLR)
    """
    def __init__(self, temperature=0.5, batch_size=None, world_size=1, device=None, epsilon=1e-6):
        """
        Args:
            temperature (float): 温度系数.
            batch_size (int): 每个 GPU 上的批次大小. 如果不提供,会尝试从输入推断.
            world_size (int): 分布式训练的总进程数 (GPU数量).
            device (torch.device): 计算所在的设备.
            epsilon (float): 防止 log(0) 的小常数.
        """
        super(NTXentLoss, self).__init__()
        self.temperature = temperature
        self.batch_size = batch_size # batch_size in one GPU
        self.world_size = world_size # num of GPUs
        self.device = device if device else torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        self.epsilon = epsilon
        self.criterion = nn.CrossEntropyLoss(reduction="sum") # 使用交叉熵简化计算

    def forward(self, projections_1, projections_2):
        """
        输入是来自两个视图的投影 (N, embedding_dim).
        """
        z_i = projections_1
        z_j = projections_2
        
        # 动态获取 batch size
        current_batch_size = z_i.size(0)
        if self.batch_size is None:
            # 警告:如果在 DDP 中不显式设置 batch_size,这里的 N 可能是 total_batch_size / world_size
            print(f"Warning: batch_size not provided, inferred as {current_batch_size}. Ensure this is correct, especially in DDP.")
            inferred_total_batch_size = current_batch_size * self.world_size
        else:
            inferred_total_batch_size = self.batch_size * self.world_size

        # 1. 归一化特征
        z_i = F.normalize(z_i, p=2, dim=1)
        z_j = F.normalize(z_j, p=2, dim=1)

        # 2. [重要] 处理分布式训练 (DDP): 收集所有 GPU 上的特征
        # 如果在单卡上运行,world_size=1,gather 操作不起作用,z_i, z_j 保持不变 (N, dim)
        if self.world_size > 1:
            # 需要安装并配置好 torch.distributed
            import torch.distributed as dist
            # Gather z_i
            z_i_list = [torch.zeros_like(z_i) for _ in range(self.world_size)]
            dist.all_gather(z_i_list, z_i)
            z_i = torch.cat(z_i_list, dim=0) # Shape: (N * world_size, dim)
            
            # Gather z_j
            z_j_list = [torch.zeros_like(z_j) for _ in range(self.world_size)]
            dist.all_gather(z_j_list, z_j)
            z_j = torch.cat(z_j_list, dim=0) # Shape: (N * world_size, dim)

        # 现在 z_i 和 z_j 的 shape 都是 (Total_N, dim),其中 Total_N = local_N * world_size

        # 3. 合并两个视图的特征
        features = torch.cat([z_i, z_j], dim=0) # Shape: (2 * Total_N, dim)

        # 4. 计算相似度矩阵 (余弦相似度)
        # (2*Total_N, dim) @ (dim, 2*Total_N) -> (2*Total_N, 2*Total_N)
        similarity_matrix = torch.matmul(features, features.T)

        # 5. 创建 Mask 过滤掉自身与自身的比较 (对角线)
        # (2*Total_N)
        mask = torch.eye(similarity_matrix.size(0), dtype=torch.bool, device=self.device)
        # 将对角线(自身相似度)置为一个很小的值(或负无穷),避免影响后续 softmax
        similarity_matrix.masked_fill_(mask, -float('inf')) 
        
        # 6. 准备 NCE Loss 的 Logits 和 Labels
        # similarity_matrix shape: (2*Total_N, 2*Total_N)
        # scale by temperature
        logits = similarity_matrix / self.temperature

        # Labels: 对于 z_i[k],它的正样本是 z_j[k];对于 z_j[k],它的正样本是 z_i[k]
        # 在合并后的 features (2*Total_N) 中:
        # 前 Total_N 个是 z_i,后 Total_N 个是 z_j
        # 对于 features[k] (0 <= k < Total_N),即 z_i[k],其对应的正样本是 features[k + Total_N],即 z_j[k]
        # 对于 features[k] (Total_N <= k < 2*Total_N),即 z_j[k-Total_N],其对应的正样本是 features[k - Total_N],即 z_i[k-Total_N]
        labels = torch.arange(inferred_total_batch_size, device=self.device).roll(shifts=current_batch_size)
        labels = labels.repeat(2) # Shape: (2 * Total_N)

        # 7. 计算 Loss (使用 CrossEntropyLoss 自动完成 log_softmax 和 NLLLoss)
        # CrossEntropyLoss(logits, labels) 等价于 NLLLoss(log_softmax(logits), labels)
        # logit 的每一行代表一个样本 i 与所有其他样本 k 的相似度 (scaled by T)
        # label 指示了哪一列 k 是样本 i 的正样本
        loss = self.criterion(logits, labels)

        # 标准 SimCLR loss 是除以 2 * Total_N
        # 我们使用了 reduction="sum",所以需要手动归一化
        loss = loss / (2 * inferred_total_batch_size)

        return loss

这个标准版有什么不同?

  • 处理所有负样本: 通过拼接 z_iz_j,计算 (2N, 2N) 的相似度矩阵,自然地包含了所有类型的负样本对 (z_i[k] vs z_i[l], z_j[k] vs z_j[l], z_i[k] vs z_j[l] where k!=l)。
  • 利用 CrossEntropyLoss: 将问题巧妙地转化为一个分类问题。对于每个样本 i(无论是来自 z_i 还是 z_j),目标是正确地从 2N-1 个其他样本中识别出它的“正配对”样本。CrossEntropyLoss 内部处理了 log_softmax,这在数值上比手动计算 log(numerator / denominator) 更稳定。
  • 分布式训练支持: 包含了使用 torch.distributed.all_gather 来收集所有 GPU 上的特征,这对于保证对比学习有足够多的负样本至关重要。
  • 掩码对角线: 使用 masked_fill_ 将对角线元素(自身与自身的相似度)设置为负无穷,这样在 softmax 计算中它们就变成了 0,不会影响损失。
  • 清晰的标签生成: labels 的生成逻辑清晰地对应了正样本对的位置。
  • 规范化损失: 最后除以 2 * Total_N(或 2 * self.batch_size * self.world_size),得到每个样本的平均损失。

安全建议:

  • 检查数据类型: 确保输入 projections_1, projections_2 是浮点类型(如 torch.float32),并且精度足够。
  • 防止除以零的 Temperature: 确认 temperature 的值大于零。实现中已添加保护。
  • 处理空 Batch: 在实际应用中,可能需要检查输入的 batch size 是否为 0。

想玩得更溜?进阶技巧来了

搞定了基础实现,这里还有些能让你玩转对比损失的小技巧:

  1. 温度系数 (Temperature) T:不只是个数字

    • 它控制着 softmax 的锐利程度。
    • 低 T (比如 0.1, 0.05): 使得模型更关注区分那些非常相似的负样本(所谓的 hard negatives),这能学到更精细的特征,但也可能导致训练不稳定或收敛困难。
    • 高 T (比如 1.0): 使得 softmax 更平滑,所有负样本的权重相对平均,训练更稳定,但可能无法有效区分难负例。
    • 通常需要根据数据集和模型进行调优,SimCLR 论文中常用 0.07 到 0.5 之间的值。
  2. 大 Batch Size 的重要性

    • 对比学习的效果很大程度上依赖于有足够多、足够好的负样本。
    • Batch Size 越大,每个正样本对就能看到越多的负样本,学习到的表示就越鲁棒。
    • 这也是为什么在多 GPU 上训练时,通常需要 all_gather 操作来聚合所有卡上的样本,变相增大有效 Batch Size。如果资源有限,可以考虑 MoCo(Momentum Contrast)等方法,它们使用队列来缓存大量的负样本,减少对大 Batch Size 的依赖。
  3. 数值稳定性:细节决定成败

    • 使用 log_softmax 如方案二所示,利用 CrossEntropyLoss 内部的 log_softmax 是处理 log(sum(exp(...))) 这类计算最稳健的方式。
    • 添加 Epsilon: 在进行除法或取对数操作前,给分母或参数加上一个微小的正数 epsilon (如 1e-61e-8),可以有效避免除零错误和 log(0)
    • 梯度裁剪 (Gradient Clipping): 如果梯度爆炸仍然偶尔发生,可以考虑在优化器步骤之前进行梯度裁剪 (torch.nn.utils.clip_grad_norm_torch.nn.utils.clip_grad_value_),强制将过大的梯度限制在一个范围内。
  4. 归一化:基础中的基础

    • 在计算余弦相似度之前对特征向量进行 L2 归一化 (F.normalize) 是标准操作。这确保了相似度度量不受向量长度的影响,只关注方向。忘记归一化会导致结果偏差很大。

现在,你应该对如何从零开始实现(以及修复)对比损失有了更清晰的认识,也了解了可能导致梯度爆炸的原因和避免方法。选择一个你觉得更清晰、更符合你需求的实现方式,动手试试看吧!