首页
/ open_clip API完全参考:函数与参数详解

open_clip API完全参考:函数与参数详解

2026-02-04 05:04:50作者:胡唯隽

1. 核心模型构建接口

1.1 create_model:模型实例化主入口

def create_model(
    model_name: str,
    pretrained: Optional[str] = None,
    load_weights: bool = True,
    precision: str = 'fp32',
    device: Union[str, torch.device] = 'cpu',
    jit: bool = False,
    force_quick_gelu: bool = False,
    force_custom_text: bool = False,
    force_patch_dropout: Optional[float] = None,
    force_image_size: Optional[Union[int, Tuple[int, int]]] = None,
    force_preprocess_cfg: Optional[Dict[str, Any]] = None,
    force_context_length: Optional[int] = None,
    pretrained_image: bool = False,
    pretrained_text: bool = True,
    pretrained_image_path: Optional[str] = None,
    pretrained_text_path: Optional[str] = None,
    cache_dir: Optional[str] = None,
    output_dict: Optional[bool] = None,
    require_pretrained: bool = False,
    weights_only: bool = True,** model_kwargs
) -> torch.nn.Module:

参数说明

参数名 类型 默认值 说明
model_name str 必需 模型标识符,支持内置名称(如"ViT-B-32")或带schema路径(如"hf-hub:laion/CLIP-ViT-B-32-laion2B-s34B-b79K")
pretrained str None 预训练权重来源,仅当model_name无schema时有效,可为模型标签(如"laion2b_s34b_b79k")或本地文件路径
precision str 'fp32' 模型精度,支持'fp32'/'fp16'/'bf16'
device str/torch.device 'cpu' 模型加载设备
force_image_size int/tuple None 覆盖模型配置中的图像尺寸
pretrained_image_path str None 图像塔权重文件路径(优先级高于pretrained)

使用示例

# 加载ViT-B-32预训练模型
model = create_model(
    model_name="ViT-B-32",
    pretrained="laion2b_s34b_b79k",
    precision="fp16",
    device="cuda"
)

# 加载本地目录模型
model = create_model(
    model_name="local-dir:/path/to/model",
    force_image_size=256
)

1.2 CLIP模型核心方法

encode_image

def encode_image(self, image: torch.Tensor, normalize: bool = False) -> torch.Tensor:
  • 功能:将图像转换为特征向量
  • 参数
    • image: 预处理后的图像张量,形状为(batch_size, 3, H, W)
    • normalize: 是否对输出特征进行L2归一化

encode_text

def encode_text(self, text: torch.Tensor, normalize: bool = False) -> torch.Tensor:
  • 功能:将文本转换为特征向量
  • 参数
    • text: 分词后的文本张量,形状为(batch_size, context_length)
    • normalize: 是否对输出特征进行L2归一化

get_logits

