首页
/ TensorFlow Workshop项目:基于低阶API构建深度神经网络实现MNIST分类

TensorFlow Workshop项目:基于低阶API构建深度神经网络实现MNIST分类

2025-07-05 04:40:30作者:董宙帆

前言

在TensorFlow Workshop项目中,深度神经网络(DNN)的实现是一个重要环节。本文将详细介绍如何使用TensorFlow的低阶API构建一个包含两个隐藏层的全连接神经网络,并在MNIST手写数字数据集上实现分类任务。相比简单的线性分类器,深度神经网络能够捕捉更复杂的特征关系,获得更高的分类准确率。

环境准备与数据加载

首先需要导入必要的库和模块:

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import math
import os
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data

初始化TensorFlow会话和计算图:

tf.reset_default_graph()
sess = tf.Session()

加载MNIST数据集,TensorFlow提供了便捷的接口:

mnist = input_data.read_data_sets('/tmp/data', one_hot=True)

模型参数配置

定义网络结构和训练参数:

# 隐藏层神经元数量
HIDDEN1_SIZE = 500  # 第一隐藏层
HIDDEN2_SIZE = 250  # 第二隐藏层

NUM_CLASSES = 10    # 输出类别数(0-9)
NUM_PIXELS = 28 * 28 # 输入像素数

# 训练参数
TRAIN_STEPS = 2000  # 训练步数
BATCH_SIZE = 100    # 批大小
LEARNING_RATE = 0.001 # 学习率

网络结构设计

输入层定义

使用name_scope组织输入层,便于TensorBoard可视化:

with tf.name_scope('input'):
    images = tf.placeholder(tf.float32, [None, NUM_PIXELS], name="pixels")
    labels = tf.placeholder(tf.float32, [None, NUM_CLASSES], name="labels")

全连接层函数

定义一个通用的全连接层创建函数,支持不同的激活函数:

def fc_layer(input, size_out, name="fc", activation=None):
    with tf.name_scope(name):
        size_in = int(input.shape[1])
        w = tf.Variable(tf.truncated_normal([size_in, size_out], stddev=0.1), name="weights")
        b = tf.Variable(tf.constant(0.1, shape=[size_out]), name="bias")
        wx_plus_b = tf.matmul(input, w) + b
        if activation: return activation(wx_plus_b)
        return wx_plus_b

网络架构构建

构建包含两个隐藏层的深度神经网络:

# 第一隐藏层(ReLU激活)
fc1 = fc_layer(images, HIDDEN1_SIZE, "fc1", activation=tf.nn.relu)

# 第二隐藏层(ReLU激活)
fc2 = fc_layer(fc1, HIDDEN2_SIZE, "fc2", activation=tf.nn.relu)

# Dropout层(防止过拟合)
dropped = tf.nn.dropout(fc2, keep_prob=0.9)

# 输出层(无激活函数)
y = fc_layer(dropped, NUM_CLASSES, name="output")

训练配置

损失函数

使用交叉熵损失函数:

with tf.name_scope("loss"):
    loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=y, labels=labels))
    tf.summary.scalar('loss', loss)  # 记录损失值

优化器

采用Adam优化器而非普通的梯度下降:

with tf.name_scope("optimizer"):
    train = tf.train.AdamOptimizer(LEARNING_RATE).minimize(loss)

评估指标

定义准确率计算:

with tf.name_scope("evaluation"):
    correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(labels, 1))
    accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
    tf.summary.scalar('accuracy', accuracy)  # 记录准确率

训练过程

日志记录配置

设置TensorBoard日志记录:

LOGDIR = './graphs'
train_writer = tf.summary.FileWriter(os.path.join(LOGDIR, "train"))
train_writer.add_graph(sess.graph)
test_writer = tf.summary.FileWriter(os.path.join(LOGDIR, "test"))
summary_op = tf.summary.merge_all()

模型训练

执行训练循环:

sess.run(tf.global_variables_initializer())

