首页
/ Tunix工具使用指南:从安装到模型部署全流程

Tunix工具使用指南:从安装到模型部署全流程

2026-02-05 05:41:21作者:宣利权Counsellor

Tunix(Tune-in-JAX)是一个基于JAX的大型语言模型后训练库,提供高效且可扩展的支持,包括监督微调、强化学习(RL)和知识蒸馏。本文将详细介绍Tunix的安装步骤、模型训练流程以及部署方法,帮助用户快速上手使用该工具。

1. 安装Tunix

1.1 环境准备

在安装Tunix之前,需要确保系统已安装Python环境。推荐使用Python 3.8及以上版本。此外,Tunix依赖JAX库,需要根据系统环境安装相应版本的JAX和JAXlib。

1.2 安装方式

Tunix提供多种安装方式,用户可以根据需求选择适合的方式:

1.2.1 从PyPI安装(推荐)

使用pip命令直接从PyPI安装Tunix:

pip install "google-tunix[prod]"

1.2.2 从源码安装(开发模式)

如果需要修改代码或使用最新开发版本,可以从源码安装:

git clone https://gitcode.com/GitHub_Trending/tu/tunix.git
cd tunix
pip install -e ".[dev]"

1.2.3 安装vLLM依赖(TPU支持)

对于需要在TPU上使用vLLM的用户,可以运行以下脚本安装相关依赖:

bash scripts/install_tunix_vllm_requirement.sh

该脚本会安装vLLM及其依赖,并配置TPU支持。

2. 模型训练流程

Tunix支持多种模型训练任务,包括监督微调(SFT)、强化学习(RL)和知识蒸馏。下面以强化学习和监督微调为例,介绍具体的训练流程。

2.1 强化学习训练(GRPO)

以Llama3.2-1B模型在GSM8K数据集上的训练为例,使用GRPO(Group Relative Policy Optimization)算法进行强化学习训练。

2.1.1 训练脚本

训练脚本位于examples/rl/gsm8k/run_llama3.2_1b.sh,主要参数设置如下:

python3 -m tunix.cli.grpo_main \
  base_config.yaml \
  reference_model_config.model_name="llama3.2-1b" \
  reference_model_config.model_id="meta-llama/Llama-3.2-1B" \
  reference_model_config.model_source="huggingface" \
  reference_model_config.mesh.shape="(2,4)" \
  reference_model_config.mesh.axis_names="('fsdp','tp')" \
  actor_model_config.lora_config.rank=64 \
  actor_model_config.lora_config.alpha=64.0 \
  batch_size=1 \
  num_batches=3738 \
  num_train_epochs=1 \
  rl_training_config.max_steps=3738 \
  rollout_config.total_generation_steps=768 \
  reward_functions="['tunix/cli/reward_fn/gsm8k.py']"

2.1.2 参数说明

  • reference_model_config:参考模型配置,包括模型名称、ID、来源等
  • actor_model_config: Actor模型配置,设置LoRA参数
  • mesh.shapemesh.axis_names:设置模型并行策略
  • batch_sizenum_batches:设置训练批次大小和数量
  • rollout_config:生成配置,设置生成步数等参数
  • reward_functions:奖励函数配置

2.2 监督微调(SFT)

以Gemma3-4B模型在MTNT数据集上的监督微调和为例,介绍SFT训练流程。

2.2.1 训练脚本

训练脚本位于examples/sft/mtnt/run_gemma3_4b.sh,主要参数设置如下:

python3 -m tunix.cli.peft_main \
  base_config.yaml \
  model_name="gemma3-4b" \
  model_id="gs://gemma-data/checkpoints/gemma3-4b-pt" \
  model_source="gcs" \
  tokenizer_path="gs://gemma-data/tokenizers/tokenizer_gemma3.model" \
  dataset_name="mtnt/en-fr" \
  optimizer_config.learning_rate=1e-5 \
  training_config.max_steps=100 \
  mesh.shape="(2,2)" \
  mesh.axis_names="('fsdp','tp')"

2.2.2 参数说明

  • model_namemodel_id:模型名称和ID
  • dataset_name:数据集名称
  • optimizer_config:优化器配置,设置学习率等参数
  • training_config:训练配置,设置最大步数等参数
  • mesh:模型并行配置

2.3 训练监控

Tunix支持使用TensorBoard监控训练过程。训练脚本中通过metrics_logging_options.log_dir参数指定日志目录,例如:

training_config.metrics_logging_options.log_dir="/tmp/tensorboard/full"

启动TensorBoard查看训练指标:

tensorboard --logdir=/tmp/tensorboard/full

3. 模型部署

3.1 模型导出

训练完成后,模型会保存到指定的 checkpoint 目录。可以通过以下方式导出模型:

from tunix.sft.checkpoint_manager import CheckpointManager

checkpoint_manager = CheckpointManager(
    checkpoint_dir="/path/to/checkpoint",
    model=model,
)
checkpoint_manager.save_checkpoint(step=1000)

3.2 模型加载与推理

使用Tunix加载训练好的模型进行推理:

from tunix.models.llama3.model import Llama3Model
from tunix.models.llama3.params import Llama3Params

params = Llama3Params.from_pretrained(
    model_id="/path/to/checkpoint",
    mesh_shape=(2, 4),
    axis_names=("fsdp", "tp"),
)
model = Llama3Model(params)

inputs = tokenizer("Hello, world!", return_tensors="jax")
outputs = model.generate(**inputs, max_length=50)
print(tokenizer.decode(outputs[0]))

4. 总结

本文介绍了Tunix的安装步骤、模型训练流程和部署方法,涵盖了从环境准备到模型推理的全流程。Tunix提供了丰富的功能和灵活的配置选项,能够满足不同场景下的大语言模型后训练需求。

更多详细信息可以参考:

通过本文的指南,用户可以快速上手Tunix工具,进行模型的微调和部署,加速大语言模型的后训练过程。

登录后查看全文
热门项目推荐
相关项目推荐