diff --git a/.github/workflows/build_and_test_maxtext.yml b/.github/workflows/build_and_test_maxtext.yml index d54f98fbec..92cd54386a 100644 --- a/.github/workflows/build_and_test_maxtext.yml +++ b/.github/workflows/build_and_test_maxtext.yml @@ -20,16 +20,23 @@ on: pull_request: workflow_call: workflow_dispatch: + inputs: + rocm_only: + description: 'Run only ROCm jobs (manual runs only)' + type: boolean + required: false + default: false schedule: - # Run the job every 4 hours - - cron: '0 */4 * * *' + - cron: '0 3 * * *' # Daily 03:00 UTC concurrency: - # Dedup pull requests (canceling previous runs of the same workflow for same PR), and scheduled runs but nothing else + # Cancel previous runs for same PR (all actors), scheduled runs, + # and manual runs per (branch + actor). group: > ${{ github.event_name == 'pull_request' && format('{0}-pr-{1}', github.workflow, github.event.pull_request.number) || github.event_name == 'schedule' && format('{0}-schedule', github.workflow) || + github.event_name == 'workflow_dispatch' && format('{0}-manual-{1}-{2}', github.workflow, github.ref, github.actor) || github.run_id }} cancel-in-progress: true @@ -118,18 +125,16 @@ jobs: build_and_upload_maxtext_package: needs: analyze_code_changes # Run if either tests or notebooks need to run - if: | - needs.analyze_code_changes.outputs.run_tests == 'true' || - needs.analyze_code_changes.outputs.run_notebooks == 'true' + if: ${{ vars.ROCM_ONLY == 'true' || needs.analyze_code_changes.outputs.run_tests == 'true' || needs.analyze_code_changes.outputs.run_notebooks == 'true' }} uses: ./.github/workflows/build_package.yml with: - device_type: tpu - device_name: v4-8 - cloud_runner: linux-x86-n2-16-buildkit + device_type: ${{ vars.ROCM_ONLY == 'true' && 'rocm' || 'tpu' }} + device_name: ${{ vars.ROCM_ONLY == 'true' && 'mi355' || 'v4-8' }} + cloud_runner: ${{ vars.ROCM_ONLY == 'true' && 'linux-x86-64-4gpu-amd' || 'linux-x86-n2-16-buildkit' }} maxtext_jupyter_notebooks: needs: build_and_upload_maxtext_package - if: needs.analyze_code_changes.outputs.run_notebooks == 'true' + if: ${{ vars.ROCM_ONLY != 'true' && !(github.event_name == 'workflow_dispatch' && inputs.rocm_only) && needs.analyze_code_changes.outputs.run_notebooks == 'true' }} uses: ./.github/workflows/run_jupyter_notebooks.yml strategy: fail-fast: false @@ -145,7 +150,7 @@ jobs: tpu-tests: name: ${{ matrix.flavor }} tests needs: [build_and_upload_maxtext_package] - if: needs.analyze_code_changes.outputs.run_tests == 'true' + if: ${{ vars.ROCM_ONLY != 'true' && !(github.event_name == 'workflow_dispatch' && inputs.rocm_only) && needs.analyze_code_changes.outputs.run_tests == 'true' }} uses: ./.github/workflows/run_tests_coordinator.yml strategy: fail-fast: false @@ -160,7 +165,7 @@ jobs: gpu-tests: name: ${{ matrix.flavor }} tests needs: [build_and_upload_maxtext_package] - if: needs.analyze_code_changes.outputs.run_tests == 'true' + if: ${{ vars.ROCM_ONLY != 'true' && !(github.event_name == 'workflow_dispatch' && inputs.rocm_only) && needs.analyze_code_changes.outputs.run_tests == 'true' }} strategy: fail-fast: false matrix: @@ -175,7 +180,7 @@ jobs: cpu-tests: name: ${{ matrix.flavor }} tests needs: [build_and_upload_maxtext_package] - if: needs.analyze_code_changes.outputs.run_tests == 'true' + if: ${{ vars.ROCM_ONLY != 'true' && !(github.event_name == 'workflow_dispatch' && inputs.rocm_only) && needs.analyze_code_changes.outputs.run_tests == 'true' }} uses: ./.github/workflows/run_tests_coordinator.yml strategy: fail-fast: false @@ -189,7 +194,7 @@ jobs: maxtext_tpu_pathways_unit_tests: needs: build_and_upload_maxtext_package - if: needs.analyze_code_changes.outputs.run_tests == 'true' + if: ${{ vars.ROCM_ONLY != 'true' && !(github.event_name == 'workflow_dispatch' && inputs.rocm_only) && needs.analyze_code_changes.outputs.run_tests == 'true' }} uses: ./.github/workflows/run_pathways_tests.yml strategy: fail-fast: false @@ -207,7 +212,7 @@ jobs: maxtext_tpu_pathways_integration_tests: needs: build_and_upload_maxtext_package - if: needs.analyze_code_changes.outputs.run_tests == 'true' + if: ${{ vars.ROCM_ONLY != 'true' && !(github.event_name == 'workflow_dispatch' && inputs.rocm_only) && needs.analyze_code_changes.outputs.run_tests == 'true' }} uses: ./.github/workflows/run_pathways_tests.yml strategy: fail-fast: false @@ -223,9 +228,39 @@ jobs: is_scheduled_run: ${{ github.event_name == 'schedule' }} maxtext_sha: ${{ needs.build_and_upload_maxtext_package.outputs.maxtext_sha }} + rocm-tests: + name: ${{ matrix.flavor }} tests + needs: [build_and_upload_maxtext_package] + if: ${{ vars.ROCM_ONLY != 'true' && needs.analyze_code_changes.outputs.run_tests == 'true' }} + uses: ./.github/workflows/run_tests_coordinator.yml + strategy: + fail-fast: false + matrix: + flavor: [rocm-unit] + with: + flavor: ${{ matrix.flavor }} + base_image: 'rocm-placeholder' + is_scheduled_run: ${{ github.event_name == 'schedule' }} + maxtext_sha: ${{ needs.build_and_upload_maxtext_package.outputs.maxtext_sha }} + + rocm-decoupled-tests: + name: ${{ matrix.flavor }} tests + needs: [build_and_upload_maxtext_package] + if: ${{ vars.ROCM_ONLY == 'true' || (github.event_name == 'workflow_dispatch' && inputs.rocm_only) || needs.analyze_code_changes.outputs.run_tests == 'true' }} + uses: ./.github/workflows/run_tests_coordinator.yml + strategy: + fail-fast: false + matrix: + flavor: [rocm-decoupled] + with: + flavor: ${{ matrix.flavor }} + base_image: 'rocm-placeholder' + is_scheduled_run: ${{ github.event_name == 'schedule' }} + maxtext_sha: ${{ needs.build_and_upload_maxtext_package.outputs.maxtext_sha }} + all_tests_passed: name: All Required Tests Passed - needs: [tpu-tests, gpu-tests, cpu-tests, maxtext_tpu_pathways_unit_tests, maxtext_tpu_pathways_integration_tests] + needs: [analyze_code_changes, build_and_upload_maxtext_package, tpu-tests, gpu-tests, cpu-tests, maxtext_tpu_pathways_unit_tests, maxtext_tpu_pathways_integration_tests, rocm-tests, rocm-decoupled-tests] if: always() runs-on: ubuntu-latest steps: @@ -243,6 +278,8 @@ jobs: echo "CPU Tests (Matrix) result: ${NEEDS_CPU_TESTS_RESULT}" echo "Pathways Unit result: ${NEEDS_MAXTEXT_TPU_PATHWAYS_UNIT_TESTS_RESULT}" echo "Pathways Integration result: ${NEEDS_MAXTEXT_TPU_PATHWAYS_INTEGRATION_TESTS_RESULT}" + echo "ROCm Tests (Matrix) result: ${NEEDS_ROCM_TESTS_RESULT}" + echo "ROCm Decoupled Tests (Matrix) result: ${NEEDS_ROCM_DECOUPLED_TESTS_RESULT}" # Fail only if any job failed or was cancelled (skipped is OK) if [ "${{ contains(needs.*.result, 'failure') }}" == "true" ] || [ "${{ contains(needs.*.result, 'cancelled') }}" == "true" ]; then @@ -259,11 +296,13 @@ jobs: NEEDS_GPU_TESTS_RESULT: ${{ needs.gpu-tests.result }} NEEDS_MAXTEXT_TPU_PATHWAYS_UNIT_TESTS_RESULT: ${{ needs.maxtext_tpu_pathways_unit_tests.result }} NEEDS_MAXTEXT_TPU_PATHWAYS_INTEGRATION_TESTS_RESULT: ${{ needs.maxtext_tpu_pathways_integration_tests.result }} + NEEDS_ROCM_TESTS_RESULT: ${{ needs.rocm-tests.result }} + NEEDS_ROCM_DECOUPLED_TESTS_RESULT: ${{ needs.rocm-decoupled-tests.result }} all_notebooks_passed: name: All Notebooks Passed needs: [analyze_code_changes, build_and_upload_maxtext_package, maxtext_jupyter_notebooks] - if: always() + if: ${{ vars.ROCM_ONLY != 'true' && !(github.event_name == 'workflow_dispatch' && inputs.rocm_only) && always() }} runs-on: ubuntu-latest steps: - name: Check notebooks results @@ -291,7 +330,7 @@ jobs: notify_failure: name: Notify failed build # creates an issue or modifies last open existing issue for failed build - needs: [tpu-tests, gpu-tests, cpu-tests, maxtext_jupyter_notebooks, maxtext_tpu_pathways_unit_tests, maxtext_tpu_pathways_integration_tests] + needs: [tpu-tests, gpu-tests, cpu-tests, maxtext_jupyter_notebooks, maxtext_tpu_pathways_unit_tests, maxtext_tpu_pathways_integration_tests, rocm-tests, rocm-decoupled-tests] if: ${{ always() }} runs-on: ubuntu-latest permissions: diff --git a/.github/workflows/build_rocm_transformer_engine_wheel_weekly.yml b/.github/workflows/build_rocm_transformer_engine_wheel_weekly.yml new file mode 100644 index 0000000000..9b90eee2a0 --- /dev/null +++ b/.github/workflows/build_rocm_transformer_engine_wheel_weekly.yml @@ -0,0 +1,202 @@ +name: Build ROCm TransformerEngine wheel (weekly) + +on: + workflow_dispatch: + schedule: + # Weekly at night (02:00 UTC every Monday), 2 hours ahead of scheduled tests. + - cron: "0 2 * * 1" + +permissions: + contents: write + +jobs: + build_upload_prune: + # AMD GPU runner (GitHub-hosted large runner label). + runs-on: linux-x86-64-4gpu-amd + container: + image: ghcr.io/rocm/jax-base-ubu24.rocm720:latest + options: >- + --device=/dev/kfd --device=/dev/dri --group-add video + --ipc=host --shm-size 64g + --cap-add=SYS_PTRACE --security-opt seccomp=unconfined + --privileged + env: + XLA_PYTHON_CLIENT_MEM_FRACTION: "0.9" + NVTE_FUSED_ATTN_AOTRITON: "0" + env: + TE_WHEELS_KEEP_DAYS: "21" + + steps: + - name: Checkout + uses: actions/checkout@v5 + + - name: Setup build environment (deps + venv) + shell: bash + run: | + set -euo pipefail + apt-get update + apt-get install -y --no-install-recommends git build-essential python3-dev + python3 -m pip install -U uv + python3 -m uv venv --seed + source .venv/bin/activate + uv pip install -U pip setuptools wheel pybind11 cmake + + - name: Install ROCm JAX/JAXlib wheels (build against CI stack) + shell: bash + run: | + set -euo pipefail + source .venv/bin/activate + uv pip install -r src/dependencies/requirements/requirements_decoupled_rocm_jax_0_8_2.txt + + - name: Install PyTorch ROCm (build-time dep for aiter JIT) + shell: bash + run: | + set -euo pipefail + source .venv/bin/activate + uv pip install torch --index-url https://download.pytorch.org/whl/rocm7.2 + + - name: Detect ROCm version and Python tag + shell: bash + run: | + set -euo pipefail + source .venv/bin/activate + + # Detect ROCm version number from JAX backend string when available (e.g. 'rocm 70200'). + ROCM_NUM="$([ -f /opt/rocm/.info/version ] && head -n1 /opt/rocm/.info/version | tr -d ' \t\r' || echo unknown)" + echo "Detected ROCm version: ${ROCM_NUM}" + echo "ROCM_NUM=${ROCM_NUM}" >> "${GITHUB_ENV}" + + PYTAG="cp$(python3 -c 'import sys; print(f"{sys.version_info.major}{sys.version_info.minor}")')" + if [ "${PYTAG}" != "cp312" ]; then + echo "Expected Python 3.12 (cp312) for ROCm CI wheels, got ${PYTAG}." + exit 1 + fi + echo "PYTAG=${PYTAG}" >> "${GITHUB_ENV}" + echo "REL_SCRIPT=.github/workflows/utils/te_wheels_release.py" >> "${GITHUB_ENV}" + + - name: Clone ROCm/TransformerEngine (dev) + shell: bash + run: | + set -euo pipefail + rm -rf TransformerEngine + git clone --recursive --branch dev https://github.com/ROCm/TransformerEngine.git + cd TransformerEngine + git submodule update --init --recursive + TE_SHA="$(git rev-parse --short=12 HEAD)" + echo "TE_SHA=${TE_SHA}" >> "${GITHUB_ENV}" + + - name: Select TE wheel arch for runner (mi300/mi355) + shell: bash + run: | + set -euo pipefail + source .venv/bin/activate + + TE_WHEEL_ARCH="$(python3 .github/workflows/utils/install_te_rocm_wheel.py --print-arch)" + echo "Resolved TE wheel arch for runner: ${TE_WHEEL_ARCH}" + echo "TE_WHEEL_ARCH=${TE_WHEEL_ARCH}" >> "${GITHUB_ENV}" + + # Build ONLY for the ROCm arch present on this CI runner (mi300 or mi355). + if [ "${TE_WHEEL_ARCH}" = "mi355" ]; then + SELECTOR="mi355" + GFX_ARCH="gfx950" + else + SELECTOR="mi300" + GFX_ARCH="gfx942;gfx941" + fi + echo "SELECTOR=${SELECTOR}" >> "${GITHUB_ENV}" + echo "GFX_ARCH=${GFX_ARCH}" >> "${GITHUB_ENV}" + + - name: Build TE wheel + shell: bash + run: | + set -euo pipefail + source .venv/bin/activate + + chmod +x "${REL_SCRIPT}" || true + + export USE_ROCM=1 + export HIP_PATH=/opt/rocm + export NVTE_FRAMEWORK=jax + export CMAKE_BUILD_PARALLEL_LEVEL=64 + export NVTE_USE_ROCM=1 + export NVTE_FUSED_ATTN_AOTRITON=0 + export NVTE_BUILD_MAX_JOBS=180 + + echo "=== Building TE wheel for ${SELECTOR} (gfx=${GFX_ARCH}) ===" + pushd TransformerEngine >/dev/null + rm -rf build dist + export PYTHONPATH="$(pwd)/3rdparty/hipify_torch${PYTHONPATH:+:${PYTHONPATH}}" + export PYTORCH_ROCM_ARCH="${GFX_ARCH}" + export NVTE_ROCM_ARCH="${GFX_ARCH}" + python3 setup.py bdist_wheel + wheel_path="$( + python3 -c 'import glob; m=sorted(glob.glob("dist/transformer_engine-*.whl")); print(m[0] if m else "")' + )" + if [ -z "${wheel_path}" ]; then + echo "No wheel produced in dist/ (selector=${SELECTOR})." + exit 1 + fi + wheel_base="$(basename "${wheel_path}")" + if [[ "${wheel_base}" == *"-1.${SELECTOR}-${PYTAG}-${PYTAG}-linux_x86_64.whl" ]]; then + asset_name="${wheel_base}" + else + asset_name="${wheel_base/-${PYTAG}-${PYTAG}-linux_x86_64.whl/-1.${SELECTOR}-${PYTAG}-${PYTAG}-linux_x86_64.whl}" + if [ "${asset_name}" = "${wheel_base}" ]; then + echo "Failed to rename wheel for selector=${SELECTOR}: ${wheel_base}" + exit 1 + fi + fi + cp -f "${wheel_path}" "../${asset_name}" + popd >/dev/null + + ls -lh "${asset_name}" + echo "TE_WHEEL_FILE=${asset_name}" >> "${GITHUB_ENV}" + + - name: Upload wheel to rolling release tag (te-rocm-wheels) + shell: bash + env: + GITHUB_TOKEN: ${{ github.token }} + run: | + set -euo pipefail + source .venv/bin/activate + python3 "${REL_SCRIPT}" upload --no-prerelease --tag "te-rocm-wheels" --title "ROCm TransformerEngine wheels (latest)" --body "Rolling release for latest weekly-built ROCm TransformerEngine wheels used by CI." --file "${TE_WHEEL_FILE}" + + - name: Prune old assets from rolling tag + shell: bash + env: + GITHUB_TOKEN: ${{ github.token }} + run: | + set -euo pipefail + source .venv/bin/activate + echo "Pruning rolling-tag assets older than ${TE_WHEELS_KEEP_DAYS} days" + python3 "${REL_SCRIPT}" prune-assets --tag "te-rocm-wheels" --keep-days "${TE_WHEELS_KEEP_DAYS}" + + - name: Publish wheel to dated weekly release tag + shell: bash + env: + GITHUB_TOKEN: ${{ github.token }} + run: | + set -euo pipefail + source .venv/bin/activate + + DATE_UTC="$(date -u +%Y-%m-%d)" + WEEKLY_TAG="te-rocm-wheels-${DATE_UTC}-${TE_SHA}" + WEEKLY_TITLE="ROCm TransformerEngine wheels ${DATE_UTC} (TE ${TE_SHA})" + # Keep this YAML-safe (no unindented heredocs inside `run: |`). + WEEKLY_BODY="$( + printf '%s\n\nROCm: %s\nPython: %s\nArch: %s (gfx=%s)\n' \ + "Built from ROCm/TransformerEngine dev @ ${TE_SHA} on ${DATE_UTC}." \ + "${ROCM_NUM}" "${PYTAG}" "${SELECTOR}" "${GFX_ARCH}" + )" + + python3 "${REL_SCRIPT}" upload --no-prerelease --tag "${WEEKLY_TAG}" --title "${WEEKLY_TITLE}" --body "${WEEKLY_BODY}" --file "${TE_WHEEL_FILE}" + + - name: Prune old weekly releases + shell: bash + env: + GITHUB_TOKEN: ${{ github.token }} + run: | + set -euo pipefail + source .venv/bin/activate + echo "Pruning weekly release pages older than ${TE_WHEELS_KEEP_DAYS} days" + python3 "${REL_SCRIPT}" prune-releases --prefix "te-rocm-wheels-" --keep-days "${TE_WHEELS_KEEP_DAYS}" diff --git a/.github/workflows/run_tests_against_package.yml b/.github/workflows/run_tests_against_package.yml index 7817e2d678..9a21f5a06a 100644 --- a/.github/workflows/run_tests_against_package.yml +++ b/.github/workflows/run_tests_against_package.yml @@ -40,7 +40,7 @@ on: default: '' is_scheduled_run: required: true - type: string + type: boolean xla_python_client_mem_fraction: required: true type: string @@ -65,6 +65,17 @@ on: description: 'Git SHA to checkout if MaxText is not pre-installed' required: false type: string + decoupled_mode: + required: false + type: boolean + default: false + requirements_file: + required: false + type: string + default: '' + extra_pip_deps_file: + required: false + type: string # Flag to skip source checkout and wheel installation maxtext_installed: description: 'If false, maxtext_sha must be provided for checkout' @@ -75,19 +86,24 @@ permissions: contents: read jobs: run: - runs-on: ${{ inputs.cloud_runner != '' && inputs.cloud_runner || fromJson(format('["self-hosted", "{0}", "{1}"]', inputs.device_type, inputs.device_name)) }} + runs-on: ${{ inputs.cloud_runner != '' && inputs.cloud_runner || (inputs.device_type == 'rocm' && fromJson('["self-hosted","linux-x86-64-4gpu-amd"]')) || fromJson(format('["self-hosted", "{0}", "{1}"]', inputs.device_type, inputs.device_name)) }} + timeout-minutes: ${{ inputs.device_type == 'rocm' && 90 || 360 }} container: - image: gcr.io/tpu-prod-env-multipod/${{ inputs.base_image }} + image: ${{ inputs.device_type == 'rocm' && 'ghcr.io/rocm/jax-base-ubu24.rocm720:latest' || format('gcr.io/tpu-prod-env-multipod/{0}', inputs.base_image) }} env: XLA_PYTHON_CLIENT_MEM_FRACTION: ${{ inputs.xla_python_client_mem_fraction }} TF_FORCE_GPU_ALLOW_GROWTH: ${{ inputs.tf_force_gpu_allow_growth }} TPU_SKIP_MDS_QUERY: ${{ inputs.device_type == 'cpu' && '1' || '' }} + # ROCm installs the wheel with --no-deps plus a pinned requirements file; never use TPU wheel extras. MAXTEXT_PACKAGE_EXTRA: >- ${{ - !contains(inputs.pytest_marker, 'not post_training') && 'tpu-post-train' + inputs.device_type == 'rocm' && 'rocm' + || !contains(inputs.pytest_marker, 'not post_training') && 'tpu-post-train' || (inputs.device_type == 'cpu' && 'tpu' || inputs.device_type) }} ALLOW_MULTIPLE_LIBTPU_LOAD: ${{ inputs.device_type == 'cpu' && 'true' || '' }} # bypass /tmp/libtpu_lockfile check for cpu tests, which don't actually use accelerators (to allow concurrency) + DECOUPLE_GCLOUD: ${{ inputs.decoupled_mode && 'TRUE' || '' }} + LOCAL_GCLOUD_PROJECT: ${{ inputs.decoupled_mode && 'ci-decoupled' || '' }} options: ${{ inputs.container_resource_option }} steps: - name: Checkout MaxText @@ -104,20 +120,79 @@ jobs: if: ${{ !inputs.maxtext_installed }} shell: bash run: | + if [ "${{ inputs.device_type }}" = "rocm" ]; then + python3 -m pip install -U uv + fi python3 -m uv venv --seed source .venv/bin/activate maxtext_wheel=$(ls maxtext-*-py3-none-any.whl 2>/dev/null) echo "Installing ${maxtext_wheel} for ${MAXTEXT_PACKAGE_EXTRA}..." - uv pip install ${maxtext_wheel}[${MAXTEXT_PACKAGE_EXTRA}] --resolution=lowest - if [ "${MAXTEXT_PACKAGE_EXTRA}" == "tpu-post-train" ]; then - install_tpu_post_train_extra_deps + if [ "${{ inputs.device_type }}" = "rocm" ]; then + if [ -n "${{ inputs.requirements_file }}" ]; then + echo "Installing requirements from ${{ inputs.requirements_file }}" + uv pip install -r "${{ inputs.requirements_file }}" + fi + uv pip install ${maxtext_wheel} --no-deps + # When a requirements file is set (e.g. decoupled or rocm-unit JAX stacks), it already + # carries test/runtime pins; otherwise add GitHub-sourced pins (no TPU extra / scripts). + if [ -z "${{ inputs.requirements_file }}" ]; then + uv pip install -r src/dependencies/extra_deps/pre_train_github_deps.txt --no-deps + fi else - install_tpu_pre_train_extra_deps + if [ -n "${{ inputs.requirements_file }}" ]; then + echo "Installing requirements from ${{ inputs.requirements_file }}" + uv pip install -r "${{ inputs.requirements_file }}" + uv pip install ${maxtext_wheel} --no-deps + else + uv pip install ${maxtext_wheel}[${MAXTEXT_PACKAGE_EXTRA}] --resolution=lowest + fi + if [ "${MAXTEXT_PACKAGE_EXTRA}" == "tpu-post-train" ]; then + install_tpu_post_train_extra_deps + else + install_tpu_pre_train_extra_deps + fi fi python3 --version python3 -m pip freeze + - name: Install extra pip deps + if: inputs.extra_pip_deps_file != '' + shell: bash + run: | + source .venv/bin/activate + uv pip install -r ${{ inputs.extra_pip_deps_file }} + + - name: Select ROCm arch (mi300/mi355) + if: ${{ inputs.device_type == 'rocm' }} + shell: bash + run: | + set -euo pipefail + echo "=== ROCm arch selection (from install_te_rocm_wheel.py) ===" + if [ -x .venv/bin/python3 ]; then + te_arch="$( + .venv/bin/python3 .github/workflows/utils/install_te_rocm_wheel.py --print-arch + )" + echo "[te detect_arch] ${te_arch}" + echo "TE_WHEEL_ARCH=${te_arch}" >> "${GITHUB_ENV}" + else + echo "No .venv python found; skipping arch selection." + fi + echo "=========================================================" + + - name: Install Transformer Engine wheel (ROCm) + if: ${{ inputs.device_type == 'rocm' }} + shell: bash + env: + GITHUB_TOKEN: ${{ github.token }} + run: | + source .venv/bin/activate + set -euo pipefail + + python3 .github/workflows/utils/install_te_rocm_wheel.py + + uv pip install --no-deps --force-reinstall transformer_engine-*.whl + - name: Copy test assets files - if: ${{ !inputs.maxtext_installed }} + if: ${{ !inputs.maxtext_installed && !inputs.decoupled_mode }} run : gcloud storage cp gs://maxtext-test-assets/* tests/assets - name: Run Tests shell: bash @@ -149,8 +224,8 @@ jobs: export MAXTEXT_ASSETS_ROOT=$(pwd)/src/maxtext/assets export MAXTEXT_TEST_ASSETS_ROOT=$(pwd)/tests/assets export MAXTEXT_PKG_DIR=$(pwd)/src/maxtext - # omit this libtpu init args for gpu tests - if [ "${INPUTS_DEVICE_TYPE}" != "cuda12" ]; then + # omit this libtpu init args for gpu tests (cuda + rocm) + if [ "${INPUTS_DEVICE_TYPE}" != "cuda12" ] && [ "${INPUTS_DEVICE_TYPE}" != "rocm" ]; then export LIBTPU_INIT_ARGS='--xla_tpu_scoped_vmem_limit_kib=65536' else # For cuda12, explicitly point to the pip-installed CUDA libraries @@ -161,6 +236,9 @@ jobs: echo "Warning: Could not find pinned nvidia libraries in .venv." fi fi + if [ "${INPUTS_DEVICE_TYPE}" = "rocm" ]; then + ulimit -c 0 + fi if [ "${INPUTS_TOTAL_WORKERS}" -gt 1 ]; then $PYTHON_EXE -m pip install --quiet pytest-split pytest-xdist SPLIT_ARGS="--splits ${INPUTS_TOTAL_WORKERS} --group ${INPUTS_WORKER_GROUP} -n auto" diff --git a/.github/workflows/run_tests_coordinator.yml b/.github/workflows/run_tests_coordinator.yml index 163c0d6763..cd06c36172 100644 --- a/.github/workflows/run_tests_coordinator.yml +++ b/.github/workflows/run_tests_coordinator.yml @@ -27,7 +27,9 @@ on: tpu-post-training-unit, tpu-post-training-integration, gpu-unit, gpu-integration, cpu-unit, - cpu-post-training-unit + cpu-post-training-unit, + rocm-unit, + rocm-decoupled ) required: true type: string @@ -71,7 +73,9 @@ jobs: "gpu-unit": "cuda12", "gpu-integration": "cuda12", "cpu-unit": "cpu", - "cpu-post-training-unit": "cpu" + "cpu-post-training-unit": "cpu", + "rocm-unit": "rocm", + "rocm-decoupled": "rocm" }')[inputs.flavor] }} device_name: >- @@ -83,7 +87,9 @@ jobs: "gpu-unit": "a100-40gb-4", "gpu-integration": "a100-40gb-4", "cpu-unit": "X64", - "cpu-post-training-unit": "X64" + "cpu-post-training-unit": "X64", + "rocm-unit": "mi355", + "rocm-decoupled": "mi355" }')[inputs.flavor] }} cloud_runner: >- @@ -95,7 +101,9 @@ jobs: "gpu-unit": "linux-x86-a2-48-a100-4gpu", "gpu-integration": "linux-x86-a2-48-a100-4gpu", "cpu-unit": "linux-x86-n2-16", - "cpu-post-training-unit": "linux-x86-n2-16" + "cpu-post-training-unit": "linux-x86-n2-16", + "rocm-unit": "linux-x86-64-4gpu-amd", + "rocm-decoupled": "linux-x86-64-4gpu-amd" }')[inputs.flavor] }} # Pytest Marker Mapping pytest_marker: >- @@ -107,7 +115,9 @@ jobs: "gpu-unit": "not cpu_only and not tpu_only and not integration_test and not post_training", "gpu-integration": "not cpu_only and not tpu_only and integration_test and not post_training", "cpu-unit": "cpu_only and not post_training", - "cpu-post-training-unit": "cpu_only and post_training" + "cpu-post-training-unit": "cpu_only and post_training", + "rocm-unit": "not cpu_only and not tpu_only and not decoupled", + "rocm-decoupled": "not cpu_only and not tpu_only and decoupled" }')[inputs.flavor] }} pytest_addopts: >- @@ -119,7 +129,9 @@ jobs: "gpu-unit": "", "gpu-integration": "", "cpu-unit": "", - "cpu-post-training-unit": "tests/post_training/unit tests/unit" + "cpu-post-training-unit": "tests/post_training/unit tests/unit", + "rocm-unit": "", + "rocm-decoupled": "" }')[inputs.flavor] }} pytest_extra_args: >- @@ -131,17 +143,43 @@ jobs: "gpu-unit": "--ignore=tests/post_training", "gpu-integration": "--ignore=tests/post_training", "cpu-unit": "--ignore=tests/post_training", - "cpu-post-training-unit": "" + "cpu-post-training-unit": "", + "rocm-unit": "", + "rocm-decoupled": "" }')[inputs.flavor] }} # Resource Scaling - xla_python_client_mem_fraction: "${{ contains(inputs.flavor, 'gpu') && '0.65' || '0.75' }}" - tf_force_gpu_allow_growth: "${{ contains(inputs.flavor, 'gpu') && 'true' || 'false' }}" + xla_python_client_mem_fraction: >- + ${{ fromJSON('{ + "gpu-unit": "0.65", + "gpu-integration": "0.65", + "rocm-unit": "0.9", + "rocm-decoupled": "0.9" + }')[inputs.flavor] || '0.75' }} + + tf_force_gpu_allow_growth: >- + ${{ fromJSON('{ + "gpu-unit": "true", + "gpu-integration": "true", + "rocm-unit": "true", + "rocm-decoupled": "true" + }')[inputs.flavor] || 'false' }} container_resource_option: >- - ${{ contains(inputs.flavor, 'gpu') - && '--shm-size 2g --runtime=nvidia --gpus all --privileged' - || '--privileged' }} + ${{ fromJSON('{ + "gpu-unit": "--shm-size 2g --runtime=nvidia --gpus all --privileged", + "gpu-integration": "--shm-size 2g --runtime=nvidia --gpus all --privileged", + "rocm-unit": "--device=/dev/kfd --device=/dev/dri --group-add video --ipc=host --shm-size 64g --cap-add=SYS_PTRACE --security-opt seccomp=unconfined --privileged", + "rocm-decoupled": "--device=/dev/kfd --device=/dev/dri --group-add video --ipc=host --shm-size 64g --cap-add=SYS_PTRACE --security-opt seccomp=unconfined --privileged" + }')[inputs.flavor] || '--privileged' }} + + # ROCm-specific parameters + decoupled_mode: ${{ contains(inputs.flavor, 'rocm') }} + requirements_file: >- + ${{ fromJSON('{ + "rocm-unit": "src/dependencies/requirements/requirements_rocm_jax_0.8.2.txt", + "rocm-decoupled": "src/dependencies/requirements/requirements_decoupled_rocm_jax_0_8_2.txt" + }')[inputs.flavor] || '' }} # Metadata base_image: ${{ inputs.base_image }} diff --git a/.github/workflows/upstream_sync.yml b/.github/workflows/upstream_sync.yml new file mode 100644 index 0000000000..d19fe20053 --- /dev/null +++ b/.github/workflows/upstream_sync.yml @@ -0,0 +1,51 @@ +name: Upstream Sync + +on: + workflow_dispatch: {} + schedule: + - cron: '30 4 * * *' # Daily 04:30 UTC + +permissions: + contents: write + +jobs: + sync: + runs-on: ubuntu-latest + if: github.ref == 'refs/heads/main' + steps: + - name: Checkout + uses: actions/checkout@v4 + with: + fetch-depth: 0 + persist-credentials: false + + - name: Configure git + run: | + git config user.name "github-actions[bot]" + git config user.email "github-actions[bot]@users.noreply.github.com" + + - name: Fetch upstream + run: | + git remote add upstream https://github.com/google/maxtext.git 2>/dev/null || true + git fetch upstream main + git checkout main + git pull --ff-only origin main || true + + - name: Merge upstream/main + id: merge_step + run: | + set -e + git merge --no-edit upstream/main || { echo "::error::Merge conflict - resolve manually"; git merge --abort || true; exit 1; } + if git diff --quiet origin/main..main; then echo "no_changes=true" >> $GITHUB_OUTPUT; else echo "no_changes=false" >> $GITHUB_OUTPUT; fi + + - name: Push (if changed) + if: steps.merge_step.outputs.no_changes == 'false' + env: + PAT: ${{ secrets.UPSTREAM_SYNC_TOKEN }} + run: | + [ -z "$PAT" ] && echo "::error::Missing UPSTREAM_SYNC_TOKEN secret" && exit 1 + git push https://x-access-token:$PAT@github.com/${{ github.repository }}.git main + + - name: Result + run: | + if [ "${{ steps.merge_step.outputs.no_changes }}" = "true" ]; then echo "Up to date"; else echo "Synced upstream"; fi diff --git a/.github/workflows/utils/install_te_rocm_wheel.py b/.github/workflows/utils/install_te_rocm_wheel.py new file mode 100644 index 0000000000..8f601884fd --- /dev/null +++ b/.github/workflows/utils/install_te_rocm_wheel.py @@ -0,0 +1,152 @@ +#!/usr/bin/env python3 +""" +ROCm CI helper: +- Detect MI300 vs MI355 +- Prefer wheel from this repo's 'te-rocm-wheels' release assets +- Fallback to pinned ROCm/maxtext release assets (arch-specific) + +This script only downloads the wheel into the current working directory. +The caller should then install it (e.g. `uv pip install transformer_engine-*.whl`). +""" + +from __future__ import annotations + +import argparse +import json +import os +import re +import subprocess +import urllib.error +import urllib.request + + +def _run(cmd: str) -> str: + try: + return subprocess.check_output(["bash", "-lc", cmd], text=True, stderr=subprocess.STDOUT) + except (subprocess.CalledProcessError, FileNotFoundError, OSError): + return "" + + +def detect_arch() -> str: + """ + Return wheel arch selector: 'mi300' or 'mi355'. + + Notes: + - MI300 family commonly reports gfx942/gfx941. + - MI350/MI355 family commonly reports gfx950 and product strings like MI350X. + """ + override = os.environ.get("TE_WHEEL_ARCH", "").strip().lower() + if override in {"mi300", "mi355"}: + return override + + rocm_smi = _run("command -v rocm-smi >/dev/null 2>&1 && rocm-smi --showproductname || true") or _run( + "[ -x /opt/rocm/bin/rocm-smi ] && /opt/rocm/bin/rocm-smi --showproductname || true" + ) + rocminfo = _run("command -v rocminfo >/dev/null 2>&1 && rocminfo || true") or _run( + "[ -x /opt/rocm/bin/rocminfo ] && /opt/rocm/bin/rocminfo || true" + ) + + blob = f"{rocm_smi}\n{rocminfo}".lower() + gfxs = sorted(set(re.findall(r"gfx[0-9a-f]+", blob))) + + # Prefer explicit gfx IDs when available. + if "gfx950" in gfxs: + return "mi355" + if "gfx942" in gfxs or "gfx941" in gfxs: + return "mi300" + + # Fall back to product string checks. + if "mi355" in blob or "mi350" in blob: + return "mi355" + if "mi300" in blob: + return "mi300" + + # Safe default. + return "mi355" + + +def _headers() -> dict[str, str]: + token = os.environ.get("GITHUB_TOKEN", "") + headers: dict[str, str] = { + "Accept": "application/vnd.github+json", + "User-Agent": "maxtext-ci", + } + if token: + headers["Authorization"] = f"Bearer {token}" + return headers + + +def download(url: str, out_name: str) -> None: + req = urllib.request.Request(url, headers=_headers()) + with urllib.request.urlopen(req) as r, open(out_name, "wb") as f: + f.write(r.read()) + print(f"[te wheel] downloaded {out_name}", flush=True) + + +def try_download_from_te_rocm_wheels(repo: str, arch: str) -> bool: + """Download from this repo's 'te-rocm-wheels' release tag if present.""" + api = f"https://api.github.com/repos/{repo}/releases/tags/te-rocm-wheels" + req = urllib.request.Request(api, headers=_headers()) + with urllib.request.urlopen(req) as r: + rel = json.loads(r.read().decode("utf-8")) + + assets = rel.get("assets", []) + name_re = re.compile(rf"^transformer_engine-.*-1\.{arch}-cp312-cp312-linux_x86_64\.whl$") + matches = [a for a in assets if name_re.match(a.get("name", ""))] + if not matches: + return False + + # Rolling tag keeps many wheels; select newest matching asset. + hit = max(matches, key=lambda a: a.get("created_at", "")) + print( + "[te wheel] selected latest te-rocm-wheels asset: " + f"{hit.get('name', '')} (created_at={hit.get('created_at', 'unknown')})", + flush=True, + ) + + download(hit["browser_download_url"], hit["name"]) + return True + + +def main(argv: list[str] | None = None) -> int: + """Entry point.""" + parser = argparse.ArgumentParser(add_help=True) + parser.add_argument( + "--print-arch", + action="store_true", + help="Print resolved wheel arch (mi300/mi355) and exit.", + ) + args = parser.parse_args(argv) + + if args.print_arch: + print(detect_arch(), flush=True) + return 0 + + repo = os.environ.get("GITHUB_REPOSITORY", "") + if not repo: + print("[te wheel] GITHUB_REPOSITORY not set; skipping.", flush=True) + return 0 + + arch = detect_arch() + print(f"[te wheel] arch={arch}", flush=True) + + # 1) Prefer: this repo's te-rocm-wheels assets. + try: + if try_download_from_te_rocm_wheels(repo, arch): + return 0 + print(f"[te wheel] no te-rocm-wheels asset for arch={arch}", flush=True) + except urllib.error.HTTPError as e: + print(f"[te wheel] te-rocm-wheels not available ({e.code})", flush=True) + except (urllib.error.URLError, json.JSONDecodeError, KeyError, ValueError) as e: + print(f"[te wheel] te-rocm-wheels lookup failed ({e})", flush=True) + + # 2) Fallback: pinned ROCm/maxtext release assets. + arch_tag = f"1.{arch}" + pinned_name = f"transformer_engine-2.8.0.dev0+2776c337-{arch_tag}-cp312-cp312-linux_x86_64.whl" + pinned = "https://github.com/ROCm/maxtext/releases/download/rocm-maxtext-v0.1.1/" f"{pinned_name}" + download(pinned, pinned_name) + return 0 + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/.github/workflows/utils/te_wheels_release.py b/.github/workflows/utils/te_wheels_release.py new file mode 100644 index 0000000000..a2ae1368e2 --- /dev/null +++ b/.github/workflows/utils/te_wheels_release.py @@ -0,0 +1,204 @@ +#!/usr/bin/env python3 +""" +Publish and prune ROCm TransformerEngine wheel assets on GitHub Releases. + +Intended for use in GitHub Actions. Requires: +- GITHUB_TOKEN +- GITHUB_REPOSITORY (owner/repo) +""" + +from __future__ import annotations + +import argparse +import datetime as dt +import json +import os +import sys +import urllib.error +import urllib.parse +import urllib.request + +API = "https://api.github.com" + + +def _env_token() -> str: + token = os.environ.get("GITHUB_TOKEN") + if not token: + raise SystemExit("GITHUB_TOKEN is not set.") + return token + + +def _env_repo() -> tuple[str, str]: + repo = os.environ.get("GITHUB_REPOSITORY") + if not repo or "/" not in repo: + raise SystemExit("GITHUB_REPOSITORY is not set (expected 'owner/repo').") + owner, name = repo.split("/", 1) + return owner, name + + +def _headers(token: str) -> dict[str, str]: + return { + "Authorization": f"Bearer {token}", + "Accept": "application/vnd.github+json", + "User-Agent": "maxtext-ci", + } + + +def _request_json(method: str, url: str, token: str, body: dict | None = None): + data = None + if body is not None: + data = json.dumps(body).encode("utf-8") + req = urllib.request.Request(url, data=data, method=method, headers=_headers(token)) + with urllib.request.urlopen(req) as r: + raw = r.read() + return json.loads(raw.decode("utf-8")) if raw else None + + +def _request_raw(method: str, url: str, token: str, data: bytes, content_type: str): + h = _headers(token) + h["Content-Type"] = content_type + req = urllib.request.Request(url, data=data, method=method, headers=h) + with urllib.request.urlopen(req) as r: + return r.read() + + +def get_or_create_release(tag: str, title: str, body: str, prerelease: bool) -> dict: + """Get a release by tag, or create it if missing. + + Args: + tag: Release tag (e.g. 'te-rocm-wheels'). + title: Release title. + body: Release body text. + prerelease: Whether to mark the release as a prerelease. + + Returns: + The GitHub release object JSON. + """ + token = _env_token() + owner, name = _env_repo() + try: + rel = _request_json("GET", f"{API}/repos/{owner}/{name}/releases/tags/{tag}", token) + if rel: + return rel + except urllib.error.HTTPError as e: + if e.code != 404: + raise + return _request_json( + "POST", + f"{API}/repos/{owner}/{name}/releases", + token, + {"tag_name": tag, "name": title, "body": body, "prerelease": prerelease}, + ) + + +def upload_asset(tag: str, title: str, body: str, file_path: str, prerelease: bool) -> None: + """Upload (replace) a release asset under the given tag. + + If an asset with the same filename already exists, it is deleted first. + """ + token = _env_token() + owner, name = _env_repo() + + rel = get_or_create_release(tag, title, body, prerelease) + release_id = rel["id"] + upload_url = rel["upload_url"].split("{", 1)[0] + + assets = _request_json("GET", f"{API}/repos/{owner}/{name}/releases/{release_id}", token)["assets"] + file_name = os.path.basename(file_path) + for a in assets: + if a.get("name") == file_name: + _request_json("DELETE", f"{API}/repos/{owner}/{name}/releases/assets/{a['id']}", token) + + with open(file_path, "rb") as f: + data = f.read() + up = f"{upload_url}?{urllib.parse.urlencode({'name': file_name})}" + _request_raw("POST", up, token, data, "application/octet-stream") + print(f"Uploaded {file_name} to release tag {tag}", flush=True) + + +def prune_assets(tag: str, keep_days: int) -> None: + """Delete assets older than `keep_days` from the given release tag.""" + token = _env_token() + owner, name = _env_repo() + try: + rel = _request_json("GET", f"{API}/repos/{owner}/{name}/releases/tags/{tag}", token) + except urllib.error.HTTPError as e: + if e.code == 404: + print(f"No release for tag {tag}; skipping asset prune.", flush=True) + return + raise + + cutoff = dt.datetime.now(dt.timezone.utc) - dt.timedelta(days=keep_days) + pruned = 0 + for a in rel.get("assets", []): + created = dt.datetime.fromisoformat(a["created_at"].replace("Z", "+00:00")) + if created < cutoff: + _request_json("DELETE", f"{API}/repos/{owner}/{name}/releases/assets/{a['id']}", token) + pruned += 1 + print(f"Pruned old asset: {a['name']} (created_at={a['created_at']})", flush=True) + print(f"Asset prune complete for {tag}. Deleted {pruned} assets older than {keep_days} days.", flush=True) + + +def prune_releases(prefix: str, keep_days: int) -> None: + """Delete releases with tag names starting with `prefix` older than `keep_days` days.""" + token = _env_token() + owner, name = _env_repo() + cutoff = dt.datetime.now(dt.timezone.utc) - dt.timedelta(days=keep_days) + + page = 1 + deleted = 0 + while True: + rels = _request_json("GET", f"{API}/repos/{owner}/{name}/releases?per_page=100&page={page}", token) + if not rels: + break + for rel in rels: + tag = rel.get("tag_name", "") + if not tag.startswith(prefix): + continue + created = dt.datetime.fromisoformat(rel["created_at"].replace("Z", "+00:00")) + if created < cutoff: + _request_json("DELETE", f"{API}/repos/{owner}/{name}/releases/{rel['id']}", token) + deleted += 1 + print(f"Deleted old release {tag} (created_at={rel['created_at']})", flush=True) + page += 1 + print(f"Release prune complete. Deleted {deleted} releases older than {keep_days} days.", flush=True) + + +def main(argv: list[str]) -> int: + p = argparse.ArgumentParser() + sub = p.add_subparsers(dest="cmd", required=True) + + up = sub.add_parser("upload") + up.add_argument("--tag", required=True) + up.add_argument("--title", required=True) + up.add_argument("--body", required=True) + up.add_argument("--file", required=True) + prg = up.add_mutually_exclusive_group() + prg.add_argument("--prerelease", dest="prerelease", action="store_true") + prg.add_argument("--no-prerelease", dest="prerelease", action="store_false") + up.set_defaults(prerelease=True) + + pa = sub.add_parser("prune-assets") + pa.add_argument("--tag", required=True) + pa.add_argument("--keep-days", type=int, required=True) + + pr = sub.add_parser("prune-releases") + pr.add_argument("--prefix", required=True) + pr.add_argument("--keep-days", type=int, required=True) + + args = p.parse_args(argv) + + if args.cmd == "upload": + upload_asset(args.tag, args.title, args.body, args.file, prerelease=args.prerelease) + return 0 + if args.cmd == "prune-assets": + prune_assets(args.tag, args.keep_days) + return 0 + if args.cmd == "prune-releases": + prune_releases(args.prefix, args.keep_days) + return 0 + raise AssertionError("unreachable") + + +if __name__ == "__main__": + raise SystemExit(main(sys.argv[1:])) diff --git a/src/dependencies/requirements/requirements_decoupled_jax_0_8_2.txt b/src/dependencies/requirements/requirements_decoupled_jax_0_8_2.txt new file mode 100644 index 0000000000..9479797cb6 --- /dev/null +++ b/src/dependencies/requirements/requirements_decoupled_jax_0_8_2.txt @@ -0,0 +1,40 @@ +absl-py +aqtp +array-record +datasets +flax +grain[parquet] +huggingface_hub +jax==0.8.2 +jaxlib==0.8.2 +jaxtyping +jsonlines +ml-collections +numpy +omegaconf +optax +orbax-checkpoint +pathwaysutils +pillow +pre-commit +protobuf +pydantic +pyink +pylint +pytest +parameterized +pytype +sentencepiece +tensorboardx +tensorflow-datasets +tensorflow-text<2.20 +tensorflow<2.20 +tiktoken +numba>=0.59.0 +google-tunix +tokamax @ git+https://github.com/openxla/tokamax.git@69b328121c3ee120d8e54cf26d1565cea189617f +transformers>=4.57.3,<5 +qwix +mlperf-logging @ https://github.com/mlcommons/logging/archive/38ab22670527888c8eb7825a4ece176fcc36a95d.zip +chex>=0.1.91 +drjax>=0.1.4 diff --git a/src/dependencies/requirements/requirements_decoupled_rocm_jax_0_7.1.txt b/src/dependencies/requirements/requirements_decoupled_rocm_jax_0_7.1.txt new file mode 100644 index 0000000000..4d1732349c --- /dev/null +++ b/src/dependencies/requirements/requirements_decoupled_rocm_jax_0_7.1.txt @@ -0,0 +1,47 @@ +absl_py>=2.3.1 +aqtp>=0.9.0 +chex>=0.1.90 +datasets>=4.2.0 +etils>=1.13.0 +evaluate>=0.4.6 +flax +grain>=0.2.12 +grpcio>=1.75.1 +huggingface_hub>=0.35.3 +https://github.com/ROCm/rocm-jax/releases/download/rocm-jax-v0.7.1/jaxlib-0.7.1-cp312-cp312-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl +jax==0.7.1 +jax-rocm7-pjrt==0.7.1 +jax-rocm7-plugin==0.7.1 +jaxtyping>=0.3.3 +jsonlines>=4.0.0 +matplotlib>=3.10.3 +ml_collections>=1.1.0 +ml_dtypes>=0.5.3 +nltk>=3.9.2 +numpy>=2.0.2 +omegaconf>=2.3.0 +optax>=0.2.6 +orbax-checkpoint>=0.11.25 +pandas>=2.3.3 +parameterized==0.9.0 +pathwaysutils>=0.1.3 +pillow>=11.3.0 +protobuf>=5.29.5 +psutil>=7.0.0 +pytest>=8.4.1 +PyYAML>=6.0.3 +Requests>=2.32.5 +qwix>=0.1.1 +safetensors>=0.6.2 +sentencepiece>=0.2.1 +setuptools>=80.9.0 +tabulate>=0.9.0 +tensorflow>=2.19.1 +tensorflow_text>=2.19.0 +tensorflow_datasets>=4.9.9 +tensorstore>=0.1.76 +tiktoken>=0.12.0 +tqdm>=4.67.1 +transformers>=4.57.0 +urllib3>=2.5.0 +git+https://github.com/google/tunix.git diff --git a/src/dependencies/requirements/requirements_decoupled_rocm_jax_0_8_2.txt b/src/dependencies/requirements/requirements_decoupled_rocm_jax_0_8_2.txt new file mode 100644 index 0000000000..f821160f03 --- /dev/null +++ b/src/dependencies/requirements/requirements_decoupled_rocm_jax_0_8_2.txt @@ -0,0 +1,47 @@ +absl-py +aqtp +array-record +datasets +flax +grain[parquet] +huggingface_hub + + +# ROCm JAX 0.8.2 (py312) wheels (install order matters) +https://github.com/ROCm/rocm-jax/releases/download/rocm-jax-v0.8.2/jax_rocm7_pjrt-0.8.2+rocm7.2.0-py3-none-manylinux_2_28_x86_64.whl +https://github.com/ROCm/rocm-jax/releases/download/rocm-jax-v0.8.2/jax_rocm7_plugin-0.8.2+rocm7.2.0-cp312-cp312-manylinux_2_28_x86_64.whl +https://github.com/ROCm/rocm-jax/releases/download/rocm-jax-v0.8.2/jaxlib-0.8.2+rocm7-cp312-cp312-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl +jax==0.8.2 + + +jaxtyping +jsonlines +ml-collections +numpy +omegaconf +optax +orbax-checkpoint +pathwaysutils +pillow +pre-commit +protobuf +pydantic +pyink +pylint +pytest +parameterized +pytype +sentencepiece +tensorboardx +tensorflow-datasets +tensorflow-text<2.20 +tensorflow<2.20 +tiktoken +numba>=0.59.0 +google-tunix +tokamax @ git+https://github.com/openxla/tokamax.git@69b328121c3ee120d8e54cf26d1565cea189617f +transformers>=4.57.3,<5 +qwix +mlperf-logging @ https://github.com/mlcommons/logging/archive/38ab22670527888c8eb7825a4ece176fcc36a95d.zip +chex>=0.1.91 +drjax>=0.1.4 diff --git a/src/dependencies/requirements/requirements_rocm_jax_0.8.2.txt b/src/dependencies/requirements/requirements_rocm_jax_0.8.2.txt new file mode 100644 index 0000000000..bffcc3a203 --- /dev/null +++ b/src/dependencies/requirements/requirements_rocm_jax_0.8.2.txt @@ -0,0 +1,60 @@ +absl-py +aqtp +array-record +datasets +flax +grain[parquet] +huggingface_hub + + +# ROCm JAX 0.8.2 (py312) wheels (install order matters) +https://github.com/ROCm/rocm-jax/releases/download/rocm-jax-v0.8.2/jax_rocm7_pjrt-0.8.2+rocm7.2.0-py3-none-manylinux_2_28_x86_64.whl +https://github.com/ROCm/rocm-jax/releases/download/rocm-jax-v0.8.2/jax_rocm7_plugin-0.8.2+rocm7.2.0-cp312-cp312-manylinux_2_28_x86_64.whl +https://github.com/ROCm/rocm-jax/releases/download/rocm-jax-v0.8.2/jaxlib-0.8.2+rocm7-cp312-cp312-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl +jax==0.8.2 + + +jaxtyping +jsonlines +ml-collections +numpy +omegaconf +optax +orbax-checkpoint +pathwaysutils +pillow +pre-commit +protobuf +pydantic +pyink +pylint +pytest +parameterized +pytype +sentencepiece +tensorboardx +tensorflow-datasets +tensorflow-text<2.20 +tensorflow<2.20 +tiktoken +numba>=0.59.0 +google-tunix +tokamax @ git+https://github.com/openxla/tokamax.git@69b328121c3ee120d8e54cf26d1565cea189617f +transformers>=4.57.3,<5 +qwix +mlperf-logging @ https://github.com/mlcommons/logging/archive/38ab22670527888c8eb7825a4ece176fcc36a95d.zip + + +# GCP / Cloud deps (restored for non-decoupled ROCm runs) +cloud-accelerator-diagnostics +cloud-tpu-diagnostics +gcsfs +google-api-python-client +google-cloud-aiplatform +google-cloud-mldiagnostics +google-cloud-monitoring +ml-goodput-measurement +tensorboard-plugin-profile +google-jetstream @ https://github.com/AI-Hypercomputer/JetStream/archive/29329e8e73820993f77cfc8efe34eb2a73f5de98.zip +chex>=0.1.91 +drjax>=0.1.4 diff --git a/src/maxtext/configs/base.yml b/src/maxtext/configs/base.yml index c88fb4767c..d0205a869e 100644 --- a/src/maxtext/configs/base.yml +++ b/src/maxtext/configs/base.yml @@ -176,7 +176,7 @@ logits_dot_in_fp32: false # whether to use fp32 in logits_dense or shared_embed cast_logits_to_fp32: true # whether to cast the logits to fp32. the higher precision is generally beneficial, but it can vary slightly. float32_qk_product: false # in dot_product attention, whether to cast to fp32 the inputs to qk product float32_logits: false # in dot_product attention, whether to cast to fp32 the inputs to softmax -float32_weight_sum: true # whether to use full fp32 precision to sum expert weights for numerical stability +float32_weight_sum: false # whether to use fp32 for MoE expert weight summation; true adds ~2 GB f32 temp per device float32_gate_logits: false # whether to cast inputs to fp32 to compute MoE gate logits for numerical stability # multi-token prediction configs diff --git a/src/maxtext/configs/types.py b/src/maxtext/configs/types.py index af809e2b42..7a469031d5 100644 --- a/src/maxtext/configs/types.py +++ b/src/maxtext/configs/types.py @@ -691,7 +691,7 @@ class MoEGeneral(BaseModel): description="Enable top-k probability normalization for router weights (Qwen3-specific).", ) float32_weight_sum: bool = Field( - True, + False, description="Whether to use full fp32 precision to sum expert weights for numerical stability.", ) float32_gate_logits: bool = Field( diff --git a/src/maxtext/kernels/gather_reduce_sc.py b/src/maxtext/kernels/gather_reduce_sc.py index cfcf67a089..c68c5ed4ce 100644 --- a/src/maxtext/kernels/gather_reduce_sc.py +++ b/src/maxtext/kernels/gather_reduce_sc.py @@ -55,7 +55,8 @@ def __getitem__(self, shape): _BF16 = VectorTypeHelper(ir.BF16Type.get) -@jax.jit( +@functools.partial( + jax.jit, static_argnames=[ "reduce_group_size", "single_sc", diff --git a/src/maxtext/layers/attention_op.py b/src/maxtext/layers/attention_op.py index 2252ceceb2..0e3fcaefb3 100644 --- a/src/maxtext/layers/attention_op.py +++ b/src/maxtext/layers/attention_op.py @@ -1580,13 +1580,22 @@ def cudnn_flash_attention( dummy_attn_mask = None mask_type = "causal" else: - # Default case: no packing, no context parallelism - dummy_attn_mask = jnp.zeros( - (1, 1, 1, self.max_target_length, self.max_target_length), - dtype=jnp.uint8, - ) - attn_mask = self.generate_attention_mask(query, key, decoder_segment_ids, model_mode) - attn_mask = jnp.where((attn_mask >= DEFAULT_MASK_VALUE * 0.5), 0, 1).astype(jnp.uint8) + # Default case: no packing, no context parallelism. + # For synthetic data, segment IDs are always all-ones (one segment per sequence), so + # the segment mask is all-True and the combined mask reduces to pure causal masking. + # Use mask_type="causal" directly to avoid materializing f32/s32[seq,seq] tensors that + # XLA loop_broadcast_fusion hoists into the pipeline scan carry (+5 GiB temp memory). + if self.config.dataset_type == "synthetic": + attn_mask = None + dummy_attn_mask = None + mask_type = "causal" + else: + dummy_attn_mask = jnp.zeros( + (1, 1, 1, self.max_target_length, self.max_target_length), + dtype=jnp.uint8, + ) + attn_mask = self.generate_attention_mask(query, key, decoder_segment_ids, model_mode) + attn_mask = jnp.where((attn_mask >= DEFAULT_MASK_VALUE * 0.5), 0, 1).astype(jnp.uint8) dpa_layer = DotProductAttention( head_dim=head_dim, @@ -1599,12 +1608,13 @@ def cudnn_flash_attention( dtype=self.dtype, float32_logits=self.float32_logits, qkv_layout=qkv_layout, - scale_factor=1.0, + # scale_factor omitted: TE default (None) auto-computes 1/sqrt(head_dim). + # Explicitly passing 1.0 disables QK scaling — see TE DotProductAttention docs. transpose_batch_sequence=False, window_size=sliding_window_size, context_parallel_causal_load_balanced=self.config.context_parallel_load_balance, context_parallel_axis=self.config.context_sharding, - context_parallel_strategy=self.config.context_parallel_strategy, + # context_parallel_strategy omitted: not supported in installed TE 2.6.x. max_segments_per_seq=max_segments_per_seq, ) diff --git a/src/maxtext/layers/attentions.py b/src/maxtext/layers/attentions.py index 8c6a4be596..412ab356e6 100644 --- a/src/maxtext/layers/attentions.py +++ b/src/maxtext/layers/attentions.py @@ -561,6 +561,7 @@ def __init__( mesh=mesh, shard_mode=config.shard_mode, debug_sharding=config.debug_sharding, + skip_trivial_specs=True, ) def _init_projections(self, inputs_q_shape: Tuple, inputs_kv_shape: Tuple) -> None: diff --git a/src/maxtext/layers/decoders.py b/src/maxtext/layers/decoders.py index eb5630968f..854944f1c7 100644 --- a/src/maxtext/layers/decoders.py +++ b/src/maxtext/layers/decoders.py @@ -299,6 +299,7 @@ class Decoder(nn.Module): config: Config mesh: Mesh + shared_embedding: nn.Module | nnx.Module | None = None quant: None | Quant = None model_mode: str = MODEL_MODE_TRAIN @@ -619,7 +620,6 @@ def get_layer_to_pipeline(blocks, cfg): @nn.compact def _apply_embedding( self, - shared_embedding: nn.Module | nnx.Module, decoder_input_tokens, decoder_positions, deterministic, @@ -629,7 +629,7 @@ def _apply_embedding( """Applies token and positional embeddings to the input tokens.""" cfg = self.config - y = shared_embedding(decoder_input_tokens.astype("int32"), model_mode=model_mode) + y = self.shared_embedding(decoder_input_tokens.astype("int32"), model_mode=model_mode) # Merge the image embeddings with the text embeddings for multimodal models if multimodal_input is not None: @@ -690,7 +690,7 @@ def _apply_embedding( return y @nn.compact - def apply_output_head(self, shared_embedding: nn.Module | nnx.Module, y, deterministic, model_mode): + def apply_output_head(self, y, deterministic, model_mode): """Applies final normalization and projects hidden states to logits.""" cfg = self.config @@ -719,10 +719,10 @@ def apply_output_head(self, shared_embedding: nn.Module | nnx.Module, y, determi # [batch, length, emb_dim] -> [batch, length, vocab_size] if cfg.logits_via_embedding: # Use the transpose of embedding matrix for logit transform. - if isinstance(shared_embedding, nnx.Module): - embedding_table = shared_embedding.embedding.value + if isinstance(self.shared_embedding, nnx.Module): + embedding_table = self.shared_embedding.embedding.value else: - embedding_table = shared_embedding.variables["params"]["embedding"] + embedding_table = self.shared_embedding.variables["params"]["embedding"] if isinstance(embedding_table, nn.spmd.LogicallyPartitioned): embedding_table = embedding_table.unbox() attend_dtype = jnp.float32 if cfg.logits_dot_in_fp32 else cfg.dtype @@ -750,6 +750,13 @@ def apply_output_head(self, shared_embedding: nn.Module | nnx.Module, y, determi out_sharding=out_sharding, ) # We do not quantize the logits matmul. + if model_mode in (MODEL_MODE_PREFILL, MODEL_MODE_AUTOREGRESSIVE): + logits = nn.with_logical_constraint(logits, (None, None, "activation_vocab")) + else: + logits = nn.with_logical_constraint( + logits, ("activation_embed_and_logits_batch", "activation_length", "activation_vocab") + ) + if self.config.cast_logits_to_fp32: logits = logits.astype(jnp.float32) @@ -758,7 +765,6 @@ def apply_output_head(self, shared_embedding: nn.Module | nnx.Module, y, determi @nn.compact def __call__( self, - shared_embedding: nn.Module | nnx.Module, decoder_input_tokens, decoder_positions, decoder_segment_ids=None, @@ -778,7 +784,6 @@ def __call__( # [batch, length] -> [batch, length, emb_dim] y = self._apply_embedding( - shared_embedding, decoder_input_tokens, decoder_positions, deterministic, @@ -1111,7 +1116,7 @@ def __call__( # When initializing with vLLM RPA attention, we need to run the output head to # initialize any parameters associated with it. if self.is_initializing() and cfg.attention == "vllm_rpa": - _ = self.apply_output_head(shared_embedding, hidden_state, deterministic, model_mode) + _ = self.apply_output_head(hidden_state, deterministic, model_mode) # When invoking from vLLM with RPA attention, logit computation is deferred to a later stage. if cfg.attention == "vllm_rpa": @@ -1130,7 +1135,7 @@ def __call__( self.sow("intermediates", "hidden_states", hidden_state) else: - logits = self.apply_output_head(shared_embedding, hidden_state, deterministic, model_mode) + logits = self.apply_output_head(hidden_state, deterministic, model_mode) # The API of the Decoder is now a tuple, providing both the main output # and the raw hidden state needed for auxiliary tasks. diff --git a/src/maxtext/layers/embeddings.py b/src/maxtext/layers/embeddings.py index f933d27440..810bffc9da 100644 --- a/src/maxtext/layers/embeddings.py +++ b/src/maxtext/layers/embeddings.py @@ -22,6 +22,7 @@ import jax.numpy as jnp from jax.sharding import Mesh, NamedSharding +from flax import linen as nn from flax import nnx from maxtext.common.common_types import ShardMode, MODEL_MODE_PREFILL, MODEL_MODE_TRAIN, Array, Config, DType @@ -156,22 +157,15 @@ def __call__(self, inputs: Array, model_mode: str = MODEL_MODE_TRAIN) -> Array: self.dtype, ) - output_axis_names = ( - ( - "activation_embed_and_logits_batch", - "prefill_activation_length", - "activation_embed", - ) - if model_mode == MODEL_MODE_PREFILL - else ( - "activation_embed_and_logits_batch", - "activation_length", - "activation_embed", - ) - ) - out_pspec = logical_to_mesh_axes(output_axis_names, self.mesh, rules=getattr(self.config, "logical_axis_rules", None)) + output_prefill_axis_names = ("activation_embed_and_logits_batch", "prefill_activation_length", "activation_embed") + output_default_axis_names = ("activation_embed_and_logits_batch", "activation_length", "activation_embed") - out_sharding = NamedSharding(self.mesh, out_pspec) if self.config.shard_mode == ShardMode.EXPLICIT else None + if self.config.shard_mode == ShardMode.EXPLICIT: + output_axis_names = output_prefill_axis_names if model_mode == MODEL_MODE_PREFILL else output_default_axis_names + out_pspec = logical_to_mesh_axes(output_axis_names, self.mesh, rules=getattr(self.config, "logical_axis_rules", None)) + out_sharding = NamedSharding(self.mesh, out_pspec) + else: + out_sharding = None if cfg.use_iota_embed: iota = lax.iota(jnp.int32, self.num_embeddings) @@ -180,6 +174,10 @@ def __call__(self, inputs: Array, model_mode: str = MODEL_MODE_TRAIN) -> Array: else: output = embedding.at[inputs].get(out_sharding=out_sharding) + if model_mode == MODEL_MODE_PREFILL: + output = nn.with_logical_constraint(output, output_prefill_axis_names) + else: + output = nn.with_logical_constraint(output, output_default_axis_names) return output def attend(self, query: Array, out_sharding: NamedSharding | None = None) -> Array: diff --git a/src/maxtext/layers/moe.py b/src/maxtext/layers/moe.py index 7abe0782f3..5ae37cea9e 100644 --- a/src/maxtext/layers/moe.py +++ b/src/maxtext/layers/moe.py @@ -656,7 +656,7 @@ def apply_ffn_activation(self, layer_w0, layer_w1): else: layer_act = self.activation_fn(layer_w0) intermediate_layer = jnp.multiply(layer_act, layer_w1) - return intermediate_layer.astype(self.dtype) + return intermediate_layer def permute(self, inputs, gate_logits, pre_bias_logits, use_custom_sort_vjp=True, rngs=None, roll_to_expert_id=None): """Permute tokens to group by expert to fit gmm call.""" @@ -1295,27 +1295,21 @@ def get_active_sharding_axes(pspec_dim_axes, tensor_dim_index): group_sizes=group_sizes, expert_assignments=selected_experts, ) + # 3-tuple (fwd only): used by jax.lax.ragged_dot and megablox forward path. + # The 9-tuple (fwd+dlhs+drhs) is only consumed by the megablox custom VJP + # (_gmm_bwd uses tiling[3:6] for dlhs and tiling[-3:] for drhs). Since + # ds-proxy uses megablox=False / use_tokamax_gmm=False (jax ragged_dot path), + # the extra 6 backward-pass values would be ignored. See base.yml comment: + # "megablox/jax ragged dot - supports forward pass only". wi_tile_size = ( self.config.wi_tile_fwd_batch_seq, # m (LHS batch) self.config.wi_tile_fwd_embed_dim, # k (contracting) self.config.wi_tile_fwd_mlp_dim, # n (RHS batch) - self.config.wi_tile_dlhs_batch_seq, # m (LHS batch) - self.config.wi_tile_dlhs_mlp_dim, # k (contracting) - self.config.wi_tile_dlhs_embed_dim, # n (RHS batch) - self.config.wi_tile_drhs_batch_seq, # Called m in megablox, but this is contracting - self.config.wi_tile_drhs_embed_dim, # Called k in megablox, but this is LHS batch dim - self.config.wi_tile_drhs_mlp_dim, # Called n in megablox, and indeed is RHS batch dim ) wo_tile_size = ( self.config.wo_tile_fwd_batch_seq, # m (LHS batch) self.config.wo_tile_fwd_mlp_dim, # k (contracting) self.config.wo_tile_fwd_embed_dim, # n (RHS batch) - self.config.wo_tile_dlhs_batch_seq, # m (LHS batch) - self.config.wo_tile_dlhs_embed_dim, # k (contracting) - self.config.wo_tile_dlhs_mlp_dim, # n (RHS) - self.config.wo_tile_drhs_batch_seq, # Called m in megablox, but this is contracting - self.config.wo_tile_drhs_mlp_dim, # Called k in megablox, but this is LHS batch dim - self.config.wo_tile_drhs_embed_dim, # Called n in megablox, and indeed is the RHS batch dim ) layer_w0 = gmm_fn( diff --git a/src/maxtext/layers/multi_token_prediction.py b/src/maxtext/layers/multi_token_prediction.py index f9af98fb30..830516fdfa 100644 --- a/src/maxtext/layers/multi_token_prediction.py +++ b/src/maxtext/layers/multi_token_prediction.py @@ -223,7 +223,6 @@ def __init__( def __call__( self, - shared_embedding, main_hidden_state, input_ids, target_ids, @@ -255,7 +254,6 @@ def __call__( rolled_position_id = roll_and_mask(rolled_position_id) target_token_embedding = self.decoder._apply_embedding( - shared_embedding, rolled_input_ids, rolled_position_id, deterministic, @@ -272,7 +270,7 @@ def __call__( model_mode=self.decoder.model_mode, ) - mtp_logits = self.decoder.apply_output_head(shared_embedding, mtp_hidden_state, deterministic, model_mode) + mtp_logits = self.decoder.apply_output_head(mtp_hidden_state, deterministic, model_mode) mtp_xent, _ = max_utils.cross_entropy_with_logits( mtp_logits, jax.nn.one_hot(rolled_target_ids, cfg.vocab_size), 0.0 diff --git a/src/maxtext/layers/normalizations.py b/src/maxtext/layers/normalizations.py index 3bce30d44e..aa1dd7101a 100644 --- a/src/maxtext/layers/normalizations.py +++ b/src/maxtext/layers/normalizations.py @@ -88,8 +88,11 @@ def __call__(self, x: jnp.ndarray, out_sharding: NamedSharding | None = None) -> scale = jax.device_put(scale, max_utils.device_space()) scale = jnp.asarray(scale, self.dtype) - effective_scale = scale + self.scale_offset - return jnp.einsum("...k,k->...k", y, effective_scale, out_sharding=out_sharding) + effective_scale = scale + self.scale_offset if self.scale_offset != 0.0 else scale + y = y * effective_scale + if out_sharding is not None: + y = jax.lax.with_sharding_constraint(y, out_sharding) + return y class GlobalRMSNorm(RMSNorm): diff --git a/src/maxtext/layers/pipeline.py b/src/maxtext/layers/pipeline.py index 62ea52782b..2a85bca542 100644 --- a/src/maxtext/layers/pipeline.py +++ b/src/maxtext/layers/pipeline.py @@ -118,6 +118,7 @@ def _maybe_shard_with_logical(self, inputs, logical_axes): rules=self.config.logical_axis_rules, debug_sharding=self.config.debug_sharding, extra_stack_level=1, + skip_trivial_specs=True, ) def _maybe_shard_with_name(self, inputs, sharding_name): @@ -139,7 +140,6 @@ def get_iteration_inputs(self, loop_iteration, state_io, circ_storage, shift): # Setup potential input from state_io, which has a rotating microbatch index (size of microbatches_per_stage) state_io_batch_idx = loop_iteration % self.microbatches_per_stage state_io_slice = state_io[:, state_io_batch_idx] - shift = self._maybe_shard_with_logical(shift, self.stages_in_logical) if self.use_circ_storage: # Setup potential input from circ_storage, which also has a rotating index for microbatch, @@ -154,7 +154,6 @@ def get_iteration_inputs(self, loop_iteration, state_io, circ_storage, shift): # state_io we instead grab from the last stage's output (possibly buffered when num_microbatches > num_stages, e.g. # from circ_storage). first_stage_in = jnp.where(loop_iteration < self.config.num_pipeline_microbatches, state_io_slice, circular_stage_in) - first_stage_in = self._maybe_shard_with_logical(first_stage_in, self.stages_in_logical) # Note that first_stage_in may correspond to bubble computation during the last few iterations. # However, these bubble computation results remain in the shift buffer (do not make it back to state_io) and are @@ -164,11 +163,7 @@ def get_iteration_inputs(self, loop_iteration, state_io, circ_storage, shift): def select_state_or_input(first_stage_in, shift): # Selects input for stage 0, shift for other stages - return jnp.where( - jax.lax.broadcasted_iota("int32", shift.shape, 0, out_sharding=self.stages_in_sharding) == 0, - first_stage_in, - shift, - ) + return jnp.where(jax.lax.broadcasted_iota("int32", shift.shape, 0) == 0, first_stage_in, shift) # Selects input (from stream_io) for stage 0, other stages get from shift (the rotated previous output) stages_in = select_state_or_input(first_stage_in, shift) @@ -180,7 +175,6 @@ def get_microbatch_and_repeat_ids(self, loop_iteration): non-circular""" # Stage 0 has processed one microbatch every loop_iter, but Stage 1 is 1 behind due to bubble, etc for other stages microbatches_processed = jnp.maximum(loop_iteration - self.forwarding_delay * jnp.arange(self.num_stages), 0) - microbatches_processed = self._maybe_shard_with_name(microbatches_processed, NamedSharding(self.mesh, P("stage"))) microbatch_ids = microbatches_processed % self.config.num_pipeline_microbatches repeat_ids = microbatches_processed // self.config.num_pipeline_microbatches return microbatch_ids, repeat_ids @@ -247,16 +241,6 @@ def get_main_vmap_func_for_iterations(self): def func_to_vmap( body_instance, weights, stages_inputs, stages_segment_ids, stages_positions, deterministic, model_mode ): - weights = meta.remove_axis( - weights, - 0, - { - nn.PARTITION_NAME: "layers", - "sub_weight_split_dims_mapping": (None,), - "is_initializing": self.is_initializing(), - "x_times": self.num_stages, - }, - ) return body_instance.apply(weights, stages_inputs, stages_segment_ids, stages_positions, deterministic, model_mode) vmap_func = nn.vmap( @@ -490,9 +474,7 @@ def vmap_gather(self, xs, ids, ids_dim): """ def _gather_one(x, i): - idx = tuple(i if d == ids_dim else slice(None) for d in range(x.ndim)) - replicated_sharding = NamedSharding(self.mesh, P()) - return x.at[idx].get(out_sharding=replicated_sharding) + return jnp.squeeze(jax.lax.dynamic_slice_in_dim(x, i, 1, ids_dim), ids_dim) ids = self.shard_dim_by_stages(ids, 0, physical_partition_spec=None) outs = jax.vmap(_gather_one, in_axes=(None, 0), out_axes=ids_dim)(xs, ids) @@ -516,21 +498,16 @@ def get_new_loop_state(self, output, loop_state): loop_iteration = loop_state["loop_iteration"] old_prev_outputs = loop_state["prev_outputs"] - @jax.shard_map(mesh=self.mesh, in_specs=self.stages_in_spec, out_specs=self.stages_in_spec, check_vma=True) def _rotate_right(arr): - # we use +1 for right shifting - stage_size = jax.lax.axis_size("stage") - perm = [(i, (i + 1) % stage_size) for i in range(stage_size)] - arr = jax.lax.ppermute(arr, axis_name="stage", perm=perm) - return arr + # Use lax.slice to avoid generating a gather. + last = jax.lax.slice_in_dim(arr, self.num_stages - 1, self.num_stages, axis=0) + except_last = jax.lax.slice_in_dim(arr, 0, self.num_stages - 1, axis=0) + return jnp.concatenate([last, except_last], axis=0) - @jax.shard_map(mesh=self.mesh, in_specs=self.stages_in_spec, out_specs=self.stages_in_spec, check_vma=True) def _shift_right(arr): - stage_idx = jax.lax.axis_index("stage") - stage_size = jax.lax.axis_size("stage") - perm = [(i, (i + 1) % stage_size) for i in range(stage_size)] - arr = jax.lax.ppermute(arr, axis_name="stage", perm=perm) - return jnp.where(stage_idx == 0, jnp.zeros_like(arr), arr) + padding = [[1, 0]] + [[0, 0]] * (arr.ndim - 1) + # Use lax.slice to guarantee the gradient is a pad. + return jax.lax.slice(jnp.pad(arr, padding), [0] * arr.ndim, arr.shape) # Shift either rotates or shifts depending on if the last stage immediately must send to first or not # For non-circular pipelines, the last stage does not need to send to first @@ -574,29 +551,17 @@ def _rotate_right_and_update(circ_storage_mover_in, circ_storage_in): stream_buf_idx = loop_iteration % self.microbatches_per_stage stream_slice = old_state_io[:, stream_buf_idx] - def _rotate_left(arr, stage_size): - # we use -1 for left shifting - perm = [(i, (i - 1) % stage_size) for i in range(stage_size)] - return jax.lax.ppermute(arr, axis_name="stage", perm=perm) - - def _shift_left(arr, stage_size, output): - stage_idx = jax.lax.axis_index("stage") - arr = _rotate_left(arr, stage_size) - return jnp.where(stage_idx == stage_size - 1, output, arr) - - @jax.shard_map( - mesh=self.mesh, - in_specs=(self.state_io_spec, self.stages_in_spec, self.stages_in_spec, P()), - out_specs=self.state_io_spec, - ) - def _update_state_io(state_in, stream_slice, output, stream_buf_idx): + def _update_state_io(state_in, stream_slice, output): # Shift the current slice to the left, then fill the last stage with the final output. - stage_size = jax.lax.axis_size("stage") - stream_slice = _shift_left(stream_slice, stage_size, output) + padding = [[0, 1]] + [[0, 0]] * (stream_slice.ndim - 1) + stream_slice = jax.lax.slice_in_dim(jnp.pad(stream_slice, padding), 1, stream_slice.shape[0] + 1, axis=0) + stream_slice = jnp.where( + jax.lax.broadcasted_iota("int32", stream_slice.shape, 0) == self.num_stages - 1, output, stream_slice + ) stream_slice = jnp.expand_dims(stream_slice, 1) return jax.lax.dynamic_update_slice_in_dim(state_in, stream_slice, stream_buf_idx, axis=1) - new_state = _update_state_io(old_state_io, stream_slice, output, stream_buf_idx) + new_state = _update_state_io(old_state_io, stream_slice, output) new_loop_state = { "state_io": new_state, diff --git a/src/maxtext/models/deepseek.py b/src/maxtext/models/deepseek.py index 7ae5c46b19..6e865c0ead 100644 --- a/src/maxtext/models/deepseek.py +++ b/src/maxtext/models/deepseek.py @@ -43,6 +43,7 @@ from maxtext.utils import max_utils from maxtext.utils.sharding import create_sharding from maxtext.utils.sharding import maybe_shard_with_logical +from maxtext.utils.sharding import remove_size_one_mesh_axis import transformers @@ -419,7 +420,7 @@ def __init__( self.DeepSeekMoeBlock_0 = moe.RoutedAndSharedMoE( config=self.config, mesh=mesh, - kernel_init=initializers.nd_dense_init(self.config.dense_init_scale, "fan_in", "truncated_normal"), + kernel_init=initializers.nd_dense_init(1.0, "fan_in", "truncated_normal"), kernel_axes=("embed", None), dtype=self.config.dtype, weight_dtype=self.config.weight_dtype, @@ -491,13 +492,14 @@ def __call__( return outputs, None # bf16 code path - input_sharding = jax.typeof(inputs).sharding - activation_pspec = jax.sharding.PartitionSpec( - ("data", "fsdp", "expert"), - None, - None, + activation_pspec = remove_size_one_mesh_axis( + jax.sharding.PartitionSpec( + ("data", "fsdp", "fsdp_transpose", "expert", "context"), + None, + None, + ), + self.mesh, ) - inputs = jax.reshard(inputs, jax.sharding.NamedSharding(self.mesh, activation_pspec)) yarn_freqs = deepseek_batchsplit.initialize_yarn_freqs( decoder_positions, embedding_dims=self.config.qk_rope_head_dim, @@ -569,7 +571,6 @@ def extract_fn(x): in_specs=([activation_pspec] * self.config.batch_split_factor,), out_specs=activation_pspec, )(outputs) - outputs = jax.reshard(outputs, input_sharding) return outputs, None x = self.with_logical_constraint(inputs) diff --git a/src/maxtext/models/mixtral.py b/src/maxtext/models/mixtral.py index faf69273c6..03c1152197 100644 --- a/src/maxtext/models/mixtral.py +++ b/src/maxtext/models/mixtral.py @@ -18,111 +18,33 @@ from flax import linen as nn -from flax import nnx from jax.ad_checkpoint import checkpoint_name import jax.numpy as jnp from jax.sharding import Mesh from maxtext.common.common_types import Config -from maxtext.layers import initializers, nnx_wrappers +from maxtext.layers import initializers from maxtext.layers import moe from maxtext.layers import quantizations -from maxtext.layers.attentions import Attention -from maxtext.layers.linears import Dropout -from maxtext.layers.normalizations import RMSNorm +from maxtext.layers.attentions import attention_as_linen +from maxtext.layers.normalizations import rms_norm from maxtext.layers.quantizations import AqtQuantization as Quant from maxtext.utils import max_utils +from maxtext.utils.sharding import maybe_shard_with_logical # ----------------------------------------- # The Decoder Layer for Mixtral # ----------------------------------------- -class MixtralDecoderLayer(nnx.Module): +class MixtralDecoderLayer(nn.Module): """Transformer decoder layer that attends to the encoder.""" - @nn.compact - def __init__( - self, - config: Config, - mesh: Mesh, - model_mode: str, - quant: None | Quant = None, - *, - rngs: nnx.Rngs, - ): - self.config = config - self.mesh = mesh - self.model_mode = model_mode - self.quant = quant - self.rngs = rngs - - batch_size, seq_len = max_utils.get_batch_seq_len_for_mode(config, model_mode) - dummy_inputs_shape = (batch_size, seq_len, config.emb_dim) - - self.pre_self_attention_layer_norm = RMSNorm( - num_features=config.emb_dim, - dtype=config.dtype, - weight_dtype=config.weight_dtype, - kernel_axes=("norm",), - epsilon=config.normalization_layer_epsilon, - rngs=self.rngs, - ) - - self.self_attention = Attention( - config=config, - num_query_heads=config.num_query_heads, - num_kv_heads=config.num_kv_heads, - head_dim=config.head_dim, - max_target_length=config.max_target_length, - max_prefill_predict_length=config.max_prefill_predict_length, - attention_kernel=config.attention, - inputs_q_shape=dummy_inputs_shape, - inputs_kv_shape=dummy_inputs_shape, - mesh=mesh, - dtype=config.dtype, - weight_dtype=config.weight_dtype, - dropout_rate=config.dropout_rate, - float32_qk_product=config.float32_qk_product, - float32_logits=config.float32_logits, - quant=self.quant, - kv_quant=quantizations.configure_kv_quant(config), - prefill_cache_axis_order=tuple(map(int, config.prefill_cache_axis_order.split(","))), - ar_cache_axis_order=tuple(map(int, config.ar_cache_axis_order.split(","))), - compute_axis_order=tuple(map(int, config.compute_axis_order.split(","))), - reshape_q=config.reshape_q, - use_ragged_attention=config.use_ragged_attention, - ragged_block_size=config.ragged_block_size, - model_mode=model_mode, - rngs=self.rngs, - ) - - self.post_self_attention_layer_norm = RMSNorm( - num_features=config.emb_dim, - dtype=config.dtype, - weight_dtype=config.weight_dtype, - kernel_axes=("norm",), - epsilon=config.normalization_layer_epsilon, - rngs=self.rngs, - ) - - self.MoeBlock_0 = moe.RoutedMoE( - config=config, - num_experts=config.num_experts, - num_experts_per_tok=config.num_experts_per_tok, - mesh=mesh, - kernel_init=initializers.nd_dense_init(config.dense_init_scale, "fan_in", "truncated_normal"), - kernel_axes=("embed", None), - intermediate_dim=config.mlp_dim, - dtype=config.dtype, - weight_dtype=config.weight_dtype, - quant=self.quant, - rngs=self.rngs, - ) - - self.dropout = Dropout(rate=config.dropout_rate, broadcast_dims=(-2,), rngs=rngs) - - self.activation_axis_names = ("activation_batch", "activation_norm_length", "activation_embed") + config: Config + mesh: Mesh + model_mode: str + quant: None | Quant = None + @nn.compact def __call__( self, inputs, @@ -139,13 +61,61 @@ def __call__( # Unpack inputs if it's a tuple (e.g. from a previous layer returning (hidden_states, kv_cache)) if isinstance(inputs, tuple): inputs = inputs[0] - inputs = nn.with_logical_constraint(inputs, self.activation_axis_names) - inputs = checkpoint_name(inputs, "decoder_layer_input") - lnx = self.pre_self_attention_layer_norm(inputs) - lnx = nn.with_logical_constraint(lnx, self.activation_axis_names) + cfg = self.config + mesh = self.mesh + + activation_axis_names = ("activation_batch", "activation_norm_length", "activation_embed") - attention_lnx, kv_cache = self.self_attention( + def shard(x): + return maybe_shard_with_logical( + x, activation_axis_names, mesh=mesh, shard_mode=cfg.shard_mode, + rules=cfg.logical_axis_rules, skip_trivial_specs=True, + ) + + inputs = shard(inputs) + inputs = checkpoint_name(inputs, "decoder_layer_input") + + lnx = rms_norm( + num_features=cfg.emb_dim, + dtype=cfg.dtype, + weight_dtype=cfg.weight_dtype, + name="pre_self_attention_layer_norm", + kernel_axes=("norm",), + epsilon=cfg.normalization_layer_epsilon, + )(inputs) + lnx = shard(lnx) + + batch_size, seq_len = max_utils.get_batch_seq_len_for_mode(cfg, model_mode) + dummy_inputs_shape = (batch_size, seq_len, cfg.emb_dim) + + attention_lnx, kv_cache = attention_as_linen( + config=cfg, + num_query_heads=cfg.num_query_heads, + num_kv_heads=cfg.num_kv_heads, + head_dim=cfg.head_dim, + max_target_length=cfg.max_target_length, + max_prefill_predict_length=cfg.max_prefill_predict_length, + attention_kernel=cfg.attention, + inputs_q_shape=dummy_inputs_shape, + inputs_kv_shape=dummy_inputs_shape, + mesh=mesh, + dtype=cfg.dtype, + weight_dtype=cfg.weight_dtype, + dropout_rate=cfg.dropout_rate, + float32_qk_product=cfg.float32_qk_product, + float32_logits=cfg.float32_logits, + quant=self.quant, + kv_quant=quantizations.configure_kv_quant(cfg), + prefill_cache_axis_order=tuple(map(int, cfg.prefill_cache_axis_order.split(","))), + ar_cache_axis_order=tuple(map(int, cfg.ar_cache_axis_order.split(","))), + compute_axis_order=tuple(map(int, cfg.compute_axis_order.split(","))), + reshape_q=cfg.reshape_q, + use_ragged_attention=cfg.use_ragged_attention, + ragged_block_size=cfg.ragged_block_size, + model_mode=model_mode, + name="self_attention", + )( lnx, lnx, decoder_positions, @@ -157,28 +127,47 @@ def __call__( attention_metadata=attention_metadata, ) - attention_lnx = nn.with_logical_constraint(attention_lnx, self.activation_axis_names) + attention_lnx = shard(attention_lnx) intermediate_inputs = inputs + attention_lnx # Fully Connected - hidden_states = self.post_self_attention_layer_norm(intermediate_inputs) - hidden_states = nn.with_logical_constraint(hidden_states, self.activation_axis_names) + hidden_states = rms_norm( + num_features=cfg.emb_dim, + dtype=cfg.dtype, + weight_dtype=cfg.weight_dtype, + name="post_self_attention_layer_norm", + kernel_axes=("norm",), + epsilon=cfg.normalization_layer_epsilon, + )(intermediate_inputs) + hidden_states = shard(hidden_states) load_balance_loss = None # NOTE: the naming mismatch here is to ensure reverse compatibility with existing checkpoints. # The `name` represents the weight name in JAX/checkpoints and so the class name # is just for readability. - mlp_lnx, load_balance_loss, _ = self.MoeBlock_0(hidden_states) - mlp_lnx = nn.with_logical_constraint(mlp_lnx, self.activation_axis_names) + mlp_lnx, load_balance_loss, _ = moe.get_routed_moe( + config=cfg, + num_experts=cfg.num_experts, + num_experts_per_tok=cfg.num_experts_per_tok, + mesh=mesh, + kernel_init=initializers.nd_dense_init(1.0, "fan_in", "truncated_normal"), + kernel_axes=("embed", None), + intermediate_dim=cfg.mlp_dim, + dtype=cfg.dtype, + weight_dtype=cfg.weight_dtype, + quant=self.quant, + name="MoeBlock_0", + )(hidden_states) + mlp_lnx = shard(mlp_lnx) layer_output = mlp_lnx + intermediate_inputs - layer_output = self.dropout(layer_output, deterministic=deterministic) - layer_output = nn.with_logical_constraint(layer_output, self.activation_axis_names) + layer_output = nn.Dropout(rate=cfg.dropout_rate, broadcast_dims=(-2,))(layer_output, deterministic=deterministic) + layer_output = shard(layer_output) - if self.config.load_balance_loss_weight > 0.0 and load_balance_loss is not None: + if cfg.load_balance_loss_weight > 0.0 and load_balance_loss is not None: self.sow("intermediates", "moe_lb_loss", load_balance_loss) - if self.config.record_internal_nn_metrics: + if cfg.record_internal_nn_metrics: self.sow("intermediates", "activation_mean", jnp.mean(layer_output)) self.sow("intermediates", "activation_stdev", jnp.std(layer_output)) self.sow( @@ -187,13 +176,10 @@ def __call__( jnp.sum(layer_output == 0) / jnp.size(layer_output), ) - if self.config.scan_layers: + if cfg.scan_layers: return layer_output, None else: return layer_output, kv_cache -MixtralDecoderLayerToLinen = nnx_wrappers.to_linen_class( - MixtralDecoderLayer, - base_metadata_fn=initializers.variable_to_logically_partitioned, -) +MixtralDecoderLayerToLinen = MixtralDecoderLayer diff --git a/src/maxtext/models/models.py b/src/maxtext/models/models.py index 23ad166c79..88075a73da 100644 --- a/src/maxtext/models/models.py +++ b/src/maxtext/models/models.py @@ -88,7 +88,7 @@ def setup(self): ) self.vision_encoder = vision_encoder_as_linen(config=cfg, mesh=mesh) if cfg.use_multimodal else None self.audio_encoder = audio_encoder_as_linen(config=cfg, mesh=mesh) if cfg.use_audio else None - self.decoder = Decoder(config=cfg, mesh=mesh, quant=self.quant, model_mode=self.model_mode) + self.decoder = Decoder(config=cfg, shared_embedding=self.shared_embedding, mesh=mesh, quant=self.quant, model_mode=self.model_mode) # If MTP is enabled via config, set up the MTP block. if self.config.mtp_num_layers > 0: @@ -114,7 +114,6 @@ def logits_from_hidden_states(self, hidden_states, deterministic, model_mode): This function is only used for vocabulary tiling. """ logits = self.decoder.apply_output_head( - shared_embedding=self.shared_embedding, y=hidden_states, deterministic=deterministic, model_mode=model_mode, @@ -186,7 +185,6 @@ def __call__( ) logits, hidden_state, kv_caches = self.decoder( - shared_embedding=self.shared_embedding, decoder_input_tokens=decoder_input_tokens, decoder_positions=decoder_positions, decoder_segment_ids=decoder_segment_ids, @@ -222,7 +220,6 @@ def __call__( # Its only effect is to "sow" these losses; it does not alter the primary logits output. if self.config.mtp_num_layers > 0: self.mtp_block( - shared_embedding=self.shared_embedding, main_hidden_state=hidden_state, input_ids=decoder_input_tokens, target_ids=decoder_target_tokens, @@ -345,7 +342,7 @@ def __init__( if cfg.pure_nnx_decoder: self.decoder = NNXDecoder(config=cfg, mesh=mesh, quant=self.quant, model_mode=self.model_mode, rngs=rngs) else: - decoder_linen = Decoder(config=cfg, mesh=mesh, quant=self.quant, model_mode=self.model_mode) + decoder_linen = Decoder(config=cfg, mesh=mesh, shared_embedding=self.token_embedder, quant=self.quant, model_mode=self.model_mode) self.decoder = nnx_wrappers.ToNNX(decoder_linen, rngs=rngs) self.hidden_states = None @@ -373,7 +370,6 @@ def __init__( if not cfg.pure_nnx_decoder: self.decoder.lazy_init( - shared_embedding=self.token_embedder, decoder_input_tokens=dummy_decoder_input_tokens, decoder_positions=dummy_decoder_positions, attention_metadata=dummy_attention_metadata, @@ -397,7 +393,6 @@ def __init__( self.mtp_block = nnx_wrappers.ToNNX(mtp_block_linen, rngs=rngs) self.mtp_block.lazy_init( - shared_embedding=self.token_embedder, main_hidden_state=jnp.ones((1, 1, self.config.emb_dim), dtype=self.config.dtype), input_ids=jnp.ones((1, 1), dtype=jnp.int32), target_ids=jnp.ones((1, 1), dtype=jnp.int32), @@ -511,7 +506,6 @@ def __call__( if self.config.pure_nnx_decoder: logits, hidden_state, kv_caches = self.decoder( - shared_embedding=self.token_embedder, decoder_input_tokens=decoder_input_tokens, decoder_positions=decoder_positions, decoder_segment_ids=decoder_segment_ids, @@ -527,7 +521,6 @@ def __call__( ) # pytype: disable=wrong-keyword-args else: logits, hidden_state, kv_caches = self.decoder( - shared_embedding=self.token_embedder, decoder_input_tokens=decoder_input_tokens, decoder_positions=decoder_positions, decoder_segment_ids=decoder_segment_ids, @@ -568,7 +561,6 @@ def __call__( # Its only effect is to "sow" these losses; it does not alter the primary logits output. if self.config.mtp_num_layers > 0: self.mtp_block( - shared_embedding=self.token_embedder, main_hidden_state=hidden_state, input_ids=decoder_input_tokens, target_ids=decoder_target_tokens, diff --git a/src/maxtext/trainers/pre_train/train.py b/src/maxtext/trainers/pre_train/train.py index 720354fe4d..0e9273d965 100644 --- a/src/maxtext/trainers/pre_train/train.py +++ b/src/maxtext/trainers/pre_train/train.py @@ -36,6 +36,8 @@ import jax import jax.numpy as jnp +import flax +flax.config.update("flax_always_shard_variable", False) from flax import linen as nn from flax.linen import partitioning as nn_partitioning @@ -333,10 +335,11 @@ def train_step(model, config, state_mesh_shardings, params_shardings, state, dat grad_func = jax.value_and_grad(_loss_fn, argnums=4, has_aux=True) (loss, aux), raw_grads = grad_func(model, config, data, dropout_rng, params, *extra_dpo_args, is_train=True) - raw_grads = jax.tree_util.tree_map( - lambda x: x.astype(config.grad_dtype) if x.dtype == jnp.float32 else x, - raw_grads, - ) + if config.grad_dtype != jnp.float32: + raw_grads = jax.tree_util.tree_map( + lambda x: x.astype(config.grad_dtype) if x.dtype == jnp.float32 else x, + raw_grads, + ) if config.parameter_memory_host_offload: raw_grads = jax.device_put( raw_grads, @@ -375,12 +378,6 @@ def train_step(model, config, state_mesh_shardings, params_shardings, state, dat fp8_stats = jax.tree_util.tree_map(lambda x: jnp.nan_to_num(x, nan=0.0), fp8_stats) grads = dict(grads) grads[maxtext_utils.OVERWRITE_WITH_GRADIENT] = fp8_stats - # Zero out any remaining NaN in float gradients to prevent param corruption - grads = jax.tree_util.tree_map( - lambda x: jnp.nan_to_num(x, nan=0.0) if jnp.issubdtype(x.dtype, jnp.floating) else x, - grads, - ) - if config.optimizer_memory_host_offload: state = state.replace( opt_state=jax.device_put( @@ -666,7 +663,8 @@ def initialize(argv: Sequence[str]) -> tuple[pyconfig.HyperParameters, Any, Any] max_utils.print_system_information() train_utils.validate_train_config(config) jax.config.update("jax_use_shardy_partitioner", config.shardy) - jax.config.update("jax_remove_size_one_mesh_axis_from_type", config.remove_size_one_mesh_axis_from_type) + if hasattr(jax.config, "jax_remove_size_one_mesh_axis_from_type"): + jax.config.update("jax_remove_size_one_mesh_axis_from_type", config.remove_size_one_mesh_axis_from_type) os.environ["TFDS_DATA_DIR"] = config.dataset_path or "" vertex_tensorboard_manager = VertexTensorboardManager() if config.use_vertex_tensorboard or os.environ.get("UPLOAD_DATA_TO_TENSORBOARD"): diff --git a/src/maxtext/trainers/pre_train/train_compile.py b/src/maxtext/trainers/pre_train/train_compile.py index a2981f67ed..88e5025824 100644 --- a/src/maxtext/trainers/pre_train/train_compile.py +++ b/src/maxtext/trainers/pre_train/train_compile.py @@ -85,7 +85,8 @@ def get_topology_mesh(config): num_slices=config.compile_topology_num_slices, wrap=target_hardware.wrap, ).devices - jax.config.update("jax_remove_size_one_mesh_axis_from_type", config.remove_size_one_mesh_axis_from_type) + if hasattr(jax.config, "jax_remove_size_one_mesh_axis_from_type"): + jax.config.update("jax_remove_size_one_mesh_axis_from_type", config.remove_size_one_mesh_axis_from_type) topology_device_mesh = maxtext_utils.create_device_mesh(config, topology_devices) mesh_axis_type = AxisType.Explicit if config.shard_mode == ShardMode.EXPLICIT else AxisType.Auto topology_mesh = Mesh(topology_device_mesh, config.mesh_axes, axis_types=(mesh_axis_type,) * len(config.mesh_axes)) diff --git a/src/maxtext/utils/sharding.py b/src/maxtext/utils/sharding.py index d4bb64f016..38cdaa707d 100644 --- a/src/maxtext/utils/sharding.py +++ b/src/maxtext/utils/sharding.py @@ -131,7 +131,15 @@ def maybe_shard_with_pspec( def maybe_shard_with_logical( - inputs, logical_axes, mesh, shard_mode, rules=None, debug_sharding=False, extra_stack_level=0, sharding_desc="" + inputs, + logical_axes, + mesh, + shard_mode, + rules=None, + debug_sharding=False, + extra_stack_level=0, + sharding_desc="", + skip_trivial_specs=False, ): """ A wrapper of maybe_shard_with_name when logical axes are inputs @@ -146,6 +154,9 @@ def maybe_shard_with_logical( named_sharding = create_sharding(mesh, logical_axes, rules=rules) + if skip_trivial_specs and all(ax is None or ax == () for ax in named_sharding.spec): + return inputs + return maybe_shard_with_name( inputs, named_sharding, diff --git a/tests/integration/train_tests.py b/tests/integration/train_tests.py index fc27753abe..9d10f62686 100644 --- a/tests/integration/train_tests.py +++ b/tests/integration/train_tests.py @@ -18,6 +18,7 @@ import pytest import jax +from jax._src import test_util as jtu from absl.testing import absltest from maxtext.common.gcloud_stub import is_decoupled from maxtext.trainers.pre_train.train import main as train_main diff --git a/tests/utils/test_helpers.py b/tests/utils/test_helpers.py index 2eea9025c7..e4ad9f1cd9 100644 --- a/tests/utils/test_helpers.py +++ b/tests/utils/test_helpers.py @@ -122,11 +122,23 @@ def get_test_base_output_directory(cloud_path=None): return cloud_path or "gs://runner-maxtext-logs" +def get_test_config_path_for(relative_path: str): + """Return absolute path for a config under the configs directory. + + Uses the same decoupled-vs-non-decoupled selection logic as + `get_test_config_path()` when `relative_path` is `base.yml`. + """ + if relative_path == "base.yml": + return get_test_config_path() + return os.path.join(MAXTEXT_CONFIGS_DIR, relative_path) + + __all__ = [ "get_test_base_output_directory", "get_decoupled_parallelism_overrides", "is_rocm_backend", "get_test_config_path", "get_post_train_test_config_path", + "get_test_config_path_for", "get_test_dataset_path", ]