NVIDIA CUTLASS 中的 EVT 功能扩展:实现 Sqrt 操作支持
2025-05-31 00:56:57作者:邬祺芯Juliet
背景介绍
NVIDIA CUTLASS 是一个高性能 CUDA C++ 模板库,用于实现矩阵乘法和其他线性代数运算。其中的 Epilogue Visitor Tree (EVT) 提供了一种灵活的方式来定义和组合核函数的尾端操作。在 Python 接口中,用户可以通过 cutlass.epilogue.trace 方法来定义自定义的尾端操作。
问题发现
在尝试使用 Python 接口定义 Adam 优化器的尾端操作时,开发者发现当前 EVT 实现缺少对 sqrt (平方根) 运算的支持。具体场景是在 Adam 优化器的实现中,需要计算梯度归一化项时使用了 torch.sqrt 操作,但 EVT 系统无法识别这一操作。
技术分析
通过分析 CUTLASS 的代码结构,我们发现 EVT 系统的操作映射主要在以下几个部分实现:
- C++ 核心功能:在
include/cutlass/functional.h中定义了各种基础数学运算 - Python 绑定:在 Python 接口中通过
ast_op_to_bindings方法将 Python AST 节点映射到 CUTLASS 的功能操作
当前系统已经支持了基本的算术运算(加、减、乘、除)和一些常用函数(如最大值、最小值),但缺少对平方根运算的支持。
解决方案
要实现 sqrt 操作的支持,需要在以下几个层面进行修改:
-
C++ 核心层:
- 在
functional.h中添加sqrt的函数实现 - 确保实现支持 CUDA 设备代码和模板参数
- 在
-
Python 接口层:
- 扩展
ast_op_to_bindings的映射表,添加对sqrt函数的支持 - 处理 Python AST 中
Call节点的特殊处理逻辑
- 扩展
-
类型系统:
- 确保新操作支持 CUTLASS 支持的各种数据类型(float16, float32, bfloat16 等)
- 实现类型推导规则
实现建议
对于想要贡献此功能的开发者,建议按照以下步骤进行:
- 首先在
functional.h中添加sqrt的模板函数实现 - 在 Python 接口中添加对应的操作映射
- 添加单元测试验证功能正确性
- 考虑性能优化(如使用 CUDA 内置函数)
- 文档更新,说明新支持的操作
扩展思考
这个问题反映了 EVT 系统的一个通用扩展模式。类似的数学函数(如指数、对数等)也可以通过相同的方式添加。CUTLASS 团队可以考虑建立一个更系统化的机制来支持常见数学函数的添加,而不是逐个硬编码。
总结
通过为 CUTLASS EVT 添加 sqrt 操作支持,可以显著增强其在机器学习优化算法(如 Adam)中的应用能力。这一改进不仅解决了眼前的问题,也为未来扩展更多数学函数提供了参考模式。对于深度学习框架开发者来说,这样的扩展意味着能够更灵活地在高性能核函数中实现复杂的数学运算组合。
登录后查看全文
热门项目推荐
相关项目推荐
AutoGLM-Phone-9BAutoGLM-Phone-9B是基于AutoGLM构建的移动智能助手框架,依托多模态感知理解手机屏幕并执行自动化操作。Jinja00
Kimi-K2-ThinkingKimi K2 Thinking 是最新、性能最强的开源思维模型。从 Kimi K2 开始,我们将其打造为能够逐步推理并动态调用工具的思维智能体。通过显著提升多步推理深度,并在 200–300 次连续调用中保持稳定的工具使用能力,它在 Humanity's Last Exam (HLE)、BrowseComp 等基准测试中树立了新的技术标杆。同时,K2 Thinking 是原生 INT4 量化模型,具备 256k 上下文窗口,实现了推理延迟和 GPU 内存占用的无损降低。Python00
GLM-4.6V-FP8GLM-4.6V-FP8是GLM-V系列开源模型,支持128K上下文窗口,融合原生多模态函数调用能力,实现从视觉感知到执行的闭环。具备文档理解、图文生成、前端重构等功能,适用于云集群与本地部署,在同类参数规模中视觉理解性能领先。Jinja00
HunyuanOCRHunyuanOCR 是基于混元原生多模态架构打造的领先端到端 OCR 专家级视觉语言模型。它采用仅 10 亿参数的轻量化设计,在业界多项基准测试中取得了当前最佳性能。该模型不仅精通复杂多语言文档解析,还在文本检测与识别、开放域信息抽取、视频字幕提取及图片翻译等实际应用场景中表现卓越。00
GLM-ASR-Nano-2512GLM-ASR-Nano-2512 是一款稳健的开源语音识别模型,参数规模为 15 亿。该模型专为应对真实场景的复杂性而设计,在保持紧凑体量的同时,多项基准测试表现优于 OpenAI Whisper V3。Python00
GLM-TTSGLM-TTS 是一款基于大语言模型的高质量文本转语音(TTS)合成系统,支持零样本语音克隆和流式推理。该系统采用两阶段架构,结合了用于语音 token 生成的大语言模型(LLM)和用于波形合成的流匹配(Flow Matching)模型。 通过引入多奖励强化学习框架,GLM-TTS 显著提升了合成语音的表现力,相比传统 TTS 系统实现了更自然的情感控制。Python00
Spark-Formalizer-X1-7BSpark-Formalizer 是由科大讯飞团队开发的专用大型语言模型,专注于数学自动形式化任务。该模型擅长将自然语言数学问题转化为精确的 Lean4 形式化语句,在形式化语句生成方面达到了业界领先水平。Python00
项目优选
收起
deepin linux kernel
C
24
9
Ascend Extension for PyTorch
Python
223
245
Nop Platform 2.0是基于可逆计算理论实现的采用面向语言编程范式的新一代低代码开发平台,包含基于全新原理从零开始研发的GraphQL引擎、ORM引擎、工作流引擎、报表引擎、规则引擎、批处理引引擎等完整设计。nop-entropy是它的后端部分,采用java语言实现,可选择集成Spring框架或者Quarkus框架。中小企业可以免费商用
Java
9
1
暂无简介
Dart
672
157
本项目是CANN提供的数学类基础计算算子库,实现网络在NPU上加速计算。
C++
662
312
React Native鸿蒙化仓库
JavaScript
262
322
🔥LeetCode solutions in any programming language | 多种编程语言实现 LeetCode、《剑指 Offer(第 2 版)》、《程序员面试金典(第 6 版)》题解
Java
64
19
仓颉编译器源码及 cjdb 调试工具。
C++
134
867
仓颉编程语言测试用例。
Cangjie
37
860
openGauss kernel ~ openGauss is an open source relational database management system
C++
160
218