NVIDIA CUTLASS 中的 EVT 功能扩展:实现 Sqrt 操作支持
2025-05-31 17:33:09作者:邬祺芯Juliet
背景介绍
NVIDIA CUTLASS 是一个高性能 CUDA C++ 模板库,用于实现矩阵乘法和其他线性代数运算。其中的 Epilogue Visitor Tree (EVT) 提供了一种灵活的方式来定义和组合核函数的尾端操作。在 Python 接口中,用户可以通过 cutlass.epilogue.trace 方法来定义自定义的尾端操作。
问题发现
在尝试使用 Python 接口定义 Adam 优化器的尾端操作时,开发者发现当前 EVT 实现缺少对 sqrt (平方根) 运算的支持。具体场景是在 Adam 优化器的实现中,需要计算梯度归一化项时使用了 torch.sqrt 操作,但 EVT 系统无法识别这一操作。
技术分析
通过分析 CUTLASS 的代码结构,我们发现 EVT 系统的操作映射主要在以下几个部分实现:
- C++ 核心功能:在
include/cutlass/functional.h中定义了各种基础数学运算 - Python 绑定:在 Python 接口中通过
ast_op_to_bindings方法将 Python AST 节点映射到 CUTLASS 的功能操作
当前系统已经支持了基本的算术运算(加、减、乘、除)和一些常用函数(如最大值、最小值),但缺少对平方根运算的支持。
解决方案
要实现 sqrt 操作的支持,需要在以下几个层面进行修改:
-
C++ 核心层:
- 在
functional.h中添加sqrt的函数实现 - 确保实现支持 CUDA 设备代码和模板参数
- 在
-
Python 接口层:
- 扩展
ast_op_to_bindings的映射表,添加对sqrt函数的支持 - 处理 Python AST 中
Call节点的特殊处理逻辑
- 扩展
-
类型系统:
- 确保新操作支持 CUTLASS 支持的各种数据类型(float16, float32, bfloat16 等)
- 实现类型推导规则
实现建议
对于想要贡献此功能的开发者,建议按照以下步骤进行:
- 首先在
functional.h中添加sqrt的模板函数实现 - 在 Python 接口中添加对应的操作映射
- 添加单元测试验证功能正确性
- 考虑性能优化(如使用 CUDA 内置函数)
- 文档更新,说明新支持的操作
扩展思考
这个问题反映了 EVT 系统的一个通用扩展模式。类似的数学函数(如指数、对数等)也可以通过相同的方式添加。CUTLASS 团队可以考虑建立一个更系统化的机制来支持常见数学函数的添加,而不是逐个硬编码。
总结
通过为 CUTLASS EVT 添加 sqrt 操作支持,可以显著增强其在机器学习优化算法(如 Adam)中的应用能力。这一改进不仅解决了眼前的问题,也为未来扩展更多数学函数提供了参考模式。对于深度学习框架开发者来说,这样的扩展意味着能够更灵活地在高性能核函数中实现复杂的数学运算组合。
登录后查看全文
热门项目推荐
相关项目推荐
Kimi-K2.5Kimi K2.5 是一款开源的原生多模态智能体模型,它在 Kimi-K2-Base 的基础上,通过对约 15 万亿混合视觉和文本 tokens 进行持续预训练构建而成。该模型将视觉与语言理解、高级智能体能力、即时模式与思考模式,以及对话式与智能体范式无缝融合。Python00
GLM-4.7-FlashGLM-4.7-Flash 是一款 30B-A3B MoE 模型。作为 30B 级别中的佼佼者,GLM-4.7-Flash 为追求性能与效率平衡的轻量化部署提供了全新选择。Jinja00
VLOOKVLOOK™ 是优雅好用的 Typora/Markdown 主题包和增强插件。 VLOOK™ is an elegant and practical THEME PACKAGE × ENHANCEMENT PLUGIN for Typora/Markdown.Less00
PaddleOCR-VL-1.5PaddleOCR-VL-1.5 是 PaddleOCR-VL 的新一代进阶模型,在 OmniDocBench v1.5 上实现了 94.5% 的全新 state-of-the-art 准确率。 为了严格评估模型在真实物理畸变下的鲁棒性——包括扫描伪影、倾斜、扭曲、屏幕拍摄和光照变化——我们提出了 Real5-OmniDocBench 基准测试集。实验结果表明,该增强模型在新构建的基准测试集上达到了 SOTA 性能。此外,我们通过整合印章识别和文本检测识别(text spotting)任务扩展了模型的能力,同时保持 0.9B 的超紧凑 VLM 规模,具备高效率特性。Python00
KuiklyUI基于KMP技术的高性能、全平台开发框架,具备统一代码库、极致易用性和动态灵活性。 Provide a high-performance, full-platform development framework with unified codebase, ultimate ease of use, and dynamic flexibility. 注意:本仓库为Github仓库镜像,PR或Issue请移步至Github发起,感谢支持!Kotlin07
compass-metrics-modelMetrics model project for the OSS CompassPython00
项目优选
收起
deepin linux kernel
C
27
11
OpenHarmony documentation | OpenHarmony开发者文档
Dockerfile
523
3.72 K
Ascend Extension for PyTorch
Python
328
387
本项目是CANN提供的数学类基础计算算子库,实现网络在NPU上加速计算。
C++
876
576
openEuler内核是openEuler操作系统的核心,既是系统性能与稳定性的基石,也是连接处理器、设备与服务的桥梁。
C
335
161
暂无简介
Dart
762
187
🎉 (RuoYi)官方仓库 基于SpringBoot,Spring Security,JWT,Vue3 & Vite、Element Plus 的前后端分离权限管理系统
Vue
1.33 K
745
Nop Platform 2.0是基于可逆计算理论实现的采用面向语言编程范式的新一代低代码开发平台,包含基于全新原理从零开始研发的GraphQL引擎、ORM引擎、工作流引擎、报表引擎、规则引擎、批处理引引擎等完整设计。nop-entropy是它的后端部分,采用java语言实现,可选择集成Spring框架或者Quarkus框架。中小企业可以免费商用
Java
12
1
React Native鸿蒙化仓库
JavaScript
302
349
华为昇腾面向大规模分布式训练的多模态大模型套件,支撑多模态生成、多模态理解。
Python
112
136