首页
/ 人脸关键点检测:MMPose 68点/98点模型训练教程

人脸关键点检测:MMPose 68点/98点模型训练教程

2026-02-04 04:57:54作者:宗隆裙

1. 环境准备:从零搭建训练框架

1.1 系统要求

pie
    title 推荐开发环境配置
    "Ubuntu 20.04" : 45
    "Python 3.8+" : 25
    "CUDA 11.3+" : 30

1.2 快速安装步骤

# 克隆仓库
git clone https://gitcode.com/GitHub_Trending/mm/mmpose
cd mmpose

# 创建虚拟环境
conda create -n mmpose python=3.8 -y
conda activate mmpose

# 安装PyTorch (根据CUDA版本调整)
pip install torch==1.10.1+cu113 torchvision==0.11.2+cu113 torchaudio==0.10.1 -f https://download.pytorch.org/whl/cu113/torch_stable.html

# 安装MMPose依赖
pip install -U openmim
mim install mmengine
mim install "mmcv>=2.0.0"
mim install "mmdet>=3.0.0"
pip install -r requirements.txt

2. 数据集准备:构建标注数据体系

2.1 主流人脸关键点数据集对比

数据集 关键点数量 样本数 标注精度 适用场景
300W 68点 600张 ★★★★☆ 通用人脸对齐
WFLW 98点 10000张 ★★★★★ 遮挡/表情变化
COFW 29点 5000张 ★★★☆☆ 遮挡鲁棒性测试
LaPa 106点 22000张 ★★★★☆ 人脸 parsing 融合

2.2 300W数据集部署(68点标准)

# 创建数据目录
mkdir -p data/300w/images
mkdir -p data/300w/annotations

# 下载数据集(需手动获取授权)
wget https://openmmlab.oss-cn-hangzhou.aliyuncs.com/mmpose/datasets/300w_annotations.tar
tar -xf 300w_annotations.tar -C data/300w/annotations

# 目录结构验证
tree data/300w -L 2
# data/300w
# ├── annotations
# │   ├── face_landmarks_300w_train.json
# │   └── face_landmarks_300w_valid.json
# └── images
#     ├── afw
#     ├── helen
#     └── lfpw

2.3 WFLW数据集部署(98点扩展)

# 下载原始数据集
wget https://wywu.github.io/projects/LAB/support/WFLW_images.tar.gz
tar -zxf WFLW_images.tar.gz -C data/

# 转换标注格式
python tools/dataset_converters/wflw2coco.py \
    --img-root data/WFLW/images \
    --ann-file data/WFLW/WFLW_annotations/list_98pt_rect_attr_train_test/list_98pt_rect_attr_test.txt \
    --out-file data/WFLW/annotations/wflw_test.json

3. 模型训练:从配置到启动的全流程

3.1 模型选择策略

flowchart TD
    A[任务需求] --> B{关键点数量}
    B -->|68点| C[ResNet50+Heatmap]
    B -->|98点| D[RTMPose-S+SimCC]
    C --> E[300W数据集]
    D --> F[WFLW数据集]
    E --> G[256x256输入尺寸]
    F --> H[256x256输入尺寸]

3.2 68点模型训练(基于300W数据集)

# 单卡训练
python tools/train.py \
    configs/face_2d_keypoint/topdown_heatmap/300w/res50_300w_256x256.py \
    --work-dir work_dirs/face/300w/res50_256x256

# 多卡训练
bash tools/dist_train.sh \
    configs/face_2d_keypoint/topdown_heatmap/300w/res50_300w_256x256.py \
    4 \
    --work-dir work_dirs/face/300w/res50_256x256

3.3 98点模型训练(基于WFLW数据集)

# 使用RTMPose-S模型
python tools/train.py \
    configs/face_2d_keypoint/rtmpose/wflw/rtmpose-s_8xb256-420e_wflw-256x256.py \
    --work-dir work_dirs/face/wflw/rtmpose-s_256x256 \
    --amp

# 断点续训
python tools/train.py \
    configs/face_2d_keypoint/rtmpose/wflw/rtmpose-s_8xb256-420e_wflw-256x256.py \
    --work-dir work_dirs/face/wflw/rtmpose-s_256x256 \
    --resume work_dirs/face/wflw/rtmpose-s_256x256/latest.pth

4. 模型配置详解:参数调优指南

4.1 核心配置参数说明

# 数据集配置示例 (res50_300w_256x256.py)
train_dataloader = dict(
    batch_size=32,
    num_workers=8,
    dataset=dict(
        type='Face300WDataset',
        data_root='data/300w',
        ann_file='annotations/face_landmarks_300w_train.json',
        data_prefix=dict(img='images/'),
        pipeline=train_pipeline
    )
)

# 模型配置示例
model = dict(
    type='TopdownPoseEstimator',
    backbone=dict(
        type='ResNet',
        depth=50,
        init_cfg=dict(type='Pretrained', checkpoint='torchvision://resnet50')
    ),
    keypoint_head=dict(
        type='HeatmapHead',
        in_channels=2048,
        out_channels=68,  # 68个关键点
        loss=dict(type='KeypointMSELoss', use_target_weight=True),
        decoder=dict(type='HeatmapDecoder', sigma=3)
    )
)

