From 1be9f7b4bf2da98bbe81aa662df233135a8b3298 Mon Sep 17 00:00:00 2001 From: gulsumgudukbay Date: Thu, 28 May 2026 06:54:56 +0000 Subject: [PATCH 1/3] CI: integrate ROCm flavors into upstream test workflows Wire ROCm (rocm-unit, rocm-decoupled) jobs into the existing test coordinator and package-test workflows, and add install_te_rocm_wheel.py to fetch the MI355 Transformer Engine wheel during container setup. - build_and_test_maxtext.yml: add ROCm jobs and ROCM_ONLY gating on all sibling jobs; switch the daily schedule to 03:00 UTC; expand concurrency to cover manual dispatch per (branch + actor). - run_tests_against_package.yml: select the ROCm base image for rocm device_type, add decoupled_mode + requirements_file + extra pip deps inputs, install the TE wheel and select arch before running tests, ulimit + libtpu init guards for rocm. - run_tests_coordinator.yml: add rocm-unit / rocm-decoupled flavors with their pytest markers, runner labels, container options and ROCm requirements files; route decoupled_mode for the decoupled flavor. - install_te_rocm_wheel.py: download the MI355 TE wheel from the repo's te-rocm-wheels release, falling back to the pinned ROCm/ maxtext release asset. MI355-only (no MI300 path, no arch detection). --- .github/workflows/build_and_test_maxtext.yml | 73 +++++++++--- .../workflows/run_tests_against_package.yml | 100 ++++++++++++++-- .github/workflows/run_tests_coordinator.yml | 62 ++++++++-- .../workflows/utils/install_te_rocm_wheel.py | 110 ++++++++++++++++++ 4 files changed, 305 insertions(+), 40 deletions(-) create mode 100644 .github/workflows/utils/install_te_rocm_wheel.py diff --git a/.github/workflows/build_and_test_maxtext.yml b/.github/workflows/build_and_test_maxtext.yml index ced53b54ab..cc9b611657 100644 --- a/.github/workflows/build_and_test_maxtext.yml +++ b/.github/workflows/build_and_test_maxtext.yml @@ -20,16 +20,23 @@ on: pull_request: workflow_call: workflow_dispatch: + inputs: + rocm_only: + description: 'Run only ROCm jobs (manual runs only)' + type: boolean + required: false + default: false schedule: - # Run the job every 4 hours - - cron: '0 */4 * * *' + - cron: '0 3 * * *' # Daily 03:00 UTC concurrency: - # Dedup pull requests (canceling previous runs of the same workflow for same PR), and scheduled runs but nothing else + # Cancel previous runs for same PR (all actors), scheduled runs, + # and manual runs per (branch + actor). group: > ${{ github.event_name == 'pull_request' && format('{0}-pr-{1}', github.workflow, github.event.pull_request.number) || github.event_name == 'schedule' && format('{0}-schedule', github.workflow) || + github.event_name == 'workflow_dispatch' && format('{0}-manual-{1}-{2}', github.workflow, github.ref, github.actor) || github.run_id }} cancel-in-progress: true @@ -118,18 +125,16 @@ jobs: build_and_upload_maxtext_package: needs: analyze_code_changes # Run if either tests or notebooks need to run - if: | - needs.analyze_code_changes.outputs.run_tests == 'true' || - needs.analyze_code_changes.outputs.run_notebooks == 'true' + if: ${{ vars.ROCM_ONLY == 'true' || needs.analyze_code_changes.outputs.run_tests == 'true' || needs.analyze_code_changes.outputs.run_notebooks == 'true' }} uses: ./.github/workflows/build_package.yml with: - device_type: tpu - device_name: v4-8 - cloud_runner: linux-x86-n2-16-buildkit + device_type: ${{ vars.ROCM_ONLY == 'true' && 'rocm' || 'tpu' }} + device_name: ${{ vars.ROCM_ONLY == 'true' && 'mi355' || 'v4-8' }} + cloud_runner: ${{ vars.ROCM_ONLY == 'true' && 'linux-x86-64-4gpu-amd' || 'linux-x86-n2-16-buildkit' }} maxtext_jupyter_notebooks: needs: build_and_upload_maxtext_package - if: needs.analyze_code_changes.outputs.run_notebooks == 'true' + if: ${{ vars.ROCM_ONLY != 'true' && !(github.event_name == 'workflow_dispatch' && inputs.rocm_only) && needs.analyze_code_changes.outputs.run_notebooks == 'true' }} uses: ./.github/workflows/run_jupyter_notebooks.yml strategy: fail-fast: false @@ -145,7 +150,7 @@ jobs: tpu-tests: name: ${{ matrix.flavor }} tests needs: [build_and_upload_maxtext_package] - if: needs.analyze_code_changes.outputs.run_tests == 'true' + if: ${{ vars.ROCM_ONLY != 'true' && !(github.event_name == 'workflow_dispatch' && inputs.rocm_only) && needs.analyze_code_changes.outputs.run_tests == 'true' }} uses: ./.github/workflows/run_tests_coordinator.yml strategy: fail-fast: false @@ -160,7 +165,7 @@ jobs: gpu-tests: name: ${{ matrix.flavor }} tests needs: [build_and_upload_maxtext_package] - if: needs.analyze_code_changes.outputs.run_tests == 'true' + if: ${{ vars.ROCM_ONLY != 'true' && !(github.event_name == 'workflow_dispatch' && inputs.rocm_only) && needs.analyze_code_changes.outputs.run_tests == 'true' }} strategy: fail-fast: false matrix: @@ -175,7 +180,7 @@ jobs: cpu-tests: name: ${{ matrix.flavor }} tests needs: [build_and_upload_maxtext_package] - if: needs.analyze_code_changes.outputs.run_tests == 'true' + if: ${{ vars.ROCM_ONLY != 'true' && !(github.event_name == 'workflow_dispatch' && inputs.rocm_only) && needs.analyze_code_changes.outputs.run_tests == 'true' }} uses: ./.github/workflows/run_tests_coordinator.yml strategy: fail-fast: false @@ -189,7 +194,7 @@ jobs: maxtext_tpu_pathways_unit_tests: needs: build_and_upload_maxtext_package - if: needs.analyze_code_changes.outputs.run_tests == 'true' + if: ${{ vars.ROCM_ONLY != 'true' && !(github.event_name == 'workflow_dispatch' && inputs.rocm_only) && needs.analyze_code_changes.outputs.run_tests == 'true' }} uses: ./.github/workflows/run_pathways_tests.yml strategy: fail-fast: false @@ -224,9 +229,39 @@ jobs: is_scheduled_run: ${{ github.event_name == 'schedule' }} maxtext_sha: ${{ needs.build_and_upload_maxtext_package.outputs.maxtext_sha }} + rocm-tests: + name: ${{ matrix.flavor }} tests + needs: [build_and_upload_maxtext_package] + if: ${{ vars.ROCM_ONLY != 'true' && needs.analyze_code_changes.outputs.run_tests == 'true' }} + uses: ./.github/workflows/run_tests_coordinator.yml + strategy: + fail-fast: false + matrix: + flavor: [rocm-unit] + with: + flavor: ${{ matrix.flavor }} + base_image: 'rocm-placeholder' + is_scheduled_run: ${{ github.event_name == 'schedule' }} + maxtext_sha: ${{ needs.build_and_upload_maxtext_package.outputs.maxtext_sha }} + + rocm-decoupled-tests: + name: ${{ matrix.flavor }} tests + needs: [build_and_upload_maxtext_package] + if: ${{ vars.ROCM_ONLY == 'true' || (github.event_name == 'workflow_dispatch' && inputs.rocm_only) || needs.analyze_code_changes.outputs.run_tests == 'true' }} + uses: ./.github/workflows/run_tests_coordinator.yml + strategy: + fail-fast: false + matrix: + flavor: [rocm-decoupled] + with: + flavor: ${{ matrix.flavor }} + base_image: 'rocm-placeholder' + is_scheduled_run: ${{ github.event_name == 'schedule' }} + maxtext_sha: ${{ needs.build_and_upload_maxtext_package.outputs.maxtext_sha }} + all_tests_passed: name: All Required Tests Passed - needs: [tpu-tests, gpu-tests, cpu-tests, maxtext_tpu_pathways_unit_tests, maxtext_tpu_pathways_integration_tests] + needs: [analyze_code_changes, build_and_upload_maxtext_package, tpu-tests, gpu-tests, cpu-tests, maxtext_tpu_pathways_unit_tests, maxtext_tpu_pathways_integration_tests, rocm-tests, rocm-decoupled-tests] if: always() runs-on: ubuntu-latest steps: @@ -244,6 +279,8 @@ jobs: echo "CPU Tests (Matrix) result: ${NEEDS_CPU_TESTS_RESULT}" echo "Pathways Unit result: ${NEEDS_MAXTEXT_TPU_PATHWAYS_UNIT_TESTS_RESULT}" echo "Pathways Integration result: ${NEEDS_MAXTEXT_TPU_PATHWAYS_INTEGRATION_TESTS_RESULT}" + echo "ROCm Tests (Matrix) result: ${NEEDS_ROCM_TESTS_RESULT}" + echo "ROCm Decoupled Tests (Matrix) result: ${NEEDS_ROCM_DECOUPLED_TESTS_RESULT}" # Fail only if any job failed or was cancelled (skipped is OK) if [ "${{ contains(needs.*.result, 'failure') }}" == "true" ] || [ "${{ contains(needs.*.result, 'cancelled') }}" == "true" ]; then @@ -260,11 +297,13 @@ jobs: NEEDS_GPU_TESTS_RESULT: ${{ needs.gpu-tests.result }} NEEDS_MAXTEXT_TPU_PATHWAYS_UNIT_TESTS_RESULT: ${{ needs.maxtext_tpu_pathways_unit_tests.result }} NEEDS_MAXTEXT_TPU_PATHWAYS_INTEGRATION_TESTS_RESULT: ${{ needs.maxtext_tpu_pathways_integration_tests.result }} + NEEDS_ROCM_TESTS_RESULT: ${{ needs.rocm-tests.result }} + NEEDS_ROCM_DECOUPLED_TESTS_RESULT: ${{ needs.rocm-decoupled-tests.result }} all_notebooks_passed: name: All Notebooks Passed needs: [analyze_code_changes, build_and_upload_maxtext_package, maxtext_jupyter_notebooks] - if: always() + if: ${{ vars.ROCM_ONLY != 'true' && !(github.event_name == 'workflow_dispatch' && inputs.rocm_only) && always() }} runs-on: ubuntu-latest steps: - name: Check notebooks results @@ -292,7 +331,7 @@ jobs: notify_failure: name: Notify failed build # creates an issue or modifies last open existing issue for failed build - needs: [tpu-tests, gpu-tests, cpu-tests, maxtext_jupyter_notebooks, maxtext_tpu_pathways_unit_tests, maxtext_tpu_pathways_integration_tests] + needs: [tpu-tests, gpu-tests, cpu-tests, maxtext_jupyter_notebooks, maxtext_tpu_pathways_unit_tests, maxtext_tpu_pathways_integration_tests, rocm-tests, rocm-decoupled-tests] if: ${{ always() }} runs-on: ubuntu-latest permissions: diff --git a/.github/workflows/run_tests_against_package.yml b/.github/workflows/run_tests_against_package.yml index 5c44e268bc..0dce000799 100644 --- a/.github/workflows/run_tests_against_package.yml +++ b/.github/workflows/run_tests_against_package.yml @@ -40,7 +40,7 @@ on: default: '' is_scheduled_run: required: true - type: string + type: boolean xla_python_client_mem_fraction: required: true type: string @@ -65,6 +65,17 @@ on: description: 'Git SHA to checkout if MaxText is not pre-installed' required: false type: string + decoupled_mode: + required: false + type: boolean + default: false + requirements_file: + required: false + type: string + default: '' + extra_pip_deps_file: + required: false + type: string # Flag to skip source checkout and wheel installation maxtext_installed: description: 'If false, maxtext_sha must be provided for checkout' @@ -78,19 +89,24 @@ permissions: contents: read jobs: run: - runs-on: ${{ inputs.cloud_runner != '' && inputs.cloud_runner || fromJson(format('["self-hosted", "{0}", "{1}"]', inputs.device_type, inputs.device_name)) }} + runs-on: ${{ inputs.cloud_runner != '' && inputs.cloud_runner || (inputs.device_type == 'rocm' && fromJson('["self-hosted","linux-x86-64-4gpu-amd"]')) || fromJson(format('["self-hosted", "{0}", "{1}"]', inputs.device_type, inputs.device_name)) }} + timeout-minutes: ${{ inputs.device_type == 'rocm' && 90 || 360 }} container: - image: gcr.io/${{ vars.PROJECT_NAME || 'tpu-prod-env-multipod' }}/${{ inputs.base_image }} + image: ${{ inputs.device_type == 'rocm' && 'ghcr.io/rocm/jax-base-ubu24.rocm720:latest' || format('gcr.io/{0}/{1}', vars.PROJECT_NAME || 'tpu-prod-env-multipod', inputs.base_image) }} env: XLA_PYTHON_CLIENT_MEM_FRACTION: ${{ inputs.xla_python_client_mem_fraction }} TF_FORCE_GPU_ALLOW_GROWTH: ${{ inputs.tf_force_gpu_allow_growth }} TPU_SKIP_MDS_QUERY: ${{ inputs.device_type == 'cpu' && '1' || '' }} + # ROCm installs the wheel with --no-deps plus a pinned requirements file; never use TPU wheel extras. MAXTEXT_PACKAGE_EXTRA: >- ${{ - !contains(inputs.pytest_marker, 'not post_training') && 'tpu-post-train' + inputs.device_type == 'rocm' && 'rocm' + || !contains(inputs.pytest_marker, 'not post_training') && 'tpu-post-train' || (inputs.device_type == 'cpu' && 'tpu' || inputs.device_type) }} ALLOW_MULTIPLE_LIBTPU_LOAD: ${{ inputs.device_type == 'cpu' && 'true' || '' }} # bypass /tmp/libtpu_lockfile check for cpu tests, which don't actually use accelerators (to allow concurrency) + DECOUPLE_GCLOUD: ${{ inputs.decoupled_mode && 'TRUE' || '' }} + LOCAL_GCLOUD_PROJECT: ${{ inputs.decoupled_mode && 'ci-decoupled' || '' }} options: ${{ inputs.container_resource_option }} steps: - name: Checkout MaxText @@ -107,20 +123,79 @@ jobs: if: ${{ !inputs.maxtext_installed }} shell: bash run: | + if [ "${{ inputs.device_type }}" = "rocm" ]; then + python3 -m pip install -U uv + fi python3 -m uv venv --seed source .venv/bin/activate maxtext_wheel=$(ls maxtext-*-py3-none-any.whl 2>/dev/null) echo "Installing ${maxtext_wheel} for ${MAXTEXT_PACKAGE_EXTRA}..." - uv pip install ${maxtext_wheel}[${MAXTEXT_PACKAGE_EXTRA}] --resolution=lowest - if [ "${MAXTEXT_PACKAGE_EXTRA}" == "tpu-post-train" ]; then - install_tpu_post_train_extra_deps + if [ "${{ inputs.device_type }}" = "rocm" ]; then + if [ -n "${{ inputs.requirements_file }}" ]; then + echo "Installing requirements from ${{ inputs.requirements_file }}" + uv pip install -r "${{ inputs.requirements_file }}" + fi + uv pip install ${maxtext_wheel} --no-deps + # When a requirements file is set (e.g. decoupled or rocm-unit JAX stacks), it already + # carries test/runtime pins; otherwise add GitHub-sourced pins (no TPU extra / scripts). + if [ -z "${{ inputs.requirements_file }}" ]; then + uv pip install -r src/dependencies/extra_deps/pre_train_github_deps.txt --no-deps + fi else - install_tpu_pre_train_extra_deps + if [ -n "${{ inputs.requirements_file }}" ]; then + echo "Installing requirements from ${{ inputs.requirements_file }}" + uv pip install -r "${{ inputs.requirements_file }}" + uv pip install ${maxtext_wheel} --no-deps + else + uv pip install ${maxtext_wheel}[${MAXTEXT_PACKAGE_EXTRA}] --resolution=lowest + fi + if [ "${MAXTEXT_PACKAGE_EXTRA}" == "tpu-post-train" ]; then + install_tpu_post_train_extra_deps + else + install_tpu_pre_train_extra_deps + fi fi python3 --version python3 -m pip freeze + - name: Install extra pip deps + if: inputs.extra_pip_deps_file != '' + shell: bash + run: | + source .venv/bin/activate + uv pip install -r ${{ inputs.extra_pip_deps_file }} + + - name: Select ROCm arch (mi300/mi355) + if: ${{ inputs.device_type == 'rocm' }} + shell: bash + run: | + set -euo pipefail + echo "=== ROCm arch selection (from install_te_rocm_wheel.py) ===" + if [ -x .venv/bin/python3 ]; then + te_arch="$( + .venv/bin/python3 .github/workflows/utils/install_te_rocm_wheel.py --print-arch + )" + echo "[te detect_arch] ${te_arch}" + echo "TE_WHEEL_ARCH=${te_arch}" >> "${GITHUB_ENV}" + else + echo "No .venv python found; skipping arch selection." + fi + echo "=========================================================" + + - name: Install Transformer Engine wheel (ROCm) + if: ${{ inputs.device_type == 'rocm' }} + shell: bash + env: + GITHUB_TOKEN: ${{ github.token }} + run: | + source .venv/bin/activate + set -euo pipefail + + python3 .github/workflows/utils/install_te_rocm_wheel.py + + uv pip install --no-deps --force-reinstall transformer_engine-*.whl + - name: Copy test assets files - if: ${{ !inputs.maxtext_installed }} + if: ${{ !inputs.maxtext_installed && !inputs.decoupled_mode }} run : gcloud storage cp gs://maxtext-test-assets/* tests/assets - name: Run Tests shell: bash @@ -153,8 +228,8 @@ jobs: export MAXTEXT_ASSETS_ROOT=$(pwd)/src/maxtext/assets export MAXTEXT_TEST_ASSETS_ROOT=$(pwd)/tests/assets export MAXTEXT_PKG_DIR=$(pwd)/src/maxtext - # omit this libtpu init args for gpu tests - if [ "${INPUTS_DEVICE_TYPE}" != "cuda12" ]; then + # omit this libtpu init args for gpu tests (cuda + rocm) + if [ "${INPUTS_DEVICE_TYPE}" != "cuda12" ] && [ "${INPUTS_DEVICE_TYPE}" != "rocm" ]; then export LIBTPU_INIT_ARGS='--xla_tpu_scoped_vmem_limit_kib=65536' else # For cuda12, explicitly point to the pip-installed CUDA libraries @@ -171,6 +246,9 @@ jobs: done fi fi + if [ "${INPUTS_DEVICE_TYPE}" = "rocm" ]; then + ulimit -c 0 + fi if [ "${INPUTS_TOTAL_WORKERS}" -gt 1 ]; then $PYTHON_EXE -m pip install --quiet pytest-split pytest-xdist SPLIT_ARGS="--splits ${INPUTS_TOTAL_WORKERS} --group ${INPUTS_WORKER_GROUP} -n auto" diff --git a/.github/workflows/run_tests_coordinator.yml b/.github/workflows/run_tests_coordinator.yml index c6c15acd3f..0aa3760cf2 100644 --- a/.github/workflows/run_tests_coordinator.yml +++ b/.github/workflows/run_tests_coordinator.yml @@ -27,7 +27,9 @@ on: tpu-post-training-unit, tpu-post-training-integration, gpu-unit, gpu-integration, cpu-unit, - cpu-post-training-unit + cpu-post-training-unit, + rocm-unit, + rocm-decoupled ) required: true type: string @@ -80,7 +82,9 @@ jobs: "gpu-unit": "cuda12", "gpu-integration": "cuda12", "cpu-unit": "cpu", - "cpu-post-training-unit": "cpu" + "cpu-post-training-unit": "cpu", + "rocm-unit": "rocm", + "rocm-decoupled": "rocm" }')[inputs.flavor] }} device_name: >- @@ -92,7 +96,9 @@ jobs: "gpu-unit": "a100-40gb-4", "gpu-integration": "a100-40gb-4", "cpu-unit": "X64", - "cpu-post-training-unit": "X64" + "cpu-post-training-unit": "X64", + "rocm-unit": "mi355", + "rocm-decoupled": "mi355" }')[inputs.flavor] }} cloud_runner: >- @@ -104,7 +110,9 @@ jobs: "gpu-unit": "linux-x86-a2-48-a100-4gpu", "gpu-integration": "linux-x86-a2-48-a100-4gpu", "cpu-unit": "linux-x86-n2-32", - "cpu-post-training-unit": "linux-x86-n2-32" + "cpu-post-training-unit": "linux-x86-n2-32", + "rocm-unit": "linux-x86-64-4gpu-amd", + "rocm-decoupled": "linux-x86-64-4gpu-amd" }')[inputs.flavor] }} # Pytest Marker Mapping pytest_marker: >- @@ -116,7 +124,9 @@ jobs: "gpu-unit": "not cpu_only and not tpu_only and not integration_test and not post_training", "gpu-integration": "not cpu_only and not tpu_only and integration_test and not post_training", "cpu-unit": "cpu_only and not post_training", - "cpu-post-training-unit": "cpu_only and post_training" + "cpu-post-training-unit": "cpu_only and post_training", + "rocm-unit": "not cpu_only and not tpu_only and not post_training and not decoupled", + "rocm-decoupled": "not cpu_only and not tpu_only and not post_training and decoupled" }')[inputs.flavor] }} pytest_addopts: >- @@ -128,7 +138,9 @@ jobs: "gpu-unit": "", "gpu-integration": "", "cpu-unit": "", - "cpu-post-training-unit": "tests/post_training/unit tests/unit" + "cpu-post-training-unit": "tests/post_training/unit tests/unit", + "rocm-unit": "", + "rocm-decoupled": "" }')[inputs.flavor] }} pytest_extra_args: >- @@ -140,18 +152,44 @@ jobs: "gpu-unit": "--ignore=tests/post_training", "gpu-integration": "--ignore=tests/post_training", "cpu-unit": "--ignore=tests/post_training", - "cpu-post-training-unit": "" + "cpu-post-training-unit": "", + "rocm-unit": "--ignore=tests/post_training", + "rocm-decoupled": "--ignore=tests/post_training" }')[inputs.flavor] }} ${{ inputs.additional_pytest_args }} # Resource Scaling - xla_python_client_mem_fraction: "${{ contains(inputs.flavor, 'gpu') && '0.65' || '0.75' }}" - tf_force_gpu_allow_growth: "${{ contains(inputs.flavor, 'gpu') && 'true' || 'false' }}" + xla_python_client_mem_fraction: >- + ${{ fromJSON('{ + "gpu-unit": "0.65", + "gpu-integration": "0.65", + "rocm-unit": "0.9", + "rocm-decoupled": "0.9" + }')[inputs.flavor] || '0.75' }} + + tf_force_gpu_allow_growth: >- + ${{ fromJSON('{ + "gpu-unit": "true", + "gpu-integration": "true", + "rocm-unit": "true", + "rocm-decoupled": "true" + }')[inputs.flavor] || 'false' }} container_resource_option: >- - ${{ contains(inputs.flavor, 'gpu') - && '--shm-size 2g --runtime=nvidia --gpus all --privileged' - || '--privileged' }} + ${{ fromJSON('{ + "gpu-unit": "--shm-size 2g --runtime=nvidia --gpus all --privileged", + "gpu-integration": "--shm-size 2g --runtime=nvidia --gpus all --privileged", + "rocm-unit": "--device=/dev/kfd --device=/dev/dri --group-add video --ipc=host --shm-size 64g --cap-add=SYS_PTRACE --security-opt seccomp=unconfined --privileged", + "rocm-decoupled": "--device=/dev/kfd --device=/dev/dri --group-add video --ipc=host --shm-size 64g --cap-add=SYS_PTRACE --security-opt seccomp=unconfined --privileged" + }')[inputs.flavor] || '--privileged' }} + + # ROCm-specific parameters + decoupled_mode: ${{ contains(inputs.flavor, 'decoupled') }} + requirements_file: >- + ${{ fromJSON('{ + "rocm-unit": "src/dependencies/requirements/requirements_rocm_jax_0_9_1.txt", + "rocm-decoupled": "src/dependencies/requirements/requirements_decoupled_rocm_jax_0_9_1.txt" + }')[inputs.flavor] || '' }} # Metadata base_image: ${{ inputs.base_image }} diff --git a/.github/workflows/utils/install_te_rocm_wheel.py b/.github/workflows/utils/install_te_rocm_wheel.py new file mode 100644 index 0000000000..85dbeaa3f1 --- /dev/null +++ b/.github/workflows/utils/install_te_rocm_wheel.py @@ -0,0 +1,110 @@ +#!/usr/bin/env python3 +""" +ROCm CI helper: +- Resolve the MI355 Transformer Engine wheel from this repo's + 'te-rocm-wheels' release assets, falling back to the pinned ROCm/maxtext + release asset if needed. + +CI runners are MI355 only, so no architecture detection is performed. + +This script only downloads the wheel into the current working directory. +The caller should then install it (e.g. `uv pip install transformer_engine-*.whl`). +""" + +from __future__ import annotations + +import argparse +import json +import os +import re +import urllib.error +import urllib.request + + +# CI runners are always MI355 (gfx950); no detection or fallback needed. +WHEEL_ARCH = "mi355" + + +def _headers() -> dict[str, str]: + token = os.environ.get("GITHUB_TOKEN", "") + headers: dict[str, str] = { + "Accept": "application/vnd.github+json", + "User-Agent": "maxtext-ci", + } + if token: + headers["Authorization"] = f"Bearer {token}" + return headers + + +def download(url: str, out_name: str) -> None: + req = urllib.request.Request(url, headers=_headers()) + with urllib.request.urlopen(req) as r, open(out_name, "wb") as f: + f.write(r.read()) + print(f"[te wheel] downloaded {out_name}", flush=True) + + +def try_download_from_te_rocm_wheels(repo: str) -> bool: + """Download from this repo's 'te-rocm-wheels' release tag if present.""" + api = f"https://api.github.com/repos/{repo}/releases/tags/te-rocm-wheels" + req = urllib.request.Request(api, headers=_headers()) + with urllib.request.urlopen(req) as r: + rel = json.loads(r.read().decode("utf-8")) + + assets = rel.get("assets", []) + name_re = re.compile(rf"^transformer_engine-.*-1\.{WHEEL_ARCH}-cp312-cp312-linux_x86_64\.whl$") + matches = [a for a in assets if name_re.match(a.get("name", ""))] + if not matches: + return False + + # Rolling tag keeps many wheels; select newest matching asset. + hit = max(matches, key=lambda a: a.get("created_at", "")) + print( + "[te wheel] selected latest te-rocm-wheels asset: " + f"{hit.get('name', '')} (created_at={hit.get('created_at', 'unknown')})", + flush=True, + ) + + download(hit["browser_download_url"], hit["name"]) + return True + + +def main(argv: list[str] | None = None) -> int: + """Entry point.""" + parser = argparse.ArgumentParser(add_help=True) + parser.add_argument( + "--print-arch", + action="store_true", + help="Print the wheel arch (always 'mi355') and exit.", + ) + args = parser.parse_args(argv) + + if args.print_arch: + print(WHEEL_ARCH, flush=True) + return 0 + + repo = os.environ.get("GITHUB_REPOSITORY", "") + if not repo: + print("[te wheel] GITHUB_REPOSITORY not set; skipping.", flush=True) + return 0 + + print(f"[te wheel] arch={WHEEL_ARCH}", flush=True) + + # 1) Prefer: this repo's te-rocm-wheels assets. + try: + if try_download_from_te_rocm_wheels(repo): + return 0 + print(f"[te wheel] no te-rocm-wheels asset for arch={WHEEL_ARCH}", flush=True) + except urllib.error.HTTPError as e: + print(f"[te wheel] te-rocm-wheels not available ({e.code})", flush=True) + except (urllib.error.URLError, json.JSONDecodeError, KeyError, ValueError) as e: + print(f"[te wheel] te-rocm-wheels lookup failed ({e})", flush=True) + + # 2) Fallback: pinned ROCm/maxtext release asset. + pinned_name = f"transformer_engine-2.8.0.dev0+2776c337-1.{WHEEL_ARCH}-cp312-cp312-linux_x86_64.whl" + pinned = f"https://github.com/ROCm/maxtext/releases/download/rocm-maxtext-v0.1.1/{pinned_name}" + download(pinned, pinned_name) + return 0 + + +if __name__ == "__main__": + raise SystemExit(main()) From 2e81b045f4e421f5ba624d166554894f77046560 Mon Sep 17 00:00:00 2001 From: gulsumgudukbay Date: Thu, 28 May 2026 07:02:46 +0000 Subject: [PATCH 2/3] CI: add upstream-sync workflow Scheduled workflow that keeps rocm-main in sync with AI-Hypercomputer/main. --- .github/workflows/upstream_sync.yml | 51 +++++++++++++++++++++++++++++ 1 file changed, 51 insertions(+) create mode 100644 .github/workflows/upstream_sync.yml diff --git a/.github/workflows/upstream_sync.yml b/.github/workflows/upstream_sync.yml new file mode 100644 index 0000000000..2b2f91e6f3 --- /dev/null +++ b/.github/workflows/upstream_sync.yml @@ -0,0 +1,51 @@ +name: Upstream Sync + +on: + workflow_dispatch: {} + schedule: + - cron: '30 4 * * *' # Daily 04:30 UTC + +permissions: + contents: write + +jobs: + sync: + runs-on: ubuntu-latest + if: github.ref == 'refs/heads/main' + steps: + - name: Checkout + uses: actions/checkout@v4 + with: + fetch-depth: 0 + persist-credentials: false + + - name: Configure git + run: | + git config user.name "github-actions[bot]" + git config user.email "github-actions[bot]@users.noreply.github.com" + + - name: Fetch upstream + run: | + git remote add upstream https://github.com/google/maxtext.git 2>/dev/null || true + git fetch upstream main + git checkout main + git pull --ff-only origin main || true + + - name: Merge upstream/main + id: merge_step + run: | + set -e + git merge --no-edit upstream/main || { echo "::error::Merge conflict - resolve manually"; git merge --abort || true; exit 1; } + if git diff --quiet origin/main..main; then echo "no_changes=true" >> $GITHUB_OUTPUT; else echo "no_changes=false" >> $GITHUB_OUTPUT; fi + + - name: Push (if changed) + if: steps.merge_step.outputs.no_changes == 'false' + env: + PAT: ${{ secrets.UPSTREAM_SYNC_TOKEN }} + run: | + [ -z "$PAT" ] && echo "::error::Missing UPSTREAM_SYNC_TOKEN secret" && exit 1 + git push https://x-access-token:$PAT@github.com/${{ github.repository }}.git main + + - name: Result + run: | + if [ "${{ steps.merge_step.outputs.no_changes }}" = "true" ]; then echo "Up to date"; else echo "Synced upstream"; fi From 698624b35843a728f49c6bc20e50f53a1b2839d4 Mon Sep 17 00:00:00 2001 From: gulsumgudukbay Date: Thu, 28 May 2026 07:21:37 +0000 Subject: [PATCH 3/3] checkpoint_conversion: import GCS via gcloud_stub for decoupled mode The top-level `from google.cloud.storage import Client, transfer_manager` in checkpoint_conversion/utils/utils.py broke pytest collection for the ROCm decoupled tests (DECOUPLE_GCLOUD=TRUE), since the package isn't installed in that environment. - gcloud_stub.gcs_storage(): also import and attach the transfer_manager submodule (it isn't auto-imported by `from google.cloud import storage`); extend _gcs_stubs() with a no-op transfer_manager stub. - checkpoint_conversion/utils/utils.py: drop the direct google.cloud import and bind Client/transfer_manager via gcs_storage(), matching the existing pattern in src/maxtext/utils/gcs_utils.py. --- .../checkpoint_conversion/utils/utils.py | 7 ++++-- src/maxtext/common/gcloud_stub.py | 22 +++++++++++++++++-- 2 files changed, 25 insertions(+), 4 deletions(-) diff --git a/src/maxtext/checkpoint_conversion/utils/utils.py b/src/maxtext/checkpoint_conversion/utils/utils.py index 93253cffb0..daff5c2736 100644 --- a/src/maxtext/checkpoint_conversion/utils/utils.py +++ b/src/maxtext/checkpoint_conversion/utils/utils.py @@ -36,8 +36,6 @@ from jax.experimental import multihost_utils from jaxtyping import Array -from google.cloud.storage import Client, transfer_manager - from safetensors import safe_open from safetensors.numpy import save_file as numpy_save_file from safetensors.numpy import save as numpy_save @@ -50,9 +48,14 @@ from flax.training import train_state from maxtext.common import checkpointing +from maxtext.common.gcloud_stub import gcs_storage from maxtext.utils import max_logging import orbax.checkpoint as ocp +_storage = gcs_storage() +Client = _storage.Client +transfer_manager = _storage.transfer_manager + SAFE_TENSORS_CONFIG_FILE = "config.json" SAFE_TENSORS_WEIGHTS_FILE = "model.safetensors" diff --git a/src/maxtext/common/gcloud_stub.py b/src/maxtext/common/gcloud_stub.py index fe6b00cbb5..1936aba5c0 100644 --- a/src/maxtext/common/gcloud_stub.py +++ b/src/maxtext/common/gcloud_stub.py @@ -237,11 +237,27 @@ def bucket(self, *a, **k): # pylint: disable=unused-argument def list_blobs(self, *a, **k): # pylint: disable=unused-argument return iter([]) - return SimpleNamespace(Client=_StubClient, _IS_STUB=True) + def _stub_upload_many_from_filenames(*_a, **_k): + """No-op stub for transfer_manager.upload_many_from_filenames.""" + return [] + + transfer_manager_stub = SimpleNamespace( + upload_many_from_filenames=_stub_upload_many_from_filenames, + _IS_STUB=True, + ) + + return SimpleNamespace(Client=_StubClient, transfer_manager=transfer_manager_stub, _IS_STUB=True) def gcs_storage(): - """Return google.cloud.storage module or stub when decoupled or missing.""" + """Return google.cloud.storage module (with transfer_manager attached) or stub. + + The returned object always exposes both ``.Client`` and ``.transfer_manager`` + so callers can use ``storage.transfer_manager.upload_many_from_filenames(...)`` + without an extra import. ``transfer_manager`` is a submodule of + ``google.cloud.storage`` and is not auto-imported by ``from google.cloud + import storage``; we explicitly import and attach it here. + """ # In decoupled mode always prefer the stub, even if the library is installed, # to avoid accidental GCS calls in tests or local runs. if is_decoupled(): # fast path @@ -250,7 +266,9 @@ def gcs_storage(): try: # pragma: no cover - attempt real import when not decoupled from google.cloud import storage # type: ignore # pylint: disable=import-outside-toplevel + from google.cloud.storage import transfer_manager # type: ignore # pylint: disable=import-outside-toplevel + setattr(storage, "transfer_manager", transfer_manager) setattr(storage, "_IS_STUB", False) return storage except Exception: # ModuleNotFoundError / ImportError for partial installs # pylint: disable=broad-exception-caught