Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
fe61a0c
adding rocm_jax_0.7.1 reqs
gulsumgudukbay Feb 3, 2026
bc13c85
Revert "removing CI workflows for now to upstream decoupling changes"
gulsumgudukbay Feb 3, 2026
cad0e39
skip ring attention test on ROCm
gulsumgudukbay Feb 9, 2026
9a840ad
[DOWNSTREAM-ONLY] update schedule for build_and_test_maxtext
gulsumgudukbay Feb 11, 2026
5d8a1d8
adding jax 0.8.2 requirements
gulsumgudukbay Feb 11, 2026
2b4b7a6
update configs in tests to use helper functions
gulsumgudukbay Feb 12, 2026
5c59338
adding TE build and upload CI workflow
gulsumgudukbay Feb 16, 2026
568f2d4
Adding CI workflow changes for ROCm and JAX 0.8.2 requirements files
gulsumgudukbay Feb 3, 2026
2c770ad
update te wheel consumption
gulsumgudukbay Feb 20, 2026
65eff61
refactoring TE wheel release workflow
gulsumgudukbay Feb 16, 2026
def5679
update runner labels to mi355
gulsumgudukbay Feb 24, 2026
cec342e
fix te wheel selection
gulsumgudukbay Mar 10, 2026
62b9b2d
Remove ROCm fused-attention backend variables
gulsumgudukbay Mar 11, 2026
9f6a367
refactor requirements location change
gulsumgudukbay Mar 12, 2026
23da49f
fix TE build workflow, add rocm torch dependency to env
gulsumgudukbay Mar 31, 2026
2eb63a2
fix ci
gulsumgudukbay Apr 8, 2026
97bc532
Merge branch 'AI-Hypercomputer:main' into rocm-main
gulsumgudukbay Apr 23, 2026
ec4b5b9
update with a reduction of temp mem usage
Apr 22, 2026
985ab79
upate to reduce temp mem usage
cj401-amd Apr 23, 2026
4fceae4
update for reduce temp memory usage for ep=2 pp=4
cj401-amd Apr 28, 2026
f14671b
Apply deepseek.py temp mem fix from cj-reduce-tmp-mem_rocm-main
Apr 30, 2026
8168931
Fix jax.jit decorator syntax for JAX 0.7.1 compatibility
Apr 30, 2026
720619e
Guard jax_remove_size_one_mesh_axis_from_type for JAX 0.7.1 compat
Apr 30, 2026
e18114c
Fix pp=8 ep=1 temp memory regression (36.8 GB -> ~31 GB)
Apr 30, 2026
3b493c8
Add skip_trivial_specs to pipeline._maybe_shard_with_logical
Apr 30, 2026
c2916c4
Remove unconditional nan_to_num on all float gradients
Apr 30, 2026
73ef4a8
Clean up debug artifacts before PR
Apr 30, 2026
091420b
Fix depth_scaling regression and attention_op.py TE compatibility issues
Apr 30, 2026
38b5602
Fix issues 5/6/7: moe tile comment, embeddings axis, float32_weight_s…
Apr 30, 2026
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
75 changes: 57 additions & 18 deletions .github/workflows/build_and_test_maxtext.yml
Original file line number Diff line number Diff line change
Expand Up @@ -20,16 +20,23 @@ on:
pull_request:
workflow_call:
workflow_dispatch:
inputs:
rocm_only:
description: 'Run only ROCm jobs (manual runs only)'
type: boolean
required: false
default: false
schedule:
# Run the job every 4 hours
- cron: '0 */4 * * *'
- cron: '0 3 * * *' # Daily 03:00 UTC

concurrency:
# Dedup pull requests (canceling previous runs of the same workflow for same PR), and scheduled runs but nothing else
# Cancel previous runs for same PR (all actors), scheduled runs,
# and manual runs per (branch + actor).
group: >
${{
github.event_name == 'pull_request' && format('{0}-pr-{1}', github.workflow, github.event.pull_request.number) ||
github.event_name == 'schedule' && format('{0}-schedule', github.workflow) ||
github.event_name == 'workflow_dispatch' && format('{0}-manual-{1}-{2}', github.workflow, github.ref, github.actor) ||
github.run_id
}}
cancel-in-progress: true
Expand Down Expand Up @@ -118,18 +125,16 @@ jobs:
build_and_upload_maxtext_package:
needs: analyze_code_changes
# Run if either tests or notebooks need to run
if: |
needs.analyze_code_changes.outputs.run_tests == 'true' ||
needs.analyze_code_changes.outputs.run_notebooks == 'true'
if: ${{ vars.ROCM_ONLY == 'true' || needs.analyze_code_changes.outputs.run_tests == 'true' || needs.analyze_code_changes.outputs.run_notebooks == 'true' }}
uses: ./.github/workflows/build_package.yml
with:
device_type: tpu
device_name: v4-8
cloud_runner: linux-x86-n2-16-buildkit
device_type: ${{ vars.ROCM_ONLY == 'true' && 'rocm' || 'tpu' }}
device_name: ${{ vars.ROCM_ONLY == 'true' && 'mi355' || 'v4-8' }}
cloud_runner: ${{ vars.ROCM_ONLY == 'true' && 'linux-x86-64-4gpu-amd' || 'linux-x86-n2-16-buildkit' }}

