tch-rs项目中实现Rust版NMS算法的探索
2025-06-11 10:59:01作者:韦蓉瑛
在将YOLO目标检测器从Python迁移到Rust的过程中,一个关键挑战是如何实现torchvision.ops.nms功能。本文将详细介绍如何在tch-rs项目中实现非极大值抑制(NMS)算法及其相关功能。
NMS算法的重要性
非极大值抑制(Non-Maximum Suppression)是目标检测后处理中的关键步骤,它用于消除冗余的检测框,保留最有可能代表真实目标的检测结果。在Python生态中,这一功能通常由torchvision.ops.nms提供。
Rust实现方案
由于tch-rs项目尚未提供torchvision功能的完整绑定,开发者需要自行实现NMS相关算法。以下是完整的Rust实现方案:
基础NMS实现
fn nms(boxes: Tensor, scores: &Tensor, iou_threshold: f32) -> Tensor {
let mut sorting: Vec<i64> = scores.argsort(0, false).try_into().unwrap();
let mut keep: Vec<i64> = Vec::new();
while let Some(idx) = sorting.pop() {
keep.push(idx);
for i in (0..sorting.len()).rev() {
if iou(&boxes.i(idx), &boxes.i(sorting[i])).double_value(&[]) > iou_threshold as f64 {
_ = sorting.remove(i);
}
}
}
Tensor::try_from(keep).unwrap().to_device(boxes.device())
}
这个实现采用了经典的NMS算法流程:
- 根据置信度分数对检测框进行排序
- 选择分数最高的检测框作为保留结果
- 移除与该检测框IoU超过阈值的其他检测框
- 重复上述过程直到处理完所有检测框
IoU计算
交并比(Intersection over Union)是NMS算法的核心计算:
fn iou(box1: &Tensor, box2: &Tensor) -> Tensor {
let zero = Tensor::zeros_like(&box1.i(0));
let b1_area = (box1.i(2) - box1.i(0) + 1) * (box1.i(3) - box1.i(1) + 1);
let b2_area = (box2.i(2) - box2.i(0) + 1) * (box2.i(3) - box2.i(1) + 1);
let i_xmin = box1.i(0).max_other(&box2.i(0));
let i_xmax = box1.i(2).min_other(&box2.i(2));
let i_ymin = box1.i(1).max_other(&box2.i(1));
let i_ymax = box1.i(3).min_other(&box2.i(3));
let i_area = (i_xmax - i_xmin + 1).max_other(&zero) * (i_ymax - i_ymin + 1).max_other(&zero);
&i_area / (b1_area + b2_area - &i_area)
}
批处理NMS实现
针对多类别检测任务,需要实现批处理版本的NMS:
fn batched_nms(boxes: &Tensor, scores: &Tensor, idxs: &Tensor, iou_threshold: f32) -> Tensor {
if boxes.numel() > (if boxes.device() == tch::Device::Cpu {4000} else {20000}) {
_batched_nms_vanilla(boxes, scores, idxs, iou_threshold)
} else {
_batched_nms_coordinate_trick(boxes, scores, idxs, iou_threshold)
}
}
根据检测框数量,自动选择两种不同的实现策略:
-
坐标偏移法:适用于少量检测框
fn _batched_nms_coordinate_trick(boxes: &Tensor, scores: &Tensor, idxs: &Tensor, iou_threshold: f32) -> Tensor { let max_coordinate = boxes.max(); let offsets = idxs * (max_coordinate + Tensor::ones([1], (tch::Kind::Float, boxes.device()))); let boxes_for_nms = boxes + offsets.unsqueeze(1); nms(boxes_for_nms, scores, iou_threshold) }
-
逐类处理法:适用于大量检测框
fn _batched_nms_vanilla(boxes: &Tensor, scores: &Tensor, idxs: &Tensor, iou_threshold: f32) -> Tensor { let mut keep_mask = Tensor::zeros_like(scores).to_kind(tch::Kind::Bool); let unique = idxs.view(-1).unique_dim(0, false, false, false).0; for i in 0..unique.size()[0] { let curr_indices = Tensor::where_(&idxs.eq_tensor(&unique.i(i))).remove(0); let curr_keep_indices = nms(boxes.i(&curr_indices), &scores.i(&curr_indices), iou_threshold); keep_mask = keep_mask.index_fill(0, &curr_indices.i(&curr_keep_indices), 1); } let keep_indices = Tensor::where_(&keep_mask).remove(0); keep_indices.i(&scores.i(&keep_indices).sort(-1, true).1) }
性能考虑
实现中考虑了不同场景下的性能优化:
- 根据设备类型(CPU/GPU)设置不同的检测框数量阈值
- 小规模数据使用坐标偏移法减少计算开销
- 大规模数据使用逐类处理法避免内存问题
总结
本文展示了在tch-rs项目中实现NMS算法的完整方案,包括基础NMS、IoU计算以及批处理NMS。这些实现虽然可能不是最优性能,但为Rust生态中的目标检测任务提供了可行的解决方案。随着tch-rs项目的不断发展,未来可能会提供更优化的官方实现。
登录后查看全文
热门项目推荐
- Ggpt-oss-20bgpt-oss-20b —— 适用于低延迟和本地或特定用途的场景(210 亿参数,其中 36 亿活跃参数)Jinja00
- Ggpt-oss-120bgpt-oss-120b是OpenAI开源的高性能大模型,专为复杂推理任务和智能代理场景设计。这款拥有1170亿参数的混合专家模型采用原生MXFP4量化技术,可单卡部署在H100 GPU上运行。它支持可调节的推理强度(低/中/高),完整思维链追溯,并内置函数调用、网页浏览等智能体能力。模型遵循Apache 2.0许可,允许自由商用和微调,特别适合需要生产级推理能力的开发者。通过Transformers、vLLM等主流框架即可快速调用,还能在消费级硬件通过Ollama运行,为AI应用开发提供强大而灵活的基础设施。【此简介由AI生成】Jinja00
- QQwen3-Coder-480B-A35B-InstructQwen3-Coder-480B-A35B-Instruct是当前最强大的开源代码模型之一,专为智能编程与工具调用设计。它拥有4800亿参数,支持256K长上下文,并可扩展至1M,特别擅长处理复杂代码库任务。模型在智能编码、浏览器操作等任务上表现卓越,性能媲美Claude Sonnet。支持多种平台工具调用,内置优化的函数调用格式,能高效完成代码生成与逻辑推理。推荐搭配温度0.7、top_p 0.8等参数使用,单次输出最高支持65536个token。无论是快速排序算法实现,还是数学工具链集成,都能流畅执行,为开发者提供接近人类水平的编程辅助体验。【此简介由AI生成】Python00
- GGLM-4.5-AirGLM-4.5 系列模型是专为智能体设计的基础模型。GLM-4.5拥有 3550 亿总参数量,其中 320 亿活跃参数;GLM-4.5-Air采用更紧凑的设计,拥有 1060 亿总参数量,其中 120 亿活跃参数。GLM-4.5模型统一了推理、编码和智能体能力,以满足智能体应用的复杂需求Jinja00
hello-uniapp
uni-app 是一个使用 Vue.js 开发所有前端应用的框架,开发者编写一套代码,可发布到iOS、Android、鸿蒙Next、Web(响应式)、以及各种小程序(微信/支付宝/百度/抖音/飞书/QQ/快手/钉钉/淘宝/京东/小红书)、快应用、鸿蒙元服务等多个平台Vue00GitCode百大开源项目
GitCode百大计划旨在表彰GitCode平台上积极推动项目社区化,拥有广泛影响力的G-Star项目,入选项目不仅代表了GitCode开源生态的蓬勃发展,也反映了当下开源行业的发展趋势。05GOT-OCR-2.0-hf
阶跃星辰StepFun推出的GOT-OCR-2.0-hf是一款强大的多语言OCR开源模型,支持从普通文档到复杂场景的文字识别。它能精准处理表格、图表、数学公式、几何图形甚至乐谱等特殊内容,输出结果可通过第三方工具渲染成多种格式。模型支持1024×1024高分辨率输入,具备多页批量处理、动态分块识别和交互式区域选择等创新功能,用户可通过坐标或颜色指定识别区域。基于Apache 2.0协议开源,提供Hugging Face演示和完整代码,适用于学术研究到工业应用的广泛场景,为OCR领域带来突破性解决方案。00openHiTLS
旨在打造算法先进、性能卓越、高效敏捷、安全可靠的密码套件,通过轻量级、可剪裁的软件技术架构满足各行业不同场景的多样化要求,让密码技术应用更简单,同时探索后量子等先进算法创新实践,构建密码前沿技术底座!C0255Yi-Coder
Yi Coder 编程模型,小而强大的编程助手HTML013RuoYi-Cloud-Plus
微服务管理系统 重写RuoYi-Cloud所有功能 整合 SpringCloudAlibaba、Dubbo3.0、Sa-Token、Mybatis-Plus、MQ、Warm-Flow工作流、ES、Docker 全方位升级 定期同步Java014
热门内容推荐
最新内容推荐
左手Annotators,右手GPT-4:企业AI战略的“开源”与“闭源”之辩 左手controlnet-openpose-sdxl-1.0,右手GPT-4:企业AI战略的“开源”与“闭源”之辩 左手ERNIE-4.5-VL-424B-A47B-Paddle,右手GPT-4:企业AI战略的“开源”与“闭源”之辩 左手m3e-base,右手GPT-4:企业AI战略的“开源”与“闭源”之辩 左手SDXL-Lightning,右手GPT-4:企业AI战略的“开源”与“闭源”之辩 左手wav2vec2-base-960h,右手GPT-4:企业AI战略的“开源”与“闭源”之辩 左手nsfw_image_detection,右手GPT-4:企业AI战略的“开源”与“闭源”之辩 左手XTTS-v2,右手GPT-4:企业AI战略的“开源”与“闭源”之辩 左手whisper-large-v3,右手GPT-4:企业AI战略的“开源”与“闭源”之辩 左手flux-ip-adapter,右手GPT-4:企业AI战略的“开源”与“闭源”之辩
项目优选
收起

