首页
/ PyTorch Lightning中如何避免保存预训练子模块的检查点

PyTorch Lightning中如何避免保存预训练子模块的检查点

2025-05-05 06:43:36作者:袁立春Spencer

在PyTorch Lightning项目中,当模型包含大型预训练子模块(如LLM、VAE等)时,默认的检查点机制会保存所有子模块的状态,这会导致存储空间和时间的浪费。本文将介绍如何优雅地解决这一问题。

问题背景

在构建复杂模型时,我们经常会使用预训练好的子模块作为模型的组成部分。这些子模块通常在整个训练过程中保持冻结状态,不会被更新。然而,PyTorch Lightning默认的检查点机制会保存所有子模块的参数,这带来了两个问题:

  1. 存储空间浪费:大型预训练模型(如LLM)的参数可能占用数GB空间
  2. 时间浪费:每次保存检查点时都需要序列化这些不变的数据

解决方案

PyTorch Lightning提供了两种主要方式来解决这个问题:

1. 自定义state_dict方法

通过重写模型的state_dict方法,可以精确控制哪些参数需要保存。例如:

def state_dict(self, *args, **kwargs):
    # 获取完整的状态字典
    state_dict = super().state_dict(*args, **kwargs)
    
    # 移除不需要保存的子模块
    for key in list(state_dict.keys()):
        if key.startswith("vae."):  # 假设vae是预训练的子模块
            del state_dict[key]
    
    return state_dict

这种方法提供了最大的灵活性,可以精确控制保存哪些参数。

2. 使用strict_loading=False选项

从PyTorch Lightning 2.2版本开始,可以设置self.strict_loading = False来允许加载部分检查点。这样即使检查点中不包含某些子模块的参数,模型也能正常加载。

class MyModel(pl.LightningModule):
    def __init__(self):
        super().__init__()
        self.strict_loading = False
        self.vae = AutoencoderKL.from_pretrained(...)  # 预训练的子模块

最佳实践

结合上述两种方法,可以构建既节省存储空间又便于部署的模型:

  1. 对于完全冻结的预训练子模块,从state_dict中排除
  2. 设置strict_loading=False以确保模型能加载部分检查点
  3. 在文档中明确说明哪些子模块需要单独加载

注意事项

  1. 确保排除的子模块确实不需要训练
  2. 在部署时需要单独提供预训练子模块的权重
  3. 测试模型加载逻辑,确保排除子模块后不影响功能

通过合理使用这些技术,可以显著减少模型检查点的大小,加快保存和加载速度,同时保持模型的完整功能。

登录后查看全文
热门项目推荐
相关项目推荐

项目优选

收起
kernelkernel
deepin linux kernel
C
22
6
docsdocs
OpenHarmony documentation | OpenHarmony开发者文档
Dockerfile
197
2.17 K
ohos_react_nativeohos_react_native
React Native鸿蒙化仓库
C++
208
285
pytorchpytorch
Ascend Extension for PyTorch
Python
59
94
RuoYi-Vue3RuoYi-Vue3
🎉 (RuoYi)官方仓库 基于SpringBoot,Spring Security,JWT,Vue3 & Vite、Element Plus 的前后端分离权限管理系统
Vue
973
574
nop-entropynop-entropy
Nop Platform 2.0是基于可逆计算理论实现的采用面向语言编程范式的新一代低代码开发平台,包含基于全新原理从零开始研发的GraphQL引擎、ORM引擎、工作流引擎、报表引擎、规则引擎、批处理引引擎等完整设计。nop-entropy是它的后端部分,采用java语言实现,可选择集成Spring框架或者Quarkus框架。中小企业可以免费商用
Java
9
1
ops-mathops-math
本项目是CANN提供的数学类基础计算算子库,实现网络在NPU上加速计算。
C++
549
81
openHiTLSopenHiTLS
旨在打造算法先进、性能卓越、高效敏捷、安全可靠的密码套件,通过轻量级、可剪裁的软件技术架构满足各行业不同场景的多样化要求,让密码技术应用更简单,同时探索后量子等先进算法创新实践,构建密码前沿技术底座!
C
1.02 K
399
communitycommunity
本项目是CANN开源社区的核心管理仓库,包含社区的治理章程、治理组织、通用操作指引及流程规范等基础信息
393
27
MateChatMateChat
前端智能化场景解决方案UI库,轻松构建你的AI应用,我们将持续完善更新,欢迎你的使用与建议。 官网地址:https://matechat.gitcode.com
1.2 K
133