解决 Optax L-BFGS 优化器 NaN 错误的实用技巧
2025-05-04 23:26:02
解决 Optax L-BFGS 优化器中的 NaN 问题
用 JAX 和 Optax 跑模型,选了 L-BFGS 这个优化器?挺好,二阶优化方法,有时候收敛起来嗖嗖的。但跑着跑着,突然蹦出个 NaN
(Not a Number)?程序直接挂掉,心情是不是瞬间就不美丽了?
你可能遇到了 Optax L-BFGS 实现中一个比较常见的数值稳定性问题。具体来说,问题通常出在 optax/_src/transform.py
里的 scale_by_lbfgs
函数:
# optax/optax/_src/transform.py (示意)
def scale_by_lbfgs(
...
def update_fn(
...
# s_k = x_{k+1} - x_k (参数变化)
# y_k = g_{k+1} - g_k (梯度变化)
# diff_params = s_k
# diff_grads = y_k
vdot_diff_params_updates = jnp.vdot(diff_params, diff_grads) # <-- 计算 s_k^T y_k
weight = jnp.where(
vdot_diff_params_updates == 0.0, 0.0, 1.0 / vdot_diff_params_updates # <-- 问题点!
)
# ... 后续计算 ...
看到 1.0 / vdot_diff_params_updates
这行没?这就是“罪魁祸首”。vdot_diff_params_updates
本质上是参数向量变化量 (s_k
) 和梯度变化量 (y_k
) 的点积,即 s_k^T y_k
。在 L-BFGS 算法里,这个值理论上应该大于零,因为它反映了函数沿 s_k
方向的曲率信息。
但!是!计算机的世界是浮点数的江湖。当 vdot_diff_params_updates
这个值变得非常非常小,小到接近浮点数的表示极限,但又不完全是 0.0
时,1.0
除以一个极小的正数,结果可能会超出浮点数能表示的最大范围,变成 inf
(无穷大)。更糟的是,如果后续计算中 inf
参与了不当运算(比如 inf - inf
, 0 * inf
),就很容易产生 NaN
。
问题根源:为啥 s_k^T y_k
会变得这么小?
搞清楚为啥这个点积会出问题,有助于我们对症下药。
- 曲率信息丢失或接近零:
s_k^T y_k
代表了目标函数在s_k
方向上的(近似)二阶导数信息。如果函数在某个区域非常“平坦”,梯度变化微乎其微,或者参数更新步长很小,这个点积就可能接近零。这在训练后期或者遇到平坦的损失景观(loss landscape)时比较常见。 - 数值精度限制: 标准的
float32
精度有限。当参数和梯度的变化量本身就很小,计算它们的点积可能因为舍入误差(rounding errors)导致结果不准确,甚至非常接近零。 - 不满足 Wolfe 条件: L-BFGS 通常依赖线搜索(line search)来保证
s_k^T y_k > 0
(满足 Wolfe 条件中的曲率条件)。Optax 的scale_by_lbfgs
本身不包含线搜索逻辑,它期望输入的梯度(或更新量)已经满足了某些条件。如果上游传来的更新信息不符合 L-BFGS 的理论假设,这个点积就可能出问题。 - L-BFGS 历史记录累积误差: L-BFGS 利用历史的
s_k
和y_k
对来近似 Hessian 矩阵的逆。如果历史记录中的某些配对数值上不稳定,可能在后续的递归计算中放大问题。
怎么搞定 NaN?几种靠谱方案
别急着换优化器,咱们有几种方法可以尝试解决或缓解这个问题。
方案一:给分母加个“小保险” (Epsilon 正则化)
这是最直接的想法:既然怕除以太小的数,那就在分母上加一个特别小的正数 epsilon
,让它永远不会真正变成零或接近零。
原理:
通过 1.0 / (vdot_diff_params_updates + epsilon)
替换原来的 1.0 / vdot_diff_params_updates
。这个 epsilon
很小(比如 1e-8
或 1e-6
),对计算结果影响不大,但能有效防止除零或浮点溢出。
操作步骤:
由于直接修改 Optax 源代码不是最佳实践(升级麻烦,不方便共享),我们可以自定义一个转换函数(Transformation)来包裹或替换 scale_by_lbfgs
的核心逻辑。
假设我们想保持 Optax 的链式结构,可以这样做:
import jax
import jax.numpy as jnp
import optax
def safe_division(numerator, denominator, epsilon=1e-8):
"""安全的除法,避免分母过小导致 NaN"""
return numerator / (denominator + jnp.where(denominator >= 0.0, epsilon, -epsilon))
# 或者更简单粗暴一点,如果确定 denominator 理论上非负:
# return numerator / (denominator + epsilon)
def scale_by_lbfgs_safe(
history_size: int = 10,
min_vdot_value: float = 1e-8, # 这个就是我们的 epsilon
initial_inv_hessian_diagonal: float = 1.0,
use_high_precision_vars: bool = False):
"""
一个尝试修复 NaN 问题的 scale_by_lbfgs 版本。
注意:这是一个简化的示例,可能需要根据 Optax 版本调整。
更稳妥的方式是理解 optax.lbfgs 的实现,并在必要时构建自己的优化步骤。
以下代码旨在说明添加 epsilon 的思路,并未完整复刻所有细节。
"""
# 这个自定义版本主要是为了演示核心思想
# 实际应用中,更推荐使用 Optax 内建方案组合或等待官方更新
# 如果急用,可能需要参考 Optax 源码,创建一个功能完整的自定义 Transformation
def init_fn(params):
# LBFGS 需要状态来存储历史 s_k, y_k 等信息
# 参考 optax.lbfgs 的 init_fn 实现
# 这里仅作示意,实际需要更完整的状态初始化
if use_high_precision_vars:
dtype = jnp.float64
else:
dtype = jnp.float32 # 或者 params 的 dtype
state = optax.LbfgsState(
iter_count=jnp.zeros([], jnp.int32),
params_history=jnp.zeros((history_size,) + params.shape, dtype=dtype), # 简化形状处理
grads_history=jnp.zeros((history_size,) + params.shape, dtype=dtype), # 简化形状处理
last_params=jnp.zeros_like(params, dtype=dtype),
last_grads=jnp.zeros_like(params, dtype=dtype),
sum_inv_hessian_updates=jnp.zeros([], dtype=dtype) # 可能需要更多状态
)
# 这里只是一个非常基础的 state 结构,实际 L-BFGS 需要更复杂的状态管理
# 请务必参考 Optax 源码中的 LbfgsState dataclass 定义
return state # 返回一个合法的 LbfgsState 对象 (这里是伪代码)
def update_fn(updates, state, params):
# 这是一个高度简化的 update_fn,仅用于说明关键计算点
# 真实的 L-BFGS two-loop recursion 要复杂得多
# 假设 state 中已经包含了计算所需的 s_k 和 y_k
# 从 state 中获取 diff_params (s_k) 和 diff_grads (y_k)
# ... (这部分逻辑需要完整实现 L-BFGS 的历史记录更新)
# 假设我们拿到了当前的 diff_params 和 diff_grads
diff_params = state.last_params - params # 这是一个简化示例 s_k = x_k - x_{k-1}
diff_grads = updates - state.last_grads # 这是一个简化示例 y_k = g_k - g_{k-1}
vdot_diff_params_updates = jnp.vdot(diff_params, diff_grads)
# !!! 关键修改点 !!!
# 使用安全除法
safe_weight = jnp.where(
# 注意:原始代码的条件是 == 0.0,这里沿用,但更稳健的可能是 abs() < threshold
vdot_diff_params_updates == 0.0,
0.0,
1.0 / (vdot_diff_params_updates + jnp.sign(vdot_diff_params_updates) * min_vdot_value)
# 添加一个小量,符号与原数相同,避免改变性质
# 或者,如果理论上 vdot > 0, 可以直接 + min_vdot_value
# safe_weight = 1.0 / jnp.maximum(vdot_diff_params_updates, min_vdot_value) # 另一种更简洁写法, 假定vdot理论非负
)
# ... L-BFGS 的后续计算,例如 two-loop recursion ...
# 使用 safe_weight 代替原始的 weight
# 这里的 updates 是 L-BFGS 算法计算出的最终更新方向 * 步长
# 这个简化版 update_fn 并没有计算实际的 L-BFGS 更新
# 返回计算得到的最终更新量 `new_updates` 和更新后的 `new_state`
new_updates = updates # 伪代码:这里应是经过LBFGS近似Hessian逆处理后的结果
new_state = state # 伪代码:这里应是更新历史记录后的状态
return new_updates, new_state
# return optax.GradientTransformation(init_fn, update_fn)
# 请注意:上面的 init_fn 和 update_fn 是高度简化的伪代码
# 真正实现需要完整复刻 Optax L-BFGS 的状态管理和 two-loop recursion
# 因此,更现实的做法可能是下面要介绍的其他方案,或者向 Optax 社区提出 issue
# 一个更实用的方法可能是:检查 Optax 是否提供了控制这个除法的选项
# 或者使用 optax.inject_hyperparams 来动态调整一些可能影响数值稳定性的参数 (如果LBFGS暴露了这类参数)
# 如果你非要自己实现一个 L-BFGS 的变种,务必仔细测试
# 这个方案的核心思想是在计算 1/vdot 时加入 epsilon 保护
print("注意:上述 safe_scale_by_lbfgs 是一个演示思路的简化伪代码。")
print("实际应用需要更完整的实现或考虑其他方案。")
# 返回一个示意性的无效对象,因为上面的代码不完整
return optax.identity() # 返回一个无操作的转换作为占位符
# 使用时 (概念上的)
# optimizer = optax.chain(
# # ... 其他转换 ...
# scale_by_lbfgs_safe(history_size=15, min_vdot_value=1e-7),
# # ... 可能的 scale step size ...
# )
注意: 上面的代码 scale_by_lbfgs_safe
只是一个 原理示意 ,并非 Optax 官方 optax.lbfgs
的完整、可直接替换的实现。完整实现 L-BFGS 的状态管理和双循环递归相当复杂。这里的重点是展示修改 1.0 / vdot
计算的思路。实际应用中,你可能需要更深入地研究 optax.lbfgs
的源代码,或者采用下面的其他方法。
安全建议: epsilon
的选择需要小心。太小了可能还是会溢出,太大了则可能影响优化器的性能,引入偏差。通常从 1e-8
到 1e-6
开始尝试。
方案二:管住梯度,别让它“放飞自我” (Gradient Clipping)
梯度裁剪是一种常用的稳定训练技巧,虽然不能直接解决 1.0 / vdot
的问题,但它能间接改善数值稳定性。
原理:
通过限制梯度的范数(Norm)或绝对值,防止梯度变得过大或过小。这有助于避免参数更新步子迈得太大扯着嗓子(导致震荡或发散),也可能间接使得 s_k
和 y_k
的变化更平稳,从而降低 s_k^T y_k
变得极端异常小的概率。
操作步骤:
Optax 提供了方便的梯度裁剪函数,通常加在优化链的早期。
import optax
# 方案 2.1: 按全局范数裁剪 (常用)
# 将所有参数的梯度视为一个大向量,计算其范数,如果超过 max_norm,则缩放整个梯度向量
clip_norm = 1.0 # 经验值,需要根据你的模型和数据调整
optimizer = optax.chain(
optax.clip_by_global_norm(clip_norm),
# ... 这里接你的 L-BFGS 或其他优化步骤 ...
# 例如:optax.lbfgs(learning_rate=1.0, history_size=10)
# 注意: optax.lbfgs 本身可能整合了 scale_by_lbfgs 的逻辑
)
# 方案 2.2: 按值裁剪 (不太常用,但有时有用)
# 对每个梯度值进行独立裁剪,限制在 [-max_value, max_value] 区间
# clip_value = 1.0
# optimizer = optax.chain(
# optax.clip(clip_value),
# # ... L-BFGS 步骤 ...
# )
进阶技巧: clip_norm
的值需要根据实验调整。设太小可能限制模型学习能力,太大则起不到裁剪效果。可以监控训练过程中的梯度范数来辅助选择合适的值。
方案三:提高精度,用“双精度”试试 (float64)
如果硬件支持且性能可以接受,使用 float64
(双精度) 可以大大提升数值计算的精度和范围,可能直接避免 float32
下的溢出问题。
原理:
float64
比 float32
有更长的尾数位和指数位,能表示更大范围和更高精度的数值。这使得它在处理非常小或非常大的数时更不容易出错。
操作步骤:
需要在 JAX 启动时全局启用 64 位浮点数支持。
import jax
# 在程序开头(import jax 之后)设置
jax.config.update("jax_enable_x64", True)
# 后续创建 JAX 数组、定义模型和优化器时,它们通常会默认使用 float64
# 或者可以在创建数组时显式指定 dtype=jnp.float64
# ... 定义你的模型和损失函数 ...
# 优化器也可能需要调整,检查 optax.lbfgs 是否有参数可以利用高精度状态
# 例如,一些 Optax 组件有 use_high_precision_vars 参数
# optimizer = optax.lbfgs(learning_rate=1.0, history_size=10, use_high_precision_vars=True) # 检查 LBFGS 是否支持此参数
# 注意:确保你的硬件(CPU/GPU)对 float64 支持良好,否则性能会大幅下降
# TPU 通常对 float64 支持有限或性能不佳,更倾向于 bfloat16/float32
进阶技巧: float64
会消耗两倍内存,并且计算通常比 float32
慢,尤其是在 GPU 或 TPU 上。这是一种“大力出奇迹”的方法,但在精度是瓶颈时非常有效。
方案四:调整 L-BFGS 历史记录大小 (History Size)
L-BFGS 通过存储最近 m
次的参数和梯度变化 (s_k
, y_k
) 来近似 Hessian 逆矩阵。这个 m
就是 history_size
。
原理:
减小 history_size
会让 L-BFGS 使用更少的历史信息,这可能:
- 减少了累积旧的、可能数值不稳定的 (
s_k
,y_k
) 对的机会。 - 降低了内存消耗。
- 但也牺牲了对 Hessian 矩阵逆的近似精度,可能影响收敛速度或效果。
操作步骤:
在创建 L-BFGS 优化器时设置 history_size
参数。
import optax
# 尝试减小 history_size,比如从默认的 10-20 减到 5-8
optimizer = optax.lbfgs(
learning_rate=1.0, # L-BFGS 的 learning_rate 通常设为 1.0,步长由线搜索或内部计算确定
history_size=8 # 减小历史记录大小
)
进阶技巧: history_size
是一个权衡参数。较大的值通常能更好地近似二阶信息,但计算量和内存占用也更大,且可能更容易引入数值问题。可以从小到大尝试不同的 history_size
。
方案五:检查你的模型和损失函数
有时候,数值不稳定并非优化器本身的问题,而是源于:
- 模型设计: 某些层或激活函数可能导致梯度爆炸或消失。
- 损失函数: 损失函数的梯度是否在某些区域行为异常?
- 数据预处理: 输入数据是否归一化得当?是否存在极端值?
排查这些方面,确保模型训练过程本身是健康的,有时也能间接解决优化器的数值问题。
希望以上这些方案能帮你解决 Optax L-BFGS 中的 NaN
烦恼。通常可以从方案一(加 Epsilon,如果能方便实现)、方案二(梯度裁剪)和方案三(双精度)入手尝试。
相关资源
- Optax Documentation: https://optax.readthedocs.io/
- JAX Numerical Stability Guide (if available, or related discussions on JAX GitHub/forums)