Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 29 additions & 5 deletions src/MaxText/checkpointing.py
Original file line number Diff line number Diff line change
Expand Up @@ -547,27 +547,47 @@ def map_to_pspec(data):
dtype=data.dtype,
)

# Cache the original ArrayHandler before potentially overriding it.
# This is the same handler used when enable_single_replica_ckpt_restoring=False.
original_array_handler = ocp.type_handlers.get_type_handler(jax.Array)

# Register SingleReplicaArrayHandler globally for restore (if enabled)
if enable_single_replica_ckpt_restoring:
array_handler = ocp.type_handlers.SingleReplicaArrayHandler(
single_replica_handler = ocp.type_handlers.SingleReplicaArrayHandler(
replica_axis_index=0,
broadcast_memory_limit_bytes=1024 * 1024 * 1000, # 1000 MB limit
)
ocp.type_handlers.register_type_handler(jax.Array, array_handler, override=True)
ocp.type_handlers.register_type_handler(jax.Array, single_replica_handler, override=True)

restore_args = jax.tree_util.tree_map(map_to_pspec, abstract_unboxed_pre_state)
checkpoint_args = ocp.args.PyTreeRestore(item=abstract_unboxed_pre_state, restore_args=restore_args)

def _restore_original_array_handler():
"""Restore the original ArrayHandler after SingleReplicaArrayHandler restore.

This is critical because SingleReplicaArrayHandler is designed for restore only.
Using it for saves will cause missing array_metadatas files and checkpoint failures.
We restore the EXACT handler that was in place before, not a new instance.
"""
if enable_single_replica_ckpt_restoring:
max_logging.log("Restoring original ArrayHandler after SingleReplicaArrayHandler restore...")
# Re-register the original handler that was cached before the override
ocp.type_handlers.register_type_handler(jax.Array, original_array_handler, override=True)
max_logging.log("Original ArrayHandler restored successfully.")

match (checkpoint_manager, dataset_type, data_iterator):
# Case 1: Matches if 'checkpoint_manager' is an instance of either EmergencyCheckpointManager
# or EmergencyReplicatorCheckpointManager. The '_' indicates that 'dataset_type' and
# 'data_iterator' can be any value and aren't used in this pattern.
case (checkpoint_manager, _, _) if isinstance(
checkpoint_manager, (EmergencyCheckpointManager, EmergencyReplicatorCheckpointManager)
):
return (
result = (
checkpoint_manager.restore(step, args=Composite(state=checkpoint_args)).state,
None,
)
_restore_original_array_handler()
return result
# Case 2: Matches if dataset type is "grain" and the data iterator is not a
# PlaceHolderDataIterator or RemoteIterator and a specific checkpoint file exists for the iterator
case (
Expand All @@ -581,13 +601,17 @@ def map_to_pspec(data):
and not _is_remote_iterator(data_iterator)
and (checkpoint_manager.directory / str(step) / "iter").exists()
):
return _restore_grain_iterator(
result = _restore_grain_iterator(
checkpoint_manager, step, data_iterator, checkpoint_args, expansion_factor_real_data
)
_restore_original_array_handler()
return result
# Case 3: Default/Fallback case.
# This case acts as a wildcard ('_') and matches if none of the preceding cases were met.
case _:
return (checkpoint_manager.restore(step, args=Composite(items=checkpoint_args)), None)
result = (checkpoint_manager.restore(step, args=Composite(items=checkpoint_args)), None)
_restore_original_array_handler()
return result

if load_parameters_from_path != "":
restored_params = load_params_from_path(
Expand Down
1 change: 1 addition & 0 deletions src/MaxText/configs/base.yml
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,7 @@ megablox: true
sparse_matmul: true
capacity_factor: -1.0 # a factor to decide expert capacity for token dropping, and no dropping by default
load_balance_loss_weight: 0.0 # weight for the load balance loss
expert_balance: False # whether or not to do expert balancing
use_random_routing: false # whether to use random routing for debug/test purpose
use_custom_sort_vjp: true # whether to use a custom VJP sort for efficient backward pass processing in sparse matmul
use_ring_of_experts: false # whether to use ring of experts for sparse matmul expert parallelism
Expand Down
1 change: 1 addition & 0 deletions src/MaxText/configs/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -553,6 +553,7 @@ class MoEGeneral(BaseModel):
num_experts_per_tok: PositiveInt = Field(1, description="The number of experts to route each token to.")
capacity_factor: float = Field(-1.0, description="Expert capacity factor. If < 0, no token dropping.")
load_balance_loss_weight: NonNegativeFloat = Field(0.0, description="Weight for the load balancing auxiliary loss.")
expert_balance: bool = Field(False, description="Whether to use expert balancing.")
use_custom_sort_vjp: bool = Field(
True, description="Whether to use a custom VJP sort for efficient backward pass processing in sparse matmul."
)
Expand Down
7 changes: 5 additions & 2 deletions src/MaxText/layers/attention_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -1399,6 +1399,10 @@ def cudnn_flash_attention(
attn_mask = None
dummy_attn_mask = None
mask_type = "causal"
elif self.config.dataset_type == "synthetic":
attn_mask = None
dummy_attn_mask = None
mask_type = "causal"
else:
# Default case: no packing, no context parallelism
dummy_attn_mask = jnp.zeros((1, 1, 1, self.max_target_length, self.max_target_length), dtype=jnp.uint8)
Expand All @@ -1416,12 +1420,11 @@ def cudnn_flash_attention(
dtype=self.dtype,
float32_logits=self.float32_logits,
qkv_layout=qkv_layout,
scale_factor=1.0,
transpose_batch_sequence=False,
window_size=sliding_window_size,
context_parallel_causal_load_balanced=self.config.context_parallel_load_balance,
context_parallel_axis="context",
context_parallel_strategy=self.config.context_parallel_strategy,
# context_parallel_strategy=self.config.context_parallel_strategy,
max_segments_per_seq=max_segments_per_seq,
)

Expand Down
2 changes: 1 addition & 1 deletion src/MaxText/layers/attentions.py
Original file line number Diff line number Diff line change
Expand Up @@ -549,7 +549,7 @@ def init_query_w(self, inputs_q_shape: Tuple) -> nnx.Module:
if self.config.use_qk_norm or (self.query_pre_attn_scalar is not None and self.query_pre_attn_scalar != 1.0):
depth_scaling = 1.0
else:
depth_scaling = jnp.sqrt(self.head_dim).astype(self.dtype)
depth_scaling = 1.0

def query_init(*args):
# pylint: disable=no-value-for-parameter
Expand Down
24 changes: 23 additions & 1 deletion src/MaxText/layers/moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -1583,7 +1583,7 @@ def get_einsum(
def aqt_einsum(*args, **kwargs): # pylint: disable=unused-argument
# simply skip kwargs, since aqt einsum doesn't support any kwargs
# like precision
is_aqt = not isinstance(self.quant, quantizations.Fp8Quantization)
is_aqt = not ( isinstance(self.quant, quantizations.Fp8Quantization) or isinstance(self.quant, quantizations.NANOOFp8Quantization) )
kw = {"mesh_axes": rhs_mesh_axes} if is_aqt else {"dtype": self.dtype}
return self.quant.einsum(**kw)(*args) # pytype: disable=attribute-error

Expand Down Expand Up @@ -1618,6 +1618,28 @@ def dense_matmul(
wo_bias,
) -> tuple[jax.Array, Optional[jax.Array], Optional[jax.Array]]:
"""Dense matrix multiplication."""
if self.config.expert_balance:
######################################################################################################
############################## start hard code for uniform expert ####################################
# Create deterministic rotational pattern for gate logits
batch_size, seq_len, num_experts = gate_logits.shape

# Create base weights for experts (increasing values)
base_weights = jnp.linspace(0.1, 0.1 * num_experts, num_experts, dtype=gate_logits.dtype)

# Create position-based indices matrix [seq_len, num_experts]
# Each row represents which index in base_weights to use after rotation
indices = (jnp.arange(num_experts)[None, :] + jnp.arange(seq_len)[:, None]) % num_experts

# Use advanced indexing to create the rotated weights matrix in one operation
# This takes the appropriate weight for each position based on the rotation pattern
rotated_weights = base_weights[indices]

# Broadcast to batch dimension
gate_logits = jnp.broadcast_to(rotated_weights[None, :, :], (batch_size, seq_len, num_experts))
############################################# end ####################################################
######################################################################################################

# gate_logits: batch, length, expert
gate_logits = self._maybe_shard_with_logical(gate_logits, ("activation_batch", "activation_norm_length", None))
if self.config.model_name.startswith("deepseek3"):
Expand Down
2 changes: 2 additions & 0 deletions src/MaxText/layers/quantizations.py
Original file line number Diff line number Diff line change
Expand Up @@ -296,6 +296,8 @@ def dot_general_cls(self, mesh_axes: Tuple[str, ...] = ()):
"""Returns dot_general configured with aqt params."""
return nn.NANOOFp8DotGeneralOp

def einsum(self, dtype: DType = jnp.float32):
return Fp8Einsum(dtype=dtype,e4m3_dtype=jnp.float8_e4m3fnuz,e5m2_dtype=jnp.float8_e5m2fnuz)

def _get_int8_quant_config(config):
drhs_bits = None
Expand Down
31 changes: 31 additions & 0 deletions src/MaxText/max_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from collections.abc import Sequence
import functools
from functools import partial
import json
import os
import socket
import subprocess
Expand Down Expand Up @@ -705,6 +706,36 @@ def print_system_information():
max_logging.log(f"System Information: Jaxlib Version: {jax.lib.__version__}")
max_logging.log(f"System Information: Jax Backend: {jax.extend.backend.get_backend().platform_version}")

devices = jax.devices()
max_logging.log(f"System Information: Number of devices: {len(devices)}, jax path {jax.__file__}")
for i, device in enumerate(devices):
if device.local_hardware_id is not None:
max_logging.log(
f"System Information: Device {i}: {device.id} "
f"(Local id: {device.local_hardware_id}, Process index: {device.process_index})"
)


def save_device_information(config):
"""Convert device information to JSON format."""
devices = jax.devices()
device_info = {'hostname': socket.gethostname(), 'devices': []}

for device in devices:
if device.local_hardware_id is not None:
info = {
"id": device.id,
"local_hardware_id": device.local_hardware_id,
"process_index": device.process_index,
"device_kind": device.device_kind,
"platform_version": jax.extend.backend.get_backend().platform_version,
}
device_info['devices'].append(info)
# Save to JSON file
device_info_path = os.path.join(config.base_output_directory, "device_info.json")
with open(device_info_path, "w") as f:
json.dump(device_info, f, indent=4)


def permute_to_match_maxtext_rope(arr):
"""Permutes the Huggingface Rope to match the MaxText logic."""
Expand Down
2 changes: 2 additions & 0 deletions src/MaxText/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -453,6 +453,7 @@ def train_loop(config, recorder, state=None):
if config.shard_optimizer_over_data:
state = sharding.maybe_shard_with_name(state, state_mesh_shardings, config.shard_mode)
state, metrics = p_train_step(state, example_batch, nextrng)
jax.block_until_ready(state)

step_time_delta = datetime.datetime.now() - last_step_completion
last_step_completion = datetime.datetime.now()
Expand Down Expand Up @@ -529,6 +530,7 @@ def initialize(argv: Sequence[str]) -> tuple[pyconfig.HyperParameters, Any, Any]
config = pyconfig.initialize(argv)
max_utils.print_system_information()
validate_train_config(config)
max_utils.save_device_information(config)
jax.config.update("jax_use_shardy_partitioner", config.shardy)
# update explicit sharding-supported config
if config.shard_mode == ShardMode.EXPLICIT:
Expand Down