Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
19 commits
Select commit Hold shift + click to select a range
8f48034
Set local_devices from CUDA_VISIBLE_DEVICES when initializing distrib…
gabeweisz Dec 9, 2025
3473e4b
adding rocm_jax_0.7.1 reqs
gulsumgudukbay Feb 3, 2026
7093ef8
Revert "removing CI workflows for now to upstream decoupling changes"
gulsumgudukbay Feb 3, 2026
27c577a
Merge branch 'AI-Hypercomputer:main' into rocm-main
gulsumgudukbay Feb 5, 2026
82f9c1d
skip ring attention test on ROCm
gulsumgudukbay Feb 9, 2026
267f80c
[DOWNSTREAM-ONLY] update schedule for build_and_test_maxtext
gulsumgudukbay Feb 11, 2026
ce49fbb
Merge branch 'AI-Hypercomputer:main' into rocm-main
gulsumgudukbay Feb 11, 2026
fe391d7
Merge branch 'AI-Hypercomputer:main' into rocm-main
gulsumgudukbay Feb 11, 2026
599c824
adding jax 0.8.2 requirements
gulsumgudukbay Feb 11, 2026
e465c0d
update configs in tests to use helper functions
gulsumgudukbay Feb 12, 2026
ba10994
update ici parallelism for decoupled mode
gulsumgudukbay Feb 13, 2026
7bd554d
adding TE build and upload CI workflow
gulsumgudukbay Feb 16, 2026
b38ac0e
Adding CI workflow changes for ROCm and JAX 0.8.2 requirements files
gulsumgudukbay Feb 3, 2026
d22eebb
update te wheel consumption
gulsumgudukbay Feb 20, 2026
4f593d7
refactoring TE wheel release workflow
gulsumgudukbay Feb 16, 2026
b47f33f
Merge branch 'AI-Hypercomputer:main' into rocm-main
gulsumgudukbay Feb 24, 2026
9d40754
update runner labels to mi355
gulsumgudukbay Feb 24, 2026
f701281
[DOWNSTREAM-ONLY] fix ROCm version finding for TE release wheels
gulsumgudukbay Feb 24, 2026
46f1378
fix syntax
gulsumgudukbay Feb 24, 2026
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 48 additions & 20 deletions .github/workflows/build_and_test_maxtext.yml
Original file line number Diff line number Diff line change
Expand Up @@ -20,16 +20,23 @@ on:
pull_request:
workflow_call:
workflow_dispatch:
inputs:
rocm_only:
description: 'Run only ROCm jobs (manual runs only)'
type: boolean
required: false
default: false
schedule:
# Run the job every 4 hours
- cron: '0 */4 * * *'
- cron: '0 3 * * *' # Daily 03:00 UTC

concurrency:
# Dedup pull requests (canceling previous runs of the same workflow for same PR), and scheduled runs but nothing else
# Cancel previous runs for same PR (all actors), scheduled runs,
# and manual runs per (branch + actor).
group: >
${{
github.event_name == 'pull_request' && format('{0}-pr-{1}', github.workflow, github.event.pull_request.number) ||
github.event_name == 'schedule' && format('{0}-schedule', github.workflow) ||
github.event_name == 'workflow_dispatch' && format('{0}-manual-{1}-{2}', github.workflow, github.ref, github.actor) ||
github.run_id
}}
cancel-in-progress: true
Expand Down Expand Up @@ -93,18 +100,16 @@ jobs:
build_and_upload_maxtext_package:
needs: doc_only_check
# Run if either tests or notebooks need to run
if: |
needs.doc_only_check.outputs.run_tests == 'true' ||
needs.doc_only_check.outputs.run_notebooks == 'true'
if: ${{ vars.ROCM_ONLY == 'true' || needs.doc_only_check.outputs.run_tests == 'true' || needs.doc_only_check.outputs.run_notebooks == 'true' }}
uses: ./.github/workflows/build_package.yml
with:
device_type: tpu
device_name: v4-8
cloud_runner: linux-x86-n2-16-buildkit
device_type: ${{ vars.ROCM_ONLY == 'true' && 'rocm' || 'tpu' }}
device_name: ${{ vars.ROCM_ONLY == 'true' && 'mi355' || 'v4-8' }}
cloud_runner: ${{ vars.ROCM_ONLY == 'true' && 'linux-x86-64-4gpu-amd' || 'linux-x86-n2-16-buildkit' }}