maxtext_jupyter_notebooks:
needs: build_and_upload_maxtext_package
if: needs.analyze_code_changes.outputs.run_notebooks == 'true'
if: ${{ vars.ROCM_ONLY != 'true' && !(github.event_name == 'workflow_dispatch' && inputs.rocm_only) && needs.analyze_code_changes.outputs.run_notebooks == 'true' }}
uses: ./.github/workflows/run_jupyter_notebooks.yml
strategy:
fail-fast: false
Expand All @@ -145,7 +150,7 @@ jobs:
tpu-tests:
name: ${{ matrix.flavor }} tests
needs: [build_and_upload_maxtext_package]
if: needs.analyze_code_changes.outputs.run_tests == 'true'
if: ${{ vars.ROCM_ONLY != 'true' && !(github.event_name == 'workflow_dispatch' && inputs.rocm_only) && needs.analyze_code_changes.outputs.run_tests == 'true' }}
uses: ./.github/workflows/run_tests_coordinator.yml
strategy:
fail-fast: false
Expand All @@ -160,7 +165,7 @@ jobs:
gpu-tests:
name: ${{ matrix.flavor }} tests
needs: [build_and_upload_maxtext_package]
if: needs.analyze_code_changes.outputs.run_tests == 'true'
if: ${{ vars.ROCM_ONLY != 'true' && !(github.event_name == 'workflow_dispatch' && inputs.rocm_only) && needs.analyze_code_changes.outputs.run_tests == 'true' }}
strategy:
fail-fast: false
matrix:
Expand All @@ -175,7 +180,7 @@ jobs:
cpu-tests:
name: ${{ matrix.flavor }} tests
needs: [build_and_upload_maxtext_package]
if: needs.analyze_code_changes.outputs.run_tests == 'true'
if: ${{ vars.ROCM_ONLY != 'true' && !(github.event_name == 'workflow_dispatch' && inputs.rocm_only) && needs.analyze_code_changes.outputs.run_tests == 'true' }}
uses: ./.github/workflows/run_tests_coordinator.yml
strategy:
fail-fast: false
Expand All @@ -189,7 +194,7 @@ jobs:

maxtext_tpu_pathways_unit_tests:
needs: build_and_upload_maxtext_package
if: needs.analyze_code_changes.outputs.run_tests == 'true'
if: ${{ vars.ROCM_ONLY != 'true' && !(github.event_name == 'workflow_dispatch' && inputs.rocm_only) && needs.analyze_code_changes.outputs.run_tests == 'true' }}
uses: ./.github/workflows/run_pathways_tests.yml
strategy:
fail-fast: false
Expand All @@ -207,7 +212,7 @@ jobs:

maxtext_tpu_pathways_integration_tests:
needs: build_and_upload_maxtext_package
if: needs.analyze_code_changes.outputs.run_tests == 'true'
if: ${{ vars.ROCM_ONLY != 'true' && !(github.event_name == 'workflow_dispatch' && inputs.rocm_only) && needs.analyze_code_changes.outputs.run_tests == 'true' }}
uses: ./.github/workflows/run_pathways_tests.yml
strategy:
fail-fast: false
Expand All @@ -223,9 +228,39 @@ jobs:
is_scheduled_run: ${{ github.event_name == 'schedule' }}
maxtext_sha: ${{ needs.build_and_upload_maxtext_package.outputs.maxtext_sha }}

