首页
/ Optax项目中使用LBFGS优化器结合线搜索的实践指南

Optax项目中使用LBFGS优化器结合线搜索的实践指南

2025-07-07 14:34:02作者:戚魁泉Nursing

背景介绍

在机器学习模型训练过程中,二阶优化算法因其收敛速度快的特点受到广泛关注。L-BFGS作为经典的拟牛顿法优化算法,在深度学习领域也有重要应用。Google DeepMind开发的Optax库提供了LBFGS优化器的实现,并支持线搜索功能,能够进一步提升优化效果。

核心问题

当在Equinox框架下构建自定义模型时,如何正确使用Optax的LBFGS优化器并启用线搜索功能,需要特别注意几个关键点:

  1. 线搜索需要单独定义纯损失函数(value_fn),而非同时计算值和梯度的函数
  2. 在Equinox框架下需要正确处理模型的可微分和不可微分部分
  3. 参数更新时需要保持模型结构的完整性

解决方案详解

1. 定义纯损失函数

线搜索算法需要一个仅返回损失值的函数,因此我们需要单独定义:

def loss_fn(model, ts, ys_true):
    y0 = jnp.array([0.0])
    y_pred = model(ts, y0)
    return jnp.mean((y_pred - ys_true) ** 2)

这与常见的value_and_grad函数不同,后者会同时返回损失值和梯度。

2. 模型参数处理

在Equinox框架下,我们需要区分模型的可微分和不可微分部分:

model_params, model_struct = eqx.partition(model, eqx.is_array)

这种分割确保了在优化过程中只更新可训练参数,同时保持模型结构不变。

3. 创建线搜索兼容的损失函数

为了在线搜索中使用,我们需要创建一个闭包函数,将当前模型结构和输入数据绑定:

loss_fn_ = lambda model_params: loss_fn(
    eqx.combine(model_params, model_struct), ti, yi)

这个lambda函数将模型参数与固定结构重新组合,确保每次线搜索评估时模型结构保持一致。

4. 优化步骤实现

完整的训练步骤实现如下:

@eqx.filter_jit
def make_step(ti, yi, model, opt_state):
    # 计算损失和梯度
    loss, grads = eqx.filter_value_and_grad(loss_fn)(model, ti, yi)
    
    # 准备优化器输入
    grads = eqx.filter(grads, eqx.is_array)
    opt_state = eqx.filter(opt_state, eqx.is_array)
    model_params, model_struct = eqx.partition(model, eqx.is_array)
    
    # 创建线搜索兼容的损失函数
    loss_fn_ = lambda model_params: loss_fn(
        eqx.combine(model_params, model_struct), ti, yi)
    
    # 执行优化步骤
    updates, opt_state = optim.update(
        grads, opt_state, model_params, 
        value=loss, grad=grads, value_fn=loss_fn_)
    
    # 更新模型
    model = eqx.apply_updates(model, updates)
    return loss, model, opt_state

关键注意事项

  1. 函数类型匹配:确保传递给线搜索的是纯损失函数,而非value_and_grad函数
  2. 模型完整性:在参数更新前后保持模型结构的完整性
  3. 性能考量:线搜索会增加每次迭代的计算量,但通常能减少总迭代次数
  4. 学习率设置:即使启用了线搜索,初始学习率的设置仍然会影响优化效果

实际应用建议

对于科学计算和物理信息神经网络(PINN)等场景,LBFGS+线搜索的组合往往能取得比一阶优化器更好的效果。建议:

  1. 先使用Adam等一阶优化器进行预训练
  2. 切换到LBFGS进行精细优化
  3. 监控线搜索的接受率,调整初始学习率

通过合理使用Optax提供的LBFGS实现,可以在保持代码简洁的同时,获得接近二阶优化算法的性能表现。

登录后查看全文
热门项目推荐

热门内容推荐

最新内容推荐

项目优选

收起
docsdocs
OpenHarmony documentation | OpenHarmony开发者文档
Dockerfile
144
1.93 K
kernelkernel
deepin linux kernel
C
22
6
ohos_react_nativeohos_react_native
React Native鸿蒙化仓库
C++
192
274
openGauss-serveropenGauss-server
openGauss kernel ~ openGauss is an open source relational database management system
C++
145
189
RuoYi-Vue3RuoYi-Vue3
🎉 (RuoYi)官方仓库 基于SpringBoot,Spring Security,JWT,Vue3 & Vite、Element Plus 的前后端分离权限管理系统
Vue
930
553
nop-entropynop-entropy
Nop Platform 2.0是基于可逆计算理论实现的采用面向语言编程范式的新一代低代码开发平台,包含基于全新原理从零开始研发的GraphQL引擎、ORM引擎、工作流引擎、报表引擎、规则引擎、批处理引引擎等完整设计。nop-entropy是它的后端部分,采用java语言实现,可选择集成Spring框架或者Quarkus框架。中小企业可以免费商用
Java
8
0
openHiTLSopenHiTLS
旨在打造算法先进、性能卓越、高效敏捷、安全可靠的密码套件,通过轻量级、可剪裁的软件技术架构满足各行业不同场景的多样化要求,让密码技术应用更简单,同时探索后量子等先进算法创新实践,构建密码前沿技术底座!
C
423
392
金融AI编程实战金融AI编程实战
为非计算机科班出身 (例如财经类高校金融学院) 同学量身定制,新手友好,让学生以亲身实践开源开发的方式,学会使用计算机自动化自己的科研/创新工作。案例以量化投资为主线,涉及 Bash、Python、SQL、BI、AI 等全技术栈,培养面向未来的数智化人才 (如数据工程师、数据分析师、数据科学家、数据决策者、量化投资人)。
Jupyter Notebook
75
66
CangjieCommunityCangjieCommunity
为仓颉编程语言开发者打造活跃、开放、高质量的社区环境
Markdown
1.11 K
0
openHiTLS-examplesopenHiTLS-examples
本仓将为广大高校开发者提供开源实践和创新开发平台,收集和展示openHiTLS示例代码及创新应用,欢迎大家投稿,让全世界看到您的精巧密码实现设计,也让更多人通过您的优秀成果,理解、喜爱上密码技术。
C
64
511