从0到1:用FlagEmbedding定制专属嵌入模型,让金融问答准确率提升20%
你是否遇到过通用嵌入模型在特定领域表现不佳的问题?当处理金融、医疗等专业领域数据时,通用模型往往无法捕捉行业术语和语境,导致检索准确率大幅下降。本文将带你通过FlagEmbedding框架,使用金融问答数据集微调嵌入模型,解决这一痛点。读完本文,你将掌握数据准备、模型微调、效果评估的全流程,让模型在你的业务场景中实现精准匹配。
为什么需要微调嵌入模型
嵌入模型(Embedding Model)是将文本转化为向量的工具,广泛用于搜索引擎、推荐系统等场景。通用模型如BGE虽然在一般任务上表现优异,但在金融领域可能出现以下问题:
- 专业术语理解不足:如"衍生品"、"资产负债表"等术语的向量表示不准确
- 领域特有语义缺失:无法识别金融问答中的隐含关系
- 检索精度不足:相关文档排名靠后,影响用户体验
通过微调,我们可以让模型学习特定领域的语言特征,提升嵌入质量。官方文档:docs/Introduction/quick_start.rst
微调全流程概览
微调嵌入模型主要包括三个步骤:数据准备、模型训练和效果评估。以下是完整流程图:
graph TD
A[数据准备] --> B[数据集格式转换]
B --> C[划分训练集/测试集]
C --> D[模型微调]
D --> E[超参数优化]
E --> F[模型评估]
F --> G[效果对比]
第一步:数据准备
数据格式要求
FlagEmbedding要求训练数据为JSON格式,每条数据包含以下字段:
query: 查询文本pos: 相关文本列表neg: 无关文本列表id: 样本唯一标识
详细格式说明:Tutorials/7_Fine-tuning/7.1.1_Data_preparation.ipynb
金融问答数据集处理
我们使用金融10K报告问答数据集作为示例,原始数据包含question和context字段,需要转换为FlagEmbedding要求的格式:
from datasets import load_dataset
# 加载原始数据集
ds = load_dataset("virattt/financial-qa-10K", split="train")
# 选择并更名必要字段
ds = ds.select_columns(column_names=["question", "context"])
ds = ds.rename_column("question", "query")
ds = ds.rename_column("context", "pos")
ds = ds.add_column("id", [str(i) for i in range(len(ds))])
构造负样本
由于原始数据没有负样本,我们从语料库中随机采样:
import numpy as np
np.random.seed(520)
neg_num = 10 # 每个query配10个负样本
def add_negatives(example):
# 从整个数据集随机选择负样本
ids = np.random.randint(0, len(ds), size=neg_num)
# 确保不选到自身
while example["id"] in ids:
ids = np.random.randint(0, len(ds), size=neg_num)
example["neg"] = [ds[i]["pos"] for i in ids]
return example
ds = ds.map(add_negatives)
添加查询指令
为模型提供明确的任务指令,提升嵌入质量:
instruction = "Represent this sentence for searching relevant passages: "
ds = ds.add_column("prompt", [instruction]*len(ds))
数据划分与保存
将数据集划分为训练集和测试集,并保存为JSON格式:
# 划分训练集和测试集(9:1)
split = ds.train_test_split(test_size=0.1, shuffle=True, seed=520)
train = split["train"]
test = split["test"]
# 保存为JSON格式
train.to_json("ft_data/training.json")
test.to_json("ft_data/test.json")
第二步:模型微调
环境准备
首先安装FlagEmbedding及微调依赖:
pip install -U FlagEmbedding[finetune]
微调脚本配置
FlagEmbedding提供了便捷的微调脚本,位于examples/finetune/embedder/encoder_only/base.sh,主要参数说明:
| 参数 | 说明 | 推荐值 |
|---|---|---|
model_name_or_path |
预训练模型路径 | BAAI/bge-large-en-v1.5 |
train_data |
训练数据路径 | ft_data/training.json |
output_dir |
模型保存路径 | ./financial_bge_model |
learning_rate |
学习率 | 1e-5 |
num_train_epochs |
训练轮数 | 2 |
per_device_train_batch_size |
每设备批次大小 | 2 |
query_max_len |
查询文本最大长度 | 512 |
passage_max_len |
文档文本最大长度 | 512 |
启动微调
使用DeepSpeed加速训练,配置文件位于Tutorials/7_Fine-tuning/config/ds_stage0.json:
deepspeed --num_gpus=1 ./examples/finetune/embedder/encoder_only/run.py \
--model_name_or_path BAAI/bge-large-en-v1.5 \
--train_data ./ft_data/training.json \
--output_dir ./financial_bge_model \
--learning_rate 1e-5 \
--num_train_epochs 2 \
--per_device_train_batch_size 2 \
--query_max_len 512 \
--passage_max_len 512 \
--train_group_size 8 \
--negatives_cross_device \
--temperature 0.02 \
--normalize_embeddings \
--do_train \
--fp16 \
--gradient_checkpointing \
--deepspeed ./Tutorials/7_Fine-tuning/config/ds_stage0.json
训练过程中会输出损失值变化,正常情况下损失应逐步下降:
{'loss': 0.0124, 'grad_norm': 1.094, 'learning_rate': 0.0, 'epoch': 0.0}
{'loss': 0.0067, 'grad_norm': 0.676, 'learning_rate': 1.909e-6, 'epoch': 0.0}
...
{'loss': 0.0001, 'grad_norm': 0.0092, 'learning_rate': 6.578e-6, 'epoch': 0.03}
第三步:效果评估
评估数据集准备
从测试集中提取查询、文档和相关性判断:
from datasets import load_dataset
# 加载测试数据
queries = load_dataset("json", data_files="ft_data/test_queries.jsonl")["train"]
corpus = load_dataset("json", data_files="ft_data/corpus.jsonl")["train"]
qrels = load_dataset("json", data_files="ft_data/test_qrels.jsonl")["train"]
# 转换为评估所需格式
qrels_dict = {}
for line in qrels:
if line['qid'] not in qrels_dict:
qrels_dict[line['qid']] = {}
qrels_dict[line['qid']][line['docid']] = line['relevance']
评估指标计算
使用官方评估工具Tutorials/7_Fine-tuning/7.1.3_Eval_FT_Model.ipynb,支持NDCG、MAP、MRR等主流指标:
from FlagEmbedding import FlagModel
from FlagEmbedding.abc.evaluation.utils import evaluate_metrics
# 加载微调前后的模型
raw_model = FlagModel("BAAI/bge-large-en-v1.5")
ft_model = FlagModel("./financial_bge_model")
# 评估原始模型
raw_results = search(raw_model, queries_text, corpus_text)
raw_metrics = evaluate_metrics(qrels_dict, raw_results, [10, 100])
# 评估微调后模型
ft_results = search(ft_model, queries_text, corpus_text)
ft_metrics = evaluate_metrics(qrels_dict, ft_results, [10, 100])
微调效果对比
在金融问答数据集上的评估结果(越高越好):
| 指标 | 原始模型 | 微调后模型 | 提升幅度 |
|---|---|---|---|
| NDCG@10 | 0.704 | 0.844 | +20% |
| MAP@10 | 0.666 | 0.816 | +22.5% |
| MRR@10 | 0.666 | 0.816 | +22.5% |
| Recall@10 | 0.823 | 0.931 | +13.1% |
可以看到,微调后的模型在各项指标上均有显著提升,特别是NDCG和MAP指标提升超过20%,证明模型更好地捕捉了金融领域的语义特征。
总结与展望
本文详细介绍了使用FlagEmbedding微调嵌入模型的完整流程,通过金融问答数据集的实践,验证了微调方法的有效性。关键步骤包括:
- 数据准备:格式转换、负样本构造、指令添加
- 模型微调:参数配置、高效训练
- 效果评估:指标对比、结果分析
后续可以尝试以下优化方向:
- 使用更专业的金融领域预训练模型作为基座
- 尝试难负样本挖掘技术,提升模型区分能力
- 结合领域知识图谱,增强语义理解
希望本文能帮助你在特定领域构建高性能的嵌入模型,FlagEmbedding更多高级功能可参考官方教程Tutorials/README.md。如有问题,欢迎在项目社区交流讨论。
点赞+收藏本文,关注FlagEmbedding项目,获取更多嵌入模型优化技巧!
Kimi-K2.5Kimi K2.5 是一款开源的原生多模态智能体模型,它在 Kimi-K2-Base 的基础上,通过对约 15 万亿混合视觉和文本 tokens 进行持续预训练构建而成。该模型将视觉与语言理解、高级智能体能力、即时模式与思考模式,以及对话式与智能体范式无缝融合。Python00
GLM-4.7-FlashGLM-4.7-Flash 是一款 30B-A3B MoE 模型。作为 30B 级别中的佼佼者,GLM-4.7-Flash 为追求性能与效率平衡的轻量化部署提供了全新选择。Jinja00
VLOOKVLOOK™ 是优雅好用的 Typora/Markdown 主题包和增强插件。 VLOOK™ is an elegant and practical THEME PACKAGE × ENHANCEMENT PLUGIN for Typora/Markdown.Less00
PaddleOCR-VL-1.5PaddleOCR-VL-1.5 是 PaddleOCR-VL 的新一代进阶模型,在 OmniDocBench v1.5 上实现了 94.5% 的全新 state-of-the-art 准确率。 为了严格评估模型在真实物理畸变下的鲁棒性——包括扫描伪影、倾斜、扭曲、屏幕拍摄和光照变化——我们提出了 Real5-OmniDocBench 基准测试集。实验结果表明,该增强模型在新构建的基准测试集上达到了 SOTA 性能。此外,我们通过整合印章识别和文本检测识别(text spotting)任务扩展了模型的能力,同时保持 0.9B 的超紧凑 VLM 规模,具备高效率特性。Python00
KuiklyUI基于KMP技术的高性能、全平台开发框架,具备统一代码库、极致易用性和动态灵活性。 Provide a high-performance, full-platform development framework with unified codebase, ultimate ease of use, and dynamic flexibility. 注意:本仓库为Github仓库镜像,PR或Issue请移步至Github发起,感谢支持!Kotlin07
compass-metrics-modelMetrics model project for the OSS CompassPython00