maxtext_jupyter_notebooks:
needs: build_and_upload_maxtext_package
if: needs.doc_only_check.outputs.run_notebooks == 'true'
if: ${{ vars.ROCM_ONLY != 'true' && ((!(github.event_name == 'workflow_dispatch' && inputs.rocm_only) && (needs.doc_only_check.outputs.run_notebooks == 'true'))) }}
uses: ./.github/workflows/run_jupyter_notebooks.yml
strategy:
fail-fast: false
Expand All @@ -121,7 +126,7 @@ jobs:

maxtext_cpu_unit_tests:
needs: build_and_upload_maxtext_package
if: needs.doc_only_check.outputs.run_tests == 'true'
if: ${{ vars.ROCM_ONLY != 'true' && (!(github.event_name == 'workflow_dispatch' && inputs.rocm_only) && (needs.doc_only_check.outputs.run_tests == 'true')) }}
uses: ./.github/workflows/run_tests_against_package.yml
strategy:
fail-fast: false # don't cancel all jobs on failure
Expand All @@ -144,7 +149,7 @@ jobs:

maxtext_tpu_unit_tests:
needs: build_and_upload_maxtext_package
if: needs.doc_only_check.outputs.run_tests == 'true'
if: ${{ vars.ROCM_ONLY != 'true' && (!(github.event_name == 'workflow_dispatch' && inputs.rocm_only) && (needs.doc_only_check.outputs.run_tests == 'true')) }}
uses: ./.github/workflows/run_tests_against_package.yml
strategy:
fail-fast: false
Expand All @@ -164,7 +169,7 @@ jobs:

maxtext_tpu_integration_tests:
needs: build_and_upload_maxtext_package
if: needs.doc_only_check.outputs.run_tests == 'true'
if: ${{ vars.ROCM_ONLY != 'true' && (!(github.event_name == 'workflow_dispatch' && inputs.rocm_only) && (needs.doc_only_check.outputs.run_tests == 'true')) }}
uses: ./.github/workflows/run_tests_against_package.yml
strategy:
fail-fast: false
Expand All @@ -184,7 +189,7 @@ jobs:

maxtext_tpu_pathways_unit_tests:
needs: build_and_upload_maxtext_package
if: needs.doc_only_check.outputs.run_tests == 'true'
if: ${{ vars.ROCM_ONLY != 'true' && ((!(github.event_name == 'workflow_dispatch' && inputs.rocm_only) && (needs.doc_only_check.outputs.run_tests == 'true'))) }}
uses: ./.github/workflows/run_pathways_tests.yml
strategy:
fail-fast: false
Expand All @@ -204,7 +209,7 @@ jobs:

maxtext_tpu_pathways_integration_tests:
needs: build_and_upload_maxtext_package
if: needs.doc_only_check.outputs.run_tests == 'true'
if: ${{ vars.ROCM_ONLY != 'true' && ((!(github.event_name == 'workflow_dispatch' && inputs.rocm_only) && (needs.doc_only_check.outputs.run_tests == 'true'))) }}
uses: ./.github/workflows/run_pathways_tests.yml
strategy:
fail-fast: false
Expand All @@ -224,7 +229,7 @@ jobs:

maxtext_gpu_unit_tests:
needs: build_and_upload_maxtext_package
if: needs.doc_only_check.outputs.run_tests == 'true'
if: ${{ vars.ROCM_ONLY != 'true' && (!(github.event_name == 'workflow_dispatch' && inputs.rocm_only) && (needs.doc_only_check.outputs.run_tests == 'true')) }}
uses: ./.github/workflows/run_tests_against_package.yml
strategy:
fail-fast: false
Expand All @@ -245,7 +250,7 @@ jobs:

maxtext_gpu_integration_tests:
needs: build_and_upload_maxtext_package
if: needs.doc_only_check.outputs.run_tests == 'true'
if: ${{ vars.ROCM_ONLY != 'true' && (!(github.event_name == 'workflow_dispatch' && inputs.rocm_only) && (needs.doc_only_check.outputs.run_tests == 'true')) }}
uses: ./.github/workflows/run_tests_against_package.yml
strategy:
fail-fast: false
Expand All @@ -264,9 +269,31 @@ jobs:
is_scheduled_run: ${{ github.event_name == 'schedule' }}
maxtext_sha: ${{ needs.build_and_upload_maxtext_package.outputs.maxtext_sha }}

