返回

解决 Optax L-BFGS 优化器 NaN 错误的实用技巧

python

解决 Optax L-BFGS 优化器中的 NaN 问题

用 JAX 和 Optax 跑模型,选了 L-BFGS 这个优化器?挺好,二阶优化方法,有时候收敛起来嗖嗖的。但跑着跑着,突然蹦出个 NaN (Not a Number)?程序直接挂掉,心情是不是瞬间就不美丽了?

你可能遇到了 Optax L-BFGS 实现中一个比较常见的数值稳定性问题。具体来说,问题通常出在 optax/_src/transform.py 里的 scale_by_lbfgs 函数:

# optax/optax/_src/transform.py (示意)
def scale_by_lbfgs(
    ...
    def update_fn(
        ...
        # s_k = x_{k+1} - x_k  (参数变化)
        # y_k = g_{k+1} - g_k  (梯度变化)
        # diff_params = s_k
        # diff_grads = y_k
        vdot_diff_params_updates = jnp.vdot(diff_params, diff_grads) #  <-- 计算 s_k^T y_k

        weight = jnp.where(
            vdot_diff_params_updates == 0.0, 0.0, 1.0 / vdot_diff_params_updates #  <-- 问题点!
        )
        # ... 后续计算 ...

看到 1.0 / vdot_diff_params_updates 这行没?这就是“罪魁祸首”。vdot_diff_params_updates 本质上是参数向量变化量 (s_k) 和梯度变化量 (y_k) 的点积,即 s_k^T y_k。在 L-BFGS 算法里,这个值理论上应该大于零,因为它反映了函数沿 s_k 方向的曲率信息。

但!是!计算机的世界是浮点数的江湖。当 vdot_diff_params_updates 这个值变得非常非常小,小到接近浮点数的表示极限,但又不完全是 0.0 时,1.0 除以一个极小的正数,结果可能会超出浮点数能表示的最大范围,变成 inf (无穷大)。更糟的是,如果后续计算中 inf 参与了不当运算(比如 inf - inf, 0 * inf),就很容易产生 NaN

问题根源:为啥 s_k^T y_k 会变得这么小?

搞清楚为啥这个点积会出问题,有助于我们对症下药。

  1. 曲率信息丢失或接近零: s_k^T y_k 代表了目标函数在 s_k 方向上的(近似)二阶导数信息。如果函数在某个区域非常“平坦”,梯度变化微乎其微,或者参数更新步长很小,这个点积就可能接近零。这在训练后期或者遇到平坦的损失景观(loss landscape)时比较常见。
  2. 数值精度限制: 标准的 float32 精度有限。当参数和梯度的变化量本身就很小,计算它们的点积可能因为舍入误差(rounding errors)导致结果不准确,甚至非常接近零。
  3. 不满足 Wolfe 条件: L-BFGS 通常依赖线搜索(line search)来保证 s_k^T y_k > 0 (满足 Wolfe 条件中的曲率条件)。Optax 的 scale_by_lbfgs 本身不包含线搜索逻辑,它期望输入的梯度(或更新量)已经满足了某些条件。如果上游传来的更新信息不符合 L-BFGS 的理论假设,这个点积就可能出问题。
  4. L-BFGS 历史记录累积误差: L-BFGS 利用历史的 s_ky_k 对来近似 Hessian 矩阵的逆。如果历史记录中的某些配对数值上不稳定,可能在后续的递归计算中放大问题。

怎么搞定 NaN?几种靠谱方案

别急着换优化器,咱们有几种方法可以尝试解决或缓解这个问题。

方案一:给分母加个“小保险” (Epsilon 正则化)

这是最直接的想法:既然怕除以太小的数,那就在分母上加一个特别小的正数 epsilon,让它永远不会真正变成零或接近零。

原理:

通过 1.0 / (vdot_diff_params_updates + epsilon) 替换原来的 1.0 / vdot_diff_params_updates。这个 epsilon 很小(比如 1e-81e-6),对计算结果影响不大,但能有效防止除零或浮点溢出。

操作步骤:

由于直接修改 Optax 源代码不是最佳实践(升级麻烦,不方便共享),我们可以自定义一个转换函数(Transformation)来包裹或替换 scale_by_lbfgs 的核心逻辑。

假设我们想保持 Optax 的链式结构,可以这样做:

import jax
import jax.numpy as jnp
import optax

def safe_division(numerator, denominator, epsilon=1e-8):
  """安全的除法,避免分母过小导致 NaN"""
  return numerator / (denominator + jnp.where(denominator >= 0.0, epsilon, -epsilon))
  # 或者更简单粗暴一点,如果确定 denominator 理论上非负:
  # return numerator / (denominator + epsilon)

def scale_by_lbfgs_safe(
    history_size: int = 10,
    min_vdot_value: float = 1e-8, # 这个就是我们的 epsilon
    initial_inv_hessian_diagonal: float = 1.0,
    use_high_precision_vars: bool = False):
  """
  一个尝试修复 NaN 问题的 scale_by_lbfgs 版本。
  注意:这是一个简化的示例,可能需要根据 Optax 版本调整。
  更稳妥的方式是理解 optax.lbfgs 的实现,并在必要时构建自己的优化步骤。
  以下代码旨在说明添加 epsilon 的思路,并未完整复刻所有细节。
  """

  # 这个自定义版本主要是为了演示核心思想
  # 实际应用中,更推荐使用 Optax 内建方案组合或等待官方更新
  # 如果急用,可能需要参考 Optax 源码,创建一个功能完整的自定义 Transformation

  def init_fn(params):
    # LBFGS 需要状态来存储历史 s_k, y_k 等信息
    # 参考 optax.lbfgs 的 init_fn 实现
    # 这里仅作示意,实际需要更完整的状态初始化
    if use_high_precision_vars:
        dtype = jnp.float64
    else:
        dtype = jnp.float32 # 或者 params 的 dtype
    
    state = optax.LbfgsState(
        iter_count=jnp.zeros([], jnp.int32),
        params_history=jnp.zeros((history_size,) + params.shape, dtype=dtype), # 简化形状处理
        grads_history=jnp.zeros((history_size,) + params.shape, dtype=dtype),  # 简化形状处理
        last_params=jnp.zeros_like(params, dtype=dtype),
        last_grads=jnp.zeros_like(params, dtype=dtype),
        sum_inv_hessian_updates=jnp.zeros([], dtype=dtype) # 可能需要更多状态
    )
    # 这里只是一个非常基础的 state 结构,实际 L-BFGS 需要更复杂的状态管理
    # 请务必参考 Optax 源码中的 LbfgsState dataclass 定义
    return state # 返回一个合法的 LbfgsState 对象 (这里是伪代码)


  def update_fn(updates, state, params):
    # 这是一个高度简化的 update_fn,仅用于说明关键计算点
    # 真实的 L-BFGS two-loop recursion 要复杂得多

    # 假设 state 中已经包含了计算所需的 s_k 和 y_k
    # 从 state 中获取 diff_params (s_k) 和 diff_grads (y_k)
    # ... (这部分逻辑需要完整实现 L-BFGS 的历史记录更新)

    # 假设我们拿到了当前的 diff_params 和 diff_grads
    diff_params = state.last_params - params # 这是一个简化示例 s_k = x_k - x_{k-1}
    diff_grads = updates - state.last_grads  # 这是一个简化示例 y_k = g_k - g_{k-1}

    vdot_diff_params_updates = jnp.vdot(diff_params, diff_grads)

    # !!! 关键修改点 !!!
    # 使用安全除法
    safe_weight = jnp.where(
        # 注意:原始代码的条件是 == 0.0,这里沿用,但更稳健的可能是 abs() < threshold
        vdot_diff_params_updates == 0.0,
        0.0,
        1.0 / (vdot_diff_params_updates + jnp.sign(vdot_diff_params_updates) * min_vdot_value)
        # 添加一个小量,符号与原数相同,避免改变性质
        # 或者,如果理论上 vdot > 0, 可以直接 + min_vdot_value
        # safe_weight = 1.0 / jnp.maximum(vdot_diff_params_updates, min_vdot_value) # 另一种更简洁写法, 假定vdot理论非负
    )

    # ... L-BFGS 的后续计算,例如 two-loop recursion ...
    # 使用 safe_weight 代替原始的 weight

    # 这里的 updates 是 L-BFGS 算法计算出的最终更新方向 * 步长
    # 这个简化版 update_fn 并没有计算实际的 L-BFGS 更新
    # 返回计算得到的最终更新量 `new_updates` 和更新后的 `new_state`
    
    new_updates = updates # 伪代码:这里应是经过LBFGS近似Hessian逆处理后的结果
    new_state = state # 伪代码:这里应是更新历史记录后的状态
    return new_updates, new_state

  # return optax.GradientTransformation(init_fn, update_fn)
  # 请注意:上面的 init_fn 和 update_fn 是高度简化的伪代码
  # 真正实现需要完整复刻 Optax L-BFGS 的状态管理和 two-loop recursion
  # 因此,更现实的做法可能是下面要介绍的其他方案,或者向 Optax 社区提出 issue

  # 一个更实用的方法可能是:检查 Optax 是否提供了控制这个除法的选项
  # 或者使用 optax.inject_hyperparams 来动态调整一些可能影响数值稳定性的参数 (如果LBFGS暴露了这类参数)

  # 如果你非要自己实现一个 L-BFGS 的变种,务必仔细测试
  # 这个方案的核心思想是在计算 1/vdot 时加入 epsilon 保护
  print("注意:上述 safe_scale_by_lbfgs 是一个演示思路的简化伪代码。")
  print("实际应用需要更完整的实现或考虑其他方案。")
  # 返回一个示意性的无效对象,因为上面的代码不完整
  return optax.identity() # 返回一个无操作的转换作为占位符

# 使用时 (概念上的)
# optimizer = optax.chain(
#     # ... 其他转换 ...
#     scale_by_lbfgs_safe(history_size=15, min_vdot_value=1e-7),
#     # ... 可能的 scale step size ...
# )

注意: 上面的代码 scale_by_lbfgs_safe 只是一个 原理示意 ,并非 Optax 官方 optax.lbfgs 的完整、可直接替换的实现。完整实现 L-BFGS 的状态管理和双循环递归相当复杂。这里的重点是展示修改 1.0 / vdot 计算的思路。实际应用中,你可能需要更深入地研究 optax.lbfgs 的源代码,或者采用下面的其他方法。

安全建议: epsilon 的选择需要小心。太小了可能还是会溢出,太大了则可能影响优化器的性能,引入偏差。通常从 1e-81e-6 开始尝试。

方案二:管住梯度,别让它“放飞自我” (Gradient Clipping)

梯度裁剪是一种常用的稳定训练技巧,虽然不能直接解决 1.0 / vdot 的问题,但它能间接改善数值稳定性。

原理:

通过限制梯度的范数(Norm)或绝对值,防止梯度变得过大或过小。这有助于避免参数更新步子迈得太大扯着嗓子(导致震荡或发散),也可能间接使得 s_ky_k 的变化更平稳,从而降低 s_k^T y_k 变得极端异常小的概率。

操作步骤:

Optax 提供了方便的梯度裁剪函数,通常加在优化链的早期。

import optax

# 方案 2.1: 按全局范数裁剪 (常用)
# 将所有参数的梯度视为一个大向量,计算其范数,如果超过 max_norm,则缩放整个梯度向量
clip_norm = 1.0 # 经验值,需要根据你的模型和数据调整
optimizer = optax.chain(
    optax.clip_by_global_norm(clip_norm),
    # ... 这里接你的 L-BFGS 或其他优化步骤 ...
    # 例如:optax.lbfgs(learning_rate=1.0, history_size=10)
    # 注意: optax.lbfgs 本身可能整合了 scale_by_lbfgs 的逻辑
)

# 方案 2.2: 按值裁剪 (不太常用,但有时有用)
# 对每个梯度值进行独立裁剪,限制在 [-max_value, max_value] 区间
# clip_value = 1.0
# optimizer = optax.chain(
#     optax.clip(clip_value),
#     # ... L-BFGS 步骤 ...
# )

进阶技巧: clip_norm 的值需要根据实验调整。设太小可能限制模型学习能力,太大则起不到裁剪效果。可以监控训练过程中的梯度范数来辅助选择合适的值。

方案三:提高精度,用“双精度”试试 (float64)

如果硬件支持且性能可以接受,使用 float64 (双精度) 可以大大提升数值计算的精度和范围,可能直接避免 float32 下的溢出问题。

原理:

float64float32 有更长的尾数位和指数位,能表示更大范围和更高精度的数值。这使得它在处理非常小或非常大的数时更不容易出错。

操作步骤:

需要在 JAX 启动时全局启用 64 位浮点数支持。

import jax

# 在程序开头(import jax 之后)设置
jax.config.update("jax_enable_x64", True)

# 后续创建 JAX 数组、定义模型和优化器时,它们通常会默认使用 float64
# 或者可以在创建数组时显式指定 dtype=jnp.float64

# ... 定义你的模型和损失函数 ...

# 优化器也可能需要调整,检查 optax.lbfgs 是否有参数可以利用高精度状态
# 例如,一些 Optax 组件有 use_high_precision_vars 参数
# optimizer = optax.lbfgs(learning_rate=1.0, history_size=10, use_high_precision_vars=True) # 检查 LBFGS 是否支持此参数

# 注意:确保你的硬件(CPU/GPU)对 float64 支持良好,否则性能会大幅下降
# TPU 通常对 float64 支持有限或性能不佳,更倾向于 bfloat16/float32

进阶技巧: float64 会消耗两倍内存,并且计算通常比 float32 慢,尤其是在 GPU 或 TPU 上。这是一种“大力出奇迹”的方法,但在精度是瓶颈时非常有效。

方案四:调整 L-BFGS 历史记录大小 (History Size)

L-BFGS 通过存储最近 m 次的参数和梯度变化 (s_k, y_k) 来近似 Hessian 逆矩阵。这个 m 就是 history_size

原理:

减小 history_size 会让 L-BFGS 使用更少的历史信息,这可能:

  • 减少了累积旧的、可能数值不稳定的 (s_k, y_k) 对的机会。
  • 降低了内存消耗。
  • 但也牺牲了对 Hessian 矩阵逆的近似精度,可能影响收敛速度或效果。

操作步骤:

在创建 L-BFGS 优化器时设置 history_size 参数。

import optax

# 尝试减小 history_size,比如从默认的 10-20 减到 5-8
optimizer = optax.lbfgs(
    learning_rate=1.0, # L-BFGS 的 learning_rate 通常设为 1.0,步长由线搜索或内部计算确定
    history_size=8 # 减小历史记录大小
)

进阶技巧: history_size 是一个权衡参数。较大的值通常能更好地近似二阶信息,但计算量和内存占用也更大,且可能更容易引入数值问题。可以从小到大尝试不同的 history_size

方案五:检查你的模型和损失函数

有时候,数值不稳定并非优化器本身的问题,而是源于:

  • 模型设计: 某些层或激活函数可能导致梯度爆炸或消失。
  • 损失函数: 损失函数的梯度是否在某些区域行为异常?
  • 数据预处理: 输入数据是否归一化得当?是否存在极端值?

排查这些方面,确保模型训练过程本身是健康的,有时也能间接解决优化器的数值问题。

希望以上这些方案能帮你解决 Optax L-BFGS 中的 NaN 烦恼。通常可以从方案一(加 Epsilon,如果能方便实现)、方案二(梯度裁剪)和方案三(双精度)入手尝试。

相关资源