Skip to content

[WIP] feat: add InfiniOps as optional kernel provider#161

Open
chen2021673 wants to merge 1 commit into
masterfrom
infiniops_plug_in
Open

[WIP] feat: add InfiniOps as optional kernel provider#161
chen2021673 wants to merge 1 commit into
masterfrom
infiniops_plug_in

Conversation

@chen2021673
Copy link
Copy Markdown
Contributor

@chen2021673 chen2021673 commented May 28, 2026

Summary

引入 InfiniOps 作为可选的 kernel provider,通过 USE_INFINIOPS=ON 启用。
插拔粒度落在 Dispatcher::Call():Dispatcher 在 GetKernel 里增加白名单 hook,命中的 key 路由到 InfiniOps registry,未命中则回退到默认 CUDA kernel。

linear、matmul、outer 三个上层算子从直接调用 GemmCuda 改为 Dispatcher::Instance().Call<void>({device.type(), "Gemm"}, ...),因此 InfiniOps 一次提供 Gemm 即可透明覆盖这三个上层算子,无需逐个包装。

Changes

  • 构建:新增 USE_INFINIOPS CMake 选项;third_party/InfiniOps 作为子模块接入;启用时按需 add_subdirectory 并链接 InfiniOps::infiniops
  • DispatcherGetKernel 增加 InfiniOps lookup hook;未命中 whitelist 时维持原行为。
  • Registry:新增 InfiniOpsRegistry(独立于主 Dispatcher 的 map)+ REGISTER_INFINIOPS_KERNEL 宏 + 全局 whitelist(当前包含 GemmAddForward)。
  • Adapteradapter.{h,cc} 提供 ToOpsDataType / ToOpsDevice / ToOpsTensor 类型与张量桥接,dtype/device 对照采用 inline const std::unordered_map
  • Handle:CUDA backend 通过 handle factory 注入 stream,避免 adapter 与 CUDA runtime 头硬绑定。
  • InfiniOps 算子封装gemm.ccelementwise.cc(AddForward)。
  • 现有 CUDA kernel 改写common/gemm.{cu,cuh}GemmCuda 重命名为 Gemm 并通过 REGISTER_KERNEL 注册;linear/matmul/outer 的所有 GEMM 调用走 Dispatcher。

测试

当前单卡测试性能精度对齐
img_v3_02124_d7186960-6324-4898-adb3-c4f56456980g

问题:InfiniTrain 原生 CUDA GEMM 在 gemm.cu (line 63) 固定用CUBLAS_COMPUTE_32F,但 InfiniOps NVIDIA GEMM 原来对 fp32 走的是CUBLAS_COMPUTE_32F_FAST_TF32,这里我手动修改了 InfiniOps 源码

TODO:解决编译warning

@chen2021673 chen2021673 changed the title feat: add InfiniOps as optional kernel provider [WIP] feat: add InfiniOps as optional kernel provider May 28, 2026
Wire InfiniOps in as a pluggable kernel provider keyed at the GEMM level:
Dispatcher consults a per-key whitelist hook and routes registered ops to
InfiniOps, falling back to the default CUDA kernel otherwise. linear, matmul
and outer now invoke Gemm via Dispatcher rather than calling the cuBLAS
wrapper directly, so InfiniOps Gemm transparently covers all three.
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里是出于什么原因要单独写一套 registry,而不能直接复用 InfiniTrain 原有的注册表呢?

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个头文件内容没什么问题,但不适合放到 include 里作为公共头文件暴露,先放 infini_train/src/kernels/common 里吧

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里不应该给 infinops 开额外分支,之前接沐曦 kernel 这块是不需要动的。

@@ -0,0 +1,25 @@
#include "infini_train/include/core/kernel_provider/infiniops/adapter.h"
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个不是配套头文件吧


} // namespace infini_train::kernel_provider::infiniops

REGISTER_INFINIOPS_KERNEL(AddForward, infini_train::kernel_provider::infiniops::AddForward)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

如果是为了修改注册 key 而专门给 infiniops 写一套注册机制的话感觉不是很有必要,直接按平台注册就行。

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这部分是必要的通用 gemm 接口抽象改动,不涉及 infiniops 相关,可以考虑单独提 pr 先合。

// FIXME: Requires stride tracking in the Tensor class before this can be implemented
// correctly. Currently always returns true as a placeholder. The contiguous guard in
// elementwise.cu ensures non-contiguous tensors fall back to the broadcast path.
// the elementwise provider ensures non-contiguous tensors fall back to the broadcast path.
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这块不用改吧

std::shared_ptr<Tensor> Contiguous();
// FIXME: Currently returns true unconditionally. Requires stride tracking in the Tensor
// class before this can be implemented correctly. The guard in elementwise.cu ensures
// class before this can be implemented correctly. The elementwise broadcast guard ensures
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

不用改吧

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants