Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Paddle
Submodule Paddle updated 1296 files
33 changes: 33 additions & 0 deletions backends/metax_gpu/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,29 @@ project(${PROJ_NAME} CXX C CUDA)

set(TARGET_NAME ${PROJ_NAME})

option(WITH_CINN "Compile with CINN support" ON)

find_package(Python3 REQUIRED COMPONENTS Interpreter)
set(PY_VERSION ${Python3_VERSION_MAJOR}.${Python3_VERSION_MINOR})
message(STATUS "Python version detected: ${PY_VERSION}")
set(PYTHON_VERSION ${PY_VERSION})

set(CMAKE_MODULE_PATH "${CMAKE_SOURCE_DIR}/cmake")
message(STATUS "CMAKE_MODULE_PATH: ${CMAKE_MODULE_PATH}")

if(NOT DEFINED PADDLE_WARP_SIZE)
set(PADDLE_WARP_SIZE 64)
endif()
math(EXPR PADDLE_WARP_MASK "${PADDLE_WARP_SIZE} - 1")
if(PADDLE_WARP_SIZE EQUAL 64)
set(PADDLE_WARP_SHIFT 6)
else()
set(PADDLE_WARP_SHIFT 5)
endif()
add_definitions(-DPADDLE_WARP_SIZE=${PADDLE_WARP_SIZE})
add_definitions(-DPADDLE_WARP_MASK=${PADDLE_WARP_MASK})
add_definitions(-DPADDLE_WARP_SHIFT=${PADDLE_WARP_SHIFT})

set(WITH_MKLML ON)
if(WITH_ARM)
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -fPIC")
Expand All @@ -40,6 +56,13 @@ if(WITH_ARM)
add_definitions(-DPADDLE_WITH_ARM)
endif()
include(paddle)

if(WITH_CINN)
message(STATUS "[MetaX] CINN enabled, adding subdirectory: cinn")
add_definitions(-DWITH_CINN)
add_subdirectory(cinn)
endif()

set(THIRD_PARTY_PATH
"${PADDLE_SOURCE_DIR}/build/third_party"
CACHE PATH "Third party libraries directory.")
Expand Down Expand Up @@ -792,6 +815,11 @@ set(CMAKE_CUCC_FLAGS "-I ${MACA_PATH}/tools/cu-bridge/include/")

add_library(${TARGET_NAME} SHARED ${CUSTOM_DEVICE_SRCS})

if(WITH_CINN)
target_include_directories(${TARGET_NAME}
PRIVATE "${CMAKE_CURRENT_SOURCE_DIR}/cinn")
endif()

target_include_directories(
${TARGET_NAME}
PRIVATE ${PADDLE_SOURCE_DIR}
Expand Down Expand Up @@ -821,6 +849,11 @@ target_link_libraries(${TARGET_NAME} ${MACA_PATH}/lib/libmccl.so)
target_link_libraries(${TARGET_NAME} ${MACA_PATH}/lib/libmcFlashAttn.so)
target_link_libraries(${TARGET_NAME} ${MACA_PATH}/lib/libmcpti.so)

if(WITH_CINN)
message(STATUS "[MetaX] Linking CINN object library")
target_link_libraries(${TARGET_NAME} $<TARGET_OBJECTS:metax_cinn_obj>)
endif()

include_directories(BEFORE ${PADDLE_SOURCE_DIR})
include_directories(BEFORE ${CMAKE_SOURCE_DIR}/headers)

Expand Down
48 changes: 48 additions & 0 deletions backends/metax_gpu/cinn/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
# =============================================================================
# CINN Plugin for MetaX (MACA) Backend
# =============================================================================

# 1. Locate MACA SDK path To allow #include <maca_runtime.h> in
# runtime/cinn_runtime.cc or compiler.cc, we need to add the MetaX SDK header
# search path.
set(MACA_PATH $ENV{MACA_PATH})
if(NOT MACA_PATH)
set(MACA_PATH "/opt/maca") # Default fallback path
message(STATUS "[MetaX CINN] MACA_PATH not set, using default: ${MACA_PATH}")
else()
message(STATUS "[MetaX CINN] Found MACA_PATH: ${MACA_PATH}")
endif()

# 1. Define source file list All .cc files involved in the CINN implementation
# must be included here.
set(CINN_SRCS
cinn_interface.cc # Main entry point, responsible for InitCinnInterface
compiler/compiler.cc # Implements MetaxCompile and MetaxGetRuntimeSource
runtime/cinn_runtime.cc # Implements MetaxModuleLoad, MetaxLaunchKernel
passes/pass_manager.cc # Implements MetaxApplyCustomPass
)

# 1. Create OBJECT library Use OBJECT mode to compile into .o files only, without
# generating .a or .so. This allows the parent CMake to directly collect these
# .o files and link them into the final plugin.so.
add_library(metax_cinn_obj OBJECT ${CINN_SRCS})

# 1. Configure header search paths
target_include_directories(
metax_cinn_obj
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} # Allow referencing headers in current
# directory (cinn_interface.h)
${CMAKE_CURRENT_SOURCE_DIR}/../ # Allow referencing parent-level
# headers (e.g., common/)
${MACA_PATH}/include # Allow referencing <maca_runtime.h>
${PADDLE_SOURCE_DIR} # Allow referencing paddle/phi/... headers
# Paddle header paths are typically auto-included via the external
# environment (Paddle_DIR)
)