rocm-tests:
name: ${{ matrix.flavor }} tests
needs: [build_and_upload_maxtext_package]
if: ${{ vars.ROCM_ONLY != 'true' && needs.analyze_code_changes.outputs.run_tests == 'true' }}
uses: ./.github/workflows/run_tests_coordinator.yml
strategy:
fail-fast: false
matrix:
flavor: [rocm-unit]
with:
flavor: ${{ matrix.flavor }}
base_image: 'rocm-placeholder'
is_scheduled_run: ${{ github.event_name == 'schedule' }}
maxtext_sha: ${{ needs.build_and_upload_maxtext_package.outputs.maxtext_sha }}

rocm-decoupled-tests:
name: ${{ matrix.flavor }} tests
needs: [build_and_upload_maxtext_package]
if: ${{ vars.ROCM_ONLY == 'true' || (github.event_name == 'workflow_dispatch' && inputs.rocm_only) || needs.analyze_code_changes.outputs.run_tests == 'true' }}
uses: ./.github/workflows/run_tests_coordinator.yml
strategy:
fail-fast: false
matrix:
flavor: [rocm-decoupled]
with:
flavor: ${{ matrix.flavor }}
base_image: 'rocm-placeholder'
is_scheduled_run: ${{ github.event_name == 'schedule' }}
maxtext_sha: ${{ needs.build_and_upload_maxtext_package.outputs.maxtext_sha }}

all_tests_passed:
name: All Required Tests Passed
needs: [tpu-tests, gpu-tests, cpu-tests, maxtext_tpu_pathways_unit_tests, maxtext_tpu_pathways_integration_tests]
needs: [analyze_code_changes, build_and_upload_maxtext_package, tpu-tests, gpu-tests, cpu-tests, maxtext_tpu_pathways_unit_tests, maxtext_tpu_pathways_integration_tests, rocm-tests, rocm-decoupled-tests]
if: always()
runs-on: ubuntu-latest
steps:
Expand All @@ -243,6 +278,8 @@ jobs:
echo "CPU Tests (Matrix) result: ${NEEDS_CPU_TESTS_RESULT}"
echo "Pathways Unit result: ${NEEDS_MAXTEXT_TPU_PATHWAYS_UNIT_TESTS_RESULT}"
echo "Pathways Integration result: ${NEEDS_MAXTEXT_TPU_PATHWAYS_INTEGRATION_TESTS_RESULT}"
echo "ROCm Tests (Matrix) result: ${NEEDS_ROCM_TESTS_RESULT}"
echo "ROCm Decoupled Tests (Matrix) result: ${NEEDS_ROCM_DECOUPLED_TESTS_RESULT}"

# Fail only if any job failed or was cancelled (skipped is OK)
if [ "${{ contains(needs.*.result, 'failure') }}" == "true" ] || [ "${{ contains(needs.*.result, 'cancelled') }}" == "true" ]; then
Expand All @@ -259,11 +296,13 @@ jobs:
NEEDS_GPU_TESTS_RESULT: ${{ needs.gpu-tests.result }}
NEEDS_MAXTEXT_TPU_PATHWAYS_UNIT_TESTS_RESULT: ${{ needs.maxtext_tpu_pathways_unit_tests.result }}
NEEDS_MAXTEXT_TPU_PATHWAYS_INTEGRATION_TESTS_RESULT: ${{ needs.maxtext_tpu_pathways_integration_tests.result }}
NEEDS_ROCM_TESTS_RESULT: ${{ needs.rocm-tests.result }}
NEEDS_ROCM_DECOUPLED_TESTS_RESULT: ${{ needs.rocm-decoupled-tests.result }}

all_notebooks_passed:
name: All Notebooks Passed
needs: [analyze_code_changes, build_and_upload_maxtext_package, maxtext_jupyter_notebooks]
if: always()
if: ${{ vars.ROCM_ONLY != 'true' && !(github.event_name == 'workflow_dispatch' && inputs.rocm_only) && always() }}
runs-on: ubuntu-latest
steps:
- name: Check notebooks results
Expand Down Expand Up @@ -291,7 +330,7 @@ jobs:

notify_failure:
name: Notify failed build # creates an issue or modifies last open existing issue for failed build
needs: [tpu-tests, gpu-tests, cpu-tests, maxtext_jupyter_notebooks, maxtext_tpu_pathways_unit_tests, maxtext_tpu_pathways_integration_tests]
needs: [tpu-tests, gpu-tests, cpu-tests, maxtext_jupyter_notebooks, maxtext_tpu_pathways_unit_tests, maxtext_tpu_pathways_integration_tests, rocm-tests, rocm-decoupled-tests]
if: ${{ always() }}
runs-on: ubuntu-latest
permissions:
Expand Down
202 changes: 202 additions & 0 deletions .github/workflows/build_rocm_transformer_engine_wheel_weekly.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,202 @@
name: Build ROCm TransformerEngine wheel (weekly)