def get_logits(self, image: torch.Tensor, text: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
  • 功能:计算图像-文本相似度分数
  • 返回值
    • image_logits: 图像到文本的相似度矩阵,形状为(batch_size, batch_size)
    • text_logits: 文本到图像的相似度矩阵,形状为(batch_size, batch_size)

2. 数据预处理接口

2.1 图像预处理

image_transform_v2

def image_transform_v2(
    cfg: PreprocessCfg,
    is_train: bool,
    aug_cfg: Optional[Union[Dict, AugmentationCfg]] = None
) -> Compose:
  • 功能:创建图像预处理流水线
  • 参数
    • cfg: 预处理配置(包含尺寸/均值/标准差等)
    • is_train: 是否为训练模式(启用数据增强)
    • aug_cfg: 增强配置(如随机裁剪/颜色抖动)

PreprocessCfg配置类

@dataclass
class PreprocessCfg:
    size: Union[int, Tuple[int, int]] = 224
    mean: Tuple[float, ...] = OPENAI_DATASET_MEAN  # (0.481, 0.457, 0.408)
    std: Tuple[float, ...] = OPENAI_DATASET_STD    # (0.268, 0.261, 0.275)
    interpolation: str = 'bicubic'                 # 插值方式
    resize_mode: str = 'shortest'                  # 缩放模式:'shortest'/'longest'/'squash'

使用示例

# 创建评估预处理流水线
preprocess = image_transform_v2(
    cfg=PreprocessCfg(
        size=224,
        resize_mode="squash"
    ),
    is_train=False
)

# 处理图像
image = preprocess(Image.open("image.jpg")).unsqueeze(0).to("cuda")

2.2 文本分词接口

SimpleTokenizer

def __call__(self, texts: Union[str, List[str]], context_length: int = 77) -> torch.LongTensor:
  • 功能:将文本转换为模型输入令牌
  • 参数
    • context_length: 文本序列长度(超长截断,不足补0)

使用示例

tokenizer = SimpleTokenizer()
tokens = tokenizer(["a photo of a cat", "a picture of a dog"], context_length=77)
# tokens形状: (2, 77)

3. 预训练模型管理

3.1 预训练配置查询

list_pretrained_tags_by_model

def list_pretrained_tags_by_model(model: str) -> List[str]:
  • 功能:获取指定模型可用的预训练标签
  • 返回值:预训练标签列表

使用示例

# 查询ViT-L-14可用预训练版本
tags = list_pretrained_tags_by_model("ViT-L-14")
print(tags)
# ['openai', 'laion400m_e31', 'laion2b_s32b_b82k']

3.2 预训练权重下载

download_pretrained

def download_pretrained(
    cfg: Dict,
    cache_dir: Optional[str] = None
) -> str:
  • 功能:下载指定配置的预训练权重
  • 返回值:权重文件本地路径

4. 损失函数

4.1 ClipLoss

class ClipLoss(nn.Module):
    def __init__(
        self,
        local_loss: bool = False,
        gather_with_grad: bool = False,
        cache_labels: bool = False,
        rank: int = 0,
        world_size: int = 1
    ):
  • 核心参数
    • local_loss: 是否仅计算本地设备内样本损失
    • gather_with_grad: 梯度聚合模式(分布式训练)
  • 前向传播
def forward(
    self,
    image_features: torch.Tensor,
    text_features: torch.Tensor,
    logit_scale: torch.Tensor,
    logit_bias: Optional[torch.Tensor] = None
) -> torch.Tensor:

4.2 SigLipLoss

class SigLipLoss(nn.Module):
    def __init__(
        self,
        cache_labels: bool = False,
        rank: int = 0,
        world_size: int = 1,
        dist_impl: str = "bidir"
    ):
  • 特点:Sigmoid交叉熵损失,支持多种分布式实现
  • dist_impl:分布式策略,支持"bidir"/"shift"/"gather"

5. 零样本分类工具

5.1 build_zero_shot_classifier

def build_zero_shot_classifier(
    model,
    tokenizer,
    classnames: Sequence[str],
    templates: Sequence[Union[Callable, str]],
    num_classes_per_batch: int = 10,
    device: str = 'cpu'
) -> torch.Tensor:
  • 功能:构建零样本分类器权重矩阵
  • 参数
    • classnames: 类别名称列表
    • templates: 文本模板列表(如["a photo of a {}"])
  • 返回值:形状为(d_model, num_classes)的分类权重矩阵

使用示例

# 构建ImageNet零样本分类器
classnames = ["cat", "dog", "bird"]
templates = ["a photo of a {}", "an image of a {}"]

zeroshot_weights = build_zero_shot_classifier(
    model, tokenizer, classnames, templates, device="cuda"
)

# 图像分类
image_features = model.encode_image(image, normalize=True)
logits = image_features @ zeroshot_weights
predictions = logits.argmax(dim=1)

6. 高级模型:CoCa

6.1 CoCa模型特点

  • 支持图像-文本对比学习和图像 captioning 任务
  • 继承CLIP基础架构,新增文本解码器

6.2 核心方法

generate

def generate(
    self,
    image,
    seq_len=30,
    generation_type="beam_search",
    num_beams=6,
    temperature=1.0
) -> torch.Tensor:
  • 功能:生成图像描述文本
  • 参数
    • generation_type: 生成策略:"beam_search"/"top_p"/"top_k"
    • num_beams: 束搜索宽度
    • temperature: 采样温度

使用示例

# 生成图像描述
image = preprocess(Image.open("image.jpg")).unsqueeze(0).to("cuda")
captions = model.generate(image, seq_len=20, num_beams=3)
# 解码令牌
decoded = tokenizer.decode(captions[0])

7. 模型配置与扩展

7.1 模型配置文件

模型配置文件位于src/open_clip/model_configs/目录,JSON格式包含:

  • embed_dim: 特征维度
  • vision_cfg: 视觉编码器配置
  • text_cfg: 文本编码器配置

7.2 自定义模型配置

# 加载自定义配置
model_cfg = json.load(open("custom_config.json"))
model = create_model(
    model_name="custom",
    pretrained=None,
    model_kwargs=model_cfg
)

8. 性能优化工具

8.1 混合精度转换

def convert_weights_to_lp(model: nn.Module, dtype=torch.float16) -> None:
  • 功能:将模型权重转换为低精度格式

8.2 梯度检查点

@torch.jit.ignore
def set_grad_checkpointing(self, enable: bool = True) -> None:
  • 功能:启用/禁用梯度检查点(节省显存)
model.set_grad_checkpointing(True)  # 启用梯度检查点

9. 常用常量

常量 说明
OPENAI_DATASET_MEAN (0.481, 0.457, 0.408) CLIP默认图像均值
OPENAI_DATASET_STD (0.268, 0.261, 0.275) CLIP默认图像标准差
DEFAULT_CONTEXT_LENGTH 77 文本最大序列长度

10. 部署与集成

10.1 Hugging Face模型转换

push_to_hf_hub

def push_to_hf_hub(
    model,
    repo_id: str,
    tokenizer=None,
    preprocess_cfg: Dict = None
) -> None:
  • 功能:将模型推送到Hugging Face Hub

10.2 ONNX导出

torch.onnx.export(
    model, (image_input, text_input),
    "clip.onnx",
    input_names=["image", "text"],
    output_names=["image_features", "text_features"]
)

附录:模型架构速查表

模型系列 特点 应用场景
ViT-B-32 基础视觉Transformer 通用图像-文本任务
ViT-L-14 大型视觉Transformer 高精度需求场景
RN50 ResNet-50视觉塔 计算资源有限时
SigLIP系列 改进对比损失 多语言任务
CoCa系列 带文本解码器 图像描述生成
登录后查看全文
热门项目推荐
相关项目推荐