首页
/ TensorFlow Workshop项目:基于低阶API构建深度神经网络实现MNIST分类

TensorFlow Workshop项目:基于低阶API构建深度神经网络实现MNIST分类

2025-07-05 10:19:03作者:董宙帆

前言

在TensorFlow Workshop项目中,深度神经网络(DNN)的实现是一个重要环节。本文将详细介绍如何使用TensorFlow的低阶API构建一个包含两个隐藏层的全连接神经网络,并在MNIST手写数字数据集上实现分类任务。相比简单的线性分类器,深度神经网络能够捕捉更复杂的特征关系,获得更高的分类准确率。

环境准备与数据加载

首先需要导入必要的库和模块:

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import math
import os
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data

初始化TensorFlow会话和计算图:

tf.reset_default_graph()
sess = tf.Session()

加载MNIST数据集,TensorFlow提供了便捷的接口:

mnist = input_data.read_data_sets('/tmp/data', one_hot=True)

模型参数配置

定义网络结构和训练参数:

# 隐藏层神经元数量
HIDDEN1_SIZE = 500  # 第一隐藏层
HIDDEN2_SIZE = 250  # 第二隐藏层

NUM_CLASSES = 10    # 输出类别数(0-9)
NUM_PIXELS = 28 * 28 # 输入像素数

# 训练参数
TRAIN_STEPS = 2000  # 训练步数
BATCH_SIZE = 100    # 批大小
LEARNING_RATE = 0.001 # 学习率

网络结构设计

输入层定义

使用name_scope组织输入层,便于TensorBoard可视化:

with tf.name_scope('input'):
    images = tf.placeholder(tf.float32, [None, NUM_PIXELS], name="pixels")
    labels = tf.placeholder(tf.float32, [None, NUM_CLASSES], name="labels")

全连接层函数

定义一个通用的全连接层创建函数,支持不同的激活函数:

def fc_layer(input, size_out, name="fc", activation=None):
    with tf.name_scope(name):
        size_in = int(input.shape[1])
        w = tf.Variable(tf.truncated_normal([size_in, size_out], stddev=0.1), name="weights")
        b = tf.Variable(tf.constant(0.1, shape=[size_out]), name="bias")
        wx_plus_b = tf.matmul(input, w) + b
        if activation: return activation(wx_plus_b)
        return wx_plus_b

网络架构构建

构建包含两个隐藏层的深度神经网络:

# 第一隐藏层(ReLU激活)
fc1 = fc_layer(images, HIDDEN1_SIZE, "fc1", activation=tf.nn.relu)

# 第二隐藏层(ReLU激活)
fc2 = fc_layer(fc1, HIDDEN2_SIZE, "fc2", activation=tf.nn.relu)

# Dropout层(防止过拟合)
dropped = tf.nn.dropout(fc2, keep_prob=0.9)

# 输出层(无激活函数)
y = fc_layer(dropped, NUM_CLASSES, name="output")

训练配置

损失函数

使用交叉熵损失函数:

with tf.name_scope("loss"):
    loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=y, labels=labels))
    tf.summary.scalar('loss', loss)  # 记录损失值

优化器

采用Adam优化器而非普通的梯度下降:

with tf.name_scope("optimizer"):
    train = tf.train.AdamOptimizer(LEARNING_RATE).minimize(loss)

评估指标

定义准确率计算:

with tf.name_scope("evaluation"):
    correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(labels, 1))
    accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
    tf.summary.scalar('accuracy', accuracy)  # 记录准确率

训练过程

日志记录配置

设置TensorBoard日志记录:

LOGDIR = './graphs'
train_writer = tf.summary.FileWriter(os.path.join(LOGDIR, "train"))
train_writer.add_graph(sess.graph)
test_writer = tf.summary.FileWriter(os.path.join(LOGDIR, "test"))
summary_op = tf.summary.merge_all()

模型训练

执行训练循环:

sess.run(tf.global_variables_initializer())

