首页
/ 深入理解随机梯度下降法在机器学习中的应用——以GenTang/intro_ds项目为例

深入理解随机梯度下降法在机器学习中的应用——以GenTang/intro_ds项目为例

2025-06-29 18:02:46作者:裴麒琰

概述

随机梯度下降法(Stochastic Gradient Descent, SGD)是机器学习中最核心的优化算法之一。本文将基于GenTang/intro_ds项目中的实现,深入解析SGD的工作原理、实现细节以及在TensorFlow中的应用方式。

随机梯度下降法基础

什么是随机梯度下降法

随机梯度下降法是传统梯度下降法的变种,它通过每次迭代仅使用一小部分数据(称为mini-batch)来计算梯度并更新参数,而不是使用全部数据集。这种方法具有以下优势:

  1. 计算效率更高,特别适合大规模数据集
  2. 可以逃离局部极小值点
  3. 在线学习能力更强

与批量梯度下降的区别

  • 批量梯度下降(Batch GD):每次迭代使用全部数据计算梯度
  • 随机梯度下降(SGD):每次迭代使用单个样本或小批量样本计算梯度

代码实现解析

数据准备

dimension = 30
num = 10000
X, Y = generateLinearData(dimension, num)

这段代码生成了一个包含10000个样本、30个特征的自变量X和对应的因变量Y。这种规模的数据集非常适合展示SGD的优势。

模型定义

model = createLinearModel(dimension)

创建了一个线性模型,包含模型参数、损失函数(默认应为均方误差)、自变量和因变量的占位符。

SGD核心实现

def stochasticGradientDescent(X, Y, model, learningRate=0.01,
        miniBatchFraction=0.01, epoch=10000, tol=1.e-6):

函数参数说明:

  • learningRate: 学习率,控制参数更新步长
  • miniBatchFraction: 小批量数据占总数据的比例
  • epoch: 最大训练轮次
  • tol: 收敛阈值,当损失函数变化小于此值时停止训练

优化器设置

method = tf.train.GradientDescentOptimizer(learning_rate=learningRate)
optimizer = method.minimize(model["loss_function"])

使用TensorFlow的GradientDescentOptimizer作为优化器,目标是最小化损失函数。

训练过程

  1. 计算小批量大小:

    batchSize = int(X.shape[0] * miniBatchFraction)
    batchNum = int(math.ceil(1 / miniBatchFraction))
    
  2. 迭代训练:

    • 每次选取一个小批量数据
    • 运行优化器更新参数
    • 计算并记录损失函数值
    • 检查收敛条件

可视化支持

代码中集成了TensorBoard支持,可以可视化训练过程中的损失函数变化、参数分布等信息:

tf.summary.scalar("loss_function", model["loss_function"])
tf.summary.histogram("params", model["model_params"])

关键实现细节

小批量处理

batchX = X[i * batchSize: (i + 1) * batchSize]
batchY = Y[i * batchSize: (i + 1) * batchSize]

这种实现确保了:

  1. 每个epoch中所有数据都会被使用一次
  2. 数据是顺序选取的(实际应用中常会先打乱数据)

收敛判断

diff = abs(prevLoss - loss)
if diff <= tol:
    break

采用损失函数的变化量作为收敛标准,比单纯依靠epoch数更合理。

实际应用建议

  1. 学习率选择:可以从0.01开始尝试,根据训练情况调整
  2. 批量大小:miniBatchFraction=0.01表示使用1%的数据作为小批量
  3. 特征缩放:在实际应用中,应先对特征进行标准化处理
  4. 随机打乱:建议在每个epoch前打乱数据顺序

总结

通过GenTang/intro_ds项目中的这个实现,我们深入了解了随机梯度下降法的核心思想和实现细节。SGD作为深度学习的基础优化算法,其高效性和灵活性使其成为处理大规模数据集的理想选择。理解这个实现有助于我们在实际项目中更好地应用和调整优化算法。

登录后查看全文
热门项目推荐

最新内容推荐

项目优选

收起
Cangjie-ExamplesCangjie-Examples
本仓将收集和展示高质量的仓颉示例代码,欢迎大家投稿,让全世界看到您的妙趣设计,也让更多人通过您的编码理解和喜爱仓颉语言。
Cangjie
338
1.19 K
RuoYi-Vue3RuoYi-Vue3
🎉 (RuoYi)官方仓库 基于SpringBoot,Spring Security,JWT,Vue3 & Vite、Element Plus 的前后端分离权限管理系统
Vue
898
534
ohos_react_nativeohos_react_native
React Native鸿蒙化仓库
C++
188
265
kernelkernel
deepin linux kernel
C
22
6
openGauss-serveropenGauss-server
openGauss kernel ~ openGauss is an open source relational database management system
C++
140
188
openHiTLSopenHiTLS
旨在打造算法先进、性能卓越、高效敏捷、安全可靠的密码套件,通过轻量级、可剪裁的软件技术架构满足各行业不同场景的多样化要求,让密码技术应用更简单,同时探索后量子等先进算法创新实践,构建密码前沿技术底座!
C
374
387
CangjieCommunityCangjieCommunity
为仓颉编程语言开发者打造活跃、开放、高质量的社区环境
Markdown
1.09 K
0
note-gennote-gen
一款跨平台的 Markdown AI 笔记软件,致力于使用 AI 建立记录和写作的桥梁。
TSX
86
4
nop-entropynop-entropy
Nop Platform 2.0是基于可逆计算理论实现的采用面向语言编程范式的新一代低代码开发平台,包含基于全新原理从零开始研发的GraphQL引擎、ORM引擎、工作流引擎、报表引擎、规则引擎、批处理引引擎等完整设计。nop-entropy是它的后端部分,采用java语言实现,可选择集成Spring框架或者Quarkus框架。中小企业可以免费商用
Java
7
0
arkanalyzerarkanalyzer
方舟分析器:面向ArkTS语言的静态程序分析框架
TypeScript
114
45