Kolmogorov-Arnold Networks(KAN)入门教程:从原理到实践
认知篇:KAN网络的数学原理与核心优势
什么是KAN?
Kolmogorov-Arnold Networks(KAN)是一种基于数学理论构建的神经网络架构,它结合了Kolmogorov定理和Arnold的函数逼近思想,通过样条函数(可理解为平滑连接的曲线段)构建具有高度可解释性的非线性模型。与传统神经网络相比,KAN不仅能够精确拟合复杂函数,还能显式揭示输入与输出之间的数学关系。
图1:KAN模型的组成及其三大核心优势(数学性、准确性、可解释性)
KAN的数学基础
KAN的核心在于其激活函数设计,每个神经元的输出由两部分组成:
其中:
- 是基础函数(如SILU、线性函数等)
- 是B样条函数,通过网格点控制曲线形状
- 和 是可学习的尺度参数
B样条函数通过grid参数控制的网格点进行定义,k参数控制多项式阶数(通常为3,表示三次样条)。这种结构使KAN能够灵活捕捉数据中的非线性模式,同时保持数学可解释性。
KAN与传统神经网络对比
| 特性 | KAN | 传统神经网络(MLP) |
|---|---|---|
| 激活函数 | 自适应样条函数 | 固定非线性函数(ReLU等) |
| 可解释性 | 显式函数关系 | 黑箱模型 |
| 参数效率 | 高(少量参数实现高精度) | 低(需要大量参数) |
| 数学可解释性 | 支持符号化表达 | 不支持 |
| 网格自适应 | 动态调整采样点 | 无此机制 |
| 适用场景 | 科学计算、物理模拟 | 图像识别、语音处理 |
核心优势:KAN通过数学原理与神经网络的结合,在保持高精度的同时,解决了传统神经网络"黑箱"问题,特别适合需要模型解释性的科学研究领域。
实战篇:环境部署与基础案例实现
从零开始的KAN开发环境
1. 系统要求与前置条件
| 组件 | 最低版本要求 | 推荐版本 |
|---|---|---|
| Python | 3.6+ | 3.9.7+ |
| PyTorch | 1.10.0+ | 2.2.2 |
| 操作系统 | Windows 10 / macOS 10.15+ / Linux | 任意现代系统 |
2. 安装步骤(推荐源码安装)
# 克隆仓库
git clone https://gitcode.com/GitHub_Trending/pyk/pykan.git
cd pykan
# 创建并激活虚拟环境
python -m venv .venv
source .venv/bin/activate # Linux/macOS
# 或
.venv\Scripts\activate # Windows
# 安装依赖
pip install -e .
3. 环境验证
import torch
from kan import KAN
# 验证安装
print("pykan安装成功!")
print(f"PyTorch版本: {torch.__version__}")
print(f"CUDA可用: {torch.cuda.is_available()}")
基础案例:函数拟合
下面实现一个简单的函数拟合任务,使用KAN逼近 :
import torch
from kan import KAN
from kan.utils import create_dataset
# 1. 创建数据集
f = lambda x: torch.sin(x[:,[0]]) + x[:,[0]]**2
dataset = create_dataset(f, n_var=1, train_num=100, test_num=30)
# 2. 初始化KAN模型
model = KAN(width=[1, 4, 1], grid=5, k=3, device='cpu')
# 3. 训练模型
model.fit(dataset, opt="LBFGS", steps=50, lamb=0.001)
# 4. 评估结果
results = model.evaluate(dataset)
print(f"训练损失: {results['train_loss']:.6f}")
print(f"测试损失: {results['test_loss']:.6f}")
💡 提示:
width参数定义网络结构,如[1,4,1]表示1个输入神经元、4个隐藏神经元和1个输出神经元。grid参数控制样条函数的网格密度。
模型可视化
训练完成后,可视化KAN网络结构和激活函数:
# 绘制网络结构
model.plot(
beta=3,
in_vars=['x'],
out_vars=['f(x)'],
title="KAN函数拟合网络"
)
进阶篇:参数调优与性能提升策略
核心参数配置指南
KAN的性能很大程度上取决于参数配置,以下是关键参数的调优建议:
| 参数名称 | 作用 | 推荐范围 | 调整策略 |
|---|---|---|---|
grid |
控制样条网格密度 | 3-10 | 复杂函数增大网格 |
k |
样条多项式阶数 | 2-5 | 通常使用3(三次样条) |
lamb |
稀疏正则化系数 | 0.001-0.1 | 过拟合时增大 |
grid_eps |
网格自适应程度 | 0-1 | 0表示完全自适应 |
mult_arity |
乘法节点元数 | 2-4 | 增加可捕捉高阶交互 |
正则化策略详解
KAN提供多种正则化机制控制模型复杂度:
# 多正则化训练示例
model.fit(
dataset,
steps=100,
lamb=0.001, # 稀疏正则化
lamb_l1=0.5, # L1正则化
lamb_entropy=2.0 # 熵正则化
)
- 稀疏正则化(
lamb):控制整体连接稀疏度 - L1正则化(
lamb_l1):促进权重稀疏,简化模型 - 熵正则化(
lamb_entropy):平衡激活函数分布
自适应网格更新机制
KAN的独特优势在于其动态网格调整能力,通过update_grid参数启用:
model.fit(
dataset,
steps=150,
update_grid=True, # 启用网格更新
grid_update_num=10, # 更新次数
grid_eps=0.02 # 自适应程度
)
🔧 工具原理:网格更新通过分析数据分布和梯度信息,在函数变化剧烈区域增加网格点密度,在平缓区域减少网格点,实现精度与效率的平衡。
剪枝优化流程
训练后剪枝可以进一步简化模型,提高推理速度:
# 剪枝冗余连接和节点
model.prune(
node_th=1e-2, # 节点剪枝阈值
edge_th=3e-2 # 边剪枝阈值
)
# 剪枝后微调
model.fit(dataset, steps=30, lamb=0.0001)
应用篇:典型场景解决方案与代码示例
物理系统模拟
KAN在物理系统模拟中表现出色,能够精确捕捉物理规律。以下是黑洞引力时间延迟模拟的示例:
# 黑洞引力时间延迟模拟
from kan import KAN
import torch
# 定义物理方程(史瓦西解)
def blackhole_time_delay(r):
return -2 * torch.sqrt(r) - torch.log((torch.sqrt(r)-1)/(torch.sqrt(r)+1))
# 创建数据集
dataset = create_dataset(
blackhole_time_delay,
n_var=1,
ranges=[[1.25, 3.0]],
train_num=200
)
# 训练KAN模型
model = KAN(width=[1, 8, 1], grid=7, k=3)
model.fit(dataset, steps=200, lamb=0.001, update_grid=True)
# 可视化结果
model.plot(in_vars=['r'], out_vars=['Δt'])
图2:KAN模拟的黑洞引力时间延迟曲线(蓝色实线为模拟结果,黄色虚线为理论解)
特殊函数逼近
KAN特别适合逼近数学物理中的特殊函数,如相对论中的质能关系:
# 相对论质能关系逼近
def relativistic_mass(v, m0=1.0, c=3e8):
return m0 / torch.sqrt(1 - (v**2)/(c**2))
# 创建数据集
v = torch.linspace(0, 0.9*3e8, 100).unsqueeze(1)
m = relativistic_mass(v)
dataset = create_dataset_from_data(v, m, train_ratio=0.8)
# 训练模型
model = KAN(width=[1, 6, 1], grid=6, k=3, mult_arity=2)
model.fit(dataset, steps=150, lamb=0.005)
图3:逼近相对论质能关系的KAN网络结构,显示了输入变量间的乘法交互
模型部署:TensorRT转换
训练好的KAN模型可以转换为TensorRT格式以加速推理:
import torch
from kan import KAN
# 加载训练好的模型
model = KAN(width=[1, 8, 1])
model.load_state_dict(torch.load('kan_model.pth'))
model.eval()
# 导出为ONNX格式
dummy_input = torch.randn(1, 1)
torch.onnx.export(
model,
dummy_input,
"kan_model.onnx",
input_names=["input"],
output_names=["output"],
dynamic_axes={"input": {0: "batch_size"}, "output": {0: "batch_size"}}
)
# 使用TensorRT转换(需安装TensorRT)
# trtexec --onnx=kan_model.onnx --saveEngine=kan_model.trt
💡 部署提示:对于边缘设备部署,可以进一步使用量化技术将模型精度从FP32降至FP16或INT8,通常可减少75%的模型大小而精度损失很小。
总结
KAN作为一种新兴的神经网络架构,通过融合数学理论与神经网络优势,在可解释性和准确性之间取得了平衡。本教程从KAN的基本原理出发,详细介绍了环境配置、基础实现、参数调优和实际应用,展示了KAN在科学计算、物理模拟等领域的独特价值。
随着研究的深入,KAN有望在更多领域发挥作用,特别是在需要结合物理先验知识和数据驱动方法的场景中。对于机器学习研究者和从业者来说,掌握KAN将为解决复杂问题提供新的思路和工具。
通过本文提供的代码示例和最佳实践,您可以快速开始KAN的实践之旅,探索这一创新架构带来的无限可能。
GLM-5智谱 AI 正式发布 GLM-5,旨在应对复杂系统工程和长时域智能体任务。Jinja00
GLM-5-w4a8GLM-5-w4a8基于混合专家架构,专为复杂系统工程与长周期智能体任务设计。支持单/多节点部署,适配Atlas 800T A3,采用w4a8量化技术,结合vLLM推理优化,高效平衡性能与精度,助力智能应用开发Jinja00
jiuwenclawJiuwenClaw 是一款基于openJiuwen开发的智能AI Agent,它能够将大语言模型的强大能力,通过你日常使用的各类通讯应用,直接延伸至你的指尖。Python0192- QQwen3.5-397B-A17BQwen3.5 实现了重大飞跃,整合了多模态学习、架构效率、强化学习规模以及全球可访问性等方面的突破性进展,旨在为开发者和企业赋予前所未有的能力与效率。Jinja00
AtomGit城市坐标计划AtomGit 城市坐标计划开启!让开源有坐标,让城市有星火。致力于与城市合伙人共同构建并长期运营一个健康、活跃的本地开发者生态。01
awesome-zig一个关于 Zig 优秀库及资源的协作列表。Makefile00