diff --git a/tests/unit/compile_cache_test.py b/tests/unit/compile_cache_test.py index b94cd0dec6..a56fb8bea9 100644 --- a/tests/unit/compile_cache_test.py +++ b/tests/unit/compile_cache_test.py @@ -96,9 +96,11 @@ def test_train_step_cache_hit(): # Check if cache dir has files cache_files = os.listdir(temp_dir) + train_step_cache_files = [f for f in cache_files if f.startswith("jit_train_step")] print("=== Cache Directory Content ===") print(f"Path: {temp_dir}") print(f"Files: {cache_files}") + print(f"Train step cache files: {train_step_cache_files}") print("===============================") assert len(cache_files) > 0, ( @@ -106,10 +108,12 @@ def test_train_step_cache_hit(): "cache was not writeable or the JAX cache configuration was ignored." ) - assert len(cache_files) == 1, ( - f"Expected exactly 1 JAX compilation cache file, but found {len(cache_files)}: {cache_files}. " - "This indicates a cache miss where AOT compilation and runtime execution generated different keys, " - "causing train_step to be compiled twice (double-compilation regression)." + assert len(train_step_cache_files) == 1, ( + f"Expected exactly 1 jit_train_step JAX compilation cache file, but found " + f"{len(train_step_cache_files)}: {train_step_cache_files}. " + f"All cache files: {cache_files}. " + "This indicates a cache miss where AOT compilation and runtime execution generated " + "different keys, causing train_step to be compiled twice (double-compilation regression)." ) assert "Persistent compilation cache hit for 'jit_train_step'" in captured_logs, (