首页
/ 使用Cirkit构建和训练概率电路模型:MNIST图像分布学习实战

使用Cirkit构建和训练概率电路模型:MNIST图像分布学习实战

2025-06-02 08:24:03作者:盛欣凯Ernestine

引言

概率电路(Probabilistic Circuits)是一种强大的概率图模型表示方法,它通过分层结构组合简单分布来建模复杂概率分布。本文将介绍如何使用Cirkit工具包构建、训练和评估一个专门用于MNIST手写数字图像分布建模的概率电路模型。

概率电路基础概念

概率电路由三种基本层组成:

  1. 输入层:直接建模原始变量的分布(如像素值)
  2. 乘积层:对输入进行分解组合
  3. 求和层:混合不同组件

与传统神经网络不同,概率电路具有明确的可解释性结构,并且能够保证某些重要性质如可分解性和平滑性。

构建符号电路

在Cirkit中,我们首先构建一个符号电路,这是一种抽象表示,定义了电路的结构和连接方式,但不包含具体参数。

from cirkit.templates import data_modalities, utils

symbolic_circuit = data_modalities.image_data(
    (1, 28, 28),                # MNIST图像形状(通道数,高,宽)
    region_graph='quad-graph',   # 使用四叉树区域图结构
    input_layer='categorical',   # 输入层使用分类分布
    num_input_units=64,          # 每个输入层64个单元
    sum_product_layer='cp',       # 使用CP(张量积)求和-乘积层
    num_sum_units=64,            # 每个求和层64个单元
    sum_weight_param=utils.Parameterization(
        activation='softmax',    # 使用softmax激活
        initialization='normal'  # 权重正态初始化
    )
)

这种结构特别适合图像数据,因为它通过区域图利用了像素间的空间局部性。

电路编译与参数初始化

符号电路需要编译为可执行的PyTorch模块才能进行训练和推理:

from cirkit.pipeline import compile
circuit = compile(symbolic_circuit)

编译后的电路实际上是一个PyTorch模型,包含约2500万可学习参数。我们可以检查电路的结构性质:

print(f'结构性质:')
print(f'平滑性: {circuit.is_smooth}')      # 保证边缘化计算正确
print(f'可分解性: {circuit.is_decomposable}') # 保证高效精确推理

MNIST数据准备

我们使用标准MNIST数据集,并进行适当预处理:

from torchvision import transforms, datasets

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Lambda(lambda x: (255 * x.view(-1)).long()) # 展平并缩放像素值
])

data_train = datasets.MNIST('data', train=True, transform=transform)
data_test = datasets.MNIST('data', train=False, transform=transform)

模型训练

训练过程采用标准的PyTorch训练流程,使用负对数似然作为损失函数:

optimizer = optim.Adam(circuit.parameters(), lr=0.01)

for epoch in range(10):
    for batch, _ in train_loader:
        batch = batch.to(device)
        log_likelihoods = circuit(batch)
        loss = -log_likelihoods.mean()
        
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

训练过程中可以观察到负对数似然(NLL)稳步下降,表明模型正在学习数据分布。

模型评估

我们使用测试集评估模型性能,计算两个关键指标:

  1. 平均对数似然(LL):衡量模型对测试数据的拟合程度
  2. 每维比特数(bpd):将似然转换为信息论单位,便于比较
with torch.no_grad():
    test_ll = 0.0
    for batch, _ in test_loader:
        batch = batch.to(device)
        test_ll += circuit(batch).sum().item()
    
    avg_ll = test_ll / len(data_test)
    bpd = -avg_ll / (28*28*np.log(2)) # 转换为比特/维度

典型的结果可能在1.25 bpd左右,这与简单的生成模型性能相当,但概率电路具有更好的可解释性。

高级应用

训练好的概率电路可以支持多种推理任务:

  1. 密度估计:计算任意图像的概率密度
  2. 边缘查询:计算部分像素的边际分布
  3. 采样:生成新的图像样本
  4. 条件推理:在已知部分像素时推断其余部分

例如,进行条件推理的代码可能如下:

# 假设我们已知前100个像素的值
evidence = torch.zeros(784)
evidence[:100] = observed_values

# 计算条件分布
marginal = circuit.marginal(evidence)

结论

通过Cirkit构建的概率电路为图像分布建模提供了一种结构清晰、理论保证的方法。虽然其性能可能不及最先进的深度生成模型,但在可解释性、精确推理能力方面具有独特优势。这种技术特别适合需要可靠概率估计的应用场景,如异常检测、不确定性量化等。

未来工作可以探索更复杂的电路结构、结合深度学习组件,或者应用于更高分辨率的图像数据。概率电路为概率建模提供了一个富有前景的研究方向。

登录后查看全文
热门项目推荐

热门内容推荐

最新内容推荐

项目优选

收起
docsdocs
OpenHarmony documentation | OpenHarmony开发者文档
Dockerfile
144
1.93 K
kernelkernel
deepin linux kernel
C
22
6
ohos_react_nativeohos_react_native
React Native鸿蒙化仓库
C++
192
274
openGauss-serveropenGauss-server
openGauss kernel ~ openGauss is an open source relational database management system
C++
145
189
RuoYi-Vue3RuoYi-Vue3
🎉 (RuoYi)官方仓库 基于SpringBoot,Spring Security,JWT,Vue3 & Vite、Element Plus 的前后端分离权限管理系统
Vue
930
553
nop-entropynop-entropy
Nop Platform 2.0是基于可逆计算理论实现的采用面向语言编程范式的新一代低代码开发平台,包含基于全新原理从零开始研发的GraphQL引擎、ORM引擎、工作流引擎、报表引擎、规则引擎、批处理引引擎等完整设计。nop-entropy是它的后端部分,采用java语言实现,可选择集成Spring框架或者Quarkus框架。中小企业可以免费商用
Java
8
0
openHiTLSopenHiTLS
旨在打造算法先进、性能卓越、高效敏捷、安全可靠的密码套件,通过轻量级、可剪裁的软件技术架构满足各行业不同场景的多样化要求,让密码技术应用更简单,同时探索后量子等先进算法创新实践,构建密码前沿技术底座!
C
423
392
金融AI编程实战金融AI编程实战
为非计算机科班出身 (例如财经类高校金融学院) 同学量身定制,新手友好,让学生以亲身实践开源开发的方式,学会使用计算机自动化自己的科研/创新工作。案例以量化投资为主线,涉及 Bash、Python、SQL、BI、AI 等全技术栈,培养面向未来的数智化人才 (如数据工程师、数据分析师、数据科学家、数据决策者、量化投资人)。
Jupyter Notebook
75
66
CangjieCommunityCangjieCommunity
为仓颉编程语言开发者打造活跃、开放、高质量的社区环境
Markdown
1.11 K
0
openHiTLS-examplesopenHiTLS-examples
本仓将为广大高校开发者提供开源实践和创新开发平台,收集和展示openHiTLS示例代码及创新应用,欢迎大家投稿,让全世界看到您的精巧密码实现设计,也让更多人通过您的优秀成果,理解、喜爱上密码技术。
C
64
511