From 881dde8c90aaea7cb96deb994caa8e52a1ca52f4 Mon Sep 17 00:00:00 2001 From: Andy Ye Date: Sun, 26 Oct 2025 16:45:15 -0400 Subject: [PATCH 1/9] Update branch --- src/MaxText/checkpointing.py | 2 ++ src/MaxText/configs/base.yml | 3 +++ .../input_pipeline/_hf_data_processing.py | 12 ++++++---- src/MaxText/layers/attention_op.py | 4 ++-- src/MaxText/layers/attentions.py | 3 ++- src/MaxText/layers/moe.py | 23 ++++++++++++++++++- src/MaxText/layers/quantizations.py | 3 +++ src/MaxText/train.py | 1 + src/MaxText/train_utils.py | 1 + 9 files changed, 44 insertions(+), 8 deletions(-) diff --git a/src/MaxText/checkpointing.py b/src/MaxText/checkpointing.py index 7e9df4802f..6448d4f457 100644 --- a/src/MaxText/checkpointing.py +++ b/src/MaxText/checkpointing.py @@ -185,6 +185,7 @@ def create_orbax_checkpoint_manager( orbax_logger: Any = None, # pytype: disable=attribute-error use_ocdbt: bool = True, use_zarr3: bool = True, + max_to_keep: int = 5, ): """Returns specified Orbax (async or not) CheckpointManager or None if checkpointing is disabled.""" if not enable_checkpointing: @@ -213,6 +214,7 @@ def create_orbax_checkpoint_manager( create=True, save_interval_steps=save_interval_steps, enable_async_checkpointing=use_async, + max_to_keep = max_to_keep, ), logger=orbax_logger, ) diff --git a/src/MaxText/configs/base.yml b/src/MaxText/configs/base.yml index 5be8df2639..56b6b8f2e2 100644 --- a/src/MaxText/configs/base.yml +++ b/src/MaxText/configs/base.yml @@ -958,3 +958,6 @@ partial_rotary_factor: 1.0 # Use tokamax library for gmm kernel implementation use_tokamax_gmm: false use_tokamax_splash: false + +expert_balance: False +max_to_keep: 5 diff --git a/src/MaxText/input_pipeline/_hf_data_processing.py b/src/MaxText/input_pipeline/_hf_data_processing.py index e056cd972e..82d4641c73 100644 --- a/src/MaxText/input_pipeline/_hf_data_processing.py +++ b/src/MaxText/input_pipeline/_hf_data_processing.py @@ -192,6 +192,7 @@ def preprocessing_pipeline( use_sft=None, sft_train_on_completion_only=True, grain_worker_count=1, # only support 0 or 1 + max_segments_per_seq = 1, # max segments per sequence ): """pipeline for preprocessing HF dataset""" @@ -298,10 +299,11 @@ def lists2array(x): if packing and not use_dpo: length_struct = {col: max_target_length for col in data_column_names} operations.append( - grain.experimental.PackAndBatchOperation( - batch_size=global_batch_size // jax.process_count(), - length_struct=length_struct, - ) + grain.experimental.PackAndBatchOperation( + batch_size=global_batch_size // jax.process_count(), + length_struct=length_struct, + max_sequences_per_bin=max_segments_per_seq, + ) ) operations.append(_input_pipeline_utils.ReformatPacking(data_column_names)) else: @@ -386,6 +388,7 @@ def make_hf_train_iterator( use_sft=config.use_sft, sft_train_on_completion_only=config.sft_train_on_completion_only, chat_template_path=config.chat_template_path, + max_segments_per_seq=config.max_segments_per_seq, ) return train_iter @@ -437,5 +440,6 @@ def make_hf_eval_iterator( use_sft=config.use_sft, sft_train_on_completion_only=config.sft_train_on_completion_only, chat_template_path=config.chat_template_path, + max_segments_per_seq=config.max_segments_per_seq, ) return eval_iter diff --git a/src/MaxText/layers/attention_op.py b/src/MaxText/layers/attention_op.py index c8924e6084..a9c172c79e 100644 --- a/src/MaxText/layers/attention_op.py +++ b/src/MaxText/layers/attention_op.py @@ -1372,7 +1372,7 @@ def cudnn_flash_attention( dummy_segment_ids = jnp.ones(shape=query.shape[:2], dtype=jnp.int32) dummy_attn_mask = SequenceDescriptor.from_segment_ids_and_pos(segment_ids=dummy_segment_ids, segment_pos=None) max_segments_per_seq = self.config.max_segments_per_seq - elif using_context_parallelism: + elif using_context_parallelism and self.config.dataset_type != "synthetic": if self.attention_type == AttentionType.LOCAL_SLIDING: raise AssertionError("Sliding window attention is not supported for context parallelism") # Context parallelism without packing: only supports causal masking @@ -1396,7 +1396,7 @@ def cudnn_flash_attention( dtype=self.dtype, float32_logits=self.float32_logits, qkv_layout=qkv_layout, - scale_factor=1.0, + # scale_factor=1.0, transpose_batch_sequence=False, window_size=sliding_window_size, context_parallel_causal_load_balanced=self.config.context_parallel_load_balance, diff --git a/src/MaxText/layers/attentions.py b/src/MaxText/layers/attentions.py index 0586953de0..ad61f89532 100644 --- a/src/MaxText/layers/attentions.py +++ b/src/MaxText/layers/attentions.py @@ -548,7 +548,8 @@ def init_query_w(self, inputs_q_shape: Tuple) -> nnx.Module: if self.config.use_qk_norm or (self.query_pre_attn_scalar is not None and self.query_pre_attn_scalar != 1.0): depth_scaling = 1.0 else: - depth_scaling = jnp.sqrt(self.head_dim).astype(self.dtype) + # depth_scaling = jnp.sqrt(self.head_dim).astype(self.dtype) + depth_scaling = 1.0 def query_init(*args): # pylint: disable=no-value-for-parameter diff --git a/src/MaxText/layers/moe.py b/src/MaxText/layers/moe.py index 2efcd22361..11c4a71f27 100644 --- a/src/MaxText/layers/moe.py +++ b/src/MaxText/layers/moe.py @@ -1445,7 +1445,7 @@ def get_einsum( def aqt_einsum(*args, **kwargs): # pylint: disable=unused-argument # simply skip kwargs, since aqt einsum doesn't support any kwargs # like precision - is_aqt = not isinstance(self.quant, quantizations.Fp8Quantization) + is_aqt = not ( isinstance(self.quant, quantizations.Fp8Quantization) or isinstance(self.quant, quantizations.NANOOFp8Quantization) ) kw = {"mesh_axes": rhs_mesh_axes} if is_aqt else {"dtype": self.dtype} return self.quant.einsum(**kw)(*args) # pytype: disable=attribute-error @@ -1480,6 +1480,27 @@ def dense_matmul( wo_bias, ) -> tuple[jax.Array, Optional[jax.Array]]: """Dense matrix multiplication.""" + if self.config.expert_balance: + ###################################################################################################### + ############################## start hard code for uniform expert #################################### + # Create deterministic rotational pattern for gate logits + batch_size, seq_len, num_experts = gate_logits.shape + + # Create base weights for experts (increasing values) + base_weights = jnp.linspace(0.1, 0.1 * num_experts, num_experts, dtype=gate_logits.dtype) + + # Create position-based indices matrix [seq_len, num_experts] + # Each row represents which index in base_weights to use after rotation + indices = (jnp.arange(num_experts)[None, :] + jnp.arange(seq_len)[:, None]) % num_experts + + # Use advanced indexing to create the rotated weights matrix in one operation + # This takes the appropriate weight for each position based on the rotation pattern + rotated_weights = base_weights[indices] + + # Broadcast to batch dimension + gate_logits = jnp.broadcast_to(rotated_weights[None, :, :], (batch_size, seq_len, num_experts)) + ############################################# end #################################################### + ###################################################################################################### # gate_logits: batch, length, expert gate_logits = nn.with_logical_constraint(gate_logits, ("activation_batch", "activation_norm_length", None)) if self.config.model_name.startswith("deepseek3"): diff --git a/src/MaxText/layers/quantizations.py b/src/MaxText/layers/quantizations.py index d0f9353b6c..22095fa960 100644 --- a/src/MaxText/layers/quantizations.py +++ b/src/MaxText/layers/quantizations.py @@ -296,6 +296,9 @@ def dot_general_cls(self, mesh_axes: Tuple[str, ...] = ()): """Returns dot_general configured with aqt params.""" return nn.NANOOFp8DotGeneralOp + def einsum(self, dtype: DType = jnp.float32): + return Fp8Einsum(dtype=dtype,e4m3_dtype=jnp.float8_e4m3fnuz,e5m2_dtype=jnp.float8_e5m2fnuz) + def _get_int8_quant_config(config): drhs_bits = None diff --git a/src/MaxText/train.py b/src/MaxText/train.py index fb5fcebdbf..998c5a4e92 100644 --- a/src/MaxText/train.py +++ b/src/MaxText/train.py @@ -438,6 +438,7 @@ def train_loop(config, recorder, state=None): if config.shard_optimizer_over_data: state = sharding.maybe_shard_with_name(state, state_mesh_shardings, config.shard_mode) state, metrics = p_train_step(state, example_batch, nextrng) + jax.block_until_ready(state) step_time_delta = datetime.datetime.now() - last_step_completion last_step_completion = datetime.datetime.now() diff --git a/src/MaxText/train_utils.py b/src/MaxText/train_utils.py index ce465a56e0..84bee76ccd 100644 --- a/src/MaxText/train_utils.py +++ b/src/MaxText/train_utils.py @@ -73,6 +73,7 @@ def create_training_tools(config, model, mesh): logger, use_ocdbt, use_zarr3, + config.max_to_keep, ) return init_rng, checkpoint_manager, learning_rate_schedule, tx From 55126816a978f8d6168b5b23f9f5b243838b5034 Mon Sep 17 00:00:00 2001 From: Andy Ye Date: Mon, 1 Dec 2025 18:11:54 -0500 Subject: [PATCH 2/9] Add logging --- src/MaxText/max_utils.py | 31 +++++++++++++++++++++++++++++++ src/MaxText/train.py | 1 + 2 files changed, 32 insertions(+) diff --git a/src/MaxText/max_utils.py b/src/MaxText/max_utils.py index 1be2b48c8e..155d654413 100644 --- a/src/MaxText/max_utils.py +++ b/src/MaxText/max_utils.py @@ -19,6 +19,7 @@ from collections.abc import Sequence import functools from functools import partial +import json import os import socket import subprocess @@ -708,6 +709,36 @@ def print_system_information(): max_logging.log(f"System Information: Jaxlib Version: {jax.lib.__version__}") max_logging.log(f"System Information: Jax Backend: {jax.extend.backend.get_backend().platform_version}") + devices = jax.devices() + max_logging.log(f"System Information: Number of devices: {len(devices)}, jax path {jax.__file__}") + for i, device in enumerate(devices): + if device.local_hardware_id is not None: + max_logging.log( + f"System Information: Device {i}: {device.id} " + f"(Local id: {device.local_hardware_id}, Process index: {device.process_index})" + ) + + +def save_device_information(config): + """Convert device information to JSON format.""" + devices = jax.devices() + device_info = {'hostname': socket.gethostname(), 'devices': []} + + for device in devices: + if device.local_hardware_id is not None: + info = { + "id": device.id, + "local_hardware_id": device.local_hardware_id, + "process_index": device.process_index, + "device_kind": device.device_kind, + "platform_version": jax.extend.backend.get_backend().platform_version, + } + device_info['devices'].append(info) + # Save to JSON file + device_info_path = os.path.join(config.base_output_directory, "device_info.json") + with open(device_info_path, "w") as f: + json.dump(device_info, f, indent=4) + def permute_to_match_maxtext_rope(arr): """Permutes the Huggingface Rope to match the MaxText logic.""" diff --git a/src/MaxText/train.py b/src/MaxText/train.py index 998c5a4e92..605cbbdbbf 100644 --- a/src/MaxText/train.py +++ b/src/MaxText/train.py @@ -515,6 +515,7 @@ def initialize(argv: Sequence[str]) -> tuple[pyconfig.HyperParameters, Any, Any] config = pyconfig.initialize(argv) max_utils.print_system_information() validate_train_config(config) + max_utils.save_device_information(config) jax.config.update("jax_use_shardy_partitioner", config.shardy) # update explicit sharding-supported config if config.shard_mode == ShardMode.EXPLICIT: From 64a960317b29e199bc8ef9fcd864fc17730cd29f Mon Sep 17 00:00:00 2001 From: Andy Ye Date: Tue, 2 Dec 2025 18:40:54 -0500 Subject: [PATCH 3/9] Update --- src/MaxText/configs/base.yml | 6 ++---- src/MaxText/configs/types.py | 1 + src/MaxText/layers/attention_op.py | 2 +- src/MaxText/layers/attentions.py | 3 +-- src/MaxText/train_utils.py | 2 +- 5 files changed, 6 insertions(+), 8 deletions(-) diff --git a/src/MaxText/configs/base.yml b/src/MaxText/configs/base.yml index 56b6b8f2e2..1ab1605853 100644 --- a/src/MaxText/configs/base.yml +++ b/src/MaxText/configs/base.yml @@ -176,6 +176,7 @@ megablox: True sparse_matmul: True capacity_factor: -1.0 # a factor to decide expert capacity for token dropping, and no dropping by default load_balance_loss_weight: 0.01 # weight for the load balance loss +expert_balance: False # whether or not to do expert balancing use_random_routing: False # whether to use random routing for debug/test purpose use_custom_sort_vjp: True # whether to use a custom sort vjp for sparse matmul ops use_ring_of_experts: False # whether to use ring of experts for sparse matmul expert parallelism @@ -957,7 +958,4 @@ partial_rotary_factor: 1.0 # Use tokamax library for gmm kernel implementation use_tokamax_gmm: false -use_tokamax_splash: false - -expert_balance: False -max_to_keep: 5 +use_tokamax_splash: false \ No newline at end of file diff --git a/src/MaxText/configs/types.py b/src/MaxText/configs/types.py index 4f24065341..b81699e5f0 100644 --- a/src/MaxText/configs/types.py +++ b/src/MaxText/configs/types.py @@ -525,6 +525,7 @@ class MoEGeneral(BaseModel): num_experts_per_tok: PositiveInt = Field(1, description="The number of experts to route each token to.") capacity_factor: float = Field(-1.0, description="Expert capacity factor. If < 0, no token dropping.") load_balance_loss_weight: NonNegativeFloat = Field(0.01, description="Weight for the load balancing auxiliary loss.") + expert_balance: bool = Field(False, description="Whether to use expert balancing.") use_custom_sort_vjp: bool = Field(True, description="Whether to use a custom sort VJP for sparse matmul ops.") use_ring_of_experts: bool = Field( False, description="Whether to use Ring of Experts for sparse matmul expert parallelism." diff --git a/src/MaxText/layers/attention_op.py b/src/MaxText/layers/attention_op.py index a9c172c79e..77f45b1cb9 100644 --- a/src/MaxText/layers/attention_op.py +++ b/src/MaxText/layers/attention_op.py @@ -1396,7 +1396,7 @@ def cudnn_flash_attention( dtype=self.dtype, float32_logits=self.float32_logits, qkv_layout=qkv_layout, - # scale_factor=1.0, + scale_factor=1.0, transpose_batch_sequence=False, window_size=sliding_window_size, context_parallel_causal_load_balanced=self.config.context_parallel_load_balance, diff --git a/src/MaxText/layers/attentions.py b/src/MaxText/layers/attentions.py index ad61f89532..0586953de0 100644 --- a/src/MaxText/layers/attentions.py +++ b/src/MaxText/layers/attentions.py @@ -548,8 +548,7 @@ def init_query_w(self, inputs_q_shape: Tuple) -> nnx.Module: if self.config.use_qk_norm or (self.query_pre_attn_scalar is not None and self.query_pre_attn_scalar != 1.0): depth_scaling = 1.0 else: - # depth_scaling = jnp.sqrt(self.head_dim).astype(self.dtype) - depth_scaling = 1.0 + depth_scaling = jnp.sqrt(self.head_dim).astype(self.dtype) def query_init(*args): # pylint: disable=no-value-for-parameter diff --git a/src/MaxText/train_utils.py b/src/MaxText/train_utils.py index 84bee76ccd..ad17cfb971 100644 --- a/src/MaxText/train_utils.py +++ b/src/MaxText/train_utils.py @@ -73,7 +73,7 @@ def create_training_tools(config, model, mesh): logger, use_ocdbt, use_zarr3, - config.max_to_keep, + config.max_num_checkpoints_to_keep, ) return init_rng, checkpoint_manager, learning_rate_schedule, tx From 6183ee187b4112c1e2781c742b9a12b0510dc2fc Mon Sep 17 00:00:00 2001 From: Andy Ye Date: Tue, 2 Dec 2025 21:52:36 -0500 Subject: [PATCH 4/9] remove context_parallel_strategy --- src/MaxText/layers/attention_op.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/MaxText/layers/attention_op.py b/src/MaxText/layers/attention_op.py index 77f45b1cb9..f0da0a4e7b 100644 --- a/src/MaxText/layers/attention_op.py +++ b/src/MaxText/layers/attention_op.py @@ -1401,7 +1401,7 @@ def cudnn_flash_attention( window_size=sliding_window_size, context_parallel_causal_load_balanced=self.config.context_parallel_load_balance, context_parallel_axis="context", - context_parallel_strategy=self.config.context_parallel_strategy, + # context_parallel_strategy=self.config.context_parallel_strategy, max_segments_per_seq=max_segments_per_seq, ) From 6f2d8a0251b158a7cdcd281900f983704c218f06 Mon Sep 17 00:00:00 2001 From: Andy Ye Date: Wed, 3 Dec 2025 12:23:39 -0500 Subject: [PATCH 5/9] update using causal mask for synthetic --- src/MaxText/layers/attention_op.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/MaxText/layers/attention_op.py b/src/MaxText/layers/attention_op.py index f0da0a4e7b..dee93f0ef8 100644 --- a/src/MaxText/layers/attention_op.py +++ b/src/MaxText/layers/attention_op.py @@ -1372,7 +1372,7 @@ def cudnn_flash_attention( dummy_segment_ids = jnp.ones(shape=query.shape[:2], dtype=jnp.int32) dummy_attn_mask = SequenceDescriptor.from_segment_ids_and_pos(segment_ids=dummy_segment_ids, segment_pos=None) max_segments_per_seq = self.config.max_segments_per_seq - elif using_context_parallelism and self.config.dataset_type != "synthetic": + elif using_context_parallelism or self.config.dataset_type == "synthetic": if self.attention_type == AttentionType.LOCAL_SLIDING: raise AssertionError("Sliding window attention is not supported for context parallelism") # Context parallelism without packing: only supports causal masking From c236a999f9eb95e62769ccb088c69267a087417e Mon Sep 17 00:00:00 2001 From: Andy Ye Date: Fri, 30 Jan 2026 18:15:18 -0500 Subject: [PATCH 6/9] Add missing query_pre_attn_scalar --- src/MaxText/layers/gemma.py | 1 + src/MaxText/layers/gemma2.py | 2 ++ src/MaxText/layers/llama2.py | 1 + src/MaxText/layers/mistral.py | 1 + src/MaxText/layers/mixtral.py | 1 + 5 files changed, 6 insertions(+) diff --git a/src/MaxText/layers/gemma.py b/src/MaxText/layers/gemma.py index dcd237162c..90f2e1366d 100644 --- a/src/MaxText/layers/gemma.py +++ b/src/MaxText/layers/gemma.py @@ -88,6 +88,7 @@ def __init__( kv_quant=quantizations.configure_kv_quant(config), use_ragged_attention=config.use_ragged_attention, ragged_block_size=config.ragged_block_size, + query_pre_attn_scalar=(config.head_dim**-0.5), model_mode=self.model_mode, rngs=self.rngs, ) diff --git a/src/MaxText/layers/gemma2.py b/src/MaxText/layers/gemma2.py index 3d0d39efeb..ef8e59c2da 100644 --- a/src/MaxText/layers/gemma2.py +++ b/src/MaxText/layers/gemma2.py @@ -90,6 +90,7 @@ def __init__( attention_type=attentions.AttentionType.LOCAL_SLIDING, sliding_window_size=config.sliding_window_size, attn_logits_soft_cap=config.attn_logits_soft_cap, + query_pre_attn_scalar=(config.head_dim**-0.5), model_mode=self.model_mode, rngs=self.rngs, ) @@ -164,6 +165,7 @@ def __init__( kv_quant=quantizations.configure_kv_quant(config), attention_type=attentions.AttentionType.GLOBAL, attn_logits_soft_cap=config.attn_logits_soft_cap, + query_pre_attn_scalar=(config.head_dim**-0.5), model_mode=model_mode, rngs=self.rngs, ) diff --git a/src/MaxText/layers/llama2.py b/src/MaxText/layers/llama2.py index 7148b1a5b9..d74b081b69 100644 --- a/src/MaxText/layers/llama2.py +++ b/src/MaxText/layers/llama2.py @@ -101,6 +101,7 @@ def __init__( reshape_q=config.reshape_q, use_ragged_attention=config.use_ragged_attention, ragged_block_size=config.ragged_block_size, + query_pre_attn_scalar=(config.head_dim**-0.5), model_mode=model_mode, rngs=rngs, ) diff --git a/src/MaxText/layers/mistral.py b/src/MaxText/layers/mistral.py index 643fecaaeb..5dd8d8ab31 100644 --- a/src/MaxText/layers/mistral.py +++ b/src/MaxText/layers/mistral.py @@ -91,6 +91,7 @@ def __init__( reshape_q=config.reshape_q, use_ragged_attention=config.use_ragged_attention, ragged_block_size=config.ragged_block_size, + query_pre_attn_scalar=(config.head_dim**-0.5), model_mode=model_mode, rngs=self.rngs, ) diff --git a/src/MaxText/layers/mixtral.py b/src/MaxText/layers/mixtral.py index 8d23e72d3f..90c69115d7 100644 --- a/src/MaxText/layers/mixtral.py +++ b/src/MaxText/layers/mixtral.py @@ -96,6 +96,7 @@ def __init__( reshape_q=config.reshape_q, use_ragged_attention=config.use_ragged_attention, ragged_block_size=config.ragged_block_size, + query_pre_attn_scalar=(config.head_dim**-0.5), model_mode=model_mode, rngs=self.rngs, ) From d0a9914f9d16edf77ad045086528cae8c3bf37d6 Mon Sep 17 00:00:00 2001 From: Yi Huang Date: Wed, 28 Jan 2026 06:50:58 +0000 Subject: [PATCH 7/9] fix "No ArrayMetadata found" error when enable_single_replica_ckpt_restoring=true --- src/MaxText/checkpointing.py | 49 ++++++++++++++++++++++++++++++------ 1 file changed, 42 insertions(+), 7 deletions(-) diff --git a/src/MaxText/checkpointing.py b/src/MaxText/checkpointing.py index 6448d4f457..609f17086c 100644 --- a/src/MaxText/checkpointing.py +++ b/src/MaxText/checkpointing.py @@ -502,8 +502,21 @@ def load_state_if_possible( if step is not None: max_logging.log(f"restoring from this run's directory step {step}") + # Check whether the mesh actually has multiple replicas along axis 0. + # If all devices are in a single replica, SingleReplicaArrayHandler + # raises InvalidShardingError — fall back to normal restore. + _sr_effective = enable_single_replica_ckpt_restoring + if _sr_effective: + _first_leaf = jax.tree_util.tree_leaves(abstract_unboxed_pre_state)[0] + if _first_leaf.sharding.mesh.devices.shape[0] <= 1: + max_logging.log( + "enable_single_replica_ckpt_restoring=True but mesh has only 1 " + f"replica (shape[0]={_first_leaf.sharding.mesh.devices.shape[0]}). " + "Falling back to normal all-replica restore.") + _sr_effective = False + def map_to_pspec(data): - if not enable_single_replica_ckpt_restoring: + if not _sr_effective: return ocp.type_handlers.ArrayRestoreArgs(sharding=data.sharding) pspec = data.sharding.spec mesh = data.sharding.mesh @@ -519,16 +532,32 @@ def map_to_pspec(data): dtype=data.dtype, ) - if enable_single_replica_ckpt_restoring: - array_handler = ocp.type_handlers.SingleReplicaArrayHandler( + # Cache the original ArrayHandler before potentially overriding it. + original_array_handler = ocp.type_handlers.get_type_handler(jax.Array) + + if _sr_effective: + single_replica_handler = ocp.type_handlers.SingleReplicaArrayHandler( replica_axis_index=0, broadcast_memory_limit_bytes=1024 * 1024 * 1000, # 1000 MB limit ) - ocp.type_handlers.register_type_handler(jax.Array, array_handler, override=True) + ocp.type_handlers.register_type_handler(jax.Array, single_replica_handler, override=True) restore_args = jax.tree_util.tree_map(map_to_pspec, abstract_unboxed_pre_state) checkpoint_args = ocp.args.PyTreeRestore(item=abstract_unboxed_pre_state, restore_args=restore_args) + def _restore_original_array_handler(): + """Restore the original ArrayHandler after SingleReplicaArrayHandler restore. + + This is critical because SingleReplicaArrayHandler is designed for restore only. + Using it for saves will cause missing array_metadatas files and checkpoint failures. + We restore the EXACT handler that was in place before, not a new instance. + """ + if _sr_effective: + max_logging.log("Restoring original ArrayHandler after SingleReplicaArrayHandler restore...") + # Re-register the original handler that was cached before the override + ocp.type_handlers.register_type_handler(jax.Array, original_array_handler, override=True) + max_logging.log("Original ArrayHandler restored successfully.") + match (checkpoint_manager, dataset_type, data_iterator): # Case 1: Matches if 'checkpoint_manager' is an instance of either EmergencyCheckpointManager # or EmergencyReplicatorCheckpointManager. The '_' indicates that 'dataset_type' and @@ -536,10 +565,12 @@ def map_to_pspec(data): case (checkpoint_manager, _, _) if isinstance( checkpoint_manager, (EmergencyCheckpointManager, EmergencyReplicatorCheckpointManager) ): - return ( + result = ( checkpoint_manager.restore(step, args=Composite(state=checkpoint_args)).state, None, ) + _restore_original_array_handler() + return result # Case 2: Matches if dataset type is "grain" and the data iterator is not a # PlaceHolderDataIterator and a specific checkpoint file exists for the iterator case ( @@ -552,13 +583,17 @@ def map_to_pspec(data): and not isinstance(data_iterator, PlaceHolderDataIterator) and (checkpoint_manager.directory / str(step) / "iter").exists() ): - return _restore_grain_iterator( + result = _restore_grain_iterator( checkpoint_manager, step, data_iterator, checkpoint_args, expansion_factor_real_data ) + _restore_original_array_handler() + return result # Case 3: Default/Fallback case. # This case acts as a wildcard ('_') and matches if none of the preceding cases were met. case _: - return (checkpoint_manager.restore(step, args=Composite(items=checkpoint_args)), None) + result = (checkpoint_manager.restore(step, args=Composite(items=checkpoint_args)), None) + _restore_original_array_handler() + return result if load_parameters_from_path != "": restored_params = load_params_from_path( From 116fc03e4241ffec5de5f406eae18711ed1038d8 Mon Sep 17 00:00:00 2001 From: Yi Huang Date: Fri, 30 Jan 2026 11:02:22 +0000 Subject: [PATCH 8/9] fix "stop sending heartbeats" error by jax_distributed_heartbeat_timeout_seconds --- src/MaxText/configs/base.yml | 1 + src/MaxText/configs/types.py | 3 +++ src/MaxText/max_utils.py | 13 +++++++++++-- 3 files changed, 15 insertions(+), 2 deletions(-) diff --git a/src/MaxText/configs/base.yml b/src/MaxText/configs/base.yml index 1ab1605853..89e0d823ca 100644 --- a/src/MaxText/configs/base.yml +++ b/src/MaxText/configs/base.yml @@ -617,6 +617,7 @@ log_period: 100 # Flushes Tensorboard jax_distributed_initialization_timeout: 300 # This is the default timeout in https://github.com/jax-ml/jax/blob/main/jax/_src/distributed.py # Note there are two separate initializations - the jax coordination service (aka jax.distributed.initialize) and the backend (e.g. PjRT), the timeout above refers # only to the jax coordination service. +jax_distributed_heartbeat_timeout_seconds: 100 # How long before a missing heartbeat marks a task as dead. Increase for slow NFS checkpoint restores. jax_debug_log_modules: "" # Set this to "jax" to enable jax verbose logging such as for the jax coordination service initialization. skip_jax_distributed_system: False # If True we will not initialize the jax distributed system. # Currently the jax distributed is needed on cloud TPUs for async checkpointing. diff --git a/src/MaxText/configs/types.py b/src/MaxText/configs/types.py index b81699e5f0..9e00e13e84 100644 --- a/src/MaxText/configs/types.py +++ b/src/MaxText/configs/types.py @@ -1093,6 +1093,9 @@ class DevelopmentAndDebugging(BaseModel): description="Directory for JAX compilation cache.", ) jax_distributed_initialization_timeout: int = Field(300, description="Timeout for jax.distributed.initialize.") + jax_distributed_heartbeat_timeout_seconds: int = Field( + 100, description="How long before a missing heartbeat marks a task as dead. Increase for slow NFS checkpoint restores." + ) jax_debug_log_modules: str = Field("", description="Set to 'jax' for verbose JAX logging.") skip_jax_distributed_system: bool = Field(False, description="If True, do not initialize the jax distributed system.") enable_single_controller: bool = Field(False, description="Enable single-controller mode (Pathways).") diff --git a/src/MaxText/max_utils.py b/src/MaxText/max_utils.py index 155d654413..6941401ae5 100644 --- a/src/MaxText/max_utils.py +++ b/src/MaxText/max_utils.py @@ -201,11 +201,17 @@ def maybe_initialize_jax_distributed_system(raw_keys): ] == "gpu_multiprocess": max_logging.log("Attempting to initialize the jax distributed system...") if not raw_keys["enable_emergency_checkpoint"]: - jax.distributed.initialize(initialization_timeout=raw_keys["jax_distributed_initialization_timeout"]) + jax.distributed.initialize( + initialization_timeout=raw_keys["jax_distributed_initialization_timeout"], + heartbeat_timeout_seconds=raw_keys["jax_distributed_heartbeat_timeout_seconds"], + ) else: if raw_keys["hardware"] == "gpu_multiprocess": max_logging.log("Initializing jax distribtued to support local checkpointing with" " GPUs...") - jax.distributed.initialize(initialization_timeout=raw_keys["jax_distributed_initialization_timeout"]) + jax.distributed.initialize( + initialization_timeout=raw_keys["jax_distributed_initialization_timeout"], + heartbeat_timeout_seconds=raw_keys["jax_distributed_heartbeat_timeout_seconds"], + ) ocp.multihost.initialize_runtime_to_distributed_ids() ocp.multihost.initialize_distributed_to_device_ids() else: @@ -223,6 +229,7 @@ def initialize_jax_for_gpu(raw_keys): num_processes=int(os.getenv("NNODES")), process_id=int(os.getenv("NODE_RANK")), initialization_timeout=raw_keys["jax_distributed_initialization_timeout"], + heartbeat_timeout_seconds=raw_keys["jax_distributed_heartbeat_timeout_seconds"], ) max_logging.log(f"JAX global devices: {jax.devices()}") @@ -243,6 +250,7 @@ def initialize_jax_for_cpu(raw_keys): process_id=pid, num_processes=int(os.environ.get("JAX_PROCESS_COUNT")), initialization_timeout=raw_keys["jax_distributed_initialization_timeout"], + heartbeat_timeout_seconds=raw_keys["jax_distributed_heartbeat_timeout_seconds"], ) @@ -263,6 +271,7 @@ def initialize_jax_for_tpu_with_emergency_checkpointing(raw_keys): coordinator_address=coordinator_address, process_id=int(process_id), initialization_timeout=raw_keys["jax_distributed_initialization_timeout"], + heartbeat_timeout_seconds=raw_keys["jax_distributed_heartbeat_timeout_seconds"], ) ocp.multihost.initialize_runtime_to_distributed_ids() From 851c0935a4e9c9e4cd0771603e3bdf75a6578ad1 Mon Sep 17 00:00:00 2001 From: Yi Huang Date: Mon, 2 Feb 2026 21:56:13 +0000 Subject: [PATCH 9/9] synchronize hosts before training loop --- src/MaxText/train.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/src/MaxText/train.py b/src/MaxText/train.py index 605cbbdbbf..553a75e89f 100644 --- a/src/MaxText/train.py +++ b/src/MaxText/train.py @@ -418,6 +418,16 @@ def train_loop(config, recorder, state=None): # Write train config params, num model params, and XLA flags to tensorboard metric_logger.write_setup_info_to_tensorboard(state.params) + # Synchronize all hosts before entering the training loop. + # Without this barrier, timing variance during initialization (JIT compilation, + # profiler/logger setup, etc.) causes hosts to enter the training loop at different + # times. The first collective operation (data sharding in load_next_batch) then + # times out waiting for straggler hosts, resulting in "collective operation timeout" + # or "stop sending heartbeats" errors. + max_logging.log("====== BARRIER: Synchronizing hosts before training loop ======") + jax.experimental.multihost_utils.sync_global_devices("sync_before_training_loop") + max_logging.log("====== BARRIER PASSED: Starting training loop ======") + try: last_step_completion = datetime.datetime.now() for step in np.arange(start_step, config.steps):