Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
cb7f66b
adding rocm_jax_0.7.1 reqs
gulsumgudukbay Feb 3, 2026
53f94f3
Revert "removing CI workflows for now to upstream decoupling changes"
gulsumgudukbay Feb 3, 2026
070adee
skip ring attention test on ROCm
gulsumgudukbay Feb 9, 2026
87a13a8
[DOWNSTREAM-ONLY] update schedule for build_and_test_maxtext
gulsumgudukbay Feb 11, 2026
65d3f9d
adding jax 0.8.2 requirements
gulsumgudukbay Feb 11, 2026
d1fd5f0
update configs in tests to use helper functions
gulsumgudukbay Feb 12, 2026
280fd87
adding TE build and upload CI workflow
gulsumgudukbay Feb 16, 2026
b40ce36
Adding CI workflow changes for ROCm and JAX 0.8.2 requirements files
gulsumgudukbay Feb 3, 2026
78d464b
update te wheel consumption
gulsumgudukbay Feb 20, 2026
d6ecca7
refactoring TE wheel release workflow
gulsumgudukbay Feb 16, 2026
63f59dd
update runner labels to mi355
gulsumgudukbay Feb 24, 2026
86ebc19
fix te wheel selection
gulsumgudukbay Mar 10, 2026
77372cf
Remove ROCm fused-attention backend variables
gulsumgudukbay Mar 11, 2026
79eb75b
refactor requirements location change
gulsumgudukbay Mar 12, 2026
fd70099
fix TE build workflow, add rocm torch dependency to env
gulsumgudukbay Mar 31, 2026
67f4d19
fix ci
gulsumgudukbay Apr 8, 2026
65affee
update requirements
gulsumgudukbay May 5, 2026
a95c831
refactor req files in CI
gulsumgudukbay May 5, 2026
f8ac7f3
fix index url usage in reqs
gulsumgudukbay May 6, 2026
7190f0f
update testing and requirements
gulsumgudukbay May 6, 2026
9d55814
Add ROCm benchmark configs and requirements for MaxText
psanal35 May 11, 2026
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
75 changes: 57 additions & 18 deletions .github/workflows/build_and_test_maxtext.yml
Original file line number Diff line number Diff line change
Expand Up @@ -20,16 +20,23 @@ on:
pull_request:
workflow_call:
workflow_dispatch:
inputs:
rocm_only:
description: 'Run only ROCm jobs (manual runs only)'
type: boolean
required: false
default: false
schedule:
# Run the job every 4 hours
- cron: '0 */4 * * *'
- cron: '0 3 * * *' # Daily 03:00 UTC

concurrency:
# Dedup pull requests (canceling previous runs of the same workflow for same PR), and scheduled runs but nothing else
# Cancel previous runs for same PR (all actors), scheduled runs,
# and manual runs per (branch + actor).
group: >
${{
github.event_name == 'pull_request' && format('{0}-pr-{1}', github.workflow, github.event.pull_request.number) ||
github.event_name == 'schedule' && format('{0}-schedule', github.workflow) ||
github.event_name == 'workflow_dispatch' && format('{0}-manual-{1}-{2}', github.workflow, github.ref, github.actor) ||
github.run_id
}}
cancel-in-progress: true
Expand Down Expand Up @@ -118,18 +125,16 @@ jobs:
build_and_upload_maxtext_package:
needs: analyze_code_changes
# Run if either tests or notebooks need to run
if: |
needs.analyze_code_changes.outputs.run_tests == 'true' ||
needs.analyze_code_changes.outputs.run_notebooks == 'true'
if: ${{ vars.ROCM_ONLY == 'true' || needs.analyze_code_changes.outputs.run_tests == 'true' || needs.analyze_code_changes.outputs.run_notebooks == 'true' }}
uses: ./.github/workflows/build_package.yml
with:
device_type: tpu
device_name: v4-8
cloud_runner: linux-x86-n2-16-buildkit
device_type: ${{ vars.ROCM_ONLY == 'true' && 'rocm' || 'tpu' }}
device_name: ${{ vars.ROCM_ONLY == 'true' && 'mi355' || 'v4-8' }}
cloud_runner: ${{ vars.ROCM_ONLY == 'true' && 'linux-x86-64-4gpu-amd' || 'linux-x86-n2-16-buildkit' }}

