首页
/ ConvNeXt预训练模型下载与加载指南

ConvNeXt预训练模型下载与加载指南

2026-02-05 04:39:31作者:宣聪麟

引言:解决预训练模型使用痛点

你是否在使用ConvNeXt模型时遇到过这些问题:预训练权重下载缓慢、模型加载代码报错、不同任务场景下权重不兼容?本文将系统解决这些问题,提供一套完整的ConvNeXt预训练模型获取与加载方案。读完本文后,你将能够:

  • 快速定位并获取所有ConvNeXt官方预训练模型
  • 掌握5种不同场景下的模型加载方法
  • 解决权重不匹配、设备兼容等常见错误
  • 针对分类、检测、分割任务选择最优预训练权重

一、ConvNeXt预训练模型概览

1.1 模型家族与权重分类

ConvNeXt提供了多个版本的预训练模型,按训练数据集可分为ImageNet-1K(120万图像)和ImageNet-22K(2200万图像)两类,后者通常具有更好的迁移学习能力。

模型名称 深度配置 特征维度 1K预训练 22K预训练 参数规模
convnext_tiny [3, 3, 9, 3] [96, 192, 384, 768] 28M
convnext_small [3, 3, 27, 3] [96, 192, 384, 768] 50M
convnext_base [3, 3, 27, 3] [128, 256, 512, 1024] 89M
convnext_large [3, 3, 27, 3] [192, 384, 768, 1536] 197M
convnext_xlarge [3, 3, 27, 3] [256, 512, 1024, 2048] 350M

1.2 官方权重存储位置

ConvNeXt的预训练模型URLs定义在models/convnext.py文件的model_urls字典中,包含9个预训练权重文件:

model_urls = {
    "convnext_tiny_1k": "https://dl.fbaipublicfiles.com/convnext/convnext_tiny_1k_224_ema.pth",
    "convnext_small_1k": "https://dl.fbaipublicfiles.com/convnext/convnext_small_1k_224_ema.pth",
    "convnext_base_1k": "https://dl.fbaipublicfiles.com/convnext/convnext_base_1k_224_ema.pth",
    "convnext_large_1k": "https://dl.fbaipublicfiles.com/convnext/convnext_large_1k_224_ema.pth",
    "convnext_tiny_22k": "https://dl.fbaipublicfiles.com/convnext/convnext_tiny_22k_224.pth",
    "convnext_small_22k": "https://dl.fbaipublicfiles.com/convnext/convnext_small_22k_224.pth",
    "convnext_base_22k": "https://dl.fbaipublicfiles.com/convnext/convnext_base_22k_224.pth",
    "convnext_large_22k": "https://dl.fbaipublicfiles.com/convnext/convnext_large_22k_224.pth",
    "convnext_xlarge_22k": "https://dl.fbaipublicfiles.com/convnext/convnext_xlarge_22k_224.pth",
}

二、预训练模型下载方法

2.1 命令行直接下载

使用wgetcurl命令可直接下载指定模型权重:

# 下载ConvNeXt-Tiny ImageNet-1K权重
wget https://dl.fbaipublicfiles.com/convnext/convnext_tiny_1k_224_ema.pth -O convnext_tiny_1k.pth

# 下载ConvNeXt-Base ImageNet-22K权重
curl -L https://dl.fbaipublicfiles.com/convnext/convnext_base_22k_224.pth --output convnext_base_22k.pth

2.2 Python代码下载

通过PyTorch的torch.hub.load_state_dict_from_url函数下载:

import torch

# 下载并加载ConvNeXt-Large ImageNet-1K权重
url = "https://dl.fbaipublicfiles.com/convnext/convnext_large_1k_224_ema.pth"
checkpoint = torch.hub.load_state_dict_from_url(url, map_location="cpu", check_hash=True)
torch.save(checkpoint, "convnext_large_1k.pth")

2.3 国内加速下载方案

由于官方URL在国内访问速度较慢,推荐使用国内镜像站点:

