diff --git a/checkpoint/orbax/checkpoint/_src/testing/benchmarks/configs/evaluations/pathways/llama-405b.yaml b/checkpoint/orbax/checkpoint/_src/testing/benchmarks/configs/evaluations/pathways/llama-405b.yaml new file mode 100644 index 000000000..760e569d8 --- /dev/null +++ b/checkpoint/orbax/checkpoint/_src/testing/benchmarks/configs/evaluations/pathways/llama-405b.yaml @@ -0,0 +1,26 @@ +# The name for the entire test suite run. +# Assumes v5p-128 (64 chips) +suite_name: "Llama 3.1 405B" +num_repeats: 20 + + +mesh_config: + mesh_axes: ["data", "fsdp", "tensor"] + ici_parallelism: {"data": 1, "fsdp": 64, "tensor": 1} + +# The checkpoint configuration, shared across all generated benchmarks. +checkpoint_config: + path: "gs://orbax-benchmarks/checkpoints/llama-405b_generate_1-64-1" + sharding_config_path: "gs://orbax-benchmarks/sharding-configs/llama3.1-405b-v5p-128-data-1-fsdp-64-tensor-1/abstract_state.json" + +benchmarks: + - generator: "orbax.checkpoint._src.testing.benchmarks.v1.benchmark.Benchmark" + options: + # --- Generator Options --- + # These keys must match the attributes of the `BenchmarkOptions` class + # associated with the `Benchmark` generator. + async_enabled: true + use_ocdbt: true + use_zarr3: true + use_replica_parallel: false + use_compression: true diff --git a/checkpoint/orbax/checkpoint/_src/testing/benchmarks/configs/evaluations/pathways/llama-70b.yaml b/checkpoint/orbax/checkpoint/_src/testing/benchmarks/configs/evaluations/pathways/llama-70b.yaml new file mode 100644 index 000000000..5fa44c864 --- /dev/null +++ b/checkpoint/orbax/checkpoint/_src/testing/benchmarks/configs/evaluations/pathways/llama-70b.yaml @@ -0,0 +1,26 @@ +# The name for the entire test suite run. +# Assumes v5p-32 (16 chips) +suite_name: "Llama 3.1 70B" +num_repeats: 20 + + +mesh_config: + mesh_axes: ["data", "fsdp", "tensor"] + ici_parallelism: {"data": 1, "fsdp": 16, "tensor": 1} + +# The checkpoint configuration, shared across all generated benchmarks. +checkpoint_config: + path: "gs://orbax-benchmarks/checkpoints/llama-70b_generate_1-16-1" + sharding_config_path: "gs://orbax-benchmarks/sharding-configs/llama3.1-70b-v5p-32-data-1-fsdp-16-tensor-1/abstract_state.json" + +benchmarks: + - generator: "orbax.checkpoint._src.testing.benchmarks.v1.benchmark.Benchmark" + options: + # --- Generator Options --- + # These keys must match the attributes of the `BenchmarkOptions` class + # associated with the `Benchmark` generator. + async_enabled: true + use_ocdbt: true + use_zarr3: true + use_replica_parallel: false + use_compression: true diff --git a/checkpoint/orbax/checkpoint/_src/testing/benchmarks/configs/evaluations/pathways/llama-8b.yaml b/checkpoint/orbax/checkpoint/_src/testing/benchmarks/configs/evaluations/pathways/llama-8b.yaml new file mode 100644 index 000000000..37fd4bfc2 --- /dev/null +++ b/checkpoint/orbax/checkpoint/_src/testing/benchmarks/configs/evaluations/pathways/llama-8b.yaml @@ -0,0 +1,25 @@ +# The name for the entire test suite run. +# Assumes v5p-8 (4 chips) +suite_name: "Llama 3.1 8B" +num_repeats: 20 + +mesh_config: + mesh_axes: ["data", "fsdp", "tensor"] + ici_parallelism: {"data": 1, "fsdp": 4, "tensor": 1} + +# The checkpoint configuration, shared across all generated benchmarks. +checkpoint_config: + path: "gs://orbax-benchmarks/checkpoints/llama-8b_generate_1-4-1" + sharding_config_path: "gs://orbax-benchmarks/sharding-configs/llama3.1-8b-v5p-8-data-1-fsdp-4-tensor-1/abstract_state.json" + +benchmarks: + - generator: "orbax.checkpoint._src.testing.benchmarks.v1.benchmark.Benchmark" + options: + # --- Generator Options --- + # These keys must match the attributes of the `BenchmarkOptions` class + # associated with the `Benchmark` generator. + async_enabled: true + use_ocdbt: true + use_zarr3: true + use_replica_parallel: false + use_compression: true diff --git a/checkpoint/orbax/checkpoint/_src/testing/benchmarks/configs/evaluations/resharding/llama-405b_generate_1-64-1.yaml b/checkpoint/orbax/checkpoint/_src/testing/benchmarks/configs/evaluations/resharding/llama-405b_generate_1-64-1.yaml new file mode 100644 index 000000000..8492da058 --- /dev/null +++ b/checkpoint/orbax/checkpoint/_src/testing/benchmarks/configs/evaluations/resharding/llama-405b_generate_1-64-1.yaml @@ -0,0 +1,27 @@ +# The name for the entire test suite run. +# Assumes v5p-128 (64 chips) +suite_name: "Llama 3.1 405B" +num_repeats: 1 + + +mesh_config: + mesh_axes: ["data", "fsdp", "tensor"] + # Should match reference_sharding_path. + ici_parallelism: {"data": 1, "fsdp": 64, "tensor": 1} + +# Note: checkpoint_config field not specified. +checkpoint_config: + path: "gs://orbax-benchmarks/checkpoints/llama-3.1-405B-checkpoints/0/items" + sharding_config_path: "gs://orbax-benchmarks/sharding-configs/llama3.1-405b-v5p-128-data-1-fsdp-64-tensor-1/abstract_state.json" + +benchmarks: + - generator: "orbax.checkpoint._src.testing.benchmarks.v1.benchmark.Benchmark" + options: + # --- Generator Options --- + # These keys must match the attributes of the `BenchmarkOptions` class + # associated with the `Benchmark` generator. + async_enabled: true + use_ocdbt: true + use_zarr3: true + use_replica_parallel: false + use_compression: true diff --git a/checkpoint/orbax/checkpoint/_src/testing/benchmarks/configs/evaluations/resharding/llama-70b_generate_1-16-1.yaml b/checkpoint/orbax/checkpoint/_src/testing/benchmarks/configs/evaluations/resharding/llama-70b_generate_1-16-1.yaml new file mode 100644 index 000000000..daa3f5171 --- /dev/null +++ b/checkpoint/orbax/checkpoint/_src/testing/benchmarks/configs/evaluations/resharding/llama-70b_generate_1-16-1.yaml @@ -0,0 +1,27 @@ +# The name for the entire test suite run. +# Assumes v5p-32 (16 chips) +suite_name: "Llama 3.1 70B" +num_repeats: 1 + + +mesh_config: + mesh_axes: ["data", "fsdp", "tensor"] + # Should match reference_sharding_path. + ici_parallelism: {"data": 1, "fsdp": 16, "tensor": 1} + +# Note: checkpoint_config field not specified. +checkpoint_config: + path: "gs://orbax-benchmarks/checkpoints/llama-3.1-70B-checkpoints/0/items" + sharding_config_path: "gs://orbax-benchmarks/sharding-configs/llama3.1-70b-v5p-32-data-1-fsdp-16-tensor-1/abstract_state.json" + +benchmarks: + - generator: "orbax.checkpoint._src.testing.benchmarks.v1.benchmark.Benchmark" + options: + # --- Generator Options --- + # These keys must match the attributes of the `BenchmarkOptions` class + # associated with the `Benchmark` generator. + async_enabled: true + use_ocdbt: true + use_zarr3: true + use_replica_parallel: false + use_compression: true diff --git a/checkpoint/orbax/checkpoint/_src/testing/benchmarks/configs/evaluations/resharding/llama-8b_generate_1-4-1.yaml b/checkpoint/orbax/checkpoint/_src/testing/benchmarks/configs/evaluations/resharding/llama-8b_generate_1-4-1.yaml new file mode 100644 index 000000000..a3e3e4228 --- /dev/null +++ b/checkpoint/orbax/checkpoint/_src/testing/benchmarks/configs/evaluations/resharding/llama-8b_generate_1-4-1.yaml @@ -0,0 +1,27 @@ +# The name for the entire test suite run. +# Assumes v5p-8 (4 chips) +suite_name: "Llama 3.1 8B" +num_repeats: 1 + + +mesh_config: + mesh_axes: ["data", "fsdp", "tensor"] + # Should match reference_sharding_path. + ici_parallelism: {"data": 1, "fsdp": 4, "tensor": 1} + +# Note: checkpoint_config field not specified. +checkpoint_config: + path: "gs://orbax-benchmarks/checkpoints/llama-3.1-8B-checkpoints/0/items" + sharding_config_path: "gs://orbax-benchmarks/sharding-configs/llama3.1-8b-v5p-8-data-1-fsdp-4-tensor-1/abstract_state.json" + +benchmarks: + - generator: "orbax.checkpoint._src.testing.benchmarks.v1.benchmark.Benchmark" + options: + # --- Generator Options --- + # These keys must match the attributes of the `BenchmarkOptions` class + # associated with the `Benchmark` generator. + async_enabled: true + use_ocdbt: true + use_zarr3: true + use_replica_parallel: false + use_compression: true diff --git a/checkpoint/orbax/checkpoint/_src/testing/benchmarks/run_benchmarks.py b/checkpoint/orbax/checkpoint/_src/testing/benchmarks/run_benchmarks.py index c470b0f85..bba23ee11 100644 --- a/checkpoint/orbax/checkpoint/_src/testing/benchmarks/run_benchmarks.py +++ b/checkpoint/orbax/checkpoint/_src/testing/benchmarks/run_benchmarks.py @@ -29,7 +29,9 @@ from absl import logging from etils import epath import jax +from orbax.checkpoint._src.multihost import multihost from orbax.checkpoint._src.testing.benchmarks.core import config_parsing +import pathwaysutils # Core Flags @@ -184,7 +186,15 @@ def main(argv: List[str]) -> None: logging.info('run_benchmarks.py started.') - _init_jax_distributed() + pathwaysutils.initialize() + + if multihost.is_pathways_backend(): + logging.info('Detected Pathways backend.') + else: + logging.info( + 'Detected non-Pathways backend. Initialize JAX distributed system.' + ) + _init_jax_distributed() if _ENABLE_HLO_DUMP.value: _configure_hlo_dump(_OUTPUT_DIRECTORY.value) diff --git a/checkpoint/orbax/checkpoint/_src/testing/benchmarks/xpk/launch_xpk.py b/checkpoint/orbax/checkpoint/_src/testing/benchmarks/xpk/launch_xpk.py index 80ad89b51..3ab5fb548 100644 --- a/checkpoint/orbax/checkpoint/_src/testing/benchmarks/xpk/launch_xpk.py +++ b/checkpoint/orbax/checkpoint/_src/testing/benchmarks/xpk/launch_xpk.py @@ -610,11 +610,10 @@ def construct_workload_command( v_level: int | None, ) -> str: """Constructs the command to run inside the workload.""" - # Environment variables + # Environment variables. if enable_pathways: env_vars = [ - 'export JAX_PLATFORMS=proxy', - 'export ENABLE_PATHWAYS_PERSISTENCE=1', + 'export JAX_PLATFORMS=tpu,cpu', 'export ENABLE_PJRT_COMPATIBILITY=true', ] else: @@ -656,12 +655,13 @@ def construct_workload_command( python_cmd = ' '.join(python_args) if hardware_type == HardwareType.CPU: python_cmd += ' --jax_cpu_collectives_implementation=gloo' - if enable_pathways: - python_cmd = ( - 'python3 -c "import pathwaysutils;' - ' pathwaysutils.initialize()" && ' - + python_cmd - ) + # if enable_pathways: + # python_cmd = ( + # 'echo "Initializing Pathways" && python3 -c "import pathwaysutils;' + # ' import logging; logging.basicConfig(level=logging.DEBUG);' + # ' pathwaysutils.initialize()" && ' + # + python_cmd + # ) return f'{env_cmd}{python_cmd}' @@ -730,7 +730,7 @@ def construct_xpk_command( image_args = [ f'--server-image={_PATHWAYS_SERVER_IMAGE.value}', f'--proxy-server-image={_PATHWAYS_PROXY_IMAGE.value}', - f'--docker-image={_DOCKER_IMAGE.value}' + f'--docker-image={_DOCKER_IMAGE.value}', ] else: diff --git a/checkpoint/orbax/checkpoint/experimental/v1/_src/context/options.py b/checkpoint/orbax/checkpoint/experimental/v1/_src/context/options.py index 5f75fb1ff..1e655204d 100644 --- a/checkpoint/orbax/checkpoint/experimental/v1/_src/context/options.py +++ b/checkpoint/orbax/checkpoint/experimental/v1/_src/context/options.py @@ -491,7 +491,9 @@ class PathwaysOptions: checkpointing_impl: The implementation to use for Pathways checkpointing. """ - checkpointing_impl: pathways_types.CheckpointingImpl | None = None + checkpointing_impl: pathways_types.CheckpointingImpl | None = ( + pathways_types.CheckpointingImpl.COLOCATED_PYTHON + ) class CheckpointLayout(enum.Enum): diff --git a/checkpoint/orbax/checkpoint/experimental/v1/_src/serialization/registration.py b/checkpoint/orbax/checkpoint/experimental/v1/_src/serialization/registration.py index 022a22e12..8b573c407 100644 --- a/checkpoint/orbax/checkpoint/experimental/v1/_src/serialization/registration.py +++ b/checkpoint/orbax/checkpoint/experimental/v1/_src/serialization/registration.py @@ -26,6 +26,7 @@ Pathways dependencies should not be added to this file. """ +from absl import logging from orbax.checkpoint._src.serialization import jax_array_handlers from orbax.checkpoint._src.serialization import pathways_handler_registry from orbax.checkpoint._src.serialization import pathways_types @@ -57,11 +58,15 @@ def resolve_pathways_checkpointing_impl( except ImportError as e: raise ImportError(_PATHWAYS_IMPORT_ERROR_MSG) from e checkpointing_impl = context.pathways_options.checkpointing_impl - return checkpointing_impl or pathways_types.CheckpointingImpl.from_options( + resolved_checkpointing_impl = checkpointing_impl or pathways_types.CheckpointingImpl.from_options( use_colocated_python=False, # Not enabled unless explicitly requested. use_remote_python=rp.available(), use_persistence_array_handler=True, # Only used as a fallback. ) + logging.info( + 'Resolved Pathways implementation: %s', resolved_checkpointing_impl + ) + return resolved_checkpointing_impl def get_array_handler(