首页
/ 在Python中直接调用nnUNet训练模型进行推理的方法

在Python中直接调用nnUNet训练模型进行推理的方法

2025-06-02 09:46:06作者:钟日瑜

背景介绍

nnUNet是医学图像分割领域广泛使用的优秀框架,其标准使用方式是通过命令行工具进行模型训练和预测。但在实际应用中,开发者有时需要将训练好的模型直接集成到Python代码中,而不是通过终端命令调用。

标准命令行预测方式

nnUNet通常通过以下命令进行预测:

nnUNetv2_predict -d Dataset510_Testsplits_cardiac -i "$input_data" -o "$output_data" -f 0 1 2 3 4 -tr nnUNetTrainer -c 2d -p nnUNetPlans --save_probabilities

这种方式虽然简单直接,但在需要将模型集成到更复杂的工作流中时,就显得不够灵活。

Python直接集成方案

nnUNet提供了Python API,允许开发者直接在代码中调用训练好的模型。核心类是nnUNetPredictor,它封装了完整的预测流程,包括预处理、网络前向传播和后处理等步骤。

基本使用示例

from nnunetv2.paths import nnUNet_results, nnUNet_raw
import torch
from batchgenerators.utilities.file_and_folder_operations import join
from nnunetv2.inference.predict_from_raw_data import nnUNetPredictor

# 初始化预测器
predictor = nnUNetPredictor(
    tile_step_size=0.5,
    use_gaussian=True,
    use_mirroring=True,
    perform_everything_on_device=True,
    device=torch.device('cuda', 0),
    verbose=False,
    verbose_preprocessing=False,
    allow_tqdm=True
)

# 加载训练好的模型
predictor.initialize_from_trained_model_folder(
    join(nnUNet_results, 'Dataset003_Liver/nnUNetTrainer__nnUNetPlans__3d_lowres'),
    use_folds=(0,),
    checkpoint_name='checkpoint_final.pth',
)

# 执行预测
predictor.predict_from_files(
    join(nnUNet_raw, 'Dataset003_Liver/imagesTs'),
    join(nnUNet_raw, 'Dataset003_Liver/imagesTs_predlowres'),
    save_probabilities=False,
    overwrite=False,
    num_processes_preprocessing=2,
    num_processes_segmentation_export=2,
    folder_with_segs_from_prev_stage=None,
    num_parts=1,
    part_id=0
)

关键参数说明

  1. 预测器初始化参数

    • tile_step_size: 控制重叠区域大小的步长
    • use_gaussian: 是否使用高斯权重
    • use_mirroring: 是否使用测试时数据增强
    • perform_everything_on_device: 是否全程在GPU上执行
  2. 模型加载参数

    • 需要指定模型存储路径和检查点名称
    • 可以指定使用的交叉验证折数
  3. 预测参数

    • 可以控制是否保存概率图
    • 支持多进程预处理和结果导出

注意事项

  1. nnUNet的预测流程不仅仅是简单的模型前向传播,还包括了完整的预处理和后处理流程,这是保证预测质量的关键。

  2. 如果确实需要直接访问模型对象,可以通过predictor.network属性获取,但建议仅在充分理解nnUNet内部机制的情况下这样做。

  3. 对于大多数应用场景,使用封装好的predict_from_files方法已经足够,它提供了与命令行工具相同的功能,但更加灵活。

总结

通过nnUNet提供的Python API,开发者可以方便地将训练好的模型集成到自己的Python工作流中,实现更灵活的医学图像分割应用。这种方式既保留了nnUNet强大的预处理和预测能力,又提供了更好的程序集成性。

登录后查看全文
热门项目推荐
相关项目推荐