diff --git a/.github/workflows/build_and_test_maxtext.yml b/.github/workflows/build_and_test_maxtext.yml index 6949bf8d7e..6f971d1b5a 100644 --- a/.github/workflows/build_and_test_maxtext.yml +++ b/.github/workflows/build_and_test_maxtext.yml @@ -20,16 +20,23 @@ on: pull_request: workflow_call: workflow_dispatch: + inputs: + rocm_only: + description: 'Run only ROCm jobs (manual runs only)' + type: boolean + required: false + default: false schedule: - # Run the job every 4 hours - - cron: '0 */4 * * *' + - cron: '0 3 * * *' # Daily 03:00 UTC concurrency: - # Dedup pull requests (canceling previous runs of the same workflow for same PR), and scheduled runs but nothing else + # Cancel previous runs for same PR (all actors), scheduled runs, + # and manual runs per (branch + actor). group: > ${{ github.event_name == 'pull_request' && format('{0}-pr-{1}', github.workflow, github.event.pull_request.number) || github.event_name == 'schedule' && format('{0}-schedule', github.workflow) || + github.event_name == 'workflow_dispatch' && format('{0}-manual-{1}-{2}', github.workflow, github.ref, github.actor) || github.run_id }} cancel-in-progress: true @@ -118,18 +125,16 @@ jobs: build_and_upload_maxtext_package: needs: analyze_code_changes # Run if either tests or notebooks need to run - if: | - needs.analyze_code_changes.outputs.run_tests == 'true' || - needs.analyze_code_changes.outputs.run_notebooks == 'true' + if: ${{ vars.ROCM_ONLY == 'true' || needs.analyze_code_changes.outputs.run_tests == 'true' || needs.analyze_code_changes.outputs.run_notebooks == 'true' }} uses: ./.github/workflows/build_package.yml with: - device_type: tpu - device_name: v4-8 - cloud_runner: linux-x86-n2-16-buildkit + device_type: ${{ vars.ROCM_ONLY == 'true' && 'rocm' || 'tpu' }} + device_name: ${{ vars.ROCM_ONLY == 'true' && 'mi355' || 'v4-8' }} + cloud_runner: ${{ vars.ROCM_ONLY == 'true' && 'linux-x86-64-4gpu-amd' || 'linux-x86-n2-16-buildkit' }} maxtext_jupyter_notebooks: needs: build_and_upload_maxtext_package - if: needs.analyze_code_changes.outputs.run_notebooks == 'true' + if: ${{ vars.ROCM_ONLY != 'true' && !(github.event_name == 'workflow_dispatch' && inputs.rocm_only) && needs.analyze_code_changes.outputs.run_notebooks == 'true' }} uses: ./.github/workflows/run_jupyter_notebooks.yml strategy: fail-fast: false @@ -145,7 +150,7 @@ jobs: tpu-tests: name: ${{ matrix.flavor }} tests needs: [build_and_upload_maxtext_package] - if: needs.analyze_code_changes.outputs.run_tests == 'true' + if: ${{ vars.ROCM_ONLY != 'true' && !(github.event_name == 'workflow_dispatch' && inputs.rocm_only) && needs.analyze_code_changes.outputs.run_tests == 'true' }} uses: ./.github/workflows/run_tests_coordinator.yml strategy: fail-fast: false @@ -160,7 +165,7 @@ jobs: gpu-tests: name: ${{ matrix.flavor }} tests needs: [build_and_upload_maxtext_package] - if: needs.analyze_code_changes.outputs.run_tests == 'true' + if: ${{ vars.ROCM_ONLY != 'true' && !(github.event_name == 'workflow_dispatch' && inputs.rocm_only) && needs.analyze_code_changes.outputs.run_tests == 'true' }} strategy: fail-fast: false matrix: @@ -175,7 +180,7 @@ jobs: cpu-tests: name: ${{ matrix.flavor }} tests needs: [build_and_upload_maxtext_package] - if: needs.analyze_code_changes.outputs.run_tests == 'true' + if: ${{ vars.ROCM_ONLY != 'true' && !(github.event_name == 'workflow_dispatch' && inputs.rocm_only) && needs.analyze_code_changes.outputs.run_tests == 'true' }} uses: ./.github/workflows/run_tests_coordinator.yml strategy: fail-fast: false @@ -189,7 +194,7 @@ jobs: maxtext_tpu_pathways_unit_tests: needs: build_and_upload_maxtext_package - if: needs.analyze_code_changes.outputs.run_tests == 'true' + if: ${{ vars.ROCM_ONLY != 'true' && !(github.event_name == 'workflow_dispatch' && inputs.rocm_only) && needs.analyze_code_changes.outputs.run_tests == 'true' }} uses: ./.github/workflows/run_pathways_tests.yml strategy: fail-fast: false @@ -208,7 +213,7 @@ jobs: maxtext_tpu_pathways_integration_tests: needs: build_and_upload_maxtext_package - if: needs.analyze_code_changes.outputs.run_tests == 'true' + if: ${{ vars.ROCM_ONLY != 'true' && !(github.event_name == 'workflow_dispatch' && inputs.rocm_only) && needs.analyze_code_changes.outputs.run_tests == 'true' }} uses: ./.github/workflows/run_pathways_tests.yml strategy: fail-fast: false @@ -225,9 +230,39 @@ jobs: is_scheduled_run: ${{ github.event_name == 'schedule' }} maxtext_sha: ${{ needs.build_and_upload_maxtext_package.outputs.maxtext_sha }} + rocm-tests: + name: ${{ matrix.flavor }} tests + needs: [build_and_upload_maxtext_package] + if: ${{ vars.ROCM_ONLY != 'true' && needs.analyze_code_changes.outputs.run_tests == 'true' }} + uses: ./.github/workflows/run_tests_coordinator.yml + strategy: + fail-fast: false + matrix: + flavor: [rocm-unit] + with: + flavor: ${{ matrix.flavor }} + base_image: 'rocm-placeholder' + is_scheduled_run: ${{ github.event_name == 'schedule' }} + maxtext_sha: ${{ needs.build_and_upload_maxtext_package.outputs.maxtext_sha }} + + rocm-decoupled-tests: + name: ${{ matrix.flavor }} tests + needs: [build_and_upload_maxtext_package] + if: ${{ vars.ROCM_ONLY == 'true' || (github.event_name == 'workflow_dispatch' && inputs.rocm_only) || needs.analyze_code_changes.outputs.run_tests == 'true' }} + uses: ./.github/workflows/run_tests_coordinator.yml + strategy: + fail-fast: false + matrix: + flavor: [rocm-decoupled] + with: + flavor: ${{ matrix.flavor }} + base_image: 'rocm-placeholder' + is_scheduled_run: ${{ github.event_name == 'schedule' }} + maxtext_sha: ${{ needs.build_and_upload_maxtext_package.outputs.maxtext_sha }} + all_tests_passed: name: All Required Tests Passed - needs: [tpu-tests, gpu-tests, cpu-tests, maxtext_tpu_pathways_unit_tests, maxtext_tpu_pathways_integration_tests] + needs: [analyze_code_changes, build_and_upload_maxtext_package, tpu-tests, gpu-tests, cpu-tests, maxtext_tpu_pathways_unit_tests, maxtext_tpu_pathways_integration_tests, rocm-tests, rocm-decoupled-tests] if: always() runs-on: ubuntu-latest steps: @@ -245,6 +280,8 @@ jobs: echo "CPU Tests (Matrix) result: ${NEEDS_CPU_TESTS_RESULT}" echo "Pathways Unit result: ${NEEDS_MAXTEXT_TPU_PATHWAYS_UNIT_TESTS_RESULT}" echo "Pathways Integration result: ${NEEDS_MAXTEXT_TPU_PATHWAYS_INTEGRATION_TESTS_RESULT}" + echo "ROCm Tests (Matrix) result: ${NEEDS_ROCM_TESTS_RESULT}" + echo "ROCm Decoupled Tests (Matrix) result: ${NEEDS_ROCM_DECOUPLED_TESTS_RESULT}" # Fail only if any job failed or was cancelled (skipped is OK) if [ "${{ contains(needs.*.result, 'failure') }}" == "true" ] || [ "${{ contains(needs.*.result, 'cancelled') }}" == "true" ]; then @@ -261,11 +298,13 @@ jobs: NEEDS_GPU_TESTS_RESULT: ${{ needs.gpu-tests.result }} NEEDS_MAXTEXT_TPU_PATHWAYS_UNIT_TESTS_RESULT: ${{ needs.maxtext_tpu_pathways_unit_tests.result }} NEEDS_MAXTEXT_TPU_PATHWAYS_INTEGRATION_TESTS_RESULT: ${{ needs.maxtext_tpu_pathways_integration_tests.result }} + NEEDS_ROCM_TESTS_RESULT: ${{ needs.rocm-tests.result }} + NEEDS_ROCM_DECOUPLED_TESTS_RESULT: ${{ needs.rocm-decoupled-tests.result }} all_notebooks_passed: name: All Notebooks Passed needs: [analyze_code_changes, build_and_upload_maxtext_package, maxtext_jupyter_notebooks] - if: always() + if: ${{ vars.ROCM_ONLY != 'true' && !(github.event_name == 'workflow_dispatch' && inputs.rocm_only) && always() }} runs-on: ubuntu-latest steps: - name: Check notebooks results @@ -293,7 +332,7 @@ jobs: notify_failure: name: Notify failed build # creates an issue or modifies last open existing issue for failed build - needs: [tpu-tests, gpu-tests, cpu-tests, maxtext_jupyter_notebooks, maxtext_tpu_pathways_unit_tests, maxtext_tpu_pathways_integration_tests] + needs: [tpu-tests, gpu-tests, cpu-tests, maxtext_jupyter_notebooks, maxtext_tpu_pathways_unit_tests, maxtext_tpu_pathways_integration_tests, rocm-tests, rocm-decoupled-tests] if: ${{ always() }} runs-on: ubuntu-latest permissions: diff --git a/.github/workflows/build_rocm_transformer_engine_wheel_weekly.yml b/.github/workflows/build_rocm_transformer_engine_wheel_weekly.yml new file mode 100644 index 0000000000..c43b78114d --- /dev/null +++ b/.github/workflows/build_rocm_transformer_engine_wheel_weekly.yml @@ -0,0 +1,202 @@ +name: Build ROCm TransformerEngine wheel (weekly) + +on: + workflow_dispatch: + schedule: + # Weekly at night (02:00 UTC every Monday), 2 hours ahead of scheduled tests. + - cron: "0 2 * * 1" + +permissions: + contents: write + +jobs: + build_upload_prune: + # AMD GPU runner (GitHub-hosted large runner label). + runs-on: linux-x86-64-4gpu-amd + container: + image: ghcr.io/rocm/jax-base-ubu24.rocm720:latest + options: >- + --device=/dev/kfd --device=/dev/dri --group-add video + --ipc=host --shm-size 64g + --cap-add=SYS_PTRACE --security-opt seccomp=unconfined + --privileged + env: + XLA_PYTHON_CLIENT_MEM_FRACTION: "0.9" + NVTE_FUSED_ATTN_AOTRITON: "0" + env: + TE_WHEELS_KEEP_DAYS: "21" + + steps: + - name: Checkout + uses: actions/checkout@v5 + + - name: Setup build environment (deps + venv) + shell: bash + run: | + set -euo pipefail + apt-get update + apt-get install -y --no-install-recommends git build-essential python3-dev + python3 -m pip install -U uv + python3 -m uv venv --seed + source .venv/bin/activate + uv pip install -U pip setuptools wheel pybind11 cmake + + - name: Install ROCm JAX/JAXlib wheels (build against CI stack) + shell: bash + run: | + set -euo pipefail + source .venv/bin/activate + uv pip install -r src/dependencies/requirements/requirements_decoupled_rocm_jax_0_9_1.txt + + - name: Install PyTorch ROCm (build-time dep for aiter JIT) + shell: bash + run: | + set -euo pipefail + source .venv/bin/activate + uv pip install torch --index-url https://download.pytorch.org/whl/rocm7.2 + + - name: Detect ROCm version and Python tag + shell: bash + run: | + set -euo pipefail + source .venv/bin/activate + + # Detect ROCm version number from JAX backend string when available (e.g. 'rocm 70200'). + ROCM_NUM="$([ -f /opt/rocm/.info/version ] && head -n1 /opt/rocm/.info/version | tr -d ' \t\r' || echo unknown)" + echo "Detected ROCm version: ${ROCM_NUM}" + echo "ROCM_NUM=${ROCM_NUM}" >> "${GITHUB_ENV}" + + PYTAG="cp$(python3 -c 'import sys; print(f"{sys.version_info.major}{sys.version_info.minor}")')" + if [ "${PYTAG}" != "cp312" ]; then + echo "Expected Python 3.12 (cp312) for ROCm CI wheels, got ${PYTAG}." + exit 1 + fi + echo "PYTAG=${PYTAG}" >> "${GITHUB_ENV}" + echo "REL_SCRIPT=.github/workflows/utils/te_wheels_release.py" >> "${GITHUB_ENV}" + + - name: Clone ROCm/TransformerEngine (dev) + shell: bash + run: | + set -euo pipefail + rm -rf TransformerEngine + git clone --recursive --branch dev https://github.com/ROCm/TransformerEngine.git + cd TransformerEngine + git submodule update --init --recursive + TE_SHA="$(git rev-parse --short=12 HEAD)" + echo "TE_SHA=${TE_SHA}" >> "${GITHUB_ENV}" + + - name: Select TE wheel arch for runner (mi300/mi355) + shell: bash + run: | + set -euo pipefail + source .venv/bin/activate + + TE_WHEEL_ARCH="$(python3 .github/workflows/utils/install_te_rocm_wheel.py --print-arch)" + echo "Resolved TE wheel arch for runner: ${TE_WHEEL_ARCH}" + echo "TE_WHEEL_ARCH=${TE_WHEEL_ARCH}" >> "${GITHUB_ENV}" + + # Build ONLY for the ROCm arch present on this CI runner (mi300 or mi355). + if [ "${TE_WHEEL_ARCH}" = "mi355" ]; then + SELECTOR="mi355" + GFX_ARCH="gfx950" + else + SELECTOR="mi300" + GFX_ARCH="gfx942;gfx941" + fi + echo "SELECTOR=${SELECTOR}" >> "${GITHUB_ENV}" + echo "GFX_ARCH=${GFX_ARCH}" >> "${GITHUB_ENV}" + + - name: Build TE wheel + shell: bash + run: | + set -euo pipefail + source .venv/bin/activate + + chmod +x "${REL_SCRIPT}" || true + + export USE_ROCM=1 + export HIP_PATH=/opt/rocm + export NVTE_FRAMEWORK=jax + export CMAKE_BUILD_PARALLEL_LEVEL=64 + export NVTE_USE_ROCM=1 + export NVTE_FUSED_ATTN_AOTRITON=0 + export NVTE_BUILD_MAX_JOBS=180 + + echo "=== Building TE wheel for ${SELECTOR} (gfx=${GFX_ARCH}) ===" + pushd TransformerEngine >/dev/null + rm -rf build dist + export PYTHONPATH="$(pwd)/3rdparty/hipify_torch${PYTHONPATH:+:${PYTHONPATH}}" + export PYTORCH_ROCM_ARCH="${GFX_ARCH}" + export NVTE_ROCM_ARCH="${GFX_ARCH}" + python3 setup.py bdist_wheel + wheel_path="$( + python3 -c 'import glob; m=sorted(glob.glob("dist/transformer_engine-*.whl")); print(m[0] if m else "")' + )" + if [ -z "${wheel_path}" ]; then + echo "No wheel produced in dist/ (selector=${SELECTOR})." + exit 1 + fi + wheel_base="$(basename "${wheel_path}")" + if [[ "${wheel_base}" == *"-1.${SELECTOR}-${PYTAG}-${PYTAG}-linux_x86_64.whl" ]]; then + asset_name="${wheel_base}" + else + asset_name="${wheel_base/-${PYTAG}-${PYTAG}-linux_x86_64.whl/-1.${SELECTOR}-${PYTAG}-${PYTAG}-linux_x86_64.whl}" + if [ "${asset_name}" = "${wheel_base}" ]; then + echo "Failed to rename wheel for selector=${SELECTOR}: ${wheel_base}" + exit 1 + fi + fi + cp -f "${wheel_path}" "../${asset_name}" + popd >/dev/null + + ls -lh "${asset_name}" + echo "TE_WHEEL_FILE=${asset_name}" >> "${GITHUB_ENV}" + + - name: Upload wheel to rolling release tag (te-rocm-wheels) + shell: bash + env: + GITHUB_TOKEN: ${{ github.token }} + run: | + set -euo pipefail + source .venv/bin/activate + python3 "${REL_SCRIPT}" upload --no-prerelease --tag "te-rocm-wheels" --title "ROCm TransformerEngine wheels (latest)" --body "Rolling release for latest weekly-built ROCm TransformerEngine wheels used by CI." --file "${TE_WHEEL_FILE}" + + - name: Prune old assets from rolling tag + shell: bash + env: + GITHUB_TOKEN: ${{ github.token }} + run: | + set -euo pipefail + source .venv/bin/activate + echo "Pruning rolling-tag assets older than ${TE_WHEELS_KEEP_DAYS} days" + python3 "${REL_SCRIPT}" prune-assets --tag "te-rocm-wheels" --keep-days "${TE_WHEELS_KEEP_DAYS}" + + - name: Publish wheel to dated weekly release tag + shell: bash + env: + GITHUB_TOKEN: ${{ github.token }} + run: | + set -euo pipefail + source .venv/bin/activate + + DATE_UTC="$(date -u +%Y-%m-%d)" + WEEKLY_TAG="te-rocm-wheels-${DATE_UTC}-${TE_SHA}" + WEEKLY_TITLE="ROCm TransformerEngine wheels ${DATE_UTC} (TE ${TE_SHA})" + # Keep this YAML-safe (no unindented heredocs inside `run: |`). + WEEKLY_BODY="$( + printf '%s\n\nROCm: %s\nPython: %s\nArch: %s (gfx=%s)\n' \ + "Built from ROCm/TransformerEngine dev @ ${TE_SHA} on ${DATE_UTC}." \ + "${ROCM_NUM}" "${PYTAG}" "${SELECTOR}" "${GFX_ARCH}" + )" + + python3 "${REL_SCRIPT}" upload --no-prerelease --tag "${WEEKLY_TAG}" --title "${WEEKLY_TITLE}" --body "${WEEKLY_BODY}" --file "${TE_WHEEL_FILE}" + + - name: Prune old weekly releases + shell: bash + env: + GITHUB_TOKEN: ${{ github.token }} + run: | + set -euo pipefail + source .venv/bin/activate + echo "Pruning weekly release pages older than ${TE_WHEELS_KEEP_DAYS} days" + python3 "${REL_SCRIPT}" prune-releases --prefix "te-rocm-wheels-" --keep-days "${TE_WHEELS_KEEP_DAYS}" diff --git a/.github/workflows/run_tests_against_package.yml b/.github/workflows/run_tests_against_package.yml index fd4a59fbba..9eacb4c1d3 100644 --- a/.github/workflows/run_tests_against_package.yml +++ b/.github/workflows/run_tests_against_package.yml @@ -40,7 +40,7 @@ on: default: '' is_scheduled_run: required: true - type: string + type: boolean xla_python_client_mem_fraction: required: true type: string @@ -65,6 +65,17 @@ on: description: 'Git SHA to checkout if MaxText is not pre-installed' required: false type: string + decoupled_mode: + required: false + type: boolean + default: false + requirements_file: + required: false + type: string + default: '' + extra_pip_deps_file: + required: false + type: string # Flag to skip source checkout and wheel installation maxtext_installed: description: 'If false, maxtext_sha must be provided for checkout' @@ -78,19 +89,24 @@ permissions: contents: read jobs: run: - runs-on: ${{ inputs.cloud_runner != '' && inputs.cloud_runner || fromJson(format('["self-hosted", "{0}", "{1}"]', inputs.device_type, inputs.device_name)) }} + runs-on: ${{ inputs.cloud_runner != '' && inputs.cloud_runner || (inputs.device_type == 'rocm' && fromJson('["self-hosted","linux-x86-64-4gpu-amd"]')) || fromJson(format('["self-hosted", "{0}", "{1}"]', inputs.device_type, inputs.device_name)) }} + timeout-minutes: ${{ inputs.device_type == 'rocm' && 90 || 360 }} container: - image: gcr.io/${{ vars.PROJECT_NAME }}/${{ inputs.base_image }} + image: ${{ inputs.device_type == 'rocm' && 'ghcr.io/rocm/jax-base-ubu24.rocm720:latest' || format('gcr.io/{0}/{1}', vars.PROJECT_NAME, inputs.base_image) }} env: XLA_PYTHON_CLIENT_MEM_FRACTION: ${{ inputs.xla_python_client_mem_fraction }} TF_FORCE_GPU_ALLOW_GROWTH: ${{ inputs.tf_force_gpu_allow_growth }} TPU_SKIP_MDS_QUERY: ${{ inputs.device_type == 'cpu' && '1' || '' }} + # ROCm installs the wheel with --no-deps plus a pinned requirements file; never use TPU wheel extras. MAXTEXT_PACKAGE_EXTRA: >- ${{ - !contains(inputs.pytest_marker, 'not post_training') && 'tpu-post-train' + inputs.device_type == 'rocm' && 'rocm' + || !contains(inputs.pytest_marker, 'not post_training') && 'tpu-post-train' || (inputs.device_type == 'cpu' && 'tpu' || inputs.device_type) }} ALLOW_MULTIPLE_LIBTPU_LOAD: ${{ inputs.device_type == 'cpu' && 'true' || '' }} # bypass /tmp/libtpu_lockfile check for cpu tests, which don't actually use accelerators (to allow concurrency) + DECOUPLE_GCLOUD: ${{ inputs.decoupled_mode && 'TRUE' || '' }} + LOCAL_GCLOUD_PROJECT: ${{ inputs.decoupled_mode && 'ci-decoupled' || '' }} options: ${{ inputs.container_resource_option }} steps: - name: Checkout MaxText @@ -107,20 +123,79 @@ jobs: if: ${{ !inputs.maxtext_installed }} shell: bash run: | + if [ "${{ inputs.device_type }}" = "rocm" ]; then + python3 -m pip install -U uv + fi python3 -m uv venv --seed source .venv/bin/activate maxtext_wheel=$(ls maxtext-*-py3-none-any.whl 2>/dev/null) echo "Installing ${maxtext_wheel} for ${MAXTEXT_PACKAGE_EXTRA}..." - uv pip install ${maxtext_wheel}[${MAXTEXT_PACKAGE_EXTRA}] --resolution=lowest - if [ "${MAXTEXT_PACKAGE_EXTRA}" == "tpu-post-train" ]; then - install_tpu_post_train_extra_deps + if [ "${{ inputs.device_type }}" = "rocm" ]; then + if [ -n "${{ inputs.requirements_file }}" ]; then + echo "Installing requirements from ${{ inputs.requirements_file }}" + uv pip install -r "${{ inputs.requirements_file }}" + fi + uv pip install ${maxtext_wheel} --no-deps + # When a requirements file is set (e.g. decoupled or rocm-unit JAX stacks), it already + # carries test/runtime pins; otherwise add GitHub-sourced pins (no TPU extra / scripts). + if [ -z "${{ inputs.requirements_file }}" ]; then + uv pip install -r src/dependencies/extra_deps/pre_train_github_deps.txt --no-deps + fi else - install_tpu_pre_train_extra_deps + if [ -n "${{ inputs.requirements_file }}" ]; then + echo "Installing requirements from ${{ inputs.requirements_file }}" + uv pip install -r "${{ inputs.requirements_file }}" + uv pip install ${maxtext_wheel} --no-deps + else + uv pip install ${maxtext_wheel}[${MAXTEXT_PACKAGE_EXTRA}] --resolution=lowest + fi + if [ "${MAXTEXT_PACKAGE_EXTRA}" == "tpu-post-train" ]; then + install_tpu_post_train_extra_deps + else + install_tpu_pre_train_extra_deps + fi fi python3 --version python3 -m pip freeze + - name: Install extra pip deps + if: inputs.extra_pip_deps_file != '' + shell: bash + run: | + source .venv/bin/activate + uv pip install -r ${{ inputs.extra_pip_deps_file }} + + - name: Select ROCm arch (mi300/mi355) + if: ${{ inputs.device_type == 'rocm' }} + shell: bash + run: | + set -euo pipefail + echo "=== ROCm arch selection (from install_te_rocm_wheel.py) ===" + if [ -x .venv/bin/python3 ]; then + te_arch="$( + .venv/bin/python3 .github/workflows/utils/install_te_rocm_wheel.py --print-arch + )" + echo "[te detect_arch] ${te_arch}" + echo "TE_WHEEL_ARCH=${te_arch}" >> "${GITHUB_ENV}" + else + echo "No .venv python found; skipping arch selection." + fi + echo "=========================================================" + + - name: Install Transformer Engine wheel (ROCm) + if: ${{ inputs.device_type == 'rocm' }} + shell: bash + env: + GITHUB_TOKEN: ${{ github.token }} + run: | + source .venv/bin/activate + set -euo pipefail + + python3 .github/workflows/utils/install_te_rocm_wheel.py + + uv pip install --no-deps --force-reinstall transformer_engine-*.whl + - name: Copy test assets files - if: ${{ !inputs.maxtext_installed }} + if: ${{ !inputs.maxtext_installed && !inputs.decoupled_mode }} run : gcloud storage cp gs://maxtext-test-assets/* tests/assets - name: Run Tests shell: bash @@ -153,8 +228,8 @@ jobs: export MAXTEXT_ASSETS_ROOT=$(pwd)/src/maxtext/assets export MAXTEXT_TEST_ASSETS_ROOT=$(pwd)/tests/assets export MAXTEXT_PKG_DIR=$(pwd)/src/maxtext - # omit this libtpu init args for gpu tests - if [ "${INPUTS_DEVICE_TYPE}" != "cuda12" ]; then + # omit this libtpu init args for gpu tests (cuda + rocm) + if [ "${INPUTS_DEVICE_TYPE}" != "cuda12" ] && [ "${INPUTS_DEVICE_TYPE}" != "rocm" ]; then export LIBTPU_INIT_ARGS='--xla_tpu_scoped_vmem_limit_kib=65536' else # For cuda12, explicitly point to the pip-installed CUDA libraries @@ -165,6 +240,9 @@ jobs: echo "Warning: Could not find pinned nvidia libraries in .venv." fi fi + if [ "${INPUTS_DEVICE_TYPE}" = "rocm" ]; then + ulimit -c 0 + fi if [ "${INPUTS_TOTAL_WORKERS}" -gt 1 ]; then $PYTHON_EXE -m pip install --quiet pytest-split pytest-xdist SPLIT_ARGS="--splits ${INPUTS_TOTAL_WORKERS} --group ${INPUTS_WORKER_GROUP} -n auto" diff --git a/.github/workflows/run_tests_coordinator.yml b/.github/workflows/run_tests_coordinator.yml index 935b7fbda0..4fea6207aa 100644 --- a/.github/workflows/run_tests_coordinator.yml +++ b/.github/workflows/run_tests_coordinator.yml @@ -27,7 +27,9 @@ on: tpu-post-training-unit, tpu-post-training-integration, gpu-unit, gpu-integration, cpu-unit, - cpu-post-training-unit + cpu-post-training-unit, + rocm-unit, + rocm-decoupled ) required: true type: string @@ -80,7 +82,9 @@ jobs: "gpu-unit": "cuda12", "gpu-integration": "cuda12", "cpu-unit": "cpu", - "cpu-post-training-unit": "cpu" + "cpu-post-training-unit": "cpu", + "rocm-unit": "rocm", + "rocm-decoupled": "rocm" }')[inputs.flavor] }} device_name: >- @@ -92,7 +96,9 @@ jobs: "gpu-unit": "a100-40gb-4", "gpu-integration": "a100-40gb-4", "cpu-unit": "X64", - "cpu-post-training-unit": "X64" + "cpu-post-training-unit": "X64", + "rocm-unit": "mi355", + "rocm-decoupled": "mi355" }')[inputs.flavor] }} cloud_runner: >- @@ -104,7 +110,9 @@ jobs: "gpu-unit": "linux-x86-a2-48-a100-4gpu", "gpu-integration": "linux-x86-a2-48-a100-4gpu", "cpu-unit": "linux-x86-n2-32", - "cpu-post-training-unit": "linux-x86-n2-32" + "cpu-post-training-unit": "linux-x86-n2-32", + "rocm-unit": "linux-x86-64-4gpu-amd", + "rocm-decoupled": "linux-x86-64-4gpu-amd" }')[inputs.flavor] }} # Pytest Marker Mapping pytest_marker: >- @@ -116,7 +124,9 @@ jobs: "gpu-unit": "not cpu_only and not tpu_only and not integration_test and not post_training", "gpu-integration": "not cpu_only and not tpu_only and integration_test and not post_training", "cpu-unit": "cpu_only and not post_training", - "cpu-post-training-unit": "cpu_only and post_training" + "cpu-post-training-unit": "cpu_only and post_training", + "rocm-unit": "not cpu_only and not tpu_only and not post_training and not decoupled", + "rocm-decoupled": "not cpu_only and not tpu_only and not post_training and decoupled" }')[inputs.flavor] }} pytest_addopts: >- @@ -128,7 +138,9 @@ jobs: "gpu-unit": "", "gpu-integration": "", "cpu-unit": "", - "cpu-post-training-unit": "tests/post_training/unit tests/unit" + "cpu-post-training-unit": "tests/post_training/unit tests/unit", + "rocm-unit": "", + "rocm-decoupled": "" }')[inputs.flavor] }} pytest_extra_args: >- @@ -140,18 +152,44 @@ jobs: "gpu-unit": "--ignore=tests/post_training", "gpu-integration": "--ignore=tests/post_training", "cpu-unit": "--ignore=tests/post_training", - "cpu-post-training-unit": "" + "cpu-post-training-unit": "", + "rocm-unit": "--ignore=tests/post_training", + "rocm-decoupled": "--ignore=tests/post_training" }')[inputs.flavor] }} ${{ inputs.additional_pytest_args }} # Resource Scaling - xla_python_client_mem_fraction: "${{ contains(inputs.flavor, 'gpu') && '0.65' || '0.75' }}" - tf_force_gpu_allow_growth: "${{ contains(inputs.flavor, 'gpu') && 'true' || 'false' }}" + xla_python_client_mem_fraction: >- + ${{ fromJSON('{ + "gpu-unit": "0.65", + "gpu-integration": "0.65", + "rocm-unit": "0.9", + "rocm-decoupled": "0.9" + }')[inputs.flavor] || '0.75' }} + + tf_force_gpu_allow_growth: >- + ${{ fromJSON('{ + "gpu-unit": "true", + "gpu-integration": "true", + "rocm-unit": "true", + "rocm-decoupled": "true" + }')[inputs.flavor] || 'false' }} container_resource_option: >- - ${{ contains(inputs.flavor, 'gpu') - && '--shm-size 2g --runtime=nvidia --gpus all --privileged' - || '--privileged' }} + ${{ fromJSON('{ + "gpu-unit": "--shm-size 2g --runtime=nvidia --gpus all --privileged", + "gpu-integration": "--shm-size 2g --runtime=nvidia --gpus all --privileged", + "rocm-unit": "--device=/dev/kfd --device=/dev/dri --group-add video --ipc=host --shm-size 64g --cap-add=SYS_PTRACE --security-opt seccomp=unconfined --privileged", + "rocm-decoupled": "--device=/dev/kfd --device=/dev/dri --group-add video --ipc=host --shm-size 64g --cap-add=SYS_PTRACE --security-opt seccomp=unconfined --privileged" + }')[inputs.flavor] || '--privileged' }} + + # ROCm-specific parameters + decoupled_mode: ${{ contains(inputs.flavor, 'decoupled') }} + requirements_file: >- + ${{ fromJSON('{ + "rocm-unit": "src/dependencies/requirements/requirements_rocm_jax_0_9_1.txt", + "rocm-decoupled": "src/dependencies/requirements/requirements_decoupled_rocm_jax_0_9_1.txt" + }')[inputs.flavor] || '' }} # Metadata base_image: ${{ inputs.base_image }} diff --git a/.github/workflows/upstream_sync.yml b/.github/workflows/upstream_sync.yml new file mode 100644 index 0000000000..d19fe20053 --- /dev/null +++ b/.github/workflows/upstream_sync.yml @@ -0,0 +1,51 @@ +name: Upstream Sync + +on: + workflow_dispatch: {} + schedule: + - cron: '30 4 * * *' # Daily 04:30 UTC + +permissions: + contents: write + +jobs: + sync: + runs-on: ubuntu-latest + if: github.ref == 'refs/heads/main' + steps: + - name: Checkout + uses: actions/checkout@v4 + with: + fetch-depth: 0 + persist-credentials: false + + - name: Configure git + run: | + git config user.name "github-actions[bot]" + git config user.email "github-actions[bot]@users.noreply.github.com" + + - name: Fetch upstream + run: | + git remote add upstream https://github.com/google/maxtext.git 2>/dev/null || true + git fetch upstream main + git checkout main + git pull --ff-only origin main || true + + - name: Merge upstream/main + id: merge_step + run: | + set -e + git merge --no-edit upstream/main || { echo "::error::Merge conflict - resolve manually"; git merge --abort || true; exit 1; } + if git diff --quiet origin/main..main; then echo "no_changes=true" >> $GITHUB_OUTPUT; else echo "no_changes=false" >> $GITHUB_OUTPUT; fi + + - name: Push (if changed) + if: steps.merge_step.outputs.no_changes == 'false' + env: + PAT: ${{ secrets.UPSTREAM_SYNC_TOKEN }} + run: | + [ -z "$PAT" ] && echo "::error::Missing UPSTREAM_SYNC_TOKEN secret" && exit 1 + git push https://x-access-token:$PAT@github.com/${{ github.repository }}.git main + + - name: Result + run: | + if [ "${{ steps.merge_step.outputs.no_changes }}" = "true" ]; then echo "Up to date"; else echo "Synced upstream"; fi diff --git a/.github/workflows/utils/install_te_rocm_wheel.py b/.github/workflows/utils/install_te_rocm_wheel.py new file mode 100644 index 0000000000..8f601884fd --- /dev/null +++ b/.github/workflows/utils/install_te_rocm_wheel.py @@ -0,0 +1,152 @@ +#!/usr/bin/env python3 +""" +ROCm CI helper: +- Detect MI300 vs MI355 +- Prefer wheel from this repo's 'te-rocm-wheels' release assets +- Fallback to pinned ROCm/maxtext release assets (arch-specific) + +This script only downloads the wheel into the current working directory. +The caller should then install it (e.g. `uv pip install transformer_engine-*.whl`). +""" + +from __future__ import annotations + +import argparse +import json +import os +import re +import subprocess +import urllib.error +import urllib.request + + +def _run(cmd: str) -> str: + try: + return subprocess.check_output(["bash", "-lc", cmd], text=True, stderr=subprocess.STDOUT) + except (subprocess.CalledProcessError, FileNotFoundError, OSError): + return "" + + +def detect_arch() -> str: + """ + Return wheel arch selector: 'mi300' or 'mi355'. + + Notes: + - MI300 family commonly reports gfx942/gfx941. + - MI350/MI355 family commonly reports gfx950 and product strings like MI350X. + """ + override = os.environ.get("TE_WHEEL_ARCH", "").strip().lower() + if override in {"mi300", "mi355"}: + return override + + rocm_smi = _run("command -v rocm-smi >/dev/null 2>&1 && rocm-smi --showproductname || true") or _run( + "[ -x /opt/rocm/bin/rocm-smi ] && /opt/rocm/bin/rocm-smi --showproductname || true" + ) + rocminfo = _run("command -v rocminfo >/dev/null 2>&1 && rocminfo || true") or _run( + "[ -x /opt/rocm/bin/rocminfo ] && /opt/rocm/bin/rocminfo || true" + ) + + blob = f"{rocm_smi}\n{rocminfo}".lower() + gfxs = sorted(set(re.findall(r"gfx[0-9a-f]+", blob))) + + # Prefer explicit gfx IDs when available. + if "gfx950" in gfxs: + return "mi355" + if "gfx942" in gfxs or "gfx941" in gfxs: + return "mi300" + + # Fall back to product string checks. + if "mi355" in blob or "mi350" in blob: + return "mi355" + if "mi300" in blob: + return "mi300" + + # Safe default. + return "mi355" + + +def _headers() -> dict[str, str]: + token = os.environ.get("GITHUB_TOKEN", "") + headers: dict[str, str] = { + "Accept": "application/vnd.github+json", + "User-Agent": "maxtext-ci", + } + if token: + headers["Authorization"] = f"Bearer {token}" + return headers + + +def download(url: str, out_name: str) -> None: + req = urllib.request.Request(url, headers=_headers()) + with urllib.request.urlopen(req) as r, open(out_name, "wb") as f: + f.write(r.read()) + print(f"[te wheel] downloaded {out_name}", flush=True) + + +def try_download_from_te_rocm_wheels(repo: str, arch: str) -> bool: + """Download from this repo's 'te-rocm-wheels' release tag if present.""" + api = f"https://api.github.com/repos/{repo}/releases/tags/te-rocm-wheels" + req = urllib.request.Request(api, headers=_headers()) + with urllib.request.urlopen(req) as r: + rel = json.loads(r.read().decode("utf-8")) + + assets = rel.get("assets", []) + name_re = re.compile(rf"^transformer_engine-.*-1\.{arch}-cp312-cp312-linux_x86_64\.whl$") + matches = [a for a in assets if name_re.match(a.get("name", ""))] + if not matches: + return False + + # Rolling tag keeps many wheels; select newest matching asset. + hit = max(matches, key=lambda a: a.get("created_at", "")) + print( + "[te wheel] selected latest te-rocm-wheels asset: " + f"{hit.get('name', '')} (created_at={hit.get('created_at', 'unknown')})", + flush=True, + ) + + download(hit["browser_download_url"], hit["name"]) + return True + + +def main(argv: list[str] | None = None) -> int: + """Entry point.""" + parser = argparse.ArgumentParser(add_help=True) + parser.add_argument( + "--print-arch", + action="store_true", + help="Print resolved wheel arch (mi300/mi355) and exit.", + ) + args = parser.parse_args(argv) + + if args.print_arch: + print(detect_arch(), flush=True) + return 0 + + repo = os.environ.get("GITHUB_REPOSITORY", "") + if not repo: + print("[te wheel] GITHUB_REPOSITORY not set; skipping.", flush=True) + return 0 + + arch = detect_arch() + print(f"[te wheel] arch={arch}", flush=True) + + # 1) Prefer: this repo's te-rocm-wheels assets. + try: + if try_download_from_te_rocm_wheels(repo, arch): + return 0 + print(f"[te wheel] no te-rocm-wheels asset for arch={arch}", flush=True) + except urllib.error.HTTPError as e: + print(f"[te wheel] te-rocm-wheels not available ({e.code})", flush=True) + except (urllib.error.URLError, json.JSONDecodeError, KeyError, ValueError) as e: + print(f"[te wheel] te-rocm-wheels lookup failed ({e})", flush=True) + + # 2) Fallback: pinned ROCm/maxtext release assets. + arch_tag = f"1.{arch}" + pinned_name = f"transformer_engine-2.8.0.dev0+2776c337-{arch_tag}-cp312-cp312-linux_x86_64.whl" + pinned = "https://github.com/ROCm/maxtext/releases/download/rocm-maxtext-v0.1.1/" f"{pinned_name}" + download(pinned, pinned_name) + return 0 + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/.github/workflows/utils/te_wheels_release.py b/.github/workflows/utils/te_wheels_release.py new file mode 100644 index 0000000000..a2ae1368e2 --- /dev/null +++ b/.github/workflows/utils/te_wheels_release.py @@ -0,0 +1,204 @@ +#!/usr/bin/env python3 +""" +Publish and prune ROCm TransformerEngine wheel assets on GitHub Releases. + +Intended for use in GitHub Actions. Requires: +- GITHUB_TOKEN +- GITHUB_REPOSITORY (owner/repo) +""" + +from __future__ import annotations + +import argparse +import datetime as dt +import json +import os +import sys +import urllib.error +import urllib.parse +import urllib.request + +API = "https://api.github.com" + + +def _env_token() -> str: + token = os.environ.get("GITHUB_TOKEN") + if not token: + raise SystemExit("GITHUB_TOKEN is not set.") + return token + + +def _env_repo() -> tuple[str, str]: + repo = os.environ.get("GITHUB_REPOSITORY") + if not repo or "/" not in repo: + raise SystemExit("GITHUB_REPOSITORY is not set (expected 'owner/repo').") + owner, name = repo.split("/", 1) + return owner, name + + +def _headers(token: str) -> dict[str, str]: + return { + "Authorization": f"Bearer {token}", + "Accept": "application/vnd.github+json", + "User-Agent": "maxtext-ci", + } + + +def _request_json(method: str, url: str, token: str, body: dict | None = None): + data = None + if body is not None: + data = json.dumps(body).encode("utf-8") + req = urllib.request.Request(url, data=data, method=method, headers=_headers(token)) + with urllib.request.urlopen(req) as r: + raw = r.read() + return json.loads(raw.decode("utf-8")) if raw else None + + +def _request_raw(method: str, url: str, token: str, data: bytes, content_type: str): + h = _headers(token) + h["Content-Type"] = content_type + req = urllib.request.Request(url, data=data, method=method, headers=h) + with urllib.request.urlopen(req) as r: + return r.read() + + +def get_or_create_release(tag: str, title: str, body: str, prerelease: bool) -> dict: + """Get a release by tag, or create it if missing. + + Args: + tag: Release tag (e.g. 'te-rocm-wheels'). + title: Release title. + body: Release body text. + prerelease: Whether to mark the release as a prerelease. + + Returns: + The GitHub release object JSON. + """ + token = _env_token() + owner, name = _env_repo() + try: + rel = _request_json("GET", f"{API}/repos/{owner}/{name}/releases/tags/{tag}", token) + if rel: + return rel + except urllib.error.HTTPError as e: + if e.code != 404: + raise + return _request_json( + "POST", + f"{API}/repos/{owner}/{name}/releases", + token, + {"tag_name": tag, "name": title, "body": body, "prerelease": prerelease}, + ) + + +def upload_asset(tag: str, title: str, body: str, file_path: str, prerelease: bool) -> None: + """Upload (replace) a release asset under the given tag. + + If an asset with the same filename already exists, it is deleted first. + """ + token = _env_token() + owner, name = _env_repo() + + rel = get_or_create_release(tag, title, body, prerelease) + release_id = rel["id"] + upload_url = rel["upload_url"].split("{", 1)[0] + + assets = _request_json("GET", f"{API}/repos/{owner}/{name}/releases/{release_id}", token)["assets"] + file_name = os.path.basename(file_path) + for a in assets: + if a.get("name") == file_name: + _request_json("DELETE", f"{API}/repos/{owner}/{name}/releases/assets/{a['id']}", token) + + with open(file_path, "rb") as f: + data = f.read() + up = f"{upload_url}?{urllib.parse.urlencode({'name': file_name})}" + _request_raw("POST", up, token, data, "application/octet-stream") + print(f"Uploaded {file_name} to release tag {tag}", flush=True) + + +def prune_assets(tag: str, keep_days: int) -> None: + """Delete assets older than `keep_days` from the given release tag.""" + token = _env_token() + owner, name = _env_repo() + try: + rel = _request_json("GET", f"{API}/repos/{owner}/{name}/releases/tags/{tag}", token) + except urllib.error.HTTPError as e: + if e.code == 404: + print(f"No release for tag {tag}; skipping asset prune.", flush=True) + return + raise + + cutoff = dt.datetime.now(dt.timezone.utc) - dt.timedelta(days=keep_days) + pruned = 0 + for a in rel.get("assets", []): + created = dt.datetime.fromisoformat(a["created_at"].replace("Z", "+00:00")) + if created < cutoff: + _request_json("DELETE", f"{API}/repos/{owner}/{name}/releases/assets/{a['id']}", token) + pruned += 1 + print(f"Pruned old asset: {a['name']} (created_at={a['created_at']})", flush=True) + print(f"Asset prune complete for {tag}. Deleted {pruned} assets older than {keep_days} days.", flush=True) + + +def prune_releases(prefix: str, keep_days: int) -> None: + """Delete releases with tag names starting with `prefix` older than `keep_days` days.""" + token = _env_token() + owner, name = _env_repo() + cutoff = dt.datetime.now(dt.timezone.utc) - dt.timedelta(days=keep_days) + + page = 1 + deleted = 0 + while True: + rels = _request_json("GET", f"{API}/repos/{owner}/{name}/releases?per_page=100&page={page}", token) + if not rels: + break + for rel in rels: + tag = rel.get("tag_name", "") + if not tag.startswith(prefix): + continue + created = dt.datetime.fromisoformat(rel["created_at"].replace("Z", "+00:00")) + if created < cutoff: + _request_json("DELETE", f"{API}/repos/{owner}/{name}/releases/{rel['id']}", token) + deleted += 1 + print(f"Deleted old release {tag} (created_at={rel['created_at']})", flush=True) + page += 1 + print(f"Release prune complete. Deleted {deleted} releases older than {keep_days} days.", flush=True) + + +def main(argv: list[str]) -> int: + p = argparse.ArgumentParser() + sub = p.add_subparsers(dest="cmd", required=True) + + up = sub.add_parser("upload") + up.add_argument("--tag", required=True) + up.add_argument("--title", required=True) + up.add_argument("--body", required=True) + up.add_argument("--file", required=True) + prg = up.add_mutually_exclusive_group() + prg.add_argument("--prerelease", dest="prerelease", action="store_true") + prg.add_argument("--no-prerelease", dest="prerelease", action="store_false") + up.set_defaults(prerelease=True) + + pa = sub.add_parser("prune-assets") + pa.add_argument("--tag", required=True) + pa.add_argument("--keep-days", type=int, required=True) + + pr = sub.add_parser("prune-releases") + pr.add_argument("--prefix", required=True) + pr.add_argument("--keep-days", type=int, required=True) + + args = p.parse_args(argv) + + if args.cmd == "upload": + upload_asset(args.tag, args.title, args.body, args.file, prerelease=args.prerelease) + return 0 + if args.cmd == "prune-assets": + prune_assets(args.tag, args.keep_days) + return 0 + if args.cmd == "prune-releases": + prune_releases(args.prefix, args.keep_days) + return 0 + raise AssertionError("unreachable") + + +if __name__ == "__main__": + raise SystemExit(main(sys.argv[1:])) diff --git a/src/dependencies/requirements/requirements_decoupled_jax_0_7.1.txt b/src/dependencies/requirements/requirements_decoupled_jax_0_7.1.txt deleted file mode 100644 index 8f904a3641..0000000000 --- a/src/dependencies/requirements/requirements_decoupled_jax_0_7.1.txt +++ /dev/null @@ -1,44 +0,0 @@ -absl_py>=2.3.1 -aqtp>=0.9.0 -chex>=0.1.90 -datasets>=4.2.0 -etils>=1.13.0 -evaluate>=0.4.6 -flax -grain>=0.2.12 -grpcio>=1.75.1 -huggingface_hub>=0.35.3 -jax==0.7.1 -jaxtyping>=0.3.3 -jsonlines>=4.0.0 -matplotlib>=3.10.3 -ml_collections>=1.1.0 -ml_dtypes>=0.5.3 -nltk>=3.9.2 -numpy>=2.0.2 -omegaconf>=2.3.0 -optax>=0.2.6 -orbax-checkpoint>=0.11.25 -pandas>=2.3.3 -parameterized==0.9.0 -pathwaysutils>=0.1.7 -pillow>=11.3.0 -protobuf>=5.29.5 -psutil>=7.0.0 -pytest>=8.4.1 -PyYAML>=6.0.3 -Requests>=2.32.5 -qwix>=0.1.1 -safetensors>=0.6.2 -sentencepiece>=0.2.1 -setuptools>=80.9.0 -tabulate>=0.9.0 -tensorflow>=2.19.1 -tensorflow_text>=2.19.0 -tensorflow_datasets>=4.9.9 -tensorstore>=0.1.76 -tiktoken>=0.12.0 -tqdm>=4.67.1 -transformers>=4.57.0 -urllib3>=2.5.0 -git+https://github.com/google/tunix.git diff --git a/src/dependencies/requirements/requirements_decoupled_jax_0_9_1.txt b/src/dependencies/requirements/requirements_decoupled_jax_0_9_1.txt new file mode 100644 index 0000000000..20cef3eba3 --- /dev/null +++ b/src/dependencies/requirements/requirements_decoupled_jax_0_9_1.txt @@ -0,0 +1,39 @@ +absl-py +aqtp +array-record +datasets +flax +grain[parquet] +huggingface_hub +jax==0.9.1 +jaxtyping +jsonlines +ml-collections +numpy +omegaconf +optax +orbax-checkpoint +pathwaysutils +pillow +pre-commit +protobuf +pydantic +pyink +pylint +pytest +parameterized +pytype +sentencepiece +tensorboardx +tensorflow-datasets +tensorflow-text<2.20 +tensorflow<2.20 +tiktoken +numba>=0.59.0 +google-tunix +tokamax @ git+https://github.com/openxla/tokamax.git@69b328121c3ee120d8e54cf26d1565cea189617f +transformers>=4.57.3,<5 +qwix +mlperf-logging @ https://github.com/mlcommons/logging/archive/38ab22670527888c8eb7825a4ece176fcc36a95d.zip +chex>=0.1.91 +drjax>=0.1.4 diff --git a/src/dependencies/requirements/requirements_decoupled_rocm_jax_0_9_1.txt b/src/dependencies/requirements/requirements_decoupled_rocm_jax_0_9_1.txt new file mode 100644 index 0000000000..95479d30be --- /dev/null +++ b/src/dependencies/requirements/requirements_decoupled_rocm_jax_0_9_1.txt @@ -0,0 +1,46 @@ +absl-py +aqtp +array-record +datasets +flax +grain[parquet] +huggingface_hub + +jax==0.9.1 + +# ROCm JAX wheels +jax-rocm7-plugin==0.9.1.* + +math-verify + +jaxtyping +jsonlines +ml-collections +numpy +omegaconf +optax +orbax-checkpoint +pathwaysutils +pillow +pre-commit +protobuf +pydantic +pyink +pylint +pytest +parameterized +pytype +sentencepiece +tensorboardx +tensorflow-datasets +tensorflow-text<2.20 +tensorflow<2.20 +tiktoken +numba>=0.59.0 +google-tunix +tokamax @ git+https://github.com/openxla/tokamax.git@69b328121c3ee120d8e54cf26d1565cea189617f +transformers>=4.57.3,<5 +qwix +mlperf-logging @ https://github.com/mlcommons/logging/archive/38ab22670527888c8eb7825a4ece176fcc36a95d.zip +chex>=0.1.91 +drjax>=0.1.4 diff --git a/src/dependencies/requirements/requirements_rocm_benchmark.txt b/src/dependencies/requirements/requirements_rocm_benchmark.txt new file mode 100644 index 0000000000..7dfded642e --- /dev/null +++ b/src/dependencies/requirements/requirements_rocm_benchmark.txt @@ -0,0 +1,52 @@ +absl-py +aqtp +array-record +datasets +flax +grain[parquet] +huggingface_hub +jaxtyping +jsonlines +ml-collections +numpy +omegaconf +optax +orbax-checkpoint +pathwaysutils +pillow +pre-commit +protobuf +pydantic +pyink +pylint +pytest +parameterized +pytype +sentencepiece +tensorboardx +tensorflow-datasets +tensorflow-text<2.20 +tensorflow<2.20 +tiktoken +numba>=0.59.0 +google-tunix +tokamax @ git+https://github.com/openxla/tokamax.git@69b328121c3ee120d8e54cf26d1565cea189617f +transformers>=4.57.3,<5 +qwix +mlperf-logging @ https://github.com/mlcommons/logging/archive/38ab22670527888c8eb7825a4ece176fcc36a95d.zip + + +# GCP / Cloud deps (restored for non-decoupled ROCm runs) +cloud-accelerator-diagnostics +cloud-tpu-diagnostics +gcsfs +google-api-python-client +google-cloud-aiplatform +google-cloud-mldiagnostics +google-cloud-monitoring +ml-goodput-measurement +tensorboard-plugin-profile +google-jetstream @ https://github.com/AI-Hypercomputer/JetStream/archive/29329e8e73820993f77cfc8efe34eb2a73f5de98.zip +chex>=0.1.91 +drjax>=0.1.4 + diff --git a/src/dependencies/requirements/requirements_rocm_jax_0_9_1.txt b/src/dependencies/requirements/requirements_rocm_jax_0_9_1.txt new file mode 100644 index 0000000000..b32d6d1380 --- /dev/null +++ b/src/dependencies/requirements/requirements_rocm_jax_0_9_1.txt @@ -0,0 +1,60 @@ +absl-py +aqtp +array-record +datasets +flax +grain[parquet] +huggingface_hub + +jax==0.9.1 + +# ROCm JAX wheels +jax-rocm7-plugin==0.9.1.* + +math-verify + + +jaxtyping +jsonlines +ml-collections +numpy +omegaconf +optax +orbax-checkpoint +pathwaysutils +pillow +pre-commit +protobuf +pydantic +pyink +pylint +pytest +parameterized +pytype +sentencepiece +tensorboardx +tensorflow-datasets +tensorflow-text<2.20 +tensorflow<2.20 +tiktoken +numba>=0.59.0 +google-tunix +tokamax @ git+https://github.com/openxla/tokamax.git@69b328121c3ee120d8e54cf26d1565cea189617f +transformers>=4.57.3,<5 +qwix +mlperf-logging @ https://github.com/mlcommons/logging/archive/38ab22670527888c8eb7825a4ece176fcc36a95d.zip + + +# GCP / Cloud deps (restored for non-decoupled ROCm runs) +cloud-accelerator-diagnostics +cloud-tpu-diagnostics +gcsfs +google-api-python-client +google-cloud-aiplatform +google-cloud-mldiagnostics +google-cloud-monitoring +ml-goodput-measurement +tensorboard-plugin-profile +google-jetstream @ https://github.com/AI-Hypercomputer/JetStream/archive/29329e8e73820993f77cfc8efe34eb2a73f5de98.zip +chex>=0.1.91 +drjax>=0.1.4 diff --git a/src/maxtext/configs/gpu/models/gemma3-4b-rocm.yml b/src/maxtext/configs/gpu/models/gemma3-4b-rocm.yml new file mode 100644 index 0000000000..92a5cc8d80 --- /dev/null +++ b/src/maxtext/configs/gpu/models/gemma3-4b-rocm.yml @@ -0,0 +1,38 @@ +# model config for gemma3-4b-rocm +base_config: "base.yml" + +steps: 15 +hardware: "gpu" +model_name: "gemma3-4b" +dataset_type: "synthetic" +enable_checkpointing: False +max_segments_per_seq: 32 + +base_num_decoder_layers: 34 +base_emb_dim: 2560 +base_num_query_heads: 8 +base_num_kv_heads: 4 +base_mlp_dim: 10240 +head_dim: 256 +mlp_activations: ["gelu","linear"] +vocab_size: 262_144 +decoder_block: "gemma3" +normalization_layer_epsilon: 1e-6 +logits_via_embedding: True +sliding_window_size: 1024 +use_post_attn_norm: true +use_post_ffw_norm: true +local_rope_max_timescale: 10_000 +rope_max_timescale: 1_000_000 +rope_linear_scaling_factor: 8.0 + +# Multimodal flags (need to set use_multimodal=true) +image_size_for_vit: 896 +num_channels_for_vit: 3 +patch_size_for_vit: 14 +conv_stride_for_vit: 14 +hidden_size_for_vit: 1152 +intermediate_size_for_vit: 4304 +num_hidden_layers_for_vit: 27 +num_attention_heads_for_vit: 16 + diff --git a/tests/integration/train_tests.py b/tests/integration/train_tests.py index fc27753abe..9d10f62686 100644 --- a/tests/integration/train_tests.py +++ b/tests/integration/train_tests.py @@ -18,6 +18,7 @@ import pytest import jax +from jax._src import test_util as jtu from absl.testing import absltest from maxtext.common.gcloud_stub import is_decoupled from maxtext.trainers.pre_train.train import main as train_main diff --git a/tests/utils/test_helpers.py b/tests/utils/test_helpers.py index 2eea9025c7..e4ad9f1cd9 100644 --- a/tests/utils/test_helpers.py +++ b/tests/utils/test_helpers.py @@ -122,11 +122,23 @@ def get_test_base_output_directory(cloud_path=None): return cloud_path or "gs://runner-maxtext-logs" +def get_test_config_path_for(relative_path: str): + """Return absolute path for a config under the configs directory. + + Uses the same decoupled-vs-non-decoupled selection logic as + `get_test_config_path()` when `relative_path` is `base.yml`. + """ + if relative_path == "base.yml": + return get_test_config_path() + return os.path.join(MAXTEXT_CONFIGS_DIR, relative_path) + + __all__ = [ "get_test_base_output_directory", "get_decoupled_parallelism_overrides", "is_rocm_backend", "get_test_config_path", "get_post_train_test_config_path", + "get_test_config_path_for", "get_test_dataset_path", ]