# 1. Compiler options The CINN component typically requires C++17 standard
set_property(TARGET metax_cinn_obj PROPERTY CXX_STANDARD 17)

# Enable PIC (Position Independent Code) Required because these .o files will
# ultimately be linked into a shared library
set_property(TARGET metax_cinn_obj PROPERTY POSITION_INDEPENDENT_CODE ON)
117 changes: 117 additions & 0 deletions backends/metax_gpu/cinn/cinn_interface.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
// Copyright (c) 2026 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "cinn/cinn_interface.h"

#include <cstring> // For memset
#include <iostream>

namespace paddle {
namespace custom_device {
namespace metax {

// ============================================================
// External Function Declarations
// These functions must be implemented in the corresponding subdirectory files
// (.cc).
// ============================================================

// --- From compiler/compiler.cc ---
// Invokes the mxcc toolchain to compile CINN-generated source code into a
// binary
extern C_Status MetaxCompile(void* dev_ptr,
const char* code,
char* out_path,
size_t len);

// Provides the MetaX GPU device runtime source code
extern const char* MetaxGetRuntimeSource(void* dev_ptr);

// --- From runtime/cinn_runtime.cc ---
// Loads a compiled binary module (.mx / .so)
extern C_Status MetaxModuleLoad(void* dev_ptr,
const char* path,
void** mod_out);

// Unloads a module
extern C_Status MetaxModuleUnload(void* dev_ptr, void* module_handle);

// Retrieves the kernel function address from a loaded module
extern C_Status MetaxGetKernelAddress(void* dev_ptr,
void* module_handle,
const char* func_name,
void** func_out);

// Launches a kernel function
extern C_Status MetaxLaunchKernel(void* dev_ptr,
void* func_ptr,
void** args,
int num_args,
int gx,
int gy,
int gz,
int bx,
int by,
int bz,
int shm,
void* stream);

// --- From passes/pass_manager.cc ---
// Applies custom graph optimization passes
extern C_Status MetaxApplyCustomPass(void* dev_ptr, void* ir_module);

// ============================================================
// Interface Initialization
// ============================================================

// Static instance, valid throughout the plugin lifetime
static C_CinnInterface metax_cinn_impl;

void InitCinnInterface(C_DeviceInterface* device_interface) {
// 1. Zero-initialize for safety
std::memset(&metax_cinn_impl, 0, sizeof(C_CinnInterface));

// 2. Set struct size (used for version validation)
metax_cinn_impl.size = sizeof(C_CinnInterface);

// 3. Set context pointer (optional)
// Point to a global state struct if your implementation needs one; otherwise
// nullptr
metax_cinn_impl.dev_ptr = nullptr;

// 4. Register Compiler Toolchain interface
metax_cinn_impl.compile = MetaxCompile;
metax_cinn_impl.get_runtime_source = MetaxGetRuntimeSource;

// 5. Register Runtime Strategy interface
metax_cinn_impl.module_load = MetaxModuleLoad;
metax_cinn_impl.module_unload = MetaxModuleUnload;
metax_cinn_impl.get_kernel_address = MetaxGetKernelAddress;
metax_cinn_impl.launch_kernel = MetaxLaunchKernel;

// 6. Register Compilation Strategy interface
metax_cinn_impl.apply_custom_pass = MetaxApplyCustomPass;

// 7. Attach the populated dispatch table to the Paddle device interface
if (device_interface) {
device_interface->cinn_interface = &metax_cinn_impl;
} else {
std::cerr << "[MetaX] Error: device_interface is null during CINN init."
<< std::endl;
}
}

} // namespace metax
} // namespace custom_device
} // namespace paddle
38 changes: 38 additions & 0 deletions backends/metax_gpu/cinn/cinn_interface.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
// Copyright (c) 2026 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

// Include the Paddle-defined C interface structures
#include "paddle/phi/backends/device_ext.h"

namespace paddle {
namespace custom_device {
namespace metax {

/**
* @brief Initialize the CINN interface.
*
* This function is called by InitPlugin in runtime.cc.
* It populates device_interface->cinn_interface with the compiler
* and runtime function pointers implemented under metax_gpu/cinn.
*
* @param device_interface The device interface pointer passed from the Paddle
* host side.
*/
void InitCinnInterface(C_DeviceInterface* device_interface);

} // namespace metax
} // namespace custom_device
} // namespace paddle
Loading