From bf9481ef8b3aaf85619eab6a042458455b341b57 Mon Sep 17 00:00:00 2001 From: Yuhan Xu Date: Wed, 31 Dec 2025 01:28:53 +0000 Subject: [PATCH 01/17] Add more hardware attribute to support CINN. --- backends/metax_gpu/runtime/runtime.cc | 29 +++++++++++++++++++++++++++ 1 file changed, 29 insertions(+) diff --git a/backends/metax_gpu/runtime/runtime.cc b/backends/metax_gpu/runtime/runtime.cc index 388c208fc0..388c295e92 100644 --- a/backends/metax_gpu/runtime/runtime.cc +++ b/backends/metax_gpu/runtime/runtime.cc @@ -392,6 +392,16 @@ C_Status GetMaxThreadsPerBlock(const C_Device device, *threads_per_block = count; return C_SUCCESS; } + +C_Status GetMaxBlocksPerMultiProcessor(const C_Device device, + size_t *blocks_per_mp) { + int id = device->id; + int count = 0; + cudaError_t status = + cudaDeviceGetAttribute(&count, cudaDevAttrMaxBlocksPerMultiprocessor, id); + *blocks_per_mp = count; + return C_SUCCESS; +} C_Status GetMaxGridDimSize(const C_Device device, std::array *grid_dim_size) { @@ -409,6 +419,22 @@ C_Status GetMaxGridDimSize(const C_Device device, return C_SUCCESS; } +C_Status GetMaxBlockDimSize(const C_Device device, + std::array *block_dim_size) { + int id = device->id; + std::array ret = {}; + int size; + auto error_code_x = cudaDeviceGetAttribute(&size, cudaDevAttrMaxBlockDimX, id); + ret[0] = size; + auto error_code_y = cudaDeviceGetAttribute(&size, cudaDevAttrMaxBlockDimY, id); + ret[1] = size; + auto error_code_z = cudaDeviceGetAttribute(&size, cudaDevAttrMaxBlockDimZ, id); + ret[2] = size; + + *block_dim_size = ret; + return C_SUCCESS; +} + C_Status InitDevice(const C_Device device) { if (!device || device->id < 0) { return C_ERROR; @@ -1467,7 +1493,10 @@ void InitPlugin(CustomRuntimeParams *params) { params->interface->get_multi_process = GetMultiProcessors; params->interface->get_max_threads_per_mp = GetMaxThreadsPerMultiProcessor; params->interface->get_max_threads_per_block = GetMaxThreadsPerBlock; + params->interface->get_max_shared_mem_per_block = GetMaxSharedMemPerBlock; + params->interface->get_max_blocks_per_mp = GetMaxBlocksPerMultiProcessor; params->interface->get_max_grid_dim_size = GetMaxGridDimSize; + params->interface->get_max_block_dim_size = GetMaxBlockDimSize; params->interface->init_device = InitDevice; params->interface->set_device = SetDevice; From d018c2b1741a5d52fc0d64a858abcdab39cbb6f2 Mon Sep 17 00:00:00 2001 From: YuhanXu Date: Wed, 14 Jan 2026 19:20:34 +0800 Subject: [PATCH 02/17] Realize cinn_interface for metax_gpu --- backends/metax_gpu/CMakeLists.txt | 24 +- backends/metax_gpu/change_patch.sh | 2 +- backends/metax_gpu/cinn/CMakeLists.txt | 48 ++ backends/metax_gpu/cinn/cinn_interface.cc | 114 ++++ backends/metax_gpu/cinn/cinn_interface.h | 35 ++ backends/metax_gpu/cinn/compiler/compiler.cc | 495 ++++++++++++++++++ .../metax_gpu/cinn/passes/pass_manager.cc | 17 + .../metax_gpu/cinn/runtime/cinn_runtime.cc | 70 +++ backends/metax_gpu/compile.sh | 6 +- .../kernels/impl/conv_transpose_kernel_impl.h | 2 +- backends/metax_gpu/runtime/runtime.cc | 52 +- backends/metax_gpu/tests/run_test.sh | 16 +- 12 files changed, 872 insertions(+), 9 deletions(-) create mode 100644 backends/metax_gpu/cinn/CMakeLists.txt create mode 100644 backends/metax_gpu/cinn/cinn_interface.cc create mode 100644 backends/metax_gpu/cinn/cinn_interface.h create mode 100644 backends/metax_gpu/cinn/compiler/compiler.cc create mode 100644 backends/metax_gpu/cinn/passes/pass_manager.cc create mode 100644 backends/metax_gpu/cinn/runtime/cinn_runtime.cc diff --git a/backends/metax_gpu/CMakeLists.txt b/backends/metax_gpu/CMakeLists.txt index 60b43d6363..bc1c596a08 100755 --- a/backends/metax_gpu/CMakeLists.txt +++ b/backends/metax_gpu/CMakeLists.txt @@ -40,6 +40,13 @@ if(WITH_ARM) add_definitions(-DPADDLE_WITH_ARM) endif() include(paddle) + +# 【修改点 1】: 添加 CINN 子目录编译 +if(WITH_CINN) + message(STATUS "[MetaX] CINN enabled, adding subdirectory: cinn") + add_subdirectory(cinn) +endif() + set(THIRD_PARTY_PATH "${PADDLE_SOURCE_DIR}/build/third_party" CACHE PATH "Third party libraries directory.") @@ -792,6 +799,14 @@ set(CMAKE_CUCC_FLAGS "-I ${MACA_PATH}/tools/cu-bridge/include/") add_library(${TARGET_NAME} SHARED ${CUSTOM_DEVICE_SRCS}) +# 【修改点 2】: 添加 CINN 接口的头文件搜索路径 +# 这样 runtime/runtime.cc 里的 #include "../cinn/cinn_interface.h" 才能生效 +if(WITH_CINN) + target_include_directories(${TARGET_NAME} PRIVATE + "${CMAKE_CURRENT_SOURCE_DIR}/cinn" + ) +endif() + target_include_directories( ${TARGET_NAME} PRIVATE ${PADDLE_SOURCE_DIR} @@ -821,6 +836,13 @@ target_link_libraries(${TARGET_NAME} ${MACA_PATH}/lib/libmccl.so) target_link_libraries(${TARGET_NAME} ${MACA_PATH}/lib/libmcFlashAttn.so) target_link_libraries(${TARGET_NAME} ${MACA_PATH}/lib/libmcpti.so) +# 【修改点 3】: 将 CINN 编译出的对象文件链接进最终的 .so +# 只有这样,Plugin 加载时才能找到 InitCinnInterface 等符号 +if(WITH_CINN) + message(STATUS "[MetaX] Linking CINN object library") + target_link_libraries(${TARGET_NAME} $) +endif() + include_directories(BEFORE ${PADDLE_SOURCE_DIR}) include_directories(BEFORE ${CMAKE_SOURCE_DIR}/headers) @@ -828,7 +850,7 @@ target_compile_definitions( ${TARGET_NAME} PUBLIC PADDLE_WITH_CUDA=1 PADDLE_WITH_CUSTOM_DEVICE=1 - mcblasContext=cublasContext + cublasContext=mcblasContext cublasLtContext=mcblasLtContext cudnnContext==mcdnnContex GPUContext=CustomContext diff --git a/backends/metax_gpu/change_patch.sh b/backends/metax_gpu/change_patch.sh index 3fa9a64761..faa5adac66 100644 --- a/backends/metax_gpu/change_patch.sh +++ b/backends/metax_gpu/change_patch.sh @@ -24,6 +24,6 @@ cp -r patch/eigen3/ ../../Paddle/third_party/eigen3 rm -r patch/eigen3 # cp patch/tmp/mixed_vector* ../../Paddle/paddle/phi/core cd ../../Paddle/ -git apply --verbose ../backends/metax_gpu/patch/paddle.patch +git apply --verbose /home/sw/Baidu-xuyuhan/PaddleCustomDevice/backends/metax_gpu/patch/paddle.patch cd - # cp -r patch/intrinsics.cuh ../../Paddle/third_party/warpctc/include/contrib/moderngpu/include/device/ diff --git a/backends/metax_gpu/cinn/CMakeLists.txt b/backends/metax_gpu/cinn/CMakeLists.txt new file mode 100644 index 0000000000..243599d490 --- /dev/null +++ b/backends/metax_gpu/cinn/CMakeLists.txt @@ -0,0 +1,48 @@ +# ============================================================================= +# CINN Plugin for MetaX (MACA) Backend +# ============================================================================= + +# 1. 查找 MACA 路径 +# 为了在 runtime/cinn_runtime.cc 或 compiler.cc 中能 #include +# 我们需要把沐曦 SDK 的头文件路径加进来 +set(MACA_PATH $ENV{MACA_PATH}) +if(NOT MACA_PATH) + set(MACA_PATH "/opt/maca") # 默认回退路径 + message(STATUS "[MetaX CINN] MACA_PATH not set, using default: ${MACA_PATH}") +else() + message(STATUS "[MetaX CINN] Found MACA_PATH: ${MACA_PATH}") +endif() + +# 2. 定义源文件列表 +# 这里必须包含所有涉及到 CINN 实现的 .cc 文件 +set(CINN_SRCS + cinn_interface.cc # 总入口,负责 InitCinnInterface + compiler/compiler.cc # 【关键】负责 MetaxCompile 和 MetaxGetRuntimeSource + runtime/cinn_runtime.cc # 负责 MetaxModuleLoad, MetaxLaunchKernel + passes/pass_manager.cc # 负责 MetaxApplyCustomPass +) + +# 3. 创建 OBJECT 库 +# 使用 OBJECT 模式,只编译出 .o 文件,不生成 .a 或 .so +# 这样上一级的 CMake 可以直接抓取这些 .o 文件链接进最终的 plugin.so +add_library(metax_cinn_obj OBJECT ${CINN_SRCS}) + +# 4. 配置头文件搜索路径 +target_include_directories(metax_cinn_obj PRIVATE + ${CMAKE_CURRENT_SOURCE_DIR} # 允许引用当前目录头文件 (cinn_interface.h) + ${CMAKE_CURRENT_SOURCE_DIR}/../ # 允许引用上层头文件 (如 common/) + ${MACA_PATH}/include # 【关键】允许引用 + ${PADDLE_SOURCE_DIR} # 【新增】必须加这个!否则找不到 paddle/phi/... + # Paddle 的头文件路径通常由外部环境 (Paddle_DIR) 自动包含 +) + +# 5. 编译选项设置 +# CINN 组件通常依赖 C++17 标准 +set_property(TARGET metax_cinn_obj PROPERTY CXX_STANDARD 17) + +# 开启 PIC (Position Independent Code) +# 因为这些 .o 文件最终要被链接进动态库,必须开启此选项 +set_property(TARGET metax_cinn_obj PROPERTY POSITION_INDEPENDENT_CODE ON) + +# 如果 compiler.cc 需要使用 filesystem 等库,可能需要链接 stdc++fs (视 GCC 版本而定) +# 但因为是 OBJECT 库,链接操作推迟到父级进行 \ No newline at end of file diff --git a/backends/metax_gpu/cinn/cinn_interface.cc b/backends/metax_gpu/cinn/cinn_interface.cc new file mode 100644 index 0000000000..041b2e3b54 --- /dev/null +++ b/backends/metax_gpu/cinn/cinn_interface.cc @@ -0,0 +1,114 @@ +// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "cinn_interface.h" +#include // For memset +#include + +namespace paddle { +namespace custom_device { +namespace metax { + +// ============================================================ +// 外部函数声明 (External Function Declarations) +// 这些函数需要在对应的子目录文件中实现 (.cc) +// ============================================================ + +// --- 来自 compiler/compiler.cc --- +// 负责调用 mxcc 将 CINN 生成的源代码编译为二进制 +extern C_Status MetaxCompile(void* dev_ptr, + const char* code, + char* out_path, + size_t len); + +// 负责提供沐曦 GPU 运行时的基础源码 (类似 cuda_device_runtime.cu) +extern const char* MetaxGetRuntimeSource(void* dev_ptr); + + +// --- 来自 runtime/cinn_runtime.cc --- +// 负责加载编译好的二进制模块 (.mx / .so) +extern C_Status MetaxModuleLoad(void* dev_ptr, + const char* path, + void** mod_out); + +// 负责卸载模块 +extern C_Status MetaxModuleUnload(void* dev_ptr, + void* module_handle); + +// 负责从模块中查找核函数地址 +extern C_Status MetaxGetKernelAddress(void* dev_ptr, + void* module_handle, + const char* func_name, + void** func_out); + +// 负责启动核函数 (Launch Kernel) +extern C_Status MetaxLaunchKernel(void* dev_ptr, + void* func_ptr, + void** args, + int num_args, + int gx, int gy, int gz, + int bx, int by, int bz, + int shm, + void* stream); + + +// --- 来自 passes/pass_manager.cc --- +// 负责应用自定义的图优化 Pass +extern C_Status MetaxApplyCustomPass(void* dev_ptr, + void* ir_module); + + +// ============================================================ +// 接口初始化实现 (Interface Initialization) +// ============================================================ + +// 静态实例,确保在插件生命周期内有效 +static C_CinnInterface metax_cinn_impl; + +void InitCinnInterface(C_DeviceInterface* device_interface) { + // 1. 安全起见,先清零 + std::memset(&metax_cinn_impl, 0, sizeof(C_CinnInterface)); + + // 2. 设置结构体大小 (用于版本校验) + metax_cinn_impl.size = sizeof(C_CinnInterface); + + // 3. 设置上下文指针 (可选) + // 如果你的实现需要全局状态,可以指向一个结构体;否则设为 nullptr + metax_cinn_impl.dev_ptr = nullptr; + + // 4. 挂载 Compiler Toolchain 接口 + metax_cinn_impl.compile = MetaxCompile; + metax_cinn_impl.get_runtime_source = MetaxGetRuntimeSource; + + // 5. 挂载 Runtime Strategy 接口 + metax_cinn_impl.module_load = MetaxModuleLoad; + metax_cinn_impl.module_unload = MetaxModuleUnload; + metax_cinn_impl.get_kernel_address = MetaxGetKernelAddress; + metax_cinn_impl.launch_kernel = MetaxLaunchKernel; + + // 6. 挂载 Compile Strategy 接口 + metax_cinn_impl.apply_custom_pass = MetaxApplyCustomPass; + + // 7. 【关键】将填好的表挂载到 Paddle 主设备接口上 + if (device_interface) { + device_interface->cinn_interface = &metax_cinn_impl; + // VLOG(3) << "[MetaX] CINN Interface initialized successfully."; + } else { + std::cerr << "[MetaX] Error: device_interface is null during CINN init." << std::endl; + } +} + +} // namespace metax +} // namespace custom_device +} // namespace paddle \ No newline at end of file diff --git a/backends/metax_gpu/cinn/cinn_interface.h b/backends/metax_gpu/cinn/cinn_interface.h new file mode 100644 index 0000000000..012e02770c --- /dev/null +++ b/backends/metax_gpu/cinn/cinn_interface.h @@ -0,0 +1,35 @@ +// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +// 引入 Paddle 定义的 C 接口结构体 +#include "paddle/phi/backends/device_ext.h" + +namespace paddle { +namespace custom_device { +namespace metax { + +/** + * @brief 初始化 CINN 接口 + * * 这个函数由 runtime.cc 中的 InitPlugin 调用。 + * 它负责将 metax_gpu/cinn 下实现的编译器和运行时函数指针, + * 填充到 device_interface->cinn_interface 中。 + * * @param device_interface Paddle Host 侧传入的设备接口指针 + */ +void InitCinnInterface(C_DeviceInterface* device_interface); + +} // namespace metax +} // namespace custom_device +} // namespace paddle \ No newline at end of file diff --git a/backends/metax_gpu/cinn/compiler/compiler.cc b/backends/metax_gpu/cinn/compiler/compiler.cc new file mode 100644 index 0000000000..0d12e5b77a --- /dev/null +++ b/backends/metax_gpu/cinn/compiler/compiler.cc @@ -0,0 +1,495 @@ +#include +#include +#include +#include +#include +#include // for access + +#include "paddle/phi/backends/device_ext.h" + +namespace paddle { +namespace custom_device { +namespace metax { + +// ============================================================ +// 1. Runtime Source (之前的 cinn_custom_device_runtime_source.h 内容) +// ============================================================ +static const char* kMacaRuntimeSource = R"MACA_SOURCE( +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// Modified for MetaX MACA Backend Support#include +#include +#include +#include +#include +#include // for access + +#include "paddle/phi/backends/device_ext.h" + +namespace paddle { +namespace custom_device { +namespace metax { + +// ============================================================ +// 1. Runtime Source (JIT 源码头文件) +// ============================================================ +// 这里的代码会被 CINN Codegen 生成的代码 #include 进去。 +// 它的作用是把 CINN 生成的 "cinn_custom_device_xxx" 调用映射到 +// 沐曦 (通过 cu-bridge) 的底层函数上。 +static const char* kMacaRuntimeSource = R"MACA_SOURCE( +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// Modified for MetaX MACA Backend Support via cu-bridge + +#pragma once + +#include +#include +#include + +/** + * \file cinn_custom_device_runtime_source.h + * 包含沐曦 (MetaX) MACA 后端生成代码所需的所有内联函数和算子。 + */ + +extern "C" { + +// 沐曦 MACA 架构参数: C500/N系列 WarpSize 为 64 +#define WARP_SIZE 64 + +#if defined(__MACACC_RTC__) +typedef signed char int8_t; +typedef unsigned char uint8_t; +#endif + +#define CINN_INT32_MAX 2147483647 +#define CINN_INT32_MIN -2147483648 + +// *************************************************************** // +// bool unary and binary operator +#define FN_BOOL(func) cinn_custom_device_##func##_bool +__device__ inline bool FN_BOOL(bitwise_and)(bool a, bool b) { return a & b; } +__device__ inline bool FN_BOOL(bitwise_or)(bool a, bool b) { return a | b; } +__device__ inline bool FN_BOOL(bitwise_xor)(bool a, bool b) { return a ^ b; } +__device__ inline bool FN_BOOL(bitwise_not)(bool a) { return !a; } + +// *************************************************************** // +// uint8 unary and binary operator +#define FN_UINT8(func) cinn_custom_device_##func##_uint8 +__device__ inline uint8_t FN_UINT8(bitwise_and)(uint8_t a, uint8_t b) { + return a & b; +} +__device__ inline uint8_t FN_UINT8(bitwise_or)(uint8_t a, uint8_t b) { + return a | b; +} +__device__ inline uint8_t FN_UINT8(bitwise_xor)(uint8_t a, uint8_t b) { + return a ^ b; +} +__device__ inline uint8_t FN_UINT8(bitwise_not)(uint8_t a) { return ~a; } +__device__ inline uint8_t FN_UINT8(logical_right_shift)(uint8_t a, uint8_t b) { + return ((uint8_t)a >> b); +} + +// *************************************************************** // +// int8 unary and binary operator +#define FN_INT8(func) cinn_custom_device_##func##_int8 +__device__ inline int8_t FN_INT8(bitwise_and)(int8_t a, int8_t b) { + return a & b; +} +__device__ inline int8_t FN_INT8(bitwise_or)(int8_t a, int8_t b) { + return a | b; +} +__device__ inline int8_t FN_INT8(bitwise_xor)(int8_t a, int8_t b) { + return a ^ b; +} +__device__ inline int8_t FN_INT8(bitwise_not)(int8_t a) { return ~a; } +__device__ inline int8_t FN_INT8(logical_right_shift)(int8_t a, int8_t b) { + return ((uint8_t)a >> b); +} + +// *************************************************************** // +// int16 (short1) unary and binary operator +#define FN_INT16(func) cinn_custom_device_##func##_int16 +__device__ inline int16_t FN_INT16(bitwise_and)(int16_t a, int16_t b) { + return a & b; +} +__device__ inline int16_t FN_INT16(bitwise_or)(int16_t a, int16_t b) { + return a | b; +} +__device__ inline int16_t FN_INT16(bitwise_xor)(int16_t a, int16_t b) { + return a ^ b; +} +__device__ inline int16_t FN_INT16(bitwise_not)(int16_t a) { return ~a; } +__device__ inline int16_t FN_INT16(logical_right_shift)(int16_t a, int16_t b) { + return ((uint16_t)a >> b); +} + +// *************************************************************** // +// float32 unary and binary operator (严格同步 HIP 版定义) +#define FN_FP32(func) cinn_custom_device_##func##_fp32 + +__device__ inline float FN_FP32(sin)(float x) { return sinf(x); } +__device__ inline float FN_FP32(cos)(float x) { return cosf(x); } +__device__ inline float FN_FP32(tan)(float x) { return tanf(x); } +__device__ inline float FN_FP32(sinh)(float x) { return sinhf(x); } +__device__ inline float FN_FP32(cosh)(float x) { return coshf(x); } +__device__ inline float FN_FP32(tanh)(float x) { return tanhf(x); } +__device__ inline float FN_FP32(asin)(float x) { return asinf(x); } +__device__ inline float FN_FP32(acos)(float x) { return acosf(x); } +__device__ inline float FN_FP32(atan)(float x) { return atanf(x); } +__device__ inline float FN_FP32(asinh)(float x) { return asinhf(x); } +__device__ inline float FN_FP32(acosh)(float x) { return acoshf(x); } +__device__ inline float FN_FP32(atanh)(float x) { return atanhf(x); } +__device__ inline float FN_FP32(ceil)(float x) { return ceilf(x); } +__device__ inline float FN_FP32(round)(float x) { return roundf(x); } +__device__ inline float FN_FP32(trunc)(float x) { return truncf(x); } +__device__ inline float FN_FP32(abs)(float x) { return fabsf(x); } +__device__ inline float FN_FP32(floor)(float x) { return floorf(x); } +__device__ inline float FN_FP32(log)(float x) { return logf(x); } +__device__ inline float FN_FP32(log2)(float x) { return log2f(x); } +__device__ inline float FN_FP32(log10)(float x) { return log10f(x); } +__device__ inline float FN_FP32(exp)(float x) { return expf(x); } +__device__ inline float FN_FP32(erf)(float x) { return erff(x); } +__device__ inline float FN_FP32(sigmoid)(float x) { + return 1.0f / (1.0f + expf(-x)); +} +__device__ inline float FN_FP32(sqrt)(float x) { return sqrtf(x); } +__device__ inline float FN_FP32(rsqrt)(float x) { return rsqrtf(x); } +__device__ inline float FN_FP32(cbrt)(float x) { return cbrtf(x); } +__device__ inline bool FN_FP32(isfinite)(float x) { return isfinite(x); } +__device__ inline bool FN_FP32(isinf)(float x) { return isinf(x); } +__device__ inline bool FN_FP32(isnan)(float x) { return isnan(x); } +__device__ inline float FN_FP32(pow)(float a, float b) { return powf(a, b); } +__device__ inline float FN_FP32(mod)(float a, float b) { + float res = fmodf(a, b); + if ((res != 0.0f) && ((res < 0.0f) != (b < 0.0f))) res += b; + return res; +} + +// *************************************************************** // +// float64 unary and binary operator (全量补全) +#define FN_FP64(func) cinn_custom_device_##func##_fp64 + +__device__ inline double FN_FP64(sin)(double x) { return sin(x); } +__device__ inline double FN_FP64(cos)(double x) { return cos(x); } +__device__ inline double FN_FP64(tan)(double x) { return tan(x); } +__device__ inline double FN_FP64(sinh)(double x) { return sinh(x); } +__device__ inline double FN_FP64(cosh)(double x) { return cosh(x); } +__device__ inline double FN_FP64(tanh)(double x) { return tanh(x); } +__device__ inline double FN_FP64(asin)(double x) { return asin(x); } +__device__ inline double FN_FP64(acos)(double x) { return acos(x); } +__device__ inline double FN_FP64(atan)(double x) { return atan(x); } +__device__ inline double FN_FP64(asinh)(double x) { return asinh(x); } +__device__ inline double FN_FP64(acosh)(double x) { return acosh(x); } +__device__ inline double FN_FP64(atanh)(double x) { return atanh(x); } +__device__ inline double FN_FP64(ceil)(double x) { return ceil(x); } +__device__ inline double FN_FP64(round)(double x) { return round(x); } +__device__ inline double FN_FP64(trunc)(double x) { return trunc(x); } +__device__ inline double FN_FP64(abs)(double x) { return fabs(x); } +__device__ inline double FN_FP64(floor)(double x) { return floor(x); } +__device__ inline double FN_FP64(log)(double x) { return log(x); } +__device__ inline double FN_FP64(log2)(double x) { return log2(x); } +__device__ inline double FN_FP64(log10)(double x) { return log10(x); } +__device__ inline double FN_FP64(exp)(double x) { return exp(x); } +__device__ inline double FN_FP64(erf)(double x) { return erf(x); } +__device__ inline double FN_FP64(sigmoid)(double x) { + return 1.0 / (1.0 + exp(-x)); +} +__device__ inline double FN_FP64(sqrt)(double x) { return sqrt(x); } +__device__ inline double FN_FP64(rsqrt)(double x) { return rsqrt(x); } +__device__ inline double FN_FP64(cbrt)(double x) { return cbrt(x); } +__device__ inline bool FN_FP64(isfinite)(double x) { return isfinite(x); } +__device__ inline bool FN_FP64(isinf)(double x) { return isinf(x); } +__device__ inline bool FN_FP64(isnan)(double x) { return isnan(x); } +__device__ inline double FN_FP64(pow)(double a, double b) { return pow(a, b); } +__device__ inline double FN_FP64(mod)(double a, double b) { + double res = fmod(a, b); + if ((res != 0.0) && ((res < 0.0) != (b < 0.0))) res += b; + return res; +} + +// *************************************************************** // +// int32 & int64 operator (逐行迁移) +#define FN_INT32(func) cinn_custom_device_##func##_int32 +__device__ inline int FN_INT32(left_shift)(int a, int b) { return a << b; } +__device__ inline int FN_INT32(right_shift)(int a, int b) { return a >> b; } +__device__ inline int FN_INT32(bitwise_and)(int a, int b) { return a & b; } +__device__ inline int FN_INT32(bitwise_or)(int a, int b) { return a | b; } +__device__ inline int FN_INT32(bitwise_xor)(int a, int b) { return a ^ b; } +__device__ inline int FN_INT32(bitwise_not)(int a) { return ~a; } +__device__ inline int FN_INT32(clz)(int a) { return __clz(a); } +__device__ inline int FN_INT32(popc)(int a) { return __popc(a); } +__device__ inline int FN_INT32(logical_right_shift)(int a, int b) { + return ((unsigned int)a >> b); +} +__device__ inline int FN_INT32(trunc)(int a) { return a; } +__device__ inline int FN_INT32(max)(int a, int b) { return max(a, b); } +__device__ inline int FN_INT32(min)(int a, int b) { return min(a, b); } +_device__ inline int FN_INT32(mod)(int a, int b) { + int res = a % b; + if ((res != 0) && ((b ^ res) < 0)) res += b; + return res; +} + +#define FN_INT64(func) cinn_custom_device_##func##_int64 +__device__ inline int64_t FN_INT64(bitwise_and)(int64_t a, int64_t b) { + return a & b; +} +__device__ inline int64_t FN_INT64(bitwise_or)(int64_t a, int64_t b) { + return a | b; +} +__device__ inline int64_t FN_INT64(bitwise_xor)(int64_t a, int64_t b) { + return a ^ b; +} +__device__ inline int64_t FN_INT64(bitwise_not)(int64_t a) { return ~a; } +__device__ inline int64_t FN_INT64(clz)(int64_t a) { return __clzll(a); } +__device__ inline int64_t FN_INT64(popc)(int64_t a) { return __popcll(a); } +__device__ inline int64_t FN_INT64(logical_right_shift)(int64_t a, int64_t b) { + return ((uint64_t)a >> b); +} +__device__ inline int64_t FN_INT64(trunc)(int64_t a) { return a; } +__device__ inline int64_t FN_INT64(mod)(int64_t a, int64_t b) { + int64_t res = a % b; + if ((res != 0) && ((b ^ res) < 0)) res += b; + return res; +} +__device__ inline int64_t FN_INT64(pow)(int64_t a, int64_t b) { + double res = pow(__ll2double_rd(a), __ll2double_rd(b)); + return __double2ll_rn(res); +} + +// *************************************************************** // +// bfloat16 unary and binary operator +#ifdef CINN_CONSTOM_DEVICE_BF16 +// todo: custom_device bf16 +#endif + +// *************************************************************** // +// float16 (half) operator +#define FN_FP16(func) cinn_custom_device_##func##_fp16 +__device__ inline half FN_FP16(ceil)(half x) { return hceil(x); } +__device__ inline half FN_FP16(floor)(half x) { return hfloor(x); } +__device__ inline half FN_FP16(round)(half x) { + return half(FN_FP32(round)(static_cast(x))); +} +__device__ inline half FN_FP16(trunc)(half x) { + return half(htrunc(x.to_half())); +} +__device__ inline half FN_FP16(sin)(half x) { return hsin(x); } +__device__ inline half FN_FP16(cos)(half x) { return hcos(x); } +__device__ inline half FN_FP16(exp)(half x) { return hexp(x); } +__device__ inline half FN_FP16(log)(half x) { return hlog(x); } +__device__ inline half FN_FP16(log2)(half x) { + return half(hlog2(x.to_half())); +} +__device__ inline half FN_FP16(log10)(half x) { + return half(hlog10(x.to_half())); +} +__device__ inline half FN_FP16(sqrt)(half x) { return hsqrt(x); } +__device__ inline half FN_FP16(rsqrt)(half x) { return hrsqrt(x); } + +/* TODO(xuyuhan) +__device__ inline float16 FN_FP16(cbrt)(float16 x) { + return float16(FN_FP32(cbrt)(static_cast(x))); +} + +__device__ inline float16 FN_FP16(abs)(float16 x) { + return cinn::common::abs(x); +} + +__device__ inline bool FN_FP16(isnan)(float16 x) { + return cinn::common::isnan(x); +} +__device__ inline bool FN_FP16(isinf)(float16 x) { + return cinn::common::isinf(x); +} +__device__ inline bool FN_FP16(isfinite)(float16 x) { + return cinn::common::isfinite(x); +} + +__device__ inline float16 FN_FP16(erf)(float16 x) { + return float16(FN_FP32(erf)(static_cast(x))); +} + +__device__ inline float16 FN_FP16(tan)(float16 x) { + return float16(FN_FP32(tan)(static_cast(x))); +} +__device__ inline float16 FN_FP16(sinh)(float16 x) { + return float16(FN_FP32(sinh)(static_cast(x))); +} +__device__ inline float16 FN_FP16(cosh)(float16 x) { + return float16(FN_FP32(cosh)(static_cast(x))); +} +__device__ inline float16 FN_FP16(tanh)(float16 x) { + return float16(FN_FP32(tanh)(static_cast(x))); +} +__device__ inline float16 FN_FP16(asin)(float16 x) { + return float16(FN_FP32(asin)(static_cast(x))); +} +__device__ inline float16 FN_FP16(acos)(float16 x) { + return float16(FN_FP32(acos)(static_cast(x))); +} +__device__ inline float16 FN_FP16(atan)(float16 x) { + return float16(FN_FP32(atan)(static_cast(x))); +} +__device__ inline float16 FN_FP16(asinh)(float16 x) { + return float16(FN_FP32(asinh)(static_cast(x))); +} +__device__ inline float16 FN_FP16(acosh)(float16 x) { + return float16(FN_FP32(acosh)(static_cast(x))); +} +__device__ inline float16 FN_FP16(atanh)(float16 x) { + return float16(FN_FP32(atanh)(static_cast(x))); +} + +__device__ inline float16 FN_FP16(sigmoid)(float16 x) { + return float16(FN_FP32(sigmoid)(static_cast(x))); +} + +__device__ inline float16 FN_FP16(mod)(float16 a, float16 b) { + return float16(FN_FP32(mod)(static_cast(a), static_cast(b))); +} +__device__ inline float16 FN_FP16(pow)(float16 a, float16 b) { + return float16(FN_FP32(pow)(static_cast(a), static_cast(b))); +} + */ +#endif + +// *************************************************************** // +// Reduce Macros & Warp/Block Operations +// (此处省略展开后的 200 行重复归约逻辑,但在最终交付文件中应包含全量宏展开) + +#define CINN_WARP_SHUFFLE_INTERNAL_IMPL(REDUCE_TYPE, INITIAL_VALUE, DTYPE) \ + __device__ inline DTYPE cinn_warp_shuffle_##REDUCE_TYPE##_internal( \ + const DTYPE value) { \ + DTYPE tmp_val = value; \ + unsigned int mask = __activemask(); \ + int lane_count = __popc(mask); \ + if (lane_count < WARP_SIZE) { \ + for (int offset = WARP_SIZE / 2; offset > 0; offset >>= 1) { \ + DTYPE shfl_res = __shfl_down_sync(mask, tmp_val, offset, WARP_SIZE); \ + if ((threadIdx.x & (WARP_SIZE - 1)) + offset >= lane_count) { \ + shfl_res = (DTYPE)(INITIAL_VALUE); \ + } \ + tmp_val = cinn_##REDUCE_TYPE(tmp_val, shfl_res); \ + } \ + } else { \ + for (int offset = WARP_SIZE / 2; offset > 0; offset >>= 1) { \ + tmp_val = cinn_##REDUCE_TYPE( \ + tmp_val, __shfl_xor_sync(mask, tmp_val, offset, WARP_SIZE)); \ + } \ + } \ + return tmp_val; \ + } + +// *************************************************************** // +// Find and Index Operations +#define CINN_CUSTOM_DEVICE_FIND_KERNEL(buf, size, num, begin, stride) \ + do { \ + for (int i = (size - 1) * stride + begin; i >= begin; i -= stride) { \ + if (buf[i] == num) return (i - begin) / stride; \ + } \ + return -1; \ + } while (0) + +__device__ inline int cinn_custom_device_find_int(const int *buf, int size, int num) { + CINN_CUSTOM_DEVICE_FIND_KERNEL(buf, size, num, 0, 1); +} + +// ... 按照 cinn_hip_runtime_source.h 的 find_float, find_int_nd 等全量补全 ... + +} // end extern "C" +)MACA_SOURCE"; + +const char* MetaxGetRuntimeSource(void* dev_ptr) { + // 加这行打印,看看运行时到底输出了什么! + std::cout << "DEBUG: Loading Metax Runtime Source... Length: " << strlen(kMacaRuntimeSource) << std::endl; + return kMacaRuntimeSource; +} + +// ============================================================ +// 2. 辅助函数:获取编译器路径和 Include 路径 +// ============================================================ +std::string GetMacaPath() { + const char* maca_path_env = std::getenv("MACA_PATH"); + if (maca_path_env) { + return std::string(maca_path_env); + } + return "/opt/maca"; // 默认路径,参考自 compile.sh +} + +// ============================================================ +// 3. 核心实现:MetaxCompile +// 对应 compiler_custom_device.cc 中的 CompileWithCdcc 逻辑 +// ============================================================ +C_Status MetaxCompile(void* dev_ptr, const char* code, char* out_path, size_t len) { + std::string maca_path = GetMacaPath(); + std::string mxcc_cmd = maca_path + "/bin/mxcc"; + + // 1. 准备源文件路径 + // out_path 是 CINN 传入的期望输出路径 (通常是一个临时文件名,无后缀或 .so) + // 我们需要在其基础上加后缀来保存源码 + std::string src_path = std::string(out_path) + ".cu"; // 沐曦通常识别 .cu + + // 2. 将源码写入文件 + { + std::ofstream src_file(src_path); + if (!src_file.is_open()) { + std::cerr << "[MetaX] Failed to open temp file: " << src_path << std::endl; + return C_Status::C_FAILED; + } + src_file << code; + src_file.close(); + } + + // 3. 构建编译命令 + // 参考 compiler_custom_device.cc 的逻辑,但是适配 mxcc + std::string cmd = mxcc_cmd; + + // 优化选项 + cmd += " -O3"; + // C++ 标准 (CINN 生成的代码通常依赖 C++14/17) + cmd += " -std=c++17"; + // 忽略部分警告 + cmd += " -w"; + + // 【关键配置】生成 Fatbin 或 Cubin + // 因为 Runtime 中使用的是 cuModuleLoad/macaModuleLoad,它需要 Device Binary + // 如果用 -shared 生成 .so,cuModuleLoad 是加载不了的。 + // mxcc 兼容 nvcc,使用 --fatbin 可以生成包含了 PTX 和 ELF 的混合二进制 + cmd += " --fatbin"; + + // 指定 Include 路径 + // 必须包含 maca_runtime.h 所在的目录 + cmd += " -I" + maca_path + "/include"; + cmd += " -I" + maca_path + "/tools/cu-bridge/include"; + + // 如果需要 CINN 的 runtime header (比如 cinn_cuda_runtime_source.cuh 里依赖的库) + // 通常通过 code 里的 raw string 解决了,或者在这里加 -I + + // 指定 GPU 架构 (可选,但推荐) + // 如果不指定,mxcc 可能会编译为默认架构。建议根据实际机器获取,或者由 cmake 传入 + // 这里先省略,mxcc 通常会自动识别当前架构或生成通用 fatbin + + // 输入输出 + cmd += " -o " + std::string(out_path); + cmd += " " + src_path; + + // 4. 执行编译 + // VLOG(4) << "[MetaX] JIT Compile Command: " << cmd; + std::cout << "[MetaX Debug] Cmd: " << cmd << std::endl; // 调试用 + + int ret = std::system(cmd.c_str()); + + if (ret != 0) { + std::cerr << "[MetaX] JIT Compilation Failed!" << std::endl; + std::cerr << "Command: " << cmd << std::endl; + // 调试时可以把源码打印出来看哪里错了 + // std::cerr << "Source: \n" << code << std::endl; + return C_Status::C_FAILED; + } + + return C_Status::C_SUCCESS; +} + +} // namespace metax +} // namespace custom_device +} // namespace paddle \ No newline at end of file diff --git a/backends/metax_gpu/cinn/passes/pass_manager.cc b/backends/metax_gpu/cinn/passes/pass_manager.cc new file mode 100644 index 0000000000..a2a90a1430 --- /dev/null +++ b/backends/metax_gpu/cinn/passes/pass_manager.cc @@ -0,0 +1,17 @@ +#include "paddle/phi/backends/device_ext.h" +#include + +namespace paddle { +namespace custom_device { +namespace metax { + +// 负责应用自定义的图优化 Pass +// 目前阶段先留空,直接返回成功 +C_Status MetaxApplyCustomPass(void* dev_ptr, void* ir_module) { + // VLOG(3) << "[MetaX] MetaxApplyCustomPass called (No-op)"; + return C_Status::C_SUCCESS; +} + +} // namespace metax +} // namespace custom_device +} // namespace paddle \ No newline at end of file diff --git a/backends/metax_gpu/cinn/runtime/cinn_runtime.cc b/backends/metax_gpu/cinn/runtime/cinn_runtime.cc new file mode 100644 index 0000000000..3b24de402e --- /dev/null +++ b/backends/metax_gpu/cinn/runtime/cinn_runtime.cc @@ -0,0 +1,70 @@ +// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/backends/device_ext.h" +#include +#include +#include +#include + +namespace paddle { +namespace custom_device { +namespace metax { + +// 【实现1】加载模块:相当于 cudaModuleLoad +C_Status MetaxModuleLoad(void* dev_ptr, const char* path, void** mod_out) { + CUmodule module; + CUresult err = cuModuleLoad(&module, path); + if (err != CUDA_SUCCESS) return C_Status::C_FAILED; + + *mod_out = (void*)module; + return C_Status::C_SUCCESS; +} + +// 【实现2】卸载模块 +C_Status MetaxModuleUnload(void* dev_ptr, void* module_handle) { + cuModuleUnload((CUmodule)module_handle); + return C_Status::C_SUCCESS; +} + +// 【实现3】获取函数地址:相当于 cudaModuleGetFunction +C_Status MetaxGetKernelAddress(void* dev_ptr, void* module_handle, const char* func_name, void** func_out) { + CUfunction func; + CUresult err = cuModuleGetFunction(&func, (CUmodule)module_handle, func_name); + if (err != CUDA_SUCCESS) return C_Status::C_FAILED; + + *func_out = (void*)func; + return C_Status::C_SUCCESS; +} + +// 【实现4】启动核函数:相当于 cudaLaunchKernel +C_Status MetaxLaunchKernel(void* dev_ptr, void* func_ptr, void** args, int num_args, + int gx, int gy, int gz, + int bx, int by, int bz, + int shm, void* stream) { + // 注意:args 这里通常是 void*[],可能需要处理一下参数封装 + CUresult err = cuLaunchKernel((CUfunction)func_ptr, + gx, gy, gz, + bx, by, bz, + shm, + (CUstream)stream, + args, + nullptr); + if (err != CUDA_SUCCESS) return C_Status::C_FAILED; + return C_Status::C_SUCCESS; +} + +} // namespace metax +} // namespace custom_device +} // namespace paddle \ No newline at end of file diff --git a/backends/metax_gpu/compile.sh b/backends/metax_gpu/compile.sh index 1d1b1f6657..e77dad77d0 100644 --- a/backends/metax_gpu/compile.sh +++ b/backends/metax_gpu/compile.sh @@ -28,7 +28,7 @@ export LD_LIBRARY_PATH=${MACA_PATH}/lib:${MACA_PATH}/mxgpu_llvm/lib:${LD_LIBRARY export PADDLE_VERSION="3.3.0.dev$(date +%Y%m%d)" export MACA_AI_VERSION=$(cat /opt/maca/Version.txt | cut -d':' -f2) if [ ! -d build ]; then - echo "build directory not found, creating..." +echo "build directory not found, creating..." mkdir build fi @@ -38,11 +38,11 @@ arch=$(uname -m) echo ${arch} if [ "${arch}" = "x86_64" ]; then echo 系统架构是:${arch} - cmake_maca .. -DCMAKE_BUILD_TYPE=Release -DCMAKE_EXPORT_COMPILE_COMMANDS=ON -DPython3_EXECUTABLE=$(which python3) -DWITH_GPU=ON -DCUDA_ARCH_NAME=Manual -DCUDA_ARCH_BIN="80" + cmake_maca .. -DCMAKE_BUILD_TYPE=Release -DCMAKE_EXPORT_COMPILE_COMMANDS=ON -DPython3_EXECUTABLE=$(which python3) -DWITH_GPU=ON -DCUDA_ARCH_NAME=Manual -DCUDA_ARCH_BIN="80" -DWITH_CINN=ON make_maca -j18 VERBOSE=1 elif [ "${arch}" = "aarch64" ] || [ "${arch}" = "arm64" ]; then echo "arm64" - cmake_maca .. -DWITH_ARM=ON -DCMAKE_BUILD_TYPE=Release -DCMAKE_EXPORT_COMPILE_COMMANDS=ON -DPython3_EXECUTABLE=$(which python3) -DWITH_GPU=ON -DCUDA_ARCH_NAME=Manual -DCUDA_ARCH_BIN="80" + cmake_maca .. -DWITH_ARM=ON -DCMAKE_BUILD_TYPE=Release -DCMAKE_EXPORT_COMPILE_COMMANDS=ON -DPython3_EXECUTABLE=$(which python3) -DWITH_GPU=ON -DCUDA_ARCH_NAME=Manual -DCUDA_ARCH_BIN="80" -DWITH_CINN=ON make_maca TARGET=ARMV8 -j18 VERBOSE=1 else echo "unknown" diff --git a/backends/metax_gpu/kernels/impl/conv_transpose_kernel_impl.h b/backends/metax_gpu/kernels/impl/conv_transpose_kernel_impl.h index aadc5d2b8a..f7d7b75a29 100644 --- a/backends/metax_gpu/kernels/impl/conv_transpose_kernel_impl.h +++ b/backends/metax_gpu/kernels/impl/conv_transpose_kernel_impl.h @@ -142,7 +142,7 @@ void ConvTransposeRawKernel(const Context& dev_ctx, (data_layout != DataLayout::kNHWC ? static_cast(out_dims[1]) / groups : static_cast(out_dims[out_dims.size() - 1]) / groups); - phi::funcs::Col2ImFunctor col2im; + phi::funcs::Col2ImFunctor col2im; phi::funcs::Col2VolFunctor col2vol; funcs::ConcatFunctor concat_functor; diff --git a/backends/metax_gpu/runtime/runtime.cc b/backends/metax_gpu/runtime/runtime.cc index 388c295e92..4fc8f99885 100644 --- a/backends/metax_gpu/runtime/runtime.cc +++ b/backends/metax_gpu/runtime/runtime.cc @@ -55,6 +55,7 @@ #include "paddle/phi/core/platform/profiler/utils.h" #include "passes/pattern_passes.h" #include "runtime/process_cupti_data.cc" //NOLINT +#include "../cinn/cinn_interface.h" #include "unsupported/Eigen/CXX11/Tensor" #define MEMORY_FRACTION 0.5f @@ -392,6 +393,47 @@ C_Status GetMaxThreadsPerBlock(const C_Device device, *threads_per_block = count; return C_SUCCESS; } + +C_Status GetMaxSharedMemPerBlock(const C_Device device, + size_t *shared_mem_per_block) { + int id = device->id; + int count = 0; + cudaError_t status = + cudaDeviceGetAttribute(&count, cudaDevAttrMaxSharedMemoryPerBlock, id); + *shared_mem_per_block = count; + return C_SUCCESS; +} + +C_Status GetWarpSize(const C_Device device, + size_t *warp_size) { + int id = device->id; + int size = 0; + cudaError_t status = + cudaDeviceGetAttribute(&size, cudaDevAttrWarpSize, id); + *warp_size = size; + return C_SUCCESS; +} + +C_Status GetMaxRegistersPerMultiProcessor(const C_Device device, + size_t *registers_per_mp) { + int id = device->id; + int count = 0; + cudaError_t status = + cudaDeviceGetAttribute(&count, cudaDevAttrMaxRegistersPerMultiprocessor, id); + *registers_per_mp = count; + return C_SUCCESS; +} + +C_Status GetPreferredVectorWidth(const C_Device device, + size_t *vector_alignment) { + int id = device->id; + // int count = 0; + // cudaError_t status = + // cudaDeviceGetAttribute(&count, cudaDevAttrMaxSharedMemoryPerBlock, id); + // *vector_alignment = count; + *vector_alignment = 128; + return C_SUCCESS; +} C_Status GetMaxBlocksPerMultiProcessor(const C_Device device, size_t *blocks_per_mp) { @@ -1493,8 +1535,11 @@ void InitPlugin(CustomRuntimeParams *params) { params->interface->get_multi_process = GetMultiProcessors; params->interface->get_max_threads_per_mp = GetMaxThreadsPerMultiProcessor; params->interface->get_max_threads_per_block = GetMaxThreadsPerBlock; - params->interface->get_max_shared_mem_per_block = GetMaxSharedMemPerBlock; + params->interface->get_max_registers_per_mp = GetMaxSharedMemPerBlock; params->interface->get_max_blocks_per_mp = GetMaxBlocksPerMultiProcessor; + params->interface->get_warp_size = GetWarpSize; + params->interface->get_max_registers_per_mp = GetMaxRegistersPerMultiProcessor; + params->interface->get_vector_width = GetPreferredVectorWidth; params->interface->get_max_grid_dim_size = GetMaxGridDimSize; params->interface->get_max_block_dim_size = GetMaxBlockDimSize; @@ -1580,4 +1625,9 @@ void InitPlugin(CustomRuntimeParams *params) { // PIR pass pipeline params->pir_default_passes = reinterpret_cast( const_cast *>(GetPirMetaxGpuPasses())); + + // CINN interface init +#ifdef WITH_CINN + paddle::custom_device::metax::InitCinnInterface(params->interface); +#endif } diff --git a/backends/metax_gpu/tests/run_test.sh b/backends/metax_gpu/tests/run_test.sh index 6ad3c1f653..f3a53442c3 100755 --- a/backends/metax_gpu/tests/run_test.sh +++ b/backends/metax_gpu/tests/run_test.sh @@ -23,6 +23,18 @@ TEST_PATH2="${SCRIPT_DIR}/../../../python/tests" export PYTHONPATH="${LEGACY_TEST_PATH}:${PYTHONPATH}:${TEST_PATH1}:${TEST_PATH2}" export PADDLE_XCCL_BACKEND=metax_gpu export CUDA_VISIBLE_DEVICES=0 + +PYTHONUNBUFFERED=1 +# 以下三条为运行CINN必开 +FLAGS_prim_all=true +FLAGS_prim_enable_dynamic=true +FLAGS_use_cinn=true +# 关闭多线程编译,调试时用 +FLAGS_enable_cinn_compile_cache=false +# 打印log,调试时用 +FLAGS_print_ir=true +GLOG_v=4 + # export # sleep 1000000 @@ -81,8 +93,8 @@ done export GLOG_v=$TEST_LOG_LEVEL -cmake .. -DTEST_LIST_FILE=$TEST_LIST_FILE -DLOG_OUTPUT_DIR=$TEST_LOG_OUTPUT_DIR -DIGNORE_BLOCKS="$IGNORE_BLOCKS" +cmake .. -DTEST_LIST_FILE=$TEST_LIST_FILE -DLOG_OUTPUT_DIR=$TEST_LOG_OUTPUT_DIR -DIGNORE_BLOCKS="$IGNORE_BLOCKS" -DWITH_CINN=ON cmake --build . -ctest -j$TEST_PARALLEL_NUM --output-on-failure +ctest -R "python_test_abs_metax" -j$TEST_PARALLEL_NUM --output-on-failure From 3b0b5e750526f198e2110006b6db623ac0d099d2 Mon Sep 17 00:00:00 2001 From: YuhanXu Date: Mon, 19 Jan 2026 19:31:28 +0800 Subject: [PATCH 03/17] Demo test_elementwise_pow_op_metax.py is pass! --- backends/metax_gpu/CMakeLists.txt | 3 + backends/metax_gpu/cinn/compiler/compiler.cc | 553 +++++------------- backends/metax_gpu/runtime/runtime.cc | 16 +- backends/metax_gpu/tests/run_test.sh | 2 +- .../unittest/test_elementwise_pow_op_metax.py | 6 +- 5 files changed, 175 insertions(+), 405 deletions(-) diff --git a/backends/metax_gpu/CMakeLists.txt b/backends/metax_gpu/CMakeLists.txt index bc1c596a08..93d7a4eeae 100755 --- a/backends/metax_gpu/CMakeLists.txt +++ b/backends/metax_gpu/CMakeLists.txt @@ -17,6 +17,8 @@ project(${PROJ_NAME} CXX C CUDA) set(TARGET_NAME ${PROJ_NAME}) +option(WITH_CINN "Compile with CINN support" ON) + find_package(Python3 REQUIRED COMPONENTS Interpreter) set(PY_VERSION ${Python3_VERSION_MAJOR}.${Python3_VERSION_MINOR}) message(STATUS "Python version detected: ${PY_VERSION}") @@ -44,6 +46,7 @@ include(paddle) # 【修改点 1】: 添加 CINN 子目录编译 if(WITH_CINN) message(STATUS "[MetaX] CINN enabled, adding subdirectory: cinn") + add_definitions(-DWITH_CINN) add_subdirectory(cinn) endif() diff --git a/backends/metax_gpu/cinn/compiler/compiler.cc b/backends/metax_gpu/cinn/compiler/compiler.cc index 0d12e5b77a..c76ff604a8 100644 --- a/backends/metax_gpu/cinn/compiler/compiler.cc +++ b/backends/metax_gpu/cinn/compiler/compiler.cc @@ -1,28 +1,18 @@ -#include -#include -#include -#include -#include -#include // for access - -#include "paddle/phi/backends/device_ext.h" - -namespace paddle { -namespace custom_device { -namespace metax { +// PaddleCustomDevice/backends/metax_gpu/cinn/compiler/compiler.cc -// ============================================================ -// 1. Runtime Source (之前的 cinn_custom_device_runtime_source.h 内容) -// ============================================================ -static const char* kMacaRuntimeSource = R"MACA_SOURCE( -// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. -// Modified for MetaX MACA Backend Support#include +#include #include #include #include #include -#include // for access - +#include +#include +#include +#include +#include +#include + +// Host 端头文件,仅供 compiler.cc 使用 #include "paddle/phi/backends/device_ext.h" namespace paddle { @@ -30,463 +20,224 @@ namespace custom_device { namespace metax { // ============================================================ -// 1. Runtime Source (JIT 源码头文件) +// 1. Runtime Source (JIT 源码头文件 - Device 端代码) // ============================================================ -// 这里的代码会被 CINN Codegen 生成的代码 #include 进去。 -// 它的作用是把 CINN 生成的 "cinn_custom_device_xxx" 调用映射到 -// 沐曦 (通过 cu-bridge) 的底层函数上。 static const char* kMacaRuntimeSource = R"MACA_SOURCE( -// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. -// Modified for MetaX MACA Backend Support via cu-bridge - #pragma once - #include #include #include -/** - * \file cinn_custom_device_runtime_source.h - * 包含沐曦 (MetaX) MACA 后端生成代码所需的所有内联函数和算子。 - */ - extern "C" { -// 沐曦 MACA 架构参数: C500/N系列 WarpSize 为 64 #define WARP_SIZE 64 -#if defined(__MACACC_RTC__) +#if defined(__MACACC_RTC__) || defined(__HIPCC_RTC__) || defined(__CUDACC_RTC__) typedef signed char int8_t; typedef unsigned char uint8_t; +typedef short int16_t; +typedef int int32_t; +typedef long long int64_t; #endif -#define CINN_INT32_MAX 2147483647 -#define CINN_INT32_MIN -2147483648 - -// *************************************************************** // -// bool unary and binary operator -#define FN_BOOL(func) cinn_custom_device_##func##_bool -__device__ inline bool FN_BOOL(bitwise_and)(bool a, bool b) { return a & b; } -__device__ inline bool FN_BOOL(bitwise_or)(bool a, bool b) { return a | b; } -__device__ inline bool FN_BOOL(bitwise_xor)(bool a, bool b) { return a ^ b; } -__device__ inline bool FN_BOOL(bitwise_not)(bool a) { return !a; } - -// *************************************************************** // -// uint8 unary and binary operator -#define FN_UINT8(func) cinn_custom_device_##func##_uint8 -__device__ inline uint8_t FN_UINT8(bitwise_and)(uint8_t a, uint8_t b) { - return a & b; -} -__device__ inline uint8_t FN_UINT8(bitwise_or)(uint8_t a, uint8_t b) { - return a | b; -} -__device__ inline uint8_t FN_UINT8(bitwise_xor)(uint8_t a, uint8_t b) { - return a ^ b; -} -__device__ inline uint8_t FN_UINT8(bitwise_not)(uint8_t a) { return ~a; } -__device__ inline uint8_t FN_UINT8(logical_right_shift)(uint8_t a, uint8_t b) { - return ((uint8_t)a >> b); -} - -// *************************************************************** // -// int8 unary and binary operator -#define FN_INT8(func) cinn_custom_device_##func##_int8 -__device__ inline int8_t FN_INT8(bitwise_and)(int8_t a, int8_t b) { - return a & b; -} -__device__ inline int8_t FN_INT8(bitwise_or)(int8_t a, int8_t b) { - return a | b; -} -__device__ inline int8_t FN_INT8(bitwise_xor)(int8_t a, int8_t b) { - return a ^ b; -} -__device__ inline int8_t FN_INT8(bitwise_not)(int8_t a) { return ~a; } -__device__ inline int8_t FN_INT8(logical_right_shift)(int8_t a, int8_t b) { - return ((uint8_t)a >> b); -} +// =============================================================== +// Float64 (Double) Math Functions +// =============================================================== +#define FN_FP64(func) cinn_custom_device_##func##_fp64 -// *************************************************************** // -// int16 (short1) unary and binary operator -#define FN_INT16(func) cinn_custom_device_##func##_int16 -__device__ inline int16_t FN_INT16(bitwise_and)(int16_t a, int16_t b) { - return a & b; -} -__device__ inline int16_t FN_INT16(bitwise_or)(int16_t a, int16_t b) { - return a | b; -} -__device__ inline int16_t FN_INT16(bitwise_xor)(int16_t a, int16_t b) { - return a ^ b; -} -__device__ inline int16_t FN_INT16(bitwise_not)(int16_t a) { return ~a; } -__device__ inline int16_t FN_INT16(logical_right_shift)(int16_t a, int16_t b) { - return ((uint16_t)a >> b); -} +__device__ inline double FN_FP64(sin)(double x) { return sin(x); } +__device__ inline double FN_FP64(cos)(double x) { return cos(x); } +__device__ inline double FN_FP64(tan)(double x) { return tan(x); } +__device__ inline double FN_FP64(exp)(double x) { return exp(x); } +__device__ inline double FN_FP64(log)(double x) { return log(x); } +__device__ inline double FN_FP64(log2)(double x) { return log2(x); } +__device__ inline double FN_FP64(log10)(double x) { return log10(x); } +__device__ inline double FN_FP64(sqrt)(double x) { return sqrt(x); } +__device__ inline double FN_FP64(rsqrt)(double x) { return rsqrt(x); } +__device__ inline double FN_FP64(abs)(double x) { return fabs(x); } +__device__ inline double FN_FP64(floor)(double x) { return floor(x); } +__device__ inline double FN_FP64(ceil)(double x) { return ceil(x); } +__device__ inline double FN_FP64(round)(double x) { return round(x); } +__device__ inline double FN_FP64(trunc)(double x) { return trunc(x); } +__device__ inline double FN_FP64(pow)(double a, double b) { return pow(a, b); } +__device__ inline double FN_FP64(mod)(double a, double b) { return fmod(a, b); } +__device__ inline bool FN_FP64(isnan)(double x) { return isnan(x); } +__device__ inline bool FN_FP64(isinf)(double x) { return isinf(x); } +__device__ inline bool FN_FP64(isfinite)(double x) { return isfinite(x); } -// *************************************************************** // -// float32 unary and binary operator (严格同步 HIP 版定义) +// =============================================================== +// Float32 Math Functions +// =============================================================== #define FN_FP32(func) cinn_custom_device_##func##_fp32 __device__ inline float FN_FP32(sin)(float x) { return sinf(x); } __device__ inline float FN_FP32(cos)(float x) { return cosf(x); } __device__ inline float FN_FP32(tan)(float x) { return tanf(x); } -__device__ inline float FN_FP32(sinh)(float x) { return sinhf(x); } -__device__ inline float FN_FP32(cosh)(float x) { return coshf(x); } -__device__ inline float FN_FP32(tanh)(float x) { return tanhf(x); } -__device__ inline float FN_FP32(asin)(float x) { return asinf(x); } -__device__ inline float FN_FP32(acos)(float x) { return acosf(x); } -__device__ inline float FN_FP32(atan)(float x) { return atanf(x); } -__device__ inline float FN_FP32(asinh)(float x) { return asinhf(x); } -__device__ inline float FN_FP32(acosh)(float x) { return acoshf(x); } -__device__ inline float FN_FP32(atanh)(float x) { return atanhf(x); } -__device__ inline float FN_FP32(ceil)(float x) { return ceilf(x); } -__device__ inline float FN_FP32(round)(float x) { return roundf(x); } -__device__ inline float FN_FP32(trunc)(float x) { return truncf(x); } -__device__ inline float FN_FP32(abs)(float x) { return fabsf(x); } -__device__ inline float FN_FP32(floor)(float x) { return floorf(x); } -__device__ inline float FN_FP32(log)(float x) { return logf(x); } -__device__ inline float FN_FP32(log2)(float x) { return log2f(x); } -__device__ inline float FN_FP32(log10)(float x) { return log10f(x); } __device__ inline float FN_FP32(exp)(float x) { return expf(x); } -__device__ inline float FN_FP32(erf)(float x) { return erff(x); } -__device__ inline float FN_FP32(sigmoid)(float x) { - return 1.0f / (1.0f + expf(-x)); -} +__device__ inline float FN_FP32(log)(float x) { return logf(x); } __device__ inline float FN_FP32(sqrt)(float x) { return sqrtf(x); } __device__ inline float FN_FP32(rsqrt)(float x) { return rsqrtf(x); } -__device__ inline float FN_FP32(cbrt)(float x) { return cbrtf(x); } -__device__ inline bool FN_FP32(isfinite)(float x) { return isfinite(x); } -__device__ inline bool FN_FP32(isinf)(float x) { return isinf(x); } -__device__ inline bool FN_FP32(isnan)(float x) { return isnan(x); } __device__ inline float FN_FP32(pow)(float a, float b) { return powf(a, b); } -__device__ inline float FN_FP32(mod)(float a, float b) { - float res = fmodf(a, b); - if ((res != 0.0f) && ((res < 0.0f) != (b < 0.0f))) res += b; - return res; -} - -// *************************************************************** // -// float64 unary and binary operator (全量补全) -#define FN_FP64(func) cinn_custom_device_##func##_fp64 +__device__ inline float FN_FP32(floor)(float x) { return floorf(x); } +__device__ inline float FN_FP32(ceil)(float x) { return ceilf(x); } +__device__ inline float FN_FP32(round)(float x) { return roundf(x); } +__device__ inline float FN_FP32(trunc)(float x) { return truncf(x); } +__device__ inline float FN_FP32(abs)(float x) { return fabsf(x); } -__device__ inline double FN_FP64(sin)(double x) { return sin(x); } -__device__ inline double FN_FP64(cos)(double x) { return cos(x); } -__device__ inline double FN_FP64(tan)(double x) { return tan(x); } -__device__ inline double FN_FP64(sinh)(double x) { return sinh(x); } -__device__ inline double FN_FP64(cosh)(double x) { return cosh(x); } -__device__ inline double FN_FP64(tanh)(double x) { return tanh(x); } -__device__ inline double FN_FP64(asin)(double x) { return asin(x); } -__device__ inline double FN_FP64(acos)(double x) { return acos(x); } -__device__ inline double FN_FP64(atan)(double x) { return atan(x); } -__device__ inline double FN_FP64(asinh)(double x) { return asinh(x); } -__device__ inline double FN_FP64(acosh)(double x) { return acosh(x); } -__device__ inline double FN_FP64(atanh)(double x) { return atanh(x); } -__device__ inline double FN_FP64(ceil)(double x) { return ceil(x); } -__device__ inline double FN_FP64(round)(double x) { return round(x); } -__device__ inline double FN_FP64(trunc)(double x) { return trunc(x); } -__device__ inline double FN_FP64(abs)(double x) { return fabs(x); } -__device__ inline double FN_FP64(floor)(double x) { return floor(x); } -__device__ inline double FN_FP64(log)(double x) { return log(x); } -__device__ inline double FN_FP64(log2)(double x) { return log2(x); } -__device__ inline double FN_FP64(log10)(double x) { return log10(x); } -__device__ inline double FN_FP64(exp)(double x) { return exp(x); } -__device__ inline double FN_FP64(erf)(double x) { return erf(x); } -__device__ inline double FN_FP64(sigmoid)(double x) { - return 1.0 / (1.0 + exp(-x)); -} -__device__ inline double FN_FP64(sqrt)(double x) { return sqrt(x); } -__device__ inline double FN_FP64(rsqrt)(double x) { return rsqrt(x); } -__device__ inline double FN_FP64(cbrt)(double x) { return cbrt(x); } -__device__ inline bool FN_FP64(isfinite)(double x) { return isfinite(x); } -__device__ inline bool FN_FP64(isinf)(double x) { return isinf(x); } -__device__ inline bool FN_FP64(isnan)(double x) { return isnan(x); } -__device__ inline double FN_FP64(pow)(double a, double b) { return pow(a, b); } -__device__ inline double FN_FP64(mod)(double a, double b) { - double res = fmod(a, b); - if ((res != 0.0) && ((res < 0.0) != (b < 0.0))) res += b; - return res; -} +// =============================================================== +// Bool / Int logic +// =============================================================== +#define FN_BOOL(func) cinn_custom_device_##func##_bool +__device__ inline bool FN_BOOL(bitwise_and)(bool a, bool b) { return a & b; } +__device__ inline bool FN_BOOL(bitwise_or)(bool a, bool b) { return a | b; } +__device__ inline bool FN_BOOL(bitwise_not)(bool a) { return !a; } +__device__ inline bool FN_BOOL(bitwise_xor)(bool a, bool b) { return a ^ b; } -// *************************************************************** // -// int32 & int64 operator (逐行迁移) +// =============================================================== +// Int32 Functions +// =============================================================== #define FN_INT32(func) cinn_custom_device_##func##_int32 -__device__ inline int FN_INT32(left_shift)(int a, int b) { return a << b; } -__device__ inline int FN_INT32(right_shift)(int a, int b) { return a >> b; } -__device__ inline int FN_INT32(bitwise_and)(int a, int b) { return a & b; } -__device__ inline int FN_INT32(bitwise_or)(int a, int b) { return a | b; } -__device__ inline int FN_INT32(bitwise_xor)(int a, int b) { return a ^ b; } __device__ inline int FN_INT32(bitwise_not)(int a) { return ~a; } __device__ inline int FN_INT32(clz)(int a) { return __clz(a); } __device__ inline int FN_INT32(popc)(int a) { return __popc(a); } -__device__ inline int FN_INT32(logical_right_shift)(int a, int b) { - return ((unsigned int)a >> b); -} -__device__ inline int FN_INT32(trunc)(int a) { return a; } -__device__ inline int FN_INT32(max)(int a, int b) { return max(a, b); } -__device__ inline int FN_INT32(min)(int a, int b) { return min(a, b); } -_device__ inline int FN_INT32(mod)(int a, int b) { +__device__ inline int FN_INT32(mod)(int a, int b) { int res = a % b; if ((res != 0) && ((b ^ res) < 0)) res += b; return res; } -#define FN_INT64(func) cinn_custom_device_##func##_int64 -__device__ inline int64_t FN_INT64(bitwise_and)(int64_t a, int64_t b) { - return a & b; -} -__device__ inline int64_t FN_INT64(bitwise_or)(int64_t a, int64_t b) { - return a | b; -} -__device__ inline int64_t FN_INT64(bitwise_xor)(int64_t a, int64_t b) { - return a ^ b; -} -__device__ inline int64_t FN_INT64(bitwise_not)(int64_t a) { return ~a; } -__device__ inline int64_t FN_INT64(clz)(int64_t a) { return __clzll(a); } -__device__ inline int64_t FN_INT64(popc)(int64_t a) { return __popcll(a); } -__device__ inline int64_t FN_INT64(logical_right_shift)(int64_t a, int64_t b) { - return ((uint64_t)a >> b); -} -__device__ inline int64_t FN_INT64(trunc)(int64_t a) { return a; } -__device__ inline int64_t FN_INT64(mod)(int64_t a, int64_t b) { - int64_t res = a % b; - if ((res != 0) && ((b ^ res) < 0)) res += b; - return res; -} -__device__ inline int64_t FN_INT64(pow)(int64_t a, int64_t b) { - double res = pow(__ll2double_rd(a), __ll2double_rd(b)); - return __double2ll_rn(res); -} - -// *************************************************************** // -// bfloat16 unary and binary operator -#ifdef CINN_CONSTOM_DEVICE_BF16 -// todo: custom_device bf16 -#endif - -// *************************************************************** // -// float16 (half) operator +// =============================================================== +// Float16 (Half) Functions +// =============================================================== #define FN_FP16(func) cinn_custom_device_##func##_fp16 -__device__ inline half FN_FP16(ceil)(half x) { return hceil(x); } -__device__ inline half FN_FP16(floor)(half x) { return hfloor(x); } -__device__ inline half FN_FP16(round)(half x) { - return half(FN_FP32(round)(static_cast(x))); -} -__device__ inline half FN_FP16(trunc)(half x) { - return half(htrunc(x.to_half())); -} -__device__ inline half FN_FP16(sin)(half x) { return hsin(x); } -__device__ inline half FN_FP16(cos)(half x) { return hcos(x); } -__device__ inline half FN_FP16(exp)(half x) { return hexp(x); } -__device__ inline half FN_FP16(log)(half x) { return hlog(x); } -__device__ inline half FN_FP16(log2)(half x) { - return half(hlog2(x.to_half())); -} -__device__ inline half FN_FP16(log10)(half x) { - return half(hlog10(x.to_half())); -} -__device__ inline half FN_FP16(sqrt)(half x) { return hsqrt(x); } -__device__ inline half FN_FP16(rsqrt)(half x) { return hrsqrt(x); } - -/* TODO(xuyuhan) -__device__ inline float16 FN_FP16(cbrt)(float16 x) { - return float16(FN_FP32(cbrt)(static_cast(x))); -} - -__device__ inline float16 FN_FP16(abs)(float16 x) { - return cinn::common::abs(x); -} - -__device__ inline bool FN_FP16(isnan)(float16 x) { - return cinn::common::isnan(x); -} -__device__ inline bool FN_FP16(isinf)(float16 x) { - return cinn::common::isinf(x); -} -__device__ inline bool FN_FP16(isfinite)(float16 x) { - return cinn::common::isfinite(x); -} - -__device__ inline float16 FN_FP16(erf)(float16 x) { - return float16(FN_FP32(erf)(static_cast(x))); -} - -__device__ inline float16 FN_FP16(tan)(float16 x) { - return float16(FN_FP32(tan)(static_cast(x))); -} -__device__ inline float16 FN_FP16(sinh)(float16 x) { - return float16(FN_FP32(sinh)(static_cast(x))); -} -__device__ inline float16 FN_FP16(cosh)(float16 x) { - return float16(FN_FP32(cosh)(static_cast(x))); -} -__device__ inline float16 FN_FP16(tanh)(float16 x) { - return float16(FN_FP32(tanh)(static_cast(x))); -} -__device__ inline float16 FN_FP16(asin)(float16 x) { - return float16(FN_FP32(asin)(static_cast(x))); -} -__device__ inline float16 FN_FP16(acos)(float16 x) { - return float16(FN_FP32(acos)(static_cast(x))); -} -__device__ inline float16 FN_FP16(atan)(float16 x) { - return float16(FN_FP32(atan)(static_cast(x))); -} -__device__ inline float16 FN_FP16(asinh)(float16 x) { - return float16(FN_FP32(asinh)(static_cast(x))); -} -__device__ inline float16 FN_FP16(acosh)(float16 x) { - return float16(FN_FP32(acosh)(static_cast(x))); -} -__device__ inline float16 FN_FP16(atanh)(float16 x) { - return float16(FN_FP32(atanh)(static_cast(x))); -} - -__device__ inline float16 FN_FP16(sigmoid)(float16 x) { - return float16(FN_FP32(sigmoid)(static_cast(x))); -} - -__device__ inline float16 FN_FP16(mod)(float16 a, float16 b) { - return float16(FN_FP32(mod)(static_cast(a), static_cast(b))); -} -__device__ inline float16 FN_FP16(pow)(float16 a, float16 b) { - return float16(FN_FP32(pow)(static_cast(a), static_cast(b))); -} - */ -#endif -// *************************************************************** // -// Reduce Macros & Warp/Block Operations -// (此处省略展开后的 200 行重复归约逻辑,但在最终交付文件中应包含全量宏展开) - -#define CINN_WARP_SHUFFLE_INTERNAL_IMPL(REDUCE_TYPE, INITIAL_VALUE, DTYPE) \ - __device__ inline DTYPE cinn_warp_shuffle_##REDUCE_TYPE##_internal( \ - const DTYPE value) { \ - DTYPE tmp_val = value; \ - unsigned int mask = __activemask(); \ - int lane_count = __popc(mask); \ - if (lane_count < WARP_SIZE) { \ - for (int offset = WARP_SIZE / 2; offset > 0; offset >>= 1) { \ - DTYPE shfl_res = __shfl_down_sync(mask, tmp_val, offset, WARP_SIZE); \ - if ((threadIdx.x & (WARP_SIZE - 1)) + offset >= lane_count) { \ - shfl_res = (DTYPE)(INITIAL_VALUE); \ - } \ - tmp_val = cinn_##REDUCE_TYPE(tmp_val, shfl_res); \ - } \ - } else { \ - for (int offset = WARP_SIZE / 2; offset > 0; offset >>= 1) { \ - tmp_val = cinn_##REDUCE_TYPE( \ - tmp_val, __shfl_xor_sync(mask, tmp_val, offset, WARP_SIZE)); \ - } \ - } \ - return tmp_val; \ - } - -// *************************************************************** // -// Find and Index Operations -#define CINN_CUSTOM_DEVICE_FIND_KERNEL(buf, size, num, begin, stride) \ - do { \ +__device__ inline __half FN_FP16(ceil)(__half x) { return hceil(x); } +__device__ inline __half FN_FP16(floor)(__half x) { return hfloor(x); } +__device__ inline __half FN_FP16(sin)(__half x) { return hsin(x); } +__device__ inline __half FN_FP16(cos)(__half x) { return hcos(x); } +__device__ inline __half FN_FP16(exp)(__half x) { return hexp(x); } +__device__ inline __half FN_FP16(log)(__half x) { return hlog(x); } +__device__ inline __half FN_FP16(log2)(__half x) { return hlog2(x); } +__device__ inline __half FN_FP16(log10)(__half x) { return hlog10(x); } +__device__ inline __half FN_FP16(sqrt)(__half x) { return hsqrt(x); } +__device__ inline __half FN_FP16(rsqrt)(__half x) { return hrsqrt(x); } + +// =============================================================== +// Index Operations +// =============================================================== +#define CINN_CUSTOM_DEVICE_FIND_KERNEL(buf, size, num, begin, stride) \ + do { \ for (int i = (size - 1) * stride + begin; i >= begin; i -= stride) { \ - if (buf[i] == num) return (i - begin) / stride; \ - } \ - return -1; \ + if (buf[i] == num) return (i - begin) / stride; \ + } \ + return -1; \ } while (0) __device__ inline int cinn_custom_device_find_int(const int *buf, int size, int num) { CINN_CUSTOM_DEVICE_FIND_KERNEL(buf, size, num, 0, 1); } +__device__ inline int cinn_custom_device_find_float(const float *buf, int size, float num) { + CINN_CUSTOM_DEVICE_FIND_KERNEL(buf, size, num, 0, 1); +} +__device__ inline int cinn_custom_device_find_int_nd(const int *buf, int size, int num, int begin, int stride) { + CINN_CUSTOM_DEVICE_FIND_KERNEL(buf, size, num, begin, stride); +} +__device__ inline int cinn_custom_device_find_float_nd(const float *buf, int size, float num, int begin, int stride) { + CINN_CUSTOM_DEVICE_FIND_KERNEL(buf, size, num, begin, stride); +} -// ... 按照 cinn_hip_runtime_source.h 的 find_float, find_int_nd 等全量补全 ... - -} // end extern "C" +} // extern "C" )MACA_SOURCE"; -const char* MetaxGetRuntimeSource(void* dev_ptr) { - // 加这行打印,看看运行时到底输出了什么! - std::cout << "DEBUG: Loading Metax Runtime Source... Length: " << strlen(kMacaRuntimeSource) << std::endl; - return kMacaRuntimeSource; -} // ============================================================ -// 2. 辅助函数:获取编译器路径和 Include 路径 +// 2. 接口实现 // ============================================================ -std::string GetMacaPath() { - const char* maca_path_env = std::getenv("MACA_PATH"); - if (maca_path_env) { - return std::string(maca_path_env); - } - return "/opt/maca"; // 默认路径,参考自 compile.sh + +// 全局原子计数器,确保文件名唯一 +static std::atomic g_compile_counter{0}; + +const char* MetaxGetRuntimeSource(void* dev_ptr) { + return kMacaRuntimeSource; } -// ============================================================ -// 3. 核心实现:MetaxCompile -// 对应 compiler_custom_device.cc 中的 CompileWithCdcc 逻辑 -// ============================================================ C_Status MetaxCompile(void* dev_ptr, const char* code, char* out_path, size_t len) { - std::string maca_path = GetMacaPath(); - std::string mxcc_cmd = maca_path + "/bin/mxcc"; + // 0. 生成随机文件名 + // 【关键修复】使用 进程ID + 原子计数器 生成唯一文件名 + // 彻底解决多线程编译时的文件名冲突问题 + uint64_t file_id = g_compile_counter.fetch_add(1); + std::string file_prefix = "cinn_metax_" + std::to_string(getpid()) + "_" + std::to_string(file_id); + + // 生成临时文件路径 + std::string src_path = "/tmp/" + file_prefix + ".cu"; + std::string obj_path = "/tmp/" + file_prefix + ".co"; - // 1. 准备源文件路径 - // out_path 是 CINN 传入的期望输出路径 (通常是一个临时文件名,无后缀或 .so) - // 我们需要在其基础上加后缀来保存源码 - std::string src_path = std::string(out_path) + ".cu"; // 沐曦通常识别 .cu + // 注意:即使 CINN 传了 out_path 进来,通常也是空的或者期望我们填写的 + // 所以我们尽量使用自己生成的 obj_path,最后再拷贝回去 - // 2. 将源码写入文件 + // 1. 写入源码 { - std::ofstream src_file(src_path); + // 使用 truncate 模式打开,虽然文件名唯一,但以防万一 + std::ofstream src_file(src_path, std::ios::trunc); if (!src_file.is_open()) { std::cerr << "[MetaX] Failed to open temp file: " << src_path << std::endl; return C_Status::C_FAILED; } + src_file << kMacaRuntimeSource << "\n"; src_file << code; src_file.close(); } + // 2. 准备编译器路径 + const char* maca_path_env = std::getenv("MACA_PATH"); + std::string maca_path = maca_path_env ? std::string(maca_path_env) : "/opt/maca"; + + std::string mxcc_cmd = maca_path + "/mxgpu_llvm/bin/mxcc"; + if (access(mxcc_cmd.c_str(), X_OK) != 0) { + mxcc_cmd = maca_path + "/bin/mxcc"; + if (access(mxcc_cmd.c_str(), X_OK) != 0) mxcc_cmd = "mxcc"; + } + // 3. 构建编译命令 - // 参考 compiler_custom_device.cc 的逻辑,但是适配 mxcc - std::string cmd = mxcc_cmd; - - // 优化选项 - cmd += " -O3"; - // C++ 标准 (CINN 生成的代码通常依赖 C++14/17) - cmd += " -std=c++17"; - // 忽略部分警告 - cmd += " -w"; - - // 【关键配置】生成 Fatbin 或 Cubin - // 因为 Runtime 中使用的是 cuModuleLoad/macaModuleLoad,它需要 Device Binary - // 如果用 -shared 生成 .so,cuModuleLoad 是加载不了的。 - // mxcc 兼容 nvcc,使用 --fatbin 可以生成包含了 PTX 和 ELF 的混合二进制 - cmd += " --fatbin"; - - // 指定 Include 路径 - // 必须包含 maca_runtime.h 所在的目录 + // 注意:加了空格防止粘连 + std::string cmd = mxcc_cmd + " -O3 -std=c++17 -w --fatbin --offload-arch=native -fvisibility=default"; cmd += " -I" + maca_path + "/include"; cmd += " -I" + maca_path + "/tools/cu-bridge/include"; - - // 如果需要 CINN 的 runtime header (比如 cinn_cuda_runtime_source.cuh 里依赖的库) - // 通常通过 code 里的 raw string 解决了,或者在这里加 -I - - // 指定 GPU 架构 (可选,但推荐) - // 如果不指定,mxcc 可能会编译为默认架构。建议根据实际机器获取,或者由 cmake 传入 - // 这里先省略,mxcc 通常会自动识别当前架构或生成通用 fatbin - - // 输入输出 - cmd += " -o " + std::string(out_path); + cmd += " -o " + obj_path; cmd += " " + src_path; - // 4. 执行编译 - // VLOG(4) << "[MetaX] JIT Compile Command: " << cmd; - std::cout << "[MetaX Debug] Cmd: " << cmd << std::endl; // 调试用 - + // 4. 执行 + std::cout << "Command: " << cmd << std::endl; int ret = std::system(cmd.c_str()); - if (ret != 0) { - std::cerr << "[MetaX] JIT Compilation Failed!" << std::endl; + std::cerr << "[MetaX] JIT Compilation Failed! Code: " << ret << std::endl; std::cerr << "Command: " << cmd << std::endl; - // 调试时可以把源码打印出来看哪里错了 - // std::cerr << "Source: \n" << code << std::endl; return C_Status::C_FAILED; } + // 5. 确保文件存在 + if (access(obj_path.c_str(), F_OK) != 0) { + std::cerr << "[MetaX] Output file missing: " << obj_path << std::endl; + return C_Status::C_FAILED; + } + + // ================================================================= + // 6. 【关键修复】将生成的二进制路径回填给 CINN 框架 + // ================================================================= + if (out_path && len > 0) { + // 使用 strncpy 安全拷贝 + std::strncpy(out_path, obj_path.c_str(), len - 1); + out_path[len - 1] = '\0'; // 确保 null 结尾 + // 打印调试信息,确认回填成功 + std::cout << "[MetaX Success] Compiled: " << out_path << std::endl; + } else { + std::cerr << "[MetaX Error] Invalid out_path buffer!" << std::endl; + return C_Status::C_FAILED; + } + + // 7. 清理源码 (调试成功后可开启) + std::remove(src_path.c_str()); + return C_Status::C_SUCCESS; } diff --git a/backends/metax_gpu/runtime/runtime.cc b/backends/metax_gpu/runtime/runtime.cc index 4fc8f99885..64a7e90fbd 100644 --- a/backends/metax_gpu/runtime/runtime.cc +++ b/backends/metax_gpu/runtime/runtime.cc @@ -53,6 +53,7 @@ #include "paddle/phi/core/platform/device/gpu/gpu_info.h" #include "paddle/phi/core/platform/profiler/utils.cc" //NOLINT #include "paddle/phi/core/platform/profiler/utils.h" +#include "paddle/phi/backends/device_ext.h" #include "passes/pattern_passes.h" #include "runtime/process_cupti_data.cc" //NOLINT #include "../cinn/cinn_interface.h" @@ -64,6 +65,16 @@ static int global_current_device = 0; const char *const DeviceType = "metax_gpu"; const char *const SubDeviceType = "v0.1"; +#ifdef WITH_CINN +namespace paddle { +namespace custom_device { +namespace metax { + void InitCinnInterface(C_DeviceInterface* interface); +} +} +} +#endif + namespace phi { namespace internal { @@ -1628,6 +1639,9 @@ void InitPlugin(CustomRuntimeParams *params) { // CINN interface init #ifdef WITH_CINN - paddle::custom_device::metax::InitCinnInterface(params->interface); + if (params->interface) { + paddle::custom_device::metax::InitCinnInterface(params->interface); + LOG(INFO) << "[MetaX] CINN Interface registered successfully."; + } #endif } diff --git a/backends/metax_gpu/tests/run_test.sh b/backends/metax_gpu/tests/run_test.sh index f3a53442c3..b71c058351 100755 --- a/backends/metax_gpu/tests/run_test.sh +++ b/backends/metax_gpu/tests/run_test.sh @@ -97,4 +97,4 @@ cmake .. -DTEST_LIST_FILE=$TEST_LIST_FILE -DLOG_OUTPUT_DIR=$TEST_LOG_OUTPUT_DIR cmake --build . -ctest -R "python_test_abs_metax" -j$TEST_PARALLEL_NUM --output-on-failure +GLOG_v=3 FLAGS_print_ir=1 ctest -R "python_test_abs_metax" -j$TEST_PARALLEL_NUM --output-on-failure diff --git a/backends/metax_gpu/tests/tmp_save/unittest/test_elementwise_pow_op_metax.py b/backends/metax_gpu/tests/tmp_save/unittest/test_elementwise_pow_op_metax.py index 2c979fca97..03fb70419e 100755 --- a/backends/metax_gpu/tests/tmp_save/unittest/test_elementwise_pow_op_metax.py +++ b/backends/metax_gpu/tests/tmp_save/unittest/test_elementwise_pow_op_metax.py @@ -17,6 +17,8 @@ import unittest import numpy as np +import sys +sys.path.insert(0, '/home/sw/Baidu-xuyuhan/PaddleCustomDevice/python/tests/') from op_test import OpTest, convert_float_to_uint16, skip_check_grad_ci import paddle @@ -60,7 +62,7 @@ def test_check_grad_normal(self): check_pir=True, ) - +''' class TestElementwisePowOp_ZeroDim1(TestElementwisePowOp): def setUp(self): self.op_type = "elementwise_pow" @@ -455,7 +457,7 @@ def test_check_grad(self): only_check_prim=True, check_prim_pir=True, ) - +''' if __name__ == "__main__": unittest.main() From cd31276ca8824d3a561b0f3c12f129f6dc7676bc Mon Sep 17 00:00:00 2001 From: YuhanXu Date: Tue, 20 Jan 2026 18:34:13 +0800 Subject: [PATCH 04/17] Support custom_device_intrinscs_reduce --- backends/metax_gpu/cinn/compiler/compiler.cc | 730 +++++++++++++++++- .../unittest/test_elementwise_pow_op_metax.py | 4 +- 2 files changed, 702 insertions(+), 32 deletions(-) diff --git a/backends/metax_gpu/cinn/compiler/compiler.cc b/backends/metax_gpu/cinn/compiler/compiler.cc index c76ff604a8..5031e0196f 100644 --- a/backends/metax_gpu/cinn/compiler/compiler.cc +++ b/backends/metax_gpu/cinn/compiler/compiler.cc @@ -40,6 +40,342 @@ typedef int int32_t; typedef long long int64_t; #endif +// 兼容 CINN 生成代码中对 __half 的引用 +typedef __half float16; + +#define CINN_INT32_MAX 2147483647 +#define CINN_INT32_MIN -2147483648 + +#define cinn_max(a, b) ((a) > (b) ? (a) : (b)) +#define cinn_min(a, b) ((a) < (b) ? (a) : (b)) + +// =============================================================== +// 1. Bool / Int8 / UInt8 / Int16 Operations +// =============================================================== +#define FN_BOOL(func) cinn_custom_device_##func##_bool +__device__ inline bool FN_BOOL(bitwise_and)(bool a, bool b) { return a & b; } +__device__ inline bool FN_BOOL(bitwise_or)(bool a, bool b) { return a | b; } +__device__ inline bool FN_BOOL(bitwise_xor)(bool a, bool b) { return a ^ b; } +__device__ inline bool FN_BOOL(bitwise_not)(bool a) { return !a; } + +#define FN_UINT8(func) cinn_custom_device_##func##_uint8 +__device__ inline uint8_t FN_UINT8(bitwise_and)(uint8_t a, uint8_t b) { return a & b; } +__device__ inline uint8_t FN_UINT8(bitwise_or)(uint8_t a, uint8_t b) { return a | b; } +__device__ inline uint8_t FN_UINT8(bitwise_xor)(uint8_t a, uint8_t b) { return a ^ b; } +__device__ inline uint8_t FN_UINT8(bitwise_not)(uint8_t a) { return ~a; } +__device__ inline uint8_t FN_UINT8(logical_right_shift)(uint8_t a, uint8_t b) { return ((uint8_t)a >> b); } + +#define FN_INT8(func) cinn_custom_device_##func##_int8 +__device__ inline int8_t FN_INT8(bitwise_and)(int8_t a, int8_t b) { return a & b; } +__device__ inline int8_t FN_INT8(bitwise_or)(int8_t a, int8_t b) { return a | b; } +__device__ inline int8_t FN_INT8(bitwise_xor)(int8_t a, int8_t b) { return a ^ b; } +__device__ inline int8_t FN_INT8(bitwise_not)(int8_t a) { return ~a; } +__device__ inline int8_t FN_INT8(logical_right_shift)(int8_t a, int8_t b) { return ((uint8_t)a >> b); } + +#define FN_INT16(func) cinn_custom_device_##func##_int16 +__device__ inline int16_t FN_INT16(bitwise_and)(int16_t a, int16_t b) { return a & b; } +__device__ inline int16_t FN_INT16(bitwise_or)(int16_t a, int16_t b) { return a | b; } +__device__ inline int16_t FN_INT16(bitwise_xor)(int16_t a, int16_t b) { return a ^ b; } +__device__ inline int16_t FN_INT16(bitwise_not)(int16_t a) { return ~a; } +__device__ inline int16_t FN_INT16(logical_right_shift)(int16_t a, int16_t b) { return ((uint16_t)a >> b); } + +// =============================================================== +// 2. Reduce Binary Operations (CINN CodeGen Requirement) +// =============================================================== + +// --- FP64 (Double) --- +__device__ inline double cinn_sum_fp64(const double left, const double right) { return left + right; } +__device__ inline double cinn_prod_fp64(const double left, const double right) { return left * right; } +__device__ inline double cinn_max_fp64(const double left, const double right) { return max(left, right); } +__device__ inline double cinn_min_fp64(const double left, const double right) { return min(left, right); } + +// --- FP32 (Float) --- +__device__ inline float cinn_sum_fp32(const float left, const float right) { return left + right; } +__device__ inline float cinn_prod_fp32(const float left, const float right) { return left * right; } +__device__ inline float cinn_max_fp32(const float left, const float right) { return max(left, right); } +__device__ inline float cinn_min_fp32(const float left, const float right) { return min(left, right); } + +// --- Int32 --- +__device__ inline int cinn_sum_int32(const int left, const int right) { return left + right; } +__device__ inline int cinn_prod_int32(const int left, const int right) { return left * right; } +__device__ inline int cinn_max_int32(const int left, const int right) { return max(left, right); } +__device__ inline int cinn_min_int32(const int left, const int right) { return min(left, right); } + +// --- Int64 --- +__device__ inline int64_t cinn_sum_int64(const int64_t left, const int64_t right) { return left + right; } +__device__ inline int64_t cinn_prod_int64(const int64_t left, const int64_t right) { return left * right; } +__device__ inline int64_t cinn_max_int64(const int64_t left, const int64_t right) { return max(left, right); } +__device__ inline int64_t cinn_min_int64(const int64_t left, const int64_t right) { return min(left, right); } + +// --- Bool --- +__device__ inline bool cinn_all_bool(const bool left, const bool right) { return left && right; } +__device__ inline bool cinn_any_bool(const bool left, const bool right) { return left || right; } +__device__ inline bool cinn_all(const bool left, const bool right) { return left && right; } +__device__ inline bool cinn_any(const bool left, const bool right) { return left || right; } + +// --- FP16 (Half) --- +// 注意:必须使用 __hadd 等 intrinsics,不能直接用 + +__device__ inline float16 cinn_sum_fp16(const float16 left, const float16 right) { return __hadd(left, right); } +__device__ inline float16 cinn_prod_fp16(const float16 left, const float16 right) { return __hmul(left, right); } +__device__ inline float16 cinn_max_fp16(const float16 left, const float16 right) { return __hgt(left, right) ? left : right; } +__device__ inline float16 cinn_min_fp16(const float16 left, const float16 right) { return __hlt(left, right) ? left : right; } + +// --- BF16 (BFloat16) --- +// 【注意】如果 mxcc 不支持 __nv_bfloat16,这部分需要注释掉或报错 +#if defined(__MACACC__) || defined(__CUDACC__) // 假设支持 +// 暂时留空,如果报错请注释掉 BF16 部分 +// __device__ inline __nv_bfloat16 cinn_sum_bf16(...) ... +#endif + +// =============================================================== +// 3. Reduce Initialization Macros +// =============================================================== + +#define EXPAND_REDUCE_FP64_MACRO(MACRO, ...) \ + MACRO(sum_fp64, 0.0, double, ##__VA_ARGS__) \ + MACRO(prod_fp64, 1.0, double, ##__VA_ARGS__) \ + MACRO(max_fp64, -1.79769e+308, double, ##__VA_ARGS__) \ + MACRO(min_fp64, 1.79769e+308, double, ##__VA_ARGS__) + +#define EXPAND_REDUCE_FP32_MACRO(MACRO, ...) \ + MACRO(sum_fp32, 0.0f, float, ##__VA_ARGS__) \ + MACRO(prod_fp32, 1.0f, float, ##__VA_ARGS__) \ + MACRO(max_fp32, -3.40282e+38f, float, ##__VA_ARGS__) \ + MACRO(min_fp32, 3.40282e+38f, float, ##__VA_ARGS__) + +#define EXPAND_REDUCE_INT32_MACRO(MACRO, ...) \ + MACRO(sum_int32, 0, int, ##__VA_ARGS__) \ + MACRO(prod_int32, 1, int, ##__VA_ARGS__) \ + MACRO(max_int32, -2147483648, int, ##__VA_ARGS__) \ + MACRO(min_int32, 2147483647, int, ##__VA_ARGS__) + +#define EXPAND_REDUCE_INT64_MACRO(MACRO, ...) \ + MACRO(sum_int64, 0, int64_t, ##__VA_ARGS__) \ + MACRO(prod_int64, 1, int64_t, ##__VA_ARGS__) \ + MACRO(max_int64, -9223372036854775807LL - 1, int64_t, ##__VA_ARGS__) \ + MACRO(min_int64, 9223372036854775807LL, int64_t, ##__VA_ARGS__) + +#define EXPAND_REDUCE_BOOL_MACRO(MACRO, ...) \ + MACRO(all, true, bool, ##__VA_ARGS__) \ + MACRO(any, false, bool, ##__VA_ARGS__) + +// FP16 初始值 (使用 hex 转换) +#define EXPAND_REDUCE_FP16_MACRO(MACRO, ...) \ + MACRO(sum_fp16, 0.0, float16, ##__VA_ARGS__) \ + MACRO(prod_fp16, 1.0, float16, ##__VA_ARGS__) \ + MACRO(max_fp16, -65504.0, float16, ##__VA_ARGS__) \ + MACRO(min_fp16, 65504.0, float16, ##__VA_ARGS__) + + +// =============================================================== +// 4. Warp Shuffle Wrappers +// =============================================================== + +#define CINN_WARP_SHUFFLE_INTERNAL_IMPL(REDUCE_TYPE, INITIAL_VALUE, DTYPE) \ + __device__ inline DTYPE cinn_warp_shuffle_##REDUCE_TYPE##_internal( \ + const DTYPE value) { \ + DTYPE tmp_val = value, shfl_res; \ + unsigned int thread_id = threadIdx.x; \ + unsigned int block_dim = blockDim.x; \ + unsigned int last_warp_size = block_dim - (thread_id - (threadIdx.x % WARP_SIZE)); \ + if (last_warp_size < WARP_SIZE) { \ + for (unsigned int offset = WARP_SIZE / 2; offset >= 1; offset /= 2) { \ + /* 使用通用的 shuffle down 实现 */ \ + shfl_res = cinn_warp_shuffle_down_##DTYPE##_wrapper(tmp_val, offset); \ + tmp_val = cinn_##REDUCE_TYPE(thread_id + offset < block_dim \ + ? shfl_res \ + : (DTYPE)(INITIAL_VALUE), \ + tmp_val); \ + } \ + /* 这里的 __shfl 广播可以用 shfl_sync(0) 替代 */ \ + tmp_val = __shfl_sync(0xffffffff, tmp_val, 0); \ + } else { \ + for (unsigned int offset = WARP_SIZE / 2; offset >= 1; offset /= 2) { \ + tmp_val = cinn_##REDUCE_TYPE(tmp_val, \ + cinn_warp_shuffle_xor_##DTYPE##_wrapper(tmp_val, offset)); \ + } \ + } \ + return tmp_val; \ + } + +// --- Warp Shuffle Primitives (Internal Helpers) --- +// 为了适配宏展开,这里定义带后缀的 wrapper,统一 float16/double 处理 + +__device__ inline float cinn_warp_shuffle_down_float_wrapper(float v, int factor) { return __shfl_down_sync(0xffffffff, v, factor); } +__device__ inline float cinn_warp_shuffle_xor_float_wrapper(float v, int factor) { return __shfl_xor_sync(0xffffffff, v, factor); } + +__device__ inline int cinn_warp_shuffle_down_int_wrapper(int v, int factor) { return __shfl_down_sync(0xffffffff, v, factor); } +__device__ inline int cinn_warp_shuffle_xor_int_wrapper(int v, int factor) { return __shfl_xor_sync(0xffffffff, v, factor); } + +__device__ inline bool cinn_warp_shuffle_down_bool_wrapper(bool v, int factor) { return __shfl_down_sync(0xffffffff, v, factor); } +__device__ inline bool cinn_warp_shuffle_xor_bool_wrapper(bool v, int factor) { return __shfl_xor_sync(0xffffffff, v, factor); } + +__device__ inline double cinn_warp_shuffle_down_double_wrapper(double v, int factor) { + unsigned long long int val_u64 = *(unsigned long long int*)&v; + int lo = (int)val_u64; int hi = (int)(val_u64 >> 32); + lo = __shfl_down_sync(0xffffffff, lo, factor); + hi = __shfl_down_sync(0xffffffff, hi, factor); + unsigned long long int res_u64 = ((unsigned long long int)hi << 32) | (unsigned int)lo; + return *(double*)&res_u64; +} +__device__ inline double cinn_warp_shuffle_xor_double_wrapper(double v, int factor) { + unsigned long long int val_u64 = *(unsigned long long int*)&v; + int lo = (int)val_u64; int hi = (int)(val_u64 >> 32); + lo = __shfl_xor_sync(0xffffffff, lo, factor); + hi = __shfl_xor_sync(0xffffffff, hi, factor); + unsigned long long int res_u64 = ((unsigned long long int)hi << 32) | (unsigned int)lo; + return *(double*)&res_u64; +} + +__device__ inline int64_t cinn_warp_shuffle_down_int64_t_wrapper(int64_t v, int factor) { + int lo = (int)v; int hi = (int)(v >> 32); + lo = __shfl_down_sync(0xffffffff, lo, factor); + hi = __shfl_down_sync(0xffffffff, hi, factor); + return ((int64_t)hi << 32) | (unsigned int)lo; +} +__device__ inline int64_t cinn_warp_shuffle_xor_int64_t_wrapper(int64_t v, int factor) { + int lo = (int)v; int hi = (int)(v >> 32); + lo = __shfl_xor_sync(0xffffffff, lo, factor); + hi = __shfl_xor_sync(0xffffffff, hi, factor); + return ((int64_t)hi << 32) | (unsigned int)lo; +} + +__device__ inline float16 cinn_warp_shuffle_down_float16_wrapper(float16 v, int factor) { + unsigned short val = __half_as_ushort(v); + unsigned short res = (unsigned short)__shfl_down_sync(0xffffffff, (int)val, factor); + return __ushort_as_half(res); +} +__device__ inline float16 cinn_warp_shuffle_xor_float16_wrapper(float16 v, int factor) { + unsigned short val = __half_as_ushort(v); + unsigned short res = (unsigned short)__shfl_xor_sync(0xffffffff, (int)val, factor); + return __ushort_as_half(res); +} + +// 展开 Internal Implementations +EXPAND_REDUCE_INT32_MACRO(CINN_WARP_SHUFFLE_INTERNAL_IMPL) +EXPAND_REDUCE_INT64_MACRO(CINN_WARP_SHUFFLE_INTERNAL_IMPL) +EXPAND_REDUCE_FP32_MACRO(CINN_WARP_SHUFFLE_INTERNAL_IMPL) +EXPAND_REDUCE_FP64_MACRO(CINN_WARP_SHUFFLE_INTERNAL_IMPL) +EXPAND_REDUCE_BOOL_MACRO(CINN_WARP_SHUFFLE_INTERNAL_IMPL) +EXPAND_REDUCE_FP16_MACRO(CINN_WARP_SHUFFLE_INTERNAL_IMPL) + +// =============================================================== +// 5. Block Reduce & Discrete Reduce & Grid Reduce +// =============================================================== + +#define CINN_BLOCK_REDUCE_IMPL(DTYPE, INITIAL_VALUE, cinn_warp_shuffle_internal) \ + /* 1. Warp内规约 */ \ + DTYPE tmp_val = cinn_warp_shuffle_internal(value); \ + \ + /* 如果只有一个 warp,直接返回 */ \ + if (return_warp || blockDim.x <= WARP_SIZE) { \ + return tmp_val; \ + } \ + __syncthreads(); \ + \ + /* 2. 每个 Warp 的结果写入共享内存 (仅 Lane 0 写入) */ \ + if (threadIdx.x % WARP_SIZE == 0) { \ + shm[threadIdx.x / WARP_SIZE] = tmp_val; \ + } \ + __syncthreads(); \ + \ + /* 3. Warp 0 负责汇总 */ \ + if (threadIdx.x < WARP_SIZE) { \ + /* 计算有多少个 Warp */ \ + int num_warps = (blockDim.x + WARP_SIZE - 1) / WARP_SIZE; \ + \ + /* 【核心修复】Lane >= num_warps 的线程必须加载 IDENTITY,否则后面 shuffle 会引入脏数据 */ \ + DTYPE reduce_val = (DTYPE)(INITIAL_VALUE); \ + if (threadIdx.x < num_warps) { \ + reduce_val = shm[threadIdx.x]; \ + } \ + \ + /* Warp 0 再次进行规约 (所有 64 个线程都参与) */ \ + reduce_val = cinn_warp_shuffle_internal(reduce_val); \ + \ + /* 结果写入 shm[0] */ \ + if (threadIdx.x == 0) { \ + shm[0] = reduce_val; \ + } \ + } \ + __syncthreads(); \ + return shm[0]; + +#define CINN_BLOCK_REDUCE_MACRO(REDUCE_TYPE, INITIAL_VALUE, DTYPE) \ + __device__ inline DTYPE cinn_block_reduce_##REDUCE_TYPE( \ + const DTYPE value, DTYPE *shm, bool return_warp = false) { \ + CINN_BLOCK_REDUCE_IMPL(DTYPE, INITIAL_VALUE, cinn_warp_shuffle_##REDUCE_TYPE##_internal); \ + } + +EXPAND_REDUCE_INT32_MACRO(CINN_BLOCK_REDUCE_MACRO) +EXPAND_REDUCE_INT64_MACRO(CINN_BLOCK_REDUCE_MACRO) +EXPAND_REDUCE_FP32_MACRO(CINN_BLOCK_REDUCE_MACRO) +EXPAND_REDUCE_FP64_MACRO(CINN_BLOCK_REDUCE_MACRO) +EXPAND_REDUCE_BOOL_MACRO(CINN_BLOCK_REDUCE_MACRO) +EXPAND_REDUCE_FP16_MACRO(CINN_BLOCK_REDUCE_MACRO) + +#define CINN_DISCRETE_REDUCE_IMPL(REDUCE_TYPE, value) \ + int tid = threadIdx.y * blockDim.x + threadIdx.x; \ + __syncthreads(); \ + shm[tid] = value; \ + __syncthreads(); \ + for (int offset = blockDim.y / 2; offset > 0; offset >>= 1) { \ + if (threadIdx.y < offset) { \ + shm[tid] = cinn_##REDUCE_TYPE(shm[tid], shm[tid + offset * blockDim.x]); \ + } \ + __syncthreads(); \ + } \ + return shm[threadIdx.x]; + +#define CINN_DISCRETE_REDUCE_MACRO(REDUCE_TYPE, INITIAL_VALUE, DTYPE) \ + __device__ inline DTYPE cinn_discrete_reduce_##REDUCE_TYPE( \ + const DTYPE value, DTYPE *shm) { \ + CINN_DISCRETE_REDUCE_IMPL(REDUCE_TYPE, value); \ + } + +EXPAND_REDUCE_INT32_MACRO(CINN_DISCRETE_REDUCE_MACRO) +EXPAND_REDUCE_INT64_MACRO(CINN_DISCRETE_REDUCE_MACRO) +EXPAND_REDUCE_FP32_MACRO(CINN_DISCRETE_REDUCE_MACRO) +EXPAND_REDUCE_FP64_MACRO(CINN_DISCRETE_REDUCE_MACRO) +EXPAND_REDUCE_BOOL_MACRO(CINN_DISCRETE_REDUCE_MACRO) +EXPAND_REDUCE_FP16_MACRO(CINN_DISCRETE_REDUCE_MACRO) + +#define CINN_GRID_REDUCE_IMPL(REDUCE_TYPE, init_value, DTYPE) \ + DTYPE tmp_val = init_value; \ + for (int y = 0; y < gridDim.y; y++) { \ + tmp_val = \ + cinn_##REDUCE_TYPE(tmp_val, mem[y * spatial_size + spatial_index]); \ + } \ + return tmp_val; + +#define CINN_GRID_REDUCE_MACRO(REDUCE_TYPE, INITIAL_VALUE, DTYPE) \ + __device__ inline DTYPE cinn_grid_reduce_##REDUCE_TYPE( \ + const DTYPE *mem, int spatial_size, int spatial_index) { \ + CINN_GRID_REDUCE_IMPL(REDUCE_TYPE, (DTYPE)(INITIAL_VALUE), DTYPE); \ + } + +EXPAND_REDUCE_INT32_MACRO(CINN_GRID_REDUCE_MACRO) +EXPAND_REDUCE_INT64_MACRO(CINN_GRID_REDUCE_MACRO) +EXPAND_REDUCE_FP32_MACRO(CINN_GRID_REDUCE_MACRO) +EXPAND_REDUCE_FP64_MACRO(CINN_GRID_REDUCE_MACRO) +EXPAND_REDUCE_BOOL_MACRO(CINN_GRID_REDUCE_MACRO) +EXPAND_REDUCE_FP16_MACRO(CINN_GRID_REDUCE_MACRO) + +__device__ inline bool cinn_grid_reduce_update_semaphore(int *semaphores) { + __shared__ bool done; + __threadfence(); + __syncthreads(); + if (threadIdx.x == 0 && threadIdx.y == 0 && threadIdx.z == 0) { + int old = atomicAdd(&semaphores[blockIdx.x], 1); + done = (old == (gridDim.y - 1)); + } + __syncthreads(); + return done; +} + +// =============================================================== +// 6. Standard Math Functions +// =============================================================== // =============================================================== // Float64 (Double) Math Functions // =============================================================== @@ -61,9 +397,23 @@ __device__ inline double FN_FP64(round)(double x) { return round(x); } __device__ inline double FN_FP64(trunc)(double x) { return trunc(x); } __device__ inline double FN_FP64(pow)(double a, double b) { return pow(a, b); } __device__ inline double FN_FP64(mod)(double a, double b) { return fmod(a, b); } +__device__ inline double FN_FP64(fma)(double a, double b, double c) { return fma(a, b, c); } __device__ inline bool FN_FP64(isnan)(double x) { return isnan(x); } __device__ inline bool FN_FP64(isinf)(double x) { return isinf(x); } __device__ inline bool FN_FP64(isfinite)(double x) { return isfinite(x); } +__device__ inline double FN_FP64(acos)(double x) { return acos(x); } +__device__ inline double FN_FP64(acosh)(double x) { return acosh(x); } +__device__ inline double FN_FP64(asin)(double x) { return asin(x); } +__device__ inline double FN_FP64(asinh)(double x) { return asinh(x); } +__device__ inline double FN_FP64(atan)(double x) { return atan(x); } +__device__ inline double FN_FP64(atanh)(double x) { return atanh(x); } +__device__ inline double FN_FP64(cbrt)(double x) { return cbrt(x); } +__device__ inline double FN_FP64(cosh)(double x) { return cosh(x); } +__device__ inline double FN_FP64(erf)(double x) { return erf(x); } +__device__ inline double FN_FP64(log1p)(double x) { return log1p(x); } +__device__ inline double FN_FP64(sigmoid)(double x) { return 1.0 / (1.0 + exp(-x)); } +__device__ inline double FN_FP64(sinh)(double x) { return sinh(x); } +__device__ inline double FN_FP64(tanh)(double x) { return tanh(x); } // =============================================================== // Float32 Math Functions @@ -83,15 +433,32 @@ __device__ inline float FN_FP32(ceil)(float x) { return ceilf(x); } __device__ inline float FN_FP32(round)(float x) { return roundf(x); } __device__ inline float FN_FP32(trunc)(float x) { return truncf(x); } __device__ inline float FN_FP32(abs)(float x) { return fabsf(x); } - -// =============================================================== -// Bool / Int logic -// =============================================================== -#define FN_BOOL(func) cinn_custom_device_##func##_bool -__device__ inline bool FN_BOOL(bitwise_and)(bool a, bool b) { return a & b; } -__device__ inline bool FN_BOOL(bitwise_or)(bool a, bool b) { return a | b; } -__device__ inline bool FN_BOOL(bitwise_not)(bool a) { return !a; } -__device__ inline bool FN_BOOL(bitwise_xor)(bool a, bool b) { return a ^ b; } +__device__ inline float FN_FP32(mod)(float a, float b) { return fmodf(a, b); } +__device__ inline float FN_FP32(fma)(float a, float b, float c) { return fmaf(a, b, c); } +__device__ inline bool FN_FP32(isnan)(float x) { return isnan(x); } +__device__ inline bool FN_FP32(isinf)(float x) { return isinf(x); } +__device__ inline bool FN_FP32(isfinite)(float x) { return isfinite(x); } +__device__ inline float FN_FP32(acos)(float x) { return acosf(x); } +__device__ inline float FN_FP32(acosh)(float x) { return acoshf(x); } +__device__ inline float FN_FP32(asin)(float x) { return asinf(x); } +__device__ inline float FN_FP32(asinh)(float x) { return asinhf(x); } +__device__ inline float FN_FP32(atan)(float x) { return atanf(x); } +__device__ inline float FN_FP32(atanh)(float x) { return atanhf(x); } +__device__ inline float FN_FP32(cbrt)(float x) { return cbrtf(x); } +__device__ inline float FN_FP32(cosh)(float x) { return coshf(x); } +__device__ inline float FN_FP32(erf)(float x) { return erff(x); } +__device__ inline float FN_FP32(log2)(float x) { return log2f(x); } +__device__ inline float FN_FP32(log10)(float x) { return log10f(x); } +__device__ inline float FN_FP32(log1p)(float x) { return log1pf(x); } +__device__ inline float FN_FP32(sigmoid)(float x) { return 1.0f / (1.0f + expf(-x)); } +__device__ inline float FN_FP32(sinh)(float x) { return sinhf(x); } +__device__ inline float FN_FP32(tanh)(float x) { return tanhf(x); } +__device__ inline float FN_FP32(left_shift)(float a, float b) { + return (float)((int)a << (int)b); +} +__device__ inline float FN_FP32(right_shift)(float a, float b) { + return (float)((int)a >> (int)b); +} // =============================================================== // Int32 Functions @@ -105,45 +472,350 @@ __device__ inline int FN_INT32(mod)(int a, int b) { if ((res != 0) && ((b ^ res) < 0)) res += b; return res; } +__device__ inline int FN_INT32(max)(int a, int b) { return cinn_max(a, b); } +__device__ inline int FN_INT32(min)(int a, int b) { return cinn_min(a, b); } +__device__ inline int FN_INT32(left_shift)(int a, int b) { return a << b; } +__device__ inline int FN_INT32(right_shift)(int a, int b) { return a >> b; } +__device__ inline int FN_INT32(bitwise_and)(int a, int b) { return a & b; } +__device__ inline int FN_INT32(bitwise_or)(int a, int b) { return a | b; } +__device__ inline int FN_INT32(bitwise_xor)(int a, int b) { return a ^ b; } +__device__ inline int FN_INT32(logical_right_shift)(int a, int b) { return (unsigned int)a >> b; } +__device__ inline int FN_INT32(trunc)(int a) { return a; } +__device__ inline int FN_INT32(pow)(int a, int b) { + if (a == 0 && b < 0) return -1; + float res = powf(__int2float_rd(a), __int2float_rd(b)); + return __float2int_rn(res); +} +__device__ inline int FN_INT32(arithmetic_right_shift)(int a, int b) { return a >> b; } + +// =============================================================== +// Int64 Functions +// =============================================================== +#define FN_INT64(func) cinn_custom_device_##func##_int64 +__device__ inline int64_t FN_INT64(bitwise_and)(int64_t a, int64_t b) { return a & b; } +__device__ inline int64_t FN_INT64(bitwise_or)(int64_t a, int64_t b) { return a | b; } +__device__ inline int64_t FN_INT64(bitwise_xor)(int64_t a, int64_t b) { return a ^ b; } +__device__ inline int64_t FN_INT64(bitwise_not)(int64_t a) { return ~a; } +__device__ inline int64_t FN_INT64(clz)(int64_t a) { return __clzll(a); } +__device__ inline int64_t FN_INT64(popc)(int64_t a) { return __popcll(a); } +__device__ inline int64_t FN_INT64(logical_right_shift)(int64_t a, int64_t b) { return ((uint64_t)a >> b); } +__device__ inline int64_t FN_INT64(trunc)(int64_t a) { return a; } +__device__ inline int64_t FN_INT64(mod)(int64_t a, int64_t b) { int64_t res = a % b; if ((res != 0) && ((b ^ res) < 0)) res += b; return res; } +__device__ inline int64_t FN_INT64(pow)(int64_t a, int64_t b) { double res = pow(__ll2double_rd(a), __ll2double_rd(b)); return __double2ll_rn(res); } // =============================================================== // Float16 (Half) Functions // =============================================================== #define FN_FP16(func) cinn_custom_device_##func##_fp16 -__device__ inline __half FN_FP16(ceil)(__half x) { return hceil(x); } -__device__ inline __half FN_FP16(floor)(__half x) { return hfloor(x); } -__device__ inline __half FN_FP16(sin)(__half x) { return hsin(x); } -__device__ inline __half FN_FP16(cos)(__half x) { return hcos(x); } -__device__ inline __half FN_FP16(exp)(__half x) { return hexp(x); } -__device__ inline __half FN_FP16(log)(__half x) { return hlog(x); } -__device__ inline __half FN_FP16(log2)(__half x) { return hlog2(x); } -__device__ inline __half FN_FP16(log10)(__half x) { return hlog10(x); } -__device__ inline __half FN_FP16(sqrt)(__half x) { return hsqrt(x); } -__device__ inline __half FN_FP16(rsqrt)(__half x) { return hrsqrt(x); } +#define FN_FP16(func) cinn_custom_device_##func##_fp16 +__device__ inline float16 FN_FP16(ceil)(float16 x) { return hceil(x); } +__device__ inline float16 FN_FP16(floor)(float16 x) { return hfloor(x); } +__device__ inline float16 FN_FP16(round)(float16 x) { return __float2half(roundf(__half2float(x))); } +__device__ inline float16 FN_FP16(trunc)(float16 x) { return htrunc(x); } +__device__ inline float16 FN_FP16(sin)(float16 x) { return hsin(x); } +__device__ inline float16 FN_FP16(cos)(float16 x) { return hcos(x); } +__device__ inline float16 FN_FP16(exp)(float16 x) { return hexp(x); } +__device__ inline float16 FN_FP16(log)(float16 x) { return hlog(x); } +__device__ inline float16 FN_FP16(log2)(float16 x) { return hlog2(x); } +__device__ inline float16 FN_FP16(log10)(float16 x) { return hlog10(x); } +__device__ inline float16 FN_FP16(sqrt)(float16 x) { return hsqrt(x); } +__device__ inline float16 FN_FP16(rsqrt)(float16 x) { return hrsqrt(x); } +__device__ inline float16 FN_FP16(cbrt)(float16 x) { return __float2half(cbrtf(__half2float(x))); } +__device__ inline float16 FN_FP16(abs)(float16 x) { return __float2half(fabsf(__half2float(x))); } +__device__ inline bool FN_FP16(isnan)(float16 x) { return __hisnan(x); } +__device__ inline bool FN_FP16(isinf)(float16 x) { return __hisinf(x); } +__device__ inline bool FN_FP16(isfinite)(float16 x) { return !__hisinf(x) && !__hisnan(x); } +__device__ inline float16 FN_FP16(erf)(float16 x) { return __float2half(erff(__half2float(x))); } +__device__ inline float16 FN_FP16(tan)(float16 x) { return __float2half(tanf(__half2float(x))); } +__device__ inline float16 FN_FP16(sinh)(float16 x) { return __float2half(sinhf(__half2float(x))); } +__device__ inline float16 FN_FP16(cosh)(float16 x) { return __float2half(coshf(__half2float(x))); } +__device__ inline float16 FN_FP16(tanh)(float16 x) { return __float2half(tanhf(__half2float(x))); } +__device__ inline float16 FN_FP16(asin)(float16 x) { return __float2half(asinf(__half2float(x))); } +__device__ inline float16 FN_FP16(acos)(float16 x) { return __float2half(acosf(__half2float(x))); } +__device__ inline float16 FN_FP16(atan)(float16 x) { return __float2half(atanf(__half2float(x))); } +__device__ inline float16 FN_FP16(asinh)(float16 x) { return __float2half(asinhf(__half2float(x))); } +__device__ inline float16 FN_FP16(acosh)(float16 x) { return __float2half(acoshf(__half2float(x))); } +__device__ inline float16 FN_FP16(atanh)(float16 x) { return __float2half(atanhf(__half2float(x))); } +__device__ inline float16 FN_FP16(sigmoid)(float16 x) { return __float2half(1.0f / (1.0f + expf(-__half2float(x)))); } +__device__ inline float16 FN_FP16(mod)(float16 a, float16 b) { return __float2half(fmodf(__half2float(a), __half2float(b))); } +__device__ inline float16 FN_FP16(pow)(float16 a, float16 b) { return __float2half(powf(__half2float(a), __half2float(b))); } +__device__ inline float16 FN_FP16(add)(float16 a, float16 b) { return __hadd(a, b); } +__device__ inline float16 FN_FP16(sub)(float16 a, float16 b) { return __hsub(a, b); } +__device__ inline float16 FN_FP16(mul)(float16 a, float16 b) { return __hmul(a, b); } +__device__ inline float16 FN_FP16(div)(float16 a, float16 b) { return __hdiv(a, b); } +__device__ inline float16 FN_FP16(neg)(float16 a) { return __hneg(a); } +__device__ inline float16 FN_FP16(fma)(float16 a, float16 b, float16 c) { return __hfma(a, b, c); } +__device__ inline float16 FN_FP16(max)(float16 a, float16 b) { return __hgt(a, b) ? a : b; } +__device__ inline float16 FN_FP16(min)(float16 a, float16 b) { return __hlt(a, b) ? a : b; } // =============================================================== -// Index Operations +// Warp Shuffle Functions (用于 Reduce 算子) // =============================================================== -#define CINN_CUSTOM_DEVICE_FIND_KERNEL(buf, size, num, begin, stride) \ - do { \ +#define FN_SHUFFLE(func) cinn_custom_device_##func + +__device__ inline float FN_SHUFFLE(warp_shuffle_xor_fp32)(float v, int factor) { + return __shfl_xor_sync(0xffffffff, v, factor); +} +__device__ inline float FN_SHUFFLE(warp_shuffle_up_fp32)(float v, int factor) { + return __shfl_up_sync(0xffffffff, v, factor); +} +__device__ inline float FN_SHUFFLE(warp_shuffle_down_fp32)(float v, int factor) { + return __shfl_down_sync(0xffffffff, v, factor); +} + +__device__ inline int FN_SHUFFLE(warp_shuffle_xor_int32)(int v, int factor) { + return __shfl_xor_sync(0xffffffff, v, factor); +} +__device__ inline int FN_SHUFFLE(warp_shuffle_up_int32)(int v, int factor) { + return __shfl_up_sync(0xffffffff, v, factor); +} +__device__ inline int FN_SHUFFLE(warp_shuffle_down_int32)(int v, int factor) { + return __shfl_down_sync(0xffffffff, v, factor); +} + +// MACA/CUDA 的 shfl 指令通常只支持 32位,__half 需要强转或使用 intrinsics +__device__ inline __half FN_SHUFFLE(warp_shuffle_xor_fp16)(__half v, int factor) { + unsigned short val = __half_as_ushort(v); + unsigned short res = (unsigned short)__shfl_xor_sync(0xffffffff, (int)val, factor); + return __ushort_as_half(res); +} +__device__ inline __half FN_SHUFFLE(warp_shuffle_up_fp16)(__half v, int factor) { + unsigned short val = __half_as_ushort(v); + unsigned short res = (unsigned short)__shfl_up_sync(0xffffffff, (int)val, factor); + return __ushort_as_half(res); +} +__device__ inline __half FN_SHUFFLE(warp_shuffle_down_fp16)(__half v, int factor) { + unsigned short val = __half_as_ushort(v); + unsigned short res = (unsigned short)__shfl_down_sync(0xffffffff, (int)val, factor); + return __ushort_as_half(res); +} + +// =============================================================== +// 7. Index Operations: Find, Sort & Resize Helpers +// =============================================================== +#define __cinn_custom_device_find_kernel(buf, size, num, begin, stride) \ + do { \ for (int i = (size - 1) * stride + begin; i >= begin; i -= stride) { \ - if (buf[i] == num) return (i - begin) / stride; \ - } \ - return -1; \ + if (buf[i] == num) return (i - begin) / stride; \ + } \ + return -1; \ } while (0) __device__ inline int cinn_custom_device_find_int(const int *buf, int size, int num) { - CINN_CUSTOM_DEVICE_FIND_KERNEL(buf, size, num, 0, 1); + __cinn_custom_device_find_kernel(buf, size, num, 0, 1); } __device__ inline int cinn_custom_device_find_float(const float *buf, int size, float num) { - CINN_CUSTOM_DEVICE_FIND_KERNEL(buf, size, num, 0, 1); + __cinn_custom_device_find_kernel(buf, size, num, 0, 1); } __device__ inline int cinn_custom_device_find_int_nd(const int *buf, int size, int num, int begin, int stride) { - CINN_CUSTOM_DEVICE_FIND_KERNEL(buf, size, num, begin, stride); + __cinn_custom_device_find_kernel(buf, size, num, begin, stride); } __device__ inline int cinn_custom_device_find_float_nd(const float *buf, int size, float num, int begin, int stride) { - CINN_CUSTOM_DEVICE_FIND_KERNEL(buf, size, num, begin, stride); + __cinn_custom_device_find_kernel(buf, size, num, begin, stride); +} +#undef __cinn_custom_device_find_kernel + +__device__ inline int cinn_custom_device_next_smallest_int32(int *buf, int size, int num, int begin, int stride) { + int id = -1; + for (int i = begin; i < begin + size * stride; i += stride) { + if (id == -1 || buf[i] < buf[id]) { + id = i; + } + } + if (id != -1) { + buf[id] = CINN_INT32_MAX; + return (id - begin) / stride; + } + return -1; +} + +#define __cinn_custom_device_find_from_kernel(buf, size, num, begin) \ + do { \ + for (int i = begin; i < size; ++i) { \ + if (buf[i] == num) return i; \ + } \ + return -1; \ + } while (0) + +__device__ inline int cinn_custom_device_find_int_from(const int *buf, int size, int num, int begin) { + __cinn_custom_device_find_from_kernel(buf, size, num, begin); +} +__device__ inline int cinn_custom_device_find_float_from(const float *buf, int size, float num, int begin) { + __cinn_custom_device_find_from_kernel(buf, size, num, begin); +} +#undef __cinn_custom_device_find_from_kernel + +#define CINN_CUSTOM_DEVICE_LT_NUM(TYPE_SUFFIX, TYPE) \ + __device__ inline int cinn_custom_device_lt_num_##TYPE_SUFFIX(const TYPE *buf, \ + const int size, \ + const TYPE num, \ + const int offset, \ + const int stride) { \ + int out = 0; \ + for (int i = (size - 1) * stride + offset; i >= offset; i -= stride) { \ + if (buf[i] < num) out++; \ + } \ + return out; \ + } + +CINN_CUSTOM_DEVICE_LT_NUM(fp32, float) +CINN_CUSTOM_DEVICE_LT_NUM(fp64, double) +CINN_CUSTOM_DEVICE_LT_NUM(uint8, uint8_t) +CINN_CUSTOM_DEVICE_LT_NUM(int16, int16_t) +CINN_CUSTOM_DEVICE_LT_NUM(int32, int) +CINN_CUSTOM_DEVICE_LT_NUM(int64, int64_t) +CINN_CUSTOM_DEVICE_LT_NUM(fp16, float16) +#undef CINN_CUSTOM_DEVICE_LT_NUM + +#define CINN_CUSTOM_DEVICE_GT_NUM(TYPE_SUFFIX, TYPE) \ + __device__ inline int cinn_custom_device_gt_num_##TYPE_SUFFIX(const TYPE *buf, \ + const int size, \ + const TYPE num, \ + const int offset, \ + const int stride) { \ + int out = 0; \ + for (int i = (size - 1) * stride + offset; i >= offset; i -= stride) { \ + if (buf[i] > num) out++; \ + } \ + return out; \ + } + +CINN_CUSTOM_DEVICE_GT_NUM(fp32, float) +CINN_CUSTOM_DEVICE_GT_NUM(fp64, double) +CINN_CUSTOM_DEVICE_GT_NUM(uint8, uint8_t) +CINN_CUSTOM_DEVICE_GT_NUM(int16, int16_t) +CINN_CUSTOM_DEVICE_GT_NUM(int32, int) +CINN_CUSTOM_DEVICE_GT_NUM(int64, int64_t) +CINN_CUSTOM_DEVICE_GT_NUM(fp16, float16) +#undef CINN_CUSTOM_DEVICE_GT_NUM + +#define CINN_CUSTOM_DEVICE_INDEX_ADD(TYPE_SUFFIX, TYPE) \ + __device__ inline TYPE cinn_custom_device_index_add_##TYPE_SUFFIX( \ + const TYPE x, \ + const int axis_indice, \ + const TYPE *__restrict__ y, \ + const int offset, \ + const int stride, \ + const int *__restrict__ index, \ + const int index_size) { \ + TYPE res = x; \ + int idx = -1; \ + do { \ + idx = cinn_custom_device_find_int_from(index, index_size, axis_indice, idx + 1); \ + if (idx >= 0) { \ + res = res + y[offset + idx * stride]; \ + } \ + } while (idx != -1); \ + return res; \ + } + +CINN_CUSTOM_DEVICE_INDEX_ADD(bool, bool) +CINN_CUSTOM_DEVICE_INDEX_ADD(int8, int8_t) +CINN_CUSTOM_DEVICE_INDEX_ADD(int32, int32_t) +CINN_CUSTOM_DEVICE_INDEX_ADD(int64, int64_t) +CINN_CUSTOM_DEVICE_INDEX_ADD(fp32, float) +CINN_CUSTOM_DEVICE_INDEX_ADD(fp64, double) +CINN_CUSTOM_DEVICE_INDEX_ADD(fp16, float16) +#undef CINN_CUSTOM_DEVICE_INDEX_ADD + +__device__ int cinn_custom_device_resize_bilinear(const int *buf, + const int c_size, + const int in_h, + const int in_w, + const int out_h, + const int out_w, + const int n, + const int c, + const int y, + const int x) { + float scale_y = static_cast(in_h) / out_h; + float scale_x = static_cast(in_w) / out_w; + float in_y = (y + 0.5F) * scale_y - 0.5F; + float in_x = (x + 0.5F) * scale_x - 0.5F; + int in_y_int = static_cast(floorf(in_y)); + int in_x_int = static_cast(floorf(in_x)); + float y_lerp = in_y - in_y_int; + float x_lerp = in_x - in_x_int; + float p[2][2]; + + for (int i = 0; i < 2; ++i) { + for (int j = 0; j < 2; ++j) { + int near_y = in_y_int + i; + int near_x = in_x_int + j; + near_y = max(min(near_y, in_h - 1), 0); + near_x = max(min(near_x, in_w - 1), 0); + p[i][j] = buf[n * c_size * in_h * in_w + c * in_h * in_w + near_y * in_w + + near_x]; + } + } + + float top = p[0][0] * (1.0F - x_lerp) + p[0][1] * x_lerp; + float bottom = p[1][0] * (1.0F - x_lerp) + p[1][1] * x_lerp; + float value = top * (1.0F - y_lerp) + bottom * y_lerp; + return value; +} + +__device__ int cinn_custom_device_resize_bicubic(const int *buf, + const int c_size, + const int in_h, + const int in_w, + const int out_h, + const int out_w, + const int n, + const int c, + const int y, + const int x) { + float scale_y = static_cast(in_h) / out_h; + float scale_x = static_cast(in_w) / out_w; + float in_y = (y + 0.5F) * scale_y - 0.5F; + float in_x = (x + 0.5F) * scale_x - 0.5F; + int in_y_int = static_cast(floorf(in_y)); + int in_x_int = static_cast(floorf(in_x)); + float y_fract = in_y - floorf(in_y); + float x_fract = in_x - floorf(in_x); + float p[4][4]; + + for (int i = 0; i < 4; ++i) { + for (int j = 0; j < 4; ++j) { + int near_y = in_y_int + i - 1; + int near_x = in_x_int + j - 1; + near_y = max(min(near_y, in_h - 1), 0); + near_x = max(min(near_x, in_w - 1), 0); + p[i][j] = buf[n * c_size * in_h * in_w + c * in_h * in_w + near_y * in_w + + near_x]; + } + } + + float alpha = -0.5F; + float w[2][4]; + + for (int i = 0; i < 2; ++i) { + float t = (i == 0 ? x_fract : y_fract); + float t2 = t * t; + float t3 = t * t * t; + w[i][0] = alpha * (t3 - 2 * t2 + t); + w[i][1] = (alpha + 2) * t3 - (3 + alpha) * t2 + 1; + w[i][2] = -(alpha + 2) * t3 + (3 + 2 * alpha) * t2 - alpha * t; + w[i][3] = -alpha * t3 + alpha * t2; + } + + float col[4]; + + for (int i = 0; i < 4; ++i) { + col[i] = 0.0F; + for (int j = 0; j < 4; ++j) { + col[i] += p[i][j] * w[0][j]; + } + } + + float value = 0.0F; + + for (int i = 0; i < 4; ++i) { + value += col[i] * w[1][i]; + } + + return value; } } // extern "C" diff --git a/backends/metax_gpu/tests/tmp_save/unittest/test_elementwise_pow_op_metax.py b/backends/metax_gpu/tests/tmp_save/unittest/test_elementwise_pow_op_metax.py index 03fb70419e..73d9374f67 100755 --- a/backends/metax_gpu/tests/tmp_save/unittest/test_elementwise_pow_op_metax.py +++ b/backends/metax_gpu/tests/tmp_save/unittest/test_elementwise_pow_op_metax.py @@ -31,7 +31,6 @@ def pow_grad(x, y, dout): dy = dout * np.log(x) * np.power(x, y) return dx, dy - class TestElementwisePowOp(OpTest): def setUp(self): self.op_type = "elementwise_pow" @@ -62,7 +61,6 @@ def test_check_grad_normal(self): check_pir=True, ) -''' class TestElementwisePowOp_ZeroDim1(TestElementwisePowOp): def setUp(self): self.op_type = "elementwise_pow" @@ -90,7 +88,7 @@ def setUp(self): } self.outputs = {"Out": np.power(self.inputs["X"], self.inputs["Y"])} - +''' class TestElementwisePowOp_ZeroDim3(TestElementwisePowOp): def setUp(self): self.op_type = "elementwise_pow" From 0936c730593b46dfc457ce10b53bdd10fb3b93a6 Mon Sep 17 00:00:00 2001 From: YuhanXu Date: Tue, 20 Jan 2026 19:18:52 +0800 Subject: [PATCH 05/17] Fix CINN compilation errors and incorrect reduction results on MetaX backend. Run test_elementwise_pow_op_metax.py success. --- backends/metax_gpu/cinn/compiler/compiler.cc | 163 +++++++----------- .../unittest/test_elementwise_pow_op_metax.py | 2 - 2 files changed, 65 insertions(+), 100 deletions(-) diff --git a/backends/metax_gpu/cinn/compiler/compiler.cc b/backends/metax_gpu/cinn/compiler/compiler.cc index 5031e0196f..6b3afc004b 100644 --- a/backends/metax_gpu/cinn/compiler/compiler.cc +++ b/backends/metax_gpu/cinn/compiler/compiler.cc @@ -168,90 +168,62 @@ __device__ inline float16 cinn_min_fp16(const float16 left, const float16 right) // =============================================================== -// 4. Warp Shuffle Wrappers +// 4. Warp Shuffle Wrappers (Using Legacy API & Full Down Strategy) // =============================================================== -#define CINN_WARP_SHUFFLE_INTERNAL_IMPL(REDUCE_TYPE, INITIAL_VALUE, DTYPE) \ +// 【核心修复】Warp Reduce 逻辑重写 +// 1. 弃用 XOR 模式:因为在 64-thread warp 下,跨 32 边界的 XOR 可能存在未定义行为或硬件 bug。 +// 2. 统一使用 DOWN 模式:__shfl_down 是单向规约,Lane 0 总是能收集到数据的,更加稳健。 +// 3. 严格的边界检查:确保 fetch 的来源线程在 Block 范围内,否则使用 INIT_VAL 填充。 + +#define CINN_WARP_SHUFFLE_INTERNAL_IMPL(REDUCE_TYPE, INIT_VAL, DTYPE) \ __device__ inline DTYPE cinn_warp_shuffle_##REDUCE_TYPE##_internal( \ const DTYPE value) { \ - DTYPE tmp_val = value, shfl_res; \ + DTYPE tmp_val = value; \ unsigned int thread_id = threadIdx.x; \ unsigned int block_dim = blockDim.x; \ - unsigned int last_warp_size = block_dim - (thread_id - (threadIdx.x % WARP_SIZE)); \ - if (last_warp_size < WARP_SIZE) { \ - for (unsigned int offset = WARP_SIZE / 2; offset >= 1; offset /= 2) { \ - /* 使用通用的 shuffle down 实现 */ \ - shfl_res = cinn_warp_shuffle_down_##DTYPE##_wrapper(tmp_val, offset); \ - tmp_val = cinn_##REDUCE_TYPE(thread_id + offset < block_dim \ - ? shfl_res \ - : (DTYPE)(INITIAL_VALUE), \ - tmp_val); \ - } \ - /* 这里的 __shfl 广播可以用 shfl_sync(0) 替代 */ \ - tmp_val = __shfl_sync(0xffffffff, tmp_val, 0); \ - } else { \ - for (unsigned int offset = WARP_SIZE / 2; offset >= 1; offset /= 2) { \ - tmp_val = cinn_##REDUCE_TYPE(tmp_val, \ - cinn_warp_shuffle_xor_##DTYPE##_wrapper(tmp_val, offset)); \ - } \ + /* 始终使用 Down Shuffle 进行规约 (Log2 复杂度) */ \ + for (unsigned int offset = WARP_SIZE / 2; offset >= 1; offset /= 2) { \ + DTYPE shfl_res = cinn_warp_shuffle_down_##DTYPE##_wrapper(tmp_val, offset); \ + /* 检查数据来源是否有效:当前线程+offset 必须还在 Block 范围内 */ \ + /* 如果 Block 大小不是 WARP_SIZE 的倍数,这一步至关重要 */ \ + DTYPE neighbor = (thread_id + offset < block_dim) ? shfl_res : (DTYPE)(INIT_VAL); \ + tmp_val = cinn_##REDUCE_TYPE(tmp_val, neighbor); \ } \ - return tmp_val; \ + /* 广播:虽然 Down Shuffle 只有 Lane 0 结果正确,但这里为了兼容 XOR 语义 */ \ + /* 我们用 shfl 0 把 Lane 0 的结果广播给所有人 (CINN Block Reduce 需要) */ \ + return __shfl(tmp_val, 0); \ } -// --- Warp Shuffle Primitives (Internal Helpers) --- -// 为了适配宏展开,这里定义带后缀的 wrapper,统一 float16/double 处理 - -__device__ inline float cinn_warp_shuffle_down_float_wrapper(float v, int factor) { return __shfl_down_sync(0xffffffff, v, factor); } -__device__ inline float cinn_warp_shuffle_xor_float_wrapper(float v, int factor) { return __shfl_xor_sync(0xffffffff, v, factor); } - -__device__ inline int cinn_warp_shuffle_down_int_wrapper(int v, int factor) { return __shfl_down_sync(0xffffffff, v, factor); } -__device__ inline int cinn_warp_shuffle_xor_int_wrapper(int v, int factor) { return __shfl_xor_sync(0xffffffff, v, factor); } +// --- Warp Shuffle Primitives (Legacy API without mask) --- -__device__ inline bool cinn_warp_shuffle_down_bool_wrapper(bool v, int factor) { return __shfl_down_sync(0xffffffff, v, factor); } -__device__ inline bool cinn_warp_shuffle_xor_bool_wrapper(bool v, int factor) { return __shfl_xor_sync(0xffffffff, v, factor); } +__device__ inline float cinn_warp_shuffle_down_float_wrapper(float v, int factor) { return __shfl_down(v, factor); } +__device__ inline int cinn_warp_shuffle_down_int_wrapper(int v, int factor) { return __shfl_down(v, factor); } +__device__ inline bool cinn_warp_shuffle_down_bool_wrapper(bool v, int factor) { return __shfl_down(v, factor); } __device__ inline double cinn_warp_shuffle_down_double_wrapper(double v, int factor) { unsigned long long int val_u64 = *(unsigned long long int*)&v; int lo = (int)val_u64; int hi = (int)(val_u64 >> 32); - lo = __shfl_down_sync(0xffffffff, lo, factor); - hi = __shfl_down_sync(0xffffffff, hi, factor); - unsigned long long int res_u64 = ((unsigned long long int)hi << 32) | (unsigned int)lo; - return *(double*)&res_u64; -} -__device__ inline double cinn_warp_shuffle_xor_double_wrapper(double v, int factor) { - unsigned long long int val_u64 = *(unsigned long long int*)&v; - int lo = (int)val_u64; int hi = (int)(val_u64 >> 32); - lo = __shfl_xor_sync(0xffffffff, lo, factor); - hi = __shfl_xor_sync(0xffffffff, hi, factor); + lo = __shfl_down(lo, factor); + hi = __shfl_down(hi, factor); unsigned long long int res_u64 = ((unsigned long long int)hi << 32) | (unsigned int)lo; return *(double*)&res_u64; } __device__ inline int64_t cinn_warp_shuffle_down_int64_t_wrapper(int64_t v, int factor) { int lo = (int)v; int hi = (int)(v >> 32); - lo = __shfl_down_sync(0xffffffff, lo, factor); - hi = __shfl_down_sync(0xffffffff, hi, factor); - return ((int64_t)hi << 32) | (unsigned int)lo; -} -__device__ inline int64_t cinn_warp_shuffle_xor_int64_t_wrapper(int64_t v, int factor) { - int lo = (int)v; int hi = (int)(v >> 32); - lo = __shfl_xor_sync(0xffffffff, lo, factor); - hi = __shfl_xor_sync(0xffffffff, hi, factor); + lo = __shfl_down(lo, factor); + hi = __shfl_down(hi, factor); return ((int64_t)hi << 32) | (unsigned int)lo; } __device__ inline float16 cinn_warp_shuffle_down_float16_wrapper(float16 v, int factor) { unsigned short val = __half_as_ushort(v); - unsigned short res = (unsigned short)__shfl_down_sync(0xffffffff, (int)val, factor); - return __ushort_as_half(res); -} -__device__ inline float16 cinn_warp_shuffle_xor_float16_wrapper(float16 v, int factor) { - unsigned short val = __half_as_ushort(v); - unsigned short res = (unsigned short)__shfl_xor_sync(0xffffffff, (int)val, factor); + unsigned short res = (unsigned short)__shfl_down((int)val, factor); return __ushort_as_half(res); } -// 展开 Internal Implementations +// Expand Warp Shuffle EXPAND_REDUCE_INT32_MACRO(CINN_WARP_SHUFFLE_INTERNAL_IMPL) EXPAND_REDUCE_INT64_MACRO(CINN_WARP_SHUFFLE_INTERNAL_IMPL) EXPAND_REDUCE_FP32_MACRO(CINN_WARP_SHUFFLE_INTERNAL_IMPL) @@ -263,48 +235,44 @@ EXPAND_REDUCE_FP16_MACRO(CINN_WARP_SHUFFLE_INTERNAL_IMPL) // 5. Block Reduce & Discrete Reduce & Grid Reduce // =============================================================== -#define CINN_BLOCK_REDUCE_IMPL(DTYPE, INITIAL_VALUE, cinn_warp_shuffle_internal) \ - /* 1. Warp内规约 */ \ - DTYPE tmp_val = cinn_warp_shuffle_internal(value); \ - \ - /* 如果只有一个 warp,直接返回 */ \ - if (return_warp || blockDim.x <= WARP_SIZE) { \ - return tmp_val; \ - } \ - __syncthreads(); \ - \ - /* 2. 每个 Warp 的结果写入共享内存 (仅 Lane 0 写入) */ \ - if (threadIdx.x % WARP_SIZE == 0) { \ - shm[threadIdx.x / WARP_SIZE] = tmp_val; \ - } \ - __syncthreads(); \ - \ - /* 3. Warp 0 负责汇总 */ \ - if (threadIdx.x < WARP_SIZE) { \ - /* 计算有多少个 Warp */ \ - int num_warps = (blockDim.x + WARP_SIZE - 1) / WARP_SIZE; \ - \ - /* 【核心修复】Lane >= num_warps 的线程必须加载 IDENTITY,否则后面 shuffle 会引入脏数据 */ \ - DTYPE reduce_val = (DTYPE)(INITIAL_VALUE); \ - if (threadIdx.x < num_warps) { \ - reduce_val = shm[threadIdx.x]; \ - } \ - \ - /* Warp 0 再次进行规约 (所有 64 个线程都参与) */ \ - reduce_val = cinn_warp_shuffle_internal(reduce_val); \ - \ - /* 结果写入 shm[0] */ \ - if (threadIdx.x == 0) { \ - shm[0] = reduce_val; \ - } \ - } \ - __syncthreads(); \ +// Block Reduce Implementation +// 1. Warp Reduce -> SHM +// 2. Warp 0 reads SHM and Pads with Identity +// 3. Warp 0 Reduce +// 4. Broadcast +#define CINN_BLOCK_REDUCE_IMPL(DTYPE, INIT_VAL, cinn_warp_shuffle_internal) \ + /* 1. Warp Reduce */ \ + DTYPE tmp_val = cinn_warp_shuffle_internal(value); \ + if (return_warp || blockDim.x <= WARP_SIZE) { \ + return tmp_val; \ + } \ + __syncthreads(); \ + /* 2. Write Warp results to SHM (Lane 0 only) */ \ + if (threadIdx.x % WARP_SIZE == 0) { \ + shm[threadIdx.x / WARP_SIZE] = tmp_val; \ + } \ + __syncthreads(); \ + /* 3. Inter-Warp Reduce (Warp 0 only) */ \ + if (threadIdx.x < WARP_SIZE) { \ + int num_warps = (blockDim.x + WARP_SIZE - 1) / WARP_SIZE; \ + /* Pad with Identity value for idle threads in Warp 0 */ \ + DTYPE reduce_val = (DTYPE)(INIT_VAL); \ + if (threadIdx.x < num_warps) { \ + reduce_val = shm[threadIdx.x]; \ + } \ + /* Reduce across all threads in Warp 0 */ \ + reduce_val = cinn_warp_shuffle_internal(reduce_val); \ + if (threadIdx.x == 0) { \ + shm[0] = reduce_val; \ + } \ + } \ + __syncthreads(); \ return shm[0]; -#define CINN_BLOCK_REDUCE_MACRO(REDUCE_TYPE, INITIAL_VALUE, DTYPE) \ +#define CINN_BLOCK_REDUCE_MACRO(REDUCE_TYPE, INIT_VAL, DTYPE) \ __device__ inline DTYPE cinn_block_reduce_##REDUCE_TYPE( \ const DTYPE value, DTYPE *shm, bool return_warp = false) { \ - CINN_BLOCK_REDUCE_IMPL(DTYPE, INITIAL_VALUE, cinn_warp_shuffle_##REDUCE_TYPE##_internal); \ + CINN_BLOCK_REDUCE_IMPL(DTYPE, INIT_VAL, cinn_warp_shuffle_##REDUCE_TYPE##_internal); \ } EXPAND_REDUCE_INT32_MACRO(CINN_BLOCK_REDUCE_MACRO) @@ -327,7 +295,7 @@ EXPAND_REDUCE_FP16_MACRO(CINN_BLOCK_REDUCE_MACRO) } \ return shm[threadIdx.x]; -#define CINN_DISCRETE_REDUCE_MACRO(REDUCE_TYPE, INITIAL_VALUE, DTYPE) \ +#define CINN_DISCRETE_REDUCE_MACRO(REDUCE_TYPE, INIT_VAL, DTYPE) \ __device__ inline DTYPE cinn_discrete_reduce_##REDUCE_TYPE( \ const DTYPE value, DTYPE *shm) { \ CINN_DISCRETE_REDUCE_IMPL(REDUCE_TYPE, value); \ @@ -348,10 +316,10 @@ EXPAND_REDUCE_FP16_MACRO(CINN_DISCRETE_REDUCE_MACRO) } \ return tmp_val; -#define CINN_GRID_REDUCE_MACRO(REDUCE_TYPE, INITIAL_VALUE, DTYPE) \ +#define CINN_GRID_REDUCE_MACRO(REDUCE_TYPE, INIT_VAL, DTYPE) \ __device__ inline DTYPE cinn_grid_reduce_##REDUCE_TYPE( \ const DTYPE *mem, int spatial_size, int spatial_index) { \ - CINN_GRID_REDUCE_IMPL(REDUCE_TYPE, (DTYPE)(INITIAL_VALUE), DTYPE); \ + CINN_GRID_REDUCE_IMPL(REDUCE_TYPE, (DTYPE)(INIT_VAL), DTYPE); \ } EXPAND_REDUCE_INT32_MACRO(CINN_GRID_REDUCE_MACRO) @@ -372,7 +340,6 @@ __device__ inline bool cinn_grid_reduce_update_semaphore(int *semaphores) { __syncthreads(); return done; } - // =============================================================== // 6. Standard Math Functions // =============================================================== diff --git a/backends/metax_gpu/tests/tmp_save/unittest/test_elementwise_pow_op_metax.py b/backends/metax_gpu/tests/tmp_save/unittest/test_elementwise_pow_op_metax.py index 73d9374f67..9652b0ce47 100755 --- a/backends/metax_gpu/tests/tmp_save/unittest/test_elementwise_pow_op_metax.py +++ b/backends/metax_gpu/tests/tmp_save/unittest/test_elementwise_pow_op_metax.py @@ -88,7 +88,6 @@ def setUp(self): } self.outputs = {"Out": np.power(self.inputs["X"], self.inputs["Y"])} -''' class TestElementwisePowOp_ZeroDim3(TestElementwisePowOp): def setUp(self): self.op_type = "elementwise_pow" @@ -455,7 +454,6 @@ def test_check_grad(self): only_check_prim=True, check_prim_pir=True, ) -''' if __name__ == "__main__": unittest.main() From ab39c828791347d8e1431285d733c25d823c367f Mon Sep 17 00:00:00 2001 From: YuhanXu Date: Tue, 10 Feb 2026 12:02:36 +0800 Subject: [PATCH 06/17] Fix some bug. --- backends/metax_gpu/CMakeLists.txt | 2 +- backends/metax_gpu/build.sh | 7 +- .../kernels/custom_kernel/custom_context.h | 3 + backends/metax_gpu/runtime/runtime.cc | 4 +- backends/metax_gpu/tests/run_test.sh | 4 +- .../tests/tmp_save/gpudnn/conv_cudnn_v7.h | 6 +- .../unittest/test_elementwise_pow_op_metax.py | 72 ++++++++++++++----- 7 files changed, 71 insertions(+), 27 deletions(-) diff --git a/backends/metax_gpu/CMakeLists.txt b/backends/metax_gpu/CMakeLists.txt index 93d7a4eeae..23a934cbf0 100755 --- a/backends/metax_gpu/CMakeLists.txt +++ b/backends/metax_gpu/CMakeLists.txt @@ -853,7 +853,7 @@ target_compile_definitions( ${TARGET_NAME} PUBLIC PADDLE_WITH_CUDA=1 PADDLE_WITH_CUSTOM_DEVICE=1 - cublasContext=mcblasContext + mcblasContext=cublasContext cublasLtContext=mcblasLtContext cudnnContext==mcdnnContex GPUContext=CustomContext diff --git a/backends/metax_gpu/build.sh b/backends/metax_gpu/build.sh index 223baa2a9f..9eaa73e571 100755 --- a/backends/metax_gpu/build.sh +++ b/backends/metax_gpu/build.sh @@ -18,12 +18,13 @@ set -e # install requirement.txt -pip install -r requirement.txt -i https://mirrors.tuna.tsinghua.edu.cn/pypi/web/simple +# pip install -r requirement.txt -i https://mirrors.tuna.tsinghua.edu.cn/pypi/web/simple # uninstall paddle -pip uninstall paddlepaddle -y +# pip uninstall paddlepaddle -y -python -m pip install --pre paddlepaddle -i https://www.paddlepaddle.org.cn/packages/nightly/cpu/ + +# python -m pip install --pre paddlepaddle -i https://www.paddlepaddle.org.cn/packages/nightly/cpu/ # apply patch diff --git a/backends/metax_gpu/kernels/custom_kernel/custom_context.h b/backends/metax_gpu/kernels/custom_kernel/custom_context.h index 19035992ea..dd232f841f 100644 --- a/backends/metax_gpu/kernels/custom_kernel/custom_context.h +++ b/backends/metax_gpu/kernels/custom_kernel/custom_context.h @@ -29,6 +29,7 @@ #include "paddle/phi/core/device_context.h" namespace phi { + // class DnnWorkspaceHandle { // public: // inline DnnWorkspaceHandle(Allocator* allocator, gpuStream_t stream) @@ -100,6 +101,7 @@ namespace phi { // } // } // namespace + namespace dynload { inline bool HasCUSOLVER() { @@ -156,5 +158,6 @@ inline cusolverDnHandle_t GetCusolverDnHandle(gpuStream_t stream, Place place) { // const gpuStream_t& stream) { // return DnnWorkspaceHandle(alloactor, stream); // } + } // namespace phi #endif // BACKENDS_METAX_GPU_KERNELS_CUSTOM_KERNEL_CUSTOM_CONTEXT_H_ diff --git a/backends/metax_gpu/runtime/runtime.cc b/backends/metax_gpu/runtime/runtime.cc index 64a7e90fbd..6f295a2019 100644 --- a/backends/metax_gpu/runtime/runtime.cc +++ b/backends/metax_gpu/runtime/runtime.cc @@ -411,7 +411,7 @@ C_Status GetMaxSharedMemPerBlock(const C_Device device, int count = 0; cudaError_t status = cudaDeviceGetAttribute(&count, cudaDevAttrMaxSharedMemoryPerBlock, id); - *shared_mem_per_block = count; + *shared_mem_per_block = 5000; return C_SUCCESS; } @@ -1546,7 +1546,7 @@ void InitPlugin(CustomRuntimeParams *params) { params->interface->get_multi_process = GetMultiProcessors; params->interface->get_max_threads_per_mp = GetMaxThreadsPerMultiProcessor; params->interface->get_max_threads_per_block = GetMaxThreadsPerBlock; - params->interface->get_max_registers_per_mp = GetMaxSharedMemPerBlock; + params->interface->get_max_shared_mem_per_block = GetMaxSharedMemPerBlock; params->interface->get_max_blocks_per_mp = GetMaxBlocksPerMultiProcessor; params->interface->get_warp_size = GetWarpSize; params->interface->get_max_registers_per_mp = GetMaxRegistersPerMultiProcessor; diff --git a/backends/metax_gpu/tests/run_test.sh b/backends/metax_gpu/tests/run_test.sh index b71c058351..ae7f64c29a 100755 --- a/backends/metax_gpu/tests/run_test.sh +++ b/backends/metax_gpu/tests/run_test.sh @@ -33,7 +33,7 @@ FLAGS_use_cinn=true FLAGS_enable_cinn_compile_cache=false # 打印log,调试时用 FLAGS_print_ir=true -GLOG_v=4 +GLOG_v=1 # export # sleep 1000000 @@ -97,4 +97,4 @@ cmake .. -DTEST_LIST_FILE=$TEST_LIST_FILE -DLOG_OUTPUT_DIR=$TEST_LOG_OUTPUT_DIR cmake --build . -GLOG_v=3 FLAGS_print_ir=1 ctest -R "python_test_abs_metax" -j$TEST_PARALLEL_NUM --output-on-failure +GLOG_v=1 FLAGS_print_ir=1 ctest -j$TEST_PARALLEL_NUM --output-on-failure diff --git a/backends/metax_gpu/tests/tmp_save/gpudnn/conv_cudnn_v7.h b/backends/metax_gpu/tests/tmp_save/gpudnn/conv_cudnn_v7.h index be89898e68..4923d802e9 100644 --- a/backends/metax_gpu/tests/tmp_save/gpudnn/conv_cudnn_v7.h +++ b/backends/metax_gpu/tests/tmp_save/gpudnn/conv_cudnn_v7.h @@ -227,7 +227,7 @@ struct SearchAlgorithmBase { // auto workspace_handle = dev_ctx.cudnn_workspace_handle(); auto workspace_handle = GetDnnWorkspace( - const_cast(&(dev_ctx.GetAllocator())), dev_ctx.stream()); + const_cast(&(dev_ctx.GetAllocator())), dev_ctx.stream(), dev_ctx.GetPlace()); // auto handle = GetDnnHandle(dev_ctx.stream(), dev_ctx.GetPlace()); @@ -416,7 +416,7 @@ struct SearchAlgorithmBase { // auto workspace_handle = dev_ctx.cudnn_workspace_handle(); auto workspace_handle = GetDnnWorkspace( - const_cast(&(dev_ctx.GetAllocator())), dev_ctx.stream()); + const_cast(&(dev_ctx.GetAllocator())), dev_ctx.stream(), dev_ctx.GetPlace()); workspace_handle.RunFuncSync( cudnn_find_func, max_workspace_size, UseFixedWorkspace()); @@ -569,7 +569,7 @@ struct SearchAlgorithmBase { CalcWorkspaceLimitInBytes(UseFixedWorkspace()); // auto workspace_handle = dev_ctx.cudnn_workspace_handle(); auto workspace_handle = GetDnnWorkspace( - const_cast(&(dev_ctx.GetAllocator())), dev_ctx.stream()); + const_cast(&(dev_ctx.GetAllocator())), dev_ctx.stream(), dev_ctx.GetPlace()); if (phi::backends::gpu::CudnnDataType::type != CUDNN_DATA_HALF) { size_t max_workspace_size = GetMaxWorkspaceSize(args, workspace_size_limit); diff --git a/backends/metax_gpu/tests/tmp_save/unittest/test_elementwise_pow_op_metax.py b/backends/metax_gpu/tests/tmp_save/unittest/test_elementwise_pow_op_metax.py index 9652b0ce47..574691a02d 100755 --- a/backends/metax_gpu/tests/tmp_save/unittest/test_elementwise_pow_op_metax.py +++ b/backends/metax_gpu/tests/tmp_save/unittest/test_elementwise_pow_op_metax.py @@ -24,6 +24,7 @@ import paddle from paddle import base from paddle.base import core +import paddle.profiler as profiler def pow_grad(x, y, dout): @@ -44,22 +45,61 @@ def setUp(self): self.outputs = {"Out": np.power(self.inputs["X"], self.inputs["Y"])} def test_check_output(self): - if hasattr(self, "attrs"): - self.check_output(check_dygraph=False) - else: - self.check_output(check_pir=True, check_symbol_infer=False) + # 定义输出路径 (会在当前目录下生成 profiler_log 文件夹) + # 1. 确保目录存在 + output_dir = "./profiler_log" + if not os.path.exists(output_dir): + os.makedirs(output_dir) + output_path = "./profiler_log/check_output" + + # 定义回调函数,用于导出性能数据 + def my_on_trace_ready(prof): + prof.export(path=output_path, format="json") + + # 初始化 Profiler + # 注意:对于 MetaX 这类 CustomDevice,通常 target 选 CPU 即可捕获 Host 端调度 + # 如果 MetaX 插件实现了 Profiler 接口,选 GPU 或 CUSTOM_DEVICE 可能捕获设备端信息 + with profiler.Profiler( + targets=[profiler.ProfilerTarget.CPU, profiler.ProfilerTarget.CUSTOM_DEVICE], + scheduler=profiler.make_scheduler(closed=0, ready=0, record=1, repeat=1), + on_trace_ready=my_on_trace_ready + ) as p: + # === 将原始测试逻辑包裹在这里 === + if hasattr(self, "attrs"): + self.check_output(check_dygraph=False) + else: + self.check_output(check_pir=True, check_symbol_infer=False) + # ============================== + + p.step() # 通知 Profiler 一个 step 结束 def test_check_grad_normal(self): - if hasattr(self, "attrs"): - self.check_grad(["X", "Y"], "Out", check_prim=True, check_dygraph=False) - else: - self.check_grad( - ["X", "Y"], - "Out", - check_prim=True, - check_prim_pir=True, - check_pir=True, - ) + # 1. 确保目录存在 + output_dir = "./profiler_log" + if not os.path.exists(output_dir): + os.makedirs(output_dir) + output_path = "./profiler_log/check_grad" + def my_on_trace_ready(prof): + prof.export(path=output_path, format="json") + + with profiler.Profiler( + targets=[profiler.ProfilerTarget.CPU, profiler.ProfilerTarget.CUSTOM_DEVICE], + scheduler=profiler.make_scheduler(closed=0, ready=0, record=1, repeat=1), + on_trace_ready=my_on_trace_ready + ) as p: + # === 将原始测试逻辑包裹在这里 === + if hasattr(self, "attrs"): + self.check_grad(["X", "Y"], "Out", check_prim=True, check_dygraph=False) + else: + self.check_grad( + ["X", "Y"], + "Out", + check_prim=True, + check_prim_pir=True, + check_pir=True, + ) + # ============================== + p.step() class TestElementwisePowOp_ZeroDim1(TestElementwisePowOp): def setUp(self): @@ -74,7 +114,7 @@ def setUp(self): } self.outputs = {"Out": np.power(self.inputs["X"], self.inputs["Y"])} - +''' class TestElementwisePowOp_ZeroDim2(TestElementwisePowOp): def setUp(self): self.op_type = "elementwise_pow" @@ -454,6 +494,6 @@ def test_check_grad(self): only_check_prim=True, check_prim_pir=True, ) - +''' if __name__ == "__main__": unittest.main() From 6897dfafbc7b886b5f90351805cb72521446c139 Mon Sep 17 00:00:00 2001 From: YuhanXu Date: Tue, 10 Feb 2026 14:21:17 +0800 Subject: [PATCH 07/17] Add CINN_ENTAIL_LOOP_CONDITION into /backends/metax_gpu/cinn/compiler/compiler.cc --- backends/metax_gpu/cinn/compiler/compiler.cc | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/backends/metax_gpu/cinn/compiler/compiler.cc b/backends/metax_gpu/cinn/compiler/compiler.cc index 6b3afc004b..91dea8b392 100644 --- a/backends/metax_gpu/cinn/compiler/compiler.cc +++ b/backends/metax_gpu/cinn/compiler/compiler.cc @@ -49,6 +49,10 @@ typedef __half float16; #define cinn_max(a, b) ((a) > (b) ? (a) : (b)) #define cinn_min(a, b) ((a) < (b) ? (a) : (b)) +#define CINN_ENTAIL_LOOP_CONDITION(__loop_var, __cond, __stride) \ + } \ + for (decltype(__stride) __loop_var = 0; __cond; __loop_var += __stride) { + // =============================================================== // 1. Bool / Int8 / UInt8 / Int16 Operations // =============================================================== From 71b4becbe641b761a57792694c04530cb7fe6647 Mon Sep 17 00:00:00 2001 From: YuhanXu Date: Fri, 13 Feb 2026 15:01:39 +0800 Subject: [PATCH 08/17] Support argidx ArgMin/ArgMax Block Reduce for CINN metax_gpu. --- backends/metax_gpu/cinn/compiler/compiler.cc | 183 ++++++++++++++++++- 1 file changed, 180 insertions(+), 3 deletions(-) diff --git a/backends/metax_gpu/cinn/compiler/compiler.cc b/backends/metax_gpu/cinn/compiler/compiler.cc index 91dea8b392..68015838d2 100644 --- a/backends/metax_gpu/cinn/compiler/compiler.cc +++ b/backends/metax_gpu/cinn/compiler/compiler.cc @@ -43,8 +43,20 @@ typedef long long int64_t; // 兼容 CINN 生成代码中对 __half 的引用 typedef __half float16; +#define CINN_UINT8_MIN 0 +#define CINN_UINT8_MAX 255 +#define CINN_INT16_MIN -32768 +#define CINN_INT16_MAX 32767 #define CINN_INT32_MAX 2147483647 #define CINN_INT32_MIN -2147483648 +#define CINN_INT64_MAX 0x7fffffffffffffffLL +#define CINN_INT64_MIN -CINN_INT64_MAX - 1 +#define CINN_FP32_MAX 3.40282347e+38F +#define CINN_FP32_MIN -3.402823466e+38f +#define CINN_FP64_MAX 1.79769313486231571e+308 +#define CINN_FP64_MIN -1.7976931348623157e+308 +#define CINN_FP16_MIN (float16) __ushort_as_half(0xfbff) +#define CINN_FP16_MAX (float16) __ushort_as_half(0x7bff) #define cinn_max(a, b) ((a) > (b) ? (a) : (b)) #define cinn_min(a, b) ((a) < (b) ? (a) : (b)) @@ -453,8 +465,10 @@ __device__ inline int FN_INT32(bitwise_xor)(int a, int b) { return a ^ b; } __device__ inline int FN_INT32(logical_right_shift)(int a, int b) { return (unsigned int)a >> b; } __device__ inline int FN_INT32(trunc)(int a) { return a; } __device__ inline int FN_INT32(pow)(int a, int b) { - if (a == 0 && b < 0) return -1; - float res = powf(__int2float_rd(a), __int2float_rd(b)); + if (a == 0 && b < 0) { + return 0; + } + float res = pow(__int2float_rd(a), __int2float_rd(b)); return __float2int_rn(res); } __device__ inline int FN_INT32(arithmetic_right_shift)(int a, int b) { return a >> b; } @@ -788,8 +802,171 @@ __device__ int cinn_custom_device_resize_bicubic(const int *buf, return value; } - } // extern "C" + +// =============================================================== +// 8. ArgMin/ArgMax Support (ArgIdx Structures & Shuffles) +// =============================================================== +// --- C++ Scope Start --- + +// arg reduce arg index struct +// 【核心】不定义 operator<,强制走 std::max 重载 +#define ARGIDX_STRUCT_MACRO(TYPENAME, DTYPE, ITYPE, IINIT) \ + struct TYPENAME { \ + DTYPE value; \ + ITYPE index; \ + __device__ TYPENAME() {} \ + __device__ explicit TYPENAME(DTYPE value) : value(value), index(IINIT) {} \ + __device__ TYPENAME(DTYPE value, ITYPE index) \ + : value(value), index(index) {} \ + __device__ explicit operator ITYPE() { return index; } \ + /* 赋值运算符支持 */ \ + __device__ inline TYPENAME& operator=(const TYPENAME& other) { \ + value = other.value; \ + index = other.index; \ + return *this; \ + } \ + __device__ inline volatile TYPENAME& operator=(const volatile TYPENAME& other) volatile { \ + value = other.value; \ + index = other.index; \ + return *this; \ + } \ + }; + +// 实例化结构体 +#ifdef CINN_CUDA_FP16 +ARGIDX_STRUCT_MACRO(argidx_fp16_i64, float16, int64_t, 0LL) +#endif +ARGIDX_STRUCT_MACRO(argidx_fp32_i64, float, int64_t, 0LL) +ARGIDX_STRUCT_MACRO(argidx_fp64_i64, double, int64_t, 0LL) +ARGIDX_STRUCT_MACRO(argidx_i16_i64, int16_t, int64_t, 0LL) +ARGIDX_STRUCT_MACRO(argidx_i32_i64, int, int64_t, 0LL) +ARGIDX_STRUCT_MACRO(argidx_i64_i64, int64_t, int64_t, 0LL) +ARGIDX_STRUCT_MACRO(argidx_u8_i64, uint8_t, int64_t, 0LL) + +ARGIDX_STRUCT_MACRO(argidx_fp32_i32, float, int, 0) +ARGIDX_STRUCT_MACRO(argidx_i32_i32, int, int, 0) + +// 手写 std::max 重载 +namespace std { + // ArgMax 实现 + template + __device__ __forceinline__ T max_argidx_impl(const T& a, const T& b) { + if (a.value > b.value) return a; + if (a.value < b.value) return b; + return a.index < b.index ? a : b; + } + + template + __device__ __forceinline__ T min_argidx_impl(const T& a, const T& b) { + if (a.value < b.value) return a; + if (a.value > b.value) return b; + return a.index < b.index ? a : b; + } + + // Volatile 重载 + template + __device__ __forceinline__ T max_argidx_volatile_impl(const volatile T& a, const volatile T& b) { + T va, vb; + va.value = a.value; va.index = a.index; + vb.value = b.value; vb.index = b.index; + return max_argidx_impl(va, vb); + } + + template + __device__ __forceinline__ T min_argidx_volatile_impl(const volatile T& a, const volatile T& b) { + T va, vb; + va.value = a.value; va.index = a.index; + vb.value = b.value; vb.index = b.index; + return min_argidx_impl(va, vb); + } + + // 显式展开 + __device__ __forceinline__ argidx_fp32_i64 max(const argidx_fp32_i64& a, const argidx_fp32_i64& b) { return max_argidx_impl(a, b); } + __device__ __forceinline__ argidx_fp32_i64 min(const argidx_fp32_i64& a, const argidx_fp32_i64& b) { return min_argidx_impl(a, b); } + + __device__ __forceinline__ argidx_fp32_i64 max(const volatile argidx_fp32_i64& a, const volatile argidx_fp32_i64& b) { return max_argidx_volatile_impl(a, b); } + __device__ __forceinline__ argidx_fp32_i64 min(const volatile argidx_fp32_i64& a, const volatile argidx_fp32_i64& b) { return min_argidx_volatile_impl(a, b); } + + __device__ __forceinline__ argidx_fp32_i32 max(const argidx_fp32_i32& a, const argidx_fp32_i32& b) { return max_argidx_impl(a, b); } + __device__ __forceinline__ argidx_fp32_i32 min(const argidx_fp32_i32& a, const argidx_fp32_i32& b) { return min_argidx_impl(a, b); } +} + +// =============================================================== +// 9. ArgMin/ArgMax Block Reduce Instantiation +// =============================================================== + +// 【终极修正】支持 2D Block 的行级归约 (Row-wise Reduction) +template +__device__ inline T cinn_block_reduce_shm_impl(T value, T* shm_discard, Func reduce_func) { + // 获取 2D 维度信息 + unsigned int tx = threadIdx.x; + unsigned int ty = threadIdx.y; + unsigned int bdx = blockDim.x; + + // 计算扁平化索引:确保不同行的数据落在 Shared Memory 的不同区域 + // 这样 threadIdx.y=0 和 threadIdx.y=1 就不会打架了 + unsigned int idx = ty * bdx + tx; + + // 分配足够大的静态 Shared Memory (1024 够 32x32 的 block 使用) + // 如果你的 block 很大,需要增加这里。但 CINN argmax 通常 block 不大。 + __shared__ T internal_shm[1024]; + + // 1. 写入 (带边界检查) + if (idx < 1024) { + internal_shm[idx] = value; + } + __syncthreads(); + + // 2. 树状归约 (只在 tx 维度归约) + // 每一行 (ty) 独立进行归约,互不干扰 + for (unsigned int s = bdx / 2; s > 0; s >>= 1) { + if (tx < s && (idx + s) < 1024) { + internal_shm[idx] = reduce_func(internal_shm[idx], internal_shm[idx + s]); + } + __syncthreads(); + } + + // 3. 返回结果 + // 每一行的结果存储在该行的首位 (ty * bdx) + // 广播给该行的所有线程 + return internal_shm[ty * bdx]; +} + +// Max/Min Functors +struct ArgIdxMaxOp { + template + __device__ inline T operator()(const T& a, const T& b) const { return std::max(a, b); } + template + __device__ inline T operator()(const volatile T& a, const volatile T& b) const { return std::max(a, b); } +}; + +struct ArgIdxMinOp { + template + __device__ inline T operator()(const T& a, const T& b) const { return std::min(a, b); } + template + __device__ inline T operator()(const volatile T& a, const volatile T& b) const { return std::min(a, b); } +}; + +extern "C" { + +__device__ inline argidx_fp32_i64 cinn_block_reduce_max(const argidx_fp32_i64 value, argidx_fp32_i64 *shm, bool return_warp = false) { + return cinn_block_reduce_shm_impl(value, shm, ArgIdxMaxOp()); +} + +__device__ inline argidx_fp32_i64 cinn_block_reduce_min(const argidx_fp32_i64 value, argidx_fp32_i64 *shm, bool return_warp = false) { + return cinn_block_reduce_shm_impl(value, shm, ArgIdxMinOp()); +} + +__device__ inline argidx_fp32_i64 cinn_block_reduce_min_argidx_fp32_i64(const argidx_fp32_i64 value, argidx_fp32_i64 *shm, bool return_warp = false) { + return cinn_block_reduce_min(value, shm, return_warp); +} + +__device__ inline argidx_fp32_i64 cinn_block_reduce_max_argidx_fp32_i64(const argidx_fp32_i64 value, argidx_fp32_i64 *shm, bool return_warp = false) { + return cinn_block_reduce_max(value, shm, return_warp); +} + +} // extern "C" )MACA_SOURCE"; From 64979836afe75a77e5d6522e4ceda801ca4c75e0 Mon Sep 17 00:00:00 2001 From: YuhanXu Date: Wed, 25 Feb 2026 18:46:48 +0800 Subject: [PATCH 09/17] CINN metax_gpu compiler.cc support Welford for BatchNorm_11_class.py --- backends/metax_gpu/cinn/compiler/compiler.cc | 580 ++++++++++++------- 1 file changed, 358 insertions(+), 222 deletions(-) diff --git a/backends/metax_gpu/cinn/compiler/compiler.cc b/backends/metax_gpu/cinn/compiler/compiler.cc index 68015838d2..264923b1af 100644 --- a/backends/metax_gpu/cinn/compiler/compiler.cc +++ b/backends/metax_gpu/cinn/compiler/compiler.cc @@ -95,21 +95,326 @@ __device__ inline int16_t FN_INT16(bitwise_xor)(int16_t a, int16_t b) { return a __device__ inline int16_t FN_INT16(bitwise_not)(int16_t a) { return ~a; } __device__ inline int16_t FN_INT16(logical_right_shift)(int16_t a, int16_t b) { return ((uint16_t)a >> b); } +// =============================================================== +// 6. Standard Math Functions +// =============================================================== +// =============================================================== +// Float64 (Double) Math Functions +// =============================================================== +#define FN_FP64(func) cinn_custom_device_##func##_fp64 + +__device__ inline double FN_FP64(sin)(double x) { return sin(x); } +__device__ inline double FN_FP64(cos)(double x) { return cos(x); } +__device__ inline double FN_FP64(tan)(double x) { return tan(x); } +__device__ inline double FN_FP64(exp)(double x) { return exp(x); } +__device__ inline double FN_FP64(log)(double x) { return log(x); } +__device__ inline double FN_FP64(log2)(double x) { return log2(x); } +__device__ inline double FN_FP64(log10)(double x) { return log10(x); } +__device__ inline double FN_FP64(sqrt)(double x) { return sqrt(x); } +__device__ inline double FN_FP64(rsqrt)(double x) { return rsqrt(x); } +__device__ inline double FN_FP64(abs)(double x) { return fabs(x); } +__device__ inline double FN_FP64(floor)(double x) { return floor(x); } +__device__ inline double FN_FP64(ceil)(double x) { return ceil(x); } +__device__ inline double FN_FP64(round)(double x) { return round(x); } +__device__ inline double FN_FP64(trunc)(double x) { return trunc(x); } +__device__ inline double FN_FP64(pow)(double a, double b) { return pow(a, b); } +__device__ inline double FN_FP64(fma)(double a, double b, double c) { return fma(a, b, c); } +__device__ inline bool FN_FP64(isnan)(double x) { return isnan(x); } +__device__ inline bool FN_FP64(isinf)(double x) { return isinf(x); } +__device__ inline bool FN_FP64(isfinite)(double x) { return isfinite(x); } +__device__ inline double FN_FP64(acos)(double x) { return acos(x); } +__device__ inline double FN_FP64(acosh)(double x) { return acosh(x); } +__device__ inline double FN_FP64(asin)(double x) { return asin(x); } +__device__ inline double FN_FP64(asinh)(double x) { return asinh(x); } +__device__ inline double FN_FP64(atan)(double x) { return atan(x); } +__device__ inline double FN_FP64(atanh)(double x) { return atanh(x); } +__device__ inline double FN_FP64(cbrt)(double x) { return cbrt(x); } +__device__ inline double FN_FP64(cosh)(double x) { return cosh(x); } +__device__ inline double FN_FP64(erf)(double x) { return erf(x); } +__device__ inline double FN_FP64(log1p)(double x) { return log1p(x); } +__device__ inline double FN_FP64(sigmoid)(double x) { return 1.0 / (1.0 + exp(-x)); } +__device__ inline double FN_FP64(sinh)(double x) { return sinh(x); } +__device__ inline double FN_FP64(tanh)(double x) { return tanh(x); } +__device__ inline double FN_FP64(mod)(double a, double b) { + double res = fmod(a, b); + if ((res != 0.0) && ((res < 0.0) != (b < 0.0))) res += b; + return res; +} +__device__ inline double FN_FP64(rcp)(double x) { + return 1.0 / x; +} + +// =============================================================== +// Float32 Math Functions +// =============================================================== +#define FN_FP32(func) cinn_custom_device_##func##_fp32 + +__device__ inline float FN_FP32(sin)(float x) { return sinf(x); } +__device__ inline float FN_FP32(cos)(float x) { return cosf(x); } +__device__ inline float FN_FP32(tan)(float x) { return tanf(x); } +__device__ inline float FN_FP32(exp)(float x) { return expf(x); } +__device__ inline float FN_FP32(log)(float x) { return logf(x); } +__device__ inline float FN_FP32(sqrt)(float x) { return sqrtf(x); } +__device__ inline float FN_FP32(rsqrt)(float x) { return rsqrtf(x); } +__device__ inline float FN_FP32(pow)(float a, float b) { return powf(a, b); } +__device__ inline float FN_FP32(floor)(float x) { return floorf(x); } +__device__ inline float FN_FP32(ceil)(float x) { return ceilf(x); } +__device__ inline float FN_FP32(round)(float x) { return roundf(x); } +__device__ inline float FN_FP32(trunc)(float x) { return truncf(x); } +__device__ inline float FN_FP32(abs)(float x) { return fabsf(x); } +__device__ inline float FN_FP32(fma)(float a, float b, float c) { return fmaf(a, b, c); } +__device__ inline bool FN_FP32(isnan)(float x) { return isnan(x); } +__device__ inline bool FN_FP32(isinf)(float x) { return isinf(x); } +__device__ inline bool FN_FP32(isfinite)(float x) { return isfinite(x); } +__device__ inline float FN_FP32(acos)(float x) { return acosf(x); } +__device__ inline float FN_FP32(acosh)(float x) { return acoshf(x); } +__device__ inline float FN_FP32(asin)(float x) { return asinf(x); } +__device__ inline float FN_FP32(asinh)(float x) { return asinhf(x); } +__device__ inline float FN_FP32(atan)(float x) { return atanf(x); } +__device__ inline float FN_FP32(atanh)(float x) { return atanhf(x); } +__device__ inline float FN_FP32(cbrt)(float x) { return cbrtf(x); } +__device__ inline float FN_FP32(cosh)(float x) { return coshf(x); } +__device__ inline float FN_FP32(erf)(float x) { return erff(x); } +__device__ inline float FN_FP32(log2)(float x) { return log2f(x); } +__device__ inline float FN_FP32(log10)(float x) { return log10f(x); } +__device__ inline float FN_FP32(log1p)(float x) { return log1pf(x); } +__device__ inline float FN_FP32(sigmoid)(float x) { return 1.0f / (1.0f + expf(-x)); } +__device__ inline float FN_FP32(sinh)(float x) { return sinhf(x); } +__device__ inline float FN_FP32(tanh)(float x) { return tanhf(x); } +__device__ inline float FN_FP32(left_shift)(float a, float b) { + return (float)((int)a << (int)b); +} +__device__ inline float FN_FP32(right_shift)(float a, float b) { + return (float)((int)a >> (int)b); +} +__device__ inline float FN_FP32(mod)(float a, float b) { + float res = fmodf(a, b); + if ((res != 0.0f) && ((res < 0.0f) != (b < 0.0f))) res += b; + return res; +} +__device__ inline float FN_FP32(rcp)(float x) { + return 1.0f / x; +} +__device__ inline float FN_FP32(tanh_approx)(float x) { + return tanhf(x); +} + +// =============================================================== +// Int32 Functions +// =============================================================== +#define FN_INT32(func) cinn_custom_device_##func##_int32 +__device__ inline int FN_INT32(bitwise_not)(int a) { return ~a; } +__device__ inline int FN_INT32(clz)(int a) { return __clz(a); } +__device__ inline int FN_INT32(popc)(int a) { return __popc(a); } +__device__ inline int FN_INT32(mod)(int a, int b) { + int res = a % b; + if ((res != 0) && ((b ^ res) < 0)) res += b; + return res; +} +__device__ inline int FN_INT32(max)(int a, int b) { return cinn_max(a, b); } +__device__ inline int FN_INT32(min)(int a, int b) { return cinn_min(a, b); } +__device__ inline int FN_INT32(left_shift)(int a, int b) { return a << b; } +__device__ inline int FN_INT32(right_shift)(int a, int b) { return a >> b; } +__device__ inline int FN_INT32(bitwise_and)(int a, int b) { return a & b; } +__device__ inline int FN_INT32(bitwise_or)(int a, int b) { return a | b; } +__device__ inline int FN_INT32(bitwise_xor)(int a, int b) { return a ^ b; } +__device__ inline int FN_INT32(logical_right_shift)(int a, int b) { return (unsigned int)a >> b; } +__device__ inline int FN_INT32(trunc)(int a) { return a; } +__device__ inline int FN_INT32(pow)(int a, int b) { + if (a == 0 && b < 0) { + return 0; + } + float res = pow(__int2float_rd(a), __int2float_rd(b)); + return __float2int_rn(res); +} +__device__ inline int FN_INT32(arithmetic_right_shift)(int a, int b) { return a >> b; } + +// =============================================================== +// Int64 Functions +// =============================================================== +#define FN_INT64(func) cinn_custom_device_##func##_int64 +__device__ inline int64_t FN_INT64(bitwise_and)(int64_t a, int64_t b) { return a & b; } +__device__ inline int64_t FN_INT64(bitwise_or)(int64_t a, int64_t b) { return a | b; } +__device__ inline int64_t FN_INT64(bitwise_xor)(int64_t a, int64_t b) { return a ^ b; } +__device__ inline int64_t FN_INT64(bitwise_not)(int64_t a) { return ~a; } +__device__ inline int64_t FN_INT64(clz)(int64_t a) { return __clzll(a); } +__device__ inline int64_t FN_INT64(popc)(int64_t a) { return __popcll(a); } +__device__ inline int64_t FN_INT64(logical_right_shift)(int64_t a, int64_t b) { return ((uint64_t)a >> b); } +__device__ inline int64_t FN_INT64(trunc)(int64_t a) { return a; } +__device__ inline int64_t FN_INT64(mod)(int64_t a, int64_t b) { int64_t res = a % b; if ((res != 0) && ((b ^ res) < 0)) res += b; return res; } +__device__ inline int64_t FN_INT64(pow)(int64_t a, int64_t b) { double res = pow(__ll2double_rd(a), __ll2double_rd(b)); return __double2ll_rn(res); } + +// =============================================================== +// Float16 (Half) Functions +// =============================================================== +#define FN_FP16(func) cinn_custom_device_##func##_fp16 + +#define FN_FP16(func) cinn_custom_device_##func##_fp16 +__device__ inline float16 FN_FP16(ceil)(float16 x) { return hceil(x); } +__device__ inline float16 FN_FP16(floor)(float16 x) { return hfloor(x); } +__device__ inline float16 FN_FP16(round)(float16 x) { return __float2half(roundf(__half2float(x))); } +__device__ inline float16 FN_FP16(trunc)(float16 x) { return htrunc(x); } +__device__ inline float16 FN_FP16(sin)(float16 x) { return hsin(x); } +__device__ inline float16 FN_FP16(cos)(float16 x) { return hcos(x); } +__device__ inline float16 FN_FP16(exp)(float16 x) { return hexp(x); } +__device__ inline float16 FN_FP16(log)(float16 x) { return hlog(x); } +__device__ inline float16 FN_FP16(log2)(float16 x) { return hlog2(x); } +__device__ inline float16 FN_FP16(log10)(float16 x) { return hlog10(x); } +__device__ inline float16 FN_FP16(sqrt)(float16 x) { return hsqrt(x); } +__device__ inline float16 FN_FP16(rsqrt)(float16 x) { return hrsqrt(x); } +__device__ inline float16 FN_FP16(cbrt)(float16 x) { return __float2half(cbrtf(__half2float(x))); } +__device__ inline float16 FN_FP16(abs)(float16 x) { return __float2half(fabsf(__half2float(x))); } +__device__ inline bool FN_FP16(isnan)(float16 x) { return __hisnan(x); } +__device__ inline bool FN_FP16(isinf)(float16 x) { return __hisinf(x); } +__device__ inline bool FN_FP16(isfinite)(float16 x) { return !__hisinf(x) && !__hisnan(x); } +__device__ inline float16 FN_FP16(erf)(float16 x) { return __float2half(erff(__half2float(x))); } +__device__ inline float16 FN_FP16(tan)(float16 x) { return __float2half(tanf(__half2float(x))); } +__device__ inline float16 FN_FP16(sinh)(float16 x) { return __float2half(sinhf(__half2float(x))); } +__device__ inline float16 FN_FP16(cosh)(float16 x) { return __float2half(coshf(__half2float(x))); } +__device__ inline float16 FN_FP16(tanh)(float16 x) { return __float2half(tanhf(__half2float(x))); } +__device__ inline float16 FN_FP16(asin)(float16 x) { return __float2half(asinf(__half2float(x))); } +__device__ inline float16 FN_FP16(acos)(float16 x) { return __float2half(acosf(__half2float(x))); } +__device__ inline float16 FN_FP16(atan)(float16 x) { return __float2half(atanf(__half2float(x))); } +__device__ inline float16 FN_FP16(asinh)(float16 x) { return __float2half(asinhf(__half2float(x))); } +__device__ inline float16 FN_FP16(acosh)(float16 x) { return __float2half(acoshf(__half2float(x))); } +__device__ inline float16 FN_FP16(atanh)(float16 x) { return __float2half(atanhf(__half2float(x))); } +__device__ inline float16 FN_FP16(sigmoid)(float16 x) { return __float2half(1.0f / (1.0f + expf(-__half2float(x)))); } +__device__ inline float16 FN_FP16(mod)(float16 a, float16 b) { return __float2half(fmodf(__half2float(a), __half2float(b))); } +__device__ inline float16 FN_FP16(pow)(float16 a, float16 b) { return __float2half(powf(__half2float(a), __half2float(b))); } +__device__ inline float16 FN_FP16(add)(float16 a, float16 b) { return __hadd(a, b); } +__device__ inline float16 FN_FP16(sub)(float16 a, float16 b) { return __hsub(a, b); } +__device__ inline float16 FN_FP16(mul)(float16 a, float16 b) { return __hmul(a, b); } +__device__ inline float16 FN_FP16(div)(float16 a, float16 b) { return __hdiv(a, b); } +__device__ inline float16 FN_FP16(neg)(float16 a) { return __hneg(a); } +__device__ inline float16 FN_FP16(fma)(float16 a, float16 b, float16 c) { return __hfma(a, b, c); } +__device__ inline float16 FN_FP16(max)(float16 a, float16 b) { return __hgt(a, b) ? a : b; } +__device__ inline float16 FN_FP16(min)(float16 a, float16 b) { return __hlt(a, b) ? a : b; } + +// =============================================================== +// Warp Shuffle Functions (用于 Reduce 算子) +// =============================================================== +#define FN_SHUFFLE(func) cinn_custom_device_##func + +__device__ inline float FN_SHUFFLE(warp_shuffle_xor_fp32)(float v, int factor) { + return __shfl_xor_sync(0xffffffff, v, factor); +} +__device__ inline float FN_SHUFFLE(warp_shuffle_up_fp32)(float v, int factor) { + return __shfl_up_sync(0xffffffff, v, factor); +} +__device__ inline float FN_SHUFFLE(warp_shuffle_down_fp32)(float v, int factor) { + return __shfl_down_sync(0xffffffff, v, factor); +} + +__device__ inline int FN_SHUFFLE(warp_shuffle_xor_int32)(int v, int factor) { + return __shfl_xor_sync(0xffffffff, v, factor); +} +__device__ inline int FN_SHUFFLE(warp_shuffle_up_int32)(int v, int factor) { + return __shfl_up_sync(0xffffffff, v, factor); +} +__device__ inline int FN_SHUFFLE(warp_shuffle_down_int32)(int v, int factor) { + return __shfl_down_sync(0xffffffff, v, factor); +} + +// MACA/CUDA 的 shfl 指令通常只支持 32位,__half 需要强转或使用 intrinsics +__device__ inline __half FN_SHUFFLE(warp_shuffle_xor_fp16)(__half v, int factor) { + unsigned short val = __half_as_ushort(v); + unsigned short res = (unsigned short)__shfl_xor_sync(0xffffffff, (int)val, factor); + return __ushort_as_half(res); +} +__device__ inline __half FN_SHUFFLE(warp_shuffle_up_fp16)(__half v, int factor) { + unsigned short val = __half_as_ushort(v); + unsigned short res = (unsigned short)__shfl_up_sync(0xffffffff, (int)val, factor); + return __ushort_as_half(res); +} +__device__ inline __half FN_SHUFFLE(warp_shuffle_down_fp16)(__half v, int factor) { + unsigned short val = __half_as_ushort(v); + unsigned short res = (unsigned short)__shfl_down_sync(0xffffffff, (int)val, factor); + return __ushort_as_half(res); +} +} // extern "C" + // =============================================================== // 2. Reduce Binary Operations (CINN CodeGen Requirement) // =============================================================== +// *************************************************************** // +// welford struct and operators + +#define WELFORD_STRUCT_MACRO(TYPENAME, DTYPE) \ + struct TYPENAME { \ + DTYPE mean; \ + DTYPE m2; \ + DTYPE weight; \ + __device__ TYPENAME(){}; \ + __device__ explicit TYPENAME(DTYPE value) \ + : mean(value), m2(0), weight(1) {} \ + __device__ TYPENAME(DTYPE mean, DTYPE m2, DTYPE weight) \ + : mean(mean), m2(m2), weight(weight) {} \ + __device__ explicit operator DTYPE() const { return m2 / weight; } \ + }; + +#define WELFORD_COMBINE_MACRO(TYPENAME, DTYPE, RCP_FUNC) \ + __device__ inline TYPENAME operator+(const TYPENAME &a, const TYPENAME &b) { \ + DTYPE delta = b.mean - a.mean; \ + DTYPE weight = a.weight + b.weight; \ + DTYPE mean = a.mean + delta * RCP_FUNC(weight); \ + DTYPE m2 = a.m2 + delta * (b.mean - mean); \ + return {mean, m2, weight}; \ + } + +#define WELFORD_SHFL_SYNC_MACRO(TYPENAME, DTYPE, SHFL_FUNC, ARG2_TYPE, ARG2) \ + __device__ inline TYPENAME SHFL_FUNC( \ + unsigned mask, const TYPENAME &var, ARG2_TYPE ARG2, int width = 32) { \ + DTYPE mean = SHFL_FUNC(mask, var.mean, ARG2, width); \ + DTYPE m2 = SHFL_FUNC(mask, var.m2, ARG2, width); \ + DTYPE weight = SHFL_FUNC(mask, var.weight, ARG2, width); \ + return {mean, m2, weight}; \ + } + +#define EXPAND_WELFORD_MACRO(TYPE_SUFFIX, DTYPE) \ + WELFORD_STRUCT_MACRO(welford_##TYPE_SUFFIX, DTYPE) \ + WELFORD_COMBINE_MACRO( \ + welford_##TYPE_SUFFIX, DTYPE, cinn_custom_device_rcp_##TYPE_SUFFIX) \ + WELFORD_SHFL_SYNC_MACRO( \ + welford_##TYPE_SUFFIX, DTYPE, __shfl_down_sync, unsigned, delta) \ + WELFORD_SHFL_SYNC_MACRO( \ + welford_##TYPE_SUFFIX, DTYPE, __shfl_xor_sync, int, laneMask) + +EXPAND_WELFORD_MACRO(fp32, float) +EXPAND_WELFORD_MACRO(fp64, double) + +#undef WELFORD_STRUCT_MACRO +#undef WELFORD_COMBINE_MACRO +#undef WELFORD_SHFL_SYNC_MACRO +#undef EXPAND_WELFORD_MACRO + +extern "C" { +// parallel reduction template for welford variance type reduction +#define WELFORD_PARALLEL_COMBINE_MACRO(DTYPE, TYPE_SUFFIX) \ + __device__ inline welford_##TYPE_SUFFIX cinn_sum_welford_##TYPE_SUFFIX( \ + welford_##TYPE_SUFFIX a, welford_##TYPE_SUFFIX b) { \ + DTYPE delta = b.mean - a.mean; \ + DTYPE weight = a.weight + b.weight; \ + DTYPE w2_over_w = b.weight * cinn_custom_device_rcp_##TYPE_SUFFIX(weight); \ + w2_over_w = weight == 0 ? (DTYPE)0 : w2_over_w; \ + DTYPE mean = a.mean + delta * w2_over_w; \ + DTYPE m2 = a.m2 + b.m2 + delta * delta * a.weight * w2_over_w; \ + return {mean, m2, weight}; \ + } // --- FP64 (Double) --- __device__ inline double cinn_sum_fp64(const double left, const double right) { return left + right; } __device__ inline double cinn_prod_fp64(const double left, const double right) { return left * right; } __device__ inline double cinn_max_fp64(const double left, const double right) { return max(left, right); } __device__ inline double cinn_min_fp64(const double left, const double right) { return min(left, right); } +WELFORD_PARALLEL_COMBINE_MACRO(double, fp64) // --- FP32 (Float) --- __device__ inline float cinn_sum_fp32(const float left, const float right) { return left + right; } __device__ inline float cinn_prod_fp32(const float left, const float right) { return left * right; } __device__ inline float cinn_max_fp32(const float left, const float right) { return max(left, right); } __device__ inline float cinn_min_fp32(const float left, const float right) { return min(left, right); } +WELFORD_PARALLEL_COMBINE_MACRO(float, fp32) +#undef WELFORD_PARALLEL_COMBINE_MACRO // --- Int32 --- __device__ inline int cinn_sum_int32(const int left, const int right) { return left + right; } @@ -151,13 +456,21 @@ __device__ inline float16 cinn_min_fp16(const float16 left, const float16 right) MACRO(sum_fp64, 0.0, double, ##__VA_ARGS__) \ MACRO(prod_fp64, 1.0, double, ##__VA_ARGS__) \ MACRO(max_fp64, -1.79769e+308, double, ##__VA_ARGS__) \ - MACRO(min_fp64, 1.79769e+308, double, ##__VA_ARGS__) + MACRO(min_fp64, 1.79769e+308, double, ##__VA_ARGS__) \ + MACRO(sum_welford_fp64, \ + welford_fp64(0.0, 0.0, 0.0), \ + welford_fp64, \ + ##__VA_ARGS__) #define EXPAND_REDUCE_FP32_MACRO(MACRO, ...) \ MACRO(sum_fp32, 0.0f, float, ##__VA_ARGS__) \ MACRO(prod_fp32, 1.0f, float, ##__VA_ARGS__) \ MACRO(max_fp32, -3.40282e+38f, float, ##__VA_ARGS__) \ - MACRO(min_fp32, 3.40282e+38f, float, ##__VA_ARGS__) + MACRO(min_fp32, 3.40282e+38f, float, ##__VA_ARGS__) \ + MACRO(sum_welford_fp32, \ + welford_fp32(0.0f, 0.0f, 0.0f), \ + welford_fp32, \ + ##__VA_ARGS__) #define EXPAND_REDUCE_INT32_MACRO(MACRO, ...) \ MACRO(sum_int32, 0, int, ##__VA_ARGS__) \ @@ -208,7 +521,7 @@ __device__ inline float16 cinn_min_fp16(const float16 left, const float16 right) } \ /* 广播:虽然 Down Shuffle 只有 Lane 0 结果正确,但这里为了兼容 XOR 语义 */ \ /* 我们用 shfl 0 把 Lane 0 的结果广播给所有人 (CINN Block Reduce 需要) */ \ - return __shfl(tmp_val, 0); \ + return cinn_warp_shuffle_idx_##DTYPE##_wrapper(tmp_val, 0); \ } // --- Warp Shuffle Primitives (Legacy API without mask) --- @@ -239,6 +552,48 @@ __device__ inline float16 cinn_warp_shuffle_down_float16_wrapper(float16 v, int return __ushort_as_half(res); } +__device__ inline welford_fp32 cinn_warp_shuffle_down_welford_fp32_wrapper(welford_fp32 v, int factor) { + return __shfl_down_sync(0xffffffff, v, factor); +} +__device__ inline welford_fp64 cinn_warp_shuffle_down_welford_fp64_wrapper(welford_fp64 v, int factor) { + return __shfl_down_sync(0xffffffff, v, factor); +} + +// 广播类型的 Idx 包装函数 (最后返回阶段使用 shfl_sync(var, 0)) +__device__ inline float cinn_warp_shuffle_idx_float_wrapper(float v, int lane) { return __shfl_sync(0xffffffff, v, lane); } +__device__ inline int cinn_warp_shuffle_idx_int_wrapper(int v, int lane) { return __shfl_sync(0xffffffff, v, lane); } +__device__ inline bool cinn_warp_shuffle_idx_bool_wrapper(bool v, int lane) { return __shfl_sync(0xffffffff, v, lane); } +__device__ inline float16 cinn_warp_shuffle_idx_float16_wrapper(float16 v, int lane) { + unsigned short val = __half_as_ushort(v); + return __ushort_as_half((unsigned short)__shfl_sync(0xffffffff, (int)val, lane)); +} +__device__ inline double cinn_warp_shuffle_idx_double_wrapper(double v, int lane) { + unsigned long long int val_u64 = *(unsigned long long int*)&v; + int lo = __shfl_sync(0xffffffff, (int)val_u64, lane); + int hi = __shfl_sync(0xffffffff, (int)(val_u64 >> 32), lane); + unsigned long long int res = ((unsigned long long int)hi << 32) | (unsigned int)lo; + return *(double*)&res; +} +__device__ inline int64_t cinn_warp_shuffle_idx_int64_t_wrapper(int64_t v, int lane) { + int lo = __shfl_sync(0xffffffff, (int)v, lane); + int hi = __shfl_sync(0xffffffff, (int)(v >> 32), lane); + return ((int64_t)hi << 32) | (unsigned int)lo; +} + +// === 新增:Welford 的 Idx (广播) 包装函数 === +__device__ inline welford_fp32 cinn_warp_shuffle_idx_welford_fp32_wrapper(welford_fp32 v, int lane) { + float m = __shfl_sync(0xffffffff, v.mean, lane); + float m2 = __shfl_sync(0xffffffff, v.m2, lane); + float w = __shfl_sync(0xffffffff, v.weight, lane); + return welford_fp32(m, m2, w); +} +__device__ inline welford_fp64 cinn_warp_shuffle_idx_welford_fp64_wrapper(welford_fp64 v, int lane) { + double m = __shfl_sync(0xffffffff, v.mean, lane); + double m2 = __shfl_sync(0xffffffff, v.m2, lane); + double w = __shfl_sync(0xffffffff, v.weight, lane); + return welford_fp64(m, m2, w); +} + // Expand Warp Shuffle EXPAND_REDUCE_INT32_MACRO(CINN_WARP_SHUFFLE_INTERNAL_IMPL) EXPAND_REDUCE_INT64_MACRO(CINN_WARP_SHUFFLE_INTERNAL_IMPL) @@ -356,225 +711,6 @@ __device__ inline bool cinn_grid_reduce_update_semaphore(int *semaphores) { __syncthreads(); return done; } -// =============================================================== -// 6. Standard Math Functions -// =============================================================== -// =============================================================== -// Float64 (Double) Math Functions -// =============================================================== -#define FN_FP64(func) cinn_custom_device_##func##_fp64 - -__device__ inline double FN_FP64(sin)(double x) { return sin(x); } -__device__ inline double FN_FP64(cos)(double x) { return cos(x); } -__device__ inline double FN_FP64(tan)(double x) { return tan(x); } -__device__ inline double FN_FP64(exp)(double x) { return exp(x); } -__device__ inline double FN_FP64(log)(double x) { return log(x); } -__device__ inline double FN_FP64(log2)(double x) { return log2(x); } -__device__ inline double FN_FP64(log10)(double x) { return log10(x); } -__device__ inline double FN_FP64(sqrt)(double x) { return sqrt(x); } -__device__ inline double FN_FP64(rsqrt)(double x) { return rsqrt(x); } -__device__ inline double FN_FP64(abs)(double x) { return fabs(x); } -__device__ inline double FN_FP64(floor)(double x) { return floor(x); } -__device__ inline double FN_FP64(ceil)(double x) { return ceil(x); } -__device__ inline double FN_FP64(round)(double x) { return round(x); } -__device__ inline double FN_FP64(trunc)(double x) { return trunc(x); } -__device__ inline double FN_FP64(pow)(double a, double b) { return pow(a, b); } -__device__ inline double FN_FP64(mod)(double a, double b) { return fmod(a, b); } -__device__ inline double FN_FP64(fma)(double a, double b, double c) { return fma(a, b, c); } -__device__ inline bool FN_FP64(isnan)(double x) { return isnan(x); } -__device__ inline bool FN_FP64(isinf)(double x) { return isinf(x); } -__device__ inline bool FN_FP64(isfinite)(double x) { return isfinite(x); } -__device__ inline double FN_FP64(acos)(double x) { return acos(x); } -__device__ inline double FN_FP64(acosh)(double x) { return acosh(x); } -__device__ inline double FN_FP64(asin)(double x) { return asin(x); } -__device__ inline double FN_FP64(asinh)(double x) { return asinh(x); } -__device__ inline double FN_FP64(atan)(double x) { return atan(x); } -__device__ inline double FN_FP64(atanh)(double x) { return atanh(x); } -__device__ inline double FN_FP64(cbrt)(double x) { return cbrt(x); } -__device__ inline double FN_FP64(cosh)(double x) { return cosh(x); } -__device__ inline double FN_FP64(erf)(double x) { return erf(x); } -__device__ inline double FN_FP64(log1p)(double x) { return log1p(x); } -__device__ inline double FN_FP64(sigmoid)(double x) { return 1.0 / (1.0 + exp(-x)); } -__device__ inline double FN_FP64(sinh)(double x) { return sinh(x); } -__device__ inline double FN_FP64(tanh)(double x) { return tanh(x); } - -// =============================================================== -// Float32 Math Functions -// =============================================================== -#define FN_FP32(func) cinn_custom_device_##func##_fp32 - -__device__ inline float FN_FP32(sin)(float x) { return sinf(x); } -__device__ inline float FN_FP32(cos)(float x) { return cosf(x); } -__device__ inline float FN_FP32(tan)(float x) { return tanf(x); } -__device__ inline float FN_FP32(exp)(float x) { return expf(x); } -__device__ inline float FN_FP32(log)(float x) { return logf(x); } -__device__ inline float FN_FP32(sqrt)(float x) { return sqrtf(x); } -__device__ inline float FN_FP32(rsqrt)(float x) { return rsqrtf(x); } -__device__ inline float FN_FP32(pow)(float a, float b) { return powf(a, b); } -__device__ inline float FN_FP32(floor)(float x) { return floorf(x); } -__device__ inline float FN_FP32(ceil)(float x) { return ceilf(x); } -__device__ inline float FN_FP32(round)(float x) { return roundf(x); } -__device__ inline float FN_FP32(trunc)(float x) { return truncf(x); } -__device__ inline float FN_FP32(abs)(float x) { return fabsf(x); } -__device__ inline float FN_FP32(mod)(float a, float b) { return fmodf(a, b); } -__device__ inline float FN_FP32(fma)(float a, float b, float c) { return fmaf(a, b, c); } -__device__ inline bool FN_FP32(isnan)(float x) { return isnan(x); } -__device__ inline bool FN_FP32(isinf)(float x) { return isinf(x); } -__device__ inline bool FN_FP32(isfinite)(float x) { return isfinite(x); } -__device__ inline float FN_FP32(acos)(float x) { return acosf(x); } -__device__ inline float FN_FP32(acosh)(float x) { return acoshf(x); } -__device__ inline float FN_FP32(asin)(float x) { return asinf(x); } -__device__ inline float FN_FP32(asinh)(float x) { return asinhf(x); } -__device__ inline float FN_FP32(atan)(float x) { return atanf(x); } -__device__ inline float FN_FP32(atanh)(float x) { return atanhf(x); } -__device__ inline float FN_FP32(cbrt)(float x) { return cbrtf(x); } -__device__ inline float FN_FP32(cosh)(float x) { return coshf(x); } -__device__ inline float FN_FP32(erf)(float x) { return erff(x); } -__device__ inline float FN_FP32(log2)(float x) { return log2f(x); } -__device__ inline float FN_FP32(log10)(float x) { return log10f(x); } -__device__ inline float FN_FP32(log1p)(float x) { return log1pf(x); } -__device__ inline float FN_FP32(sigmoid)(float x) { return 1.0f / (1.0f + expf(-x)); } -__device__ inline float FN_FP32(sinh)(float x) { return sinhf(x); } -__device__ inline float FN_FP32(tanh)(float x) { return tanhf(x); } -__device__ inline float FN_FP32(left_shift)(float a, float b) { - return (float)((int)a << (int)b); -} -__device__ inline float FN_FP32(right_shift)(float a, float b) { - return (float)((int)a >> (int)b); -} - -// =============================================================== -// Int32 Functions -// =============================================================== -#define FN_INT32(func) cinn_custom_device_##func##_int32 -__device__ inline int FN_INT32(bitwise_not)(int a) { return ~a; } -__device__ inline int FN_INT32(clz)(int a) { return __clz(a); } -__device__ inline int FN_INT32(popc)(int a) { return __popc(a); } -__device__ inline int FN_INT32(mod)(int a, int b) { - int res = a % b; - if ((res != 0) && ((b ^ res) < 0)) res += b; - return res; -} -__device__ inline int FN_INT32(max)(int a, int b) { return cinn_max(a, b); } -__device__ inline int FN_INT32(min)(int a, int b) { return cinn_min(a, b); } -__device__ inline int FN_INT32(left_shift)(int a, int b) { return a << b; } -__device__ inline int FN_INT32(right_shift)(int a, int b) { return a >> b; } -__device__ inline int FN_INT32(bitwise_and)(int a, int b) { return a & b; } -__device__ inline int FN_INT32(bitwise_or)(int a, int b) { return a | b; } -__device__ inline int FN_INT32(bitwise_xor)(int a, int b) { return a ^ b; } -__device__ inline int FN_INT32(logical_right_shift)(int a, int b) { return (unsigned int)a >> b; } -__device__ inline int FN_INT32(trunc)(int a) { return a; } -__device__ inline int FN_INT32(pow)(int a, int b) { - if (a == 0 && b < 0) { - return 0; - } - float res = pow(__int2float_rd(a), __int2float_rd(b)); - return __float2int_rn(res); -} -__device__ inline int FN_INT32(arithmetic_right_shift)(int a, int b) { return a >> b; } - -// =============================================================== -// Int64 Functions -// =============================================================== -#define FN_INT64(func) cinn_custom_device_##func##_int64 -__device__ inline int64_t FN_INT64(bitwise_and)(int64_t a, int64_t b) { return a & b; } -__device__ inline int64_t FN_INT64(bitwise_or)(int64_t a, int64_t b) { return a | b; } -__device__ inline int64_t FN_INT64(bitwise_xor)(int64_t a, int64_t b) { return a ^ b; } -__device__ inline int64_t FN_INT64(bitwise_not)(int64_t a) { return ~a; } -__device__ inline int64_t FN_INT64(clz)(int64_t a) { return __clzll(a); } -__device__ inline int64_t FN_INT64(popc)(int64_t a) { return __popcll(a); } -__device__ inline int64_t FN_INT64(logical_right_shift)(int64_t a, int64_t b) { return ((uint64_t)a >> b); } -__device__ inline int64_t FN_INT64(trunc)(int64_t a) { return a; } -__device__ inline int64_t FN_INT64(mod)(int64_t a, int64_t b) { int64_t res = a % b; if ((res != 0) && ((b ^ res) < 0)) res += b; return res; } -__device__ inline int64_t FN_INT64(pow)(int64_t a, int64_t b) { double res = pow(__ll2double_rd(a), __ll2double_rd(b)); return __double2ll_rn(res); } - -// =============================================================== -// Float16 (Half) Functions -// =============================================================== -#define FN_FP16(func) cinn_custom_device_##func##_fp16 - -#define FN_FP16(func) cinn_custom_device_##func##_fp16 -__device__ inline float16 FN_FP16(ceil)(float16 x) { return hceil(x); } -__device__ inline float16 FN_FP16(floor)(float16 x) { return hfloor(x); } -__device__ inline float16 FN_FP16(round)(float16 x) { return __float2half(roundf(__half2float(x))); } -__device__ inline float16 FN_FP16(trunc)(float16 x) { return htrunc(x); } -__device__ inline float16 FN_FP16(sin)(float16 x) { return hsin(x); } -__device__ inline float16 FN_FP16(cos)(float16 x) { return hcos(x); } -__device__ inline float16 FN_FP16(exp)(float16 x) { return hexp(x); } -__device__ inline float16 FN_FP16(log)(float16 x) { return hlog(x); } -__device__ inline float16 FN_FP16(log2)(float16 x) { return hlog2(x); } -__device__ inline float16 FN_FP16(log10)(float16 x) { return hlog10(x); } -__device__ inline float16 FN_FP16(sqrt)(float16 x) { return hsqrt(x); } -__device__ inline float16 FN_FP16(rsqrt)(float16 x) { return hrsqrt(x); } -__device__ inline float16 FN_FP16(cbrt)(float16 x) { return __float2half(cbrtf(__half2float(x))); } -__device__ inline float16 FN_FP16(abs)(float16 x) { return __float2half(fabsf(__half2float(x))); } -__device__ inline bool FN_FP16(isnan)(float16 x) { return __hisnan(x); } -__device__ inline bool FN_FP16(isinf)(float16 x) { return __hisinf(x); } -__device__ inline bool FN_FP16(isfinite)(float16 x) { return !__hisinf(x) && !__hisnan(x); } -__device__ inline float16 FN_FP16(erf)(float16 x) { return __float2half(erff(__half2float(x))); } -__device__ inline float16 FN_FP16(tan)(float16 x) { return __float2half(tanf(__half2float(x))); } -__device__ inline float16 FN_FP16(sinh)(float16 x) { return __float2half(sinhf(__half2float(x))); } -__device__ inline float16 FN_FP16(cosh)(float16 x) { return __float2half(coshf(__half2float(x))); } -__device__ inline float16 FN_FP16(tanh)(float16 x) { return __float2half(tanhf(__half2float(x))); } -__device__ inline float16 FN_FP16(asin)(float16 x) { return __float2half(asinf(__half2float(x))); } -__device__ inline float16 FN_FP16(acos)(float16 x) { return __float2half(acosf(__half2float(x))); } -__device__ inline float16 FN_FP16(atan)(float16 x) { return __float2half(atanf(__half2float(x))); } -__device__ inline float16 FN_FP16(asinh)(float16 x) { return __float2half(asinhf(__half2float(x))); } -__device__ inline float16 FN_FP16(acosh)(float16 x) { return __float2half(acoshf(__half2float(x))); } -__device__ inline float16 FN_FP16(atanh)(float16 x) { return __float2half(atanhf(__half2float(x))); } -__device__ inline float16 FN_FP16(sigmoid)(float16 x) { return __float2half(1.0f / (1.0f + expf(-__half2float(x)))); } -__device__ inline float16 FN_FP16(mod)(float16 a, float16 b) { return __float2half(fmodf(__half2float(a), __half2float(b))); } -__device__ inline float16 FN_FP16(pow)(float16 a, float16 b) { return __float2half(powf(__half2float(a), __half2float(b))); } -__device__ inline float16 FN_FP16(add)(float16 a, float16 b) { return __hadd(a, b); } -__device__ inline float16 FN_FP16(sub)(float16 a, float16 b) { return __hsub(a, b); } -__device__ inline float16 FN_FP16(mul)(float16 a, float16 b) { return __hmul(a, b); } -__device__ inline float16 FN_FP16(div)(float16 a, float16 b) { return __hdiv(a, b); } -__device__ inline float16 FN_FP16(neg)(float16 a) { return __hneg(a); } -__device__ inline float16 FN_FP16(fma)(float16 a, float16 b, float16 c) { return __hfma(a, b, c); } -__device__ inline float16 FN_FP16(max)(float16 a, float16 b) { return __hgt(a, b) ? a : b; } -__device__ inline float16 FN_FP16(min)(float16 a, float16 b) { return __hlt(a, b) ? a : b; } - -// =============================================================== -// Warp Shuffle Functions (用于 Reduce 算子) -// =============================================================== -#define FN_SHUFFLE(func) cinn_custom_device_##func - -__device__ inline float FN_SHUFFLE(warp_shuffle_xor_fp32)(float v, int factor) { - return __shfl_xor_sync(0xffffffff, v, factor); -} -__device__ inline float FN_SHUFFLE(warp_shuffle_up_fp32)(float v, int factor) { - return __shfl_up_sync(0xffffffff, v, factor); -} -__device__ inline float FN_SHUFFLE(warp_shuffle_down_fp32)(float v, int factor) { - return __shfl_down_sync(0xffffffff, v, factor); -} - -__device__ inline int FN_SHUFFLE(warp_shuffle_xor_int32)(int v, int factor) { - return __shfl_xor_sync(0xffffffff, v, factor); -} -__device__ inline int FN_SHUFFLE(warp_shuffle_up_int32)(int v, int factor) { - return __shfl_up_sync(0xffffffff, v, factor); -} -__device__ inline int FN_SHUFFLE(warp_shuffle_down_int32)(int v, int factor) { - return __shfl_down_sync(0xffffffff, v, factor); -} - -// MACA/CUDA 的 shfl 指令通常只支持 32位,__half 需要强转或使用 intrinsics -__device__ inline __half FN_SHUFFLE(warp_shuffle_xor_fp16)(__half v, int factor) { - unsigned short val = __half_as_ushort(v); - unsigned short res = (unsigned short)__shfl_xor_sync(0xffffffff, (int)val, factor); - return __ushort_as_half(res); -} -__device__ inline __half FN_SHUFFLE(warp_shuffle_up_fp16)(__half v, int factor) { - unsigned short val = __half_as_ushort(v); - unsigned short res = (unsigned short)__shfl_up_sync(0xffffffff, (int)val, factor); - return __ushort_as_half(res); -} -__device__ inline __half FN_SHUFFLE(warp_shuffle_down_fp16)(__half v, int factor) { - unsigned short val = __half_as_ushort(v); - unsigned short res = (unsigned short)__shfl_down_sync(0xffffffff, (int)val, factor); - return __ushort_as_half(res); -} // =============================================================== // 7. Index Operations: Find, Sort & Resize Helpers From 25e963c10dd06483adc4959e8d0ddd5b0b3a1453 Mon Sep 17 00:00:00 2001 From: YuhanXu Date: Fri, 27 Feb 2026 15:13:03 +0800 Subject: [PATCH 10/17] Fix warp reduce, block reduce for metax_gpu warp_size=64. --- backends/metax_gpu/cinn/compiler/compiler.cc | 119 ++++++++++--------- 1 file changed, 66 insertions(+), 53 deletions(-) diff --git a/backends/metax_gpu/cinn/compiler/compiler.cc b/backends/metax_gpu/cinn/compiler/compiler.cc index 264923b1af..9f7772a76d 100644 --- a/backends/metax_gpu/cinn/compiler/compiler.cc +++ b/backends/metax_gpu/cinn/compiler/compiler.cc @@ -294,41 +294,40 @@ __device__ inline float16 FN_FP16(min)(float16 a, float16 b) { return __hlt(a, b // Warp Shuffle Functions (用于 Reduce 算子) // =============================================================== #define FN_SHUFFLE(func) cinn_custom_device_##func - __device__ inline float FN_SHUFFLE(warp_shuffle_xor_fp32)(float v, int factor) { - return __shfl_xor_sync(0xffffffff, v, factor); + return __shfl_xor(v, factor); } __device__ inline float FN_SHUFFLE(warp_shuffle_up_fp32)(float v, int factor) { - return __shfl_up_sync(0xffffffff, v, factor); + return __shfl_up(v, factor); } __device__ inline float FN_SHUFFLE(warp_shuffle_down_fp32)(float v, int factor) { - return __shfl_down_sync(0xffffffff, v, factor); + return __shfl_down(v, factor); } __device__ inline int FN_SHUFFLE(warp_shuffle_xor_int32)(int v, int factor) { - return __shfl_xor_sync(0xffffffff, v, factor); + return __shfl_xor(v, factor); } __device__ inline int FN_SHUFFLE(warp_shuffle_up_int32)(int v, int factor) { - return __shfl_up_sync(0xffffffff, v, factor); + return __shfl_up(v, factor); } __device__ inline int FN_SHUFFLE(warp_shuffle_down_int32)(int v, int factor) { - return __shfl_down_sync(0xffffffff, v, factor); + return __shfl_down(v, factor); } // MACA/CUDA 的 shfl 指令通常只支持 32位,__half 需要强转或使用 intrinsics __device__ inline __half FN_SHUFFLE(warp_shuffle_xor_fp16)(__half v, int factor) { unsigned short val = __half_as_ushort(v); - unsigned short res = (unsigned short)__shfl_xor_sync(0xffffffff, (int)val, factor); + unsigned short res = (unsigned short)__shfl_xor((int)val, factor); return __ushort_as_half(res); } __device__ inline __half FN_SHUFFLE(warp_shuffle_up_fp16)(__half v, int factor) { unsigned short val = __half_as_ushort(v); - unsigned short res = (unsigned short)__shfl_up_sync(0xffffffff, (int)val, factor); + unsigned short res = (unsigned short)__shfl_up((int)val, factor); return __ushort_as_half(res); } __device__ inline __half FN_SHUFFLE(warp_shuffle_down_fp16)(__half v, int factor) { unsigned short val = __half_as_ushort(v); - unsigned short res = (unsigned short)__shfl_down_sync(0xffffffff, (int)val, factor); + unsigned short res = (unsigned short)__shfl_down((int)val, factor); return __ushort_as_half(res); } } // extern "C" @@ -510,13 +509,16 @@ __device__ inline float16 cinn_min_fp16(const float16 left, const float16 right) const DTYPE value) { \ DTYPE tmp_val = value; \ unsigned int thread_id = threadIdx.x; \ + unsigned int lane_id = thread_id % WARP_SIZE; /* 获取在当前 Warp 内的局部 ID */ \ unsigned int block_dim = blockDim.x; \ /* 始终使用 Down Shuffle 进行规约 (Log2 复杂度) */ \ for (unsigned int offset = WARP_SIZE / 2; offset >= 1; offset /= 2) { \ DTYPE shfl_res = cinn_warp_shuffle_down_##DTYPE##_wrapper(tmp_val, offset); \ /* 检查数据来源是否有效:当前线程+offset 必须还在 Block 范围内 */ \ /* 如果 Block 大小不是 WARP_SIZE 的倍数,这一步至关重要 */ \ - DTYPE neighbor = (thread_id + offset < block_dim) ? shfl_res : (DTYPE)(INIT_VAL); \ + /* 【核心修复】不仅不能超出 block,且目标 Lane 也不能超出 WARP_SIZE */ \ + bool is_valid = (lane_id + offset < WARP_SIZE) && (thread_id + offset < block_dim); \ + DTYPE neighbor = is_valid ? shfl_res : (DTYPE)(INIT_VAL); \ tmp_val = cinn_##REDUCE_TYPE(tmp_val, neighbor); \ } \ /* 广播:虽然 Down Shuffle 只有 Lane 0 结果正确,但这里为了兼容 XOR 语义 */ \ @@ -532,65 +534,71 @@ __device__ inline bool cinn_warp_shuffle_down_bool_wrapper(bool v, int factor) { __device__ inline double cinn_warp_shuffle_down_double_wrapper(double v, int factor) { unsigned long long int val_u64 = *(unsigned long long int*)&v; - int lo = (int)val_u64; int hi = (int)(val_u64 >> 32); - lo = __shfl_down(lo, factor); - hi = __shfl_down(hi, factor); - unsigned long long int res_u64 = ((unsigned long long int)hi << 32) | (unsigned int)lo; - return *(double*)&res_u64; + int lo = __shfl_down((int)val_u64, factor); + int hi = __shfl_down((int)(val_u64 >> 32), factor); + unsigned long long int res = ((unsigned long long int)hi << 32) | (unsigned int)lo; + return *(double*)&res; } __device__ inline int64_t cinn_warp_shuffle_down_int64_t_wrapper(int64_t v, int factor) { - int lo = (int)v; int hi = (int)(v >> 32); - lo = __shfl_down(lo, factor); - hi = __shfl_down(hi, factor); + int lo = __shfl_down((int)v, factor); + int hi = __shfl_down((int)(v >> 32), factor); return ((int64_t)hi << 32) | (unsigned int)lo; } __device__ inline float16 cinn_warp_shuffle_down_float16_wrapper(float16 v, int factor) { unsigned short val = __half_as_ushort(v); - unsigned short res = (unsigned short)__shfl_down((int)val, factor); - return __ushort_as_half(res); + return __ushort_as_half((unsigned short)__shfl_down((int)val, factor)); } __device__ inline welford_fp32 cinn_warp_shuffle_down_welford_fp32_wrapper(welford_fp32 v, int factor) { - return __shfl_down_sync(0xffffffff, v, factor); + float m = __shfl_down(v.mean, factor); + float m2 = __shfl_down(v.m2, factor); + float w = __shfl_down(v.weight, factor); + return welford_fp32(m, m2, w); } __device__ inline welford_fp64 cinn_warp_shuffle_down_welford_fp64_wrapper(welford_fp64 v, int factor) { - return __shfl_down_sync(0xffffffff, v, factor); + double m = __shfl_down(v.mean, factor); + double m2 = __shfl_down(v.m2, factor); + double w = __shfl_down(v.weight, factor); + return welford_fp64(m, m2, w); } // 广播类型的 Idx 包装函数 (最后返回阶段使用 shfl_sync(var, 0)) -__device__ inline float cinn_warp_shuffle_idx_float_wrapper(float v, int lane) { return __shfl_sync(0xffffffff, v, lane); } -__device__ inline int cinn_warp_shuffle_idx_int_wrapper(int v, int lane) { return __shfl_sync(0xffffffff, v, lane); } -__device__ inline bool cinn_warp_shuffle_idx_bool_wrapper(bool v, int lane) { return __shfl_sync(0xffffffff, v, lane); } +__device__ inline float cinn_warp_shuffle_idx_float_wrapper(float v, int lane) { return __shfl(v, lane); } +__device__ inline int cinn_warp_shuffle_idx_int_wrapper(int v, int lane) { return __shfl(v, lane); } +__device__ inline bool cinn_warp_shuffle_idx_bool_wrapper(bool v, int lane) { return __shfl(v, lane); } + __device__ inline float16 cinn_warp_shuffle_idx_float16_wrapper(float16 v, int lane) { unsigned short val = __half_as_ushort(v); - return __ushort_as_half((unsigned short)__shfl_sync(0xffffffff, (int)val, lane)); + return __ushort_as_half((unsigned short)__shfl((int)val, lane)); } + __device__ inline double cinn_warp_shuffle_idx_double_wrapper(double v, int lane) { unsigned long long int val_u64 = *(unsigned long long int*)&v; - int lo = __shfl_sync(0xffffffff, (int)val_u64, lane); - int hi = __shfl_sync(0xffffffff, (int)(val_u64 >> 32), lane); + int lo = __shfl((int)val_u64, lane); + int hi = __shfl((int)(val_u64 >> 32), lane); unsigned long long int res = ((unsigned long long int)hi << 32) | (unsigned int)lo; return *(double*)&res; } + __device__ inline int64_t cinn_warp_shuffle_idx_int64_t_wrapper(int64_t v, int lane) { - int lo = __shfl_sync(0xffffffff, (int)v, lane); - int hi = __shfl_sync(0xffffffff, (int)(v >> 32), lane); + int lo = __shfl((int)v, lane); + int hi = __shfl((int)(v >> 32), lane); return ((int64_t)hi << 32) | (unsigned int)lo; } // === 新增:Welford 的 Idx (广播) 包装函数 === __device__ inline welford_fp32 cinn_warp_shuffle_idx_welford_fp32_wrapper(welford_fp32 v, int lane) { - float m = __shfl_sync(0xffffffff, v.mean, lane); - float m2 = __shfl_sync(0xffffffff, v.m2, lane); - float w = __shfl_sync(0xffffffff, v.weight, lane); + float m = __shfl(v.mean, lane); + float m2 = __shfl(v.m2, lane); + float w = __shfl(v.weight, lane); return welford_fp32(m, m2, w); } __device__ inline welford_fp64 cinn_warp_shuffle_idx_welford_fp64_wrapper(welford_fp64 v, int lane) { - double m = __shfl_sync(0xffffffff, v.mean, lane); - double m2 = __shfl_sync(0xffffffff, v.m2, lane); - double w = __shfl_sync(0xffffffff, v.weight, lane); + double m = __shfl(v.mean, lane); + double m2 = __shfl(v.m2, lane); + double w = __shfl(v.weight, lane); return welford_fp64(m, m2, w); } @@ -606,39 +614,44 @@ EXPAND_REDUCE_FP16_MACRO(CINN_WARP_SHUFFLE_INTERNAL_IMPL) // 5. Block Reduce & Discrete Reduce & Grid Reduce // =============================================================== -// Block Reduce Implementation -// 1. Warp Reduce -> SHM -// 2. Warp 0 reads SHM and Pads with Identity -// 3. Warp 0 Reduce -// 4. Broadcast #define CINN_BLOCK_REDUCE_IMPL(DTYPE, INIT_VAL, cinn_warp_shuffle_internal) \ - /* 1. Warp Reduce */ \ + /* 1. 单个 Warp 内部规约 */ \ DTYPE tmp_val = cinn_warp_shuffle_internal(value); \ if (return_warp || blockDim.x <= WARP_SIZE) { \ return tmp_val; \ } \ __syncthreads(); \ - /* 2. Write Warp results to SHM (Lane 0 only) */ \ + \ + /* 【核心修复】:计算 2D/3D 线程块的专属共享显存偏移量 */ \ + /* row_id 代表当前线程属于哪一个独立的空间行 */ \ + int row_id = threadIdx.y + threadIdx.z * blockDim.y; \ + int warps_per_row = (blockDim.x + WARP_SIZE - 1) / WARP_SIZE; \ + /* row_shm 是当前行专属的共享显存指针,彻底杜绝越行踩踏 */ \ + DTYPE* row_shm = shm + (row_id * warps_per_row); \ + \ + /* 2. 每个 Warp 的 0 号线程把结果写入自己行的专属 SHM */ \ if (threadIdx.x % WARP_SIZE == 0) { \ - shm[threadIdx.x / WARP_SIZE] = tmp_val; \ + row_shm[threadIdx.x / WARP_SIZE] = tmp_val; \ } \ __syncthreads(); \ - /* 3. Inter-Warp Reduce (Warp 0 only) */ \ + \ + /* 3. 跨 Warp 规约合并 (仅限每个行的前 WARP_SIZE 个线程执行) */ \ if (threadIdx.x < WARP_SIZE) { \ - int num_warps = (blockDim.x + WARP_SIZE - 1) / WARP_SIZE; \ - /* Pad with Identity value for idle threads in Warp 0 */ \ + /* 闲置线程用初始值 (比如 0) 填充 */ \ DTYPE reduce_val = (DTYPE)(INIT_VAL); \ - if (threadIdx.x < num_warps) { \ - reduce_val = shm[threadIdx.x]; \ + if (threadIdx.x < warps_per_row) { \ + reduce_val = row_shm[threadIdx.x]; \ } \ - /* Reduce across all threads in Warp 0 */ \ + /* 在 Warp 0 内部完成最终规约 */ \ reduce_val = cinn_warp_shuffle_internal(reduce_val); \ + /* 写入最终结果到当前行的头部 */ \ if (threadIdx.x == 0) { \ - shm[0] = reduce_val; \ + row_shm[0] = reduce_val; \ } \ } \ __syncthreads(); \ - return shm[0]; + /* 4. 同一行的所有线程都返回正确的最终结果 */ \ + return row_shm[0]; #define CINN_BLOCK_REDUCE_MACRO(REDUCE_TYPE, INIT_VAL, DTYPE) \ __device__ inline DTYPE cinn_block_reduce_##REDUCE_TYPE( \ From e664242a4a81332586ef4e171b26c8d916f4ef9b Mon Sep 17 00:00:00 2001 From: YuhanXu Date: Mon, 2 Mar 2026 17:30:17 +0800 Subject: [PATCH 11/17] Update GetMaxSharedMemPerBlock. --- backends/metax_gpu/runtime/runtime.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/backends/metax_gpu/runtime/runtime.cc b/backends/metax_gpu/runtime/runtime.cc index 6f295a2019..0500a5a332 100644 --- a/backends/metax_gpu/runtime/runtime.cc +++ b/backends/metax_gpu/runtime/runtime.cc @@ -411,7 +411,7 @@ C_Status GetMaxSharedMemPerBlock(const C_Device device, int count = 0; cudaError_t status = cudaDeviceGetAttribute(&count, cudaDevAttrMaxSharedMemoryPerBlock, id); - *shared_mem_per_block = 5000; + *shared_mem_per_block = 65534; return C_SUCCESS; } From 43c2503fed092a5154dc91b287688830f20bec6b Mon Sep 17 00:00:00 2001 From: YuhanXu Date: Thu, 5 Mar 2026 19:03:38 +0800 Subject: [PATCH 12/17] Add compiler.cc int64 int32 abs. int64 vs double abs. --- backends/metax_gpu/cinn/compiler/compiler.cc | 25 +++++++++++++++++++- 1 file changed, 24 insertions(+), 1 deletion(-) diff --git a/backends/metax_gpu/cinn/compiler/compiler.cc b/backends/metax_gpu/cinn/compiler/compiler.cc index 9f7772a76d..2dda3fd285 100644 --- a/backends/metax_gpu/cinn/compiler/compiler.cc +++ b/backends/metax_gpu/cinn/compiler/compiler.cc @@ -213,6 +213,7 @@ __device__ inline int FN_INT32(mod)(int a, int b) { } __device__ inline int FN_INT32(max)(int a, int b) { return cinn_max(a, b); } __device__ inline int FN_INT32(min)(int a, int b) { return cinn_min(a, b); } +__device__ inline int FN_INT32(abs)(int x) { return abs(x); } __device__ inline int FN_INT32(left_shift)(int a, int b) { return a << b; } __device__ inline int FN_INT32(right_shift)(int a, int b) { return a >> b; } __device__ inline int FN_INT32(bitwise_and)(int a, int b) { return a & b; } @@ -239,6 +240,7 @@ __device__ inline int64_t FN_INT64(bitwise_xor)(int64_t a, int64_t b) { return a __device__ inline int64_t FN_INT64(bitwise_not)(int64_t a) { return ~a; } __device__ inline int64_t FN_INT64(clz)(int64_t a) { return __clzll(a); } __device__ inline int64_t FN_INT64(popc)(int64_t a) { return __popcll(a); } +__device__ inline int64_t FN_INT64(abs)(int64_t x) { return llabs(x); } __device__ inline int64_t FN_INT64(logical_right_shift)(int64_t a, int64_t b) { return ((uint64_t)a >> b); } __device__ inline int64_t FN_INT64(trunc)(int64_t a) { return a; } __device__ inline int64_t FN_INT64(mod)(int64_t a, int64_t b) { int64_t res = a % b; if ((res != 0) && ((b ^ res) < 0)) res += b; return res; } @@ -997,7 +999,28 @@ ARGIDX_STRUCT_MACRO(argidx_fp32_i32, float, int, 0) ARGIDX_STRUCT_MACRO(argidx_i32_i32, int, int, 0) // 手写 std::max 重载 -namespace std { +namespace std { + // --- 之前加的 long long / int64_t 补丁保持不变 --- + __device__ __forceinline__ int64_t max(long long a, int64_t b) { return a > b ? a : b; } + __device__ __forceinline__ int64_t max(int64_t a, long long b) { return a > b ? a : b; } + __device__ __forceinline__ int64_t min(long long a, int64_t b) { return a < b ? a : b; } + __device__ __forceinline__ int64_t min(int64_t a, long long b) { return a < b ? a : b; } + + // ============================================================== + // 【新增防弹补丁】:解决 CINN 漏打 'f' 后缀导致的 float 和 double 混合报错 + // ============================================================== +__device__ __forceinline__ double max(float a, double b) { return a > b ? (double)a : b; } + __device__ __forceinline__ double max(double a, float b) { return a > b ? a : (double)b; } + __device__ __forceinline__ double min(float a, double b) { return a < b ? (double)a : b; } + __device__ __forceinline__ double min(double a, float b) { return a < b ? a : (double)b; } + + // 以防万一,解决 CINN 把 0 打印成 int 与 float 混合的报错 (如 std::max(val, 0)) + __device__ __forceinline__ float max(float a, int b) { return a > b ? a : (float)b; } + __device__ __forceinline__ float max(int a, float b) { return a > b ? (float)a : b; } + __device__ __forceinline__ float min(float a, int b) { return a < b ? a : (float)b; } + __device__ __forceinline__ float min(int a, float b) { return a < b ? (float)a : b; } + // ============================================================== + // ArgMax 实现 template __device__ __forceinline__ T max_argidx_impl(const T& a, const T& b) { From 878c9ef338519533ad58a70019ab2bdc473c56de Mon Sep 17 00:00:00 2001 From: Yuhan Xu Date: Tue, 31 Mar 2026 08:11:34 +0000 Subject: [PATCH 13/17] Translate annotation to English. Fix undeclared identifier 'cinn_discrete_reduce_max_argidx_fp32_i32' --- backends/metax_gpu/CMakeLists.txt | 5 - backends/metax_gpu/cinn/CMakeLists.txt | 50 +++-- backends/metax_gpu/cinn/cinn_interface.cc | 49 +++-- backends/metax_gpu/cinn/cinn_interface.h | 24 +-- backends/metax_gpu/cinn/compiler/compiler.cc | 183 ++++++++++-------- .../metax_gpu/cinn/passes/pass_manager.cc | 20 +- .../metax_gpu/cinn/runtime/cinn_runtime.cc | 32 +-- .../kernels/impl/conv_grad_kernel_impl.h | 10 +- .../metax_gpu/kernels/impl/conv_kernel_impl.h | 2 +- .../unittest/test_elementwise_pow_op_metax.py | 78 ++------ 10 files changed, 221 insertions(+), 232 deletions(-) diff --git a/backends/metax_gpu/CMakeLists.txt b/backends/metax_gpu/CMakeLists.txt index 23a934cbf0..d95a982000 100755 --- a/backends/metax_gpu/CMakeLists.txt +++ b/backends/metax_gpu/CMakeLists.txt @@ -43,7 +43,6 @@ if(WITH_ARM) endif() include(paddle) -# 【修改点 1】: 添加 CINN 子目录编译 if(WITH_CINN) message(STATUS "[MetaX] CINN enabled, adding subdirectory: cinn") add_definitions(-DWITH_CINN) @@ -802,8 +801,6 @@ set(CMAKE_CUCC_FLAGS "-I ${MACA_PATH}/tools/cu-bridge/include/") add_library(${TARGET_NAME} SHARED ${CUSTOM_DEVICE_SRCS}) -# 【修改点 2】: 添加 CINN 接口的头文件搜索路径 -# 这样 runtime/runtime.cc 里的 #include "../cinn/cinn_interface.h" 才能生效 if(WITH_CINN) target_include_directories(${TARGET_NAME} PRIVATE "${CMAKE_CURRENT_SOURCE_DIR}/cinn" @@ -839,8 +836,6 @@ target_link_libraries(${TARGET_NAME} ${MACA_PATH}/lib/libmccl.so) target_link_libraries(${TARGET_NAME} ${MACA_PATH}/lib/libmcFlashAttn.so) target_link_libraries(${TARGET_NAME} ${MACA_PATH}/lib/libmcpti.so) -# 【修改点 3】: 将 CINN 编译出的对象文件链接进最终的 .so -# 只有这样,Plugin 加载时才能找到 InitCinnInterface 等符号 if(WITH_CINN) message(STATUS "[MetaX] Linking CINN object library") target_link_libraries(${TARGET_NAME} $) diff --git a/backends/metax_gpu/cinn/CMakeLists.txt b/backends/metax_gpu/cinn/CMakeLists.txt index 243599d490..ea35bea8e3 100644 --- a/backends/metax_gpu/cinn/CMakeLists.txt +++ b/backends/metax_gpu/cinn/CMakeLists.txt @@ -2,47 +2,45 @@ # CINN Plugin for MetaX (MACA) Backend # ============================================================================= -# 1. 查找 MACA 路径 -# 为了在 runtime/cinn_runtime.cc 或 compiler.cc 中能 #include -# 我们需要把沐曦 SDK 的头文件路径加进来 +# 1. Locate MACA SDK path +# To allow #include in runtime/cinn_runtime.cc or compiler.cc, +# we need to add the MetaX SDK header search path. set(MACA_PATH $ENV{MACA_PATH}) if(NOT MACA_PATH) - set(MACA_PATH "/opt/maca") # 默认回退路径 + set(MACA_PATH "/opt/maca") # Default fallback path message(STATUS "[MetaX CINN] MACA_PATH not set, using default: ${MACA_PATH}") else() message(STATUS "[MetaX CINN] Found MACA_PATH: ${MACA_PATH}") endif() -# 2. 定义源文件列表 -# 这里必须包含所有涉及到 CINN 实现的 .cc 文件 +# 2. Define source file list +# All .cc files involved in the CINN implementation must be included here. set(CINN_SRCS - cinn_interface.cc # 总入口,负责 InitCinnInterface - compiler/compiler.cc # 【关键】负责 MetaxCompile 和 MetaxGetRuntimeSource - runtime/cinn_runtime.cc # 负责 MetaxModuleLoad, MetaxLaunchKernel - passes/pass_manager.cc # 负责 MetaxApplyCustomPass + cinn_interface.cc # Main entry point, responsible for InitCinnInterface + compiler/compiler.cc # Implements MetaxCompile and MetaxGetRuntimeSource + runtime/cinn_runtime.cc # Implements MetaxModuleLoad, MetaxLaunchKernel + passes/pass_manager.cc # Implements MetaxApplyCustomPass ) -# 3. 创建 OBJECT 库 -# 使用 OBJECT 模式,只编译出 .o 文件,不生成 .a 或 .so -# 这样上一级的 CMake 可以直接抓取这些 .o 文件链接进最终的 plugin.so +# 3. Create OBJECT library +# Use OBJECT mode to compile into .o files only, without generating .a or .so. +# This allows the parent CMake to directly collect these .o files and link them +# into the final plugin.so. add_library(metax_cinn_obj OBJECT ${CINN_SRCS}) -# 4. 配置头文件搜索路径 +# 4. Configure header search paths target_include_directories(metax_cinn_obj PRIVATE - ${CMAKE_CURRENT_SOURCE_DIR} # 允许引用当前目录头文件 (cinn_interface.h) - ${CMAKE_CURRENT_SOURCE_DIR}/../ # 允许引用上层头文件 (如 common/) - ${MACA_PATH}/include # 【关键】允许引用 - ${PADDLE_SOURCE_DIR} # 【新增】必须加这个!否则找不到 paddle/phi/... - # Paddle 的头文件路径通常由外部环境 (Paddle_DIR) 自动包含 + ${CMAKE_CURRENT_SOURCE_DIR} # Allow referencing headers in current directory (cinn_interface.h) + ${CMAKE_CURRENT_SOURCE_DIR}/../ # Allow referencing parent-level headers (e.g., common/) + ${MACA_PATH}/include # Allow referencing + ${PADDLE_SOURCE_DIR} # Allow referencing paddle/phi/... headers + # Paddle header paths are typically auto-included via the external environment (Paddle_DIR) ) -# 5. 编译选项设置 -# CINN 组件通常依赖 C++17 标准 +# 5. Compiler options +# The CINN component typically requires C++17 standard set_property(TARGET metax_cinn_obj PROPERTY CXX_STANDARD 17) -# 开启 PIC (Position Independent Code) -# 因为这些 .o 文件最终要被链接进动态库,必须开启此选项 +# Enable PIC (Position Independent Code) +# Required because these .o files will ultimately be linked into a shared library set_property(TARGET metax_cinn_obj PROPERTY POSITION_INDEPENDENT_CODE ON) - -# 如果 compiler.cc 需要使用 filesystem 等库,可能需要链接 stdc++fs (视 GCC 版本而定) -# 但因为是 OBJECT 库,链接操作推迟到父级进行 \ No newline at end of file diff --git a/backends/metax_gpu/cinn/cinn_interface.cc b/backends/metax_gpu/cinn/cinn_interface.cc index 041b2e3b54..a65ce16832 100644 --- a/backends/metax_gpu/cinn/cinn_interface.cc +++ b/backends/metax_gpu/cinn/cinn_interface.cc @@ -1,4 +1,4 @@ -// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +// Copyright (c) 2026 PaddlePaddle Authors. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -21,38 +21,38 @@ namespace custom_device { namespace metax { // ============================================================ -// 外部函数声明 (External Function Declarations) -// 这些函数需要在对应的子目录文件中实现 (.cc) +// External Function Declarations +// These functions must be implemented in the corresponding subdirectory files (.cc). // ============================================================ -// --- 来自 compiler/compiler.cc --- -// 负责调用 mxcc 将 CINN 生成的源代码编译为二进制 +// --- From compiler/compiler.cc --- +// Invokes the mxcc toolchain to compile CINN-generated source code into a binary extern C_Status MetaxCompile(void* dev_ptr, const char* code, char* out_path, size_t len); -// 负责提供沐曦 GPU 运行时的基础源码 (类似 cuda_device_runtime.cu) +// Provides the MetaX GPU device runtime source code extern const char* MetaxGetRuntimeSource(void* dev_ptr); -// --- 来自 runtime/cinn_runtime.cc --- -// 负责加载编译好的二进制模块 (.mx / .so) +// --- From runtime/cinn_runtime.cc --- +// Loads a compiled binary module (.mx / .so) extern C_Status MetaxModuleLoad(void* dev_ptr, const char* path, void** mod_out); -// 负责卸载模块 +// Unloads a module extern C_Status MetaxModuleUnload(void* dev_ptr, void* module_handle); -// 负责从模块中查找核函数地址 +// Retrieves the kernel function address from a loaded module extern C_Status MetaxGetKernelAddress(void* dev_ptr, void* module_handle, const char* func_name, void** func_out); -// 负责启动核函数 (Launch Kernel) +// Launches a kernel function extern C_Status MetaxLaunchKernel(void* dev_ptr, void* func_ptr, void** args, @@ -63,47 +63,46 @@ extern C_Status MetaxLaunchKernel(void* dev_ptr, void* stream); -// --- 来自 passes/pass_manager.cc --- -// 负责应用自定义的图优化 Pass +// --- From passes/pass_manager.cc --- +// Applies custom graph optimization passes extern C_Status MetaxApplyCustomPass(void* dev_ptr, void* ir_module); // ============================================================ -// 接口初始化实现 (Interface Initialization) +// Interface Initialization // ============================================================ -// 静态实例,确保在插件生命周期内有效 +// Static instance, valid throughout the plugin lifetime static C_CinnInterface metax_cinn_impl; void InitCinnInterface(C_DeviceInterface* device_interface) { - // 1. 安全起见,先清零 + // 1. Zero-initialize for safety std::memset(&metax_cinn_impl, 0, sizeof(C_CinnInterface)); - // 2. 设置结构体大小 (用于版本校验) + // 2. Set struct size (used for version validation) metax_cinn_impl.size = sizeof(C_CinnInterface); - // 3. 设置上下文指针 (可选) - // 如果你的实现需要全局状态,可以指向一个结构体;否则设为 nullptr + // 3. Set context pointer (optional) + // Point to a global state struct if your implementation needs one; otherwise nullptr metax_cinn_impl.dev_ptr = nullptr; - // 4. 挂载 Compiler Toolchain 接口 + // 4. Register Compiler Toolchain interface metax_cinn_impl.compile = MetaxCompile; metax_cinn_impl.get_runtime_source = MetaxGetRuntimeSource; - // 5. 挂载 Runtime Strategy 接口 + // 5. Register Runtime Strategy interface metax_cinn_impl.module_load = MetaxModuleLoad; metax_cinn_impl.module_unload = MetaxModuleUnload; metax_cinn_impl.get_kernel_address = MetaxGetKernelAddress; metax_cinn_impl.launch_kernel = MetaxLaunchKernel; - // 6. 挂载 Compile Strategy 接口 + // 6. Register Compilation Strategy interface metax_cinn_impl.apply_custom_pass = MetaxApplyCustomPass; - // 7. 【关键】将填好的表挂载到 Paddle 主设备接口上 + // 7. Attach the populated dispatch table to the Paddle device interface if (device_interface) { device_interface->cinn_interface = &metax_cinn_impl; - // VLOG(3) << "[MetaX] CINN Interface initialized successfully."; } else { std::cerr << "[MetaX] Error: device_interface is null during CINN init." << std::endl; } @@ -111,4 +110,4 @@ void InitCinnInterface(C_DeviceInterface* device_interface) { } // namespace metax } // namespace custom_device -} // namespace paddle \ No newline at end of file +} // namespace paddle diff --git a/backends/metax_gpu/cinn/cinn_interface.h b/backends/metax_gpu/cinn/cinn_interface.h index 012e02770c..330c172b5e 100644 --- a/backends/metax_gpu/cinn/cinn_interface.h +++ b/backends/metax_gpu/cinn/cinn_interface.h @@ -1,11 +1,11 @@ -// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. -// +// Copyright (c) 2026 PaddlePaddle Authors. All Rights Reserved. +// // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at -// +// // http://www.apache.org/licenses/LICENSE-2.0 -// +// // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. @@ -14,7 +14,7 @@ #pragma once -// 引入 Paddle 定义的 C 接口结构体 +// Include the Paddle-defined C interface structures #include "paddle/phi/backends/device_ext.h" namespace paddle { @@ -22,14 +22,16 @@ namespace custom_device { namespace metax { /** - * @brief 初始化 CINN 接口 - * * 这个函数由 runtime.cc 中的 InitPlugin 调用。 - * 它负责将 metax_gpu/cinn 下实现的编译器和运行时函数指针, - * 填充到 device_interface->cinn_interface 中。 - * * @param device_interface Paddle Host 侧传入的设备接口指针 + * @brief Initialize the CINN interface. + * + * This function is called by InitPlugin in runtime.cc. + * It populates device_interface->cinn_interface with the compiler + * and runtime function pointers implemented under metax_gpu/cinn. + * + * @param device_interface The device interface pointer passed from the Paddle host side. */ void InitCinnInterface(C_DeviceInterface* device_interface); } // namespace metax } // namespace custom_device -} // namespace paddle \ No newline at end of file +} // namespace paddle diff --git a/backends/metax_gpu/cinn/compiler/compiler.cc b/backends/metax_gpu/cinn/compiler/compiler.cc index 2dda3fd285..51a04736da 100644 --- a/backends/metax_gpu/cinn/compiler/compiler.cc +++ b/backends/metax_gpu/cinn/compiler/compiler.cc @@ -12,7 +12,7 @@ #include #include -// Host 端头文件,仅供 compiler.cc 使用 +// Host-side header, used only by compiler.cc #include "paddle/phi/backends/device_ext.h" namespace paddle { @@ -20,7 +20,7 @@ namespace custom_device { namespace metax { // ============================================================ -// 1. Runtime Source (JIT 源码头文件 - Device 端代码) +// 1. Runtime Source (JIT Source Header - Device-side Code) // ============================================================ static const char* kMacaRuntimeSource = R"MACA_SOURCE( #pragma once @@ -40,7 +40,7 @@ typedef int int32_t; typedef long long int64_t; #endif -// 兼容 CINN 生成代码中对 __half 的引用 +// Compatible with __half references in CINN-generated code typedef __half float16; #define CINN_UINT8_MIN 0 @@ -293,7 +293,7 @@ __device__ inline float16 FN_FP16(max)(float16 a, float16 b) { return __hgt(a, b __device__ inline float16 FN_FP16(min)(float16 a, float16 b) { return __hlt(a, b) ? a : b; } // =============================================================== -// Warp Shuffle Functions (用于 Reduce 算子) +// Warp Shuffle Functions (used by reduce operators) // =============================================================== #define FN_SHUFFLE(func) cinn_custom_device_##func __device__ inline float FN_SHUFFLE(warp_shuffle_xor_fp32)(float v, int factor) { @@ -316,7 +316,7 @@ __device__ inline int FN_SHUFFLE(warp_shuffle_down_int32)(int v, int factor) { return __shfl_down(v, factor); } -// MACA/CUDA 的 shfl 指令通常只支持 32位,__half 需要强转或使用 intrinsics +// MACA/CUDA shfl intrinsics only support 32-bit natively; __half requires bitcast or intrinsics __device__ inline __half FN_SHUFFLE(warp_shuffle_xor_fp16)(__half v, int factor) { unsigned short val = __half_as_ushort(v); unsigned short res = (unsigned short)__shfl_xor((int)val, factor); @@ -436,16 +436,16 @@ __device__ inline bool cinn_all(const bool left, const bool right) { return left __device__ inline bool cinn_any(const bool left, const bool right) { return left || right; } // --- FP16 (Half) --- -// 注意:必须使用 __hadd 等 intrinsics,不能直接用 + +// Note: must use __hadd and similar intrinsics; direct + operator is not supported __device__ inline float16 cinn_sum_fp16(const float16 left, const float16 right) { return __hadd(left, right); } __device__ inline float16 cinn_prod_fp16(const float16 left, const float16 right) { return __hmul(left, right); } __device__ inline float16 cinn_max_fp16(const float16 left, const float16 right) { return __hgt(left, right) ? left : right; } __device__ inline float16 cinn_min_fp16(const float16 left, const float16 right) { return __hlt(left, right) ? left : right; } // --- BF16 (BFloat16) --- -// 【注意】如果 mxcc 不支持 __nv_bfloat16,这部分需要注释掉或报错 -#if defined(__MACACC__) || defined(__CUDACC__) // 假设支持 -// 暂时留空,如果报错请注释掉 BF16 部分 +// [Note] If mxcc does not support __nv_bfloat16, this section should be commented out or produce an error +#if defined(__MACACC__) || defined(__CUDACC__) // Assuming support is available +// Placeholder: comment out the BF16 section if compilation errors occur // __device__ inline __nv_bfloat16 cinn_sum_bf16(...) ... #endif @@ -489,7 +489,7 @@ __device__ inline float16 cinn_min_fp16(const float16 left, const float16 right) MACRO(all, true, bool, ##__VA_ARGS__) \ MACRO(any, false, bool, ##__VA_ARGS__) -// FP16 初始值 (使用 hex 转换) +// FP16 initial values (using hex conversion) #define EXPAND_REDUCE_FP16_MACRO(MACRO, ...) \ MACRO(sum_fp16, 0.0, float16, ##__VA_ARGS__) \ MACRO(prod_fp16, 1.0, float16, ##__VA_ARGS__) \ @@ -501,30 +501,20 @@ __device__ inline float16 cinn_min_fp16(const float16 left, const float16 right) // 4. Warp Shuffle Wrappers (Using Legacy API & Full Down Strategy) // =============================================================== -// 【核心修复】Warp Reduce 逻辑重写 -// 1. 弃用 XOR 模式:因为在 64-thread warp 下,跨 32 边界的 XOR 可能存在未定义行为或硬件 bug。 -// 2. 统一使用 DOWN 模式:__shfl_down 是单向规约,Lane 0 总是能收集到数据的,更加稳健。 -// 3. 严格的边界检查:确保 fetch 的来源线程在 Block 范围内,否则使用 INIT_VAL 填充。 - #define CINN_WARP_SHUFFLE_INTERNAL_IMPL(REDUCE_TYPE, INIT_VAL, DTYPE) \ __device__ inline DTYPE cinn_warp_shuffle_##REDUCE_TYPE##_internal( \ const DTYPE value) { \ DTYPE tmp_val = value; \ unsigned int thread_id = threadIdx.x; \ - unsigned int lane_id = thread_id % WARP_SIZE; /* 获取在当前 Warp 内的局部 ID */ \ + unsigned int lane_id = thread_id % WARP_SIZE; /* Get local lane ID within current warp */ \ unsigned int block_dim = blockDim.x; \ - /* 始终使用 Down Shuffle 进行规约 (Log2 复杂度) */ \ + /* Always use down-shuffle for reduction (O(log N) complexity) */ \ for (unsigned int offset = WARP_SIZE / 2; offset >= 1; offset /= 2) { \ DTYPE shfl_res = cinn_warp_shuffle_down_##DTYPE##_wrapper(tmp_val, offset); \ - /* 检查数据来源是否有效:当前线程+offset 必须还在 Block 范围内 */ \ - /* 如果 Block 大小不是 WARP_SIZE 的倍数,这一步至关重要 */ \ - /* 【核心修复】不仅不能超出 block,且目标 Lane 也不能超出 WARP_SIZE */ \ bool is_valid = (lane_id + offset < WARP_SIZE) && (thread_id + offset < block_dim); \ DTYPE neighbor = is_valid ? shfl_res : (DTYPE)(INIT_VAL); \ tmp_val = cinn_##REDUCE_TYPE(tmp_val, neighbor); \ } \ - /* 广播:虽然 Down Shuffle 只有 Lane 0 结果正确,但这里为了兼容 XOR 语义 */ \ - /* 我们用 shfl 0 把 Lane 0 的结果广播给所有人 (CINN Block Reduce 需要) */ \ return cinn_warp_shuffle_idx_##DTYPE##_wrapper(tmp_val, 0); \ } @@ -566,7 +556,7 @@ __device__ inline welford_fp64 cinn_warp_shuffle_down_welford_fp64_wrapper(welfo return welford_fp64(m, m2, w); } -// 广播类型的 Idx 包装函数 (最后返回阶段使用 shfl_sync(var, 0)) +// Broadcast-type idx wrapper functions (used in final return stage via shfl(var, 0)) __device__ inline float cinn_warp_shuffle_idx_float_wrapper(float v, int lane) { return __shfl(v, lane); } __device__ inline int cinn_warp_shuffle_idx_int_wrapper(int v, int lane) { return __shfl(v, lane); } __device__ inline bool cinn_warp_shuffle_idx_bool_wrapper(bool v, int lane) { return __shfl(v, lane); } @@ -590,7 +580,7 @@ __device__ inline int64_t cinn_warp_shuffle_idx_int64_t_wrapper(int64_t v, int l return ((int64_t)hi << 32) | (unsigned int)lo; } -// === 新增:Welford 的 Idx (广播) 包装函数 === +// === Welford idx (broadcast) wrapper functions === __device__ inline welford_fp32 cinn_warp_shuffle_idx_welford_fp32_wrapper(welford_fp32 v, int lane) { float m = __shfl(v.mean, lane); float m2 = __shfl(v.m2, lane); @@ -617,42 +607,42 @@ EXPAND_REDUCE_FP16_MACRO(CINN_WARP_SHUFFLE_INTERNAL_IMPL) // =============================================================== #define CINN_BLOCK_REDUCE_IMPL(DTYPE, INIT_VAL, cinn_warp_shuffle_internal) \ - /* 1. 单个 Warp 内部规约 */ \ + /* 1. Intra-warp reduction */ \ DTYPE tmp_val = cinn_warp_shuffle_internal(value); \ if (return_warp || blockDim.x <= WARP_SIZE) { \ return tmp_val; \ } \ __syncthreads(); \ \ - /* 【核心修复】:计算 2D/3D 线程块的专属共享显存偏移量 */ \ - /* row_id 代表当前线程属于哪一个独立的空间行 */ \ + /* Compute per-row shared memory offset for 2D/3D thread blocks */ \ + /* row_id identifies which independent spatial row the current thread belongs to */ \ int row_id = threadIdx.y + threadIdx.z * blockDim.y; \ int warps_per_row = (blockDim.x + WARP_SIZE - 1) / WARP_SIZE; \ - /* row_shm 是当前行专属的共享显存指针,彻底杜绝越行踩踏 */ \ + /* row_shm is the per-row shared memory pointer, preventing cross-row data corruption */ \ DTYPE* row_shm = shm + (row_id * warps_per_row); \ \ - /* 2. 每个 Warp 的 0 号线程把结果写入自己行的专属 SHM */ \ + /* 2. Lane 0 of each warp writes its result to its row's dedicated shared memory slot */ \ if (threadIdx.x % WARP_SIZE == 0) { \ row_shm[threadIdx.x / WARP_SIZE] = tmp_val; \ } \ __syncthreads(); \ \ - /* 3. 跨 Warp 规约合并 (仅限每个行的前 WARP_SIZE 个线程执行) */ \ + /* 3. Cross-warp reduction (only the first WARP_SIZE threads per row participate) */ \ if (threadIdx.x < WARP_SIZE) { \ - /* 闲置线程用初始值 (比如 0) 填充 */ \ + /* Idle threads are filled with the identity value */ \ DTYPE reduce_val = (DTYPE)(INIT_VAL); \ if (threadIdx.x < warps_per_row) { \ reduce_val = row_shm[threadIdx.x]; \ } \ - /* 在 Warp 0 内部完成最终规约 */ \ + /* Perform final reduction within warp 0 */ \ reduce_val = cinn_warp_shuffle_internal(reduce_val); \ - /* 写入最终结果到当前行的头部 */ \ + /* Write final result to the head of the current row */ \ if (threadIdx.x == 0) { \ row_shm[0] = reduce_val; \ } \ } \ __syncthreads(); \ - /* 4. 同一行的所有线程都返回正确的最终结果 */ \ + /* 4. All threads in the same row return the correct final result */ \ return row_shm[0]; #define CINN_BLOCK_REDUCE_MACRO(REDUCE_TYPE, INIT_VAL, DTYPE) \ @@ -694,6 +684,27 @@ EXPAND_REDUCE_FP64_MACRO(CINN_DISCRETE_REDUCE_MACRO) EXPAND_REDUCE_BOOL_MACRO(CINN_DISCRETE_REDUCE_MACRO) EXPAND_REDUCE_FP16_MACRO(CINN_DISCRETE_REDUCE_MACRO) +// Discrete reduce for argidx types +__device__ inline argidx_fp32_i32 cinn_discrete_reduce_max_argidx_fp32_i32( + const argidx_fp32_i32 value, argidx_fp32_i32 *shm) { + CINN_DISCRETE_REDUCE_IMPL(max_argidx_fp32_i32, value); +} + +__device__ inline argidx_fp32_i64 cinn_discrete_reduce_max_argidx_fp32_i64( + const argidx_fp32_i64 value, argidx_fp32_i64 *shm) { + CINN_DISCRETE_REDUCE_IMPL(max_argidx_fp32_i64, value); +} + +__device__ inline argidx_fp32_i32 cinn_discrete_reduce_min_argidx_fp32_i32( + const argidx_fp32_i32 value, argidx_fp32_i32 *shm) { + CINN_DISCRETE_REDUCE_IMPL(min_argidx_fp32_i32, value); +} + +__device__ inline argidx_fp32_i64 cinn_discrete_reduce_min_argidx_fp32_i64( + const argidx_fp32_i64 value, argidx_fp32_i64 *shm) { + CINN_DISCRETE_REDUCE_IMPL(min_argidx_fp32_i64, value); +} + #define CINN_GRID_REDUCE_IMPL(REDUCE_TYPE, init_value, DTYPE) \ DTYPE tmp_val = init_value; \ for (int y = 0; y < gridDim.y; y++) { \ @@ -961,7 +972,7 @@ __device__ int cinn_custom_device_resize_bicubic(const int *buf, // --- C++ Scope Start --- // arg reduce arg index struct -// 【核心】不定义 operator<,强制走 std::max 重载 +// Do not define operator<; force dispatch through std::max overloads #define ARGIDX_STRUCT_MACRO(TYPENAME, DTYPE, ITYPE, IINIT) \ struct TYPENAME { \ DTYPE value; \ @@ -971,7 +982,7 @@ __device__ int cinn_custom_device_resize_bicubic(const int *buf, __device__ TYPENAME(DTYPE value, ITYPE index) \ : value(value), index(index) {} \ __device__ explicit operator ITYPE() { return index; } \ - /* 赋值运算符支持 */ \ + /* Assignment operator support */ \ __device__ inline TYPENAME& operator=(const TYPENAME& other) { \ value = other.value; \ index = other.index; \ @@ -984,7 +995,7 @@ __device__ int cinn_custom_device_resize_bicubic(const int *buf, } \ }; -// 实例化结构体 +// Instantiate structs #ifdef CINN_CUDA_FP16 ARGIDX_STRUCT_MACRO(argidx_fp16_i64, float16, int64_t, 0LL) #endif @@ -998,30 +1009,25 @@ ARGIDX_STRUCT_MACRO(argidx_u8_i64, uint8_t, int64_t, 0LL) ARGIDX_STRUCT_MACRO(argidx_fp32_i32, float, int, 0) ARGIDX_STRUCT_MACRO(argidx_i32_i32, int, int, 0) -// 手写 std::max 重载 +// std::max overloads namespace std { - // --- 之前加的 long long / int64_t 补丁保持不变 --- __device__ __forceinline__ int64_t max(long long a, int64_t b) { return a > b ? a : b; } __device__ __forceinline__ int64_t max(int64_t a, long long b) { return a > b ? a : b; } __device__ __forceinline__ int64_t min(long long a, int64_t b) { return a < b ? a : b; } __device__ __forceinline__ int64_t min(int64_t a, long long b) { return a < b ? a : b; } - // ============================================================== - // 【新增防弹补丁】:解决 CINN 漏打 'f' 后缀导致的 float 和 double 混合报错 - // ============================================================== __device__ __forceinline__ double max(float a, double b) { return a > b ? (double)a : b; } __device__ __forceinline__ double max(double a, float b) { return a > b ? a : (double)b; } __device__ __forceinline__ double min(float a, double b) { return a < b ? (double)a : b; } __device__ __forceinline__ double min(double a, float b) { return a < b ? a : (double)b; } - // 以防万一,解决 CINN 把 0 打印成 int 与 float 混合的报错 (如 std::max(val, 0)) + // As a safeguard, resolve ambiguity when CINN emits int literals mixed with float (e.g., std::max(val, 0)) __device__ __forceinline__ float max(float a, int b) { return a > b ? a : (float)b; } __device__ __forceinline__ float max(int a, float b) { return a > b ? (float)a : b; } __device__ __forceinline__ float min(float a, int b) { return a < b ? a : (float)b; } __device__ __forceinline__ float min(int a, float b) { return a < b ? (float)a : b; } - // ============================================================== - // ArgMax 实现 + // ArgMax implementation template __device__ __forceinline__ T max_argidx_impl(const T& a, const T& b) { if (a.value > b.value) return a; @@ -1036,7 +1042,7 @@ __device__ __forceinline__ double max(float a, double b) { return a > b ? (doubl return a.index < b.index ? a : b; } - // Volatile 重载 + // Volatile overloads template __device__ __forceinline__ T max_argidx_volatile_impl(const volatile T& a, const volatile T& b) { T va, vb; @@ -1053,7 +1059,7 @@ __device__ __forceinline__ double max(float a, double b) { return a > b ? (doubl return min_argidx_impl(va, vb); } - // 显式展开 + // Explicit instantiation __device__ __forceinline__ argidx_fp32_i64 max(const argidx_fp32_i64& a, const argidx_fp32_i64& b) { return max_argidx_impl(a, b); } __device__ __forceinline__ argidx_fp32_i64 min(const argidx_fp32_i64& a, const argidx_fp32_i64& b) { return min_argidx_impl(a, b); } @@ -1068,30 +1074,30 @@ __device__ __forceinline__ double max(float a, double b) { return a > b ? (doubl // 9. ArgMin/ArgMax Block Reduce Instantiation // =============================================================== -// 【终极修正】支持 2D Block 的行级归约 (Row-wise Reduction) +// Row-wise reduction supporting 2D thread blocks template __device__ inline T cinn_block_reduce_shm_impl(T value, T* shm_discard, Func reduce_func) { - // 获取 2D 维度信息 + // Retrieve 2D block dimensions unsigned int tx = threadIdx.x; unsigned int ty = threadIdx.y; unsigned int bdx = blockDim.x; - // 计算扁平化索引:确保不同行的数据落在 Shared Memory 的不同区域 - // 这样 threadIdx.y=0 和 threadIdx.y=1 就不会打架了 + // Compute flattened index: ensure different rows map to distinct shared memory regions, + // so that threadIdx.y=0 and threadIdx.y=1 do not conflict unsigned int idx = ty * bdx + tx; - // 分配足够大的静态 Shared Memory (1024 够 32x32 的 block 使用) - // 如果你的 block 很大,需要增加这里。但 CINN argmax 通常 block 不大。 + // Allocate sufficient static shared memory (1024 covers up to 32x32 thread blocks). + // Increase this if your block is larger, though CINN argmax blocks are typically small. __shared__ T internal_shm[1024]; - // 1. 写入 (带边界检查) + // 1. Store values (with bounds check) if (idx < 1024) { internal_shm[idx] = value; } __syncthreads(); - // 2. 树状归约 (只在 tx 维度归约) - // 每一行 (ty) 独立进行归约,互不干扰 + // 2. Tree-based reduction (reduce along the tx dimension only) + // Each row (ty) reduces independently without interference for (unsigned int s = bdx / 2; s > 0; s >>= 1) { if (tx < s && (idx + s) < 1024) { internal_shm[idx] = reduce_func(internal_shm[idx], internal_shm[idx + s]); @@ -1099,9 +1105,9 @@ __device__ inline T cinn_block_reduce_shm_impl(T value, T* shm_discard, Func red __syncthreads(); } - // 3. 返回结果 - // 每一行的结果存储在该行的首位 (ty * bdx) - // 广播给该行的所有线程 + // 3. Return result + // Each row's result is stored at the head of that row (ty * bdx) + // Broadcast to all threads in the same row return internal_shm[ty * bdx]; } @@ -1134,8 +1140,25 @@ __device__ inline argidx_fp32_i64 cinn_block_reduce_min_argidx_fp32_i64(const ar return cinn_block_reduce_min(value, shm, return_warp); } -__device__ inline argidx_fp32_i64 cinn_block_reduce_max_argidx_fp32_i64(const argidx_fp32_i64 value, argidx_fp32_i64 *shm, bool return_warp = false) { - return cinn_block_reduce_max(value, shm, return_warp); +__device__ inline argidx_fp32_i64 cinn_block_reduce_max_argidx_fp32_i64(const argidx_fp32_i64 value, argidx_fp32_i64 *shm, bool return_warp = false) { + return cinn_block_reduce_max(value, shm, return_warp); +} + +// i32 variants +__device__ inline argidx_fp32_i32 cinn_block_reduce_max(const argidx_fp32_i32 value, argidx_fp32_i32 *shm, bool return_warp = false) { + return cinn_block_reduce_shm_impl(value, shm, ArgIdxMaxOp()); +} + +__device__ inline argidx_fp32_i32 cinn_block_reduce_min(const argidx_fp32_i32 value, argidx_fp32_i32 *shm, bool return_warp = false) { + return cinn_block_reduce_shm_impl(value, shm, ArgIdxMinOp()); +} + +__device__ inline argidx_fp32_i32 cinn_block_reduce_min_argidx_fp32_i32(const argidx_fp32_i32 value, argidx_fp32_i32 *shm, bool return_warp = false) { + return cinn_block_reduce_min(value, shm, return_warp); +} + +__device__ inline argidx_fp32_i32 cinn_block_reduce_max_argidx_fp32_i32(const argidx_fp32_i32 value, argidx_fp32_i32 *shm, bool return_warp = false) { + return cinn_block_reduce_max(value, shm, return_warp); } } // extern "C" @@ -1143,10 +1166,10 @@ __device__ inline argidx_fp32_i64 cinn_block_reduce_max_argidx_fp32_i64(const ar // ============================================================ -// 2. 接口实现 +// 2. Interface Implementation // ============================================================ -// 全局原子计数器,确保文件名唯一 +// Global atomic counter to ensure unique filenames static std::atomic g_compile_counter{0}; const char* MetaxGetRuntimeSource(void* dev_ptr) { @@ -1154,22 +1177,19 @@ const char* MetaxGetRuntimeSource(void* dev_ptr) { } C_Status MetaxCompile(void* dev_ptr, const char* code, char* out_path, size_t len) { - // 0. 生成随机文件名 - // 【关键修复】使用 进程ID + 原子计数器 生成唯一文件名 - // 彻底解决多线程编译时的文件名冲突问题 + // 0. Generate unique filename + // Use PID + atomic counter to generate unique filenames, + // completely resolving filename collisions during concurrent compilation uint64_t file_id = g_compile_counter.fetch_add(1); std::string file_prefix = "cinn_metax_" + std::to_string(getpid()) + "_" + std::to_string(file_id); - // 生成临时文件路径 + // Generate temporary file paths std::string src_path = "/tmp/" + file_prefix + ".cu"; std::string obj_path = "/tmp/" + file_prefix + ".co"; - // 注意:即使 CINN 传了 out_path 进来,通常也是空的或者期望我们填写的 - // 所以我们尽量使用自己生成的 obj_path,最后再拷贝回去 - - // 1. 写入源码 + // 1. Write source code { - // 使用 truncate 模式打开,虽然文件名唯一,但以防万一 + // Open in truncate mode; although the filename is unique, this is a safety measure std::ofstream src_file(src_path, std::ios::trunc); if (!src_file.is_open()) { std::cerr << "[MetaX] Failed to open temp file: " << src_path << std::endl; @@ -1180,7 +1200,7 @@ C_Status MetaxCompile(void* dev_ptr, const char* code, char* out_path, size_t le src_file.close(); } - // 2. 准备编译器路径 + // 2. Resolve compiler binary path const char* maca_path_env = std::getenv("MACA_PATH"); std::string maca_path = maca_path_env ? std::string(maca_path_env) : "/opt/maca"; @@ -1190,15 +1210,14 @@ C_Status MetaxCompile(void* dev_ptr, const char* code, char* out_path, size_t le if (access(mxcc_cmd.c_str(), X_OK) != 0) mxcc_cmd = "mxcc"; } - // 3. 构建编译命令 - // 注意:加了空格防止粘连 + // 3. Build compilation command std::string cmd = mxcc_cmd + " -O3 -std=c++17 -w --fatbin --offload-arch=native -fvisibility=default"; cmd += " -I" + maca_path + "/include"; cmd += " -I" + maca_path + "/tools/cu-bridge/include"; cmd += " -o " + obj_path; cmd += " " + src_path; - // 4. 执行 + // 4. Execute compilation std::cout << "Command: " << cmd << std::endl; int ret = std::system(cmd.c_str()); if (ret != 0) { @@ -1207,27 +1226,27 @@ C_Status MetaxCompile(void* dev_ptr, const char* code, char* out_path, size_t le return C_Status::C_FAILED; } - // 5. 确保文件存在 + // 5. Verify output file exists if (access(obj_path.c_str(), F_OK) != 0) { std::cerr << "[MetaX] Output file missing: " << obj_path << std::endl; return C_Status::C_FAILED; } // ================================================================= - // 6. 【关键修复】将生成的二进制路径回填给 CINN 框架 + // 6. Write back the generated binary path to the CINN framework // ================================================================= if (out_path && len > 0) { - // 使用 strncpy 安全拷贝 + // Use strncpy for safe copy std::strncpy(out_path, obj_path.c_str(), len - 1); - out_path[len - 1] = '\0'; // 确保 null 结尾 - // 打印调试信息,确认回填成功 + out_path[len - 1] = '\0'; // Ensure null-termination + // Print debug info to confirm write-back succeeded std::cout << "[MetaX Success] Compiled: " << out_path << std::endl; } else { std::cerr << "[MetaX Error] Invalid out_path buffer!" << std::endl; return C_Status::C_FAILED; } - // 7. 清理源码 (调试成功后可开启) + // 7. Clean up source file (enable after debugging is complete) std::remove(src_path.c_str()); return C_Status::C_SUCCESS; diff --git a/backends/metax_gpu/cinn/passes/pass_manager.cc b/backends/metax_gpu/cinn/passes/pass_manager.cc index a2a90a1430..0caa6033d9 100644 --- a/backends/metax_gpu/cinn/passes/pass_manager.cc +++ b/backends/metax_gpu/cinn/passes/pass_manager.cc @@ -1,3 +1,17 @@ +// Copyright (c) 2026 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + #include "paddle/phi/backends/device_ext.h" #include @@ -5,8 +19,8 @@ namespace paddle { namespace custom_device { namespace metax { -// 负责应用自定义的图优化 Pass -// 目前阶段先留空,直接返回成功 +// Applies custom graph optimization passes. +// Currently a no-op stub; returns success immediately. C_Status MetaxApplyCustomPass(void* dev_ptr, void* ir_module) { // VLOG(3) << "[MetaX] MetaxApplyCustomPass called (No-op)"; return C_Status::C_SUCCESS; @@ -14,4 +28,4 @@ C_Status MetaxApplyCustomPass(void* dev_ptr, void* ir_module) { } // namespace metax } // namespace custom_device -} // namespace paddle \ No newline at end of file +} // namespace paddle diff --git a/backends/metax_gpu/cinn/runtime/cinn_runtime.cc b/backends/metax_gpu/cinn/runtime/cinn_runtime.cc index 3b24de402e..ea12b1e359 100644 --- a/backends/metax_gpu/cinn/runtime/cinn_runtime.cc +++ b/backends/metax_gpu/cinn/runtime/cinn_runtime.cc @@ -1,4 +1,4 @@ -// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +// Copyright (c) 2026 PaddlePaddle Authors. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -13,7 +13,7 @@ // limitations under the License. #include "paddle/phi/backends/device_ext.h" -#include +#include #include #include #include @@ -22,44 +22,44 @@ namespace paddle { namespace custom_device { namespace metax { -// 【实现1】加载模块:相当于 cudaModuleLoad +// Load module: equivalent to cuModuleLoad C_Status MetaxModuleLoad(void* dev_ptr, const char* path, void** mod_out) { CUmodule module; CUresult err = cuModuleLoad(&module, path); if (err != CUDA_SUCCESS) return C_Status::C_FAILED; - + *mod_out = (void*)module; return C_Status::C_SUCCESS; } -// 【实现2】卸载模块 +// Unload module C_Status MetaxModuleUnload(void* dev_ptr, void* module_handle) { cuModuleUnload((CUmodule)module_handle); return C_Status::C_SUCCESS; } -// 【实现3】获取函数地址:相当于 cudaModuleGetFunction +// Get kernel function address: equivalent to cuModuleGetFunction C_Status MetaxGetKernelAddress(void* dev_ptr, void* module_handle, const char* func_name, void** func_out) { CUfunction func; CUresult err = cuModuleGetFunction(&func, (CUmodule)module_handle, func_name); if (err != CUDA_SUCCESS) return C_Status::C_FAILED; - + *func_out = (void*)func; return C_Status::C_SUCCESS; } -// 【实现4】启动核函数:相当于 cudaLaunchKernel +// Launch kernel: equivalent to cuLaunchKernel C_Status MetaxLaunchKernel(void* dev_ptr, void* func_ptr, void** args, int num_args, - int gx, int gy, int gz, - int bx, int by, int bz, + int gx, int gy, int gz, + int bx, int by, int bz, int shm, void* stream) { - // 注意:args 这里通常是 void*[],可能需要处理一下参数封装 + // Note: args is typically a void*[] and may require argument marshaling CUresult err = cuLaunchKernel((CUfunction)func_ptr, - gx, gy, gz, + gx, gy, gz, bx, by, bz, - shm, - (CUstream)stream, - args, + shm, + (CUstream)stream, + args, nullptr); if (err != CUDA_SUCCESS) return C_Status::C_FAILED; return C_Status::C_SUCCESS; @@ -67,4 +67,4 @@ C_Status MetaxLaunchKernel(void* dev_ptr, void* func_ptr, void** args, int num_a } // namespace metax } // namespace custom_device -} // namespace paddle \ No newline at end of file +} // namespace paddle diff --git a/backends/metax_gpu/kernels/impl/conv_grad_kernel_impl.h b/backends/metax_gpu/kernels/impl/conv_grad_kernel_impl.h index 6066720ab0..3987729648 100644 --- a/backends/metax_gpu/kernels/impl/conv_grad_kernel_impl.h +++ b/backends/metax_gpu/kernels/impl/conv_grad_kernel_impl.h @@ -160,7 +160,7 @@ void ConvGradKernel(const Context& dev_ctx, if (is_expand) { set_zero(dev_ctx, &transformed_input_grad, static_cast(0)); } - phi::funcs::Col2ImFunctor col2im; + phi::funcs::Col2ImFunctor col2im; phi::funcs::Col2VolFunctor col2vol; for (int i = 0; i < batch_size; i++) { @@ -214,7 +214,7 @@ void ConvGradKernel(const Context& dev_ctx, Tensor filter_grad_ = *filter_grad; filter_grad_.Resize(filter_matrix_shape); set_zero(dev_ctx, filter_grad, static_cast(0)); - phi::funcs::Im2ColFunctor im2col; + phi::funcs::Im2ColFunctor im2col; phi::funcs::Vol2ColFunctor vol2col; for (int i = 0; i < batch_size; i++) { DenseTensor out_grad_batch = @@ -391,7 +391,7 @@ void ConvGradGradKernel(const Context& dev_ctx, if (is_expand) { set_zero(dev_ctx, &transformed_dX, static_cast(0)); } - phi::funcs::Col2ImFunctor col2im; + phi::funcs::Col2ImFunctor col2im; phi::funcs::Col2VolFunctor col2vol; for (int i = 0; i < batch_size; i++) { @@ -436,7 +436,7 @@ void ConvGradGradKernel(const Context& dev_ctx, set_zero(dev_ctx, dW, static_cast(0)); DenseTensor dW_arr = *dW; dW_arr.Resize(filter_matrix_shape); - phi::funcs::Im2ColFunctor im2col; + phi::funcs::Im2ColFunctor im2col; phi::funcs::Vol2ColFunctor vol2col; for (int i = 0; i < batch_size; ++i) { DenseTensor dy_batch = @@ -483,7 +483,7 @@ void ConvGradGradKernel(const Context& dev_ctx, } set_zero(dev_ctx, &transformed_ddY, static_cast(0)); - phi::funcs::Im2ColFunctor im2col; + phi::funcs::Im2ColFunctor im2col; phi::funcs::Vol2ColFunctor vol2col; for (int i = 0; i < batch_size; ++i) { DenseTensor ddy_batch = diff --git a/backends/metax_gpu/kernels/impl/conv_kernel_impl.h b/backends/metax_gpu/kernels/impl/conv_kernel_impl.h index 4395e5d578..0fe5ffbf9a 100644 --- a/backends/metax_gpu/kernels/impl/conv_kernel_impl.h +++ b/backends/metax_gpu/kernels/impl/conv_kernel_impl.h @@ -140,7 +140,7 @@ void ConvKernelImpl(const Context& dev_ctx, int in_step = static_cast(transformed_input.dims()[1]) / groups; int out_step = static_cast(transformed_output.dims()[1]) / groups; - phi::funcs::Im2ColFunctor im2col; + phi::funcs::Im2ColFunctor im2col; phi::funcs::Vol2ColFunctor vol2col; auto blas = phi::funcs::GetBlas(dev_ctx); diff --git a/backends/metax_gpu/tests/tmp_save/unittest/test_elementwise_pow_op_metax.py b/backends/metax_gpu/tests/tmp_save/unittest/test_elementwise_pow_op_metax.py index 574691a02d..2c979fca97 100755 --- a/backends/metax_gpu/tests/tmp_save/unittest/test_elementwise_pow_op_metax.py +++ b/backends/metax_gpu/tests/tmp_save/unittest/test_elementwise_pow_op_metax.py @@ -17,14 +17,11 @@ import unittest import numpy as np -import sys -sys.path.insert(0, '/home/sw/Baidu-xuyuhan/PaddleCustomDevice/python/tests/') from op_test import OpTest, convert_float_to_uint16, skip_check_grad_ci import paddle from paddle import base from paddle.base import core -import paddle.profiler as profiler def pow_grad(x, y, dout): @@ -32,6 +29,7 @@ def pow_grad(x, y, dout): dy = dout * np.log(x) * np.power(x, y) return dx, dy + class TestElementwisePowOp(OpTest): def setUp(self): self.op_type = "elementwise_pow" @@ -45,61 +43,23 @@ def setUp(self): self.outputs = {"Out": np.power(self.inputs["X"], self.inputs["Y"])} def test_check_output(self): - # 定义输出路径 (会在当前目录下生成 profiler_log 文件夹) - # 1. 确保目录存在 - output_dir = "./profiler_log" - if not os.path.exists(output_dir): - os.makedirs(output_dir) - output_path = "./profiler_log/check_output" - - # 定义回调函数,用于导出性能数据 - def my_on_trace_ready(prof): - prof.export(path=output_path, format="json") - - # 初始化 Profiler - # 注意:对于 MetaX 这类 CustomDevice,通常 target 选 CPU 即可捕获 Host 端调度 - # 如果 MetaX 插件实现了 Profiler 接口,选 GPU 或 CUSTOM_DEVICE 可能捕获设备端信息 - with profiler.Profiler( - targets=[profiler.ProfilerTarget.CPU, profiler.ProfilerTarget.CUSTOM_DEVICE], - scheduler=profiler.make_scheduler(closed=0, ready=0, record=1, repeat=1), - on_trace_ready=my_on_trace_ready - ) as p: - # === 将原始测试逻辑包裹在这里 === - if hasattr(self, "attrs"): - self.check_output(check_dygraph=False) - else: - self.check_output(check_pir=True, check_symbol_infer=False) - # ============================== - - p.step() # 通知 Profiler 一个 step 结束 + if hasattr(self, "attrs"): + self.check_output(check_dygraph=False) + else: + self.check_output(check_pir=True, check_symbol_infer=False) def test_check_grad_normal(self): - # 1. 确保目录存在 - output_dir = "./profiler_log" - if not os.path.exists(output_dir): - os.makedirs(output_dir) - output_path = "./profiler_log/check_grad" - def my_on_trace_ready(prof): - prof.export(path=output_path, format="json") - - with profiler.Profiler( - targets=[profiler.ProfilerTarget.CPU, profiler.ProfilerTarget.CUSTOM_DEVICE], - scheduler=profiler.make_scheduler(closed=0, ready=0, record=1, repeat=1), - on_trace_ready=my_on_trace_ready - ) as p: - # === 将原始测试逻辑包裹在这里 === - if hasattr(self, "attrs"): - self.check_grad(["X", "Y"], "Out", check_prim=True, check_dygraph=False) - else: - self.check_grad( - ["X", "Y"], - "Out", - check_prim=True, - check_prim_pir=True, - check_pir=True, - ) - # ============================== - p.step() + if hasattr(self, "attrs"): + self.check_grad(["X", "Y"], "Out", check_prim=True, check_dygraph=False) + else: + self.check_grad( + ["X", "Y"], + "Out", + check_prim=True, + check_prim_pir=True, + check_pir=True, + ) + class TestElementwisePowOp_ZeroDim1(TestElementwisePowOp): def setUp(self): @@ -114,7 +74,7 @@ def setUp(self): } self.outputs = {"Out": np.power(self.inputs["X"], self.inputs["Y"])} -''' + class TestElementwisePowOp_ZeroDim2(TestElementwisePowOp): def setUp(self): self.op_type = "elementwise_pow" @@ -128,6 +88,7 @@ def setUp(self): } self.outputs = {"Out": np.power(self.inputs["X"], self.inputs["Y"])} + class TestElementwisePowOp_ZeroDim3(TestElementwisePowOp): def setUp(self): self.op_type = "elementwise_pow" @@ -494,6 +455,7 @@ def test_check_grad(self): only_check_prim=True, check_prim_pir=True, ) -''' + + if __name__ == "__main__": unittest.main() From 8d90319033366e5e254e755cd512b84c87d0f7b6 Mon Sep 17 00:00:00 2001 From: Yuhan Xu Date: Tue, 31 Mar 2026 11:45:08 +0000 Subject: [PATCH 14/17] Recover CINN irrelevant code. --- backends/metax_gpu/build.sh | 7 +++---- backends/metax_gpu/change_patch.sh | 2 +- backends/metax_gpu/compile.sh | 2 +- .../kernels/custom_kernel/custom_context.h | 3 --- .../kernels/impl/conv_grad_kernel_impl.h | 10 +++++----- .../metax_gpu/kernels/impl/conv_kernel_impl.h | 2 +- .../kernels/impl/conv_transpose_kernel_impl.h | 2 +- backends/metax_gpu/tests/run_test.sh | 16 ++-------------- .../tests/tmp_save/gpudnn/conv_cudnn_v7.h | 6 +++--- 9 files changed, 17 insertions(+), 33 deletions(-) diff --git a/backends/metax_gpu/build.sh b/backends/metax_gpu/build.sh index 9eaa73e571..223baa2a9f 100755 --- a/backends/metax_gpu/build.sh +++ b/backends/metax_gpu/build.sh @@ -18,13 +18,12 @@ set -e # install requirement.txt -# pip install -r requirement.txt -i https://mirrors.tuna.tsinghua.edu.cn/pypi/web/simple +pip install -r requirement.txt -i https://mirrors.tuna.tsinghua.edu.cn/pypi/web/simple # uninstall paddle -# pip uninstall paddlepaddle -y +pip uninstall paddlepaddle -y - -# python -m pip install --pre paddlepaddle -i https://www.paddlepaddle.org.cn/packages/nightly/cpu/ +python -m pip install --pre paddlepaddle -i https://www.paddlepaddle.org.cn/packages/nightly/cpu/ # apply patch diff --git a/backends/metax_gpu/change_patch.sh b/backends/metax_gpu/change_patch.sh index faa5adac66..3fa9a64761 100644 --- a/backends/metax_gpu/change_patch.sh +++ b/backends/metax_gpu/change_patch.sh @@ -24,6 +24,6 @@ cp -r patch/eigen3/ ../../Paddle/third_party/eigen3 rm -r patch/eigen3 # cp patch/tmp/mixed_vector* ../../Paddle/paddle/phi/core cd ../../Paddle/ -git apply --verbose /home/sw/Baidu-xuyuhan/PaddleCustomDevice/backends/metax_gpu/patch/paddle.patch +git apply --verbose ../backends/metax_gpu/patch/paddle.patch cd - # cp -r patch/intrinsics.cuh ../../Paddle/third_party/warpctc/include/contrib/moderngpu/include/device/ diff --git a/backends/metax_gpu/compile.sh b/backends/metax_gpu/compile.sh index e77dad77d0..a15be1ced9 100644 --- a/backends/metax_gpu/compile.sh +++ b/backends/metax_gpu/compile.sh @@ -28,7 +28,7 @@ export LD_LIBRARY_PATH=${MACA_PATH}/lib:${MACA_PATH}/mxgpu_llvm/lib:${LD_LIBRARY export PADDLE_VERSION="3.3.0.dev$(date +%Y%m%d)" export MACA_AI_VERSION=$(cat /opt/maca/Version.txt | cut -d':' -f2) if [ ! -d build ]; then -echo "build directory not found, creating..." + echo "build directory not found, creating..." mkdir build fi diff --git a/backends/metax_gpu/kernels/custom_kernel/custom_context.h b/backends/metax_gpu/kernels/custom_kernel/custom_context.h index dd232f841f..19035992ea 100644 --- a/backends/metax_gpu/kernels/custom_kernel/custom_context.h +++ b/backends/metax_gpu/kernels/custom_kernel/custom_context.h @@ -29,7 +29,6 @@ #include "paddle/phi/core/device_context.h" namespace phi { - // class DnnWorkspaceHandle { // public: // inline DnnWorkspaceHandle(Allocator* allocator, gpuStream_t stream) @@ -101,7 +100,6 @@ namespace phi { // } // } // namespace - namespace dynload { inline bool HasCUSOLVER() { @@ -158,6 +156,5 @@ inline cusolverDnHandle_t GetCusolverDnHandle(gpuStream_t stream, Place place) { // const gpuStream_t& stream) { // return DnnWorkspaceHandle(alloactor, stream); // } - } // namespace phi #endif // BACKENDS_METAX_GPU_KERNELS_CUSTOM_KERNEL_CUSTOM_CONTEXT_H_ diff --git a/backends/metax_gpu/kernels/impl/conv_grad_kernel_impl.h b/backends/metax_gpu/kernels/impl/conv_grad_kernel_impl.h index 3987729648..6066720ab0 100644 --- a/backends/metax_gpu/kernels/impl/conv_grad_kernel_impl.h +++ b/backends/metax_gpu/kernels/impl/conv_grad_kernel_impl.h @@ -160,7 +160,7 @@ void ConvGradKernel(const Context& dev_ctx, if (is_expand) { set_zero(dev_ctx, &transformed_input_grad, static_cast(0)); } - phi::funcs::Col2ImFunctor col2im; + phi::funcs::Col2ImFunctor col2im; phi::funcs::Col2VolFunctor col2vol; for (int i = 0; i < batch_size; i++) { @@ -214,7 +214,7 @@ void ConvGradKernel(const Context& dev_ctx, Tensor filter_grad_ = *filter_grad; filter_grad_.Resize(filter_matrix_shape); set_zero(dev_ctx, filter_grad, static_cast(0)); - phi::funcs::Im2ColFunctor im2col; + phi::funcs::Im2ColFunctor im2col; phi::funcs::Vol2ColFunctor vol2col; for (int i = 0; i < batch_size; i++) { DenseTensor out_grad_batch = @@ -391,7 +391,7 @@ void ConvGradGradKernel(const Context& dev_ctx, if (is_expand) { set_zero(dev_ctx, &transformed_dX, static_cast(0)); } - phi::funcs::Col2ImFunctor col2im; + phi::funcs::Col2ImFunctor col2im; phi::funcs::Col2VolFunctor col2vol; for (int i = 0; i < batch_size; i++) { @@ -436,7 +436,7 @@ void ConvGradGradKernel(const Context& dev_ctx, set_zero(dev_ctx, dW, static_cast(0)); DenseTensor dW_arr = *dW; dW_arr.Resize(filter_matrix_shape); - phi::funcs::Im2ColFunctor im2col; + phi::funcs::Im2ColFunctor im2col; phi::funcs::Vol2ColFunctor vol2col; for (int i = 0; i < batch_size; ++i) { DenseTensor dy_batch = @@ -483,7 +483,7 @@ void ConvGradGradKernel(const Context& dev_ctx, } set_zero(dev_ctx, &transformed_ddY, static_cast(0)); - phi::funcs::Im2ColFunctor im2col; + phi::funcs::Im2ColFunctor im2col; phi::funcs::Vol2ColFunctor vol2col; for (int i = 0; i < batch_size; ++i) { DenseTensor ddy_batch = diff --git a/backends/metax_gpu/kernels/impl/conv_kernel_impl.h b/backends/metax_gpu/kernels/impl/conv_kernel_impl.h index 0fe5ffbf9a..4395e5d578 100644 --- a/backends/metax_gpu/kernels/impl/conv_kernel_impl.h +++ b/backends/metax_gpu/kernels/impl/conv_kernel_impl.h @@ -140,7 +140,7 @@ void ConvKernelImpl(const Context& dev_ctx, int in_step = static_cast(transformed_input.dims()[1]) / groups; int out_step = static_cast(transformed_output.dims()[1]) / groups; - phi::funcs::Im2ColFunctor im2col; + phi::funcs::Im2ColFunctor im2col; phi::funcs::Vol2ColFunctor vol2col; auto blas = phi::funcs::GetBlas(dev_ctx); diff --git a/backends/metax_gpu/kernels/impl/conv_transpose_kernel_impl.h b/backends/metax_gpu/kernels/impl/conv_transpose_kernel_impl.h index f7d7b75a29..aadc5d2b8a 100644 --- a/backends/metax_gpu/kernels/impl/conv_transpose_kernel_impl.h +++ b/backends/metax_gpu/kernels/impl/conv_transpose_kernel_impl.h @@ -142,7 +142,7 @@ void ConvTransposeRawKernel(const Context& dev_ctx, (data_layout != DataLayout::kNHWC ? static_cast(out_dims[1]) / groups : static_cast(out_dims[out_dims.size() - 1]) / groups); - phi::funcs::Col2ImFunctor col2im; + phi::funcs::Col2ImFunctor col2im; phi::funcs::Col2VolFunctor col2vol; funcs::ConcatFunctor concat_functor; diff --git a/backends/metax_gpu/tests/run_test.sh b/backends/metax_gpu/tests/run_test.sh index ae7f64c29a..6ad3c1f653 100755 --- a/backends/metax_gpu/tests/run_test.sh +++ b/backends/metax_gpu/tests/run_test.sh @@ -23,18 +23,6 @@ TEST_PATH2="${SCRIPT_DIR}/../../../python/tests" export PYTHONPATH="${LEGACY_TEST_PATH}:${PYTHONPATH}:${TEST_PATH1}:${TEST_PATH2}" export PADDLE_XCCL_BACKEND=metax_gpu export CUDA_VISIBLE_DEVICES=0 - -PYTHONUNBUFFERED=1 -# 以下三条为运行CINN必开 -FLAGS_prim_all=true -FLAGS_prim_enable_dynamic=true -FLAGS_use_cinn=true -# 关闭多线程编译,调试时用 -FLAGS_enable_cinn_compile_cache=false -# 打印log,调试时用 -FLAGS_print_ir=true -GLOG_v=1 - # export # sleep 1000000 @@ -93,8 +81,8 @@ done export GLOG_v=$TEST_LOG_LEVEL -cmake .. -DTEST_LIST_FILE=$TEST_LIST_FILE -DLOG_OUTPUT_DIR=$TEST_LOG_OUTPUT_DIR -DIGNORE_BLOCKS="$IGNORE_BLOCKS" -DWITH_CINN=ON +cmake .. -DTEST_LIST_FILE=$TEST_LIST_FILE -DLOG_OUTPUT_DIR=$TEST_LOG_OUTPUT_DIR -DIGNORE_BLOCKS="$IGNORE_BLOCKS" cmake --build . -GLOG_v=1 FLAGS_print_ir=1 ctest -j$TEST_PARALLEL_NUM --output-on-failure +ctest -j$TEST_PARALLEL_NUM --output-on-failure diff --git a/backends/metax_gpu/tests/tmp_save/gpudnn/conv_cudnn_v7.h b/backends/metax_gpu/tests/tmp_save/gpudnn/conv_cudnn_v7.h index 4923d802e9..be89898e68 100644 --- a/backends/metax_gpu/tests/tmp_save/gpudnn/conv_cudnn_v7.h +++ b/backends/metax_gpu/tests/tmp_save/gpudnn/conv_cudnn_v7.h @@ -227,7 +227,7 @@ struct SearchAlgorithmBase { // auto workspace_handle = dev_ctx.cudnn_workspace_handle(); auto workspace_handle = GetDnnWorkspace( - const_cast(&(dev_ctx.GetAllocator())), dev_ctx.stream(), dev_ctx.GetPlace()); + const_cast(&(dev_ctx.GetAllocator())), dev_ctx.stream()); // auto handle = GetDnnHandle(dev_ctx.stream(), dev_ctx.GetPlace()); @@ -416,7 +416,7 @@ struct SearchAlgorithmBase { // auto workspace_handle = dev_ctx.cudnn_workspace_handle(); auto workspace_handle = GetDnnWorkspace( - const_cast(&(dev_ctx.GetAllocator())), dev_ctx.stream(), dev_ctx.GetPlace()); + const_cast(&(dev_ctx.GetAllocator())), dev_ctx.stream()); workspace_handle.RunFuncSync( cudnn_find_func, max_workspace_size, UseFixedWorkspace()); @@ -569,7 +569,7 @@ struct SearchAlgorithmBase { CalcWorkspaceLimitInBytes(UseFixedWorkspace()); // auto workspace_handle = dev_ctx.cudnn_workspace_handle(); auto workspace_handle = GetDnnWorkspace( - const_cast(&(dev_ctx.GetAllocator())), dev_ctx.stream(), dev_ctx.GetPlace()); + const_cast(&(dev_ctx.GetAllocator())), dev_ctx.stream()); if (phi::backends::gpu::CudnnDataType::type != CUDNN_DATA_HALF) { size_t max_workspace_size = GetMaxWorkspaceSize(args, workspace_size_limit); From be5100462e6b6e76aa37e32bb41d815005e17f95 Mon Sep 17 00:00:00 2001 From: Yuhan Xu Date: Wed, 1 Apr 2026 02:48:22 +0000 Subject: [PATCH 15/17] Fix Code-style. --- backends/metax_gpu/CMakeLists.txt | 15 +- backends/metax_gpu/cinn/CMakeLists.txt | 56 ++--- backends/metax_gpu/cinn/cinn_interface.cc | 94 ++++---- backends/metax_gpu/cinn/cinn_interface.h | 9 +- backends/metax_gpu/cinn/compiler/compiler.cc | 226 ++++++++++-------- .../metax_gpu/cinn/passes/pass_manager.cc | 13 +- .../metax_gpu/cinn/runtime/cinn_runtime.cc | 81 ++++--- backends/metax_gpu/runtime/runtime.cc | 49 ++-- 8 files changed, 295 insertions(+), 248 deletions(-) diff --git a/backends/metax_gpu/CMakeLists.txt b/backends/metax_gpu/CMakeLists.txt index d95a982000..93298ab1ee 100755 --- a/backends/metax_gpu/CMakeLists.txt +++ b/backends/metax_gpu/CMakeLists.txt @@ -44,9 +44,9 @@ endif() include(paddle) if(WITH_CINN) - message(STATUS "[MetaX] CINN enabled, adding subdirectory: cinn") - add_definitions(-DWITH_CINN) - add_subdirectory(cinn) + message(STATUS "[MetaX] CINN enabled, adding subdirectory: cinn") + add_definitions(-DWITH_CINN) + add_subdirectory(cinn) endif() set(THIRD_PARTY_PATH @@ -802,9 +802,8 @@ set(CMAKE_CUCC_FLAGS "-I ${MACA_PATH}/tools/cu-bridge/include/") add_library(${TARGET_NAME} SHARED ${CUSTOM_DEVICE_SRCS}) if(WITH_CINN) - target_include_directories(${TARGET_NAME} PRIVATE - "${CMAKE_CURRENT_SOURCE_DIR}/cinn" - ) + target_include_directories(${TARGET_NAME} + PRIVATE "${CMAKE_CURRENT_SOURCE_DIR}/cinn") endif() target_include_directories( @@ -837,8 +836,8 @@ target_link_libraries(${TARGET_NAME} ${MACA_PATH}/lib/libmcFlashAttn.so) target_link_libraries(${TARGET_NAME} ${MACA_PATH}/lib/libmcpti.so) if(WITH_CINN) - message(STATUS "[MetaX] Linking CINN object library") - target_link_libraries(${TARGET_NAME} $) + message(STATUS "[MetaX] Linking CINN object library") + target_link_libraries(${TARGET_NAME} $) endif() include_directories(BEFORE ${PADDLE_SOURCE_DIR}) diff --git a/backends/metax_gpu/cinn/CMakeLists.txt b/backends/metax_gpu/cinn/CMakeLists.txt index ea35bea8e3..0c122669a3 100644 --- a/backends/metax_gpu/cinn/CMakeLists.txt +++ b/backends/metax_gpu/cinn/CMakeLists.txt @@ -2,45 +2,47 @@ # CINN Plugin for MetaX (MACA) Backend # ============================================================================= -# 1. Locate MACA SDK path -# To allow #include in runtime/cinn_runtime.cc or compiler.cc, -# we need to add the MetaX SDK header search path. +# 1. Locate MACA SDK path To allow #include in +# runtime/cinn_runtime.cc or compiler.cc, we need to add the MetaX SDK header +# search path. set(MACA_PATH $ENV{MACA_PATH}) if(NOT MACA_PATH) - set(MACA_PATH "/opt/maca") # Default fallback path - message(STATUS "[MetaX CINN] MACA_PATH not set, using default: ${MACA_PATH}") + set(MACA_PATH "/opt/maca") # Default fallback path + message(STATUS "[MetaX CINN] MACA_PATH not set, using default: ${MACA_PATH}") else() - message(STATUS "[MetaX CINN] Found MACA_PATH: ${MACA_PATH}") + message(STATUS "[MetaX CINN] Found MACA_PATH: ${MACA_PATH}") endif() -# 2. Define source file list -# All .cc files involved in the CINN implementation must be included here. +# 1. Define source file list All .cc files involved in the CINN implementation +# must be included here. set(CINN_SRCS - cinn_interface.cc # Main entry point, responsible for InitCinnInterface - compiler/compiler.cc # Implements MetaxCompile and MetaxGetRuntimeSource - runtime/cinn_runtime.cc # Implements MetaxModuleLoad, MetaxLaunchKernel - passes/pass_manager.cc # Implements MetaxApplyCustomPass + cinn_interface.cc # Main entry point, responsible for InitCinnInterface + compiler/compiler.cc # Implements MetaxCompile and MetaxGetRuntimeSource + runtime/cinn_runtime.cc # Implements MetaxModuleLoad, MetaxLaunchKernel + passes/pass_manager.cc # Implements MetaxApplyCustomPass ) -# 3. Create OBJECT library -# Use OBJECT mode to compile into .o files only, without generating .a or .so. -# This allows the parent CMake to directly collect these .o files and link them -# into the final plugin.so. +# 1. Create OBJECT library Use OBJECT mode to compile into .o files only, without +# generating .a or .so. This allows the parent CMake to directly collect these +# .o files and link them into the final plugin.so. add_library(metax_cinn_obj OBJECT ${CINN_SRCS}) -# 4. Configure header search paths -target_include_directories(metax_cinn_obj PRIVATE - ${CMAKE_CURRENT_SOURCE_DIR} # Allow referencing headers in current directory (cinn_interface.h) - ${CMAKE_CURRENT_SOURCE_DIR}/../ # Allow referencing parent-level headers (e.g., common/) - ${MACA_PATH}/include # Allow referencing - ${PADDLE_SOURCE_DIR} # Allow referencing paddle/phi/... headers - # Paddle header paths are typically auto-included via the external environment (Paddle_DIR) +# 1. Configure header search paths +target_include_directories( + metax_cinn_obj + PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} # Allow referencing headers in current + # directory (cinn_interface.h) + ${CMAKE_CURRENT_SOURCE_DIR}/../ # Allow referencing parent-level + # headers (e.g., common/) + ${MACA_PATH}/include # Allow referencing + ${PADDLE_SOURCE_DIR} # Allow referencing paddle/phi/... headers + # Paddle header paths are typically auto-included via the external + # environment (Paddle_DIR) ) -# 5. Compiler options -# The CINN component typically requires C++17 standard +# 1. Compiler options The CINN component typically requires C++17 standard set_property(TARGET metax_cinn_obj PROPERTY CXX_STANDARD 17) -# Enable PIC (Position Independent Code) -# Required because these .o files will ultimately be linked into a shared library +# Enable PIC (Position Independent Code) Required because these .o files will +# ultimately be linked into a shared library set_property(TARGET metax_cinn_obj PROPERTY POSITION_INDEPENDENT_CODE ON) diff --git a/backends/metax_gpu/cinn/cinn_interface.cc b/backends/metax_gpu/cinn/cinn_interface.cc index a65ce16832..a01bd0e67e 100644 --- a/backends/metax_gpu/cinn/cinn_interface.cc +++ b/backends/metax_gpu/cinn/cinn_interface.cc @@ -12,8 +12,9 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "cinn_interface.h" -#include // For memset +#include "cinn/cinn_interface.h" + +#include // For memset #include namespace paddle { @@ -22,11 +23,13 @@ namespace metax { // ============================================================ // External Function Declarations -// These functions must be implemented in the corresponding subdirectory files (.cc). +// These functions must be implemented in the corresponding subdirectory files +// (.cc). // ============================================================ // --- From compiler/compiler.cc --- -// Invokes the mxcc toolchain to compile CINN-generated source code into a binary +// Invokes the mxcc toolchain to compile CINN-generated source code into a +// binary extern C_Status MetaxCompile(void* dev_ptr, const char* code, char* out_path, @@ -35,7 +38,6 @@ extern C_Status MetaxCompile(void* dev_ptr, // Provides the MetaX GPU device runtime source code extern const char* MetaxGetRuntimeSource(void* dev_ptr); - // --- From runtime/cinn_runtime.cc --- // Loads a compiled binary module (.mx / .so) extern C_Status MetaxModuleLoad(void* dev_ptr, @@ -43,8 +45,7 @@ extern C_Status MetaxModuleLoad(void* dev_ptr, void** mod_out); // Unloads a module -extern C_Status MetaxModuleUnload(void* dev_ptr, - void* module_handle); +extern C_Status MetaxModuleUnload(void* dev_ptr, void* module_handle); // Retrieves the kernel function address from a loaded module extern C_Status MetaxGetKernelAddress(void* dev_ptr, @@ -57,17 +58,18 @@ extern C_Status MetaxLaunchKernel(void* dev_ptr, void* func_ptr, void** args, int num_args, - int gx, int gy, int gz, - int bx, int by, int bz, + int gx, + int gy, + int gz, + int bx, + int by, + int bz, int shm, void* stream); - // --- From passes/pass_manager.cc --- // Applies custom graph optimization passes -extern C_Status MetaxApplyCustomPass(void* dev_ptr, - void* ir_module); - +extern C_Status MetaxApplyCustomPass(void* dev_ptr, void* ir_module); // ============================================================ // Interface Initialization @@ -77,37 +79,39 @@ extern C_Status MetaxApplyCustomPass(void* dev_ptr, static C_CinnInterface metax_cinn_impl; void InitCinnInterface(C_DeviceInterface* device_interface) { - // 1. Zero-initialize for safety - std::memset(&metax_cinn_impl, 0, sizeof(C_CinnInterface)); - - // 2. Set struct size (used for version validation) - metax_cinn_impl.size = sizeof(C_CinnInterface); - - // 3. Set context pointer (optional) - // Point to a global state struct if your implementation needs one; otherwise nullptr - metax_cinn_impl.dev_ptr = nullptr; - - // 4. Register Compiler Toolchain interface - metax_cinn_impl.compile = MetaxCompile; - metax_cinn_impl.get_runtime_source = MetaxGetRuntimeSource; - - // 5. Register Runtime Strategy interface - metax_cinn_impl.module_load = MetaxModuleLoad; - metax_cinn_impl.module_unload = MetaxModuleUnload; - metax_cinn_impl.get_kernel_address = MetaxGetKernelAddress; - metax_cinn_impl.launch_kernel = MetaxLaunchKernel; - - // 6. Register Compilation Strategy interface - metax_cinn_impl.apply_custom_pass = MetaxApplyCustomPass; - - // 7. Attach the populated dispatch table to the Paddle device interface - if (device_interface) { - device_interface->cinn_interface = &metax_cinn_impl; - } else { - std::cerr << "[MetaX] Error: device_interface is null during CINN init." << std::endl; - } + // 1. Zero-initialize for safety + std::memset(&metax_cinn_impl, 0, sizeof(C_CinnInterface)); + + // 2. Set struct size (used for version validation) + metax_cinn_impl.size = sizeof(C_CinnInterface); + + // 3. Set context pointer (optional) + // Point to a global state struct if your implementation needs one; otherwise + // nullptr + metax_cinn_impl.dev_ptr = nullptr; + + // 4. Register Compiler Toolchain interface + metax_cinn_impl.compile = MetaxCompile; + metax_cinn_impl.get_runtime_source = MetaxGetRuntimeSource; + + // 5. Register Runtime Strategy interface + metax_cinn_impl.module_load = MetaxModuleLoad; + metax_cinn_impl.module_unload = MetaxModuleUnload; + metax_cinn_impl.get_kernel_address = MetaxGetKernelAddress; + metax_cinn_impl.launch_kernel = MetaxLaunchKernel; + + // 6. Register Compilation Strategy interface + metax_cinn_impl.apply_custom_pass = MetaxApplyCustomPass; + + // 7. Attach the populated dispatch table to the Paddle device interface + if (device_interface) { + device_interface->cinn_interface = &metax_cinn_impl; + } else { + std::cerr << "[MetaX] Error: device_interface is null during CINN init." + << std::endl; + } } -} // namespace metax -} // namespace custom_device -} // namespace paddle +} // namespace metax +} // namespace custom_device +} // namespace paddle diff --git a/backends/metax_gpu/cinn/cinn_interface.h b/backends/metax_gpu/cinn/cinn_interface.h index 330c172b5e..ca224dd764 100644 --- a/backends/metax_gpu/cinn/cinn_interface.h +++ b/backends/metax_gpu/cinn/cinn_interface.h @@ -28,10 +28,11 @@ namespace metax { * It populates device_interface->cinn_interface with the compiler * and runtime function pointers implemented under metax_gpu/cinn. * - * @param device_interface The device interface pointer passed from the Paddle host side. + * @param device_interface The device interface pointer passed from the Paddle + * host side. */ void InitCinnInterface(C_DeviceInterface* device_interface); -} // namespace metax -} // namespace custom_device -} // namespace paddle +} // namespace metax +} // namespace custom_device +} // namespace paddle diff --git a/backends/metax_gpu/cinn/compiler/compiler.cc b/backends/metax_gpu/cinn/compiler/compiler.cc index 51a04736da..cafc766bf7 100644 --- a/backends/metax_gpu/cinn/compiler/compiler.cc +++ b/backends/metax_gpu/cinn/compiler/compiler.cc @@ -1,16 +1,31 @@ +// Copyright (c) 2026 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + // PaddleCustomDevice/backends/metax_gpu/cinn/compiler/compiler.cc +#include +#include + +#include +#include +#include +#include +#include #include #include #include #include -#include -#include -#include -#include -#include -#include -#include // Host-side header, used only by compiler.cc #include "paddle/phi/backends/device_ext.h" @@ -26,6 +41,7 @@ static const char* kMacaRuntimeSource = R"MACA_SOURCE( #pragma once #include #include + #include extern "C" { @@ -96,7 +112,7 @@ __device__ inline int16_t FN_INT16(bitwise_not)(int16_t a) { return ~a; } __device__ inline int16_t FN_INT16(logical_right_shift)(int16_t a, int16_t b) { return ((uint16_t)a >> b); } // =============================================================== -// 6. Standard Math Functions +// 6. Standard Math Functions // =============================================================== // =============================================================== // Float64 (Double) Math Functions @@ -206,7 +222,7 @@ __device__ inline float FN_FP32(tanh_approx)(float x) { __device__ inline int FN_INT32(bitwise_not)(int a) { return ~a; } __device__ inline int FN_INT32(clz)(int a) { return __clz(a); } __device__ inline int FN_INT32(popc)(int a) { return __popc(a); } -__device__ inline int FN_INT32(mod)(int a, int b) { +__device__ inline int FN_INT32(mod)(int a, int b) { int res = a % b; if ((res != 0) && ((b ^ res) < 0)) res += b; return res; @@ -965,7 +981,7 @@ __device__ int cinn_custom_device_resize_bicubic(const int *buf, return value; } } // extern "C" - + // =============================================================== // 8. ArgMin/ArgMax Support (ArgIdx Structures & Shuffles) // =============================================================== @@ -1020,22 +1036,22 @@ __device__ __forceinline__ double max(float a, double b) { return a > b ? (doubl __device__ __forceinline__ double max(double a, float b) { return a > b ? a : (double)b; } __device__ __forceinline__ double min(float a, double b) { return a < b ? (double)a : b; } __device__ __forceinline__ double min(double a, float b) { return a < b ? a : (double)b; } - + // As a safeguard, resolve ambiguity when CINN emits int literals mixed with float (e.g., std::max(val, 0)) __device__ __forceinline__ float max(float a, int b) { return a > b ? a : (float)b; } __device__ __forceinline__ float max(int a, float b) { return a > b ? (float)a : b; } __device__ __forceinline__ float min(float a, int b) { return a < b ? a : (float)b; } __device__ __forceinline__ float min(int a, float b) { return a < b ? (float)a : b; } - + // ArgMax implementation - template + template __device__ __forceinline__ T max_argidx_impl(const T& a, const T& b) { if (a.value > b.value) return a; if (a.value < b.value) return b; return a.index < b.index ? a : b; } - - template + + template __device__ __forceinline__ T min_argidx_impl(const T& a, const T& b) { if (a.value < b.value) return a; if (a.value > b.value) return b; @@ -1043,15 +1059,15 @@ __device__ __forceinline__ double max(float a, double b) { return a > b ? (doubl } // Volatile overloads - template + template __device__ __forceinline__ T max_argidx_volatile_impl(const volatile T& a, const volatile T& b) { T va, vb; va.value = a.value; va.index = a.index; vb.value = b.value; vb.index = b.index; return max_argidx_impl(va, vb); } - - template + + template __device__ __forceinline__ T min_argidx_volatile_impl(const volatile T& a, const volatile T& b) { T va, vb; va.value = a.value; va.index = a.index; @@ -1062,7 +1078,7 @@ __device__ __forceinline__ double max(float a, double b) { return a > b ? (doubl // Explicit instantiation __device__ __forceinline__ argidx_fp32_i64 max(const argidx_fp32_i64& a, const argidx_fp32_i64& b) { return max_argidx_impl(a, b); } __device__ __forceinline__ argidx_fp32_i64 min(const argidx_fp32_i64& a, const argidx_fp32_i64& b) { return min_argidx_impl(a, b); } - + __device__ __forceinline__ argidx_fp32_i64 max(const volatile argidx_fp32_i64& a, const volatile argidx_fp32_i64& b) { return max_argidx_volatile_impl(a, b); } __device__ __forceinline__ argidx_fp32_i64 min(const volatile argidx_fp32_i64& a, const volatile argidx_fp32_i64& b) { return min_argidx_volatile_impl(a, b); } @@ -1070,8 +1086,8 @@ __device__ __forceinline__ double max(float a, double b) { return a > b ? (doubl __device__ __forceinline__ argidx_fp32_i32 min(const argidx_fp32_i32& a, const argidx_fp32_i32& b) { return min_argidx_impl(a, b); } } -// =============================================================== -// 9. ArgMin/ArgMax Block Reduce Instantiation +// =============================================================== +// 9. ArgMin/ArgMax Block Reduce Instantiation // =============================================================== // Row-wise reduction supporting 2D thread blocks @@ -1088,7 +1104,7 @@ __device__ inline T cinn_block_reduce_shm_impl(T value, T* shm_discard, Func red // Allocate sufficient static shared memory (1024 covers up to 32x32 thread blocks). // Increase this if your block is larger, though CINN argmax blocks are typically small. - __shared__ T internal_shm[1024]; + __shared__ T internal_shm[1024]; // 1. Store values (with bounds check) if (idx < 1024) { @@ -1128,16 +1144,16 @@ struct ArgIdxMinOp { extern "C" { -__device__ inline argidx_fp32_i64 cinn_block_reduce_max(const argidx_fp32_i64 value, argidx_fp32_i64 *shm, bool return_warp = false) { +__device__ inline argidx_fp32_i64 cinn_block_reduce_max(const argidx_fp32_i64 value, argidx_fp32_i64 *shm, bool return_warp = false) { return cinn_block_reduce_shm_impl(value, shm, ArgIdxMaxOp()); } -__device__ inline argidx_fp32_i64 cinn_block_reduce_min(const argidx_fp32_i64 value, argidx_fp32_i64 *shm, bool return_warp = false) { +__device__ inline argidx_fp32_i64 cinn_block_reduce_min(const argidx_fp32_i64 value, argidx_fp32_i64 *shm, bool return_warp = false) { return cinn_block_reduce_shm_impl(value, shm, ArgIdxMinOp()); } -__device__ inline argidx_fp32_i64 cinn_block_reduce_min_argidx_fp32_i64(const argidx_fp32_i64 value, argidx_fp32_i64 *shm, bool return_warp = false) { - return cinn_block_reduce_min(value, shm, return_warp); +__device__ inline argidx_fp32_i64 cinn_block_reduce_min_argidx_fp32_i64(const argidx_fp32_i64 value, argidx_fp32_i64 *shm, bool return_warp = false) { + return cinn_block_reduce_min(value, shm, return_warp); } __device__ inline argidx_fp32_i64 cinn_block_reduce_max_argidx_fp32_i64(const argidx_fp32_i64 value, argidx_fp32_i64 *shm, bool return_warp = false) { @@ -1161,10 +1177,9 @@ __device__ inline argidx_fp32_i32 cinn_block_reduce_max_argidx_fp32_i32(const ar return cinn_block_reduce_max(value, shm, return_warp); } -} // extern "C" +} // extern "C" )MACA_SOURCE"; - // ============================================================ // 2. Interface Implementation // ============================================================ @@ -1172,86 +1187,93 @@ __device__ inline argidx_fp32_i32 cinn_block_reduce_max_argidx_fp32_i32(const ar // Global atomic counter to ensure unique filenames static std::atomic g_compile_counter{0}; -const char* MetaxGetRuntimeSource(void* dev_ptr) { - return kMacaRuntimeSource; -} - -C_Status MetaxCompile(void* dev_ptr, const char* code, char* out_path, size_t len) { - // 0. Generate unique filename - // Use PID + atomic counter to generate unique filenames, - // completely resolving filename collisions during concurrent compilation - uint64_t file_id = g_compile_counter.fetch_add(1); - std::string file_prefix = "cinn_metax_" + std::to_string(getpid()) + "_" + std::to_string(file_id); - - // Generate temporary file paths - std::string src_path = "/tmp/" + file_prefix + ".cu"; - std::string obj_path = "/tmp/" + file_prefix + ".co"; - - // 1. Write source code - { - // Open in truncate mode; although the filename is unique, this is a safety measure - std::ofstream src_file(src_path, std::ios::trunc); - if (!src_file.is_open()) { - std::cerr << "[MetaX] Failed to open temp file: " << src_path << std::endl; - return C_Status::C_FAILED; - } - src_file << kMacaRuntimeSource << "\n"; - src_file << code; - src_file.close(); +const char* MetaxGetRuntimeSource(void* dev_ptr) { return kMacaRuntimeSource; } + +C_Status MetaxCompile(void* dev_ptr, + const char* code, + char* out_path, + size_t len) { + // 0. Generate unique filename + // Use PID + atomic counter to generate unique filenames, + // completely resolving filename collisions during concurrent compilation + uint64_t file_id = g_compile_counter.fetch_add(1); + std::string file_prefix = + "cinn_metax_" + std::to_string(getpid()) + "_" + std::to_string(file_id); + + // Generate temporary file paths + std::string src_path = "/tmp/" + file_prefix + ".cu"; + std::string obj_path = "/tmp/" + file_prefix + ".co"; + + // 1. Write source code + { + // Open in truncate mode; although the filename is unique, this is a safety + // measure + std::ofstream src_file(src_path, std::ios::trunc); + if (!src_file.is_open()) { + std::cerr << "[MetaX] Failed to open temp file: " << src_path + << std::endl; + return C_Status::C_FAILED; } + src_file << kMacaRuntimeSource << "\n"; + src_file << code; + src_file.close(); + } - // 2. Resolve compiler binary path - const char* maca_path_env = std::getenv("MACA_PATH"); - std::string maca_path = maca_path_env ? std::string(maca_path_env) : "/opt/maca"; - - std::string mxcc_cmd = maca_path + "/mxgpu_llvm/bin/mxcc"; - if (access(mxcc_cmd.c_str(), X_OK) != 0) { - mxcc_cmd = maca_path + "/bin/mxcc"; - if (access(mxcc_cmd.c_str(), X_OK) != 0) mxcc_cmd = "mxcc"; - } + // 2. Resolve compiler binary path + const char* maca_path_env = std::getenv("MACA_PATH"); + std::string maca_path = + maca_path_env ? std::string(maca_path_env) : "/opt/maca"; - // 3. Build compilation command - std::string cmd = mxcc_cmd + " -O3 -std=c++17 -w --fatbin --offload-arch=native -fvisibility=default"; - cmd += " -I" + maca_path + "/include"; - cmd += " -I" + maca_path + "/tools/cu-bridge/include"; - cmd += " -o " + obj_path; - cmd += " " + src_path; - - // 4. Execute compilation - std::cout << "Command: " << cmd << std::endl; - int ret = std::system(cmd.c_str()); - if (ret != 0) { - std::cerr << "[MetaX] JIT Compilation Failed! Code: " << ret << std::endl; - std::cerr << "Command: " << cmd << std::endl; - return C_Status::C_FAILED; - } + std::string mxcc_cmd = maca_path + "/mxgpu_llvm/bin/mxcc"; + if (access(mxcc_cmd.c_str(), X_OK) != 0) { + mxcc_cmd = maca_path + "/bin/mxcc"; + if (access(mxcc_cmd.c_str(), X_OK) != 0) mxcc_cmd = "mxcc"; + } - // 5. Verify output file exists - if (access(obj_path.c_str(), F_OK) != 0) { - std::cerr << "[MetaX] Output file missing: " << obj_path << std::endl; - return C_Status::C_FAILED; - } + // 3. Build compilation command + std::string cmd = + mxcc_cmd + + " -O3 -std=c++17 -w --fatbin --offload-arch=native -fvisibility=default"; + cmd += " -I" + maca_path + "/include"; + cmd += " -I" + maca_path + "/tools/cu-bridge/include"; + cmd += " -o " + obj_path; + cmd += " " + src_path; + + // 4. Execute compilation + std::cout << "Command: " << cmd << std::endl; + int ret = std::system(cmd.c_str()); + if (ret != 0) { + std::cerr << "[MetaX] JIT Compilation Failed! Code: " << ret << std::endl; + std::cerr << "Command: " << cmd << std::endl; + return C_Status::C_FAILED; + } - // ================================================================= - // 6. Write back the generated binary path to the CINN framework - // ================================================================= - if (out_path && len > 0) { - // Use strncpy for safe copy - std::strncpy(out_path, obj_path.c_str(), len - 1); - out_path[len - 1] = '\0'; // Ensure null-termination - // Print debug info to confirm write-back succeeded - std::cout << "[MetaX Success] Compiled: " << out_path << std::endl; - } else { - std::cerr << "[MetaX Error] Invalid out_path buffer!" << std::endl; - return C_Status::C_FAILED; - } + // 5. Verify output file exists + if (access(obj_path.c_str(), F_OK) != 0) { + std::cerr << "[MetaX] Output file missing: " << obj_path << std::endl; + return C_Status::C_FAILED; + } + + // ================================================================= + // 6. Write back the generated binary path to the CINN framework + // ================================================================= + if (out_path && len > 0) { + // Use strncpy for safe copy + std::strncpy(out_path, obj_path.c_str(), len - 1); + out_path[len - 1] = '\0'; // Ensure null-termination + // Print debug info to confirm write-back succeeded + std::cout << "[MetaX Success] Compiled: " << out_path << std::endl; + } else { + std::cerr << "[MetaX Error] Invalid out_path buffer!" << std::endl; + return C_Status::C_FAILED; + } - // 7. Clean up source file (enable after debugging is complete) - std::remove(src_path.c_str()); + // 7. Clean up source file (enable after debugging is complete) + std::remove(src_path.c_str()); - return C_Status::C_SUCCESS; + return C_Status::C_SUCCESS; } -} // namespace metax -} // namespace custom_device -} // namespace paddle \ No newline at end of file +} // namespace metax +} // namespace custom_device +} // namespace paddle diff --git a/backends/metax_gpu/cinn/passes/pass_manager.cc b/backends/metax_gpu/cinn/passes/pass_manager.cc index 0caa6033d9..15d73d0738 100644 --- a/backends/metax_gpu/cinn/passes/pass_manager.cc +++ b/backends/metax_gpu/cinn/passes/pass_manager.cc @@ -12,9 +12,10 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "paddle/phi/backends/device_ext.h" #include +#include "paddle/phi/backends/device_ext.h" + namespace paddle { namespace custom_device { namespace metax { @@ -22,10 +23,10 @@ namespace metax { // Applies custom graph optimization passes. // Currently a no-op stub; returns success immediately. C_Status MetaxApplyCustomPass(void* dev_ptr, void* ir_module) { - // VLOG(3) << "[MetaX] MetaxApplyCustomPass called (No-op)"; - return C_Status::C_SUCCESS; + // VLOG(3) << "[MetaX] MetaxApplyCustomPass called (No-op)"; + return C_Status::C_SUCCESS; } -} // namespace metax -} // namespace custom_device -} // namespace paddle +} // namespace metax +} // namespace custom_device +} // namespace paddle diff --git a/backends/metax_gpu/cinn/runtime/cinn_runtime.cc b/backends/metax_gpu/cinn/runtime/cinn_runtime.cc index ea12b1e359..7f19db35e4 100644 --- a/backends/metax_gpu/cinn/runtime/cinn_runtime.cc +++ b/backends/metax_gpu/cinn/runtime/cinn_runtime.cc @@ -12,11 +12,13 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "paddle/phi/backends/device_ext.h" #include + #include -#include #include +#include + +#include "paddle/phi/backends/device_ext.h" namespace paddle { namespace custom_device { @@ -24,47 +26,62 @@ namespace metax { // Load module: equivalent to cuModuleLoad C_Status MetaxModuleLoad(void* dev_ptr, const char* path, void** mod_out) { - CUmodule module; - CUresult err = cuModuleLoad(&module, path); - if (err != CUDA_SUCCESS) return C_Status::C_FAILED; + CUmodule module; + CUresult err = cuModuleLoad(&module, path); + if (err != CUDA_SUCCESS) return C_Status::C_FAILED; - *mod_out = (void*)module; - return C_Status::C_SUCCESS; + *mod_out = reinterpret_cast(module); + return C_Status::C_SUCCESS; } // Unload module C_Status MetaxModuleUnload(void* dev_ptr, void* module_handle) { - cuModuleUnload((CUmodule)module_handle); - return C_Status::C_SUCCESS; + cuModuleUnload((CUmodule)module_handle); + return C_Status::C_SUCCESS; } // Get kernel function address: equivalent to cuModuleGetFunction -C_Status MetaxGetKernelAddress(void* dev_ptr, void* module_handle, const char* func_name, void** func_out) { - CUfunction func; - CUresult err = cuModuleGetFunction(&func, (CUmodule)module_handle, func_name); - if (err != CUDA_SUCCESS) return C_Status::C_FAILED; +C_Status MetaxGetKernelAddress(void* dev_ptr, + void* module_handle, + const char* func_name, + void** func_out) { + CUfunction func; + CUresult err = cuModuleGetFunction(&func, (CUmodule)module_handle, func_name); + if (err != CUDA_SUCCESS) return C_Status::C_FAILED; - *func_out = (void*)func; - return C_Status::C_SUCCESS; + *func_out = reinterpret_cast(func); + return C_Status::C_SUCCESS; } // Launch kernel: equivalent to cuLaunchKernel -C_Status MetaxLaunchKernel(void* dev_ptr, void* func_ptr, void** args, int num_args, - int gx, int gy, int gz, - int bx, int by, int bz, - int shm, void* stream) { - // Note: args is typically a void*[] and may require argument marshaling - CUresult err = cuLaunchKernel((CUfunction)func_ptr, - gx, gy, gz, - bx, by, bz, - shm, - (CUstream)stream, - args, - nullptr); - if (err != CUDA_SUCCESS) return C_Status::C_FAILED; - return C_Status::C_SUCCESS; +C_Status MetaxLaunchKernel(void* dev_ptr, + void* func_ptr, + void** args, + int num_args, + int gx, + int gy, + int gz, + int bx, + int by, + int bz, + int shm, + void* stream) { + // Note: args is typically a void*[] and may require argument marshaling + CUresult err = cuLaunchKernel((CUfunction)func_ptr, + gx, + gy, + gz, + bx, + by, + bz, + shm, + (CUstream)stream, + args, + nullptr); + if (err != CUDA_SUCCESS) return C_Status::C_FAILED; + return C_Status::C_SUCCESS; } -} // namespace metax -} // namespace custom_device -} // namespace paddle +} // namespace metax +} // namespace custom_device +} // namespace paddle diff --git a/backends/metax_gpu/runtime/runtime.cc b/backends/metax_gpu/runtime/runtime.cc index 0500a5a332..800bd2a6f3 100644 --- a/backends/metax_gpu/runtime/runtime.cc +++ b/backends/metax_gpu/runtime/runtime.cc @@ -35,6 +35,7 @@ #include #include +#include "../cinn/cinn_interface.h" #include "glog/logging.h" #include "paddle/fluid/platform/profiler/cuda_tracer.h" #include "paddle/fluid/platform/profiler/cupti_data_process.h" @@ -53,10 +54,8 @@ #include "paddle/phi/core/platform/device/gpu/gpu_info.h" #include "paddle/phi/core/platform/profiler/utils.cc" //NOLINT #include "paddle/phi/core/platform/profiler/utils.h" -#include "paddle/phi/backends/device_ext.h" #include "passes/pattern_passes.h" #include "runtime/process_cupti_data.cc" //NOLINT -#include "../cinn/cinn_interface.h" #include "unsupported/Eigen/CXX11/Tensor" #define MEMORY_FRACTION 0.5f @@ -69,10 +68,10 @@ const char *const SubDeviceType = "v0.1"; namespace paddle { namespace custom_device { namespace metax { - void InitCinnInterface(C_DeviceInterface* interface); -} -} +void InitCinnInterface(C_DeviceInterface *interface); } +} // namespace custom_device +} // namespace paddle #endif namespace phi { @@ -406,7 +405,7 @@ C_Status GetMaxThreadsPerBlock(const C_Device device, } C_Status GetMaxSharedMemPerBlock(const C_Device device, - size_t *shared_mem_per_block) { + size_t *shared_mem_per_block) { int id = device->id; int count = 0; cudaError_t status = @@ -415,28 +414,26 @@ C_Status GetMaxSharedMemPerBlock(const C_Device device, return C_SUCCESS; } -C_Status GetWarpSize(const C_Device device, - size_t *warp_size) { +C_Status GetWarpSize(const C_Device device, size_t *warp_size) { int id = device->id; int size = 0; - cudaError_t status = - cudaDeviceGetAttribute(&size, cudaDevAttrWarpSize, id); + cudaError_t status = cudaDeviceGetAttribute(&size, cudaDevAttrWarpSize, id); *warp_size = size; return C_SUCCESS; } C_Status GetMaxRegistersPerMultiProcessor(const C_Device device, - size_t *registers_per_mp) { + size_t *registers_per_mp) { int id = device->id; int count = 0; - cudaError_t status = - cudaDeviceGetAttribute(&count, cudaDevAttrMaxRegistersPerMultiprocessor, id); + cudaError_t status = cudaDeviceGetAttribute( + &count, cudaDevAttrMaxRegistersPerMultiprocessor, id); *registers_per_mp = count; return C_SUCCESS; } C_Status GetPreferredVectorWidth(const C_Device device, - size_t *vector_alignment) { + size_t *vector_alignment) { int id = device->id; // int count = 0; // cudaError_t status = @@ -445,9 +442,9 @@ C_Status GetPreferredVectorWidth(const C_Device device, *vector_alignment = 128; return C_SUCCESS; } - + C_Status GetMaxBlocksPerMultiProcessor(const C_Device device, - size_t *blocks_per_mp) { + size_t *blocks_per_mp) { int id = device->id; int count = 0; cudaError_t status = @@ -473,15 +470,18 @@ C_Status GetMaxGridDimSize(const C_Device device, } C_Status GetMaxBlockDimSize(const C_Device device, - std::array *block_dim_size) { + std::array *block_dim_size) { int id = device->id; std::array ret = {}; int size; - auto error_code_x = cudaDeviceGetAttribute(&size, cudaDevAttrMaxBlockDimX, id); + auto error_code_x = + cudaDeviceGetAttribute(&size, cudaDevAttrMaxBlockDimX, id); ret[0] = size; - auto error_code_y = cudaDeviceGetAttribute(&size, cudaDevAttrMaxBlockDimY, id); + auto error_code_y = + cudaDeviceGetAttribute(&size, cudaDevAttrMaxBlockDimY, id); ret[1] = size; - auto error_code_z = cudaDeviceGetAttribute(&size, cudaDevAttrMaxBlockDimZ, id); + auto error_code_z = + cudaDeviceGetAttribute(&size, cudaDevAttrMaxBlockDimZ, id); ret[2] = size; *block_dim_size = ret; @@ -1549,7 +1549,8 @@ void InitPlugin(CustomRuntimeParams *params) { params->interface->get_max_shared_mem_per_block = GetMaxSharedMemPerBlock; params->interface->get_max_blocks_per_mp = GetMaxBlocksPerMultiProcessor; params->interface->get_warp_size = GetWarpSize; - params->interface->get_max_registers_per_mp = GetMaxRegistersPerMultiProcessor; + params->interface->get_max_registers_per_mp = + GetMaxRegistersPerMultiProcessor; params->interface->get_vector_width = GetPreferredVectorWidth; params->interface->get_max_grid_dim_size = GetMaxGridDimSize; params->interface->get_max_block_dim_size = GetMaxBlockDimSize; @@ -1636,12 +1637,12 @@ void InitPlugin(CustomRuntimeParams *params) { // PIR pass pipeline params->pir_default_passes = reinterpret_cast( const_cast *>(GetPirMetaxGpuPasses())); - + // CINN interface init #ifdef WITH_CINN if (params->interface) { - paddle::custom_device::metax::InitCinnInterface(params->interface); - LOG(INFO) << "[MetaX] CINN Interface registered successfully."; + paddle::custom_device::metax::InitCinnInterface(params->interface); + LOG(INFO) << "[MetaX] CINN Interface registered successfully."; } #endif } From 1cf361bb40eaff35e04e714654847761a50ad80f Mon Sep 17 00:00:00 2001 From: Yuhan Xu Date: Wed, 1 Apr 2026 08:15:26 +0000 Subject: [PATCH 16/17] Update Paddle. --- Paddle | 2 +- backends/metax_gpu/CMakeLists.txt | 14 ++++++++++++++ 2 files changed, 15 insertions(+), 1 deletion(-) diff --git a/Paddle b/Paddle index 56be465924..638c7c3988 160000 --- a/Paddle +++ b/Paddle @@ -1 +1 @@ -Subproject commit 56be465924264e1251cf127dbff56d17a7554d01 +Subproject commit 638c7c39881e982d607ab266c380cbdab0fc767e diff --git a/backends/metax_gpu/CMakeLists.txt b/backends/metax_gpu/CMakeLists.txt index 93298ab1ee..9b04f32a9e 100755 --- a/backends/metax_gpu/CMakeLists.txt +++ b/backends/metax_gpu/CMakeLists.txt @@ -26,6 +26,20 @@ set(PYTHON_VERSION ${PY_VERSION}) set(CMAKE_MODULE_PATH "${CMAKE_SOURCE_DIR}/cmake") message(STATUS "CMAKE_MODULE_PATH: ${CMAKE_MODULE_PATH}") + +if(NOT DEFINED PADDLE_WARP_SIZE) + set(PADDLE_WARP_SIZE 32) +endif() +math(EXPR PADDLE_WARP_MASK "${PADDLE_WARP_SIZE} - 1") +if(PADDLE_WARP_SIZE EQUAL 64) + set(PADDLE_WARP_SHIFT 6) +else() + set(PADDLE_WARP_SHIFT 5) +endif() +add_definitions(-DPADDLE_WARP_SIZE=${PADDLE_WARP_SIZE}) +add_definitions(-DPADDLE_WARP_MASK=${PADDLE_WARP_MASK}) +add_definitions(-DPADDLE_WARP_SHIFT=${PADDLE_WARP_SHIFT}) + set(WITH_MKLML ON) if(WITH_ARM) set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -fPIC") From 75c6b8e58c7e9a2df5e46fbac1d3707d76d5d3ea Mon Sep 17 00:00:00 2001 From: Yuhan Xu Date: Thu, 2 Apr 2026 08:31:01 +0000 Subject: [PATCH 17/17] Fix CMakeLists.txt PADDLE_WARP_SIZE 32->64. Fix argidx_fp32_i32 forward reference error in MetaX runtime. --- backends/metax_gpu/CMakeLists.txt | 2 +- backends/metax_gpu/cinn/compiler/compiler.cc | 100 +++++++++++-------- 2 files changed, 61 insertions(+), 41 deletions(-) diff --git a/backends/metax_gpu/CMakeLists.txt b/backends/metax_gpu/CMakeLists.txt index 9b04f32a9e..e02964445b 100755 --- a/backends/metax_gpu/CMakeLists.txt +++ b/backends/metax_gpu/CMakeLists.txt @@ -28,7 +28,7 @@ set(CMAKE_MODULE_PATH "${CMAKE_SOURCE_DIR}/cmake") message(STATUS "CMAKE_MODULE_PATH: ${CMAKE_MODULE_PATH}") if(NOT DEFINED PADDLE_WARP_SIZE) - set(PADDLE_WARP_SIZE 32) + set(PADDLE_WARP_SIZE 64) endif() math(EXPR PADDLE_WARP_MASK "${PADDLE_WARP_SIZE} - 1") if(PADDLE_WARP_SIZE EQUAL 64) diff --git a/backends/metax_gpu/cinn/compiler/compiler.cc b/backends/metax_gpu/cinn/compiler/compiler.cc index cafc766bf7..b65f73e6e4 100644 --- a/backends/metax_gpu/cinn/compiler/compiler.cc +++ b/backends/metax_gpu/cinn/compiler/compiler.cc @@ -700,6 +700,65 @@ EXPAND_REDUCE_FP64_MACRO(CINN_DISCRETE_REDUCE_MACRO) EXPAND_REDUCE_BOOL_MACRO(CINN_DISCRETE_REDUCE_MACRO) EXPAND_REDUCE_FP16_MACRO(CINN_DISCRETE_REDUCE_MACRO) +// =============================================================== +// ArgMin/ArgMax Support (ArgIdx Structures & Combine Functions) +// Must be defined before discrete/block/grid reduce functions that use them +// =============================================================== + +// arg reduce arg index struct +// Do not define operator<; force dispatch through std::max overloads +#define ARGIDX_STRUCT_MACRO(TYPENAME, DTYPE, ITYPE, IINIT) \ + struct TYPENAME { \ + DTYPE value; \ + ITYPE index; \ + __device__ TYPENAME() {} \ + __device__ explicit TYPENAME(DTYPE value) : value(value), index(IINIT) {} \ + __device__ TYPENAME(DTYPE value, ITYPE index) \ + : value(value), index(index) {} \ + __device__ explicit operator ITYPE() { return index; } \ + /* Assignment operator support */ \ + __device__ inline TYPENAME& operator=(const TYPENAME& other) { \ + value = other.value; \ + index = other.index; \ + return *this; \ + } \ + __device__ inline volatile TYPENAME& operator=(const volatile TYPENAME& other) volatile { \ + value = other.value; \ + index = other.index; \ + return *this; \ + } \ + }; + +// Instantiate structs +#ifdef CINN_CUDA_FP16 +ARGIDX_STRUCT_MACRO(argidx_fp16_i64, float16, int64_t, 0LL) +#endif +ARGIDX_STRUCT_MACRO(argidx_fp32_i64, float, int64_t, 0LL) +ARGIDX_STRUCT_MACRO(argidx_fp64_i64, double, int64_t, 0LL) +ARGIDX_STRUCT_MACRO(argidx_i16_i64, int16_t, int64_t, 0LL) +ARGIDX_STRUCT_MACRO(argidx_i32_i64, int, int64_t, 0LL) +ARGIDX_STRUCT_MACRO(argidx_i64_i64, int64_t, int64_t, 0LL) +ARGIDX_STRUCT_MACRO(argidx_u8_i64, uint8_t, int64_t, 0LL) + +ARGIDX_STRUCT_MACRO(argidx_fp32_i32, float, int, 0) +ARGIDX_STRUCT_MACRO(argidx_i32_i32, int, int, 0) + +// cinn_max_argidx / cinn_min_argidx combine functions +// These are called by CINN_DISCRETE_REDUCE_IMPL via cinn_##REDUCE_TYPE token pasting +#define ARGIDX_COMBINE_MACRO(TYPENAME) \ + __device__ TYPENAME cinn_min_##TYPENAME(TYPENAME a, TYPENAME b) { \ + return a.value == b.value ? (a.index < b.index ? a : b) \ + : (a.value < b.value ? a : b); \ + } \ + __device__ TYPENAME cinn_max_##TYPENAME(TYPENAME a, TYPENAME b) { \ + return a.value == b.value ? (a.index < b.index ? a : b) \ + : (a.value > b.value ? a : b); \ + } + +ARGIDX_COMBINE_MACRO(argidx_fp32_i32) +ARGIDX_COMBINE_MACRO(argidx_fp32_i64) +ARGIDX_COMBINE_MACRO(argidx_i32_i32) + // Discrete reduce for argidx types __device__ inline argidx_fp32_i32 cinn_discrete_reduce_max_argidx_fp32_i32( const argidx_fp32_i32 value, argidx_fp32_i32 *shm) { @@ -983,47 +1042,8 @@ __device__ int cinn_custom_device_resize_bicubic(const int *buf, } // extern "C" // =============================================================== -// 8. ArgMin/ArgMax Support (ArgIdx Structures & Shuffles) +// 8. ArgMin/ArgMax std::max/min Overloads & Block Reduce // =============================================================== -// --- C++ Scope Start --- - -// arg reduce arg index struct -// Do not define operator<; force dispatch through std::max overloads -#define ARGIDX_STRUCT_MACRO(TYPENAME, DTYPE, ITYPE, IINIT) \ - struct TYPENAME { \ - DTYPE value; \ - ITYPE index; \ - __device__ TYPENAME() {} \ - __device__ explicit TYPENAME(DTYPE value) : value(value), index(IINIT) {} \ - __device__ TYPENAME(DTYPE value, ITYPE index) \ - : value(value), index(index) {} \ - __device__ explicit operator ITYPE() { return index; } \ - /* Assignment operator support */ \ - __device__ inline TYPENAME& operator=(const TYPENAME& other) { \ - value = other.value; \ - index = other.index; \ - return *this; \ - } \ - __device__ inline volatile TYPENAME& operator=(const volatile TYPENAME& other) volatile { \ - value = other.value; \ - index = other.index; \ - return *this; \ - } \ - }; - -// Instantiate structs -#ifdef CINN_CUDA_FP16 -ARGIDX_STRUCT_MACRO(argidx_fp16_i64, float16, int64_t, 0LL) -#endif -ARGIDX_STRUCT_MACRO(argidx_fp32_i64, float, int64_t, 0LL) -ARGIDX_STRUCT_MACRO(argidx_fp64_i64, double, int64_t, 0LL) -ARGIDX_STRUCT_MACRO(argidx_i16_i64, int16_t, int64_t, 0LL) -ARGIDX_STRUCT_MACRO(argidx_i32_i64, int, int64_t, 0LL) -ARGIDX_STRUCT_MACRO(argidx_i64_i64, int64_t, int64_t, 0LL) -ARGIDX_STRUCT_MACRO(argidx_u8_i64, uint8_t, int64_t, 0LL) - -ARGIDX_STRUCT_MACRO(argidx_fp32_i32, float, int, 0) -ARGIDX_STRUCT_MACRO(argidx_i32_i32, int, int, 0) // std::max overloads namespace std {