首页
/ Transformer Reinforcement Learning X:大规模强化学习框架

Transformer Reinforcement Learning X:大规模强化学习框架

2024-09-20 12:27:17作者:蔡丛锟
trlx
A repo for distributed training of language models with Reinforcement Learning via Human Feedback (RLHF)

项目介绍

Transformer Reinforcement Learning X (trlX) 是一个专为大规模语言模型微调而设计的分布式训练框架。它利用强化学习技术,通过提供的奖励函数或奖励标注数据集来优化语言模型。trlX 支持对高达 20B 参数的模型进行微调,如 facebook/opt-6.7bEleutherAI/gpt-neox-20bgoogle/flan-t5-xxl。对于超过 20B 参数的模型,trlX 提供了基于 NVIDIA NeMo 的训练器,利用高效的并行技术进行扩展。

项目技术分析

trlX 的核心技术包括:

  1. 强化学习算法:目前支持 Proximal Policy Optimization (PPO) 和 Implicit Language Q-Learning (ILQL) 两种算法。
  2. 分布式训练:通过 Hugging Face 的 Accelerate 和 NVIDIA 的 NeMo 框架,trlX 能够高效地进行分布式训练。
  3. 模型支持:支持多种大型语言模型,包括 GPT、T5 等。
  4. 灵活的训练配置:用户可以通过配置文件自定义训练参数,如批量大小、序列长度等。

项目及技术应用场景

trlX 适用于以下场景:

  1. 对话系统优化:通过强化学习优化对话生成模型,提升对话质量和用户满意度。
  2. 文本生成任务:如新闻生成、故事创作等,通过奖励函数优化生成文本的质量。
  3. 代码生成:优化代码生成模型,提高代码的正确性和可读性。
  4. 数据增强:通过强化学习生成高质量的训练数据,提升模型的泛化能力。

项目特点

  1. 高效性:利用分布式训练技术,能够高效地处理大规模语言模型。
  2. 灵活性:支持多种强化学习算法和模型,用户可以根据需求选择合适的配置。
  3. 易用性:提供了详细的文档和示例代码,方便用户快速上手。
  4. 扩展性:支持多种模型和训练框架,能够适应不同的应用场景。

总结

trlX 是一个功能强大且灵活的强化学习框架,特别适合大规模语言模型的微调任务。无论你是研究者还是开发者,trlX 都能为你提供高效的解决方案。快来尝试吧!

📖 文档

🧀 CHEESE:用于强化学习应用的人类标注数据收集库。

安装

git clone https://github.com/CarperAI/trlx.git
cd trlx
pip install torch --extra-index-url https://download.pytorch.org/whl/cu118
pip install -e .

示例

更多使用示例请参考 examples。你也可以尝试以下 Colab 笔记本:

描述 链接
Simulacra (GPT2, ILQL) Open In Colab
Sentiment (GPT2, ILQL) Open In Colab

最新运行示例请查看 Weights & Biases

如何训练

你可以使用奖励函数或奖励标注数据集来训练模型。

使用奖励函数

trainer = trlx.train('gpt2', reward_fn=lambda samples, **kwargs: [sample.count('cats') for sample in samples])

使用奖励标注数据集

trainer = trlx.train('EleutherAI/gpt-j-6B', samples=['dolphins', 'geese'], rewards=[1.0, 100.0])

使用提示-完成数据集

trainer = trlx.train('gpt2', samples=[['Question: 1 + 2 Answer:', '3'], ['Question: Solve this equation: ∀n>0, s=2, sum(n ** -s). Answer:', '(pi ** 2)/ 6']])

配置超参数

from trlx.data.default_configs import default_ppo_config

config = default_ppo_config()
config.model.model_path = 'EleutherAI/gpt-neox-20b'
config.tokenizer.tokenizer_path = 'EleutherAI/gpt-neox-20b'
config.train.seq_length = 2048

trainer = trlx.train(config=config, reward_fn=lambda samples, **kwargs: [len(sample) for sample in samples])

保存模型

trainer.save_pretrained('/path/to/output/folder/')

贡献

欢迎贡献代码和提出建议!请参考 贡献指南文档

引用

@inproceedings{havrilla-etal-2023-trlx,
    title = "trl{X}: A Framework for Large Scale Reinforcement Learning from Human Feedback",
    author = "Havrilla, Alexander  and
      Zhuravinskyi, Maksym  and
      Phung, Duy  and
      Tiwari, Aman  and
      Tow, Jonathan  and
      Biderman, Stella  and
      Anthony, Quentin  and
      Castricato, Louis",
    booktitle = "Proceedings of the 2023 Conference on Empirical Methods in Natural Language Processing",
    month = dec,
    year = "2023",
    address = "Singapore",
    publisher = "Association for Computational Linguistics",
    url = "https://aclanthology.org/2023.emnlp-main.530",
    doi = "10.18653/v1/2023.emnlp-main.530",
    pages = "8578--8595",
}

致谢

特别感谢 Leandro von Werra 对 trl 的贡献,该库最初启发了本项目的开发。

trlx
A repo for distributed training of language models with Reinforcement Learning via Human Feedback (RLHF)
热门项目推荐
相关项目推荐

项目优选

收起
CangjieCommunity
为仓颉编程语言开发者打造活跃、开放、高质量的社区环境
Markdown
669
0
RuoYi-Vue
🎉 基于SpringBoot,Spring Security,JWT,Vue & Element 的前后端分离权限管理系统,同时提供了 Vue3 的版本
Java
136
18
openHiTLS
旨在打造算法先进、性能卓越、高效敏捷、安全可靠的密码套件,通过轻量级、可剪裁的软件技术架构满足各行业不同场景的多样化要求,让密码技术应用更简单,同时探索后量子等先进算法创新实践,构建密码前沿技术底座!
C
10
4
redis-sdk
仓颉语言实现的Redis客户端SDK。已适配仓颉0.53.4 Beta版本。接口设计兼容jedis接口语义,支持RESP2和RESP3协议,支持发布订阅模式,支持哨兵模式和集群模式。
Cangjie
322
26
advanced-java
Advanced-Java是一个Java进阶教程,适合用于学习Java高级特性和编程技巧。特点:内容深入、实例丰富、适合进阶学习。
JavaScript
75.83 K
19.04 K
qwerty-learner
为键盘工作者设计的单词记忆与英语肌肉记忆锻炼软件 / Words learning and English muscle memory training software designed for keyboard workers
TSX
15.56 K
1.44 K
Jpom
🚀简而轻的低侵入式在线构建、自动部署、日常运维、项目监控软件
Java
1.41 K
292
Yi-Coder
Yi Coder 编程模型,小而强大的编程助手
HTML
30
5
easy-es
Elasticsearch 国内Top1 elasticsearch搜索引擎框架es ORM框架,索引全自动智能托管,如丝般顺滑,与Mybatis-plus一致的API,屏蔽语言差异,开发者只需要会MySQL语法即可完成对Es的相关操作,零额外学习成本.底层采用RestHighLevelClient,兼具低码,易用,易拓展等特性,支持es独有的高亮,权重,分词,Geo,嵌套,父子类型等功能...
Java
1.42 K
231
taro
开放式跨端跨框架解决方案,支持使用 React/Vue/Nerv 等框架来开发微信/京东/百度/支付宝/字节跳动/ QQ 小程序/H5/React Native 等应用。 https://taro.zone/
TypeScript
35.34 K
4.77 K