首页
/ Detectron2中Faster RCNN模型导出TorchScript格式问题解析

Detectron2中Faster RCNN模型导出TorchScript格式问题解析

2025-05-04 22:56:40作者:柯茵沙

问题背景

在使用Detectron2框架训练Faster RCNN模型(R_101_FPN_3x版本)并尝试导出为TorchScript格式时,开发者遇到了一个常见问题:导出的模型虽然能成功运行,但在推理时却无法输出预期的检测结果,而是返回空张量。这个问题在自定义数据集训练的场景下尤为常见。

问题现象

当开发者按照标准流程将训练好的Faster RCNN模型导出为TorchScript格式后,使用简单的测试脚本加载模型并进行推理时,输出结果如下:

(tensor([], size=(0, 4), grad_fn=<IndexBackward0>), 
 tensor([], dtype=torch.int64), 
 tensor([], grad_fn=<IndexBackward0>), 
 tensor([408, 612]))

这表明模型虽然运行成功,但未能检测到任何目标,而实际上在训练和评估阶段,同一张图片是可以正确检测出目标的。

问题根源分析

经过深入调查,发现这个问题主要由以下几个因素导致:

  1. 输入数据预处理不匹配:直接使用torchvision.transforms.ToTensor()进行图像转换,与模型训练时的预处理流程不一致。

  2. 输入格式差异:Detectron2训练时使用特定的输入格式,而导出后的TorchScript模型期望的输入格式有所不同。

  3. 输出结构变化:TorchScript导出的模型输出结构从字典变为了元组,需要开发者调整后处理逻辑。

解决方案

正确的处理流程应该包含以下几个关键步骤:

1. 正确的输入预处理

def get_input(image_file):
    # 使用OpenCV读取图像,保持与训练时一致的颜色通道顺序(BGR)
    input_img = cv2.imread(image_file, cv2.IMREAD_COLOR)
    height, width, channels = input_img.shape
    
    # 转换为PyTorch张量并调整维度顺序
    img_tensor = torch.from_numpy(input_img).view(height, width, channels).to(torch.float32)
    img_tensor = img_tensor.permute(2, 0, 1).contiguous()
    
    return img_tensor

2. 模型加载与推理

model = torch.jit.load(model_path)
model.eval()

input_tensor = get_input(img_path)
output = model.forward(input_tensor)

3. 输出结果解析

成功运行后,输出将变为如下格式:

(tensor([[377.6366, 397.9915, 779.1143, 561.9951],
         [550.6393, 388.4034, 778.9048, 546.3689]], grad_fn=<IndexBackward0>), 
 tensor([0, 0]), 
 tensor([0.9840, 0.0890], grad_fn=<IndexBackward0>), 
 tensor([1527,  990]))

其中:

  • 第一个张量:检测到的边界框坐标(x1,y1,x2,y2)
  • 第二个张量:预测的类别索引
  • 第三个张量:预测的置信度分数
  • 第四个张量:输入图像的尺寸

技术要点总结

  1. 图像读取一致性:必须使用与训练时相同的图像读取方式(OpenCV),因为PIL和OpenCV在颜色通道顺序上存在差异。

  2. 张量转换规范:需要确保张量的数据类型(float32)和维度顺序(C,H,W)正确。

  3. 输出结构调整:导出的TorchScript模型输出结构从字典变为元组,开发者需要相应调整后处理代码。

  4. 尺寸处理:虽然可以调整图像尺寸,但需要保持与训练时相同的宽高比处理方式。

最佳实践建议

  1. 在导出模型前,先在Python环境中验证模型的推理功能正常。

  2. 建立标准化的预处理流程,确保训练和部署阶段的一致性。

  3. 对于生产环境,建议封装专门的预处理和后处理类,提高代码的可维护性。

  4. 考虑添加日志记录输入输出张量的形状和数值范围,便于调试。

通过以上解决方案,开发者可以成功将Detectron2训练的Faster RCNN模型导出为TorchScript格式,并在部署环境中获得正确的推理结果。这一过程强调了深度学习模型从训练到部署全流程中数据一致性的重要性。

登录后查看全文
热门项目推荐

热门内容推荐

项目优选

收起
openHiTLS-examplesopenHiTLS-examples
本仓将为广大高校开发者提供开源实践和创新开发平台,收集和展示openHiTLS示例代码及创新应用,欢迎大家投稿,让全世界看到您的精巧密码实现设计,也让更多人通过您的优秀成果,理解、喜爱上密码技术。
C
52
461
kernelkernel
deepin linux kernel
C
22
5
openHiTLSopenHiTLS
旨在打造算法先进、性能卓越、高效敏捷、安全可靠的密码套件,通过轻量级、可剪裁的软件技术架构满足各行业不同场景的多样化要求,让密码技术应用更简单,同时探索后量子等先进算法创新实践,构建密码前沿技术底座!
C
349
381
nop-entropynop-entropy
Nop Platform 2.0是基于可逆计算理论实现的采用面向语言编程范式的新一代低代码开发平台,包含基于全新原理从零开始研发的GraphQL引擎、ORM引擎、工作流引擎、报表引擎、规则引擎、批处理引引擎等完整设计。nop-entropy是它的后端部分,采用java语言实现,可选择集成Spring框架或者Quarkus框架。中小企业可以免费商用
Java
7
0
openGauss-serveropenGauss-server
openGauss kernel ~ openGauss is an open source relational database management system
C++
131
185
RuoYi-Vue3RuoYi-Vue3
🎉 (RuoYi)官方仓库 基于SpringBoot,Spring Security,JWT,Vue3 & Vite、Element Plus 的前后端分离权限管理系统
Vue
873
517
Cangjie-ExamplesCangjie-Examples
本仓将收集和展示高质量的仓颉示例代码,欢迎大家投稿,让全世界看到您的妙趣设计,也让更多人通过您的编码理解和喜爱仓颉语言。
Cangjie
336
1.09 K
ohos_react_nativeohos_react_native
React Native鸿蒙化仓库
C++
179
264
cherry-studiocherry-studio
🍒 Cherry Studio 是一款支持多个 LLM 提供商的桌面客户端
TypeScript
608
59
note-gennote-gen
一款跨平台的 Markdown AI 笔记软件,致力于使用 AI 建立记录和写作的桥梁。
TSX
83
4