[WIP] feat: add InfiniOps as optional kernel provider#161
Open
chen2021673 wants to merge 1 commit into
Open
Conversation
614baf6 to
91b309a
Compare
91b309a to
c7e3c27
Compare
Wire InfiniOps in as a pluggable kernel provider keyed at the GEMM level: Dispatcher consults a per-key whitelist hook and routes registered ops to InfiniOps, falling back to the default CUDA kernel otherwise. linear, matmul and outer now invoke Gemm via Dispatcher rather than calling the cuBLAS wrapper directly, so InfiniOps Gemm transparently covers all three.
c7e3c27 to
865b51c
Compare
kilinchange
reviewed
May 29, 2026
Collaborator
There was a problem hiding this comment.
这里是出于什么原因要单独写一套 registry,而不能直接复用 InfiniTrain 原有的注册表呢?
Collaborator
There was a problem hiding this comment.
这个头文件内容没什么问题,但不适合放到 include 里作为公共头文件暴露,先放 infini_train/src/kernels/common 里吧
Collaborator
There was a problem hiding this comment.
这里不应该给 infinops 开额外分支,之前接沐曦 kernel 这块是不需要动的。
| @@ -0,0 +1,25 @@ | |||
| #include "infini_train/include/core/kernel_provider/infiniops/adapter.h" | |||
|
|
||
| } // namespace infini_train::kernel_provider::infiniops | ||
|
|
||
| REGISTER_INFINIOPS_KERNEL(AddForward, infini_train::kernel_provider::infiniops::AddForward) |
Collaborator
There was a problem hiding this comment.
如果是为了修改注册 key 而专门给 infiniops 写一套注册机制的话感觉不是很有必要,直接按平台注册就行。
Collaborator
There was a problem hiding this comment.
这部分是必要的通用 gemm 接口抽象改动,不涉及 infiniops 相关,可以考虑单独提 pr 先合。
| // FIXME: Requires stride tracking in the Tensor class before this can be implemented | ||
| // correctly. Currently always returns true as a placeholder. The contiguous guard in | ||
| // elementwise.cu ensures non-contiguous tensors fall back to the broadcast path. | ||
| // the elementwise provider ensures non-contiguous tensors fall back to the broadcast path. |
| std::shared_ptr<Tensor> Contiguous(); | ||
| // FIXME: Currently returns true unconditionally. Requires stride tracking in the Tensor | ||
| // class before this can be implemented correctly. The guard in elementwise.cu ensures | ||
| // class before this can be implemented correctly. The elementwise broadcast guard ensures |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
引入 InfiniOps 作为可选的 kernel provider,通过
USE_INFINIOPS=ON启用。插拔粒度落在 Dispatcher::Call():Dispatcher 在
GetKernel里增加白名单 hook,命中的 key 路由到 InfiniOps registry,未命中则回退到默认 CUDA kernel。linear、matmul、outer 三个上层算子从直接调用
GemmCuda改为Dispatcher::Instance().Call<void>({device.type(), "Gemm"}, ...),因此 InfiniOps 一次提供 Gemm 即可透明覆盖这三个上层算子,无需逐个包装。Changes
USE_INFINIOPSCMake 选项;third_party/InfiniOps作为子模块接入;启用时按需 add_subdirectory 并链接InfiniOps::infiniops。GetKernel增加 InfiniOps lookup hook;未命中 whitelist 时维持原行为。InfiniOpsRegistry(独立于主 Dispatcher 的 map)+REGISTER_INFINIOPS_KERNEL宏 + 全局 whitelist(当前包含Gemm、AddForward)。adapter.{h,cc}提供ToOpsDataType/ToOpsDevice/ToOpsTensor类型与张量桥接,dtype/device 对照采用inline const std::unordered_map。gemm.cc、elementwise.cc(AddForward)。common/gemm.{cu,cuh}把GemmCuda重命名为Gemm并通过REGISTER_KERNEL注册;linear/matmul/outer 的所有 GEMM 调用走 Dispatcher。测试
当前单卡测试性能精度对齐

问题:InfiniTrain 原生 CUDA GEMM 在 gemm.cu (line 63) 固定用CUBLAS_COMPUTE_32F,但 InfiniOps NVIDIA GEMM 原来对 fp32 走的是CUBLAS_COMPUTE_32F_FAST_TF32,这里我手动修改了 InfiniOps 源码
TODO:解决编译warning