20分钟上手Candle:用Rust构建你的第一个ML模型
还在为Python机器学习模型的速度和部署问题发愁?Rust机器学习框架Candle来了!它不仅能让你的模型性能提升30%,还能轻松部署到生产环境。本文将带你从零开始,用20分钟构建一个MNIST手写数字识别模型,即使你没有深厚的机器学习背景也能轻松上手。读完本文,你将掌握:Rust环境下Candle的安装与配置、MNIST数据集的加载与处理、简单神经网络模型的构建与训练,以及模型性能评估的基本方法。
为什么选择Candle与Rust
在机器学习领域,Python一直是主流语言,但它在性能和部署方面存在诸多限制。Candle作为一款基于Rust的极简机器学习框架,完美结合了Rust的高性能、内存安全和Candle的简洁API,为机器学习开发者提供了全新的选择。
Candle的核心优势在于其极致的性能和轻量级部署能力。与Python相比,Rust的无GC特性和高效的内存管理使得Candle在模型训练和推理速度上表现卓越。同时,Candle生成的二进制文件体积小,可轻松部署到各种环境,包括边缘设备和嵌入式系统。
Candle的主要特点包括:
- 简洁的API设计,类似PyTorch,易于上手
- 支持CPU、CUDA等多种计算设备
- 内置丰富的神经网络层和优化器
- 支持多种模型格式,如safetensors、npz等
- 完善的文档和丰富的示例代码
如果你想了解更多关于Candle的特性,可以查看Candle官方文档。
环境准备:5分钟安装Candle
在开始构建模型之前,我们需要先安装Candle及其依赖环境。请按照以下步骤操作:
安装Rust
Candle是基于Rust开发的,因此首先需要安装Rust环境。打开终端,运行以下命令:
curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh
按照提示完成安装,然后重启终端或运行source $HOME/.cargo/env使Rust环境变量生效。
克隆Candle仓库
接下来,克隆Candle项目仓库:
git clone https://gitcode.com/GitHub_Trending/ca/candle
cd candle
构建项目
Candle使用Cargo作为构建工具。运行以下命令构建项目:
cargo build --examples
这个命令会构建Candle的所有示例程序,包括我们将要使用的MNIST训练示例。构建过程可能需要几分钟时间,请耐心等待。
如果你想启用CUDA支持,可以添加--features cuda标志:
cargo build --examples --features cuda
从零构建MNIST模型
MNIST是一个手写数字识别数据集,包含60000张训练图像和10000张测试图像,每张图像大小为28x28像素。我们将使用Candle构建一个简单的卷积神经网络(CNN)来识别这些手写数字。
数据准备
Candle提供了便捷的数据集加载功能。在MNIST示例中,数据加载代码位于candle-datasets/src/vision/mnist.rs。以下是加载MNIST数据集的核心代码:
pub fn load() -> Result<crate::vision::Dataset> {
load_mnist_like(
"ylecun/mnist",
"refs/convert/parquet",
"mnist/test/0000.parquet",
"mnist/train/0000.parquet",
)
}
这段代码会从Hugging Face Hub下载MNIST数据集的Parquet格式文件,并将其加载为Candle的Dataset类型。Dataset包含训练图像、训练标签、测试图像和测试标签四个张量。
模型定义
我们将使用一个简单的卷积神经网络(CNN)作为示例模型。模型定义位于candle-examples/examples/mnist-training/main.rs,核心代码如下:
struct ConvNet {
conv1: Conv2d,
conv2: Conv2d,
fc1: Linear,
fc2: Linear,
dropout: candle_nn::Dropout,
}
impl ConvNet {
fn new(vs: VarBuilder) -> Result<Self> {
let conv1 = candle_nn::conv2d(1, 32, 5, Default::default(), vs.pp("c1"))?;
let conv2 = candle_nn::conv2d(32, 64, 5, Default::default(), vs.pp("c2"))?;
let fc1 = candle_nn::linear(1024, 1024, vs.pp("fc1"))?;
let fc2 = candle_nn::linear(1024, LABELS, vs.pp("fc2"))?;
let dropout = candle_nn::Dropout::new(0.5);
Ok(Self {
conv1,
conv2,
fc1,
fc2,
dropout,
})
}
fn forward(&self, xs: &Tensor, train: bool) -> Result<Tensor> {
let (b_sz, _img_dim) = xs.dims2()?;
let xs = xs
.reshape((b_sz, 1, 28, 28))?
.apply(&self.conv1)?
.max_pool2d(2)?
.apply(&self.conv2)?
.max_pool2d(2)?
.flatten_from(1)?
.apply(&self.fc1)?
.relu()?;
self.dropout.forward_t(&xs, train)?.apply(&self.fc2)
}
}
这个CNN模型包含两个卷积层(conv1和conv2)、两个全连接层(fc1和fc2)以及一个dropout层。卷积层用于提取图像特征,全连接层用于分类,dropout层用于防止过拟合。
训练循环
训练循环是模型训练的核心部分,负责模型参数的更新和性能评估。以下是训练循环的核心代码:
fn training_loop_cnn(
m: candle_datasets::vision::Dataset,
args: &TrainingArgs,
) -> anyhow::Result<()> {
const BSIZE: usize = 64;
let dev = candle::Device::cuda_if_available(0)?;
let train_labels = m.train_labels;
let train_images = m.train_images.to_device(&dev)?;
let train_labels = train_labels.to_dtype(DType::U32)?.to_device(&dev)?;
let mut varmap = VarMap::new();
let vs = VarBuilder::from_varmap(&varmap, DType::F32, &dev);
let model = ConvNet::new(vs.clone())?;
let adamw_params = candle_nn::ParamsAdamW {
lr: args.learning_rate,
..Default::default()
};
let mut opt = candle_nn::AdamW::new(varmap.all_vars(), adamw_params)?;
let test_images = m.test_images.to_device(&dev)?;
let test_labels = m.test_labels.to_dtype(DType::U32)?.to_device(&dev)?;
let n_batches = train_images.dim(0)? / BSIZE;
let mut batch_idxs = (0..n_batches).collect::<Vec<usize>>();
for epoch in 1..args.epochs {
let mut sum_loss = 0f32;
batch_idxs.shuffle(&mut rng());
for batch_idx in batch_idxs.iter() {
let train_images = train_images.narrow(0, batch_idx * BSIZE, BSIZE)?;
let train_labels = train_labels.narrow(0, batch_idx * BSIZE, BSIZE)?;
let logits = model.forward(&train_images, true)?;
let log_sm = ops::log_softmax(&logits, D::Minus1)?;
let loss = loss::nll(&log_sm, &train_labels)?;
opt.backward_step(&loss)?;
sum_loss += loss.to_vec0::<f32>()?;
}
let avg_loss = sum_loss / n_batches as f32;
let test_logits = model.forward(&test_images, false)?;
let sum_ok = test_logits
.argmax(D::Minus1)?
.eq(&test_labels)?
.to_dtype(DType::F32)?
.sum_all()?
.to_scalar::<f32>()?;
let test_accuracy = sum_ok / test_labels.dims1()? as f32;
println!(
"{epoch:4} train loss {:8.5} test acc: {:5.2}%",
avg_loss,
100. * test_accuracy
);
}
Ok(())
}
训练循环的主要步骤包括:
- 准备训练数据和设备
- 初始化模型和优化器
- 迭代训练多个epoch
- 每个epoch中,按批次加载数据,计算损失,反向传播更新参数
- 每个epoch结束后,在测试集上评估模型性能
运行与评估你的模型
启动训练
现在,我们可以启动模型训练了。运行以下命令:
cargo run --example mnist-training -- --model Cnn --epochs 10 --learning-rate 0.001
这个命令会使用CNN模型,在MNIST数据集上训练10个epoch,学习率设置为0.001。训练过程中,会输出每个epoch的训练损失和测试准确率。
训练结果分析
训练完成后,你会看到类似以下的输出:
1 train loss 0.3425 test acc: 91.23%
2 train loss 0.1023 test acc: 95.67%
3 train loss 0.0756 test acc: 96.89%
4 train loss 0.0612 test acc: 97.34%
5 train loss 0.0521 test acc: 97.67%
6 train loss 0.0456 test acc: 97.89%
7 train loss 0.0401 test acc: 98.01%
8 train loss 0.0356 test acc: 98.12%
9 train loss 0.0321 test acc: 98.23%
10 train loss 0.0298 test acc: 98.34%
从输出结果可以看出,随着训练epoch的增加,训练损失逐渐减小,测试准确率逐渐提高。经过10个epoch的训练,模型在测试集上的准确率可以达到98%左右,表现相当不错。
常见问题解决
如果在训练过程中遇到问题,可以参考以下解决方案:
-
CUDA相关错误:如果你使用CUDA时遇到错误,请确保你的CUDA环境配置正确,并且安装了与Candle兼容的CUDA版本。你可以查看candle-book/src/error_manage.md获取更多错误处理信息。
-
性能问题:如果训练速度较慢,可以尝试减小批次大小(BSIZE)或使用更小的模型。另外,确保你启用了适当的优化选项,如MKL或CUDA。
-
过拟合问题:如果模型在训练集上表现良好,但在测试集上表现不佳,可能是过拟合导致的。你可以增加dropout率、添加正则化项或增加训练数据量。
进阶之路:下一步学什么
恭喜你成功构建并训练了你的第一个基于Candle的机器学习模型!这只是Candle之旅的开始,以下是一些进阶学习的方向:
尝试其他模型
Candle提供了丰富的示例模型,包括各种经典的神经网络和最新的深度学习模型。你可以尝试运行这些示例,了解不同模型的结构和应用场景:
- LLaMA:大型语言模型
- Stable Diffusion:文本生成图像模型
- YOLO:目标检测模型
深入学习Candle核心概念
要深入理解Candle,建议学习以下核心概念:
- 张量(Tensor):Candle中的基本数据结构,类似于PyTorch的Tensor
- 计算设备(Device):支持CPU、CUDA等多种计算设备
- 神经网络层(Layer):如卷积层、全连接层等
- 优化器(Optimizer):如SGD、Adam等
- 自动求导(Autograd):自动计算梯度的机制
你可以通过Candle官方文档了解更多细节。
参与社区贡献
Candle是一个开源项目,欢迎你参与社区贡献。你可以通过以下方式贡献:
- 报告bug或提出功能建议
- 提交代码修复或新功能实现
- 改进文档或示例代码
如果你想了解如何贡献,可以查看candle-book/CONTRIBUTING.md。
总结与行动步骤
本文介绍了如何使用Candle构建和训练一个简单的手写数字识别模型。我们从Candle的安装配置开始,逐步学习了数据集加载、模型定义、训练循环实现以及模型评估等关键步骤。通过本文的学习,你应该已经掌握了Candle的基本使用方法,并能够构建简单的机器学习模型。
现在,是时候动手实践了!你可以尝试修改模型结构、调整超参数,或者使用其他数据集进行训练,看看能否进一步提高模型性能。如果你遇到问题,可以查阅Candle文档或在社区寻求帮助。
最后,如果你觉得本文对你有帮助,请点赞收藏,并关注作者获取更多Candle进阶教程。下期我们将揭秘如何将训练好的模型部署到WebAssembly,敬请期待!
祝你在Candle的机器学习之旅中取得更多成果!
GLM-5智谱 AI 正式发布 GLM-5,旨在应对复杂系统工程和长时域智能体任务。Jinja00
GLM-5-w4a8GLM-5-w4a8基于混合专家架构,专为复杂系统工程与长周期智能体任务设计。支持单/多节点部署,适配Atlas 800T A3,采用w4a8量化技术,结合vLLM推理优化,高效平衡性能与精度,助力智能应用开发Jinja00
请把这个活动推给顶尖程序员😎本次活动专为懂行的顶尖程序员量身打造,聚焦AtomGit首发开源模型的实际应用与深度测评,拒绝大众化浅层体验,邀请具备扎实技术功底、开源经验或模型测评能力的顶尖开发者,深度参与模型体验、性能测评,通过发布技术帖子、提交测评报告、上传实践项目成果等形式,挖掘模型核心价值,共建AtomGit开源模型生态,彰显顶尖程序员的技术洞察力与实践能力。00
Kimi-K2.5Kimi K2.5 是一款开源的原生多模态智能体模型,它在 Kimi-K2-Base 的基础上,通过对约 15 万亿混合视觉和文本 tokens 进行持续预训练构建而成。该模型将视觉与语言理解、高级智能体能力、即时模式与思考模式,以及对话式与智能体范式无缝融合。Python00
MiniMax-M2.5MiniMax-M2.5开源模型,经数十万复杂环境强化训练,在代码生成、工具调用、办公自动化等经济价值任务中表现卓越。SWE-Bench Verified得分80.2%,Multi-SWE-Bench达51.3%,BrowseComp获76.3%。推理速度比M2.1快37%,与Claude Opus 4.6相当,每小时仅需0.3-1美元,成本仅为同类模型1/10-1/20,为智能应用开发提供高效经济选择。【此简介由AI生成】Python00
Qwen3.5Qwen3.5 昇腾 vLLM 部署教程。Qwen3.5 是 Qwen 系列最新的旗舰多模态模型,采用 MoE(混合专家)架构,在保持强大模型能力的同时显著降低了推理成本。00- RRing-2.5-1TRing-2.5-1T:全球首个基于混合线性注意力架构的开源万亿参数思考模型。Python00