Vision Transformer训练与微调实战
本文详细介绍了Vision Transformer (ViT) 的训练与微调全流程,涵盖了从ImageNet预训练到下游任务微调的完整技术方案。内容包括预训练阶段的架构设计、超参数配置、数据增强策略,以及微调阶段的学习率调度、权重加载、性能优化技巧等关键技术细节。文章还深入探讨了AugReg增强正则化框架、多GPU/TPU分布式训练实现等高级主题,为实际应用提供了全面的技术指导和最佳实践。
ImageNet预训练与下游任务微调流程
Vision Transformer (ViT) 的训练流程遵循经典的预训练-微调范式,其中模型首先在大型数据集(如ImageNet-21k)上进行预训练,学习通用的视觉表示,然后在特定下游任务上进行微调。这种两阶段训练策略已被证明能够显著提升模型在各种计算机视觉任务上的性能。
预训练阶段架构设计
ViT的预训练过程采用标准的Transformer编码器架构,将图像分割成固定大小的patch,通过线性嵌入层转换为序列向量,并添加位置编码信息:
class VisionTransformer(nn.Module):
num_classes: int
patch_size: int = 16
hidden_size: int = 768
num_layers: int = 12
num_heads: int = 12
mlp_dim: int = 3072
dropout_rate: float = 0.1
@nn.compact
def __call__(self, inputs, *, train):
# Patch embedding
x = nn.Conv(self.hidden_size,
kernel_size=(self.patch_size, self.patch_size),
strides=(self.patch_size, self.patch_size),
padding='VALID',
name='embedding')(inputs)
batch_size, h, w, channels = x.shape
x = jnp.reshape(x, [batch_size, h * w, channels])
# Add class token and position embeddings
cls_token = self.param('cls', nn.initializers.zeros, (1, 1, self.hidden_size))
cls_tokens = jnp.tile(cls_token, (batch_size, 1, 1))
x = jnp.concatenate([cls_tokens, x], axis=1)
pos_emb = self.param('pos_embedding',
nn.initializers.normal(stddev=0.02),
(1, h * w + 1, self.hidden_size))
x = x + pos_emb
# Transformer layers
for _ in range(self.num_layers):
x = TransformerBlock()(x, train=train)
# Classification head
x = x[:, 0] # Class token
x = nn.LayerNorm()(x)
x = nn.Dense(self.num_classes)(x)
return x
预训练配置与超参数设置
ImageNet预训练阶段采用以下关键配置参数:
| 超参数 | 推荐值 | 说明 |
|---|---|---|
| 学习率 | 0.001 | 基础学习率 |
| 批次大小 | 4096 | 全局批次大小 |
| 训练步数 | 300,000 | 总训练步数 |
| 预热步数 | 10,000 | 学习率预热步数 |
| 权重衰减 | 0.1 | L2正则化强度 |
| Dropout率 | 0.0-0.1 | 根据模型大小调整 |
def get_pretraining_config():
config = ml_collections.ConfigDict()
config.batch = 4096
config.base_lr = 0.001
config.total_steps = 300000
config.warmup_steps = 10000
config.weight_decay = 0.1
config.dropout_rate = 0.0
config.grad_norm_clip = 1.0
return config
数据增强与正则化策略
在ImageNet预训练阶段,采用多种数据增强技术来提升模型的泛化能力:
flowchart TD
A[原始图像] --> B[随机裁剪]
B --> C[随机水平翻转]
C --> D[颜色抖动]
D --> E[MixUp增强]
E --> F[CutMix增强]
F --> G[最终训练图像]
具体的增强实现代码如下:
def create_pretraining_augmentations():
augmentations = [
tf.image.random_flip_left_right,
tf.image.random_brightness(0.2),
tf.image.random_contrast(0.8, 1.2),
tf.image.random_saturation(0.8, 1.2),
tf.image.random_hue(0.1)
]
return tf.keras.Sequential(augmentations)
下游任务微调流程
预训练完成后,模型可以在各种下游任务上进行微调。微调过程需要调整分类头并优化学习率策略:
sequenceDiagram
participant User
participant Config
participant Model
participant Trainer
User->>Config: 指定模型和数据集
Config->>Model: 加载预训练权重
Model->>Trainer: 初始化微调配置
Trainer->>Model: 替换分类头
Trainer->>Model: 设置优化器参数
Model->>Trainer: 开始微调训练
Trainer->>User: 返回微调结果
微调配置示例:
def get_finetuning_config(model_name, dataset_name):
config = common.with_dataset(common.get_config(), dataset_name)
get_model_config = getattr(models, f'get_{model_name}_config')
config.model = get_model_config()
# 针对不同数据集的特定配置
if model_name == 'b16' and dataset_name == 'cifar10':
config.base_lr = 0.01
config.total_steps = 1000
config.warmup_steps = 100
elif model_name == 'b16' and dataset_name == 'imagenet2012':
config.base_lr = 0.003
config.total_steps = 20000
config.warmup_steps = 500
return config
模型加载与权重初始化
微调过程中,预训练权重的正确加载至关重要:
def load_pretrained_weights(pretrained_path, init_params, model_config):
"""加载预训练权重并适配当前模型架构"""
pretrained_params = checkpoint.load(pretrained_path)
# 处理可能的架构差异
if 'pos_embedding' in pretrained_params:
# 调整位置编码以适应不同的输入尺寸
current_posemb = init_params['pos_embedding']
pretrained_posemb = pretrained_params['pos_embedding']
if current_posemb.shape != pretrained_posemb.shape:
pretrained_params['pos_embedding'] = interpolate_posembed(
pretrained_posemb,
current_posemb.shape[1],
has_class_token=True
)
return pretrained_params
学习率调度策略
微调阶段采用带预热的分段学习率调度:
def create_finetuning_lr_schedule(total_steps, base_lr, warmup_steps):
def lr_fn(step):
# 预热阶段
if step < warmup_steps:
return base_lr * (step / warmup_steps)
# 余弦衰减阶段
progress = (step - warmup_steps) / (total_steps - warmup_steps)
return base_lr * 0.5 * (1 + jnp.cos(jnp.pi * progress))
return lr_fn
微调性能优化技巧
为了获得最佳的微调性能,推荐以下优化策略:
| 技巧 | 实施方法 | 效果 |
|---|---|---|
| 分层学习率 | backbone使用较小学习率 | 保护预训练特征 |
| 梯度累积 | accum_steps参数调整 | 解决内存限制 |
| 早停机制 | 监控验证集性能 | 防止过拟合 |
| 标签平滑 | 使用smooth_labels | 提升泛化能力 |
def apply_finetuning_optimizations(config):
# 分层学习率设置
if config.get('differential_learning_rate', False):
backbone_params = ['embedding', 'Transformer', 'pos_embedding']
head_params = ['classifier']
backbone_optimizer = optax.sgd(learning_rate=config.base_lr * 0.1)
head_optimizer = optax.sgd(learning_rate=config.base_lr)
tx = optax.multi_transform(
{'backbone': backbone_optimizer, 'head': head_optimizer},
lambda param: 'backbone' if any(p in param for p in backbone_params) else 'head'
)
else:
tx = optax.sgd(learning_rate=config.base_lr)
return tx
实际微调命令示例
使用预训练的ViT-B/16模型在CIFAR-10数据集上进行微调:
python -m vit_jax.main --workdir=/tmp/vit-finetune \
--config=vit_jax/configs/vit.py:b16,cifar10 \
--config.pretrained_dir='gs://vit_models/imagenet21k' \
--config.total_steps=1000 \
--config.warmup_steps=100 \
--config.base_lr=0.01 \
--config.batch=512
微调结果监控与评估
微调过程中需要密切监控以下指标:
| 指标 | 监控频率 | 预期趋势 |
|---|---|---|
| 训练损失 | 每100步 | 逐渐下降 |
| 验证准确率 | 每500步 | 稳步提升 |
| 学习率 | 每步 | 按调度变化 |
| 训练速度 | 每100步 | 保持稳定 |
通过这种系统的预训练-微调流程,Vision Transformer能够在保持预训练知识的同时,快速适应各种下游视觉任务,实现优异的性能表现。
数据增强与正则化技术最佳实践
在Vision Transformer的训练与微调过程中,数据增强和正则化技术是提升模型泛化能力、防止过拟合的关键技术。本小节将深入探讨在ViT项目中实现的各种数据增强策略和正则化方法,并通过代码示例和实验数据展示其最佳实践。
数据增强策略详解
Vision Transformer项目实现了多种数据增强技术,主要包含在input_pipeline.py文件中的预处理管道中。让我们详细分析这些增强技术的实现原理:
随机裁剪与尺寸调整
def _pp(data):
im = image_decoder(data['image'])
if mode == 'train':
channels = im.shape[-1]
begin, size, _ = tf.image.sample_distorted_bounding_box(
tf.shape(im),
tf.zeros([0, 0, 4], tf.float32),
area_range=(0.05, 1.0),
min_object_covered=0,
use_image_if_no_bounding_boxes=True)
im = tf.slice(im, begin, size)
im.set_shape([None, None, channels])
im = tf.image.resize(im, [image_size, image_size])
if tf.random.uniform(shape=[]) > 0.5:
im = tf.image.flip_left_right(im)
这段代码实现了以下增强技术:
- 随机裁剪:使用
sample_distorted_bounding_box在5%-100%的面积范围内随机裁剪 - 水平翻转:以50%的概率进行水平镜像
- 尺寸标准化:将所有图像调整到统一的训练尺寸
数据增强强度分级
根据AugReg论文的研究,项目实现了不同强度的数据增强策略:
flowchart TD
A[数据增强策略] --> B[无增强<br>aug_none]
A --> C[轻度增强<br>aug_light1]
A --> D[中等增强<br>aug_medium1]
A --> E[强增强<br>aug_strong1]
B --> F[仅基础预处理]
C --> G[随机裁剪+翻转]
D --> H[中等强度变换]
E --> I[完整增强组合]
正则化技术实现
Dropout正则化
在Vision Transformer的不同组件中,项目实现了多层次的Dropout正则化:
# 在模型配置中定义Dropout率
config.transformer.attention_dropout_rate = 0.0
config.transformer.dropout_rate = 0.1
不同模型的Dropout配置对比:
| 模型类型 | Attention Dropout | MLP Dropout | 适用场景 |
|---|---|---|---|
| ViT-Ti/16 | 0.0 | 0.0 | 小规模数据集 |
| ViT-B/16 | 0.0 | 0.0 | 中等规模 |
| ViT-L/16 | 0.0 | 0.1 | 大规模数据 |
| ViT-H/14 | 0.0 | 0.1 | 超大规模 |
权重衰减(Weight Decay)
权重衰减是另一个重要的正则化技术,在训练配置中通过wd参数控制:
# 在训练配置中设置权重衰减
config.weight_decay = 0.1 # 通常设置为0.03-0.1
AugReg增强正则化框架
AugReg(Augmentation + Regularization)是项目中最重要的技术创新之一,它系统地探索了数据增强和正则化的组合效果:
AugReg配置参数
def get_config(model_or_filename):
config = common.get_config()
config.pretrained_dir = 'gs://vit_models/augreg'
# 从文件名解析增强正则化参数
model = model_or_filename.split('-')[0]
config.model = models.AUGREG_CONFIGS[model].copy_and_resolve_references()
config.model.transformer.dropout_rate = 0 # 微调时关闭Dropout
AugReg参数命名约定
AugReg检查点的文件名包含了完整的训练配置信息:
B_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.0-sd_0.0.npz
各部分的含义:
B_16: 模型架构(ViT-B/16)i21k: 在ImageNet-21k上预训练300ep: 300个训练周期lr_0.001: 学习率0.001aug_medium1: 中等强度数据增强wd_0.1: 权重衰减0.1do_0.0: Dropout率为0sd_0.0: 随机深度为0
最佳实践建议
1. 数据增强强度选择
根据数据集规模和特点选择合适的数据增强强度:
graph LR
A[数据集规模] --> B[小数据集<br>10K样本以下]
A --> C[中等数据集<br>10K-100K样本]
A --> D[大数据集<br>100K+样本]
B --> E[推荐: 强增强<br>aug_strong1]
C --> F[推荐: 中等增强<br>aug_medium1]
D --> G[推荐: 轻度增强<br>aug_light1]
2. 正则化参数调优
基于模型复杂度和数据特征的参数建议:
| 模型复杂度 | 权重衰减 | Dropout率 | 随机深度 |
|---|---|---|---|
| 低(Ti/16) | 0.03 | 0.0 | 0.0 |
| 中(B/16) | 0.1 | 0.0 | 0.0 |
| 高(L/16) | 0.1 | 0.1 | 0.1 |
3. 微调时的注意事项
在微调预训练模型时,需要调整正则化策略:
# 微调时通常减少正则化强度
config.model.transformer.dropout_rate = 0 # 关闭Dropout
config.weight_decay = 0.01 # 降低权重衰减
实验效果验证
通过大量实验,AugReg框架证明了数据增强和正则化的组合效果:
| 增强策略 | ImageNet准确率 | 相对提升 |
|---|---|---|
| 无增强 | 81.8% | 基准 |
| 轻度增强 | 83.2% | +1.4% |
| 中等增强 | 84.1% | +2.3% |
| 强增强 | 84.5% | +2.7% |
代码实现示例
以下是一个完整的数据增强和正则化配置示例:
# 配置强增强和正则化
config = ml_collections.ConfigDict()
config.pp = ml_collections.ConfigDict()
config.pp.train = 'train'
config.pp.test = 'test'
config.pp.resize = 448 # 先调整到较大尺寸
config.pp.crop = 384 # 随机裁剪到训练尺寸
# 正则化参数
config.weight_decay = 0.1 # 权重衰减
config.dropout_rate = 0.1 # Dropout率
config.attention_dropout_rate = 0.1 # Attention Dropout
# 学习率调度
config.base_lr = 0.03
config.total_steps = 1000
config.warmup_steps = 100
实际应用建议
- 从小开始:首先尝试轻度增强,逐步增加强度
- 监控过拟合:密切关注训练和验证损失的差距
- 组合使用:数据增强和正则化应该协同工作
- 领域适配:根据具体任务调整增强策略
通过合理的数据增强和正则化技术组合,可以显著提升Vision Transformer在各种视觉任务上的性能和泛化能力。这些最佳实践基于大量实验验证,为实际应用提供了可靠的技术指导。
学习率调度与优化器配置详解
在Vision Transformer的训练与微调过程中,学习率调度和优化器配置是决定模型性能的关键因素。本小节将深入探讨ViT项目中采用的优化策略,包括学习率调度算法、优化器选择以及相关的超参数配置。
优化器配置
Kimi-K2.5Kimi K2.5 是一款开源的原生多模态智能体模型,它在 Kimi-K2-Base 的基础上,通过对约 15 万亿混合视觉和文本 tokens 进行持续预训练构建而成。该模型将视觉与语言理解、高级智能体能力、即时模式与思考模式,以及对话式与智能体范式无缝融合。Python00- QQwen3-Coder-Next2026年2月4日,正式发布的Qwen3-Coder-Next,一款专为编码智能体和本地开发场景设计的开源语言模型。Python00
xw-cli实现国产算力大模型零门槛部署,一键跑通 Qwen、GLM-4.7、Minimax-2.1、DeepSeek-OCR 等模型Go06
PaddleOCR-VL-1.5PaddleOCR-VL-1.5 是 PaddleOCR-VL 的新一代进阶模型,在 OmniDocBench v1.5 上实现了 94.5% 的全新 state-of-the-art 准确率。 为了严格评估模型在真实物理畸变下的鲁棒性——包括扫描伪影、倾斜、扭曲、屏幕拍摄和光照变化——我们提出了 Real5-OmniDocBench 基准测试集。实验结果表明,该增强模型在新构建的基准测试集上达到了 SOTA 性能。此外,我们通过整合印章识别和文本检测识别(text spotting)任务扩展了模型的能力,同时保持 0.9B 的超紧凑 VLM 规模,具备高效率特性。Python00
KuiklyUI基于KMP技术的高性能、全平台开发框架,具备统一代码库、极致易用性和动态灵活性。 Provide a high-performance, full-platform development framework with unified codebase, ultimate ease of use, and dynamic flexibility. 注意:本仓库为Github仓库镜像,PR或Issue请移步至Github发起,感谢支持!Kotlin07
VLOOKVLOOK™ 是优雅好用的 Typora/Markdown 主题包和增强插件。 VLOOK™ is an elegant and practical THEME PACKAGE × ENHANCEMENT PLUGIN for Typora/Markdown.Less00