首页
/ 使用ObjAX框架实现MNIST手写数字识别教程

使用ObjAX框架实现MNIST手写数字识别教程

2025-06-19 17:38:42作者:伍希望

引言

ObjAX是一个基于JAX的深度学习框架,提供了类似于PyTorch的API设计风格。本教程将带领读者使用ObjAX框架构建和训练一个卷积神经网络(CNN)模型,用于MNIST手写数字识别任务。

环境准备与数据加载

首先需要导入必要的Python库:

import numpy as np
import jax
import jax.numpy as jn
import objax
import matplotlib.pyplot as plt

MNIST数据集可以通过Keras直接加载:

import tensorflow.keras as keras
(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()

ObjAX与PyTorch类似,采用通道优先的数据格式(CHW),因此我们需要调整数据维度并将像素值归一化到[0,1]范围:

x_train = x_train[:,None,:,:]/255.0
x_test = x_test[:,None,:,:]/255.0

模型构建

我们首先构建一个简单的CNN模型,包含三个卷积层和ReLU激活函数:

def conv_relu_pool(in_layers, out_layers, pool=True):
    ops = [objax.nn.Conv2D(in_layers, out_layers, 5),
            objax.functional.relu]
    if pool:
        ops.append(lambda x: objax.functional.average_pool_2d(x, size=2, strides=1))
    return ops

model = objax.nn.Sequential(conv_relu_pool(1, 32) + 
                          conv_relu_pool(32, 32) + 
                          conv_relu_pool(32, 64) + 
                          [objax.nn.Conv2D(64, 10, 1),
                           lambda x: x.mean((2,3))])

这个模型结构的特点是:

  1. 使用5x5的卷积核
  2. 每层后接ReLU激活
  3. 使用平均池化降采样
  4. 最后使用1x1卷积将特征图转换为10类输出

训练流程

优化器设置

使用Adam优化器:

opt = objax.optimizer.Adam(model.vars())

预测函数

使用JIT编译加速预测过程:

predict = objax.Jit(lambda x: objax.functional.softmax(model(x)), model.vars())

损失函数

定义交叉熵损失函数:

def loss(x, label):
    logit = model(x)
    return objax.functional.loss.cross_entropy_logits_sparse(logit, label).mean()

梯度计算与训练操作

gv = objax.GradValues(loss, model.vars())

def train_op_nojit(x, y):
    g, v = gv(x, y)
    opt(lr=0.002, grads=g)
    return v

train_op = objax.Jit(train_op_nojit, model.vars() + opt.vars())

模型训练

训练10个epoch:

def train_epoch(batch_size=50):    
    losses = []
    for x_batch, y_batch in zip(x_train.reshape((-1, batch_size, 1, 28, 28)), 
                             y_train.reshape((-1, batch_size))):
        losses.append(train_op(x_batch, y_batch))
    return np.mean(losses)

for i in range(10):
    print('loss', train_epoch())

模型评估

定义准确率计算函数:

def compute_accuracy():
    test_predictions = [predict(test_batch).argmax(1) 
                       for test_batch in x_test.reshape((-1, 50, 1, 28, 28))]
    return np.mean(y_test == np.array(test_predictions).flatten())

print('model accuracy', compute_accuracy())

更大模型的构建与训练

为了获得更好的性能,我们可以构建一个更大的模型:

model = objax.nn.Sequential(conv_relu_pool(1, 32, pool=False) + 
                          conv_relu_pool(32, 64) + 
                          conv_relu_pool(64, 64, pool=False) + 
                          conv_relu_pool(64, 128) + 
                          [objax.nn.Conv2D(128, 10, 1),
                           lambda x: x.mean((2,3))])

加入L1正则化防止过拟合:

def loss_with_wd(x, label):
    logit = model(x)
    xe_loss = objax.functional.loss.cross_entropy_logits_sparse(logit, label).mean()
    wd_loss = sum(jn.abs(v).sum() for k,v in model.vars().items() if k.endswith('.w'))
    return xe_loss + wd_loss * 1e-5

训练更大的模型:

for i in range(10):
    random_shuffle = np.arange(x_train.shape[0])
    np.random.shuffle(random_shuffle)
    x_train = x_train[random_shuffle]
    y_train = y_train[random_shuffle]
    print('loss', train_epoch(batch_size=200))

模型权重分析

可以分析模型权重的稀疏性:

for k,v in model.vars().items():
    if k.endswith('.w'):
        print("Small weight ratio on layer", k, (jn.abs(v) < 1e-2).mean())

总结

本教程展示了如何使用ObjAX框架:

  1. 加载和预处理MNIST数据集
  2. 构建CNN模型
  3. 设置优化器和损失函数
  4. 训练和评估模型
  5. 构建更大模型并分析权重

ObjAX结合了JAX的高性能自动微分和PyTorch风格的API设计,使得深度学习模型的开发和训练更加高效。通过本教程,读者可以掌握ObjAX的基本使用方法,并应用于自己的深度学习项目中。

登录后查看全文
热门项目推荐

热门内容推荐

最新内容推荐

项目优选

收起
ohos_react_nativeohos_react_native
React Native鸿蒙化仓库
C++
179
263
RuoYi-Vue3RuoYi-Vue3
🎉 (RuoYi)官方仓库 基于SpringBoot,Spring Security,JWT,Vue3 & Vite、Element Plus 的前后端分离权限管理系统
Vue
871
515
openGauss-serveropenGauss-server
openGauss kernel ~ openGauss is an open source relational database management system
C++
131
184
openHiTLSopenHiTLS
旨在打造算法先进、性能卓越、高效敏捷、安全可靠的密码套件,通过轻量级、可剪裁的软件技术架构满足各行业不同场景的多样化要求,让密码技术应用更简单,同时探索后量子等先进算法创新实践,构建密码前沿技术底座!
C
346
380
Cangjie-ExamplesCangjie-Examples
本仓将收集和展示高质量的仓颉示例代码,欢迎大家投稿,让全世界看到您的妙趣设计,也让更多人通过您的编码理解和喜爱仓颉语言。
Cangjie
334
1.09 K
harmony-utilsharmony-utils
harmony-utils 一款功能丰富且极易上手的HarmonyOS工具库,借助众多实用工具类,致力于助力开发者迅速构建鸿蒙应用。其封装的工具涵盖了APP、设备、屏幕、授权、通知、线程间通信、弹框、吐司、生物认证、用户首选项、拍照、相册、扫码、文件、日志,异常捕获、字符、字符串、数字、集合、日期、随机、base64、加密、解密、JSON等一系列的功能和操作,能够满足各种不同的开发需求。
ArkTS
31
0
CangjieCommunityCangjieCommunity
为仓颉编程语言开发者打造活跃、开放、高质量的社区环境
Markdown
1.08 K
0
kernelkernel
deepin linux kernel
C
22
5
WxJavaWxJava
微信开发 Java SDK,支持微信支付、开放平台、公众号、视频号、企业微信、小程序等的后端开发,记得关注公众号及时接受版本更新信息,以及加入微信群进行深入讨论
Java
829
22
cherry-studiocherry-studio
🍒 Cherry Studio 是一款支持多个 LLM 提供商的桌面客户端
TypeScript
603
58