diff --git a/src/maxtext/checkpoint_conversion/reshard_checkpoint.py b/src/maxtext/checkpoint_conversion/reshard_checkpoint.py new file mode 100644 index 0000000000..0cf0714e4a --- /dev/null +++ b/src/maxtext/checkpoint_conversion/reshard_checkpoint.py @@ -0,0 +1,149 @@ +# Copyright 2023–2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +This script re-shards a MaxText checkpoint on CPU, assuming linen format. +- It loads a checkpoint, shards it according to the simulated mesh, and saves to a new checkpoint. +- The goal is to pre-shard checkpoints (source) to accelerate loading speeds on TPUs (target) by reducing re-sharding overhead. + For example, when target sharding is fsdp=64, loading time for checkpoint with source sharding: fsdp=64 < fsdp=16 < ep=16 + +Key Parameters: +- `--simulated_cpu_devices_count` (default to 16). Examples: + - Suitable for most cases: `--simulated_cpu_devices_count=16 ici_fsdp_parallelism=16` + - More customization: `--simulated_cpu_devices_count=32 ici_fsdp_parallelism=16 ici_expert_parallelism=2` +- `weight_dtype`: The dtype used to load and save the checkpoint. Highly recommend using `weight_dtype=bfloat16`. +- `load_parameters_path`: The input checkpoint path (GCS or local). +- `base_output_directory`: The output directory (GCS or local). + - The output checkpoint path will be `/0/items` + +Memory Requirements: +- For a model with X billion parameters, it needs at least 2X GB RAM (each parameter takes 2 bytes with `weight_dtype=bfloat16`). +- Example: DeepSeek-V3 with MTP layers has ~685B parameters and requires at least 1.4 TB of RAM. +- Note: The input checkpoint is re-sharded in place, so we avoid holding two full copies in memory. However, JAX requires additional buffer memory during the resharding operations. + +Example Commands: + +python3 -m maxtext.checkpoint_conversion.reshard_checkpoint \ + model_name=deepseek2-16b attention=dot_product mla_naive_kvcache=false \ + scan_layers=True load_parameters_path= \ + base_output_directory= \ + weight_dtype=bfloat16 \ + checkpoint_storage_concurrent_gb=1024 checkpoint_storage_use_ocdbt=True checkpoint_storage_use_zarr3=True \ + skip_jax_distributed_system=True ici_fsdp_parallelism=16 \ + --simulated_cpu_devices_count=16 + +python3 -m maxtext.checkpoint_conversion.reshard_checkpoint \ + model_name=deepseek3-671b mtp_num_layers=1 mtp_loss_scaling_factor=0.1 attention=dot_product mla_naive_kvcache=false \ + scan_layers=True load_parameters_path= \ + base_output_directory= \ + weight_dtype=bfloat16 \ + checkpoint_storage_concurrent_gb=1024 checkpoint_storage_use_ocdbt=True checkpoint_storage_use_zarr3=True \ + skip_jax_distributed_system=True ici_fsdp_parallelism=16 ici_expert_parallelism=2 \ + --simulated_cpu_devices_count=32 +""" + + +import argparse +import os +import sys +import time +from typing import Sequence +from absl import app + +import jax +from flax.training import train_state + +from maxtext.configs import pyconfig +from maxtext.inference.maxengine import maxengine +from maxtext.utils import max_utils, max_logging +from maxtext.common import checkpointing +from maxtext.checkpoint_conversion.utils.utils import print_peak_memory + + +def main(argv: Sequence[str]) -> None: + config = pyconfig.initialize(argv) + max_utils.print_system_information() + max_logging.log(f"Load and save checkpoint with weight dtype: {config.weight_dtype}") + + # 1. Engine sets up the mesh based on config + engine = maxengine.MaxEngine(config) + rng = jax.random.PRNGKey(1234) + rng, rng_load_params = jax.random.split(rng) + + # 2. Load parameters and reshard with the mesh + start = time.time() + params = engine.load_params(rng_load_params) + max_logging.log(f"Elapse for checkpoint load (with reshard): {(time.time() - start) / 60:.2f} min") + + # 3. Save checkpoint + start = time.time() + save_ckpt_directory = config.base_output_directory + + # Dummy configs for the checkpoint_manager + step_number = 0 + enable_checkpointing = True + async_checkpointing = False + save_interval_steps = 1 + + checkpoint_manager = checkpointing.create_orbax_checkpoint_manager( + save_ckpt_directory, + enable_checkpointing, + async_checkpointing, + save_interval_steps, + use_ocdbt=config.checkpoint_storage_use_ocdbt, + use_zarr3=config.checkpoint_storage_use_zarr3, + ) + if checkpoint_manager is None: + raise RuntimeError("Failed to create Orbax checkpoint manager.") + + state_new = train_state.TrainState( + step=step_number, apply_fn=None, params=params, tx=None, opt_state={} # type: ignore + ) + + if checkpointing.save_checkpoint(checkpoint_manager, step_number, state_new): + save_ckpt_path = os.path.join(save_ckpt_directory, str(step_number), "items") + max_logging.log(f"Saved checkpoint: {save_ckpt_path}") + # Upon preemption, exit when and only when all ongoing saves are complete. + checkpoint_manager.wait_until_finished() + + max_logging.log(f"Elapse for checkpoint save: {(time.time() - start) / 60:.2f} min") + print_peak_memory() + + +if __name__ == "__main__": + jax.config.update("jax_default_prng_impl", "unsafe_rbg") + os.environ["TF_CPP_MIN_LOG_LEVEL"] = "0" # Suppress TensorFlow logging + + # Define local parser + parser = argparse.ArgumentParser() + parser.add_argument( + "--simulated_cpu_devices_count", + type=int, + required=False, + default=16, + help="Number of simulated CPU devices for sharding the checkpoint", + ) + + # Parse known args returns the namespace AND the list of remaining arguments + local_args, remaining_args = parser.parse_known_args() + + # Reconstruct model_args (script name + the args MaxText needs) + model_args = [sys.argv[0]] + remaining_args + + # Set JAX environment + jax.config.update("jax_platforms", "cpu") + # Simulate CPU devices as virtual mesh + os.environ["XLA_FLAGS"] = f"--xla_force_host_platform_device_count={local_args.simulated_cpu_devices_count}" + + app.run(main, argv=model_args) diff --git a/src/maxtext/checkpoint_conversion/standalone_scripts/llama_or_mistral_ckpt.py b/src/maxtext/checkpoint_conversion/standalone_scripts/llama_or_mistral_ckpt.py index c3a8d7d05a..8770db5c3e 100644 --- a/src/maxtext/checkpoint_conversion/standalone_scripts/llama_or_mistral_ckpt.py +++ b/src/maxtext/checkpoint_conversion/standalone_scripts/llama_or_mistral_ckpt.py @@ -1760,17 +1760,18 @@ def save_weights_to_checkpoint( use_ocdbt=use_ocdbt, use_zarr3=use_zarr3, ) + if checkpoint_manager is None: + raise RuntimeError("Failed to create Orbax checkpoint manager.") state_new = train_state.TrainState( - step=0, apply_fn=None, params={"params": jax_weights}, tx=None, opt_state={} # type: ignore + step=step_number_to_save_new_ckpt, apply_fn=None, params={"params": jax_weights}, tx=None, opt_state={} # type: ignore ) logging.debug("Memory usage: %f GB", mem_info.memory_info().rss / (1024**3)) - if checkpoint_manager is not None: - if checkpointing.save_checkpoint(checkpoint_manager, step_number_to_save_new_ckpt, state_new): - max_logging.log(f"saved a checkpoint at step {step_number_to_save_new_ckpt}") - # Upon preemption, exit when and only when all ongoing saves are complete. - checkpoint_manager.wait_until_finished() + if checkpointing.save_checkpoint(checkpoint_manager, step_number_to_save_new_ckpt, state_new): + max_logging.log(f"saved a checkpoint at step {step_number_to_save_new_ckpt}") + # Upon preemption, exit when and only when all ongoing saves are complete. + checkpoint_manager.wait_until_finished() max_logging.log(f"Elapse for checkpoint save: {(time.time() - start) / 60:.2f} min")