maxtext_jupyter_notebooks:
needs: build_and_upload_maxtext_package
if: needs.analyze_code_changes.outputs.run_notebooks == 'true'
if: ${{ vars.ROCM_ONLY != 'true' && !(github.event_name == 'workflow_dispatch' && inputs.rocm_only) && needs.analyze_code_changes.outputs.run_notebooks == 'true' }}
uses: ./.github/workflows/run_jupyter_notebooks.yml
strategy:
fail-fast: false
Expand All @@ -145,7 +150,7 @@ jobs:
tpu-tests:
name: ${{ matrix.flavor }} tests
needs: [build_and_upload_maxtext_package]
if: needs.analyze_code_changes.outputs.run_tests == 'true'
if: ${{ vars.ROCM_ONLY != 'true' && !(github.event_name == 'workflow_dispatch' && inputs.rocm_only) && needs.analyze_code_changes.outputs.run_tests == 'true' }}
uses: ./.github/workflows/run_tests_coordinator.yml
strategy:
fail-fast: false
Expand All @@ -160,7 +165,7 @@ jobs:
gpu-tests:
name: ${{ matrix.flavor }} tests
needs: [build_and_upload_maxtext_package]
if: needs.analyze_code_changes.outputs.run_tests == 'true'
if: ${{ vars.ROCM_ONLY != 'true' && !(github.event_name == 'workflow_dispatch' && inputs.rocm_only) && needs.analyze_code_changes.outputs.run_tests == 'true' }}
strategy:
fail-fast: false
matrix:
Expand All @@ -175,7 +180,7 @@ jobs:
cpu-tests:
name: ${{ matrix.flavor }} tests
needs: [build_and_upload_maxtext_package]
if: needs.analyze_code_changes.outputs.run_tests == 'true'
if: ${{ vars.ROCM_ONLY != 'true' && !(github.event_name == 'workflow_dispatch' && inputs.rocm_only) && needs.analyze_code_changes.outputs.run_tests == 'true' }}
uses: ./.github/workflows/run_tests_coordinator.yml
strategy:
fail-fast: false
Expand All @@ -189,7 +194,7 @@ jobs:

maxtext_tpu_pathways_unit_tests:
needs: build_and_upload_maxtext_package
if: needs.analyze_code_changes.outputs.run_tests == 'true'
if: ${{ vars.ROCM_ONLY != 'true' && !(github.event_name == 'workflow_dispatch' && inputs.rocm_only) && needs.analyze_code_changes.outputs.run_tests == 'true' }}
uses: ./.github/workflows/run_pathways_tests.yml
strategy:
fail-fast: false
Expand All @@ -208,7 +213,7 @@ jobs:

maxtext_tpu_pathways_integration_tests:
needs: build_and_upload_maxtext_package
if: needs.analyze_code_changes.outputs.run_tests == 'true'
if: ${{ vars.ROCM_ONLY != 'true' && !(github.event_name == 'workflow_dispatch' && inputs.rocm_only) && needs.analyze_code_changes.outputs.run_tests == 'true' }}
uses: ./.github/workflows/run_pathways_tests.yml
strategy:
fail-fast: false
Expand All @@ -225,9 +230,39 @@ jobs:
is_scheduled_run: ${{ github.event_name == 'schedule' }}
maxtext_sha: ${{ needs.build_and_upload_maxtext_package.outputs.maxtext_sha }}

rocm-tests:
name: ${{ matrix.flavor }} tests
needs: [build_and_upload_maxtext_package]
if: ${{ vars.ROCM_ONLY != 'true' && needs.analyze_code_changes.outputs.run_tests == 'true' }}
uses: ./.github/workflows/run_tests_coordinator.yml
strategy:
fail-fast: false
matrix:
flavor: [rocm-unit]
with:
flavor: ${{ matrix.flavor }}
base_image: 'rocm-placeholder'
is_scheduled_run: ${{ github.event_name == 'schedule' }}
maxtext_sha: ${{ needs.build_and_upload_maxtext_package.outputs.maxtext_sha }}

