tch-rs项目中实现Rust版NMS算法的探索
2025-06-11 13:50:48作者:韦蓉瑛
在将YOLO目标检测器从Python迁移到Rust的过程中,一个关键挑战是如何实现torchvision.ops.nms功能。本文将详细介绍如何在tch-rs项目中实现非极大值抑制(NMS)算法及其相关功能。
NMS算法的重要性
非极大值抑制(Non-Maximum Suppression)是目标检测后处理中的关键步骤,它用于消除冗余的检测框,保留最有可能代表真实目标的检测结果。在Python生态中,这一功能通常由torchvision.ops.nms提供。
Rust实现方案
由于tch-rs项目尚未提供torchvision功能的完整绑定,开发者需要自行实现NMS相关算法。以下是完整的Rust实现方案:
基础NMS实现
fn nms(boxes: Tensor, scores: &Tensor, iou_threshold: f32) -> Tensor {
let mut sorting: Vec<i64> = scores.argsort(0, false).try_into().unwrap();
let mut keep: Vec<i64> = Vec::new();
while let Some(idx) = sorting.pop() {
keep.push(idx);
for i in (0..sorting.len()).rev() {
if iou(&boxes.i(idx), &boxes.i(sorting[i])).double_value(&[]) > iou_threshold as f64 {
_ = sorting.remove(i);
}
}
}
Tensor::try_from(keep).unwrap().to_device(boxes.device())
}
这个实现采用了经典的NMS算法流程:
- 根据置信度分数对检测框进行排序
- 选择分数最高的检测框作为保留结果
- 移除与该检测框IoU超过阈值的其他检测框
- 重复上述过程直到处理完所有检测框
IoU计算
交并比(Intersection over Union)是NMS算法的核心计算:
fn iou(box1: &Tensor, box2: &Tensor) -> Tensor {
let zero = Tensor::zeros_like(&box1.i(0));
let b1_area = (box1.i(2) - box1.i(0) + 1) * (box1.i(3) - box1.i(1) + 1);
let b2_area = (box2.i(2) - box2.i(0) + 1) * (box2.i(3) - box2.i(1) + 1);
let i_xmin = box1.i(0).max_other(&box2.i(0));
let i_xmax = box1.i(2).min_other(&box2.i(2));
let i_ymin = box1.i(1).max_other(&box2.i(1));
let i_ymax = box1.i(3).min_other(&box2.i(3));
let i_area = (i_xmax - i_xmin + 1).max_other(&zero) * (i_ymax - i_ymin + 1).max_other(&zero);
&i_area / (b1_area + b2_area - &i_area)
}
批处理NMS实现
针对多类别检测任务,需要实现批处理版本的NMS:
fn batched_nms(boxes: &Tensor, scores: &Tensor, idxs: &Tensor, iou_threshold: f32) -> Tensor {
if boxes.numel() > (if boxes.device() == tch::Device::Cpu {4000} else {20000}) {
_batched_nms_vanilla(boxes, scores, idxs, iou_threshold)
} else {
_batched_nms_coordinate_trick(boxes, scores, idxs, iou_threshold)
}
}
根据检测框数量,自动选择两种不同的实现策略:
-
坐标偏移法:适用于少量检测框
fn _batched_nms_coordinate_trick(boxes: &Tensor, scores: &Tensor, idxs: &Tensor, iou_threshold: f32) -> Tensor { let max_coordinate = boxes.max(); let offsets = idxs * (max_coordinate + Tensor::ones([1], (tch::Kind::Float, boxes.device()))); let boxes_for_nms = boxes + offsets.unsqueeze(1); nms(boxes_for_nms, scores, iou_threshold) } -
逐类处理法:适用于大量检测框
fn _batched_nms_vanilla(boxes: &Tensor, scores: &Tensor, idxs: &Tensor, iou_threshold: f32) -> Tensor { let mut keep_mask = Tensor::zeros_like(scores).to_kind(tch::Kind::Bool); let unique = idxs.view(-1).unique_dim(0, false, false, false).0; for i in 0..unique.size()[0] { let curr_indices = Tensor::where_(&idxs.eq_tensor(&unique.i(i))).remove(0); let curr_keep_indices = nms(boxes.i(&curr_indices), &scores.i(&curr_indices), iou_threshold); keep_mask = keep_mask.index_fill(0, &curr_indices.i(&curr_keep_indices), 1); } let keep_indices = Tensor::where_(&keep_mask).remove(0); keep_indices.i(&scores.i(&keep_indices).sort(-1, true).1) }
性能考虑
实现中考虑了不同场景下的性能优化:
- 根据设备类型(CPU/GPU)设置不同的检测框数量阈值
- 小规模数据使用坐标偏移法减少计算开销
- 大规模数据使用逐类处理法避免内存问题
总结
本文展示了在tch-rs项目中实现NMS算法的完整方案,包括基础NMS、IoU计算以及批处理NMS。这些实现虽然可能不是最优性能,但为Rust生态中的目标检测任务提供了可行的解决方案。随着tch-rs项目的不断发展,未来可能会提供更优化的官方实现。
登录后查看全文
热门项目推荐
相关项目推荐
AutoGLM-Phone-9BAutoGLM-Phone-9B是基于AutoGLM构建的移动智能助手框架,依托多模态感知理解手机屏幕并执行自动化操作。Jinja00
Kimi-K2-ThinkingKimi K2 Thinking 是最新、性能最强的开源思维模型。从 Kimi K2 开始,我们将其打造为能够逐步推理并动态调用工具的思维智能体。通过显著提升多步推理深度,并在 200–300 次连续调用中保持稳定的工具使用能力,它在 Humanity's Last Exam (HLE)、BrowseComp 等基准测试中树立了新的技术标杆。同时,K2 Thinking 是原生 INT4 量化模型,具备 256k 上下文窗口,实现了推理延迟和 GPU 内存占用的无损降低。Python00
GLM-4.6V-FP8GLM-4.6V-FP8是GLM-V系列开源模型,支持128K上下文窗口,融合原生多模态函数调用能力,实现从视觉感知到执行的闭环。具备文档理解、图文生成、前端重构等功能,适用于云集群与本地部署,在同类参数规模中视觉理解性能领先。Jinja00
HunyuanOCRHunyuanOCR 是基于混元原生多模态架构打造的领先端到端 OCR 专家级视觉语言模型。它采用仅 10 亿参数的轻量化设计,在业界多项基准测试中取得了当前最佳性能。该模型不仅精通复杂多语言文档解析,还在文本检测与识别、开放域信息抽取、视频字幕提取及图片翻译等实际应用场景中表现卓越。00
GLM-ASR-Nano-2512GLM-ASR-Nano-2512 是一款稳健的开源语音识别模型,参数规模为 15 亿。该模型专为应对真实场景的复杂性而设计,在保持紧凑体量的同时,多项基准测试表现优于 OpenAI Whisper V3。Python00
GLM-TTSGLM-TTS 是一款基于大语言模型的高质量文本转语音(TTS)合成系统,支持零样本语音克隆和流式推理。该系统采用两阶段架构,结合了用于语音 token 生成的大语言模型(LLM)和用于波形合成的流匹配(Flow Matching)模型。 通过引入多奖励强化学习框架,GLM-TTS 显著提升了合成语音的表现力,相比传统 TTS 系统实现了更自然的情感控制。Python00
Spark-Formalizer-X1-7BSpark-Formalizer 是由科大讯飞团队开发的专用大型语言模型,专注于数学自动形式化任务。该模型擅长将自然语言数学问题转化为精确的 Lean4 形式化语句,在形式化语句生成方面达到了业界领先水平。Python00
最新内容推荐
MQTT 3.1.1协议中文版文档:物联网开发者的必备技术指南 Solidcam后处理文件下载与使用完全指南:提升CNC编程效率的必备资源 Python案例资源下载 - 从入门到精通的完整项目代码合集 TortoiseSVN 1.14.5.29465 中文版:高效版本控制的终极解决方案 CrystalIndex资源文件管理系统:高效索引与文件管理的最佳实践指南 QT连接阿里云MySQL数据库完整指南:从环境配置到问题解决 Windows Server 2016 .NET Framework 3.5 SXS文件下载与安装完整指南 Python开发者的macOS终极指南:VSCode安装配置全攻略 瀚高迁移工具migration-4.1.4:企业级数据库迁移的智能解决方案 STM32到GD32项目移植完全指南:从兼容性到实战技巧
项目优选
收起
deepin linux kernel
C
24
9
OpenHarmony documentation | OpenHarmony开发者文档
Dockerfile
402
3.12 K
Ascend Extension for PyTorch
Python
224
249
暂无简介
Dart
672
159
本项目是CANN提供的数学类基础计算算子库,实现网络在NPU上加速计算。
C++
663
315
React Native鸿蒙化仓库
JavaScript
262
324
Nop Platform 2.0是基于可逆计算理论实现的采用面向语言编程范式的新一代低代码开发平台,包含基于全新原理从零开始研发的GraphQL引擎、ORM引擎、工作流引擎、报表引擎、规则引擎、批处理引引擎等完整设计。nop-entropy是它的后端部分,采用java语言实现,可选择集成Spring框架或者Quarkus框架。中小企业可以免费商用
Java
9
1
🎉 (RuoYi)官方仓库 基于SpringBoot,Spring Security,JWT,Vue3 & Vite、Element Plus 的前后端分离权限管理系统
Vue
1.2 K
655
🔥LeetCode solutions in any programming language | 多种编程语言实现 LeetCode、《剑指 Offer(第 2 版)》、《程序员面试金典(第 6 版)》题解
Java
64
19
openGauss kernel ~ openGauss is an open source relational database management system
C++
160
219