首页
/ PyTorch-Meta 元学习数据集详解与使用指南

PyTorch-Meta 元学习数据集详解与使用指南

2026-02-04 05:17:44作者:郜逊炳

元学习数据集概述

PyTorch-Meta 是一个专注于元学习(Meta-Learning)的 PyTorch 扩展库,它提供了一系列专为元学习任务设计的数据集。这些数据集在少样本学习(Few-Shot Learning)领域被广泛使用,具有明确的训练集、验证集和测试集划分,非常适合用于评估元学习算法的性能。

核心数据集介绍

1. Omniglot 数据集

Omniglot 是一个经典的手写字符识别数据集,由 1623 个来自 50 种不同字母表的手写字符组成。该数据集在元学习领域被广泛用于评估少样本分类算法。

关键特性

  • 数据规模:1623 个字符类别
  • 默认划分:1028 训练类 / 172 验证类 / 423 测试类
  • 支持 Vinyals 划分方式(默认启用)

使用示例

from torchmeta.datasets import Omniglot

dataset = Omniglot(
    root='./data',
    num_classes_per_task=5,  # 5-way分类
    meta_train=True,
    download=True
)

技术要点

  • use_vinyals_split 参数控制是否使用标准划分方式
  • 支持字符增强(如水平翻转)来扩充类别
  • 每个字符包含 20 个样本,适合 K-shot 学习任务

2. MiniImagenet 数据集

MiniImagenet 是从 ImageNet 数据集中精选的子集,包含 100 个类别,是评估图像少样本学习算法的重要基准。

关键特性

  • 数据规模:100 个类别
  • 标准划分:64 训练类 / 16 验证类 / 20 测试类
  • 每类 600 张图片(84×84 像素)

使用示例

from torchmeta.datasets import MiniImagenet

dataset = MiniImagenet(
    root='./data',
    num_classes_per_task=5,
    meta_val=True,
    download=True
)

技术要点

  • 图像尺寸较小(84×84),适合快速实验
  • 类别间差异较大,挑战性适中
  • 常用于原型网络(Prototypical Networks)等算法的基准测试

3. TieredImagenet 数据集

TieredImagenet 是 ImageNet 的更大型子集,包含 608 个类别,采用层级划分方式确保训练和测试类别差异更大。

关键特性

  • 数据规模:608 个类别(34 个高级类别)
  • 层级划分:20 训练类 / 6 验证类 / 8 测试类(高级类别)
  • 每高级类别包含 10-30 个具体类别

使用示例

from torchmeta.datasets import TieredImagenet

dataset = TieredImagenet(
    root='./data',
    num_classes_per_task=5,
    meta_test=True,
    download=True
)

技术要点

  • 层级划分确保训练和测试类别差异明显
  • 比 MiniImagenet 更具挑战性
  • 适合评估算法在更大规模数据上的泛化能力

4. FC100 数据集

FC100 (Fewshot-CIFAR100) 是基于 CIFAR100 数据集构建的元学习专用数据集,按照超类别进行划分。

关键特性

  • 数据规模:100 个类别(20 个超类别)
  • 划分方式:60 训练类 / 20 验证类 / 20 测试类
  • 图像尺寸:32×32 像素

使用示例

from torchmeta.datasets import FC100

dataset = FC100(
    root='./data',
    num_classes_per_task=5,
    meta_train=True,
    download=True
)

技术要点

  • 超类别划分确保训练和测试数据差异
  • 图像尺寸小,训练速度快
  • 适合算法快速迭代和验证

5. CIFARFS 数据集

CIFAR-FS 是另一种基于 CIFAR100 的元学习数据集,采用不同的类别划分方式。

关键特性

  • 数据规模:100 个类别
  • 标准划分:64 训练类 / 16 验证类 / 20 测试类
  • 图像尺寸:32×32 像素

使用示例

from torchmeta.datasets import CIFARFS

dataset = CIFARFS(
    root='./data',
    num_classes_per_task=5,
    meta_train=True,
    download=True
)

技术要点

  • 与 FC100 相同的源数据,不同划分方式
  • 提供与 MiniImagenet 相似的类别数量
  • 适合与 MiniImagenet 结果进行对比

6. CUB 数据集

CUB (Caltech-UCSD Birds) 是一个细粒度鸟类识别数据集,在元学习中用于评估细粒度分类能力。

关键特性

  • 数据规模:200 种鸟类
  • 图像具有丰富的背景变化
  • 细粒度分类挑战大

使用示例

from torchmeta.datasets import CUB

dataset = CUB(
    root='./data',
    num_classes_per_task=5,
    meta_train=True,
    download=True
)

技术要点

  • 细粒度分类的代表性数据集
  • 图像背景复杂,分类难度高
  • 适合评估算法在细粒度任务上的表现

数据集通用参数解析

所有 PyTorch-Meta 数据集共享一组核心参数,理解这些参数对正确使用数据集至关重要:

  1. root (字符串): 数据集存储根目录
  2. num_classes_per_task (整数): N-way 分类中的 N 值
  3. meta_train/meta_val/meta_test (布尔值): 选择数据划分
  4. transform (可调用对象): 图像预处理变换
  5. target_transform (可调用对象): 标签变换
  6. dataset_transform (可调用对象): 整个数据集的变换
  7. class_augmentations (可调用对象列表): 类别增强方法
  8. download (布尔值): 是否自动下载数据集

最佳实践建议

  1. 数据预处理:合理使用 transform 参数进行图像标准化、增强等操作

    from torchvision.transforms import Compose, Resize, ToTensor
    
    transform = Compose([
        Resize(84),
        ToTensor()
    ])
    
  2. 任务构造:使用 ClassSplitter 创建少样本学习任务

    from torchmeta.transforms import ClassSplitter
    
    dataset_transform = ClassSplitter(
        num_train_per_class=5,  # 5-shot
        num_test_per_class=15
    )
    
  3. 类别增强:利用 class_augmentations 增加类别多样性

    from torchmeta.transforms import HorizontalFlip
    
    class_augmentations = [HorizontalFlip()]
    
  4. 数据加载:结合 MetaDataLoader 进行批量任务加载

    from torchmeta.utils.data import BatchMetaDataLoader
    
    dataloader = BatchMetaDataLoader(dataset, batch_size=4)
    

总结

PyTorch-Meta 提供了一系列精心设计的元学习数据集,覆盖了从简单字符识别到复杂细粒度分类的各种挑战。理解这些数据集的特性和正确使用方式,对于开发有效的元学习算法至关重要。建议初学者从 Omniglot 或 MiniImagenet 开始,逐步挑战更复杂的数据集如 TieredImagenet 和 CUB。

登录后查看全文
热门项目推荐
相关项目推荐