on:
workflow_dispatch:
schedule:
# Weekly at night (02:00 UTC every Monday), 2 hours ahead of scheduled tests.
- cron: "0 2 * * 1"

permissions:
contents: write

jobs:
build_upload_prune:
# AMD GPU runner (GitHub-hosted large runner label).
runs-on: linux-x86-64-4gpu-amd
container:
image: ghcr.io/rocm/jax-base-ubu24.rocm720:latest
options: >-
--device=/dev/kfd --device=/dev/dri --group-add video
--ipc=host --shm-size 64g
--cap-add=SYS_PTRACE --security-opt seccomp=unconfined
--privileged
env:
XLA_PYTHON_CLIENT_MEM_FRACTION: "0.9"
NVTE_FUSED_ATTN_AOTRITON: "0"
env:
TE_WHEELS_KEEP_DAYS: "21"

steps:
- name: Checkout
uses: actions/checkout@v5

- name: Setup build environment (deps + venv)
shell: bash
run: |
set -euo pipefail
apt-get update
apt-get install -y --no-install-recommends git build-essential python3-dev
python3 -m pip install -U uv
python3 -m uv venv --seed
source .venv/bin/activate
uv pip install -U pip setuptools wheel pybind11 cmake

- name: Install ROCm JAX/JAXlib wheels (build against CI stack)
shell: bash
run: |
set -euo pipefail
source .venv/bin/activate
uv pip install -r src/dependencies/requirements/requirements_decoupled_rocm_jax_0_8_2.txt

- name: Install PyTorch ROCm (build-time dep for aiter JIT)
shell: bash
run: |
set -euo pipefail
source .venv/bin/activate
uv pip install torch --index-url https://download.pytorch.org/whl/rocm7.2

- name: Detect ROCm version and Python tag
shell: bash
run: |
set -euo pipefail
source .venv/bin/activate

# Detect ROCm version number from JAX backend string when available (e.g. 'rocm 70200').
ROCM_NUM="$([ -f /opt/rocm/.info/version ] && head -n1 /opt/rocm/.info/version | tr -d ' \t\r' || echo unknown)"
echo "Detected ROCm version: ${ROCM_NUM}"
echo "ROCM_NUM=${ROCM_NUM}" >> "${GITHUB_ENV}"

PYTAG="cp$(python3 -c 'import sys; print(f"{sys.version_info.major}{sys.version_info.minor}")')"
if [ "${PYTAG}" != "cp312" ]; then
echo "Expected Python 3.12 (cp312) for ROCm CI wheels, got ${PYTAG}."
exit 1
fi
echo "PYTAG=${PYTAG}" >> "${GITHUB_ENV}"
echo "REL_SCRIPT=.github/workflows/utils/te_wheels_release.py" >> "${GITHUB_ENV}"

- name: Clone ROCm/TransformerEngine (dev)
shell: bash
run: |
set -euo pipefail
rm -rf TransformerEngine
git clone --recursive --branch dev https://github.com/ROCm/TransformerEngine.git
cd TransformerEngine
git submodule update --init --recursive
TE_SHA="$(git rev-parse --short=12 HEAD)"
echo "TE_SHA=${TE_SHA}" >> "${GITHUB_ENV}"

- name: Select TE wheel arch for runner (mi300/mi355)
shell: bash
run: |
set -euo pipefail
source .venv/bin/activate

TE_WHEEL_ARCH="$(python3 .github/workflows/utils/install_te_rocm_wheel.py --print-arch)"
echo "Resolved TE wheel arch for runner: ${TE_WHEEL_ARCH}"
echo "TE_WHEEL_ARCH=${TE_WHEEL_ARCH}" >> "${GITHUB_ENV}"

# Build ONLY for the ROCm arch present on this CI runner (mi300 or mi355).
if [ "${TE_WHEEL_ARCH}" = "mi355" ]; then
SELECTOR="mi355"
GFX_ARCH="gfx950"
else
SELECTOR="mi300"
GFX_ARCH="gfx942;gfx941"
fi
echo "SELECTOR=${SELECTOR}" >> "${GITHUB_ENV}"
echo "GFX_ARCH=${GFX_ARCH}" >> "${GITHUB_ENV}"

- name: Build TE wheel
shell: bash
run: |
set -euo pipefail
source .venv/bin/activate

chmod +x "${REL_SCRIPT}" || true

