首页
/ tch-rs项目中实现Rust版NMS算法的探索

tch-rs项目中实现Rust版NMS算法的探索

2025-06-11 13:50:48作者:韦蓉瑛

在将YOLO目标检测器从Python迁移到Rust的过程中,一个关键挑战是如何实现torchvision.ops.nms功能。本文将详细介绍如何在tch-rs项目中实现非极大值抑制(NMS)算法及其相关功能。

NMS算法的重要性

非极大值抑制(Non-Maximum Suppression)是目标检测后处理中的关键步骤,它用于消除冗余的检测框,保留最有可能代表真实目标的检测结果。在Python生态中,这一功能通常由torchvision.ops.nms提供。

Rust实现方案

由于tch-rs项目尚未提供torchvision功能的完整绑定,开发者需要自行实现NMS相关算法。以下是完整的Rust实现方案:

基础NMS实现

fn nms(boxes: Tensor, scores: &Tensor, iou_threshold: f32) -> Tensor {
    let mut sorting: Vec<i64> = scores.argsort(0, false).try_into().unwrap();
    let mut keep: Vec<i64> = Vec::new();
    while let Some(idx) = sorting.pop() {
        keep.push(idx);
        for i in (0..sorting.len()).rev() {
            if iou(&boxes.i(idx), &boxes.i(sorting[i])).double_value(&[]) > iou_threshold as f64 {
                _ = sorting.remove(i);
            }
        }
    }
    Tensor::try_from(keep).unwrap().to_device(boxes.device())
}

这个实现采用了经典的NMS算法流程:

  1. 根据置信度分数对检测框进行排序
  2. 选择分数最高的检测框作为保留结果
  3. 移除与该检测框IoU超过阈值的其他检测框
  4. 重复上述过程直到处理完所有检测框

IoU计算

交并比(Intersection over Union)是NMS算法的核心计算:

fn iou(box1: &Tensor, box2: &Tensor) -> Tensor {
    let zero = Tensor::zeros_like(&box1.i(0));
    let b1_area = (box1.i(2) - box1.i(0) + 1) * (box1.i(3) - box1.i(1) + 1);
    let b2_area = (box2.i(2) - box2.i(0) + 1) * (box2.i(3) - box2.i(1) + 1);
    let i_xmin = box1.i(0).max_other(&box2.i(0));
    let i_xmax = box1.i(2).min_other(&box2.i(2));
    let i_ymin = box1.i(1).max_other(&box2.i(1));
    let i_ymax = box1.i(3).min_other(&box2.i(3));
    let i_area = (i_xmax - i_xmin + 1).max_other(&zero) * (i_ymax - i_ymin + 1).max_other(&zero);
    &i_area / (b1_area + b2_area - &i_area)
}

批处理NMS实现

针对多类别检测任务,需要实现批处理版本的NMS:

fn batched_nms(boxes: &Tensor, scores: &Tensor, idxs: &Tensor, iou_threshold: f32) -> Tensor {
    if boxes.numel() > (if boxes.device() == tch::Device::Cpu {4000} else {20000}) {
        _batched_nms_vanilla(boxes, scores, idxs, iou_threshold)
    } else {
        _batched_nms_coordinate_trick(boxes, scores, idxs, iou_threshold)
    }
}

根据检测框数量,自动选择两种不同的实现策略:

  1. 坐标偏移法:适用于少量检测框

    fn _batched_nms_coordinate_trick(boxes: &Tensor, scores: &Tensor, idxs: &Tensor, iou_threshold: f32) -> Tensor {
        let max_coordinate = boxes.max();
        let offsets = idxs * (max_coordinate + Tensor::ones([1], (tch::Kind::Float, boxes.device())));
        let boxes_for_nms = boxes + offsets.unsqueeze(1);
        nms(boxes_for_nms, scores, iou_threshold)
    }
    
  2. 逐类处理法:适用于大量检测框

    fn _batched_nms_vanilla(boxes: &Tensor, scores: &Tensor, idxs: &Tensor, iou_threshold: f32) -> Tensor {
        let mut keep_mask = Tensor::zeros_like(scores).to_kind(tch::Kind::Bool);
        let unique = idxs.view(-1).unique_dim(0, false, false, false).0;
        for i in 0..unique.size()[0] {
            let curr_indices = Tensor::where_(&idxs.eq_tensor(&unique.i(i))).remove(0);
            let curr_keep_indices = nms(boxes.i(&curr_indices), &scores.i(&curr_indices), iou_threshold);
            keep_mask = keep_mask.index_fill(0, &curr_indices.i(&curr_keep_indices), 1);
        }
        let keep_indices = Tensor::where_(&keep_mask).remove(0);
        keep_indices.i(&scores.i(&keep_indices).sort(-1, true).1)
    }
    

性能考虑

实现中考虑了不同场景下的性能优化:

  1. 根据设备类型(CPU/GPU)设置不同的检测框数量阈值
  2. 小规模数据使用坐标偏移法减少计算开销
  3. 大规模数据使用逐类处理法避免内存问题

总结

本文展示了在tch-rs项目中实现NMS算法的完整方案,包括基础NMS、IoU计算以及批处理NMS。这些实现虽然可能不是最优性能,但为Rust生态中的目标检测任务提供了可行的解决方案。随着tch-rs项目的不断发展,未来可能会提供更优化的官方实现。

登录后查看全文
热门项目推荐
相关项目推荐

项目优选

收起
kernelkernel
deepin linux kernel
C
32
16
pytorchpytorch
Ascend Extension for PyTorch
Python
746
926
flutter_flutterflutter_flutter
本仓库是 Flutter SDK 与 Flutter Engine 的 OpenHarmony 适配版本,由 CPF-Flutter 团队维护。开发者可使用熟悉的 Flutter 技术栈开发 OpenHarmony 应用,3.35.7 及以后的适配版本可基于本仓库源码构建支持 OpenHarmony 的 Flutter Engine。
Dart
1.02 K
266
docsdocs
暂无描述
Dockerfile
771
5.02 K
ops-transformerops-transformer
本项目是CANN提供的transformer类大模型算子库,实现网络在NPU上加速计算。
C++
865
1.96 K
leetcodeleetcode
🔥LeetCode solutions in any programming language | 多种编程语言实现 LeetCode、《剑指 Offer(第 2 版)》、《程序员面试金典(第 6 版)》题解
Java
70
22
atomcodeatomcode
Claude Code 的开源替代方案。连接任意大模型,编辑代码,运行命令,自动验证 — 全自动执行。用 Rust 构建,极致性能。 | An open-source alternative to Claude Code. Connect any LLM, edit code, run commands, and verify changes — autonomously. Built in Rust for speed. Get Started
Rust
1.94 K
201
ops-nnops-nn
本项目是CANN提供的神经网络类计算算子库,实现网络在NPU上加速计算。
C++
693
1.36 K
kernelkernel
openEuler内核是openEuler操作系统的核心,既是系统性能与稳定性的基石,也是连接处理器、设备与服务的桥梁。
C
461
455
Cangjie-ExamplesCangjie-Examples
本仓将收集和展示高质量的仓颉示例代码,欢迎大家投稿,让全世界看到您的妙趣设计,也让更多人通过您的编码理解和喜爱仓颉语言。
C
458
5.24 K