open_clip API完全参考:函数与参数详解
2026-02-04 05:04:50作者:胡唯隽
1. 核心模型构建接口
1.1 create_model:模型实例化主入口
def create_model(
model_name: str,
pretrained: Optional[str] = None,
load_weights: bool = True,
precision: str = 'fp32',
device: Union[str, torch.device] = 'cpu',
jit: bool = False,
force_quick_gelu: bool = False,
force_custom_text: bool = False,
force_patch_dropout: Optional[float] = None,
force_image_size: Optional[Union[int, Tuple[int, int]]] = None,
force_preprocess_cfg: Optional[Dict[str, Any]] = None,
force_context_length: Optional[int] = None,
pretrained_image: bool = False,
pretrained_text: bool = True,
pretrained_image_path: Optional[str] = None,
pretrained_text_path: Optional[str] = None,
cache_dir: Optional[str] = None,
output_dict: Optional[bool] = None,
require_pretrained: bool = False,
weights_only: bool = True,** model_kwargs
) -> torch.nn.Module:
参数说明
| 参数名 | 类型 | 默认值 | 说明 |
|---|---|---|---|
| model_name | str | 必需 | 模型标识符,支持内置名称(如"ViT-B-32")或带schema路径(如"hf-hub:laion/CLIP-ViT-B-32-laion2B-s34B-b79K") |
| pretrained | str | None | 预训练权重来源,仅当model_name无schema时有效,可为模型标签(如"laion2b_s34b_b79k")或本地文件路径 |
| precision | str | 'fp32' | 模型精度,支持'fp32'/'fp16'/'bf16' |
| device | str/torch.device | 'cpu' | 模型加载设备 |
| force_image_size | int/tuple | None | 覆盖模型配置中的图像尺寸 |
| pretrained_image_path | str | None | 图像塔权重文件路径(优先级高于pretrained) |
使用示例
# 加载ViT-B-32预训练模型
model = create_model(
model_name="ViT-B-32",
pretrained="laion2b_s34b_b79k",
precision="fp16",
device="cuda"
)
# 加载本地目录模型
model = create_model(
model_name="local-dir:/path/to/model",
force_image_size=256
)
1.2 CLIP模型核心方法
encode_image
def encode_image(self, image: torch.Tensor, normalize: bool = False) -> torch.Tensor:
- 功能:将图像转换为特征向量
- 参数:
image: 预处理后的图像张量,形状为(batch_size, 3, H, W)normalize: 是否对输出特征进行L2归一化
encode_text
def encode_text(self, text: torch.Tensor, normalize: bool = False) -> torch.Tensor:
- 功能:将文本转换为特征向量
- 参数:
text: 分词后的文本张量,形状为(batch_size, context_length)normalize: 是否对输出特征进行L2归一化
get_logits
def get_logits(self, image: torch.Tensor, text: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
- 功能:计算图像-文本相似度分数
- 返回值:
image_logits: 图像到文本的相似度矩阵,形状为(batch_size, batch_size)text_logits: 文本到图像的相似度矩阵,形状为(batch_size, batch_size)
2. 数据预处理接口
2.1 图像预处理
image_transform_v2
def image_transform_v2(
cfg: PreprocessCfg,
is_train: bool,
aug_cfg: Optional[Union[Dict, AugmentationCfg]] = None
) -> Compose:
- 功能:创建图像预处理流水线
- 参数:
cfg: 预处理配置(包含尺寸/均值/标准差等)is_train: 是否为训练模式(启用数据增强)aug_cfg: 增强配置(如随机裁剪/颜色抖动)
PreprocessCfg配置类
@dataclass
class PreprocessCfg:
size: Union[int, Tuple[int, int]] = 224
mean: Tuple[float, ...] = OPENAI_DATASET_MEAN # (0.481, 0.457, 0.408)
std: Tuple[float, ...] = OPENAI_DATASET_STD # (0.268, 0.261, 0.275)
interpolation: str = 'bicubic' # 插值方式
resize_mode: str = 'shortest' # 缩放模式:'shortest'/'longest'/'squash'
使用示例
# 创建评估预处理流水线
preprocess = image_transform_v2(
cfg=PreprocessCfg(
size=224,
resize_mode="squash"
),
is_train=False
)
# 处理图像
image = preprocess(Image.open("image.jpg")).unsqueeze(0).to("cuda")
2.2 文本分词接口
SimpleTokenizer
def __call__(self, texts: Union[str, List[str]], context_length: int = 77) -> torch.LongTensor:
- 功能:将文本转换为模型输入令牌
- 参数:
context_length: 文本序列长度(超长截断,不足补0)
使用示例
tokenizer = SimpleTokenizer()
tokens = tokenizer(["a photo of a cat", "a picture of a dog"], context_length=77)
# tokens形状: (2, 77)
3. 预训练模型管理
3.1 预训练配置查询
list_pretrained_tags_by_model
def list_pretrained_tags_by_model(model: str) -> List[str]:
- 功能:获取指定模型可用的预训练标签
- 返回值:预训练标签列表
使用示例
# 查询ViT-L-14可用预训练版本
tags = list_pretrained_tags_by_model("ViT-L-14")
print(tags)
# ['openai', 'laion400m_e31', 'laion2b_s32b_b82k']
3.2 预训练权重下载
download_pretrained
def download_pretrained(
cfg: Dict,
cache_dir: Optional[str] = None
) -> str:
- 功能:下载指定配置的预训练权重
- 返回值:权重文件本地路径
4. 损失函数
4.1 ClipLoss
class ClipLoss(nn.Module):
def __init__(
self,
local_loss: bool = False,
gather_with_grad: bool = False,
cache_labels: bool = False,
rank: int = 0,
world_size: int = 1
):
- 核心参数:
local_loss: 是否仅计算本地设备内样本损失gather_with_grad: 梯度聚合模式(分布式训练)
- 前向传播:
def forward(
self,
image_features: torch.Tensor,
text_features: torch.Tensor,
logit_scale: torch.Tensor,
logit_bias: Optional[torch.Tensor] = None
) -> torch.Tensor:
4.2 SigLipLoss
class SigLipLoss(nn.Module):
def __init__(
self,
cache_labels: bool = False,
rank: int = 0,
world_size: int = 1,
dist_impl: str = "bidir"
):
- 特点:Sigmoid交叉熵损失,支持多种分布式实现
- dist_impl:分布式策略,支持"bidir"/"shift"/"gather"
5. 零样本分类工具
5.1 build_zero_shot_classifier
def build_zero_shot_classifier(
model,
tokenizer,
classnames: Sequence[str],
templates: Sequence[Union[Callable, str]],
num_classes_per_batch: int = 10,
device: str = 'cpu'
) -> torch.Tensor:
- 功能:构建零样本分类器权重矩阵
- 参数:
classnames: 类别名称列表templates: 文本模板列表(如["a photo of a {}"])
- 返回值:形状为
(d_model, num_classes)的分类权重矩阵
使用示例
# 构建ImageNet零样本分类器
classnames = ["cat", "dog", "bird"]
templates = ["a photo of a {}", "an image of a {}"]
zeroshot_weights = build_zero_shot_classifier(
model, tokenizer, classnames, templates, device="cuda"
)
# 图像分类
image_features = model.encode_image(image, normalize=True)
logits = image_features @ zeroshot_weights
predictions = logits.argmax(dim=1)
6. 高级模型:CoCa
6.1 CoCa模型特点
- 支持图像-文本对比学习和图像 captioning 任务
- 继承CLIP基础架构,新增文本解码器
6.2 核心方法
generate
def generate(
self,
image,
seq_len=30,
generation_type="beam_search",
num_beams=6,
temperature=1.0
) -> torch.Tensor:
- 功能:生成图像描述文本
- 参数:
generation_type: 生成策略:"beam_search"/"top_p"/"top_k"num_beams: 束搜索宽度temperature: 采样温度
使用示例
# 生成图像描述
image = preprocess(Image.open("image.jpg")).unsqueeze(0).to("cuda")
captions = model.generate(image, seq_len=20, num_beams=3)
# 解码令牌
decoded = tokenizer.decode(captions[0])
7. 模型配置与扩展
7.1 模型配置文件
模型配置文件位于src/open_clip/model_configs/目录,JSON格式包含:
embed_dim: 特征维度vision_cfg: 视觉编码器配置text_cfg: 文本编码器配置
7.2 自定义模型配置
# 加载自定义配置
model_cfg = json.load(open("custom_config.json"))
model = create_model(
model_name="custom",
pretrained=None,
model_kwargs=model_cfg
)
8. 性能优化工具
8.1 混合精度转换
def convert_weights_to_lp(model: nn.Module, dtype=torch.float16) -> None:
- 功能:将模型权重转换为低精度格式
8.2 梯度检查点
@torch.jit.ignore
def set_grad_checkpointing(self, enable: bool = True) -> None:
- 功能:启用/禁用梯度检查点(节省显存)
model.set_grad_checkpointing(True) # 启用梯度检查点
9. 常用常量
| 常量 | 值 | 说明 |
|---|---|---|
| OPENAI_DATASET_MEAN | (0.481, 0.457, 0.408) | CLIP默认图像均值 |
| OPENAI_DATASET_STD | (0.268, 0.261, 0.275) | CLIP默认图像标准差 |
| DEFAULT_CONTEXT_LENGTH | 77 | 文本最大序列长度 |
10. 部署与集成
10.1 Hugging Face模型转换
push_to_hf_hub
def push_to_hf_hub(
model,
repo_id: str,
tokenizer=None,
preprocess_cfg: Dict = None
) -> None:
- 功能:将模型推送到Hugging Face Hub
10.2 ONNX导出
torch.onnx.export(
model, (image_input, text_input),
"clip.onnx",
input_names=["image", "text"],
output_names=["image_features", "text_features"]
)
附录:模型架构速查表
| 模型系列 | 特点 | 应用场景 |
|---|---|---|
| ViT-B-32 | 基础视觉Transformer | 通用图像-文本任务 |
| ViT-L-14 | 大型视觉Transformer | 高精度需求场景 |
| RN50 | ResNet-50视觉塔 | 计算资源有限时 |
| SigLIP系列 | 改进对比损失 | 多语言任务 |
| CoCa系列 | 带文本解码器 | 图像描述生成 |
登录后查看全文
热门项目推荐
相关项目推荐
GLM-5智谱 AI 正式发布 GLM-5,旨在应对复杂系统工程和长时域智能体任务。Jinja00
GLM-5.1GLM-5.1是智谱迄今最智能的旗舰模型,也是目前全球最强的开源模型。GLM-5.1大大提高了代码能力,在完成长程任务方面提升尤为显著。和此前分钟级交互的模型不同,它能够在一次任务中独立、持续工作超过8小时,期间自主规划、执行、自我进化,最终交付完整的工程级成果。Jinja00
LongCat-AudioDiT-1BLongCat-AudioDiT 是一款基于扩散模型的文本转语音(TTS)模型,代表了当前该领域的最高水平(SOTA),它直接在波形潜空间中进行操作。00- QQwen3.5-397B-A17BQwen3.5 实现了重大飞跃,整合了多模态学习、架构效率、强化学习规模以及全球可访问性等方面的突破性进展,旨在为开发者和企业赋予前所未有的能力与效率。Jinja00
AtomGit城市坐标计划AtomGit 城市坐标计划开启!让开源有坐标,让城市有星火。致力于与城市合伙人共同构建并长期运营一个健康、活跃的本地开发者生态。01
CAP基于最终一致性的微服务分布式事务解决方案,也是一种采用 Outbox 模式的事件总线。C#00
热门内容推荐
最新内容推荐
3种实用方案解决软件试用期管理难题SMUDebugTool:重新定义AMD Ryzen硬件调试的开源解决方案企业级视频本地化:技术架构与商业落地指南4个效率优化维度:Kronos金融大模型资源配置与训练实战指南3步打造高效键盘效率工具:MyKeymap个性化配置指南RapidOCR:企业级本地化OCR工具的技术解析与应用实践开源小说下载工具:实现网络小说本地存储的完整方案Detect-It-Easy技术教程:精准识别PyInstaller打包文件的核心方法GDevelop零代码游戏开发:3大痛点解决方案与实战案例高效解决知识星球内容备份难题:完全掌握zsxq-spider从爬取到PDF的知识管理方案
项目优选
收起
OpenHarmony documentation | OpenHarmony开发者文档
Dockerfile
650
4.23 K
deepin linux kernel
C
27
14
Ascend Extension for PyTorch
Python
487
596
openEuler内核是openEuler操作系统的核心,既是系统性能与稳定性的基石,也是连接处理器、设备与服务的桥梁。
C
390
279
🎉 (RuoYi)官方仓库 基于SpringBoot,Spring Security,JWT,Vue3 & Vite、Element Plus 的前后端分离权限管理系统
Vue
1.53 K
886
Oohos_react_native
React Native鸿蒙化仓库
JavaScript
332
387
本项目是CANN提供的数学类基础计算算子库,实现网络在NPU上加速计算。
C++
937
851
暂无简介
Dart
899
215
昇腾LLM分布式训练框架
Python
141
167
AscendNPU-IR是基于MLIR(Multi-Level Intermediate Representation)构建的,面向昇腾亲和算子编译时使用的中间表示,提供昇腾完备表达能力,通过编译优化提升昇腾AI处理器计算效率,支持通过生态框架使能昇腾AI处理器与深度调优
C++
123
194