首页
/ Hi-FT/ERD项目:自定义模型在标准数据集上的训练指南

Hi-FT/ERD项目:自定义模型在标准数据集上的训练指南

2025-06-19 01:44:25作者:秋泉律Samson

前言

在计算机视觉领域,目标检测和实例分割是两项基础而重要的任务。Hi-FT/ERD项目提供了一个强大的框架,允许研究人员和开发者基于标准数据集训练自定义模型。本文将详细介绍如何在标准数据集上训练、测试和推理自定义模型,以Cityscapes数据集上训练自定义的Cascade Mask R-CNN R50模型为例。

准备工作

1. 数据集准备

首先需要准备标准数据集。以Cityscapes数据集为例,建议将数据集根目录链接到指定位置。数据集目录结构应如下:

data
├── cityscapes
│   ├── annotations
│   ├── leftImg8bit
│   │   ├── train
│   │   ├── val
│   ├── gtFine
│   │   ├── train
│   │   ├── val

Cityscapes数据集需要转换为COCO格式,使用提供的转换脚本:

pip install cityscapesscripts
python tools/dataset_converters/cityscapes.py ./data/cityscapes --nproc 8 --out-dir ./data/cityscapes/annotations

2. 预训练模型

当前配置使用COCO预训练权重进行初始化。建议提前下载预训练模型,以避免训练开始时因网络问题导致的错误。

自定义模型实现

1. 定义新模块

假设我们要用AugFPN替换默认的FPN作为neck部分。首先需要创建新的neck模块文件:

import torch.nn as nn
from mmdet.registry import MODELS

@MODELS.register_module()
class AugFPN(nn.Module):
    def __init__(self, in_channels, out_channels, num_outs,
                 start_level=0, end_level=-1, add_extra_convs=False):
        pass
    
    def forward(self, inputs):
        pass

2. 模块导入

有两种方式导入新模块:

  1. __init__.py中添加导入语句
  2. 在配置文件中通过custom_imports指定

3. 修改配置文件

在配置文件中指定使用新的neck模块:

neck=dict(
    type='AugFPN',
    in_channels=[256, 512, 1024, 2048],
    out_channels=256,
    num_outs=5)

配置文件准备

完整的配置文件示例如下:

_base_ = [
    '../_base_/models/cascade-mask-rcnn_r50_fpn.py',
    '../_base_/datasets/cityscapes_instance.py', 
    '../_base_/default_runtime.py'
]

model = dict(
    backbone=dict(init_cfg=None),
    neck=dict(
        type='AugFPN',
        in_channels=[256, 512, 1024, 2048],
        out_channels=256,
        num_outs=5),
    roi_head=dict(
        bbox_head=[...],  # 修改类别数为8
        mask_head=dict(num_classes=8)))

train_pipeline = [
    ...  # 包含AutoAugment配置
]

train_dataloader = dict(batch_size=1, num_workers=3)
optim_wrapper = dict(...)
param_scheduler = [...]
train_cfg = dict(max_epochs=10, val_interval=1)
load_from = '预训练模型路径'

模型训练

准备好配置文件后,使用以下命令开始训练:

python tools/train.py configs/cityscapes/cascade-mask-rcnn_r50_augfpn_autoaug-10e_cityscapes.py

测试与推理

训练完成后,可以使用以下命令测试模型性能:

python tools/test.py configs/cityscapes/cascade-mask-rcnn_r50_augfpn_autoaug-10e_cityscapes.py work_dirs/cascade-mask-rcnn_r50_augfpn_autoaug-10e_cityscapes/epoch_10.pth

总结

本文详细介绍了在Hi-FT/ERD框架下,如何在标准数据集上训练自定义模型的完整流程。从数据集准备、自定义模块实现、配置文件修改到最终的训练和测试,每个步骤都提供了详细的说明。这种灵活的框架设计使得研究人员能够轻松尝试新的网络结构和训练策略,加速计算机视觉算法的研发进程。

对于更高级的自定义需求,如实现新的backbone、head或loss函数,以及自定义训练策略等,可以参考项目提供的其他高级指南。

登录后查看全文
热门项目推荐