Tunix工具使用指南:从安装到模型部署全流程
Tunix(Tune-in-JAX)是一个基于JAX的大型语言模型后训练库,提供高效且可扩展的支持,包括监督微调、强化学习(RL)和知识蒸馏。本文将详细介绍Tunix的安装步骤、模型训练流程以及部署方法,帮助用户快速上手使用该工具。
1. 安装Tunix
1.1 环境准备
在安装Tunix之前,需要确保系统已安装Python环境。推荐使用Python 3.8及以上版本。此外,Tunix依赖JAX库,需要根据系统环境安装相应版本的JAX和JAXlib。
1.2 安装方式
Tunix提供多种安装方式,用户可以根据需求选择适合的方式:
1.2.1 从PyPI安装(推荐)
使用pip命令直接从PyPI安装Tunix:
pip install "google-tunix[prod]"
1.2.2 从源码安装(开发模式)
如果需要修改代码或使用最新开发版本,可以从源码安装:
git clone https://gitcode.com/GitHub_Trending/tu/tunix.git
cd tunix
pip install -e ".[dev]"
1.2.3 安装vLLM依赖(TPU支持)
对于需要在TPU上使用vLLM的用户,可以运行以下脚本安装相关依赖:
bash scripts/install_tunix_vllm_requirement.sh
该脚本会安装vLLM及其依赖,并配置TPU支持。
2. 模型训练流程
Tunix支持多种模型训练任务,包括监督微调(SFT)、强化学习(RL)和知识蒸馏。下面以强化学习和监督微调为例,介绍具体的训练流程。
2.1 强化学习训练(GRPO)
以Llama3.2-1B模型在GSM8K数据集上的训练为例,使用GRPO(Group Relative Policy Optimization)算法进行强化学习训练。
2.1.1 训练脚本
训练脚本位于examples/rl/gsm8k/run_llama3.2_1b.sh,主要参数设置如下:
python3 -m tunix.cli.grpo_main \
base_config.yaml \
reference_model_config.model_name="llama3.2-1b" \
reference_model_config.model_id="meta-llama/Llama-3.2-1B" \
reference_model_config.model_source="huggingface" \
reference_model_config.mesh.shape="(2,4)" \
reference_model_config.mesh.axis_names="('fsdp','tp')" \
actor_model_config.lora_config.rank=64 \
actor_model_config.lora_config.alpha=64.0 \
batch_size=1 \
num_batches=3738 \
num_train_epochs=1 \
rl_training_config.max_steps=3738 \
rollout_config.total_generation_steps=768 \
reward_functions="['tunix/cli/reward_fn/gsm8k.py']"
2.1.2 参数说明
reference_model_config:参考模型配置,包括模型名称、ID、来源等actor_model_config: Actor模型配置,设置LoRA参数mesh.shape和mesh.axis_names:设置模型并行策略batch_size和num_batches:设置训练批次大小和数量rollout_config:生成配置,设置生成步数等参数reward_functions:奖励函数配置
2.2 监督微调(SFT)
以Gemma3-4B模型在MTNT数据集上的监督微调和为例,介绍SFT训练流程。
2.2.1 训练脚本
训练脚本位于examples/sft/mtnt/run_gemma3_4b.sh,主要参数设置如下:
python3 -m tunix.cli.peft_main \
base_config.yaml \
model_name="gemma3-4b" \
model_id="gs://gemma-data/checkpoints/gemma3-4b-pt" \
model_source="gcs" \
tokenizer_path="gs://gemma-data/tokenizers/tokenizer_gemma3.model" \
dataset_name="mtnt/en-fr" \
optimizer_config.learning_rate=1e-5 \
training_config.max_steps=100 \
mesh.shape="(2,2)" \
mesh.axis_names="('fsdp','tp')"
2.2.2 参数说明
model_name和model_id:模型名称和IDdataset_name:数据集名称optimizer_config:优化器配置,设置学习率等参数training_config:训练配置,设置最大步数等参数mesh:模型并行配置
2.3 训练监控
Tunix支持使用TensorBoard监控训练过程。训练脚本中通过metrics_logging_options.log_dir参数指定日志目录,例如:
training_config.metrics_logging_options.log_dir="/tmp/tensorboard/full"
启动TensorBoard查看训练指标:
tensorboard --logdir=/tmp/tensorboard/full
3. 模型部署
3.1 模型导出
训练完成后,模型会保存到指定的 checkpoint 目录。可以通过以下方式导出模型:
from tunix.sft.checkpoint_manager import CheckpointManager
checkpoint_manager = CheckpointManager(
checkpoint_dir="/path/to/checkpoint",
model=model,
)
checkpoint_manager.save_checkpoint(step=1000)
3.2 模型加载与推理
使用Tunix加载训练好的模型进行推理:
from tunix.models.llama3.model import Llama3Model
from tunix.models.llama3.params import Llama3Params
params = Llama3Params.from_pretrained(
model_id="/path/to/checkpoint",
mesh_shape=(2, 4),
axis_names=("fsdp", "tp"),
)
model = Llama3Model(params)
inputs = tokenizer("Hello, world!", return_tensors="jax")
outputs = model.generate(**inputs, max_length=50)
print(tokenizer.decode(outputs[0]))
4. 总结
本文介绍了Tunix的安装步骤、模型训练流程和部署方法,涵盖了从环境准备到模型推理的全流程。Tunix提供了丰富的功能和灵活的配置选项,能够满足不同场景下的大语言模型后训练需求。
更多详细信息可以参考:
- 官方文档:docs/index.md
- 社区教程:README.md
- API参考:docs/api/api_sft.rst
通过本文的指南,用户可以快速上手Tunix工具,进行模型的微调和部署,加速大语言模型的后训练过程。
GLM-5智谱 AI 正式发布 GLM-5,旨在应对复杂系统工程和长时域智能体任务。Jinja00
GLM-5-w4a8GLM-5-w4a8基于混合专家架构,专为复杂系统工程与长周期智能体任务设计。支持单/多节点部署,适配Atlas 800T A3,采用w4a8量化技术,结合vLLM推理优化,高效平衡性能与精度,助力智能应用开发Jinja00
jiuwenclawJiuwenClaw 是一款基于openJiuwen开发的智能AI Agent,它能够将大语言模型的强大能力,通过你日常使用的各类通讯应用,直接延伸至你的指尖。Python0201- QQwen3.5-397B-A17BQwen3.5 实现了重大飞跃,整合了多模态学习、架构效率、强化学习规模以及全球可访问性等方面的突破性进展,旨在为开发者和企业赋予前所未有的能力与效率。Jinja00
AtomGit城市坐标计划AtomGit 城市坐标计划开启!让开源有坐标,让城市有星火。致力于与城市合伙人共同构建并长期运营一个健康、活跃的本地开发者生态。01
awesome-zig一个关于 Zig 优秀库及资源的协作列表。Makefile00