pykan实战指南:从零开始构建你的第一个KAN模型
本文是一份详细的pykan实战指南,全面介绍了从环境配置到模型训练评估的完整流程。文章首先详细讲解了pykan开发环境的配置方法,包括系统要求、多种安装方式(PyPI、GitHub源码、Conda)以及依赖包管理。然后深入解析了KAN模型的初始化参数配置,包括网络结构、网格设置、激活函数选择等核心参数。接着介绍了数据集创建与预处理的最佳实践,涵盖合成数据生成、现有数据处理、归一化策略等。最后详细阐述了模型训练、评估与可视化的完整流程,包括网格自适应机制、正则化策略、剪枝优化和性能监控等重要内容。
环境配置与依赖安装详细步骤
在开始使用pykan构建Kolmogorov-Arnold Networks之前,确保您拥有一个稳定且兼容的开发环境至关重要。本节将详细介绍从零开始配置pykan开发环境的完整流程,涵盖多种安装方式和环境管理策略。
系统要求与前置条件
pykan对运行环境有明确的要求,以下是必须满足的基本条件:
| 组件 | 最低版本要求 | 推荐版本 |
|---|---|---|
| Python | 3.6+ | 3.9.7+ |
| pip | 最新版本 | 最新版本 |
| 操作系统 | Windows 10 / macOS 10.15+ / Linux | 任意现代系统 |
方法一:使用PyPI安装(推荐初学者)
对于大多数用户,通过PyPI安装是最简单快捷的方式:
# 创建并激活虚拟环境(推荐)
python -m venv pykan-env
source pykan-env/bin/activate # Linux/macOS
# 或
pykan-env\Scripts\activate # Windows
# 安装pykan
pip install pykan
安装过程会自动处理所有依赖关系,包括:
torch==2.2.2- PyTorch深度学习框架numpy==1.24.4- 数值计算库matplotlib==3.6.2- 数据可视化scikit-learn==1.1.3- 机器学习工具- 以及其他必要的科学计算库
方法二:从GitHub源码安装(开发者模式)
如果您需要最新功能或希望贡献代码,建议从源码安装:
# 克隆仓库
git clone https://gitcode.com/gh_mirrors/pyk/pykan.git
cd pykan
# 创建虚拟环境
python -m venv .venv
source .venv/bin/activate
# 安装开发模式依赖
pip install -e .
这种安装方式允许您直接修改源代码并立即看到效果,非常适合开发和调试。
方法三:使用Conda环境管理
对于习惯使用Anaconda/Miniconda的用户:
# 创建Conda环境
conda create --name pykan-env python=3.9.7
conda activate pykan-env
# 安装PyTorch(根据您的CUDA版本选择)
conda install pytorch torchvision torchaudio pytorch-cuda=11.8 -c pytorch -c nvidia
# 安装其他依赖
pip install matplotlib numpy scikit-learn sympy tqdm pandas seaborn pyyaml
# 安装pykan
pip install pykan
依赖包详细说明
pykan的核心依赖关系如下表所示:
| 包名称 | 版本 | 用途描述 |
|---|---|---|
| torch | 2.2.2 | 深度学习框架核心 |
| numpy | 1.24.4 | 数值计算和数组操作 |
| matplotlib | 3.6.2 | 数据可视化和绘图 |
| scikit-learn | 1.1.3 | 机器学习工具和评估指标 |
| sympy | 1.11.1 | 符号数学计算 |
| tqdm | 4.66.2 | 进度条显示 |
| pandas | 2.0.1 | 数据处理和分析 |
| seaborn | 最新 | 统计数据可视化 |
| pyyaml | 最新 | YAML配置文件解析 |
环境验证与测试
安装完成后,通过以下步骤验证环境配置是否正确:
# 验证pykan安装
import pykan
print(f"pykan版本: {pykan.__version__}")
# 验证核心依赖
import torch
import numpy as np
import matplotlib.pyplot as plt
print(f"PyTorch版本: {torch.__version__}")
print(f"NumPy版本: {np.__version__}")
print(f"Matplotlib版本: {plt.__version__}")
# 测试GPU支持(如果可用)
print(f"CUDA可用: {torch.cuda.is_available()}")
if torch.cuda.is_available():
print(f"GPU设备: {torch.cuda.get_device_name(0)}")
常见问题排查
问题1:安装过程中出现版本冲突
# 解决方案:使用虚拟环境隔离依赖
python -m venv clean-env
source clean-env/bin/activate
pip install --upgrade pip
pip install pykan
问题2:PyTorch CUDA版本不匹配
# 先卸载现有PyTorch
pip uninstall torch torchvision torchaudio
# 安装指定版本的PyTorch
pip install torch==2.2.2 torchvision==0.17.2 torchaudio==2.2.2 --index-url https://download.pytorch.org/whl/cu118
问题3:依赖包版本过旧
# 更新所有依赖到兼容版本
pip install --upgrade matplotlib numpy scikit-learn sympy tqdm pandas
开发环境配置最佳实践
为了获得最佳的开发体验,建议配置以下开发工具:
- 代码编辑器: VS Code with Python extension
- Jupyter Notebook: 用于交互式开发和实验
- 调试工具: 配置Python调试器用于代码调试
- 版本控制: 使用Git进行代码版本管理
通过遵循上述步骤,您将获得一个完整且稳定的pykan开发环境,为后续的KAN模型构建和实验打下坚实基础。
KAN模型初始化与参数配置详解
Kolmogorov-Arnold Networks (KANs) 的初始化过程是确保模型训练成功的关键第一步。与传统的多层感知机不同,KANs采用了一种基于样条函数和基础函数的混合激活机制,这使得其初始化参数配置更加丰富和灵活。
KAN初始化核心参数解析
KAN模型的初始化通过MultKAN类的构造函数完成,该函数提供了丰富的参数来控制模型的初始状态。以下是主要的初始化参数及其作用:
| 参数名称 | 类型 | 默认值 | 描述 |
|---|---|---|---|
width |
list | None | 网络宽度配置,如[2,5,1]表示2输入、5隐藏神经元、1输出 |
grid |
int | 3 | 网格间隔数量,控制样条的分辨率 |
k |
int | 3 | 样条多项式阶数,通常为3(三次样条) |
mult_arity |
int/list | 2 | 乘法节点的元数(要相乘的数字数量) |
noise_scale |
float | 0.3 | 样条初始注入噪声的尺度 |
scale_base_mu |
float | 0.0 | 基础函数尺度的均值 |
scale_base_sigma |
float | 1.0 | 基础函数尺度的标准差 |
base_fun |
str | 'silu' | 基础函数类型:'silu', 'identity', 'zero' |
symbolic_enabled |
bool | True | 是否启用符号计算分支 |
grid_eps |
float | 0.02 | 网格自适应参数,0-1之间插值 |
grid_range |
list | [-1, 1] | 网格范围设置 |
sp_trainable |
bool | True | 样条尺度是否可训练 |
sb_trainable |
bool | True | 基础函数尺度是否可训练 |
sparse_init |
bool | False | 是否使用稀疏初始化 |
激活函数初始化机制
KAN的每个激活函数初始化为:
其中各组件的作用如下:
graph TD
A[激活函数 φ(x)] --> B[基础函数 b(x)]
A --> C[样条函数 spline(x)]
B --> D[尺度参数 scale_base]
C --> E[尺度参数 scale_sp]
D --> F[正态分布 N(μ, σ²)]
E --> G[正态分布 N(0, noise_scale²)]
网格配置参数详解
网格参数控制着样条函数的分布和自适应特性:
# 网格配置示例
model = MultKAN(
width=[2, 5, 1],
grid=5, # 5个网格间隔
k=3, # 三次样条
grid_eps=0.02, # 接近均匀网格
grid_range=[-1, 1] # 网格范围从-1到1
)
grid_eps参数控制网格的自适应程度:
grid_eps = 1:完全均匀网格grid_eps = 0:基于样本分位数的自适应网格0 < grid_eps < 1:两种极端情况的插值
基础函数选择策略
KAN支持多种基础函数,每种适用于不同的场景:
# 不同基础函数配置示例
model_silu = MultKAN(width=[2,5,1], base_fun='silu') # 默认SILU函数
model_linear = MultKAN(width=[2,5,1], base_fun='identity') # 线性初始化
model_zero = MultKAN(width=[2,5,1], base_fun='zero') # 零基础函数
初始化模式对比
以下是几种常见的初始化配置模式及其适用场景:
| 初始化模式 | 参数配置 | 适用场景 | 效果 |
|---|---|---|---|
| 默认初始化 | noise_scale=0.3, base_fun='silu' |
通用场景 | 平衡的初始状态 |
| 线性初始化 | noise_scale=0, base_fun='identity' |
理论分析 | 所有激活函数初始为线性 |
| 高噪声初始化 | noise_scale=10.0 |
实验演示 | 高度波动的初始状态 |
| 稀疏初始化 | sparse_init=True |
特征选择 | 大多数参数初始为零 |
参数配置最佳实践
基于项目文档和实际经验,以下是KAN模型初始化的推荐配置:
# 推荐的基础配置
model = MultKAN(
width=[input_dim, hidden_dim, output_dim],
grid=3, # 适中的网格分辨率
k=3, # 三次样条
noise_scale=0.1, # 适度的初始噪声
scale_base_mu=0.0, # 基础函数尺度均值为0
scale_base_sigma=1.0, # 基础函数尺度标准差为1
base_fun='silu', # 使用SILU作为基础函数
grid_eps=0.02, # 接近均匀网格
grid_range=[-1, 1], # 标准化的输入范围
seed=42 # 固定随机种子确保可重现性
)
设备与内存配置
KAN模型支持GPU加速,初始化时可以通过.to(device)方法指定设备:
import torch
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = MultKAN(width=[2,5,1], grid=5, k=3)
model.to(device) # 移动到指定设备
检查点与状态管理
KAN提供了自动检查点功能,可以保存模型的不同状态:
model = MultKAN(
width=[2,5,1],
auto_save=True, # 启用自动保存
ckpt_path='./model', # 检查点保存路径
state_id=0 # 初始状态ID
)
这种设计使得模型可以在训练过程中回滚到之前的任何状态,为超参数调优和实验管理提供了便利。
正确的初始化配置是KAN模型成功训练的基础。通过合理设置网格参数、基础函数类型和初始化噪声,可以为后续的训练过程奠定良好的基础,确保模型能够有效地学习数据中的复杂模式。
数据集创建与预处理最佳实践
在Kolmogorov-Arnold Networks (KANs) 的训练过程中,高质量的数据集创建和适当的预处理是获得优秀模型性能的关键。pykan提供了灵活的工具来支持各种类型的数据集创建,从简单的函数拟合到复杂的物理系统建模。本节将深入探讨数据集创建的最佳实践、预处理技巧以及常见问题的解决方案。
数据集创建的核心工具
pykan提供了两个主要的数据集创建函数:create_dataset 用于从数学函数生成合成数据,create_dataset_from_data 用于从现有数据创建训练测试分割。
1. 使用 create_dataset 创建合成数据
create_dataset 函数是创建函数拟合任务数据集的强大工具,它支持多种配置选项:
from kan.utils import create_dataset
import torch
# 基本用法:创建二维函数数据集
f = lambda x: torch.exp(torch.sin(torch.pi*x[:,[0]]) + x[:,[1]]**2)
dataset = create_dataset(f, n_var=2, device='cpu')
# 高级配置:自定义范围和样本数量
dataset = create_dataset(
f,
n_var=2,
ranges=[[-2, 2], [-3, 3]], # 每个变量的不同范围
train_num=5000, # 训练样本数
test_num=1000, # 测试样本数
normalize_input=True, # 输入归一化
normalize_label=True, # 标签归一化
seed=42 # 随机种子
)
2. 使用 create_dataset_from_data 处理现有数据
对于已有数据,可以使用 create_dataset_from_data 函数:
from kan.utils import create_dataset_from_data
import numpy as np
# 从numpy数组创建数据集
x = np.random.randn(1000, 3) # 1000个样本,3个特征
y = np.sin(x[:, 0]) + np.cos(x[:, 1]) * x[:, 2] # 目标函数
dataset = create_dataset_from_data(
torch.from_numpy(x).float(),
torch.from_numpy(y).float().unsqueeze(1),
train_ratio=0.8, # 训练集比例
device='cuda' if torch.cuda.is_available() else 'cpu'
)
数据预处理最佳实践
1. 输入归一化策略
输入归一化对于KAN训练至关重要,可以加速收敛并提高数值稳定性:
# 方法1:使用内置归一化
dataset = create_dataset(f, n_var=2, normalize_input=True)
# 方法2:手动归一化(更灵活的控制)
def custom_normalize(data):
mean = torch.mean(data, dim=0, keepdim=True)
std = torch.std(data, dim=0, keepdim=True)
# 添加小常数避免除零
return (data - mean) / (std + 1e-8)
# 应用自定义归一化
train_input_normalized = custom_normalize(dataset['train_input'])
test_input_normalized = custom_normalize(dataset['test_input'])
2. 输出标签处理
对于不同的任务类型,标签处理策略有所不同:
回归任务:
# 标签归一化有助于训练稳定性
dataset = create_dataset(f, normalize_label=True)
# 或者手动标准化
train_label = dataset['train_label']
label_mean = torch.mean(train_label)
label_std = torch.std(train_label)
normalized_labels = (train_label - label_mean) / label_std
分类任务:
from sklearn.datasets import make_moons
from sklearn.preprocessing import OneHotEncoder
# 创建分类数据集
X, y = make_moons(n_samples=1000, noise=0.1)
encoder = OneHotEncoder(sparse_output=False)
y_onehot = encoder.fit_transform(y.reshape(-1, 1))
dataset = {
'train_input': torch.from_numpy(X).float(),
'train_label': torch.from_numpy(y_onehot).float(),
'test_input': torch.from_numpy(X_test).float(),
'test_label': torch.from_numpy(y_test_onehot).float()
}
高级数据集配置技巧
1. 多变量范围配置
对于不同输入变量,可以设置不同的取值范围:
# 每个变量独立的范围配置
ranges = [
[-5, 5], # 第一个变量范围
[0, 10], # 第二个变量范围
[-1, 1] # 第三个变量范围
]
dataset = create_dataset(
lambda x: x[:,0]**2 + torch.sin(x[:,1]) * x[:,2],
n_var=3,
ranges=ranges,
train_num=2000
)
2. 函数模式选择
pykan支持两种函数计算模式,适应不同的函数定义习惯:
# 列模式(默认)- 使用 x[:,[i]] 索引
f_col = lambda x: x[:,[0]] * x[:,[1]] # 保持二维结构
dataset_col = create_dataset(f_col, n_var=2, f_mode='col')
# 行模式 - 使用 x[i] 索引
f_row = lambda x: x[0] * x[1] # 直接使用标量索引
dataset_row = create_dataset(f_row, n_var=2, f_mode='row')
3. 设备兼容性处理
确保数据与模型在同一设备上:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# 创建时指定设备
dataset = create_dataset(f, n_var=2, device=device)
# 或者后期迁移设备
dataset = {k: v.to(device) for k, v in dataset.items()}
数据质量验证
创建数据集后,进行基本的数据质量检查:
def validate_dataset(dataset):
"""验证数据集的基本属性"""
print(f"训练输入形状: {dataset['train_input'].shape}")
print(f"训练标签形状: {dataset['train_label'].shape}")
print(f"测试输入形状: {dataset['test_input'].shape}")
print(f"测试标签形状: {dataset['test_label'].shape}")
# 检查NaN值
for key, value in dataset.items():
if torch.isnan(value).any():
print(f"警告: {key} 包含NaN值")
# 检查数值范围
print(f"输入数值范围: [{dataset['train_input'].min():.3f}, {dataset['train_input'].max():.3f}]")
print(f"标签数值范围: [{dataset['train_label'].min():.3f}, {dataset['train_label'].max():.3f}]")
validate_dataset(dataset)
特殊场景处理
1. 处理奇异点问题
对于包含除零、对数等可能产生奇异点的函数:
def safe_function(x):
"""处理可能产生奇异点的安全函数"""
# 避免除零
denominator = x[:, 1].clone()
denominator[torch.abs(denominator) < 1e-8] = 1e-8 * torch.sign(denominator[torch.abs(denominator) < 1e-8])
# 避免对数负值
log_input = x[:, 0].clone()
log_input[log_input <= 0] = 1e-8
return torch.log(log_input) / denominator
dataset = create_dataset(safe_function, n_var=2)
2. 时间序列数据准备
对于时间序列预测任务:
def create_time_series_dataset(series, lookback=10, forecast=1):
"""创建时间序列数据集"""
X, y = [], []
for i in range(len(series) - lookback - forecast + 1):
X.append(series[i:i+lookback])
y.append(series[i+lookback:i+lookback+forecast])
X = torch.tensor(X, dtype=torch.float32)
y = torch.tensor(y, dtype=torch.float32)
return create_dataset_from_data(X, y, train_ratio=0.8)
性能优化建议
1. 内存效率优化
对于大型数据集,使用内存高效的创建方式:
# 分批创建大数据集
def create_large_dataset(f, n_var, total_samples=100000, batch_size=10000):
datasets = []
for i in range(0, total_samples, batch_size):
current_batch = min(batch_size, total_samples - i)
dataset_batch = create_dataset(f, n_var=n_var, train_num=current_batch, test_num=0)
datasets.append(dataset_batch)
# 合并批次
combined = {
'train_input': torch.cat([d['train_input'] for d in datasets], dim=0),
'train_label': torch.cat([d['train_label'] for d in datasets], dim=0)
}
return combined
2. 数据增强策略
对于数据稀缺的场景,考虑数据增强:
def augment_dataset(dataset, noise_level=0.01, num_augment=5):
"""通过添加噪声增强数据集"""
augmented_inputs = []
augmented_labels = []
for _ in range(num_augment):
noise = torch.randn_like(dataset['train_input']) * noise_level
augmented_inputs.append(dataset['train_input'] + noise)
augmented_labels.append(dataset['train_label'])
return {
'train_input': torch.cat([dataset['train_input']] + augmented_inputs, dim=0),
'train_label': torch.cat([dataset['train_label']] + augmented_labels, dim=0),
'test_input': dataset['test_input'],
'test_label': dataset['test_label']
}
可视化与调试
创建数据集后,进行可视化验证:
import matplotlib.pyplot as plt
def visualize_dataset(dataset, max_points=1000):
"""可视化数据集样本"""
if dataset['train_input'].shape[1] == 1:
# 一维输入可视化
plt.figure(figsize=(10, 6))
plt.scatter(dataset['train_input'][:max_points, 0].cpu().numpy(),
dataset['train_label'][:max_points, 0].cpu().numpy(),
alpha=0.5, label='训练数据')
plt.scatter(dataset['test_input'][:max_points//5, 0].cpu().numpy(),
dataset['test_label'][:max_points//5, 0].cpu().numpy(),
alpha=0.5, label='测试数据')
plt.legend()
plt.xlabel('输入')
plt.ylabel('输出')
plt.title('数据集分布')
plt.show()
elif dataset['train_input'].shape[1] == 2:
# 二维输入可视化
fig = plt.figure(figsize=(12, 5))
ax1 = fig.add_subplot(121)
scatter = ax1.scatter(dataset['train_input'][:max_points, 0].cpu().numpy(),
dataset['train_input'][:max_points, 1].cpu().numpy(),
c=dataset['train_label'][:max_points, 0].cpu().numpy(),
cmap='viridis', alpha=0.6)
plt.colorbar(scatter, label='输出值')
ax1.set_title('训练数据分布')
ax1.set_xlabel('x1')
ax1.set_ylabel('x2')
ax2 = fig.add_subplot(122)
scatter = ax2.scatter(dataset['test_input'][:max_points//5, 0].cpu().numpy(),
dataset['test_input'][:max_points//5, 1].cpu().numpy(),
c=dataset['test_label'][:max_points//5, 0].cpu().numpy(),
cmap='viridis', alpha=0.6)
plt.colorbar(scatter, label='输出值')
ax2.set_title('测试数据分布')
ax2.set_xlabel('x1')
ax2.set_ylabel('x2')
plt.tight_layout()
plt.show()
# 使用可视化函数
visualize_dataset(dataset)
通过遵循这些最佳实践,您可以创建高质量、适合KAN模型训练的数据集。正确的数据预处理不仅能够提高训练效率,还能显著改善模型的最终性能和泛化能力。记住,在机器学习项目中,数据质量往往比模型架构的选择更为重要。
模型训练、评估与可视化完整流程
Kolmogorov-Arnold Networks (KANs) 的训练、评估和可视化是一个完整的工作流程,每个环节都体现了KAN模型独特的优势。与传统的MLP不同,KAN的训练过程更加精细,包含了网格更新、正则化控制、剪枝优化等多个关键步骤。
训练流程详解
KAN模型的训练通过fit方法实现,该方法提供了丰富的参数来控制训练过程:
# 基本训练配置
model.fit(
dataset=dataset, # 训练数据集
opt="LBFGS", # 优化器选择(LBFGS/Adam)
steps=100, # 训练步数
lamb=0.001, # 稀疏正则化系数
lamb_l1=1.0, # L1正则化系数
lamb_entropy=2.0, # 熵正则化系数
update_grid=True, # 是否更新网格
grid_update_num=10, # 网格更新次数
lr=1.0, # 学习率
batch=-1, # 批次大小(-1表示全批次)
metrics=['train_loss', 'test_loss'] # 监控指标
)
训练过程的关键特性
网格自适应机制: KAN的独特之处在于其网格自适应能力。在训练过程中,模型会根据输入数据的分布动态调整B样条基函数的网格点:
flowchart TD
A[输入数据采样] --> B[计算数据分布]
B --> C{是否需要更新网格?}
C -->|是| D[重新计算网格点]
C -->|否| E[保持当前网格]
D --> F[更新B样条系数]
E --> G[继续参数优化]
F --> G
G --> H[完成当前训练步]
多目标正则化策略: KAN采用多层次正则化来控制模型复杂度和促进稀疏性:
| 正则化类型 | 参数 | 作用 | 推荐值 |
|---|---|---|---|
| 稀疏正则化 | lamb | 控制整体稀疏度 | 0.001-0.1 |
| L1正则化 | lamb_l1 | 促进权重稀疏 | 0.1-2.0 |
| 熵正则化 | lamb_entropy | 平衡激活分布 | 1.0-5.0 |
| 系数平滑 | lamb_coef | 平滑B样条系数 | 0.0-0.1 |
模型评估与性能监控
训练过程中,KAN会自动监控多个性能指标:
# 评估模型性能
results = model.evaluate(dataset)
print(f"训练损失: {results['train_loss']:.4e}")
print(f"测试损失: {results['test_loss']:.4e}")
print(f"正则化项: {results['reg']:.4e}")
关键评估指标
损失函数分解: KAN的总损失由三部分组成:
- 预测损失:衡量模型输出与真实值的差异
- 正则化损失:控制模型复杂度和稀疏性
- 网格约束损失:确保B样条基函数的平滑性
过拟合检测: 通过监控训练/测试损失比来识别过拟合:
def detect_overfitting(train_loss, test_loss, threshold=1.5):
ratio = test_loss / train_loss
if ratio > threshold:
return f"可能过拟合,比率: {ratio:.2f}"
return f"训练正常,比率: {ratio:.2f}"
可视化分析技术
KAN提供了丰富的可视化功能来理解模型内部工作机制:
网络结构可视化
# 绘制KAN网络结构
model.plot(
beta=3, # 线条粗细系数
metric='backward', # 可视化指标
scale=0.5, # 缩放因子
in_vars=['x', 'y'], # 输入变量名
out_vars=['f(x,y)'], # 输出变量名
title="训练后的KAN网络"
)
可视化输出展示了:
- 每个边的激活函数形状
- 连接权重大小(通过线条粗细表示)
- 节点的激活强度
- 输入输出变量的对应关系
激活函数分析
graph LR
A[原始B样条] --> B[训练后调整]
B --> C[符号函数拟合]
C --> D[最终激活函数]
style A fill:#e1f5fe
style D fill:#f1f8e9
剪枝与优化流程
训练完成后,通常需要进行剪枝来简化网络结构:
# 剪枝流程
model.prune(
node_th=1e-2, # 节点剪枝阈值
edge_th=3e-2 # 边剪枝阈值
)
# 剪枝后重新训练
model.fit(dataset, steps=20, lamb=0.0001)
剪枝策略对比
| 剪枝类型 | 阈值范围 | 效果 | 适用场景 |
|---|---|---|---|
| 节点剪枝 | 1e-3 to 1e-2 | 移除冗余神经元 | 高度冗余网络 |
| 边剪枝 | 1e-2 to 1e-1 | 移除弱连接 | 一般稀疏化 |
| 输入剪枝 | 1e-3 to 1e-2 | 移除不重要输入 | 特征选择 |
超参数调优指南
基于实践经验,以下超参数组合通常效果较好:
# 推荐超参数配置
hyperparameter_configs = {
'简单任务': {'grid': 3, 'k': 3, 'lamb': 0.001, 'steps': 50},
'中等任务': {'grid': 5, 'k': 3, 'lamb': 0.01, 'steps': 100},
'复杂任务': {'grid': 7, 'k': 4, 'lamb': 0.1, 'steps': 200}
}
网格大小选择策略
xychart-beta
title "网格大小对性能的影响"
x-axis [3, 5, 7, 10]
y-axis "模型复杂度" 0 --> 100
y-axis "训练时间" 0 --> 100
line [30, 60, 85, 100]
line [20, 50, 80, 95]
完整训练示例
下面是一个完整的KAN训练、评估和可视化示例:
# 1. 初始化模型
model = KAN(width=[2, 5, 1], grid=5, k=3, device=device)
# 2. 初始可视化
model.plot(title="初始KAN网络")
# 3. 第一阶段训练(基础拟合)
model.fit(dataset, opt="LBFGS", steps=50, lamb=0.001)
model.plot(title="第一阶段训练后")
# 4. 剪枝优化
model.prune(node_th=1e-2, edge_th=3e-2)
model.plot(title="剪枝后网络")
# 5. 第二阶段训练(精细调优)
model.fit(dataset, steps=30, lamb=0.0001, update_grid=False)
# 6. 最终评估
results = model.evaluate(dataset)
print(f"最终性能: 训练损失={results['train_loss']:.4e}, 测试损失={results['test_loss']:.4e}")
# 7. 最终可视化
model.plot(beta=4, scale=0.6, title="最终KAN网络")
这个完整流程展示了KAN模型从初始化到最终优化的全过程,每个步骤都包含了相应的评估和可视化,确保开发者能够全面理解模型的训练状态和性能表现。
通过这种系统化的训练、评估和可视化流程,KAN模型不仅能够实现高精度的函数逼近,还能保持很好的可解释性,为科学计算和工程应用提供了强大的工具。
通过本指南的完整学习,您已经掌握了使用pykan构建Kolmogorov-Arnold Networks的全流程技能。从环境配置到模型初始化,从数据预处理到训练优化,每个环节都提供了详细的实践指导和最佳实践。KAN模型相比传统MLP具有更好的可解释性和自适应能力,通过网格更新机制和多重正则化策略,能够有效学习复杂的数据模式。文章中的可视化技术和剪枝方法进一步增强了模型的可解释性和实用性。这份指南为您提供了从零开始构建高质量KAN模型所需的所有工具和知识,为后续的科学计算和机器学习项目奠定了坚实基础。
Kimi-K2.5Kimi K2.5 是一款开源的原生多模态智能体模型,它在 Kimi-K2-Base 的基础上,通过对约 15 万亿混合视觉和文本 tokens 进行持续预训练构建而成。该模型将视觉与语言理解、高级智能体能力、即时模式与思考模式,以及对话式与智能体范式无缝融合。Python00
GLM-4.7-FlashGLM-4.7-Flash 是一款 30B-A3B MoE 模型。作为 30B 级别中的佼佼者,GLM-4.7-Flash 为追求性能与效率平衡的轻量化部署提供了全新选择。Jinja00
VLOOKVLOOK™ 是优雅好用的 Typora/Markdown 主题包和增强插件。 VLOOK™ is an elegant and practical THEME PACKAGE × ENHANCEMENT PLUGIN for Typora/Markdown.Less00
PaddleOCR-VL-1.5PaddleOCR-VL-1.5 是 PaddleOCR-VL 的新一代进阶模型,在 OmniDocBench v1.5 上实现了 94.5% 的全新 state-of-the-art 准确率。 为了严格评估模型在真实物理畸变下的鲁棒性——包括扫描伪影、倾斜、扭曲、屏幕拍摄和光照变化——我们提出了 Real5-OmniDocBench 基准测试集。实验结果表明,该增强模型在新构建的基准测试集上达到了 SOTA 性能。此外,我们通过整合印章识别和文本检测识别(text spotting)任务扩展了模型的能力,同时保持 0.9B 的超紧凑 VLM 规模,具备高效率特性。Python00
KuiklyUI基于KMP技术的高性能、全平台开发框架,具备统一代码库、极致易用性和动态灵活性。 Provide a high-performance, full-platform development framework with unified codebase, ultimate ease of use, and dynamic flexibility. 注意:本仓库为Github仓库镜像,PR或Issue请移步至Github发起,感谢支持!Kotlin07
compass-metrics-modelMetrics model project for the OSS CompassPython00