diff --git a/.github/workflows/build_and_test_maxtext.yml b/.github/workflows/build_and_test_maxtext.yml index 9a2d778bb6..df01819109 100644 --- a/.github/workflows/build_and_test_maxtext.yml +++ b/.github/workflows/build_and_test_maxtext.yml @@ -20,16 +20,23 @@ on: pull_request: workflow_call: workflow_dispatch: + inputs: + rocm_only: + description: 'Run only ROCm jobs (manual runs only)' + type: boolean + required: false + default: false schedule: - # Run the job every 4 hours - - cron: '0 */4 * * *' + - cron: '0 3 * * *' # Daily 03:00 UTC concurrency: - # Dedup pull requests (canceling previous runs of the same workflow for same PR), and scheduled runs but nothing else + # Cancel previous runs for same PR (all actors), scheduled runs, + # and manual runs per (branch + actor). group: > ${{ github.event_name == 'pull_request' && format('{0}-pr-{1}', github.workflow, github.event.pull_request.number) || github.event_name == 'schedule' && format('{0}-schedule', github.workflow) || + github.event_name == 'workflow_dispatch' && format('{0}-manual-{1}-{2}', github.workflow, github.ref, github.actor) || github.run_id }} cancel-in-progress: true @@ -93,18 +100,16 @@ jobs: build_and_upload_maxtext_package: needs: doc_only_check # Run if either tests or notebooks need to run - if: | - needs.doc_only_check.outputs.run_tests == 'true' || - needs.doc_only_check.outputs.run_notebooks == 'true' + if: ${{ vars.ROCM_ONLY == 'true' || needs.doc_only_check.outputs.run_tests == 'true' || needs.doc_only_check.outputs.run_notebooks == 'true' }} uses: ./.github/workflows/build_package.yml with: - device_type: tpu - device_name: v4-8 - cloud_runner: linux-x86-n2-16-buildkit + device_type: ${{ vars.ROCM_ONLY == 'true' && 'rocm' || 'tpu' }} + device_name: ${{ vars.ROCM_ONLY == 'true' && 'mi355' || 'v4-8' }} + cloud_runner: ${{ vars.ROCM_ONLY == 'true' && 'linux-x86-64-4gpu-amd' || 'linux-x86-n2-16-buildkit' }} maxtext_jupyter_notebooks: needs: build_and_upload_maxtext_package - if: needs.doc_only_check.outputs.run_notebooks == 'true' + if: ${{ vars.ROCM_ONLY != 'true' && ((!(github.event_name == 'workflow_dispatch' && inputs.rocm_only) && (needs.doc_only_check.outputs.run_notebooks == 'true'))) }} uses: ./.github/workflows/run_jupyter_notebooks.yml strategy: fail-fast: false @@ -121,7 +126,7 @@ jobs: maxtext_cpu_unit_tests: needs: build_and_upload_maxtext_package - if: needs.doc_only_check.outputs.run_tests == 'true' + if: ${{ vars.ROCM_ONLY != 'true' && (!(github.event_name == 'workflow_dispatch' && inputs.rocm_only) && (needs.doc_only_check.outputs.run_tests == 'true')) }} uses: ./.github/workflows/run_tests_against_package.yml strategy: fail-fast: false # don't cancel all jobs on failure @@ -144,7 +149,7 @@ jobs: maxtext_tpu_unit_tests: needs: build_and_upload_maxtext_package - if: needs.doc_only_check.outputs.run_tests == 'true' + if: ${{ vars.ROCM_ONLY != 'true' && (!(github.event_name == 'workflow_dispatch' && inputs.rocm_only) && (needs.doc_only_check.outputs.run_tests == 'true')) }} uses: ./.github/workflows/run_tests_against_package.yml strategy: fail-fast: false @@ -164,7 +169,7 @@ jobs: maxtext_tpu_integration_tests: needs: build_and_upload_maxtext_package - if: needs.doc_only_check.outputs.run_tests == 'true' + if: ${{ vars.ROCM_ONLY != 'true' && (!(github.event_name == 'workflow_dispatch' && inputs.rocm_only) && (needs.doc_only_check.outputs.run_tests == 'true')) }} uses: ./.github/workflows/run_tests_against_package.yml strategy: fail-fast: false @@ -184,7 +189,7 @@ jobs: maxtext_tpu_pathways_unit_tests: needs: build_and_upload_maxtext_package - if: needs.doc_only_check.outputs.run_tests == 'true' + if: ${{ vars.ROCM_ONLY != 'true' && ((!(github.event_name == 'workflow_dispatch' && inputs.rocm_only) && (needs.doc_only_check.outputs.run_tests == 'true'))) }} uses: ./.github/workflows/run_pathways_tests.yml strategy: fail-fast: false @@ -204,7 +209,7 @@ jobs: maxtext_tpu_pathways_integration_tests: needs: build_and_upload_maxtext_package - if: needs.doc_only_check.outputs.run_tests == 'true' + if: ${{ vars.ROCM_ONLY != 'true' && ((!(github.event_name == 'workflow_dispatch' && inputs.rocm_only) && (needs.doc_only_check.outputs.run_tests == 'true'))) }} uses: ./.github/workflows/run_pathways_tests.yml strategy: fail-fast: false @@ -224,7 +229,7 @@ jobs: maxtext_gpu_unit_tests: needs: build_and_upload_maxtext_package - if: needs.doc_only_check.outputs.run_tests == 'true' + if: ${{ vars.ROCM_ONLY != 'true' && (!(github.event_name == 'workflow_dispatch' && inputs.rocm_only) && (needs.doc_only_check.outputs.run_tests == 'true')) }} uses: ./.github/workflows/run_tests_against_package.yml strategy: fail-fast: false @@ -245,7 +250,7 @@ jobs: maxtext_gpu_integration_tests: needs: build_and_upload_maxtext_package - if: needs.doc_only_check.outputs.run_tests == 'true' + if: ${{ vars.ROCM_ONLY != 'true' && (!(github.event_name == 'workflow_dispatch' && inputs.rocm_only) && (needs.doc_only_check.outputs.run_tests == 'true')) }} uses: ./.github/workflows/run_tests_against_package.yml strategy: fail-fast: false @@ -264,9 +269,31 @@ jobs: is_scheduled_run: ${{ github.event_name == 'schedule' }} maxtext_sha: ${{ needs.build_and_upload_maxtext_package.outputs.maxtext_sha }} + maxtext_rocm_decoupled_unit_tests: + needs: build_and_upload_maxtext_package + if: ${{ vars.ROCM_ONLY == 'true' || (github.event_name == 'workflow_dispatch' && inputs.rocm_only) || needs.doc_only_check.outputs.run_tests == 'true' }} + uses: ./.github/workflows/run_tests_against_package.yml + strategy: + fail-fast: false + matrix: + image_type: ["py312"] + with: + device_type: rocm + device_name: mi355 + image_type: ${{ matrix.image_type }} + cloud_runner: linux-x86-64-4gpu-amd + pytest_marker: 'decoupled' + xla_python_client_mem_fraction: 0.9 + tf_force_gpu_allow_growth: true + container_resource_option: "--device=/dev/kfd --device=/dev/dri --group-add video --ipc=host --shm-size 64g --cap-add=SYS_PTRACE --security-opt seccomp=unconfined --privileged" + requirements_file: "dependencies/requirements/requirements_decoupled_rocm_jax_0_8_2.txt" + decoupled_mode: true + is_scheduled_run: ${{ github.event_name == 'schedule' }} + maxtext_sha: ${{ needs.build_and_upload_maxtext_package.outputs.maxtext_sha }} + all_tests_passed: name: All Required Tests Passed - needs: [doc_only_check, build_and_upload_maxtext_package, maxtext_cpu_unit_tests, maxtext_tpu_unit_tests, maxtext_tpu_integration_tests, maxtext_tpu_pathways_unit_tests, maxtext_tpu_pathways_integration_tests, maxtext_gpu_unit_tests, maxtext_gpu_integration_tests] + needs: [doc_only_check, build_and_upload_maxtext_package, maxtext_cpu_unit_tests, maxtext_tpu_unit_tests, maxtext_tpu_integration_tests, maxtext_tpu_pathways_unit_tests, maxtext_tpu_pathways_integration_tests, maxtext_gpu_unit_tests, maxtext_gpu_integration_tests, maxtext_rocm_decoupled_unit_tests] if: always() runs-on: ubuntu-latest steps: @@ -287,6 +314,7 @@ jobs: echo "TPU pathways integration: ${{ needs.maxtext_tpu_pathways_integration_tests.result }}" echo "GPU tests: ${{ needs.maxtext_gpu_unit_tests.result }}" echo "GPU integration: ${{ needs.maxtext_gpu_integration_tests.result }}" + echo "ROCm decoupled tests: ${{ needs.maxtext_rocm_decoupled_unit_tests.result }}" # Fail only if any job failed or was cancelled (skipped is OK) if [ "${{ contains(needs.*.result, 'failure') }}" == "true" ] || [ "${{ contains(needs.*.result, 'cancelled') }}" == "true" ]; then @@ -299,7 +327,7 @@ jobs: all_notebooks_passed: name: All Notebooks Passed needs: [doc_only_check, build_and_upload_maxtext_package, maxtext_jupyter_notebooks] - if: always() + if: ${{ vars.ROCM_ONLY != 'true' && ((!(github.event_name == 'workflow_dispatch' && inputs.rocm_only) && (always()))) }} runs-on: ubuntu-latest steps: - name: Check notebooks results @@ -323,7 +351,7 @@ jobs: notify_failure: name: Notify failed build # creates an issue or modifies last open existing issue for failed build - needs: [maxtext_jupyter_notebooks, maxtext_cpu_unit_tests, maxtext_tpu_unit_tests, maxtext_tpu_integration_tests, maxtext_tpu_pathways_unit_tests, maxtext_tpu_pathways_integration_tests, maxtext_gpu_unit_tests, maxtext_gpu_integration_tests] + needs: [maxtext_jupyter_notebooks, maxtext_cpu_unit_tests, maxtext_tpu_unit_tests, maxtext_tpu_integration_tests, maxtext_tpu_pathways_unit_tests, maxtext_tpu_pathways_integration_tests, maxtext_gpu_unit_tests, maxtext_gpu_integration_tests, maxtext_rocm_decoupled_unit_tests] if: ${{ always() }} runs-on: ubuntu-latest permissions: diff --git a/.github/workflows/build_rocm_transformer_engine_wheel_weekly.yml b/.github/workflows/build_rocm_transformer_engine_wheel_weekly.yml new file mode 100644 index 0000000000..af988bf12a --- /dev/null +++ b/.github/workflows/build_rocm_transformer_engine_wheel_weekly.yml @@ -0,0 +1,191 @@ +name: Build ROCm TransformerEngine wheel (weekly) + +on: + workflow_dispatch: + schedule: + # Weekly at night (02:00 UTC every Monday), 2 hours ahead of scheduled tests. + - cron: "0 2 * * 1" + +permissions: + contents: write + +jobs: + build_upload_prune: + # AMD GPU runner (GitHub-hosted large runner label). + runs-on: linux-x86-64-4gpu-amd + container: + image: ghcr.io/rocm/jax-base-ubu24.rocm720:latest + options: >- + --device=/dev/kfd --device=/dev/dri --group-add video + --ipc=host --shm-size 64g + --cap-add=SYS_PTRACE --security-opt seccomp=unconfined + --privileged + env: + XLA_PYTHON_CLIENT_MEM_FRACTION: "0.9" + NVTE_FUSED_ATTN_AOTRITON: "0" + + steps: + - name: Checkout + uses: actions/checkout@v5 + + - name: Setup build environment (deps + venv) + shell: bash + run: | + set -euo pipefail + apt-get update + apt-get install -y --no-install-recommends git build-essential python3-dev + python3 -m pip install -U uv + python3 -m uv venv --seed + source .venv/bin/activate + uv pip install -U pip setuptools wheel pybind11 cmake + + - name: Install ROCm JAX/JAXlib wheels (build against CI stack) + shell: bash + run: | + set -euo pipefail + source .venv/bin/activate + uv pip install -r dependencies/requirements/requirements_decoupled_rocm_jax_0_8_2.txt + + - name: Detect ROCm version and Python tag + shell: bash + run: | + set -euo pipefail + source .venv/bin/activate + + ROCM_NUM="$([ -f /opt/rocm/.info/version ] && head -n1 /opt/rocm/.info/version | tr -d '[:space:]' || echo unknown)" + echo "Detected ROCm version: ${ROCM_NUM}" + echo "ROCM_NUM=${ROCM_NUM}" >> "${GITHUB_ENV}" + + PYTAG="cp$(python3 -c 'import sys; print(f"{sys.version_info.major}{sys.version_info.minor}")')" + if [ "${PYTAG}" != "cp312" ]; then + echo "Expected Python 3.12 (cp312) for ROCm CI wheels, got ${PYTAG}." + exit 1 + fi + echo "PYTAG=${PYTAG}" >> "${GITHUB_ENV}" + echo "REL_SCRIPT=.github/workflows/utils/te_wheels_release.py" >> "${GITHUB_ENV}" + + - name: Clone ROCm/TransformerEngine (dev) + shell: bash + run: | + set -euo pipefail + rm -rf TransformerEngine + git clone --recursive --branch dev https://github.com/ROCm/TransformerEngine.git + cd TransformerEngine + git submodule update --init --recursive + TE_SHA="$(git rev-parse --short=12 HEAD)" + echo "TE_SHA=${TE_SHA}" >> "${GITHUB_ENV}" + + - name: Select TE wheel arch for runner (mi300/mi355) + shell: bash + run: | + set -euo pipefail + source .venv/bin/activate + + TE_WHEEL_ARCH="$(python3 .github/workflows/utils/install_te_rocm_wheel.py --print-arch)" + echo "Resolved TE wheel arch for runner: ${TE_WHEEL_ARCH}" + echo "TE_WHEEL_ARCH=${TE_WHEEL_ARCH}" >> "${GITHUB_ENV}" + + # Build ONLY for the ROCm arch present on this CI runner (mi300 or mi355). + if [ "${TE_WHEEL_ARCH}" = "mi355" ]; then + SELECTOR="mi355" + GFX_ARCH="gfx950" + else + SELECTOR="mi300" + GFX_ARCH="gfx942;gfx941" + fi + echo "SELECTOR=${SELECTOR}" >> "${GITHUB_ENV}" + echo "GFX_ARCH=${GFX_ARCH}" >> "${GITHUB_ENV}" + + - name: Build TE wheel + shell: bash + run: | + set -euo pipefail + source .venv/bin/activate + + chmod +x "${REL_SCRIPT}" || true + + export USE_ROCM=1 + export HIP_PATH=/opt/rocm + export NVTE_FRAMEWORK=jax + export CMAKE_BUILD_PARALLEL_LEVEL=64 + export NVTE_USE_ROCM=1 + export NVTE_FUSED_ATTN_AOTRITON=0 + export NVTE_BUILD_MAX_JOBS=180 + #export NVTE_AITER_PREBUILT_BASE_URL=https://compute-artifactory.amd.com:5000/artifactory/rocm-generic-local/te-ci/aiter-prebuilts + + echo "=== Building TE wheel for ${SELECTOR} (gfx=${GFX_ARCH}) ===" + pushd TransformerEngine >/dev/null + rm -rf build dist + export PYTHONPATH="$(pwd)/3rdparty/hipify_torch${PYTHONPATH:+:${PYTHONPATH}}" + export PYTORCH_ROCM_ARCH="${GFX_ARCH}" + export NVTE_ROCM_ARCH="${GFX_ARCH}" + python3 setup.py bdist_wheel + wheel_path="$( + python3 -c 'import glob; m=sorted(glob.glob("dist/transformer_engine-*.whl")); print(m[0] if m else "")' + )" + if [ -z "${wheel_path}" ]; then + echo "No wheel produced in dist/ (selector=${SELECTOR})." + exit 1 + fi + wheel_base="$(basename "${wheel_path}")" + if [[ "${wheel_base}" == *"-1.${SELECTOR}-${PYTAG}-${PYTAG}-linux_x86_64.whl" ]]; then + asset_name="${wheel_base}" + else + asset_name="${wheel_base/-${PYTAG}-${PYTAG}-linux_x86_64.whl/-1.${SELECTOR}-${PYTAG}-${PYTAG}-linux_x86_64.whl}" + if [ "${asset_name}" = "${wheel_base}" ]; then + echo "Failed to rename wheel for selector=${SELECTOR}: ${wheel_base}" + exit 1 + fi + fi + cp -f "${wheel_path}" "../${asset_name}" + popd >/dev/null + + ls -lh "${asset_name}" + echo "TE_WHEEL_FILE=${asset_name}" >> "${GITHUB_ENV}" + + - name: Upload wheel to rolling release tag (te-rocm-wheels) + shell: bash + env: + GITHUB_TOKEN: ${{ github.token }} + run: | + set -euo pipefail + source .venv/bin/activate + python3 "${REL_SCRIPT}" upload --no-prerelease --tag "te-rocm-wheels" --title "ROCm TransformerEngine wheels (latest)" --body "Rolling release for latest weekly-built ROCm TransformerEngine wheels used by CI." --file "${TE_WHEEL_FILE}" + + - name: Prune old assets from rolling tag + shell: bash + env: + GITHUB_TOKEN: ${{ github.token }} + run: | + set -euo pipefail + source .venv/bin/activate + python3 "${REL_SCRIPT}" prune-assets --tag "te-rocm-wheels" --keep-days 21 + + - name: Publish wheel to dated weekly release tag + shell: bash + env: + GITHUB_TOKEN: ${{ github.token }} + run: | + set -euo pipefail + source .venv/bin/activate + + DATE_UTC="$(date -u +%Y-%m-%d)" + WEEKLY_TAG="te-rocm-wheels-${DATE_UTC}-${TE_SHA}" + WEEKLY_TITLE="ROCm TransformerEngine wheels ${DATE_UTC} (TE ${TE_SHA})" + # Keep this YAML-safe (no unindented heredocs inside `run: |`). + WEEKLY_BODY="$( + printf '%s\n\nROCm: %s\nPython: %s\nArch: %s (gfx=%s)\n' \ + "Built from ROCm/TransformerEngine dev @ ${TE_SHA} on ${DATE_UTC}." \ + "${ROCM_NUM}" "${PYTAG}" "${SELECTOR}" "${GFX_ARCH}" + )" + + python3 "${REL_SCRIPT}" upload --no-prerelease --tag "${WEEKLY_TAG}" --title "${WEEKLY_TITLE}" --body "${WEEKLY_BODY}" --file "${TE_WHEEL_FILE}" + + - name: Prune old weekly releases + shell: bash + env: + GITHUB_TOKEN: ${{ github.token }} + run: | + set -euo pipefail + source .venv/bin/activate + python3 "${REL_SCRIPT}" prune-releases --prefix "te-rocm-wheels-" --keep-days 21 diff --git a/.github/workflows/run_tests_against_package.yml b/.github/workflows/run_tests_against_package.yml index 6afda428ce..6de07e6004 100644 --- a/.github/workflows/run_tests_against_package.yml +++ b/.github/workflows/run_tests_against_package.yml @@ -37,7 +37,7 @@ on: default: '' is_scheduled_run: required: true - type: string + type: boolean xla_python_client_mem_fraction: required: true type: string @@ -61,20 +61,34 @@ on: maxtext_sha: required: true type: string + decoupled_mode: + required: false + type: boolean + default: false + requirements_file: + required: false + type: string + default: '' permissions: contents: read jobs: run: - runs-on: ${{ inputs.cloud_runner != '' && inputs.cloud_runner || fromJson(format('["self-hosted", "{0}", "{1}"]', inputs.device_type, inputs.device_name)) }} + runs-on: ${{ inputs.cloud_runner != '' && inputs.cloud_runner || (inputs.device_type == 'rocm' && fromJson('["self-hosted","linux-x86-64-4gpu-amd"]')) || fromJson(format('["self-hosted", "{0}", "{1}"]', inputs.device_type, inputs.device_name)) }} + timeout-minutes: ${{ inputs.device_type == 'rocm' && 90 || 360 }} container: - image: gcr.io/tpu-prod-env-multipod/maxtext-unit-test-${{ inputs.device_type == 'cpu' && 'tpu' || inputs.device_type }}:${{ inputs.image_type != '' && inputs.image_type }} + image: ${{ inputs.device_type == 'rocm' && 'ghcr.io/rocm/jax-base-ubu24.rocm720:latest' || format('gcr.io/tpu-prod-env-multipod/maxtext-unit-test-{0}:{1}', inputs.device_type == 'cpu' && 'tpu' || inputs.device_type, inputs.image_type) }} env: XLA_PYTHON_CLIENT_MEM_FRACTION: ${{ inputs.xla_python_client_mem_fraction }} TF_FORCE_GPU_ALLOW_GROWTH: ${{ inputs.tf_force_gpu_allow_growth }} TPU_SKIP_MDS_QUERY: ${{ inputs.device_type == 'cpu' && '1' || '' }} MAXTEXT_PACKAGE_EXTRA: ${{ inputs.device_type == 'cpu' && 'tpu' || inputs.device_type }} ALLOW_MULTIPLE_LIBTPU_LOAD: ${{ inputs.device_type == 'cpu' && 'true' || '' }} # bypass /tmp/libtpu_lockfile check for cpu tests, which don't actually use accelerators (to allow concurrency) + DECOUPLE_GCLOUD: ${{ inputs.decoupled_mode && 'TRUE' || '' }} + LOCAL_GCLOUD_PROJECT: ${{ inputs.decoupled_mode && 'ci-decoupled' || '' }} + # ROCm: prefer CK fused-attention backend over AOTriton for stability. + NVTE_FUSED_ATTN_CK: ${{ inputs.device_type == 'rocm' && '1' || '' }} + NVTE_FUSED_ATTN_AOTRITON: ${{ inputs.device_type == 'rocm' && '0' || '' }} options: ${{ inputs.container_resource_option }} steps: - name: Checkout MaxText @@ -88,15 +102,64 @@ jobs: - name: Install the maxtext wheel shell: bash run: | + if [ "${{ inputs.device_type }}" = "rocm" ]; then + python3 -m pip install -U uv + fi python3 -m uv venv --seed source .venv/bin/activate maxtext_wheel=$(ls maxtext-*-py3-none-any.whl 2>/dev/null) - uv pip install ${maxtext_wheel}[${MAXTEXT_PACKAGE_EXTRA}] --resolution=lowest + if [ "${{ inputs.device_type }}" = "rocm" ]; then + if [ -n "${{ inputs.requirements_file }}" ]; then + echo "Installing requirements from ${{ inputs.requirements_file }}" + uv pip install -r "${{ inputs.requirements_file }}" + fi + uv pip install ${maxtext_wheel} --no-deps + else + if [ -n "${{ inputs.requirements_file }}" ]; then + echo "Installing requirements from ${{ inputs.requirements_file }}" + uv pip install -r "${{ inputs.requirements_file }}" + uv pip install ${maxtext_wheel} --no-deps + else + uv pip install ${maxtext_wheel}[${MAXTEXT_PACKAGE_EXTRA}] --resolution=lowest + fi + fi uv pip install -r src/install_maxtext_extra_deps/extra_deps_from_github.txt python3 --version python3 -m pip freeze uv pip install pytest-cov + + - name: Select ROCm arch (mi300/mi355) + if: ${{ inputs.device_type == 'rocm' }} + shell: bash + run: | + set -euo pipefail + echo "=== ROCm arch selection (from install_te_rocm_wheel.py) ===" + if [ -x .venv/bin/python3 ]; then + te_arch="$( + .venv/bin/python3 .github/workflows/utils/install_te_rocm_wheel.py --print-arch + )" + echo "[te detect_arch] ${te_arch}" + echo "TE_WHEEL_ARCH=${te_arch}" >> "${GITHUB_ENV}" + else + echo "No .venv python found; skipping arch selection." + fi + echo "=========================================================" + + - name: Install Transformer Engine wheel (ROCm) + if: ${{ inputs.device_type == 'rocm' }} + shell: bash + env: + GITHUB_TOKEN: ${{ github.token }} + run: | + source .venv/bin/activate + set -euo pipefail + + python3 .github/workflows/utils/install_te_rocm_wheel.py + + uv pip install --no-deps --force-reinstall transformer_engine-*.whl + - name: Copy test assets files + if: ${{ !inputs.decoupled_mode }} run : gcloud storage cp gs://maxtext-test-assets/* tests/assets - name: Run Tests shell: bash @@ -111,10 +174,13 @@ jobs: export MAXTEXT_ASSETS_ROOT=$(pwd)/src/maxtext/assets export MAXTEXT_TEST_ASSETS_ROOT=$(pwd)/tests/assets export MAXTEXT_PKG_DIR=$(pwd)/src/maxtext - # omit this libtpu init args for gpu tests - if [ "${{ inputs.device_type }}" != "cuda12" ]; then + # omit this libtpu init args for gpu tests (cuda + rocm) + if [ "${{ inputs.device_type }}" != "cuda12" ] && [ "${{ inputs.device_type }}" != "rocm" ]; then export LIBTPU_INIT_ARGS='--xla_tpu_scoped_vmem_limit_kib=65536' fi + if [ "${{ inputs.device_type }}" = "rocm" ]; then + ulimit -c 0 + fi if [ "${{ inputs.total_workers }}" -gt 1 ]; then .venv/bin/python3 -m pip install --quiet pytest-split pytest-xdist SPLIT_ARGS="--splits ${{ inputs.total_workers }} --group ${{ inputs.worker_group }} -n auto" diff --git a/.github/workflows/upstream_sync.yml b/.github/workflows/upstream_sync.yml new file mode 100644 index 0000000000..d19fe20053 --- /dev/null +++ b/.github/workflows/upstream_sync.yml @@ -0,0 +1,51 @@ +name: Upstream Sync + +on: + workflow_dispatch: {} + schedule: + - cron: '30 4 * * *' # Daily 04:30 UTC + +permissions: + contents: write + +jobs: + sync: + runs-on: ubuntu-latest + if: github.ref == 'refs/heads/main' + steps: + - name: Checkout + uses: actions/checkout@v4 + with: + fetch-depth: 0 + persist-credentials: false + + - name: Configure git + run: | + git config user.name "github-actions[bot]" + git config user.email "github-actions[bot]@users.noreply.github.com" + + - name: Fetch upstream + run: | + git remote add upstream https://github.com/google/maxtext.git 2>/dev/null || true + git fetch upstream main + git checkout main + git pull --ff-only origin main || true + + - name: Merge upstream/main + id: merge_step + run: | + set -e + git merge --no-edit upstream/main || { echo "::error::Merge conflict - resolve manually"; git merge --abort || true; exit 1; } + if git diff --quiet origin/main..main; then echo "no_changes=true" >> $GITHUB_OUTPUT; else echo "no_changes=false" >> $GITHUB_OUTPUT; fi + + - name: Push (if changed) + if: steps.merge_step.outputs.no_changes == 'false' + env: + PAT: ${{ secrets.UPSTREAM_SYNC_TOKEN }} + run: | + [ -z "$PAT" ] && echo "::error::Missing UPSTREAM_SYNC_TOKEN secret" && exit 1 + git push https://x-access-token:$PAT@github.com/${{ github.repository }}.git main + + - name: Result + run: | + if [ "${{ steps.merge_step.outputs.no_changes }}" = "true" ]; then echo "Up to date"; else echo "Synced upstream"; fi diff --git a/.github/workflows/utils/install_te_rocm_wheel.py b/.github/workflows/utils/install_te_rocm_wheel.py new file mode 100644 index 0000000000..7a3f70f75b --- /dev/null +++ b/.github/workflows/utils/install_te_rocm_wheel.py @@ -0,0 +1,144 @@ +#!/usr/bin/env python3 +""" +ROCm CI helper: +- Detect MI300 vs MI355 +- Prefer wheel from this repo's 'te-rocm-wheels' release assets +- Fallback to pinned ROCm/maxtext release assets (arch-specific) + +This script only downloads the wheel into the current working directory. +The caller should then install it (e.g. `uv pip install transformer_engine-*.whl`). +""" + +from __future__ import annotations + +import argparse +import json +import os +import re +import subprocess +import urllib.error +import urllib.request + + +def _run(cmd: str) -> str: + try: + return subprocess.check_output(["bash", "-lc", cmd], text=True, stderr=subprocess.STDOUT) + except (subprocess.CalledProcessError, FileNotFoundError, OSError): + return "" + + +def detect_arch() -> str: + """ + Return wheel arch selector: 'mi300' or 'mi355'. + + Notes: + - MI300 family commonly reports gfx942/gfx941. + - MI350/MI355 family commonly reports gfx950 and product strings like MI350X. + """ + override = os.environ.get("TE_WHEEL_ARCH", "").strip().lower() + if override in {"mi300", "mi355"}: + return override + + rocm_smi = _run("command -v rocm-smi >/dev/null 2>&1 && rocm-smi --showproductname || true") or _run( + "[ -x /opt/rocm/bin/rocm-smi ] && /opt/rocm/bin/rocm-smi --showproductname || true" + ) + rocminfo = _run("command -v rocminfo >/dev/null 2>&1 && rocminfo || true") or _run( + "[ -x /opt/rocm/bin/rocminfo ] && /opt/rocm/bin/rocminfo || true" + ) + + blob = f"{rocm_smi}\n{rocminfo}".lower() + gfxs = sorted(set(re.findall(r"gfx[0-9a-f]+", blob))) + + # Prefer explicit gfx IDs when available. + if "gfx950" in gfxs: + return "mi355" + if "gfx942" in gfxs or "gfx941" in gfxs: + return "mi300" + + # Fall back to product string checks. + if "mi355" in blob or "mi350" in blob: + return "mi355" + if "mi300" in blob: + return "mi300" + + # Safe default. + return "mi355" + + +def _headers() -> dict[str, str]: + token = os.environ.get("GITHUB_TOKEN", "") + headers: dict[str, str] = { + "Accept": "application/vnd.github+json", + "User-Agent": "maxtext-ci", + } + if token: + headers["Authorization"] = f"Bearer {token}" + return headers + + +def download(url: str, out_name: str) -> None: + req = urllib.request.Request(url, headers=_headers()) + with urllib.request.urlopen(req) as r, open(out_name, "wb") as f: + f.write(r.read()) + print(f"[te wheel] downloaded {out_name}", flush=True) + + +def try_download_from_te_rocm_wheels(repo: str, arch: str) -> bool: + """Download from this repo's 'te-rocm-wheels' release tag if present.""" + api = f"https://api.github.com/repos/{repo}/releases/tags/te-rocm-wheels" + req = urllib.request.Request(api, headers=_headers()) + with urllib.request.urlopen(req) as r: + rel = json.loads(r.read().decode("utf-8")) + + assets = rel.get("assets", []) + name_re = re.compile(rf"^transformer_engine-.*-1\.{arch}-cp312-cp312-linux_x86_64\.whl$") + hit = next((a for a in assets if name_re.match(a.get("name", ""))), None) + if not hit: + return False + + download(hit["browser_download_url"], hit["name"]) + return True + + +def main(argv: list[str] | None = None) -> int: + """Entry point.""" + parser = argparse.ArgumentParser(add_help=True) + parser.add_argument( + "--print-arch", + action="store_true", + help="Print resolved wheel arch (mi300/mi355) and exit.", + ) + args = parser.parse_args(argv) + + if args.print_arch: + print(detect_arch(), flush=True) + return 0 + + repo = os.environ.get("GITHUB_REPOSITORY", "") + if not repo: + print("[te wheel] GITHUB_REPOSITORY not set; skipping.", flush=True) + return 0 + + arch = detect_arch() + print(f"[te wheel] arch={arch}", flush=True) + + # 1) Prefer: this repo's te-rocm-wheels assets. + try: + if try_download_from_te_rocm_wheels(repo, arch): + return 0 + print(f"[te wheel] no te-rocm-wheels asset for arch={arch}", flush=True) + except urllib.error.HTTPError as e: + print(f"[te wheel] te-rocm-wheels not available ({e.code})", flush=True) + except (urllib.error.URLError, json.JSONDecodeError, KeyError, ValueError) as e: + print(f"[te wheel] te-rocm-wheels lookup failed ({e})", flush=True) + + # 2) Fallback: pinned ROCm/maxtext release assets. + arch_tag = f"1.{arch}" + pinned_name = f"transformer_engine-2.8.0.dev0+2776c337-{arch_tag}-cp312-cp312-linux_x86_64.whl" + pinned = "https://github.com/ROCm/maxtext/releases/download/rocm-maxtext-v0.1.1/" f"{pinned_name}" + download(pinned, pinned_name) + return 0 + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/.github/workflows/utils/te_wheels_release.py b/.github/workflows/utils/te_wheels_release.py new file mode 100644 index 0000000000..a2ae1368e2 --- /dev/null +++ b/.github/workflows/utils/te_wheels_release.py @@ -0,0 +1,204 @@ +#!/usr/bin/env python3 +""" +Publish and prune ROCm TransformerEngine wheel assets on GitHub Releases. + +Intended for use in GitHub Actions. Requires: +- GITHUB_TOKEN +- GITHUB_REPOSITORY (owner/repo) +""" + +from __future__ import annotations + +import argparse +import datetime as dt +import json +import os +import sys +import urllib.error +import urllib.parse +import urllib.request + +API = "https://api.github.com" + + +def _env_token() -> str: + token = os.environ.get("GITHUB_TOKEN") + if not token: + raise SystemExit("GITHUB_TOKEN is not set.") + return token + + +def _env_repo() -> tuple[str, str]: + repo = os.environ.get("GITHUB_REPOSITORY") + if not repo or "/" not in repo: + raise SystemExit("GITHUB_REPOSITORY is not set (expected 'owner/repo').") + owner, name = repo.split("/", 1) + return owner, name + + +def _headers(token: str) -> dict[str, str]: + return { + "Authorization": f"Bearer {token}", + "Accept": "application/vnd.github+json", + "User-Agent": "maxtext-ci", + } + + +def _request_json(method: str, url: str, token: str, body: dict | None = None): + data = None + if body is not None: + data = json.dumps(body).encode("utf-8") + req = urllib.request.Request(url, data=data, method=method, headers=_headers(token)) + with urllib.request.urlopen(req) as r: + raw = r.read() + return json.loads(raw.decode("utf-8")) if raw else None + + +def _request_raw(method: str, url: str, token: str, data: bytes, content_type: str): + h = _headers(token) + h["Content-Type"] = content_type + req = urllib.request.Request(url, data=data, method=method, headers=h) + with urllib.request.urlopen(req) as r: + return r.read() + + +def get_or_create_release(tag: str, title: str, body: str, prerelease: bool) -> dict: + """Get a release by tag, or create it if missing. + + Args: + tag: Release tag (e.g. 'te-rocm-wheels'). + title: Release title. + body: Release body text. + prerelease: Whether to mark the release as a prerelease. + + Returns: + The GitHub release object JSON. + """ + token = _env_token() + owner, name = _env_repo() + try: + rel = _request_json("GET", f"{API}/repos/{owner}/{name}/releases/tags/{tag}", token) + if rel: + return rel + except urllib.error.HTTPError as e: + if e.code != 404: + raise + return _request_json( + "POST", + f"{API}/repos/{owner}/{name}/releases", + token, + {"tag_name": tag, "name": title, "body": body, "prerelease": prerelease}, + ) + + +def upload_asset(tag: str, title: str, body: str, file_path: str, prerelease: bool) -> None: + """Upload (replace) a release asset under the given tag. + + If an asset with the same filename already exists, it is deleted first. + """ + token = _env_token() + owner, name = _env_repo() + + rel = get_or_create_release(tag, title, body, prerelease) + release_id = rel["id"] + upload_url = rel["upload_url"].split("{", 1)[0] + + assets = _request_json("GET", f"{API}/repos/{owner}/{name}/releases/{release_id}", token)["assets"] + file_name = os.path.basename(file_path) + for a in assets: + if a.get("name") == file_name: + _request_json("DELETE", f"{API}/repos/{owner}/{name}/releases/assets/{a['id']}", token) + + with open(file_path, "rb") as f: + data = f.read() + up = f"{upload_url}?{urllib.parse.urlencode({'name': file_name})}" + _request_raw("POST", up, token, data, "application/octet-stream") + print(f"Uploaded {file_name} to release tag {tag}", flush=True) + + +def prune_assets(tag: str, keep_days: int) -> None: + """Delete assets older than `keep_days` from the given release tag.""" + token = _env_token() + owner, name = _env_repo() + try: + rel = _request_json("GET", f"{API}/repos/{owner}/{name}/releases/tags/{tag}", token) + except urllib.error.HTTPError as e: + if e.code == 404: + print(f"No release for tag {tag}; skipping asset prune.", flush=True) + return + raise + + cutoff = dt.datetime.now(dt.timezone.utc) - dt.timedelta(days=keep_days) + pruned = 0 + for a in rel.get("assets", []): + created = dt.datetime.fromisoformat(a["created_at"].replace("Z", "+00:00")) + if created < cutoff: + _request_json("DELETE", f"{API}/repos/{owner}/{name}/releases/assets/{a['id']}", token) + pruned += 1 + print(f"Pruned old asset: {a['name']} (created_at={a['created_at']})", flush=True) + print(f"Asset prune complete for {tag}. Deleted {pruned} assets older than {keep_days} days.", flush=True) + + +def prune_releases(prefix: str, keep_days: int) -> None: + """Delete releases with tag names starting with `prefix` older than `keep_days` days.""" + token = _env_token() + owner, name = _env_repo() + cutoff = dt.datetime.now(dt.timezone.utc) - dt.timedelta(days=keep_days) + + page = 1 + deleted = 0 + while True: + rels = _request_json("GET", f"{API}/repos/{owner}/{name}/releases?per_page=100&page={page}", token) + if not rels: + break + for rel in rels: + tag = rel.get("tag_name", "") + if not tag.startswith(prefix): + continue + created = dt.datetime.fromisoformat(rel["created_at"].replace("Z", "+00:00")) + if created < cutoff: + _request_json("DELETE", f"{API}/repos/{owner}/{name}/releases/{rel['id']}", token) + deleted += 1 + print(f"Deleted old release {tag} (created_at={rel['created_at']})", flush=True) + page += 1 + print(f"Release prune complete. Deleted {deleted} releases older than {keep_days} days.", flush=True) + + +def main(argv: list[str]) -> int: + p = argparse.ArgumentParser() + sub = p.add_subparsers(dest="cmd", required=True) + + up = sub.add_parser("upload") + up.add_argument("--tag", required=True) + up.add_argument("--title", required=True) + up.add_argument("--body", required=True) + up.add_argument("--file", required=True) + prg = up.add_mutually_exclusive_group() + prg.add_argument("--prerelease", dest="prerelease", action="store_true") + prg.add_argument("--no-prerelease", dest="prerelease", action="store_false") + up.set_defaults(prerelease=True) + + pa = sub.add_parser("prune-assets") + pa.add_argument("--tag", required=True) + pa.add_argument("--keep-days", type=int, required=True) + + pr = sub.add_parser("prune-releases") + pr.add_argument("--prefix", required=True) + pr.add_argument("--keep-days", type=int, required=True) + + args = p.parse_args(argv) + + if args.cmd == "upload": + upload_asset(args.tag, args.title, args.body, args.file, prerelease=args.prerelease) + return 0 + if args.cmd == "prune-assets": + prune_assets(args.tag, args.keep_days) + return 0 + if args.cmd == "prune-releases": + prune_releases(args.prefix, args.keep_days) + return 0 + raise AssertionError("unreachable") + + +if __name__ == "__main__": + raise SystemExit(main(sys.argv[1:])) diff --git a/dependencies/requirements/requirements_decoupled_jax_0_8_2.txt b/dependencies/requirements/requirements_decoupled_jax_0_8_2.txt new file mode 100644 index 0000000000..9479797cb6 --- /dev/null +++ b/dependencies/requirements/requirements_decoupled_jax_0_8_2.txt @@ -0,0 +1,40 @@ +absl-py +aqtp +array-record +datasets +flax +grain[parquet] +huggingface_hub +jax==0.8.2 +jaxlib==0.8.2 +jaxtyping +jsonlines +ml-collections +numpy +omegaconf +optax +orbax-checkpoint +pathwaysutils +pillow +pre-commit +protobuf +pydantic +pyink +pylint +pytest +parameterized +pytype +sentencepiece +tensorboardx +tensorflow-datasets +tensorflow-text<2.20 +tensorflow<2.20 +tiktoken +numba>=0.59.0 +google-tunix +tokamax @ git+https://github.com/openxla/tokamax.git@69b328121c3ee120d8e54cf26d1565cea189617f +transformers>=4.57.3,<5 +qwix +mlperf-logging @ https://github.com/mlcommons/logging/archive/38ab22670527888c8eb7825a4ece176fcc36a95d.zip +chex>=0.1.91 +drjax>=0.1.4 diff --git a/dependencies/requirements/requirements_decoupled_rocm_jax_0_7.1.txt b/dependencies/requirements/requirements_decoupled_rocm_jax_0_7.1.txt new file mode 100644 index 0000000000..4d1732349c --- /dev/null +++ b/dependencies/requirements/requirements_decoupled_rocm_jax_0_7.1.txt @@ -0,0 +1,47 @@ +absl_py>=2.3.1 +aqtp>=0.9.0 +chex>=0.1.90 +datasets>=4.2.0 +etils>=1.13.0 +evaluate>=0.4.6 +flax +grain>=0.2.12 +grpcio>=1.75.1 +huggingface_hub>=0.35.3 +https://github.com/ROCm/rocm-jax/releases/download/rocm-jax-v0.7.1/jaxlib-0.7.1-cp312-cp312-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl +jax==0.7.1 +jax-rocm7-pjrt==0.7.1 +jax-rocm7-plugin==0.7.1 +jaxtyping>=0.3.3 +jsonlines>=4.0.0 +matplotlib>=3.10.3 +ml_collections>=1.1.0 +ml_dtypes>=0.5.3 +nltk>=3.9.2 +numpy>=2.0.2 +omegaconf>=2.3.0 +optax>=0.2.6 +orbax-checkpoint>=0.11.25 +pandas>=2.3.3 +parameterized==0.9.0 +pathwaysutils>=0.1.3 +pillow>=11.3.0 +protobuf>=5.29.5 +psutil>=7.0.0 +pytest>=8.4.1 +PyYAML>=6.0.3 +Requests>=2.32.5 +qwix>=0.1.1 +safetensors>=0.6.2 +sentencepiece>=0.2.1 +setuptools>=80.9.0 +tabulate>=0.9.0 +tensorflow>=2.19.1 +tensorflow_text>=2.19.0 +tensorflow_datasets>=4.9.9 +tensorstore>=0.1.76 +tiktoken>=0.12.0 +tqdm>=4.67.1 +transformers>=4.57.0 +urllib3>=2.5.0 +git+https://github.com/google/tunix.git diff --git a/dependencies/requirements/requirements_decoupled_rocm_jax_0_8_2.txt b/dependencies/requirements/requirements_decoupled_rocm_jax_0_8_2.txt new file mode 100644 index 0000000000..f821160f03 --- /dev/null +++ b/dependencies/requirements/requirements_decoupled_rocm_jax_0_8_2.txt @@ -0,0 +1,47 @@ +absl-py +aqtp +array-record +datasets +flax +grain[parquet] +huggingface_hub + + +# ROCm JAX 0.8.2 (py312) wheels (install order matters) +https://github.com/ROCm/rocm-jax/releases/download/rocm-jax-v0.8.2/jax_rocm7_pjrt-0.8.2+rocm7.2.0-py3-none-manylinux_2_28_x86_64.whl +https://github.com/ROCm/rocm-jax/releases/download/rocm-jax-v0.8.2/jax_rocm7_plugin-0.8.2+rocm7.2.0-cp312-cp312-manylinux_2_28_x86_64.whl +https://github.com/ROCm/rocm-jax/releases/download/rocm-jax-v0.8.2/jaxlib-0.8.2+rocm7-cp312-cp312-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl +jax==0.8.2 + + +jaxtyping +jsonlines +ml-collections +numpy +omegaconf +optax +orbax-checkpoint +pathwaysutils +pillow +pre-commit +protobuf +pydantic +pyink +pylint +pytest +parameterized +pytype +sentencepiece +tensorboardx +tensorflow-datasets +tensorflow-text<2.20 +tensorflow<2.20 +tiktoken +numba>=0.59.0 +google-tunix +tokamax @ git+https://github.com/openxla/tokamax.git@69b328121c3ee120d8e54cf26d1565cea189617f +transformers>=4.57.3,<5 +qwix +mlperf-logging @ https://github.com/mlcommons/logging/archive/38ab22670527888c8eb7825a4ece176fcc36a95d.zip +chex>=0.1.91 +drjax>=0.1.4 diff --git a/dependencies/requirements/requirements_rocm_jax_0.8.2.txt b/dependencies/requirements/requirements_rocm_jax_0.8.2.txt new file mode 100644 index 0000000000..bffcc3a203 --- /dev/null +++ b/dependencies/requirements/requirements_rocm_jax_0.8.2.txt @@ -0,0 +1,60 @@ +absl-py +aqtp +array-record +datasets +flax +grain[parquet] +huggingface_hub + + +# ROCm JAX 0.8.2 (py312) wheels (install order matters) +https://github.com/ROCm/rocm-jax/releases/download/rocm-jax-v0.8.2/jax_rocm7_pjrt-0.8.2+rocm7.2.0-py3-none-manylinux_2_28_x86_64.whl +https://github.com/ROCm/rocm-jax/releases/download/rocm-jax-v0.8.2/jax_rocm7_plugin-0.8.2+rocm7.2.0-cp312-cp312-manylinux_2_28_x86_64.whl +https://github.com/ROCm/rocm-jax/releases/download/rocm-jax-v0.8.2/jaxlib-0.8.2+rocm7-cp312-cp312-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl +jax==0.8.2 + + +jaxtyping +jsonlines +ml-collections +numpy +omegaconf +optax +orbax-checkpoint +pathwaysutils +pillow +pre-commit +protobuf +pydantic +pyink +pylint +pytest +parameterized +pytype +sentencepiece +tensorboardx +tensorflow-datasets +tensorflow-text<2.20 +tensorflow<2.20 +tiktoken +numba>=0.59.0 +google-tunix +tokamax @ git+https://github.com/openxla/tokamax.git@69b328121c3ee120d8e54cf26d1565cea189617f +transformers>=4.57.3,<5 +qwix +mlperf-logging @ https://github.com/mlcommons/logging/archive/38ab22670527888c8eb7825a4ece176fcc36a95d.zip + + +# GCP / Cloud deps (restored for non-decoupled ROCm runs) +cloud-accelerator-diagnostics +cloud-tpu-diagnostics +gcsfs +google-api-python-client +google-cloud-aiplatform +google-cloud-mldiagnostics +google-cloud-monitoring +ml-goodput-measurement +tensorboard-plugin-profile +google-jetstream @ https://github.com/AI-Hypercomputer/JetStream/archive/29329e8e73820993f77cfc8efe34eb2a73f5de98.zip +chex>=0.1.91 +drjax>=0.1.4 diff --git a/src/maxtext/utils/max_utils.py b/src/maxtext/utils/max_utils.py index 765122478e..39b11cc653 100644 --- a/src/maxtext/utils/max_utils.py +++ b/src/maxtext/utils/max_utils.py @@ -248,11 +248,20 @@ def initialize_jax_for_gpu(raw_keys): if os.environ.get("JAX_COORDINATOR_IP") is not None: coordinator_ip = str(os.getenv("JAX_COORDINATOR_IP")) coordinator_port = str(os.getenv("JAX_COORDINATOR_PORT")) + devices = os.getenv("CUDA_VISIBLE_DEVICES") + if devices is not None: + try: + devices = [int(x) for x in devices.split(",")] + except (ValueError, TypeError) as e: + max_logging.log(f"Error parsing CUDA_VISIBLE_DEVICES: {e}") + devices = None + jax.distributed.initialize( coordinator_address=f"{coordinator_ip}:{coordinator_port}", num_processes=int(os.getenv("NNODES")), process_id=int(os.getenv("NODE_RANK")), initialization_timeout=raw_keys["jax_distributed_initialization_timeout"], + local_device_ids=devices, ) max_logging.log(f"JAX global devices: {jax.devices()}") diff --git a/tests/integration/smoke/train_gpu_smoke_test.py b/tests/integration/smoke/train_gpu_smoke_test.py index 76f09c7b0d..84c5d8ce8f 100644 --- a/tests/integration/smoke/train_gpu_smoke_test.py +++ b/tests/integration/smoke/train_gpu_smoke_test.py @@ -20,8 +20,8 @@ from maxtext.common.gcloud_stub import is_decoupled from maxtext.trainers.pre_train.train import main as train_main -from MaxText.globals import MAXTEXT_ASSETS_ROOT, MAXTEXT_PKG_DIR -from tests.utils.test_helpers import get_test_dataset_path, get_test_base_output_directory +from MaxText.globals import MAXTEXT_ASSETS_ROOT +from tests.utils.test_helpers import get_test_dataset_path, get_test_base_output_directory, get_test_config_path_for class Train(unittest.TestCase): @@ -43,7 +43,7 @@ def test_tiny_config(self): train_main( [ None, - os.path.join(MAXTEXT_PKG_DIR, "configs", "gpu", "gpu_smoke_test.yml"), + get_test_config_path_for("gpu/gpu_smoke_test.yml"), # pylint: disable=f-string-without-interpolation f"base_output_directory={self.base_output_directory}", "run_name=runner_test", diff --git a/tests/integration/train_tests.py b/tests/integration/train_tests.py index 77b3157f30..45dedfa00c 100644 --- a/tests/integration/train_tests.py +++ b/tests/integration/train_tests.py @@ -17,7 +17,7 @@ import unittest import pytest import jax - +from jax._src import test_util as jtu from absl.testing import absltest from maxtext.common.gcloud_stub import is_decoupled from maxtext.trainers.pre_train.train import main as train_main @@ -563,6 +563,8 @@ def test_gpu_packed_attention(self): @pytest.mark.integration_test @pytest.mark.gpu_only def test_gpu_ring_attention(self): + if jtu.is_device_rocm(): + pytest.skip("TE ring attention context parallelism not supported on ROCm.") os.environ["NVTE_FUSED_ATTN"] = "1" # Enable fused attention os.environ["NVTE_FUSED_RING_ATTENTION_USE_SCAN"] = "0" # Disable scan for ring attention ring_attention = [ # tests base config on GPU with ring attention diff --git a/tests/unit/attention_test.py b/tests/unit/attention_test.py index 14257b6fa7..6ed418a7a6 100644 --- a/tests/unit/attention_test.py +++ b/tests/unit/attention_test.py @@ -1285,7 +1285,10 @@ def test_projection_initialization(self): mla_extra_args = {"ici_fsdp_parallelism": jax.device_count()} if is_decoupled() else {} mla_config_args.update(mla_extra_args) _, mla_layer = self.init_mla(mla_config_args, rope_type="default") - _, mla_layer = self.init_mla(self.config_arguments, rope_type="default") + # Also test the baseline init path; in decoupled mode inject mesh override. + mla_baseline_args = self.config_arguments.copy() + mla_baseline_args.update(mla_extra_args) + _, mla_layer = self.init_mla(mla_baseline_args, rope_type="default") # 4. Assert that the MLA layer DOES NOT HAVE the base projections self.assertFalse(hasattr(mla_layer, "query"), "MLA should not have 'query' projection.") diff --git a/tests/unit/engram_vs_reference_test.py b/tests/unit/engram_vs_reference_test.py index 2243f753d3..f575e357c3 100644 --- a/tests/unit/engram_vs_reference_test.py +++ b/tests/unit/engram_vs_reference_test.py @@ -28,7 +28,6 @@ from typing import List from dataclasses import dataclass, field import math -import os import unittest from absl.testing import parameterized @@ -45,9 +44,9 @@ import jax.numpy as jnp from jax.sharding import Mesh -from MaxText.globals import MAXTEXT_PKG_DIR from MaxText import pyconfig from MaxText import maxtext_utils +from tests.utils.test_helpers import get_test_config_path from maxtext.layers.engram import CompressedTokenizer as CompressedTokenizerJAX from maxtext.layers.engram import NgramHashMapping as NgramHashMappingJAX @@ -464,7 +463,7 @@ def init_torch_weights(module, std=1): def get_cfg_and_mesh(config): """Returns MaxText configuration and mesh.""" cfg = pyconfig.initialize( - [None, os.path.join(MAXTEXT_PKG_DIR, "configs", "base.yml")], + [None, get_test_config_path()], run_name="", enable_checkpointing=False, model_name="default", diff --git a/tests/unit/maxtext_utils_test.py b/tests/unit/maxtext_utils_test.py index 8ac66d3421..4d5208cc4a 100644 --- a/tests/unit/maxtext_utils_test.py +++ b/tests/unit/maxtext_utils_test.py @@ -958,8 +958,10 @@ class TestGetAbstractState(unittest.TestCase): """Test class for get_abstract_state.""" def setUp(self): + extra_args = {"ici_fsdp_parallelism": jax.device_count()} if is_decoupled() else {} self.config = pyconfig.initialize( [None, get_test_config_path()], + **extra_args, enable_checkpointing=False, model_name="llama3.1-8b", per_device_batch_size=1, diff --git a/tests/unit/mhc_test.py b/tests/unit/mhc_test.py index 619899ee81..410982dde5 100644 --- a/tests/unit/mhc_test.py +++ b/tests/unit/mhc_test.py @@ -14,7 +14,6 @@ """Test for DeepSeek Manifold-Constrained Hyper Connections (mHC).""" -import os.path import unittest import pytest @@ -27,11 +26,12 @@ from MaxText import pyconfig from MaxText.common_types import HyperConnectionType -from MaxText.globals import MAXTEXT_PKG_DIR from maxtext.layers import attention_mla, linears, mhc, moe from maxtext.layers.initializers import nd_dense_init from maxtext.layers.normalizations import RMSNorm from maxtext.utils import maxtext_utils +from maxtext.common.gcloud_stub import is_decoupled +from tests.utils.test_helpers import get_test_config_path class TestExpandReduce(unittest.TestCase): @@ -92,8 +92,10 @@ class TestMHC(unittest.TestCase): def setUp(self): self.dim = 16 + extra_args = {"ici_fsdp_parallelism": jax.device_count()} if is_decoupled() else {} self.config = pyconfig.initialize( - [None, os.path.join(MAXTEXT_PKG_DIR, "configs", "base.yml")], + [None, get_test_config_path()], + **extra_args, run_name="test_mhc", enable_checkpointing=False, model_name="deepseek-custom", diff --git a/tests/unit/profiler_test.py b/tests/unit/profiler_test.py index a25aecfe43..8ba9df2792 100644 --- a/tests/unit/profiler_test.py +++ b/tests/unit/profiler_test.py @@ -13,7 +13,6 @@ # limitations under the License. """Profiler tests.""" -import os import sys import unittest from unittest.mock import MagicMock, patch @@ -21,7 +20,6 @@ import pytest from MaxText import pyconfig -from MaxText.globals import MAXTEXT_PKG_DIR from maxtext.configs import types from maxtext.common import profiler from tests.utils.test_helpers import get_test_config_path @@ -48,7 +46,7 @@ def tearDown(self): def test_profiler_options_populated_from_config(self): """Verifies that Profiler initializes jax.profiler.ProfileOptions from config.""" config = pyconfig.initialize( - [sys.argv[0], os.path.join(MAXTEXT_PKG_DIR, "configs", "base.yml")], + [sys.argv[0], get_test_config_path()], enable_checkpointing=False, run_name="test_profiler_options", base_output_directory="/tmp", @@ -83,7 +81,7 @@ def test_profiler_options_populated_from_config(self): def test_profiler_activate_passes_options(self, mock_start_trace): """Verifies that activate() passes the profiling_options to jax.profiler.start_trace.""" config = pyconfig.initialize( - [sys.argv[0], os.path.join(MAXTEXT_PKG_DIR, "configs", "base.yml")], + [sys.argv[0], get_test_config_path()], enable_checkpointing=False, run_name="test_profiler_options", base_output_directory="/tmp", diff --git a/tests/unit/sharding_compare_test.py b/tests/unit/sharding_compare_test.py index 054e9656c6..fb35cb06d1 100644 --- a/tests/unit/sharding_compare_test.py +++ b/tests/unit/sharding_compare_test.py @@ -23,12 +23,11 @@ from MaxText import optimizers from MaxText import pyconfig # import optax - -from MaxText.globals import MAXTEXT_PKG_DIR from maxtext.layers import quantizations from maxtext.models import models from maxtext.trainers.pre_train.train_compile import get_shaped_inputs, get_topology_mesh, validate_config from tests.utils.sharding_dump import TEST_CASES, load_json, named_shardings_to_json, partition_specs_to_json +from tests.utils.test_helpers import get_test_config_path import pytest Transformer = models.transformer_as_linen @@ -117,7 +116,7 @@ def test_sharding_dump_for_model(model_name: str, topology: str, num_slice: str) """ params = [ "/deps/MaxText/tests/unit/sharding_compare_test", - os.path.join(MAXTEXT_PKG_DIR, "configs", "base.yml"), + get_test_config_path(), f"compile_topology={topology}", f"compile_topology_num_slices={num_slice}", f"model_name={model_name}", @@ -183,7 +182,7 @@ def abstract_state_and_shardings(request): print(f"Testing model: {model_name}, topology: {topology}, num_slices: {num_slice}", flush=True) params = [ "/deps/MaxText/tests/unit/sharding_compare_test", - os.path.join(MAXTEXT_PKG_DIR, "configs", "base.yml"), + get_test_config_path(), f"compile_topology={topology}", f"compile_topology_num_slices={num_slice}", f"model_name={model_name}", diff --git a/tests/utils/run_sharding_dump.py b/tests/utils/run_sharding_dump.py index 70b8e7bc1b..fd158e8b85 100644 --- a/tests/utils/run_sharding_dump.py +++ b/tests/utils/run_sharding_dump.py @@ -46,9 +46,9 @@ from typing import Sequence -from MaxText.globals import MAXTEXT_PKG_DIR, MAXTEXT_REPO_ROOT +from MaxText.globals import MAXTEXT_REPO_ROOT from tests.utils.sharding_dump import TEST_CASES -import os +from tests.utils.test_helpers import get_test_config_path import subprocess from absl import app, flags from pathlib import Path @@ -67,7 +67,7 @@ def run_single_dump(model_name: str, topology: str, num_slice: str) -> None: "python3", "-m", "tests.utils.sharding_dump", - os.path.join(MAXTEXT_PKG_DIR, "configs", "base.yml"), + get_test_config_path(), f"compile_topology={topology}", f"compile_topology_num_slices={num_slice}", f"model_name={model_name}", diff --git a/tests/utils/test_helper.py b/tests/utils/test_helper.py deleted file mode 100644 index a35ea2d780..0000000000 --- a/tests/utils/test_helper.py +++ /dev/null @@ -1,38 +0,0 @@ -# Copyright 2025 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# https://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Test utilities file for helper for test configuration path selection. - -Provides a single helper to return the absolute path to a test config. When -running in decoupled mode (DECOUPLE_GCLOUD=TRUE) the decoupled test config is -returned. -""" - -import os -from maxtext.common.gcloud_stub import is_decoupled -from MaxText.globals import MAXTEXT_PKG_DIR - - -def get_test_config_path(): - """Return absolute path to the chosen test config file. - - Returns `decoupled_base_test.yml` when decoupled, otherwise `base.yml`. - """ - base_cfg = "base.yml" - if is_decoupled(): - base_cfg = "decoupled_base_test.yml" - return os.path.join(MAXTEXT_PKG_DIR, "configs", base_cfg) - - -__all__ = ["get_test_config_path"] diff --git a/tests/utils/test_helpers.py b/tests/utils/test_helpers.py index 656e0e1c37..82d3574459 100644 --- a/tests/utils/test_helpers.py +++ b/tests/utils/test_helpers.py @@ -76,9 +76,21 @@ def get_test_base_output_directory(cloud_path=None): return cloud_path or "gs://runner-maxtext-logs" +def get_test_config_path_for(relative_path: str): + """Return absolute path for a config under the configs directory. + + Uses the same decoupled-vs-non-decoupled selection logic as + `get_test_config_path()` when `relative_path` is `base.yml`. + """ + if relative_path == "base.yml": + return get_test_config_path() + return os.path.join(MAXTEXT_CONFIGS_DIR, relative_path) + + __all__ = [ "get_test_base_output_directory", "get_test_config_path", "get_post_train_test_config_path", + "get_test_config_path_for", "get_test_dataset_path", ]