4.2 训练策略调整

参数类别 推荐设置 作用
学习率 1e-3 (ResNet) / 5e-3 (RTMPose) 控制参数更新步长
批大小 32-128 平衡GPU利用率和梯度稳定性
迭代次数 420e (42万次) 保证模型充分收敛
数据增强 随机旋转(-30°~30°)、缩放(0.75~1.5x) 提升模型泛化能力

5. 模型评估与可视化

5.1 量化评估指标

# 评估68点模型
python tools/test.py \
    configs/face_2d_keypoint/topdown_heatmap/300w/res50_300w_256x256.py \
    work_dirs/face/300w/res50_256x256/best_PCK_epoch_200.pth \
    --eval PCK PCKh

5.2 可视化结果生成

# 单张图片推理
python demo/image_demo.py \
    demo/demo.jpg \
    configs/face_2d_keypoint/topdown_heatmap/300w/res50_300w_256x256.py \
    work_dirs/face/300w/res50_256x256/best_PCK_epoch_200.pth \
    --out-file vis_results/face_demo.jpg

# 关键点热力图可视化
python tools/visualizations/vis_heatmap.py \
    --img-path demo/demo.jpg \
    --config configs/face_2d_keypoint/topdown_heatmap/300w/res50_300w_256x256.py \
    --checkpoint work_dirs/face/300w/res50_256x256/best_PCK_epoch_200.pth \
    --out-dir vis_results/heatmaps

6. 常见问题解决方案

6.1 训练不收敛问题

stateDiagram
    [*] --> 低学习率
    低学习率 --> 提高学习率至2e-3
    [*] --> 数据不足
    数据不足 --> 添加WFLW数据集联合训练
    [*] --> 梯度爆炸
    梯度爆炸 --> 启用梯度裁剪(grad_clip=dict(max_norm=35, norm_type=2))

6.2 精度优化技巧

  1. 预训练模型选择:优先使用ImageNet预训练权重,而非随机初始化
  2. 损失函数调整:遮挡场景下添加 focal loss:loss=dict(type='FocalHeatmapLoss', gamma=2.0)
  3. 测试时增强:启用多尺度测试:test_pipeline = [dict(type='TopdownAffine', scales=[0.5, 1.0, 1.5])]

7. 工程化部署指南

7.1 模型导出为ONNX格式

python tools/export.py \
    configs/face_2d_keypoint/rtmpose/wflw/rtmpose-s_8xb256-420e_wflw-256x256.py \
    work_dirs/face/wflw/rtmpose-s_256x256/best_PCK_epoch_300.pth \
    --export-format onnx \
    --save-file rtmpose_face_98.onnx \
    --input-shape 1 3 256 256

7.2 性能优化指标

模型 输入尺寸 推理速度(ms) 精度(PCK) 模型体积(MB)
ResNet50(68点) 256x256 28 0.965 92
RTMPose-S(98点) 256x256 12 0.972 13

8. 实战案例:从训练到应用

8.1 实时人脸美妆试戴系统

import cv2
from mmpose.apis import inference_topdown
from mmpose.utils import visualize

# 加载模型
model = init_pose_model('configs/face_2d_keypoint/rtmpose/wflw/rtmpose-s_8xb256-420e_wflw-256x256.py',
                        'work_dirs/face/wflw/rtmpose-s_256x256/best_PCK_epoch_300.pth')

# 摄像头实时处理
cap = cv2.VideoCapture(0)
while cap.isOpened():
    ret, frame = cap.read()
    if not ret: break
    
    # 关键点检测
    results = inference_topdown(model, frame)
    
    # 可视化关键点
    vis_frame = visualize.visualize_keypoints(
        frame,
        results[0].pred_instances.keypoints[0],
        skeleton=model.dataset_meta['skeleton'],
        keypoint_score=results[0].pred_instances.keypoint_scores[0]
    )
    
    cv2.imshow('Face Keypoints', vis_frame)
    if cv2.waitKey(1) & 0xFF == ord('q'):
        break

9. 总结与扩展

9.1 技术路线图

timeline
    title 人脸关键点技术演进
    2016 : 300W数据集发布,68点标注标准确立
    2018 : WFLW数据集提出,支持98点精细标注
    2020 : Heatmap方法达到性能瓶颈(PCK≈0.96)
    2022 : RTMPose系列模型发布,精度速度双重突破
    2023 : 基于生成式AI的无监督标注技术兴起

9.2 未来研究方向

  1. 跨模态融合:结合RGB-D数据提升遮挡场景鲁棒性
  2. 动态关键点跟踪:视频序列中的关键点时序一致性优化
  3. 轻量化部署:模型压缩至移动端实时运行(<5ms)

通过本教程,您已掌握MMPose框架下68点和98点人脸关键点模型的完整训练流程。实际应用中建议优先选择RTMPose系列模型,在保证精度的同时获得更优的推理速度。后续可关注MMPose官方仓库获取最新模型和技术更新。

登录后查看全文
热门项目推荐
相关项目推荐