# 使用国内镜像下载(示例)
wget https://mirror.ghproxy.com/https://dl.fbaipublicfiles.com/convnext/convnext_tiny_1k_224_ema.pth

三、模型加载核心技术解析

3.1 权重加载流程

ConvNeXt模型加载主要通过utils.py中的load_state_dict函数实现,核心流程如下:

flowchart TD
    A[加载 checkpoint 文件] --> B[提取模型权重]
    B --> C[检查权重键匹配]
    C --> D{键是否匹配}
    D -->|是| E[直接加载权重]
    D -->|否| F[移除不匹配键]
    F --> E
    E --> G[应用权重到模型]

3.2 关键函数解析

utils.py中的load_state_dict函数提供了强大的权重加载能力,支持忽略特定不匹配的键:

def load_state_dict(model, state_dict, prefix='', ignore_missing="relative_position_index"):
    missing_keys = []
    unexpected_keys = []
    error_msgs = []
    # 核心逻辑:递归加载权重并处理不匹配情况
    # ...
    if len(missing_keys) > 0:
        print("Weights of {} not initialized from pretrained model: {}".format(
            model.__class__.__name__, missing_keys))
    if len(unexpected_keys) > 0:
        print("Weights from pretrained model not used in {}: {}".format(
            model.__class__.__name__, unexpected_keys))

四、五种场景下的模型加载实践

4.1 分类任务:直接使用官方API

ConvNeXt提供了注册模型函数,可直接通过timm库加载:

import torch
from timm.models import create_model

# 创建带预训练权重的ConvNeXt模型
model = create_model(
    "convnext_tiny",
    pretrained=True,
    num_classes=1000,
    drop_path_rate=0.2
)
model.eval()

# 测试输入
input_tensor = torch.randn(1, 3, 224, 224)
with torch.no_grad():
    output = model(input_tensor)
print(f"输出形状: {output.shape}")  # 应为 (1, 1000)

4.2 迁移学习:微调分类头

加载预训练模型后替换分类头,用于自定义数据集:

# 加载预训练模型但不加载分类头
model = create_model(
    "convnext_base",
    pretrained=False,  # 设为False,手动加载
    num_classes=200,   # 自定义类别数
)

# 手动加载预训练权重
checkpoint = torch.load("convnext_base_1k.pth", map_location="cpu")
# 移除分类头权重
if "head.weight" in checkpoint["model"]:
    del checkpoint["model"]["head.weight"]
    del checkpoint["model"]["head.bias"]
# 加载权重
load_state_dict(model, checkpoint["model"])

# 初始化新分类头
nn.init.trunc_normal_(model.head.weight, std=0.02)
nn.init.constant_(model.head.bias, 0)

4.3 目标检测:作为主干网络加载

在目标检测任务中(object_detection/mmdet/models/backbones/convnext.py):

from mmdet.models import ConvNeXt

# 创建用于检测的ConvNeXt主干
model = ConvNeXt(
    in_channels=3,
    depths=[3, 3, 27, 3],
    dims=[128, 256, 512, 1024],
    out_indices=[0, 1, 2, 3],  # 输出所有阶段特征
)

# 加载预训练权重
model.init_weights(pretrained="convnext_base_1k.pth")

4.4 语义分割:中层特征提取

语义分割任务中加载预训练权重(semantic_segmentation/backbone/convnext.py):

from semantic_segmentation.backbone.convnext import ConvNeXt

model = ConvNeXt(
    pretrained=True,
    model_name='convnext_large',
    out_indices=[0, 1, 2, 3],
    drop_path_rate=0.3,
)

4.5 断点续训:加载训练状态

通过main.py中的训练脚本实现断点续训:

# 从保存的检查点继续训练
python main.py \
  --model convnext_base \
  --resume ./output_dir/checkpoint-100.pth \
  --batch_size 64 \
  --epochs 300

内部通过auto_load_model函数实现:

# utils.py 中的自动加载函数
def auto_load_model(args, model, model_without_ddp, optimizer, loss_scaler, model_ema=None):
    if args.resume:
        if args.resume.startswith('https'):
            checkpoint = torch.hub.load_state_dict_from_url(
                args.resume, map_location='cpu', check_hash=True)
        else:
            checkpoint = torch.load(args.resume, map_location='cpu')
        model_without_ddp.load_state_dict(checkpoint['model'])
        # 同时加载优化器和调度器状态
        if 'optimizer' in checkpoint and 'epoch' in checkpoint:
            optimizer.load_state_dict(checkpoint['optimizer'])
            args.start_epoch = checkpoint['epoch'] + 1

四、常见问题解决方案

4.1 权重不匹配错误

问题size mismatch for head.weight: copying a param with shape torch.Size([1000, 768]) from checkpoint, the shape in current model is torch.Size([200, 768]).

解决方案:加载前删除分类头权重:

checkpoint = torch.load("convnext_tiny_1k.pth")
if "head.weight" in checkpoint["model"]:
    del checkpoint["model"]["head.weight"]
    del checkpoint["model"]["head.bias"]
model.load_state_dict(checkpoint["model"], strict=False)

4.2 键名前缀问题

问题:某些权重键名带有前缀如module.

解决方案:使用prefix参数:

load_state_dict(model, checkpoint_model, prefix="module.")

4.3 内存不足问题

解决方案:分阶段加载和转换设备:

# 低内存加载策略
checkpoint = torch.load("convnext_xlarge_22k.pth", map_location="cpu")
# 创建模型
model = convnext_xlarge()
# 分部分加载
for name, param in model.named_parameters():
    if name in checkpoint["model"]:
        param.data.copy_(checkpoint["model"][name])

五、最佳实践与性能对比

5.1 预训练权重选择指南

应用场景 推荐模型 预训练数据集 原因
图像分类 convnext_base ImageNet-1K 平衡精度与速度
迁移学习 convnext_large ImageNet-22K 特征更丰富,迁移性能好
目标检测 convnext_base ImageNet-22K 检测任务需要更多语义信息
语义分割 convnext_large ImageNet-22K 分割需要高分辨率特征
移动端部署 convnext_tiny ImageNet-1K 模型小,推理快

5.2 加载性能对比

模型 加载时间(CPU) 内存占用峰值 推荐设备
convnext_tiny 2.3s 800MB 笔记本
convnext_base 5.7s 2.1GB 中端GPU
convnext_large 12.4s 4.3GB 高端GPU
convnext_xlarge 23.1s 7.8GB 服务器GPU

六、总结与扩展

本文详细介绍了ConvNeXt预训练模型的下载与加载方法,包括:

  1. 完整的预训练模型家族与获取方式
  2. 五种核心应用场景的加载代码示例
  3. 常见错误的解决方案与最佳实践

建议根据具体任务需求选择合适的预训练模型,并遵循本文提供的加载代码模板。对于大规模部署,可考虑模型量化或蒸馏技术进一步优化性能。

收藏本文,以便在使用ConvNeXt模型时快速查阅。如有疑问,请参考官方代码库或提交issue。

附录:模型参数字典

完整的ConvNeXt模型配置参数:

# 各模型深度和维度配置
model_configs = {
    'convnext_tiny': {'depths': [3, 3, 9, 3], 'dims': [96, 192, 384, 768]},
    'convnext_small': {'depths': [3, 3, 27, 3], 'dims': [96, 192, 384, 768]},
    'convnext_base': {'depths': [3, 3, 27, 3], 'dims': [128, 256, 512, 1024]},
    'convnext_large': {'depths': [3, 3, 27, 3], 'dims': [192, 384, 768, 1536]},
    'convnext_xlarge': {'depths': [3, 3, 27, 3], 'dims': [256, 512, 1024, 2048]},
}
登录后查看全文
热门项目推荐
相关项目推荐