首页
/ TinyLlama项目中的微调实践指南

TinyLlama项目中的微调实践指南

2025-05-27 03:31:25作者:盛欣凯Ernestine

数据集预处理关键点

在TinyLlama项目中进行模型微调时,数据集的组织格式至关重要。根据项目代码分析,数据集需要包含两个关键字段:inputoutput。这两个字段分别代表了模型的输入文本和期望的输出文本。

正确的数据集格式示例如下:

{
    'input': '用户输入的文本内容',
    'output': '模型期望生成的回答内容'
}

数据预处理实现

预处理函数应当将原始数据转换为上述格式。假设原始数据是以制表符分隔的文本文件,预处理可以这样实现:

from datasets import load_dataset

def preprocess_function(examples):
    parts = examples["text"].split("\t")
    return {
        "input": parts[0],  # 第一部分作为输入
        "output": parts[1]  # 第二部分作为期望输出
    }

# 加载并预处理数据集
dataset = load_dataset('text', data_files='your_data.txt')['train']
dataset = dataset.map(preprocess_function).remove_columns('text')

训练器配置详解

微调TinyLlama需要使用特定的训练器配置,核心在于正确设置数据整理器(DataCollator)。以下是关键配置步骤:

  1. 训练参数设置
from transformers import TrainingArguments

training_args = TrainingArguments(
    output_dir="./finetune_results",  # 输出目录
    num_train_epochs=1,              # 训练轮数
    per_device_train_batch_size=16,  # 训练批次大小
    per_device_eval_batch_size=64,   # 评估批次大小
    weight_decay=0.005,              # 权重衰减系数
    logging_dir="./logs",            # 日志目录
    remove_unused_columns=False      # 保留未使用列
)
  1. 数据整理器实现: 数据整理器负责将文本数据转换为模型可处理的张量格式,核心功能包括:
  • 添加特殊标记(BOS/EOS)
  • 控制输入输出长度
  • 处理填充和注意力掩码
from torch.nn.utils.rnn import pad_sequence
import torch
from dataclasses import dataclass
from typing import Dict, Sequence
import copy

@dataclass
class DataCollatorForCausalLM:
    tokenizer: transformers.PreTrainedTokenizer
    source_max_len: int = 128    # 输入最大长度
    target_max_len: int = 128    # 输出最大长度
    train_on_source: bool = True # 是否在源文本上训练
    predict_with_generate: bool = False  # 是否生成式预测

    def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]:
        # 添加特殊标记并分词
        sources = [f"{self.tokenizer.bos_token}{x['input']}" for x in instances]
        targets = [f"{x['output']}{self.tokenizer.eos_token}" for x in instances]
        
        # 分词处理
        tokenized_sources = self.tokenizer(
            sources, max_length=self.source_max_len, 
            truncation=True, add_special_tokens=False
        )
        tokenized_targets = self.tokenizer(
            targets, max_length=self.target_max_len,
            truncation=True, add_special_tokens=False
        )
        
        # 构建模型输入和标签
        input_ids, labels = [], []
        for src, tgt in zip(tokenized_sources['input_ids'], tokenized_targets['input_ids']):
            input_ids.append(torch.tensor(src + tgt))
            if not self.train_on_source:
                labels.append(torch.tensor(
                    [IGNORE_INDEX]*len(src) + tgt.copy()
                ))
            else:
                labels.append(torch.tensor(src + tgt.copy()))
        
        # 填充处理
        input_ids = pad_sequence(input_ids, batch_first=True, 
                                padding_value=self.tokenizer.pad_token_id)
        labels = pad_sequence(labels, batch_first=True, 
                             padding_value=IGNORE_INDEX)
        
        return {
            'input_ids': input_ids,
            'attention_mask': input_ids.ne(self.tokenizer.pad_token_id),
            'labels': labels
        }

完整训练流程

结合上述组件,完整的微调流程如下:

from transformers import Seq2SeqTrainer

# 初始化数据整理器
data_collator = DataCollatorForCausalLM(
    tokenizer=tokenizer,
    source_max_len=128,
    target_max_len=128,
    train_on_source=True
)

# 创建训练器
trainer = Seq2SeqTrainer(
    model=model,
    tokenizer=tokenizer,
    args=training_args,
    train_dataset=dataset,
    data_collator=data_collator
)

# 开始训练
trainer.train()

常见问题解决方案

  1. 数据类型错误:确保预处理后的数据是文本格式,而非其他类型对象。

  2. 索引越界错误:检查数据分割逻辑是否正确,确保每行数据都包含input和output两部分。

  3. 特殊标记处理:根据使用的tokenizer确认BOS(开始)和EOS(结束)标记是否正确添加。

  4. 长度控制:合理设置source_max_len和target_max_len参数,避免输入过长被截断或过短浪费资源。

通过以上步骤,开发者可以成功地在TinyLlama项目上实现模型的微调。关键在于理解数据流向:从原始文本→预处理→分词→训练这一完整流程,每个环节都需要正确配置才能保证训练顺利进行。

登录后查看全文
热门项目推荐
相关项目推荐

项目优选

收起
kernelkernel
deepin linux kernel
C
27
11
docsdocs
OpenHarmony documentation | OpenHarmony开发者文档
Dockerfile
472
3.49 K
nop-entropynop-entropy
Nop Platform 2.0是基于可逆计算理论实现的采用面向语言编程范式的新一代低代码开发平台,包含基于全新原理从零开始研发的GraphQL引擎、ORM引擎、工作流引擎、报表引擎、规则引擎、批处理引引擎等完整设计。nop-entropy是它的后端部分,采用java语言实现,可选择集成Spring框架或者Quarkus框架。中小企业可以免费商用
Java
10
1
leetcodeleetcode
🔥LeetCode solutions in any programming language | 多种编程语言实现 LeetCode、《剑指 Offer(第 2 版)》、《程序员面试金典(第 6 版)》题解
Java
65
19
flutter_flutterflutter_flutter
暂无简介
Dart
719
173
giteagitea
喝着茶写代码!最易用的自托管一站式代码托管平台,包含Git托管,代码审查,团队协作,软件包和CI/CD。
Go
23
0
kernelkernel
openEuler内核是openEuler操作系统的核心,既是系统性能与稳定性的基石,也是连接处理器、设备与服务的桥梁。
C
213
86
RuoYi-Vue3RuoYi-Vue3
🎉 (RuoYi)官方仓库 基于SpringBoot,Spring Security,JWT,Vue3 & Vite、Element Plus 的前后端分离权限管理系统
Vue
1.27 K
696
rainbondrainbond
无需学习 Kubernetes 的容器平台,在 Kubernetes 上构建、部署、组装和管理应用,无需 K8s 专业知识,全流程图形化管理
Go
15
1
apintoapinto
基于golang开发的网关。具有各种插件,可以自行扩展,即插即用。此外,它可以快速帮助企业管理API服务,提高API服务的稳定性和安全性。
Go
22
1