首页
/ TinyLlama项目中的微调实践指南

TinyLlama项目中的微调实践指南

2025-05-27 07:24:41作者:盛欣凯Ernestine

数据集预处理关键点

在TinyLlama项目中进行模型微调时,数据集的组织格式至关重要。根据项目代码分析,数据集需要包含两个关键字段:inputoutput。这两个字段分别代表了模型的输入文本和期望的输出文本。

正确的数据集格式示例如下:

{
    'input': '用户输入的文本内容',
    'output': '模型期望生成的回答内容'
}

数据预处理实现

预处理函数应当将原始数据转换为上述格式。假设原始数据是以制表符分隔的文本文件,预处理可以这样实现:

from datasets import load_dataset

def preprocess_function(examples):
    parts = examples["text"].split("\t")
    return {
        "input": parts[0],  # 第一部分作为输入
        "output": parts[1]  # 第二部分作为期望输出
    }

# 加载并预处理数据集
dataset = load_dataset('text', data_files='your_data.txt')['train']
dataset = dataset.map(preprocess_function).remove_columns('text')

训练器配置详解

微调TinyLlama需要使用特定的训练器配置,核心在于正确设置数据整理器(DataCollator)。以下是关键配置步骤:

  1. 训练参数设置
from transformers import TrainingArguments

training_args = TrainingArguments(
    output_dir="./finetune_results",  # 输出目录
    num_train_epochs=1,              # 训练轮数
    per_device_train_batch_size=16,  # 训练批次大小
    per_device_eval_batch_size=64,   # 评估批次大小
    weight_decay=0.005,              # 权重衰减系数
    logging_dir="./logs",            # 日志目录
    remove_unused_columns=False      # 保留未使用列
)
  1. 数据整理器实现: 数据整理器负责将文本数据转换为模型可处理的张量格式,核心功能包括:
  • 添加特殊标记(BOS/EOS)
  • 控制输入输出长度
  • 处理填充和注意力掩码
from torch.nn.utils.rnn import pad_sequence
import torch
from dataclasses import dataclass
from typing import Dict, Sequence
import copy

@dataclass
class DataCollatorForCausalLM:
    tokenizer: transformers.PreTrainedTokenizer
    source_max_len: int = 128    # 输入最大长度
    target_max_len: int = 128    # 输出最大长度
    train_on_source: bool = True # 是否在源文本上训练
    predict_with_generate: bool = False  # 是否生成式预测

    def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]:
        # 添加特殊标记并分词
        sources = [f"{self.tokenizer.bos_token}{x['input']}" for x in instances]
        targets = [f"{x['output']}{self.tokenizer.eos_token}" for x in instances]
        
        # 分词处理
        tokenized_sources = self.tokenizer(
            sources, max_length=self.source_max_len, 
            truncation=True, add_special_tokens=False
        )
        tokenized_targets = self.tokenizer(
            targets, max_length=self.target_max_len,
            truncation=True, add_special_tokens=False
        )
        
        # 构建模型输入和标签
        input_ids, labels = [], []
        for src, tgt in zip(tokenized_sources['input_ids'], tokenized_targets['input_ids']):
            input_ids.append(torch.tensor(src + tgt))
            if not self.train_on_source:
                labels.append(torch.tensor(
                    [IGNORE_INDEX]*len(src) + tgt.copy()
                ))
            else:
                labels.append(torch.tensor(src + tgt.copy()))
        
        # 填充处理
        input_ids = pad_sequence(input_ids, batch_first=True, 
                                padding_value=self.tokenizer.pad_token_id)
        labels = pad_sequence(labels, batch_first=True, 
                             padding_value=IGNORE_INDEX)
        
        return {
            'input_ids': input_ids,
            'attention_mask': input_ids.ne(self.tokenizer.pad_token_id),
            'labels': labels
        }

完整训练流程

结合上述组件,完整的微调流程如下:

from transformers import Seq2SeqTrainer

# 初始化数据整理器
data_collator = DataCollatorForCausalLM(
    tokenizer=tokenizer,
    source_max_len=128,
    target_max_len=128,
    train_on_source=True
)

# 创建训练器
trainer = Seq2SeqTrainer(
    model=model,
    tokenizer=tokenizer,
    args=training_args,
    train_dataset=dataset,
    data_collator=data_collator
)

# 开始训练
trainer.train()

常见问题解决方案

  1. 数据类型错误:确保预处理后的数据是文本格式,而非其他类型对象。

  2. 索引越界错误:检查数据分割逻辑是否正确,确保每行数据都包含input和output两部分。

  3. 特殊标记处理:根据使用的tokenizer确认BOS(开始)和EOS(结束)标记是否正确添加。

  4. 长度控制:合理设置source_max_len和target_max_len参数,避免输入过长被截断或过短浪费资源。

通过以上步骤,开发者可以成功地在TinyLlama项目上实现模型的微调。关键在于理解数据流向:从原始文本→预处理→分词→训练这一完整流程,每个环节都需要正确配置才能保证训练顺利进行。

登录后查看全文
热门项目推荐

热门内容推荐

最新内容推荐

项目优选

收起
openGauss-serveropenGauss-server
openGauss kernel ~ openGauss is an open source relational database management system
C++
137
188
RuoYi-Vue3RuoYi-Vue3
🎉 (RuoYi)官方仓库 基于SpringBoot,Spring Security,JWT,Vue3 & Vite、Element Plus 的前后端分离权限管理系统
Vue
885
527
openHiTLSopenHiTLS
旨在打造算法先进、性能卓越、高效敏捷、安全可靠的密码套件,通过轻量级、可剪裁的软件技术架构满足各行业不同场景的多样化要求,让密码技术应用更简单,同时探索后量子等先进算法创新实践,构建密码前沿技术底座!
C
368
382
ohos_react_nativeohos_react_native
React Native鸿蒙化仓库
C++
184
265
kernelkernel
deepin linux kernel
C
22
5
MateChatMateChat
前端智能化场景解决方案UI库,轻松构建你的AI应用,我们将持续完善更新,欢迎你的使用与建议。 官网地址:https://matechat.gitcode.com
735
105
note-gennote-gen
一款跨平台的 Markdown AI 笔记软件,致力于使用 AI 建立记录和写作的桥梁。
TSX
84
4
CangjieCommunityCangjieCommunity
为仓颉编程语言开发者打造活跃、开放、高质量的社区环境
Markdown
1.08 K
0
harmony-utilsharmony-utils
harmony-utils 一款功能丰富且极易上手的HarmonyOS工具库,借助众多实用工具类,致力于助力开发者迅速构建鸿蒙应用。其封装的工具涵盖了APP、设备、屏幕、授权、通知、线程间通信、弹框、吐司、生物认证、用户首选项、拍照、相册、扫码、文件、日志,异常捕获、字符、字符串、数字、集合、日期、随机、base64、加密、解密、JSON等一系列的功能和操作,能够满足各种不同的开发需求。
ArkTS
54
1
HarmonyOS-ExamplesHarmonyOS-Examples
本仓将收集和展示仓颉鸿蒙应用示例代码,欢迎大家投稿,在仓颉鸿蒙社区展现你的妙趣设计!
Cangjie
400
376