首页
/ TinyLlama项目中的微调实践指南

TinyLlama项目中的微调实践指南

2025-05-27 23:33:55作者:盛欣凯Ernestine

数据集预处理关键点

在TinyLlama项目中进行模型微调时,数据集的组织格式至关重要。根据项目代码分析,数据集需要包含两个关键字段:inputoutput。这两个字段分别代表了模型的输入文本和期望的输出文本。

正确的数据集格式示例如下:

{
    'input': '用户输入的文本内容',
    'output': '模型期望生成的回答内容'
}

数据预处理实现

预处理函数应当将原始数据转换为上述格式。假设原始数据是以制表符分隔的文本文件,预处理可以这样实现:

from datasets import load_dataset

def preprocess_function(examples):
    parts = examples["text"].split("\t")
    return {
        "input": parts[0],  # 第一部分作为输入
        "output": parts[1]  # 第二部分作为期望输出
    }

# 加载并预处理数据集
dataset = load_dataset('text', data_files='your_data.txt')['train']
dataset = dataset.map(preprocess_function).remove_columns('text')

训练器配置详解

微调TinyLlama需要使用特定的训练器配置,核心在于正确设置数据整理器(DataCollator)。以下是关键配置步骤:

  1. 训练参数设置
from transformers import TrainingArguments

training_args = TrainingArguments(
    output_dir="./finetune_results",  # 输出目录
    num_train_epochs=1,              # 训练轮数
    per_device_train_batch_size=16,  # 训练批次大小
    per_device_eval_batch_size=64,   # 评估批次大小
    weight_decay=0.005,              # 权重衰减系数
    logging_dir="./logs",            # 日志目录
    remove_unused_columns=False      # 保留未使用列
)
  1. 数据整理器实现: 数据整理器负责将文本数据转换为模型可处理的张量格式,核心功能包括:
  • 添加特殊标记(BOS/EOS)
  • 控制输入输出长度
  • 处理填充和注意力掩码
from torch.nn.utils.rnn import pad_sequence
import torch
from dataclasses import dataclass
from typing import Dict, Sequence
import copy

@dataclass
class DataCollatorForCausalLM:
    tokenizer: transformers.PreTrainedTokenizer
    source_max_len: int = 128    # 输入最大长度
    target_max_len: int = 128    # 输出最大长度
    train_on_source: bool = True # 是否在源文本上训练
    predict_with_generate: bool = False  # 是否生成式预测

    def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]:
        # 添加特殊标记并分词
        sources = [f"{self.tokenizer.bos_token}{x['input']}" for x in instances]
        targets = [f"{x['output']}{self.tokenizer.eos_token}" for x in instances]
        
        # 分词处理
        tokenized_sources = self.tokenizer(
            sources, max_length=self.source_max_len, 
            truncation=True, add_special_tokens=False
        )
        tokenized_targets = self.tokenizer(
            targets, max_length=self.target_max_len,
            truncation=True, add_special_tokens=False
        )
        
        # 构建模型输入和标签
        input_ids, labels = [], []
        for src, tgt in zip(tokenized_sources['input_ids'], tokenized_targets['input_ids']):
            input_ids.append(torch.tensor(src + tgt))
            if not self.train_on_source:
                labels.append(torch.tensor(
                    [IGNORE_INDEX]*len(src) + tgt.copy()
                ))
            else:
                labels.append(torch.tensor(src + tgt.copy()))
        
        # 填充处理
        input_ids = pad_sequence(input_ids, batch_first=True, 
                                padding_value=self.tokenizer.pad_token_id)
        labels = pad_sequence(labels, batch_first=True, 
                             padding_value=IGNORE_INDEX)
        
        return {
            'input_ids': input_ids,
            'attention_mask': input_ids.ne(self.tokenizer.pad_token_id),
            'labels': labels
        }

完整训练流程

结合上述组件,完整的微调流程如下:

from transformers import Seq2SeqTrainer

# 初始化数据整理器
data_collator = DataCollatorForCausalLM(
    tokenizer=tokenizer,
    source_max_len=128,
    target_max_len=128,
    train_on_source=True
)

# 创建训练器
trainer = Seq2SeqTrainer(
    model=model,
    tokenizer=tokenizer,
    args=training_args,
    train_dataset=dataset,
    data_collator=data_collator
)

# 开始训练
trainer.train()

常见问题解决方案

  1. 数据类型错误:确保预处理后的数据是文本格式,而非其他类型对象。

  2. 索引越界错误:检查数据分割逻辑是否正确,确保每行数据都包含input和output两部分。

  3. 特殊标记处理:根据使用的tokenizer确认BOS(开始)和EOS(结束)标记是否正确添加。

  4. 长度控制:合理设置source_max_len和target_max_len参数,避免输入过长被截断或过短浪费资源。

通过以上步骤,开发者可以成功地在TinyLlama项目上实现模型的微调。关键在于理解数据流向:从原始文本→预处理→分词→训练这一完整流程,每个环节都需要正确配置才能保证训练顺利进行。

登录后查看全文
热门项目推荐
相关项目推荐