Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .github/workflows/rocm-wheels-build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@ jobs:
3rdparty/aotriton \
3rdparty/aiter \
3rdparty/QoLA \
3rdparty/ck_jit \
3rdparty/hipify_torch

- name: Derive Docker image tag
Expand Down
3 changes: 3 additions & 0 deletions .gitmodules
Original file line number Diff line number Diff line change
Expand Up @@ -26,3 +26,6 @@
[submodule "3rdparty/QoLA"]
path = 3rdparty/QoLA
url = https://github.com/ROCm/QoLA.git
[submodule "3rdparty/ck_jit"]
path = 3rdparty/ck_jit
url = https://github.com/ROCm/ck-jit.git
1 change: 1 addition & 0 deletions 3rdparty/ck_jit
Submodule ck_jit added at 62ce69
43 changes: 41 additions & 2 deletions ci/_utils.sh
Original file line number Diff line number Diff line change
Expand Up @@ -237,11 +237,15 @@ start_message() {
python3 --version
}

configure_omp_threads() {
get_cpu_count() {
n_vcpus=$(lscpu | grep "^CPU(s):" | awk '{print $2}')
cpus_per_core=$(lscpu | grep "Thread(s) per core:" | awk '{print $NF}')

n_physical_cores=$((n_vcpus / cpus_per_core))
echo $((n_vcpus / cpus_per_core))
}

configure_omp_threads() {
n_physical_cores=`get_cpu_count`
n_parallel_jobs=$1

if [ -z ${OMP_NUM_THREADS} ]; then
Expand Down Expand Up @@ -270,3 +274,38 @@ pytest_run() {
test $? -eq 0 || test_run_error "[$_test_variant_tag] $1"
echo "Done [$_test_variant_tag] $1 in `time_elapsed $_start_ts`"
}

PYTHON_TE_IMPORT="import sys; sys.path[:] = [p for p in sys.path if p not in ['', '.']]; import transformer_engine"
ck_jit_prebuild() {
_prebuild_list="${TE_PATH}ci/ck_jit_prebuild.txt"
if [ ! -f "$_prebuild_list" ]; then
script_error "ck_jit_prebuild: blob list not found: $_prebuild_list"
return 1
fi
_gpu_arch=$(rocminfo | grep -E "^ *Name: *gfx" | head -1 | sed "s/.*gfx/gfx/;s/ .*//" 2>/dev/null)
if [ -n "$_gpu_arch" ]; then
_arch_arg="--arch $_gpu_arch"
else
script_error "ck_jit_prebuild: GPU architecture not detected, omitting --arch"
_arch_arg=""
fi
_te_install_dir=$(python -c "${PYTHON_TE_IMPORT}; import os; print(os.path.dirname(transformer_engine.__file__))" 2>/dev/null)
if [ -z "$_te_install_dir" ]; then
script_error "ck_jit_prebuild: failed to determine transformer_engine installation directory"
return 1
fi
_prebuild_py="$_te_install_dir/lib/ck_jit/ck_jit_prebuild.py"
if [ ! -f "$_prebuild_py" ]; then
script_error "ck_jit_prebuild: prebuild script not found: $_prebuild_py"
return 1
fi
_cpu_count=$(get_cpu_count)
if [ -n "$_cpu_count" -a "$_cpu_count" != "0" ]; then
_jobs_arg="--jobs $((_cpu_count/2))"
fi
if [ "$1" = "build" ]; then
echo "Building CK JIT cache for arch=${_gpu_arch:-<not detected>}..."
python "$_prebuild_py" build --blob-list "$_prebuild_list" $_arch_arg $_jobs_arg > /dev/null
fi
python "$_prebuild_py" cache | grep Cache
}
520 changes: 520 additions & 0 deletions ci/ck_jit_prebuild.txt

Large diffs are not rendered by default.

3 changes: 3 additions & 0 deletions ci/jax.sh
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,7 @@ install_prerequisites
pip list | egrep "flax|fidle|jax|ml_dtypes|numpy|transformer_e|typing_ext"
#check_test_jobs_requested
#test $? -eq 0 && init_test_jobs `python -c "import jax; print(len([d for d in jax.devices() if 'rocm' in d.client.platform_version]))"`
ck_jit_prebuild build || exit $?

for _fus_attn in auto ck aotriton; do
configure_fused_attn_env $_fus_attn || continue
Expand Down Expand Up @@ -139,4 +140,6 @@ if [ -n "$TEST_JOBS_MODE" -a -n "$TEST_MGPU" ]; then
configure_fused_attn_env $_fus_attn && run_test_config_mgpu
done
fi

ck_jit_prebuild list
return_run_results
16 changes: 16 additions & 0 deletions ci/pytorch.sh
Original file line number Diff line number Diff line change
Expand Up @@ -136,11 +136,22 @@ if [ -n "$SINGLE_CONFIG" ]; then
exit $?
fi

check_flash_attn_installed() {
_result=$(python -c "${PYTHON_TE_IMPORT}; from transformer_engine.pytorch.attention.dot_product_attention.utils import FlashAttentionUtils; print(FlashAttentionUtils.is_installed)" 2>/dev/null)
if [ "$_result" = "True" ]; then
return 0
else
echo "Flash attention is not installed" >&2
return 1
fi
}

#Master script mode: prepare testing prerequisites first
start_message
install_prerequisites
pip list | egrep "flash|ml_dtypes|numpy|torch|transformer_e|typing_ext"
#check_test_jobs_requested && init_test_jobs `python -c "import torch; print(torch.cuda.device_count())"`
ck_jit_prebuild build || exit $?

for _fus_attn in auto flash ck aotriton unfused; do
configure_fused_attn_env $_fus_attn || continue
Expand All @@ -163,6 +174,10 @@ for _fus_attn in auto flash ck aotriton unfused; do
_DEFAULT_FUSED_ATTN="auto"
fi

if [ $_fus_attn = flash ]; then
check_flash_attn_installed || continue
fi

if [ -n "$TEST_JOBS_MODE" ]; then
test -n "$TEST_SGPU" && run_test_job "$_fus_attn"
else
Expand All @@ -185,4 +200,5 @@ if [ $TEST_LEVEL -ge 3 ]; then
fi
fi

ck_jit_prebuild list
return_run_results
12 changes: 4 additions & 8 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,16 +79,12 @@ def setup_common_extension() -> CMakeExtension:
os.getenv("MPI_HOME") is not None
), "MPI_HOME must be set when compiling with NVTE_UB_WITH_MPI=1"
cmake_flags.append("-DNVTE_UB_WITH_MPI=ON")

if rocm_build():
cmake_flags.append("-DUSE_ROCM=ON")

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This hunk silently drops the NVTE_AOTRITON_PATH -> -DAOTRITON_PATH=… and NVTE_CK_FUSED_ATTN_PATH -> -DAITER_MHA_PATH=… translations.

Downstream the CMake files now read $ENV{AOTRITON_PATH} and $ENV{AITER_MHA_PATH} directly (see aotriton/CMakeLists.txt:11 and ck_fused_attn/CMakeLists.txt:78), so users have to rename their env vars from NVTE_AOTRITON_PATH / NVTE_CK_FUSED_ATTN_PATH to the bare names. Anyone with existing scripts setting the NVTE_* form will get a silent "ignored" failure — the build will quietly fall back to the default behavior instead of using their prebuilt path.

PR is marked as breaking, but the description doesn't call this rename out. Two reasonable options:

  • keep the NVTE_* aliases here as a thin shim (a couple of lines) so old usage still works; or
  • explicitly document the env-var rename in the PR description / release notes.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this an intentional part of this change? Seems unrelated?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removing of NVTE_AOTRITON_PATH is just a reflection of earlier changes in CMakeLists.txt to use ENV{AOTRITON_PATH}, so NVTE_AOTRITON_PATH is actually unused.
It is not directly related but just aligned to corresponding NVTE_CK_FUSED_ATTN_PATH. The removal of latter in this PR has few reasons:

  • Old scheme cached value of AITER_MHA_PATH define and the only way to unset it w/o full reconfiguration and rebuild was manual CMakeCache edit. It always made using of this env variable inconvenient, alternative changing would be always set -DAITER_MHA_PATH ON/OFF depending on env var.
  • The env var was very important and useful before introducing of AITER pre-build cache. It also supposed that whoever use it, knows what do they do. With quick evolving of pre-built AITER format after QoLA introducing it started causing building problems for internal customers who historically used the env var. With CK_JIT it might also require deeper understanding of building architecture of TE MHA libs. So I thought it is more practical to get rid of the old var. New var is introduced for developers, similar to AOTriton.
    It is, however, can be discussed separately, not as a part of this PR

if os.getenv("NVTE_AOTRITON_PATH"):
aotriton_path = Path(os.getenv("NVTE_AOTRITON_PATH"))
cmake_flags.append(f"-DAOTRITON_PATH={aotriton_path}")
cmake_flags.append(f"-DCK_FUSED_ATTN_FLOAT_TO_BFLOAT16_DEFAULT={os.getenv('NVTE_CK_FUSED_ATTN_FLOAT_TO_BFLOAT16_DEFAULT', 3)}")
if os.getenv("NVTE_CK_FUSED_ATTN_PATH"):
ck_path = Path(os.getenv("NVTE_CK_FUSED_ATTN_PATH"))
cmake_flags.append(f"-DAITER_MHA_PATH={ck_path}")
cmake_flags.append(
f"-DCK_FUSED_ATTN_FLOAT_TO_BFLOAT16_DEFAULT={os.getenv('NVTE_CK_FUSED_ATTN_FLOAT_TO_BFLOAT16_DEFAULT', '3')}"
)

if int(os.getenv("NVTE_FUSED_ATTN_AOTRITON", "1"))==0 or int(os.getenv("NVTE_FUSED_ATTN", "1"))==0:
cmake_flags.append("-DUSE_FUSED_ATTN_AOTRITON=OFF")
Expand Down
90 changes: 66 additions & 24 deletions transformer_engine/common/ck_fused_attn/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ set(CMAKE_CXX_STANDARD 17)
project(ck_fused_attn LANGUAGES HIP CXX)


set(AITER_MHA_INSTALL_PREFIX "transformer_engine" CACHE STRING "aiter mha shared lib install prefix in TE")
set(AITER_MHA_INSTALL_DIR "${CMAKE_INSTALL_PREFIX}/transformer_engine/lib")

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Minor: this replaces the previous AITER_MHA_INSTALL_PREFIX CACHE STRING with a hardcoded path. The old form let a downstream consumer override the install layout with -DAITER_MHA_INSTALL_PREFIX=…; the new form bakes transformer_engine/lib in. Probably no one was actually using the override, but if you don't need configurability the CACHE STRING doc-string was worth keeping for grep-ability. Not blocking.


#Corresponding runtime check is in nvte_get_fused_attn_backend()
list(FIND CMAKE_HIP_ARCHITECTURES "gfx1250" _gfx1250_idx)
Expand Down Expand Up @@ -67,22 +67,56 @@ else()
message(WARNING "Python interpreter not found; skipping AITER API validation.")
endif()

if(DEFINED AITER_MHA_PATH)
message(STATUS "[AITER-BUILD] Using AITER_MHA_PATH=${AITER_MHA_PATH}")
# use pre-built te_libmha_fwd.so te_libmha_bwd.so
set(__AITER_MHA_PATH ${AITER_MHA_PATH})
set(__AITER_CACHE_DIR "")
set(__AITER_MHA_PATH "")
set(__QOLA_INCLUDE_DIR "")
if(NOT "$ENV{NVTE_CK_JIT}" STREQUAL "0")
set(__USE_CK_JIT TRUE)
else()
set(__AITER_MHA_PATH "")
set(__USE_CK_JIT FALSE)
endif()
if(DEFINED ENV{AITER_MHA_PATH})
message(STATUS "[AITER-BUILD] Using AITER_MHA_PATH=$ENV{AITER_MHA_PATH}")
# use pre-built libraries and includes from a location specified by the user
set(__AITER_CACHE_DIR $ENV{AITER_MHA_PATH})
elseif(NOT __USE_CK_JIT) #disable for CK_JIT for now
# use pre-built cache
include("${CMAKE_CURRENT_LIST_DIR}/aiter_prebuilt.cmake")
get_prebuilt_aiter(__AITER_MHA_PATH)
get_prebuilt_aiter(__AITER_CACHE_DIR)
elseif(DEFINED ENV{NVTE_AITER_PREBUILT_BASE_URL})
message(WARNING "[AITER-BUILD] NVTE_AITER_PREBUILT_BASE_URL is set but will be ignored because CK_JIT is enabled.")
endif()
Comment on lines +73 to +88

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Two concerns about the default-on behavior:

  1. Default-on with a fail-closed signal. __USE_CK_JIT is TRUE whenever NVTE_CK_JIT is anything except literal "0" — unset, empty, "off", "false", etc. all enable CK_JIT. If you want default-on, fine, but the check NOT "$ENV{NVTE_CK_JIT}" STREQUAL "0" makes "opt out" only work for NVTE_CK_JIT=0. A clearer pattern is to use a CMake option() plus an env-var override so the choice surfaces in the configure summary and is documented.

  2. Default-on skips the prebuilt-cache path entirely. The elseif(NOT __USE_CK_JIT) on line 82 means that with the default settings (no env vars set) we never call get_prebuilt_aiter(), so the existing on-disk cache and NVTE_AITER_PREBUILT_BASE_URL download path are bypassed and we always rebuild via CK_JIT. That's a non-trivial CI/dev-loop regression vs. the previous default. If CK_JIT is meant to also honor that cache, this branch should still call into aiter_prebuilt.cmake first; if not, the PR description should call this out as an intentional behavior change. Also: NVTE_CK_JIT / NVTE_CK_JIT_DIR aren't documented anywhere — please add a short note (README or env-var table) since this is the new default.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ditto on point 2, I think we should still use the pre-built cache if available as first priority.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Using of pre-built AITER cache was disabled for now because the idea of CK_JIT is reduce building time so the cache is not needed. While in general we may still have benefits of cache, enabling of it with CK_JIT requires at minimum uploading of CK_JIT built AITER libs and to avoid conflict with other (non CK_JIT) branches different AITER commit is needed for CK_JIT. So till this logistic task is completed and for wider testing, CK_JIT was made mutually exclusive with AITER pre-built cache.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I meant more so the idea that if there's a pre-built cache already, we don't need to use CK JIT at all -- not that we need to cache CK JIT. If we explicitly use CK JIT, then absolutely no need to worry about the prebuilt cache, but it's the case where we don't specify build strategy that I'm concerned about. In that case, I think we still need to be able to use a pre-built cache by default if available. What do you think?


if(__AITER_MHA_PATH STREQUAL "")
# If not available, fallback: Build from source via QoLA
list(JOIN CMAKE_HIP_ARCHITECTURES ";" GPU_ARCHS_STR)
message(STATUS "[AITER-BUILD] Building AITER kernels for ${GPU_ARCHS_STR} via QoLA.")
set(__QOLA_DIR "${CMAKE_CURRENT_LIST_DIR}/../../../3rdparty/QoLA")
if(__AITER_CACHE_DIR STREQUAL "")
# If not available or not requested, build from source via QoLA
list(JOIN CMAKE_HIP_ARCHITECTURES ";" GPU_ARCHS_STR)
message(STATUS "[AITER-BUILD] Building AITER kernels for ${GPU_ARCHS_STR} via QoLA.")
set(__QOLA_DIR "${CMAKE_CURRENT_LIST_DIR}/../../../3rdparty/QoLA")
set(__QOLA_MANIFEST "${CMAKE_CURRENT_LIST_DIR}/qola_manifest.toml")
if(__USE_CK_JIT)
message(STATUS "[AITER-BUILD] CK_JIT is enabled; will build CK kernels via CK_JIT.")
set(__CK_JIT_BUILD_DIR "${CMAKE_CURRENT_BINARY_DIR}/ck_jit")
set(__QOLA_BUILD_DIR "${__CK_JIT_BUILD_DIR}/qola") #Need it under ck_jit to clean on full build
if(DEFINED ENV{NVTE_CK_JIT_DIR})
set(__CK_JIT_SOURCE_DIR $ENV{NVTE_CK_JIT_DIR})
else()
set(__CK_JIT_SOURCE_DIR "${CMAKE_CURRENT_LIST_DIR}/../../../3rdparty/ck_jit")
endif()
execute_process(
COMMAND ${Python_EXECUTABLE} "${__CK_JIT_SOURCE_DIR}/ck_jit_build.py" full
--with-qola
--qola-dir ${__QOLA_DIR}
--qola-manifest ${__QOLA_MANIFEST}
--qola-output "${__QOLA_BUILD_DIR}"
--gpu-archs "${GPU_ARCHS_STR}"
--aiter-dir ${__AITER_SOURCE_DIR}
--tmp-dir "${__CK_JIT_BUILD_DIR}"
--install-dir ${AITER_MHA_INSTALL_DIR}
--jit-name "te_ck_jit"
RESULT_VARIABLE QOLA_BUILD_RESULT
)
else()
set(__QOLA_BUILD_DIR "${__QOLA_DIR}/build")
set(__QOLA_MANIFEST "${CMAKE_CURRENT_LIST_DIR}/qola_manifest.toml")
execute_process(
COMMAND ${CMAKE_COMMAND} -E env "PYTHONPATH=${__QOLA_DIR}:$ENV{PYTHONPATH}"
${Python_EXECUTABLE} -m qola.cli build
Expand All @@ -92,22 +126,29 @@ else()
--arch "${GPU_ARCHS_STR}"
RESULT_VARIABLE QOLA_BUILD_RESULT
)
if(NOT QOLA_BUILD_RESULT EQUAL 0)
message(FATAL_ERROR "[AITER-BUILD] QoLA build failed.")
endif()
endif()
if(NOT QOLA_BUILD_RESULT EQUAL 0)
message(FATAL_ERROR "[AITER-BUILD] QoLA build failed.")
endif()

if(__USE_CK_JIT)
set(__AITER_MHA_PATH ${AITER_MHA_INSTALL_DIR})
set(__QOLA_INCLUDE_DIR "${__QOLA_BUILD_DIR}/include")
else()
# Copy the final .so libs and exported public headers into the aiter
# prebuilt cache so downstream consumers see a self-contained tree.
get_default_aiter_cache_dir(__QOLA_CACHE_DIR)
set(__QOLA_CACHE_LIB "${__QOLA_CACHE_DIR}/lib")
get_default_aiter_cache_dir(__AITER_CACHE_DIR)
set(__QOLA_CACHE_LIB "${__AITER_CACHE_DIR}/lib")
file(MAKE_DIRECTORY ${__QOLA_CACHE_LIB})
file(GLOB __QOLA_BUILT_LIBS "${__QOLA_BUILD_DIR}/lib/*.so")
file(COPY ${__QOLA_BUILT_LIBS} DESTINATION ${__QOLA_CACHE_LIB})
file(COPY "${__QOLA_BUILD_DIR}/include" DESTINATION "${__QOLA_CACHE_DIR}")
file(COPY "${__QOLA_BUILD_DIR}/include" DESTINATION "${__AITER_CACHE_DIR}")
set(__AITER_MHA_PATH "${__QOLA_CACHE_LIB}")
else()
message(STATUS "[AITER-BUILD] Using pre-built AITER from ${__AITER_MHA_PATH}")
set(__QOLA_INCLUDE_DIR "${__AITER_CACHE_DIR}/include")
endif()
else()
set(__AITER_MHA_PATH "${__AITER_CACHE_DIR}/lib")
set(__QOLA_INCLUDE_DIR "${__AITER_CACHE_DIR}/include")
endif()

set(ck_fused_attn_SOURCES)
Expand All @@ -129,7 +170,6 @@ list(APPEND CK_FUSED_ATTN_COMPILE_OPTIONS
# Public QoLA headers ship alongside the .so libs in ${__AITER_MHA_PATH}/../include
# (emitted by qola.cli build, or copied from the QoLA build dir above for the
# source-build path).
set(__QOLA_INCLUDE_DIR "${__AITER_MHA_PATH}/../include")
if(NOT EXISTS "${__QOLA_INCLUDE_DIR}/qola_config.h")
message(FATAL_ERROR "Could not find QoLA public headers at ${__QOLA_INCLUDE_DIR}.")
endif()
Expand All @@ -146,5 +186,7 @@ target_link_libraries(ck_fused_attn PUBLIC ${ck_fused_attn_LINKER_LIBS})
target_compile_options(ck_fused_attn PRIVATE ${CK_FUSED_ATTN_COMPILE_OPTIONS})
set_target_properties(ck_fused_attn PROPERTIES INSTALL_RPATH "$ORIGIN")

install(FILES ${__AITER_MHA_PATH}/te_libmha_fwd.so ${__AITER_MHA_PATH}/te_libmha_bwd.so DESTINATION ${CMAKE_INSTALL_PREFIX}/${AITER_MHA_INSTALL_PREFIX}/lib)
install(TARGETS ck_fused_attn DESTINATION ${CMAKE_INSTALL_PREFIX}/${AITER_MHA_INSTALL_PREFIX}/lib)
if (NOT "${__AITER_MHA_PATH}" STREQUAL "${AITER_MHA_INSTALL_DIR}")
install(FILES ${__AITER_MHA_PATH}/te_libmha_fwd.so ${__AITER_MHA_PATH}/te_libmha_bwd.so DESTINATION ${AITER_MHA_INSTALL_DIR})
endif()
install(TARGETS ck_fused_attn DESTINATION ${AITER_MHA_INSTALL_DIR})
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ function(get_prebuilt_aiter PREBUILT_DIR_VAR)
is_aiter_cache_valid("${ROCM_VER_PARAM}" RESULT)
if(RESULT)
get_aiter_cache_key("${ROCM_VER_PARAM}" _UNUSED CACHE_DIR)
set(${PREBUILT_DIR_VAR} "${CACHE_DIR}/lib" PARENT_SCOPE)
set(${PREBUILT_DIR_VAR} "${CACHE_DIR}" PARENT_SCOPE)
return()
endif()
endforeach()
Expand All @@ -62,7 +62,7 @@ function(get_prebuilt_aiter PREBUILT_DIR_VAR)
download_aiter_prebuilt("${ROCM_VER_PARAM}" RESULT)
if(RESULT)
get_aiter_cache_key("${ROCM_VER_PARAM}" _UNUSED CACHE_DIR)
set(${PREBUILT_DIR_VAR} "${CACHE_DIR}/lib" PARENT_SCOPE)
set(${PREBUILT_DIR_VAR} "${CACHE_DIR}" PARENT_SCOPE)
return()
endif()
endforeach()
Expand Down
Loading