diff --git a/.github/workflows/run_tests_against_package.yml b/.github/workflows/run_tests_against_package.yml index 602b5e72fd..bcda98005a 100644 --- a/.github/workflows/run_tests_against_package.yml +++ b/.github/workflows/run_tests_against_package.yml @@ -84,6 +84,7 @@ permissions: jobs: run: runs-on: ${{ inputs.cloud_runner != '' && inputs.cloud_runner || (inputs.device_type == 'rocm' && fromJson('["self-hosted","linux-x86-64-4gpu-amd"]')) || fromJson(format('["self-hosted", "{0}", "{1}"]', inputs.device_type, inputs.device_name)) }} + timeout-minutes: ${{ inputs.device_type == 'rocm' && 90 || 360 }} container: image: ${{ inputs.device_type == 'rocm' && 'ghcr.io/rocm/jax-base-ubu24.rocm720:latest' || format('gcr.io/tpu-prod-env-multipod/{0}', inputs.base_image) }} env: diff --git a/.github/workflows/utils/install_te_rocm_wheel.py b/.github/workflows/utils/install_te_rocm_wheel.py index a2a16b225c..4d119f451a 100644 --- a/.github/workflows/utils/install_te_rocm_wheel.py +++ b/.github/workflows/utils/install_te_rocm_wheel.py @@ -91,7 +91,8 @@ def try_download_from_te_rocm_wheels(repo: str, arch: str) -> bool: rel = json.loads(r.read().decode("utf-8")) assets = rel.get("assets", []) - name_re = re.compile(rf"^transformer_engine-.*-{arch}-cp312-cp312-linux_x86_64\.whl$") + # Wheels published by this repo use the selector format: `-1.-...` (e.g. `-1.mi355-...`). + name_re = re.compile(rf"^transformer_engine-.*-1\.{arch}-cp312-cp312-linux_x86_64\.whl$") hit = next((a for a in assets if name_re.match(a.get("name", ""))), None) if not hit: return False