首页
/ Transformer Reinforcement Learning X:大规模强化学习框架

Transformer Reinforcement Learning X:大规模强化学习框架

2024-09-20 12:27:17作者:蔡丛锟

项目介绍

Transformer Reinforcement Learning X (trlX) 是一个专为大规模语言模型微调而设计的分布式训练框架。它利用强化学习技术,通过提供的奖励函数或奖励标注数据集来优化语言模型。trlX 支持对高达 20B 参数的模型进行微调,如 facebook/opt-6.7bEleutherAI/gpt-neox-20bgoogle/flan-t5-xxl。对于超过 20B 参数的模型,trlX 提供了基于 NVIDIA NeMo 的训练器,利用高效的并行技术进行扩展。

项目技术分析

trlX 的核心技术包括:

  1. 强化学习算法:目前支持 Proximal Policy Optimization (PPO) 和 Implicit Language Q-Learning (ILQL) 两种算法。
  2. 分布式训练:通过 Hugging Face 的 Accelerate 和 NVIDIA 的 NeMo 框架,trlX 能够高效地进行分布式训练。
  3. 模型支持:支持多种大型语言模型,包括 GPT、T5 等。
  4. 灵活的训练配置:用户可以通过配置文件自定义训练参数,如批量大小、序列长度等。

项目及技术应用场景

trlX 适用于以下场景:

  1. 对话系统优化:通过强化学习优化对话生成模型,提升对话质量和用户满意度。
  2. 文本生成任务:如新闻生成、故事创作等,通过奖励函数优化生成文本的质量。
  3. 代码生成:优化代码生成模型,提高代码的正确性和可读性。
  4. 数据增强:通过强化学习生成高质量的训练数据,提升模型的泛化能力。

项目特点

  1. 高效性:利用分布式训练技术,能够高效地处理大规模语言模型。
  2. 灵活性:支持多种强化学习算法和模型,用户可以根据需求选择合适的配置。
  3. 易用性:提供了详细的文档和示例代码,方便用户快速上手。
  4. 扩展性:支持多种模型和训练框架,能够适应不同的应用场景。

总结

trlX 是一个功能强大且灵活的强化学习框架,特别适合大规模语言模型的微调任务。无论你是研究者还是开发者,trlX 都能为你提供高效的解决方案。快来尝试吧!

📖 文档

🧀 CHEESE:用于强化学习应用的人类标注数据收集库。

安装

git clone https://github.com/CarperAI/trlx.git
cd trlx
pip install torch --extra-index-url https://download.pytorch.org/whl/cu118
pip install -e .

示例

更多使用示例请参考 examples。你也可以尝试以下 Colab 笔记本:

描述 链接
Simulacra (GPT2, ILQL) Open In Colab
Sentiment (GPT2, ILQL) Open In Colab

最新运行示例请查看 Weights & Biases

如何训练

你可以使用奖励函数或奖励标注数据集来训练模型。

使用奖励函数

trainer = trlx.train('gpt2', reward_fn=lambda samples, **kwargs: [sample.count('cats') for sample in samples])

使用奖励标注数据集

trainer = trlx.train('EleutherAI/gpt-j-6B', samples=['dolphins', 'geese'], rewards=[1.0, 100.0])

使用提示-完成数据集

trainer = trlx.train('gpt2', samples=[['Question: 1 + 2 Answer:', '3'], ['Question: Solve this equation: ∀n>0, s=2, sum(n ** -s). Answer:', '(pi ** 2)/ 6']])

配置超参数

from trlx.data.default_configs import default_ppo_config

config = default_ppo_config()
config.model.model_path = 'EleutherAI/gpt-neox-20b'
config.tokenizer.tokenizer_path = 'EleutherAI/gpt-neox-20b'
config.train.seq_length = 2048

trainer = trlx.train(config=config, reward_fn=lambda samples, **kwargs: [len(sample) for sample in samples])

保存模型

trainer.save_pretrained('/path/to/output/folder/')

贡献

欢迎贡献代码和提出建议!请参考 贡献指南文档

引用

@inproceedings{havrilla-etal-2023-trlx,
    title = "trl{X}: A Framework for Large Scale Reinforcement Learning from Human Feedback",
    author = "Havrilla, Alexander  and
      Zhuravinskyi, Maksym  and
      Phung, Duy  and
      Tiwari, Aman  and
      Tow, Jonathan  and
      Biderman, Stella  and
      Anthony, Quentin  and
      Castricato, Louis",
    booktitle = "Proceedings of the 2023 Conference on Empirical Methods in Natural Language Processing",
    month = dec,
    year = "2023",
    address = "Singapore",
    publisher = "Association for Computational Linguistics",
    url = "https://aclanthology.org/2023.emnlp-main.530",
    doi = "10.18653/v1/2023.emnlp-main.530",
    pages = "8578--8595",
}

致谢

特别感谢 Leandro von Werra 对 trl 的贡献,该库最初启发了本项目的开发。

热门项目推荐
相关项目推荐

项目优选

收起
Python-100-DaysPython-100-Days
Python - 100天从新手到大师
Python
609
115
HarmonyOS-ExamplesHarmonyOS-Examples
本仓将收集和展示仓颉鸿蒙应用示例代码,欢迎大家投稿,在仓颉鸿蒙社区展现你的妙趣设计!
Cangjie
286
79
mdmd
✍ WeChat Markdown Editor | 一款高度简洁的微信 Markdown 编辑器:支持 Markdown 语法、色盘取色、多图上传、一键下载文档、自定义 CSS 样式、一键重置等特性
Vue
111
25
openHiTLSopenHiTLS
旨在打造算法先进、性能卓越、高效敏捷、安全可靠的密码套件,通过轻量级、可剪裁的软件技术架构满足各行业不同场景的多样化要求,让密码技术应用更简单,同时探索后量子等先进算法创新实践,构建密码前沿技术底座!
C
60
48
RuoYi-Cloud-Vue3RuoYi-Cloud-Vue3
🎉 基于Spring Boot、Spring Cloud & Alibaba、Vue3 & Vite、Element Plus的分布式前后端分离微服务架构权限管理系统
Vue
45
29
go-stockgo-stock
🦄🦄🦄AI赋能股票分析:自选股行情获取,成本盈亏展示,涨跌报警推送,市场整体/个股情绪分析,K线技术指标分析等。数据全部保留在本地。支持DeepSeek,OpenAI, Ollama,LMStudio,AnythingLLM,硅基流动,火山方舟,阿里云百炼等平台或模型。
Go
1
0
Cangjie-ExamplesCangjie-Examples
本仓将收集和展示高质量的仓颉示例代码,欢迎大家投稿,让全世界看到您的妙趣设计,也让更多人通过您的编码理解和喜爱仓颉语言。
Cangjie
205
57
MateChatMateChat
前端智能化场景解决方案UI库,轻松构建你的AI应用,我们将持续完善更新,欢迎你的使用与建议。 官网地址:https://matechat.gitcode.com
184
34
RuoYi-VueRuoYi-Vue
🎉 基于SpringBoot,Spring Security,JWT,Vue & Element 的前后端分离权限管理系统,同时提供了 Vue3 的版本
Java
182
44
frogfrog
这是一个人工生命试验项目,最终目标是创建“有自我意识表现”的模拟生命体。
Java
8
0