首页
/ 使用ObjAX框架实现MNIST手写数字识别教程

使用ObjAX框架实现MNIST手写数字识别教程

2025-06-19 12:29:05作者:伍希望

引言

ObjAX是一个基于JAX的深度学习框架,提供了类似于PyTorch的API设计风格。本教程将带领读者使用ObjAX框架构建和训练一个卷积神经网络(CNN)模型,用于MNIST手写数字识别任务。

环境准备与数据加载

首先需要导入必要的Python库:

import numpy as np
import jax
import jax.numpy as jn
import objax
import matplotlib.pyplot as plt

MNIST数据集可以通过Keras直接加载:

import tensorflow.keras as keras
(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()

ObjAX与PyTorch类似,采用通道优先的数据格式(CHW),因此我们需要调整数据维度并将像素值归一化到[0,1]范围:

x_train = x_train[:,None,:,:]/255.0
x_test = x_test[:,None,:,:]/255.0

模型构建

我们首先构建一个简单的CNN模型,包含三个卷积层和ReLU激活函数:

def conv_relu_pool(in_layers, out_layers, pool=True):
    ops = [objax.nn.Conv2D(in_layers, out_layers, 5),
            objax.functional.relu]
    if pool:
        ops.append(lambda x: objax.functional.average_pool_2d(x, size=2, strides=1))
    return ops

model = objax.nn.Sequential(conv_relu_pool(1, 32) + 
                          conv_relu_pool(32, 32) + 
                          conv_relu_pool(32, 64) + 
                          [objax.nn.Conv2D(64, 10, 1),
                           lambda x: x.mean((2,3))])

这个模型结构的特点是:

  1. 使用5x5的卷积核
  2. 每层后接ReLU激活
  3. 使用平均池化降采样
  4. 最后使用1x1卷积将特征图转换为10类输出

训练流程

优化器设置

使用Adam优化器:

opt = objax.optimizer.Adam(model.vars())

预测函数

使用JIT编译加速预测过程:

predict = objax.Jit(lambda x: objax.functional.softmax(model(x)), model.vars())

损失函数

定义交叉熵损失函数:

def loss(x, label):
    logit = model(x)
    return objax.functional.loss.cross_entropy_logits_sparse(logit, label).mean()

梯度计算与训练操作

gv = objax.GradValues(loss, model.vars())

def train_op_nojit(x, y):
    g, v = gv(x, y)
    opt(lr=0.002, grads=g)
    return v

train_op = objax.Jit(train_op_nojit, model.vars() + opt.vars())

模型训练

训练10个epoch:

def train_epoch(batch_size=50):    
    losses = []
    for x_batch, y_batch in zip(x_train.reshape((-1, batch_size, 1, 28, 28)), 
                             y_train.reshape((-1, batch_size))):
        losses.append(train_op(x_batch, y_batch))
    return np.mean(losses)

for i in range(10):
    print('loss', train_epoch())

模型评估

定义准确率计算函数:

def compute_accuracy():
    test_predictions = [predict(test_batch).argmax(1) 
                       for test_batch in x_test.reshape((-1, 50, 1, 28, 28))]
    return np.mean(y_test == np.array(test_predictions).flatten())

print('model accuracy', compute_accuracy())

更大模型的构建与训练

为了获得更好的性能,我们可以构建一个更大的模型:

model = objax.nn.Sequential(conv_relu_pool(1, 32, pool=False) + 
                          conv_relu_pool(32, 64) + 
                          conv_relu_pool(64, 64, pool=False) + 
                          conv_relu_pool(64, 128) + 
                          [objax.nn.Conv2D(128, 10, 1),
                           lambda x: x.mean((2,3))])

加入L1正则化防止过拟合:

def loss_with_wd(x, label):
    logit = model(x)
    xe_loss = objax.functional.loss.cross_entropy_logits_sparse(logit, label).mean()
    wd_loss = sum(jn.abs(v).sum() for k,v in model.vars().items() if k.endswith('.w'))
    return xe_loss + wd_loss * 1e-5

训练更大的模型:

for i in range(10):
    random_shuffle = np.arange(x_train.shape[0])
    np.random.shuffle(random_shuffle)
    x_train = x_train[random_shuffle]
    y_train = y_train[random_shuffle]
    print('loss', train_epoch(batch_size=200))

模型权重分析

可以分析模型权重的稀疏性:

for k,v in model.vars().items():
    if k.endswith('.w'):
        print("Small weight ratio on layer", k, (jn.abs(v) < 1e-2).mean())

总结

本教程展示了如何使用ObjAX框架:

  1. 加载和预处理MNIST数据集
  2. 构建CNN模型
  3. 设置优化器和损失函数
  4. 训练和评估模型
  5. 构建更大模型并分析权重

ObjAX结合了JAX的高性能自动微分和PyTorch风格的API设计,使得深度学习模型的开发和训练更加高效。通过本教程,读者可以掌握ObjAX的基本使用方法,并应用于自己的深度学习项目中。

登录后查看全文
热门项目推荐
相关项目推荐

项目优选

收起
kernelkernel
deepin linux kernel
C
27
11
docsdocs
OpenHarmony documentation | OpenHarmony开发者文档
Dockerfile
472
3.49 K
nop-entropynop-entropy
Nop Platform 2.0是基于可逆计算理论实现的采用面向语言编程范式的新一代低代码开发平台,包含基于全新原理从零开始研发的GraphQL引擎、ORM引擎、工作流引擎、报表引擎、规则引擎、批处理引引擎等完整设计。nop-entropy是它的后端部分,采用java语言实现,可选择集成Spring框架或者Quarkus框架。中小企业可以免费商用
Java
10
1
leetcodeleetcode
🔥LeetCode solutions in any programming language | 多种编程语言实现 LeetCode、《剑指 Offer(第 2 版)》、《程序员面试金典(第 6 版)》题解
Java
65
19
flutter_flutterflutter_flutter
暂无简介
Dart
719
173
giteagitea
喝着茶写代码!最易用的自托管一站式代码托管平台,包含Git托管,代码审查,团队协作,软件包和CI/CD。
Go
23
0
kernelkernel
openEuler内核是openEuler操作系统的核心,既是系统性能与稳定性的基石,也是连接处理器、设备与服务的桥梁。
C
213
86
RuoYi-Vue3RuoYi-Vue3
🎉 (RuoYi)官方仓库 基于SpringBoot,Spring Security,JWT,Vue3 & Vite、Element Plus 的前后端分离权限管理系统
Vue
1.27 K
696
rainbondrainbond
无需学习 Kubernetes 的容器平台,在 Kubernetes 上构建、部署、组装和管理应用,无需 K8s 专业知识,全流程图形化管理
Go
15
1
apintoapinto
基于golang开发的网关。具有各种插件,可以自行扩展,即插即用。此外,它可以快速帮助企业管理API服务,提高API服务的稳定性和安全性。
Go
22
1