maxtext_rocm_decoupled_unit_tests:
needs: build_and_upload_maxtext_package
if: ${{ vars.ROCM_ONLY == 'true' || (github.event_name == 'workflow_dispatch' && inputs.rocm_only) || needs.doc_only_check.outputs.run_tests == 'true' }}
uses: ./.github/workflows/run_tests_against_package.yml
strategy:
fail-fast: false
matrix:
image_type: ["py312"]
with:
device_type: rocm
device_name: mi355
image_type: ${{ matrix.image_type }}
cloud_runner: linux-x86-64-4gpu-amd
pytest_marker: 'decoupled'
xla_python_client_mem_fraction: 0.9
tf_force_gpu_allow_growth: true
container_resource_option: "--device=/dev/kfd --device=/dev/dri --group-add video --ipc=host --shm-size 64g --cap-add=SYS_PTRACE --security-opt seccomp=unconfined --privileged"
requirements_file: "dependencies/requirements/requirements_decoupled_rocm_jax_0_8_2.txt"
decoupled_mode: true
is_scheduled_run: ${{ github.event_name == 'schedule' }}
maxtext_sha: ${{ needs.build_and_upload_maxtext_package.outputs.maxtext_sha }}

all_tests_passed:
name: All Required Tests Passed
needs: [doc_only_check, build_and_upload_maxtext_package, maxtext_cpu_unit_tests, maxtext_tpu_unit_tests, maxtext_tpu_integration_tests, maxtext_tpu_pathways_unit_tests, maxtext_tpu_pathways_integration_tests, maxtext_gpu_unit_tests, maxtext_gpu_integration_tests]
needs: [doc_only_check, build_and_upload_maxtext_package, maxtext_cpu_unit_tests, maxtext_tpu_unit_tests, maxtext_tpu_integration_tests, maxtext_tpu_pathways_unit_tests, maxtext_tpu_pathways_integration_tests, maxtext_gpu_unit_tests, maxtext_gpu_integration_tests, maxtext_rocm_decoupled_unit_tests]
if: always()
runs-on: ubuntu-latest
steps:
Expand All @@ -287,6 +314,7 @@ jobs:
echo "TPU pathways integration: ${{ needs.maxtext_tpu_pathways_integration_tests.result }}"
echo "GPU tests: ${{ needs.maxtext_gpu_unit_tests.result }}"
echo "GPU integration: ${{ needs.maxtext_gpu_integration_tests.result }}"
echo "ROCm decoupled tests: ${{ needs.maxtext_rocm_decoupled_unit_tests.result }}"

# Fail only if any job failed or was cancelled (skipped is OK)
if [ "${{ contains(needs.*.result, 'failure') }}" == "true" ] || [ "${{ contains(needs.*.result, 'cancelled') }}" == "true" ]; then
Expand All @@ -299,7 +327,7 @@ jobs:
all_notebooks_passed:
name: All Notebooks Passed
needs: [doc_only_check, build_and_upload_maxtext_package, maxtext_jupyter_notebooks]
if: always()
if: ${{ vars.ROCM_ONLY != 'true' && ((!(github.event_name == 'workflow_dispatch' && inputs.rocm_only) && (always()))) }}
runs-on: ubuntu-latest
steps:
- name: Check notebooks results
Expand All @@ -323,7 +351,7 @@ jobs:

notify_failure:
name: Notify failed build # creates an issue or modifies last open existing issue for failed build
needs: [maxtext_jupyter_notebooks, maxtext_cpu_unit_tests, maxtext_tpu_unit_tests, maxtext_tpu_integration_tests, maxtext_tpu_pathways_unit_tests, maxtext_tpu_pathways_integration_tests, maxtext_gpu_unit_tests, maxtext_gpu_integration_tests]
needs: [maxtext_jupyter_notebooks, maxtext_cpu_unit_tests, maxtext_tpu_unit_tests, maxtext_tpu_integration_tests, maxtext_tpu_pathways_unit_tests, maxtext_tpu_pathways_integration_tests, maxtext_gpu_unit_tests, maxtext_gpu_integration_tests, maxtext_rocm_decoupled_unit_tests]
if: ${{ always() }}
runs-on: ubuntu-latest
permissions:
Expand Down
191 changes: 191 additions & 0 deletions .github/workflows/build_rocm_transformer_engine_wheel_weekly.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,191 @@
name: Build ROCm TransformerEngine wheel (weekly)

