diff --git a/.github/workflows/build_rocm_transformer_engine_wheel_weekly.yml b/.github/workflows/build_rocm_transformer_engine_wheel_weekly.yml index 515a3d63ce..aede854e9d 100644 --- a/.github/workflows/build_rocm_transformer_engine_wheel_weekly.yml +++ b/.github/workflows/build_rocm_transformer_engine_wheel_weekly.yml @@ -23,6 +23,8 @@ jobs: env: XLA_PYTHON_CLIENT_MEM_FRACTION: "0.9" NVTE_FUSED_ATTN_AOTRITON: "0" + env: + TE_WHEELS_KEEP_DAYS: "21" steps: - name: Checkout @@ -53,7 +55,7 @@ jobs: source .venv/bin/activate # Detect ROCm version number from JAX backend string when available (e.g. 'rocm 70200'). - ROCM_NUM="$(python3 -c 'import re, jax; s=str(jax.devices()[0].client.platform_version); m=re.search(r"rocm\\s+([0-9]+)", s); print(m.group(1) if m else "unknown")')" + ROCM_NUM="$([ -f /opt/rocm/.info/version ] && head -n1 /opt/rocm/.info/version | tr -d ' \t\r' || echo unknown)" echo "Detected ROCm version: ${ROCM_NUM}" echo "ROCM_NUM=${ROCM_NUM}" >> "${GITHUB_ENV}" @@ -160,7 +162,8 @@ jobs: run: | set -euo pipefail source .venv/bin/activate - python3 "${REL_SCRIPT}" prune-assets --tag "te-rocm-wheels" --keep-days 21 + echo "Pruning rolling-tag assets older than ${TE_WHEELS_KEEP_DAYS} days" + python3 "${REL_SCRIPT}" prune-assets --tag "te-rocm-wheels" --keep-days "${TE_WHEELS_KEEP_DAYS}" - name: Publish wheel to dated weekly release tag shell: bash @@ -189,4 +192,5 @@ jobs: run: | set -euo pipefail source .venv/bin/activate - python3 "${REL_SCRIPT}" prune-releases --prefix "te-rocm-wheels-" --keep-days 21 + echo "Pruning weekly release pages older than ${TE_WHEELS_KEEP_DAYS} days" + python3 "${REL_SCRIPT}" prune-releases --prefix "te-rocm-wheels-" --keep-days "${TE_WHEELS_KEEP_DAYS}" diff --git a/.github/workflows/utils/install_te_rocm_wheel.py b/.github/workflows/utils/install_te_rocm_wheel.py index 7a3f70f75b..8f601884fd 100644 --- a/.github/workflows/utils/install_te_rocm_wheel.py +++ b/.github/workflows/utils/install_te_rocm_wheel.py @@ -92,10 +92,18 @@ def try_download_from_te_rocm_wheels(repo: str, arch: str) -> bool: assets = rel.get("assets", []) name_re = re.compile(rf"^transformer_engine-.*-1\.{arch}-cp312-cp312-linux_x86_64\.whl$") - hit = next((a for a in assets if name_re.match(a.get("name", ""))), None) - if not hit: + matches = [a for a in assets if name_re.match(a.get("name", ""))] + if not matches: return False + # Rolling tag keeps many wheels; select newest matching asset. + hit = max(matches, key=lambda a: a.get("created_at", "")) + print( + "[te wheel] selected latest te-rocm-wheels asset: " + f"{hit.get('name', '')} (created_at={hit.get('created_at', 'unknown')})", + flush=True, + ) + download(hit["browser_download_url"], hit["name"]) return True