Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
149 changes: 149 additions & 0 deletions src/maxtext/checkpoint_conversion/reshard_checkpoint.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,149 @@
# Copyright 2023–2026 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""
This script re-shards a MaxText checkpoint on CPU, assuming linen format.
- It loads a checkpoint, shards it according to the simulated mesh, and saves to a new checkpoint.
- The goal is to pre-shard checkpoints (source) to accelerate loading speeds on TPUs (target) by reducing re-sharding overhead.
For example, when target sharding is fsdp=64, loading time for checkpoint with source sharding: fsdp=64 < fsdp=16 < ep=16

Key Parameters:
- `--simulated_cpu_devices_count` (default to 16). Examples:
- Suitable for most cases: `--simulated_cpu_devices_count=16 ici_fsdp_parallelism=16`
- More customization: `--simulated_cpu_devices_count=32 ici_fsdp_parallelism=16 ici_expert_parallelism=2`
- `weight_dtype`: The dtype used to load and save the checkpoint. Highly recommend using `weight_dtype=bfloat16`.
- `load_parameters_path`: The input checkpoint path (GCS or local).
- `base_output_directory`: The output directory (GCS or local).
- The output checkpoint path will be `<base_output_directory>/0/items`

Memory Requirements:
- For a model with X billion parameters, it needs at least 2X GB RAM (each parameter takes 2 bytes with `weight_dtype=bfloat16`).
- Example: DeepSeek-V3 with MTP layers has ~685B parameters and requires at least 1.4 TB of RAM.
- Note: The input checkpoint is re-sharded in place, so we avoid holding two full copies in memory. However, JAX requires additional buffer memory during the resharding operations.

Example Commands:

python3 -m maxtext.checkpoint_conversion.reshard_checkpoint \
model_name=deepseek2-16b attention=dot_product mla_naive_kvcache=false \
scan_layers=True load_parameters_path=<input_ckpt_path> \
base_output_directory=<out_ckpt_directory> \
weight_dtype=bfloat16 \
checkpoint_storage_concurrent_gb=1024 checkpoint_storage_use_ocdbt=True checkpoint_storage_use_zarr3=True \
skip_jax_distributed_system=True ici_fsdp_parallelism=16 \
--simulated_cpu_devices_count=16

python3 -m maxtext.checkpoint_conversion.reshard_checkpoint \
model_name=deepseek3-671b mtp_num_layers=1 mtp_loss_scaling_factor=0.1 attention=dot_product mla_naive_kvcache=false \
scan_layers=True load_parameters_path=<input_ckpt_path> \
base_output_directory=<out_ckpt_directory> \
weight_dtype=bfloat16 \
checkpoint_storage_concurrent_gb=1024 checkpoint_storage_use_ocdbt=True checkpoint_storage_use_zarr3=True \
skip_jax_distributed_system=True ici_fsdp_parallelism=16 ici_expert_parallelism=2 \
--simulated_cpu_devices_count=32
"""


import argparse
import os
import sys
import time
from typing import Sequence
from absl import app

import jax
from flax.training import train_state

from maxtext.configs import pyconfig
from maxtext.inference.maxengine import maxengine
from maxtext.utils import max_utils, max_logging
from maxtext.common import checkpointing
from maxtext.checkpoint_conversion.utils.utils import print_peak_memory


def main(argv: Sequence[str]) -> None:
config = pyconfig.initialize(argv)
max_utils.print_system_information()
max_logging.log(f"Load and save checkpoint with weight dtype: {config.weight_dtype}")

# 1. Engine sets up the mesh based on config
engine = maxengine.MaxEngine(config)
rng = jax.random.PRNGKey(1234)
rng, rng_load_params = jax.random.split(rng)

# 2. Load parameters and reshard with the mesh
start = time.time()
params = engine.load_params(rng_load_params)
max_logging.log(f"Elapse for checkpoint load (with reshard): {(time.time() - start) / 60:.2f} min")

# 3. Save checkpoint
start = time.time()
save_ckpt_directory = config.base_output_directory

# Dummy configs for the checkpoint_manager
step_number = 0
enable_checkpointing = True
async_checkpointing = False
save_interval_steps = 1

checkpoint_manager = checkpointing.create_orbax_checkpoint_manager(
save_ckpt_directory,
enable_checkpointing,
async_checkpointing,
save_interval_steps,
use_ocdbt=config.checkpoint_storage_use_ocdbt,
use_zarr3=config.checkpoint_storage_use_zarr3,
)
if checkpoint_manager is None:
raise RuntimeError("Failed to create Orbax checkpoint manager.")

state_new = train_state.TrainState(
step=step_number, apply_fn=None, params=params, tx=None, opt_state={} # type: ignore
)

if checkpointing.save_checkpoint(checkpoint_manager, step_number, state_new):
save_ckpt_path = os.path.join(save_ckpt_directory, str(step_number), "items")
max_logging.log(f"Saved checkpoint: {save_ckpt_path}")
# Upon preemption, exit when and only when all ongoing saves are complete.
checkpoint_manager.wait_until_finished()

max_logging.log(f"Elapse for checkpoint save: {(time.time() - start) / 60:.2f} min")
print_peak_memory()


if __name__ == "__main__":
jax.config.update("jax_default_prng_impl", "unsafe_rbg")
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "0" # Suppress TensorFlow logging

# Define local parser
parser = argparse.ArgumentParser()
parser.add_argument(
"--simulated_cpu_devices_count",
type=int,
required=False,
default=16,
help="Number of simulated CPU devices for sharding the checkpoint",
)

# Parse known args returns the namespace AND the list of remaining arguments
local_args, remaining_args = parser.parse_known_args()

# Reconstruct model_args (script name + the args MaxText needs)
model_args = [sys.argv[0]] + remaining_args

# Set JAX environment
jax.config.update("jax_platforms", "cpu")
# Simulate CPU devices as virtual mesh
os.environ["XLA_FLAGS"] = f"--xla_force_host_platform_device_count={local_args.simulated_cpu_devices_count}"

app.run(main, argv=model_args)
Original file line number Diff line number Diff line change
Expand Up @@ -1760,17 +1760,18 @@ def save_weights_to_checkpoint(
use_ocdbt=use_ocdbt,
use_zarr3=use_zarr3,
)
if checkpoint_manager is None:
raise RuntimeError("Failed to create Orbax checkpoint manager.")

state_new = train_state.TrainState(
step=0, apply_fn=None, params={"params": jax_weights}, tx=None, opt_state={} # type: ignore
step=step_number_to_save_new_ckpt, apply_fn=None, params={"params": jax_weights}, tx=None, opt_state={} # type: ignore
)

logging.debug("Memory usage: %f GB", mem_info.memory_info().rss / (1024**3))
if checkpoint_manager is not None:
if checkpointing.save_checkpoint(checkpoint_manager, step_number_to_save_new_ckpt, state_new):
max_logging.log(f"saved a checkpoint at step {step_number_to_save_new_ckpt}")
# Upon preemption, exit when and only when all ongoing saves are complete.
checkpoint_manager.wait_until_finished()
if checkpointing.save_checkpoint(checkpoint_manager, step_number_to_save_new_ckpt, state_new):
max_logging.log(f"saved a checkpoint at step {step_number_to_save_new_ckpt}")
# Upon preemption, exit when and only when all ongoing saves are complete.
checkpoint_manager.wait_until_finished()

max_logging.log(f"Elapse for checkpoint save: {(time.time() - start) / 60:.2f} min")

Expand Down
Loading