on:
workflow_dispatch:
schedule:
# Weekly at night (02:00 UTC every Monday), 2 hours ahead of scheduled tests.
- cron: "0 2 * * 1"

permissions:
contents: write

jobs:
build_upload_prune:
# AMD GPU runner (GitHub-hosted large runner label).
runs-on: linux-x86-64-4gpu-amd
container:
image: ghcr.io/rocm/jax-base-ubu24.rocm720:latest
options: >-
--device=/dev/kfd --device=/dev/dri --group-add video
--ipc=host --shm-size 64g
--cap-add=SYS_PTRACE --security-opt seccomp=unconfined
--privileged
env:
XLA_PYTHON_CLIENT_MEM_FRACTION: "0.9"
NVTE_FUSED_ATTN_AOTRITON: "0"

steps:
- name: Checkout
uses: actions/checkout@v5

- name: Setup build environment (deps + venv)
shell: bash
run: |
set -euo pipefail
apt-get update
apt-get install -y --no-install-recommends git build-essential python3-dev
python3 -m pip install -U uv
python3 -m uv venv --seed
source .venv/bin/activate
uv pip install -U pip setuptools wheel pybind11 cmake

- name: Install ROCm JAX/JAXlib wheels (build against CI stack)
shell: bash
run: |
set -euo pipefail
source .venv/bin/activate
uv pip install -r dependencies/requirements/requirements_decoupled_rocm_jax_0_8_2.txt

- name: Detect ROCm version and Python tag
shell: bash
run: |
set -euo pipefail
source .venv/bin/activate

ROCM_NUM="$([ -f /opt/rocm/.info/version ] && head -n1 /opt/rocm/.info/version | tr -d '[:space:]' || echo unknown)"
echo "Detected ROCm version: ${ROCM_NUM}"
echo "ROCM_NUM=${ROCM_NUM}" >> "${GITHUB_ENV}"

PYTAG="cp$(python3 -c 'import sys; print(f"{sys.version_info.major}{sys.version_info.minor}")')"
if [ "${PYTAG}" != "cp312" ]; then
echo "Expected Python 3.12 (cp312) for ROCm CI wheels, got ${PYTAG}."
exit 1
fi
echo "PYTAG=${PYTAG}" >> "${GITHUB_ENV}"
echo "REL_SCRIPT=.github/workflows/utils/te_wheels_release.py" >> "${GITHUB_ENV}"

- name: Clone ROCm/TransformerEngine (dev)
shell: bash
run: |
set -euo pipefail
rm -rf TransformerEngine
git clone --recursive --branch dev https://github.com/ROCm/TransformerEngine.git
cd TransformerEngine
git submodule update --init --recursive
TE_SHA="$(git rev-parse --short=12 HEAD)"
echo "TE_SHA=${TE_SHA}" >> "${GITHUB_ENV}"

- name: Select TE wheel arch for runner (mi300/mi355)
shell: bash
run: |
set -euo pipefail
source .venv/bin/activate

TE_WHEEL_ARCH="$(python3 .github/workflows/utils/install_te_rocm_wheel.py --print-arch)"
echo "Resolved TE wheel arch for runner: ${TE_WHEEL_ARCH}"
echo "TE_WHEEL_ARCH=${TE_WHEEL_ARCH}" >> "${GITHUB_ENV}"

# Build ONLY for the ROCm arch present on this CI runner (mi300 or mi355).
if [ "${TE_WHEEL_ARCH}" = "mi355" ]; then
SELECTOR="mi355"
GFX_ARCH="gfx950"
else
SELECTOR="mi300"
GFX_ARCH="gfx942;gfx941"
fi
echo "SELECTOR=${SELECTOR}" >> "${GITHUB_ENV}"
echo "GFX_ARCH=${GFX_ARCH}" >> "${GITHUB_ENV}"

- name: Build TE wheel
shell: bash
run: |
set -euo pipefail
source .venv/bin/activate

chmod +x "${REL_SCRIPT}" || true

