From cb7f66bab7a46f5d010d6c49cad541bd63da698d Mon Sep 17 00:00:00 2001 From: gulsumgudukbay Date: Tue, 3 Feb 2026 20:01:55 +0000 Subject: [PATCH 01/21] adding rocm_jax_0.7.1 reqs --- .../requirements_decoupled_rocm_jax_0_7.1.txt | 47 +++++++++++++++++++ 1 file changed, 47 insertions(+) create mode 100644 src/dependencies/requirements/requirements_decoupled_rocm_jax_0_7.1.txt diff --git a/src/dependencies/requirements/requirements_decoupled_rocm_jax_0_7.1.txt b/src/dependencies/requirements/requirements_decoupled_rocm_jax_0_7.1.txt new file mode 100644 index 0000000000..4d1732349c --- /dev/null +++ b/src/dependencies/requirements/requirements_decoupled_rocm_jax_0_7.1.txt @@ -0,0 +1,47 @@ +absl_py>=2.3.1 +aqtp>=0.9.0 +chex>=0.1.90 +datasets>=4.2.0 +etils>=1.13.0 +evaluate>=0.4.6 +flax +grain>=0.2.12 +grpcio>=1.75.1 +huggingface_hub>=0.35.3 +https://github.com/ROCm/rocm-jax/releases/download/rocm-jax-v0.7.1/jaxlib-0.7.1-cp312-cp312-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl +jax==0.7.1 +jax-rocm7-pjrt==0.7.1 +jax-rocm7-plugin==0.7.1 +jaxtyping>=0.3.3 +jsonlines>=4.0.0 +matplotlib>=3.10.3 +ml_collections>=1.1.0 +ml_dtypes>=0.5.3 +nltk>=3.9.2 +numpy>=2.0.2 +omegaconf>=2.3.0 +optax>=0.2.6 +orbax-checkpoint>=0.11.25 +pandas>=2.3.3 +parameterized==0.9.0 +pathwaysutils>=0.1.3 +pillow>=11.3.0 +protobuf>=5.29.5 +psutil>=7.0.0 +pytest>=8.4.1 +PyYAML>=6.0.3 +Requests>=2.32.5 +qwix>=0.1.1 +safetensors>=0.6.2 +sentencepiece>=0.2.1 +setuptools>=80.9.0 +tabulate>=0.9.0 +tensorflow>=2.19.1 +tensorflow_text>=2.19.0 +tensorflow_datasets>=4.9.9 +tensorstore>=0.1.76 +tiktoken>=0.12.0 +tqdm>=4.67.1 +transformers>=4.57.0 +urllib3>=2.5.0 +git+https://github.com/google/tunix.git From 53f94f31900b29d0df0bd0869f1809764dfd9c5b Mon Sep 17 00:00:00 2001 From: Gulsum Gudukbay Akbulut Date: Tue, 3 Feb 2026 14:07:26 -0600 Subject: [PATCH 02/21] Revert "removing CI workflows for now to upstream decoupling changes" This reverts commit 11a8852882c6b2d39ca9509fbc1fdb129cca4d22. --- .github/workflows/decouple_tests.yml | 183 +++++++++++++++++++++++++++ .github/workflows/upstream_sync.yml | 51 ++++++++ 2 files changed, 234 insertions(+) create mode 100644 .github/workflows/decouple_tests.yml create mode 100644 .github/workflows/upstream_sync.yml diff --git a/.github/workflows/decouple_tests.yml b/.github/workflows/decouple_tests.yml new file mode 100644 index 0000000000..2003441b06 --- /dev/null +++ b/.github/workflows/decouple_tests.yml @@ -0,0 +1,183 @@ +name: Decoupled Offline Tests + +on: + pull_request: + branches: + - rocm-main + paths: + - 'maxtext/**' + - '.github/workflows/decouple_tests.yml' + - 'pyproject.toml' + - 'README.md' + push: + branches: + - rocm-main + paths: + - 'maxtext/**' + - '.github/workflows/decouple_tests.yml' + workflow_dispatch: + +jobs: + decoupled: + runs-on: ubuntu-latest + env: + DECOUPLE_GCLOUD: TRUE + LOCAL_GCLOUD_PROJECT: ci-decoupled + LOCAL_BASE: datasets + LOCAL_BASE_OUTPUT: datasets/gcloud_decoupled_test_logs + KEEPALIVE: sleep infinity + steps: + - name: Checkout + uses: actions/checkout@v4 + + - name: Pull ROCm JAX base image + run: docker pull rocm/jax:rocm7.0.2-jax0.6.0-py3.12-ubu24 + + - name: Start container + run: | + set -euo pipefail + mkdir -p "${LOCAL_BASE_OUTPUT}" + DOCKER_IMAGE="rocm/jax:rocm7.0.2-jax0.6.0-py3.12-ubu24" + CONTAINER_NAME="decouple_maxtext_tests" + DEVICE_FLAGS="--device=/dev/kfd --device=/dev/dri --group-add video" + RUNTIME_FLAGS="--network=host --ipc=host --shm-size 16G --cap-add=SYS_PTRACE --security-opt seccomp=unconfined" + WORKDIR="/maxtext" + MOUNT="-v ${{ github.workspace }}:${WORKDIR}" + ENV_VARS="-e DECOUPLE_GCLOUD=${DECOUPLE_GCLOUD} -e LOCAL_GCLOUD_PROJECT=${LOCAL_GCLOUD_PROJECT} -e LOCAL_BASE_OUTPUT=${LOCAL_BASE_OUTPUT}" + docker rm -f ${CONTAINER_NAME} >/dev/null 2>&1 || true + docker run -d -it --name ${CONTAINER_NAME} -w ${WORKDIR} ${ENV_VARS} ${RUNTIME_FLAGS} ${DEVICE_FLAGS} ${MOUNT} ${DOCKER_IMAGE} bash -c "${KEEPALIVE}" + echo "Container ${CONTAINER_NAME} started with centralized env vars." >> "${LOCAL_BASE_OUTPUT}/ci.log" + + - name: Install Requirements and Maxtext Package + run: | + set -euo pipefail + CONTAINER_NAME="decouple_maxtext_tests" + docker exec ${CONTAINER_NAME} bash -euo pipefail -c ' + mkdir -p "${LOCAL_BASE_OUTPUT}" + apt-get update + apt-get install -y git vim python3.12-venv python3-pip python3-dev build-essential + python3 -m venv .venv || { echo "venv creation failed"; exit 10; } + . .venv/bin/activate + python -m ensurepip --upgrade || true + python -m pip install --upgrade pip + python -m pip install wheel setuptools cmake + python -m pip install -r requirements_decoupled_rocm_jax_0_6_0.txt --no-cache-dir || { echo "Requirements install failed"; exit 13; } + python -m pip install . --no-cache-dir || { echo "Package install failed"; exit 11; } + echo "Maxtext package installed." >> "${LOCAL_BASE_OUTPUT}/ci.log" + python -m pip install pytest pytest-html pytest-csv minio array-record || { echo "Test deps failed"; exit 12; } + python -m pip list >> "${LOCAL_BASE_OUTPUT}/pip_freeze.txt" + echo "Requirements installed." >> "${LOCAL_BASE_OUTPUT}/ci.log" + ' + + - name: Build and Install TE + run: | + set -euo pipefail + CONTAINER_NAME="decouple_maxtext_tests" + docker exec ${CONTAINER_NAME} bash -euo pipefail -c ' + . .venv/bin/activate + if [ ! -e "TransformerEngine" ]; then + git clone https://github.com/ROCm/TransformerEngine.git + fi + cd TransformerEngine + git reset --hard ec21e1f35f5a36c1def9c89c4865eba3947f0665 + git submodule update --init --recursive + export USE_ROCM=1 + export HIP_PATH=/opt/rocm + export NVTE_FRAMEWORK=jax + export CMAKE_BUILD_PARALLEL_LEVEL=64 + export PYTORCH_ROCM_ARCH=gfx942 + export NVTE_ROCM_ARCH=gfx942 + export NVTE_USE_ROCM=1 + export NVTE_FUSED_ATTN_AOTRITON=0 + export NVTE_BUILD_MAX_JOBS=180 + export CU_NUM=304 + if find dist -maxdepth 1 -name "*.whl" -print -quit | grep -q .; then + echo "TE wheels exist, skip building" + else + python setup.py bdist_wheel || { echo "TransformerEngine build failed"; exit 20; } + fi + python -m pip install dist/transformer_engine-*.whl --no-cache-dir --force-reinstall || { echo "TE wheel install failed"; exit 21; } + cd .. + echo "TE built and installed." >> "${LOCAL_BASE_OUTPUT}/ci.log" + ' + + - name: Generate datasets + run: | + set -euo pipefail + CONTAINER_NAME="decouple_maxtext_tests" + docker exec ${CONTAINER_NAME} bash -euo pipefail -c ' + . .venv/bin/activate + echo "Generating minimal ArrayRecord shards" >> "${LOCAL_BASE_OUTPUT}/ci.log" + python ${LOCAL_BASE}/get_minimal_c4_en_dataset.py --force || { echo "ArrayRecord generation failed"; exit 3; } + echo "Generating minimal parquet shards" >> "${LOCAL_BASE_OUTPUT}/ci.log" + python ${LOCAL_BASE}/get_minimal_hf_c4_parquet.py --force --train-rows 200 --val-rows 40 + echo "Minimal dataset setup complete." >> "${LOCAL_BASE_OUTPUT}/ci.log" + echo "Generating minimal tfrecord shards." >> "${LOCAL_BASE_OUTPUT}/ci.log" + python ${LOCAL_BASE}/convert_arrayrecord_to_tfrecord.py --version-dir ${LOCAL_BASE}/c4_en_dataset_minimal/c4/en/3.0.1 --builder-name __local_c4_builder --force + echo "Minimal tfrecord shards generation complete." >> "${LOCAL_BASE_OUTPUT}/ci.log" + cp -r ${LOCAL_BASE}/c4_en_dataset_minimal/c4/en/3.0.1 ${LOCAL_BASE}/c4_en_dataset_minimal/c4/en/3.1.0 + echo "Generating TFDS metadata." >> "${LOCAL_BASE_OUTPUT}/ci.log" + python ${LOCAL_BASE}/generate_tfds_metadata.py --root ${LOCAL_BASE}/c4_en_dataset_minimal --version 3.0.1 --source-version 3.0.1 --force + python ${LOCAL_BASE}/generate_tfds_metadata.py --root ${LOCAL_BASE}/c4_en_dataset_minimal --version 3.1.0 --source-version 3.0.1 --force + echo "Minimal dataset generation complete!" >> "${LOCAL_BASE_OUTPUT}/ci.log" + ' + + - name: Collect & run decoupled tests + run: | + set -euo pipefail + CONTAINER_NAME="decouple_maxtext_tests" + docker exec ${CONTAINER_NAME} bash -euo pipefail -c ' + . .venv/bin/activate + echo "Collecting tests" >> "${LOCAL_BASE_OUTPUT}/ci.log" + pytest tests -m decoupled --collect-only -q | tee "${LOCAL_BASE_OUTPUT}/collect.txt" + TEST_COUNT="$(grep -v "^$" "${LOCAL_BASE_OUTPUT}/collect.txt" | wc -l)" + echo Collected:$TEST_COUNT >> "${LOCAL_BASE_OUTPUT}/ci.log" + if [ "$TEST_COUNT" -eq 0 ]; then echo "ERROR: No decoupled tests found" >> "${LOCAL_BASE_OUTPUT}/ci.log"; exit 2; fi + pytest -v tests -m decoupled -q --csv=decoupled-tests-report.csv --html=decoupled-tests-report.html --self-contained-html + mv decoupled-tests-report.csv decoupled-tests-report.html "${LOCAL_BASE_OUTPUT}/" + EXEC_COUNT="$(grep -E '^[^,]+,' "${LOCAL_BASE_OUTPUT}/decoupled-tests-report.csv" | wc -l)" + echo Executed:$EXEC_COUNT >> "${LOCAL_BASE_OUTPUT}/ci.log" + echo "### Decoupled Test Run Summary" >> $GITHUB_STEP_SUMMARY + echo Collected: $TEST_COUNT >> $GITHUB_STEP_SUMMARY + echo Executed: $EXEC_COUNT >> $GITHUB_STEP_SUMMARY + echo "Reports: CSV & HTML attached as artifacts." >> $GITHUB_STEP_SUMMARY + ' + + echo "Container ${CONTAINER_NAME} still alive." >> datasets/gcloud_decoupled_test_logs/ci.log + echo "Attach to container with: docker exec -it ${CONTAINER_NAME} bash" >> $GITHUB_STEP_SUMMARY + echo "Remove when done: docker rm -f ${CONTAINER_NAME}" >> $GITHUB_STEP_SUMMARY + + - name: Snapshot container image Artifact + if: always() + run: | + set -euo pipefail + CONTAINER_NAME="decouple_maxtext_tests" + IMAGE_TAG="maxtext-decoupled-tests:latest" + IMAGE_FILE="maxtext-decoupled-tests-image.tar.gz" + if ! docker ps -a --format '{{.Names}}' | grep -q "^${CONTAINER_NAME}$"; then + echo "Container ${CONTAINER_NAME} not found; skipping snapshot." >> "${LOCAL_BASE_OUTPUT}/ci.log" + exit 0 + fi + docker commit ${CONTAINER_NAME} ${IMAGE_TAG} + docker save ${IMAGE_TAG} | gzip -c > ${IMAGE_FILE} + ls -lh ${IMAGE_FILE} >> "${LOCAL_BASE_OUTPUT}/ci.log" || true + echo "Image snapshot saved as artifact candidate: ${IMAGE_FILE}" >> "${LOCAL_BASE_OUTPUT}/ci.log" + + - name: Upload snapshot image + if: always() + uses: actions/upload-artifact@v3 + with: + name: decoupled-docker-image + path: maxtext-decoupled-tests-image.tar.gz + if-no-files-found: warn + + - name: Upload decoupled logs + if: always() + uses: actions/upload-artifact@v3 + with: + name: decoupled-logs + path: datasets/gcloud_decoupled_test_logs + if-no-files-found: warn + + + diff --git a/.github/workflows/upstream_sync.yml b/.github/workflows/upstream_sync.yml new file mode 100644 index 0000000000..d19fe20053 --- /dev/null +++ b/.github/workflows/upstream_sync.yml @@ -0,0 +1,51 @@ +name: Upstream Sync + +on: + workflow_dispatch: {} + schedule: + - cron: '30 4 * * *' # Daily 04:30 UTC + +permissions: + contents: write + +jobs: + sync: + runs-on: ubuntu-latest + if: github.ref == 'refs/heads/main' + steps: + - name: Checkout + uses: actions/checkout@v4 + with: + fetch-depth: 0 + persist-credentials: false + + - name: Configure git + run: | + git config user.name "github-actions[bot]" + git config user.email "github-actions[bot]@users.noreply.github.com" + + - name: Fetch upstream + run: | + git remote add upstream https://github.com/google/maxtext.git 2>/dev/null || true + git fetch upstream main + git checkout main + git pull --ff-only origin main || true + + - name: Merge upstream/main + id: merge_step + run: | + set -e + git merge --no-edit upstream/main || { echo "::error::Merge conflict - resolve manually"; git merge --abort || true; exit 1; } + if git diff --quiet origin/main..main; then echo "no_changes=true" >> $GITHUB_OUTPUT; else echo "no_changes=false" >> $GITHUB_OUTPUT; fi + + - name: Push (if changed) + if: steps.merge_step.outputs.no_changes == 'false' + env: + PAT: ${{ secrets.UPSTREAM_SYNC_TOKEN }} + run: | + [ -z "$PAT" ] && echo "::error::Missing UPSTREAM_SYNC_TOKEN secret" && exit 1 + git push https://x-access-token:$PAT@github.com/${{ github.repository }}.git main + + - name: Result + run: | + if [ "${{ steps.merge_step.outputs.no_changes }}" = "true" ]; then echo "Up to date"; else echo "Synced upstream"; fi From 070adee2411000dc7dac5cb6eea269476c3466c2 Mon Sep 17 00:00:00 2001 From: gulsumgudukbay Date: Mon, 9 Feb 2026 20:23:19 +0000 Subject: [PATCH 03/21] skip ring attention test on ROCm --- tests/integration/train_tests.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/integration/train_tests.py b/tests/integration/train_tests.py index fc27753abe..9d10f62686 100644 --- a/tests/integration/train_tests.py +++ b/tests/integration/train_tests.py @@ -18,6 +18,7 @@ import pytest import jax +from jax._src import test_util as jtu from absl.testing import absltest from maxtext.common.gcloud_stub import is_decoupled from maxtext.trainers.pre_train.train import main as train_main From 87a13a8c9763b3f3be2c67f8f653f0b8fe2b1f6b Mon Sep 17 00:00:00 2001 From: gulsumgudukbay Date: Wed, 11 Feb 2026 16:53:55 +0000 Subject: [PATCH 04/21] [DOWNSTREAM-ONLY] update schedule for build_and_test_maxtext --- .github/workflows/build_and_test_maxtext.yml | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/.github/workflows/build_and_test_maxtext.yml b/.github/workflows/build_and_test_maxtext.yml index 6949bf8d7e..0e228a86ce 100644 --- a/.github/workflows/build_and_test_maxtext.yml +++ b/.github/workflows/build_and_test_maxtext.yml @@ -21,8 +21,7 @@ on: workflow_call: workflow_dispatch: schedule: - # Run the job every 4 hours - - cron: '0 */4 * * *' + - cron: '0 4 * * *' # Daily 04:00 UTC concurrency: # Dedup pull requests (canceling previous runs of the same workflow for same PR), and scheduled runs but nothing else From 65d3f9d0ca366a35cbb16024c86c1c39aa553194 Mon Sep 17 00:00:00 2001 From: gulsumgudukbay Date: Wed, 11 Feb 2026 21:53:58 +0000 Subject: [PATCH 05/21] adding jax 0.8.2 requirements --- .../requirements_decoupled_jax_0_8_2.txt | 38 ++++++++++++ .../requirements_decoupled_rocm_jax_0_8_2.txt | 45 ++++++++++++++ .../requirements_rocm_jax_0.8.2.txt | 58 +++++++++++++++++++ 3 files changed, 141 insertions(+) create mode 100644 src/dependencies/requirements/requirements_decoupled_jax_0_8_2.txt create mode 100644 src/dependencies/requirements/requirements_decoupled_rocm_jax_0_8_2.txt create mode 100644 src/dependencies/requirements/requirements_rocm_jax_0.8.2.txt diff --git a/src/dependencies/requirements/requirements_decoupled_jax_0_8_2.txt b/src/dependencies/requirements/requirements_decoupled_jax_0_8_2.txt new file mode 100644 index 0000000000..8f85aa10e0 --- /dev/null +++ b/src/dependencies/requirements/requirements_decoupled_jax_0_8_2.txt @@ -0,0 +1,38 @@ +absl-py +aqtp +array-record +datasets +flax +grain[parquet] +huggingface_hub +jax==0.8.2 +jaxlib==0.8.2 +jaxtyping +jsonlines +ml-collections +numpy +omegaconf +optax +orbax-checkpoint +pathwaysutils +pillow +pre-commit +protobuf +pydantic +pyink +pylint +pytest +parameterized +pytype +sentencepiece +tensorboardx +tensorflow-datasets +tensorflow-text<2.20 +tensorflow<2.20 +tiktoken +numba>=0.59.0 +google-tunix +tokamax @ git+https://github.com/openxla/tokamax.git@69b328121c3ee120d8e54cf26d1565cea189617f +transformers>=4.57.3,<5 +qwix +mlperf-logging @ https://github.com/mlcommons/logging/archive/38ab22670527888c8eb7825a4ece176fcc36a95d.zip diff --git a/src/dependencies/requirements/requirements_decoupled_rocm_jax_0_8_2.txt b/src/dependencies/requirements/requirements_decoupled_rocm_jax_0_8_2.txt new file mode 100644 index 0000000000..7f78ab5749 --- /dev/null +++ b/src/dependencies/requirements/requirements_decoupled_rocm_jax_0_8_2.txt @@ -0,0 +1,45 @@ +absl-py +aqtp +array-record +datasets +flax +grain[parquet] +huggingface_hub + + +# ROCm JAX 0.8.2 (py312) wheels (install order matters) +https://github.com/ROCm/rocm-jax/releases/download/rocm-jax-v0.8.2/jax_rocm7_pjrt-0.8.2+rocm7.2.0-py3-none-manylinux_2_28_x86_64.whl +https://github.com/ROCm/rocm-jax/releases/download/rocm-jax-v0.8.2/jax_rocm7_plugin-0.8.2+rocm7.2.0-cp312-cp312-manylinux_2_28_x86_64.whl +https://github.com/ROCm/rocm-jax/releases/download/rocm-jax-v0.8.2/jaxlib-0.8.2+rocm7-cp312-cp312-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl +jax==0.8.2 + + +jaxtyping +jsonlines +ml-collections +numpy +omegaconf +optax +orbax-checkpoint +pathwaysutils +pillow +pre-commit +protobuf +pydantic +pyink +pylint +pytest +parameterized +pytype +sentencepiece +tensorboardx +tensorflow-datasets +tensorflow-text<2.20 +tensorflow<2.20 +tiktoken +numba>=0.59.0 +google-tunix +tokamax @ git+https://github.com/openxla/tokamax.git@69b328121c3ee120d8e54cf26d1565cea189617f +transformers>=4.57.3,<5 +qwix +mlperf-logging @ https://github.com/mlcommons/logging/archive/38ab22670527888c8eb7825a4ece176fcc36a95d.zip diff --git a/src/dependencies/requirements/requirements_rocm_jax_0.8.2.txt b/src/dependencies/requirements/requirements_rocm_jax_0.8.2.txt new file mode 100644 index 0000000000..aa2817c218 --- /dev/null +++ b/src/dependencies/requirements/requirements_rocm_jax_0.8.2.txt @@ -0,0 +1,58 @@ +absl-py +aqtp +array-record +datasets +flax +grain[parquet] +huggingface_hub + + +# ROCm JAX 0.8.2 (py312) wheels (install order matters) +https://github.com/ROCm/rocm-jax/releases/download/rocm-jax-v0.8.2/jax_rocm7_pjrt-0.8.2+rocm7.2.0-py3-none-manylinux_2_28_x86_64.whl +https://github.com/ROCm/rocm-jax/releases/download/rocm-jax-v0.8.2/jax_rocm7_plugin-0.8.2+rocm7.2.0-cp312-cp312-manylinux_2_28_x86_64.whl +https://github.com/ROCm/rocm-jax/releases/download/rocm-jax-v0.8.2/jaxlib-0.8.2+rocm7-cp312-cp312-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl +jax==0.8.2 + + +jaxtyping +jsonlines +ml-collections +numpy +omegaconf +optax +orbax-checkpoint +pathwaysutils +pillow +pre-commit +protobuf +pydantic +pyink +pylint +pytest +parameterized +pytype +sentencepiece +tensorboardx +tensorflow-datasets +tensorflow-text<2.20 +tensorflow<2.20 +tiktoken +numba>=0.59.0 +google-tunix +tokamax @ git+https://github.com/openxla/tokamax.git@69b328121c3ee120d8e54cf26d1565cea189617f +transformers>=4.57.3,<5 +qwix +mlperf-logging @ https://github.com/mlcommons/logging/archive/38ab22670527888c8eb7825a4ece176fcc36a95d.zip + + +# GCP / Cloud deps (restored for non-decoupled ROCm runs) +cloud-accelerator-diagnostics +cloud-tpu-diagnostics +gcsfs +google-api-python-client +google-cloud-aiplatform +google-cloud-mldiagnostics +google-cloud-monitoring +ml-goodput-measurement +tensorboard-plugin-profile +google-jetstream @ https://github.com/AI-Hypercomputer/JetStream/archive/29329e8e73820993f77cfc8efe34eb2a73f5de98.zip From d1fd5f0913c4ad851ff62313d19fea9d8634d5ef Mon Sep 17 00:00:00 2001 From: gulsumgudukbay Date: Thu, 12 Feb 2026 16:14:17 +0000 Subject: [PATCH 06/21] update configs in tests to use helper functions --- tests/utils/test_helpers.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/tests/utils/test_helpers.py b/tests/utils/test_helpers.py index 2eea9025c7..e4ad9f1cd9 100644 --- a/tests/utils/test_helpers.py +++ b/tests/utils/test_helpers.py @@ -122,11 +122,23 @@ def get_test_base_output_directory(cloud_path=None): return cloud_path or "gs://runner-maxtext-logs" +def get_test_config_path_for(relative_path: str): + """Return absolute path for a config under the configs directory. + + Uses the same decoupled-vs-non-decoupled selection logic as + `get_test_config_path()` when `relative_path` is `base.yml`. + """ + if relative_path == "base.yml": + return get_test_config_path() + return os.path.join(MAXTEXT_CONFIGS_DIR, relative_path) + + __all__ = [ "get_test_base_output_directory", "get_decoupled_parallelism_overrides", "is_rocm_backend", "get_test_config_path", "get_post_train_test_config_path", + "get_test_config_path_for", "get_test_dataset_path", ] From 280fd876e446361972a9590f2d14111a5813c42e Mon Sep 17 00:00:00 2001 From: gulsumgudukbay Date: Mon, 16 Feb 2026 20:10:45 +0000 Subject: [PATCH 07/21] adding TE build and upload CI workflow --- ...d_rocm_transformer_engine_wheel_weekly.yml | 314 ++++++++++++++++++ 1 file changed, 314 insertions(+) create mode 100644 .github/workflows/build_rocm_transformer_engine_wheel_weekly.yml diff --git a/.github/workflows/build_rocm_transformer_engine_wheel_weekly.yml b/.github/workflows/build_rocm_transformer_engine_wheel_weekly.yml new file mode 100644 index 0000000000..693c175d3d --- /dev/null +++ b/.github/workflows/build_rocm_transformer_engine_wheel_weekly.yml @@ -0,0 +1,314 @@ +name: Build ROCm TransformerEngine wheel (weekly) + +on: + workflow_dispatch: + schedule: + # Weekly at night (02:00 UTC every Monday), 2 hours ahead of scheduled tests. + - cron: "0 2 * * 1" + +permissions: + contents: write + +jobs: + build_upload_prune: + # Same runner label used by ROCm test workflows. + runs-on: ["self-hosted", "linux-x86-64-4gpu-amd"] + container: + image: ghcr.io/rocm/jax-base-ubu24.rocm720:latest + options: >- + --device=/dev/kfd --device=/dev/dri --group-add video + --ipc=host --shm-size 64g + --cap-add=SYS_PTRACE --security-opt seccomp=unconfined + --privileged + env: + XLA_PYTHON_CLIENT_MEM_FRACTION: "0.9" + NVTE_FUSED_ATTN_AOTRITON: "0" + + steps: + - name: Checkout + uses: actions/checkout@v5 + + - name: Build & publish TE wheels (native arch then MI300) + shell: bash + env: + GITHUB_TOKEN: ${{ github.token }} + run: | + set -euo pipefail + apt-get update + apt-get install -y --no-install-recommends git build-essential python3-dev + python3 -m pip install -U uv + python3 -m uv venv --seed + source .venv/bin/activate + uv pip install -U pip setuptools wheel pybind11 cmake + # Install ROCm JAX/JAXlib wheels into the venv (so TE builds against the same stack as CI). + uv pip install -r dependencies/requirements/requirements_decoupled_rocm_jax_0_8_2.txt + + # Detect runner ROCm architecture (avoid generic tokens like gfx9). + PRIMARY_ARCH="$( + (command -v rocminfo >/dev/null 2>&1 && rocminfo || /opt/rocm/bin/rocminfo) \ + | grep -oE 'gfx[0-9a-f]+' \ + | sort -u \ + | awk '{ print length($0), $0 }' \ + | sort -nr \ + | head -n1 \ + | awk '{ print $2 }' + )" + echo "Detected runner ROCm arch: ${PRIMARY_ARCH}" + + # Detect ROCm version number from JAX backend string when available (e.g. 'rocm 70200'). + ROCM_NUM="$(python3 -c 'import re, jax; s=str(jax.devices()[0].client.platform_version); m=re.search(r"rocm\\s+([0-9]+)", s); print(m.group(1) if m else "unknown")')" + echo "Detected ROCm version: ${ROCM_NUM}" + + # Build TE from ROCm/TransformerEngine dev branch. + rm -rf TransformerEngine + git clone --recursive --branch dev https://github.com/ROCm/TransformerEngine.git + cd TransformerEngine + git submodule update --init --recursive + + TE_SHA="$(git rev-parse --short=12 HEAD)" + PYTAG="cp$(python3 -c 'import sys; print(f"{sys.version_info.major}{sys.version_info.minor}")')" + + export USE_ROCM=1 + export HIP_PATH=/opt/rocm + export NVTE_FRAMEWORK=jax + export CMAKE_BUILD_PARALLEL_LEVEL=64 + export NVTE_USE_ROCM=1 + export NVTE_FUSED_ATTN_AOTRITON=0 + export NVTE_BUILD_MAX_JOBS=180 + + # Return to workspace root; build function will enter TE directory. + cd .. + + build_one() { + local arch="$1" + echo "=== Building TE wheel for ${arch} ===" + pushd TransformerEngine >/dev/null + rm -rf build dist + export PYTORCH_ROCM_ARCH="${arch}" + export NVTE_ROCM_ARCH="${arch}" + python3 setup.py bdist_wheel + local wheel_path + wheel_path="$(ls -1 dist/transformer_engine-*.whl | head -n1)" + popd >/dev/null + local asset_name + asset_name="transformer_engine-${TE_SHA}-${PYTAG}-rocm${ROCM_NUM}-${arch}.whl" + cp -f "TransformerEngine/${wheel_path}" "${asset_name}" + echo "${asset_name}" + } + + publish_asset_to_release_tag() { + local tag="$1" + local title="$2" + local body="$3" + local wheel_file="$4" + python3 - "$tag" "$title" "$body" "$wheel_file" <<'PY' + import json, os, sys, urllib.error, urllib.parse, urllib.request + + tag, title, body, wheel_file = sys.argv[1:5] + token = os.environ["GITHUB_TOKEN"] + repo = os.environ.get("GITHUB_REPOSITORY") + if not repo: + raise SystemExit("GITHUB_REPOSITORY is not set") + owner, name = repo.split("/", 1) + + api = "https://api.github.com" + headers = { + "Authorization": f"Bearer {token}", + "Accept": "application/vnd.github+json", + "User-Agent": "maxtext-ci", + } + + def request_json(method: str, url: str, body_obj=None): + data = None + if body_obj is not None: + data = json.dumps(body_obj).encode("utf-8") + req = urllib.request.Request(url, data=data, method=method, headers=headers) + with urllib.request.urlopen(req) as r: + return json.loads(r.read().decode("utf-8")) + + def request_raw(method: str, url: str, data: bytes, extra_headers=None): + h = dict(headers) + if extra_headers: + h.update(extra_headers) + req = urllib.request.Request(url, data=data, method=method, headers=h) + with urllib.request.urlopen(req) as r: + return r.read() + + # Ensure release exists for this tag. + release = None + try: + release = request_json("GET", f"{api}/repos/{owner}/{name}/releases/tags/{tag}") + except urllib.error.HTTPError as e: + if e.code != 404: + raise + + if release is None: + release = request_json( + "POST", + f"{api}/repos/{owner}/{name}/releases", + { + "tag_name": tag, + "name": title, + "body": body, + "prerelease": True, + }, + ) + + release_id = release["id"] + upload_url = release["upload_url"].split("{", 1)[0] + + # Delete any existing asset with same name. + assets = request_json("GET", f"{api}/repos/{owner}/{name}/releases/{release_id}")["assets"] + wheel_name = os.path.basename(wheel_file) + for a in assets: + if a.get("name") == wheel_name: + request_json("DELETE", f"{api}/repos/{owner}/{name}/releases/assets/{a['id']}") + + with open(wheel_file, "rb") as f: + wheel_bytes = f.read() + up = f"{upload_url}?{urllib.parse.urlencode({'name': wheel_name})}" + request_raw("POST", up, wheel_bytes, extra_headers={"Content-Type": "application/octet-stream"}) + print(f"Uploaded {wheel_name} to release tag {tag}", flush=True) + PY + } + + prune_old_weekly_releases() { + local prefix="$1" + local keep_days="$2" + python3 - "$prefix" "$keep_days" <<'PY' + import datetime as dt + import json, os, sys, urllib.request + + prefix = sys.argv[1] + keep_days = int(sys.argv[2]) + token = os.environ["GITHUB_TOKEN"] + repo = os.environ.get("GITHUB_REPOSITORY") + if not repo: + raise SystemExit("GITHUB_REPOSITORY is not set") + + api = "https://api.github.com" + headers = { + "Authorization": f"Bearer {token}", + "Accept": "application/vnd.github+json", + "User-Agent": "maxtext-ci", + } + + def get_json(url: str): + req = urllib.request.Request(url, headers=headers) + with urllib.request.urlopen(req) as r: + return json.loads(r.read().decode("utf-8")) + + def delete(url: str): + req = urllib.request.Request(url, method="DELETE", headers=headers) + with urllib.request.urlopen(req): + return + + cutoff = dt.datetime.now(dt.timezone.utc) - dt.timedelta(days=keep_days) + + # Paginate releases. + page = 1 + deleted = 0 + while True: + releases = get_json(f"{api}/repos/{repo}/releases?per_page=100&page={page}") + if not releases: + break + for rel in releases: + tag = rel.get("tag_name", "") + if not tag.startswith(prefix): + continue + created = dt.datetime.fromisoformat(rel["created_at"].replace("Z", "+00:00")) + if created < cutoff: + delete(f"{api}/repos/{repo}/releases/{rel['id']}") + deleted += 1 + print(f"Deleted old weekly TE release {tag} (created_at={rel['created_at']})", flush=True) + page += 1 + print(f"Weekly TE release prune complete. Deleted {deleted} releases older than {keep_days} days.", flush=True) + PY + } + + prune_old_assets_in_release_tag() { + local tag="$1" + local keep_days="$2" + python3 - "$tag" "$keep_days" <<'PY' + import datetime as dt + import json, os, sys, urllib.error, urllib.request + + tag = sys.argv[1] + keep_days = int(sys.argv[2]) + token = os.environ["GITHUB_TOKEN"] + repo = os.environ.get("GITHUB_REPOSITORY") + if not repo: + raise SystemExit("GITHUB_REPOSITORY is not set") + owner, name = repo.split("/", 1) + + api = "https://api.github.com" + headers = { + "Authorization": f"Bearer {token}", + "Accept": "application/vnd.github+json", + "User-Agent": "maxtext-ci", + } + + def request_json(method: str, url: str): + req = urllib.request.Request(url, method=method, headers=headers) + with urllib.request.urlopen(req) as r: + data = r.read() + return json.loads(data.decode("utf-8")) if data else None + + def delete(url: str): + req = urllib.request.Request(url, method="DELETE", headers=headers) + with urllib.request.urlopen(req): + return + + try: + rel = request_json("GET", f"{api}/repos/{owner}/{name}/releases/tags/{tag}") + except urllib.error.HTTPError as e: + if e.code == 404: + print(f"No release for tag {tag}; skipping asset prune.", flush=True) + raise SystemExit(0) + raise + + cutoff = dt.datetime.now(dt.timezone.utc) - dt.timedelta(days=keep_days) + pruned = 0 + for a in rel.get("assets", []): + created = dt.datetime.fromisoformat(a["created_at"].replace("Z", "+00:00")) + if created < cutoff: + delete(f"{api}/repos/{owner}/{name}/releases/assets/{a['id']}") + pruned += 1 + print(f"Pruned old asset: {a['name']} (created_at={a['created_at']})", flush=True) + print(f"Asset prune complete for {tag}. Deleted {pruned} assets older than {keep_days} days.", flush=True) + PY + } + + # Build runner-native wheel and upload immediately so other CI can pick it up. + NATIVE_WHEEL="$(build_one "${PRIMARY_ARCH}")" + ls -lh "${NATIVE_WHEEL}" + publish_asset_to_release_tag \ + "te-rocm-wheels" \ + "ROCm TransformerEngine wheels (latest)" \ + "Rolling release for latest weekly-built ROCm TransformerEngine wheels used by CI." \ + "${NATIVE_WHEEL}" + + # Build MI300 wheel (gfx942) and upload. + MI300_WHEEL="$(build_one "gfx942")" + ls -lh "${MI300_WHEEL}" + publish_asset_to_release_tag \ + "te-rocm-wheels" \ + "ROCm TransformerEngine wheels (latest)" \ + "Rolling release for latest weekly-built ROCm TransformerEngine wheels used by CI." \ + "${MI300_WHEEL}" + + # Retention: prune rolling release assets older than 3 weeks. + prune_old_assets_in_release_tag "te-rocm-wheels" 21 + + # Also publish both wheels to a new weekly release page. + DATE_UTC="$(date -u +%Y-%m-%d)" + WEEKLY_TAG="te-rocm-wheels-${DATE_UTC}-${TE_SHA}" + WEEKLY_TITLE="ROCm TransformerEngine wheels ${DATE_UTC} (TE ${TE_SHA})" + WEEKLY_BODY="Built from ROCm/TransformerEngine dev @ ${TE_SHA} on ${DATE_UTC}.\nROCm=${ROCM_NUM}, Python=${PYTAG}, arches=${PRIMARY_ARCH} and gfx942." + + publish_asset_to_release_tag "${WEEKLY_TAG}" "${WEEKLY_TITLE}" "${WEEKLY_BODY}" "${NATIVE_WHEEL}" + publish_asset_to_release_tag "${WEEKLY_TAG}" "${WEEKLY_TITLE}" "${WEEKLY_BODY}" "${MI300_WHEEL}" + + # Retention: delete weekly releases older than 3 weeks. + prune_old_weekly_releases "te-rocm-wheels-" 21 + From b40ce36048ec78399be61b108087df3e184a92ff Mon Sep 17 00:00:00 2001 From: gulsumgudukbay Date: Tue, 3 Feb 2026 23:26:45 +0000 Subject: [PATCH 08/21] Adding CI workflow changes for ROCm and JAX 0.8.2 requirements files --- .github/workflows/build_and_test_maxtext.yml | 74 +++++-- ...d_rocm_transformer_engine_wheel_weekly.yml | 1 - .github/workflows/decouple_tests.yml | 183 ------------------ .../workflows/run_tests_against_package.yml | 92 ++++++++- .github/workflows/run_tests_coordinator.yml | 62 ++++-- .../workflows/utils/install_te_rocm_wheel.py | 144 ++++++++++++++ .../requirements_decoupled_jax_0_8_2.txt | 2 + .../requirements_decoupled_rocm_jax_0_8_2.txt | 2 + .../requirements_rocm_jax_0.8.2.txt | 2 + 9 files changed, 340 insertions(+), 222 deletions(-) delete mode 100644 .github/workflows/decouple_tests.yml create mode 100644 .github/workflows/utils/install_te_rocm_wheel.py diff --git a/.github/workflows/build_and_test_maxtext.yml b/.github/workflows/build_and_test_maxtext.yml index 0e228a86ce..6f971d1b5a 100644 --- a/.github/workflows/build_and_test_maxtext.yml +++ b/.github/workflows/build_and_test_maxtext.yml @@ -20,15 +20,23 @@ on: pull_request: workflow_call: workflow_dispatch: + inputs: + rocm_only: + description: 'Run only ROCm jobs (manual runs only)' + type: boolean + required: false + default: false schedule: - - cron: '0 4 * * *' # Daily 04:00 UTC + - cron: '0 3 * * *' # Daily 03:00 UTC concurrency: - # Dedup pull requests (canceling previous runs of the same workflow for same PR), and scheduled runs but nothing else + # Cancel previous runs for same PR (all actors), scheduled runs, + # and manual runs per (branch + actor). group: > ${{ github.event_name == 'pull_request' && format('{0}-pr-{1}', github.workflow, github.event.pull_request.number) || github.event_name == 'schedule' && format('{0}-schedule', github.workflow) || + github.event_name == 'workflow_dispatch' && format('{0}-manual-{1}-{2}', github.workflow, github.ref, github.actor) || github.run_id }} cancel-in-progress: true @@ -117,18 +125,16 @@ jobs: build_and_upload_maxtext_package: needs: analyze_code_changes # Run if either tests or notebooks need to run - if: | - needs.analyze_code_changes.outputs.run_tests == 'true' || - needs.analyze_code_changes.outputs.run_notebooks == 'true' + if: ${{ vars.ROCM_ONLY == 'true' || needs.analyze_code_changes.outputs.run_tests == 'true' || needs.analyze_code_changes.outputs.run_notebooks == 'true' }} uses: ./.github/workflows/build_package.yml with: - device_type: tpu - device_name: v4-8 - cloud_runner: linux-x86-n2-16-buildkit + device_type: ${{ vars.ROCM_ONLY == 'true' && 'rocm' || 'tpu' }} + device_name: ${{ vars.ROCM_ONLY == 'true' && 'mi355' || 'v4-8' }} + cloud_runner: ${{ vars.ROCM_ONLY == 'true' && 'linux-x86-64-4gpu-amd' || 'linux-x86-n2-16-buildkit' }} maxtext_jupyter_notebooks: needs: build_and_upload_maxtext_package - if: needs.analyze_code_changes.outputs.run_notebooks == 'true' + if: ${{ vars.ROCM_ONLY != 'true' && !(github.event_name == 'workflow_dispatch' && inputs.rocm_only) && needs.analyze_code_changes.outputs.run_notebooks == 'true' }} uses: ./.github/workflows/run_jupyter_notebooks.yml strategy: fail-fast: false @@ -144,7 +150,7 @@ jobs: tpu-tests: name: ${{ matrix.flavor }} tests needs: [build_and_upload_maxtext_package] - if: needs.analyze_code_changes.outputs.run_tests == 'true' + if: ${{ vars.ROCM_ONLY != 'true' && !(github.event_name == 'workflow_dispatch' && inputs.rocm_only) && needs.analyze_code_changes.outputs.run_tests == 'true' }} uses: ./.github/workflows/run_tests_coordinator.yml strategy: fail-fast: false @@ -159,7 +165,7 @@ jobs: gpu-tests: name: ${{ matrix.flavor }} tests needs: [build_and_upload_maxtext_package] - if: needs.analyze_code_changes.outputs.run_tests == 'true' + if: ${{ vars.ROCM_ONLY != 'true' && !(github.event_name == 'workflow_dispatch' && inputs.rocm_only) && needs.analyze_code_changes.outputs.run_tests == 'true' }} strategy: fail-fast: false matrix: @@ -174,7 +180,7 @@ jobs: cpu-tests: name: ${{ matrix.flavor }} tests needs: [build_and_upload_maxtext_package] - if: needs.analyze_code_changes.outputs.run_tests == 'true' + if: ${{ vars.ROCM_ONLY != 'true' && !(github.event_name == 'workflow_dispatch' && inputs.rocm_only) && needs.analyze_code_changes.outputs.run_tests == 'true' }} uses: ./.github/workflows/run_tests_coordinator.yml strategy: fail-fast: false @@ -188,7 +194,7 @@ jobs: maxtext_tpu_pathways_unit_tests: needs: build_and_upload_maxtext_package - if: needs.analyze_code_changes.outputs.run_tests == 'true' + if: ${{ vars.ROCM_ONLY != 'true' && !(github.event_name == 'workflow_dispatch' && inputs.rocm_only) && needs.analyze_code_changes.outputs.run_tests == 'true' }} uses: ./.github/workflows/run_pathways_tests.yml strategy: fail-fast: false @@ -207,7 +213,7 @@ jobs: maxtext_tpu_pathways_integration_tests: needs: build_and_upload_maxtext_package - if: needs.analyze_code_changes.outputs.run_tests == 'true' + if: ${{ vars.ROCM_ONLY != 'true' && !(github.event_name == 'workflow_dispatch' && inputs.rocm_only) && needs.analyze_code_changes.outputs.run_tests == 'true' }} uses: ./.github/workflows/run_pathways_tests.yml strategy: fail-fast: false @@ -224,9 +230,39 @@ jobs: is_scheduled_run: ${{ github.event_name == 'schedule' }} maxtext_sha: ${{ needs.build_and_upload_maxtext_package.outputs.maxtext_sha }} + rocm-tests: + name: ${{ matrix.flavor }} tests + needs: [build_and_upload_maxtext_package] + if: ${{ vars.ROCM_ONLY != 'true' && needs.analyze_code_changes.outputs.run_tests == 'true' }} + uses: ./.github/workflows/run_tests_coordinator.yml + strategy: + fail-fast: false + matrix: + flavor: [rocm-unit] + with: + flavor: ${{ matrix.flavor }} + base_image: 'rocm-placeholder' + is_scheduled_run: ${{ github.event_name == 'schedule' }} + maxtext_sha: ${{ needs.build_and_upload_maxtext_package.outputs.maxtext_sha }} + + rocm-decoupled-tests: + name: ${{ matrix.flavor }} tests + needs: [build_and_upload_maxtext_package] + if: ${{ vars.ROCM_ONLY == 'true' || (github.event_name == 'workflow_dispatch' && inputs.rocm_only) || needs.analyze_code_changes.outputs.run_tests == 'true' }} + uses: ./.github/workflows/run_tests_coordinator.yml + strategy: + fail-fast: false + matrix: + flavor: [rocm-decoupled] + with: + flavor: ${{ matrix.flavor }} + base_image: 'rocm-placeholder' + is_scheduled_run: ${{ github.event_name == 'schedule' }} + maxtext_sha: ${{ needs.build_and_upload_maxtext_package.outputs.maxtext_sha }} + all_tests_passed: name: All Required Tests Passed - needs: [tpu-tests, gpu-tests, cpu-tests, maxtext_tpu_pathways_unit_tests, maxtext_tpu_pathways_integration_tests] + needs: [analyze_code_changes, build_and_upload_maxtext_package, tpu-tests, gpu-tests, cpu-tests, maxtext_tpu_pathways_unit_tests, maxtext_tpu_pathways_integration_tests, rocm-tests, rocm-decoupled-tests] if: always() runs-on: ubuntu-latest steps: @@ -244,6 +280,8 @@ jobs: echo "CPU Tests (Matrix) result: ${NEEDS_CPU_TESTS_RESULT}" echo "Pathways Unit result: ${NEEDS_MAXTEXT_TPU_PATHWAYS_UNIT_TESTS_RESULT}" echo "Pathways Integration result: ${NEEDS_MAXTEXT_TPU_PATHWAYS_INTEGRATION_TESTS_RESULT}" + echo "ROCm Tests (Matrix) result: ${NEEDS_ROCM_TESTS_RESULT}" + echo "ROCm Decoupled Tests (Matrix) result: ${NEEDS_ROCM_DECOUPLED_TESTS_RESULT}" # Fail only if any job failed or was cancelled (skipped is OK) if [ "${{ contains(needs.*.result, 'failure') }}" == "true" ] || [ "${{ contains(needs.*.result, 'cancelled') }}" == "true" ]; then @@ -260,11 +298,13 @@ jobs: NEEDS_GPU_TESTS_RESULT: ${{ needs.gpu-tests.result }} NEEDS_MAXTEXT_TPU_PATHWAYS_UNIT_TESTS_RESULT: ${{ needs.maxtext_tpu_pathways_unit_tests.result }} NEEDS_MAXTEXT_TPU_PATHWAYS_INTEGRATION_TESTS_RESULT: ${{ needs.maxtext_tpu_pathways_integration_tests.result }} + NEEDS_ROCM_TESTS_RESULT: ${{ needs.rocm-tests.result }} + NEEDS_ROCM_DECOUPLED_TESTS_RESULT: ${{ needs.rocm-decoupled-tests.result }} all_notebooks_passed: name: All Notebooks Passed needs: [analyze_code_changes, build_and_upload_maxtext_package, maxtext_jupyter_notebooks] - if: always() + if: ${{ vars.ROCM_ONLY != 'true' && !(github.event_name == 'workflow_dispatch' && inputs.rocm_only) && always() }} runs-on: ubuntu-latest steps: - name: Check notebooks results @@ -292,7 +332,7 @@ jobs: notify_failure: name: Notify failed build # creates an issue or modifies last open existing issue for failed build - needs: [tpu-tests, gpu-tests, cpu-tests, maxtext_jupyter_notebooks, maxtext_tpu_pathways_unit_tests, maxtext_tpu_pathways_integration_tests] + needs: [tpu-tests, gpu-tests, cpu-tests, maxtext_jupyter_notebooks, maxtext_tpu_pathways_unit_tests, maxtext_tpu_pathways_integration_tests, rocm-tests, rocm-decoupled-tests] if: ${{ always() }} runs-on: ubuntu-latest permissions: diff --git a/.github/workflows/build_rocm_transformer_engine_wheel_weekly.yml b/.github/workflows/build_rocm_transformer_engine_wheel_weekly.yml index 693c175d3d..d0bd1c6c79 100644 --- a/.github/workflows/build_rocm_transformer_engine_wheel_weekly.yml +++ b/.github/workflows/build_rocm_transformer_engine_wheel_weekly.yml @@ -311,4 +311,3 @@ jobs: # Retention: delete weekly releases older than 3 weeks. prune_old_weekly_releases "te-rocm-wheels-" 21 - diff --git a/.github/workflows/decouple_tests.yml b/.github/workflows/decouple_tests.yml deleted file mode 100644 index 2003441b06..0000000000 --- a/.github/workflows/decouple_tests.yml +++ /dev/null @@ -1,183 +0,0 @@ -name: Decoupled Offline Tests - -on: - pull_request: - branches: - - rocm-main - paths: - - 'maxtext/**' - - '.github/workflows/decouple_tests.yml' - - 'pyproject.toml' - - 'README.md' - push: - branches: - - rocm-main - paths: - - 'maxtext/**' - - '.github/workflows/decouple_tests.yml' - workflow_dispatch: - -jobs: - decoupled: - runs-on: ubuntu-latest - env: - DECOUPLE_GCLOUD: TRUE - LOCAL_GCLOUD_PROJECT: ci-decoupled - LOCAL_BASE: datasets - LOCAL_BASE_OUTPUT: datasets/gcloud_decoupled_test_logs - KEEPALIVE: sleep infinity - steps: - - name: Checkout - uses: actions/checkout@v4 - - - name: Pull ROCm JAX base image - run: docker pull rocm/jax:rocm7.0.2-jax0.6.0-py3.12-ubu24 - - - name: Start container - run: | - set -euo pipefail - mkdir -p "${LOCAL_BASE_OUTPUT}" - DOCKER_IMAGE="rocm/jax:rocm7.0.2-jax0.6.0-py3.12-ubu24" - CONTAINER_NAME="decouple_maxtext_tests" - DEVICE_FLAGS="--device=/dev/kfd --device=/dev/dri --group-add video" - RUNTIME_FLAGS="--network=host --ipc=host --shm-size 16G --cap-add=SYS_PTRACE --security-opt seccomp=unconfined" - WORKDIR="/maxtext" - MOUNT="-v ${{ github.workspace }}:${WORKDIR}" - ENV_VARS="-e DECOUPLE_GCLOUD=${DECOUPLE_GCLOUD} -e LOCAL_GCLOUD_PROJECT=${LOCAL_GCLOUD_PROJECT} -e LOCAL_BASE_OUTPUT=${LOCAL_BASE_OUTPUT}" - docker rm -f ${CONTAINER_NAME} >/dev/null 2>&1 || true - docker run -d -it --name ${CONTAINER_NAME} -w ${WORKDIR} ${ENV_VARS} ${RUNTIME_FLAGS} ${DEVICE_FLAGS} ${MOUNT} ${DOCKER_IMAGE} bash -c "${KEEPALIVE}" - echo "Container ${CONTAINER_NAME} started with centralized env vars." >> "${LOCAL_BASE_OUTPUT}/ci.log" - - - name: Install Requirements and Maxtext Package - run: | - set -euo pipefail - CONTAINER_NAME="decouple_maxtext_tests" - docker exec ${CONTAINER_NAME} bash -euo pipefail -c ' - mkdir -p "${LOCAL_BASE_OUTPUT}" - apt-get update - apt-get install -y git vim python3.12-venv python3-pip python3-dev build-essential - python3 -m venv .venv || { echo "venv creation failed"; exit 10; } - . .venv/bin/activate - python -m ensurepip --upgrade || true - python -m pip install --upgrade pip - python -m pip install wheel setuptools cmake - python -m pip install -r requirements_decoupled_rocm_jax_0_6_0.txt --no-cache-dir || { echo "Requirements install failed"; exit 13; } - python -m pip install . --no-cache-dir || { echo "Package install failed"; exit 11; } - echo "Maxtext package installed." >> "${LOCAL_BASE_OUTPUT}/ci.log" - python -m pip install pytest pytest-html pytest-csv minio array-record || { echo "Test deps failed"; exit 12; } - python -m pip list >> "${LOCAL_BASE_OUTPUT}/pip_freeze.txt" - echo "Requirements installed." >> "${LOCAL_BASE_OUTPUT}/ci.log" - ' - - - name: Build and Install TE - run: | - set -euo pipefail - CONTAINER_NAME="decouple_maxtext_tests" - docker exec ${CONTAINER_NAME} bash -euo pipefail -c ' - . .venv/bin/activate - if [ ! -e "TransformerEngine" ]; then - git clone https://github.com/ROCm/TransformerEngine.git - fi - cd TransformerEngine - git reset --hard ec21e1f35f5a36c1def9c89c4865eba3947f0665 - git submodule update --init --recursive - export USE_ROCM=1 - export HIP_PATH=/opt/rocm - export NVTE_FRAMEWORK=jax - export CMAKE_BUILD_PARALLEL_LEVEL=64 - export PYTORCH_ROCM_ARCH=gfx942 - export NVTE_ROCM_ARCH=gfx942 - export NVTE_USE_ROCM=1 - export NVTE_FUSED_ATTN_AOTRITON=0 - export NVTE_BUILD_MAX_JOBS=180 - export CU_NUM=304 - if find dist -maxdepth 1 -name "*.whl" -print -quit | grep -q .; then - echo "TE wheels exist, skip building" - else - python setup.py bdist_wheel || { echo "TransformerEngine build failed"; exit 20; } - fi - python -m pip install dist/transformer_engine-*.whl --no-cache-dir --force-reinstall || { echo "TE wheel install failed"; exit 21; } - cd .. - echo "TE built and installed." >> "${LOCAL_BASE_OUTPUT}/ci.log" - ' - - - name: Generate datasets - run: | - set -euo pipefail - CONTAINER_NAME="decouple_maxtext_tests" - docker exec ${CONTAINER_NAME} bash -euo pipefail -c ' - . .venv/bin/activate - echo "Generating minimal ArrayRecord shards" >> "${LOCAL_BASE_OUTPUT}/ci.log" - python ${LOCAL_BASE}/get_minimal_c4_en_dataset.py --force || { echo "ArrayRecord generation failed"; exit 3; } - echo "Generating minimal parquet shards" >> "${LOCAL_BASE_OUTPUT}/ci.log" - python ${LOCAL_BASE}/get_minimal_hf_c4_parquet.py --force --train-rows 200 --val-rows 40 - echo "Minimal dataset setup complete." >> "${LOCAL_BASE_OUTPUT}/ci.log" - echo "Generating minimal tfrecord shards." >> "${LOCAL_BASE_OUTPUT}/ci.log" - python ${LOCAL_BASE}/convert_arrayrecord_to_tfrecord.py --version-dir ${LOCAL_BASE}/c4_en_dataset_minimal/c4/en/3.0.1 --builder-name __local_c4_builder --force - echo "Minimal tfrecord shards generation complete." >> "${LOCAL_BASE_OUTPUT}/ci.log" - cp -r ${LOCAL_BASE}/c4_en_dataset_minimal/c4/en/3.0.1 ${LOCAL_BASE}/c4_en_dataset_minimal/c4/en/3.1.0 - echo "Generating TFDS metadata." >> "${LOCAL_BASE_OUTPUT}/ci.log" - python ${LOCAL_BASE}/generate_tfds_metadata.py --root ${LOCAL_BASE}/c4_en_dataset_minimal --version 3.0.1 --source-version 3.0.1 --force - python ${LOCAL_BASE}/generate_tfds_metadata.py --root ${LOCAL_BASE}/c4_en_dataset_minimal --version 3.1.0 --source-version 3.0.1 --force - echo "Minimal dataset generation complete!" >> "${LOCAL_BASE_OUTPUT}/ci.log" - ' - - - name: Collect & run decoupled tests - run: | - set -euo pipefail - CONTAINER_NAME="decouple_maxtext_tests" - docker exec ${CONTAINER_NAME} bash -euo pipefail -c ' - . .venv/bin/activate - echo "Collecting tests" >> "${LOCAL_BASE_OUTPUT}/ci.log" - pytest tests -m decoupled --collect-only -q | tee "${LOCAL_BASE_OUTPUT}/collect.txt" - TEST_COUNT="$(grep -v "^$" "${LOCAL_BASE_OUTPUT}/collect.txt" | wc -l)" - echo Collected:$TEST_COUNT >> "${LOCAL_BASE_OUTPUT}/ci.log" - if [ "$TEST_COUNT" -eq 0 ]; then echo "ERROR: No decoupled tests found" >> "${LOCAL_BASE_OUTPUT}/ci.log"; exit 2; fi - pytest -v tests -m decoupled -q --csv=decoupled-tests-report.csv --html=decoupled-tests-report.html --self-contained-html - mv decoupled-tests-report.csv decoupled-tests-report.html "${LOCAL_BASE_OUTPUT}/" - EXEC_COUNT="$(grep -E '^[^,]+,' "${LOCAL_BASE_OUTPUT}/decoupled-tests-report.csv" | wc -l)" - echo Executed:$EXEC_COUNT >> "${LOCAL_BASE_OUTPUT}/ci.log" - echo "### Decoupled Test Run Summary" >> $GITHUB_STEP_SUMMARY - echo Collected: $TEST_COUNT >> $GITHUB_STEP_SUMMARY - echo Executed: $EXEC_COUNT >> $GITHUB_STEP_SUMMARY - echo "Reports: CSV & HTML attached as artifacts." >> $GITHUB_STEP_SUMMARY - ' - - echo "Container ${CONTAINER_NAME} still alive." >> datasets/gcloud_decoupled_test_logs/ci.log - echo "Attach to container with: docker exec -it ${CONTAINER_NAME} bash" >> $GITHUB_STEP_SUMMARY - echo "Remove when done: docker rm -f ${CONTAINER_NAME}" >> $GITHUB_STEP_SUMMARY - - - name: Snapshot container image Artifact - if: always() - run: | - set -euo pipefail - CONTAINER_NAME="decouple_maxtext_tests" - IMAGE_TAG="maxtext-decoupled-tests:latest" - IMAGE_FILE="maxtext-decoupled-tests-image.tar.gz" - if ! docker ps -a --format '{{.Names}}' | grep -q "^${CONTAINER_NAME}$"; then - echo "Container ${CONTAINER_NAME} not found; skipping snapshot." >> "${LOCAL_BASE_OUTPUT}/ci.log" - exit 0 - fi - docker commit ${CONTAINER_NAME} ${IMAGE_TAG} - docker save ${IMAGE_TAG} | gzip -c > ${IMAGE_FILE} - ls -lh ${IMAGE_FILE} >> "${LOCAL_BASE_OUTPUT}/ci.log" || true - echo "Image snapshot saved as artifact candidate: ${IMAGE_FILE}" >> "${LOCAL_BASE_OUTPUT}/ci.log" - - - name: Upload snapshot image - if: always() - uses: actions/upload-artifact@v3 - with: - name: decoupled-docker-image - path: maxtext-decoupled-tests-image.tar.gz - if-no-files-found: warn - - - name: Upload decoupled logs - if: always() - uses: actions/upload-artifact@v3 - with: - name: decoupled-logs - path: datasets/gcloud_decoupled_test_logs - if-no-files-found: warn - - - diff --git a/.github/workflows/run_tests_against_package.yml b/.github/workflows/run_tests_against_package.yml index fd4a59fbba..ca1d774dd7 100644 --- a/.github/workflows/run_tests_against_package.yml +++ b/.github/workflows/run_tests_against_package.yml @@ -40,7 +40,7 @@ on: default: '' is_scheduled_run: required: true - type: string + type: boolean xla_python_client_mem_fraction: required: true type: string @@ -65,6 +65,17 @@ on: description: 'Git SHA to checkout if MaxText is not pre-installed' required: false type: string + decoupled_mode: + required: false + type: boolean + default: false + requirements_file: + required: false + type: string + default: '' + extra_pip_deps_file: + required: false + type: string # Flag to skip source checkout and wheel installation maxtext_installed: description: 'If false, maxtext_sha must be provided for checkout' @@ -80,7 +91,7 @@ jobs: run: runs-on: ${{ inputs.cloud_runner != '' && inputs.cloud_runner || fromJson(format('["self-hosted", "{0}", "{1}"]', inputs.device_type, inputs.device_name)) }} container: - image: gcr.io/${{ vars.PROJECT_NAME }}/${{ inputs.base_image }} + image: ${{ inputs.device_type == 'rocm' && 'ghcr.io/rocm/jax-base-ubu24.rocm720:latest' || format('gcr.io/{0}/{1}', vars.PROJECT_NAME, inputs.base_image) }} env: XLA_PYTHON_CLIENT_MEM_FRACTION: ${{ inputs.xla_python_client_mem_fraction }} TF_FORCE_GPU_ALLOW_GROWTH: ${{ inputs.tf_force_gpu_allow_growth }} @@ -91,6 +102,11 @@ jobs: || (inputs.device_type == 'cpu' && 'tpu' || inputs.device_type) }} ALLOW_MULTIPLE_LIBTPU_LOAD: ${{ inputs.device_type == 'cpu' && 'true' || '' }} # bypass /tmp/libtpu_lockfile check for cpu tests, which don't actually use accelerators (to allow concurrency) + DECOUPLE_GCLOUD: ${{ inputs.decoupled_mode && 'TRUE' || '' }} + LOCAL_GCLOUD_PROJECT: ${{ inputs.decoupled_mode && 'ci-decoupled' || '' }} + # ROCm: prefer CK fused-attention backend over AOTriton for stability. + NVTE_FUSED_ATTN_CK: ${{ inputs.device_type == 'rocm' && '1' || '' }} + NVTE_FUSED_ATTN_AOTRITON: ${{ inputs.device_type == 'rocm' && '0' || '' }} options: ${{ inputs.container_resource_option }} steps: - name: Checkout MaxText @@ -107,20 +123,75 @@ jobs: if: ${{ !inputs.maxtext_installed }} shell: bash run: | + if [ "${{ inputs.device_type }}" = "rocm" ]; then + python3 -m pip install -U uv + fi python3 -m uv venv --seed source .venv/bin/activate maxtext_wheel=$(ls maxtext-*-py3-none-any.whl 2>/dev/null) echo "Installing ${maxtext_wheel} for ${MAXTEXT_PACKAGE_EXTRA}..." - uv pip install ${maxtext_wheel}[${MAXTEXT_PACKAGE_EXTRA}] --resolution=lowest - if [ "${MAXTEXT_PACKAGE_EXTRA}" == "tpu-post-train" ]; then - install_tpu_post_train_extra_deps + if [ "${{ inputs.device_type }}" = "rocm" ]; then + if [ -n "${{ inputs.requirements_file }}" ]; then + echo "Installing requirements from ${{ inputs.requirements_file }}" + uv pip install -r "${{ inputs.requirements_file }}" + fi + uv pip install ${maxtext_wheel} --no-deps + uv pip install -r src/dependencies/github_deps/pre_train_deps.txt else - install_tpu_pre_train_extra_deps + if [ -n "${{ inputs.requirements_file }}" ]; then + echo "Installing requirements from ${{ inputs.requirements_file }}" + uv pip install -r "${{ inputs.requirements_file }}" + uv pip install ${maxtext_wheel} --no-deps + else + uv pip install ${maxtext_wheel}[${MAXTEXT_PACKAGE_EXTRA}] --resolution=lowest + fi + if [ "${MAXTEXT_PACKAGE_EXTRA}" == "tpu-post-train" ]; then + install_tpu_post_train_extra_deps + else + install_tpu_pre_train_extra_deps + fi fi python3 --version python3 -m pip freeze + - name: Install extra pip deps + if: inputs.extra_pip_deps_file != '' + shell: bash + run: | + source .venv/bin/activate + uv pip install -r ${{ inputs.extra_pip_deps_file }} + + - name: Select ROCm arch (mi300/mi355) + if: ${{ inputs.device_type == 'rocm' }} + shell: bash + run: | + set -euo pipefail + echo "=== ROCm arch selection (from install_te_rocm_wheel.py) ===" + if [ -x .venv/bin/python3 ]; then + te_arch="$( + .venv/bin/python3 .github/workflows/utils/install_te_rocm_wheel.py --print-arch + )" + echo "[te detect_arch] ${te_arch}" + echo "TE_WHEEL_ARCH=${te_arch}" >> "${GITHUB_ENV}" + else + echo "No .venv python found; skipping arch selection." + fi + echo "=========================================================" + + - name: Install Transformer Engine wheel (ROCm) + if: ${{ inputs.device_type == 'rocm' }} + shell: bash + env: + GITHUB_TOKEN: ${{ github.token }} + run: | + source .venv/bin/activate + set -euo pipefail + + python3 .github/workflows/utils/install_te_rocm_wheel.py + + uv pip install --no-deps --force-reinstall transformer_engine-*.whl + - name: Copy test assets files - if: ${{ !inputs.maxtext_installed }} + if: ${{ !inputs.maxtext_installed && !inputs.decoupled_mode }} run : gcloud storage cp gs://maxtext-test-assets/* tests/assets - name: Run Tests shell: bash @@ -153,8 +224,8 @@ jobs: export MAXTEXT_ASSETS_ROOT=$(pwd)/src/maxtext/assets export MAXTEXT_TEST_ASSETS_ROOT=$(pwd)/tests/assets export MAXTEXT_PKG_DIR=$(pwd)/src/maxtext - # omit this libtpu init args for gpu tests - if [ "${INPUTS_DEVICE_TYPE}" != "cuda12" ]; then + # omit this libtpu init args for gpu tests (cuda + rocm) + if [ "${INPUTS_DEVICE_TYPE}" != "cuda12" ] && [ "${INPUTS_DEVICE_TYPE}" != "rocm" ]; then export LIBTPU_INIT_ARGS='--xla_tpu_scoped_vmem_limit_kib=65536' else # For cuda12, explicitly point to the pip-installed CUDA libraries @@ -165,6 +236,9 @@ jobs: echo "Warning: Could not find pinned nvidia libraries in .venv." fi fi + if [ "${INPUTS_DEVICE_TYPE}" = "rocm" ]; then + ulimit -c 0 + fi if [ "${INPUTS_TOTAL_WORKERS}" -gt 1 ]; then $PYTHON_EXE -m pip install --quiet pytest-split pytest-xdist SPLIT_ARGS="--splits ${INPUTS_TOTAL_WORKERS} --group ${INPUTS_WORKER_GROUP} -n auto" diff --git a/.github/workflows/run_tests_coordinator.yml b/.github/workflows/run_tests_coordinator.yml index 935b7fbda0..c7bc49ee5c 100644 --- a/.github/workflows/run_tests_coordinator.yml +++ b/.github/workflows/run_tests_coordinator.yml @@ -27,7 +27,9 @@ on: tpu-post-training-unit, tpu-post-training-integration, gpu-unit, gpu-integration, cpu-unit, - cpu-post-training-unit + cpu-post-training-unit, + rocm-unit, + rocm-decoupled ) required: true type: string @@ -80,7 +82,9 @@ jobs: "gpu-unit": "cuda12", "gpu-integration": "cuda12", "cpu-unit": "cpu", - "cpu-post-training-unit": "cpu" + "cpu-post-training-unit": "cpu", + "rocm-unit": "rocm", + "rocm-decoupled": "rocm" }')[inputs.flavor] }} device_name: >- @@ -92,7 +96,9 @@ jobs: "gpu-unit": "a100-40gb-4", "gpu-integration": "a100-40gb-4", "cpu-unit": "X64", - "cpu-post-training-unit": "X64" + "cpu-post-training-unit": "X64", + "rocm-unit": "mi355", + "rocm-decoupled": "mi355" }')[inputs.flavor] }} cloud_runner: >- @@ -104,7 +110,9 @@ jobs: "gpu-unit": "linux-x86-a2-48-a100-4gpu", "gpu-integration": "linux-x86-a2-48-a100-4gpu", "cpu-unit": "linux-x86-n2-32", - "cpu-post-training-unit": "linux-x86-n2-32" + "cpu-post-training-unit": "linux-x86-n2-32", + "rocm-unit": "linux-x86-64-4gpu-amd", + "rocm-decoupled": "linux-x86-64-4gpu-amd" }')[inputs.flavor] }} # Pytest Marker Mapping pytest_marker: >- @@ -116,7 +124,9 @@ jobs: "gpu-unit": "not cpu_only and not tpu_only and not integration_test and not post_training", "gpu-integration": "not cpu_only and not tpu_only and integration_test and not post_training", "cpu-unit": "cpu_only and not post_training", - "cpu-post-training-unit": "cpu_only and post_training" + "cpu-post-training-unit": "cpu_only and post_training", + "rocm-unit": "not cpu_only and not tpu_only and not integration_test and not post_training and not decoupled", + "rocm-decoupled": "decoupled" }')[inputs.flavor] }} pytest_addopts: >- @@ -128,7 +138,9 @@ jobs: "gpu-unit": "", "gpu-integration": "", "cpu-unit": "", - "cpu-post-training-unit": "tests/post_training/unit tests/unit" + "cpu-post-training-unit": "tests/post_training/unit tests/unit", + "rocm-unit": "", + "rocm-decoupled": "" }')[inputs.flavor] }} pytest_extra_args: >- @@ -140,18 +152,44 @@ jobs: "gpu-unit": "--ignore=tests/post_training", "gpu-integration": "--ignore=tests/post_training", "cpu-unit": "--ignore=tests/post_training", - "cpu-post-training-unit": "" + "cpu-post-training-unit": "", + "rocm-unit": "--ignore=tests/post_training", + "rocm-decoupled": "" }')[inputs.flavor] }} ${{ inputs.additional_pytest_args }} # Resource Scaling - xla_python_client_mem_fraction: "${{ contains(inputs.flavor, 'gpu') && '0.65' || '0.75' }}" - tf_force_gpu_allow_growth: "${{ contains(inputs.flavor, 'gpu') && 'true' || 'false' }}" + xla_python_client_mem_fraction: >- + ${{ fromJSON('{ + "gpu-unit": "0.65", + "gpu-integration": "0.65", + "rocm-unit": "0.9", + "rocm-decoupled": "0.9" + }')[inputs.flavor] || '0.75' }} + + tf_force_gpu_allow_growth: >- + ${{ fromJSON('{ + "gpu-unit": "true", + "gpu-integration": "true", + "rocm-unit": "true", + "rocm-decoupled": "true" + }')[inputs.flavor] || 'false' }} container_resource_option: >- - ${{ contains(inputs.flavor, 'gpu') - && '--shm-size 2g --runtime=nvidia --gpus all --privileged' - || '--privileged' }} + ${{ fromJSON('{ + "gpu-unit": "--shm-size 2g --runtime=nvidia --gpus all --privileged", + "gpu-integration": "--shm-size 2g --runtime=nvidia --gpus all --privileged", + "rocm-unit": "--device=/dev/kfd --device=/dev/dri --group-add video --ipc=host --shm-size 64g --cap-add=SYS_PTRACE --security-opt seccomp=unconfined --privileged", + "rocm-decoupled": "--device=/dev/kfd --device=/dev/dri --group-add video --ipc=host --shm-size 64g --cap-add=SYS_PTRACE --security-opt seccomp=unconfined --privileged" + }')[inputs.flavor] || '--privileged' }} + + # ROCm-specific parameters + decoupled_mode: ${{ contains(inputs.flavor, 'rocm') }} + requirements_file: >- + ${{ fromJSON('{ + "rocm-unit": "src/dependencies/requirements/requirements_rocm_jax_0.8.2.txt", + "rocm-decoupled": "src/dependencies/requirements/requirements_decoupled_rocm_jax_0_8_2.txt" + }')[inputs.flavor] || '' }} # Metadata base_image: ${{ inputs.base_image }} diff --git a/.github/workflows/utils/install_te_rocm_wheel.py b/.github/workflows/utils/install_te_rocm_wheel.py new file mode 100644 index 0000000000..a2a16b225c --- /dev/null +++ b/.github/workflows/utils/install_te_rocm_wheel.py @@ -0,0 +1,144 @@ +#!/usr/bin/env python3 +""" +ROCm CI helper: +- Detect MI300 vs MI355 +- Prefer wheel from this repo's 'te-rocm-wheels' release assets +- Fallback to pinned ROCm/maxtext release assets (arch-specific) + +This script only downloads the wheel into the current working directory. +The caller should then install it (e.g. `uv pip install transformer_engine-*.whl`). +""" + +from __future__ import annotations + +import argparse +import json +import os +import re +import subprocess +import urllib.error +import urllib.request + + +def _run(cmd: str) -> str: + try: + return subprocess.check_output(["bash", "-lc", cmd], text=True, stderr=subprocess.STDOUT) + except (subprocess.CalledProcessError, FileNotFoundError, OSError): + return "" + + +def detect_arch() -> str: + """ + Return wheel arch selector: 'mi300' or 'mi355'. + + Notes: + - MI300 family commonly reports gfx942/gfx941. + - MI350/MI355 family commonly reports gfx950 and product strings like MI350X. + """ + override = os.environ.get("TE_WHEEL_ARCH", "").strip().lower() + if override in {"mi300", "mi355"}: + return override + + rocm_smi = _run("command -v rocm-smi >/dev/null 2>&1 && rocm-smi --showproductname || true") or _run( + "[ -x /opt/rocm/bin/rocm-smi ] && /opt/rocm/bin/rocm-smi --showproductname || true" + ) + rocminfo = _run("command -v rocminfo >/dev/null 2>&1 && rocminfo || true") or _run( + "[ -x /opt/rocm/bin/rocminfo ] && /opt/rocm/bin/rocminfo || true" + ) + + blob = f"{rocm_smi}\n{rocminfo}".lower() + gfxs = sorted(set(re.findall(r"gfx[0-9a-f]+", blob))) + + # Prefer explicit gfx IDs when available. + if "gfx950" in gfxs: + return "mi355" + if "gfx942" in gfxs or "gfx941" in gfxs: + return "mi300" + + # Fall back to product string checks. + if "mi355" in blob or "mi350" in blob: + return "mi355" + if "mi300" in blob: + return "mi300" + + # Safe default. + return "mi300" + + +def _headers() -> dict[str, str]: + token = os.environ.get("GITHUB_TOKEN", "") + headers: dict[str, str] = { + "Accept": "application/vnd.github+json", + "User-Agent": "maxtext-ci", + } + if token: + headers["Authorization"] = f"Bearer {token}" + return headers + + +def download(url: str, out_name: str) -> None: + req = urllib.request.Request(url, headers=_headers()) + with urllib.request.urlopen(req) as r, open(out_name, "wb") as f: + f.write(r.read()) + print(f"[te wheel] downloaded {out_name}", flush=True) + + +def try_download_from_te_rocm_wheels(repo: str, arch: str) -> bool: + """Download from this repo's 'te-rocm-wheels' release tag if present.""" + api = f"https://api.github.com/repos/{repo}/releases/tags/te-rocm-wheels" + req = urllib.request.Request(api, headers=_headers()) + with urllib.request.urlopen(req) as r: + rel = json.loads(r.read().decode("utf-8")) + + assets = rel.get("assets", []) + name_re = re.compile(rf"^transformer_engine-.*-{arch}-cp312-cp312-linux_x86_64\.whl$") + hit = next((a for a in assets if name_re.match(a.get("name", ""))), None) + if not hit: + return False + + download(hit["browser_download_url"], hit["name"]) + return True + + +def main(argv: list[str] | None = None) -> int: + """Entry point.""" + parser = argparse.ArgumentParser(add_help=True) + parser.add_argument( + "--print-arch", + action="store_true", + help="Print resolved wheel arch (mi300/mi355) and exit.", + ) + args = parser.parse_args(argv) + + if args.print_arch: + print(detect_arch(), flush=True) + return 0 + + repo = os.environ.get("GITHUB_REPOSITORY", "") + if not repo: + print("[te wheel] GITHUB_REPOSITORY not set; skipping.", flush=True) + return 0 + + arch = detect_arch() + print(f"[te wheel] arch={arch}", flush=True) + + # 1) Prefer: this repo's te-rocm-wheels assets. + try: + if try_download_from_te_rocm_wheels(repo, arch): + return 0 + print(f"[te wheel] no te-rocm-wheels asset for arch={arch}", flush=True) + except urllib.error.HTTPError as e: + print(f"[te wheel] te-rocm-wheels not available ({e.code})", flush=True) + except (urllib.error.URLError, json.JSONDecodeError, KeyError, ValueError) as e: + print(f"[te wheel] te-rocm-wheels lookup failed ({e})", flush=True) + + # 2) Fallback: pinned ROCm/maxtext release assets. + arch_tag = f"1.{arch}" + pinned_name = f"transformer_engine-2.8.0.dev0+2776c337-{arch_tag}-cp312-cp312-linux_x86_64.whl" + pinned = "https://github.com/ROCm/maxtext/releases/download/rocm-maxtext-v0.1.1/" f"{pinned_name}" + download(pinned, pinned_name) + return 0 + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/src/dependencies/requirements/requirements_decoupled_jax_0_8_2.txt b/src/dependencies/requirements/requirements_decoupled_jax_0_8_2.txt index 8f85aa10e0..9479797cb6 100644 --- a/src/dependencies/requirements/requirements_decoupled_jax_0_8_2.txt +++ b/src/dependencies/requirements/requirements_decoupled_jax_0_8_2.txt @@ -36,3 +36,5 @@ tokamax @ git+https://github.com/openxla/tokamax.git@69b328121c3ee120d8e54cf26d1 transformers>=4.57.3,<5 qwix mlperf-logging @ https://github.com/mlcommons/logging/archive/38ab22670527888c8eb7825a4ece176fcc36a95d.zip +chex>=0.1.91 +drjax>=0.1.4 diff --git a/src/dependencies/requirements/requirements_decoupled_rocm_jax_0_8_2.txt b/src/dependencies/requirements/requirements_decoupled_rocm_jax_0_8_2.txt index 7f78ab5749..f821160f03 100644 --- a/src/dependencies/requirements/requirements_decoupled_rocm_jax_0_8_2.txt +++ b/src/dependencies/requirements/requirements_decoupled_rocm_jax_0_8_2.txt @@ -43,3 +43,5 @@ tokamax @ git+https://github.com/openxla/tokamax.git@69b328121c3ee120d8e54cf26d1 transformers>=4.57.3,<5 qwix mlperf-logging @ https://github.com/mlcommons/logging/archive/38ab22670527888c8eb7825a4ece176fcc36a95d.zip +chex>=0.1.91 +drjax>=0.1.4 diff --git a/src/dependencies/requirements/requirements_rocm_jax_0.8.2.txt b/src/dependencies/requirements/requirements_rocm_jax_0.8.2.txt index aa2817c218..bffcc3a203 100644 --- a/src/dependencies/requirements/requirements_rocm_jax_0.8.2.txt +++ b/src/dependencies/requirements/requirements_rocm_jax_0.8.2.txt @@ -56,3 +56,5 @@ google-cloud-monitoring ml-goodput-measurement tensorboard-plugin-profile google-jetstream @ https://github.com/AI-Hypercomputer/JetStream/archive/29329e8e73820993f77cfc8efe34eb2a73f5de98.zip +chex>=0.1.91 +drjax>=0.1.4 From 78d464b5dd2ebc6b238679d2eaf0d3ba034f0389 Mon Sep 17 00:00:00 2001 From: gulsumgudukbay Date: Fri, 20 Feb 2026 16:12:22 +0000 Subject: [PATCH 09/21] update te wheel consumption --- .github/workflows/run_tests_against_package.yml | 3 ++- .github/workflows/utils/install_te_rocm_wheel.py | 3 ++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/.github/workflows/run_tests_against_package.yml b/.github/workflows/run_tests_against_package.yml index ca1d774dd7..c56e1801b0 100644 --- a/.github/workflows/run_tests_against_package.yml +++ b/.github/workflows/run_tests_against_package.yml @@ -89,7 +89,8 @@ permissions: contents: read jobs: run: - runs-on: ${{ inputs.cloud_runner != '' && inputs.cloud_runner || fromJson(format('["self-hosted", "{0}", "{1}"]', inputs.device_type, inputs.device_name)) }} + runs-on: ${{ inputs.cloud_runner != '' && inputs.cloud_runner || (inputs.device_type == 'rocm' && fromJson('["self-hosted","linux-x86-64-4gpu-amd"]')) || fromJson(format('["self-hosted", "{0}", "{1}"]', inputs.device_type, inputs.device_name)) }} + timeout-minutes: ${{ inputs.device_type == 'rocm' && 90 || 360 }} container: image: ${{ inputs.device_type == 'rocm' && 'ghcr.io/rocm/jax-base-ubu24.rocm720:latest' || format('gcr.io/{0}/{1}', vars.PROJECT_NAME, inputs.base_image) }} env: diff --git a/.github/workflows/utils/install_te_rocm_wheel.py b/.github/workflows/utils/install_te_rocm_wheel.py index a2a16b225c..4d119f451a 100644 --- a/.github/workflows/utils/install_te_rocm_wheel.py +++ b/.github/workflows/utils/install_te_rocm_wheel.py @@ -91,7 +91,8 @@ def try_download_from_te_rocm_wheels(repo: str, arch: str) -> bool: rel = json.loads(r.read().decode("utf-8")) assets = rel.get("assets", []) - name_re = re.compile(rf"^transformer_engine-.*-{arch}-cp312-cp312-linux_x86_64\.whl$") + # Wheels published by this repo use the selector format: `-1.-...` (e.g. `-1.mi355-...`). + name_re = re.compile(rf"^transformer_engine-.*-1\.{arch}-cp312-cp312-linux_x86_64\.whl$") hit = next((a for a in assets if name_re.match(a.get("name", ""))), None) if not hit: return False From d6ecca74abdaf6d80e4f99df72834b094ed4d61e Mon Sep 17 00:00:00 2001 From: gulsumgudukbay Date: Mon, 16 Feb 2026 20:42:46 +0000 Subject: [PATCH 10/21] refactoring TE wheel release workflow --- ...d_rocm_transformer_engine_wheel_weekly.yml | 375 ++++++------------ .../workflows/utils/install_te_rocm_wheel.py | 1 - .github/workflows/utils/te_wheels_release.py | 204 ++++++++++ 3 files changed, 331 insertions(+), 249 deletions(-) create mode 100644 .github/workflows/utils/te_wheels_release.py diff --git a/.github/workflows/build_rocm_transformer_engine_wheel_weekly.yml b/.github/workflows/build_rocm_transformer_engine_wheel_weekly.yml index d0bd1c6c79..515a3d63ce 100644 --- a/.github/workflows/build_rocm_transformer_engine_wheel_weekly.yml +++ b/.github/workflows/build_rocm_transformer_engine_wheel_weekly.yml @@ -11,8 +11,8 @@ permissions: jobs: build_upload_prune: - # Same runner label used by ROCm test workflows. - runs-on: ["self-hosted", "linux-x86-64-4gpu-amd"] + # AMD GPU runner (GitHub-hosted large runner label). + runs-on: linux-x86-64-4gpu-amd container: image: ghcr.io/rocm/jax-base-ubu24.rocm720:latest options: >- @@ -28,10 +28,8 @@ jobs: - name: Checkout uses: actions/checkout@v5 - - name: Build & publish TE wheels (native arch then MI300) + - name: Setup build environment (deps + venv) shell: bash - env: - GITHUB_TOKEN: ${{ github.token }} run: | set -euo pipefail apt-get update @@ -40,33 +38,72 @@ jobs: python3 -m uv venv --seed source .venv/bin/activate uv pip install -U pip setuptools wheel pybind11 cmake - # Install ROCm JAX/JAXlib wheels into the venv (so TE builds against the same stack as CI). + + - name: Install ROCm JAX/JAXlib wheels (build against CI stack) + shell: bash + run: | + set -euo pipefail + source .venv/bin/activate uv pip install -r dependencies/requirements/requirements_decoupled_rocm_jax_0_8_2.txt - # Detect runner ROCm architecture (avoid generic tokens like gfx9). - PRIMARY_ARCH="$( - (command -v rocminfo >/dev/null 2>&1 && rocminfo || /opt/rocm/bin/rocminfo) \ - | grep -oE 'gfx[0-9a-f]+' \ - | sort -u \ - | awk '{ print length($0), $0 }' \ - | sort -nr \ - | head -n1 \ - | awk '{ print $2 }' - )" - echo "Detected runner ROCm arch: ${PRIMARY_ARCH}" + - name: Detect ROCm version and Python tag + shell: bash + run: | + set -euo pipefail + source .venv/bin/activate # Detect ROCm version number from JAX backend string when available (e.g. 'rocm 70200'). ROCM_NUM="$(python3 -c 'import re, jax; s=str(jax.devices()[0].client.platform_version); m=re.search(r"rocm\\s+([0-9]+)", s); print(m.group(1) if m else "unknown")')" echo "Detected ROCm version: ${ROCM_NUM}" + echo "ROCM_NUM=${ROCM_NUM}" >> "${GITHUB_ENV}" - # Build TE from ROCm/TransformerEngine dev branch. + PYTAG="cp$(python3 -c 'import sys; print(f"{sys.version_info.major}{sys.version_info.minor}")')" + if [ "${PYTAG}" != "cp312" ]; then + echo "Expected Python 3.12 (cp312) for ROCm CI wheels, got ${PYTAG}." + exit 1 + fi + echo "PYTAG=${PYTAG}" >> "${GITHUB_ENV}" + echo "REL_SCRIPT=.github/workflows/utils/te_wheels_release.py" >> "${GITHUB_ENV}" + + - name: Clone ROCm/TransformerEngine (dev) + shell: bash + run: | + set -euo pipefail rm -rf TransformerEngine git clone --recursive --branch dev https://github.com/ROCm/TransformerEngine.git cd TransformerEngine git submodule update --init --recursive - TE_SHA="$(git rev-parse --short=12 HEAD)" - PYTAG="cp$(python3 -c 'import sys; print(f"{sys.version_info.major}{sys.version_info.minor}")')" + echo "TE_SHA=${TE_SHA}" >> "${GITHUB_ENV}" + + - name: Select TE wheel arch for runner (mi300/mi355) + shell: bash + run: | + set -euo pipefail + source .venv/bin/activate + + TE_WHEEL_ARCH="$(python3 .github/workflows/utils/install_te_rocm_wheel.py --print-arch)" + echo "Resolved TE wheel arch for runner: ${TE_WHEEL_ARCH}" + echo "TE_WHEEL_ARCH=${TE_WHEEL_ARCH}" >> "${GITHUB_ENV}" + + # Build ONLY for the ROCm arch present on this CI runner (mi300 or mi355). + if [ "${TE_WHEEL_ARCH}" = "mi355" ]; then + SELECTOR="mi355" + GFX_ARCH="gfx950" + else + SELECTOR="mi300" + GFX_ARCH="gfx942;gfx941" + fi + echo "SELECTOR=${SELECTOR}" >> "${GITHUB_ENV}" + echo "GFX_ARCH=${GFX_ARCH}" >> "${GITHUB_ENV}" + + - name: Build TE wheel + shell: bash + run: | + set -euo pipefail + source .venv/bin/activate + + chmod +x "${REL_SCRIPT}" || true export USE_ROCM=1 export HIP_PATH=/opt/rocm @@ -75,239 +112,81 @@ jobs: export NVTE_USE_ROCM=1 export NVTE_FUSED_ATTN_AOTRITON=0 export NVTE_BUILD_MAX_JOBS=180 + #export NVTE_AITER_PREBUILT_BASE_URL=https://compute-artifactory.amd.com:5000/artifactory/rocm-generic-local/te-ci/aiter-prebuilts + + echo "=== Building TE wheel for ${SELECTOR} (gfx=${GFX_ARCH}) ===" + pushd TransformerEngine >/dev/null + rm -rf build dist + export PYTHONPATH="$(pwd)/3rdparty/hipify_torch${PYTHONPATH:+:${PYTHONPATH}}" + export PYTORCH_ROCM_ARCH="${GFX_ARCH}" + export NVTE_ROCM_ARCH="${GFX_ARCH}" + python3 setup.py bdist_wheel + wheel_path="$( + python3 -c 'import glob; m=sorted(glob.glob("dist/transformer_engine-*.whl")); print(m[0] if m else "")' + )" + if [ -z "${wheel_path}" ]; then + echo "No wheel produced in dist/ (selector=${SELECTOR})." + exit 1 + fi + wheel_base="$(basename "${wheel_path}")" + if [[ "${wheel_base}" == *"-1.${SELECTOR}-${PYTAG}-${PYTAG}-linux_x86_64.whl" ]]; then + asset_name="${wheel_base}" + else + asset_name="${wheel_base/-${PYTAG}-${PYTAG}-linux_x86_64.whl/-1.${SELECTOR}-${PYTAG}-${PYTAG}-linux_x86_64.whl}" + if [ "${asset_name}" = "${wheel_base}" ]; then + echo "Failed to rename wheel for selector=${SELECTOR}: ${wheel_base}" + exit 1 + fi + fi + cp -f "${wheel_path}" "../${asset_name}" + popd >/dev/null + + ls -lh "${asset_name}" + echo "TE_WHEEL_FILE=${asset_name}" >> "${GITHUB_ENV}" + + - name: Upload wheel to rolling release tag (te-rocm-wheels) + shell: bash + env: + GITHUB_TOKEN: ${{ github.token }} + run: | + set -euo pipefail + source .venv/bin/activate + python3 "${REL_SCRIPT}" upload --no-prerelease --tag "te-rocm-wheels" --title "ROCm TransformerEngine wheels (latest)" --body "Rolling release for latest weekly-built ROCm TransformerEngine wheels used by CI." --file "${TE_WHEEL_FILE}" - # Return to workspace root; build function will enter TE directory. - cd .. - - build_one() { - local arch="$1" - echo "=== Building TE wheel for ${arch} ===" - pushd TransformerEngine >/dev/null - rm -rf build dist - export PYTORCH_ROCM_ARCH="${arch}" - export NVTE_ROCM_ARCH="${arch}" - python3 setup.py bdist_wheel - local wheel_path - wheel_path="$(ls -1 dist/transformer_engine-*.whl | head -n1)" - popd >/dev/null - local asset_name - asset_name="transformer_engine-${TE_SHA}-${PYTAG}-rocm${ROCM_NUM}-${arch}.whl" - cp -f "TransformerEngine/${wheel_path}" "${asset_name}" - echo "${asset_name}" - } - - publish_asset_to_release_tag() { - local tag="$1" - local title="$2" - local body="$3" - local wheel_file="$4" - python3 - "$tag" "$title" "$body" "$wheel_file" <<'PY' - import json, os, sys, urllib.error, urllib.parse, urllib.request - - tag, title, body, wheel_file = sys.argv[1:5] - token = os.environ["GITHUB_TOKEN"] - repo = os.environ.get("GITHUB_REPOSITORY") - if not repo: - raise SystemExit("GITHUB_REPOSITORY is not set") - owner, name = repo.split("/", 1) - - api = "https://api.github.com" - headers = { - "Authorization": f"Bearer {token}", - "Accept": "application/vnd.github+json", - "User-Agent": "maxtext-ci", - } - - def request_json(method: str, url: str, body_obj=None): - data = None - if body_obj is not None: - data = json.dumps(body_obj).encode("utf-8") - req = urllib.request.Request(url, data=data, method=method, headers=headers) - with urllib.request.urlopen(req) as r: - return json.loads(r.read().decode("utf-8")) - - def request_raw(method: str, url: str, data: bytes, extra_headers=None): - h = dict(headers) - if extra_headers: - h.update(extra_headers) - req = urllib.request.Request(url, data=data, method=method, headers=h) - with urllib.request.urlopen(req) as r: - return r.read() - - # Ensure release exists for this tag. - release = None - try: - release = request_json("GET", f"{api}/repos/{owner}/{name}/releases/tags/{tag}") - except urllib.error.HTTPError as e: - if e.code != 404: - raise - - if release is None: - release = request_json( - "POST", - f"{api}/repos/{owner}/{name}/releases", - { - "tag_name": tag, - "name": title, - "body": body, - "prerelease": True, - }, - ) - - release_id = release["id"] - upload_url = release["upload_url"].split("{", 1)[0] - - # Delete any existing asset with same name. - assets = request_json("GET", f"{api}/repos/{owner}/{name}/releases/{release_id}")["assets"] - wheel_name = os.path.basename(wheel_file) - for a in assets: - if a.get("name") == wheel_name: - request_json("DELETE", f"{api}/repos/{owner}/{name}/releases/assets/{a['id']}") - - with open(wheel_file, "rb") as f: - wheel_bytes = f.read() - up = f"{upload_url}?{urllib.parse.urlencode({'name': wheel_name})}" - request_raw("POST", up, wheel_bytes, extra_headers={"Content-Type": "application/octet-stream"}) - print(f"Uploaded {wheel_name} to release tag {tag}", flush=True) - PY - } - - prune_old_weekly_releases() { - local prefix="$1" - local keep_days="$2" - python3 - "$prefix" "$keep_days" <<'PY' - import datetime as dt - import json, os, sys, urllib.request - - prefix = sys.argv[1] - keep_days = int(sys.argv[2]) - token = os.environ["GITHUB_TOKEN"] - repo = os.environ.get("GITHUB_REPOSITORY") - if not repo: - raise SystemExit("GITHUB_REPOSITORY is not set") - - api = "https://api.github.com" - headers = { - "Authorization": f"Bearer {token}", - "Accept": "application/vnd.github+json", - "User-Agent": "maxtext-ci", - } - - def get_json(url: str): - req = urllib.request.Request(url, headers=headers) - with urllib.request.urlopen(req) as r: - return json.loads(r.read().decode("utf-8")) - - def delete(url: str): - req = urllib.request.Request(url, method="DELETE", headers=headers) - with urllib.request.urlopen(req): - return - - cutoff = dt.datetime.now(dt.timezone.utc) - dt.timedelta(days=keep_days) - - # Paginate releases. - page = 1 - deleted = 0 - while True: - releases = get_json(f"{api}/repos/{repo}/releases?per_page=100&page={page}") - if not releases: - break - for rel in releases: - tag = rel.get("tag_name", "") - if not tag.startswith(prefix): - continue - created = dt.datetime.fromisoformat(rel["created_at"].replace("Z", "+00:00")) - if created < cutoff: - delete(f"{api}/repos/{repo}/releases/{rel['id']}") - deleted += 1 - print(f"Deleted old weekly TE release {tag} (created_at={rel['created_at']})", flush=True) - page += 1 - print(f"Weekly TE release prune complete. Deleted {deleted} releases older than {keep_days} days.", flush=True) - PY - } - - prune_old_assets_in_release_tag() { - local tag="$1" - local keep_days="$2" - python3 - "$tag" "$keep_days" <<'PY' - import datetime as dt - import json, os, sys, urllib.error, urllib.request - - tag = sys.argv[1] - keep_days = int(sys.argv[2]) - token = os.environ["GITHUB_TOKEN"] - repo = os.environ.get("GITHUB_REPOSITORY") - if not repo: - raise SystemExit("GITHUB_REPOSITORY is not set") - owner, name = repo.split("/", 1) - - api = "https://api.github.com" - headers = { - "Authorization": f"Bearer {token}", - "Accept": "application/vnd.github+json", - "User-Agent": "maxtext-ci", - } - - def request_json(method: str, url: str): - req = urllib.request.Request(url, method=method, headers=headers) - with urllib.request.urlopen(req) as r: - data = r.read() - return json.loads(data.decode("utf-8")) if data else None - - def delete(url: str): - req = urllib.request.Request(url, method="DELETE", headers=headers) - with urllib.request.urlopen(req): - return - - try: - rel = request_json("GET", f"{api}/repos/{owner}/{name}/releases/tags/{tag}") - except urllib.error.HTTPError as e: - if e.code == 404: - print(f"No release for tag {tag}; skipping asset prune.", flush=True) - raise SystemExit(0) - raise - - cutoff = dt.datetime.now(dt.timezone.utc) - dt.timedelta(days=keep_days) - pruned = 0 - for a in rel.get("assets", []): - created = dt.datetime.fromisoformat(a["created_at"].replace("Z", "+00:00")) - if created < cutoff: - delete(f"{api}/repos/{owner}/{name}/releases/assets/{a['id']}") - pruned += 1 - print(f"Pruned old asset: {a['name']} (created_at={a['created_at']})", flush=True) - print(f"Asset prune complete for {tag}. Deleted {pruned} assets older than {keep_days} days.", flush=True) - PY - } - - # Build runner-native wheel and upload immediately so other CI can pick it up. - NATIVE_WHEEL="$(build_one "${PRIMARY_ARCH}")" - ls -lh "${NATIVE_WHEEL}" - publish_asset_to_release_tag \ - "te-rocm-wheels" \ - "ROCm TransformerEngine wheels (latest)" \ - "Rolling release for latest weekly-built ROCm TransformerEngine wheels used by CI." \ - "${NATIVE_WHEEL}" - - # Build MI300 wheel (gfx942) and upload. - MI300_WHEEL="$(build_one "gfx942")" - ls -lh "${MI300_WHEEL}" - publish_asset_to_release_tag \ - "te-rocm-wheels" \ - "ROCm TransformerEngine wheels (latest)" \ - "Rolling release for latest weekly-built ROCm TransformerEngine wheels used by CI." \ - "${MI300_WHEEL}" + - name: Prune old assets from rolling tag + shell: bash + env: + GITHUB_TOKEN: ${{ github.token }} + run: | + set -euo pipefail + source .venv/bin/activate + python3 "${REL_SCRIPT}" prune-assets --tag "te-rocm-wheels" --keep-days 21 - # Retention: prune rolling release assets older than 3 weeks. - prune_old_assets_in_release_tag "te-rocm-wheels" 21 + - name: Publish wheel to dated weekly release tag + shell: bash + env: + GITHUB_TOKEN: ${{ github.token }} + run: | + set -euo pipefail + source .venv/bin/activate - # Also publish both wheels to a new weekly release page. DATE_UTC="$(date -u +%Y-%m-%d)" WEEKLY_TAG="te-rocm-wheels-${DATE_UTC}-${TE_SHA}" WEEKLY_TITLE="ROCm TransformerEngine wheels ${DATE_UTC} (TE ${TE_SHA})" - WEEKLY_BODY="Built from ROCm/TransformerEngine dev @ ${TE_SHA} on ${DATE_UTC}.\nROCm=${ROCM_NUM}, Python=${PYTAG}, arches=${PRIMARY_ARCH} and gfx942." + # Keep this YAML-safe (no unindented heredocs inside `run: |`). + WEEKLY_BODY="$( + printf '%s\n\nROCm: %s\nPython: %s\nArch: %s (gfx=%s)\n' \ + "Built from ROCm/TransformerEngine dev @ ${TE_SHA} on ${DATE_UTC}." \ + "${ROCM_NUM}" "${PYTAG}" "${SELECTOR}" "${GFX_ARCH}" + )" - publish_asset_to_release_tag "${WEEKLY_TAG}" "${WEEKLY_TITLE}" "${WEEKLY_BODY}" "${NATIVE_WHEEL}" - publish_asset_to_release_tag "${WEEKLY_TAG}" "${WEEKLY_TITLE}" "${WEEKLY_BODY}" "${MI300_WHEEL}" + python3 "${REL_SCRIPT}" upload --no-prerelease --tag "${WEEKLY_TAG}" --title "${WEEKLY_TITLE}" --body "${WEEKLY_BODY}" --file "${TE_WHEEL_FILE}" - # Retention: delete weekly releases older than 3 weeks. - prune_old_weekly_releases "te-rocm-wheels-" 21 + - name: Prune old weekly releases + shell: bash + env: + GITHUB_TOKEN: ${{ github.token }} + run: | + set -euo pipefail + source .venv/bin/activate + python3 "${REL_SCRIPT}" prune-releases --prefix "te-rocm-wheels-" --keep-days 21 diff --git a/.github/workflows/utils/install_te_rocm_wheel.py b/.github/workflows/utils/install_te_rocm_wheel.py index 4d119f451a..673d8f731e 100644 --- a/.github/workflows/utils/install_te_rocm_wheel.py +++ b/.github/workflows/utils/install_te_rocm_wheel.py @@ -91,7 +91,6 @@ def try_download_from_te_rocm_wheels(repo: str, arch: str) -> bool: rel = json.loads(r.read().decode("utf-8")) assets = rel.get("assets", []) - # Wheels published by this repo use the selector format: `-1.-...` (e.g. `-1.mi355-...`). name_re = re.compile(rf"^transformer_engine-.*-1\.{arch}-cp312-cp312-linux_x86_64\.whl$") hit = next((a for a in assets if name_re.match(a.get("name", ""))), None) if not hit: diff --git a/.github/workflows/utils/te_wheels_release.py b/.github/workflows/utils/te_wheels_release.py new file mode 100644 index 0000000000..a2ae1368e2 --- /dev/null +++ b/.github/workflows/utils/te_wheels_release.py @@ -0,0 +1,204 @@ +#!/usr/bin/env python3 +""" +Publish and prune ROCm TransformerEngine wheel assets on GitHub Releases. + +Intended for use in GitHub Actions. Requires: +- GITHUB_TOKEN +- GITHUB_REPOSITORY (owner/repo) +""" + +from __future__ import annotations + +import argparse +import datetime as dt +import json +import os +import sys +import urllib.error +import urllib.parse +import urllib.request + +API = "https://api.github.com" + + +def _env_token() -> str: + token = os.environ.get("GITHUB_TOKEN") + if not token: + raise SystemExit("GITHUB_TOKEN is not set.") + return token + + +def _env_repo() -> tuple[str, str]: + repo = os.environ.get("GITHUB_REPOSITORY") + if not repo or "/" not in repo: + raise SystemExit("GITHUB_REPOSITORY is not set (expected 'owner/repo').") + owner, name = repo.split("/", 1) + return owner, name + + +def _headers(token: str) -> dict[str, str]: + return { + "Authorization": f"Bearer {token}", + "Accept": "application/vnd.github+json", + "User-Agent": "maxtext-ci", + } + + +def _request_json(method: str, url: str, token: str, body: dict | None = None): + data = None + if body is not None: + data = json.dumps(body).encode("utf-8") + req = urllib.request.Request(url, data=data, method=method, headers=_headers(token)) + with urllib.request.urlopen(req) as r: + raw = r.read() + return json.loads(raw.decode("utf-8")) if raw else None + + +def _request_raw(method: str, url: str, token: str, data: bytes, content_type: str): + h = _headers(token) + h["Content-Type"] = content_type + req = urllib.request.Request(url, data=data, method=method, headers=h) + with urllib.request.urlopen(req) as r: + return r.read() + + +def get_or_create_release(tag: str, title: str, body: str, prerelease: bool) -> dict: + """Get a release by tag, or create it if missing. + + Args: + tag: Release tag (e.g. 'te-rocm-wheels'). + title: Release title. + body: Release body text. + prerelease: Whether to mark the release as a prerelease. + + Returns: + The GitHub release object JSON. + """ + token = _env_token() + owner, name = _env_repo() + try: + rel = _request_json("GET", f"{API}/repos/{owner}/{name}/releases/tags/{tag}", token) + if rel: + return rel + except urllib.error.HTTPError as e: + if e.code != 404: + raise + return _request_json( + "POST", + f"{API}/repos/{owner}/{name}/releases", + token, + {"tag_name": tag, "name": title, "body": body, "prerelease": prerelease}, + ) + + +def upload_asset(tag: str, title: str, body: str, file_path: str, prerelease: bool) -> None: + """Upload (replace) a release asset under the given tag. + + If an asset with the same filename already exists, it is deleted first. + """ + token = _env_token() + owner, name = _env_repo() + + rel = get_or_create_release(tag, title, body, prerelease) + release_id = rel["id"] + upload_url = rel["upload_url"].split("{", 1)[0] + + assets = _request_json("GET", f"{API}/repos/{owner}/{name}/releases/{release_id}", token)["assets"] + file_name = os.path.basename(file_path) + for a in assets: + if a.get("name") == file_name: + _request_json("DELETE", f"{API}/repos/{owner}/{name}/releases/assets/{a['id']}", token) + + with open(file_path, "rb") as f: + data = f.read() + up = f"{upload_url}?{urllib.parse.urlencode({'name': file_name})}" + _request_raw("POST", up, token, data, "application/octet-stream") + print(f"Uploaded {file_name} to release tag {tag}", flush=True) + + +def prune_assets(tag: str, keep_days: int) -> None: + """Delete assets older than `keep_days` from the given release tag.""" + token = _env_token() + owner, name = _env_repo() + try: + rel = _request_json("GET", f"{API}/repos/{owner}/{name}/releases/tags/{tag}", token) + except urllib.error.HTTPError as e: + if e.code == 404: + print(f"No release for tag {tag}; skipping asset prune.", flush=True) + return + raise + + cutoff = dt.datetime.now(dt.timezone.utc) - dt.timedelta(days=keep_days) + pruned = 0 + for a in rel.get("assets", []): + created = dt.datetime.fromisoformat(a["created_at"].replace("Z", "+00:00")) + if created < cutoff: + _request_json("DELETE", f"{API}/repos/{owner}/{name}/releases/assets/{a['id']}", token) + pruned += 1 + print(f"Pruned old asset: {a['name']} (created_at={a['created_at']})", flush=True) + print(f"Asset prune complete for {tag}. Deleted {pruned} assets older than {keep_days} days.", flush=True) + + +def prune_releases(prefix: str, keep_days: int) -> None: + """Delete releases with tag names starting with `prefix` older than `keep_days` days.""" + token = _env_token() + owner, name = _env_repo() + cutoff = dt.datetime.now(dt.timezone.utc) - dt.timedelta(days=keep_days) + + page = 1 + deleted = 0 + while True: + rels = _request_json("GET", f"{API}/repos/{owner}/{name}/releases?per_page=100&page={page}", token) + if not rels: + break + for rel in rels: + tag = rel.get("tag_name", "") + if not tag.startswith(prefix): + continue + created = dt.datetime.fromisoformat(rel["created_at"].replace("Z", "+00:00")) + if created < cutoff: + _request_json("DELETE", f"{API}/repos/{owner}/{name}/releases/{rel['id']}", token) + deleted += 1 + print(f"Deleted old release {tag} (created_at={rel['created_at']})", flush=True) + page += 1 + print(f"Release prune complete. Deleted {deleted} releases older than {keep_days} days.", flush=True) + + +def main(argv: list[str]) -> int: + p = argparse.ArgumentParser() + sub = p.add_subparsers(dest="cmd", required=True) + + up = sub.add_parser("upload") + up.add_argument("--tag", required=True) + up.add_argument("--title", required=True) + up.add_argument("--body", required=True) + up.add_argument("--file", required=True) + prg = up.add_mutually_exclusive_group() + prg.add_argument("--prerelease", dest="prerelease", action="store_true") + prg.add_argument("--no-prerelease", dest="prerelease", action="store_false") + up.set_defaults(prerelease=True) + + pa = sub.add_parser("prune-assets") + pa.add_argument("--tag", required=True) + pa.add_argument("--keep-days", type=int, required=True) + + pr = sub.add_parser("prune-releases") + pr.add_argument("--prefix", required=True) + pr.add_argument("--keep-days", type=int, required=True) + + args = p.parse_args(argv) + + if args.cmd == "upload": + upload_asset(args.tag, args.title, args.body, args.file, prerelease=args.prerelease) + return 0 + if args.cmd == "prune-assets": + prune_assets(args.tag, args.keep_days) + return 0 + if args.cmd == "prune-releases": + prune_releases(args.prefix, args.keep_days) + return 0 + raise AssertionError("unreachable") + + +if __name__ == "__main__": + raise SystemExit(main(sys.argv[1:])) From 63f59dd59fbc4517c2510fae91385ed85d2f4368 Mon Sep 17 00:00:00 2001 From: gulsumgudukbay Date: Tue, 24 Feb 2026 20:30:30 +0000 Subject: [PATCH 11/21] update runner labels to mi355 --- .github/workflows/utils/install_te_rocm_wheel.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/utils/install_te_rocm_wheel.py b/.github/workflows/utils/install_te_rocm_wheel.py index 673d8f731e..7a3f70f75b 100644 --- a/.github/workflows/utils/install_te_rocm_wheel.py +++ b/.github/workflows/utils/install_te_rocm_wheel.py @@ -62,7 +62,7 @@ def detect_arch() -> str: return "mi300" # Safe default. - return "mi300" + return "mi355" def _headers() -> dict[str, str]: From 86ebc19d64c5bb89237dacf39a1968fbb2a89fa8 Mon Sep 17 00:00:00 2001 From: gulsumgudukbay Date: Tue, 10 Mar 2026 05:00:20 +0000 Subject: [PATCH 12/21] fix te wheel selection (cherry picked from commit a2f9860f0a468507b1296cb35cca5e1c5c9d13c7) fix rocm version finding --- .../build_rocm_transformer_engine_wheel_weekly.yml | 10 +++++++--- .github/workflows/utils/install_te_rocm_wheel.py | 12 ++++++++++-- 2 files changed, 17 insertions(+), 5 deletions(-) diff --git a/.github/workflows/build_rocm_transformer_engine_wheel_weekly.yml b/.github/workflows/build_rocm_transformer_engine_wheel_weekly.yml index 515a3d63ce..aede854e9d 100644 --- a/.github/workflows/build_rocm_transformer_engine_wheel_weekly.yml +++ b/.github/workflows/build_rocm_transformer_engine_wheel_weekly.yml @@ -23,6 +23,8 @@ jobs: env: XLA_PYTHON_CLIENT_MEM_FRACTION: "0.9" NVTE_FUSED_ATTN_AOTRITON: "0" + env: + TE_WHEELS_KEEP_DAYS: "21" steps: - name: Checkout @@ -53,7 +55,7 @@ jobs: source .venv/bin/activate # Detect ROCm version number from JAX backend string when available (e.g. 'rocm 70200'). - ROCM_NUM="$(python3 -c 'import re, jax; s=str(jax.devices()[0].client.platform_version); m=re.search(r"rocm\\s+([0-9]+)", s); print(m.group(1) if m else "unknown")')" + ROCM_NUM="$([ -f /opt/rocm/.info/version ] && head -n1 /opt/rocm/.info/version | tr -d ' \t\r' || echo unknown)" echo "Detected ROCm version: ${ROCM_NUM}" echo "ROCM_NUM=${ROCM_NUM}" >> "${GITHUB_ENV}" @@ -160,7 +162,8 @@ jobs: run: | set -euo pipefail source .venv/bin/activate - python3 "${REL_SCRIPT}" prune-assets --tag "te-rocm-wheels" --keep-days 21 + echo "Pruning rolling-tag assets older than ${TE_WHEELS_KEEP_DAYS} days" + python3 "${REL_SCRIPT}" prune-assets --tag "te-rocm-wheels" --keep-days "${TE_WHEELS_KEEP_DAYS}" - name: Publish wheel to dated weekly release tag shell: bash @@ -189,4 +192,5 @@ jobs: run: | set -euo pipefail source .venv/bin/activate - python3 "${REL_SCRIPT}" prune-releases --prefix "te-rocm-wheels-" --keep-days 21 + echo "Pruning weekly release pages older than ${TE_WHEELS_KEEP_DAYS} days" + python3 "${REL_SCRIPT}" prune-releases --prefix "te-rocm-wheels-" --keep-days "${TE_WHEELS_KEEP_DAYS}" diff --git a/.github/workflows/utils/install_te_rocm_wheel.py b/.github/workflows/utils/install_te_rocm_wheel.py index 7a3f70f75b..8f601884fd 100644 --- a/.github/workflows/utils/install_te_rocm_wheel.py +++ b/.github/workflows/utils/install_te_rocm_wheel.py @@ -92,10 +92,18 @@ def try_download_from_te_rocm_wheels(repo: str, arch: str) -> bool: assets = rel.get("assets", []) name_re = re.compile(rf"^transformer_engine-.*-1\.{arch}-cp312-cp312-linux_x86_64\.whl$") - hit = next((a for a in assets if name_re.match(a.get("name", ""))), None) - if not hit: + matches = [a for a in assets if name_re.match(a.get("name", ""))] + if not matches: return False + # Rolling tag keeps many wheels; select newest matching asset. + hit = max(matches, key=lambda a: a.get("created_at", "")) + print( + "[te wheel] selected latest te-rocm-wheels asset: " + f"{hit.get('name', '')} (created_at={hit.get('created_at', 'unknown')})", + flush=True, + ) + download(hit["browser_download_url"], hit["name"]) return True From 77372cf1fbdbff308bb132178b8125e4d1741b74 Mon Sep 17 00:00:00 2001 From: Gulsum Gudukbay Akbulut Date: Wed, 11 Mar 2026 07:40:50 -0500 Subject: [PATCH 13/21] Remove ROCm fused-attention backend variables Removed ROCm specific environment variables for fused-attention. --- .github/workflows/run_tests_against_package.yml | 3 --- 1 file changed, 3 deletions(-) diff --git a/.github/workflows/run_tests_against_package.yml b/.github/workflows/run_tests_against_package.yml index c56e1801b0..5d0f91cb76 100644 --- a/.github/workflows/run_tests_against_package.yml +++ b/.github/workflows/run_tests_against_package.yml @@ -105,9 +105,6 @@ jobs: ALLOW_MULTIPLE_LIBTPU_LOAD: ${{ inputs.device_type == 'cpu' && 'true' || '' }} # bypass /tmp/libtpu_lockfile check for cpu tests, which don't actually use accelerators (to allow concurrency) DECOUPLE_GCLOUD: ${{ inputs.decoupled_mode && 'TRUE' || '' }} LOCAL_GCLOUD_PROJECT: ${{ inputs.decoupled_mode && 'ci-decoupled' || '' }} - # ROCm: prefer CK fused-attention backend over AOTriton for stability. - NVTE_FUSED_ATTN_CK: ${{ inputs.device_type == 'rocm' && '1' || '' }} - NVTE_FUSED_ATTN_AOTRITON: ${{ inputs.device_type == 'rocm' && '0' || '' }} options: ${{ inputs.container_resource_option }} steps: - name: Checkout MaxText From 79eb75b1b767eea0a7c49343cf0ce9e863e06fde Mon Sep 17 00:00:00 2001 From: gulsumgudukbay Date: Thu, 12 Mar 2026 22:52:25 +0000 Subject: [PATCH 14/21] refactor requirements location change --- .../workflows/build_rocm_transformer_engine_wheel_weekly.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/build_rocm_transformer_engine_wheel_weekly.yml b/.github/workflows/build_rocm_transformer_engine_wheel_weekly.yml index aede854e9d..c6fdd677e7 100644 --- a/.github/workflows/build_rocm_transformer_engine_wheel_weekly.yml +++ b/.github/workflows/build_rocm_transformer_engine_wheel_weekly.yml @@ -46,7 +46,7 @@ jobs: run: | set -euo pipefail source .venv/bin/activate - uv pip install -r dependencies/requirements/requirements_decoupled_rocm_jax_0_8_2.txt + uv pip install -r src/dependencies/requirements/requirements_decoupled_rocm_jax_0_8_2.txt - name: Detect ROCm version and Python tag shell: bash From fd700991d4723b0d0c495bef177eef8e76d46d4d Mon Sep 17 00:00:00 2001 From: gulsumgudukbay Date: Tue, 31 Mar 2026 17:56:05 +0000 Subject: [PATCH 15/21] fix TE build workflow, add rocm torch dependency to env --- .../build_rocm_transformer_engine_wheel_weekly.yml | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/.github/workflows/build_rocm_transformer_engine_wheel_weekly.yml b/.github/workflows/build_rocm_transformer_engine_wheel_weekly.yml index c6fdd677e7..9b90eee2a0 100644 --- a/.github/workflows/build_rocm_transformer_engine_wheel_weekly.yml +++ b/.github/workflows/build_rocm_transformer_engine_wheel_weekly.yml @@ -48,6 +48,13 @@ jobs: source .venv/bin/activate uv pip install -r src/dependencies/requirements/requirements_decoupled_rocm_jax_0_8_2.txt + - name: Install PyTorch ROCm (build-time dep for aiter JIT) + shell: bash + run: | + set -euo pipefail + source .venv/bin/activate + uv pip install torch --index-url https://download.pytorch.org/whl/rocm7.2 + - name: Detect ROCm version and Python tag shell: bash run: | @@ -114,7 +121,6 @@ jobs: export NVTE_USE_ROCM=1 export NVTE_FUSED_ATTN_AOTRITON=0 export NVTE_BUILD_MAX_JOBS=180 - #export NVTE_AITER_PREBUILT_BASE_URL=https://compute-artifactory.amd.com:5000/artifactory/rocm-generic-local/te-ci/aiter-prebuilts echo "=== Building TE wheel for ${SELECTOR} (gfx=${GFX_ARCH}) ===" pushd TransformerEngine >/dev/null From 67f4d1987e78f885f7f5da910dcc4d1cfc54ca54 Mon Sep 17 00:00:00 2001 From: gulsumgudukbay Date: Wed, 8 Apr 2026 16:02:39 +0000 Subject: [PATCH 16/21] fix ci --- .github/workflows/run_tests_against_package.yml | 10 ++++++++-- .github/workflows/run_tests_coordinator.yml | 6 +++--- 2 files changed, 11 insertions(+), 5 deletions(-) diff --git a/.github/workflows/run_tests_against_package.yml b/.github/workflows/run_tests_against_package.yml index 5d0f91cb76..9eacb4c1d3 100644 --- a/.github/workflows/run_tests_against_package.yml +++ b/.github/workflows/run_tests_against_package.yml @@ -97,9 +97,11 @@ jobs: XLA_PYTHON_CLIENT_MEM_FRACTION: ${{ inputs.xla_python_client_mem_fraction }} TF_FORCE_GPU_ALLOW_GROWTH: ${{ inputs.tf_force_gpu_allow_growth }} TPU_SKIP_MDS_QUERY: ${{ inputs.device_type == 'cpu' && '1' || '' }} + # ROCm installs the wheel with --no-deps plus a pinned requirements file; never use TPU wheel extras. MAXTEXT_PACKAGE_EXTRA: >- ${{ - !contains(inputs.pytest_marker, 'not post_training') && 'tpu-post-train' + inputs.device_type == 'rocm' && 'rocm' + || !contains(inputs.pytest_marker, 'not post_training') && 'tpu-post-train' || (inputs.device_type == 'cpu' && 'tpu' || inputs.device_type) }} ALLOW_MULTIPLE_LIBTPU_LOAD: ${{ inputs.device_type == 'cpu' && 'true' || '' }} # bypass /tmp/libtpu_lockfile check for cpu tests, which don't actually use accelerators (to allow concurrency) @@ -134,7 +136,11 @@ jobs: uv pip install -r "${{ inputs.requirements_file }}" fi uv pip install ${maxtext_wheel} --no-deps - uv pip install -r src/dependencies/github_deps/pre_train_deps.txt + # When a requirements file is set (e.g. decoupled or rocm-unit JAX stacks), it already + # carries test/runtime pins; otherwise add GitHub-sourced pins (no TPU extra / scripts). + if [ -z "${{ inputs.requirements_file }}" ]; then + uv pip install -r src/dependencies/extra_deps/pre_train_github_deps.txt --no-deps + fi else if [ -n "${{ inputs.requirements_file }}" ]; then echo "Installing requirements from ${{ inputs.requirements_file }}" diff --git a/.github/workflows/run_tests_coordinator.yml b/.github/workflows/run_tests_coordinator.yml index c7bc49ee5c..3b62ec5f95 100644 --- a/.github/workflows/run_tests_coordinator.yml +++ b/.github/workflows/run_tests_coordinator.yml @@ -125,8 +125,8 @@ jobs: "gpu-integration": "not cpu_only and not tpu_only and integration_test and not post_training", "cpu-unit": "cpu_only and not post_training", "cpu-post-training-unit": "cpu_only and post_training", - "rocm-unit": "not cpu_only and not tpu_only and not integration_test and not post_training and not decoupled", - "rocm-decoupled": "decoupled" + "rocm-unit": "not cpu_only and not tpu_only and not decoupled", + "rocm-decoupled": "not cpu_only and not tpu_only and decoupled" }')[inputs.flavor] }} pytest_addopts: >- @@ -153,7 +153,7 @@ jobs: "gpu-integration": "--ignore=tests/post_training", "cpu-unit": "--ignore=tests/post_training", "cpu-post-training-unit": "", - "rocm-unit": "--ignore=tests/post_training", + "rocm-unit": "", "rocm-decoupled": "" }')[inputs.flavor] }} ${{ inputs.additional_pytest_args }} From 65affee2377376b1386cb62cb4d4faaa5ce89811 Mon Sep 17 00:00:00 2001 From: gulsumgudukbay Date: Tue, 5 May 2026 15:56:31 +0000 Subject: [PATCH 17/21] update requirements --- .../requirements_decoupled_jax_0_7.1.txt | 44 ----------------- ...t => requirements_decoupled_jax_0_9_1.txt} | 3 +- .../requirements_decoupled_rocm_jax_0_7.1.txt | 47 ------------------- ...requirements_decoupled_rocm_jax_0_9_1.txt} | 13 +++-- ....2.txt => requirements_rocm_jax_0_9_1.txt} | 14 ++++-- 5 files changed, 18 insertions(+), 103 deletions(-) delete mode 100644 src/dependencies/requirements/requirements_decoupled_jax_0_7.1.txt rename src/dependencies/requirements/{requirements_decoupled_jax_0_8_2.txt => requirements_decoupled_jax_0_9_1.txt} (96%) delete mode 100644 src/dependencies/requirements/requirements_decoupled_rocm_jax_0_7.1.txt rename src/dependencies/requirements/{requirements_decoupled_rocm_jax_0_8_2.txt => requirements_decoupled_rocm_jax_0_9_1.txt} (56%) rename src/dependencies/requirements/{requirements_rocm_jax_0.8.2.txt => requirements_rocm_jax_0_9_1.txt} (67%) diff --git a/src/dependencies/requirements/requirements_decoupled_jax_0_7.1.txt b/src/dependencies/requirements/requirements_decoupled_jax_0_7.1.txt deleted file mode 100644 index 8f904a3641..0000000000 --- a/src/dependencies/requirements/requirements_decoupled_jax_0_7.1.txt +++ /dev/null @@ -1,44 +0,0 @@ -absl_py>=2.3.1 -aqtp>=0.9.0 -chex>=0.1.90 -datasets>=4.2.0 -etils>=1.13.0 -evaluate>=0.4.6 -flax -grain>=0.2.12 -grpcio>=1.75.1 -huggingface_hub>=0.35.3 -jax==0.7.1 -jaxtyping>=0.3.3 -jsonlines>=4.0.0 -matplotlib>=3.10.3 -ml_collections>=1.1.0 -ml_dtypes>=0.5.3 -nltk>=3.9.2 -numpy>=2.0.2 -omegaconf>=2.3.0 -optax>=0.2.6 -orbax-checkpoint>=0.11.25 -pandas>=2.3.3 -parameterized==0.9.0 -pathwaysutils>=0.1.7 -pillow>=11.3.0 -protobuf>=5.29.5 -psutil>=7.0.0 -pytest>=8.4.1 -PyYAML>=6.0.3 -Requests>=2.32.5 -qwix>=0.1.1 -safetensors>=0.6.2 -sentencepiece>=0.2.1 -setuptools>=80.9.0 -tabulate>=0.9.0 -tensorflow>=2.19.1 -tensorflow_text>=2.19.0 -tensorflow_datasets>=4.9.9 -tensorstore>=0.1.76 -tiktoken>=0.12.0 -tqdm>=4.67.1 -transformers>=4.57.0 -urllib3>=2.5.0 -git+https://github.com/google/tunix.git diff --git a/src/dependencies/requirements/requirements_decoupled_jax_0_8_2.txt b/src/dependencies/requirements/requirements_decoupled_jax_0_9_1.txt similarity index 96% rename from src/dependencies/requirements/requirements_decoupled_jax_0_8_2.txt rename to src/dependencies/requirements/requirements_decoupled_jax_0_9_1.txt index 9479797cb6..20cef3eba3 100644 --- a/src/dependencies/requirements/requirements_decoupled_jax_0_8_2.txt +++ b/src/dependencies/requirements/requirements_decoupled_jax_0_9_1.txt @@ -5,8 +5,7 @@ datasets flax grain[parquet] huggingface_hub -jax==0.8.2 -jaxlib==0.8.2 +jax==0.9.1 jaxtyping jsonlines ml-collections diff --git a/src/dependencies/requirements/requirements_decoupled_rocm_jax_0_7.1.txt b/src/dependencies/requirements/requirements_decoupled_rocm_jax_0_7.1.txt deleted file mode 100644 index 4d1732349c..0000000000 --- a/src/dependencies/requirements/requirements_decoupled_rocm_jax_0_7.1.txt +++ /dev/null @@ -1,47 +0,0 @@ -absl_py>=2.3.1 -aqtp>=0.9.0 -chex>=0.1.90 -datasets>=4.2.0 -etils>=1.13.0 -evaluate>=0.4.6 -flax -grain>=0.2.12 -grpcio>=1.75.1 -huggingface_hub>=0.35.3 -https://github.com/ROCm/rocm-jax/releases/download/rocm-jax-v0.7.1/jaxlib-0.7.1-cp312-cp312-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl -jax==0.7.1 -jax-rocm7-pjrt==0.7.1 -jax-rocm7-plugin==0.7.1 -jaxtyping>=0.3.3 -jsonlines>=4.0.0 -matplotlib>=3.10.3 -ml_collections>=1.1.0 -ml_dtypes>=0.5.3 -nltk>=3.9.2 -numpy>=2.0.2 -omegaconf>=2.3.0 -optax>=0.2.6 -orbax-checkpoint>=0.11.25 -pandas>=2.3.3 -parameterized==0.9.0 -pathwaysutils>=0.1.3 -pillow>=11.3.0 -protobuf>=5.29.5 -psutil>=7.0.0 -pytest>=8.4.1 -PyYAML>=6.0.3 -Requests>=2.32.5 -qwix>=0.1.1 -safetensors>=0.6.2 -sentencepiece>=0.2.1 -setuptools>=80.9.0 -tabulate>=0.9.0 -tensorflow>=2.19.1 -tensorflow_text>=2.19.0 -tensorflow_datasets>=4.9.9 -tensorstore>=0.1.76 -tiktoken>=0.12.0 -tqdm>=4.67.1 -transformers>=4.57.0 -urllib3>=2.5.0 -git+https://github.com/google/tunix.git diff --git a/src/dependencies/requirements/requirements_decoupled_rocm_jax_0_8_2.txt b/src/dependencies/requirements/requirements_decoupled_rocm_jax_0_9_1.txt similarity index 56% rename from src/dependencies/requirements/requirements_decoupled_rocm_jax_0_8_2.txt rename to src/dependencies/requirements/requirements_decoupled_rocm_jax_0_9_1.txt index f821160f03..f9312c756c 100644 --- a/src/dependencies/requirements/requirements_decoupled_rocm_jax_0_8_2.txt +++ b/src/dependencies/requirements/requirements_decoupled_rocm_jax_0_9_1.txt @@ -6,13 +6,16 @@ flax grain[parquet] huggingface_hub +jax==0.9.1 -# ROCm JAX 0.8.2 (py312) wheels (install order matters) -https://github.com/ROCm/rocm-jax/releases/download/rocm-jax-v0.8.2/jax_rocm7_pjrt-0.8.2+rocm7.2.0-py3-none-manylinux_2_28_x86_64.whl -https://github.com/ROCm/rocm-jax/releases/download/rocm-jax-v0.8.2/jax_rocm7_plugin-0.8.2+rocm7.2.0-cp312-cp312-manylinux_2_28_x86_64.whl -https://github.com/ROCm/rocm-jax/releases/download/rocm-jax-v0.8.2/jaxlib-0.8.2+rocm7-cp312-cp312-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl -jax==0.8.2 +# ROCm JAX wheels +jax-rocm7-plugin==0.9.1.* +# Post-training deps (needed at pytest collection time for tests/post_training/unit/*). +# vLLM ROCm wheel from AMD's index; pinned to match the rocm7.1.2 / gfx950 stack. +--extra-index-url https://rocm.frameworks.amd.com/whl/gfx950-dcgpu/ +vllm==0.16.1.dev10+g11515110f.d20260324.rocm712 +math-verify jaxtyping jsonlines diff --git a/src/dependencies/requirements/requirements_rocm_jax_0.8.2.txt b/src/dependencies/requirements/requirements_rocm_jax_0_9_1.txt similarity index 67% rename from src/dependencies/requirements/requirements_rocm_jax_0.8.2.txt rename to src/dependencies/requirements/requirements_rocm_jax_0_9_1.txt index bffcc3a203..84930d6b7e 100644 --- a/src/dependencies/requirements/requirements_rocm_jax_0.8.2.txt +++ b/src/dependencies/requirements/requirements_rocm_jax_0_9_1.txt @@ -6,12 +6,16 @@ flax grain[parquet] huggingface_hub +jax==0.9.1 -# ROCm JAX 0.8.2 (py312) wheels (install order matters) -https://github.com/ROCm/rocm-jax/releases/download/rocm-jax-v0.8.2/jax_rocm7_pjrt-0.8.2+rocm7.2.0-py3-none-manylinux_2_28_x86_64.whl -https://github.com/ROCm/rocm-jax/releases/download/rocm-jax-v0.8.2/jax_rocm7_plugin-0.8.2+rocm7.2.0-cp312-cp312-manylinux_2_28_x86_64.whl -https://github.com/ROCm/rocm-jax/releases/download/rocm-jax-v0.8.2/jaxlib-0.8.2+rocm7-cp312-cp312-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl -jax==0.8.2 +# ROCm JAX wheels +jax-rocm7-plugin==0.9.1.* + +# Post-training deps (needed at pytest collection time for tests/post_training/unit/*). +# vLLM ROCm wheel from AMD's index; pinned to match the rocm7.1.2 / gfx950 stack. +--extra-index-url https://rocm.frameworks.amd.com/whl/gfx950-dcgpu/ +vllm==0.16.1.dev10+g11515110f.d20260324.rocm712 +math-verify jaxtyping From a95c831c38a2cdb4112829539c9c4e5d9dac381b Mon Sep 17 00:00:00 2001 From: gulsumgudukbay Date: Tue, 5 May 2026 18:56:51 +0000 Subject: [PATCH 18/21] refactor req files in CI --- .../workflows/build_rocm_transformer_engine_wheel_weekly.yml | 2 +- .github/workflows/run_tests_coordinator.yml | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/.github/workflows/build_rocm_transformer_engine_wheel_weekly.yml b/.github/workflows/build_rocm_transformer_engine_wheel_weekly.yml index 9b90eee2a0..c43b78114d 100644 --- a/.github/workflows/build_rocm_transformer_engine_wheel_weekly.yml +++ b/.github/workflows/build_rocm_transformer_engine_wheel_weekly.yml @@ -46,7 +46,7 @@ jobs: run: | set -euo pipefail source .venv/bin/activate - uv pip install -r src/dependencies/requirements/requirements_decoupled_rocm_jax_0_8_2.txt + uv pip install -r src/dependencies/requirements/requirements_decoupled_rocm_jax_0_9_1.txt - name: Install PyTorch ROCm (build-time dep for aiter JIT) shell: bash diff --git a/.github/workflows/run_tests_coordinator.yml b/.github/workflows/run_tests_coordinator.yml index 3b62ec5f95..ddf100fd8b 100644 --- a/.github/workflows/run_tests_coordinator.yml +++ b/.github/workflows/run_tests_coordinator.yml @@ -187,8 +187,8 @@ jobs: decoupled_mode: ${{ contains(inputs.flavor, 'rocm') }} requirements_file: >- ${{ fromJSON('{ - "rocm-unit": "src/dependencies/requirements/requirements_rocm_jax_0.8.2.txt", - "rocm-decoupled": "src/dependencies/requirements/requirements_decoupled_rocm_jax_0_8_2.txt" + "rocm-unit": "src/dependencies/requirements/requirements_rocm_jax_0_9_1.txt", + "rocm-decoupled": "src/dependencies/requirements/requirements_decoupled_rocm_jax_0_9_1.txt" }')[inputs.flavor] || '' }} # Metadata From f8ac7f30da636c043e3eb06f0745e31e0478d354 Mon Sep 17 00:00:00 2001 From: gulsumgudukbay Date: Wed, 6 May 2026 05:33:41 +0000 Subject: [PATCH 19/21] fix index url usage in reqs --- .../requirements/requirements_decoupled_rocm_jax_0_9_1.txt | 7 ++++--- .../requirements/requirements_rocm_jax_0_9_1.txt | 7 ++++--- 2 files changed, 8 insertions(+), 6 deletions(-) diff --git a/src/dependencies/requirements/requirements_decoupled_rocm_jax_0_9_1.txt b/src/dependencies/requirements/requirements_decoupled_rocm_jax_0_9_1.txt index f9312c756c..5c785a86b9 100644 --- a/src/dependencies/requirements/requirements_decoupled_rocm_jax_0_9_1.txt +++ b/src/dependencies/requirements/requirements_decoupled_rocm_jax_0_9_1.txt @@ -12,9 +12,10 @@ jax==0.9.1 jax-rocm7-plugin==0.9.1.* # Post-training deps (needed at pytest collection time for tests/post_training/unit/*). -# vLLM ROCm wheel from AMD's index; pinned to match the rocm7.1.2 / gfx950 stack. ---extra-index-url https://rocm.frameworks.amd.com/whl/gfx950-dcgpu/ -vllm==0.16.1.dev10+g11515110f.d20260324.rocm712 +# vLLM ROCm wheel pinned to match the rocm7.1.2 / gfx950 stack. Installed via direct +# URL (instead of --extra-index-url) so AMD's flat index isn't probed for unrelated +# packages (e.g. setuptools during sdist builds), which would 403 and break resolution. +vllm @ https://rocm.frameworks.amd.com/whl/gfx950-dcgpu/vllm-0.16.1.dev10+g11515110f.d20260324.rocm712-cp312-cp312-linux_x86_64.whl math-verify jaxtyping diff --git a/src/dependencies/requirements/requirements_rocm_jax_0_9_1.txt b/src/dependencies/requirements/requirements_rocm_jax_0_9_1.txt index 84930d6b7e..a0f965d322 100644 --- a/src/dependencies/requirements/requirements_rocm_jax_0_9_1.txt +++ b/src/dependencies/requirements/requirements_rocm_jax_0_9_1.txt @@ -12,9 +12,10 @@ jax==0.9.1 jax-rocm7-plugin==0.9.1.* # Post-training deps (needed at pytest collection time for tests/post_training/unit/*). -# vLLM ROCm wheel from AMD's index; pinned to match the rocm7.1.2 / gfx950 stack. ---extra-index-url https://rocm.frameworks.amd.com/whl/gfx950-dcgpu/ -vllm==0.16.1.dev10+g11515110f.d20260324.rocm712 +# vLLM ROCm wheel pinned to match the rocm7.1.2 / gfx950 stack. Installed via direct +# URL (instead of --extra-index-url) so AMD's flat index isn't probed for unrelated +# packages (e.g. setuptools during sdist builds), which would 403 and break resolution. +vllm @ https://rocm.frameworks.amd.com/whl/gfx950-dcgpu/vllm-0.16.1.dev10+g11515110f.d20260324.rocm712-cp312-cp312-linux_x86_64.whl math-verify From 7190f0f9e57b8a21bebcb6d4dcdd3f09a13665fd Mon Sep 17 00:00:00 2001 From: gulsumgudukbay Date: Wed, 6 May 2026 17:48:58 +0000 Subject: [PATCH 20/21] update testing and requirements --- .github/workflows/run_tests_coordinator.yml | 10 +++++----- .../requirements_decoupled_rocm_jax_0_9_1.txt | 5 ----- .../requirements/requirements_rocm_jax_0_9_1.txt | 5 ----- 3 files changed, 5 insertions(+), 15 deletions(-) diff --git a/.github/workflows/run_tests_coordinator.yml b/.github/workflows/run_tests_coordinator.yml index ddf100fd8b..4fea6207aa 100644 --- a/.github/workflows/run_tests_coordinator.yml +++ b/.github/workflows/run_tests_coordinator.yml @@ -125,8 +125,8 @@ jobs: "gpu-integration": "not cpu_only and not tpu_only and integration_test and not post_training", "cpu-unit": "cpu_only and not post_training", "cpu-post-training-unit": "cpu_only and post_training", - "rocm-unit": "not cpu_only and not tpu_only and not decoupled", - "rocm-decoupled": "not cpu_only and not tpu_only and decoupled" + "rocm-unit": "not cpu_only and not tpu_only and not post_training and not decoupled", + "rocm-decoupled": "not cpu_only and not tpu_only and not post_training and decoupled" }')[inputs.flavor] }} pytest_addopts: >- @@ -153,8 +153,8 @@ jobs: "gpu-integration": "--ignore=tests/post_training", "cpu-unit": "--ignore=tests/post_training", "cpu-post-training-unit": "", - "rocm-unit": "", - "rocm-decoupled": "" + "rocm-unit": "--ignore=tests/post_training", + "rocm-decoupled": "--ignore=tests/post_training" }')[inputs.flavor] }} ${{ inputs.additional_pytest_args }} @@ -184,7 +184,7 @@ jobs: }')[inputs.flavor] || '--privileged' }} # ROCm-specific parameters - decoupled_mode: ${{ contains(inputs.flavor, 'rocm') }} + decoupled_mode: ${{ contains(inputs.flavor, 'decoupled') }} requirements_file: >- ${{ fromJSON('{ "rocm-unit": "src/dependencies/requirements/requirements_rocm_jax_0_9_1.txt", diff --git a/src/dependencies/requirements/requirements_decoupled_rocm_jax_0_9_1.txt b/src/dependencies/requirements/requirements_decoupled_rocm_jax_0_9_1.txt index 5c785a86b9..95479d30be 100644 --- a/src/dependencies/requirements/requirements_decoupled_rocm_jax_0_9_1.txt +++ b/src/dependencies/requirements/requirements_decoupled_rocm_jax_0_9_1.txt @@ -11,11 +11,6 @@ jax==0.9.1 # ROCm JAX wheels jax-rocm7-plugin==0.9.1.* -# Post-training deps (needed at pytest collection time for tests/post_training/unit/*). -# vLLM ROCm wheel pinned to match the rocm7.1.2 / gfx950 stack. Installed via direct -# URL (instead of --extra-index-url) so AMD's flat index isn't probed for unrelated -# packages (e.g. setuptools during sdist builds), which would 403 and break resolution. -vllm @ https://rocm.frameworks.amd.com/whl/gfx950-dcgpu/vllm-0.16.1.dev10+g11515110f.d20260324.rocm712-cp312-cp312-linux_x86_64.whl math-verify jaxtyping diff --git a/src/dependencies/requirements/requirements_rocm_jax_0_9_1.txt b/src/dependencies/requirements/requirements_rocm_jax_0_9_1.txt index a0f965d322..b32d6d1380 100644 --- a/src/dependencies/requirements/requirements_rocm_jax_0_9_1.txt +++ b/src/dependencies/requirements/requirements_rocm_jax_0_9_1.txt @@ -11,11 +11,6 @@ jax==0.9.1 # ROCm JAX wheels jax-rocm7-plugin==0.9.1.* -# Post-training deps (needed at pytest collection time for tests/post_training/unit/*). -# vLLM ROCm wheel pinned to match the rocm7.1.2 / gfx950 stack. Installed via direct -# URL (instead of --extra-index-url) so AMD's flat index isn't probed for unrelated -# packages (e.g. setuptools during sdist builds), which would 403 and break resolution. -vllm @ https://rocm.frameworks.amd.com/whl/gfx950-dcgpu/vllm-0.16.1.dev10+g11515110f.d20260324.rocm712-cp312-cp312-linux_x86_64.whl math-verify From 9d55814e17ffcd22e5d160529743556ea10dfa10 Mon Sep 17 00:00:00 2001 From: Pakize Sanal Date: Mon, 11 May 2026 00:55:17 +0000 Subject: [PATCH 21/21] Add ROCm benchmark configs and requirements for MaxText --- .../requirements_rocm_benchmark.txt | 52 +++++++++++++++++++ .../configs/gpu/models/gemma3-4b-rocm.yml | 38 ++++++++++++++ 2 files changed, 90 insertions(+) create mode 100644 src/dependencies/requirements/requirements_rocm_benchmark.txt create mode 100644 src/maxtext/configs/gpu/models/gemma3-4b-rocm.yml diff --git a/src/dependencies/requirements/requirements_rocm_benchmark.txt b/src/dependencies/requirements/requirements_rocm_benchmark.txt new file mode 100644 index 0000000000..7dfded642e --- /dev/null +++ b/src/dependencies/requirements/requirements_rocm_benchmark.txt @@ -0,0 +1,52 @@ +absl-py +aqtp +array-record +datasets +flax +grain[parquet] +huggingface_hub +jaxtyping +jsonlines +ml-collections +numpy +omegaconf +optax +orbax-checkpoint +pathwaysutils +pillow +pre-commit +protobuf +pydantic +pyink +pylint +pytest +parameterized +pytype +sentencepiece +tensorboardx +tensorflow-datasets +tensorflow-text<2.20 +tensorflow<2.20 +tiktoken +numba>=0.59.0 +google-tunix +tokamax @ git+https://github.com/openxla/tokamax.git@69b328121c3ee120d8e54cf26d1565cea189617f +transformers>=4.57.3,<5 +qwix +mlperf-logging @ https://github.com/mlcommons/logging/archive/38ab22670527888c8eb7825a4ece176fcc36a95d.zip + + +# GCP / Cloud deps (restored for non-decoupled ROCm runs) +cloud-accelerator-diagnostics +cloud-tpu-diagnostics +gcsfs +google-api-python-client +google-cloud-aiplatform +google-cloud-mldiagnostics +google-cloud-monitoring +ml-goodput-measurement +tensorboard-plugin-profile +google-jetstream @ https://github.com/AI-Hypercomputer/JetStream/archive/29329e8e73820993f77cfc8efe34eb2a73f5de98.zip +chex>=0.1.91 +drjax>=0.1.4 + diff --git a/src/maxtext/configs/gpu/models/gemma3-4b-rocm.yml b/src/maxtext/configs/gpu/models/gemma3-4b-rocm.yml new file mode 100644 index 0000000000..92a5cc8d80 --- /dev/null +++ b/src/maxtext/configs/gpu/models/gemma3-4b-rocm.yml @@ -0,0 +1,38 @@ +# model config for gemma3-4b-rocm +base_config: "base.yml" + +steps: 15 +hardware: "gpu" +model_name: "gemma3-4b" +dataset_type: "synthetic" +enable_checkpointing: False +max_segments_per_seq: 32 + +base_num_decoder_layers: 34 +base_emb_dim: 2560 +base_num_query_heads: 8 +base_num_kv_heads: 4 +base_mlp_dim: 10240 +head_dim: 256 +mlp_activations: ["gelu","linear"] +vocab_size: 262_144 +decoder_block: "gemma3" +normalization_layer_epsilon: 1e-6 +logits_via_embedding: True +sliding_window_size: 1024 +use_post_attn_norm: true +use_post_ffw_norm: true +local_rope_max_timescale: 10_000 +rope_max_timescale: 1_000_000 +rope_linear_scaling_factor: 8.0 + +# Multimodal flags (need to set use_multimodal=true) +image_size_for_vit: 896 +num_channels_for_vit: 3 +patch_size_for_vit: 14 +conv_stride_for_vit: 14 +hidden_size_for_vit: 1152 +intermediate_size_for_vit: 4304 +num_hidden_layers_for_vit: 27 +num_attention_heads_for_vit: 16 +