首页
/ Hi-FT/ERD项目训练指南:从标准数据集到自定义数据集

Hi-FT/ERD项目训练指南:从标准数据集到自定义数据集

2025-06-19 14:30:27作者:裘旻烁

项目概述

Hi-FT/ERD是一个基于MMDetection框架的目标检测与实例分割项目,提供了丰富的预训练模型和灵活的配置选项。本文将详细介绍如何在标准数据集和自定义数据集上进行模型训练,帮助用户快速上手项目并应用于实际场景。

标准数据集训练

准备工作

在开始训练前,需要确保已经准备好标准数据集(如COCO)。数据集应按照项目要求的格式组织,通常包括图像文件和对应的标注文件。

重要提示:某些配置文件(如configs/cityscapes下的配置)使用COCO预训练权重进行初始化。如果网络连接不稳定,建议提前下载好预训练模型以避免训练初期出现错误。

学习率自动缩放机制

项目支持学习率自动缩放功能,这是基于线性扩展规则实现的。默认配置针对8块GPU(每块2张图像,总batch size=16)设置。关键点包括:

  1. 自动缩放功能默认关闭,需通过--auto-scale-lr参数启用
  2. 配置文件中auto_scale_lr.base_batch_size定义了基准batch size
  3. 不同配置文件的默认batch size可能不同,可通过文件名识别(如_NxM_表示N GPU×M图像)

使用示例:

python tools/train.py config_file --auto-scale-lr

训练方式

单GPU训练

基本命令格式:

python tools/train.py config_file [可选参数]

常用参数说明:

  • --work-dir:指定工作目录
  • --resume-from:从检查点恢复训练(保留优化器状态和迭代次数)
  • --no-validate:关闭训练过程中的验证(不推荐)

CPU训练

虽然支持,但由于性能问题仅建议用于调试:

export CUDA_VISIBLE_DEVICES=-1
python tools/train.py config_file

多GPU训练

使用分布式训练脚本:

bash ./tools/dist_train.sh config_file GPU数量 [可选参数]

多任务并行时需指定不同端口避免冲突:

CUDA_VISIBLE_DEVICES=0,1,2,3 PORT=29500 ./tools/dist_train.sh config_file 4
CUDA_VISIBLE_DEVICES=4,5,6,7 PORT=29501 ./tools/dist_train.sh config_file 4

多机训练

通过以太网连接的多机训练命令: 第一台机器:

NNODES=2 NODE_RANK=0 PORT=MASTER_PORT MASTER_ADDR=MASTER_ADDR sh tools/dist_train.sh config_file GPUS

第二台机器:

NNODES=2 NODE_RANK=1 PORT=MASTER_PORT MASTER_ADDR=MASTER_ADDR sh tools/dist_train.sh config_file GPUS

Slurm集群管理

在Slurm管理的集群上训练:

GPUS=16 ./tools/slurm_train.sh 分区名称 任务名称 配置文件 工作目录

端口设置建议通过--options参数实现,避免修改原始配置文件:

--options 'dist_params.port=29500'

自定义数据集训练

数据集准备

项目支持三种自定义数据集方式:

  1. 转换为COCO格式(推荐)
  2. 转换为中间格式
  3. 实现全新数据集类

以balloon数据集为例,转换到COCO格式的关键步骤包括:

  1. 解析原始标注文件
  2. 构建COCO格式所需的images、annotations和categories字段
  3. 处理多边形标注和边界框信息

转换后的标注文件示例结构:

{
    "images": [{"id": 0, "file_name": "image1.jpg", ...}],
    "annotations": [{"image_id": 0, "category_id": 0, ...}],
    "categories": [{"id": 0, "name": "balloon"}]
}

配置文件调整

基于现有配置修改是最高效的方式。以Mask R-CNN为例:

  1. 修改基础配置路径
  2. 调整模型head中的类别数
  3. 配置数据集路径和元信息
  4. 设置数据加载器和评估器

关键配置示例:

_base_ = '../mask_rcnn/base_config.py'

model = dict(
    roi_head=dict(
        bbox_head=dict(num_classes=1),  # 修改为自定义类别数
        mask_head=dict(num_classes=1)))

metainfo = {
    'classes': ('balloon',),  # 自定义类别名称
    'palette': [(220, 20, 60)]  # 可视化颜色
}

训练与评估

启动训练:

python tools/train.py configs/balloon/custom_config.py

模型测试:

python tools/test.py configs/balloon/custom_config.py work_dirs/checkpoint.pth

最佳实践建议

  1. 数据准备:确保标注质量,特别是边界框和多边形标注的准确性
  2. 学习率设置:对于小数据集,建议使用较小的学习率
  3. 预训练权重:尽量使用与目标领域相近的预训练模型
  4. 验证频率:根据数据集大小调整验证间隔,大数据集可适当减少验证频率
  5. 可视化分析:定期检查训练日志和预测结果,及时发现问题

通过本指南,用户应该能够顺利地在Hi-FT/ERD项目上开展从标准数据集到自定义数据集的模型训练工作。根据实际需求调整配置和参数,可以获得更好的模型性能。

登录后查看全文

项目优选

收起
leetcodeleetcode
🔥LeetCode solutions in any programming language | 多种编程语言实现 LeetCode、《剑指 Offer(第 2 版)》、《程序员面试金典(第 6 版)》题解
Java
51
15
RuoYi-Vue3RuoYi-Vue3
🎉 (RuoYi)官方仓库 基于SpringBoot,Spring Security,JWT,Vue3 & Vite、Element Plus 的前后端分离权限管理系统
Vue
566
410
ohos_react_nativeohos_react_native
React Native鸿蒙化仓库
C++
125
208
openGauss-serveropenGauss-server
openGauss kernel ~ openGauss is an open source relational database management system
C++
75
145
cherry-studiocherry-studio
🍒 Cherry Studio 是一款支持多个 LLM 提供商的桌面客户端
TypeScript
430
38
openHiTLSopenHiTLS
旨在打造算法先进、性能卓越、高效敏捷、安全可靠的密码套件,通过轻量级、可剪裁的软件技术架构满足各行业不同场景的多样化要求,让密码技术应用更简单,同时探索后量子等先进算法创新实践,构建密码前沿技术底座!
C
98
253
MateChatMateChat
前端智能化场景解决方案UI库,轻松构建你的AI应用,我们将持续完善更新,欢迎你的使用与建议。 官网地址:https://matechat.gitcode.com
693
91
folibfolib
FOLib 是一个为Ai研发而生的、全语言制品库和供应链服务平台
Java
42
2
CS-BooksCS-Books
🔥🔥超过1000本的计算机经典书籍、个人笔记资料以及本人在各平台发表文章中所涉及的资源等。书籍资源包括C/C++、Java、Python、Go语言、数据结构与算法、操作系统、后端架构、计算机系统知识、数据库、计算机网络、设计模式、前端、汇编以及校招社招各种面经~
97
13
Cangjie-ExamplesCangjie-Examples
本仓将收集和展示高质量的仓颉示例代码,欢迎大家投稿,让全世界看到您的妙趣设计,也让更多人通过您的编码理解和喜爱仓颉语言。
Cangjie
298
1.03 K