首页
/ 3D Gaussian Splatting模型保存与加载:PLY格式与优化状态

3D Gaussian Splatting模型保存与加载:PLY格式与优化状态

2026-02-04 05:05:23作者:谭伦延

引言:模型持久化的技术挑战

在3D Gaussian Splatting(3DGS)实时渲染系统中,模型的保存与加载是连接训练与部署的关键环节。不同于传统网格模型仅需存储顶点与拓扑关系,3DGS的核心表示包含球形高斯分布(Spherical Gaussian)的几何参数、辐射场系数和优化状态,这对存储格式和恢复机制提出了特殊要求。本文将深入解析3DGS官方实现中的模型持久化方案,重点剖析PLY(Polygon File Format)格式的扩展应用、优化器状态的保存策略,以及复杂参数在磁盘与内存间的高效转换机制。

读完本文你将掌握:

  • PLY格式如何扩展以存储高斯分布的全部参数
  • 模型训练状态(含优化器)的完整持久化方案
  • 加载过程中的参数重构与内存管理技巧
  • 大规模场景下的模型存储优化实践

PLY格式的扩展:超越传统网格的参数存储

标准PLY格式的局限性

传统PLY格式主要面向多边形网格数据,其顶点属性通常包含三维坐标(x,y,z)、法向量(nx,ny,nz)和颜色(red,green,blue)。3DGS模型包含的球谐系数缩放因子旋转四元数不透明度等参数无法通过标准属性表达,需要设计扩展属性集。

3DGS专用PLY属性定义

gaussian_model.pyconstruct_list_of_attributes方法中,官方实现定义了包含23个属性的扩展PLY格式:

def construct_list_of_attributes(self):
    l = ['x', 'y', 'z', 'nx', 'ny', 'nz']  # 几何基础属性
    # 球谐系数(DC分量3个 + 高阶分量3*(SH_degree²-1)个)
    for i in range(self._features_dc.shape[1]*self._features_dc.shape[2]):
        l.append('f_dc_{}'.format(i))
    for i in range(self._features_rest.shape[1]*self._features_rest.shape[2]):
        l.append('f_rest_{}'.format(i))
    l.append('opacity')  # 不透明度
    for i in range(self._scaling.shape[1]):  # 缩放因子(3个分量)
        l.append('scale_{}'.format(i))
    for i in range(self._rotation.shape[1]):  # 旋转四元数(4个分量)
        l.append('rot_{}'.format(i))
    return l

属性解析表

属性类别 字段名前缀 数量 数据意义 存储类型
几何坐标 x,y,z 3 高斯中心三维坐标 float32
法向量 nx,ny,nz 3 占位字段(未使用) float32
球谐DC分量 f_dc_* 3 低频辐射场系数 float32
球谐高阶分量 f_rest_* 3*(L²-1) 高频辐射场系数(L为SH阶数) float32
不透明度 opacity 1 密度控制参数(sigmoid前) float32
缩放因子 scale_* 3 轴对齐缩放(指数前) float32
旋转四元数 rot_* 4 朝向参数(规范化前) float32

关键设计决策:所有参数均存储原始优化值而非激活后的值。例如缩放因子存储的是对数空间值(_scaling),加载时需通过scaling_activation(指数函数)转换为实际缩放值。这种设计确保模型加载后可直接继续训练。

数据打包与文件写入

save_ply方法实现了内存参数到磁盘文件的转换流程:

def save_ply(self, path):
    mkdir_p(os.path.dirname(path))
    
    # 提取并转换参数(CPU迁移与格式调整)
    xyz = self._xyz.detach().cpu().numpy()
    normals = np.zeros_like(xyz)  # 占位法向量
    f_dc = self._features_dc.detach().transpose(1, 2).flatten(start_dim=1).cpu().numpy()
    f_rest = self._features_rest.detach().transpose(1, 2).flatten(start_dim=1).cpu().numpy()
    opacities = self._opacity.detach().cpu().numpy()
    scale = self._scaling.detach().cpu().numpy()
    rotation = self._rotation.detach().cpu().numpy()
    
    # 构建属性数组
    dtype_full = [(attr, 'f4') for attr in self.construct_list_of_attributes()]
    elements = np.empty(xyz.shape[0], dtype=dtype_full)
    attributes = np.concatenate((xyz, normals, f_dc, f_rest, opacities, scale, rotation), axis=1)
    elements[:] = list(map(tuple, attributes))
    
    # 写入PLY文件
    el = PlyElement.describe(elements, 'vertex')
    PlyData([el]).write(path)

性能考量:对于包含100万个高斯的模型,单个PLY文件体积约为100MB(每个高斯约100字节)。通过detach().cpu().numpy()的分步执行,可避免峰值内存占用过高。

模型加载:从文件到优化状态的重建

PLY文件解析流程

load_ply方法实现了与保存过程的逆向操作,重点解决参数类型转换和内存布局重建:

def load_ply(self, path):
    plydata = PlyData.read(path)
    xyz = np.stack((plydata.elements[0]["x"], plydata.elements[0]["y"], plydata.elements[0]["z"]), axis=1)
    
    # 球谐系数重组(DC分量与高阶分量分离)
    features_dc = np.zeros((xyz.shape[0], 3, 1))
    features_dc[:, 0, 0] = plydata.elements[0]["f_dc_0"]
    features_dc[:, 1, 0] = plydata.elements[0]["f_dc_1"]
    features_dc[:, 2, 0] = plydata.elements[0]["f_dc_2"]
    
    # 高阶系数需要根据SH阶数动态解析
    extra_f_names = sorted([p.name for p in plydata.elements[0].properties if p.name.startswith("f_rest_")],
                          key=lambda x: int(x.split('_')[-1]))
    features_extra = np.zeros((xyz.shape[0], len(extra_f_names)))
    for idx, attr_name in enumerate(extra_f_names):
        features_extra[:, idx] = plydata.elements[0][attr_name]
    features_extra = features_extra.reshape((xyz.shape[0], 3, (self.max_sh_degree + 1)**2 - 1))
    
    # 参数转换为PyTorch张量
    self._xyz = nn.Parameter(torch.tensor(xyz, dtype=torch.float, device="cuda").requires_grad_(True))
    self._features_dc = nn.Parameter(torch.tensor(features_dc, dtype=torch.float, device="cuda").transpose(1, 2).contiguous().requires_grad_(True))
    # ... 其他参数的加载过程 ...