🎉 (RuoYi)官方仓库 基于SpringBoot,Spring Security,JWT,Vue3 & Vite、Element Plus 的前后端分离权限管理系统
Vue
763
475

React Native鸿蒙化仓库
C++
150
241

本仓将收集和展示高质量的仓颉示例代码,欢迎大家投稿,让全世界看到您的妙趣设计,也让更多人通过您的编码理解和喜爱仓颉语言。
Cangjie
318
1.05 K

一个轻量级 java 权限认证框架,让鉴权变得简单、优雅! —— 登录认证、权限认证、分布式Session会话、微服务网关鉴权、SSO 单点登录、OAuth2.0 统一认证
Java
73
13

🔥🔥🔥ShopXO企业级免费开源商城系统,可视化DIY拖拽装修、包含PC、H5、多端小程序(微信+支付宝+百度+头条&抖音+QQ+快手)、APP、多仓库、多商户、多门店、IM客服、进销存,遵循MIT开源协议发布、基于ThinkPHP8框架研发
JavaScript
85
15

本仓将收集和展示仓颉鸿蒙应用示例代码,欢迎大家投稿,在仓颉鸿蒙社区展现你的妙趣设计!
Cangjie
377
361

一款跨平台的 Markdown AI 笔记软件,致力于使用 AI 建立记录和写作的桥梁。
TSX
79
2

旨在打造算法先进、性能卓越、高效敏捷、安全可靠的密码套件,通过轻量级、可剪裁的软件技术架构满足各行业不同场景的多样化要求,让密码技术应用更简单,同时探索后量子等先进算法创新实践,构建密码前沿技术底座!
C
128
255

为仓颉编程语言开发者打造活跃、开放、高质量的社区环境
Markdown
1.04 K
0

一个高性能、可扩展、轻量、省心的仓颉Web框架。Rest, 宏路由,Json, 中间件,参数绑定与校验,文件上传下载,MCP......
Cangjie
78
9