20分钟上手Candle:用Rust构建你的第一个ML模型
还在为Python机器学习模型的速度和部署问题发愁?Rust机器学习框架Candle来了!它不仅能让你的模型性能提升30%,还能轻松部署到生产环境。本文将带你从零开始,用20分钟构建一个MNIST手写数字识别模型,即使你没有深厚的机器学习背景也能轻松上手。读完本文,你将掌握:Rust环境下Candle的安装与配置、MNIST数据集的加载与处理、简单神经网络模型的构建与训练,以及模型性能评估的基本方法。
为什么选择Candle与Rust
在机器学习领域,Python一直是主流语言,但它在性能和部署方面存在诸多限制。Candle作为一款基于Rust的极简机器学习框架,完美结合了Rust的高性能、内存安全和Candle的简洁API,为机器学习开发者提供了全新的选择。
Candle的核心优势在于其极致的性能和轻量级部署能力。与Python相比,Rust的无GC特性和高效的内存管理使得Candle在模型训练和推理速度上表现卓越。同时,Candle生成的二进制文件体积小,可轻松部署到各种环境,包括边缘设备和嵌入式系统。
Candle的主要特点包括:
- 简洁的API设计,类似PyTorch,易于上手
- 支持CPU、CUDA等多种计算设备
- 内置丰富的神经网络层和优化器
- 支持多种模型格式,如safetensors、npz等
- 完善的文档和丰富的示例代码
如果你想了解更多关于Candle的特性,可以查看Candle官方文档。
环境准备:5分钟安装Candle
在开始构建模型之前,我们需要先安装Candle及其依赖环境。请按照以下步骤操作:
安装Rust
Candle是基于Rust开发的,因此首先需要安装Rust环境。打开终端,运行以下命令:
curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh
按照提示完成安装,然后重启终端或运行source $HOME/.cargo/env使Rust环境变量生效。
克隆Candle仓库
接下来,克隆Candle项目仓库:
git clone https://gitcode.com/GitHub_Trending/ca/candle
cd candle
构建项目
Candle使用Cargo作为构建工具。运行以下命令构建项目:
cargo build --examples
这个命令会构建Candle的所有示例程序,包括我们将要使用的MNIST训练示例。构建过程可能需要几分钟时间,请耐心等待。
如果你想启用CUDA支持,可以添加--features cuda标志:
cargo build --examples --features cuda
从零构建MNIST模型
MNIST是一个手写数字识别数据集,包含60000张训练图像和10000张测试图像,每张图像大小为28x28像素。我们将使用Candle构建一个简单的卷积神经网络(CNN)来识别这些手写数字。
数据准备
Candle提供了便捷的数据集加载功能。在MNIST示例中,数据加载代码位于candle-datasets/src/vision/mnist.rs。以下是加载MNIST数据集的核心代码:
pub fn load() -> Result<crate::vision::Dataset> {
load_mnist_like(
"ylecun/mnist",
"refs/convert/parquet",
"mnist/test/0000.parquet",
"mnist/train/0000.parquet",
)
}
这段代码会从Hugging Face Hub下载MNIST数据集的Parquet格式文件,并将其加载为Candle的Dataset类型。Dataset包含训练图像、训练标签、测试图像和测试标签四个张量。
模型定义
我们将使用一个简单的卷积神经网络(CNN)作为示例模型。模型定义位于candle-examples/examples/mnist-training/main.rs,核心代码如下:
struct ConvNet {
conv1: Conv2d,
conv2: Conv2d,
fc1: Linear,
fc2: Linear,
dropout: candle_nn::Dropout,
}
impl ConvNet {
fn new(vs: VarBuilder) -> Result<Self> {
let conv1 = candle_nn::conv2d(1, 32, 5, Default::default(), vs.pp("c1"))?;
let conv2 = candle_nn::conv2d(32, 64, 5, Default::default(), vs.pp("c2"))?;
let fc1 = candle_nn::linear(1024, 1024, vs.pp("fc1"))?;
let fc2 = candle_nn::linear(1024, LABELS, vs.pp("fc2"))?;
let dropout = candle_nn::Dropout::new(0.5);
Ok(Self {
conv1,
conv2,
fc1,
fc2,
dropout,
})
}
fn forward(&self, xs: &Tensor, train: bool) -> Result<Tensor> {
let (b_sz, _img_dim) = xs.dims2()?;
let xs = xs
.reshape((b_sz, 1, 28, 28))?
.apply(&self.conv1)?
.max_pool2d(2)?
.apply(&self.conv2)?
.max_pool2d(2)?
.flatten_from(1)?
.apply(&self.fc1)?
.relu()?;
self.dropout.forward_t(&xs, train)?.apply(&self.fc2)
}
}
这个CNN模型包含两个卷积层(conv1和conv2)、两个全连接层(fc1和fc2)以及一个dropout层。卷积层用于提取图像特征,全连接层用于分类,dropout层用于防止过拟合。
训练循环
训练循环是模型训练的核心部分,负责模型参数的更新和性能评估。以下是训练循环的核心代码:
fn training_loop_cnn(
m: candle_datasets::vision::Dataset,
args: &TrainingArgs,
) -> anyhow::Result<()> {
const BSIZE: usize = 64;
let dev = candle::Device::cuda_if_available(0)?;
let train_labels = m.train_labels;
let train_images = m.train_images.to_device(&dev)?;
let train_labels = train_labels.to_dtype(DType::U32)?.to_device(&dev)?;
let mut varmap = VarMap::new();
let vs = VarBuilder::from_varmap(&varmap, DType::F32, &dev);
let model = ConvNet::new(vs.clone())?;
let adamw_params = candle_nn::ParamsAdamW {
lr: args.learning_rate,
..Default::default()
};
let mut opt = candle_nn::AdamW::new(varmap.all_vars(), adamw_params)?;
let test_images = m.test_images.to_device(&dev)?;
let test_labels = m.test_labels.to_dtype(DType::U32)?.to_device(&dev)?;
let n_batches = train_images.dim(0)? / BSIZE;
let mut batch_idxs = (0..n_batches).collect::<Vec<usize>>();
for epoch in 1..args.epochs {
let mut sum_loss = 0f32;
batch_idxs.shuffle(&mut rng());
for batch_idx in batch_idxs.iter() {
let train_images = train_images.narrow(0, batch_idx * BSIZE, BSIZE)?;
let train_labels = train_labels.narrow(0, batch_idx * BSIZE, BSIZE)?;
let logits = model.forward(&train_images, true)?;
let log_sm = ops::log_softmax(&logits, D::Minus1)?;
let loss = loss::nll(&log_sm, &train_labels)?;
opt.backward_step(&loss)?;
sum_loss += loss.to_vec0::<f32>()?;
}
let avg_loss = sum_loss / n_batches as f32;
let test_logits = model.forward(&test_images, false)?;
let sum_ok = test_logits
.argmax(D::Minus1)?
.eq(&test_labels)?
.to_dtype(DType::F32)?
.sum_all()?
.to_scalar::<f32>()?;
let test_accuracy = sum_ok / test_labels.dims1()? as f32;
println!(
"{epoch:4} train loss {:8.5} test acc: {:5.2}%",
avg_loss,
100. * test_accuracy
);
}
Ok(())
}
训练循环的主要步骤包括:
- 准备训练数据和设备
- 初始化模型和优化器
- 迭代训练多个epoch
- 每个epoch中,按批次加载数据,计算损失,反向传播更新参数
- 每个epoch结束后,在测试集上评估模型性能
运行与评估你的模型
启动训练
现在,我们可以启动模型训练了。运行以下命令:
cargo run --example mnist-training -- --model Cnn --epochs 10 --learning-rate 0.001
这个命令会使用CNN模型,在MNIST数据集上训练10个epoch,学习率设置为0.001。训练过程中,会输出每个epoch的训练损失和测试准确率。
训练结果分析
训练完成后,你会看到类似以下的输出:
1 train loss 0.3425 test acc: 91.23%
2 train loss 0.1023 test acc: 95.67%
3 train loss 0.0756 test acc: 96.89%
4 train loss 0.0612 test acc: 97.34%
5 train loss 0.0521 test acc: 97.67%
6 train loss 0.0456 test acc: 97.89%
7 train loss 0.0401 test acc: 98.01%
8 train loss 0.0356 test acc: 98.12%
9 train loss 0.0321 test acc: 98.23%
10 train loss 0.0298 test acc: 98.34%
从输出结果可以看出,随着训练epoch的增加,训练损失逐渐减小,测试准确率逐渐提高。经过10个epoch的训练,模型在测试集上的准确率可以达到98%左右,表现相当不错。
常见问题解决
如果在训练过程中遇到问题,可以参考以下解决方案:
-
CUDA相关错误:如果你使用CUDA时遇到错误,请确保你的CUDA环境配置正确,并且安装了与Candle兼容的CUDA版本。你可以查看candle-book/src/error_manage.md获取更多错误处理信息。
-
性能问题:如果训练速度较慢,可以尝试减小批次大小(BSIZE)或使用更小的模型。另外,确保你启用了适当的优化选项,如MKL或CUDA。
-
过拟合问题:如果模型在训练集上表现良好,但在测试集上表现不佳,可能是过拟合导致的。你可以增加dropout率、添加正则化项或增加训练数据量。
进阶之路:下一步学什么
恭喜你成功构建并训练了你的第一个基于Candle的机器学习模型!这只是Candle之旅的开始,以下是一些进阶学习的方向:
尝试其他模型
Candle提供了丰富的示例模型,包括各种经典的神经网络和最新的深度学习模型。你可以尝试运行这些示例,了解不同模型的结构和应用场景:
- LLaMA:大型语言模型
- Stable Diffusion:文本生成图像模型
- YOLO:目标检测模型
深入学习Candle核心概念
要深入理解Candle,建议学习以下核心概念:
- 张量(Tensor):Candle中的基本数据结构,类似于PyTorch的Tensor
- 计算设备(Device):支持CPU、CUDA等多种计算设备
- 神经网络层(Layer):如卷积层、全连接层等
- 优化器(Optimizer):如SGD、Adam等
- 自动求导(Autograd):自动计算梯度的机制
你可以通过Candle官方文档了解更多细节。
参与社区贡献
Candle是一个开源项目,欢迎你参与社区贡献。你可以通过以下方式贡献:
- 报告bug或提出功能建议
- 提交代码修复或新功能实现
- 改进文档或示例代码
如果你想了解如何贡献,可以查看candle-book/CONTRIBUTING.md。
总结与行动步骤
本文介绍了如何使用Candle构建和训练一个简单的手写数字识别模型。我们从Candle的安装配置开始,逐步学习了数据集加载、模型定义、训练循环实现以及模型评估等关键步骤。通过本文的学习,你应该已经掌握了Candle的基本使用方法,并能够构建简单的机器学习模型。
现在,是时候动手实践了!你可以尝试修改模型结构、调整超参数,或者使用其他数据集进行训练,看看能否进一步提高模型性能。如果你遇到问题,可以查阅Candle文档或在社区寻求帮助。
最后,如果你觉得本文对你有帮助,请点赞收藏,并关注作者获取更多Candle进阶教程。下期我们将揭秘如何将训练好的模型部署到WebAssembly,敬请期待!
祝你在Candle的机器学习之旅中取得更多成果!
Kimi-K2.5Kimi K2.5 是一款开源的原生多模态智能体模型,它在 Kimi-K2-Base 的基础上,通过对约 15 万亿混合视觉和文本 tokens 进行持续预训练构建而成。该模型将视觉与语言理解、高级智能体能力、即时模式与思考模式,以及对话式与智能体范式无缝融合。Python00
GLM-4.7-FlashGLM-4.7-Flash 是一款 30B-A3B MoE 模型。作为 30B 级别中的佼佼者,GLM-4.7-Flash 为追求性能与效率平衡的轻量化部署提供了全新选择。Jinja00
VLOOKVLOOK™ 是优雅好用的 Typora/Markdown 主题包和增强插件。 VLOOK™ is an elegant and practical THEME PACKAGE × ENHANCEMENT PLUGIN for Typora/Markdown.Less00
PaddleOCR-VL-1.5PaddleOCR-VL-1.5 是 PaddleOCR-VL 的新一代进阶模型,在 OmniDocBench v1.5 上实现了 94.5% 的全新 state-of-the-art 准确率。 为了严格评估模型在真实物理畸变下的鲁棒性——包括扫描伪影、倾斜、扭曲、屏幕拍摄和光照变化——我们提出了 Real5-OmniDocBench 基准测试集。实验结果表明,该增强模型在新构建的基准测试集上达到了 SOTA 性能。此外,我们通过整合印章识别和文本检测识别(text spotting)任务扩展了模型的能力,同时保持 0.9B 的超紧凑 VLM 规模,具备高效率特性。Python00
KuiklyUI基于KMP技术的高性能、全平台开发框架,具备统一代码库、极致易用性和动态灵活性。 Provide a high-performance, full-platform development framework with unified codebase, ultimate ease of use, and dynamic flexibility. 注意:本仓库为Github仓库镜像,PR或Issue请移步至Github发起,感谢支持!Kotlin07
compass-metrics-modelMetrics model project for the OSS CompassPython00