export USE_ROCM=1
export HIP_PATH=/opt/rocm
export NVTE_FRAMEWORK=jax
export CMAKE_BUILD_PARALLEL_LEVEL=64
export NVTE_USE_ROCM=1
export NVTE_FUSED_ATTN_AOTRITON=0
export NVTE_BUILD_MAX_JOBS=180
#export NVTE_AITER_PREBUILT_BASE_URL=https://compute-artifactory.amd.com:5000/artifactory/rocm-generic-local/te-ci/aiter-prebuilts

echo "=== Building TE wheel for ${SELECTOR} (gfx=${GFX_ARCH}) ==="
pushd TransformerEngine >/dev/null
rm -rf build dist
export PYTHONPATH="$(pwd)/3rdparty/hipify_torch${PYTHONPATH:+:${PYTHONPATH}}"
export PYTORCH_ROCM_ARCH="${GFX_ARCH}"
export NVTE_ROCM_ARCH="${GFX_ARCH}"
python3 setup.py bdist_wheel
wheel_path="$(
python3 -c 'import glob; m=sorted(glob.glob("dist/transformer_engine-*.whl")); print(m[0] if m else "")'
)"
if [ -z "${wheel_path}" ]; then
echo "No wheel produced in dist/ (selector=${SELECTOR})."
exit 1
fi
wheel_base="$(basename "${wheel_path}")"
if [[ "${wheel_base}" == *"-1.${SELECTOR}-${PYTAG}-${PYTAG}-linux_x86_64.whl" ]]; then
asset_name="${wheel_base}"
else
asset_name="${wheel_base/-${PYTAG}-${PYTAG}-linux_x86_64.whl/-1.${SELECTOR}-${PYTAG}-${PYTAG}-linux_x86_64.whl}"
if [ "${asset_name}" = "${wheel_base}" ]; then
echo "Failed to rename wheel for selector=${SELECTOR}: ${wheel_base}"
exit 1
fi
fi
cp -f "${wheel_path}" "../${asset_name}"
popd >/dev/null

ls -lh "${asset_name}"
echo "TE_WHEEL_FILE=${asset_name}" >> "${GITHUB_ENV}"

- name: Upload wheel to rolling release tag (te-rocm-wheels)
shell: bash
env:
GITHUB_TOKEN: ${{ github.token }}
run: |
set -euo pipefail
source .venv/bin/activate
python3 "${REL_SCRIPT}" upload --no-prerelease --tag "te-rocm-wheels" --title "ROCm TransformerEngine wheels (latest)" --body "Rolling release for latest weekly-built ROCm TransformerEngine wheels used by CI." --file "${TE_WHEEL_FILE}"

- name: Prune old assets from rolling tag
shell: bash
env:
GITHUB_TOKEN: ${{ github.token }}
run: |
set -euo pipefail
source .venv/bin/activate
python3 "${REL_SCRIPT}" prune-assets --tag "te-rocm-wheels" --keep-days 21

- name: Publish wheel to dated weekly release tag
shell: bash
env:
GITHUB_TOKEN: ${{ github.token }}
run: |
set -euo pipefail
source .venv/bin/activate

DATE_UTC="$(date -u +%Y-%m-%d)"
WEEKLY_TAG="te-rocm-wheels-${DATE_UTC}-${TE_SHA}"
WEEKLY_TITLE="ROCm TransformerEngine wheels ${DATE_UTC} (TE ${TE_SHA})"
# Keep this YAML-safe (no unindented heredocs inside `run: |`).
WEEKLY_BODY="$(
printf '%s\n\nROCm: %s\nPython: %s\nArch: %s (gfx=%s)\n' \
"Built from ROCm/TransformerEngine dev @ ${TE_SHA} on ${DATE_UTC}." \
"${ROCM_NUM}" "${PYTAG}" "${SELECTOR}" "${GFX_ARCH}"
)"

python3 "${REL_SCRIPT}" upload --no-prerelease --tag "${WEEKLY_TAG}" --title "${WEEKLY_TITLE}" --body "${WEEKLY_BODY}" --file "${TE_WHEEL_FILE}"

- name: Prune old weekly releases
shell: bash
env:
GITHUB_TOKEN: ${{ github.token }}
run: |
set -euo pipefail
source .venv/bin/activate
python3 "${REL_SCRIPT}" prune-releases --prefix "te-rocm-wheels-" --keep-days 21
Loading
Loading