open_clip API完全参考:函数与参数详解
2026-02-04 05:04:50作者:胡唯隽
1. 核心模型构建接口
1.1 create_model:模型实例化主入口
def create_model(
model_name: str,
pretrained: Optional[str] = None,
load_weights: bool = True,
precision: str = 'fp32',
device: Union[str, torch.device] = 'cpu',
jit: bool = False,
force_quick_gelu: bool = False,
force_custom_text: bool = False,
force_patch_dropout: Optional[float] = None,
force_image_size: Optional[Union[int, Tuple[int, int]]] = None,
force_preprocess_cfg: Optional[Dict[str, Any]] = None,
force_context_length: Optional[int] = None,
pretrained_image: bool = False,
pretrained_text: bool = True,
pretrained_image_path: Optional[str] = None,
pretrained_text_path: Optional[str] = None,
cache_dir: Optional[str] = None,
output_dict: Optional[bool] = None,
require_pretrained: bool = False,
weights_only: bool = True,** model_kwargs
) -> torch.nn.Module:
参数说明
| 参数名 | 类型 | 默认值 | 说明 |
|---|---|---|---|
| model_name | str | 必需 | 模型标识符,支持内置名称(如"ViT-B-32")或带schema路径(如"hf-hub:laion/CLIP-ViT-B-32-laion2B-s34B-b79K") |
| pretrained | str | None | 预训练权重来源,仅当model_name无schema时有效,可为模型标签(如"laion2b_s34b_b79k")或本地文件路径 |
| precision | str | 'fp32' | 模型精度,支持'fp32'/'fp16'/'bf16' |
| device | str/torch.device | 'cpu' | 模型加载设备 |
| force_image_size | int/tuple | None | 覆盖模型配置中的图像尺寸 |
| pretrained_image_path | str | None | 图像塔权重文件路径(优先级高于pretrained) |
使用示例
# 加载ViT-B-32预训练模型
model = create_model(
model_name="ViT-B-32",
pretrained="laion2b_s34b_b79k",
precision="fp16",
device="cuda"
)
# 加载本地目录模型
model = create_model(
model_name="local-dir:/path/to/model",
force_image_size=256
)
1.2 CLIP模型核心方法
encode_image
def encode_image(self, image: torch.Tensor, normalize: bool = False) -> torch.Tensor:
- 功能:将图像转换为特征向量
- 参数:
image: 预处理后的图像张量,形状为(batch_size, 3, H, W)normalize: 是否对输出特征进行L2归一化
encode_text
def encode_text(self, text: torch.Tensor, normalize: bool = False) -> torch.Tensor:
- 功能:将文本转换为特征向量
- 参数:
text: 分词后的文本张量,形状为(batch_size, context_length)normalize: 是否对输出特征进行L2归一化
get_logits
def get_logits(self, image: torch.Tensor, text: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
- 功能:计算图像-文本相似度分数
- 返回值:
image_logits: 图像到文本的相似度矩阵,形状为(batch_size, batch_size)text_logits: 文本到图像的相似度矩阵,形状为(batch_size, batch_size)
2. 数据预处理接口
2.1 图像预处理
image_transform_v2
def image_transform_v2(
cfg: PreprocessCfg,
is_train: bool,
aug_cfg: Optional[Union[Dict, AugmentationCfg]] = None
) -> Compose:
- 功能:创建图像预处理流水线
- 参数:
cfg: 预处理配置(包含尺寸/均值/标准差等)is_train: 是否为训练模式(启用数据增强)aug_cfg: 增强配置(如随机裁剪/颜色抖动)
PreprocessCfg配置类
@dataclass
class PreprocessCfg:
size: Union[int, Tuple[int, int]] = 224
mean: Tuple[float, ...] = OPENAI_DATASET_MEAN # (0.481, 0.457, 0.408)
std: Tuple[float, ...] = OPENAI_DATASET_STD # (0.268, 0.261, 0.275)
interpolation: str = 'bicubic' # 插值方式
resize_mode: str = 'shortest' # 缩放模式:'shortest'/'longest'/'squash'
使用示例
# 创建评估预处理流水线
preprocess = image_transform_v2(
cfg=PreprocessCfg(
size=224,
resize_mode="squash"
),
is_train=False
)
# 处理图像
image = preprocess(Image.open("image.jpg")).unsqueeze(0).to("cuda")
2.2 文本分词接口
SimpleTokenizer
def __call__(self, texts: Union[str, List[str]], context_length: int = 77) -> torch.LongTensor:
- 功能:将文本转换为模型输入令牌
- 参数:
context_length: 文本序列长度(超长截断,不足补0)
使用示例
tokenizer = SimpleTokenizer()
tokens = tokenizer(["a photo of a cat", "a picture of a dog"], context_length=77)
# tokens形状: (2, 77)
3. 预训练模型管理
3.1 预训练配置查询
list_pretrained_tags_by_model
def list_pretrained_tags_by_model(model: str) -> List[str]:
- 功能:获取指定模型可用的预训练标签
- 返回值:预训练标签列表
使用示例
# 查询ViT-L-14可用预训练版本
tags = list_pretrained_tags_by_model("ViT-L-14")
print(tags)
# ['openai', 'laion400m_e31', 'laion2b_s32b_b82k']
3.2 预训练权重下载
download_pretrained
def download_pretrained(
cfg: Dict,
cache_dir: Optional[str] = None
) -> str:
- 功能:下载指定配置的预训练权重
- 返回值:权重文件本地路径
4. 损失函数
4.1 ClipLoss
class ClipLoss(nn.Module):
def __init__(
self,
local_loss: bool = False,
gather_with_grad: bool = False,
cache_labels: bool = False,
rank: int = 0,
world_size: int = 1
):
- 核心参数:
local_loss: 是否仅计算本地设备内样本损失gather_with_grad: 梯度聚合模式(分布式训练)
- 前向传播:
def forward(
self,
image_features: torch.Tensor,
text_features: torch.Tensor,
logit_scale: torch.Tensor,
logit_bias: Optional[torch.Tensor] = None
) -> torch.Tensor:
4.2 SigLipLoss
class SigLipLoss(nn.Module):
def __init__(
self,
cache_labels: bool = False,
rank: int = 0,
world_size: int = 1,
dist_impl: str = "bidir"
):
- 特点:Sigmoid交叉熵损失,支持多种分布式实现
- dist_impl:分布式策略,支持"bidir"/"shift"/"gather"
5. 零样本分类工具
5.1 build_zero_shot_classifier
def build_zero_shot_classifier(
model,
tokenizer,
classnames: Sequence[str],
templates: Sequence[Union[Callable, str]],
num_classes_per_batch: int = 10,
device: str = 'cpu'
) -> torch.Tensor:
- 功能:构建零样本分类器权重矩阵
- 参数:
classnames: 类别名称列表templates: 文本模板列表(如["a photo of a {}"])
- 返回值:形状为
(d_model, num_classes)的分类权重矩阵
使用示例
# 构建ImageNet零样本分类器
classnames = ["cat", "dog", "bird"]
templates = ["a photo of a {}", "an image of a {}"]
zeroshot_weights = build_zero_shot_classifier(
model, tokenizer, classnames, templates, device="cuda"
)
# 图像分类
image_features = model.encode_image(image, normalize=True)
logits = image_features @ zeroshot_weights
predictions = logits.argmax(dim=1)
6. 高级模型:CoCa
6.1 CoCa模型特点
- 支持图像-文本对比学习和图像 captioning 任务
- 继承CLIP基础架构,新增文本解码器
6.2 核心方法
generate
def generate(
self,
image,
seq_len=30,
generation_type="beam_search",
num_beams=6,
temperature=1.0
) -> torch.Tensor:
- 功能:生成图像描述文本
- 参数:
generation_type: 生成策略:"beam_search"/"top_p"/"top_k"num_beams: 束搜索宽度temperature: 采样温度
使用示例
# 生成图像描述
image = preprocess(Image.open("image.jpg")).unsqueeze(0).to("cuda")
captions = model.generate(image, seq_len=20, num_beams=3)
# 解码令牌
decoded = tokenizer.decode(captions[0])
7. 模型配置与扩展
7.1 模型配置文件
模型配置文件位于src/open_clip/model_configs/目录,JSON格式包含:
embed_dim: 特征维度vision_cfg: 视觉编码器配置text_cfg: 文本编码器配置
7.2 自定义模型配置
# 加载自定义配置
model_cfg = json.load(open("custom_config.json"))
model = create_model(
model_name="custom",
pretrained=None,
model_kwargs=model_cfg
)
8. 性能优化工具
8.1 混合精度转换
def convert_weights_to_lp(model: nn.Module, dtype=torch.float16) -> None:
- 功能:将模型权重转换为低精度格式
8.2 梯度检查点
@torch.jit.ignore
def set_grad_checkpointing(self, enable: bool = True) -> None:
- 功能:启用/禁用梯度检查点(节省显存)
model.set_grad_checkpointing(True) # 启用梯度检查点
9. 常用常量
| 常量 | 值 | 说明 |
|---|---|---|
| OPENAI_DATASET_MEAN | (0.481, 0.457, 0.408) | CLIP默认图像均值 |
| OPENAI_DATASET_STD | (0.268, 0.261, 0.275) | CLIP默认图像标准差 |
| DEFAULT_CONTEXT_LENGTH | 77 | 文本最大序列长度 |
10. 部署与集成
10.1 Hugging Face模型转换
push_to_hf_hub
def push_to_hf_hub(
model,
repo_id: str,
tokenizer=None,
preprocess_cfg: Dict = None
) -> None:
- 功能:将模型推送到Hugging Face Hub
10.2 ONNX导出
torch.onnx.export(
model, (image_input, text_input),
"clip.onnx",
input_names=["image", "text"],
output_names=["image_features", "text_features"]
)
附录:模型架构速查表
| 模型系列 | 特点 | 应用场景 |
|---|---|---|
| ViT-B-32 | 基础视觉Transformer | 通用图像-文本任务 |
| ViT-L-14 | 大型视觉Transformer | 高精度需求场景 |
| RN50 | ResNet-50视觉塔 | 计算资源有限时 |
| SigLIP系列 | 改进对比损失 | 多语言任务 |
| CoCa系列 | 带文本解码器 | 图像描述生成 |
登录后查看全文
热门项目推荐
相关项目推荐
Kimi-K2.5Kimi K2.5 是一款开源的原生多模态智能体模型,它在 Kimi-K2-Base 的基础上,通过对约 15 万亿混合视觉和文本 tokens 进行持续预训练构建而成。该模型将视觉与语言理解、高级智能体能力、即时模式与思考模式,以及对话式与智能体范式无缝融合。Python00
GLM-4.7-FlashGLM-4.7-Flash 是一款 30B-A3B MoE 模型。作为 30B 级别中的佼佼者,GLM-4.7-Flash 为追求性能与效率平衡的轻量化部署提供了全新选择。Jinja00
VLOOKVLOOK™ 是优雅好用的 Typora/Markdown 主题包和增强插件。 VLOOK™ is an elegant and practical THEME PACKAGE × ENHANCEMENT PLUGIN for Typora/Markdown.Less00
PaddleOCR-VL-1.5PaddleOCR-VL-1.5 是 PaddleOCR-VL 的新一代进阶模型,在 OmniDocBench v1.5 上实现了 94.5% 的全新 state-of-the-art 准确率。 为了严格评估模型在真实物理畸变下的鲁棒性——包括扫描伪影、倾斜、扭曲、屏幕拍摄和光照变化——我们提出了 Real5-OmniDocBench 基准测试集。实验结果表明,该增强模型在新构建的基准测试集上达到了 SOTA 性能。此外,我们通过整合印章识别和文本检测识别(text spotting)任务扩展了模型的能力,同时保持 0.9B 的超紧凑 VLM 规模,具备高效率特性。Python00
KuiklyUI基于KMP技术的高性能、全平台开发框架,具备统一代码库、极致易用性和动态灵活性。 Provide a high-performance, full-platform development framework with unified codebase, ultimate ease of use, and dynamic flexibility. 注意:本仓库为Github仓库镜像,PR或Issue请移步至Github发起,感谢支持!Kotlin07
compass-metrics-modelMetrics model project for the OSS CompassPython00
项目优选
收起
deepin linux kernel
C
27
11
OpenHarmony documentation | OpenHarmony开发者文档
Dockerfile
525
3.72 K
Ascend Extension for PyTorch
Python
329
391
本项目是CANN提供的数学类基础计算算子库,实现网络在NPU上加速计算。
C++
877
578
openEuler内核是openEuler操作系统的核心,既是系统性能与稳定性的基石,也是连接处理器、设备与服务的桥梁。
C
335
162
暂无简介
Dart
764
189
Nop Platform 2.0是基于可逆计算理论实现的采用面向语言编程范式的新一代低代码开发平台,包含基于全新原理从零开始研发的GraphQL引擎、ORM引擎、工作流引擎、报表引擎、规则引擎、批处理引引擎等完整设计。nop-entropy是它的后端部分,采用java语言实现,可选择集成Spring框架或者Quarkus框架。中小企业可以免费商用
Java
12
1
🎉 (RuoYi)官方仓库 基于SpringBoot,Spring Security,JWT,Vue3 & Vite、Element Plus 的前后端分离权限管理系统
Vue
1.33 K
746
🔥LeetCode solutions in any programming language | 多种编程语言实现 LeetCode、《剑指 Offer(第 2 版)》、《程序员面试金典(第 6 版)》题解
Java
67
20
React Native鸿蒙化仓库
JavaScript
302
350