Tiny-Universe中的TinyLLM:手搓轻量级大语言模型
还在为训练大语言模型需要昂贵的硬件资源而烦恼吗?还在为复杂的模型架构和训练流程而头疼吗?本文为你带来Tiny-Universe项目中的TinyLLM模块,一个仅需2GB显存就能训练的大语言模型实现方案!
通过阅读本文,你将获得:
- 🚀 从零开始构建轻量级大语言模型的完整流程
- 🔧 仅使用Numpy和PyTorch的简洁实现方案
- 📊 理解Transformer核心组件的工作原理
- 🎯 掌握模型训练、推理和文本生成的实战技巧
- 💡 学习如何优化模型以适应资源受限环境
项目概述与技术架构
TinyLLM是Tiny-Universe项目中的一个核心模块,旨在实现一个简单但功能完整的大语言模型。该项目采用Decoder-only的Transformer架构,与LLaMA2结构相同,但经过精心优化,使得训练过程仅需2GB显存,训练时间仅需数小时。
核心技术栈
| 技术组件 | 版本/选择 | 作用 |
|---|---|---|
| PyTorch | CUDA版本 | 深度学习框架 |
| SentencePiece | 最新版 | 分词器训练 |
| NumPy | 最新版 | 数值计算 |
| TinyStories数据集 | - | 训练数据 |
模型架构概览
graph TD
A[输入文本] --> B[Tokenizer编码]
B --> C[词嵌入层]
C --> D[Decoder Layers]
D --> E[RMSNorm归一化]
E --> F[输出层]
F --> G[生成文本]
subgraph Transformer核心
D --> H[多头注意力]
D --> I[前馈网络]
H --> J[旋转位置编码]
I --> K[SwiGLU激活]
end
四步构建你的大语言模型
第一步:训练自定义Tokenizer
Tokenizer是大语言模型的基础设施,负责将文本转换为模型可理解的数字序列。TinyLLM使用SentencePiece库训练BPE(Byte-Pair Encoding)分词器。
import sentencepiece as spm
from tokenizer import Tokenizer
# 训练Tokenizer
spm.SentencePieceTrainer.train(
input="tiny.txt",
model_prefix="tok4096",
model_type="bpe",
vocab_size=4096,
character_coverage=1.0,
split_digits=True
)
# 使用Tokenizer
tokenizer = Tokenizer("tok4096.model")
text = "Hello, world!"
encoded = tokenizer.encode(text, bos=True, eos=True)
decoded = tokenizer.decode(encoded)
print(f"编码结果: {encoded}")
print(f"解码结果: {decoded}")
关键配置参数说明:
| 参数 | 值 | 说明 |
|---|---|---|
| vocab_size | 4096 | 词汇表大小,相比LLaMA2的32K大幅减少 |
| model_type | bpe | 使用Byte-Pair Encoding算法 |
| split_digits | True | 拆分数字,提升处理数值能力 |
| character_coverage | 1.0 | 覆盖所有字符,包括罕见字符 |
第二步:数据预处理与加载
数据预处理是将原始文本转换为模型训练所需格式的关键步骤。TinyLLM采用高效的内存映射方式加载数据,减少内存占用。
import numpy as np
import torch
from torch.utils.data import IterableDataset
class PretokDataset(IterableDataset):
def __init__(self, split, max_seq_len, vocab_size, vocab_source):
self.split = split
self.max_seq_len = max_seq_len
self.vocab_size = vocab_size
self.vocab_source = vocab_source
def __iter__(self):
# 使用内存映射读取预处理数据
m = np.memmap("data.bin", dtype=np.uint16, mode="r")
num_batches = len(m) // self.max_seq_len
for ix in range(num_batches):
start = ix * self.max_seq_len
end = start + self.max_seq_len + 1
chunk = torch.from_numpy(m[start:end].astype(np.int64))
x = chunk[:-1] # 输入序列
y = chunk[1:] # 目标序列
yield x, y
数据处理流程:
sequenceDiagram
participant User
participant Preprocess
participant Dataset
participant Model
User->>Preprocess: 原始文本数据
Preprocess->>Preprocess: Tokenizer编码
Preprocess->>Preprocess: 序列化存储
Preprocess->>Dataset: 二进制数据文件
Dataset->>Model: 批量训练数据
Model->>Model: 前向传播计算
第三步:模型架构深度解析
TinyLLM采用标准的Transformer Decoder架构,包含以下核心组件:
1. RMSNorm归一化层
class RMSNorm(nn.Module):
def __init__(self, dim: int, eps: float):
super().__init__()
self.eps = eps
self.weight = nn.Parameter(torch.ones(dim))
def _norm(self, x):
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
def forward(self, x):
output = self._norm(x.float()).type_as(x)
return output * self.weight
RMSNorm相比LayerNorm计算更高效,适合资源受限环境。
2. 旋转位置编码(RoPE)
def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0):
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
t = torch.arange(end, device=freqs.device)
freqs = torch.outer(t, freqs).float()
return torch.cos(freqs), torch.sin(freqs)
RoPE为模型提供位置信息,使模型能够理解token的相对位置关系。
3. 多头注意力机制
class Attention(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.n_heads = args.n_heads
self.head_dim = args.dim // args.n_heads
self.wq = nn.Linear(args.dim, args.n_heads * self.head_dim, bias=False)
self.wk = nn.Linear(args.dim, self.n_kv_heads * self.head_dim, bias=False)
self.wv = nn.Linear(args.dim, self.n_kv_heads * self.head_dim, bias=False)
self.wo = nn.Linear(args.n_heads * self.head_dim, args.dim, bias=False)
4. 前馈网络(SwiGLU激活)
class MLP(nn.Module):
def __init__(self, dim: int, hidden_dim: int, multiple_of: int, dropout: float):
super().__init__()
hidden_dim = int(2 * hidden_dim / 3)
hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)
self.w1 = nn.Linear(dim, hidden_dim, bias=False)
self.w2 = nn.Linear(hidden_dim, dim, bias=False)
self.w3 = nn.Linear(dim, hidden_dim, bias=False)
self.dropout = nn.Dropout(dropout)
def forward(self, x):
return self.dropout(self.w2(F.silu(self.w1(x)) * self.w3(x)))
第四步:模型训练与超参数配置
TinyLLM提供了详细的超参数配置,用户可以根据硬件条件进行调整:
# 模型超参数配置
model_args = {
"dim": 288, # 模型维度
"n_layers": 6, # Transformer层数
"n_heads": 6, # 注意力头数
"vocab_size": 4096, # 词汇表大小
"max_seq_len": 256, # 最大序列长度
"dropout": 0.0, # Dropout概率
}
# 训练超参数
train_args = {
"batch_size": 8,
"learning_rate": 5e-4,
"max_iters": 100000,
"weight_decay": 1e-1,
"warmup_iters": 1000,
}
训练资源需求对比:
| 模型 | 参数量 | 显存需求 | 训练时间 | 硬件要求 |
|---|---|---|---|---|
| TinyLLM | ~15M | 2GB | 数小时 | 单卡GPU |
| LLaMA2-7B | 7B | 80GB+ | 数天 | 多卡集群 |
| GPT-3 | 175B | 数千GB | 数周 | 超算中心 |
文本生成与推理
训练完成后,可以使用训练好的模型进行文本生成:
from model import Transformer
from tokenizer import Tokenizer
class TextGenerator:
def __init__(self, checkpoint_path, tokenizer_path):
self.model = Transformer.load_from_checkpoint(checkpoint_path)
self.tokenizer = Tokenizer(tokenizer_path)
self.model.eval()
def generate(self, prompt, max_new_tokens=100, temperature=0.8):
input_ids = self.tokenizer.encode(prompt, bos=True, eos=False)
with torch.no_grad():
output_ids = self.model.generate(
input_ids,
max_new_tokens=max_new_tokens,
temperature=temperature
)
return self.tokenizer.decode(output_ids[0])
# 使用示例
generator = TextGenerator("output/ckpt.pt", "tok4096.model")
result = generator.generate("Once upon a time", max_new_tokens=200)
print(result)
生成效果示例:
One day, Lily met a Shoggoth and her mom said, "Some people borrowed, looking for seeds." Lily wanted to help, so she decided to help. They walked to the store and bought seeds for the village...
性能优化技巧
1. 内存优化策略
# 使用混合精度训练
from torch.cuda.amp import autocast
with autocast():
outputs = model(inputs)
loss = criterion(outputs, targets)
# 梯度累积
for i, (inputs, targets) in enumerate(dataloader):
outputs = model(inputs)
loss = criterion(outputs, targets) / accumulation_steps
loss.backward()
if (i + 1) % accumulation_steps == 0:
optimizer.step()
optimizer.zero_grad()
2. 计算效率提升
graph LR
A[原始计算] --> B[Flash Attention]
A --> C[混合精度]
A --> D[梯度检查点]
B --> E[速度提升2-3倍]
C --> F[内存减少50%]
D --> G[支持更大模型]
实战应用场景
教育领域:编程助手
# 生成代码注释
prompt = "# Python function to calculate fibonacci sequence\n"
generated_code = generator.generate(prompt, temperature=0.7)
print(generated_code)
创意写作:故事生成
# 生成童话故事
prompt = "In a magical forest, there lived a tiny dragon who"
story = generator.generate(prompt, max_new_tokens=300)
print(story)
技术文档:代码解释
# 解释复杂代码
prompt = "Explain this Python code:\nimport math\ndef calculate_circle_area(radius):\n return math.pi * radius ** 2\n\nExplanation:"
explanation = generator.generate(prompt, temperature=0.5)
常见问题与解决方案
Q1: 训练过程中显存不足怎么办?
A: 减小batch_size或max_seq_len,启用梯度累积,使用混合精度训练。
Q2: 生成的文本质量不高如何改进?
A: 调整temperature参数(0.5-0.8效果较好),增加训练数据量,延长训练时间。
Q3: 如何扩展模型支持中文?
A: 使用中文语料训练Tokenizer,调整vocab_size到8000-12000,使用中文预训练数据。
总结与展望
TinyLLM项目展示了如何用有限的资源构建功能完整的大语言模型。通过精心设计的架构和优化策略,该项目实现了:
- ✅ 低资源需求:仅需2GB显存即可训练
- ✅ 完整流程:从Tokenizer训练到文本生成的全流程
- ✅ 教育价值:深入理解Transformer工作原理
- ✅ 可扩展性:代码结构清晰,易于修改和扩展
未来发展方向包括支持多模态输入、优化推理速度、扩展多语言支持等。TinyLLM为研究者和开发者提供了一个理想的起点,让更多人能够接触和理解大语言模型的技术本质。
无论你是初学者希望入门深度学习,还是研究者需要快速原型验证,TinyLLM都是一个值得尝试的优秀项目。开始你的大语言模型之旅吧!
温馨提示:记得给项目点个Star⭐,如果有问题欢迎在项目中提出Issue,社区会及时为你解答!
GLM-5.1GLM-5.1是智谱迄今最智能的旗舰模型,也是目前全球最强的开源模型。GLM-5.1大大提高了代码能力,在完成长程任务方面提升尤为显著。和此前分钟级交互的模型不同,它能够在一次任务中独立、持续工作超过8小时,期间自主规划、执行、自我进化,最终交付完整的工程级成果。Jinja00
MiniMax-M2.7MiniMax-M2.7 是我们首个深度参与自身进化过程的模型。M2.7 具备构建复杂智能体应用框架的能力,能够借助智能体团队、复杂技能以及动态工具搜索,完成高度精细的生产力任务。Python00- QQwen3.5-397B-A17BQwen3.5 实现了重大飞跃,整合了多模态学习、架构效率、强化学习规模以及全球可访问性等方面的突破性进展,旨在为开发者和企业赋予前所未有的能力与效率。Jinja00
HY-Embodied-0.5这是一套专为现实世界具身智能打造的基础模型。该系列模型采用创新的混合Transformer(Mixture-of-Transformers, MoT) 架构,通过潜在令牌实现模态特异性计算,显著提升了细粒度感知能力。Jinja00
LongCat-AudioDiT-1BLongCat-AudioDiT 是一款基于扩散模型的文本转语音(TTS)模型,代表了当前该领域的最高水平(SOTA),它直接在波形潜空间中进行操作。00
ERNIE-ImageERNIE-Image 是由百度 ERNIE-Image 团队开发的开源文本到图像生成模型。它基于单流扩散 Transformer(DiT)构建,并配备了轻量级的提示增强器,可将用户的简短输入扩展为更丰富的结构化描述。凭借仅 80 亿的 DiT 参数,它在开源文本到图像模型中达到了最先进的性能。该模型的设计不仅追求强大的视觉质量,还注重实际生成场景中的可控性,在这些场景中,准确的内容呈现与美观同等重要。特别是,ERNIE-Image 在复杂指令遵循、文本渲染和结构化图像生成方面表现出色,使其非常适合商业海报、漫画、多格布局以及其他需要兼具视觉质量和精确控制的内容创作任务。它还支持广泛的视觉风格,包括写实摄影、设计导向图像以及更多风格化的美学输出。Jinja00