for step in range(TRAIN_STEPS):
    batch_xs, batch_ys = mnist.train.next_batch(BATCH_SIZE)
    summary_result, _ = sess.run([summary_op, train], 
                               feed_dict={images: batch_xs, labels: batch_ys})
    
    train_writer.add_summary(summary_result, step)
    
    # 每100步在测试集上评估一次
    if step % 100 == 0:
        summary_result, acc = sess.run([summary_op, accuracy], 
                                     feed_dict={images: mnist.test.images, 
                                                labels: mnist.test.labels})
        test_writer.add_summary(summary_result, step)
        print("test accuracy: %f at step %d" % (acc, step))

# 最终测试准确率
print("Accuracy %f" % sess.run(accuracy, 
                             feed_dict={images: mnist.test.images,
                                       labels: mnist.test.labels}))
train_writer.close()
test_writer.close()

模型优化技巧

  1. 权重初始化:使用截断正态分布初始化权重,标准差设为0.1,避免梯度消失或爆炸。

  2. 激活函数:隐藏层使用ReLU激活函数,相比Sigmoid或Tanh能有效缓解梯度消失问题。

  3. Dropout:在第二隐藏层后添加Dropout(keep_prob=0.9),随机丢弃10%的神经元,防止过拟合。

  4. 优化器选择:使用Adam优化器而非普通梯度下降,自适应调整学习率,通常能获得更好的收敛效果。

扩展练习

建议读者尝试以下扩展练习:

  1. 添加第三个隐藏层,观察模型性能变化
  2. 调整各隐藏层的神经元数量
  3. 尝试不同的Dropout比率
  4. 比较Adam优化器与SGD优化器的效果差异
  5. 在TensorBoard中可视化训练过程

通过本教程,读者可以掌握使用TensorFlow低阶API构建深度神经网络的基本方法,理解各组件的作用原理,为进一步学习更复杂的神经网络架构打下坚实基础。

登录后查看全文
热门项目推荐

项目优选

收起
kernelkernel
deepin linux kernel
C
22
6
docsdocs
OpenHarmony documentation | OpenHarmony开发者文档
Dockerfile
192
2.16 K
nop-entropynop-entropy
Nop Platform 2.0是基于可逆计算理论实现的采用面向语言编程范式的新一代低代码开发平台,包含基于全新原理从零开始研发的GraphQL引擎、ORM引擎、工作流引擎、报表引擎、规则引擎、批处理引引擎等完整设计。nop-entropy是它的后端部分,采用java语言实现,可选择集成Spring框架或者Quarkus框架。中小企业可以免费商用
Java
9
1
金融AI编程实战金融AI编程实战
为非计算机科班出身 (例如财经类高校金融学院) 同学量身定制,新手友好,让学生以亲身实践开源开发的方式,学会使用计算机自动化自己的科研/创新工作。案例以量化投资为主线,涉及 Bash、Python、SQL、BI、AI 等全技术栈,培养面向未来的数智化人才 (如数据工程师、数据分析师、数据科学家、数据决策者、量化投资人)。
Python
78
72
RuoYi-Vue3RuoYi-Vue3
🎉 (RuoYi)官方仓库 基于SpringBoot,Spring Security,JWT,Vue3 & Vite、Element Plus 的前后端分离权限管理系统
Vue
971
572
ops-mathops-math
本项目是CANN提供的数学类基础计算算子库,实现网络在NPU上加速计算。
C++
548
76
Cangjie-ExamplesCangjie-Examples
本仓将收集和展示高质量的仓颉示例代码,欢迎大家投稿,让全世界看到您的妙趣设计,也让更多人通过您的编码理解和喜爱仓颉语言。
Cangjie
349
1.36 K
giteagitea
喝着茶写代码!最易用的自托管一站式代码托管平台,包含Git托管,代码审查,团队协作,软件包和CI/CD。
Go
17
0
ohos_react_nativeohos_react_native
React Native鸿蒙化仓库
C++
206
284
leetcodeleetcode
🔥LeetCode solutions in any programming language | 多种编程语言实现 LeetCode、《剑指 Offer(第 2 版)》、《程序员面试金典(第 6 版)》题解
Java
60
17