首页
/ LLM-Sequential-Recommendation项目核心技术实现解析

LLM-Sequential-Recommendation项目核心技术实现解析

2025-06-12 18:19:25作者:凤尚柏Louis

项目概述

LLM-Sequential-Recommendation是一个基于大型语言模型(LLM)的序列推荐系统项目,该项目创新性地将传统序列推荐算法与大型语言模型相结合,实现了多种先进的推荐模型架构。本文将深入解析该项目的核心技术实现细节,包括神经网络模型设计、SKNN算法实现、混合模型策略以及超参数搜索方法。

神经网络模型实现

基础架构设计

项目中的BERT4Rec、SASRec和GRU4Rec模型均继承自NeuralModel基类,该基类封装了Keras模型的训练和预测功能。训练过程中采用了创新的"内部验证"机制:

  1. 从训练会话中抽取10%(最多500个会话)作为验证集
  2. 使用NDCG@20作为早停指标
  3. 设置2个epoch的耐心值(patience),在验证指标连续2个epoch未提升时停止训练
  4. 恢复模型到验证指标最佳的epoch权重

这种设计有效防止了模型过拟合,同时考虑到模型通常在5-15个epoch内收敛,较小的耐心值既保证了效率又不会影响最终性能。

所有神经网络模型都采用了N=20的会话截断长度,确保不同长度的会话可以统一处理。训练过程使用Keras的AdamW优化器。

BERT4Rec实现细节

项目中的BERT4Rec实现主要遵循原始论文设计,但在初始化策略上有所改进:

  • 使用GlorotUniform初始化权重(keras默认)
  • 偏置项初始化为零
  • 相比原始论文的截断正态分布初始化,这种策略获得了更好的性能表现

SASRec实现优化

SASRec实现有以下关键改进:

  1. 采用BERT4Rec的投影头(projection head)替代原始设计
  2. 使用分类交叉熵损失函数,省略了原始论文中的负采样
  3. 支持多头注意力机制(虽然原始论文发现单头效果已足够)
  4. 在嵌入层后添加了dropout层

这些改进使得模型性能得到提升,同时保持了原始架构的核心思想。

GRU4Rec实现创新

项目对GRU4Rec进行了几项重要改进:

  1. 采用整会话批处理方式,简化了实现复杂度
  2. 同样使用BERT4Rec的投影头设计,减少了模型参数量
  3. 使用分类交叉熵损失函数,未出现原始论文中报告的不稳定问题
  4. 通过会话截断有效去除了长尾噪声

SKNN算法实现

项目中的SKNN实现支持多种变体和配置:

  1. 相似度计算:支持点积和余弦相似度
  2. 采样策略:随机采样或基于时间戳的最近优先采样
  3. IDF加权:可选是否基于项目IDF分数进行加权
  4. 变体支持
    • V-SKNN(使用衰减参数)
    • SF-SKNN(设置sequential_filter为true)
    • S-SKNN(设置sequential_weighting为true)
    • 标准SKNN
  5. 嵌入版本:支持使用项目嵌入表示,需配置:
    • 提示会话嵌入组合策略
    • 训练会话嵌入组合策略
    • 降维配置
    • 衰减参数

混合模型策略

嵌入模型集成

LLMSeqSim模型基于四个核心维度构建:

  1. 源LLM嵌入模型(OpenAI或Google)
  2. 降维方法
  3. 降维后的维度数
  4. 会话嵌入计算方法

项目通过分析不同变体推荐结果的交集发现:

  1. 不同配置的推荐结果差异显著
  2. 通过排名或置信度组合可以产生独特的推荐列表
  3. 这种集成能够捕捉互补的语义信号,提升推荐性能

数据处理策略

项目采用了严格的数据预处理方法:

  1. Beauty和Steam数据集

    • 使用5-core方法处理,确保每个会话和项目至少有5次交互
    • 采用时间分割策略划分训练测试集
    • 过滤测试提示中未出现在训练数据中的项目
  2. Delivery Hero数据集

    • 保持原始数据以模拟真实场景
    • 仅从测试集中移除单交互会话

时间分割策略最能模拟真实场景,优于文献中常见的随机分割或演化分割方法。

超参数搜索方法