for step in range(TRAIN_STEPS):
    batch_xs, batch_ys = mnist.train.next_batch(BATCH_SIZE)
    summary_result, _ = sess.run([summary_op, train], 
                               feed_dict={images: batch_xs, labels: batch_ys})
    
    train_writer.add_summary(summary_result, step)
    
    # 每100步在测试集上评估一次
    if step % 100 == 0:
        summary_result, acc = sess.run([summary_op, accuracy], 
                                     feed_dict={images: mnist.test.images, 
                                                labels: mnist.test.labels})
        test_writer.add_summary(summary_result, step)
        print("test accuracy: %f at step %d" % (acc, step))

# 最终测试准确率
print("Accuracy %f" % sess.run(accuracy, 
                             feed_dict={images: mnist.test.images,
                                       labels: mnist.test.labels}))
train_writer.close()
test_writer.close()

模型优化技巧

  1. 权重初始化:使用截断正态分布初始化权重,标准差设为0.1,避免梯度消失或爆炸。

  2. 激活函数:隐藏层使用ReLU激活函数,相比Sigmoid或Tanh能有效缓解梯度消失问题。

  3. Dropout:在第二隐藏层后添加Dropout(keep_prob=0.9),随机丢弃10%的神经元,防止过拟合。

  4. 优化器选择:使用Adam优化器而非普通梯度下降,自适应调整学习率,通常能获得更好的收敛效果。

扩展练习

建议读者尝试以下扩展练习:

  1. 添加第三个隐藏层,观察模型性能变化
  2. 调整各隐藏层的神经元数量
  3. 尝试不同的Dropout比率
  4. 比较Adam优化器与SGD优化器的效果差异
  5. 在TensorBoard中可视化训练过程

通过本教程,读者可以掌握使用TensorFlow低阶API构建深度神经网络的基本方法,理解各组件的作用原理,为进一步学习更复杂的神经网络架构打下坚实基础。

登录后查看全文
热门项目推荐

热门内容推荐

最新内容推荐

项目优选

收起
ohos_react_nativeohos_react_native
React Native鸿蒙化仓库
C++
179
263
RuoYi-Vue3RuoYi-Vue3
🎉 (RuoYi)官方仓库 基于SpringBoot,Spring Security,JWT,Vue3 & Vite、Element Plus 的前后端分离权限管理系统
Vue
871
515
openGauss-serveropenGauss-server
openGauss kernel ~ openGauss is an open source relational database management system
C++
131
184
openHiTLSopenHiTLS
旨在打造算法先进、性能卓越、高效敏捷、安全可靠的密码套件,通过轻量级、可剪裁的软件技术架构满足各行业不同场景的多样化要求,让密码技术应用更简单,同时探索后量子等先进算法创新实践,构建密码前沿技术底座!
C
346
380
Cangjie-ExamplesCangjie-Examples
本仓将收集和展示高质量的仓颉示例代码,欢迎大家投稿,让全世界看到您的妙趣设计,也让更多人通过您的编码理解和喜爱仓颉语言。
Cangjie
334
1.09 K
harmony-utilsharmony-utils
harmony-utils 一款功能丰富且极易上手的HarmonyOS工具库,借助众多实用工具类,致力于助力开发者迅速构建鸿蒙应用。其封装的工具涵盖了APP、设备、屏幕、授权、通知、线程间通信、弹框、吐司、生物认证、用户首选项、拍照、相册、扫码、文件、日志,异常捕获、字符、字符串、数字、集合、日期、随机、base64、加密、解密、JSON等一系列的功能和操作,能够满足各种不同的开发需求。
ArkTS
31
0
CangjieCommunityCangjieCommunity
为仓颉编程语言开发者打造活跃、开放、高质量的社区环境
Markdown
1.08 K
0
kernelkernel
deepin linux kernel
C
22
5
WxJavaWxJava
微信开发 Java SDK,支持微信支付、开放平台、公众号、视频号、企业微信、小程序等的后端开发,记得关注公众号及时接受版本更新信息,以及加入微信群进行深入讨论
Java
829
22
cherry-studiocherry-studio
🍒 Cherry Studio 是一款支持多个 LLM 提供商的桌面客户端
TypeScript
603
58