rocm-decoupled-tests:
name: ${{ matrix.flavor }} tests
needs: [build_and_upload_maxtext_package]
if: ${{ vars.ROCM_ONLY == 'true' || (github.event_name == 'workflow_dispatch' && inputs.rocm_only) || needs.analyze_code_changes.outputs.run_tests == 'true' }}
uses: ./.github/workflows/run_tests_coordinator.yml
strategy:
fail-fast: false
matrix:
flavor: [rocm-decoupled]
with:
flavor: ${{ matrix.flavor }}
base_image: 'rocm-placeholder'
is_scheduled_run: ${{ github.event_name == 'schedule' }}
maxtext_sha: ${{ needs.build_and_upload_maxtext_package.outputs.maxtext_sha }}

all_tests_passed:
name: All Required Tests Passed
needs: [tpu-tests, gpu-tests, cpu-tests, maxtext_tpu_pathways_unit_tests, maxtext_tpu_pathways_integration_tests]
needs: [analyze_code_changes, build_and_upload_maxtext_package, tpu-tests, gpu-tests, cpu-tests, maxtext_tpu_pathways_unit_tests, maxtext_tpu_pathways_integration_tests, rocm-tests, rocm-decoupled-tests]
if: always()
runs-on: ubuntu-latest
steps:
Expand All @@ -245,6 +280,8 @@ jobs:
echo "CPU Tests (Matrix) result: ${NEEDS_CPU_TESTS_RESULT}"
echo "Pathways Unit result: ${NEEDS_MAXTEXT_TPU_PATHWAYS_UNIT_TESTS_RESULT}"
echo "Pathways Integration result: ${NEEDS_MAXTEXT_TPU_PATHWAYS_INTEGRATION_TESTS_RESULT}"
echo "ROCm Tests (Matrix) result: ${NEEDS_ROCM_TESTS_RESULT}"
echo "ROCm Decoupled Tests (Matrix) result: ${NEEDS_ROCM_DECOUPLED_TESTS_RESULT}"

# Fail only if any job failed or was cancelled (skipped is OK)
if [ "${{ contains(needs.*.result, 'failure') }}" == "true" ] || [ "${{ contains(needs.*.result, 'cancelled') }}" == "true" ]; then
Expand All @@ -261,11 +298,13 @@ jobs:
NEEDS_GPU_TESTS_RESULT: ${{ needs.gpu-tests.result }}
NEEDS_MAXTEXT_TPU_PATHWAYS_UNIT_TESTS_RESULT: ${{ needs.maxtext_tpu_pathways_unit_tests.result }}
NEEDS_MAXTEXT_TPU_PATHWAYS_INTEGRATION_TESTS_RESULT: ${{ needs.maxtext_tpu_pathways_integration_tests.result }}
NEEDS_ROCM_TESTS_RESULT: ${{ needs.rocm-tests.result }}
NEEDS_ROCM_DECOUPLED_TESTS_RESULT: ${{ needs.rocm-decoupled-tests.result }}

all_notebooks_passed:
name: All Notebooks Passed
needs: [analyze_code_changes, build_and_upload_maxtext_package, maxtext_jupyter_notebooks]
if: always()
if: ${{ vars.ROCM_ONLY != 'true' && !(github.event_name == 'workflow_dispatch' && inputs.rocm_only) && always() }}
runs-on: ubuntu-latest
steps:
- name: Check notebooks results
Expand Down Expand Up @@ -293,7 +332,7 @@ jobs:

notify_failure:
name: Notify failed build # creates an issue or modifies last open existing issue for failed build
needs: [tpu-tests, gpu-tests, cpu-tests, maxtext_jupyter_notebooks, maxtext_tpu_pathways_unit_tests, maxtext_tpu_pathways_integration_tests]
needs: [tpu-tests, gpu-tests, cpu-tests, maxtext_jupyter_notebooks, maxtext_tpu_pathways_unit_tests, maxtext_tpu_pathways_integration_tests, rocm-tests, rocm-decoupled-tests]
if: ${{ always() }}
runs-on: ubuntu-latest
permissions:
Expand Down
202 changes: 202 additions & 0 deletions .github/workflows/build_rocm_transformer_engine_wheel_weekly.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,202 @@
name: Build ROCm TransformerEngine wheel (weekly)

on:
workflow_dispatch:
schedule:
# Weekly at night (02:00 UTC every Monday), 2 hours ahead of scheduled tests.
- cron: "0 2 * * 1"

permissions:
contents: write