export USE_ROCM=1
export HIP_PATH=/opt/rocm
export NVTE_FRAMEWORK=jax
export CMAKE_BUILD_PARALLEL_LEVEL=64
export NVTE_USE_ROCM=1
export NVTE_FUSED_ATTN_AOTRITON=0
export NVTE_BUILD_MAX_JOBS=180

echo "=== Building TE wheel for ${SELECTOR} (gfx=${GFX_ARCH}) ==="
pushd TransformerEngine >/dev/null
rm -rf build dist
export PYTHONPATH="$(pwd)/3rdparty/hipify_torch${PYTHONPATH:+:${PYTHONPATH}}"
export PYTORCH_ROCM_ARCH="${GFX_ARCH}"
export NVTE_ROCM_ARCH="${GFX_ARCH}"
python3 setup.py bdist_wheel
wheel_path="$(
python3 -c 'import glob; m=sorted(glob.glob("dist/transformer_engine-*.whl")); print(m[0] if m else "")'
)"
if [ -z "${wheel_path}" ]; then
echo "No wheel produced in dist/ (selector=${SELECTOR})."
exit 1
fi
wheel_base="$(basename "${wheel_path}")"
if [[ "${wheel_base}" == *"-1.${SELECTOR}-${PYTAG}-${PYTAG}-linux_x86_64.whl" ]]; then
asset_name="${wheel_base}"
else
asset_name="${wheel_base/-${PYTAG}-${PYTAG}-linux_x86_64.whl/-1.${SELECTOR}-${PYTAG}-${PYTAG}-linux_x86_64.whl}"
if [ "${asset_name}" = "${wheel_base}" ]; then
echo "Failed to rename wheel for selector=${SELECTOR}: ${wheel_base}"
exit 1
fi
fi
cp -f "${wheel_path}" "../${asset_name}"
popd >/dev/null

ls -lh "${asset_name}"
echo "TE_WHEEL_FILE=${asset_name}" >> "${GITHUB_ENV}"

- name: Upload wheel to rolling release tag (te-rocm-wheels)
shell: bash
env:
GITHUB_TOKEN: ${{ github.token }}
run: |
set -euo pipefail
source .venv/bin/activate
python3 "${REL_SCRIPT}" upload --no-prerelease --tag "te-rocm-wheels" --title "ROCm TransformerEngine wheels (latest)" --body "Rolling release for latest weekly-built ROCm TransformerEngine wheels used by CI." --file "${TE_WHEEL_FILE}"

- name: Prune old assets from rolling tag
shell: bash
env:
GITHUB_TOKEN: ${{ github.token }}
run: |
set -euo pipefail
source .venv/bin/activate
echo "Pruning rolling-tag assets older than ${TE_WHEELS_KEEP_DAYS} days"
python3 "${REL_SCRIPT}" prune-assets --tag "te-rocm-wheels" --keep-days "${TE_WHEELS_KEEP_DAYS}"

- name: Publish wheel to dated weekly release tag
shell: bash
env:
GITHUB_TOKEN: ${{ github.token }}
run: |
set -euo pipefail
source .venv/bin/activate

DATE_UTC="$(date -u +%Y-%m-%d)"
WEEKLY_TAG="te-rocm-wheels-${DATE_UTC}-${TE_SHA}"
WEEKLY_TITLE="ROCm TransformerEngine wheels ${DATE_UTC} (TE ${TE_SHA})"
# Keep this YAML-safe (no unindented heredocs inside `run: |`).
WEEKLY_BODY="$(
printf '%s\n\nROCm: %s\nPython: %s\nArch: %s (gfx=%s)\n' \
"Built from ROCm/TransformerEngine dev @ ${TE_SHA} on ${DATE_UTC}." \
"${ROCM_NUM}" "${PYTAG}" "${SELECTOR}" "${GFX_ARCH}"
)"

python3 "${REL_SCRIPT}" upload --no-prerelease --tag "${WEEKLY_TAG}" --title "${WEEKLY_TITLE}" --body "${WEEKLY_BODY}" --file "${TE_WHEEL_FILE}"

- name: Prune old weekly releases
shell: bash
env:
GITHUB_TOKEN: ${{ github.token }}
run: |
set -euo pipefail
source .venv/bin/activate
echo "Pruning weekly release pages older than ${TE_WHEELS_KEEP_DAYS} days"
python3 "${REL_SCRIPT}" prune-releases --prefix "te-rocm-wheels-" --keep-days "${TE_WHEELS_KEEP_DAYS}"
Loading
Loading