NVIDIA CUTLASS 中的 EVT 功能扩展:实现 Sqrt 操作支持
2025-05-31 17:33:09作者:邬祺芯Juliet
背景介绍
NVIDIA CUTLASS 是一个高性能 CUDA C++ 模板库,用于实现矩阵乘法和其他线性代数运算。其中的 Epilogue Visitor Tree (EVT) 提供了一种灵活的方式来定义和组合核函数的尾端操作。在 Python 接口中,用户可以通过 cutlass.epilogue.trace 方法来定义自定义的尾端操作。
问题发现
在尝试使用 Python 接口定义 Adam 优化器的尾端操作时,开发者发现当前 EVT 实现缺少对 sqrt (平方根) 运算的支持。具体场景是在 Adam 优化器的实现中,需要计算梯度归一化项时使用了 torch.sqrt 操作,但 EVT 系统无法识别这一操作。
技术分析
通过分析 CUTLASS 的代码结构,我们发现 EVT 系统的操作映射主要在以下几个部分实现:
- C++ 核心功能:在
include/cutlass/functional.h中定义了各种基础数学运算 - Python 绑定:在 Python 接口中通过
ast_op_to_bindings方法将 Python AST 节点映射到 CUTLASS 的功能操作
当前系统已经支持了基本的算术运算(加、减、乘、除)和一些常用函数(如最大值、最小值),但缺少对平方根运算的支持。
解决方案
要实现 sqrt 操作的支持,需要在以下几个层面进行修改:
-
C++ 核心层:
- 在
functional.h中添加sqrt的函数实现 - 确保实现支持 CUDA 设备代码和模板参数
- 在
-
Python 接口层:
- 扩展
ast_op_to_bindings的映射表,添加对sqrt函数的支持 - 处理 Python AST 中
Call节点的特殊处理逻辑
- 扩展
-
类型系统:
- 确保新操作支持 CUTLASS 支持的各种数据类型(float16, float32, bfloat16 等)
- 实现类型推导规则
实现建议
对于想要贡献此功能的开发者,建议按照以下步骤进行:
- 首先在
functional.h中添加sqrt的模板函数实现 - 在 Python 接口中添加对应的操作映射
- 添加单元测试验证功能正确性
- 考虑性能优化(如使用 CUDA 内置函数)
- 文档更新,说明新支持的操作
扩展思考
这个问题反映了 EVT 系统的一个通用扩展模式。类似的数学函数(如指数、对数等)也可以通过相同的方式添加。CUTLASS 团队可以考虑建立一个更系统化的机制来支持常见数学函数的添加,而不是逐个硬编码。
总结
通过为 CUTLASS EVT 添加 sqrt 操作支持,可以显著增强其在机器学习优化算法(如 Adam)中的应用能力。这一改进不仅解决了眼前的问题,也为未来扩展更多数学函数提供了参考模式。对于深度学习框架开发者来说,这样的扩展意味着能够更灵活地在高性能核函数中实现复杂的数学运算组合。
登录后查看全文
热门项目推荐
相关项目推荐
GLM-5智谱 AI 正式发布 GLM-5,旨在应对复杂系统工程和长时域智能体任务。Jinja00
GLM-5-w4a8GLM-5-w4a8基于混合专家架构,专为复杂系统工程与长周期智能体任务设计。支持单/多节点部署,适配Atlas 800T A3,采用w4a8量化技术,结合vLLM推理优化,高效平衡性能与精度,助力智能应用开发Jinja00
jiuwenclawJiuwenClaw 是一款基于openJiuwen开发的智能AI Agent,它能够将大语言模型的强大能力,通过你日常使用的各类通讯应用,直接延伸至你的指尖。Python0193- QQwen3.5-397B-A17BQwen3.5 实现了重大飞跃,整合了多模态学习、架构效率、强化学习规模以及全球可访问性等方面的突破性进展,旨在为开发者和企业赋予前所未有的能力与效率。Jinja00
AtomGit城市坐标计划AtomGit 城市坐标计划开启!让开源有坐标,让城市有星火。致力于与城市合伙人共同构建并长期运营一个健康、活跃的本地开发者生态。01
awesome-zig一个关于 Zig 优秀库及资源的协作列表。Makefile00
项目优选
收起
deepin linux kernel
C
27
12
OpenHarmony documentation | OpenHarmony开发者文档
Dockerfile
601
4.04 K
🔥LeetCode solutions in any programming language | 多种编程语言实现 LeetCode、《剑指 Offer(第 2 版)》、《程序员面试金典(第 6 版)》题解
Java
69
21
Ascend Extension for PyTorch
Python
441
531
AscendNPU-IR是基于MLIR(Multi-Level Intermediate Representation)构建的,面向昇腾亲和算子编译时使用的中间表示,提供昇腾完备表达能力,通过编译优化提升昇腾AI处理器计算效率,支持通过生态框架使能昇腾AI处理器与深度调优
C++
112
170
🎉 (RuoYi)官方仓库 基于SpringBoot,Spring Security,JWT,Vue3 & Vite、Element Plus 的前后端分离权限管理系统
Vue
1.46 K
825
本项目是CANN提供的数学类基础计算算子库,实现网络在NPU上加速计算。
C++
922
770
暂无简介
Dart
847
204
React Native鸿蒙化仓库
JavaScript
321
375
openGauss kernel ~ openGauss is an open source relational database management system
C++
174
249