首页
/ 线性可分支持向量机(SVM)原理与实现 - 基于GenTang/intro_ds项目分析

线性可分支持向量机(SVM)原理与实现 - 基于GenTang/intro_ds项目分析

2025-06-29 00:14:16作者:农烁颖Land

引言

支持向量机(Support Vector Machine, SVM)是一种强大的监督学习算法,特别适用于分类任务。本文将通过GenTang/intro_ds项目中的线性可分SVM实现代码,深入讲解SVM的核心原理和实际应用。

线性可分SVM基础概念

线性可分SVM的核心思想是找到一个最优超平面,将不同类别的数据点分开,并使两类数据点到这个超平面的最小距离最大化。这个最优超平面被称为最大间隔超平面

关键术语解释

  1. 支持向量:距离超平面最近的样本点,决定了超平面的位置
  2. 间隔(margin):两个平行于超平面且分别经过正负类支持向量的平面之间的距离
  3. 决策边界:用于分类的超平面

代码实现解析

1. 数据生成

项目提供了两种数据生成函数:

def generateSeparatableData(n):
    # 生成线性可分数据
    np.random.seed(2046)
    X = np.r_[np.random.randn(n, 2) - [1, 1], np.random.randn(n, 2) + [3, 3]]
    Y = [[0]] * n + [[1]] * n
    data = np.concatenate((Y, X), axis=1)
    data = pd.DataFrame(data, columns=["y", "x1", "x2"])
    return data

def generateInseparatableData(n):
    # 生成线性不可分数据
    data = generateSeparatableData(n)
    inseparatable = [[1, -1, 1.5], [0, 3, 1]]
    inseparatable = pd.DataFrame(inseparatable, columns=["y", "x1", "x2"])
    data = data.append(inseparatable)
    return data
  • generateSeparatableData生成两组正态分布数据,分别以(-1,-1)和(3,3)为中心
  • generateInseparatableData在可分数据基础上添加两个异常点,使数据变得线性不可分

2. 模型训练

def trainModel(data):
    # 训练SVM模型
    model = SVC(C=1e4, kernel="linear")
    model.fit(data[["x1", "x2"]], data["y"])
    return model

这里使用了sklearnSVC类,关键参数:

  • C=1e4:设置很大的惩罚系数,使模型更倾向于找到完美分类的超平面
  • kernel="linear":使用线性核函数

3. 结果可视化

def visualize(data, model=None):
    # 可视化结果
    fig = plt.figure(figsize=(6, 6), dpi=80)
    ax = fig.add_subplot(1, 1, 1)
    # 绘制数据点
    label1 = data[data["y"]>0]
    ax.scatter(label1[["x1"]], label1[["x2"]], marker="o")
    label0 = data[data["y"]==0]
    ax.scatter(label0[["x1"]], label0[["x2"]], marker="^", color="k")
    
    if model is not None:
        # 绘制决策边界和间隔
        w = model.coef_
        a = -w[0][0] / w[0][1]
        xx = np.linspace(-3, 5)
        yy = a * xx - (model.intercept_) / w[0][1]
        yy_down = yy - 1 / w[0][1]
        yy_up = yy + 1 / w[0][1]
        ax.plot(xx, yy, "r")
        ax.plot(xx, yy_down, "r--")
        ax.plot(xx, yy_up, "r--")
    plt.show()

可视化部分展示了:

  • 不同类别数据点的分布(圆形和三角形)
  • 决策边界(红色实线)
  • 间隔边界(红色虚线)

数学原理深入

超平面方程

决策超平面可以表示为: w·x + b = 0

其中:

  • w是法向量,决定超平面的方向
  • b是位移项,决定超平面与原点的距离

间隔计算

两个间隔边界的方程分别为: w·x + b = 1 w·x + b = -1

间隔距离为:2/||w||

优化目标

SVM的优化目标是最大化间隔,等价于最小化||w||²/2: min 1/2 ||w||² s.t. y_i(w·x_i + b) ≥ 1, ∀i

实际应用中的注意事项

  1. 线性可分性:真实数据往往不是严格线性可分的,这时需要引入松弛变量
  2. 参数C的选择:C值越大对误分类的惩罚越大,可能导致过拟合
  3. 特征缩放:SVM对特征的尺度敏感,建议先进行标准化处理
  4. 核函数选择:线性不可分时可考虑使用非线性核函数

总结

通过GenTang/intro_ds项目中的线性可分SVM实现,我们学习了:

  1. SVM的基本原理和最大间隔思想
  2. 如何使用scikit-learn实现线性SVM
  3. 如何可视化和解释SVM模型结果
  4. SVM的数学基础和优化目标

线性可分SVM是理解更复杂SVM模型的基础,掌握这些概念对于后续学习非线性SVM和核方法至关重要。

登录后查看全文
热门项目推荐