jobs:
build_upload_prune:
# AMD GPU runner (GitHub-hosted large runner label).
runs-on: linux-x86-64-4gpu-amd
container:
image: ghcr.io/rocm/jax-base-ubu24.rocm720:latest
options: >-
--device=/dev/kfd --device=/dev/dri --group-add video
--ipc=host --shm-size 64g
--cap-add=SYS_PTRACE --security-opt seccomp=unconfined
--privileged
env:
XLA_PYTHON_CLIENT_MEM_FRACTION: "0.9"
NVTE_FUSED_ATTN_AOTRITON: "0"
env:
TE_WHEELS_KEEP_DAYS: "21"

steps:
- name: Checkout
uses: actions/checkout@v5

- name: Setup build environment (deps + venv)
shell: bash
run: |
set -euo pipefail
apt-get update
apt-get install -y --no-install-recommends git build-essential python3-dev
python3 -m pip install -U uv
python3 -m uv venv --seed
source .venv/bin/activate
uv pip install -U pip setuptools wheel pybind11 cmake

- name: Install ROCm JAX/JAXlib wheels (build against CI stack)
shell: bash
run: |
set -euo pipefail
source .venv/bin/activate
uv pip install -r src/dependencies/requirements/requirements_decoupled_rocm_jax_0_9_1.txt

- name: Install PyTorch ROCm (build-time dep for aiter JIT)
shell: bash
run: |
set -euo pipefail
source .venv/bin/activate
uv pip install torch --index-url https://download.pytorch.org/whl/rocm7.2

- name: Detect ROCm version and Python tag
shell: bash
run: |
set -euo pipefail
source .venv/bin/activate

# Detect ROCm version number from JAX backend string when available (e.g. 'rocm 70200').
ROCM_NUM="$([ -f /opt/rocm/.info/version ] && head -n1 /opt/rocm/.info/version | tr -d ' \t\r' || echo unknown)"
echo "Detected ROCm version: ${ROCM_NUM}"
echo "ROCM_NUM=${ROCM_NUM}" >> "${GITHUB_ENV}"

PYTAG="cp$(python3 -c 'import sys; print(f"{sys.version_info.major}{sys.version_info.minor}")')"
if [ "${PYTAG}" != "cp312" ]; then
echo "Expected Python 3.12 (cp312) for ROCm CI wheels, got ${PYTAG}."
exit 1
fi
echo "PYTAG=${PYTAG}" >> "${GITHUB_ENV}"
echo "REL_SCRIPT=.github/workflows/utils/te_wheels_release.py" >> "${GITHUB_ENV}"

- name: Clone ROCm/TransformerEngine (dev)
shell: bash
run: |
set -euo pipefail
rm -rf TransformerEngine
git clone --recursive --branch dev https://github.com/ROCm/TransformerEngine.git
cd TransformerEngine
git submodule update --init --recursive
TE_SHA="$(git rev-parse --short=12 HEAD)"
echo "TE_SHA=${TE_SHA}" >> "${GITHUB_ENV}"

- name: Select TE wheel arch for runner (mi300/mi355)
shell: bash
run: |
set -euo pipefail
source .venv/bin/activate

TE_WHEEL_ARCH="$(python3 .github/workflows/utils/install_te_rocm_wheel.py --print-arch)"
echo "Resolved TE wheel arch for runner: ${TE_WHEEL_ARCH}"
echo "TE_WHEEL_ARCH=${TE_WHEEL_ARCH}" >> "${GITHUB_ENV}"

# Build ONLY for the ROCm arch present on this CI runner (mi300 or mi355).
if [ "${TE_WHEEL_ARCH}" = "mi355" ]; then
SELECTOR="mi355"
GFX_ARCH="gfx950"
else
SELECTOR="mi300"
GFX_ARCH="gfx942;gfx941"
fi
echo "SELECTOR=${SELECTOR}" >> "${GITHUB_ENV}"
echo "GFX_ARCH=${GFX_ARCH}" >> "${GITHUB_ENV}"

- name: Build TE wheel
shell: bash
run: |
set -euo pipefail
source .venv/bin/activate

chmod +x "${REL_SCRIPT}" || true

export USE_ROCM=1
export HIP_PATH=/opt/rocm
export NVTE_FRAMEWORK=jax
export CMAKE_BUILD_PARALLEL_LEVEL=64
export NVTE_USE_ROCM=1
export NVTE_FUSED_ATTN_AOTRITON=0
export NVTE_BUILD_MAX_JOBS=180