关键挑战:球谐高阶系数的维度取决于max_sh_degree,加载时需根据当前配置动态调整。官方实现通过reshape((xyz.shape[0], 3, (self.max_sh_degree + 1)**2 - 1))确保维度匹配。

优化器状态的完整恢复

PLY文件仅存储模型参数,不包含优化器状态。为实现训练断点续训,3DGS设计了双持久化机制

def capture(self):
    """捕获完整训练状态(含优化器)"""
    return (
        self.active_sh_degree,
        self._xyz, self._features_dc, self._features_rest,
        self._scaling, self._rotation, self._opacity,
        self.max_radii2D, self.xyz_gradient_accum, self.denom,
        self.optimizer.state_dict(),  # 优化器状态
        self.spatial_lr_scale,
    )

def restore(self, model_args, training_args):
    """从捕获的状态恢复训练"""
    (self.active_sh_degree, self._xyz, self._features_dc, self._features_rest,
     self._scaling, self._rotation, self._opacity, self.max_radii2D,
     xyz_gradient_accum, denom, opt_dict, self.spatial_lr_scale) = model_args
    self.training_setup(training_args)
    self.xyz_gradient_accum = xyz_gradient_accum
    self.denom = denom
    self.optimizer.load_state_dict(opt_dict)  # 恢复优化器

优化器状态包含

  • 动量项(exp_avg)和二阶矩估计(exp_avg_sq
  • 参数学习率和权重衰减配置
  • 迭代次数记录(隐式包含在学习率调度器中)

高级主题:大规模场景的存储优化

参数压缩策略

对于包含1000万个高斯的超大场景,原始PLY格式可能导致1GB以上的存储开销。实践中可采用以下优化:

  1. 量化存储:将球谐系数从32位浮点数压缩为16位甚至8位定点数。官方实现中features_dcfeatures_rest均使用32位float,可通过numpy.astype(np.float16)降低50%存储占用。

  2. 空间分区存储:参考dataset_readers.py中的场景分块逻辑,将大场景分割为多个子PLY文件:

# 伪代码:空间分区存储实现
def save_partitioned_ply(xyz, params, partition_size=1e6):
    for i in range(0, xyz.shape[0], partition_size):
        partition_xyz = xyz[i:i+partition_size]
        save_ply(f"model_part_{i//partition_size}.ply", partition_xyz, params[i:i+partition_size])
  1. 增量更新机制:仅保存训练过程中变化的参数,通过基线模型+增量文件减少重复存储。

加载性能优化

大规模模型加载时,内存占用可能成为瓶颈。load_ply方法可通过以下改进提升效率:

  1. 内存映射:使用numpy.memmap延迟加载大文件:
def load_large_ply(path):
    plydata = PlyData.read(path)  # 基础元数据读取
    xyz = np.memmap(path, dtype='f4', mode='r', offset=plydata.elements[0].data.offset,
                   shape=(plydata.elements[0].count, 3))
    return xyz
  1. 按需加载:结合视锥体剔除,仅加载当前视场可见的高斯参数,实现流式渲染。

完整工作流:从训练到部署

训练状态保存流程

sequenceDiagram
    participant Trainer
    participant GaussianModel
    participant FileSystem
    
    Trainer->>GaussianModel: capture()
    GaussianModel->>GaussianModel: 收集参数与优化器状态
    GaussianModel-->>Trainer: 返回状态元组
    Trainer->>FileSystem: 保存状态元组到.pth文件
    Trainer->>GaussianModel: save_ply("output/model.ply")
    GaussianModel->>FileSystem: 写入PLY文件

部署加载流程

flowchart TD
    A[加载PLY文件] --> B[解析xyz坐标]
    A --> C[解析球谐系数]
    A --> D[解析几何参数]
    B --> E[构建高斯位置张量]
    C --> F[重组SH系数矩阵]
    D --> G[构建缩放/旋转张量]
    E & F & G --> H[创建GaussianModel实例]
    H --> I[渲染管线集成]

常见问题与解决方案

问题场景 原因分析 解决方案
加载后渲染颜色异常 球谐系数维度不匹配 检查max_sh_degree是否与保存时一致
优化器加载失败 参数数量变化 使用_prune_optimizer修剪优化器状态
PLY文件体积过大 未过滤低贡献高斯 保存前执行prune_points(不透明度阈值0.01)
加载速度慢 文件IO瓶颈 采用分块PLY+内存映射方案

总结与展望

3D Gaussian Splatting的模型持久化方案通过扩展PLY格式和优化器状态捕获,实现了训练过程的无缝衔接和部署阶段的高效加载。随着实时渲染需求的增长,未来可能出现更紧凑的专用格式(如二进制GS格式)和硬件加速的加载路径。开发者在实际应用中需根据场景规模平衡存储开销与加载速度,关键是理解PLY文件中每个参数的物理意义和优化器状态的恢复机制。

登录后查看全文
热门项目推荐
相关项目推荐