首页
/ 从零开始实现Softmax回归——d2l-ai/d2l-ko项目解析

从零开始实现Softmax回归——d2l-ai/d2l-ko项目解析

2025-06-04 05:29:54作者:何将鹤

引言

在机器学习中,分类问题是最常见的任务之一。Softmax回归(也称为多项逻辑回归)是解决多类分类问题的基本模型。本文将基于d2l-ai/d2l-ko项目中的实现,详细讲解如何从零开始构建Softmax回归模型。

数据准备

我们使用Fashion-MNIST数据集,它包含10个类别的服装图片,每个图片大小为28×28像素。首先设置批量大小为256的数据迭代器:

batch_size = 256
train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size)

模型参数初始化

对于Softmax回归,我们需要为每个类别准备一组参数。由于输入是28×28=784像素的图片,输出是10个类别:

  1. 权重矩阵W:形状为784×10
  2. 偏置向量b:形状为1×10

我们使用正态分布初始化权重,偏置初始化为0:

num_inputs = 784
num_outputs = 10

W = np.random.normal(0, 0.01, (num_inputs, num_outputs))
b = np.zeros(num_outputs)

Softmax运算实现

Softmax函数将原始分数转换为概率分布,实现分为三步:

  1. 对每个元素取指数
  2. 计算每行的和(归一化常数)
  3. 将每行除以其归一化常数

数学表达式为:

softmax(X)ij=exp(Xij)kexp(Xik)\text{softmax}(X)_{ij} = \frac{\exp(X_{ij})}{\sum_k \exp(X_{ik})}

实现代码如下:

def softmax(X):
    X_exp = np.exp(X)
    partition = X_exp.sum(1, keepdims=True)
    return X_exp / partition  # 这里应用了广播机制

模型定义

模型将输入图片展平为向量,然后进行线性变换和Softmax运算:

def net(X):
    return softmax(np.dot(X.reshape(-1, W.shape[0]), W) + b)

损失函数:交叉熵

交叉熵损失是分类问题中最常用的损失函数。对于预测概率y_hat和真实标签y:

def cross_entropy(y_hat, y):
    return -np.log(y_hat[range(len(y_hat)), y])

评估指标:准确率

准确率是最直观的分类性能指标:

def accuracy(y_hat, y):
    if len(y_hat.shape) > 1 and y_hat.shape[1] > 1:
        y_hat = y_hat.argmax(axis=1)
    cmp = y_hat.astype(y.dtype) == y
    return float(cmp.astype(y.dtype).sum())

训练过程

训练循环包括:

  1. 前向传播计算预测
  2. 计算损失
  3. 反向传播计算梯度
  4. 更新参数

我们使用小批量随机梯度下降进行优化:

lr = 0.1

def updater(batch_size):
    return d2l.sgd([W, b], lr, batch_size)

num_epochs = 10
train_ch3(net, train_iter, test_iter, cross_entropy, num_epochs, updater)

预测与评估

训练完成后,我们可以用模型进行预测:

def predict_ch3(net, test_iter, n=6):
    for X, y in test_iter:
        break
    trues = d2l.get_fashion_mnist_labels(y)
    preds = d2l.get_fashion_mnist_labels(net(X).argmax(axis=1))
    titles = [f"{true}\n{pred}" for true, pred in zip(trues, preds)]
    d2l.show_images(X[0:n].reshape(n, 28, 28), 1, n, titles=titles[0:n])

predict_ch3(net, test_iter)

数值稳定性问题

在实际实现中,直接计算Softmax可能会遇到数值不稳定的问题:

  1. 当输入值很大时,exp(x)可能导致数值溢出
  2. 当输入值很小时,exp(x)可能导致数值下溢

解决方案通常是在计算Softmax前减去最大值:

def softmax(X):
    X_exp = np.exp(X - X.max(1, keepdims=True))
    return X_exp / X_exp.sum(1, keepdims=True)

总结

本文详细介绍了从零实现Softmax回归的关键步骤:

  1. 数据准备和预处理
  2. 模型参数初始化
  3. Softmax运算实现
  4. 交叉熵损失函数
  5. 训练循环和优化
  6. 预测和评估

Softmax回归虽然简单,但包含了深度学习模型的基本要素,是理解更复杂神经网络的基础。

登录后查看全文
热门项目推荐

项目优选

收起
kernelkernel
deepin linux kernel
C
22
6
docsdocs
OpenHarmony documentation | OpenHarmony开发者文档
Dockerfile
161
2.05 K
nop-entropynop-entropy
Nop Platform 2.0是基于可逆计算理论实现的采用面向语言编程范式的新一代低代码开发平台,包含基于全新原理从零开始研发的GraphQL引擎、ORM引擎、工作流引擎、报表引擎、规则引擎、批处理引引擎等完整设计。nop-entropy是它的后端部分,采用java语言实现,可选择集成Spring框架或者Quarkus框架。中小企业可以免费商用
Java
8
0
openGauss-serveropenGauss-server
openGauss kernel ~ openGauss is an open source relational database management system
C++
146
191
leetcodeleetcode
🔥LeetCode solutions in any programming language | 多种编程语言实现 LeetCode、《剑指 Offer(第 2 版)》、《程序员面试金典(第 6 版)》题解
Java
60
16
ohos_react_nativeohos_react_native
React Native鸿蒙化仓库
C++
198
279
apintoapinto
基于golang开发的网关。具有各种插件,可以自行扩展,即插即用。此外,它可以快速帮助企业管理API服务,提高API服务的稳定性和安全性。
Go
22
0
RuoYi-Vue3RuoYi-Vue3
🎉 (RuoYi)官方仓库 基于SpringBoot,Spring Security,JWT,Vue3 & Vite、Element Plus 的前后端分离权限管理系统
Vue
949
556
ShopXO开源商城ShopXO开源商城
🔥🔥🔥ShopXO企业级免费开源商城系统,可视化DIY拖拽装修、包含PC、H5、多端小程序(微信+支付宝+百度+头条&抖音+QQ+快手)、APP、多仓库、多商户、多门店、IM客服、进销存,遵循MIT开源协议发布、基于ThinkPHP8框架研发
JavaScript
96
15
Cangjie-ExamplesCangjie-Examples
本仓将收集和展示高质量的仓颉示例代码,欢迎大家投稿,让全世界看到您的妙趣设计,也让更多人通过您的编码理解和喜爱仓颉语言。
Cangjie
346
1.33 K