echo "=== Building TE wheel for ${SELECTOR} (gfx=${GFX_ARCH}) ==="
pushd TransformerEngine >/dev/null
rm -rf build dist
export PYTHONPATH="$(pwd)/3rdparty/hipify_torch${PYTHONPATH:+:${PYTHONPATH}}"
export PYTORCH_ROCM_ARCH="${GFX_ARCH}"
export NVTE_ROCM_ARCH="${GFX_ARCH}"
python3 setup.py bdist_wheel
wheel_path="$(
python3 -c 'import glob; m=sorted(glob.glob("dist/transformer_engine-*.whl")); print(m[0] if m else "")'
)"
if [ -z "${wheel_path}" ]; then
echo "No wheel produced in dist/ (selector=${SELECTOR})."
exit 1
fi
wheel_base="$(basename "${wheel_path}")"
if [[ "${wheel_base}" == *"-1.${SELECTOR}-${PYTAG}-${PYTAG}-linux_x86_64.whl" ]]; then
asset_name="${wheel_base}"
else
asset_name="${wheel_base/-${PYTAG}-${PYTAG}-linux_x86_64.whl/-1.${SELECTOR}-${PYTAG}-${PYTAG}-linux_x86_64.whl}"
if [ "${asset_name}" = "${wheel_base}" ]; then
echo "Failed to rename wheel for selector=${SELECTOR}: ${wheel_base}"
exit 1
fi
fi
cp -f "${wheel_path}" "../${asset_name}"
popd >/dev/null

ls -lh "${asset_name}"
echo "TE_WHEEL_FILE=${asset_name}" >> "${GITHUB_ENV}"

- name: Upload wheel to rolling release tag (te-rocm-wheels)
shell: bash
env:
GITHUB_TOKEN: ${{ github.token }}
run: |
set -euo pipefail
source .venv/bin/activate
python3 "${REL_SCRIPT}" upload --no-prerelease --tag "te-rocm-wheels" --title "ROCm TransformerEngine wheels (latest)" --body "Rolling release for latest weekly-built ROCm TransformerEngine wheels used by CI." --file "${TE_WHEEL_FILE}"

- name: Prune old assets from rolling tag
shell: bash
env:
GITHUB_TOKEN: ${{ github.token }}
run: |
set -euo pipefail
source .venv/bin/activate
echo "Pruning rolling-tag assets older than ${TE_WHEELS_KEEP_DAYS} days"
python3 "${REL_SCRIPT}" prune-assets --tag "te-rocm-wheels" --keep-days "${TE_WHEELS_KEEP_DAYS}"

- name: Publish wheel to dated weekly release tag
shell: bash
env:
GITHUB_TOKEN: ${{ github.token }}
run: |
set -euo pipefail
source .venv/bin/activate

DATE_UTC="$(date -u +%Y-%m-%d)"
WEEKLY_TAG="te-rocm-wheels-${DATE_UTC}-${TE_SHA}"
WEEKLY_TITLE="ROCm TransformerEngine wheels ${DATE_UTC} (TE ${TE_SHA})"
# Keep this YAML-safe (no unindented heredocs inside `run: |`).
WEEKLY_BODY="$(
printf '%s\n\nROCm: %s\nPython: %s\nArch: %s (gfx=%s)\n' \
"Built from ROCm/TransformerEngine dev @ ${TE_SHA} on ${DATE_UTC}." \
"${ROCM_NUM}" "${PYTAG}" "${SELECTOR}" "${GFX_ARCH}"
)"

python3 "${REL_SCRIPT}" upload --no-prerelease --tag "${WEEKLY_TAG}" --title "${WEEKLY_TITLE}" --body "${WEEKLY_BODY}" --file "${TE_WHEEL_FILE}"

- name: Prune old weekly releases
shell: bash
env:
GITHUB_TOKEN: ${{ github.token }}
run: |
set -euo pipefail
source .venv/bin/activate
echo "Pruning weekly release pages older than ${TE_WHEELS_KEEP_DAYS} days"
python3 "${REL_SCRIPT}" prune-releases --prefix "te-rocm-wheels-" --keep-days "${TE_WHEELS_KEEP_DAYS}"
Loading
Loading