从零实现对比损失:解决PyTorch梯度爆炸难题
2025-04-18 00:07:50
从零实现对比损失 (Contrastive Loss):搞定梯度爆炸难题
写对比损失(Contrastive Loss)的代码时,碰上梯度爆炸(Gradient Explosion)这事儿挺让人头疼的。模型训练过程中,梯度突然变成 inf
或 NaN
,训练直接中断,感觉就像代码里埋了个雷。
这不,就有朋友遇到了这个问题,他的 ContrastiveLoss
实现如下:
import torch
import torch.nn as nn
import torch.nn.functional as F
class ContrastiveLoss(nn.Module):
def __init__(self, temperature=0.5):
super(ContrastiveLoss, self).__init__()
self.temperature = temperature
def forward(self, projections_1, projections_2):
z_i = projections_1
z_j = projections_2
z_i_norm = F.normalize(z_i, dim=1)
z_j_norm = F.normalize(z_j, dim=1)
# 问题可能出在这里
cosine_num = torch.matmul(z_i, z_j.T)
cosine_denom = torch.matmul(z_i_norm, z_j_norm.T)
cosine_similarity = cosine_num / cosine_denom # <- 潜在的除零风险或数值不稳定
numerator = torch.exp(torch.diag(cosine_similarity) / self.temperature)
# 分母的计算方式也有问题
denominator = cosine_similarity
diagonal_indices = torch.arange(denominator.size(0))
denominator[diagonal_indices, diagonal_indices] = 0 # <- 原地修改张量
denominator = torch.exp(torch.sum(cosine_similarity, dim=1)) # <- 应该先 exp 再 sum
loss = -torch.log(numerator / denominator).sum() # <- log(0) 或除零风险
return loss
这段代码尝试实现类似 SimCLR 中的 NT-Xent Loss,目标是拉近正样本对(同一输入的两种增强视图的投影)的表示,推开负样本对(不同输入的投影)的表示。
看着代码,梯度直接飙到无穷大,那大概率是哪里写错了。咱们来仔细扒一扒。
问题分析:雷区在哪?
这段代码里有几个地方看着就“不太对劲”,很可能是梯度爆炸的元凶:
-
余弦相似度的计算:画蛇添足还引入了不稳定性
- 代码先分别计算了 未归一化 向量的点积 (
cosine_num = torch.matmul(z_i, z_j.T)
) 和 已归一化 向量的点积 (cosine_denom = torch.matmul(z_i_norm, z_j_norm.T)
)。 - 然后用
cosine_similarity = cosine_num / cosine_denom
来得到最终的相似度。 - 核心问题 :
cosine_denom
本身 就是 余弦相似度矩阵!因为根据定义,cosine_similarity(A, B) = dot(A, B) / (norm(A) * norm(B))
。当 A 和 B 已经是单位向量(经过F.normalize
处理)时,它们的 L2 范数norm(A)
和norm(B)
都等于 1,所以dot(A_norm, B_norm)
直接就是余弦相似度。 - 那个
cosine_num / cosine_denom
的除法操作不仅多余,而且极其危险。cosine_denom
的值域是 [-1, 1],如果其中某个元素非常接近 0,做除法时结果就会趋近无穷大,梯度自然也就“原地爆炸”了。
- 代码先分别计算了 未归一化 向量的点积 (
-
分母的计算:顺序搞反,结果全错
- NT-Xent Loss 的分母应该是 所有负样本对 的相似度(经过温度系数缩放后)的 指数和 ,即
sum(exp(sim(i, k) / T))
,其中k
遍历所有与i
不同的样本(包括j
的其他实例,但不包括i
本身)。 - 代码里先把
cosine_similarity
对角线元素(正样本对相似度)置零:denominator[diagonal_indices, diagonal_indices] = 0
。虽然意图是排除正样本,但这种原地修改张量的操作(in-place modification)有时会带来麻烦,尤其是在计算图中,可能干扰梯度计算。 - 接着,代码计算了
denominator = torch.exp(torch.sum(cosine_similarity, dim=1))
。这完全搞错了顺序!它先对一行中的(修改后的)相似度求和,然后 再对这个和取指数。正确的做法应该是先对每个负样本对的相似度sim(i, k) / T
取指数,然后 再把这些指数值加起来。也就是torch.sum(torch.exp(cosine_similarity / self.temperature), dim=1)
(当然,还要排除掉正样本对)。
- NT-Xent Loss 的分母应该是 所有负样本对 的相似度(经过温度系数缩放后)的 指数和 ,即
-
潜在的数值问题:
log(0)
风险- 最后计算损失时用了
-torch.log(numerator / denominator)
。如果numerator
非常小或者denominator
非常小(接近零),numerator / denominator
的结果也可能接近零,取对数torch.log
会得到负无穷,导致loss
变成NaN
或inf
。
- 最后计算损失时用了
搞清楚了问题所在,修复起来就容易多了。
动手修复:两种方案任你选
我们可以提供两种修改方案:一种是尽量在原有代码结构上修正错误,另一种是采用更标准、更简洁的方式来实现 NT-Xent Loss。
方案一:修正版代码
这个方案主要修正计算余弦相似度和分母的逻辑错误,并添加一些数值稳定性处理。
import torch
import torch.nn as nn
import torch.nn.functional as F
class ContrastiveLossFixed(nn.Module):
def __init__(self, temperature=0.5, epsilon=1e-6):
"""
修复版的对比损失函数
Args:
temperature (float): 温度系数,控制分布的锐利程度
epsilon (float): 一个很小的数,用于防止 log(0) 和除零错误
"""
super(ContrastiveLossFixed, self).__init__()
self.temperature = temperature
self.epsilon = epsilon
def forward(self, projections_1, projections_2):
z_i = projections_1
z_j = projections_2
# 1. 归一化特征向量
z_i_norm = F.normalize(z_i, p=2, dim=1)
z_j_norm = F.normalize(z_j, p=2, dim=1)
# 2. 计算正确的余弦相似度矩阵
# 直接用归一化后的向量做点积,结果就是余弦相似度
# 注意这里 z_i_norm 和 z_j_norm 来自同一批数据,所以得到的是 N x N 的矩阵
# 其中对角线元素是 (z_i[k], z_j[k]) 对的相似度,它们是正样本对
# 非对角线元素是 (z_i[k], z_j[l]) (k!=l) 对的相似度,是负样本对的一部分
# 理论上还应该包含 (z_i[k], z_i[l]) 和 (z_j[k], z_j[l]) 等负样本对,标准实现会更复杂
# 但我们先基于原代码的简化假设:只考虑 z_i 和 z_j 之间的 N^2 对相似度
cosine_similarity_matrix = torch.matmul(z_i_norm, z_j_norm.T)
# 提取正样本对的相似度 (对角线元素)
# sim(i, i) in the paper, calculated between z_i[k] and z_j[k]
positive_pairs_similarity = torch.diag(cosine_similarity_matrix)
# 3. 计算分子: exp(sim(i,j) / T)
# 确保不会除以零温度
safe_temperature = max(self.temperature, self.epsilon)
numerator = torch.exp(positive_pairs_similarity / safe_temperature)
# 4. 计算分母: sum_{k!=i} exp(sim(i, k) / T)
# 这里的分母是针对每个 z_i[k] 而言的,需要计算它与所有 z_j[l] (l!=k) 的相似度的指数和
# 创建一个 N x N 的 mask,对角线为 False,其余为 True
mask = ~torch.eye(cosine_similarity_matrix.size(0), device=cosine_similarity_matrix.device, dtype=torch.bool)
# 选取所有负样本对的相似度
# 对 cosine_similarity_matrix 应用 mask,得到一个稀疏表示或只是标记
# 我们需要计算每一行的 sum(exp(sim(i, k) / T)) for k!=i
# 计算所有项的 exp(sim / T)
exp_similarity = torch.exp(cosine_similarity_matrix / safe_temperature)
# 将对角线(正样本)的 exp 值置零,或者使用 mask 来求和
# 这里我们用 mask 来选取负样本项,然后求和
# 注意:原代码的 sum 在 exp 之前是错误的!
denominator = torch.sum(exp_similarity * mask, dim=1) # 对每一行(每个 z_i)求和
# 5. 计算 Loss
# 防止分母为零
loss_per_sample = -torch.log(numerator / (denominator + self.epsilon))
# 这里原代码是 .sum(),但通常 NT-Xent Loss 是对 Batch 取平均 .mean()
# 如果遵循原代码的 .sum(),这里保留
loss = loss_per_sample.sum()
# 或者更常见的: loss = loss_per_sample.mean()
return loss
这个修正版做了什么?
- 直接计算余弦相似度: 使用
torch.matmul(z_i_norm, z_j_norm.T)
直接得到正确的cosine_similarity_matrix
。简单、直接、数值稳定。 - 修正分母计算: 先对所有相似度(除以温度 T)取指数
torch.exp()
,然后利用布尔掩码mask
过滤掉正样本对(对角线元素),最后对每行的负样本对指数值求和torch.sum(..., dim=1)
。这才是正确的计算顺序。 - 增加数值稳定性: 加入
epsilon
防止除以零和log(0)
的情况。同时检查temperature
防止其为零。 - Loss 聚合方式: 指出原代码使用
.sum()
,但.mean()
更常见,让使用者可以根据需要选择。
注意: 上述代码仍然是基于原代码的一个简化假设,即只考虑了 z_i
和 z_j
构成的 N*N 相似度矩阵。一个更完整的 NT-Xent 实现(如下面的方案二所示)通常会将 z_i
和 z_j
合并,计算一个 2N * 2N 的相似度矩阵,从而包含更多类型的负样本对(例如 z_i[k]
vs z_i[l]
)。
方案二:更标准的 NT-Xent Loss 实现
这个方案直接实现更标准的 NT-Xent Loss,它将两个视图的投影 z_i
和 z_j
拼接到一起,计算一个更大的相似度矩阵,处理起来更规范,也更容易理解。
import torch
import torch.nn as nn
import torch.nn.functional as F
class NTXentLoss(nn.Module):
"""
更标准的 NT-Xent Loss 实现 (来自 SimCLR)
"""
def __init__(self, temperature=0.5, batch_size=None, world_size=1, device=None, epsilon=1e-6):
"""
Args:
temperature (float): 温度系数.
batch_size (int): 每个 GPU 上的批次大小. 如果不提供,会尝试从输入推断.
world_size (int): 分布式训练的总进程数 (GPU数量).
device (torch.device): 计算所在的设备.
epsilon (float): 防止 log(0) 的小常数.
"""
super(NTXentLoss, self).__init__()
self.temperature = temperature
self.batch_size = batch_size # batch_size in one GPU
self.world_size = world_size # num of GPUs
self.device = device if device else torch.device('cuda' if torch.cuda.is_available() else 'cpu')
self.epsilon = epsilon
self.criterion = nn.CrossEntropyLoss(reduction="sum") # 使用交叉熵简化计算
def forward(self, projections_1, projections_2):
"""
输入是来自两个视图的投影 (N, embedding_dim).
"""
z_i = projections_1
z_j = projections_2
# 动态获取 batch size
current_batch_size = z_i.size(0)
if self.batch_size is None:
# 警告:如果在 DDP 中不显式设置 batch_size,这里的 N 可能是 total_batch_size / world_size
print(f"Warning: batch_size not provided, inferred as {current_batch_size}. Ensure this is correct, especially in DDP.")
inferred_total_batch_size = current_batch_size * self.world_size
else:
inferred_total_batch_size = self.batch_size * self.world_size
# 1. 归一化特征
z_i = F.normalize(z_i, p=2, dim=1)
z_j = F.normalize(z_j, p=2, dim=1)
# 2. [重要] 处理分布式训练 (DDP): 收集所有 GPU 上的特征
# 如果在单卡上运行,world_size=1,gather 操作不起作用,z_i, z_j 保持不变 (N, dim)
if self.world_size > 1:
# 需要安装并配置好 torch.distributed
import torch.distributed as dist
# Gather z_i
z_i_list = [torch.zeros_like(z_i) for _ in range(self.world_size)]
dist.all_gather(z_i_list, z_i)
z_i = torch.cat(z_i_list, dim=0) # Shape: (N * world_size, dim)
# Gather z_j
z_j_list = [torch.zeros_like(z_j) for _ in range(self.world_size)]
dist.all_gather(z_j_list, z_j)
z_j = torch.cat(z_j_list, dim=0) # Shape: (N * world_size, dim)
# 现在 z_i 和 z_j 的 shape 都是 (Total_N, dim),其中 Total_N = local_N * world_size
# 3. 合并两个视图的特征
features = torch.cat([z_i, z_j], dim=0) # Shape: (2 * Total_N, dim)
# 4. 计算相似度矩阵 (余弦相似度)
# (2*Total_N, dim) @ (dim, 2*Total_N) -> (2*Total_N, 2*Total_N)
similarity_matrix = torch.matmul(features, features.T)
# 5. 创建 Mask 过滤掉自身与自身的比较 (对角线)
# (2*Total_N)
mask = torch.eye(similarity_matrix.size(0), dtype=torch.bool, device=self.device)
# 将对角线(自身相似度)置为一个很小的值(或负无穷),避免影响后续 softmax
similarity_matrix.masked_fill_(mask, -float('inf'))
# 6. 准备 NCE Loss 的 Logits 和 Labels
# similarity_matrix shape: (2*Total_N, 2*Total_N)
# scale by temperature
logits = similarity_matrix / self.temperature
# Labels: 对于 z_i[k],它的正样本是 z_j[k];对于 z_j[k],它的正样本是 z_i[k]
# 在合并后的 features (2*Total_N) 中:
# 前 Total_N 个是 z_i,后 Total_N 个是 z_j
# 对于 features[k] (0 <= k < Total_N),即 z_i[k],其对应的正样本是 features[k + Total_N],即 z_j[k]
# 对于 features[k] (Total_N <= k < 2*Total_N),即 z_j[k-Total_N],其对应的正样本是 features[k - Total_N],即 z_i[k-Total_N]
labels = torch.arange(inferred_total_batch_size, device=self.device).roll(shifts=current_batch_size)
labels = labels.repeat(2) # Shape: (2 * Total_N)
# 7. 计算 Loss (使用 CrossEntropyLoss 自动完成 log_softmax 和 NLLLoss)
# CrossEntropyLoss(logits, labels) 等价于 NLLLoss(log_softmax(logits), labels)
# logit 的每一行代表一个样本 i 与所有其他样本 k 的相似度 (scaled by T)
# label 指示了哪一列 k 是样本 i 的正样本
loss = self.criterion(logits, labels)
# 标准 SimCLR loss 是除以 2 * Total_N
# 我们使用了 reduction="sum",所以需要手动归一化
loss = loss / (2 * inferred_total_batch_size)
return loss
这个标准版有什么不同?
- 处理所有负样本: 通过拼接
z_i
和z_j
,计算(2N, 2N)
的相似度矩阵,自然地包含了所有类型的负样本对 (z_i[k]
vsz_i[l]
,z_j[k]
vsz_j[l]
,z_i[k]
vsz_j[l]
wherek!=l
)。 - 利用 CrossEntropyLoss: 将问题巧妙地转化为一个分类问题。对于每个样本
i
(无论是来自z_i
还是z_j
),目标是正确地从2N-1
个其他样本中识别出它的“正配对”样本。CrossEntropyLoss
内部处理了log_softmax
,这在数值上比手动计算log(numerator / denominator)
更稳定。 - 分布式训练支持: 包含了使用
torch.distributed.all_gather
来收集所有 GPU 上的特征,这对于保证对比学习有足够多的负样本至关重要。 - 掩码对角线: 使用
masked_fill_
将对角线元素(自身与自身的相似度)设置为负无穷,这样在softmax
计算中它们就变成了 0,不会影响损失。 - 清晰的标签生成:
labels
的生成逻辑清晰地对应了正样本对的位置。 - 规范化损失: 最后除以
2 * Total_N
(或2 * self.batch_size * self.world_size
),得到每个样本的平均损失。
安全建议:
- 检查数据类型: 确保输入
projections_1
,projections_2
是浮点类型(如torch.float32
),并且精度足够。 - 防止除以零的 Temperature: 确认
temperature
的值大于零。实现中已添加保护。 - 处理空 Batch: 在实际应用中,可能需要检查输入的 batch size 是否为 0。
想玩得更溜?进阶技巧来了
搞定了基础实现,这里还有些能让你玩转对比损失的小技巧:
-
温度系数 (Temperature) T:不只是个数字
- 它控制着
softmax
的锐利程度。 - 低 T (比如 0.1, 0.05): 使得模型更关注区分那些非常相似的负样本(所谓的 hard negatives),这能学到更精细的特征,但也可能导致训练不稳定或收敛困难。
- 高 T (比如 1.0): 使得
softmax
更平滑,所有负样本的权重相对平均,训练更稳定,但可能无法有效区分难负例。 - 通常需要根据数据集和模型进行调优,SimCLR 论文中常用 0.07 到 0.5 之间的值。
- 它控制着
-
大 Batch Size 的重要性
- 对比学习的效果很大程度上依赖于有足够多、足够好的负样本。
- Batch Size 越大,每个正样本对就能看到越多的负样本,学习到的表示就越鲁棒。
- 这也是为什么在多 GPU 上训练时,通常需要
all_gather
操作来聚合所有卡上的样本,变相增大有效 Batch Size。如果资源有限,可以考虑 MoCo(Momentum Contrast)等方法,它们使用队列来缓存大量的负样本,减少对大 Batch Size 的依赖。
-
数值稳定性:细节决定成败
- 使用
log_softmax
: 如方案二所示,利用CrossEntropyLoss
内部的log_softmax
是处理log(sum(exp(...)))
这类计算最稳健的方式。 - 添加 Epsilon: 在进行除法或取对数操作前,给分母或参数加上一个微小的正数
epsilon
(如1e-6
或1e-8
),可以有效避免除零错误和log(0)
。 - 梯度裁剪 (Gradient Clipping): 如果梯度爆炸仍然偶尔发生,可以考虑在优化器步骤之前进行梯度裁剪 (
torch.nn.utils.clip_grad_norm_
或torch.nn.utils.clip_grad_value_
),强制将过大的梯度限制在一个范围内。
- 使用
-
归一化:基础中的基础
- 在计算余弦相似度之前对特征向量进行 L2 归一化 (
F.normalize
) 是标准操作。这确保了相似度度量不受向量长度的影响,只关注方向。忘记归一化会导致结果偏差很大。
- 在计算余弦相似度之前对特征向量进行 L2 归一化 (
现在,你应该对如何从零开始实现(以及修复)对比损失有了更清晰的认识,也了解了可能导致梯度爆炸的原因和避免方法。选择一个你觉得更清晰、更符合你需求的实现方式,动手试试看吧!