项目采用Tree-Parzen-Estimator(TPE)采样器进行超参数搜索:

  1. 搜索策略

    • 初始40个随机配置避免偏差
    • 后续使用TPE采样器建议
    • 连续100次试验未改进则停止
    • 最长运行时间72小时
  2. 验证折叠

    • 将训练数据分为4个时间箱
    • 创建3个验证折叠,保持时间顺序
    • 使用早剪枝策略(每折后剪除后20%配置)
  3. LLMSeqPrompt特殊处理

    • 仅搜索预测参数(temperature和top_p)
    • 固定训练参数以减少成本
    • 针对GPT和PaLM采用不同温度范围

超参数搜索范围

项目为不同模型设定了详细的超参数搜索范围,主要包含:

  1. 神经网络共同参数

    • 学习率(0.0001-0.01)
    • 权重衰减(0-0.1)
    • 批量大小(32-512)
    • 嵌入维度(16-512)
    • dropout率(0-0.9)
  2. 模型特定参数

    • BERT4Rec/SASRec:层数(1-4)、头数(1-4)、mask概率(0.05-0.9)
    • GRU4Rec:隐藏层维度(16-512)
    • 降维方法:PCA、自编码器、LDA、随机投影
  3. SKNN相关参数

    • 近邻数(50-500)
    • 采样大小(500-2000)
    • 衰减策略(无、线性、谐波)
    • 相似度度量(点积、余弦)
  4. LLMSeqPrompt参数

    • 温度(GPT:0-0.5, PaLM:0.125-0.25)
    • 固定top_p=1

技术实现亮点

  1. 高效的早停机制:基于内部验证集的2-epoch耐心值,平衡了训练效率和模型性能。

  2. 创新的模型改进:在各原始模型基础上进行针对性优化,如投影头共享、初始化策略调整等。

  3. 灵活的SKNN实现:单一类支持多种SKNN变体,便于比较和组合。

  4. 智能的超参数搜索:结合时间分割的验证策略和早剪枝,大幅提升搜索效率。

  5. 严谨的数据处理:5-core处理和时间分割策略确保了数据质量和真实场景模拟。

通过以上技术实现,LLM-Sequential-Recommendation项目构建了一个强大而灵活的序列推荐系统框架,为相关领域的研究和应用提供了有价值的参考。

登录后查看全文
热门项目推荐

项目优选

收起
ohos_react_nativeohos_react_native
React Native鸿蒙化仓库
C++
179
263
RuoYi-Vue3RuoYi-Vue3
🎉 (RuoYi)官方仓库 基于SpringBoot,Spring Security,JWT,Vue3 & Vite、Element Plus 的前后端分离权限管理系统
Vue
869
514
openGauss-serveropenGauss-server
openGauss kernel ~ openGauss is an open source relational database management system
C++
130
183
openHiTLSopenHiTLS
旨在打造算法先进、性能卓越、高效敏捷、安全可靠的密码套件,通过轻量级、可剪裁的软件技术架构满足各行业不同场景的多样化要求,让密码技术应用更简单,同时探索后量子等先进算法创新实践,构建密码前沿技术底座!
C
328
377
Cangjie-ExamplesCangjie-Examples
本仓将收集和展示高质量的仓颉示例代码,欢迎大家投稿,让全世界看到您的妙趣设计,也让更多人通过您的编码理解和喜爱仓颉语言。
Cangjie
333
1.09 K
harmony-utilsharmony-utils
harmony-utils 一款功能丰富且极易上手的HarmonyOS工具库,借助众多实用工具类,致力于助力开发者迅速构建鸿蒙应用。其封装的工具涵盖了APP、设备、屏幕、授权、通知、线程间通信、弹框、吐司、生物认证、用户首选项、拍照、相册、扫码、文件、日志,异常捕获、字符、字符串、数字、集合、日期、随机、base64、加密、解密、JSON等一系列的功能和操作,能够满足各种不同的开发需求。
ArkTS
28
0
CangjieCommunityCangjieCommunity
为仓颉编程语言开发者打造活跃、开放、高质量的社区环境
Markdown
1.08 K
0
kernelkernel
deepin linux kernel
C
22
5
WxJavaWxJava
微信开发 Java SDK,支持微信支付、开放平台、公众号、视频号、企业微信、小程序等的后端开发,记得关注公众号及时接受版本更新信息,以及加入微信群进行深入讨论
Java
829
22
cherry-studiocherry-studio
🍒 Cherry Studio 是一款支持多个 LLM 提供商的桌面客户端
TypeScript
601
58