torchvision数据集大全:内置数据集使用指南
本文全面介绍了torchvision中各类内置数据集的使用方法,涵盖了图像分类(MNIST、CIFAR、ImageNet)、目标检测(COCO、VOC、WiderFace)、视频与光流(Kinetics、UCF101、HMDB51)等主流数据集,并详细讲解了自定义数据集的创建方法和数据加载最佳实践。
图像分类数据集(ImageNet、CIFAR、MNIST)
在计算机视觉领域,图像分类是最基础且重要的任务之一。torchvision提供了丰富的内置数据集,其中ImageNet、CIFAR和MNIST是最具代表性的图像分类数据集。这些数据集不仅为模型训练提供了标准化的基准,也为研究者们提供了可靠的评估标准。
MNIST数据集
MNIST(Modified National Institute of Standards and Technology)数据集是深度学习入门的经典数据集,包含手写数字的灰度图像。
数据集特性
graph TD
A[MNIST数据集] --> B[训练集: 60,000张]
A --> C[测试集: 10,000张]
B --> D[图像尺寸: 28x28像素]
C --> D
D --> E[灰度图像: 1通道]
E --> F[标签: 0-9数字]
使用方法
import torchvision.datasets as datasets
import torchvision.transforms as transforms
# 基本使用方式
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
])
# 下载并加载训练集
train_dataset = datasets.MNIST(
root='./data',
train=True,
download=True,
transform=transform
)
# 加载测试集
test_dataset = datasets.MNIST(
root='./data',
train=False,
transform=transform
)
数据集变体
torchvision还提供了MNIST的多个变体:
| 数据集 | 描述 | 类别数 | 图像尺寸 |
|---|---|---|---|
| FashionMNIST | 时尚物品图像 | 10 | 28x28 |
| KMNIST | 日文字符数据集 | 10 | 28x28 |
| EMNIST | 扩展MNIST(字母和数字) | 47 | 28x28 |
| QMNIST | 高质量MNIST扩展 | 10 | 28x28 |
CIFAR数据集
CIFAR(Canadian Institute For Advanced Research)数据集包含彩色图像,分为CIFAR-10和CIFAR-100两个版本。
CIFAR-10数据集结构
pie
title CIFAR-10类别分布
"飞机" : 6000
"汽车" : 6000
"鸟" : 6000
"猫" : 6000
"鹿" : 6000
"狗" : 6000
"青蛙" : 6000
"马" : 6000
"船" : 6000
"卡车" : 6000
技术实现细节
CIFAR数据集使用Python的pickle格式存储,数据组织方式如下:
# CIFAR数据文件结构示例
{
'data': ndarray, # 图像数据 (10000, 3072)
'labels': list, # 对应标签
'batch_label': str, # 批次标签
'filenames': list # 文件名列表
}
图像数据以3072维向量存储,前1024维是红色通道,中间1024维是绿色通道,最后1024维是蓝色通道。
使用示例
from torchvision import datasets, transforms
# 数据预处理
transform = transforms.Compose([
transforms.RandomHorizontalFlip(),
transforms.RandomCrop(32, padding=4),
transforms.ToTensor(),
transforms.Normalize(
mean=[0.4914, 0.4822, 0.4465],
std=[0.2023, 0.1994, 0.2010]
)
])
# CIFAR-10数据集
cifar10_train = datasets.CIFAR10(
root='./data',
train=True,
download=True,
transform=transform
)
cifar10_test = datasets.CIFAR10(
root='./data',
train=False,
transform=transform
)
# CIFAR-100数据集
cifar100_train = datasets.CIFAR100(
root='./data',
train=True,
download=True,
transform=transform
)
ImageNet数据集
ImageNet是目前最大规模的图像分类数据集之一,包含超过1400万张图像,覆盖21841个类别。
数据集组织结构
ImageNet数据集采用WordNet层次结构组织,每个类别对应一个WordNet ID(wnid)。数据集的主要特点:
| 特性 | 描述 |
|---|---|
| 训练集 | 1,281,167张图像 |
| 验证集 | 50,000张图像 |
| 测试集 | 100,000张图像(无标签) |
| 类别数 | 1000个叶节点类别 |
| 图像尺寸 | 可变尺寸,平均约469x387像素 |
数据加载流程
sequenceDiagram
participant User
participant ImageNet
participant ImageFolder
participant FileSystem
User->>ImageNet: 初始化数据集
ImageNet->>ImageNet: 检查元数据文件
ImageNet->>ImageNet: 解析归档文件
ImageNet->>ImageFolder: 委托给ImageFolder加载
ImageFolder->>FileSystem: 读取图像文件
FileSystem-->>ImageFolder: 返回图像数据
ImageFolder-->>ImageNet: 返回处理后的数据
ImageNet-->>User: 返回数据集实例
使用注意事项
# ImageNet数据集需要手动下载并放置到指定目录
# 数据集目录结构应该如下:
# imagenet/
# ├── train/
# │ ├── n01440764/
# │ ├── n01443537/
# │ └── ...
# ├── val/
# │ ├── n01440764/
# │ ├── n01443537/
# │ └── ...
# └── devkit/ (可选)
# 使用示例
transform = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]
)
])
train_dataset = datasets.ImageNet(
root='/path/to/imagenet',
split='train',
transform=transform
)
val_dataset = datasets.ImageNet(
root='/path/to/imagenet',
split='val',
transform=transform
)
数据集对比分析
下表总结了三个主要图像分类数据集的关键特性:
| 特性 | MNIST | CIFAR-10/100 | ImageNet |
|---|---|---|---|
| 图像类型 | 灰度 | RGB彩色 | RGB彩色 |
| 图像尺寸 | 28x28 | 32x32 | 可变尺寸 |
| 类别数 | 10 | 10/100 | 1000 |
| 训练样本数 | 60,000 | 50,000 | 1.2M |
| 测试样本数 | 10,000 | 10,000 | 50,000 |
| 主要用途 | 入门教学 | 算法研究 | 实际应用 |
最佳实践建议
-
数据预处理标准化:每个数据集都有推荐的标准化参数,使用这些参数可以确保模型获得最佳性能。
-
数据增强策略:
- MNIST:简单的旋转和平移
- CIFAR:随机裁剪、水平翻转、颜色抖动
- ImageNet:大规模增强包括MixUp、CutMix等
-
内存管理:对于大规模数据集如ImageNet,建议使用DataLoader的pin_memory选项加速GPU数据传输。
-
分布式训练:对于ImageNet等大数据集,使用分布式数据并行(DDP)可以显著减少训练时间。
这些图像分类数据集为计算机视觉研究提供了坚实的基础,从简单的数字识别到复杂的场景理解,它们覆盖了不同难度的视觉任务。通过合理使用这些数据集,研究人员可以有效地开发和评估新的计算机视觉算法。
目标检测数据集(COCO、VOC、WiderFace)
在计算机视觉领域,目标检测是一个核心任务,而高质量的数据集是训练优秀检测模型的基础。TorchVision提供了三个业界标准的目标检测数据集:COCO、Pascal VOC和WIDER FACE。这些数据集涵盖了从通用物体检测到特定人脸检测的各种场景。
COCO数据集:大规模通用物体检测
COCO(Common Objects in Context)是当前最流行的目标检测数据集之一,包含80个物体类别和丰富的标注信息。
数据集特点
flowchart TD
A[COCO数据集] --> B[80个物体类别]
A --> C[33万张图像]
A --> D[250万个标注实例]
B --> E[日常物体]
B --> F[动物]
B --> G[交通工具]
B --> H[家具等]
使用示例
import torchvision.datasets as datasets
import torchvision.transforms as transforms
from pycocotools.coco import COCO
# 初始化COCO检测数据集
coco_dataset = datasets.CocoDetection(
root='path/to/coco/images',
annFile='path/to/annotations/instances_train2017.json',
transform=transforms.Compose([
transforms.Resize((800, 800)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
])
)
# 获取样本
image, target = coco_dataset[0]
print(f"图像尺寸: {image.shape}")
print(f"标注数量: {len(target)}")
print(f"第一个标注: {target[0]}")
标注格式解析
COCO标注使用JSON格式,包含以下关键信息:
| 字段 | 描述 | 类型 |
|---|---|---|
id |
标注ID | int |
image_id |
图像ID | int |
category_id |
类别ID | int |
bbox |
边界框 [x,y,width,height] | list[float] |
area |
区域面积 | float |
iscrowd |
是否人群标注 | int |
Pascal VOC数据集:经典目标检测基准
Pascal VOC(Visual Object Classes)是目标检测领域的经典基准数据集,包含20个物体类别。
数据集版本
| 版本 | 图像数量 | 标注数量 | 特点 |
|---|---|---|---|
| VOC2007 | 9,963 | 24,640 | 首个完整版本 |
| VOC2012 | 11,530 | 27,450 | 扩展版本 |
| 其他年份 | 各年份不同 | 各年份不同 | 历史版本 |
使用示例
from torchvision.datasets import VOCDetection
import torchvision.transforms as transforms
# 下载并加载VOC数据集
voc_dataset = VOCDetection(
root='./data',
year='2012',
image_set='train',
download=True,
transform=transforms.Compose([
transforms.Resize((500, 500)),
transforms.ToTensor()
])
)
# 查看数据集信息
print(f"数据集大小: {len(voc_dataset)}")
image, target = voc_dataset[0]
print(f"标注信息键: {target['annotation'].keys()}")
XML标注结构
classDiagram
class Annotation {
+str folder
+str filename
+Source source
+Size size
+bool segmented
+List[Object] objects
}
class Object {
+str name
+str pose
+bool truncated
+bool difficult
+BndBox bndbox
}
class BndBox {
+int xmin
+int ymin
+int xmax
+int ymax
}
Annotation --> Object
Object --> BndBox
WIDER FACE数据集:大规模人脸检测
WIDER FACE是专门用于人脸检测的大规模数据集,包含32,203张图像和393,703个人脸标注。
数据集统计
# WIDER FACE数据集统计信息
widerface_stats = {
"训练集": {
"图像数量": 12,880,
"人脸数量": 159,424,
"平均每图像人脸数": 12.4
},
"验证集": {
"图像数量": 3,226,
"人脸数量": 39,593,
"平均每图像人脸数": 12.3
},
"测试集": {
"图像数量": 16,097,
"人脸数量": 未知, # 测试集标注未公开
"平均每图像人脸数": 未知
}
}
使用示例
from torchvision.datasets import WIDERFace
import torchvision.transforms as transforms
# 加载WIDER FACE数据集
widerface_dataset = WIDERFace(
root='./data',
split='train',
download=True,
transform=transforms.Compose([
transforms.Resize((256, 256)),
transforms.ToTensor()
])
)
# 获取样本数据
image, target = widerface_dataset[0]
print(f"人脸数量: {len(target['bbox'])}")
print(f"边界框格式: {target['bbox'].shape}")
丰富的人脸属性
WIDER FACE提供了详细的人脸属性标注:
| 属性 | 描述 | 取值范围 |
|---|---|---|
blur |
模糊程度 | 0-清晰, 1-一般, 2-严重 |
expression |
表情 | 0-典型, 1-夸张 |
illumination |
光照条件 | 0-正常, 1-异常 |
occlusion |
遮挡程度 | 0-无, 1-部分, 2-严重 |
pose |
姿态 | 0-典型, 1-异常 |
invalid |
是否有效 | 0-有效, 1-无效 |
数据集对比分析
| 特性 | COCO | Pascal VOC | WIDER FACE |
|---|---|---|---|
| 应用场景 | 通用物体检测 | 通用物体检测 | 人脸检测 |
| 类别数量 | 80 | 20 | 1(人脸) |
| 标注密度 | 高(7.7/图) | 中(2.4/图) | 很高(12.4/图) |
| 标注格式 | JSON | XML | 文本文件 |
| 挑战性 | 复杂场景、小物体 | 标准场景 | 尺度变化、遮挡 |
最佳实践建议
数据预处理流程
flowchart LR
A[原始图像] --> B[尺寸调整]
B --> C[数据增强]
C --> D[归一化]
D --> E[转换为张量]
E --> F[模型输入]
G[原始标注] --> H[格式解析]
H --> I[坐标转换]
I --> J[标签编码]
J --> K[模型目标]
性能优化技巧
- 批量加载:使用DataLoader进行批量处理
- 缓存机制:对预处理结果进行缓存
- 并行处理:利用多进程加速数据加载
- 内存映射:对于大型数据集使用内存映射文件
from torch.utils.data import DataLoader
from torchvision.datasets import CocoDetection
# 创建数据加载器
dataset = CocoDetection(...)
dataloader = DataLoader(
dataset,
batch_size=16,
shuffle=True,
num_workers=4,
pin_memory=True # 加速GPU传输
)
# 迭代训练
for batch_idx, (images, targets) in enumerate(dataloader):
# 训练代码
pass
常见问题解决
COCO数据集安装问题
# 安装pycocotools依赖
pip install pycocotools
# 或者使用conda安装
conda install -c conda-forge pycocotools
数据集下载问题
对于需要手动下载的数据集,确保:
- 目录结构正确
- 文件完整性验证
- 解压到指定位置
内存优化
对于大型数据集,建议:
- 使用增量加载
- 实施数据采样策略
- 利用磁盘缓存机制
这三个目标检测数据集为计算机视觉研究提供了坚实的基础,每个数据集都有其独特的优势和适用场景。选择合适的数据集并正确使用TorchVision提供的接口,可以显著提高目标检测模型的开发效率。
视频与光流数据集(Kinetics、UCF101、HMDB51)
在计算机视觉领域,视频理解是一个重要且具有挑战性的研究方向。torchvision提供了三个主流的视频动作识别数据集:Kinetics、UCF101和HMDB51。这些数据集广泛应用于视频分类、动作识别和时间序列分析等任务。
Kinetics数据集
Kinetics是由DeepMind开发的大规模视频动作识别数据集,包含多个版本:Kinetics-400、Kinetics-600和Kinetics-700,分别对应400、600和700个动作类别。
数据集特点
| 特性 | 描述 |
|---|---|
| 数据规模 | 大规模,数十万视频片段 |
| 动作类别 | 400/600/700个日常动作 |
| 视频来源 | YouTube视频 |
| 时间跨度 | 每个片段约10秒 |
| 数据质量 | 高质量,人工标注 |
使用示例
import torchvision
from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor
# 数据预处理
transform = Compose([
Resize(256),
CenterCrop(224),
ToTensor(),
])
# 加载Kinetics-400训练集
dataset = torchvision.datasets.Kinetics(
root='./data/kinetics',
frames_per_clip=16,
num_classes='400',
split='train',
frame_rate=30,
step_between_clips=1,
transform=transform,
download=True
)
# 获取一个样本
video, audio, label = dataset[0]
print(f"视频形状: {video.shape}")
print(f"音频形状: {audio.shape}")
print(f"标签: {label}")
数据加载流程
flowchart TD
A[初始化Kinetics数据集] --> B[下载视频文件]
B --> C[提取视频片段]
C --> D[应用数据增强]
D --> E[返回视频张量]
E --> F[模型训练]
UCF101数据集
UCF101是佛罗里达中央大学开发的动作识别数据集,包含101个动作类别,涵盖日常生活中的各种动作。
数据集结构
classDiagram
class UCF101 {
+root: str
+annotation_path: str
+frames_per_clip: int
+fold: int
+train: bool
+__init__()
+__getitem__()
+__len__()
}
class VideoClips {
+video_list: List[str]
+frames_per_clip: int
+metadata: Dict
+get_clip()
}
UCF101 --> VideoClips : 使用
完整使用示例
import torch
from torch.utils.data import DataLoader
from torchvision.datasets import UCF101
from torchvision.transforms import v2
# 数据增强管道
transform = v2.Compose([
v2.Resize(256),
v2.RandomCrop(224),
v2.RandomHorizontalFlip(p=0.5),
v2.ToDtype(torch.float32, scale=True),
v2.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
# 加载UCF101数据集
dataset = UCF101(
root='./data/ucf101',
annotation_path='./data/ucf101/annotations',
frames_per_clip=16,
step_between_clips=2,
fold=1,
train=True,
transform=transform
)
# 创建数据加载器
dataloader = DataLoader(
dataset,
batch_size=8,
shuffle=True,
num_workers=4
)
# 训练循环示例
for batch_idx, (videos, audios, labels) in enumerate(dataloader):
print(f"批次 {batch_idx}: 视频形状 {videos.shape}, 标签形状 {labels.shape}")
# 这里添加模型训练代码
break
HMDB51数据集
HMDB51是由布朗大学开发的动作识别数据集,包含51个动作类别,主要关注人类动作。
数据集配置参数
| 参数 | 类型 | 默认值 | 描述 |
|---|---|---|---|
| root | str | 必需 | 数据集根目录 |
| annotation_path | str | 必需 | 标注文件路径 |
| frames_per_clip | int | 必需 | 每个片段的帧数 |
| step_between_clips | int | 1 | 片段间的帧步长 |
| fold | int | 1 | 交叉验证折数(1-3) |
| train | bool | True | 是否使用训练集 |
| transform | Callable | None | 数据变换函数 |
高级使用技巧
import torchvision
from torchvision.datasets import HMDB51
# 配置多个数据增强选项
def create_transforms(mode='train'):
if mode == 'train':
return torchvision.transforms.v2.Compose([
torchvision.transforms.v2.RandomResizedCrop(224),
torchvision.transforms.v2.RandomHorizontalFlip(),
torchvision.transforms.v2.ColorJitter(
brightness=0.2, contrast=0.2, saturation=0.2
),
torchvision.transforms.v2.ToDtype(torch.float32, scale=True),
torchvision.transforms.v2.Normalize(
mean=[0.43216, 0.394666, 0.37645],
std=[0.22803, 0.22145, 0.216989]
)
])
else:
return torchvision.transforms.v2.Compose([
torchvision.transforms.v2.Resize(256),
torchvision.transforms.v2.CenterCrop(224),
torchvision.transforms.v2.ToDtype(torch.float32, scale=True),
torchvision.transforms.v2.Normalize(
mean=[0.43216, 0.394666, 0.37645],
std=[0.22803, 0.22145, 0.216989]
)
])
# 创建训练和验证数据集
train_dataset = HMDB51(
root='./data/hmdb51',
annotation_path='./data/hmdb51/annotations',
frames_per_clip=32,
step_between_clips=2,
fold=1,
train=True,
transform=create_transforms('train')
)
val_dataset = HMDB51(
root='./data/hmdb51',
annotation_path='./data/hmdb51/annotations',
frames_per_clip=32,
step_between_clips=2,
fold=1,
train=False,
transform=create_transforms('val')
)
性能优化建议
内存管理
# 使用视频元数据预计算加速加载
dataset = Kinetics(
root='./data/kinetics',
frames_per_clip=16,
num_workers=4, # 多进程加速
_precomputed_metadata={'frame_rates': [30] * 1000} # 预计算元数据
)
批量处理策略
flowchart LR
A[原始视频] --> B[视频解码]
B --> C[片段提取]
C --> D[数据增强]
D --> E[批量组合]
E --> F[模型输入]
常见问题解决
- 内存不足:减少
frames_per_clip或使用更小的分辨率 - 下载失败:检查网络连接,或手动下载数据集
- 标注文件缺失:确保annotation_path包含正确的分割文件
- 视频格式不支持:确认系统安装了合适的视频编解码器
最佳实践
- 使用适当的数据增强提高模型泛化能力
- 根据硬件配置调整批次大小和帧数
- 利用多进程数据加载加速训练过程
- 定期验证数据集的完整性和正确性
这三个视频数据集为视频理解任务提供了丰富的训练资源,通过合理的配置和使用,可以构建出高性能的视频动作识别模型。
自定义数据集创建与数据加载最佳实践
在实际的计算机视觉项目中,我们经常需要处理自定义数据集。TorchVision 提供了强大的工具来帮助我们创建和管理自定义数据集,确保数据加载的高效性和灵活性。本节将深入探讨如何基于 TorchVision 的架构创建自定义数据集,并分享数据加载的最佳实践。
理解 VisionDataset 基类
TorchVision 的所有数据集都继承自 VisionDataset 基类,这是一个抽象基类,定义了数据集的基本接口:
class VisionDataset(data.Dataset):
def __init__(self, root=None, transforms=None, transform=None, target_transform=None):
# 初始化逻辑
pass
def __getitem__(self, index):
raise NotImplementedError
def __len__(self):
raise NotImplementedError
这个设计模式为我们创建自定义数据集提供了清晰的框架。让我们通过一个流程图来理解数据集的生命周期:
flowchart TD
A[数据集初始化] --> B[设置根目录和转换]
B --> C[实现 __getitem__ 方法]
C --> D[实现 __len__ 方法]
D --> E[数据加载器使用]
E --> F[模型训练/验证]
创建自定义数据集的三种方式
1. 继承 VisionDataset 基类
这是最灵活的方式,适用于复杂的数据集结构:
import os
from torchvision.datasets import VisionDataset
from PIL import Image
class CustomDataset(VisionDataset):
def __init__(self, root, transform=None, target_transform=None):
super().__init__(root, transform=transform, target_transform=target_transform)
self.samples = self._load_samples()
def _load_samples(self):
samples = []
for class_name in os.listdir(self.root):
class_dir = os.path.join(self.root, class_name)
if os.path.isdir(class_dir):
for img_name in os.listdir(class_dir):
if img_name.endswith(('.jpg', '.png', '.jpeg')):
img_path = os.path.join(class_dir, img_name)
samples.append((img_path, class_name))
return samples
def __getitem__(self, index):
img_path, label = self.samples[index]
image = Image.open(img_path).convert('RGB')
if self.transform is not None:
image = self.transform(image)
return image, label
def __len__(self):
return len(self.samples)
2. 使用 DatasetFolder 类
对于标准的文件夹结构数据集,可以使用 DatasetFolder 类:
from torchvision.datasets import DatasetFolder
from torchvision.datasets.folder import default_loader
class CustomImageFolder(DatasetFolder):
def __init__(self, root, transform=None, target_transform=None):
super().__init__(
root=root,
loader=default_loader,
extensions=('.jpg', '.jpeg', '.png', '.bmp'),
transform=transform,
target_transform=target_transform
)
3. 使用 ImageFolder 类
对于标准的图像分类数据集结构,这是最简单的方式:
from torchvision.datasets import ImageFolder
dataset = ImageFolder(
root='path/to/data',
transform=transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
])
)
数据加载最佳实践
高效的数据加载配置
from torch.utils.data import DataLoader
# 最佳实践配置
dataloader = DataLoader(
dataset,
batch_size=32,
shuffle=True,
num_workers=4, # 根据CPU核心数调整
pin_memory=True, # 加速GPU数据传输
persistent_workers=True, # 保持worker进程
drop_last=True # 避免最后一个不完整的batch
)
内存映射优化
对于大型数据集,使用内存映射可以显著提高加载速度:
class MemoryMappedDataset(VisionDataset):
def __init__(self, root, transform=None):
super().__init__(root, transform=transform)
self.data = np.load(os.path.join(root, 'data.npy'), mmap_mode='r')
self.labels = np.load(os.path.join(root, 'labels.npy'), mmap_mode='r')
def __getitem__(self, index):
img = self.data[index]
label = self.labels[index]
if self.transform:
img = self.transform(img)
return img, label
数据增强策略
根据不同的任务类型,选择合适的数据增强策略:
| 任务类型 | 推荐增强 | 注意事项 |
|---|---|---|
| 图像分类 | RandomRotation, RandomHorizontalFlip, ColorJitter | 保持标签不变 |
| 目标检测 | 同上,但需要同步变换bbox坐标 | 需要自定义变换逻辑 |
| 语义分割 | 同上,需要同步变换mask | 使用相同的随机参数 |
from torchvision import transforms
# 分类任务增强
train_transform = transforms.Compose([
transforms.RandomResizedCrop(224),
transforms.RandomHorizontalFlip(),
transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
# 检测任务增强(需要自定义)
class DetectionTransform:
def __call__(self, image, target):
# 同步应用相同的随机变换到图像和bbox
if random.random() > 0.5:
image = F.hflip(image)
target['boxes'][:, [0, 2]] = image.width - target['boxes'][:, [2, 0]]
return image, target
数据集验证与完整性检查
创建自定义数据集时,完整性检查至关重要:
def validate_dataset(dataset):
"""验证数据集的完整性"""
issues = []
# 检查样本数量
if len(dataset) == 0:
issues.append("数据集为空")
# 检查文件存在性
for i in range(min(100, len(dataset))): # 抽样检查
try:
sample, label = dataset[i]
if sample is None:
issues.append(f"样本 {i} 加载失败")
except Exception as e:
issues.append(f"样本 {i} 错误: {str(e)}")
# 检查标签分布
if hasattr(dataset, 'targets'):
label_counts = {}
for label in dataset.targets:
label_counts[label] = label_counts.get(label, 0) + 1
if len(label_counts) == 1:
issues.append("数据集中只有一个类别")
return issues
性能优化技巧
预加载和缓存
对于小型数据集,可以考虑预加载到内存中:
class CachedDataset(VisionDataset):
def __init__(self, dataset):
self.dataset = dataset
self.cache = [None] * len(dataset)
def __getitem__(self, index):
if self.cache[index] is None:
self.cache[index] = self.dataset[index]
return self.cache[index]
def __len__(self):
return len(self.dataset)
使用 LMDB 数据库
对于超大型数据集,使用 LMDB 可以显著提高IO性能:
import lmdb
import pickle
class LmdbDataset(VisionDataset):
def __init__(self, lmdb_path, transform=None):
self.env = lmdb.open(lmdb_path, readonly=True, lock=False)
self.transform = transform
with self.env.begin() as txn:
self.length = pickle.loads(txn.get(b'__len__'))
def __getitem__(self, index):
with self.env.begin() as txn:
key = f'data_{index}'.encode()
data = pickle.loads(txn.get(key))
image = data['image']
label = data['label']
if self.transform:
image = self.transform(image)
return image, label
错误处理和恢复
健壮的数据集应该包含适当的错误处理机制:
class RobustDataset(VisionDataset):
def __getitem__(self, index):
max_retries = 3
for attempt in range(max_retries):
try:
return self._get_item(index)
except Exception as e:
if attempt == max_retries - 1:
# 返回一个替代样本
return self._get_fallback_sample()
continue
def _get_fallback_sample(self):
# 返回一个中性样本,避免训练中断
return torch.zeros(3, 224, 224), 0
通过遵循这些最佳实践,您可以创建高效、健壮的自定义数据集,为计算机视觉项目的成功奠定坚实的基础。记住,良好的数据管道是模型性能的关键因素之一。
torchvision提供了丰富而强大的内置数据集支持,从经典的MNIST、CIFAR到大规模的ImageNet和COCO,涵盖了计算机视觉各个领域的任务需求。通过合理使用这些数据集,结合本文介绍的最佳实践和自定义数据集创建方法,研究人员和开发者可以高效地构建数据管道,为模型训练提供可靠的数据基础。掌握这些数据集的特性、使用技巧和优化策略,将显著提升计算机视觉项目的开发效率和质量。
Kimi-K2.5Kimi K2.5 是一款开源的原生多模态智能体模型,它在 Kimi-K2-Base 的基础上,通过对约 15 万亿混合视觉和文本 tokens 进行持续预训练构建而成。该模型将视觉与语言理解、高级智能体能力、即时模式与思考模式,以及对话式与智能体范式无缝融合。Python00
GLM-4.7-FlashGLM-4.7-Flash 是一款 30B-A3B MoE 模型。作为 30B 级别中的佼佼者,GLM-4.7-Flash 为追求性能与效率平衡的轻量化部署提供了全新选择。Jinja00
VLOOKVLOOK™ 是优雅好用的 Typora/Markdown 主题包和增强插件。 VLOOK™ is an elegant and practical THEME PACKAGE × ENHANCEMENT PLUGIN for Typora/Markdown.Less00
PaddleOCR-VL-1.5PaddleOCR-VL-1.5 是 PaddleOCR-VL 的新一代进阶模型,在 OmniDocBench v1.5 上实现了 94.5% 的全新 state-of-the-art 准确率。 为了严格评估模型在真实物理畸变下的鲁棒性——包括扫描伪影、倾斜、扭曲、屏幕拍摄和光照变化——我们提出了 Real5-OmniDocBench 基准测试集。实验结果表明,该增强模型在新构建的基准测试集上达到了 SOTA 性能。此外,我们通过整合印章识别和文本检测识别(text spotting)任务扩展了模型的能力,同时保持 0.9B 的超紧凑 VLM 规模,具备高效率特性。Python00
KuiklyUI基于KMP技术的高性能、全平台开发框架,具备统一代码库、极致易用性和动态灵活性。 Provide a high-performance, full-platform development framework with unified codebase, ultimate ease of use, and dynamic flexibility. 注意:本仓库为Github仓库镜像,PR或Issue请移步至Github发起,感谢支持!Kotlin07
compass-metrics-modelMetrics model project for the OSS CompassPython00