首页
/ TabPFN 项目使用教程

TabPFN 项目使用教程

2026-02-06 05:27:38作者:胡唯隽

项目概述

TabPFN 是一个基于 Transformer 架构的表格数据基础模型,能够在极短时间内(约1秒)解决小型表格分类问题。该项目由 Prior Labs 开发,提供分类和回归任务的快速预测能力,特别适合中小规模数据集。

项目目录结构

TabPFN 项目采用标准 Python 包结构,主要目录如下:

TabPFN/
├── CHANGELOG.md          # 版本变更记录
├── LICENSE               # 许可证文件
├── README.md            # 项目说明文档
├── TELEMETRY.md         # 遥测数据说明
├── TabPFN_Demo_Local.ipynb  # 本地演示笔记本
├── examples/            # 示例代码目录
│   ├── finetune_classifier.py      # 分类器微调示例
│   ├── finetune_regressor.py       # 回归器微调示例
│   ├── kv_cache_fast_prediction.py # 快速预测示例
│   ├── notebooks/                  # 笔记本目录
│   ├── save_and_load_model.py      # 模型保存加载示例
│   ├── tabpfn_for_binary_classification.py  # 二分类示例
│   ├── tabpfn_for_multiclass_classification.py  # 多分类示例
│   ├── tabpfn_for_regression.py    # 回归示例
│   └── tabpfn_with_tuning.py       # 调优示例
├── pyproject.toml       # 项目构建配置文件
├── scripts/             # 脚本目录
│   ├── __init__.py
│   └── download_all_models.py  # 模型下载脚本
├── src/                 # 源代码目录
│   └── tabpfn/         # 核心包
│       ├── __init__.py
│       ├── architectures/     # 模型架构
│       ├── base.py           # 基础模块
│       ├── classifier.py     # 分类器实现
│       ├── constants.py      # 常量定义
│       ├── finetune_utils.py # 微调工具
│       ├── inference.py      # 推理模块
│       ├── model_loading.py  # 模型加载
│       ├── preprocessors/    # 预处理器
│       ├── regressor.py      # 回归器实现
│       └── utils.py          # 工具函数
└── tests/               # 测试目录
    ├── test_classifier_interface.py  # 分类器接口测试
    ├── test_regressor_interface.py   # 回归器接口测试
    └── 其他测试文件

安装配置

环境要求

  • Python 3.9 或更高版本
  • PyTorch >= 2.1
  • CUDA 支持(推荐,用于 GPU 加速)

安装方式

官方 PyPI 安装:

pip install tabpfn

从源码安装:

pip install "tabpfn @ git+https://gitcode.com/gh_mirrors/ta/TabPFN.git"

本地开发安装:

git clone https://gitcode.com/gh_mirrors/ta/TabPFN.git --depth 1
cd TabPFN
pip install -e ".[dev]"

快速开始

分类任务示例

from sklearn.datasets import load_breast_cancer
from sklearn.metrics import accuracy_score, roc_auc_score
from sklearn.model_selection import train_test_split
from tabpfn import TabPFNClassifier

# 加载数据
X, y = load_breast_cancer(return_X_y=True)
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.5, random_state=42)

# 初始化分类器
clf = TabPFNClassifier()
clf.fit(X_train, y_train)

# 预测概率
prediction_probabilities = clf.predict_proba(X_test)
print("ROC AUC:", roc_auc_score(y_test, prediction_probabilities[:, 1]))

# 预测标签
predictions = clf.predict(X_test)
print("Accuracy", accuracy_score(y_test, predictions))

回归任务示例

from sklearn.datasets import fetch_openml
from sklearn.metrics import mean_squared_error, r2_score
from sklearn.model_selection import train_test_split
from tabpfn import TabPFNRegressor

# 加载波士顿房价数据
df = fetch_openml(data_id=531, as_frame=True)
X = df.data
y = df.target.astype(float)

# 训练测试分割
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.5, random_state=42)

# 初始化回归器
regressor = TabPFNRegressor()
regressor.fit(X_train, y_train)

# 预测
predictions = regressor.predict(X_test)

# 评估
mse = mean_squared_error(y_test, predictions)
r2 = r2_score(y_test, predictions)
print("Mean Squared Error (MSE):", mse)
print("R² Score:", r2)

核心功能

1. 分类器 (TabPFNClassifier)

提供快速的表格数据分类能力,支持二分类和多分类任务。

2. 回归器 (TabPFNRegressor)

提供表格数据回归预测功能,适用于连续值预测任务。

3. 模型微调

支持对预训练模型进行微调,以适应特定数据集:

from tabpfn import TabPFNClassifier
from tabpfn.finetune_utils import finetune

# 微调分类器
clf = TabPFNClassifier()
finetuned_clf = finetune(clf, X_train, y_train, epochs=10)

4. 模型保存与加载

from tabpfn.model_loading import save_fitted_tabpfn_model, load_fitted_tabpfn_model

# 保存模型
save_fitted_tabpfn_model(clf, "my_model.tabpfn_fit")

# 加载模型
loaded_clf = load_fitted_tabpfn_model("my_model.tabpfn_fit", device="cpu")

配置说明

环境变量配置

TabPFN 支持通过环境变量进行配置:

# 设置模型缓存目录
export TABPFN_MODEL_CACHE_DIR="/path/to/models"

# 允许在 CPU 上运行大型数据集
export TABPFN_ALLOW_CPU_LARGE_DATASET=true

# 配置 PyTorch CUDA 内存分配
export PYTORCH_CUDA_ALLOC_CONF="max_split_size_mb:512"

配置文件

项目使用 pyproject.toml 作为主要配置文件,定义了项目依赖、构建系统和开发工具配置。

开发指南

设置开发环境

python -m venv venv
source venv/bin/activate  # Windows: venv\Scripts\activate
git clone https://gitcode.com/gh_mirrors/ta/TabPFN.git
cd TabPFN
pip install -e ".[dev]"
pre-commit install

运行测试

# 运行所有测试
pytest tests/

# 运行特定测试
pytest tests/test_classifier_interface.py

代码规范

项目使用 Ruff 进行代码格式化和 linting:

# 格式化代码
ruff format .

# 检查代码规范
ruff check .

常见问题解答

Q: TabPFN 支持的数据集大小是多少?

A: TabPFN-2.5 针对最多 50,000 行的数据集进行了优化。对于更大的数据集,建议使用随机森林预处理或其他扩展方法。

Q: 如何在没有互联网连接的情况下使用 TabPFN?

A: 使用提供的下载脚本下载所有模型:

python scripts/download_all_models.py

Q: TabPFN 能处理缺失值吗?

A: 是的,TabPFN 内置了缺失值处理能力。

性能优化建议

  1. 使用 GPU:推荐使用 GPU 以获得最佳性能,即使是较旧的 8GB VRAM GPU 也能良好工作
  2. 启用 KV 缓存:使用 fit_mode='fit_with_cache' 来加快预测速度
  3. 批量处理:对于多个数据集,使用批量处理来提高效率

许可证说明

TabPFN-2.5 模型权重使用非商业许可证。代码和 TabPFN-2 模型权重使用 Prior Labs 许可证(Apache 2.0 带有额外的归属要求)。

通过本教程,您应该能够快速上手使用 TabPFN 进行表格数据的分类和回归任务。项目的示例代码和详细文档为您提供了丰富的参考资源。

登录后查看全文
热门项目推荐
相关项目推荐