From 1630b801ef4c702cc5ec6aefbd0c1f443002899e Mon Sep 17 00:00:00 2001 From: Lance Wang Date: Fri, 24 Apr 2026 22:05:58 +0000 Subject: [PATCH 1/4] NNX: add TrainState, model creation utilities, and training loop support - Add TrainStateNNX (layers/train_state_nnx.py) with checkpoint and unit tests - Refactor model_creation_utils with create_nnx_abstract_model(); add NNX support to muon_utils - Add get_abstract_state_nnx() and get_nnx_named_sharding_with_scan_axis() to maxtext_utils.py - Wire NNX train state into train.py and train_utils.py with pure_nnx dispatch --- src/maxtext/common/checkpointing.py | 36 +- src/maxtext/layers/nnx_decoders.py | 8 + src/maxtext/trainers/pre_train/train.py | 503 +++++++++++------- src/maxtext/utils/gradient_accumulation.py | 35 +- src/maxtext/utils/maxtext_utils.py | 236 ++++++-- src/maxtext/utils/model_creation_utils.py | 172 +++--- src/maxtext/utils/muon_utils.py | 60 ++- src/maxtext/utils/sharding.py | 121 ++++- src/maxtext/utils/train_utils.py | 54 +- .../integration/setup_train_loop_nnx_test.py | 140 +++++ tests/unit/checkpointing_nnx_load_test.py | 106 ++++ tests/unit/gradient_accumulation_nnx_test.py | 159 ++++++ tests/unit/maxtext_utils_test.py | 263 ++++++++- tests/unit/muon_utils_test.py | 224 ++++++++ tests/unit/nnx_decoders_test.py | 89 +++- tests/unit/optimizers_test.py | 116 +++- tests/unit/sharding_nnx_test.py | 161 ++++++ tests/unit/train_nnx_test.py | 239 +++++++++ tests/unit/train_state_nnx_checkpoint_test.py | 399 ++++++++++++++ tests/unit/train_state_nnx_test.py | 90 ++++ tests/unit/train_utils_nnx_test.py | 149 ++++++ 21 files changed, 2984 insertions(+), 376 deletions(-) create mode 100644 tests/integration/setup_train_loop_nnx_test.py create mode 100644 tests/unit/checkpointing_nnx_load_test.py create mode 100644 tests/unit/gradient_accumulation_nnx_test.py create mode 100644 tests/unit/muon_utils_test.py create mode 100644 tests/unit/sharding_nnx_test.py create mode 100644 tests/unit/train_nnx_test.py create mode 100644 tests/unit/train_state_nnx_checkpoint_test.py create mode 100644 tests/unit/train_state_nnx_test.py create mode 100644 tests/unit/train_utils_nnx_test.py diff --git a/src/maxtext/common/checkpointing.py b/src/maxtext/common/checkpointing.py index 75535fae29..e67329ecbd 100644 --- a/src/maxtext/common/checkpointing.py +++ b/src/maxtext/common/checkpointing.py @@ -20,6 +20,7 @@ from absl import flags import datetime from etils import epath +from flax import nnx from flax.training import train_state import jax from maxtext.utils.globals import DEFAULT_OCDBT_TARGET_DATA_FILE_SIZE @@ -536,7 +537,7 @@ def load_state_if_possible( load_parameters_from_path: str, load_full_state_from_path: str, checkpoint_storage_concurrent_gb: int, - abstract_unboxed_pre_state: train_state.TrainState, + abstract_unboxed_pre_state: train_state.TrainState | nnx.State, enable_single_replica_ckpt_restoring: bool | None = False, dataset_type: str | None = "tfds", step: int = -1, # -1 means latest @@ -604,9 +605,14 @@ def map_to_pspec(data): ) ocp.type_handlers.register_type_handler(jax.Array, array_handler, override=True) - restore_args = jax.tree_util.tree_map(map_to_pspec, abstract_unboxed_pre_state) + # Convert nnx.State to pure dict to match how checkpoints are saved for NNX + restore_target = abstract_unboxed_pre_state + if isinstance(abstract_unboxed_pre_state, nnx.State): + restore_target = abstract_unboxed_pre_state.to_pure_dict() + + restore_args = jax.tree_util.tree_map(map_to_pspec, restore_target) checkpoint_args = ocp.args.PyTreeRestore( - item=abstract_unboxed_pre_state, + item=restore_target, restore_args=restore_args, partial_restore=True, ) @@ -620,9 +626,7 @@ def map_to_pspec(data): (EmergencyCheckpointManager, EmergencyReplicatorCheckpointManager), ): return ( - checkpoint_manager.restore( - step, args=Composite(state=checkpoint_args) - ).state, + checkpoint_manager.restore(step, args=Composite(state=checkpoint_args)).state, None, ) # Case 2: Matches if dataset type is "grain" and the data iterator is not a @@ -647,9 +651,14 @@ def map_to_pspec(data): return (checkpoint_manager.restore(step, args=Composite(items=checkpoint_args)), None) if load_parameters_from_path != "": + if isinstance(abstract_unboxed_pre_state, nnx.State): + _, params, _ = nnx.split(abstract_unboxed_pre_state.model, nnx.Param, ...) + else: + params = abstract_unboxed_pre_state.params + restored_params = load_params_from_path( load_parameters_from_path, - abstract_unboxed_pre_state.params, + params, checkpoint_storage_concurrent_gb, use_ocdbt=use_ocdbt, use_zarr3=use_zarr3, @@ -741,7 +750,18 @@ def maybe_save_checkpoint(checkpoint_manager, state, config, data_iterator, step # Determine the effective step for saving a checkpoint. # If 'step' is not provided, this call is for a potential final checkpoint # and use the last completed step from the state. - actual_step = (int(state.step) - 1) if step is None else int(step) + if step is not None: + actual_step = int(step) + else: + if config.pure_nnx: + actual_step = int(state.optimizer.step) - 1 + else: + # Linen TrainState has .step attribute + actual_step = int(state.step) - 1 + + if config.pure_nnx: + # Convert nnx.State to dict. + state = state.to_pure_dict() # Determine if a checkpoint save should be forced, overriding the usual `config.checkpoint_period` logic. # This occurs if this function was called: diff --git a/src/maxtext/layers/nnx_decoders.py b/src/maxtext/layers/nnx_decoders.py index 23767ae741..42cf44cab0 100644 --- a/src/maxtext/layers/nnx_decoders.py +++ b/src/maxtext/layers/nnx_decoders.py @@ -35,6 +35,7 @@ MODEL_MODE_TRAIN, Config, DecoderBlockType, + MultimodalInput, ShardMode, ) from maxtext.inference import page_manager @@ -963,7 +964,14 @@ def __call__( audio_embeddings: None | jnp.ndarray = None, audio_masks: None | jnp.ndarray = None, deepstack_visual_embeds: None | list[jnp.ndarray] = None, + multimodal_input: None | MultimodalInput = None, ): + if multimodal_input is not None: + image_embeddings = multimodal_input.image_embeddings + image_masks = multimodal_input.image_masks + audio_embeddings = multimodal_input.audio_embeddings + audio_masks = multimodal_input.audio_masks + bidirectional_mask = multimodal_input.bidirectional_mask cfg = self.config assert decoder_input_tokens.ndim == 2 # [batch, len] diff --git a/src/maxtext/trainers/pre_train/train.py b/src/maxtext/trainers/pre_train/train.py index 0e94e0c8ba..6038cad7b6 100644 --- a/src/maxtext/trainers/pre_train/train.py +++ b/src/maxtext/trainers/pre_train/train.py @@ -35,8 +35,9 @@ import jax import jax.numpy as jnp +from jax.sharding import NamedSharding -from flax import linen as nn +from flax import linen as nn, nnx from flax.linen import partitioning as nn_partitioning from maxtext.configs import pyconfig @@ -67,6 +68,7 @@ from maxtext.utils import maxtext_utils from maxtext.utils import qk_clip_utils from maxtext.utils import sharding +from maxtext.utils import maxtext_utils_nnx from maxtext.utils import train_utils from maxtext.utils.gradient_accumulation import gradient_accumulation_loss_and_grad from maxtext.utils.vocabulary_tiling import vocab_tiling_linen_loss @@ -91,11 +93,11 @@ def loss_fn(model, config, data, dropout_rng, params, sparsity_state=None, is_tr """loss_fn for both train and eval. Args: - model: A nn.Module + model: A nn.Module (Linen) or nnx.Module (NNX). config: Config of parameters data: Batch of data to apply to the model - dropout_rng: A key to use to generate rng for dropout - params: Model params + dropout_rng: A key to use to generate rng for dropout (Linen); unused for NNX. + params: Model params (Linen); unused for NNX (params are part of the model). is_train: True for train_step and False for eval_step Returns: @@ -182,7 +184,7 @@ def loss_fn(model, config, data, dropout_rng, params, sparsity_state=None, is_tr xent_sum = jnp.sum(xent) total_z_loss = jnp.sum(z_loss) else: - # Flax NNX model + # Flax NNX model: logits = model( decoder_input_tokens=data["inputs"], decoder_positions=data["inputs_position"], @@ -193,7 +195,11 @@ def loss_fn(model, config, data, dropout_rng, params, sparsity_state=None, is_tr decoder_target_tokens=data["targets"], decoder_target_mask=data["targets_segmentation"], ) - intermediate_outputs = {} + # Capture NNX intermediates (MoE losses, hidden states, etc.) + intermediate_outputs = nnx.state(model, nnx.Intermediate).to_pure_dict() + + if config.num_vocab_tiling > 1: + raise NotImplementedError("Vocab tiling for NNX modules has not been implemented.") if (config.use_indexer and not config.indexer_sparse_training) and is_train: # In Dense Warm-up stage, we skip main model loss calculation for efficiency. @@ -285,74 +291,111 @@ def loss_fn(model, config, data, dropout_rng, params, sparsity_state=None, is_tr return loss, aux -def train_step(model, config, state_mesh_shardings, params_shardings, state, data, dropout_rng): - """ +def train_step(model, config, state_mesh_shardings, params_shardings, state, data, dropout_rng=None): + """Training step for both Linen and NNX models. Args: - model: A nn.Module - state: A pytree of the current state of the model - data: Batch of data to apply to the model - dropout_rng: A key to use to generate rng for dropout + model: A nn.Module (Linen) or nnx.GraphDef of the TrainStateNNX (NNX). + config: Hyperparameters. + state_mesh_shardings: PyTree of PartitionSpecs for the train state. + params_shardings: PyTree of PartitionSpecs for model parameters, used for gradient accumulation. + state: Linen TrainState or NNX pure State. + data: Training data batch. + dropout_rng: A key to use to generate rng for dropout (Linen); unused for NNX. Returns: - new_state: Same format as state. + new_state: Updated Linen TrainState or NNX pure State. metrics: Dictionary of model metrics such as loss, training rate, etc. - rng2: A new rng key that can be used in future calls. - """ - reference_params, reference_params_sharding, extra_dpo_args, _loss_fn = ( - [], - [], - [], - loss_fn, - ) - if config.use_dpo: - state, reference_params = _split_dpo_state(state) - state_mesh_shardings, reference_params_sharding = _split_dpo_state(state_mesh_shardings) - extra_dpo_args = [reference_params] - _loss_fn = dpo_loss_fn + # --- Per-path initialization --- + if isinstance(model, nn.Module): + reference_params, reference_params_sharding, extra_dpo_args, _loss_fn = [], [], [], loss_fn + if config.use_dpo: + state, reference_params = _split_dpo_state(state) + state_mesh_shardings, reference_params_sharding = _split_dpo_state(state_mesh_shardings) + extra_dpo_args = [reference_params] + _loss_fn = dpo_loss_fn + params = state.params + ga_fn, ga_model, ga_params, ga_rng, ga_dpo = _loss_fn, model, params, dropout_rng, extra_dpo_args + else: + if config.use_dpo: + raise NotImplementedError("DPO for NNX modules has not been implemented.") + state = nnx.merge(model, state) # reconstruct TrainStateNNX + ga_fn, ga_model, ga_params, ga_rng, ga_dpo = loss_fn, state.model, None, None, [] - params = state.params + # --- Gradient computation --- if config.gradient_accumulation_steps > 1: loss, aux, raw_grads = gradient_accumulation_loss_and_grad( - _loss_fn, + ga_fn, config, - model, - params, + ga_model, + ga_params, params_shardings, data, - dropout_rng, - extra_dpo_args, + ga_rng, + ga_dpo, ) else: - if config.optimizer_memory_host_offload: - if config.use_dpo: + if isinstance(model, nn.Module): + if config.optimizer_memory_host_offload and config.use_dpo: reference_params = jax.device_put( reference_params, max_utils.with_memory_kind(reference_params_sharding, "device"), ) extra_dpo_args = [reference_params] - if config.shard_optimizer_over_data: - params = jax.tree.map( - functools.partial(sharding.maybe_shard_with_name, shard_mode=config.shard_mode), - params, - params_shardings, + if config.shard_optimizer_over_data: + params = jax.tree.map( + functools.partial(sharding.maybe_shard_with_name, shard_mode=config.shard_mode), + params, + params_shardings, + ) + sparsity_enabled = config.weight_sparsity_n and config.weight_sparsity_m + pure_params = params["params"] if sparsity_enabled else params + batch_stats = params.get("batch_stats", {}) + + grad_func = jax.value_and_grad(_loss_fn, argnums=4, has_aux=True) + (loss, aux), raw_grads = grad_func( + model, + config, + data, + dropout_rng, + pure_params, + *extra_dpo_args, + sparsity_state=batch_stats, + is_train=True, ) - sparsity_enabled = config.weight_sparsity_n and config.weight_sparsity_m - pure_params = params["params"] if sparsity_enabled else params - batch_stats = params.get("batch_stats", {}) + else: + model_graphdef, curr_params, rest = nnx.split(state.model, nnx.Param, ...) + if config.parameter_memory_host_offload: + # Params are kept on host (pinned_host) in in_shardings. Move only Param + # variables to device before the forward/backward pass so that all dot_general + # operands share the same memory space (XLA on GPU requires this). + # Using params_shardings (Param-only) avoids Shardy rank mismatches that + # occur when applying PartitionSpec() (rank-0 in SDY) to rank-1 RNG key tensors. + device_param_shardings = jax.tree_util.tree_map_with_path( + maxtext_utils_nnx.move_memory_to_device, + params_shardings, + is_leaf=lambda x: isinstance(x, NamedSharding), + ) + curr_params = jax.device_put(curr_params, device_param_shardings) + nnx.update(state.model, curr_params) # ensure state.model has device params for optimizer update + if config.shard_optimizer_over_data: + curr_params = jax.tree.map( + functools.partial(sharding.maybe_shard_with_name, shard_mode=config.shard_mode), + curr_params, + params_shardings, + ) + nnx.update(state.model, curr_params) - grad_func = jax.value_and_grad(_loss_fn, argnums=4, has_aux=True) - (loss, aux), raw_grads = grad_func( - model, - config, - data, - dropout_rng, - pure_params, - *extra_dpo_args, - sparsity_state=batch_stats, - is_train=True, - ) + def diff_wrapper(param, rest, config, data): + local_model = nnx.merge(model_graphdef, param, rest, copy=True) + loss, aux = loss_fn(local_model, config, data, None, None, is_train=True) + _, _, new_rest = nnx.split(local_model, nnx.Param, ...) + return loss, (aux, new_rest) + + grad_func = jax.value_and_grad(diff_wrapper, argnums=0, has_aux=True) + (loss, (aux, new_rest)), raw_grads = grad_func(curr_params, rest, config, data) + nnx.update(state.model, new_rest) raw_grads = jax.tree_util.tree_map( lambda x: x.astype(config.grad_dtype) if x.dtype == jnp.float32 else x, @@ -363,6 +406,8 @@ def train_step(model, config, state_mesh_shardings, params_shardings, state, dat raw_grads, max_utils.with_memory_kind(params_shardings, "device"), ) + + # Extract aux fields into locals intermediate_outputs = aux["intermediate_outputs"] xent_sum = aux["xent_sum"] total_weights = aux["total_weights"] @@ -372,67 +417,90 @@ def train_step(model, config, state_mesh_shardings, params_shardings, state, dat moe_bias_updates = aux.get("moe_bias_updates") mtp_loss = aux.get("mtp_loss", 0.0) - if config.gradient_clipping_threshold > 0: - grads = maxtext_utils.apply_gradient_clipping(raw_grads, state, config.gradient_clipping_threshold) - else: - grads = raw_grads - - if config.optimizer_memory_host_offload: - state = state.replace( - opt_state=jax.device_put( - state.opt_state, - jax.tree_util.tree_map( - lambda x: x.with_memory_kind(kind="device"), - state_mesh_shardings.opt_state, - ), - ) - ) - # Move all parameters to device before optimizer update - if config.parameter_memory_host_offload: - max_logging.log("\nMoving all parameters to device before optimizer update") - - def move(path, value): - max_logging.log(f"train.py: Moving f{path} to device") - return value.with_memory_kind(kind="device") - - state = state.replace( - params=jax.device_put( - state.params, - jax.tree_util.tree_map_with_path(move, state_mesh_shardings.params), - ) - ) - # Re-wrap grads to match state.params structure if it's a dict of collections - sparsity_enabled = config.weight_sparsity_n and config.weight_sparsity_m - if sparsity_enabled: - full_grads = {"params": grads} - if sparsity_enabled and "batch_stats" in state.params: - batch_stats_grads = jax.tree_util.tree_map(jnp.zeros_like, state.params.get("batch_stats", {})) - full_grads["batch_stats"] = batch_stats_grads - full_grads = max_utils.unbox_logicallypartioned(full_grads) - else: - full_grads = grads - - if getattr(config, "skip_step_on_spikes", False): - grad_norm = max_utils.l2norm_pytree(grads) - # TrainState.apply_gradients doesn't pass **kwargs to tx.update, so we unpack it manually. - updates, new_opt_state = state.tx.update(grads, state.opt_state, state.params, loss=loss, grad_norm=grad_norm) - new_params = optax.apply_updates(state.params, updates) - - new_state = state.replace( - step=state.step + 1, - params=new_params, - opt_state=new_opt_state, - ) + if isinstance(model, nn.Module): + if config.gradient_clipping_threshold > 0: + grads = maxtext_utils.apply_gradient_clipping(raw_grads, state, config.gradient_clipping_threshold) + else: + grads = raw_grads + if config.optimizer_memory_host_offload: + state = state.replace( + opt_state=jax.device_put( + state.opt_state, + jax.tree_util.tree_map( + lambda x: x.with_memory_kind(kind="device"), + state_mesh_shardings.opt_state, + ), + ) + ) + # Move all parameters to device before optimizer update + if config.parameter_memory_host_offload: + max_logging.log("\nMoving all parameters to device before optimizer update") + + def move(path, value): + max_logging.log(f"train.py: Moving f{path} to device") + return value.with_memory_kind(kind="device") + + state = state.replace( + params=jax.device_put( + state.params, + jax.tree_util.tree_map_with_path(move, state_mesh_shardings.params), + ) + ) + # Re-wrap grads to match state.params structure if it's a dict of collections + # (when weight_sparsity is enabled, params has both 'params' and 'batch_stats' keys). + sparsity_enabled = config.weight_sparsity_n and config.weight_sparsity_m + if sparsity_enabled: + full_grads = {"params": grads} + if "batch_stats" in state.params: + batch_stats_grads = jax.tree_util.tree_map(jnp.zeros_like, state.params.get("batch_stats", {})) + full_grads["batch_stats"] = batch_stats_grads + full_grads = max_utils.unbox_logicallypartioned(full_grads) + else: + full_grads = grads + + if getattr(config, "skip_step_on_spikes", False): + grad_norm = max_utils.l2norm_pytree(grads) + # TrainState.apply_gradients doesn't pass **kwargs to tx.update, so we unpack it manually. + updates, new_opt_state = state.tx.update(grads, state.opt_state, state.params, loss=loss, grad_norm=grad_norm) + new_params = optax.apply_updates(state.params, updates) + + new_state = state.replace( + step=state.step + 1, + params=new_params, + opt_state=new_opt_state, + ) + else: + new_state = state.apply_gradients(grads=full_grads) + + # Apply updates for Auxiliary-Loss-Free load balancing for DeepSeek family + if config.routed_bias and config.routed_bias_update_rate > 0.0 and moe_bias_updates is not None: + target_path = ("params", "decoder", "moe_layers", "DeepSeekMoeBlock_0", "MoeBlock_0", "gate", "bias") + # Updates the shape to be aligned with state. + moe_bias_updates = jnp.array(moe_bias_updates[0]).transpose() + new_state = maxtext_utils.update_state_param(new_state, target_path, moe_bias_updates) else: - new_state = state.apply_gradients(grads=full_grads) + if config.gradient_clipping_threshold > 0: + grads = maxtext_utils.apply_gradient_clipping(raw_grads, None, config.gradient_clipping_threshold) + else: + grads = raw_grads + if config.optimizer_memory_host_offload: + # state.optimizer is an NNX Optimizer module; state_mesh_shardings.optimizer + # is an NNX State. Use nnx.state() to get a compatible State for device_put. + device_opt_shardings = jax.tree_util.tree_map_with_path( + maxtext_utils_nnx.move_memory_to_device, + state_mesh_shardings.optimizer, + is_leaf=lambda x: isinstance(x, NamedSharding), + ) + opt_state = nnx.state(state.optimizer) + new_opt_state = jax.device_put(opt_state, device_opt_shardings) + nnx.update(state.optimizer, new_opt_state) + state.apply_gradients(grads) + new_state = state - # Apply updates for Auxiliary-Loss-Free load balancing for DeepSeek family - if config.routed_bias and config.routed_bias_update_rate > 0.0 and moe_bias_updates is not None: - target_path = ("params", "decoder", "moe_layers", "DeepSeekMoeBlock_0", "MoeBlock_0", "gate", "bias") - # Flax 'sow' returns a tuple, so we take the first element [0]. - # Updates the shape to be aligned with state. - moe_bias_updates = jnp.array(moe_bias_updates[0]).transpose() - new_state = maxtext_utils.update_state_param(new_state, target_path, moe_bias_updates) + # Apply updates for Auxiliary-Loss-Free load balancing for DeepSeek family + if config.routed_bias and config.routed_bias_update_rate > 0.0 and moe_bias_updates is not None: + target_bias = new_state.model.decoder.moe_layers.DeepSeekMoeBlock_0.MoeBlock_0.gate.bias + target_bias.value = target_bias.value + jnp.array(moe_bias_updates[0]).transpose() lm_loss = xent_sum / (total_weights + EPS) scalar_metrics = { @@ -446,8 +514,9 @@ def move(path, value): "learning/total_weights": total_weights, } if config.use_qk_clip: - # Apply QK-Clip - new_state = qk_clip_utils.apply_qk_clip(new_state, intermediate_outputs, config) + # Apply QK-Clip (Linen path only; NNX uses different state layout — TODO: implement for NNX) + if isinstance(model, nn.Module): + new_state = qk_clip_utils.apply_qk_clip(new_state, intermediate_outputs, config) # Report max_logits metric global_max_logit = qk_clip_utils.calculate_max_logit_metric(intermediate_outputs) @@ -457,7 +526,11 @@ def move(path, value): if not config.optimizer_memory_host_offload: scalar_metrics["learning/grad_norm"] = max_utils.l2norm_pytree(grads) scalar_metrics["learning/raw_grad_norm"] = max_utils.l2norm_pytree(raw_grads) - scalar_metrics["learning/param_norm"] = max_utils.l2norm_pytree(new_state.params) + if isinstance(model, nn.Module): + scalar_metrics["learning/param_norm"] = max_utils.l2norm_pytree(new_state.params) + else: + _, model_params, _ = nnx.split(new_state.model, nnx.Param, ...) + scalar_metrics["learning/param_norm"] = max_utils.l2norm_pytree(model_params) if config.use_dpo: scalar_metrics["learning/dpo_loss"] = aux["dpo_loss"] scalar_metrics["learning/dpo_reward_accuracy"] = aux["reward_accuracy"] @@ -465,31 +538,34 @@ def move(path, value): "scalar": scalar_metrics, "scalars": {}, } - if config.record_internal_nn_metrics: record_activation_metrics(metrics, intermediate_outputs, config) - if config.use_dpo: - new_state = _merge_dpo_state(new_state, reference_params) - - return new_state, metrics + if isinstance(model, nn.Module): + if config.use_dpo: + new_state = _merge_dpo_state(new_state, reference_params) + return new_state, metrics + return nnx.state(new_state), metrics -def eval_step(model, config, state, data, dropout_rng): +def eval_step(model, config, state, data, dropout_rng=None): """eval_step no backprop and new state compared with train_step.""" + if isinstance(model, nn.Module): + reference_params, extra_dpo_args, _loss_fn = [], [], loss_fn + if config.use_dpo: + state, reference_params = _split_dpo_state(state) + extra_dpo_args = [reference_params] + _loss_fn = dpo_loss_fn - reference_params, extra_dpo_args, _loss_fn = [], [], loss_fn - if config.use_dpo: - state, reference_params = _split_dpo_state(state) - extra_dpo_args = [reference_params] - _loss_fn = dpo_loss_fn - - sparsity_enabled = config.weight_sparsity_n and config.weight_sparsity_m - pure_params = state.params["params"] if sparsity_enabled else state.params - batch_stats = state.params.get("batch_stats", {}) + sparsity_enabled = config.weight_sparsity_n and config.weight_sparsity_m + pure_params = state.params["params"] if sparsity_enabled else state.params + batch_stats = state.params.get("batch_stats", {}) - eval_loss_fn = functools.partial(_loss_fn, model, config, data, dropout_rng, is_train=False) - loss, aux = eval_loss_fn(pure_params, *extra_dpo_args, sparsity_state=batch_stats) + eval_loss_fn = functools.partial(_loss_fn, model, config, data, dropout_rng, is_train=False) + loss, aux = eval_loss_fn(pure_params, *extra_dpo_args, sparsity_state=batch_stats) + else: + state = nnx.merge(model, state) # reconstruct TrainStateNNX + loss, aux = loss_fn(state.model, config, data, None, None, is_train=False) mtp_acceptance_rate = 0.0 if config.mtp_eval_target_module > 0: @@ -517,7 +593,7 @@ def eval_step(model, config, state, data, dropout_rng): "evaluation/mtp_acceptance_rate_percent": mtp_acceptance_rate, }, } - if config.use_dpo: + if isinstance(model, nn.Module) and config.use_dpo: metrics["scalar"]["evaluation/dpo_reward_accuracy"] = aux["reward_accuracy"] return metrics @@ -539,32 +615,46 @@ def train_loop(config, recorder, state=None): state, ) = train_utils.setup_train_loop(config, recorder) - if config.use_dpo: - if "reference_params" not in state.params: - reference_params = jax.tree.map(jnp.copy, state.params["params"]) - state = _merge_dpo_state(state, reference_params) - state_mesh_shardings = _merge_dpo_state(state_mesh_shardings, state_mesh_shardings.params["params"]) + if isinstance(model, nn.Module): + if config.use_dpo: + if "reference_params" not in state.params: + reference_params = jax.tree.map(jnp.copy, state.params["params"]) + state = _merge_dpo_state(state, reference_params) + state_mesh_shardings = _merge_dpo_state(state_mesh_shardings, state_mesh_shardings.params["params"]) + jit_model = model + else: + if config.use_dpo: + raise NotImplementedError("DPO is not supported for NNX models.") + jit_model, state = nnx.split(state) params_shardings, state_mesh_shardings = sharding.maybe_update_params_sharding_with_opt(config, state_mesh_shardings) + p_train_step, p_eval_step = train_utils.jit_train_and_eval_step( + config, + jit_model, + mesh, + state, + state_mesh_shardings, + train_step, + eval_step, + eval_data_iterator, + params_shardings, + ) + with jax.set_mesh(mesh), mesh, nn_partitioning.axis_rules(config.logical_axis_rules): - p_train_step, p_eval_step = train_utils.jit_train_and_eval_step( - config, - model, - mesh, - state, - state_mesh_shardings, - train_step, - eval_step, - eval_data_iterator, - params_shardings, - ) shaped_batch = maxtext_utils.get_shaped_batch(config) - if config.shard_optimizer_over_data: + if config.shard_optimizer_over_data and isinstance(model, nn.Module): state = sharding.maybe_shard_with_name(state, state_mesh_shardings, config.shard_mode) - maxtext_utils.maybe_dump_jaxpr(config, p_train_step, (state, shaped_batch, init_rng)) + elif config.shard_optimizer_over_data: + # NNX: reshard state so params match the data-sharded in_shardings (Zero-1 layout) + state = jax.device_put(state, state_mesh_shardings) + if isinstance(model, nn.Module): + lower_args = (state, shaped_batch, init_rng) + else: + lower_args = (state, shaped_batch) + maxtext_utils.maybe_dump_jaxpr(config, p_train_step, lower_args) if config.compiled_trainstep_file == "": # compile only when there is no pre-compiled file loaded - compiled = p_train_step.lower(state, shaped_batch, init_rng).compile() + compiled = p_train_step.lower(*lower_args).compile() compiled_stats = compiled.memory_analysis() max_utils.print_compiled_memory_stats(compiled_stats) @@ -573,7 +663,11 @@ def train_loop(config, recorder, state=None): metric_logger = MetricLogger(config=config, learning_rate_schedule=learning_rate_schedule) # Write train config params, num model params, and XLA flags to tensorboard - metric_logger.write_setup_info_to_tensorboard(state.params) + if isinstance(model, nn.Module): + setup_params = state.params + else: + _, setup_params, _ = nnx.split(state.model, nnx.Param, ...) + metric_logger.write_setup_info_to_tensorboard(setup_params) _job_completed_gracefully = False try: @@ -583,59 +677,62 @@ def train_loop(config, recorder, state=None): with jax.profiler.StepTraceAnnotation("train", step_num=step): example_batch = data_loader.load_next_batch(rampup_manager=rampup_manager) - # pylint: disable=not-callable - nextrng = jax.jit(jax.random.fold_in)(init_rng, step) + if isinstance(model, nn.Module): + # pylint: disable=not-callable + step_rng_args = (jax.jit(jax.random.fold_in)(init_rng, step),) + else: + step_rng_args = () with maybe_record_goodput(recorder, GoodputEvent.STEP, step): with jax.set_mesh(mesh), nn_partitioning.axis_rules(config.logical_axis_rules): - if config.shard_optimizer_over_data: + if config.shard_optimizer_over_data and isinstance(model, nn.Module): state = sharding.maybe_shard_with_name(state, state_mesh_shardings, config.shard_mode) - state, metrics = p_train_step(state, example_batch, nextrng) - - step_time_delta = datetime.datetime.now() - last_step_completion - last_step_completion = datetime.datetime.now() - - state_to_save = state if not config.use_dpo else _split_dpo_state(state)[0] - checkpointing.maybe_save_checkpoint(checkpoint_manager, state_to_save, config, data_iterator, step) - - if config.dump_hlo and step == (config.dump_step if config.dump_step >= 0 else start_step): - jax.block_until_ready(state) # Ensure compilation has finished. - gcs_utils.upload_dump( - config.dump_hlo_local_dir, - config.dump_hlo_gcs_dir, - module_name=config.dump_hlo_module_name, - delete_local_after=config.dump_hlo_delete_local_after, - all_host_upload=config.dump_hlo_upload_all, - ) - - if config.eval_interval > 0 and step > start_step and (step + 1) % config.eval_interval == 0: - assert eval_data_iterator - # Explicitly reset the eval iterator and counters before starting the eval loop - eval_data_iterator.reset() - metric_logger.reset_eval_metrics() - - eval_step_count = 0 - # pylint: disable=not-callable - for eval_batch in eval_data_iterator: - # Shard input eval data - eval_batch = jax.device_put(eval_batch, sharding.get_input_data_sharding(config, mesh)) - if config.eval_steps > 0 and eval_step_count >= config.eval_steps: - break - with jax.set_mesh(mesh), nn_partitioning.axis_rules(config.logical_axis_rules): - eval_metrics = p_eval_step(state, eval_batch, nextrng) - metric_logger.record_eval_metrics(step, metrics=eval_metrics) - max_logging.log(f"Completed eval step {eval_step_count}") - eval_step_count += 1 - metric_logger.record_eval_metrics(step, eval_step_count=eval_step_count) - if metric_logger.cumulative_eval_metrics["scalar"]["eval/avg_loss"] <= config.target_eval_loss: - prof.deactivate() - raise exceptions.StopTraining(f"Target loss {config.target_eval_loss=} is achieved.") - - prof.maybe_deactivate_profiler(step, state) - - if step == start_step: - max_utils.print_mem_stats("After params initialized") - - metric_logger.buffer_and_write_train_metrics(metrics, step, step_time_delta) + state, metrics = p_train_step(state, example_batch, *step_rng_args) + + step_time_delta = datetime.datetime.now() - last_step_completion + last_step_completion = datetime.datetime.now() + + state_to_save = state if not config.use_dpo else _split_dpo_state(state)[0] + checkpointing.maybe_save_checkpoint(checkpoint_manager, state_to_save, config, data_iterator, step) + + if config.dump_hlo and step == (config.dump_step if config.dump_step >= 0 else start_step): + jax.block_until_ready(state) # Ensure compilation has finished. + gcs_utils.upload_dump( + config.dump_hlo_local_dir, + config.dump_hlo_gcs_dir, + module_name=config.dump_hlo_module_name, + delete_local_after=config.dump_hlo_delete_local_after, + all_host_upload=config.dump_hlo_upload_all, + ) + + if config.eval_interval > 0 and step > start_step and (step + 1) % config.eval_interval == 0: + assert eval_data_iterator + # Explicitly reset the eval iterator and counters before starting the eval loop + eval_data_iterator.reset() + metric_logger.reset_eval_metrics() + + eval_step_count = 0 + # pylint: disable=not-callable + for eval_batch in eval_data_iterator: + # Shard input eval data + eval_batch = jax.device_put(eval_batch, sharding.get_input_data_sharding(config, mesh)) + if config.eval_steps > 0 and eval_step_count >= config.eval_steps: + break + with jax.set_mesh(mesh), nn_partitioning.axis_rules(config.logical_axis_rules): + eval_metrics = p_eval_step(state, eval_batch, *step_rng_args) + metric_logger.record_eval_metrics(step, metrics=eval_metrics) + max_logging.log(f"Completed eval step {eval_step_count}") + eval_step_count += 1 + metric_logger.record_eval_metrics(step, eval_step_count=eval_step_count) + if metric_logger.cumulative_eval_metrics["scalar"]["eval/avg_loss"] <= config.target_eval_loss: + prof.deactivate() + raise exceptions.StopTraining(f"Target loss {config.target_eval_loss=} is achieved.") + + prof.maybe_deactivate_profiler(step, state) + + if step == start_step: + max_utils.print_mem_stats("After params initialized") + + metric_logger.buffer_and_write_train_metrics(metrics, step, step_time_delta) if config.save_checkpoint_on_completion: state_to_save = state if not config.use_dpo else _split_dpo_state(state)[0] diff --git a/src/maxtext/utils/gradient_accumulation.py b/src/maxtext/utils/gradient_accumulation.py index 9bad1cfb35..e1699647c6 100644 --- a/src/maxtext/utils/gradient_accumulation.py +++ b/src/maxtext/utils/gradient_accumulation.py @@ -17,6 +17,7 @@ import jax import jax.numpy as jnp from jax.sharding import NamedSharding +from flax import nnx from maxtext.common.common_types import ShardMode from maxtext.utils.sharding import maybe_shard_with_name @@ -49,7 +50,8 @@ def gradient_accumulation_loss_and_grad( config: Model and training configuration object. Must contain `gradient_accumulation_steps` and `shard_optimizer_over_data`. model: The model module. - params: The model parameters (PyTree). + params: The model parameters (PyTree). This is only used for Linen. For NNX, + we can get the params from the model. params_shardings: The sharding constraints for the parameters (PyTree). data: A PyTree of batched data. The leading dimension is assumed to be the total batch size (microbatch_size * num_accumulations). @@ -67,12 +69,18 @@ def _maybe_shard_with_name(inputs, sharding_names): """Wrapper of maybe_shard_with_name with fixed shard_mode""" return maybe_shard_with_name(inputs, sharding_names, config.shard_mode, debug_sharding=config.debug_sharding) + is_nnx = isinstance(model, nnx.Module) + # For more efficient DP/ZeRO-1 + GA if config.shard_mode == ShardMode.EXPLICIT and config.ici_data_parallelism > 1: ga_params_shardings = jax.tree.map(update_sharding_for_reduced, params_shardings) grad_shardings = jax.tree.map(update_sharding_for_unreduced, params_shardings) else: ga_params_shardings = grad_shardings = params_shardings + + if is_nnx: + graphdef, params, rest = nnx.split(model, nnx.Param, ...) + # When using Zero-1 optimizer sharding, cast params to lower precision and apply sharding constraints # so that all-gather is done once in the lower precision before the gradient accumulation loop if config.shard_optimizer_over_data: @@ -87,11 +95,27 @@ def convert_to_bf16(param): ga_params = params ga_params = jax.tree.map(_maybe_shard_with_name, ga_params, ga_params_shardings) - grad_func = jax.value_and_grad(_loss_fn, argnums=4, has_aux=True) + if is_nnx: + grad_func = nnx.value_and_grad(_loss_fn, argnums=0, has_aux=True) + else: + grad_func = jax.value_and_grad(_loss_fn, argnums=4, has_aux=True) def accumulate_gradient(acc_grad_and_loss, data): ga_params = acc_grad_and_loss["ga_params"] - (_, aux), cur_batch_gradient = grad_func(model, config, data, dropout_rng, ga_params, *extra_dpo_args, is_train=True) + if is_nnx: + # Reconstruct the model using the fixed parameters (ga_params) + # and the advancing non-parameter state (RNGs) from the carry. + local_model = nnx.merge(graphdef, ga_params, acc_grad_and_loss["rest_state"]) + (_, aux), cur_batch_gradient = grad_func(local_model, config, data, None, None, *extra_dpo_args, is_train=True) + _, _, next_rest_state = nnx.split(local_model, nnx.Param, ...) + acc_grad_and_loss["rest_state"] = next_rest_state + else: + rng = ( + jax.random.fold_in(dropout_rng, acc_grad_and_loss["total_weights"].astype(jnp.int32)) + if dropout_rng is not None + else None + ) + (_, aux), cur_batch_gradient = grad_func(model, config, data, rng, ga_params, *extra_dpo_args, is_train=True) acc_grad_and_loss["loss"] += aux["xent_sum"] + aux.get("dpo_loss", 0.0) acc_grad_and_loss["moe_lb_loss"] += aux["moe_lb_loss"] acc_grad_and_loss["indexer_loss"] += aux["indexer_loss"] @@ -119,6 +143,8 @@ def reshape_to_microbatch_accumulations(batch_arr): "mtp_loss": 0.0, "ga_params": ga_params, } + if is_nnx: + init_grad_and_loss["rest_state"] = rest grad_and_loss, aux = jax.lax.scan( accumulate_gradient, init_grad_and_loss, data, length=config.gradient_accumulation_steps @@ -134,6 +160,9 @@ def reshape_to_microbatch_accumulations(batch_arr): raw_grads = jax.tree_util.tree_map(lambda arr: arr / grad_and_loss["total_weights"], raw_grads) aux = jax.tree.map(lambda x: jnp.sum(x, axis=0), aux) # pytype: disable=module-attr + if is_nnx: + nnx.update(model, grad_and_loss["rest_state"]) + return loss, aux, raw_grads diff --git a/src/maxtext/utils/maxtext_utils.py b/src/maxtext/utils/maxtext_utils.py index d0e4d05113..0479d4bfca 100644 --- a/src/maxtext/utils/maxtext_utils.py +++ b/src/maxtext/utils/maxtext_utils.py @@ -20,21 +20,20 @@ import os from typing import Sequence -from flax import linen as nn +from flax import nnx, linen as nn +from flax.core.spmd import composite_rules, from_sharding_rules, get_logical_axis_rules from flax.linen import partitioning as nn_partitioning -from flax.training import train_state +from flax.training.train_state import TrainState import numpy as np -from jax.experimental import mesh_utils -from jax.experimental.serialize_executable import deserialize_and_load -from jax.sharding import AxisType, Mesh - import jax import jax.numpy as jnp +from jax.sharding import AxisType, Mesh, NamedSharding, PartitionSpec +from jax.experimental import mesh_utils +from jax.experimental.serialize_executable import deserialize_and_load import optax - import orbax.checkpoint.experimental.emergency.checkpoint_manager as emergency_checkpoint_manager import orbax.checkpoint.experimental.emergency.replicator_checkpoint_manager as emergency_replicator_checkpoint_manager @@ -54,6 +53,7 @@ from maxtext.utils import max_logging from maxtext.utils import max_utils from maxtext.utils import sharding +from maxtext.utils import maxtext_utils_nnx OVERWRITE_WITH_GRADIENT = "_overwrite_with_gradient" @@ -101,7 +101,10 @@ def get_functional_train_with_signature( """Get the shardings (both state and data) for `train_step`.""" functional_train = functools.partial(train_step, model, config, state_mesh_shardings, params_shardings) functional_train.__name__ = "train_step" - in_shardings = (state_mesh_shardings, data_sharding, None) # State, batch, rng + if config.pure_nnx: + in_shardings = (state_mesh_shardings, data_sharding) # State, batch + else: + in_shardings = (state_mesh_shardings, data_sharding, None) # State, batch, rng out_shardings = (state_mesh_shardings, None) # State, metrics static_argnums = () # We partial out the static argnums of model and config donate_argnums = 0 # This is the index of the state - we allow the compiler to make use of this memory. @@ -112,7 +115,10 @@ def get_functional_eval_with_signature(eval_step, data_sharding, state_mesh_shar """Get the shardings (both state and data) for `eval_step`.""" functional_eval = functools.partial(eval_step, model, config) functional_eval.__name__ = "eval_step" - in_shardings = (state_mesh_shardings, data_sharding, None) # State, batch, rng + if config.pure_nnx: + in_shardings = (state_mesh_shardings, data_sharding) # State, batch (NNX: no rng) + else: + in_shardings = (state_mesh_shardings, data_sharding, None) # State, batch, rng out_shardings = None # metrics static_argnums = () # We partial out the static argnums of model, config donate_argnums = () # state will be kept instead of being donated in eval_step @@ -1231,15 +1237,15 @@ def _apply_update(path, param): return state.replace(params=new_params) -def init_decode_state(apply_fn, params) -> train_state.TrainState: +def init_decode_state(apply_fn, params) -> TrainState: """Init train state with null opt state for decode.""" - state = train_state.TrainState(step=0, apply_fn=apply_fn, params=params, tx=None, opt_state={}) # type: ignore + state = TrainState(step=0, apply_fn=apply_fn, params=params, tx=None, opt_state={}) # type: ignore return state def init_training_state(apply_fn, params, tx): """Init train state with null opt state for decode.""" - state = train_state.TrainState.create(apply_fn=apply_fn, params=params, tx=tx) + state = TrainState.create(apply_fn=apply_fn, params=params, tx=tx) return state @@ -1367,7 +1373,7 @@ def setup_initial_state( is_training: True to initialize training state, False for decode state Returns: - state: the initialized train state + train_state: the initialized train state. For NNX, this is a TrainStateNNX instance state_mesh_annotations: the mesh annotations for the train state """ @@ -1406,29 +1412,48 @@ def setup_initial_state( else: # The update of data_iterator state happens in place, no need to assign explicitly state = restored["items"] + + # For NNX, convert the pure dict to nnx.State using the abstract state as template + if config.pure_nnx: + nnx.replace_by_pure_dict(unboxed_abstract_state, state) + state = unboxed_abstract_state else: init_state_partial = init_state_fn init_state_partial.__name__ = "initialize_state" - # pylint: disable=not-callable - state = jax.jit( - init_state_partial, - in_shardings=None, - out_shardings=state_mesh_shardings, - )() - sparsity_enabled = config.weight_sparsity_n and config.weight_sparsity_m - if sparsity_enabled and raw_params: # If we loaded a partial state, we need to merge it. - - def _merge_params(p_raw, p_init): - if isinstance(p_raw, jax.ShapeDtypeStruct): - return p_init - return p_raw - - merged_params = jax.tree_util.tree_map(_merge_params, raw_params, state.params) - state = state.replace(params=merged_params) - elif raw_params: - state = state.replace(params=raw_params) - - state = max_utils.unbox_logicallypartioned(state) + if config.pure_nnx: + state = jax.jit( + lambda: nnx.state(init_state_partial()), # Get state only, mapping to out_sharding structure + in_shardings=None, + out_shardings=state_mesh_shardings, + )() + else: + # pylint: disable=not-callable + state = jax.jit( + init_state_partial, + in_shardings=None, + out_shardings=state_mesh_shardings, + )() + if raw_params: # If we loaded a partial state, we need to merge it. + if config.pure_nnx: + # raw_params should have the same sharding info as in the model + nnx.update(state.model, raw_params) + else: + sparsity_enabled = config.weight_sparsity_n and config.weight_sparsity_m + if sparsity_enabled: + # Sparsity-init keeps freshly initialized params for any leaf still + # represented as an abstract ShapeDtypeStruct in raw_params (i.e. not + # actually restored), and uses the restored value otherwise. + def _merge_params(p_raw, p_init): + if isinstance(p_raw, jax.ShapeDtypeStruct): + return p_init + return p_raw + + merged_params = jax.tree_util.tree_map(_merge_params, raw_params, state.params) + state = state.replace(params=merged_params) + else: + state = state.replace(params=raw_params) + if not config.pure_nnx: + state = max_utils.unbox_logicallypartioned(state) return state, state_mesh_annotations, state_mesh_shardings, data_iterator @@ -1443,6 +1468,9 @@ def get_logical_annotations(config, mesh, init_state_fn): def get_abstract_state(config, mesh, init_state_fn, is_training=True): """Get a shaped abstraction of the state (including optimizer)""" + if config.pure_nnx: + return get_abstract_state_nnx(config, mesh, init_state_fn, is_training) + init_state_partial = init_state_fn with nn_partitioning.axis_rules(config.logical_axis_rules): @@ -1486,6 +1514,148 @@ def move(path, x): ) +def get_nnx_named_sharding_with_scan_axis(abs_var_state: nnx.State, mesh) -> nnx.State: + """Compute NamedSharding for each NNX variable, correctly handling the scan (stacked layers) axis. + + Unlike flax.nnx.spmd.get_var_pspec (used inside nnx.get_abstract_model), this function also + inserts the partition_name axis at the correct scan_axis position for parameters created by + _create_scanned_layers. Without this, scanned parameters get a 2D partition spec applied to a + 3D tensor, placing sharding on the stacked-layers dimension instead of the embedding dimension. + + Args: + abs_var_state: NNX abstract variable state from nnx.split(nnx.eval_shape(...)). + mesh: JAX physical mesh. + + Returns: + Same tree structure as abs_var_state but each Variable's value replaced with NamedSharding. + """ + + def _make_named_sharding(v): + val = v.get_value() + if not hasattr(val, "shape"): + # Non-tensor value (e.g., optax MaskedNode for non-trainable params). Preserve + # as-is so the treedef matches abs_var_state in the downstream jax.tree.map. + return v + metadata = v.get_metadata() + out_sharding = metadata.get("out_sharding") or metadata.get("sharding_names") or metadata.get("sharding") + if not out_sharding: + pspec = PartitionSpec() + else: + # Insert the scan axis for parameters created by _create_scanned_layers. + # _add_scan_metadata stores the axis name in nnx.PARTITION_NAME and the + # axis index in "param_scan_axis". flax.nnx.spmd.get_var_pspec ignores these. + if nnx.PARTITION_NAME in metadata: + partition_name = metadata[nnx.PARTITION_NAME] + # Always use param_scan_axis from metadata. OptVariable (optimizer state) inherits + # param_scan_axis=1 from the model Param via to_opt_state(), so we must not hardcode + # scan_axis=0 for non-Param types. stacked_rest non-Param variables have + # param_scan_axis=0 set explicitly by _add_scan_metadata, so this is always correct. + scan_axis = metadata.get("param_scan_axis", 0) + out_sharding = [out_sharding] if isinstance(out_sharding, str) else list(out_sharding) + # Guard against double-insertion: Flax 0.12.6 _remap_sharding_metadata renames + # 'sharding' -> 'out_sharding', so _add_scan_metadata may have already inserted + # the scan axis. Only insert if not already present. + if partition_name not in out_sharding: + out_sharding.insert(scan_axis, partition_name) + out_sharding = tuple(out_sharding) + # Convert logical axis names to physical mesh axes using current context rules. + context_rules = get_logical_axis_rules() + local_rules = metadata.get("sharding_rules", ()) + if context_rules or local_rules: + rules = composite_rules(context_rules, local_rules) + pspec = PartitionSpec(*from_sharding_rules(out_sharding, rules)) + else: + pspec = PartitionSpec(*out_sharding) + return v.replace(NamedSharding(mesh, pspec)) + + return jax.tree.map(_make_named_sharding, abs_var_state, is_leaf=lambda x: isinstance(x, nnx.Variable)) + + +def get_abstract_state_nnx(config, mesh, nnx_init_trainstate_fn, is_training=True): + """Calculates the abstract sharded state and memory placement for an NNX TrainState. + + This function performs an abstract trace of the NNX model and optimizer using + `nnx.get_abstract_model`. It resolves logical sharding annotations into physical + JAX shardings and applies memory placement optimizations such as optimizer + sharding and host memory offloading (pinning to CPU RAM). + + Args: + config: Configuration object containing sharding and offloading hyperparameters + (e.g., shard_optimizer_over_data, optimizer_memory_host_offload). + mesh: JAX physical mesh used to resolve logical axis names to physical devices. + nnx_init_trainstate_fn: A zero-argument factory function that produces a + TrainStateNNX instance during the abstract trace. + is_training: Boolean indicating if the state is for training. If True, + optimizer state is processed and memory offloading strategies are applied. + + Returns: + A tuple containing (abstract_sharded_state, None, state_mesh_shardings): + abstract_sharded_state: An nnx.State containing ShapeDtypeStructs with + fully resolved physical sharding and memory_kind metadata. + state_mesh_annotations: An nnx.State tree consisting of the raw PartitionSpec + objects corresponding to each parameter/variable. + state_mesh_shardings: An nnx.State tree consisting of the raw JAX + Sharding objects corresponding to each parameter/variable. + """ + assert nnx_init_trainstate_fn is not None, "get_abstract_state_nnx: init function must be given." + + with nn_partitioning.axis_rules(config.logical_axis_rules): + # Use nnx.eval_shape + nnx.split instead of nnx.get_abstract_model, so we can apply + # get_nnx_named_sharding_with_scan_axis which correctly inserts the stacked-layers + # axis into the partition spec. nnx.get_abstract_model uses get_var_pspec internally + # which ignores nnx.PARTITION_NAME / param_scan_axis metadata set by _create_scanned_layers, + # causing the 2D partition spec to be misapplied to the 3D stacked parameter tensor. + # Do NOT wrap nnx.eval_shape in jax.set_mesh: Flax 0.12.6's _to_variable calls + # var.shape for every variable when a global mesh is active, but masked optimizer + # state variables (e.g. from trainable_parameters_mask) have value=MaskedNode() + # which has no .shape and would raise AttributeError. We handle sharding + # ourselves via get_nnx_named_sharding_with_scan_axis, so auto-assignment is not + # needed here. + abs_model = nnx.eval_shape(nnx_init_trainstate_fn) + _, abs_var_state = nnx.split(abs_model) + named_sharding_state = get_nnx_named_sharding_with_scan_axis(abs_var_state, mesh) + abstract_state = jax.tree.map( + lambda a, s: jax.ShapeDtypeStruct(a.shape, a.dtype, sharding=s), + abs_var_state, + named_sharding_state, + ) + + state_mesh_shardings = maxtext_utils_nnx.get_named_sharding_nnx(abstract_state) + + if is_training and config.shard_optimizer_over_data: + # Add data to sharding for optimizer state + optimizer_sharding = jax.tree_util.tree_map_with_path( + functools.partial(sharding.add_data_to_sharding, mesh), + abstract_state.optimizer, + state_mesh_shardings.optimizer, + ) + state_mesh_shardings.optimizer = optimizer_sharding + if is_training and config.optimizer_memory_host_offload: + optimizer_sharding = jax.tree_util.tree_map_with_path( + maxtext_utils_nnx.move_memory_to_host, + state_mesh_shardings.optimizer, + is_leaf=lambda x: isinstance(x, NamedSharding), + ) + state_mesh_shardings.optimizer = optimizer_sharding + if is_training and config.parameter_memory_host_offload: + assert config.param_scan_axis == 0, "You must set the scan axis 0 to enable parameter offloading." + _, state_params, _ = nnx.split(state_mesh_shardings, nnx.Param, ...) + state_params = jax.tree_util.tree_map_with_path( + maxtext_utils_nnx.move_memory_to_host, + state_params, + is_leaf=lambda x: isinstance(x, NamedSharding), + ) + nnx.update(state_mesh_shardings, state_params) + + abstract_sharded_state = maxtext_utils_nnx.set_named_sharding_nnx(abstract_state, state_mesh_shardings) + state_mesh_annotations = maxtext_utils_nnx.get_partition_spec_nnx(state_mesh_shardings) + return ( + abstract_sharded_state, + state_mesh_annotations, + state_mesh_shardings, + ) + + def get_prefill_kv_cache_annotations(model, config, rng, mesh, page_state: None | PageState = None): """Get a shaped abstraction of the state (including optimizer)""" diff --git a/src/maxtext/utils/model_creation_utils.py b/src/maxtext/utils/model_creation_utils.py index 9cd935730d..7647e86042 100644 --- a/src/maxtext/utils/model_creation_utils.py +++ b/src/maxtext/utils/model_creation_utils.py @@ -1,3 +1,17 @@ +# Copyright 2023–2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + # Copyright 2023–2025 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -18,11 +32,11 @@ import dataclasses import collections from collections.abc import Sequence +from typing import Callable, overload from functools import partial import os import subprocess import sys -from typing import overload from etils import epath from flax import nnx import flax.linen as nn @@ -261,34 +275,99 @@ def create_model(config, mesh, model_mode: str = MODEL_MODE_TRAIN, rngs: nnx.Rng return model -def create_nnx_abstract_model(config, mesh, model_mode=MODEL_MODE_TRAIN, rng_key=None): - """Returns (_create_model_partial, abstract_model) for AOT compilation. +def get_nnx_create_model_fn(config, mesh=None, devices=None, model_mode=MODEL_MODE_TRAIN, rng_key=None) -> Callable: - This does not shard parameters or load checkpoints. It only builds the - abstract shape/dtype structure needed by get_abstract_state and optimizer - construction (e.g. Muon). + def _create_model(): + rngs = maxtext_utils_nnx.create_nnx_rngs(config, model_mode=model_mode, rng_key=rng_key) + return from_config(config, devices, mesh, rngs=rngs, model_mode=model_mode) - Args: - config: the configuration - mesh: the device mesh - model_mode: train or inference - rng_key: optional RNG key + return _create_model + + +def create_nnx_abstract_model( + config, mesh=None, devices=None, model_mode=MODEL_MODE_TRAIN, rng_key=None +) -> tuple[Callable, nnx.Module]: + """Creates an abstract NNX model. Returns: - (_create_model_partial, abstract_model) where _create_model_partial() creates - a concrete model instance and abstract_model is the eval_shape result. + A tuple containing (create_model_fn, abstract_model): + create_model_fn: A zero-argument callable that produces a new model instance. + abstract_model: The stateful NNX model instance in an abstract state. """ - def _create_model(rng_key=None): - rngs = maxtext_utils_nnx.create_nnx_rngs(config, model_mode=model_mode, rng_key=rng_key) - return from_config(config, mesh=mesh, rngs=rngs, model_mode=model_mode) + with nn.logical_axis_rules(config.logical_axis_rules): + _create_model = get_nnx_create_model_fn(config, mesh, devices, model_mode, rng_key) + if mesh is None: + _tmp = nnx.eval_shape(_create_model) + mesh = _tmp.mesh + # Use nnx.eval_shape + our scan-axis-aware sharding helper instead of + # nnx.get_abstract_model, which uses get_var_pspec internally and ignores + # param_scan_axis / nnx.PARTITION_NAME metadata set by _create_scanned_layers, + # causing the stacked layers axis to be missing from the PartitionSpec. + with jax.set_mesh(mesh): + abs_model = nnx.eval_shape(_create_model) + graphdef, abs_var_state = nnx.split(abs_model) + named_sharding_state = maxtext_utils.get_nnx_named_sharding_with_scan_axis(abs_var_state, mesh) + abstract_state = jax.tree.map( + lambda a, s: jax.ShapeDtypeStruct(a.shape, a.dtype, sharding=s), + abs_var_state, + named_sharding_state, + ) + return _create_model, nnx.merge(graphdef, abstract_state) + + +def create_nnx_sharded_model_hybrid(config, mesh=None, devices=None, model_mode=MODEL_MODE_TRAIN, rng_key=None): + """Creates a sharded model for hybrid NNX modules containing Linen sub-modules. - _create_model_partial = partial(_create_model, rng_key=rng_key) + DEPRECATED: This function is a transitional utility for the Linen-to-NNX + migration. It should be removed once all model components are ported to + pure NNX modules. + + This function specifically handles the complexity of "mixed" state initialization, + where logical sharding annotations must be resolved for both NNX native + Parameters and legacy Linen variables wrapped via the NNX-Linen bridge. + It ensures that both systems correctly respect the provided mesh and + logical axis rules during the abstraction/sharding planning phase. + """ + _create_model_partial = get_nnx_create_model_fn(config, mesh, devices, model_mode, rng_key) with nn.logical_axis_rules(config.logical_axis_rules): abstract_model = nnx.eval_shape(_create_model_partial) + graphdef, abstract_state = nnx.split(abstract_model) + specs = nnx.get_partition_spec(abstract_state) + + if mesh is None: + mesh = abstract_model.mesh + + # JIT a function that creates the model state with proper sharding from the start. + # By providing out_shardings, we instruct JAX to produce sharded output directly, + # avoiding a large intermediate allocation on a single device. + with nn.logical_axis_rules(config.logical_axis_rules): + out_shardings = nn.logical_to_mesh_sharding(specs, mesh) - return _create_model_partial, abstract_model + @partial(jax.jit, out_shardings=out_shardings) + def create_sharded_state(): + # This will be JIT-compiled. JAX knows the output sharding and can + # initialize the parameters directly on the target devices in a sharded way. + model = _create_model_partial() + return nnx.state(model) + + with mesh: + # Create the model with sharded parameters. + with nn.logical_axis_rules(config.logical_axis_rules): + sharded_state = create_sharded_state() + model = nnx.merge(graphdef, sharded_state) + + # print weights sharding info under debug sharding mode + if config.debug_sharding: + max_utils.print_non_trivial_mesh_axis(model.mesh) + maxtext_utils.print_shardings_params( + params=sharded_state, + params_sharding=out_shardings, + mesh=model.mesh, + logical_annotations=specs, + ) + return model def setup_configs_and_devices(argv: list[str] | None = None, kwargs: dict | None = None, **extra_kwargs): @@ -473,60 +552,19 @@ def from_pretrained( ) config = pyconfig.HyperParameters(new_config) - def _create_model(mesh: Mesh | None = None, model_mode: str = MODEL_MODE_TRAIN, rng_key: jax.Array | None = None): - rngs = maxtext_utils_nnx.create_nnx_rngs(config, model_mode=model_mode, rng_key=rng_key) - return from_config(config, devices, mesh, rngs=rngs, model_mode=model_mode) - - _create_model_partial = partial(_create_model, mesh=mesh, model_mode=model_mode, rng_key=rng_key) + if config.pure_nnx: + _create_model, abstract_model = create_nnx_abstract_model(config, mesh, devices, model_mode, rng_key) + model = maxtext_utils_nnx.create_nnx_sharded_model(abstract_model, _create_model, mesh=mesh) + # TODO: print debug_sharding info + else: + model = create_nnx_sharded_model_hybrid(config, mesh, devices, model_mode, rng_key) - with nn.logical_axis_rules(config.logical_axis_rules): - abstract_model = nnx.eval_shape(_create_model_partial) - graphdef, abstract_state = nnx.split(abstract_model) - specs = nnx.get_partition_spec(abstract_state) + sharded_state = nnx.state(model) if mesh is None: - mesh = abstract_model.mesh - - # Note for pure_nnx: - # Currently, the NNX model returned has a linen decoder wrapped to NNX. So it is not a pure NNX model and - # we still need to use nn.logical_axis_rules(config.logical_axis_rules) to get the out sharding from the linen - # LogicallyPartitioned structure. - # In the future if the pure NNX model is used, with pure NNX's eager sharding, there will be no LogicallyPartitioned - # structure in the abstract state and we can get the sharded state with the following code: - # graphdef, state = nnx.get_abstract_model(_create_model_partial, mesh) - # abstract_model = nnx.merge(graphdef, state) - # model = maxtext_utils_nnx.create_nnx_sharded_model(abstract_model, _create_model_partial, mesh=mesh) - # sharded_state = nnx.state(model) - - # JIT a function that creates the model state with proper sharding from the start. - # By providing out_shardings, we instruct JAX to produce sharded output directly, - # avoiding a large intermediate allocation on a single device. - with nn.logical_axis_rules(config.logical_axis_rules): - out_shardings = nn.logical_to_mesh_sharding(specs, mesh) - - @partial(jax.jit, out_shardings=out_shardings) - def create_sharded_state(): - # This will be JIT-compiled. JAX knows the output sharding and can - # initialize the parameters directly on the target devices in a sharded way. - model = _create_model_partial() - return nnx.state(model) + mesh = model.mesh with mesh: - # Create the model with sharded parameters. - with nn.logical_axis_rules(config.logical_axis_rules): - sharded_state = create_sharded_state() - model = nnx.merge(graphdef, sharded_state) - - # print weights sharding info under debug sharding mode - if config.debug_sharding: - max_utils.print_non_trivial_mesh_axis(model.mesh) - maxtext_utils.print_shardings_params( - params=sharded_state, - params_sharding=out_shardings, - mesh=model.mesh, - logical_annotations=specs, - ) - if config.load_parameters_path: try: ckptr = ocp.Checkpointer( diff --git a/src/maxtext/utils/muon_utils.py b/src/maxtext/utils/muon_utils.py index 3ba60d7371..3bd2b186b1 100644 --- a/src/maxtext/utils/muon_utils.py +++ b/src/maxtext/utils/muon_utils.py @@ -24,25 +24,23 @@ python3 -m maxtext.utils.muon_utils qwen3-4b True """ - import os import sys from typing import Optional, Tuple import flax.linen as nn +from flax import nnx import jax from maxtext.configs import pyconfig from maxtext.utils.globals import MAXTEXT_PKG_DIR from maxtext.layers import quantizations from maxtext.models import models -from maxtext.utils import maxtext_utils +from maxtext.utils import maxtext_utils, model_creation_utils from optax.contrib._muon import MuonDimensionNumbers as mdn -Transformer = models.transformer_as_linen - - def _is_path_contain_any(tuples, path): + """Checks if any element in 'tuples' is present in 'path'.""" return any(x in path for x in tuples) @@ -107,10 +105,25 @@ def get_transform_tree(tree, path=()): def get_muon_weight_dimension_numbers(model, config, verbose=False): """Extract muon dimension number from model structure.""" - # quickly get param structure without materialization - abstract_param = maxtext_utils.get_abstract_param(model, config) - # get muon dimension number from param - muon_weight_dimension_numbers = get_transform_tree(abstract_param) + + if isinstance(model, nnx.Module): + _, abstract_param, _ = nnx.split(model, nnx.Param, ...) + + def apply_transform_nnx(path: Tuple[jax.tree_util.KeyEntry, ...], leaf): + # Convert jax.tree_util.KeyEntry path to Tuple[str, ...] + path_strings = tuple(p.key for p in path if isinstance(p, jax.tree_util.DictKey)) + return transform_logic(path_strings) + + # Use jax.tree_util.tree_map_with_path for NNX's potentially complex PyTree structure. + # This is different with linen where abstract_param is a dict-based tree with nn.LogicallyPartitioned leaves. + muon_weight_dimension_numbers = jax.tree_util.tree_map_with_path(apply_transform_nnx, abstract_param) + + else: # Linen + # quickly get param structure without materialization + abstract_param = maxtext_utils.get_abstract_param(model, config) + # get muon dimension number from param + muon_weight_dimension_numbers = get_transform_tree(abstract_param) + if verbose: _print_structure_debug(abstract_param, muon_weight_dimension_numbers) return muon_weight_dimension_numbers @@ -118,19 +131,30 @@ def get_muon_weight_dimension_numbers(model, config, verbose=False): def _print_structure_debug(abstract_param, muon_weight_dimension_numbers): """Prints the model structure and the resulting Muon config.""" - # Access the shape from the inner ShapeDtypeStruct and names from the wrapper - # Return a new tree with the same structure containing only shapes/names + + def get_leaf_info(leaf): + # For linen: + # Access the shape from the inner ShapeDtypeStruct and names from the wrapper + # Return a new tree with the same structure containing only shapes/names + if isinstance(leaf, nn.LogicallyPartitioned): + return {"shape": leaf.value.shape, "names": leaf.names} + # For nnx: + # Only return the shape because it doesn't have a wrapper. + elif isinstance(leaf, jax.ShapeDtypeStruct): + return {"shape": leaf.shape} + return {"shape": "N/A"} + info_tree = jax.tree_util.tree_map( - lambda leaf: {"shape": leaf.value.shape, "names": leaf.names}, + get_leaf_info, abstract_param, - is_leaf=lambda x: isinstance(x, nn.LogicallyPartitioned), + is_leaf=lambda x: isinstance(x, (nn.LogicallyPartitioned, jax.ShapeDtypeStruct)), ) print(f"\n=== Model Structure ===\n{info_tree}") print(f"\n=== Muon Dimension Numbers ===\n{muon_weight_dimension_numbers}") print("\nIs this reasonable?") -def get_model_mdn(model_name, scan_layers=True, verbose=False): +def get_model_mdn(model_name, scan_layers=True, verbose=False, pure_nnx=False): """Initializes a model and retrieves its Muon dimension numbers. This function sets up the configuration for a given model, initializes the @@ -154,13 +178,17 @@ def get_model_mdn(model_name, scan_layers=True, verbose=False): f"model_name={model_name}", f"scan_layers={scan_layers}", "attention=dot_product", + f"pure_nnx={pure_nnx}", ] config = pyconfig.initialize(argv) # Setup model devices_array = maxtext_utils.create_device_mesh(config) mesh = jax.sharding.Mesh(devices_array, config.mesh_axes) quant = quantizations.configure_quantization(config) - model = Transformer(config, mesh=mesh, quant=quant) + if pure_nnx: + _, model = model_creation_utils.create_nnx_abstract_model(config, mesh) + else: + model = models.transformer_as_linen(config, mesh=mesh, quant=quant) # Get dimension number muon_weight_dimension_numbers = get_muon_weight_dimension_numbers(model, config, verbose=verbose) return muon_weight_dimension_numbers @@ -172,4 +200,4 @@ def get_model_mdn(model_name, scan_layers=True, verbose=False): sys.exit(1) model_name_arg = sys.argv[1] scan_layers_arg = sys.argv[2].lower() == "true" - get_model_mdn(model_name_arg, scan_layers_arg, verbose=True) + get_model_mdn(model_name_arg, scan_layers_arg, verbose=True, pure_nnx=False) diff --git a/src/maxtext/utils/sharding.py b/src/maxtext/utils/sharding.py index d4bb64f016..4a500e2fe1 100644 --- a/src/maxtext/utils/sharding.py +++ b/src/maxtext/utils/sharding.py @@ -15,7 +15,7 @@ # pylint: disable=line-too-long, disable=bare-except, consider-using-generator """ Utils that are only interesting to MaxText and sharding related. """ -from flax import linen as nn +from flax import linen as nn, nnx from collections.abc import Iterable @@ -25,6 +25,7 @@ import optax +from maxtext.configs import pyconfig from maxtext.common.common_types import ShardMode from maxtext.utils import max_logging from maxtext.utils import max_utils @@ -483,6 +484,8 @@ def maybe_update_params_sharding_with_opt(config, state_mesh_shardings): - updated_state_mesh_shardings: State mesh shardings with updated params field (unchanged if shard_optimizer_over_data is False) """ + if config.pure_nnx: + return maybe_update_params_sharding_with_opt_nnx(config, state_mesh_shardings) prev_params_shardings = state_mesh_shardings.params if config.shard_optimizer_over_data: if isinstance(state_mesh_shardings.opt_state, optax.ScaleByAdamState): @@ -501,6 +504,122 @@ def maybe_update_params_sharding_with_opt(config, state_mesh_shardings): return prev_params_shardings, state_mesh_shardings +def maybe_update_params_sharding_with_opt_nnx( + config: pyconfig.HyperParameters, state_mesh_shardings: nnx.State +) -> tuple[nnx.State, nnx.State]: + """ + NNX version of parameter sharding update. Updates parameter sharding configuration + when optimizer state sharding is enabled. + + When shard_optimizer_over_data is enabled (Zero-1 style sharding), this function + extracts the optimizer state shardings from the Adam optimizer's first moment (mu) + and merges them with the parameter shardings. This ensures parameter sharding is + consistent with how the optimizer state is distributed across the compute mesh. + + Args: + config: Configuration with shard_optimizer_over_data flag. + state_mesh_shardings: The sharding state for a TrainStateNNX container. + + Returns: + A tuple of (prev_params_shardings, updated_state_mesh_shardings): + - prev_params_shardings: Original parameter shardings before the update + - updated_state_mesh_shardings: State mesh shardings with updated params field + (unchanged if shard_optimizer_over_data is False)""" + # In TrainStateNNX, parameters are under 'model' + model_shardings = state_mesh_shardings.model + + def _extract_param_only(state): + """Recursively extract nnx.Param variables from an nnx.State into a nested plain dict. + + Constructs nnx.State({'key': nested_dict, ...}) which produces the same pytree + structure as nnx.split(model, nnx.Param, ...)[1], enabling jax.tree.map + to work correctly between ga_params (Param-only) and params_shardings. + """ + result = {} + for k, v in state.items(): + if isinstance(v, nnx.Param): + result[k] = v + elif isinstance(v, nnx.Variable): + pass # skip non-Param variables (RngKey, RngCount, OptVariable, etc.) + elif hasattr(v, "items"): + sub = _extract_param_only(v) + if sub: + result[k] = sub + return result + + # prev_params_shardings must match the pytree structure of ga_params from + # nnx.split(model, nnx.Param, ...) — Param variables only, no rngs. + prev_params_shardings = nnx.State(_extract_param_only(model_shardings)) + + if not config.shard_optimizer_over_data: + return prev_params_shardings, state_mesh_shardings + + sharded_fp32_params = None + # Check if the optimizer has any state at all (stateless optimizers like SGD omit this key) + if "opt_state" in state_mesh_shardings.optimizer: + # Access the optimizer branch to find the optax state + # state_mesh_shardings.optimizer contains the sharding for the nnx.Optimizer + opt_state = state_mesh_shardings.optimizer.opt_state + + def find_adam_mu(obj): + # 1. Direct hit on ScaleByAdamState (Linen path or unflattened NNX) + if isinstance(obj, optax.ScaleByAdamState): + return obj.mu + + # 2. Check for flattened ScaleByAdamState (nnx.State/dict) + # These nodes contain 'mu', 'nu', and 'count' as keys. + if hasattr(obj, "__getitem__") and "mu" in obj and "nu" in obj: + return obj["mu"] + + # 3. Recursive search through containers (nnx.State, dict, list, tuple) + values = None + if hasattr(obj, "values"): # Handles nnx.State and dict + values = obj.values() + elif isinstance(obj, (list, tuple)): + values = obj + + if values: + for v in values: + res = find_adam_mu(v) + if res is not None: + return res + return None + + sharded_fp32_params = find_adam_mu(opt_state) + if sharded_fp32_params is None: + actual_type = type(state_mesh_shardings.optimizer.get("opt_state", "None")) + raise NotImplementedError(f"Could not find Adam optimizer state in: {actual_type}") + + # Update model parameter sharding to match the mu (first moment) sharding. + # This ensures parameter sharding is consistent with the Zero-1 distributed layout. + # Build a path → new_PS lookup from sharded_fp32_params (mu), then update model_shardings + # at those paths while preserving rngs and any other non-Param variables. + mu_leaves_with_paths = list( + jax.tree_util.tree_leaves_with_path(sharded_fp32_params, is_leaf=lambda x: isinstance(x, nnx.Variable)) + ) + mu_lookup = {path: mu_var.get_value() for path, mu_var in mu_leaves_with_paths} + + def _update_model_var(path, var): + if path in mu_lookup: + return var.replace(mu_lookup[path]) + return var + + new_model_shardings = jax.tree_util.tree_map_with_path( + _update_model_var, model_shardings, is_leaf=lambda x: isinstance(x, nnx.Variable) + ) + # Use jax.tree_util.tree_map (identity) to create a new nnx.State via JAX's unflatten + # mechanism (not the nnx.State constructor). This is critical because: + # 1. nnx.State({...}) constructor recursively converts nested plain dicts to nnx.State, + # causing a pytree type mismatch with the actual state from nnx.split (which stores + # nested module states as plain dicts). JAX's unflatten preserves the original types. + # 2. copy.deepcopy fails because NamedSharding contains non-picklable jaxlib.Device objects. + # Direct __setattr__ assignment stores new_model_shardings as-is (no type conversion). + updated_state = jax.tree_util.tree_map(lambda x: x, state_mesh_shardings, is_leaf=lambda x: isinstance(x, nnx.Variable)) + updated_state.model = new_model_shardings + + return prev_params_shardings, updated_state + + def logical_axis_rules_pp_act_as_dp(logical_rules): """Add stage as a physical axes before data for each rule, so stage acts just like data instead of PP. This is used when we want to pipeline only a subset of layers, and leave the rest like DP. diff --git a/src/maxtext/utils/train_utils.py b/src/maxtext/utils/train_utils.py index 906a597728..ca90550630 100644 --- a/src/maxtext/utils/train_utils.py +++ b/src/maxtext/utils/train_utils.py @@ -15,12 +15,14 @@ # pylint: disable=bare-except, consider-using-generator """Utils that are only interesting for training in MaxText.""" +import functools import os from functools import partial import jax -import functools +from flax import nnx from flax.linen import partitioning as nn_partitioning +from maxtext.layers import train_state_nnx from maxtext.common import checkpointing from maxtext.common.data_loader import create_dataloader from maxtext.common.goodput import GoodputEvent, maybe_record_goodput @@ -205,7 +207,7 @@ def setup_train_loop(config, recorder, devices=None): data_iterator: data_loader: rampup_manager: the class managing rampup batch sizes - state: the initialized train state + train_state: the initialized train state. For NNX, this is a TrainStateNNX instance """ # pylint: disable=import-outside-toplevel from maxtext.input_pipeline.input_pipeline_interface import create_data_iterator @@ -213,16 +215,22 @@ def setup_train_loop(config, recorder, devices=None): with maybe_record_goodput(recorder, GoodputEvent.TPU_INIT): is_training = True init_rng = jax.random.PRNGKey(config.init_weights_seed) + mesh = maxtext_utils.get_mesh_from_config(config, devices) if config.pure_nnx: # Create abstract NNX model. - raise NotImplementedError("Pure NNX support has not been implemented yet.") + _create_model_partial, model = model_creation_utils.create_nnx_abstract_model(config, mesh, devices) else: model = model_creation_utils.from_config(config, devices) - mesh = model.mesh learning_rate_schedule, tx = create_training_optimizer(config, model) + if config.pure_nnx: - # NNX has a different function to init the training state. - raise NotImplementedError("Pure NNX support has not been implemented yet.") + # For NNX, the train state is wrapped in the TrainStateNNX module. + def create_train_state_fn(): + model = _create_model_partial() + optimizer = nnx.Optimizer(model, tx, wrt=nnx.Param) + return train_state_nnx.TrainStateNNX(model, optimizer) + + init_state_fn = create_train_state_fn else: init_state_fn = partial(maxtext_utils.init_initial_state, model, tx, config, is_training, init_rng) checkpoint_manager = create_checkpoint_manager(config, mesh, init_state_fn) @@ -266,6 +274,15 @@ def setup_train_loop(config, recorder, devices=None): state, _, state_mesh_shardings, data_iterator = maxtext_utils.setup_training_state( data_iterator, config, mesh, checkpoint_manager, init_state_fn ) + if config.pure_nnx: + with nn_partitioning.axis_rules(config.logical_axis_rules): + # train_state is instance of TrainStateNNX + state_graphdef, _ = nnx.get_abstract_model(init_state_fn, mesh) + _, state_params, _ = nnx.split(state.model, nnx.Param, ...) + _, state_mesh_shardings_params, _ = nnx.split(state_mesh_shardings.model, nnx.Param, ...) + else: + state_params = state.params + state_mesh_shardings_params = state_mesh_shardings.params if config.enable_diloco: with jax.set_mesh(mesh), nn_partitioning.axis_rules(config.logical_axis_rules): @@ -283,17 +300,24 @@ def setup_train_loop(config, recorder, devices=None): # TODO(aireenmei, hengtaoguo): support sharding in vit for multimodal if not config.using_pipeline_parallelism and not config.use_multimodal: # The vocab tensor(s) of shape [vocab, embed] (and transpose) are not sharded by stage - sharding.assert_params_sufficiently_sharded(state.params, mesh, config.sharding_tolerance) + sharding.assert_params_sufficiently_sharded(state_params, mesh, config.sharding_tolerance) # print weights sharding info under debug sharding mode if config.debug_sharding: - logical_annotations = maxtext_utils.get_logical_annotations(config, mesh, init_state_fn) + if config.pure_nnx: + # TODO: Study how to get logical annotations of NNX module. Because of eager sharding, we + # probably already lost the logical partition info at this moment. + logical_annotations_params = None + else: + logical_annotations = maxtext_utils.get_logical_annotations(config, mesh, init_state_fn) + logical_annotations_params = logical_annotations.params + max_utils.print_non_trivial_mesh_axis(model.mesh) - maxtext_utils.print_shardings_params( - state.params, state_mesh_shardings.params, model.mesh, logical_annotations.params - ) + maxtext_utils.print_shardings_params(state_params, state_mesh_shardings_params, mesh, logical_annotations_params) if config.use_dpo: + if config.pure_nnx: + raise NotImplementedError("DPO is not supported yet by NNX models.") abstract_state, _, _ = maxtext_utils.get_abstract_state(config, mesh, init_state_fn, is_training) max_logging.log( "Restoring reference parameters for DPO from" f" '{os.path.join(str(config.checkpoint_dir), str(0))}'" @@ -318,12 +342,18 @@ def setup_train_loop(config, recorder, devices=None): except FileNotFoundError: step0_restored = None if step0_restored is not None: + # TODO: For pure_nnx, the dpo state manipulation is different. reference_params = step0_restored["items"].params["params"] state = _merge_dpo_state(state, reference_params) else: max_logging.log( "Could not restore reference parameters for DPO from" f" '{os.path.join(str(config.checkpoint_dir), str(0))}'" ) + if config.pure_nnx: + train_state = nnx.merge(state_graphdef, state) + model = train_state.model + else: + train_state = state return ( init_rng, @@ -336,7 +366,7 @@ def setup_train_loop(config, recorder, devices=None): data_loader, rampup_manager, eval_data_iterator, - state, + train_state, ) diff --git a/tests/integration/setup_train_loop_nnx_test.py b/tests/integration/setup_train_loop_nnx_test.py new file mode 100644 index 0000000000..c15c59fd3b --- /dev/null +++ b/tests/integration/setup_train_loop_nnx_test.py @@ -0,0 +1,140 @@ +# Copyright 2025-2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Integration test for setup_train_loop with pure_nnx=True. + +setup_train_loop wires together create_nnx_abstract_model, the training optimizer, +the checkpoint manager, the data iterator, and finally nnx.split / nnx.merge to +return a fully-formed TrainStateNNX. This test exercises that wiring end-to-end +on a tiny synthetic config — the goal is to cover the integration glue that the +unit tests in tests/unit/train_utils_nnx_test.py cannot reach. +""" + +import os +import sys +import unittest + +import pytest + +import jax +from flax import nnx + +from maxtext.configs import pyconfig +from maxtext.layers import train_state_nnx +from maxtext.utils.globals import MAXTEXT_ASSETS_ROOT +from maxtext.utils.train_utils import setup_train_loop +from tests.utils.test_helpers import get_test_config_path + + +def _tiny_nnx_pyconfig(**overrides): + """Build a tiny pyconfig suitable for a single-host setup_train_loop run.""" + init_kwargs = { + "run_name": "setup_train_loop_nnx_test", + "enable_checkpointing": False, + "dataset_type": "synthetic", + "model_name": "default", + "pure_nnx": True, + "per_device_batch_size": 1.0, + "base_emb_dim": 8, + "base_num_query_heads": 4, + "base_num_kv_heads": 4, + "base_mlp_dim": 32, + "base_num_decoder_layers": 2, + "head_dim": 128, + "max_target_length": 128, + "vocab_size": 256, + "steps": 1, + "tokenizer_path": os.path.join(MAXTEXT_ASSETS_ROOT, "tokenizers", "tokenizer.llama2"), + "enable_goodput_recording": False, + "enable_checkpoint_cloud_logger": False, + "monitor_goodput": False, + } + init_kwargs.update(overrides) + return pyconfig.initialize([sys.argv[0], get_test_config_path()], **init_kwargs) + + +@pytest.mark.integration_test +@pytest.mark.tpu_only +class SetupTrainLoopNNXIntegrationTest(unittest.TestCase): + """End-to-end check that setup_train_loop returns a usable TrainStateNNX.""" + + def test_pure_nnx_setup_returns_train_state_nnx(self): + config = _tiny_nnx_pyconfig() + + ( + init_rng, + checkpoint_manager, + state_mesh_shardings, + model, + mesh, + learning_rate_schedule, + data_iterator, + data_loader, + rampup_manager, + eval_data_iterator, + train_state, + ) = setup_train_loop(config, recorder=None) + + # The NNX path returns a fully-merged TrainStateNNX (lines 352-354 in train_utils.py). + self.assertIsInstance(train_state, train_state_nnx.TrainStateNNX) + # Optimizer.step starts at 0 for a fresh init. + self.assertEqual(int(train_state.optimizer.step.get_value()), 0) + # The returned model is train_state.model, an NNX module. + self.assertIsInstance(model, nnx.Module) + self.assertIs(model, train_state.model) + + # Sanity for sibling outputs: + self.assertIsNotNone(init_rng) + self.assertIsNotNone(mesh) + self.assertTrue(callable(learning_rate_schedule)) + # data_loader is mandatory; data_iterator may be wrapped/unwrapped. + self.assertIsNotNone(data_loader) + self.assertIsNotNone(data_iterator) + + # state_mesh_shardings (NNX) is an nnx.State and contains a 'model' branch. + self.assertIsInstance(state_mesh_shardings, nnx.State) + self.assertIn("model", state_mesh_shardings) + + # Cleanup: the rest are not asserted on but referenced so linters don't + # flag them as unused — they're part of the public return contract. + del checkpoint_manager, rampup_manager, eval_data_iterator + + def test_pure_nnx_setup_param_only_split_matches_model(self): + """nnx.split(state.model, nnx.Param, ...) must yield a non-empty Param tree + whose structure matches state_mesh_shardings.model after the same split.""" + config = _tiny_nnx_pyconfig() + *_, state_mesh_shardings, model, _, _, _, _, _, _, train_state = setup_train_loop(config, recorder=None) + + _, params, _ = nnx.split(train_state.model, nnx.Param, ...) + _, params_shardings, _ = nnx.split(state_mesh_shardings.model, nnx.Param, ...) + + # Same key-set after nnx.split — this is what setup_train_loop relies on at + # train_utils.py:281-282 to pair state_params with state_mesh_shardings_params. + self.assertEqual(jax.tree_util.tree_structure(params), jax.tree_util.tree_structure(params_shardings)) + self.assertGreater(len(jax.tree.leaves(params)), 0) + + del model + + def test_pure_nnx_dpo_raises_not_implemented(self): + """The use_dpo branch (train_utils.py:319-320) must raise for NNX.""" + # use_dpo requires a few prerequisites; the simplest is to set the flag and + # let setup_train_loop reach the NotImplementedError check before the more + # involved DPO path runs. + config = _tiny_nnx_pyconfig(use_dpo=True) + with self.assertRaises(NotImplementedError): + setup_train_loop(config, recorder=None) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/unit/checkpointing_nnx_load_test.py b/tests/unit/checkpointing_nnx_load_test.py new file mode 100644 index 0000000000..622f19323a --- /dev/null +++ b/tests/unit/checkpointing_nnx_load_test.py @@ -0,0 +1,106 @@ +# Copyright 2025-2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unit tests for the NNX branches of load_state_if_possible.""" + +import unittest +from unittest import mock + +import jax +import jax.numpy as jnp +import optax +from flax import nnx + +from maxtext.common import checkpointing +from maxtext.layers import train_state_nnx + + +class _Model(nnx.Module): + """Tiny single-linear NNX model for restore tests.""" + + def __init__(self, rngs: nnx.Rngs): + self.linear = nnx.Linear(2, 1, rngs=rngs) + + +def _abstract_nnx_state(): + """Build an nnx.State from a TrainStateNNX — same shape that pre_train passes in.""" + model = _Model(rngs=nnx.Rngs(0)) + optimizer = nnx.Optimizer(model, optax.adam(1e-3), wrt=nnx.Param) + return nnx.state(train_state_nnx.TrainStateNNX(model, optimizer)) + + +class TestLoadStateIfPossibleNNX(unittest.TestCase): + """Cover the NNX branches in load_state_if_possible.""" + + def test_load_parameters_from_path_splits_nnx_state_for_param_view(self): + """When abstract_unboxed_pre_state is an nnx.State, the function must call + nnx.split(model, nnx.Param, ...) to get the params and forward them to load_params_from_path.""" + abstract = _abstract_nnx_state() + sentinel_restored = {"linear": {"kernel": jnp.ones((2, 1)), "bias": jnp.zeros((1,))}} + + with mock.patch.object(checkpointing, "load_params_from_path", return_value=sentinel_restored) as m: + full, params = checkpointing.load_state_if_possible( + checkpoint_manager=None, + data_iterator=None, + load_parameters_from_path="gs://does-not-exist/params", + load_full_state_from_path="", + checkpoint_storage_concurrent_gb=8, + abstract_unboxed_pre_state=abstract, + ) + + self.assertIsNone(full) + self.assertIs(params, sentinel_restored) + m.assert_called_once() + forwarded_params = m.call_args[0][1] # second positional arg = abstract_unboxed_params + # The forwarded params come from nnx.split(..., nnx.Param, ...) — same key shape as the model. + leaves = jax.tree.leaves(forwarded_params) + self.assertEqual(len(leaves), 2) # linear.kernel + linear.bias + + def test_load_parameters_from_path_uses_state_params_for_linen(self): + """For Linen TrainState, the function must use state.params (not nnx.split).""" + fake_state = mock.Mock(spec=["params"]) + fake_state.params = {"layer": {"kernel": jnp.ones((2, 2))}} + sentinel = object() + + with mock.patch.object(checkpointing, "load_params_from_path", return_value=sentinel) as m: + full, params = checkpointing.load_state_if_possible( + checkpoint_manager=None, + data_iterator=None, + load_parameters_from_path="gs://does-not-exist/params", + load_full_state_from_path="", + checkpoint_storage_concurrent_gb=8, + abstract_unboxed_pre_state=fake_state, + ) + + self.assertIsNone(full) + self.assertIs(params, sentinel) + forwarded_params = m.call_args[0][1] + self.assertIs(forwarded_params, fake_state.params) + + def test_no_paths_returns_none_none(self): + """Sanity: with no checkpoint manager and no load paths, the function returns (None, None).""" + full, params = checkpointing.load_state_if_possible( + checkpoint_manager=None, + data_iterator=None, + load_parameters_from_path="", + load_full_state_from_path="", + checkpoint_storage_concurrent_gb=8, + abstract_unboxed_pre_state=_abstract_nnx_state(), + ) + self.assertIsNone(full) + self.assertIsNone(params) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/unit/gradient_accumulation_nnx_test.py b/tests/unit/gradient_accumulation_nnx_test.py new file mode 100644 index 0000000000..6353f02397 --- /dev/null +++ b/tests/unit/gradient_accumulation_nnx_test.py @@ -0,0 +1,159 @@ +# Copyright 2025-2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unit tests for the NNX branch of gradient_accumulation_loss_and_grad.""" + +import unittest +from dataclasses import dataclass + +import jax +import jax.numpy as jnp +import numpy as np +from flax import nnx +from jax.sharding import Mesh, NamedSharding, PartitionSpec + +from maxtext.common.common_types import ShardMode +from maxtext.utils import gradient_accumulation + + +@dataclass +class _Cfg: + gradient_accumulation_steps: int = 2 + shard_optimizer_over_data: bool = False + shard_mode: int = ShardMode.AUTO + ici_data_parallelism: int = 1 + debug_sharding: bool = False + + +class _TinyNNX(nnx.Module): + """Single linear layer NNX model.""" + + def __init__(self, rngs: nnx.Rngs): + self.linear = nnx.Linear(2, 1, rngs=rngs) + + def __call__(self, x): + return self.linear(x) + + +def _fake_loss_fn(model, config, data, dropout_rng, params, is_train=True): + """A loss_fn shaped like the production loss_fn but for a tiny linear model. + + Returns (loss, aux) where aux follows the schema gradient_accumulation_loss_and_grad + reads from: xent_sum / total_weights / moe_lb_loss / indexer_loss / mtp_loss. + """ + del config, dropout_rng, params, is_train + pred = model(data["inputs"]) + per_sample_loss = jnp.mean((pred - data["targets"]) ** 2, axis=-1) + xent_sum = jnp.sum(per_sample_loss) + total_weights = jnp.array(per_sample_loss.shape[0], dtype=jnp.float32) + aux = { + "xent_sum": xent_sum, + "total_weights": total_weights, + "moe_lb_loss": jnp.array(0.0), + "indexer_loss": jnp.array(0.0), + "mtp_loss": jnp.array(0.0), + } + return xent_sum / total_weights, aux + + +class TestGradientAccumulationNNX(unittest.TestCase): + """Cover the NNX path of gradient_accumulation_loss_and_grad.""" + + def setUp(self): + self.model = _TinyNNX(rngs=nnx.Rngs(0)) + self.cfg = _Cfg(gradient_accumulation_steps=2) + # 4 examples → 2 microbatches of 2 each + self.data = { + "inputs": jnp.arange(8.0).reshape(4, 2), + "targets": jnp.zeros((4, 1)), + } + + def _params_shardings(self): + """Build a per-leaf NamedSharding tree shaped like nnx.split(model, nnx.Param, ...)[1]. + + Uses a trivial single-device mesh so jax.lax.with_sharding_constraint accepts the + sharding without contradicting the actual device topology. + """ + _, params, _ = nnx.split(self.model, nnx.Param, ...) + mesh = Mesh( + np.array(jax.local_devices()[:1]).reshape( + 1, + ), + ("x",), + ) + ns = NamedSharding(mesh, PartitionSpec()) + return jax.tree.map(lambda _: ns, params) + + def test_nnx_path_runs_and_returns_grad_for_every_param(self): + """The NNX branch must call nnx.value_and_grad and return one gradient per Param.""" + loss, aux, raw_grads = gradient_accumulation.gradient_accumulation_loss_and_grad( + _fake_loss_fn, + self.cfg, + self.model, + params=None, # NNX branch ignores params + params_shardings=self._params_shardings(), + data=self.data, + dropout_rng=None, + extra_dpo_args=[], + ) + self.assertTrue(jnp.isfinite(loss)) + self.assertIn("xent_sum", aux) + self.assertIn("total_weights", aux) + grad_leaves = jax.tree.leaves(raw_grads) + self.assertEqual(len(grad_leaves), 2) # linear.kernel + linear.bias + for g in grad_leaves: + self.assertTrue(jnp.all(jnp.isfinite(g))) + + def test_nnx_path_updates_model_rest_state_after_scan(self): + """After accumulation, nnx.update is called on the model with the rest_state from the scan. + + For a TinyNNX (no rngs/dropout), the rest tree is empty but the call path must still + succeed end-to-end without raising — covering the `if is_nnx: nnx.update(...)` branch. + """ + pre_kernel = self.model.linear.kernel.value.copy() + gradient_accumulation.gradient_accumulation_loss_and_grad( + _fake_loss_fn, + self.cfg, + self.model, + params=None, + params_shardings=self._params_shardings(), + data=self.data, + dropout_rng=None, + extra_dpo_args=[], + ) + # The kernel itself is a Param — gradient_accumulation_loss_and_grad does not apply + # gradients to params, so the value should be untouched. + self.assertTrue(jnp.allclose(self.model.linear.kernel.value, pre_kernel)) + + def test_nnx_with_shard_optimizer_over_data_casts_to_bf16(self): + """Zero-1 path must convert fp32 params to bf16 before the scan loop.""" + self.cfg.shard_optimizer_over_data = True + # Should not raise; just verify the function runs and returns sensible outputs. + loss, _, raw_grads = gradient_accumulation.gradient_accumulation_loss_and_grad( + _fake_loss_fn, + self.cfg, + self.model, + params=None, + params_shardings=self._params_shardings(), + data=self.data, + dropout_rng=None, + extra_dpo_args=[], + ) + self.assertTrue(jnp.isfinite(loss)) + for g in jax.tree.leaves(raw_grads): + self.assertTrue(jnp.all(jnp.isfinite(g))) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/unit/maxtext_utils_test.py b/tests/unit/maxtext_utils_test.py index 7a09750a86..2ef825c9a7 100644 --- a/tests/unit/maxtext_utils_test.py +++ b/tests/unit/maxtext_utils_test.py @@ -15,11 +15,13 @@ """Tests for the common MaxText utilities""" import functools -from typing import Any, Sequence from collections.abc import Callable +from typing import Any, Sequence import unittest from unittest.mock import MagicMock, Mock, patch from dataclasses import dataclass, field +import numpy as np +import optax from flax import linen as nn from flax import nnx @@ -29,6 +31,7 @@ from jax import random, vmap import jax.numpy as jnp from jax.sharding import AxisType, Mesh, NamedSharding, PartitionSpec +from jax.experimental import mesh_utils from maxtext.configs import pyconfig from maxtext.common.common_types import DecoderBlockType, MODEL_MODE_TRAIN, ShardMode from maxtext.inference import inference_utils @@ -39,8 +42,7 @@ from maxtext.utils import sharding from maxtext.utils.sharding import assert_params_sufficiently_sharded, get_formatted_sharding_annotations from tests.utils.test_helpers import get_test_config_path, get_decoupled_parallelism_overrides -import numpy as np -import optax +from maxtext.utils import maxtext_utils_nnx Transformer = models.transformer_as_linen @@ -179,11 +181,7 @@ def setUp(self): "decoder": {"gate": {"bias": jnp.array([0.5, 0.5])}}, } self.state = train_state.TrainState( - step=0, - apply_fn=self.model.apply, - params=self.initial_params, - tx=None, - opt_state={}, + step=0, apply_fn=self.model.apply, params=self.initial_params, tx=None, opt_state={} ) def test_update_mode_add(self): @@ -196,10 +194,10 @@ def test_update_mode_add(self): self.assertTrue(jnp.allclose(actual, expected)) # Other values are untouched - original_layer_0 = self.state.params["layers"]["layer_0"]["bias"] + original_layer_0 = self.state.params["layers"]["layer_0"]["bias"] # pylint: disable=unsubscriptable-object new_layer_0 = new_state.params["layers"]["layer_0"]["bias"] self.assertTrue(jnp.array_equal(original_layer_0, new_layer_0)) - original_layer_1 = self.state.params["layers"]["layer_1"]["bias"] + original_layer_1 = self.state.params["layers"]["layer_1"]["bias"] # pylint: disable=unsubscriptable-object new_layer_1 = new_state.params["layers"]["layer_1"]["bias"] self.assertTrue(jnp.array_equal(original_layer_1, new_layer_1)) @@ -264,7 +262,7 @@ def test_init_training_state(self): @nnx.register_variable_name("special_variables") -class SpecialVariables(nnx.Variable): +class SpecialVariables(nnx.Variable): # pylint: disable=abstract-method pass @@ -281,7 +279,7 @@ def __call__(self, x, y, encoder_images=None, nnx_method=None, model_mode=None): return x -class TrainState(train_state.TrainState): +class TrainState(train_state.TrainState): # pylint: disable=abstract-method other_variables: nnx.State @@ -993,49 +991,63 @@ def train_step(_model, _config, _state_shardings, _params_shardings, state, _bat return train_step + def _make_mock_config(self, pure_nnx=False): + cfg = MagicMock() + cfg.pure_nnx = pure_nnx + return cfg + def test_returns_five_tuple(self): step = self._make_mock_step() result = maxtext_utils.get_functional_train_with_signature( - step, "data_sharding", "state_shardings", "model", "config" + step, "data_sharding", "state_shardings", "model", self._make_mock_config() ) self.assertEqual(len(result), 5) def test_functional_train_has_correct_name(self): step = self._make_mock_step() fn, _, _, _, _ = maxtext_utils.get_functional_train_with_signature( - step, "data_sharding", "state_shardings", "model", "config" + step, "data_sharding", "state_shardings", "model", self._make_mock_config() ) self.assertEqual(fn.__name__, "train_step") - def test_in_shardings_structure(self): + def test_linen_in_shardings_includes_rng(self): + """pure_nnx=False: in_shardings should be (state, batch, rng).""" step = self._make_mock_step() _, in_shardings, _, _, _ = maxtext_utils.get_functional_train_with_signature( - step, "data_sharding", "state_shardings", "model", "config" + step, "data_sharding", "state_shardings", "model", self._make_mock_config(pure_nnx=False) ) - # (state, batch, rng) self.assertEqual(len(in_shardings), 3) self.assertIsNone(in_shardings[2]) # rng sharding is None + def test_nnx_in_shardings_excludes_rng(self): + """pure_nnx=True: in_shardings should be (state, batch) — no rng slot.""" + step = self._make_mock_step() + _, in_shardings, _, _, _ = maxtext_utils.get_functional_train_with_signature( + step, "data_sharding", "state_shardings", "model", self._make_mock_config(pure_nnx=True) + ) + self.assertEqual(len(in_shardings), 2) + def test_donate_argnums_is_zero(self): step = self._make_mock_step() _, _, _, _, donate_argnums = maxtext_utils.get_functional_train_with_signature( - step, "data_sharding", "state_shardings", "model", "config" + step, "data_sharding", "state_shardings", "model", self._make_mock_config() ) self.assertEqual(donate_argnums, 0) def test_functional_train_is_partial(self): """functional_train should partially apply model and config.""" received = {} + cfg = self._make_mock_config() def train_step(model, config, _state_shardings, _params_shardings, state, _batch, _rng=None): received["model"] = model received["config"] = config return state, {} - fn, _, _, _, _ = maxtext_utils.get_functional_train_with_signature(train_step, "ds", "ss", "my_model", "my_config") + fn, _, _, _, _ = maxtext_utils.get_functional_train_with_signature(train_step, "ds", "ss", "my_model", cfg) fn("state", "batch") self.assertEqual(received["model"], "my_model") - self.assertEqual(received["config"], "my_config") + self.assertEqual(received["config"], cfg) class TestGetFunctionalEvalWithSignature(unittest.TestCase): @@ -1047,26 +1059,51 @@ def eval_step(_model, _config, _state, _batch, _rng=None): return eval_step + def _make_mock_config(self, pure_nnx=False): + cfg = MagicMock() + cfg.pure_nnx = pure_nnx + return cfg + def test_returns_five_tuple(self): step = self._make_mock_eval_step() - result = maxtext_utils.get_functional_eval_with_signature(step, "ds", "ss", "model", "config") + result = maxtext_utils.get_functional_eval_with_signature(step, "ds", "ss", "model", self._make_mock_config()) self.assertEqual(len(result), 5) def test_functional_eval_has_correct_name(self): step = self._make_mock_eval_step() - fn, _, _, _, _ = maxtext_utils.get_functional_eval_with_signature(step, "ds", "ss", "model", "config") + fn, _, _, _, _ = maxtext_utils.get_functional_eval_with_signature(step, "ds", "ss", "model", self._make_mock_config()) self.assertEqual(fn.__name__, "eval_step") def test_out_shardings_is_none(self): step = self._make_mock_eval_step() - _, _, out_shardings, _, _ = maxtext_utils.get_functional_eval_with_signature(step, "ds", "ss", "model", "config") + _, _, out_shardings, _, _ = maxtext_utils.get_functional_eval_with_signature( + step, "ds", "ss", "model", self._make_mock_config() + ) self.assertIsNone(out_shardings) def test_donate_argnums_is_empty(self): step = self._make_mock_eval_step() - _, _, _, _, donate_argnums = maxtext_utils.get_functional_eval_with_signature(step, "ds", "ss", "model", "config") + _, _, _, _, donate_argnums = maxtext_utils.get_functional_eval_with_signature( + step, "ds", "ss", "model", self._make_mock_config() + ) self.assertEqual(donate_argnums, ()) + def test_nnx_in_shardings_excludes_rng(self): + """pure_nnx=True: in_shardings should be (state, batch) — no rng slot.""" + step = self._make_mock_eval_step() + _, in_shardings, _, _, _ = maxtext_utils.get_functional_eval_with_signature( + step, "batch_sharding", "state_sharding", "model", self._make_mock_config(pure_nnx=True) + ) + self.assertEqual(len(in_shardings), 2) + + def test_linen_in_shardings_includes_rng(self): + """pure_nnx=False: in_shardings should be (state, batch, rng).""" + step = self._make_mock_eval_step() + _, in_shardings, _, _, _ = maxtext_utils.get_functional_eval_with_signature( + step, "batch_sharding", "state_sharding", "model", self._make_mock_config(pure_nnx=False) + ) + self.assertEqual(len(in_shardings), 3) + class TestGetShapedBatch(unittest.TestCase): """Tests for get_shaped_batch.""" @@ -1414,5 +1451,183 @@ def test_runs_without_logical_annotations(self): maxtext_utils.print_shardings_params(params, param_sharding, mesh=self.mesh, logical_annotations=None) +class TestNNXAbstractState(unittest.TestCase): + """Test the get_abstract_state_nnx func.""" + + @dataclass + class MockConfig: + init_weights_seed: int = 42 + shard_optimizer_over_data: bool = False + optimizer_memory_host_offload: bool = False + parameter_memory_host_offload: bool = False + param_scan_axis: int = 0 + logical_axis_rules: list = field(default_factory=lambda: [["data", ["data"]]]) + + class MockTrainState(nnx.Module): + """Simulates a TrainState with params and optimizer state.""" + + def __init__(self, rngs: nnx.Rngs): + # Model parameters + device_num = len(jax.local_devices()) + self.params = nnx.Linear( + 2, 4, kernel_init=nnx.with_partitioning(nnx.initializers.ones, sharding=("model",)), rngs=rngs + ) + # Simulated optimizer state + self.optimizer = nnx.Variable(jnp.zeros((device_num,)), sharding=("model",)) + + def setUp(self): + # Create a real 1D mesh on local devices + devices = jax.local_devices() + self.mesh = Mesh(mesh_utils.create_device_mesh((len(devices), 1)), axis_names=("model", "data")) + self.config = self.MockConfig() + + def nnx_init_trainstate_wrapper(self): + """Wrapper to initialize the mock NNX model.""" + rngs = maxtext_utils_nnx.create_nnx_rngs(self.config) + return self.MockTrainState(rngs) + + def test_basic_abstraction(self): + """Verifies the basic return structure and partition spec extraction.""" + abstract_state, annotations, shardings = maxtext_utils.get_abstract_state_nnx( + self.config, self.mesh, self.nnx_init_trainstate_wrapper + ) + + # Check return types + self.assertIsInstance(abstract_state, nnx.State) + self.assertIsInstance(annotations, nnx.State) + self.assertIsInstance(shardings, nnx.State) + + # Verify PartitionSpec was extracted correctly from the mock model's annotations + # Path: params -> kernel -> spec + self.assertEqual( + annotations.params.kernel.get_value(), + PartitionSpec( + "model", + ), + ) + + def test_shard_optimizer_over_data(self): + """Verifies that 'data' is added to optimizer sharding using the real utility.""" + self.config.shard_optimizer_over_data = True + + _, annotations, _ = maxtext_utils.get_abstract_state_nnx(self.config, self.mesh, self.nnx_init_trainstate_wrapper) + + # Original Pspec for optimizer was PartitionSpec(None). + # add_data_to_sharding should find that dim 0 is compatible with mesh 'data' + # and update it to PartitionSpec(('data',)). + opt_spec = annotations.optimizer.get_value() + + # Verify 'data' is now in the spec + self.assertEqual(opt_spec, PartitionSpec(("data", "model"))) + + def test_optimizer_host_offload(self): + """Verifies that optimizer memory is moved to host when configured.""" + self.config.optimizer_memory_host_offload = True + + _, _, shardings = maxtext_utils.get_abstract_state_nnx(self.config, self.mesh, self.nnx_init_trainstate_wrapper) + + # Optimizer state should be pinned to host + opt_sharding = shardings.optimizer.get_value() + self.assertEqual(opt_sharding.memory_kind, "pinned_host") + + # Params should still be on default memory (usually device) + param_sharding = shardings.params.kernel.get_value() + self.assertNotEqual(param_sharding.memory_kind, "pinned_host") + + def test_parameter_host_offload(self): + """Verifies that parameter memory is moved to host when configured.""" + self.config.parameter_memory_host_offload = True + self.config.param_scan_axis = 0 + + _, _, shardings = maxtext_utils.get_abstract_state_nnx(self.config, self.mesh, self.nnx_init_trainstate_wrapper) + + # Parameters should be pinned to host + param_sharding = shardings.params.kernel.get_value() + self.assertEqual(param_sharding.memory_kind, "pinned_host") + + def test_invalid_init_fn(self): + """Ensures function raises error if no init function is provided.""" + with self.assertRaises(AssertionError): + maxtext_utils.get_abstract_state_nnx(self.config, self.mesh, None) + + +class TestGetNnxNamedShardingWithScanAxis(unittest.TestCase): + """Unit tests for get_nnx_named_sharding_with_scan_axis covering every branch. + + The helper resolves a NamedSharding for each NNX Variable and — unlike + flax.nnx.spmd.get_var_pspec — also inserts the `nnx.PARTITION_NAME` axis at + `param_scan_axis` when scanned-layers metadata is present. + """ + + def setUp(self): + # Mesh needs to contain every axis name the tests reference in partition specs. + self.mesh = Mesh(np.array(jax.local_devices()[:1]).reshape(1, 1), ("fsdp", "layers")) + + def _build_state(self, **variables): + """Wrap a dict of {key: nnx.Variable} in an nnx.State for tree traversal.""" + return nnx.State(variables) + + def _run(self, state): + return maxtext_utils.get_nnx_named_sharding_with_scan_axis(state, self.mesh) + + def test_scan_axis_inserted_at_param_scan_axis(self): + """When PARTITION_NAME is present, the partition name is inserted at `param_scan_axis`.""" + with jax.set_mesh(self.mesh): + v = nnx.Param( + jnp.zeros((3, 4, 8)), + out_sharding=(None, "fsdp"), + **{nnx.PARTITION_NAME: "layers", "param_scan_axis": 1}, + ) + out = self._run(self._build_state(w=v)) + result_sharding = out["w"].get_value() + self.assertIsInstance(result_sharding, NamedSharding) + # 'layers' must be inserted at position 1 (param_scan_axis=1). + self.assertEqual(result_sharding.spec, PartitionSpec(None, "layers", "fsdp")) + + def test_scan_axis_not_inserted_when_already_present(self): + """Guard against double-insertion when partition_name is already in out_sharding.""" + with jax.set_mesh(self.mesh): + v = nnx.Param( + jnp.zeros((2, 2, 2)), + out_sharding=("layers", None, "fsdp"), + **{nnx.PARTITION_NAME: "layers", "param_scan_axis": 0}, + ) + out = self._run(self._build_state(w=v)) + result_sharding = out["w"].get_value() + # 'layers' must appear exactly once — the same PartitionSpec we started with. + self.assertEqual(result_sharding.spec, PartitionSpec("layers", None, "fsdp")) + + def test_masked_node_preserved_as_is(self): + """Values without a .shape attribute (e.g., optax.MaskedNode) are returned unchanged.""" + masked = nnx.Variable(optax.MaskedNode()) + state = self._build_state(masked=masked) + out = self._run(state) + # The leaf must be the original Variable, not a NamedSharding wrapper. + self.assertIs(out["masked"], masked) + + def test_empty_out_sharding_yields_empty_pspec(self): + """A Variable without any sharding metadata should resolve to PartitionSpec().""" + with jax.set_mesh(self.mesh): + # No out_sharding/sharding_names/sharding metadata → falsy → PartitionSpec() + v = nnx.Param(jnp.zeros((4,))) + out = self._run(self._build_state(w=v)) + result_sharding = out["w"].get_value() + self.assertIsInstance(result_sharding, NamedSharding) + self.assertEqual(result_sharding.spec, PartitionSpec()) + + def test_string_out_sharding_is_wrapped_into_tuple(self): + """A single-string out_sharding value should still produce a valid PartitionSpec.""" + with jax.set_mesh(self.mesh): + v = nnx.Param( + jnp.zeros((4,)), + out_sharding="fsdp", + **{nnx.PARTITION_NAME: "layers", "param_scan_axis": 0}, + ) + out = self._run(self._build_state(w=v)) + result_sharding = out["w"].get_value() + # The single string 'fsdp' is turned into a list, and 'layers' is prepended. + self.assertEqual(result_sharding.spec, PartitionSpec("layers", "fsdp")) + + if __name__ == "__main__": unittest.main() diff --git a/tests/unit/muon_utils_test.py b/tests/unit/muon_utils_test.py new file mode 100644 index 0000000000..9570257eee --- /dev/null +++ b/tests/unit/muon_utils_test.py @@ -0,0 +1,224 @@ +# Copyright 2023–2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unit tests for muon_utils.py.""" + +# pylint: disable=protected-access + +import io +import contextlib +import unittest +from unittest import mock + +import jax +import jax.numpy as jnp +from flax import linen as nn +from flax import nnx +from optax.contrib._muon import MuonDimensionNumbers as mdn + +from maxtext.utils import muon_utils + + +class TestIsPathContainAny(unittest.TestCase): + """Tests for _is_path_contain_any helper.""" + + def test_returns_true_when_any_element_in_path(self): + self.assertTrue(muon_utils._is_path_contain_any(("bias", "scale"), ("decoder", "bias"))) + + def test_returns_false_when_no_element_in_path(self): + self.assertFalse(muon_utils._is_path_contain_any(("bias", "scale"), ("decoder", "kernel"))) + + def test_empty_tuples_returns_false(self): + self.assertFalse(muon_utils._is_path_contain_any((), ("decoder", "kernel"))) + + +class TestTransformLogic(unittest.TestCase): + """Tests for transform_logic: covers every branch of the mapping.""" + + # --- 1. Exclusions --- + def test_scale_is_excluded(self): + self.assertIsNone(muon_utils.transform_logic(("decoder", "norm", "scale"))) + + def test_bias_is_excluded(self): + self.assertIsNone(muon_utils.transform_logic(("decoder", "dense", "bias"))) + + def test_embedding_is_excluded(self): + self.assertIsNone(muon_utils.transform_logic(("token_embedder", "embedding"))) + + def test_logits_dense_is_excluded(self): + self.assertIsNone(muon_utils.transform_logic(("decoder", "logits_dense", "kernel"))) + + # --- 2.1 MoE --- + def test_moe_wi_0_uses_last_two_axes(self): + self.assertEqual(muon_utils.transform_logic(("decoder", "MoeBlock_0", "wi_0")), mdn((-2,), (-1,))) + + def test_moe_wi_1_uses_last_two_axes(self): + self.assertEqual(muon_utils.transform_logic(("decoder", "MoeBlock_0", "wi_1")), mdn((-2,), (-1,))) + + def test_moe_wo_uses_last_two_axes(self): + self.assertEqual(muon_utils.transform_logic(("decoder", "MoeBlock_0", "wo")), mdn((-2,), (-1,))) + + def test_moe_gate_falls_through_to_standard(self): + # 'gate' is inside MoeBlock_0 but not one of (wi_0, wi_1, wo) → standard. + self.assertEqual(muon_utils.transform_logic(("decoder", "MoeBlock_0", "gate", "kernel")), mdn((0,), (-1,))) + + # --- 2.2 Self-attention --- + def test_self_attention_out_projection(self): + self.assertEqual(muon_utils.transform_logic(("decoder", "self_attention", "out")), mdn((0, -2), (-1,))) + + def test_self_attention_query_projection(self): + self.assertEqual(muon_utils.transform_logic(("decoder", "self_attention", "query")), mdn((0,), (-2, -1))) + + def test_self_attention_key_projection(self): + self.assertEqual(muon_utils.transform_logic(("decoder", "self_attention", "key")), mdn((0,), (-2, -1))) + + def test_self_attention_value_projection(self): + self.assertEqual(muon_utils.transform_logic(("decoder", "self_attention", "value")), mdn((0,), (-2, -1))) + + def test_self_attention_wq_b_and_wkv_b(self): + self.assertEqual(muon_utils.transform_logic(("decoder", "self_attention", "wq_b")), mdn((0,), (-2, -1))) + self.assertEqual(muon_utils.transform_logic(("decoder", "self_attention", "wkv_b")), mdn((0,), (-2, -1))) + + def test_self_attention_mla_wq_a_is_excluded_from_special(self): + # wq_a / wkv_a are MLA down-projections; they fall through the self_attention branch + # without matching anything, so the function returns the default standard mdn((0,), (-1,)). + self.assertEqual(muon_utils.transform_logic(("decoder", "self_attention", "wq_a")), mdn((0,), (-1,))) + self.assertEqual(muon_utils.transform_logic(("decoder", "self_attention", "wkv_a")), mdn((0,), (-1,))) + + # --- 3. Standard --- + def test_standard_weight(self): + self.assertEqual(muon_utils.transform_logic(("decoder", "mlp", "kernel")), mdn((0,), (-1,))) + + +class TestGetTransformTree(unittest.TestCase): + """Tests for get_transform_tree: recursive dict walk that applies transform_logic.""" + + def test_nested_dict_is_walked(self): + tree = {"decoder": {"self_attention": {"out": 0}, "mlp": {"kernel": 0}}} + result = muon_utils.get_transform_tree(tree) + self.assertEqual(result["decoder"]["self_attention"]["out"], mdn((0, -2), (-1,))) + self.assertEqual(result["decoder"]["mlp"]["kernel"], mdn((0,), (-1,))) + + def test_excluded_leaves_become_none(self): + tree = {"decoder": {"norm": {"scale": 0}}} + self.assertIsNone(muon_utils.get_transform_tree(tree)["decoder"]["norm"]["scale"]) + + def test_non_dict_leaf_at_root_returns_transform(self): + # If the tree itself is a leaf, path=() and transform_logic returns the standard mdn. + self.assertEqual(muon_utils.get_transform_tree(0), mdn((0,), (-1,))) + + +class _MoeLikeNNXModel(nnx.Module): + """Small NNX model whose param paths exercise the NNX branch of get_muon_weight_dimension_numbers.""" + + def __init__(self, rngs): + # Names are chosen so transform_logic matches each of the three meaningful branches: + # - w_standard: default mdn + # - self_attention_out: attention-out mdn + # - scale: excluded (None) + self.w_standard = nnx.Param(jnp.ones((4, 8))) + self.self_attention_out = nnx.Param(jnp.ones((4, 8))) + self.scale = nnx.Param(jnp.ones((8,))) + + +class TestGetMuonWeightDimensionNumbersNNX(unittest.TestCase): + """Covers the NNX branch of get_muon_weight_dimension_numbers (isinstance(model, nnx.Module)).""" + + def setUp(self): + self.model = _MoeLikeNNXModel(rngs=nnx.Rngs(0)) + + def test_nnx_model_dispatches_to_tree_map_with_path(self): + """NNX branch should produce an nnx.State tree with transform_logic applied per leaf.""" + result = muon_utils.get_muon_weight_dimension_numbers(self.model, config=None) + + # Result is an nnx.State whose top-level keys mirror the model attributes. + self.assertIn("w_standard", result) + self.assertIn("self_attention_out", result) + self.assertIn("scale", result) + + # NNX Variables are walked by jax.tree_util.tree_map_with_path, so the returned + # tree replaces each Variable's value with transform_logic(path_strings). + # 'scale' matches the exclusion branch → value is None. + self.assertIsNone(result["scale"].get_value()) + # 'w_standard' does not trigger any special rule → standard mdn. + self.assertEqual(result["w_standard"].get_value(), mdn((0,), (-1,))) + + def test_nnx_verbose_path_executes_print_debug(self): + """verbose=True should also execute _print_structure_debug without raising.""" + buf = io.StringIO() + with contextlib.redirect_stdout(buf): + muon_utils.get_muon_weight_dimension_numbers(self.model, config=None, verbose=True) + self.assertIn("Model Structure", buf.getvalue()) + self.assertIn("Muon Dimension Numbers", buf.getvalue()) + + +class TestGetMuonWeightDimensionNumbersLinen(unittest.TestCase): + """Covers the Linen branch of get_muon_weight_dimension_numbers.""" + + def test_linen_branch_uses_get_abstract_param(self): + """Linen models dispatch to maxtext_utils.get_abstract_param + get_transform_tree.""" + # Build a Linen nn.Module so isinstance(model, nnx.Module) is False. + + class LinenStub(nn.Module): + + @nn.compact + def __call__(self, x): + return x + + model = LinenStub() + + # Mock the heavy get_abstract_param call with a pre-shaped dict that exercises + # both a standard weight path and an excluded path. + fake_abstract_param = { + "params": { + "self_attention": {"out": object()}, + "norm": {"scale": object()}, + }, + } + + with mock.patch.object(muon_utils.maxtext_utils, "get_abstract_param", return_value=fake_abstract_param): + result = muon_utils.get_muon_weight_dimension_numbers(model, config=mock.MagicMock()) + + self.assertEqual(result["params"]["self_attention"]["out"], mdn((0, -2), (-1,))) + self.assertIsNone(result["params"]["norm"]["scale"]) + + +class TestPrintStructureDebug(unittest.TestCase): + """Covers both branches of get_leaf_info inside _print_structure_debug.""" + + def test_handles_logically_partitioned_leaf(self): + """Linen leaves are nn.LogicallyPartitioned; the helper should return {shape, names}.""" + leaf = nn.LogicallyPartitioned(value=jax.ShapeDtypeStruct((4, 8), jnp.float32), names=("embed", "mlp")) + tree = {"params": {"kernel": leaf}} + + buf = io.StringIO() + with contextlib.redirect_stdout(buf): + muon_utils._print_structure_debug(tree, muon_weight_dimension_numbers={"params": {"kernel": mdn((0,), (-1,))}}) + out = buf.getvalue() + self.assertIn("(4, 8)", out) + self.assertIn("embed", out) + + def test_handles_shape_dtype_struct_leaf(self): + """NNX abstract leaves are ShapeDtypeStruct directly; the helper should return {shape}.""" + tree = {"kernel": jax.ShapeDtypeStruct((16, 32), jnp.float32)} + + buf = io.StringIO() + with contextlib.redirect_stdout(buf): + muon_utils._print_structure_debug(tree, muon_weight_dimension_numbers={"kernel": mdn((0,), (-1,))}) + out = buf.getvalue() + self.assertIn("(16, 32)", out) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/unit/nnx_decoders_test.py b/tests/unit/nnx_decoders_test.py index 8979440732..00f761c2cf 100644 --- a/tests/unit/nnx_decoders_test.py +++ b/tests/unit/nnx_decoders_test.py @@ -31,7 +31,12 @@ from flax import nnx from jax.sharding import Mesh -from maxtext.common.common_types import DECODING_ACTIVE_SEQUENCE_INDICATOR, MODEL_MODE_TRAIN, DecoderBlockType +from maxtext.common.common_types import ( + DECODING_ACTIVE_SEQUENCE_INDICATOR, + MODEL_MODE_TRAIN, + DecoderBlockType, + MultimodalInput, +) from maxtext.configs import pyconfig from maxtext.layers import linears from maxtext.layers.attentions import Attention @@ -507,6 +512,88 @@ def test_logits_are_finite(self): ) self.assertTrue(jnp.all(jnp.isfinite(logits))) + def test_multimodal_input_unpacks_into_individual_fields(self): + """Passing `multimodal_input=...` must forward each field into `_apply_embedding`. + + The decoder accepts either a `MultimodalInput` struct or the individual + image/audio/bidirectional_mask arguments. When both forms are provided, the + unpacked struct takes precedence. This test stubs `_apply_embedding` to + capture the forwarded positional arguments without running the real + embedding path (the test config has `use_multimodal=False`). + """ + ids, segment_ids, positions = self._make_token_inputs() + + # Distinct sentinels so each field can be traced independently. + sentinel_img_emb = jnp.full((1, 1), 11.0) + sentinel_img_mask = jnp.full((1, 1), 22.0) + sentinel_aud_emb = jnp.full((1, 1), 33.0) + sentinel_aud_mask = jnp.full((1, 1), 44.0) + sentinel_bidir = jnp.full((1, 1), 55.0) + + mm_input = MultimodalInput( + image_embeddings=sentinel_img_emb, + image_masks=sentinel_img_mask, + audio_embeddings=sentinel_aud_emb, + audio_masks=sentinel_aud_mask, + bidirectional_mask=sentinel_bidir, + ) + + captured = {} + + def fake_apply_embedding( + _shared_embedding, + _ids, + _positions, + _deterministic, + _model_mode, + image_embeddings, + bidirectional_mask, + image_masks, + audio_embeddings, + audio_masks, + ): + captured.update( + image_embeddings=image_embeddings, + image_masks=image_masks, + audio_embeddings=audio_embeddings, + audio_masks=audio_masks, + bidirectional_mask=bidirectional_mask, + ) + # Return a correctly-shaped tensor so the rest of __call__ can proceed. + batch = self.cfg.global_batch_size_to_train_on + seq_len = self.cfg.max_target_length + emb_dim = self.cfg.emb_dim + return jnp.zeros((batch, seq_len, emb_dim), dtype=self.cfg.dtype) + + self.decoder._apply_embedding = fake_apply_embedding # pylint: disable=protected-access + try: + self.decoder( + self.shared_embedding, + ids, + positions, + decoder_segment_ids=segment_ids, + deterministic=True, + model_mode=MODEL_MODE_TRAIN, + # Intentionally pass the individual args as None; multimodal_input must override them. + image_embeddings=None, + image_masks=None, + audio_embeddings=None, + audio_masks=None, + bidirectional_mask=None, + multimodal_input=mm_input, + ) + finally: + # NNX modules bind attributes statefully; remove the override to avoid leaking. + del self.decoder._apply_embedding # pylint: disable=protected-access + + # Every field in the MultimodalInput struct must have been forwarded + # unchanged into _apply_embedding's arguments (not the None overrides). + self.assertTrue(jnp.array_equal(captured["image_embeddings"], sentinel_img_emb)) + self.assertTrue(jnp.array_equal(captured["image_masks"], sentinel_img_mask)) + self.assertTrue(jnp.array_equal(captured["audio_embeddings"], sentinel_aud_emb)) + self.assertTrue(jnp.array_equal(captured["audio_masks"], sentinel_aud_mask)) + self.assertTrue(jnp.array_equal(captured["bidirectional_mask"], sentinel_bidir)) + def test_different_random_seeds_produce_different_logits(self): """Two randomly-initialised decoders should not produce identical logits.""" cfg = self.cfg diff --git a/tests/unit/optimizers_test.py b/tests/unit/optimizers_test.py index 44623f24f3..5194719ce2 100644 --- a/tests/unit/optimizers_test.py +++ b/tests/unit/optimizers_test.py @@ -15,19 +15,19 @@ """ Unit tests for all optimizers. """ import re import unittest -from unittest.mock import patch +from unittest.mock import patch, MagicMock import jax import optax import jax.numpy as jnp import pytest from absl.testing import parameterized +from flax import nnx from optax.contrib import MuonDimensionNumbers as mdn from maxtext.configs import pyconfig from maxtext.optimizers import optimizers -from maxtext.utils import maxtext_utils -from maxtext.utils.muon_utils import get_model_mdn +from maxtext.utils import maxtext_utils, muon_utils from tests.utils.test_helpers import get_test_config_path from typing import NamedTuple @@ -49,6 +49,7 @@ DEEPSEEK2_DIMENSION_NUMBER = { "params": { "decoder": { + "decoder_norm": {"scale": None}, "dense_layers": { "mlp": { "wi_0": {"kernel": mdn((0,), (-1,))}, @@ -57,6 +58,7 @@ }, **_DEEPSEEK2_ATTENTION, }, + "logits_dense": {"kernel": None}, "moe_layers": { "DeepSeekMoeBlock_0": { "MoeBlock_0": { @@ -73,8 +75,6 @@ }, **_DEEPSEEK2_ATTENTION, }, - "decoder_norm": {"scale": None}, - "logits_dense": {"kernel": None}, }, "token_embedder": {"embedding": None}, } @@ -99,6 +99,7 @@ DEEPSEEK3_DIMENSION_NUMBER = { "params": { "decoder": { + "decoder_norm": {"scale": None}, "dense_layers": { "mlp": { "wi_0": {"kernel": mdn((0,), (-1,))}, @@ -107,6 +108,7 @@ }, **_DEEPSEEK3_ATTENTION, }, + "logits_dense": {"kernel": None}, "moe_layers": { "DeepSeekMoeBlock_0": { "MoeBlock_0": { @@ -123,8 +125,6 @@ }, **_DEEPSEEK3_ATTENTION, }, - "decoder_norm": {"scale": None}, - "logits_dense": {"kernel": None}, }, "token_embedder": {"embedding": None}, } @@ -243,7 +243,7 @@ def test_model_integration(self, model_name, expected_output): Initializes the specified MaxText model and asserts that the generated Muon dimension numbers match the hardcoded reference. """ - actual_output = get_model_mdn(model_name, scan_layers=True) + actual_output = muon_utils.get_model_mdn(model_name, scan_layers=True, pure_nnx=False) self.assertEqual(actual_output, expected_output) @@ -483,5 +483,105 @@ def test_no_skip_without_kwargs(self): self.assertEqual(opt_state["count"], 0) +class TestMuonLogic(unittest.TestCase): + """Tests the granular path transformation functions.""" + + def test_is_path_contain_any(self): + # pylint: disable=protected-access + self.assertTrue(muon_utils._is_path_contain_any(("a", "b"), ("x", "a", "z"))) + self.assertFalse(muon_utils._is_path_contain_any(("a", "b"), ("x", "y", "z"))) + + def test_transform_logic_exclusions(self): + self.assertIsNone(muon_utils.transform_logic(("layer_0", "bias"))) + self.assertIsNone(muon_utils.transform_logic(("layer_0", "scale"))) + self.assertIsNone(muon_utils.transform_logic(("embedding", "kernel"))) + + def test_transform_logic_moe(self): + path = ("layers_0", "MoeBlock_0", "wi_0") + result = muon_utils.transform_logic(path) + self.assertEqual(result.reduction_axis, (-2,)) + self.assertEqual(result.output_axis, (-1,)) + + def test_transform_logic_attention(self): + path_out = ("layers_0", "self_attention", "out", "kernel") + self.assertEqual(muon_utils.transform_logic(path_out), mdn((0, -2), (-1,))) + + path_q = ("layers_0", "self_attention", "query", "kernel") + self.assertEqual(muon_utils.transform_logic(path_q), mdn((0,), (-2, -1))) + + def test_get_transform_tree(self): + fake_tree = {"params": {"layer_0": {"kernel": "leaf", "bias": "leaf"}, "MoeBlock_0": {"wi_0": "leaf"}}} + result = muon_utils.get_transform_tree(fake_tree) + self.assertEqual(result["params"]["layer_0"]["kernel"], mdn((0,), (-1,))) + self.assertIsNone(result["params"]["layer_0"]["bias"]) + + def test_get_muon_weight_dimension_numbers_nnx(self): + """Verifies dimension extraction for stateful NNX modules.""" + + class MockNNXModel(nnx.Module): + """Mock NNX Module.""" + + def __init__(self, rngs: nnx.Rngs): + # 1. Standard layer + self.layer1 = nnx.Linear(2, 4, rngs=rngs) + + # 2. MoE specific naming to trigger transform logic. + # The logic expects "MoeBlock_0" AND "wi_0"/"wi_1"/"wo" in the path. + # We nest the linear layer to create the path: ('MoeBlock_0', 'wi_0', 'kernel') + self.MoeBlock_0 = nnx.Module() + self.MoeBlock_0.wi_0 = nnx.Linear(4, 2, rngs=rngs) + + # 3. Exclusion case (scaler/scale) + self.scale = nnx.Param(jnp.ones((1,))) + + # Use eval_shape to create an abstract version of the model. + model = nnx.eval_shape(lambda: MockNNXModel(rngs=nnx.Rngs(0))) + config = MagicMock() + + # Extract dimension numbers using the NNX path in muon_utils + result = muon_utils.get_muon_weight_dimension_numbers(model, config) + + # Verify standard weight path: ('layer1', 'kernel') -> default (0,) + self.assertEqual(result.layer1.kernel.value, mdn((0,), (-1,))) + + # Verify MoE weight path: ('MoeBlock_0', 'wi_0', 'kernel') -> (-2,) + self.assertEqual(result.MoeBlock_0.wi_0.kernel.value, mdn((-2,), (-1,))) + + # Verify exclusion (scalar/scale) + self.assertIsNone(result.scale.value) + + def test_verbose_output_nnx(self): + """Covers lines 128 and 135-154: _print_structure_debug via verbose=True with NNX model.""" + + class SimpleNNXModel(nnx.Module): + + def __init__(self, rngs: nnx.Rngs): + self.linear = nnx.Linear(2, 4, rngs=rngs) + + model = nnx.eval_shape(lambda: SimpleNNXModel(rngs=nnx.Rngs(0))) + config = MagicMock() + muon_utils.get_muon_weight_dimension_numbers(model, config, verbose=True) + + def test_nnx_deepseek_attention_logic(self): + """Simulates a DeepSeek-like attention structure in NNX.""" + + class DeepSeekAttention(nnx.Module): + + def __init__(self, rngs: nnx.Rngs): + self.self_attention = nnx.Module() + self.self_attention.query = nnx.Linear(8, 8, rngs=rngs) + self.self_attention.out = nnx.Linear(8, 8, rngs=rngs) + + # Use eval_shape to create an abstract version of the model. + model = nnx.eval_shape(lambda: DeepSeekAttention(nnx.Rngs(0))) + config = MagicMock() + result = muon_utils.get_muon_weight_dimension_numbers(model, config) + + # Check attention query: [0] -> [-2, -1] + self.assertEqual(result.self_attention.query.kernel.value, mdn((0,), (-2, -1))) + # Check attention out: [0, -2] -> [-1] + self.assertEqual(result.self_attention.out.kernel.value, mdn((0, -2), (-1,))) + + if __name__ == "__main__": unittest.main() diff --git a/tests/unit/sharding_nnx_test.py b/tests/unit/sharding_nnx_test.py new file mode 100644 index 0000000000..3cda286c68 --- /dev/null +++ b/tests/unit/sharding_nnx_test.py @@ -0,0 +1,161 @@ +# Copyright 2025-2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unit tests for the NNX-specific helpers in maxtext.utils.sharding.""" + +import unittest +from dataclasses import dataclass + +import jax +from jax.sharding import Mesh, NamedSharding, PartitionSpec + +from flax import nnx +import numpy as np +import optax + +from maxtext.layers import train_state_nnx +from maxtext.utils import sharding + + +@dataclass +class _Cfg: + pure_nnx: bool = True + shard_optimizer_over_data: bool = False + + +class _LinearNNX(nnx.Module): + """Tiny NNX model with a single Linear layer for sharding tests.""" + + def __init__(self, rngs: nnx.Rngs): + self.linear = nnx.Linear(2, 4, rngs=rngs) + + +def _build_state_mesh_shardings(model, tx): + """Build an nnx.State of NamedShardings mirroring the TrainStateNNX layout. + + This emulates what get_abstract_state_nnx returns: an nnx.State whose leaves + are nnx.Variable wrappers around NamedSharding objects. + """ + optimizer = nnx.Optimizer(model, tx, wrt=nnx.Param) + state_obj = train_state_nnx.TrainStateNNX(model, optimizer) + state = nnx.state(state_obj) + mesh = Mesh(np.array(jax.local_devices()[:1]).reshape(1, 1), ("data", "model")) + + def _to_sharding(var): + val = var.get_value() + if not hasattr(val, "shape") or val.ndim == 0: + pspec = PartitionSpec() + elif val.ndim == 1: + pspec = PartitionSpec("model") + else: + pspec = PartitionSpec("data", "model") + return var.replace(NamedSharding(mesh, pspec)) + + return jax.tree.map(_to_sharding, state, is_leaf=lambda x: isinstance(x, nnx.Variable)) + + +class TestMaybeUpdateParamsShardingWithOptNNX(unittest.TestCase): + """Cover the NNX branches of maybe_update_params_sharding_with_opt.""" + + def setUp(self): + self.model = _LinearNNX(rngs=nnx.Rngs(0)) + + def test_dispatch_from_main_helper_when_pure_nnx(self): + """maybe_update_params_sharding_with_opt should dispatch to the NNX variant.""" + cfg = _Cfg(pure_nnx=True, shard_optimizer_over_data=False) + state_mesh_shardings = _build_state_mesh_shardings(self.model, optax.adam(1e-3)) + prev, updated = sharding.maybe_update_params_sharding_with_opt(cfg, state_mesh_shardings) + # prev is the param-only view (no rngs / non-Param nodes) + self.assertIsInstance(prev, nnx.State) + self.assertIn("linear", prev) + # updated is unchanged because shard_optimizer_over_data=False + self.assertIs(updated, state_mesh_shardings) + + def test_extract_param_only_skips_non_param_variables(self): + """prev_params_shardings must contain Params only — RngKey/RngCount/OptVariable filtered out.""" + cfg = _Cfg(shard_optimizer_over_data=False) + state_mesh_shardings = _build_state_mesh_shardings(self.model, optax.adam(1e-3)) + prev, _ = sharding.maybe_update_params_sharding_with_opt_nnx(cfg, state_mesh_shardings) + leaves = jax.tree.leaves(prev, is_leaf=lambda x: isinstance(x, nnx.Variable)) + # Every surviving leaf is wrapped as an nnx.Param. + self.assertTrue(all(isinstance(leaf, nnx.Param) for leaf in leaves)) + # The model has linear.kernel and linear.bias — exactly two Param leaves. + self.assertEqual(len(leaves), 2) + + def test_returns_unchanged_when_shard_optimizer_over_data_false(self): + """When shard_optimizer_over_data=False, the second return value must be the input object.""" + cfg = _Cfg(shard_optimizer_over_data=False) + state_mesh_shardings = _build_state_mesh_shardings(self.model, optax.adam(1e-3)) + _, updated = sharding.maybe_update_params_sharding_with_opt_nnx(cfg, state_mesh_shardings) + self.assertIs(updated, state_mesh_shardings) + + def test_zero1_propagates_mu_sharding_to_model_params(self): + """Zero-1: model param shardings must be replaced with the optimizer mu shardings.""" + cfg = _Cfg(shard_optimizer_over_data=True) + state_mesh_shardings = _build_state_mesh_shardings(self.model, optax.adam(1e-3)) + + # Mutate the optimizer mu leaves in place so the function picks up a distinct PartitionSpec. + mesh = Mesh(np.array(jax.local_devices()[:1]).reshape(1, 1), ("data", "model")) + target_pspec = PartitionSpec(("data", "model")) + new_mu_sharding = NamedSharding(mesh, target_pspec) + + # After _build_state_mesh_shardings, every leaf's .value is a NamedSharding (no .shape), + # so we just override every Variable leaf in mu in place. + # After _build_state_mesh_shardings, every leaf's value is a NamedSharding (no .shape), + # so we just override every Variable leaf in mu in place via set_value (modern API). + mu_state = state_mesh_shardings.optimizer.opt_state[0]["mu"] + for var in jax.tree.leaves(mu_state, is_leaf=lambda x: isinstance(x, nnx.Variable)): + if isinstance(var, nnx.Variable): + var.set_value(new_mu_sharding) + + _, updated = sharding.maybe_update_params_sharding_with_opt_nnx(cfg, state_mesh_shardings) + + # All Param leaves under updated.model must now share the new mu sharding. + param_leaves = jax.tree.leaves(updated.model, is_leaf=lambda x: isinstance(x, nnx.Variable)) + param_leaves = [v for v in param_leaves if isinstance(v, nnx.Param)] + self.assertGreater(len(param_leaves), 0) + for leaf in param_leaves: + self.assertEqual(leaf.get_value().spec, target_pspec) + + def test_raises_when_no_adam_state_present(self): + """Stateless optimizers (e.g., SGD) have no mu — function must raise NotImplementedError.""" + cfg = _Cfg(shard_optimizer_over_data=True) + state_mesh_shardings = _build_state_mesh_shardings(self.model, optax.sgd(1e-3)) + with self.assertRaises(NotImplementedError): + sharding.maybe_update_params_sharding_with_opt_nnx(cfg, state_mesh_shardings) + + def test_chained_optimizer_recursion_finds_adam_mu(self): + """A nested optax.chain(clip, adam) wraps mu under multiple containers — recursion must find it.""" + cfg = _Cfg(shard_optimizer_over_data=True) + chained = optax.chain(optax.clip_by_global_norm(1.0), optax.adam(1e-3)) + state_mesh_shardings = _build_state_mesh_shardings(self.model, chained) + + # Should not raise; verify update happens (params replaced with mu shardings). + prev, updated = sharding.maybe_update_params_sharding_with_opt_nnx(cfg, state_mesh_shardings) + self.assertIsInstance(prev, nnx.State) + self.assertIsInstance(updated, nnx.State) + # Same number of Param leaves before and after. + n_prev = len(jax.tree.leaves(prev, is_leaf=lambda x: isinstance(x, nnx.Variable))) + n_after = len( + [ + v + for v in jax.tree.leaves(updated.model, is_leaf=lambda x: isinstance(x, nnx.Variable)) + if isinstance(v, nnx.Param) + ] + ) + self.assertEqual(n_prev, n_after) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/unit/train_nnx_test.py b/tests/unit/train_nnx_test.py new file mode 100644 index 0000000000..3495b4c557 --- /dev/null +++ b/tests/unit/train_nnx_test.py @@ -0,0 +1,239 @@ +# Copyright 2025-2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unit tests for the NNX paths of loss_fn / train_step / eval_step in pre_train.train. + +These tests exercise the NNX branches without standing up a real Transformer or +data pipeline. We use a tiny NNX module that mimics the call signature the +production loss_fn uses (decoder_input_tokens, decoder_positions, ...). +""" + +import unittest +from dataclasses import dataclass + +import jax.numpy as jnp +import optax +from flax import nnx + +from maxtext.layers import train_state_nnx +from maxtext.trainers.pre_train import train as pre_train + + +@dataclass +class _Cfg: + """Subset of HyperParameters used by loss_fn / train_step / eval_step.""" + + micro_batch_size_to_train_on: int = 2 + micro_batch_size_to_eval_on: int = 2 + vocab_size: int = 8 + z_loss_multiplier: float = 0.0 + enable_dropout: bool = False + use_multimodal: bool = False + use_indexer: bool = False + indexer_sparse_training: bool = False + indexer_loss_scaling_factor: float = 0.0 + num_vocab_tiling: int = 1 + num_experts: int = 1 + routed_bias: bool = False + routed_bias_update_rate: float = 0.0 + mtp_num_layers: int = 0 + mtp_eval_target_module: int = 0 + use_dpo: bool = False + use_qk_clip: bool = False + use_tunix_gradient_accumulation: bool = False + gradient_accumulation_steps: int = 1 + shard_optimizer_over_data: bool = False + optimizer_memory_host_offload: bool = False + parameter_memory_host_offload: bool = False + gradient_clipping_threshold: float = 0.0 + grad_dtype: jnp.dtype = jnp.float32 + record_internal_nn_metrics: bool = False + skip_step_on_spikes: bool = False + shard_mode: int = 0 # ShardMode.AUTO + weight_sparsity_n: int = 0 + weight_sparsity_m: int = 0 + + +class _TinyDecoder(nnx.Module): + """Mimics NNXDecoder.__call__ enough for loss_fn to run end-to-end. + + Returns logits of shape [batch, seq_len, vocab_size]. Ignores all multimodal + / dropout / target arguments — they exist only to match the keyword signature. + """ + + def __init__(self, vocab_size: int, hidden: int, rngs: nnx.Rngs): + self.embed = nnx.Embed(vocab_size, hidden, rngs=rngs) + self.proj = nnx.Linear(hidden, vocab_size, rngs=rngs) + + def __call__( + self, + decoder_input_tokens, + decoder_positions, + decoder_segment_ids=None, + encoder_images=None, + encoder_image_masks=None, + enable_dropout=False, + decoder_target_tokens=None, + decoder_target_mask=None, + ): + del decoder_positions, decoder_segment_ids, encoder_images, encoder_image_masks + del enable_dropout, decoder_target_tokens, decoder_target_mask + h = self.embed(decoder_input_tokens) + return self.proj(h) + + +def _make_data(batch=2, seq=4, vocab=8): + return { + "inputs": jnp.zeros((batch, seq), dtype=jnp.int32), + "inputs_position": jnp.broadcast_to(jnp.arange(seq), (batch, seq)), + "inputs_segmentation": jnp.ones((batch, seq), dtype=jnp.int32), + "targets": jnp.zeros((batch, seq), dtype=jnp.int32), + "targets_segmentation": jnp.ones((batch, seq), dtype=jnp.int32), + } + + +def _build_state(): + cfg = _Cfg() + model = _TinyDecoder(cfg.vocab_size, hidden=4, rngs=nnx.Rngs(0)) + optimizer = nnx.Optimizer(model, optax.sgd(0.01), wrt=nnx.Param) + ts = train_state_nnx.TrainStateNNX(model, optimizer) + return cfg, ts + + +class TestLossFnNNX(unittest.TestCase): + """Cover the NNX branch of loss_fn (lines 178-213).""" + + def test_returns_loss_and_full_aux_dict(self): + cfg, ts = _build_state() + data = _make_data(batch=cfg.micro_batch_size_to_train_on, vocab=cfg.vocab_size) + loss, aux = pre_train.loss_fn(ts.model, cfg, data, None, None, is_train=True) + self.assertTrue(jnp.isfinite(loss)) + # Aux schema relied on by train_step / eval_step / GA. + for key in ( + "intermediate_outputs", + "xent_sum", + "z_loss", + "total_weights", + "moe_lb_loss", + "indexer_loss", + "moe_bias_updates", + "mtp_loss", + ): + self.assertIn(key, aux) + # NNX intermediates are captured into a pure-dict snapshot, then logits attached. + self.assertIsInstance(aux["intermediate_outputs"], dict) + self.assertIn("logits", aux["intermediate_outputs"]) + + def test_eval_mode_truncates_to_eval_micro_batch(self): + cfg, ts = _build_state() + cfg.micro_batch_size_to_eval_on = 1 + data = _make_data(batch=2, vocab=cfg.vocab_size) + loss, aux = pre_train.loss_fn(ts.model, cfg, data, None, None, is_train=False) + self.assertTrue(jnp.isfinite(loss)) + # eval truncated batch to 1 → total_weights = seq_len * 1 + self.assertEqual(int(aux["total_weights"]), data["targets_segmentation"].shape[1]) + + def test_indexer_dense_warmup_skips_xent(self): + cfg, ts = _build_state() + cfg.use_indexer = True + cfg.indexer_sparse_training = False + data = _make_data(batch=cfg.micro_batch_size_to_train_on, vocab=cfg.vocab_size) + loss, aux = pre_train.loss_fn(ts.model, cfg, data, None, None, is_train=True) + # When dense warm-up is active the loss_fn skips the main loss entirely. + self.assertEqual(float(aux["xent_sum"]), 0.0) + self.assertEqual(float(loss), 0.0) + + def test_vocab_tiling_raises_not_implemented(self): + cfg, ts = _build_state() + cfg.num_vocab_tiling = 4 + data = _make_data(batch=cfg.micro_batch_size_to_train_on, vocab=cfg.vocab_size) + with self.assertRaises(NotImplementedError): + pre_train.loss_fn(ts.model, cfg, data, None, None, is_train=True) + + +class TestTrainStepNNX(unittest.TestCase): + """Cover the NNX branch of train_step (the diff_wrapper / nnx.update path).""" + + def test_train_step_returns_state_and_metrics(self): + cfg, ts = _build_state() + state_graphdef, state_pure = nnx.split(ts) + + data = _make_data(batch=cfg.micro_batch_size_to_train_on, vocab=cfg.vocab_size) + new_state, metrics = pre_train.train_step( + state_graphdef, cfg, state_mesh_shardings=None, params_shardings=None, state=state_pure, data=data + ) + # NNX path returns nnx.State (via nnx.state(new_state)) and a metrics dict. + self.assertIsInstance(new_state, nnx.State) + self.assertIn("scalar", metrics) + self.assertIn("learning/loss", metrics["scalar"]) + self.assertIn("learning/grad_norm", metrics["scalar"]) + self.assertIn("learning/param_norm", metrics["scalar"]) + self.assertTrue(jnp.isfinite(metrics["scalar"]["learning/loss"])) + + def test_train_step_dpo_raises_for_nnx(self): + cfg, ts = _build_state() + cfg.use_dpo = True + state_graphdef, state_pure = nnx.split(ts) + data = _make_data(batch=cfg.micro_batch_size_to_train_on, vocab=cfg.vocab_size) + with self.assertRaises(NotImplementedError): + pre_train.train_step( + state_graphdef, cfg, state_mesh_shardings=None, params_shardings=None, state=state_pure, data=data + ) + + def test_train_step_increments_optimizer_step(self): + cfg, ts = _build_state() + state_graphdef, state_pure = nnx.split(ts) + pre_step = int(state_pure.optimizer.step.get_value()) + data = _make_data(batch=cfg.micro_batch_size_to_train_on, vocab=cfg.vocab_size) + new_state, _ = pre_train.train_step( + state_graphdef, cfg, state_mesh_shardings=None, params_shardings=None, state=state_pure, data=data + ) + self.assertEqual(int(new_state.optimizer.step.get_value()), pre_step + 1) + + def test_train_step_with_gradient_clipping(self): + """The clipping branch (gradient_clipping_threshold > 0) must run without raising.""" + cfg, ts = _build_state() + cfg.gradient_clipping_threshold = 1.0 + state_graphdef, state_pure = nnx.split(ts) + data = _make_data(batch=cfg.micro_batch_size_to_train_on, vocab=cfg.vocab_size) + new_state, metrics = pre_train.train_step( + state_graphdef, cfg, state_mesh_shardings=None, params_shardings=None, state=state_pure, data=data + ) + self.assertIsInstance(new_state, nnx.State) + self.assertTrue(jnp.isfinite(metrics["scalar"]["learning/loss"])) + + +class TestEvalStepNNX(unittest.TestCase): + """Cover the NNX branch of eval_step (lines 568-570).""" + + def test_eval_step_returns_metrics(self): + cfg, ts = _build_state() + state_graphdef, state_pure = nnx.split(ts) + data = _make_data(batch=cfg.micro_batch_size_to_eval_on, vocab=cfg.vocab_size) + metrics = pre_train.eval_step(state_graphdef, cfg, state_pure, data) + self.assertIn("scalar", metrics) + for key in ( + "evaluation/loss", + "evaluation/total_loss", + "evaluation/total_weights", + "evaluation/moe_lb_loss", + ): + self.assertIn(key, metrics["scalar"]) + # NNX path must NOT include DPO eval metric. + self.assertNotIn("evaluation/dpo_reward_accuracy", metrics["scalar"]) + self.assertTrue(jnp.isfinite(metrics["scalar"]["evaluation/loss"])) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/unit/train_state_nnx_checkpoint_test.py b/tests/unit/train_state_nnx_checkpoint_test.py new file mode 100644 index 0000000000..100d3f81e1 --- /dev/null +++ b/tests/unit/train_state_nnx_checkpoint_test.py @@ -0,0 +1,399 @@ +# Copyright 2023–2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""TrainStateNNX checkpoint tests.""" + +import pathlib +import tempfile +import shutil +from types import SimpleNamespace +from unittest import mock + +import unittest +import jax +import jax.numpy as jnp +from flax import nnx, serialization +from flax import linen as nn +from flax.training import train_state +import optax +import orbax.checkpoint as ocp + +from maxtext.common import checkpointing +from maxtext.layers import train_state_nnx + + +class MockModel(nnx.Module): + """A simple model for checkpoint testing.""" + + def __init__(self, rngs: nnx.Rngs): + self.linear = nnx.Linear(2, 1, rngs=rngs) + + def __call__(self, x): + return self.linear(x) + + +class LinenMockModel(nn.Module): + """The Linen equivalent of the MockModel.""" + + @nn.compact + def __call__(self, x): + # We name the layer 'linear' to match the attribute name in the NNX MockModel + return nn.Dense(features=1, name="linear")(x) + + +class TestTrainStateNNXCheckpoint(unittest.TestCase): + """Class to test NNX checkpoint.""" + + def setUp(self): + self.rngs = nnx.Rngs(0) + self.model = MockModel(rngs=self.rngs) + + # Setup a chained optimizer: Gradient Clipping -> Adam + # Note: optax.adam is also a chain (scale_by_adam + scale_by_learning_rate). + # This creates a nested state structure: (EmptyState, (ScaleByAdamState, EmptyState)) + self.tx = optax.chain( + optax.clip_by_global_norm(max_norm=1.0), + optax.adam(1e-3), + ) + + def test_checkpoint_structure(self): + """Ensures the state object contains both model and optimizer keys.""" + optimizer = nnx.Optimizer(self.model, self.tx, wrt=nnx.Param) + state = train_state_nnx.TrainStateNNX(self.model, optimizer) + + # We use .to_pure_dict() to simulate the format stored in a checkpoint. + # This converts nnx.Variable/State objects into raw arrays and dictionaries. + full_state = nnx.state(state).to_pure_dict() + + # 1. Verify Top-level Keys + self.assertIn("model", full_state) + self.assertIn("optimizer", full_state) + + # 2. Verify Optimizer Internal Structure + opt_inner_state = full_state["optimizer"]["opt_state"] + + # Because we used optax.chain(clip, adam), index 0 is clip, index 1 is adam. + # Since adam is also a chain, index 1 is itself a dictionary/tuple representation. + # Adam's momentum (mu/nu) is in the first element of its own sub-chain. + adam_component = opt_inner_state[1][0] + + self.assertIn("mu", adam_component, "Adam 'mu' buffer not found in pure dict state.") + self.assertIn("nu", adam_component, "Adam 'nu' buffer not found in pure dict state.") + + # In a pure dict, these are nested dictionaries containing arrays, not NNX objects. + self.assertIsInstance(adam_component["mu"], dict) + self.assertIsInstance(adam_component["nu"], dict) + + # To verify a specific leaf, we navigate the dictionary hierarchy: + self.assertIsInstance(adam_component["mu"]["linear"]["kernel"], jax.Array) + + def test_checkpoint_and_restore(self): + """Verifies that the full state can be captured and restored into a new instance.""" + # 1. Initialize original state and optimizer + optimizer = nnx.Optimizer(self.model, self.tx, wrt=nnx.Param) + state_original = train_state_nnx.TrainStateNNX(self.model, optimizer) + + # 2. Perform a training step to modify weights and optimizer buffers + def loss_fn(m): + return jnp.mean(m(jnp.ones((1, 2))) ** 2) + + grads = nnx.grad(loss_fn)(state_original.model) + state_original.apply_gradients(grads) + + # Capture state after one step + original_kernel_val = state_original.model.linear.kernel.value + original_step_val = state_original.optimizer.step.value + self.assertEqual(original_step_val, 1) + + # 3. Capture the "Checkpoint" as a pure dictionary + checkpoint_state = nnx.state(state_original).to_pure_dict() + + # 4. Initialize a fresh, different instance + new_rngs = nnx.Rngs(1) + new_model = MockModel(rngs=new_rngs) + new_optimizer = nnx.Optimizer(new_model, self.tx, wrt=nnx.Param) + state_restored = train_state_nnx.TrainStateNNX(new_model, new_optimizer) + + # Check differences before restoration + self.assertEqual(state_restored.optimizer.step.value, 0) + self.assertFalse(jnp.allclose(state_restored.model.linear.kernel.value, original_kernel_val)) + + # 5. Restore the state into the new instance. + # nnx.update supports updating from a pure dictionary. + nnx.update(state_restored, checkpoint_state) + + # 6. Verify restoration + # Check step counter + self.assertEqual(state_restored.optimizer.step.value, original_step_val) + # Check model weights + self.assertTrue(jnp.allclose(state_restored.model.linear.kernel.value, original_kernel_val)) + + # Check that it can still be trained after restoration + new_grads = nnx.grad(loss_fn)(state_restored.model) + state_restored.apply_gradients(new_grads) + self.assertEqual(state_restored.optimizer.step.value, 2) + + def test_restore_from_linen_state(self): + """Verifies a multi-stage migration: Linen CKPT -> Migrate -> NNX CKPT -> Restore.""" + # 1. Setup Linen TrainState (Simulating original training) + linen_model = LinenMockModel() + dummy_input = jnp.ones((1, 2)) + variables = linen_model.init(jax.random.key(42), dummy_input) + + state_linen = train_state.TrainState.create(apply_fn=linen_model.apply, params=variables["params"], tx=self.tx) + + # Perform a step to populate optimizer buffers + grads = jax.tree.map(jnp.ones_like, state_linen.params) + state_linen = state_linen.apply_gradients(grads=grads) + + temp_dir = pathlib.Path(tempfile.mkdtemp()) + try: + # --- PHASE 1: Save Legacy Linen Checkpoint --- + linen_ckpt_dir = temp_dir / "linen_ckpt" + mngr_linen = ocp.CheckpointManager( + linen_ckpt_dir, options=ocp.CheckpointManagerOptions(create=True), item_handlers=ocp.StandardCheckpointHandler() + ) + mngr_linen.save(0, args=ocp.args.StandardSave(state_linen)) + mngr_linen.wait_until_finished() + + # --- PHASE 2: Read Linen CKPT and Convert to NNX Structure --- + # Load it back without knowing the blueprint (reading as a pure PyTree) + restored_linen_obj = mngr_linen.restore(0) + + # Convert the restored object to a pure dictionary structure. + restored_linen_dict = serialization.to_state_dict(restored_linen_obj) + + # Helper to recursively convert string keys back to integers + # and filter out None values. + def recursive_clean(obj): + if isinstance(obj, dict): + return {int(k) if k.isdigit() else k: recursive_clean(v) for k, v in obj.items() if v is not None} + return obj + + # Converted dict - simple PyTree mapping, no NNX Module initialization needed here. + # This simulates a situation where the conversion logic is blueprint-agnostic. + linen_as_nnx_dict = { + "model": restored_linen_dict["params"], + "optimizer": { + "step": jnp.array(restored_linen_dict["step"]), + "opt_state": recursive_clean(restored_linen_dict["opt_state"]), + }, + } + + # --- PHASE 3: Save as Native NNX Checkpoint --- + nnx_ckpt_dir = temp_dir / "nnx_ckpt" + mngr_nnx = ocp.CheckpointManager( + nnx_ckpt_dir, options=ocp.CheckpointManagerOptions(create=True), item_handlers=ocp.StandardCheckpointHandler() + ) + # We save the raw dictionary directly to disk. + mngr_nnx.save(0, args=ocp.args.StandardSave(linen_as_nnx_dict)) + mngr_nnx.wait_until_finished() + + # --- PHASE 4: Restore from NNX Checkpoint to target Model --- + nnx_model = MockModel(rngs=nnx.Rngs(0)) + nnx_optimizer = nnx.Optimizer(nnx_model, self.tx, wrt=nnx.Param) + state_nnx = train_state_nnx.TrainStateNNX(nnx_model, nnx_optimizer) + + # We now restore using the nnx.State as a blueprint. This ensures Orbax + # correctly maps the arrays on disk to the model's structural expectation. + blueprint = nnx.state(state_nnx).to_pure_dict() + restored_nnx_pytree = mngr_nnx.restore(0, args=ocp.args.StandardRestore(item=blueprint)) + nnx.update(state_nnx, restored_nnx_pytree) + + # --- PHASE 5: Verification --- + # 1. Verify Step + self.assertEqual(state_nnx.optimizer.step.value, 1) + + # 2. Verify Weights + self.assertTrue(jnp.allclose(state_nnx.model.linear.kernel.value, state_linen.params["linear"]["kernel"])) + + # 3. Verify Chained Optimizer State (Clip at index 0, Adam at index 1) + self.assertEqual(type(state_nnx.optimizer.opt_state[0]), type(state_linen.opt_state[0])) + + # state_linen.opt_state[1] is the Adam chain state. + # state_linen.opt_state[1][0] is the ScaleByAdamState containing 'mu'. + self.assertTrue( + jnp.allclose( + state_nnx.optimizer.opt_state[1][0].mu["linear"]["kernel"], + state_linen.opt_state[1][0].mu["linear"]["kernel"], + ) + ) + + finally: + # Cleanup temporary directory + shutil.rmtree(temp_dir) + + def test_restore_from_checkpoint_model_params(self): + """Verifies that model parameters can be restored from model params only.""" + # 1. Setup mocked parameters manually (no Linen model needed for setup) + # This structure matches the path model.linear.kernel/bias in the NNX MockModel. + mock_params = {"linear": {"kernel": jnp.ones((2, 1)) * 9.0, "bias": jnp.zeros((1,))}} + + # Simplified checkpoint dictionary using hardcoded mocked params as requested + checkpoint_dict = { + "model": mock_params, + } + + temp_dir = pathlib.Path(tempfile.mkdtemp()) + try: + # --- PHASE 1: Save the partial checkpoint --- + mngr = ocp.CheckpointManager( + temp_dir, options=ocp.CheckpointManagerOptions(create=True), item_handlers=ocp.StandardCheckpointHandler() + ) + mngr.save(0, args=ocp.args.StandardSave(checkpoint_dict)) + mngr.wait_until_finished() + + # --- PHASE 2: Restore into a full TrainStateNNX --- + nnx_model = MockModel(rngs=nnx.Rngs(0)) + nnx_optimizer = nnx.Optimizer(nnx_model, self.tx, wrt=nnx.Param) + state_nnx = train_state_nnx.TrainStateNNX(nnx_model, nnx_optimizer) + + # We use nnx.state to get a full blueprint as a reference. + full_nnx_pure_dict = nnx.state(state_nnx).to_pure_dict() + blueprint = {"model": full_nnx_pure_dict["model"]} + + # If we don't know if the checkpoint on disk has 'optimizer' or not, we simulate + # schema-agnostic restoration by calling restore without a blueprint. + # This avoids Orbax structural mismatch errors while allowing us to see the data. + restored_pytree = mngr.restore(0, args=ocp.args.StandardRestore(item=blueprint)) + + # Use nnx.update to apply the restored data to the stateful NNX object. + # nnx.update is naturally partial: it will update 'model' from the restored dict + # and leave 'optimizer' untouched at its initialized value. + nnx.update(state_nnx, restored_pytree) + + # --- PHASE 3: Verification --- + # Check that weights were restored to the specific mock values + self.assertTrue(jnp.allclose(state_nnx.model.linear.kernel.value, mock_params["linear"]["kernel"])) + # Step remains at its initialized value (0) because it was not in the checkpoint + self.assertEqual(state_nnx.optimizer.step.value, 0) + + # Verify that the optimizer state still exists in the object (initialized) + # even though it was not provided in the checkpoint. + # Adam's state is at index 1 of the chain, and it's a nested structure (tuple). + # We verify that index 0 (ScaleByAdamState) contains the 'mu' State container. + self.assertIsInstance(state_nnx.optimizer.opt_state[1][0].mu, nnx.State) + + finally: + # Cleanup temporary directory + shutil.rmtree(temp_dir) + + +class TestMaybeSaveCheckpointStepAlignment(unittest.TestCase): + """Verify maybe_save_checkpoint's fallback step matches the last completed step. + + When the training loop's final save calls maybe_save_checkpoint without an + explicit `step`, it derives `actual_step` from the state: + - NNX: int(state.optimizer.step) - 1 + - Linen: int(state.step) - 1 + Both TrainStateNNX.apply_gradients (via nnx.Optimizer.update) and Linen + TrainState.apply_gradients increment the counter by 1 per call, so after N + gradient applications the counter is N and the "last completed step" is N-1. + """ + + N_STEPS = 5 + + def setUp(self): + self.tx = optax.adam(1e-3) + + def _build_nnx_state(self, num_steps): + """Build an nnx.State flattened from TrainStateNNX after num_steps gradient applications.""" + model = MockModel(rngs=nnx.Rngs(0)) + optimizer = nnx.Optimizer(model, self.tx, wrt=nnx.Param) + state = train_state_nnx.TrainStateNNX(model, optimizer) + + def loss_fn(m): + return jnp.mean(m(jnp.ones((1, 2))) ** 2) + + for _ in range(num_steps): + grads = nnx.grad(loss_fn)(state.model) + state.apply_gradients(grads) + # maybe_save_checkpoint is called with a flat nnx.State in the NNX path + # (train_step returns nnx.state(new_state)). + return nnx.state(state) + + def _build_linen_state(self, num_steps): + """Build a Linen TrainState after num_steps gradient applications.""" + model = LinenMockModel() + variables = model.init(jax.random.key(0), jnp.ones((1, 2))) + state = train_state.TrainState.create(apply_fn=model.apply, params=variables["params"], tx=self.tx) + grads = jax.tree.map(jnp.ones_like, state.params) + for _ in range(num_steps): + state = state.apply_gradients(grads=grads) + return state + + def _invoke_maybe_save(self, state, pure_nnx): + """Call maybe_save_checkpoint with save_checkpoint patched, return {step, state} captured.""" + # checkpoint_period=1 keeps force_ckpt_save False regardless of actual_step. + config = SimpleNamespace(pure_nnx=pure_nnx, checkpoint_period=1, async_checkpointing=False) + mgr = mock.MagicMock() + mgr.reached_preemption.return_value = False + + captured = {} + + def fake_save_checkpoint(_mgr, step, state_arg, *_args, **_kwargs): + captured["step"] = step + captured["state"] = state_arg + return False # no save happened => print_save_message is skipped + + with mock.patch.object(checkpointing, "save_checkpoint", side_effect=fake_save_checkpoint): + checkpointing.maybe_save_checkpoint(mgr, state, config, data_iterator=None, step=None) + return captured + + def test_nnx_final_save_step_is_n_minus_1(self): + state = self._build_nnx_state(self.N_STEPS) + self.assertEqual(int(state.optimizer.step.value), self.N_STEPS) + captured = self._invoke_maybe_save(state, pure_nnx=True) + self.assertEqual(captured["step"], self.N_STEPS - 1) + + def test_linen_final_save_step_is_n_minus_1(self): + state = self._build_linen_state(self.N_STEPS) + self.assertEqual(int(state.step), self.N_STEPS) + captured = self._invoke_maybe_save(state, pure_nnx=False) + self.assertEqual(captured["step"], self.N_STEPS - 1) + + def test_nnx_and_linen_agree_on_actual_step(self): + """TrainStateNNX and Linen TrainState must yield the same fallback actual_step.""" + nnx_state = self._build_nnx_state(self.N_STEPS) + linen_state = self._build_linen_state(self.N_STEPS) + self.assertEqual( + self._invoke_maybe_save(nnx_state, pure_nnx=True)["step"], + self._invoke_maybe_save(linen_state, pure_nnx=False)["step"], + ) + + def test_nnx_state_is_converted_to_pure_dict_before_save(self): + """For pure_nnx=True, maybe_save_checkpoint must pass a plain dict to save_checkpoint, not an nnx.State.""" + state = self._build_nnx_state(self.N_STEPS) + self.assertIsInstance(state, nnx.State) # precondition: NNX train_step returns an nnx.State + + captured = self._invoke_maybe_save(state, pure_nnx=True) + + # save_checkpoint should have received a plain Python dict (the result of + # nnx.State.to_pure_dict()), not the original nnx.State. + self.assertIsInstance(captured["state"], dict) + self.assertNotIsInstance(captured["state"], nnx.State) + # Sanity: the converted dict still mirrors the TrainStateNNX structure. + self.assertIn("model", captured["state"]) + self.assertIn("optimizer", captured["state"]) + + def test_linen_state_is_passed_through_unchanged(self): + """For pure_nnx=False, maybe_save_checkpoint must pass the original TrainState object through.""" + state = self._build_linen_state(self.N_STEPS) + captured = self._invoke_maybe_save(state, pure_nnx=False) + # Linen path must not invoke to_pure_dict(); state is forwarded as-is. + self.assertIs(captured["state"], state) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/unit/train_state_nnx_test.py b/tests/unit/train_state_nnx_test.py new file mode 100644 index 0000000000..03db77ff63 --- /dev/null +++ b/tests/unit/train_state_nnx_test.py @@ -0,0 +1,90 @@ +# Copyright 2023–2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""TrainStateNNX tests.""" + +import unittest +import jax.numpy as jnp +from flax import nnx +import optax + +from maxtext.layers import train_state_nnx + + +class MockModel(nnx.Module): + """Mocked NNX model""" + + def __init__(self, rngs: nnx.Rngs): + self.linear = nnx.Linear(2, 1, rngs=rngs) + + def __call__(self, x): + return self.linear(x) + + +class TestTrainStateNNX(unittest.TestCase): + """TrainStateNNX tests.""" + + def setUp(self): + self.rngs = nnx.Rngs(0) + self.model = MockModel(rngs=self.rngs) + self.tx = optax.adam(1e-3) + + def test_init_with_optimizer(self): + """Test init with iptimizer.""" + optimizer = nnx.Optimizer(self.model, self.tx, wrt=nnx.Param) + state = train_state_nnx.TrainStateNNX(self.model, optimizer) + + self.assertEqual(state.model, self.model) + self.assertEqual(state.optimizer, optimizer) + # Access step directly from optimizer + self.assertEqual(state.optimizer.step.value, 0) + + def test_init_without_optimizer(self): + """Test init without optimizer.""" + state = train_state_nnx.TrainStateNNX(self.model, None) + + self.assertEqual(state.model, self.model) + self.assertIsNone(state.optimizer) + + def test_apply_gradients_success(self): + """Test apply gradients can be called successfully.""" + optimizer = nnx.Optimizer(self.model, self.tx, wrt=nnx.Param) + state = train_state_nnx.TrainStateNNX(self.model, optimizer) + + # Create dummy gradients matching the model state structure + def loss_fn(m): + return jnp.mean(m(jnp.ones((1, 2))) ** 2) + + grads = nnx.grad(loss_fn)(state.model) + + # Apply gradients + state.apply_gradients(grads) + + # Verify step incremented (managed by nnx.Optimizer) + self.assertEqual(state.optimizer.step.value, 1) + + def test_apply_gradients_raises_runtime_error(self): + """Test apply gradients without a optimizer.""" + # Initialize without optimizer (inference mode) + state = train_state_nnx.TrainStateNNX(self.model, None) + + dummy_grads = {} + with self.assertRaises(RuntimeError) as cm: + state.apply_gradients(dummy_grads) + + self.assertIn("inference only", str(cm.exception)) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/unit/train_utils_nnx_test.py b/tests/unit/train_utils_nnx_test.py new file mode 100644 index 0000000000..2ff7276fd9 --- /dev/null +++ b/tests/unit/train_utils_nnx_test.py @@ -0,0 +1,149 @@ +# Copyright 2025-2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unit tests for the NNX-specific helpers / patterns in train_utils.setup_train_loop. + +setup_train_loop itself is integration territory (it touches data iterators, +checkpoint managers, and a real mesh), so we cover the NNX-only pieces that +have unit-testable contracts: + + 1. The create_train_state_fn closure pattern: builds nnx.Optimizer + TrainStateNNX + from a zero-arg model factory and a transform. + 2. nnx.split(state.model, nnx.Param, ...) returns Param-only state used to + compute state_params / state_mesh_shardings_params. + 3. nnx.merge(state_graphdef, state) reconstitutes a TrainStateNNX from the + pure-state form returned by setup_training_state. +""" + +import unittest +from functools import partial + +import jax +import jax.numpy as jnp +import optax +from flax import nnx + +from maxtext.layers import train_state_nnx + + +class _Model(nnx.Module): + + def __init__(self, rngs: nnx.Rngs): + self.linear = nnx.Linear(2, 1, rngs=rngs) + + +class TestCreateTrainStateFnClosure(unittest.TestCase): + """Exercise the closure pattern in setup_train_loop: + + def create_train_state_fn(): + model = _create_model_partial() + optimizer = nnx.Optimizer(model, tx, wrt=nnx.Param) + return train_state_nnx.TrainStateNNX(model, optimizer) + """ + + def test_returns_train_state_nnx_with_optimizer(self): + tx = optax.sgd(0.01) + + def _create_model(): + return _Model(rngs=nnx.Rngs(0)) + + def create_train_state_fn(): + model = _create_model() + optimizer = nnx.Optimizer(model, tx, wrt=nnx.Param) + return train_state_nnx.TrainStateNNX(model, optimizer) + + state = create_train_state_fn() + self.assertIsInstance(state, train_state_nnx.TrainStateNNX) + self.assertIsInstance(state.optimizer, nnx.Optimizer) + self.assertEqual(int(state.optimizer.step.get_value()), 0) + + def test_two_invocations_produce_independent_states(self): + """The lambda must call the factory each time (otherwise checkpoint init/restore would alias).""" + tx = optax.sgd(0.01) + counter = {"n": 0} + + def _create_model(): + counter["n"] += 1 + return _Model(rngs=nnx.Rngs(counter["n"])) + + def create_train_state_fn(): + model = _create_model() + return train_state_nnx.TrainStateNNX(model, nnx.Optimizer(model, tx, wrt=nnx.Param)) + + s1 = create_train_state_fn() + s2 = create_train_state_fn() + self.assertEqual(counter["n"], 2) + self.assertIsNot(s1.model, s2.model) + + +class TestSetupTrainLoopNNXTreeOps(unittest.TestCase): + """Cover the nnx.split(state.model, nnx.Param, ...) and nnx.merge round-trip + patterns that setup_train_loop uses to derive Param-only views and rebuild + the full TrainStateNNX before returning.""" + + def setUp(self): + self.tx = optax.sgd(0.01) + self.model = _Model(rngs=nnx.Rngs(0)) + self.state = train_state_nnx.TrainStateNNX(self.model, nnx.Optimizer(self.model, self.tx, wrt=nnx.Param)) + + def test_nnx_split_yields_param_only_state(self): + """state_params used for assert_params_sufficiently_sharded must contain only nnx.Param leaves.""" + _, state_params, _ = nnx.split(self.state.model, nnx.Param, ...) + leaves = jax.tree.leaves(state_params, is_leaf=lambda x: isinstance(x, nnx.Variable)) + self.assertGreater(len(leaves), 0) + for leaf in leaves: + self.assertIsInstance(leaf, nnx.Param) + + def test_nnx_merge_reconstructs_train_state_nnx(self): + """setup_train_loop ends with nnx.merge(state_graphdef, state) — verify that round-trips.""" + state_graphdef, state_pure = nnx.split(self.state) + train_state = nnx.merge(state_graphdef, state_pure) + self.assertIsInstance(train_state, train_state_nnx.TrainStateNNX) + # Same numeric values. + self.assertTrue(jnp.allclose(train_state.model.linear.kernel.value, self.state.model.linear.kernel.value)) + + +class TestInitStateFnIsCallable(unittest.TestCase): + """For the Linen path setup_train_loop builds init_state_fn = partial(...). + + The NNX path uses a closure instead — confirm both forms have the + zero-argument call contract create_checkpoint_manager / setup_training_state expect. + """ + + def test_nnx_init_state_fn_callable_with_no_args(self): + tx = optax.sgd(0.01) + + def _create_model(): + return _Model(rngs=nnx.Rngs(0)) + + def init_state_fn(): + model = _create_model() + return train_state_nnx.TrainStateNNX(model, nnx.Optimizer(model, tx, wrt=nnx.Param)) + + state = init_state_fn() # must not raise / require args + self.assertIsInstance(state, train_state_nnx.TrainStateNNX) + + def test_linen_init_state_fn_is_partial_callable_with_no_args(self): + """Sanity: the Linen-side `partial(init_initial_state, model, tx, config, is_training, init_rng)` form.""" + + def init_initial_state(model, tx, config, is_training, init_rng): + del model, tx, config, is_training, init_rng + return "linen-state" + + init_state_fn = partial(init_initial_state, "model", "tx", "config", True, "rng") + self.assertEqual(init_state_fn(), "linen-state") + + +if __name__ == "__main__": + unittest.main() From bc7793afdc132d1c62f036582d774732e4ca2e6b Mon Sep 17 00:00:00 2001 From: Lance Wang Date: Tue, 31 Mar 2026 14:32:29 +0000 Subject: [PATCH 2/4] NNX: add sharding tools, Linen<->NNX checkpoint utilities, and post-training fixes MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Part 1 — sharding diagnostics and Linen<->NNX checkpoint utilities: - modify print_shardings_params to support NNX (maxtext_utils.py) - add --pure_nnx flag to run_sharding_dump.py - add bidirectional Linen<->NNX checkpoint conversion utility (linen_nnx_converter.py) - add checkpoint comparison utility for Linen vs NNX validation (compare_linen_nnx_checkpoint.py) Part 2 — post-training bug fixes: - models.py: unpack MultimodalInput before passing to NNXDecoder (was passing the whole object as multimodal_input= kwarg; NNXDecoder only accepts individual fields) - optimizers.py: guard adam_pax against scalar LR from optax.inject_hyperparams (callable() check before invoking learning_rate_fn) - train_distill.py: fix nested NNX transform issue (nnx.value_and_grad inside nnx.jit raises conflicting outer_index error); refactored to jax.value_and_grad + explicit nnx.split/merge pattern; teacher inference moved outside value_and_grad --- .../compare_linen_nnx_checkpoint.py | 609 ++++++++++++ .../linen_nnx_converter.py | 581 ++++++++++++ src/maxtext/models/models.py | 6 +- src/maxtext/optimizers/optimizers.py | 4 +- .../post_train/distillation/train_distill.py | 78 +- .../trainers/post_train/rl/train_rl.py | 43 +- .../trainers/post_train/sft/train_sft.py | 70 +- src/maxtext/utils/maxtext_utils.py | 47 +- src/maxtext/utils/model_creation_utils.py | 7 + .../unit/distillation_scheduling_test.py | 44 +- .../post_training/unit/train_distill_test.py | 84 +- .../unit/compare_linen_nnx_checkpoint_test.py | 501 ++++++++++ tests/unit/linen_nnx_converter_test.py | 869 ++++++++++++++++++ tests/utils/run_sharding_dump.py | 9 +- 14 files changed, 2863 insertions(+), 89 deletions(-) create mode 100644 src/maxtext/checkpoint_conversion/compare_linen_nnx_checkpoint.py create mode 100644 src/maxtext/checkpoint_conversion/linen_nnx_converter.py create mode 100644 tests/unit/compare_linen_nnx_checkpoint_test.py create mode 100644 tests/unit/linen_nnx_converter_test.py diff --git a/src/maxtext/checkpoint_conversion/compare_linen_nnx_checkpoint.py b/src/maxtext/checkpoint_conversion/compare_linen_nnx_checkpoint.py new file mode 100644 index 0000000000..c103f234ee --- /dev/null +++ b/src/maxtext/checkpoint_conversion/compare_linen_nnx_checkpoint.py @@ -0,0 +1,609 @@ +# Copyright 2023-2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Compare checkpoint tree structures, shapes, and values. + +Supports comparing any combination of Linen and NNX checkpoints: +- Linen vs NNX (cross-format comparison) +- Linen vs Linen (same-format comparison) +- NNX vs NNX (same-format comparison) + +The script auto-detects the format of each checkpoint and applies the +appropriate normalization. Cross-format transformations (like layer axis +transposition) are only applied when comparing Linen vs NNX. + +Key differences between Linen and NNX checkpoints: +- Linen: params/params/decoder/layers/0/... (per-layer, double nested) +- NNX: model/decoder/layers/... (stacked layers, single nested, {value: array} wrappers) + +The script handles: +- Double 'params' nesting in Linen checkpoints +- 'model' key in NNX checkpoints (vs 'params' in Linen) +- {value: array} wrappers in NNX checkpoints +- Layer axis transposition (NNX stacks layers along axis 0, only for cross-format) +- RNG filtering (NNX has rngs, Linen doesn't) + +Usage: + # Compare Linen vs NNX (structure and shapes only) + python compare_linen_nnx_checkpoint.py \ + --ckpt_path_1="gs://bucket/linen_checkpoint/0/items" \ + --ckpt_path_2="gs://bucket/nnx_checkpoint/0/items" + + # Compare NNX vs NNX + python compare_linen_nnx_checkpoint.py \ + --ckpt_path_1="gs://bucket/nnx_checkpoint_a/0/items" \ + --ckpt_path_2="gs://bucket/nnx_checkpoint_b/0/items" + + # Compare Linen vs Linen + python compare_linen_nnx_checkpoint.py \ + --ckpt_path_1="gs://bucket/linen_checkpoint_a/0/items" \ + --ckpt_path_2="gs://bucket/linen_checkpoint_b/0/items" + + # Compare with value checking + python compare_linen_nnx_checkpoint.py \ + --ckpt_path_1="gs://bucket/checkpoint_a/0/items" \ + --ckpt_path_2="gs://bucket/checkpoint_b/0/items" \ + --compare_values --atol=1e-5 --rtol=1e-5 +""" + +import os +from typing import Any, Dict, Sequence + +# MUST set before importing JAX to force CPU-only mode +os.environ["JAX_PLATFORMS"] = "cpu" + +import jax +import jax.numpy as jnp +from jax.tree_util import tree_flatten_with_path, keystr, tree_structure, tree_map_with_path +import numpy as np +from etils import epath +import orbax.checkpoint as ocp +from absl import app +from absl import flags + +FLAGS = flags.FLAGS + +flags.DEFINE_string( + "ckpt_path_1", + None, + "Path to the first checkpoint items directory. Format is auto-detected.", + required=True, +) +flags.DEFINE_string( + "ckpt_path_2", + None, + "Path to the second checkpoint items directory. Format is auto-detected.", + required=True, +) +flags.DEFINE_boolean( + "verbose", + False, + "Print detailed per-parameter information.", +) +flags.DEFINE_boolean( + "transpose_nnx_layers", + False, + "Transpose NNX layer params from (layers, ...) to (...) for comparison. " + "NNX stacks layers along axis 0, while Linen stores per-layer params. " + "Only applied for cross-format (Linen vs NNX) comparisons.", +) +flags.DEFINE_string( + "compare_only", + "params", + "Which parts to compare: 'params' for params only, 'all' for full state.", +) +flags.DEFINE_boolean( + "ignore_rngs", + True, + "Ignore RNG-related paths in comparison (NNX has rngs, Linen doesn't).", +) +flags.DEFINE_boolean( + "compare_values", + False, + "Also compare parameter values (not just structure and shapes).", +) +flags.DEFINE_float( + "atol", + 1e-5, + "Absolute tolerance for value comparison.", +) +flags.DEFINE_float( + "rtol", + 1e-5, + "Relative tolerance for value comparison.", +) + + +def log(message: str) -> None: + """Log a message with prefix.""" + print(f"[compare_ckpt] {message}") + + +def is_rng_path(path: str) -> bool: + """Check if a path is RNG-related.""" + path_lower = path.lower() + return "rngs" in path_lower or "rng" in path_lower + + +def filter_rngs(tree: Dict[str, Any]) -> Dict[str, Any]: + """Filter out RNG-related keys from a tree.""" + if not isinstance(tree, dict): + return tree + + result = {} + for key, value in tree.items(): + # Skip RNG-related keys + if is_rng_path(key): + continue + # Recursively filter nested dicts + if isinstance(value, dict): + filtered = filter_rngs(value) + if filtered: # Only add if not empty after filtering + result[key] = filtered + else: + result[key] = value + return result + + +def detect_format(state: dict) -> str: + """Detects checkpoint format from state structure ('linen' or 'nnx'). + + Linen format: + - Top-level keys: ['params', 'opt_state', 'step'] + - params/params/decoder/... (double nested) + + NNX format: + - Top-level keys: ['model', 'optimizer'] (nnx.State style) + - model/decoder/... with {value: array} wrappers + """ + # Check for NNX nnx.State format (has 'model' key instead of 'params') + if "model" in state: + return "nnx" + + if "params" not in state: + raise ValueError(f"Checkpoint does not contain 'params' or 'model' key. Found keys: {list(state.keys())}") + + params = state["params"] + + # Check for Linen's double 'params' nesting + if isinstance(params, dict) and "params" in params: + inner = params["params"] + if isinstance(inner, dict) and ("decoder" in inner or "encoder" in inner): + return "linen" + + # Check for NNX's flat structure (params/decoder/...) + if isinstance(params, dict) and ("decoder" in params or "encoder" in params): + return "nnx" + + # Try to detect by looking for {value: array} wrappers (NNX style) + if _has_value_wrappers(params): + return "nnx" + + raise ValueError( + f"Could not detect checkpoint format. params keys: {list(params.keys()) if isinstance(params, dict) else type(params)}" + ) + + +def _has_value_wrappers(tree: Any) -> bool: + """Check if tree contains {value: array} wrappers (NNX style).""" + if isinstance(tree, dict): + if set(tree.keys()) == {"value"}: + inner = tree["value"] + if hasattr(inner, "shape") or isinstance(inner, (np.ndarray, jnp.ndarray)): + return True + for v in tree.values(): + if _has_value_wrappers(v): + return True + return False + + +def _strip_value_wrappers(tree: Any) -> Any: + """Recursively strips {'value': array} wrappers from a tree.""" + if isinstance(tree, dict): + if set(tree.keys()) == {"value"}: + inner = tree["value"] + if hasattr(inner, "shape") or isinstance(inner, (np.ndarray, jnp.ndarray)): + return inner + return {k: _strip_value_wrappers(v) for k, v in tree.items()} + elif isinstance(tree, (list, tuple)): + return type(tree)(_strip_value_wrappers(item) for item in tree) + else: + return tree + + +def _normalize_linen_params(params: dict) -> dict: + """Normalize Linen params by removing double 'params' nesting.""" + if isinstance(params, dict) and "params" in params: + inner = params["params"] + if isinstance(inner, dict) and ("decoder" in inner or "encoder" in inner): + return inner + return params + + +def _normalize_nnx_params(params: dict) -> dict: + """Normalize NNX params by stripping {value: array} wrappers.""" + return _strip_value_wrappers(params) + + +def load_checkpoint(checkpoint_path: str, metadata_only: bool = False) -> dict: + """Loads checkpoint from local or GCS path. + + If metadata_only=True, returns a pytree of ArrayMetadata (shape/dtype only) + without downloading any tensor data. This is fast and sufficient for + structure/shape comparison. + """ + log(f"Loading checkpoint from: {checkpoint_path}") + if metadata_only: + log(" Mode: metadata only (no tensor data downloaded)") + + checkpoint_dir = epath.Path(checkpoint_path) + + # Create checkpointer and get metadata + ckptr = ocp.Checkpointer(ocp.PyTreeCheckpointHandler()) + + try: + metadata = ckptr.metadata(checkpoint_dir) + + if metadata_only: + tree = metadata.item_metadata.tree + log(f" Loaded metadata keys: {list(tree.keys())}") + return tree + + # Create a mesh with all available devices for unsharded restoration + devices = np.array(jax.devices()).reshape((-1,)) + single_device_mesh = jax.sharding.Mesh(devices, ("x",)) + unsharded = jax.sharding.NamedSharding(single_device_mesh, jax.sharding.PartitionSpec()) + + # Build restore args that restore arrays without original sharding + restore_args = jax.tree_util.tree_map( + lambda x: ocp.ArrayRestoreArgs(sharding=unsharded) if hasattr(x, "shape") else None, + metadata.item_metadata.tree, + is_leaf=lambda x: hasattr(x, "shape"), + ) + state = ckptr.restore(checkpoint_dir, restore_args=restore_args) + except Exception as e: # pylint: disable=broad-exception-caught + if metadata_only: + log(f" Metadata loading failed: {e}") + raise + # Fallback to simple restore without sharding args + log(f" Falling back to simple restore: {e}") + checkpointer = ocp.PyTreeCheckpointer() + state = checkpointer.restore(checkpoint_path) + + if state is None: + raise ValueError(f"Failed to restore checkpoint from {checkpoint_path}") + + log(f" Loaded keys: {list(state.keys())}") + return state + + +def transform_nnx_params_for_comparison(nnx_params: Dict[str, Any]) -> Dict[str, Any]: + """Transform NNX params to match Linen structure for comparison. + + NNX stacks layer parameters along axis 0 (shape: [num_layers, ...]), + while Linen stores per-layer parameters (shape: [...]). + + This function transposes layer params from (layers, d1, d2, ...) to (d1, layers, d2, ...) + to align with how Linen params would look if stacked. + """ + + def _transform(path, leaf: jax.Array) -> jax.Array: + key_str = keystr(path) + + # Only transform arrays in 'layers' with ndim >= 2 + if "layers" in key_str and hasattr(leaf, "ndim") and leaf.ndim >= 2: + # Transpose from (layers, d1, d2, ...) to (d1, layers, d2, ...) + axes = (1, 0) + tuple(range(2, leaf.ndim)) + result = jnp.transpose(leaf, axes=axes) + if FLAGS.verbose: + log(f" TRANSPOSING: {key_str} shape {leaf.shape} -> {result.shape}") + return result + else: + return leaf + + log("Transforming NNX params (transposing layer dimensions)...") + return tree_map_with_path(_transform, nnx_params) + + +def get_tree_structure_info(tree: Dict[str, Any]) -> Dict[str, tuple]: + """Get structure info as dict of path -> (shape, dtype).""" + flat_with_path, _ = tree_flatten_with_path(tree) + return { + keystr(p): ( + getattr(leaf, "shape", "N/A"), + str(getattr(leaf, "dtype", type(leaf).__name__)), + ) + for p, leaf in flat_with_path + } + + +def print_structure_diff(params1: Dict, params2: Dict, name1: str = "Linen", name2: str = "NNX"): + """Print structural differences between two param trees.""" + info1 = get_tree_structure_info(params1) + info2 = get_tree_structure_info(params2) + keys1, keys2 = set(info1.keys()), set(info2.keys()) + + only_in_1 = sorted(keys1 - keys2) + only_in_2 = sorted(keys2 - keys1) + common = keys1 & keys2 + + if only_in_1: + print(f"\n--- Paths only in {name1} ({len(only_in_1)}) ---") + for k in only_in_1: + shape, dtype = info1[k] + print(f" - {k}: shape={shape}, dtype={dtype}") + + if only_in_2: + print(f"\n--- Paths only in {name2} ({len(only_in_2)}) ---") + for k in only_in_2: + shape, dtype = info2[k] + print(f" + {k}: shape={shape}, dtype={dtype}") + + # Check for shape/dtype mismatches in common paths + shape_mismatches = [] + dtype_mismatches = [] + for k in common: + shape1, dtype1 = info1[k] + shape2, dtype2 = info2[k] + if shape1 != shape2: + shape_mismatches.append((k, shape1, shape2)) + if dtype1 != dtype2: + dtype_mismatches.append((k, dtype1, dtype2)) + + if shape_mismatches: + print(f"\n--- Shape mismatches ({len(shape_mismatches)}) ---") + for k, s1, s2 in shape_mismatches: + print(f" {k}: {name1}={s1}, {name2}={s2}") + + if dtype_mismatches: + print(f"\n--- Dtype mismatches ({len(dtype_mismatches)}) ---") + for k, d1, d2 in dtype_mismatches: + print(f" {k}: {name1}={d1}, {name2}={d2}") + + return only_in_1, only_in_2, shape_mismatches, dtype_mismatches + + +def compare_params( + params1: Dict[str, Any], + params2: Dict[str, Any], + verbose: bool = False, + compare_values: bool = False, + atol: float = 1e-5, + rtol: float = 1e-5, + name1: str = "Ckpt1", + name2: str = "Ckpt2", +) -> bool: + """Compare two parameter trees for structure, shape, and optionally values. + + Returns True if tree structures, shapes, and (optionally) values match. + """ + # First check tree structure + if tree_structure(params1) != tree_structure(params2): + print("\n[✗] Tree structures differ.") + print_structure_diff(params1, params2, name1=name1, name2=name2) + return False + + print("\n[✓] Tree structures are the same.") + + all_match = True + num_params = 0 + shape_mismatches = [] + dtype_mismatches = [] + value_mismatches = [] + value_matches = 0 + + def _compare_leaf(path, x, y): + nonlocal all_match, num_params, shape_mismatches, dtype_mismatches, value_mismatches, value_matches + key_str = keystr(path) + num_params += 1 + + shape1 = getattr(x, "shape", "N/A") + shape2 = getattr(y, "shape", "N/A") + dtype1 = getattr(x, "dtype", type(x).__name__) + dtype2 = getattr(y, "dtype", type(y).__name__) + + # Check shape + shape_match = shape1 == shape2 + if not shape_match: + shape_mismatches.append((key_str, shape1, shape2)) + all_match = False + + # Check dtype + dtype_match = str(dtype1) == str(dtype2) + if not dtype_match: + dtype_mismatches.append((key_str, dtype1, dtype2)) + all_match = False + + # Check values if requested and shapes match + if compare_values and shape_match and hasattr(x, "shape") and hasattr(y, "shape"): + try: + x_arr = np.asarray(x) + y_arr = np.asarray(y) + is_close = bool(np.allclose(x_arr, y_arr, atol=atol, rtol=rtol)) + + if is_close: + value_matches += 1 + if verbose: + print(f" [✓] {key_str} | Shape: {shape1} | Values match") + else: + diff = np.abs(x_arr - y_arr) + mean_diff = float(np.mean(diff)) + max_diff = float(np.max(diff)) + value_mismatches.append((key_str, mean_diff, max_diff)) + all_match = False + if verbose: + print(f" [✗] {key_str} | Shape: {shape1} | Mean diff: {mean_diff:.2e}, Max diff: {max_diff:.2e}") + except Exception as e: # pylint: disable=broad-exception-caught + value_mismatches.append((key_str, f"Error: {e}", "")) + all_match = False + elif verbose and not compare_values: + print(f" {key_str} | Shape: {shape1} | Dtype: {dtype1}") + + tree_map_with_path(_compare_leaf, params1, params2) + + # Print summary + print("\n--- Summary ---") + print(f"Total parameters: {num_params}") + + if shape_mismatches: + print(f"\n[✗] Shape mismatches ({len(shape_mismatches)}):") + for key_str, s1, s2 in shape_mismatches: + print(f" {key_str}: {name1}={s1}, {name2}={s2}") + else: + print("[✓] All shapes match.") + + if dtype_mismatches: + print(f"\n[✗] Dtype mismatches ({len(dtype_mismatches)}):") + for key_str, d1, d2 in dtype_mismatches: + print(f" {key_str}: {name1}={d1}, {name2}={d2}") + else: + print("[✓] All dtypes match.") + + if compare_values: + if value_mismatches: + print(f"\n[✗] Value mismatches ({len(value_mismatches)}):") + for item in value_mismatches[:20]: # Show first 20 + if len(item) == 3: + key_str, mean_diff, max_diff = item + if isinstance(mean_diff, float): + print(f" {key_str}: mean_diff={mean_diff:.2e}, max_diff={max_diff:.2e}") + else: + print(f" {key_str}: {mean_diff}") + if len(value_mismatches) > 20: + print(f" ... and {len(value_mismatches) - 20} more (use --verbose to see all)") + else: + print(f"[✓] All values match (atol={atol}, rtol={rtol}).") + print(f" Values matching: {value_matches}/{num_params}") + + return all_match + + +def _extract_params(state: dict, fmt: str) -> dict: + """Extract params from a checkpoint state based on its detected format.""" + if fmt == "linen": + return state.get("params", {}) + else: + # NNX format: params are in 'model' key + return state.get("model", state.get("params", {})) + + +def _normalize_params(params: dict, fmt: str) -> dict: + """Normalize params based on detected format.""" + if fmt == "linen": + return _normalize_linen_params(params) + else: + return _normalize_nnx_params(params) + + +def main(argv: Sequence[str]): + if len(argv) > 1: + raise app.UsageError("Too many command-line arguments.") + + ckpt_path_1 = FLAGS.ckpt_path_1 + ckpt_path_2 = FLAGS.ckpt_path_2 + + print("=" * 80) + print("Checkpoint Comparator") + print("=" * 80) + + print(f"\nCheckpoint 1: {ckpt_path_1}") + print(f"Checkpoint 2: {ckpt_path_2}") + print(f"Transpose NNX layers: {FLAGS.transpose_nnx_layers}") + print(f"Ignore RNGs: {FLAGS.ignore_rngs}") + print(f"Compare values: {FLAGS.compare_values}") + if FLAGS.compare_values: + print(f" Tolerance: atol={FLAGS.atol}, rtol={FLAGS.rtol}") + + # Load checkpoints — use metadata-only when not comparing values to avoid + # downloading tensor data (which can be 100+ GiB and cause XPK timeouts). + metadata_only = not FLAGS.compare_values + print("\n" + "-" * 40) + state_1 = load_checkpoint(ckpt_path_1, metadata_only=metadata_only) + state_2 = load_checkpoint(ckpt_path_2, metadata_only=metadata_only) + + # Detect formats + format_1 = detect_format(state_1) + format_2 = detect_format(state_2) + log(f"Detected checkpoint 1 format: {format_1}") + log(f"Detected checkpoint 2 format: {format_2}") + + is_cross_format = format_1 != format_2 + name_1 = f"Ckpt1({format_1})" + name_2 = f"Ckpt2({format_2})" + + # Extract and normalize params + print("\n" + "-" * 40) + log("Normalizing parameters...") + + if FLAGS.compare_only == "params": + params_1 = _extract_params(state_1, format_1) + params_2 = _extract_params(state_2, format_2) + else: + params_1 = state_1 + params_2 = state_2 + + params_1 = _normalize_params(params_1, format_1) + log(f" Checkpoint 1 ({format_1}): normalized") + params_2 = _normalize_params(params_2, format_2) + log(f" Checkpoint 2 ({format_2}): normalized") + + # Filter out RNG paths if requested + if FLAGS.ignore_rngs: + print("\n" + "-" * 40) + log("Filtering out RNG-related paths...") + params_1 = filter_rngs(params_1) + params_2 = filter_rngs(params_2) + + # Transform NNX params for cross-format comparison (transpose layer dimensions) + # Only apply when comparing Linen vs NNX, not for same-format comparisons + if FLAGS.transpose_nnx_layers and is_cross_format: + print("\n" + "-" * 40) + if format_1 == "nnx": + params_1 = transform_nnx_params_for_comparison(params_1) + if format_2 == "nnx": + params_2 = transform_nnx_params_for_comparison(params_2) + + # Compare + print("\n" + "-" * 40) + log("Comparing parameters...") + + success = compare_params( + params_1, + params_2, + verbose=FLAGS.verbose, + compare_values=FLAGS.compare_values, + atol=FLAGS.atol, + rtol=FLAGS.rtol, + name1=name_1, + name2=name_2, + ) + + # Final verdict + print("\n" + "=" * 80) + if success: + print("CHECKPOINTS MATCH") + if FLAGS.compare_values: + print(" Tree structure, shapes, and values are identical!") + else: + print(" Tree structure and all shapes are identical!") + else: + print("CHECKPOINTS DIFFER") + print(" See details above for mismatches.") + print("=" * 80) + + return 0 if success else 1 + + +if __name__ == "__main__": + app.run(main) diff --git a/src/maxtext/checkpoint_conversion/linen_nnx_converter.py b/src/maxtext/checkpoint_conversion/linen_nnx_converter.py new file mode 100644 index 0000000000..015d3b5a56 --- /dev/null +++ b/src/maxtext/checkpoint_conversion/linen_nnx_converter.py @@ -0,0 +1,581 @@ +# Copyright 2023-2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Bidirectional conversion between Linen and NNX checkpoint formats. + +Top-level key mapping: + Linen → NNX: + params/params/ → model/ (remove double-nesting, rename, add {value:} wrappers) + opt_state → optimizer/opt_state (remove 'params' level from mu/nu) + step → optimizer/step (move inside optimizer) + + NNX → Linen: + model/ → params/params/ (strip {value:} wrappers, add double-nesting) + optimizer/opt_state → opt_state (add 'params' level to mu/nu) + optimizer/step → step (move to top level) + +Layer structure (--scan_layers): + linen_to_nnx: + scan_layers=True (default): stack layers_N arrays → 'layers' tensor with layer dim at axis 1 + scan_layers=False: rename layers_N → integer-keyed 'layers/{N}' + + nnx_to_linen (auto-detected): + Stacked 'layers' tensor → unstack along axis 1 → layers_N per-layer arrays + Integer-keyed layers/{N} → rename to layers_N + +Usage: + python linen_nnx_converter.py \\ + --source_path="gs://bucket/checkpoint/0/items" \\ + --target_path="gs://bucket/converted/" \\ + --direction=auto +""" + +import argparse +import os +import re +import time +from typing import Any + +# MUST set before importing JAX to force CPU-only mode +os.environ["JAX_PLATFORMS"] = "cpu" + +import jax +import numpy as np +from etils import epath +import orbax.checkpoint as ocp + + +def log(message: str) -> None: + print(f"[linen_nnx_converter] {message}") + + +# ── Format detection ─────────────────────────────────────────────────────────── + + +def detect_format(state: dict) -> str: + """Detects checkpoint format ('linen' or 'nnx') from top-level keys.""" + # NNX: uses 'model' as the top-level params key + if "model" in state: + return "nnx" + + if "params" not in state: + raise ValueError(f"Cannot detect checkpoint format: no 'model' or 'params' key. " f"Found: {list(state.keys())}") + + params = state["params"] + + # Linen: double-nested params/params/decoder + if isinstance(params, dict) and "params" in params: + inner = params["params"] + if isinstance(inner, dict) and ("decoder" in inner or "encoder" in inner): + return "linen" + + # Old NNX format: params/decoder (single-nested with value wrappers) + if isinstance(params, dict) and ("decoder" in params or "encoder" in params): + if _has_value_wrappers(params): + return "nnx" + + if "optimizer" in state: + return "nnx" + if "opt_state" in state: + return "linen" + + raise ValueError( + f"Could not detect checkpoint format. Keys: {list(state.keys())}, " + f"params keys: {list(params.keys()) if isinstance(params, dict) else type(params)}" + ) + + +# ── Value wrapper helpers ────────────────────────────────────────────────────── + + +def _has_value_wrappers(tree: Any) -> bool: + """Returns True if tree contains {value: array} wrappers (NNX style).""" + if isinstance(tree, dict): + if set(tree.keys()) == {"value"}: + inner = tree["value"] + if hasattr(inner, "shape") or isinstance(inner, np.ndarray): + return True + for v in tree.values(): + if _has_value_wrappers(v): + return True + return False + + +def _strip_value_wrappers(tree: Any) -> Any: + """Recursively strips {value: array} wrappers from a tree.""" + if isinstance(tree, dict): + if set(tree.keys()) == {"value"}: + inner = tree["value"] + if hasattr(inner, "shape") or isinstance(inner, np.ndarray): + return inner + return {k: _strip_value_wrappers(v) for k, v in tree.items()} + elif isinstance(tree, (list, tuple)): + return type(tree)(_strip_value_wrappers(item) for item in tree) + else: + return tree + + +def _add_value_wrappers(tree: Any) -> Any: + """Recursively wraps leaf arrays in {value: array} (NNX nnx.Param format).""" + if isinstance(tree, dict): + if set(tree.keys()) == {"value"}: + inner = tree["value"] + if hasattr(inner, "shape") or isinstance(inner, np.ndarray): + return tree # Already wrapped + return {k: _add_value_wrappers(v) for k, v in tree.items()} + elif isinstance(tree, (list, tuple)): + return type(tree)(_add_value_wrappers(item) for item in tree) + elif hasattr(tree, "shape") or isinstance(tree, np.ndarray): + return {"value": tree} + else: + return tree + + +# ── Layer structure helpers ──────────────────────────────────────────────────── + + +def _stack_layers(decoder: dict) -> tuple[dict, bool]: + """Stacks per-layer parameters (layers_N) into a single 'layers' dict at axis 0. + + Returns (result_dict, was_stacked). + """ + layer_pattern = re.compile(r"^layers_(\d+)$") + layer_indices = {} + other_keys = {} + + for key, value in decoder.items(): + match = layer_pattern.match(key) + if match: + layer_indices[int(match.group(1))] = value + else: + other_keys[key] = value + + if not layer_indices: + return decoder, False + + sorted_indices = sorted(layer_indices.keys()) + num_layers = len(sorted_indices) + log(f" Found {num_layers} individual layers, stacking into 'layers'") + + def stack_arrays(layers_data: list) -> Any: + first = layers_data[0] + if hasattr(first, "shape") or isinstance(first, np.ndarray): + return np.stack([np.asarray(layers_data[i]) for i in range(len(layers_data))], axis=0) + elif isinstance(first, dict): + result = {} + for key in first.keys(): + child_data = [layers_data[i].get(key) for i in range(len(layers_data))] + if all(c is not None for c in child_data): + result[key] = stack_arrays(child_data) + return result + else: + return first + + layers_data = [layer_indices[i] for i in sorted_indices] + stacked = stack_arrays(layers_data) + + result = dict(other_keys) + result["layers"] = stacked + return result, True + + +def _rename_layers_to_integer_keys(decoder: dict) -> dict: + """Converts layers_N keys to integer-keyed dict under 'layers' (no stacking). + + Converts {layers_0: {...}, layers_1: {...}} → {layers: {'0': {...}, '1': {...}}}. + Used for scan_layers=False linen→nnx conversion (Pattern C). + """ + layer_pattern = re.compile(r"^layers_(\d+)$") + layer_indices = {} + other_keys = {} + + for key, value in decoder.items(): + match = layer_pattern.match(key) + if match: + layer_indices[int(match.group(1))] = value + else: + other_keys[key] = value + + if not layer_indices: + return decoder + + sorted_indices = sorted(layer_indices.keys()) + log(f" Found {len(sorted_indices)} individual layers, renaming to integer-keyed 'layers/N'") + result = dict(other_keys) + result["layers"] = {str(i): layer_indices[i] for i in sorted_indices} + return result + + +def _transpose_layers_axes(tree: Any, src_axis: int, dst_axis: int) -> Any: + """Transposes the layers dimension in arrays within a tree (src_axis ↔ dst_axis).""" + if src_axis == dst_axis: + return tree + if isinstance(tree, dict): + return {k: _transpose_layers_axes(v, src_axis, dst_axis) for k, v in tree.items()} + elif isinstance(tree, (list, tuple)): + return type(tree)(_transpose_layers_axes(item, src_axis, dst_axis) for item in tree) + elif hasattr(tree, "shape") and len(tree.shape) >= 2: + axes = list(range(len(tree.shape))) + axes[src_axis], axes[dst_axis] = axes[dst_axis], axes[src_axis] + result = np.transpose(np.asarray(tree), axes=axes) + log(f" Transposed: {tree.shape} → {result.shape}") + return result + else: + return tree + + +def _detect_num_layers(tree: Any, scan_axis: int) -> int | None: + """Detects num_layers from the first array with ndim > scan_axis.""" + if hasattr(tree, "shape") or isinstance(tree, np.ndarray): + shape = getattr(tree, "shape", None) or np.asarray(tree).shape + if len(shape) > scan_axis: + return shape[scan_axis] + return None + if isinstance(tree, dict): + for v in tree.values(): + result = _detect_num_layers(v, scan_axis) + if result is not None: + return result + return None + + +def _unstack_single_layer(tree: Any, idx: int, scan_axis: int) -> Any: + """Extracts a single layer by indexing at scan_axis.""" + if hasattr(tree, "shape") or isinstance(tree, np.ndarray): + arr = np.asarray(tree) + if arr.ndim > scan_axis: + return np.take(arr, idx, axis=scan_axis) + return arr + if isinstance(tree, dict): + return {k: _unstack_single_layer(v, idx, scan_axis) for k, v in tree.items()} + if isinstance(tree, (list, tuple)): + return type(tree)(_unstack_single_layer(v, idx, scan_axis) for v in tree) + return tree + + +def _convert_layers_to_linen_format(decoder: dict) -> dict: + """Converts NNX 'layers' back to Linen's layers_N format (auto-detects NNX style). + + Handles: + - Stacked tensor (Pattern B): layers/ + → layers_0, layers_1, ... (unstack along axis 1) + - Integer-keyed (Pattern C): layers/0, layers/1, ... + → layers_0, layers_1, ... (rename) + """ + if "layers" not in decoder: + return decoder + + layers_val = decoder["layers"] + other_keys = {k: v for k, v in decoder.items() if k != "layers"} + + if not isinstance(layers_val, dict): + # Already a non-dict (shouldn't happen normally), keep as-is + return decoder + + # Pattern C: integer-keyed per-layer dict → rename + if all(k.isdigit() for k in layers_val.keys()): + result = dict(other_keys) + for idx_str, layer_data in sorted(layers_val.items(), key=lambda x: int(x[0])): + result[f"layers_{idx_str}"] = layer_data + log(f" Renamed integer-keyed layers/N → layers_N ({len(layers_val)} layers)") + return result + + # Pattern B: stacked tensor (layer dim at axis 1) → unstack + num_layers = _detect_num_layers(layers_val, scan_axis=1) + if num_layers is None: + log(" WARNING: Could not detect num_layers for unstacking, keeping 'layers' as-is") + result = dict(other_keys) + result["layers"] = layers_val + return result + + result = dict(other_keys) + for i in range(num_layers): + result[f"layers_{i}"] = _unstack_single_layer(layers_val, idx=i, scan_axis=1) + log(f" Unstacked scanned 'layers' → layers_N ({num_layers} layers at axis 1)") + return result + + +# ── Optimizer state helpers ──────────────────────────────────────────────────── + + +def _convert_opt_state_linen_to_nnx(opt_state: Any) -> Any: + """Removes 'params' nesting from mu/nu in linen opt_state. + + NNX optimizer state has plain arrays (no {value:} wrappers). + Linen opt_state mirrors the params structure (params/decoder/...), + so we remove the 'params' level to get decoder/... directly. + """ + if isinstance(opt_state, dict): + result = {} + for k, v in opt_state.items(): + if k == "params": + # Remove this level by merging its contents up + converted = _convert_opt_state_linen_to_nnx(v) + if isinstance(converted, dict): + result.update(converted) + else: + result[k] = converted + else: + result[k] = _convert_opt_state_linen_to_nnx(v) + return result + elif isinstance(opt_state, (list, tuple)): + return type(opt_state)(_convert_opt_state_linen_to_nnx(item) for item in opt_state) + else: + return opt_state # Plain array or scalar — no value wrapper for opt_state + + +def _convert_opt_state_nnx_to_linen(opt_state: Any, depth: int = 0) -> Any: + """Adds 'params' nesting to mu/nu, removes any stray {value:} wrappers. + + NNX optimizer mu/nu contains decoder/... directly. + Linen expects mu/params/decoder/... (one 'params' level mirroring the params structure). + """ + if isinstance(opt_state, dict): + # Strip any {value:} wrappers in opt_state (shouldn't be there but handle gracefully) + if set(opt_state.keys()) == {"value"}: + inner = opt_state["value"] + if hasattr(inner, "shape") or isinstance(inner, np.ndarray): + return inner + + result = {} + for k, v in opt_state.items(): + converted = _convert_opt_state_nnx_to_linen(v, depth + 1) + # Add one 'params' level after mu/nu (mirrors linen's params structure) + if k in ("mu", "nu") and isinstance(converted, dict): + result[k] = {"params": converted} + else: + result[k] = converted + return result + elif isinstance(opt_state, (list, tuple)): + return type(opt_state)(_convert_opt_state_nnx_to_linen(item, depth + 1) for item in opt_state) + else: + return opt_state + + +# ── Main conversion functions ────────────────────────────────────────────────── + + +def convert_linen_to_nnx(state: dict, scan_layers: bool = True) -> dict: + """Converts Linen checkpoint to NNX format. + + Args: + state: Linen checkpoint dict with keys ['params', 'opt_state', 'step']. + scan_layers: If True (default), stack per-layer arrays and insert layer + dim at axis 1 (for NNX with scan_layers=True). + If False, rename layers_N → integer-keyed layers/N + (for NNX with scan_layers=False). + """ + result = {} + + if "params" in state: + linen_params = state["params"] + # Remove double 'params' nesting: params/params/decoder → decoder + if isinstance(linen_params, dict) and "params" in linen_params: + nnx_params = linen_params["params"] + log(" params: Removed double 'params' nesting (params/params → model)") + else: + nnx_params = linen_params + log(" params: No double nesting found") + + stripped = _strip_value_wrappers(nnx_params) + + for component in ("decoder", "encoder"): + if component in stripped and isinstance(stripped[component], dict): + if scan_layers: + stripped[component], was_stacked = _stack_layers(stripped[component]) + if was_stacked and "layers" in stripped[component]: + log(f" {component}/layers: Transposing stacked (layers, ...) → (..., layers, ...) at axis 1") + stripped[component]["layers"] = _transpose_layers_axes(stripped[component]["layers"], src_axis=0, dst_axis=1) + else: + stripped[component] = _rename_layers_to_integer_keys(stripped[component]) + + result["model"] = _add_value_wrappers(stripped) + log(" model: Saved with {value:} wrappers under 'model' key") + + # optimizer: move step inside, keep opt_state + optimizer_dict = {} + if "step" in state: + optimizer_dict["step"] = state["step"] + log(f" optimizer/step: Moved from top-level (step={state['step']})") + if "opt_state" in state: + optimizer_dict["opt_state"] = _convert_opt_state_linen_to_nnx(state["opt_state"]) + log(" optimizer/opt_state: Removed 'params' nesting from mu/nu") + if optimizer_dict: + result["optimizer"] = optimizer_dict + + return result + + +def convert_nnx_to_linen(state: dict) -> dict: + """Converts NNX checkpoint to Linen format. + + Reads from 'model'/'optimizer' keys (or falls back to old 'params'/'opt_state' format). + Layer structure is auto-detected (stacked vs integer-keyed). + """ + result = {} + + model_key = "model" if "model" in state else "params" + if model_key in state: + nnx_params = state[model_key] + stripped = _strip_value_wrappers(nnx_params) + log(f" {model_key}: Removed {{value:}} wrappers") + + for component in ("decoder", "encoder"): + if component in stripped and isinstance(stripped[component], dict): + stripped[component] = _convert_layers_to_linen_format(stripped[component]) + + # Add double 'params' nesting: decoder → params/params/decoder + result["params"] = {"params": stripped} + log(" params: Added double 'params' nesting (model → params/params)") + + # optimizer: extract step and opt_state back to top level + if "optimizer" in state: + optimizer = state["optimizer"] + if "step" in optimizer: + result["step"] = optimizer["step"] + log(" step: Extracted from optimizer/step to top level") + if "opt_state" in optimizer: + result["opt_state"] = _convert_opt_state_nnx_to_linen(optimizer["opt_state"]) + log(" opt_state: Added 'params' nesting to mu/nu") + elif "opt_state" in state: + # Backward compat: old format with opt_state at top level + result["opt_state"] = _convert_opt_state_nnx_to_linen(state["opt_state"]) + log(" opt_state: Converted from top-level opt_state (old format)") + + if "step" in state and "step" not in result: + result["step"] = state["step"] + + return result + + +# ── Checkpoint I/O ───────────────────────────────────────────────────────────── + + +def load_checkpoint(checkpoint_path: str) -> dict: + """Loads checkpoint from local or GCS path.""" + log(f"Loading checkpoint from: {checkpoint_path}") + + checkpoint_dir = epath.Path(checkpoint_path) + ckptr = ocp.Checkpointer(ocp.PyTreeCheckpointHandler()) + metadata = ckptr.metadata(checkpoint_dir) + + devices = np.array(jax.devices()).reshape((-1,)) + single_device_mesh = jax.sharding.Mesh(devices, ("x",)) + unsharded = jax.sharding.NamedSharding(single_device_mesh, jax.sharding.PartitionSpec()) + + restore_args = jax.tree_util.tree_map( + lambda x: ocp.ArrayRestoreArgs(sharding=unsharded) if hasattr(x, "shape") else None, + metadata.item_metadata.tree, + is_leaf=lambda x: hasattr(x, "shape"), + ) + + state = ckptr.restore(checkpoint_dir, restore_args=restore_args) + log(f" Loaded keys: {list(state.keys())}") + return state + + +def save_checkpoint(state: dict, output_path: str) -> None: + """Saves checkpoint to local or GCS path.""" + log(f"Saving checkpoint to: {output_path}") + + output_dir = epath.Path(output_path) + output_dir.mkdir(exist_ok=True, parents=True) + + ckptr = ocp.PyTreeCheckpointer() + ckptr.save(output_dir, state, force=True) + log(" Checkpoint saved successfully") + + +# ── CLI ──────────────────────────────────────────────────────────────────────── + + +def main(): + parser = argparse.ArgumentParser( + description="Convert between Linen and NNX checkpoint formats.", + formatter_class=argparse.RawDescriptionHelpFormatter, + ) + parser.add_argument( + "--source_path", + type=str, + required=True, + help="Path to source checkpoint items directory (e.g. gs://bucket/ckpt/0/items).", + ) + parser.add_argument( + "--target_path", + type=str, + required=True, + help="Path to save converted checkpoint.", + ) + parser.add_argument( + "--direction", + type=str, + choices=["auto", "linen_to_nnx", "nnx_to_linen"], + default="auto", + help="Conversion direction. 'auto' detects from source format.", + ) + parser.add_argument( + "--scan_layers", + action=argparse.BooleanOptionalAction, + default=True, + help=( + "For linen_to_nnx only: if True (default), stack per-layer arrays into a " + "scanned 'layers' tensor with layer dim at axis 1 (for NNX with scan_layers=True). " + "If False, rename layers_N to integer-keyed layers/N without stacking " + "(for NNX with scan_layers=False)." + ), + ) + + args = parser.parse_args() + + print("=" * 80) + print("Linen <-> NNX Checkpoint Converter") + print("=" * 80) + + start_time = time.time() + + state = load_checkpoint(args.source_path) + + if args.direction == "auto": + source_format = detect_format(state) + target_format = "nnx" if source_format == "linen" else "linen" + log(f"Auto-detected: {source_format} → {target_format}") + else: + source_format = args.direction.split("_to_")[0] + target_format = args.direction.split("_to_")[1] + log(f"Using specified direction: {source_format} → {target_format}") + + log(f"Converting: {source_format} → {target_format}") + if source_format == "linen": + log(f"scan_layers={args.scan_layers}") + + if source_format == "linen" and target_format == "nnx": + converted_state = convert_linen_to_nnx(state, scan_layers=args.scan_layers) + elif source_format == "nnx" and target_format == "linen": + converted_state = convert_nnx_to_linen(state) + else: + raise ValueError(f"Invalid conversion: {source_format} → {target_format}") + + save_checkpoint(converted_state, args.target_path) + + elapsed = time.time() - start_time + print("\n" + "=" * 80) + print(f"Conversion complete in {elapsed:.2f} seconds") + print(f" Source: {args.source_path}") + print(f" Target: {args.target_path}") + print(f" Direction: {source_format} → {target_format}") + print("=" * 80) + + +if __name__ == "__main__": + main() diff --git a/src/maxtext/models/models.py b/src/maxtext/models/models.py index f5dd4e6cc3..b4708b10de 100644 --- a/src/maxtext/models/models.py +++ b/src/maxtext/models/models.py @@ -522,7 +522,11 @@ def __call__( previous_chunk=previous_chunk, slot=slot, page_state=page_state, - multimodal_input=multimodal_input, + image_embeddings=multimodal_input.image_embeddings if multimodal_input is not None else None, + image_masks=multimodal_input.image_masks if multimodal_input is not None else None, + audio_embeddings=multimodal_input.audio_embeddings if multimodal_input is not None else None, + audio_masks=multimodal_input.audio_masks if multimodal_input is not None else None, + bidirectional_mask=multimodal_input.bidirectional_mask if multimodal_input is not None else None, kv_caches=kv_caches, attention_metadata=attention_metadata, deepstack_visual_embeds=deepstack_visual_embeds, diff --git a/src/maxtext/optimizers/optimizers.py b/src/maxtext/optimizers/optimizers.py index 2ae7e5f8e5..9992d7674f 100644 --- a/src/maxtext/optimizers/optimizers.py +++ b/src/maxtext/optimizers/optimizers.py @@ -336,7 +336,9 @@ def _update_momentum(update, mu, nu): else: updates = jax.tree_util.tree_map(lambda x, v: x + weight_decay * v, updates, params) - step_size = -1.0 * learning_rate_fn(count) + # learning_rate_fn may be a callable schedule or a scalar (e.g. when wrapped + # by optax.inject_hyperparams, it is passed as a pre-evaluated scalar). + step_size = -1.0 * (learning_rate_fn(count) if callable(learning_rate_fn) else learning_rate_fn) # Finally, fold in step size. updates = jax.tree_util.tree_map(lambda x: step_size * x, updates) diff --git a/src/maxtext/trainers/post_train/distillation/train_distill.py b/src/maxtext/trainers/post_train/distillation/train_distill.py index 5de310da90..79e66dae83 100644 --- a/src/maxtext/trainers/post_train/distillation/train_distill.py +++ b/src/maxtext/trainers/post_train/distillation/train_distill.py @@ -259,30 +259,45 @@ def wrt_filter(path, x): # Inherits _shard_optimizer from PeftTrainer. def _train_step(self, model, optimizer, inputs): - """Overrides the main JIT block to natively handle ModelBundle module.""" + """Overrides the main JIT block to natively handle ModelBundle module. + Uses jax.value_and_grad with explicit split/merge to avoid nesting + nnx.value_and_grad inside nnx.jit, which causes Flax NNX to assign + conflicting outer_index values and raises: + ValueError: The graph structure of a node added to cached_partial was + mutated inside the transformation. + """ batch = self.gen_model_input_fn(inputs) + student = model.student_model + teacher = model.teacher_model current_step = model.training_step.value - def loss_wrapper(student, teacher, batch): - if "teacher_output" in batch: - teacher_output = batch["teacher_output"] - else: - teacher_output = self.strategy.teacher_forward_fn( - model=teacher, - input_tokens=batch["input_tokens"], - positions=batch["positions"], - attention_mask=batch.get("attention_mask"), - decoder_segment_ids=batch.get("decoder_segment_ids"), - decoder_target_tokens=batch.get("targets", None), - decoder_target_mask=batch.get("targets_segmentation", None), - cache=None, - ) + # Run teacher inference outside of value_and_grad. + # The teacher is frozen (stop_gradient), so its output is a constant + # from the perspective of the student gradient computation. + if "teacher_output" in batch: + teacher_output = batch["teacher_output"] + else: + teacher_output = self.strategy.teacher_forward_fn( + model=teacher, + input_tokens=batch["input_tokens"], + positions=batch["positions"], + attention_mask=batch.get("attention_mask"), + decoder_segment_ids=batch.get("decoder_segment_ids"), + decoder_target_tokens=batch.get("targets", None), + decoder_target_mask=batch.get("targets_segmentation", None), + cache=None, + ) + teacher_output = jax.tree.map(jax.lax.stop_gradient, teacher_output) - teacher_output = jax.tree.map(jax.lax.stop_gradient, teacher_output) + # Split student into differentiable params and non-differentiable rest. + # Capture graphdef outside of jax.value_and_grad for stable graph tracking. + student_graphdef, diff_params, rest = nnx.split(student, self.wrt_filter, ...) + def loss_wrapper_pure(diff_params, rest): + local_student = nnx.merge(student_graphdef, diff_params, rest, copy=True) student_output = self.strategy.student_forward_fn( - model=student, + model=local_student, input_tokens=batch["input_tokens"], positions=batch["positions"], attention_mask=batch.get("attention_mask"), @@ -291,30 +306,27 @@ def loss_wrapper(student, teacher, batch): decoder_target_mask=batch.get("targets_segmentation", None), cache=None, ) - # we should apply a mask for labels to disable segment-separator tokens labels = self.strategy.create_labels(batch["targets"], targets_segmentation=batch.get("targets_segmentation", None)) - return self.strategy.compute_loss(student_output, teacher_output, labels, step=current_step) - - # Because student is the 0th argument, argnums=0 guarantees - # we only compute gradients for the student. - grad_fn = nnx.value_and_grad( - loss_wrapper, - argnums=nnx.DiffState(0, self.wrt_filter), - has_aux=True, - ) + loss, aux = self.strategy.compute_loss(student_output, teacher_output, labels, step=current_step) + # Capture updated non-param state (e.g. RNG counters) from local_student. + _, _, new_rest = nnx.split(local_student, self.wrt_filter, ...) + return loss, (aux, new_rest) - out, grads = grad_fn(model.student_model, model.teacher_model, batch) + grad_fn = jax.value_and_grad(loss_wrapper_pure, argnums=0, has_aux=True) + (loss, (aux, new_rest)), grads = grad_fn(diff_params, rest) + + # Propagate updated non-param state back to student. + nnx.update(student, new_rest) + + optimizer.update(student, grads) # Increment step counter after loss computation model.training_step.value = current_step + 1 tunix_expects_grad_norm = getattr(self, "_tunix_expects_grad_norm", True) - - optimizer.update(model.student_model, grads) - if tunix_expects_grad_norm: - return out[0], out[1], optax.global_norm(grads) - return out[0], out[1] + return loss, aux, optax.global_norm(grads) + return loss, aux def _eval_step(self, model, inputs): """Evaluation only needs the student.""" diff --git a/src/maxtext/trainers/post_train/rl/train_rl.py b/src/maxtext/trainers/post_train/rl/train_rl.py index fda6d1f933..5f1b7a8808 100644 --- a/src/maxtext/trainers/post_train/rl/train_rl.py +++ b/src/maxtext/trainers/post_train/rl/train_rl.py @@ -55,6 +55,42 @@ import os import pathwaysutils +# JAX 0.9+ changed with_sharding_constraint to assert (not reshard) when all +# mesh axes are Explicit. tpu_inference still expects resharding semantics. +# Patch: try the original (works for Auto axes); on AssertionError (Explicit +# mesh) fall back to jax.sharding.reshard. +_orig_wsc = jax.lax.with_sharding_constraint + + +def _compat_wsc(x, shardings): + try: + return _orig_wsc(x, shardings) + except AssertionError: + return jax.sharding.reshard(x, shardings) + + +jax.lax.with_sharding_constraint = _compat_wsc + +# tpu_inference JaxEinsum defaults param_dtype=float32, so tpu_inference model weights +# initialize as float32. During weight sync, tunix._apply_dtype_cast then upcasts the +# incoming bfloat16 MaxText weights → float32 to match the target. This leaves v_proj +# as float32 while k_proj output appears bfloat16 (due to k_norm dtype promotion), +# causing a dtype mismatch in the ragged paged attention kernel. +# Fix: skip bfloat16→float32 upcasts during weight sync so synced weights stay bfloat16. +import jax.numpy as _jnp +import tunix.generate.utils as _tunix_utils + +_orig_apply_dtype_cast = _tunix_utils._apply_dtype_cast # pylint: disable=protected-access + + +def _no_bf16_to_f32_cast(val, tgt_dtype, src_key): + if hasattr(val, "dtype") and val.dtype == _jnp.bfloat16 and tgt_dtype == _jnp.float32: + return val # keep bfloat16; tpu_inference model dtype is bfloat16 despite float32 init + return _orig_apply_dtype_cast(val, tgt_dtype, src_key) + + +_tunix_utils._apply_dtype_cast = _no_bf16_to_f32_cast # pylint: disable=protected-access + from absl import app from absl import logging as absl_logging from etils import epath @@ -410,6 +446,8 @@ def create_rl_components( "hf_overrides": trainer_config.vllm_hf_overrides, "enable_expert_parallel": sampler_config.enable_expert_parallel, "enable_prefix_caching": True, # Enable prefix caching to speed up generation for long prompts + # Ensures vLLM model initializes with correct dtype (not float32 default) + "dtype": trainer_config.weight_dtype, }, rollout_vllm_sampling_kwargs={ "stop": trainer_config.stop_strings, @@ -555,7 +593,10 @@ def rl_train(argv: Sequence[str], kwargs: dict): max_train_steps = get_max_train_steps(trainer_config) # Create model tokenizer - model_tokenizer = AutoTokenizer.from_pretrained(trainer_config.tokenizer_path) + model_tokenizer = AutoTokenizer.from_pretrained( + trainer_config.tokenizer_path, + token=trainer_config.hf_access_token or None, + ) train_dataset, test_dataset = prepare_datasets(trainer_config, model_tokenizer) diff --git a/src/maxtext/trainers/post_train/sft/train_sft.py b/src/maxtext/trainers/post_train/sft/train_sft.py index c7c726cec9..a6c80d27dc 100644 --- a/src/maxtext/trainers/post_train/sft/train_sft.py +++ b/src/maxtext/trainers/post_train/sft/train_sft.py @@ -35,7 +35,7 @@ eval_interval=-1 steps=10 profiler=xplane weight_dtype=bfloat16 """ -from typing import Sequence +from typing import Any, Sequence from absl import app import os @@ -43,6 +43,7 @@ import optax import pathwaysutils +from flax import nnx from flax.linen import partitioning as nn_partitioning from orbax import checkpoint as ocp @@ -68,6 +69,70 @@ from maxtext.utils import model_creation_utils +class MaxTextPeftTrainer(peft_trainer.PeftTrainer): + """MaxText-specific PeftTrainer that avoids nested NNX transformations. + + Tunix's default PeftTrainer._train_step creates nnx.value_and_grad inside + nnx.jit. This nesting causes Flax NNX to assign conflicting outer_index + values to graph nodes, resulting in: + ValueError: The graph structure of a node added to cached_partial was + mutated inside the transformation. + + This subclass overrides create_train_step_fn to use jax.value_and_grad + with an explicit split/merge pattern (matching MaxText's pre-training NNX + train_step), which avoids the nested NNX transformation issue entirely. + """ + + def create_train_step_fn(self): + """Creates a train step using jax.value_and_grad with explicit NNX split/merge.""" + loss_fn_ref = self.loss_fn + has_aux = self._has_aux + gen_fn = self.gen_model_input_fn + is_lora_enabled = self._lora_enabled + wrt = nnx.LoRAParam if is_lora_enabled else nnx.Param + tunix_expects_grad_norm = getattr(self, "_tunix_expects_grad_norm", True) + + # Capture the graphdef once outside of JIT so that split/merge inside + # jax.value_and_grad can use a stable (non-traced) structural descriptor. + graphdef, _, _ = nnx.split(self.model, wrt, ...) + + def train_step(model: nnx.Module, optimizer: nnx.Optimizer, inputs: Any): + inputs = gen_fn(inputs) + + # Split model into differentiable params and non-differentiable rest. + # Using jax.value_and_grad (not nnx.value_and_grad) avoids nesting NNX + # transforms inside nnx.jit, which would corrupt outer_index tracking. + _, diff_params, rest = nnx.split(model, wrt, ...) + + def loss_wrapper(diff_params, rest, **inputs_kw): + local_model = nnx.merge(graphdef, diff_params, rest, copy=True) + out = loss_fn_ref(local_model, **inputs_kw) + # Capture updated non-param state (e.g. RNG counters) from local_model. + _, _, new_rest = nnx.split(local_model, wrt, ...) + if has_aux: + loss, aux = out + return loss, (aux, new_rest) + else: + return out, (None, new_rest) + + grad_fn = jax.value_and_grad(loss_wrapper, argnums=0, has_aux=True) + (out_val, (aux, new_rest)), grads = grad_fn(diff_params, rest, **inputs) + + # Propagate updated non-param state (RNG counters, etc.) back to model. + nnx.update(model, new_rest) + + # Apply optimizer update. grads has the same nnx.State(wrt) structure + # as diff_params, which is compatible with optimizer.update. + optimizer.update(model, grads) + + aux_out = aux if has_aux else None + if tunix_expects_grad_norm: + return out_val, aux_out, optax.global_norm(grads) + return out_val, aux_out + + return train_step + + def get_tunix_config(mt_config): """Gets the Tunix training configurations from the MaxText config. @@ -109,6 +174,7 @@ def get_tunix_config(mt_config): checkpointing_options=checkpointing_options, metrics_logging_options=metrics_logging_options, profiler_options=profiler_options, + data_sharding_axis=tuple(mt_config.data_sharding), ) @@ -162,7 +228,7 @@ def setup_trainer_state(mt_config, goodput_recorder=None): data_hooks = hooks.SFTDataHooks(mt_config, mesh, goodput_recorder) # Provide rules context so 'norm' is translated to mesh axes during maybe_restore with nn_partitioning.axis_rules(mt_config.logical_axis_rules): - trainer = peft_trainer.PeftTrainer(model, optimizer, tunix_config) + trainer = MaxTextPeftTrainer(model, optimizer, tunix_config) trainer.with_training_hooks(training_hooks) trainer.with_data_hooks(data_hooks) trainer = use_maxtext_loss_function(trainer, mt_config) diff --git a/src/maxtext/utils/maxtext_utils.py b/src/maxtext/utils/maxtext_utils.py index 0479d4bfca..e182ac973d 100644 --- a/src/maxtext/utils/maxtext_utils.py +++ b/src/maxtext/utils/maxtext_utils.py @@ -1894,26 +1894,41 @@ def print_shardings_params(params, params_sharding, mesh, logical_annotations=No """ Print state shardings comparing Logical Definition vs Physical Result. """ - if not hasattr(params, "params"): - params = {"params": params} - if not hasattr(params_sharding, "params"): - params_sharding = {"params": params_sharding} - if logical_annotations and not hasattr(logical_annotations, "params"): - logical_annotations = {"params": logical_annotations} + if not isinstance(params, nnx.State): + if not hasattr(params, "params"): + params = {"params": params} + if not hasattr(params_sharding, "params"): + params_sharding = {"params": params_sharding} + if logical_annotations and not hasattr(logical_annotations, "params"): + logical_annotations = {"params": logical_annotations} leaves_params, _ = jax.tree_util.tree_flatten_with_path(params) leaves_sharding, _ = jax.tree_util.tree_flatten_with_path(params_sharding) - leaves_logical, _ = jax.tree_util.tree_flatten_with_path(logical_annotations) - for (path, leaf_val), (_, leaf_sharding), (_, leaf_logical_val) in zip(leaves_params, leaves_sharding, leaves_logical): - path_str = "/".join(str(p.key if hasattr(p, "key") else p.name) for p in path) - shape = jax.typeof(leaf_val) - pspec = sharding.remove_size_one_mesh_axis(leaf_sharding.spec, mesh) - pspec_str = str(tuple(pspec)) - logical_str = str(leaf_logical_val) - - message = f" {path_str}\n" f" Shape: {shape}\n" f" Logical: {logical_str}\n" f" Physical: {pspec_str}" - max_logging.info(message) + if logical_annotations is not None: + leaves_logical, _ = jax.tree_util.tree_flatten_with_path(logical_annotations) + for (path, leaf_val), (_, leaf_sharding), (_, leaf_logical_val) in zip( + leaves_params, leaves_sharding, leaves_logical + ): + path_str = "/".join(str(p.key if hasattr(p, "key") else p.name) for p in path) + shape = jax.typeof(leaf_val) + pspec = sharding.remove_size_one_mesh_axis(leaf_sharding.spec, mesh) + pspec_str = str(tuple(pspec)) + logical_str = str(leaf_logical_val) + + message = ( + f" {path_str}\n" f" Shape: {shape}\n" f" Logical: {logical_str}\n" f" Physical: {pspec_str}" + ) + max_logging.info(message) + else: + for (path, leaf_val), (_, leaf_sharding) in zip(leaves_params, leaves_sharding): + path_str = "/".join(str(p.key if hasattr(p, "key") else p.name) for p in path) + shape = jax.typeof(leaf_val) + pspec = sharding.remove_size_one_mesh_axis(leaf_sharding.spec, mesh) + pspec_str = str(tuple(pspec)) + + message = f" {path_str}\n" f" Shape: {shape}\n" f" Physical: {pspec_str}" + max_logging.info(message) print(flush=True) diff --git a/src/maxtext/utils/model_creation_utils.py b/src/maxtext/utils/model_creation_utils.py index 7647e86042..38be9af460 100644 --- a/src/maxtext/utils/model_creation_utils.py +++ b/src/maxtext/utils/model_creation_utils.py @@ -594,6 +594,13 @@ def from_pretrained( "Please check your load_parameters_path." ) + if metadata is None or metadata.item_metadata is None: + raise ValueError( + f"Cannot read checkpoint metadata from '{config.load_parameters_path}'. " + "The checkpoint directory may be empty or the save did not complete " + "(missing _CHECKPOINT_METADATA). Ensure the checkpoint save finished successfully." + ) + def _adjust_target_for_moe_fusion(target, meta_tree, is_nnx): if not hasattr(target, "items") or not hasattr(meta_tree, "items"): return target diff --git a/tests/post_training/unit/distillation_scheduling_test.py b/tests/post_training/unit/distillation_scheduling_test.py index 21e22839b4..24b9b6d721 100644 --- a/tests/post_training/unit/distillation_scheduling_test.py +++ b/tests/post_training/unit/distillation_scheduling_test.py @@ -412,9 +412,15 @@ def __call__(self, x): self.assertEqual(int(bundle.training_step[...]), 2) @mock.patch("maxtext.trainers.post_train.distillation.train_distill.optax.global_norm") - @mock.patch("maxtext.trainers.post_train.distillation.train_distill.nnx.value_and_grad") - def test_train_step_increments_and_passes_step(self, mock_value_and_grad, mock_global_norm): + @mock.patch("maxtext.trainers.post_train.distillation.train_distill.nnx.update") + @mock.patch("maxtext.trainers.post_train.distillation.train_distill.nnx.merge") + @mock.patch("maxtext.trainers.post_train.distillation.train_distill.nnx.split") + @mock.patch("maxtext.trainers.post_train.distillation.train_distill.jax.value_and_grad") + def test_train_step_increments_and_passes_step( + self, mock_value_and_grad, mock_split, mock_merge, mock_update, mock_global_norm + ): """_train_step passes pre-increment step to compute_loss and increments after.""" + del mock_merge, mock_update # pylint: disable=no-value-for-parameter trainer = train_distill.MaxTextDistillationTrainer.__new__(train_distill.MaxTextDistillationTrainer) trainer.strategy = mock.Mock() @@ -442,37 +448,54 @@ def test_train_step_increments_and_passes_step(self, mock_value_and_grad, mock_g # Simulate resume from step 5 model_bundle.training_step.set_value(jnp.array(5, dtype=jnp.int32)) - mock_grad_fn = mock.Mock(return_value=((mock.Mock(), {}), mock.Mock())) + # nnx.split returns (graphdef, diff_params, rest); loss_wrapper_pure takes (diff_params, rest). + mock_graphdef, mock_diff_params, mock_rest = mock.Mock(), mock.Mock(), mock.Mock() + mock_split.return_value = (mock_graphdef, mock_diff_params, mock_rest) + + # grad_fn returns ((loss, (aux, new_rest)), grads) + mock_grad_fn = mock.Mock(return_value=((mock.Mock(), ({}, mock.Mock())), mock.Mock())) mock_value_and_grad.return_value = mock_grad_fn mock_global_norm.return_value = mock.Mock() + trainer.strategy.compute_loss.return_value = (mock.Mock(), {}) trainer._train_step(model_bundle, optimizer, mock.Mock()) # Step should have incremented to 6 self.assertEqual(int(model_bundle.training_step[...]), 6) - # Trigger loss_wrapper to verify step=5 was passed to compute_loss + # Trigger loss_wrapper_pure to verify step=5 was passed to compute_loss. + # Signature is (diff_params, rest). loss_wrapper = mock_value_and_grad.call_args[0][0] - loss_wrapper(student_model, teacher_model, mock_batch) + loss_wrapper(mock_diff_params, mock_rest) call_kwargs = trainer.strategy.compute_loss.call_args self.assertIn("step", call_kwargs.kwargs) self.assertEqual(int(call_kwargs.kwargs["step"]), 5) @mock.patch("maxtext.trainers.post_train.distillation.train_distill.optax.global_norm") - @mock.patch("maxtext.trainers.post_train.distillation.train_distill.nnx.value_and_grad") - def test_consecutive_train_steps_increment(self, mock_value_and_grad, mock_global_norm): + @mock.patch("maxtext.trainers.post_train.distillation.train_distill.nnx.update") + @mock.patch("maxtext.trainers.post_train.distillation.train_distill.nnx.merge") + @mock.patch("maxtext.trainers.post_train.distillation.train_distill.nnx.split") + @mock.patch("maxtext.trainers.post_train.distillation.train_distill.jax.value_and_grad") + def test_consecutive_train_steps_increment( + self, mock_value_and_grad, mock_split, mock_merge, mock_update, mock_global_norm + ): """training_step increments 0→1→2→3 across consecutive _train_step calls.""" + del mock_merge, mock_update # pylint: disable=no-value-for-parameter trainer = train_distill.MaxTextDistillationTrainer.__new__(train_distill.MaxTextDistillationTrainer) trainer.strategy = mock.Mock() trainer.wrt_filter = lambda path, x: True # type: ignore + # Use a real DistillationForwardOutput so jax.tree.map(stop_gradient, ...) works. + fake_teacher_output = distillation_utils.DistillationForwardOutput( + logits=jnp.zeros((1, 2, 4)), out_projection_activations=None + ) mock_batch = { "input_tokens": mock.Mock(), "positions": mock.Mock(), "targets": mock.Mock(), - "teacher_output": mock.Mock(), + "teacher_output": fake_teacher_output, } trainer.gen_model_input_fn = mock.Mock(return_value=mock_batch) @@ -480,7 +503,10 @@ def test_consecutive_train_steps_increment(self, mock_value_and_grad, mock_globa model_bundle = train_distill.ModelBundle(teacher_model=teacher_model, student_model=student_model) optimizer = mock.Mock() - mock_grad_fn = mock.Mock(return_value=((mock.Mock(), {}), mock.Mock())) + mock_graphdef, mock_diff_params, mock_rest = mock.Mock(), mock.Mock(), mock.Mock() + mock_split.return_value = (mock_graphdef, mock_diff_params, mock_rest) + + mock_grad_fn = mock.Mock(return_value=((mock.Mock(), ({}, mock.Mock())), mock.Mock())) mock_value_and_grad.return_value = mock_grad_fn mock_global_norm.return_value = mock.Mock() diff --git a/tests/post_training/unit/train_distill_test.py b/tests/post_training/unit/train_distill_test.py index ca2cbfa91f..b5a8a090b6 100644 --- a/tests/post_training/unit/train_distill_test.py +++ b/tests/post_training/unit/train_distill_test.py @@ -162,9 +162,12 @@ def test_prepare_inputs_logic(self): @mock.patch("maxtext.trainers.post_train.distillation.train_distill.optax.global_norm") @mock.patch("maxtext.trainers.post_train.distillation.train_distill.jax.tree.map") - @mock.patch("maxtext.trainers.post_train.distillation.train_distill.nnx.value_and_grad") + @mock.patch("maxtext.trainers.post_train.distillation.train_distill.nnx.update") + @mock.patch("maxtext.trainers.post_train.distillation.train_distill.nnx.merge") + @mock.patch("maxtext.trainers.post_train.distillation.train_distill.nnx.split") + @mock.patch("maxtext.trainers.post_train.distillation.train_distill.jax.value_and_grad") def test_train_step_skips_teacher_forward_when_output_present( - self, mock_value_and_grad, mock_tree_map, mock_global_norm + self, mock_value_and_grad, mock_split, mock_merge, mock_update, mock_tree_map, mock_global_norm ): """Verifies teacher forward is skipped when model_output is already in the batch.""" # 1. Initialize Trainer @@ -189,21 +192,28 @@ def test_train_step_skips_teacher_forward_when_output_present( model_bundle = train_distill.ModelBundle(teacher_model=teacher_model, student_model=student_model) optimizer, inputs = mock.Mock(), mock.Mock() - # 4. Configure mocked nnx.value_and_grad + # 4. Configure nnx.split/merge/update mocks + mock_graphdef, mock_diff_params, mock_rest = mock.Mock(), mock.Mock(), mock.Mock() + mock_split.return_value = (mock_graphdef, mock_diff_params, mock_rest) + + # 5. Configure mocked jax.value_and_grad + # _train_step uses: (loss, (aux, new_rest)), grads = grad_fn(diff_params, rest) mock_loss, mock_aux, mock_grads = mock.Mock(), {}, mock.Mock() - mock_grad_fn = mock.Mock(return_value=((mock_loss, mock_aux), mock_grads)) + mock_grad_fn = mock.Mock(return_value=((mock_loss, (mock_aux, mock.Mock())), mock_grads)) mock_value_and_grad.return_value = mock_grad_fn mock_global_norm.return_value = mock.Mock() + trainer.strategy.compute_loss.return_value = (mock.Mock(), {}) - # 5. Execute outer function & trigger inner loss_wrapper + # 6. Execute outer function & trigger inner loss_wrapper_pure trainer._train_step(model_bundle, optimizer, inputs) loss_wrapper = mock_value_and_grad.call_args[0][0] - loss_wrapper(student_model, teacher_model, mock_batch) + # loss_wrapper_pure signature is (diff_params, rest), not (student, teacher, batch) + loss_wrapper(mock_diff_params, mock_rest) - # 6. Assertions + # 7. Assertions trainer.strategy.teacher_forward_fn.assert_not_called() trainer.strategy.student_forward_fn.assert_called_once_with( - model=student_model, + model=mock.ANY, # local_student from nnx.merge, not the original student_model input_tokens=mock_batch["input_tokens"], positions=mock_batch["positions"], attention_mask=mock_batch["attention_mask"], @@ -215,9 +225,12 @@ def test_train_step_skips_teacher_forward_when_output_present( @mock.patch("maxtext.trainers.post_train.distillation.train_distill.optax.global_norm") @mock.patch("maxtext.trainers.post_train.distillation.train_distill.jax.tree.map") - @mock.patch("maxtext.trainers.post_train.distillation.train_distill.nnx.value_and_grad") + @mock.patch("maxtext.trainers.post_train.distillation.train_distill.nnx.update") + @mock.patch("maxtext.trainers.post_train.distillation.train_distill.nnx.merge") + @mock.patch("maxtext.trainers.post_train.distillation.train_distill.nnx.split") + @mock.patch("maxtext.trainers.post_train.distillation.train_distill.jax.value_and_grad") def test_train_step_calls_teacher_forward_when_output_missing( - self, mock_value_and_grad, mock_tree_map, mock_global_norm + self, mock_value_and_grad, mock_split, mock_merge, mock_update, mock_tree_map, mock_global_norm ): """Verifies teacher forward is called when model_output is missing from the batch.""" # 1. Initialize Trainer @@ -242,19 +255,27 @@ def test_train_step_calls_teacher_forward_when_output_missing( model_bundle = train_distill.ModelBundle(teacher_model=teacher_model, student_model=student_model) optimizer, inputs = mock.Mock(), mock.Mock() - # 4. Configure mocked nnx.value_and_grad + # 4. Configure nnx.split/merge/update mocks + mock_graphdef, mock_diff_params, mock_rest = mock.Mock(), mock.Mock(), mock.Mock() + mock_split.return_value = (mock_graphdef, mock_diff_params, mock_rest) + + # 5. Configure mocked jax.value_and_grad + # _train_step uses: (loss, (aux, new_rest)), grads = grad_fn(diff_params, rest) mock_loss, mock_aux, mock_grads = mock.Mock(), {}, mock.Mock() - mock_grad_fn = mock.Mock(return_value=((mock_loss, mock_aux), mock_grads)) + mock_grad_fn = mock.Mock(return_value=((mock_loss, (mock_aux, mock.Mock())), mock_grads)) mock_value_and_grad.return_value = mock_grad_fn mock_gn = mock.Mock() mock_global_norm.return_value = mock_gn + trainer.strategy.compute_loss.return_value = (mock.Mock(), {}) - # 5. Execute outer function & trigger inner loss_wrapper + # 6. Execute outer function & trigger inner loss_wrapper_pure train_step_out = trainer._train_step(model_bundle, optimizer, inputs) loss_wrapper = mock_value_and_grad.call_args[0][0] - loss_wrapper(student_model, teacher_model, mock_batch) + # loss_wrapper_pure signature is (diff_params, rest), not (student, teacher, batch) + loss_wrapper(mock_diff_params, mock_rest) - # 6. Assertions + # 7. Assertions + # Teacher forward is called OUTSIDE value_and_grad in _train_step trainer.strategy.teacher_forward_fn.assert_called_once_with( model=teacher_model, input_tokens=mock_batch["input_tokens"], @@ -266,8 +287,9 @@ def test_train_step_calls_teacher_forward_when_output_missing( decoder_target_mask=None, ) + # Student forward is called INSIDE loss_wrapper_pure via nnx.merge'd local_student trainer.strategy.student_forward_fn.assert_called_once_with( - model=student_model, + model=mock.ANY, # local_student from nnx.merge, not the original student_model input_tokens=mock_batch["input_tokens"], positions=mock_batch["positions"], attention_mask=mock_batch["attention_mask"], @@ -291,8 +313,13 @@ def test_train_step_calls_teacher_forward_when_output_missing( @mock.patch("maxtext.trainers.post_train.distillation.train_distill.optax.global_norm") @mock.patch("maxtext.trainers.post_train.distillation.train_distill.jax.tree.map") - @mock.patch("maxtext.trainers.post_train.distillation.train_distill.nnx.value_and_grad") - def test_train_step_passes_targets_segmentation(self, mock_value_and_grad, mock_tree_map, mock_global_norm): + @mock.patch("maxtext.trainers.post_train.distillation.train_distill.nnx.update") + @mock.patch("maxtext.trainers.post_train.distillation.train_distill.nnx.merge") + @mock.patch("maxtext.trainers.post_train.distillation.train_distill.nnx.split") + @mock.patch("maxtext.trainers.post_train.distillation.train_distill.jax.value_and_grad") + def test_train_step_passes_targets_segmentation( + self, mock_value_and_grad, mock_split, mock_merge, mock_update, mock_tree_map, mock_global_norm + ): """Verifies strategy callbacks receive decoder_target_tokens and decoder_target_mask.""" # 1. Initialize Trainer # pylint: disable=no-value-for-parameter @@ -317,22 +344,30 @@ def test_train_step_passes_targets_segmentation(self, mock_value_and_grad, mock_ model_bundle = train_distill.ModelBundle(teacher_model=teacher_model, student_model=student_model) optimizer, inputs = mock.Mock(), mock.Mock() - # 4. Configure mocked nnx.value_and_grad - mock_grad_fn = mock.Mock(return_value=((mock.Mock(), {}), mock.Mock())) + # 4. Configure nnx.split/merge/update mocks + mock_graphdef, mock_diff_params, mock_rest = mock.Mock(), mock.Mock(), mock.Mock() + mock_split.return_value = (mock_graphdef, mock_diff_params, mock_rest) + + # 5. Configure mocked jax.value_and_grad + # _train_step uses: (loss, (aux, new_rest)), grads = grad_fn(diff_params, rest) + mock_grad_fn = mock.Mock(return_value=((mock.Mock(), ({}, mock.Mock())), mock.Mock())) mock_value_and_grad.return_value = mock_grad_fn mock_global_norm.return_value = mock.Mock() + trainer.strategy.compute_loss.return_value = (mock.Mock(), {}) - # 5. Execute outer function & trigger inner loss_wrapper + # 6. Execute outer function & trigger inner loss_wrapper_pure trainer._train_step(model_bundle, optimizer, inputs) loss_wrapper = mock_value_and_grad.call_args[0][0] - loss_wrapper(student_model, teacher_model, mock_batch) + # loss_wrapper_pure signature is (diff_params, rest), not (student, teacher, batch) + loss_wrapper(mock_diff_params, mock_rest) - # 6. Assertions + # 7. Assertions trainer.strategy.create_labels.assert_called_once_with( mock_batch["targets"], targets_segmentation=mock_targets_segmentation ) + # Student forward is called INSIDE loss_wrapper_pure via nnx.merge'd local_student trainer.strategy.student_forward_fn.assert_called_once_with( - model=student_model, + model=mock.ANY, # local_student from nnx.merge, not the original student_model input_tokens=mock_batch["input_tokens"], positions=mock_batch["positions"], attention_mask=mock_batch["attention_mask"], @@ -341,6 +376,7 @@ def test_train_step_passes_targets_segmentation(self, mock_value_and_grad, mock_ decoder_target_mask=mock_targets_segmentation, cache=None, ) + # Teacher forward is called OUTSIDE value_and_grad in _train_step trainer.strategy.teacher_forward_fn.assert_called_once_with( model=teacher_model, input_tokens=mock_batch["input_tokens"], diff --git a/tests/unit/compare_linen_nnx_checkpoint_test.py b/tests/unit/compare_linen_nnx_checkpoint_test.py new file mode 100644 index 0000000000..d3d49e6a63 --- /dev/null +++ b/tests/unit/compare_linen_nnx_checkpoint_test.py @@ -0,0 +1,501 @@ +# Copyright 2023-2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for compare_linen_nnx_checkpoint utilities.""" + +import io +import unittest +from unittest.mock import patch +import numpy as np + +from absl import flags as absl_flags +from maxtext.checkpoint_conversion.compare_linen_nnx_checkpoint import ( + is_rng_path, + filter_rngs, + detect_format, + _has_value_wrappers, + _strip_value_wrappers, + _normalize_linen_params, + _normalize_nnx_params, + _extract_params, + _normalize_params, + get_tree_structure_info, + print_structure_diff, + compare_params, + transform_nnx_params_for_comparison, +) + + +def _arr(*shape): + """Helper: float32 array of given shape, values 0..prod(shape)-1.""" + return np.arange(int(np.prod(shape)), dtype=np.float32).reshape(shape) + + +def setUpModule(): + # Mark FLAGS as parsed so FLAGS.verbose etc. are accessible without a full + # app.run(). Required flags (ckpt_path_1/2) are not needed in unit tests. + absl_flags.FLAGS.mark_as_parsed() + + +# --------------------------------------------------------------------------- +# is_rng_path +# --------------------------------------------------------------------------- + + +class TestIsRngPath(unittest.TestCase): + """Tests for is_rng_path.""" + + def test_returns_true_for_rngs(self): + self.assertTrue(is_rng_path("model/decoder/rngs/dropout")) + + def test_returns_true_for_rng(self): + self.assertTrue(is_rng_path("model/rngs/params/key")) + + def test_returns_true_case_insensitive(self): + self.assertTrue(is_rng_path("model/RNGs/state")) + self.assertTrue(is_rng_path("model/RNG/state")) + + def test_returns_false_for_normal_path(self): + self.assertFalse(is_rng_path("model/decoder/layers/kernel")) + + def test_returns_false_for_empty_string(self): + self.assertFalse(is_rng_path("")) + + +# --------------------------------------------------------------------------- +# filter_rngs +# --------------------------------------------------------------------------- + + +class TestFilterRngs(unittest.TestCase): + """Tests for filter_rngs.""" + + def test_removes_top_level_rngs_key(self): + tree = {"model": {"kernel": _arr(4)}, "rngs": {"dropout": _arr(2)}} + result = filter_rngs(tree) + self.assertNotIn("rngs", result) + self.assertIn("model", result) + + def test_removes_nested_rngs_key(self): + tree = {"model": {"kernel": _arr(4), "rngs": {"key": _arr(2)}}} + result = filter_rngs(tree) + self.assertNotIn("rngs", result["model"]) + self.assertIn("kernel", result["model"]) + + def test_keeps_empty_parent_when_only_child_is_rng(self): + # After filtering, the parent dict becomes empty and is dropped. + tree = {"model": {"rngs": {"key": _arr(2)}}} + result = filter_rngs(tree) + self.assertNotIn("model", result) + + def test_passthrough_for_non_rng_tree(self): + tree = {"params": {"kernel": _arr(4), "bias": _arr(2)}} + result = filter_rngs(tree) + self.assertEqual(set(result.keys()), {"params"}) + + def test_passthrough_for_non_dict_input(self): + arr = _arr(4) + self.assertIs(filter_rngs(arr), arr) + + +# --------------------------------------------------------------------------- +# _has_value_wrappers +# --------------------------------------------------------------------------- + + +class TestHasValueWrappers(unittest.TestCase): + """Tests for _has_value_wrappers.""" + + def test_returns_true_for_direct_value_wrapper(self): + tree = {"value": _arr(3, 4)} + self.assertTrue(_has_value_wrappers(tree)) + + def test_returns_true_for_nested_wrapper(self): + tree = {"decoder": {"kernel": {"value": _arr(2, 2)}}} + self.assertTrue(_has_value_wrappers(tree)) + + def test_returns_false_for_plain_array(self): + self.assertFalse(_has_value_wrappers(_arr(3))) + + def test_returns_false_for_multi_key_dict(self): + tree = {"value": _arr(2), "extra": _arr(2)} + self.assertFalse(_has_value_wrappers(tree)) + + def test_returns_false_for_value_key_with_non_array(self): + tree = {"value": 42} + self.assertFalse(_has_value_wrappers(tree)) + + +# --------------------------------------------------------------------------- +# _strip_value_wrappers +# --------------------------------------------------------------------------- + + +class TestStripValueWrappers(unittest.TestCase): + """Tests for _strip_value_wrappers.""" + + def test_strips_direct_wrapper(self): + arr = _arr(3, 4) + result = _strip_value_wrappers({"value": arr}) + np.testing.assert_array_equal(result, arr) + + def test_strips_nested_wrappers(self): + arr = _arr(2, 2) + tree = {"decoder": {"kernel": {"value": arr}}} + result = _strip_value_wrappers(tree) + np.testing.assert_array_equal(result["decoder"]["kernel"], arr) + + def test_passthrough_plain_array(self): + arr = _arr(4) + self.assertIs(_strip_value_wrappers(arr), arr) + + def test_handles_list(self): + arr = _arr(2) + result = _strip_value_wrappers([{"value": arr}]) + np.testing.assert_array_equal(result[0], arr) + + def test_handles_tuple(self): + arr = _arr(2) + result = _strip_value_wrappers(({"value": arr},)) + np.testing.assert_array_equal(result[0], arr) + + def test_passthrough_non_array_scalar(self): + self.assertEqual(_strip_value_wrappers(42), 42) + + +# --------------------------------------------------------------------------- +# _normalize_linen_params +# --------------------------------------------------------------------------- + + +class TestNormalizeLinenParams(unittest.TestCase): + """Tests for _normalize_linen_params.""" + + def test_removes_double_nesting(self): + inner = {"decoder": {"layers": {}}} + params = {"params": inner} + result = _normalize_linen_params(params) + self.assertIs(result, inner) + + def test_removes_double_nesting_encoder(self): + inner = {"encoder": {"layers": {}}} + params = {"params": inner} + result = _normalize_linen_params(params) + self.assertIs(result, inner) + + def test_passthrough_when_no_double_nesting(self): + params = {"decoder": {"layers": {}}} + result = _normalize_linen_params(params) + self.assertIs(result, params) + + def test_passthrough_when_inner_has_no_decoder_encoder(self): + params = {"params": {"other_key": {}}} + result = _normalize_linen_params(params) + self.assertIs(result, params) + + +# --------------------------------------------------------------------------- +# _normalize_nnx_params +# --------------------------------------------------------------------------- + + +class TestNormalizeNnxParams(unittest.TestCase): + """Tests for _normalize_nnx_params.""" + + def test_strips_value_wrappers(self): + arr = _arr(2, 3) + params = {"decoder": {"kernel": {"value": arr}}} + result = _normalize_nnx_params(params) + np.testing.assert_array_equal(result["decoder"]["kernel"], arr) + + def test_passthrough_plain_tree(self): + arr = _arr(4) + params = {"decoder": {"kernel": arr}} + result = _normalize_nnx_params(params) + np.testing.assert_array_equal(result["decoder"]["kernel"], arr) + + +# --------------------------------------------------------------------------- +# detect_format +# --------------------------------------------------------------------------- + + +class TestDetectFormat(unittest.TestCase): + """Tests for detect_format.""" + + def test_detects_nnx_via_model_key(self): + state = {"model": {"decoder": {}}, "optimizer": {}} + self.assertEqual(detect_format(state), "nnx") + + def test_detects_linen_via_double_nested_decoder(self): + state = {"params": {"params": {"decoder": {}}}} + self.assertEqual(detect_format(state), "linen") + + def test_detects_linen_via_double_nested_encoder(self): + state = {"params": {"params": {"encoder": {}}}} + self.assertEqual(detect_format(state), "linen") + + def test_detects_nnx_via_value_wrappers(self): + arr = _arr(2, 2) + state = {"params": {"decoder": {"kernel": {"value": arr}}}} + self.assertEqual(detect_format(state), "nnx") + + def test_raises_when_no_params_or_model_key(self): + with self.assertRaises(ValueError): + detect_format({"step": 0}) + + def test_raises_on_undetectable_format(self): + with self.assertRaises(ValueError): + detect_format({"params": {"unknown_key": {}}}) + + +# --------------------------------------------------------------------------- +# _extract_params +# --------------------------------------------------------------------------- + + +class TestExtractParams(unittest.TestCase): + """Tests for _extract_params.""" + + def test_extracts_linen_params(self): + params = {"params": {"decoder": {}}} + state = {"params": params, "opt_state": {}} + self.assertIs(_extract_params(state, "linen"), params) + + def test_extracts_nnx_params_from_model_key(self): + model = {"decoder": {}} + state = {"model": model, "optimizer": {}} + self.assertIs(_extract_params(state, "nnx"), model) + + def test_extracts_nnx_params_falls_back_to_params_key(self): + params = {"decoder": {}} + state = {"params": params} + self.assertIs(_extract_params(state, "nnx"), params) + + def test_returns_empty_dict_when_key_missing(self): + state = {"optimizer": {}} + result = _extract_params(state, "linen") + self.assertEqual(result, {}) + + +# --------------------------------------------------------------------------- +# _normalize_params +# --------------------------------------------------------------------------- + + +class TestNormalizeParams(unittest.TestCase): + """Tests for _normalize_params.""" + + def test_dispatches_to_linen(self): + inner = {"decoder": {}} + params = {"params": inner} + result = _normalize_params(params, "linen") + self.assertIs(result, inner) + + def test_dispatches_to_nnx(self): + arr = _arr(2, 2) + params = {"decoder": {"kernel": {"value": arr}}} + result = _normalize_params(params, "nnx") + np.testing.assert_array_equal(result["decoder"]["kernel"], arr) + + +# --------------------------------------------------------------------------- +# get_tree_structure_info +# --------------------------------------------------------------------------- + + +class TestGetTreeStructureInfo(unittest.TestCase): + """Tests for get_tree_structure_info.""" + + def test_returns_shape_and_dtype(self): + tree = {"kernel": _arr(3, 4), "bias": _arr(4)} + info = get_tree_structure_info(tree) + self.assertEqual(info["['kernel']"], ((3, 4), "float32")) + self.assertEqual(info["['bias']"], ((4,), "float32")) + + def test_handles_nested_tree(self): + tree = {"decoder": {"kernel": _arr(2, 2)}} + info = get_tree_structure_info(tree) + self.assertEqual(len(info), 1) + shapes = [v[0] for v in info.values()] + self.assertIn((2, 2), shapes) + + def test_handles_non_array_leaves(self): + tree = {"step": 5} + info = get_tree_structure_info(tree) + self.assertEqual(len(info), 1) + shape, _ = list(info.values())[0] + self.assertEqual(shape, "N/A") + + +# --------------------------------------------------------------------------- +# print_structure_diff +# --------------------------------------------------------------------------- + + +class TestPrintStructureDiff(unittest.TestCase): + """Tests for print_structure_diff.""" + + def _make_params(self, keys_and_shapes): + return {k: _arr(*s) for k, s in keys_and_shapes.items()} + + def test_returns_empty_tuples_when_identical(self): + params = self._make_params({"kernel": (4, 4), "bias": (4,)}) + with patch("sys.stdout", new_callable=io.StringIO): + only1, only2, shape_mm, dtype_mm = print_structure_diff(params, params) + self.assertEqual(only1, []) + self.assertEqual(only2, []) + self.assertEqual(shape_mm, []) + self.assertEqual(dtype_mm, []) + + def test_detects_key_only_in_first(self): + p1 = self._make_params({"kernel": (4, 4), "bias": (4,)}) + p2 = self._make_params({"kernel": (4, 4)}) + with patch("sys.stdout", new_callable=io.StringIO): + only1, only2, _, _ = print_structure_diff(p1, p2) + self.assertEqual(len(only1), 1) + self.assertEqual(only2, []) + + def test_detects_key_only_in_second(self): + p1 = self._make_params({"kernel": (4, 4)}) + p2 = self._make_params({"kernel": (4, 4), "bias": (4,)}) + with patch("sys.stdout", new_callable=io.StringIO): + only1, only2, _, _ = print_structure_diff(p1, p2) + self.assertEqual(only1, []) + self.assertEqual(len(only2), 1) + + def test_detects_shape_mismatch(self): + p1 = {"kernel": _arr(4, 4)} + p2 = {"kernel": _arr(4, 8)} + with patch("sys.stdout", new_callable=io.StringIO): + _, _, shape_mm, _ = print_structure_diff(p1, p2) + self.assertEqual(len(shape_mm), 1) + + def test_detects_dtype_mismatch(self): + p1 = {"kernel": np.zeros((4,), dtype=np.float32)} + p2 = {"kernel": np.zeros((4,), dtype=np.float16)} + with patch("sys.stdout", new_callable=io.StringIO): + _, _, _, dtype_mm = print_structure_diff(p1, p2) + self.assertEqual(len(dtype_mm), 1) + + +# --------------------------------------------------------------------------- +# compare_params +# --------------------------------------------------------------------------- + + +class TestCompareParams(unittest.TestCase): + """Tests for compare_params.""" + + def test_returns_true_for_identical_params(self): + params = {"kernel": _arr(4, 4), "bias": _arr(4)} + with patch("builtins.print"): + result = compare_params(params, params) + self.assertTrue(result) + + def test_returns_false_for_different_structures(self): + p1 = {"kernel": _arr(4, 4)} + p2 = {"kernel": _arr(4, 4), "bias": _arr(4)} + with patch("builtins.print"): + result = compare_params(p1, p2) + self.assertFalse(result) + + def test_returns_false_for_shape_mismatch(self): + p1 = {"kernel": _arr(4, 4)} + p2 = {"kernel": _arr(4, 8)} + with patch("builtins.print"): + result = compare_params(p1, p2) + self.assertFalse(result) + + def test_returns_false_for_dtype_mismatch(self): + p1 = {"kernel": np.zeros((4,), dtype=np.float32)} + p2 = {"kernel": np.zeros((4,), dtype=np.float16)} + with patch("builtins.print"): + result = compare_params(p1, p2) + self.assertFalse(result) + + def test_value_comparison_passes_when_equal(self): + arr = _arr(4) + with patch("builtins.print"): + result = compare_params({"w": arr}, {"w": arr.copy()}, compare_values=True) + self.assertTrue(result) + + def test_value_comparison_fails_when_different(self): + p1 = {"w": np.array([1.0, 2.0], dtype=np.float32)} + p2 = {"w": np.array([1.0, 9.0], dtype=np.float32)} + with patch("builtins.print"): + result = compare_params(p1, p2, compare_values=True, atol=1e-5, rtol=1e-5) + self.assertFalse(result) + + def test_value_comparison_passes_within_tolerance(self): + p1 = {"w": np.array([1.0], dtype=np.float32)} + p2 = {"w": np.array([1.0 + 1e-7], dtype=np.float32)} + with patch("builtins.print"): + result = compare_params(p1, p2, compare_values=True, atol=1e-5, rtol=1e-5) + self.assertTrue(result) + + def test_verbose_mode_does_not_raise(self): + params = {"kernel": _arr(2, 2)} + with patch("builtins.print"): + result = compare_params(params, params, verbose=True, compare_values=True) + self.assertTrue(result) + + def test_nested_params(self): + params = {"decoder": {"kernel": _arr(4, 4), "bias": _arr(4)}} + with patch("builtins.print"): + result = compare_params(params, params) + self.assertTrue(result) + + +# --------------------------------------------------------------------------- +# transform_nnx_params_for_comparison +# --------------------------------------------------------------------------- + + +class TestTransformNnxParamsForComparison(unittest.TestCase): + """Tests for transform_nnx_params_for_comparison.""" + + def test_transposes_layer_array(self): + # Shape (num_layers=3, d=4) -> (d=4, num_layers=3) + arr = _arr(3, 4) + tree = {"layers": {"kernel": arr}} + with patch("builtins.print"): + result = transform_nnx_params_for_comparison(tree) + self.assertEqual(result["layers"]["kernel"].shape, (4, 3)) + + def test_does_not_transpose_non_layer_array(self): + arr = _arr(3, 4) + tree = {"embedding": arr} + with patch("builtins.print"): + result = transform_nnx_params_for_comparison(tree) + self.assertEqual(result["embedding"].shape, (3, 4)) + + def test_does_not_transpose_1d_layer_array(self): + arr = _arr(4) + tree = {"layers": {"bias": arr}} + with patch("builtins.print"): + result = transform_nnx_params_for_comparison(tree) + self.assertEqual(result["layers"]["bias"].shape, (4,)) + + def test_transposes_higher_rank_layer_array(self): + # Shape (num_layers=2, d1=3, d2=5) -> (d1=3, num_layers=2, d2=5) + arr = _arr(2, 3, 5) + tree = {"layers": {"kernel": arr}} + with patch("builtins.print"): + result = transform_nnx_params_for_comparison(tree) + self.assertEqual(result["layers"]["kernel"].shape, (3, 2, 5)) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/unit/linen_nnx_converter_test.py b/tests/unit/linen_nnx_converter_test.py new file mode 100644 index 0000000000..808990f8cf --- /dev/null +++ b/tests/unit/linen_nnx_converter_test.py @@ -0,0 +1,869 @@ +# Copyright 2023–2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for linen_nnx_converter utilities.""" + +import unittest +import numpy as np +from unittest.mock import MagicMock, patch + +from maxtext.checkpoint_conversion.linen_nnx_converter import ( + detect_format, + _has_value_wrappers, + _strip_value_wrappers, + _add_value_wrappers, + _transpose_layers_axes, + _stack_layers, + convert_linen_to_nnx, + convert_nnx_to_linen, + _convert_opt_state_linen_to_nnx, + _convert_opt_state_nnx_to_linen, + load_checkpoint, + save_checkpoint, + main, +) + + +def _make_array(*shape): + """Helper to create a numpy array with given shape.""" + return np.arange(np.prod(shape), dtype=np.float32).reshape(shape) + + +class TestDetectFormat(unittest.TestCase): + """Tests for the detect_format function.""" + + def test_raises_when_no_params_key(self): + with self.assertRaises(ValueError): + detect_format({"step": 0}) + + def test_detects_nnx_format_via_model_key(self): + # NNX: top-level "model" key + state = {"model": {"decoder": {"layers": {}}}, "optimizer": {}} + self.assertEqual(detect_format(state), "nnx") + + def test_detects_linen_format_double_nested(self): + state = {"params": {"params": {"decoder": {"layers": {}}}}} + self.assertEqual(detect_format(state), "linen") + + def test_detects_nnx_format_single_nested_with_value_wrappers(self): + # Old NNX format: params/decoder with {value:} wrappers + arr = _make_array(2, 2) + state = {"params": {"decoder": {"kernel": {"value": arr}}}} + self.assertEqual(detect_format(state), "nnx") + + def test_detects_linen_via_encoder(self): + state = {"params": {"params": {"encoder": {"layers": {}}}}} + self.assertEqual(detect_format(state), "linen") + + def test_detects_nnx_via_encoder_with_value_wrappers(self): + arr = _make_array(2, 2) + state = {"params": {"encoder": {"kernel": {"value": arr}}}} + self.assertEqual(detect_format(state), "nnx") + + def test_detects_nnx_via_optimizer_key(self): + arr = _make_array(2, 2) + state = {"params": {"something": arr}, "optimizer": {"step": 0}} + self.assertEqual(detect_format(state), "nnx") + + def test_detects_linen_via_opt_state(self): + arr = _make_array(2, 2) + state = { + "params": {"something": arr}, + "opt_state": {"params": {"mu": {"decoder": {"kernel": arr}}}}, + } + self.assertEqual(detect_format(state), "linen") + + def test_detects_nnx_via_optimizer_over_opt_state(self): + # "optimizer" key takes precedence for NNX detection + arr = _make_array(2, 2) + state = { + "params": {"something": arr}, + "optimizer": {"step": 0, "opt_state": {}}, + } + self.assertEqual(detect_format(state), "nnx") + + def test_raises_on_undetectable_format(self): + state = {"params": {"some_unknown_key": 42}} + with self.assertRaises(ValueError): + detect_format(state) + + +class TestHasValueWrappers(unittest.TestCase): + """Tests for the _has_value_wrappers helper.""" + + def test_returns_true_for_value_wrapper(self): + arr = _make_array(2, 2) + self.assertTrue(_has_value_wrappers({"value": arr})) + + def test_returns_true_for_nested_value_wrapper(self): + arr = _make_array(2, 2) + self.assertTrue(_has_value_wrappers({"mu": {"value": arr}})) + + def test_returns_false_for_plain_array(self): + # A plain array is not a {"value": ...} wrapper dict + self.assertFalse(_has_value_wrappers(_make_array(2, 2))) + + def test_returns_false_for_multi_key_dict(self): + arr = _make_array(2, 2) + self.assertFalse(_has_value_wrappers({"value": arr, "extra": arr})) + + def test_returns_false_for_non_array_value(self): + self.assertFalse(_has_value_wrappers({"value": "string"})) + + +class TestStripValueWrappers(unittest.TestCase): + """Tests for the _strip_value_wrappers helper.""" + + def test_strips_single_wrapper(self): + arr = _make_array(3, 4) + result = _strip_value_wrappers({"value": arr}) + np.testing.assert_array_equal(result, arr) + + def test_strips_nested_wrappers(self): + arr = _make_array(2, 2) + wrapped = {"decoder": {"layers": {"kernel": {"value": arr}}}} + stripped = _strip_value_wrappers(wrapped) + np.testing.assert_array_equal(stripped["decoder"]["layers"]["kernel"], arr) + + def test_passes_through_plain_array(self): + arr = _make_array(2, 3) + result = _strip_value_wrappers(arr) + np.testing.assert_array_equal(result, arr) + + def test_handles_list_and_tuple(self): + arr = _make_array(2) + result_list = _strip_value_wrappers([{"value": arr}]) + result_tuple = _strip_value_wrappers(({"value": arr},)) + np.testing.assert_array_equal(result_list[0], arr) + np.testing.assert_array_equal(result_tuple[0], arr) + + def test_passes_through_non_array_value(self): + # A dict with key "value" but scalar content should not be unwrapped + d = {"value": 42} + result = _strip_value_wrappers(d) + self.assertEqual(result, d) + + +class TestAddValueWrappers(unittest.TestCase): + """Tests for the _add_value_wrappers helper.""" + + def test_wraps_array(self): + arr = _make_array(3, 4) + result = _add_value_wrappers(arr) + self.assertIsInstance(result, dict) + self.assertIn("value", result) + np.testing.assert_array_equal(result["value"], arr) + + def test_wraps_nested_arrays(self): + arr = _make_array(2, 2) + nested = {"decoder": {"layers": {"kernel": arr}}} + wrapped = _add_value_wrappers(nested) + self.assertEqual(set(wrapped["decoder"]["layers"]["kernel"].keys()), {"value"}) + np.testing.assert_array_equal(wrapped["decoder"]["layers"]["kernel"]["value"], arr) + + def test_idempotent_on_already_wrapped(self): + arr = _make_array(2) + already_wrapped = {"value": arr} + result = _add_value_wrappers(already_wrapped) + # Should not double-wrap + self.assertEqual(set(result.keys()), {"value"}) + np.testing.assert_array_equal(result["value"], arr) + + def test_handles_list_and_tuple(self): + arr = _make_array(2) + result_list = _add_value_wrappers([arr]) + result_tuple = _add_value_wrappers((arr,)) + self.assertEqual(set(result_list[0].keys()), {"value"}) + self.assertEqual(set(result_tuple[0].keys()), {"value"}) + + def test_passes_through_non_array_scalars(self): + result = _add_value_wrappers(42) + self.assertEqual(result, 42) + result_str = _add_value_wrappers("text") + self.assertEqual(result_str, "text") + + +class TestTransposeLayersAxes(unittest.TestCase): + """Tests for the _transpose_layers_axes helper.""" + + def test_noop_when_same_axis(self): + arr = _make_array(4, 2, 3) + result = _transpose_layers_axes(arr, src_axis=0, dst_axis=0) + np.testing.assert_array_equal(result, arr) + + def test_transposes_axis_0_to_1(self): + arr = _make_array(4, 2, 3) + result = _transpose_layers_axes(arr, src_axis=0, dst_axis=1) + self.assertEqual(result.shape, (2, 4, 3)) + + def test_transposes_axis_1_to_0(self): + arr = _make_array(2, 4, 3) + result = _transpose_layers_axes(arr, src_axis=1, dst_axis=0) + self.assertEqual(result.shape, (4, 2, 3)) + + def test_transposes_nested_dict(self): + arr = _make_array(4, 2, 3) + tree = {"decoder": {"layers": {"kernel": arr}}} + result = _transpose_layers_axes(tree, src_axis=0, dst_axis=1) + self.assertEqual(result["decoder"]["layers"]["kernel"].shape, (2, 4, 3)) + + def test_passes_through_1d_array(self): + arr = _make_array(5) + result = _transpose_layers_axes(arr, src_axis=0, dst_axis=1) + # 1D array has no axis 1, should be returned unchanged + np.testing.assert_array_equal(result, arr) + + def test_handles_list(self): + arr = _make_array(4, 2, 3) + result = _transpose_layers_axes([arr], src_axis=0, dst_axis=1) + self.assertIsInstance(result, list) + self.assertEqual(result[0].shape, (2, 4, 3)) + + def test_handles_tuple(self): + arr = _make_array(4, 2, 3) + result = _transpose_layers_axes((arr,), src_axis=0, dst_axis=1) + self.assertIsInstance(result, tuple) + self.assertEqual(result[0].shape, (2, 4, 3)) + + +class TestStackLayers(unittest.TestCase): + """Tests for the _stack_layers helper.""" + + def test_stacks_individual_layers(self): + arr0 = _make_array(3, 4) + arr1 = _make_array(3, 4) + decoder = { + "layers_0": {"mlp": {"kernel": arr0}}, + "layers_1": {"mlp": {"kernel": arr1}}, + } + result, was_stacked = _stack_layers(decoder) + self.assertTrue(was_stacked) + self.assertIn("layers", result) + stacked = result["layers"]["mlp"]["kernel"] + self.assertEqual(stacked.shape, (2, 3, 4)) + np.testing.assert_array_equal(stacked[0], arr0) + np.testing.assert_array_equal(stacked[1], arr1) + + def test_noop_when_no_layer_pattern(self): + arr = _make_array(3, 4) + decoder = {"layers": {"mlp": {"kernel": arr}}} + result, was_stacked = _stack_layers(decoder) + self.assertFalse(was_stacked) + self.assertIs(result, decoder) + + def test_preserves_non_layer_keys(self): + norm_weight = _make_array(4) + arr0 = _make_array(3, 4) + decoder = { + "layers_0": {"mlp": {"kernel": arr0}}, + "final_norm": {"scale": norm_weight}, + } + result, was_stacked = _stack_layers(decoder) + self.assertTrue(was_stacked) + self.assertIn("final_norm", result) + np.testing.assert_array_equal(result["final_norm"]["scale"], norm_weight) + + def test_stacks_three_layers(self): + arrays = [_make_array(2, 2) for _ in range(3)] + decoder = {f"layers_{i}": {"w": arrays[i]} for i in range(3)} + result, was_stacked = _stack_layers(decoder) + self.assertTrue(was_stacked) + stacked = result["layers"]["w"] + self.assertEqual(stacked.shape, (3, 2, 2)) + + def test_non_array_non_dict_leaf(self): + # Scalar leaf — stack_arrays returns first element + decoder = {"layers_0": {"count": 1}, "layers_1": {"count": 2}} + result, was_stacked = _stack_layers(decoder) + self.assertTrue(was_stacked) + self.assertIn("layers", result) + + def test_with_missing_key_in_some_layers(self): + arr = _make_array(3, 4) + decoder = { + "layers_0": {"mlp": {"kernel": arr, "bias": arr}}, + "layers_1": {"mlp": {"kernel": arr}}, # no "bias" + } + result, was_stacked = _stack_layers(decoder) + self.assertTrue(was_stacked) + self.assertIn("kernel", result["layers"]["mlp"]) + + +class TestConvertLinenToNNX(unittest.TestCase): + """Tests for the convert_linen_to_nnx function.""" + + def _make_linen_state(self, add_opt_state=False): + """Creates a minimal Linen checkpoint structure.""" + arr = _make_array(2, 4, 3) + state = { + "step": 10, + "params": { + "params": { + "decoder": { + "layers": {"mlp": {"wi": {"kernel": arr}}}, + "decoder_norm": {"scale": _make_array(4)}, + } + } + }, + } + if add_opt_state: + state["opt_state"] = {"params": {"mu": {"decoder": {"layers": {"kernel": arr}}}}} + return state + + def test_converts_step_under_optimizer(self): + state = self._make_linen_state() + result = convert_linen_to_nnx(state) + self.assertEqual(result["optimizer"]["step"], 10) + + def test_step_not_at_top_level(self): + state = self._make_linen_state() + result = convert_linen_to_nnx(state) + self.assertNotIn("step", result) + + def test_params_stored_under_model_key(self): + state = self._make_linen_state() + result = convert_linen_to_nnx(state) + self.assertIn("model", result) + self.assertNotIn("params", result) + + def test_removes_double_nesting(self): + state = self._make_linen_state() + result = convert_linen_to_nnx(state) + # model should have 'decoder' directly, not 'params.decoder' + self.assertIn("decoder", result["model"]) + self.assertNotIn("params", result["model"]) + + def test_adds_value_wrappers(self): + state = self._make_linen_state() + result = convert_linen_to_nnx(state) + # Arrays should be wrapped in {"value": array} + kernel = result["model"]["decoder"]["layers"]["mlp"]["wi"]["kernel"] + self.assertIsInstance(kernel, dict) + self.assertIn("value", kernel) + + def test_converts_opt_state_under_optimizer(self): + state = self._make_linen_state(add_opt_state=True) + result = convert_linen_to_nnx(state) + self.assertIn("opt_state", result["optimizer"]) + # Linen opt_state had nested 'params' level; it should be removed + self.assertNotIn("params", result["optimizer"]["opt_state"]) + + def test_no_step_produces_no_optimizer_step(self): + arr = _make_array(2, 4, 3) + state = {"params": {"params": {"decoder": {"layers": {"kernel": arr}}}}} + result = convert_linen_to_nnx(state) + self.assertNotIn("step", result) + self.assertIn("model", result) + + def test_no_double_nesting_still_converts(self): + # Linen state without double-nesting (unusual but handled) + arr = _make_array(2, 4) + state = {"params": {"decoder": {"layers": {"kernel": arr}}}} + result = convert_linen_to_nnx(state) + self.assertIn("decoder", result["model"]) + + def test_no_params_key_only_step(self): + state = {"step": 3} + result = convert_linen_to_nnx(state) + self.assertEqual(result["optimizer"]["step"], 3) + self.assertNotIn("model", result) + + def test_with_per_layer_params_stacked_and_transposed(self): + # Linen checkpoint with layers_0, layers_1 → stacked + transposed to axis 1 + arr = _make_array(3, 4) + state = { + "params": { + "params": { + "decoder": { + "layers_0": {"mlp": {"kernel": arr}}, + "layers_1": {"mlp": {"kernel": arr}}, + } + } + } + } + result = convert_linen_to_nnx(state) + stacked = result["model"]["decoder"]["layers"]["mlp"]["kernel"]["value"] + # Original (3, 4) stacked → (2, 3, 4), transposed to (3, 2, 4) + self.assertEqual(stacked.shape, (3, 2, 4)) + + +class TestConvertNNXToLinen(unittest.TestCase): + """Tests for the convert_nnx_to_linen function.""" + + def _make_nnx_state(self, add_opt_state=False): + """Creates an NNX checkpoint with 'model' and 'optimizer' keys. + + Uses 'attention' (not 'layers') as the sub-key so _convert_layers_to_linen_format + does not try to unstack the data. + """ + arr = _make_array(2, 4, 3) + state = { + "model": { + "decoder": { + "attention": {"wi": {"kernel": {"value": arr}}}, + "decoder_norm": {"scale": {"value": _make_array(4)}}, + } + }, + "optimizer": {"step": 5}, + } + if add_opt_state: + state["optimizer"]["opt_state"] = { + "mu": {"decoder": {"layers": {"kernel": {"value": arr}}}}, + "nu": {"decoder": {"layers": {"kernel": {"value": arr}}}}, + } + return state + + def test_converts_step(self): + state = self._make_nnx_state() + result = convert_nnx_to_linen(state) + self.assertEqual(result["step"], 5) + + def test_adds_double_nesting(self): + state = self._make_nnx_state() + result = convert_nnx_to_linen(state) + self.assertIn("params", result["params"]) + self.assertIn("decoder", result["params"]["params"]) + + def test_strips_value_wrappers(self): + state = self._make_nnx_state() + result = convert_nnx_to_linen(state) + kernel = result["params"]["params"]["decoder"]["attention"]["wi"]["kernel"] + self.assertIsInstance(kernel, np.ndarray) + + def test_converts_opt_state(self): + state = self._make_nnx_state(add_opt_state=True) + result = convert_nnx_to_linen(state) + self.assertIn("opt_state", result) + # mu/nu should get a 'params' level added + self.assertIn("params", result["opt_state"]["mu"]) + self.assertIn("params", result["opt_state"]["nu"]) + + def test_backward_compat_params_key(self): + # Old NNX format: "params" instead of "model", top-level "step" + arr = _make_array(2, 4, 3) + state = { + "step": 5, + "params": { + "decoder": { + "layers": {"mlp": {"wi": {"kernel": {"value": arr}}}}, + "decoder_norm": {"scale": {"value": _make_array(4)}}, + } + }, + } + result = convert_nnx_to_linen(state) + self.assertEqual(result["step"], 5) + self.assertIn("decoder", result["params"]["params"]) + + def test_no_step(self): + arr = _make_array(2, 4) + state = {"model": {"decoder": {"layers": {"kernel": {"value": arr}}}}} + result = convert_nnx_to_linen(state) + self.assertNotIn("step", result) + self.assertIn("params", result) + + +class TestRoundTrip(unittest.TestCase): + """Verifies that linen->nnx->linen round-trip preserves data.""" + + def test_linen_to_nnx_to_linen(self): + # Use "attention" (not "layers") so _convert_layers_to_linen_format + # does not try to unstack the dict as a stacked-layers tensor. + arr = _make_array(2, 4, 3) + linen_state = { + "step": 42, + "params": { + "params": { + "decoder": { + "attention": {"mlp": {"wi": {"kernel": arr}}}, + "norm": {"scale": _make_array(4)}, + } + } + }, + } + nnx_state = convert_linen_to_nnx(linen_state) + recovered_state = convert_nnx_to_linen(nnx_state) + + self.assertEqual(recovered_state["step"], 42) + recovered_kernel = recovered_state["params"]["params"]["decoder"]["attention"]["mlp"]["wi"]["kernel"] + np.testing.assert_array_equal(recovered_kernel, arr) + + def test_nnx_to_linen_to_nnx(self): + arr = _make_array(2, 4, 3) + nnx_state = { + "model": { + "decoder": { + "layers": {"mlp": {"wi": {"kernel": {"value": arr}}}}, + } + }, + "optimizer": {"step": 7}, + } + linen_state = convert_nnx_to_linen(nnx_state) + recovered_state = convert_linen_to_nnx(linen_state) + + self.assertEqual(recovered_state["optimizer"]["step"], 7) + recovered_kernel = recovered_state["model"]["decoder"]["layers"]["mlp"]["wi"]["kernel"] + self.assertIn("value", recovered_kernel) + np.testing.assert_array_equal(recovered_kernel["value"], arr) + + +class TestConvertOptState(unittest.TestCase): + """Tests for the _convert_opt_state_linen_to_nnx and _convert_opt_state_nnx_to_linen helpers.""" + + def test_linen_to_nnx_removes_params_level(self): + arr = _make_array(3, 4) + opt_state = {"mu": {"params": {"decoder": {"kernel": arr}}}} + result = _convert_opt_state_linen_to_nnx(opt_state) + # 'params' key removed; decoder promoted + self.assertNotIn("params", result["mu"]) + self.assertIn("decoder", result["mu"]) + # Arrays are plain (no value wrappers in NNX opt_state) + np.testing.assert_array_equal(result["mu"]["decoder"]["kernel"], arr) + + def test_linen_to_nnx_handles_list_input(self): + arr = _make_array(2, 2) + opt_state = [{"decoder": {"kernel": arr}}, {"decoder": {"kernel": arr}}] + result = _convert_opt_state_linen_to_nnx(opt_state) + self.assertIsInstance(result, list) + np.testing.assert_array_equal(result[0]["decoder"]["kernel"], arr) + + def test_linen_to_nnx_handles_tuple_input(self): + arr = _make_array(2, 2) + opt_state = ({"decoder": {"kernel": arr}},) + result = _convert_opt_state_linen_to_nnx(opt_state) + self.assertIsInstance(result, tuple) + np.testing.assert_array_equal(result[0]["decoder"]["kernel"], arr) + + def test_linen_to_nnx_handles_non_array_non_dict(self): + # Scalars should be passed through unchanged + result = _convert_opt_state_linen_to_nnx(42) + self.assertEqual(result, 42) + + def test_linen_to_nnx_params_key_with_non_dict_value(self): + # When k == "params" but converted value is not a dict, store it as-is + opt_state = {"params": 99} + result = _convert_opt_state_linen_to_nnx(opt_state) + self.assertIn("params", result) + self.assertEqual(result["params"], 99) + + def test_nnx_to_linen_adds_params_level_and_strips(self): + arr = _make_array(3, 4) + opt_state = { + "mu": {"decoder": {"kernel": {"value": arr}}}, + "nu": {"decoder": {"kernel": {"value": arr}}}, + } + result = _convert_opt_state_nnx_to_linen(opt_state) + # mu/nu should have 'params' nested inside + self.assertIn("params", result["mu"]) + self.assertIn("params", result["nu"]) + # Arrays unwrapped + kernel = result["mu"]["params"]["decoder"]["kernel"] + np.testing.assert_array_equal(kernel, arr) + + def test_nnx_to_linen_handles_list_input(self): + arr = _make_array(2, 2) + opt_state = [{"decoder": {"kernel": {"value": arr}}}] + result = _convert_opt_state_nnx_to_linen(opt_state) + self.assertIsInstance(result, list) + np.testing.assert_array_equal(result[0]["decoder"]["kernel"], arr) + + def test_nnx_to_linen_handles_tuple_input(self): + arr = _make_array(2, 2) + opt_state = ({"decoder": {"kernel": {"value": arr}}},) + result = _convert_opt_state_nnx_to_linen(opt_state) + self.assertIsInstance(result, tuple) + np.testing.assert_array_equal(result[0]["decoder"]["kernel"], arr) + + def test_nnx_to_linen_passes_through_scalars(self): + result = _convert_opt_state_nnx_to_linen("scalar_string") + self.assertEqual(result, "scalar_string") + + def test_nnx_to_linen_value_wrapper_with_non_array_inner(self): + # {"value": scalar} should NOT be unwrapped (only arrays get unwrapped) + d = {"value": 42} + result = _convert_opt_state_nnx_to_linen(d) + self.assertIn("value", result) + self.assertEqual(result["value"], 42) + + +class TestConvertLinenToNNXEncoder(unittest.TestCase): + """Tests encoder path in convert_linen_to_nnx.""" + + def test_converts_encoder_params(self): + arr = _make_array(2, 4, 3) + state = { + "params": { + "params": { + "encoder": { + "layers": {"mlp": {"wi": {"kernel": arr}}}, + } + } + } + } + result = convert_linen_to_nnx(state) + self.assertIn("encoder", result["model"]) + kernel = result["model"]["encoder"]["layers"]["mlp"]["wi"]["kernel"] + self.assertIsInstance(kernel, dict) + self.assertIn("value", kernel) + + def test_converts_encoder_with_per_layer_stacking(self): + arr = _make_array(3, 4) + state = { + "params": { + "params": { + "encoder": { + "layers_0": {"mlp": {"kernel": arr}}, + "layers_1": {"mlp": {"kernel": arr}}, + } + } + } + } + result = convert_linen_to_nnx(state) + stacked = result["model"]["encoder"]["layers"]["mlp"]["kernel"]["value"] + # Stacked at axis 0 → (2, 3, 4), then transposed to (3, 2, 4) + self.assertEqual(stacked.shape, (3, 2, 4)) + + +class TestAdditionalEdgeCases(unittest.TestCase): + """Covers remaining edge cases.""" + + def test_detect_format_params_has_params_but_no_decoder_encoder(self): + # params["params"] exists but inner has no decoder/encoder -> falls through + # no optimizer/opt_state -> should raise + state = {"params": {"params": {"some_other_key": {}}}} + with self.assertRaises(ValueError): + detect_format(state) + + def test_detect_format_opt_state_returns_linen(self): + # Any state with "opt_state" (but no "model"/"optimizer") detects as linen + arr = _make_array(2) + state = { + "params": {"something": arr}, + "opt_state": {"mu": {"decoder": {"kernel": arr}}}, + } + self.assertEqual(detect_format(state), "linen") + + def test_add_value_wrappers_value_key_with_non_array(self): + # {"value": "text"} is not a wrapper (inner is not an array), recurse normally + d = {"value": "not_an_array"} + result = _add_value_wrappers(d) + self.assertEqual(result, {"value": "not_an_array"}) + + def test_convert_nnx_to_linen_no_step(self): + arr = _make_array(2, 4) + state = {"model": {"decoder": {"layers": {"kernel": {"value": arr}}}}} + result = convert_nnx_to_linen(state) + self.assertNotIn("step", result) + self.assertIn("params", result) + + def test_convert_nnx_to_linen_already_has_params_nesting(self): + arr = _make_array(2, 4) + state = {"params": {"params": {"decoder": {"layers": {"kernel": {"value": arr}}}}}} + result = convert_nnx_to_linen(state) + self.assertIn("params", result) + + def test_convert_nnx_to_linen_no_params_key(self): + state = {"optimizer": {"step": 8}} + result = convert_nnx_to_linen(state) + self.assertEqual(result["step"], 8) + self.assertNotIn("params", result) + + +class TestLoadCheckpoint(unittest.TestCase): + """Tests for load_checkpoint with mocked orbax/epath.""" + + @patch("maxtext.checkpoint_conversion.linen_nnx_converter.ocp") + @patch("maxtext.checkpoint_conversion.linen_nnx_converter.epath") + def test_load_checkpoint_calls_checkpointer_and_returns_state(self, mock_epath, mock_ocp): + arr = _make_array(2, 2) + expected_state = {"params": arr, "step": 0} + + mock_path = MagicMock() + mock_epath.Path.return_value = mock_path + + mock_metadata = MagicMock() + mock_metadata.item_metadata.tree = {"params": arr} + + mock_ckptr = MagicMock() + mock_ckptr.metadata.return_value = mock_metadata + mock_ckptr.restore.return_value = expected_state + mock_ocp.Checkpointer.return_value = mock_ckptr + mock_ocp.ArrayRestoreArgs.return_value = MagicMock() + + result = load_checkpoint("/tmp/test_ckpt") + + mock_epath.Path.assert_called_once_with("/tmp/test_ckpt") + mock_ocp.Checkpointer.assert_called_once() + mock_ckptr.metadata.assert_called_once_with(mock_path) + mock_ckptr.restore.assert_called_once() + self.assertEqual(result, expected_state) + + @patch("maxtext.checkpoint_conversion.linen_nnx_converter.ocp") + @patch("maxtext.checkpoint_conversion.linen_nnx_converter.epath") + def test_load_checkpoint_with_empty_tree_metadata(self, mock_epath, mock_ocp): + expected_state = {"step": 5} + + mock_path = MagicMock() + mock_epath.Path.return_value = mock_path + + mock_metadata = MagicMock() + mock_metadata.item_metadata.tree = {} + + mock_ckptr = MagicMock() + mock_ckptr.metadata.return_value = mock_metadata + mock_ckptr.restore.return_value = expected_state + mock_ocp.Checkpointer.return_value = mock_ckptr + + result = load_checkpoint("/tmp/empty_ckpt") + + self.assertEqual(result["step"], 5) + + +class TestSaveCheckpoint(unittest.TestCase): + """Tests for save_checkpoint with mocked orbax/epath.""" + + @patch("maxtext.checkpoint_conversion.linen_nnx_converter.ocp") + @patch("maxtext.checkpoint_conversion.linen_nnx_converter.epath") + def test_save_checkpoint_creates_dir_and_saves(self, mock_epath, mock_ocp): + state = {"params": _make_array(2, 2), "step": 1} + + mock_path = MagicMock() + mock_epath.Path.return_value = mock_path + + mock_ckptr = MagicMock() + mock_ocp.PyTreeCheckpointer.return_value = mock_ckptr + + save_checkpoint(state, "/tmp/output") + + mock_epath.Path.assert_called_once_with("/tmp/output") + mock_path.mkdir.assert_called_once_with(exist_ok=True, parents=True) + mock_ocp.PyTreeCheckpointer.assert_called_once() + mock_ckptr.save.assert_called_once_with(mock_path, state, force=True) + + @patch("maxtext.checkpoint_conversion.linen_nnx_converter.ocp") + @patch("maxtext.checkpoint_conversion.linen_nnx_converter.epath") + def test_save_checkpoint_passes_state_unchanged(self, mock_epath, mock_ocp): + state = {"step": 99, "params": {"decoder": {}}} + + mock_path = MagicMock() + mock_epath.Path.return_value = mock_path + mock_ckptr = MagicMock() + mock_ocp.PyTreeCheckpointer.return_value = mock_ckptr + + save_checkpoint(state, "/tmp/out2") + + call_args = mock_ckptr.save.call_args + self.assertIs(call_args[0][1], state) + + +class TestMain(unittest.TestCase): + """Tests for the main() CLI entry point.""" + + def _run_main(self, argv): + with patch("sys.argv", ["prog"] + argv): + main() + + @patch("maxtext.checkpoint_conversion.linen_nnx_converter.save_checkpoint") + @patch("maxtext.checkpoint_conversion.linen_nnx_converter.load_checkpoint") + def test_main_explicit_linen_to_nnx(self, mock_load, mock_save): + arr = _make_array(2, 4, 3) + mock_load.return_value = { + "step": 1, + "params": {"params": {"decoder": {"layers": {"kernel": arr}}}}, + } + self._run_main(["--source_path=/src", "--target_path=/dst", "--direction=linen_to_nnx"]) + mock_load.assert_called_once_with("/src") + mock_save.assert_called_once() + saved_state = mock_save.call_args[0][0] + # NNX format: decoder at top level of model + self.assertIn("decoder", saved_state["model"]) + self.assertEqual(mock_save.call_args[0][1], "/dst") + + @patch("maxtext.checkpoint_conversion.linen_nnx_converter.save_checkpoint") + @patch("maxtext.checkpoint_conversion.linen_nnx_converter.load_checkpoint") + def test_main_explicit_nnx_to_linen(self, mock_load, mock_save): + arr = _make_array(2, 4, 3) + mock_load.return_value = { + "model": {"decoder": {"layers": {"kernel": {"value": arr}}}}, + "optimizer": {"step": 2}, + } + self._run_main(["--source_path=/src", "--target_path=/dst", "--direction=nnx_to_linen"]) + mock_load.assert_called_once_with("/src") + mock_save.assert_called_once() + saved_state = mock_save.call_args[0][0] + # Linen format: double nesting + self.assertIn("params", saved_state["params"]) + + @patch("maxtext.checkpoint_conversion.linen_nnx_converter.save_checkpoint") + @patch("maxtext.checkpoint_conversion.linen_nnx_converter.load_checkpoint") + def test_main_auto_detects_linen_converts_to_nnx(self, mock_load, mock_save): + arr = _make_array(2, 4, 3) + mock_load.return_value = { + "step": 3, + "params": {"params": {"decoder": {"layers": {"kernel": arr}}}}, + } + self._run_main(["--source_path=/src", "--target_path=/dst", "--direction=auto"]) + mock_save.assert_called_once() + saved_state = mock_save.call_args[0][0] + # Auto-detected linen → NNX format: model key + self.assertIn("decoder", saved_state["model"]) + + @patch("maxtext.checkpoint_conversion.linen_nnx_converter.save_checkpoint") + @patch("maxtext.checkpoint_conversion.linen_nnx_converter.load_checkpoint") + def test_main_auto_detects_nnx_converts_to_linen(self, mock_load, mock_save): + arr = _make_array(2, 4, 3) + mock_load.return_value = { + "model": {"decoder": {"layers": {"kernel": {"value": arr}}}}, + "optimizer": {"step": 4}, + } + self._run_main(["--source_path=/src", "--target_path=/dst", "--direction=auto"]) + mock_save.assert_called_once() + saved_state = mock_save.call_args[0][0] + # Auto-detected nnx → Linen format + self.assertIn("params", saved_state["params"]) + + @patch("maxtext.checkpoint_conversion.linen_nnx_converter.save_checkpoint") + @patch("maxtext.checkpoint_conversion.linen_nnx_converter.load_checkpoint") + def test_main_default_direction_is_auto(self, mock_load, mock_save): + arr = _make_array(2, 4, 3) + mock_load.return_value = { + "params": {"params": {"decoder": {"layers": {"kernel": arr}}}}, + } + # No --direction arg -> defaults to "auto" + self._run_main(["--source_path=/src", "--target_path=/dst"]) + mock_save.assert_called_once() + + @patch("maxtext.checkpoint_conversion.linen_nnx_converter.save_checkpoint") + @patch("maxtext.checkpoint_conversion.linen_nnx_converter.load_checkpoint") + def test_main_scan_layers_false(self, mock_load, mock_save): + arr = _make_array(3, 4) + mock_load.return_value = { + "params": { + "params": { + "decoder": { + "layers_0": {"mlp": {"kernel": arr}}, + "layers_1": {"mlp": {"kernel": arr}}, + } + } + } + } + self._run_main(["--source_path=/src", "--target_path=/dst", "--direction=linen_to_nnx", "--no-scan_layers"]) + saved_state = mock_save.call_args[0][0] + # With scan_layers=False: integer-keyed layers/N + layers = saved_state["model"]["decoder"]["layers"] + self.assertIsInstance(layers, dict) + self.assertTrue(all(k.isdigit() for k in layers.keys())) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/utils/run_sharding_dump.py b/tests/utils/run_sharding_dump.py index 7d3156fe00..62c71a9b5b 100644 --- a/tests/utils/run_sharding_dump.py +++ b/tests/utils/run_sharding_dump.py @@ -59,9 +59,12 @@ flags.DEFINE_string("topology", None, "Specific topology to dump.") flags.DEFINE_string("num_slice", None, "Specific number of slices to dump.") flags.DEFINE_string("custom_mesh_and_rule", None, "Specific custom_mesh_and_rule to dump.") +flags.DEFINE_bool("pure_nnx", False, "Use pure NNX model.") -def run_single_dump(model_name: str, topology: str, num_slice: str, custom_mesh_and_rule: str, overrides: tuple) -> None: +def run_single_dump( + model_name: str, topology: str, num_slice: str, custom_mesh_and_rule: str, overrides: tuple, pure_nnx: bool = False +) -> None: """Generate sharding json file for one specific model, topology, slice and rule.""" args = [ "python3", @@ -79,6 +82,8 @@ def run_single_dump(model_name: str, topology: str, num_slice: str, custom_mesh_ args.append(f"custom_mesh_and_rule={custom_mesh_and_rule}") if overrides: args.extend(overrides) + if pure_nnx: + args.append("pure_nnx=true") subprocess.run(args, check=True) @@ -117,7 +122,7 @@ def main(argv: Sequence[str]) -> None: print(" -> Sharding files already exist. Regenerating to overwrite.") try: - run_single_dump(model_name, topology, str(num_slice), custom_mesh_and_rule, overrides) + run_single_dump(model_name, topology, str(num_slice), custom_mesh_and_rule, overrides, pure_nnx=FLAGS.pure_nnx) except subprocess.CalledProcessError: print(f"!!! FAILED: {model_name} {topology} {num_slice} {custom_mesh_and_rule} overrides={overrides}") From 4f7763a5a0978e78f7599a5e8b81b07ffed77732 Mon Sep 17 00:00:00 2001 From: Lance Wang Date: Tue, 28 Apr 2026 21:17:04 +0000 Subject: [PATCH 3/4] NNX: correctness fixes, enable feature paths, and vocab tiling on NNX MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Bug fixes (run as no-op while pure_nnx=False stays default): - nnx_wrappers.py: add _refresh_variable_trace_state + is_linen_initializing; call from ToLinen after nnx.update to fix "Cannot extract graph node from different trace level" when grad tracers leak into Variable._trace_state. - gpt_oss.py / olmo3.py: replace inline nn.Dropout(...) with self.dropout = linears.Dropout(...) in __init__ to fix CallCompactUnboundModuleError. - normalizations.py: Qwen3NextRMSNorm signature: eps -> epsilon, accept shard_mode/kernel_axes/parameter_memory_host_offload for callsite parity. - attentions.py / qwen3.py: callsites eps= -> epsilon=. - moe.py: per_expert_scale block moved into the unfused-kernel else branch (was scaling wo even when fused_kernel was active). - models.py: build MTP block as MultiTokenPredictionBlock(...) directly (drop the ToNNX(linen) + lazy_init wrap); pass multimodal_input whole to NNXDecoder instead of unpacking 5 fields. - gradient_accumulation.py: ZeRO-1+GA all-reduce annotation deferred until after lax.scan (reduced/unreduced PartitionSpec is rejected inside scan carry); use nnx.merge(..., copy=True) to avoid Variable reuse. - diloco.py: NNX-aware state handling — state.params -> state.model.filter (nnx.Param), step counter at state.optimizer.step, replace_nnx_model_params helper for jax.lax.cond pytree-structure parity. - train_compile.py: new _collect_nnx_activation_shardings helper (forward pass populates _ACTIVATION_SHARDINGS_DUMP — get_abstract_state_nnx only traces __init__); NNX path now passes 2-arg shaped_train_args (no rng); diloco path patched to handle the 2-vs-3 length difference. - muon_utils.py: get_model_mdn default pure_nnx=True; wrap NNX result as {"params": nnx.to_pure_dict(...)} for parity with Linen tree shape. - nnx_decoders.py: FP8+NNX scan fix — Linen FP8 ops (fp8_nanoo, fp8_gpu) retain tracers in Linen scope across re-traces. Skip jax.checkpoint and use a Python for-loop instead of jax.lax.scan when quantization is FP8. Makes FP8 quantization usable on the NNX path. - train.py (pre-train train_step): return nnx.state(new_state, nnx.Not (nnx.Intermediate)) so sowed forward-pass artifacts (e.g. max_logits for QK-Clip) don't break leaf-count parity with state_mesh_shardings. - llama2.py: pass parameter_memory_host_offload to pre_self_attention_layer _norm RMSNorm (was missing on this norm only). - base.yml: add 4 pipeline-related logical_axis_rules — layers_outside _pipeline, layers_per_stage, num_activations, circular_repeats. Additive, no-op without use_nnx_pipeline=True. NNX feature enablements (clear all 17 "Pure NNX support has not been implemented yet" NotImplementedError sites by routing Linen-coupled utilities to the Linen path; their on-disk format is Linen): - layerwise_quantization.py (2 sites): operates on Linen-format checkpoints via DeepSeek*ToLinen layers. - lora_utils.py (1 site): downstream get_lora_abstract_state expects Linen tree shape; LoRA adapters on disk are Linen. - standalone_checkpointer.py (2 sites): add_entropy_to_checkpoint accesses state.opt_state[0]._replace(mu=..., nu=...) — Linen-only. - generate_param_only_checkpoint.py (3 sites): _possibly_unroll_params and _save_decode_checkpoint use state.params["params"]["decoder"] — Linen. - convert_gpt3_ckpt_from_paxml.py (2 sites): keystr_map targets Linen tree paths (.params['params'], .opt_state.mu['params']). - maxengine.py (3 sites): inference engine uses state.params and serves Linen-format inference checkpoints. - grpo_trainer.py (4 sites): RL trainer is end-to-end Linen-shaped; route to Linen with a clear log warning since NNX-format checkpoints will fail at restore time. Vocab tiling on NNX (real implementation, not just routing): - models.py: add Transformer.logits_from_hidden_states on the NNX Transformer class — wraps NNXDecoder.apply_output_head with the token_embedder; mirrors TransformerLinenPure.logits_from_hidden_states. - vocabulary_tiling.py: add vocab_tiling_nnx_loss — chunks the vocab axis via jax.lax.scan and calls model.logits_from_hidden_states(chunk) per chunk. The NNX model carries its parameters internally so no explicit FSDP gather is needed (unlike the Linen gathered_params pattern). MVP uses default autograd; custom_vjp memory-savings optimization is a follow-up if backward memory becomes a concern. - train.py (NNX loss_fn): replace the NotImplementedError with the call to vocab_tiling_nnx_loss using hidden_states from intermediates. - pyconfig_deprecated.py / configs/types.py: drop the num_vocab_tiling > 1 and enable_nnx validation guards (no longer needed). DPO + NNX retained as NotImplementedError but with a much more informative message (points users at pure_nnx=False workaround). Full implementation is deferred — needs a new TrainState shape carrying both policy and reference NNX models plus an NNX dpo_loss_fn. Stats: 26 source files modified, +406 / -171 lines. Linen invariant verified: pure_nnx / enable_nnx / pure_nnx_decoder still default to False; Linen-path UTs unaffected (3 pre-existing failures on the parent branch remain unchanged — sharding_compare_test::deepseek2-16b, optimizers_test::test_model_integration_kimi-k2-1t, diloco_test::two _slices x2). All "Pure NNX support has not been implemented yet" NotImplementedError sites cleared (was 17, now 0). --- .../convert_gpt3_ckpt_from_paxml.py | 15 +-- src/maxtext/configs/base.yml | 7 ++ src/maxtext/configs/pyconfig_deprecated.py | 3 +- src/maxtext/configs/types.py | 3 +- src/maxtext/experimental/rl/grpo_trainer.py | 37 +++--- src/maxtext/inference/maxengine/maxengine.py | 22 ++-- src/maxtext/layers/attentions.py | 4 +- src/maxtext/layers/moe.py | 4 +- src/maxtext/layers/nnx_decoders.py | 34 +++++- src/maxtext/layers/nnx_wrappers.py | 35 ++++++ src/maxtext/layers/normalizations.py | 14 ++- src/maxtext/models/gpt_oss.py | 5 +- src/maxtext/models/llama2.py | 1 + src/maxtext/models/models.py | 36 +++--- src/maxtext/models/olmo3.py | 4 +- src/maxtext/models/qwen3.py | 4 +- src/maxtext/trainers/diloco/diloco.py | 59 ++++++++-- src/maxtext/trainers/pre_train/train.py | 22 +++- .../trainers/pre_train/train_compile.py | 38 ++++++- .../utils/generate_param_only_checkpoint.py | 26 ++--- src/maxtext/utils/gradient_accumulation.py | 21 +++- src/maxtext/utils/layerwise_quantization.py | 20 ++-- src/maxtext/utils/lora_utils.py | 11 +- src/maxtext/utils/muon_utils.py | 5 +- src/maxtext/utils/standalone_checkpointer.py | 15 +-- src/maxtext/utils/vocabulary_tiling.py | 107 ++++++++++++++++++ 26 files changed, 401 insertions(+), 151 deletions(-) diff --git a/src/maxtext/checkpoint_conversion/standalone_scripts/convert_gpt3_ckpt_from_paxml.py b/src/maxtext/checkpoint_conversion/standalone_scripts/convert_gpt3_ckpt_from_paxml.py index 9b5f0cfb21..d4d4c39290 100644 --- a/src/maxtext/checkpoint_conversion/standalone_scripts/convert_gpt3_ckpt_from_paxml.py +++ b/src/maxtext/checkpoint_conversion/standalone_scripts/convert_gpt3_ckpt_from_paxml.py @@ -87,11 +87,12 @@ def convert(paxml_ckpt_path, maxtext_model_name, base_output_directory, run_name devices_array = maxtext_utils.create_device_mesh(cfg) mesh = Mesh(devices_array, cfg.mesh_axes) + # This conversion script reads paxml-format weights and emits a Linen-format + # MaxText checkpoint (downstream uses `.params['params']`, `.opt_state.mu['params']`, + # `.opt_state.nu['params']` keystr paths; the keystr_map below targets the Linen + # tree shape). Use the Linen path regardless of pure_nnx. quant = quantizations.configure_quantization(cfg) - if cfg.pure_nnx: - raise NotImplementedError("Pure NNX support has not been implemented yet.") - else: - model = transformer_as_linen(cfg, mesh, quant=quant, model_mode=MODEL_MODE_TRAIN) + model = transformer_as_linen(cfg, mesh, quant=quant, model_mode=MODEL_MODE_TRAIN) learning_rate_schedule = maxtext_utils.create_learning_rate_schedule(cfg) tx = optimizers.get_optimizer(cfg, learning_rate_schedule) @@ -102,11 +103,7 @@ def convert(paxml_ckpt_path, maxtext_model_name, base_output_directory, run_name cfg.checkpoint_period, ) - if cfg.pure_nnx: - # NNX has a different function to init the training state. - raise NotImplementedError("Pure NNX support has not been implemented yet.") - else: - init_state_fn = functools.partial(maxtext_utils.init_initial_state, model, tx, cfg, True, init_rng) + init_state_fn = functools.partial(maxtext_utils.init_initial_state, model, tx, cfg, True, init_rng) state, _, _, _ = maxtext_utils.setup_training_state(None, cfg, mesh, checkpoint_manager, init_state_fn) max_logging.log("start") max_utils.print_mem_stats("After params initialized") diff --git a/src/maxtext/configs/base.yml b/src/maxtext/configs/base.yml index 5e59a0f4be..f9aefd91d7 100644 --- a/src/maxtext/configs/base.yml +++ b/src/maxtext/configs/base.yml @@ -559,6 +559,13 @@ logical_axis_rules: [ ['tokens_per_page', []], ['paged_kv_head_dim_size', []], # ========================================== + # Pipeline Parallelism + # ========================================== + ['layers_outside_pipeline', []], + ['layers_per_stage', []], + ['num_activations', []], + ['circular_repeats', []], + # ========================================== # Deprecated / Scheduled for Removal # ========================================== ['mlp_no_fsdp', ['tensor', 'tensor_sequence', 'autoregressive']], diff --git a/src/maxtext/configs/pyconfig_deprecated.py b/src/maxtext/configs/pyconfig_deprecated.py index f5e080ed2e..511a39d29d 100644 --- a/src/maxtext/configs/pyconfig_deprecated.py +++ b/src/maxtext/configs/pyconfig_deprecated.py @@ -195,10 +195,9 @@ def validate_expert_shard_attention_option(expert_shard_attention_option: str) - def validate_vocab_tiling(num_vocab_tiling: int, per_device_batch_size: int, max_target_length: int, enable_nnx: bool): + del enable_nnx # NNX vocab tiling supported via vocab_tiling_nnx_loss in vocabulary_tiling.py if (per_device_batch_size * max_target_length) % num_vocab_tiling != 0: raise ValueError("Per device batch size times sequence length should be divisible by the number of vocab tiles.") - if num_vocab_tiling > 1 and enable_nnx: # TODO (chengnuojin) enable vocab tiling on NNX after NNX migration - raise ValueError("We currently don't support vocab tiling on NNX module.") def validate_rampup_batch_size(batch_size_start, batch_size_end, batch_size_increment, global_rampup_samples): diff --git a/src/maxtext/configs/types.py b/src/maxtext/configs/types.py index b463516945..1073114449 100644 --- a/src/maxtext/configs/types.py +++ b/src/maxtext/configs/types.py @@ -2736,8 +2736,7 @@ def calculate_global_batch_sizes(per_device_batch_size, expansion_factor, num_de and (self.per_device_batch_size * self.max_target_length) % self.num_vocab_tiling != 0 ): raise ValueError("Per device batch size times sequence length should be divisible by the number of vocab tiles.") - if self.num_vocab_tiling > 1 and self.enable_nnx: - raise ValueError("We currently don't support vocab tiling on NNX module.") + # Vocab tiling on NNX is now supported via vocab_tiling_nnx_loss in vocabulary_tiling.py. if self.context_parallel_size > 1 and self.context_parallel_strategy.lower() == "ring": if "gpu" not in self.hardware: raise ValueError( diff --git a/src/maxtext/experimental/rl/grpo_trainer.py b/src/maxtext/experimental/rl/grpo_trainer.py index 28eef21cb0..4244d199a8 100644 --- a/src/maxtext/experimental/rl/grpo_trainer.py +++ b/src/maxtext/experimental/rl/grpo_trainer.py @@ -542,29 +542,28 @@ def setup_train_loop( - eval_data_iterator: The iterator for the evaluation dataset (or None). - state: The initialized training state. """ + # GRPO RL trainer is Linen-shaped end-to-end (state.params accesses below, + # state_mesh_shardings.params, and the inference path through MaxEngine which is + # Linen-only). Run on Linen path regardless of pure_nnx; warn the user since + # NNX-format checkpoints will mismatch at restore time. + if config.pure_nnx or config_inference.pure_nnx: + max_logging.log( + "WARNING: GRPO RL trainer does not yet support pure_nnx natively; " + "running on the Linen path. NNX-format checkpoints will not load correctly here." + ) with maybe_record_goodput(recorder, GoodputEvent.TPU_INIT): max_logging.log("Training mesh used for the workload") num_inference_devices = config.inference_devices_per_replica * config.inference_replicas training_devices = jax.devices()[num_inference_devices:] - if config.pure_nnx: - raise NotImplementedError("Pure NNX support has not been implemented yet.") - else: - model = mt.from_config(config, devices=training_devices) + model = mt.from_config(config, devices=training_devices) mesh = model.mesh max_logging.log("Inference mesh used for the workload") inference_devices = jax.devices()[:num_inference_devices] - if config_inference.pure_nnx: - raise NotImplementedError("Pure NNX support has not been implemented yet.") - else: - inference_model = mt.from_config(config_inference, devices=inference_devices) + inference_model = mt.from_config(config_inference, devices=inference_devices) inference_mesh = inference_model.mesh init_rng = jax.random.PRNGKey(config.init_weights_seed) learning_rate_schedule, tx = train_utils.create_training_optimizer(config, model) - if config.pure_nnx: - # NNX has a different function to init the training state. - raise NotImplementedError("Pure NNX support has not been implemented yet.") - else: - init_state_fn = functools.partial(maxtext_utils.init_initial_state, model, tx, config, True, init_rng) + init_state_fn = functools.partial(maxtext_utils.init_initial_state, model, tx, config, True, init_rng) checkpoint_manager = train_utils.create_checkpoint_manager(config, mesh, init_state_fn) with maybe_record_goodput(recorder, GoodputEvent.TRAINING_PREPARATION): @@ -573,14 +572,10 @@ def setup_train_loop( data_iterator, config, mesh, checkpoint_manager, init_state_fn ) - # create inference_state_mesh_shardings from inference_mesh - if config_inference.pure_nnx: - # NNX has a different function to init the training state. - raise NotImplementedError("Pure NNX support has not been implemented yet.") - else: - init_inference_state_fn = functools.partial( - maxtext_utils.init_initial_state, inference_model, tx, config_inference, False, init_rng - ) + # create inference_state_mesh_shardings from inference_mesh (Linen path; see warning above) + init_inference_state_fn = functools.partial( + maxtext_utils.init_initial_state, inference_model, tx, config_inference, False, init_rng + ) inference_state_mesh_shardings = maxtext_utils.get_abstract_state( config_inference, inference_mesh, init_inference_state_fn, is_training=False )[2] diff --git a/src/maxtext/inference/maxengine/maxengine.py b/src/maxtext/inference/maxengine/maxengine.py index 5bb0a87b5a..c00f475e8d 100644 --- a/src/maxtext/inference/maxengine/maxengine.py +++ b/src/maxtext/inference/maxengine/maxengine.py @@ -111,12 +111,12 @@ def __init__(self, config: Any, devices: Any | None = None): devices_array = maxtext_utils.create_device_mesh(config=config, devices=devices) self._mesh = jax.sharding.Mesh(devices_array, config.mesh_axes) - # Model and Optimizer definition + # Model and Optimizer definition. + # MaxEngine uses Linen-shaped state (state.params, state_mesh_shardings.params, + # state.opt_state) and serves Linen-format inference checkpoints. Use Linen path + # regardless of pure_nnx — the flag affects training, not inference serving. quant = quantizations.configure_quantization(config) - if config.pure_nnx: - raise NotImplementedError("Pure NNX support has not been implemented yet.") - else: - self.model = models.transformer_as_linen(config, mesh=self._mesh, quant=quant, model_mode=MODEL_MODE_PREFILL) + self.model = models.transformer_as_linen(config, mesh=self._mesh, quant=quant, model_mode=MODEL_MODE_PREFILL) self.replicated_sharding = jax.sharding.NamedSharding(self._mesh, P(None)) self.abstract_params = None @@ -232,11 +232,7 @@ def load_params(self, *args, params=None, rng: PRNGKeyType | None = None, **kwar rng1, rng2, rng3 = jax.random.split(rng, 3) if params: print("Resharding given params") - if self.config.pure_nnx: - # NNX has a different function to init the training state. - raise NotImplementedError("Pure NNX support has not been implemented yet.") - else: - init_state_fn = functools.partial(maxtext_utils.init_initial_state, self.model, None, self.config, False, rng) + init_state_fn = functools.partial(maxtext_utils.init_initial_state, self.model, None, self.config, False, rng) _, self.state_mesh_annotations, state_mesh_shardings = maxtext_utils.get_abstract_state( self.config, self._mesh, init_state_fn, False ) @@ -245,11 +241,7 @@ def load_params(self, *args, params=None, rng: PRNGKeyType | None = None, **kwar state = maxtext_utils.init_decode_state(None, params) state = max_utils.unbox_logicallypartioned(state) else: - if self.config.pure_nnx: - # NNX has a different function to init the training state. - raise NotImplementedError("Pure NNX support has not been implemented yet.") - else: - init_state_fn = functools.partial(maxtext_utils.init_initial_state, self.model, None, self.config, False, rng1) + init_state_fn = functools.partial(maxtext_utils.init_initial_state, self.model, None, self.config, False, rng1) state, self.state_mesh_annotations = maxtext_utils.setup_decode_state(self.config, self._mesh, None, init_state_fn) # pylint: disable=isinstance-second-argument-not-valid-type self.abstract_params = jax.tree_util.tree_map( diff --git a/src/maxtext/layers/attentions.py b/src/maxtext/layers/attentions.py index e53de0973a..f2c337f330 100644 --- a/src/maxtext/layers/attentions.py +++ b/src/maxtext/layers/attentions.py @@ -525,14 +525,14 @@ def __init__( elif self.is_qwen3_next: self.query_norm = Qwen3NextRMSNorm( num_features=self.config.head_dim, - eps=self.config.normalization_layer_epsilon, + epsilon=self.config.normalization_layer_epsilon, dtype=self.config.dtype, weight_dtype=self.config.weight_dtype, rngs=self.rngs, ) self.key_norm = Qwen3NextRMSNorm( num_features=self.config.head_dim, - eps=self.config.normalization_layer_epsilon, + epsilon=self.config.normalization_layer_epsilon, dtype=self.config.dtype, weight_dtype=self.config.weight_dtype, rngs=self.rngs, diff --git a/src/maxtext/layers/moe.py b/src/maxtext/layers/moe.py index a08c1d10ff..59822dd3ef 100644 --- a/src/maxtext/layers/moe.py +++ b/src/maxtext/layers/moe.py @@ -2185,8 +2185,8 @@ def __call__( w0_kernel = jnp.asarray(self.wi_0[...], self.dtype) w1_kernel = jnp.asarray(self.wi_1[...], self.dtype) - if self.per_expert_scale is not None: - wo_kernel = wo_kernel * jnp.asarray(self.per_expert_scale[...], self.dtype)[:, None, None] + if self.per_expert_scale is not None: + wo_kernel = wo_kernel * jnp.asarray(self.per_expert_scale[...], self.dtype)[:, None, None] if self.wi_0_sparsity_module is not None: _, w0_kernel = self.wi_0_sparsity_module(jnp.zeros_like(w0_kernel), w0_kernel) diff --git a/src/maxtext/layers/nnx_decoders.py b/src/maxtext/layers/nnx_decoders.py index 42cf44cab0..8a00262088 100644 --- a/src/maxtext/layers/nnx_decoders.py +++ b/src/maxtext/layers/nnx_decoders.py @@ -423,8 +423,16 @@ def pure_layer_fn(state_in, y_in): out = merged_layer(y_in, **kwargs) return out, nnx.state(merged_layer) - checkpointed_fn = jax.checkpoint(pure_layer_fn, policy=policy, prevent_cse=prevent_cse) - out, new_state = checkpointed_fn(state, y) + # Linen-based FP8 ops (fp8_nanoo, fp8_gpu) store scale/amax_history in Linen + # mutable scope. jax.checkpoint re-traces the scan body during backward (remat), + # but the Linen scope retains JAX tracers from the first trace, causing + # UnexpectedTracerError. Skip checkpoint for these quantization types. + uses_linen_fp8_mutable_state = self.config.quantization in ("fp8_nanoo", "fp8_gpu") + if uses_linen_fp8_mutable_state: + out, new_state = pure_layer_fn(state, y) + else: + checkpointed_fn = jax.checkpoint(pure_layer_fn, policy=policy, prevent_cse=prevent_cse) + out, new_state = checkpointed_fn(state, y) nnx.update(layer, new_state) return out @@ -503,13 +511,12 @@ def layer_fn(carry, scanned_vars): return new_carry, (new_current_state, updated_kv) return new_carry, new_current_state - layer_fn = jax.checkpoint(layer_fn, policy=policy, prevent_cse=prevent_cse) - if use_kv: # If kv_caches is provided (e.g., from vLLM), we CANNOT use jax.lax.scan # because scanning requires stacking the kv_caches list, which creates a copy # and breaks the in-place memory updates required by vLLM's PagedAttention. # Therefore, we must unroll the loop statically when kv_caches is provided. + layer_fn = jax.checkpoint(layer_fn, policy=policy, prevent_cse=prevent_cse) # kv_caches_stacked is actually the original kv_caches list in this new flow kv_caches_list = kv_caches_stacked @@ -531,7 +538,24 @@ def layer_fn(carry, scanned_vars): # inference with vLLM, parameters do not change and we don't need intermediates. return current_carry, layers, None else: - final_carry, scanned_state = jax.lax.scan(layer_fn, x_in, (params, state)) + # Linen-based FP8 ops (fp8_nanoo, fp8_gpu) store scale/amax_history in Linen + # mutable scope. jax.lax.scan traces the body function and Linen's setup() creates + # intermediate tracer values (amax_history float32[1024]) that escape the scan scope, + # causing UnexpectedTracerError. Use a Python for loop instead for these types. + uses_linen_fp8_mutable_state = self.config.quantization in ("fp8_nanoo", "fp8_gpu") + if uses_linen_fp8_mutable_state: + carry = x_in + per_layer_states = [] + for i in range(length): + current_params = jax.tree.map(lambda x, i=i: x[i], params) + current_state = jax.tree.map(lambda x, i=i: x[i], state) + carry, new_state_i = layer_fn(carry, (current_params, current_state)) + per_layer_states.append(new_state_i) + final_carry = carry + scanned_state = jax.tree.map(lambda *xs: jnp.stack(list(xs)), *per_layer_states) + else: + layer_fn = jax.checkpoint(layer_fn, policy=policy, prevent_cse=prevent_cse) + final_carry, scanned_state = jax.lax.scan(layer_fn, x_in, (params, state)) returned_kv_stacked = None if scan_axis != 0: diff --git a/src/maxtext/layers/nnx_wrappers.py b/src/maxtext/layers/nnx_wrappers.py index eb81d596d9..2b770b1b0a 100644 --- a/src/maxtext/layers/nnx_wrappers.py +++ b/src/maxtext/layers/nnx_wrappers.py @@ -26,6 +26,7 @@ from flax.core import FrozenDict from flax.core import meta from flax.nnx import graph +from flax.nnx import tracers as nnx_tracers from flax.nnx import variablelib from flax.nnx.bridge import module as bdg_module from flax.nnx.module import Module @@ -170,6 +171,39 @@ def current_linen_module() -> linen.Module | None: return None +def is_linen_initializing() -> bool: + """Check if the current execution context is inside a Linen init() call. + + Returns True when called from within a ``to_linen_class`` wrapper's + ``init()`` path. Uses :func:`current_linen_module` to access the Linen + module stack (private API already used by this module). + + This is used by NNX pipeline modules to short-circuit the full scan + during Linen init, where only the output shape/dtype is needed. + """ + module = current_linen_module() + if module is not None and hasattr(module, "is_initializing") and callable(module.is_initializing): + return module.is_initializing() + return False + + +def _refresh_variable_trace_state(module: Module) -> None: + """Refresh _trace_state for Variables that have stale trace state. + + When nnx.update() is called with tracer values from a JAX transformation + (e.g. jax.grad's LinearizeTracer), it uses _unsafe_bypass_check=True which + updates the raw value but not _trace_state. This leaves Variables with a + stale _trace_state from the outer (Python) context, causing nnx.split() to + fail with "Cannot extract graph node from different trace level" errors. + + This function resets _trace_state on any Variables whose _can_update is False + so that downstream NNX operations (e.g. nnx.split in NNXPipeline) succeed. + """ + for _, v in nnx.graph.iter_graph(module): + if isinstance(v, variablelib.Variable) and not v._can_update: # pylint: disable=protected-access + object.__setattr__(v, "_trace_state", nnx_tracers.TraceState()) + + class ToNNX(Module): """A wrapper to turn any Linen module into an NNX module. @@ -467,6 +501,7 @@ def maybe_unbox(x): warnings.warn(f"Found unknown module paths in incoming state:{paths_str}") nnx.update(module, new_state) + _refresh_variable_trace_state(module) _fix_for_qwix_quantization(module) method_fn = _get_module_method(module, nnx_method) diff --git a/src/maxtext/layers/normalizations.py b/src/maxtext/layers/normalizations.py index 3bce30d44e..c904b0e4e0 100644 --- a/src/maxtext/layers/normalizations.py +++ b/src/maxtext/layers/normalizations.py @@ -114,7 +114,17 @@ def __call__(self, x: jnp.ndarray, out_sharding: NamedSharding | None = None) -> return y_flat.reshape(input_shape) -def Qwen3NextRMSNorm(num_features: int, eps: float, dtype: DType, weight_dtype: DType, *, rngs: nnx.Rngs): +def Qwen3NextRMSNorm( + num_features: int, + epsilon: float = 1e-6, + dtype: DType = None, + weight_dtype: DType = None, + shard_mode=None, + kernel_axes=None, + parameter_memory_host_offload=None, + *, + rngs: nnx.Rngs, +): """ Used for input and post attention layernorms in Qwen3NextDecoderLayer. @@ -127,7 +137,7 @@ def Qwen3NextRMSNorm(num_features: int, eps: float, dtype: DType, weight_dtype: return nnx.data( RMSNorm( num_features=num_features, - epsilon=eps, + epsilon=epsilon, dtype=dtype, weight_dtype=weight_dtype, scale_init=linen_initializers.zeros, diff --git a/src/maxtext/models/gpt_oss.py b/src/maxtext/models/gpt_oss.py index 9401d01d9f..5f4a2f3fb6 100644 --- a/src/maxtext/models/gpt_oss.py +++ b/src/maxtext/models/gpt_oss.py @@ -29,6 +29,7 @@ from maxtext.common.common_types import AttentionType, Config from maxtext.layers import attentions from maxtext.layers import initializers +from maxtext.layers import linears from maxtext.layers import moe from maxtext.layers import nnx_wrappers from maxtext.layers import quantizations @@ -132,6 +133,8 @@ def __init__( rngs=rngs, ) + self.dropout = linears.Dropout(rate=config.dropout_rate, broadcast_dims=(-2,), rngs=rngs) + def __call__( self, inputs, @@ -189,7 +192,7 @@ def __call__( mlp_lnx = nn.with_logical_constraint(mlp_lnx, ("activation_batch", "activation_norm_length", "activation_embed")) layer_output = mlp_lnx + intermediate_inputs - layer_output = nn.Dropout(rate=cfg.dropout_rate, broadcast_dims=(-2,))(layer_output, deterministic=deterministic) + layer_output = self.dropout(layer_output, deterministic=deterministic) layer_output = nn.with_logical_constraint( layer_output, diff --git a/src/maxtext/models/llama2.py b/src/maxtext/models/llama2.py index 6a215c5dbe..244eed03bb 100644 --- a/src/maxtext/models/llama2.py +++ b/src/maxtext/models/llama2.py @@ -71,6 +71,7 @@ def __init__( shard_mode=config.shard_mode, kernel_axes=("norm",), epsilon=config.normalization_layer_epsilon, + parameter_memory_host_offload=config.parameter_memory_host_offload, rngs=rngs, ) diff --git a/src/maxtext/models/models.py b/src/maxtext/models/models.py index b4708b10de..5ba365b74b 100644 --- a/src/maxtext/models/models.py +++ b/src/maxtext/models/models.py @@ -33,7 +33,7 @@ from maxtext.layers.decoders import Decoder from maxtext.layers.embeddings import Embed, embed_as_linen from maxtext.layers.encoders import AudioEncoder, VisionEncoder, audio_encoder_as_linen, vision_encoder_as_linen -from maxtext.layers.multi_token_prediction import multi_token_prediction_block_as_linen +from maxtext.layers.multi_token_prediction import MultiTokenPredictionBlock, multi_token_prediction_block_as_linen from maxtext.layers.quantizations import AqtQuantization as Quant from maxtext.multimodal import processor as mm_processor from maxtext.utils import max_utils @@ -386,31 +386,31 @@ def __init__( # For MTP, we use the DecoderLayer blueprint to ensure architectural consistency. # By convention, this is the last layer in the list. mtp_layer = layer_types[-1] - mtp_block_linen = multi_token_prediction_block_as_linen( + self.mtp_block = MultiTokenPredictionBlock( config=self.config, mesh=self.mesh, transformer_layer_module=mtp_layer, decoder=self.decoder, rngs=rngs, - name="mtp_block", - ) - self.mtp_block = nnx_wrappers.ToNNX(mtp_block_linen, rngs=rngs) - - self.mtp_block.lazy_init( - shared_embedding=self.token_embedder, - main_hidden_state=jnp.ones((1, 1, self.config.emb_dim), dtype=self.config.dtype), - input_ids=jnp.ones((1, 1), dtype=jnp.int32), - target_ids=jnp.ones((1, 1), dtype=jnp.int32), - target_mask=jnp.ones((1, 1), dtype=jnp.int32), - position_ids=jnp.ones((1, 1), dtype=jnp.int32), - decoder_segment_ids=jnp.ones((1, 1), dtype=jnp.int32), - deterministic=True, ) def no_op(self, *args, **kwargs): """A no-op method to allow the model to be used in a lazy context.""" return + def logits_from_hidden_states(self, hidden_states, deterministic, model_mode): + """Compute logits from hidden states (wraps NNXDecoder.apply_output_head). + + Mirrors the Linen TransformerLinenPure.logits_from_hidden_states method; + used by vocabulary tiling to recompute logits from chunked hidden states. + """ + return self.decoder.apply_output_head( + shared_embedding=self.token_embedder, + y=hidden_states, + deterministic=deterministic, + model_mode=model_mode, + ) + def init_cache(self, cache_size: int, batch_size: int, dtype=jnp.float32): """Initializes the KV cache for the Transformer. @@ -522,11 +522,7 @@ def __call__( previous_chunk=previous_chunk, slot=slot, page_state=page_state, - image_embeddings=multimodal_input.image_embeddings if multimodal_input is not None else None, - image_masks=multimodal_input.image_masks if multimodal_input is not None else None, - audio_embeddings=multimodal_input.audio_embeddings if multimodal_input is not None else None, - audio_masks=multimodal_input.audio_masks if multimodal_input is not None else None, - bidirectional_mask=multimodal_input.bidirectional_mask if multimodal_input is not None else None, + multimodal_input=multimodal_input, kv_caches=kv_caches, attention_metadata=attention_metadata, deepstack_visual_embeds=deepstack_visual_embeds, diff --git a/src/maxtext/models/olmo3.py b/src/maxtext/models/olmo3.py index a3a8b6997d..9d68d6a57d 100644 --- a/src/maxtext/models/olmo3.py +++ b/src/maxtext/models/olmo3.py @@ -30,6 +30,7 @@ from maxtext.common.common_types import AttentionType, Config from maxtext.layers import attentions from maxtext.layers import initializers +from maxtext.layers import linears from maxtext.layers import nnx_wrappers from maxtext.layers import quantizations from maxtext.layers.attentions import Attention @@ -140,6 +141,7 @@ def __init__( model_mode=model_mode, rngs=rngs, ) + self.dropout = linears.Dropout(rate=config.dropout_rate, broadcast_dims=(-2,), rngs=rngs) def __call__( self, @@ -200,7 +202,7 @@ def __call__( mlp_lnx = nn.with_logical_constraint(mlp_lnx, ("activation_batch", "activation_norm_length", "activation_embed")) layer_output = mlp_lnx + intermediate_inputs - layer_output = nn.Dropout(rate=cfg.dropout_rate, broadcast_dims=(-2,))(layer_output, deterministic=deterministic) + layer_output = self.dropout(layer_output, deterministic=deterministic) layer_output = nn.with_logical_constraint( layer_output, diff --git a/src/maxtext/models/qwen3.py b/src/maxtext/models/qwen3.py index 3491663ebe..4f6094516f 100644 --- a/src/maxtext/models/qwen3.py +++ b/src/maxtext/models/qwen3.py @@ -966,7 +966,7 @@ def __init__( # First LayerNorm, applied before the attention block. self.input_layernorm = Qwen3NextRMSNorm( num_features=cfg.emb_dim, - eps=cfg.normalization_layer_epsilon, + epsilon=cfg.normalization_layer_epsilon, dtype=cfg.dtype, weight_dtype=cfg.weight_dtype, rngs=rngs, @@ -991,7 +991,7 @@ def __init__( # Second LayerNorm, applied before the MoE block. self.post_attention_layernorm = Qwen3NextRMSNorm( num_features=cfg.emb_dim, - eps=cfg.normalization_layer_epsilon, + epsilon=cfg.normalization_layer_epsilon, dtype=cfg.dtype, weight_dtype=cfg.weight_dtype, rngs=rngs, diff --git a/src/maxtext/trainers/diloco/diloco.py b/src/maxtext/trainers/diloco/diloco.py index a9ef64631a..39d84a89dc 100644 --- a/src/maxtext/trainers/diloco/diloco.py +++ b/src/maxtext/trainers/diloco/diloco.py @@ -26,6 +26,7 @@ from typing import Any, Callable import drjax +from flax import nnx from flax import struct from flax.training import train_state import jax @@ -153,7 +154,15 @@ def add_diloco_dim(x): momentum=config.diloco_outer_momentum, nesterov=True, ) - outer_opt_state = jax.eval_shape(outer_optimizer.init, abstract_state.params) + # For NNX, model params (Param variables only) live under abstract_state.model; + # for Linen under abstract_state.params. + if config.pure_nnx: + model_params = abstract_state.model.filter(nnx.Param) + model_params_sharding = state_mesh_shardings.model.filter(nnx.Param) + else: + model_params = abstract_state.params + model_params_sharding = state_mesh_shardings.params + outer_opt_state = jax.eval_shape(outer_optimizer.init, model_params) # Create abstract step abstract_step = jax.ShapeDtypeStruct((), jnp.int32) @@ -161,7 +170,7 @@ def add_diloco_dim(x): # Build abstract DiLoCo state diloco_state = DiLoCoTrainState( inner_state=inner_state, - params=abstract_state.params, + params=model_params, outer_opt_state=outer_opt_state, step=abstract_step, ) @@ -171,12 +180,12 @@ def add_diloco_dim(x): # Sharding for outer_opt_state. For SGD with momentum, it is (TraceState(trace=...), EmptyState()) # We shard the momentum trace the same way as the parameters. outer_opt_state_sharding = ( - optax.TraceState(trace=state_mesh_shardings.params), + optax.TraceState(trace=model_params_sharding), optax.EmptyState(), ) diloco_state_shardings = DiLoCoTrainState( inner_state=inner_state_shardings, - params=state_mesh_shardings.params, + params=model_params_sharding, outer_opt_state=outer_opt_state_sharding, step=None, ) @@ -205,11 +214,15 @@ def init_diloco_state() -> tuple[DiLoCoTrainState, PyTree]: # mesh automatically when jax.set_mesh is used. inner_state = drjax.broadcast(state, mesh=mesh) # Outer state retains a single copy of the model parameters and optimizer state. - outer_params = state.params + # For NNX, model params (Param variables only) live under state.model; + # for Linen under state.params. + outer_params = state.model.filter(nnx.Param) if config.pure_nnx else state.params outer_opt_state = outer_optimizer.init(outer_params) outer_opt_state_sharding = jax.tree_util.tree_map(lambda x: x.sharding, outer_opt_state) + # For NNX, the step counter lives at state.optimizer.step; for Linen at state.step. + step = state.optimizer.step if config.pure_nnx else state.step return ( - DiLoCoTrainState(inner_state=inner_state, params=outer_params, outer_opt_state=outer_opt_state, step=state.step), + DiLoCoTrainState(inner_state=inner_state, params=outer_params, outer_opt_state=outer_opt_state, step=step), outer_opt_state_sharding, ) @@ -244,7 +257,11 @@ def synchronize(state): # Calculate the delta between the current replica's state and the global # state (since last synchronization). broadcast_outer_params = drjax.broadcast(state.params, mesh=mesh) - model_delta = jax.tree.map(lambda x, y: y - x, state.inner_state.params, broadcast_outer_params) + # For NNX, model Param vars live under inner_state.model; for Linen under inner_state.params. + inner_model_params = ( + nnx.filter_state(state.inner_state.model, nnx.Param) if config.pure_nnx else state.inner_state.params + ) + model_delta = jax.tree.map(lambda x, y: y - x, inner_model_params, broadcast_outer_params) # Treat the average delta as the outer optimizer's gradient and apply to # the global (outer) model params. averaged_pseudo_grad = drjax.reduce_mean(model_delta) @@ -253,7 +270,27 @@ def synchronize(state): # Replace inner model params with the new global model params. # NOTE: inner optimizer state is retained despite the change in parameters, # see section 6.1 in https://arxiv.org/pdf/2311.08105. - new_inner_state = drjax.map_fn(lambda state: state.replace(params=new_outer_params), state.inner_state, mesh=mesh) + if config.pure_nnx: + # For NNX: merge new Param vars back with the non-Param model vars (e.g. RNG state). + def replace_nnx_model_params(s, new_params): + non_param_model = nnx.filter_state(s.model, nnx.Not(nnx.Param)) + new_model = nnx.merge_state(non_param_model, new_params) + # Build result via __setitem__ so nested States are stored as plain dicts + # internally, matching the pytree structure produced by nnx.state(). + # (Passing State objects via the constructor dict literal stores them + # as-is, causing jax.lax.cond to see mismatched pytree structures.) + result = type(s)({}) + result["model"] = new_model + result["optimizer"] = s["optimizer"] + return result + + new_inner_state = drjax.map_fn( + lambda s: replace_nnx_model_params(s, new_outer_params), + state.inner_state, + mesh=mesh, + ) + else: + new_inner_state = drjax.map_fn(lambda s: s.replace(params=new_outer_params), state.inner_state, mesh=mesh) return state.replace( params=new_outer_params, outer_opt_state=new_opt_state, @@ -271,14 +308,16 @@ def diloco_train_step(state, batch, prng): broadcast_rng = drjax.broadcast(prng, mesh=mesh) inner_state, metrics = drjax.map_fn(train_step, (state.inner_state, batch, broadcast_rng), mesh=mesh) avg_metrics = typed_reduce_mean(metrics) + # For NNX, the step counter lives at inner_state.optimizer.step; for Linen at inner_state.step. + new_step = inner_state.optimizer.step[0] if config.pure_nnx else inner_state.step[0] state = state.replace( inner_state=inner_state, - step=inner_state.step[0], + step=new_step, ) # Either synchronize the model, or no-op, depending on whether the current # step falls on the synchronization period. state = jax.lax.cond( - inner_state.step[0] % config.diloco_sync_period == 0, + new_step % config.diloco_sync_period == 0, synchronize, lambda x: x, # no-op state, diff --git a/src/maxtext/trainers/pre_train/train.py b/src/maxtext/trainers/pre_train/train.py index 6038cad7b6..97e043c7f7 100644 --- a/src/maxtext/trainers/pre_train/train.py +++ b/src/maxtext/trainers/pre_train/train.py @@ -71,7 +71,7 @@ from maxtext.utils import maxtext_utils_nnx from maxtext.utils import train_utils from maxtext.utils.gradient_accumulation import gradient_accumulation_loss_and_grad -from maxtext.utils.vocabulary_tiling import vocab_tiling_linen_loss +from maxtext.utils.vocabulary_tiling import vocab_tiling_linen_loss, vocab_tiling_nnx_loss _diag_modules = _cloud_diag() diagnostic, debug_configuration, diagnostic_configuration, stack_trace_configuration = _diag_modules @@ -199,9 +199,10 @@ def loss_fn(model, config, data, dropout_rng, params, sparsity_state=None, is_tr intermediate_outputs = nnx.state(model, nnx.Intermediate).to_pure_dict() if config.num_vocab_tiling > 1: - raise NotImplementedError("Vocab tiling for NNX modules has not been implemented.") - - if (config.use_indexer and not config.indexer_sparse_training) and is_train: + hidden_state_key = ("decoder", "hidden_states") + hidden_states = maxtext_utils.get_nested_value(intermediate_outputs, hidden_state_key)[0] + xent_sum, total_z_loss = vocab_tiling_nnx_loss(model, hidden_states, data, config, is_train) + elif (config.use_indexer and not config.indexer_sparse_training) and is_train: # In Dense Warm-up stage, we skip main model loss calculation for efficiency. # The main model parameters are frozen and only the indexer is trained via KL divergence. xent_sum = 0.0 @@ -319,7 +320,12 @@ def train_step(model, config, state_mesh_shardings, params_shardings, state, dat ga_fn, ga_model, ga_params, ga_rng, ga_dpo = _loss_fn, model, params, dropout_rng, extra_dpo_args else: if config.use_dpo: - raise NotImplementedError("DPO for NNX modules has not been implemented.") + raise NotImplementedError( + "DPO is not yet supported for NNX modules. DPO requires a reference model " + "stored alongside the policy model (Linen path uses state.params['reference_params']); " + "the NNX TrainState equivalent has not been wired up. As a workaround, set " + "pure_nnx=False for DPO runs." + ) state = nnx.merge(model, state) # reconstruct TrainStateNNX ga_fn, ga_model, ga_params, ga_rng, ga_dpo = loss_fn, state.model, None, None, [] @@ -545,7 +551,11 @@ def move(path, value): if config.use_dpo: new_state = _merge_dpo_state(new_state, reference_params) return new_state, metrics - return nnx.state(new_state), metrics + # Exclude Intermediate variables (e.g., sowed max_logits for QK-Clip) from the + # returned state. Intermediates are transient forward-pass artifacts and must not + # persist across steps: they're absent from the abstract state used to build + # state_mesh_shardings, so including them would cause a leaf-count mismatch in JAX. + return nnx.state(new_state, nnx.Not(nnx.Intermediate)), metrics def eval_step(model, config, state, data, dropout_rng=None): diff --git a/src/maxtext/trainers/pre_train/train_compile.py b/src/maxtext/trainers/pre_train/train_compile.py index a2981f67ed..c593d3c540 100644 --- a/src/maxtext/trainers/pre_train/train_compile.py +++ b/src/maxtext/trainers/pre_train/train_compile.py @@ -30,6 +30,7 @@ from flax import nnx from flax.linen import partitioning as nn_partitioning import jax +import jax.numpy as jnp from jax.experimental.serialize_executable import serialize from jax.experimental.topologies import get_topology_desc from jax.sharding import AxisType, Mesh @@ -92,6 +93,27 @@ def get_topology_mesh(config): return topology_mesh +def _collect_nnx_activation_shardings(create_model_fn, config, mesh): + """Run an NNX forward pass in abstract mode to populate _ACTIVATION_SHARDINGS_DUMP. + + get_abstract_state_nnx uses nnx.eval_shape which only traces model initialization, + not __call__. Activation shardings are only collected during a forward pass. + """ + input_shape = (config.micro_batch_size_to_train_on, config.max_target_length) + + def _nnx_forward(): + model_instance = create_model_fn() + return model_instance( + decoder_input_tokens=jnp.ones(input_shape, dtype=jnp.int32), + decoder_positions=jnp.ones(input_shape, dtype=jnp.int32), + decoder_segment_ids=jnp.ones(input_shape, dtype=jnp.int32), + enable_dropout=False, + ) + + with nn_partitioning.axis_rules(config.logical_axis_rules): + jax.eval_shape(_nnx_forward) + + def get_shaped_inputs(topology_mesh, config): """Get shaped abstractions of inputs to train_step: state, batch and rng""" # Construct the model and optimizer to get shaped versions of the state @@ -129,7 +151,8 @@ def create_train_state_fn(): # For NNX, get_functional_train_with_signature expects the graphdef (static structure), # not the raw model — mirroring how the training loop does nnx.split(train_state). with nn_partitioning.axis_rules(config.logical_axis_rules): - graphdef, _ = nnx.get_abstract_model(init_state_fn, topology_mesh) + abs_train_state = nnx.eval_shape(init_state_fn) + graphdef, _ = nnx.split(abs_train_state) model = graphdef else: # unsharded logical annotations @@ -139,10 +162,17 @@ def create_train_state_fn(): shaped_batch = maxtext_utils.get_shaped_batch(config) if config.pure_nnx: - shaped_train_args = (abstract_state, shaped_batch, None) # NNX doesn't use dropout_rng + shaped_train_args = (abstract_state, shaped_batch) # NNX doesn't use dropout_rng else: shaped_train_args = (abstract_state, shaped_batch, shaped_rng) shaped_train_kwargs = {} + + # Collect activation shardings for NNX by running an abstract forward pass. + # This must happen after get_abstract_state (which uses nnx.eval_shape and only + # traces __init__, not __call__). + if config.debug_sharding and config.pure_nnx: + _collect_nnx_activation_shardings(_create_model_partial, config, topology_mesh) + return shaped_train_args, shaped_train_kwargs, state_mesh_shardings, logical_annotations, model @@ -280,7 +310,9 @@ def main(argv: Sequence[str]) -> None: diloco_state, state_mesh_shardings, inner_state_shardings = diloco.build_abstract_diloco_state( config, abstract_state, state_mesh_shardings, topology_mesh ) - shaped_train_args = (diloco_state, shaped_train_args[1], shaped_train_args[2]) + # For NNX, shaped_train_args has 2 elements (state, batch) — no rng; pass None for prng. + shaped_rng_arg = shaped_train_args[2] if len(shaped_train_args) > 2 else None + shaped_train_args = (diloco_state, shaped_train_args[1], shaped_rng_arg) # Wrap train_step with diloco train_step_partial = functools.partial(train.train_step, model, config, inner_state_shardings, None) diff --git a/src/maxtext/utils/generate_param_only_checkpoint.py b/src/maxtext/utils/generate_param_only_checkpoint.py index 2fd14b87a2..0f997a6577 100644 --- a/src/maxtext/utils/generate_param_only_checkpoint.py +++ b/src/maxtext/utils/generate_param_only_checkpoint.py @@ -90,20 +90,17 @@ def slice_ith(input_layers): def _read_train_checkpoint(config, checkpoint_manager, mesh): """Read training checkpoint at path defined by load_full_state_path.""" - # Model and Optimizer definition + # Model and Optimizer definition. + # This script reads a Linen-format full state and emits a Linen-format + # parameter-only checkpoint (downstream `_possibly_unroll_params` and + # `_save_decode_checkpoint` access `state.params["params"]["decoder"]` / `state.opt_state`, + # both Linen-only). Use the Linen path regardless of pure_nnx. quant = quantizations.configure_quantization(config) - if config.pure_nnx: - raise NotImplementedError("Pure NNX support has not been implemented yet.") - else: - model = models.transformer_as_linen(config, mesh, quant, MODEL_MODE_TRAIN) + model = models.transformer_as_linen(config, mesh, quant, MODEL_MODE_TRAIN) rng = random.PRNGKey(0) learning_rate_schedule = maxtext_utils.create_learning_rate_schedule(config) tx = optimizers.get_optimizer(config, learning_rate_schedule) - if config.pure_nnx: - # NNX has a different function to init the training state. - raise NotImplementedError("Pure NNX support has not been implemented yet.") - else: - init_state_fn = functools.partial(maxtext_utils.init_initial_state, model, tx, config, True, rng) + init_state_fn = functools.partial(maxtext_utils.init_initial_state, model, tx, config, True, rng) state, state_mesh_notations, _, _ = maxtext_utils.setup_training_state( None, config, mesh, checkpoint_manager, init_state_fn ) @@ -114,12 +111,11 @@ def _read_train_checkpoint(config, checkpoint_manager, mesh): def _generate_lora_decode_checkpoints(config, mesh): """Read lora checkpoints checkpoint at path defined by load_full_state_path.""" - # Model and Optimizer definition + # Model and Optimizer definition. + # LoRA adapters and downstream `_save_decode_checkpoint`/`_possibly_unroll_params` + # are Linen-shaped; use the Linen path regardless of pure_nnx. quant = quantizations.configure_quantization(config) - if config.pure_nnx: - raise NotImplementedError("Pure NNX support has not been implemented yet.") - else: - model = models.transformer_as_linen(config, mesh, quant, MODEL_MODE_TRAIN) + model = models.transformer_as_linen(config, mesh, quant, MODEL_MODE_TRAIN) rng = random.PRNGKey(0) learning_rate_schedule = maxtext_utils.create_learning_rate_schedule(config) tx = optimizers.get_optimizer(config, learning_rate_schedule) diff --git a/src/maxtext/utils/gradient_accumulation.py b/src/maxtext/utils/gradient_accumulation.py index e1699647c6..cf84577dbd 100644 --- a/src/maxtext/utils/gradient_accumulation.py +++ b/src/maxtext/utils/gradient_accumulation.py @@ -71,10 +71,16 @@ def _maybe_shard_with_name(inputs, sharding_names): is_nnx = isinstance(model, nnx.Module) - # For more efficient DP/ZeRO-1 + GA - if config.shard_mode == ShardMode.EXPLICIT and config.ici_data_parallelism > 1: - ga_params_shardings = jax.tree.map(update_sharding_for_reduced, params_shardings) - grad_shardings = jax.tree.map(update_sharding_for_unreduced, params_shardings) + # For more efficient DP/ZeRO-1 + GA. + # config.ici_data_parallelism may be -1 (auto-fill: resolved at mesh creation time, but + # the config field remains -1). Treat any value != 1 as "data parallelism is active". + if config.shard_mode == ShardMode.EXPLICIT and config.ici_data_parallelism != 1: + # jax.lax.scan traces its body with an AbstractMesh where all axis types are Auto, + # which rejects reduced/unreduced PartitionSpec in scan carry tensors (raises ValueError). + # Use plain params_shardings for ga_params and init_grad in the carry. + # The all-reduce for data parallelism is applied to raw_grads after the scan instead. + ga_params_shardings = params_shardings + grad_shardings = params_shardings else: ga_params_shardings = grad_shardings = params_shardings @@ -105,7 +111,7 @@ def accumulate_gradient(acc_grad_and_loss, data): if is_nnx: # Reconstruct the model using the fixed parameters (ga_params) # and the advancing non-parameter state (RNGs) from the carry. - local_model = nnx.merge(graphdef, ga_params, acc_grad_and_loss["rest_state"]) + local_model = nnx.merge(graphdef, ga_params, acc_grad_and_loss["rest_state"], copy=True) (_, aux), cur_batch_gradient = grad_func(local_model, config, data, None, None, *extra_dpo_args, is_train=True) _, _, next_rest_state = nnx.split(local_model, nnx.Param, ...) acc_grad_and_loss["rest_state"] = next_rest_state @@ -156,6 +162,11 @@ def reshape_to_microbatch_accumulations(batch_arr): + grad_and_loss["mtp_loss"] / config.gradient_accumulation_steps ) raw_grads = grad_and_loss["grad"] + if config.shard_mode == ShardMode.EXPLICIT and config.ici_data_parallelism != 1: + # Apply unreduced annotation after the scan to trigger all-reduce across data-parallel + # devices (reduced/unreduced cannot be used inside jax.lax.scan carry tensors). + unreduced_shardings = jax.tree.map(update_sharding_for_unreduced, params_shardings) + raw_grads = jax.tree.map(_maybe_shard_with_name, raw_grads, unreduced_shardings) raw_grads = jax.tree.map(_maybe_shard_with_name, raw_grads, params_shardings) raw_grads = jax.tree_util.tree_map(lambda arr: arr / grad_and_loss["total_weights"], raw_grads) aux = jax.tree.map(lambda x: jnp.sum(x, axis=0), aux) # pytype: disable=module-attr diff --git a/src/maxtext/utils/layerwise_quantization.py b/src/maxtext/utils/layerwise_quantization.py index 29fa928656..a6c1c07f67 100644 --- a/src/maxtext/utils/layerwise_quantization.py +++ b/src/maxtext/utils/layerwise_quantization.py @@ -173,19 +173,15 @@ def __init__(self, config: Any, rng: PRNGKeyType): devices_array = maxtext_utils.create_device_mesh(config=config) self._mesh = jax.sharding.Mesh(devices_array, config.mesh_axes) - # Model and quantization config + # Model and quantization config. + # This script produces and consumes Linen-format checkpoints (see DeepSeek*ToLinen + # layer classes used in load_and_quantize). Always use the Linen path internally, + # regardless of the pure_nnx flag — the flag affects training, not checkpoint format. self.quant = quantizations.configure_quantization(config) - if self.config.pure_nnx: - raise NotImplementedError("Pure NNX support has not been implemented yet.") - else: - model = models.transformer_as_linen( - config, mesh=self._mesh, quant=self.quant, model_mode=common_types.MODEL_MODE_TRAIN - ) - if self.config.pure_nnx: - # NNX has a different function to init the training state. - raise NotImplementedError("Pure NNX support has not been implemented yet.") - else: - init_state_fn = functools.partial(maxtext_utils.init_initial_state, model, None, self.config, False, self.rng) + model = models.transformer_as_linen( + config, mesh=self._mesh, quant=self.quant, model_mode=common_types.MODEL_MODE_TRAIN + ) + init_state_fn = functools.partial(maxtext_utils.init_initial_state, model, None, self.config, False, self.rng) self.unboxed_abstract_state, _, _ = maxtext_utils.get_abstract_state(self.config, self._mesh, init_state_fn, False) diff --git a/src/maxtext/utils/lora_utils.py b/src/maxtext/utils/lora_utils.py index 24099ef22a..76cd26d20e 100644 --- a/src/maxtext/utils/lora_utils.py +++ b/src/maxtext/utils/lora_utils.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -""" Common LoRA utils needed to support LoRA adapters.""" +"""Common LoRA utils needed to support LoRA adapters.""" from functools import partial import json @@ -167,11 +167,10 @@ def setup_initial_lora_state(model, data_iterator, tx, config, rng, mesh, checkp if lora_adapter_path: max_logging.log(f"Setting initial state of LoRA with lora_adapter_path = {lora_adapter_path}") - if config.pure_nnx: - # NNX has a different function to init the training state. - raise NotImplementedError("Pure NNX support has not been implemented yet.") - else: - init_state_fn = partial(maxtext_utils.init_initial_state, model, tx, config, True, rng) + # LoRA adapters are Linen-format on disk (downstream `get_lora_abstract_state` expects + # `unboxed_abstract_state.params` Linen tree shape; `lora_state.replace(params=...)` + # uses Linen TrainState API). Use the Linen init path regardless of the pure_nnx flag. + init_state_fn = partial(maxtext_utils.init_initial_state, model, tx, config, True, rng) unboxed_abstract_state, _, _ = maxtext_utils.get_abstract_state(config, mesh, init_state_fn, True) lora_config_path = lora_adapter_path + "adapter_config.json" diff --git a/src/maxtext/utils/muon_utils.py b/src/maxtext/utils/muon_utils.py index 3bd2b186b1..049a084979 100644 --- a/src/maxtext/utils/muon_utils.py +++ b/src/maxtext/utils/muon_utils.py @@ -116,6 +116,7 @@ def apply_transform_nnx(path: Tuple[jax.tree_util.KeyEntry, ...], leaf): # Use jax.tree_util.tree_map_with_path for NNX's potentially complex PyTree structure. # This is different with linen where abstract_param is a dict-based tree with nn.LogicallyPartitioned leaves. + # The result is an nnx.State with the same structure, where each Param's value holds the mdn result. muon_weight_dimension_numbers = jax.tree_util.tree_map_with_path(apply_transform_nnx, abstract_param) else: # Linen @@ -154,7 +155,7 @@ def get_leaf_info(leaf): print("\nIs this reasonable?") -def get_model_mdn(model_name, scan_layers=True, verbose=False, pure_nnx=False): +def get_model_mdn(model_name, scan_layers=True, verbose=False, pure_nnx=True): """Initializes a model and retrieves its Muon dimension numbers. This function sets up the configuration for a given model, initializes the @@ -191,6 +192,8 @@ def get_model_mdn(model_name, scan_layers=True, verbose=False, pure_nnx=False): model = models.transformer_as_linen(config, mesh=mesh, quant=quant) # Get dimension number muon_weight_dimension_numbers = get_muon_weight_dimension_numbers(model, config, verbose=verbose) + if pure_nnx: + muon_weight_dimension_numbers = {"params": nnx.to_pure_dict(muon_weight_dimension_numbers)} return muon_weight_dimension_numbers diff --git a/src/maxtext/utils/standalone_checkpointer.py b/src/maxtext/utils/standalone_checkpointer.py index ba6b148b04..2fc2b09e25 100644 --- a/src/maxtext/utils/standalone_checkpointer.py +++ b/src/maxtext/utils/standalone_checkpointer.py @@ -52,18 +52,15 @@ def checkpoint_loop(config, state=None): Returns: """ - if config.pure_nnx: - raise NotImplementedError("Pure NNX support has not been implemented yet.") - else: - model = from_config(config) + # Standalone checkpointer is a save/restore exerciser that uses + # add_entropy_to_checkpoint() to populate Linen-shaped optimizer state + # (state.opt_state, state.params). Use the Linen path regardless of pure_nnx — + # the flag affects training, not this checkpoint test harness. + model = from_config(config) mesh = model.mesh init_rng = jax.random.PRNGKey(config.init_weights_seed) _, tx = train_utils.create_training_optimizer(config, model) - if config.pure_nnx: - # NNX has a different function to init the training state. - raise NotImplementedError("Pure NNX support has not been implemented yet.") - else: - init_state_fn = partial(maxtext_utils.init_initial_state, model, tx, config, True, init_rng) + init_state_fn = partial(maxtext_utils.init_initial_state, model, tx, config, True, init_rng) checkpoint_manager = train_utils.create_checkpoint_manager(config, mesh, init_state_fn) unboxed_abstract_state, _, _ = maxtext_utils.get_abstract_state(config, mesh, init_state_fn, is_training=True) diff --git a/src/maxtext/utils/vocabulary_tiling.py b/src/maxtext/utils/vocabulary_tiling.py index e7b155416c..6a61f9ed23 100644 --- a/src/maxtext/utils/vocabulary_tiling.py +++ b/src/maxtext/utils/vocabulary_tiling.py @@ -247,3 +247,110 @@ def _bwd_scan_body(grad_params_acc, chunk_data): ) return total_loss, total_z_loss + + +def vocab_tiling_nnx_loss(model, hidden_states, data, config, is_train): + """Calculates cross-entropy loss using vocab tiling for NNX models. + + NNX equivalent of `vocab_tiling_linen_loss`. Iterates the vocab dimension via + `jax.lax.scan` with `model.logits_from_hidden_states` per chunk; the model + carries its parameters internally so no explicit gather is needed. + + This is a memory-efficient forward (chunked logits) but uses the default + autograd path (no custom_vjp), so backward memory savings vs. the Linen + custom_vjp path are not yet realized. TODO: add a custom_vjp using + `nnx.split`/`nnx.merge` if backward memory becomes a concern. + + Args: + model: The NNX model instance (must implement `logits_from_hidden_states`). + hidden_states: The final hidden states from the decoder. + data: A dictionary containing the input data, including 'targets' and 'targets_segmentation'. + config: The model and training configuration. + is_train: A boolean indicating if the model is in training mode. + + Returns: + A tuple (total_loss, total_z_loss). + """ + labels = data["targets"] + segmentation = data["targets_segmentation"] + deterministic = not config.enable_dropout if is_train else True + model_mode = "train" + + hidden_spec = create_sharding( + model.mesh, + ("activation_embed_and_logits_batch", "activation_length", "activation_embed"), + ) + label_spec = create_sharding( + model.mesh, + ("activation_embed_and_logits_batch", "activation_length"), + ) + reshaped_hidden_spec = create_sharding( + model.mesh, + ("num_tile", "activation_embed_and_logits_batch_sequence", "activation_embed"), + ) + reshaped_data_spec = create_sharding( + model.mesh, + ("num_tile", "activation_embed_and_logits_batch_sequence"), + ) + chunked_hidden_spec = create_sharding( + model.mesh, + ("activation_embed_and_logits_batch_sequence", "activation_embed"), + ) + chunked_data_spec = create_sharding( + model.mesh, + ("activation_embed_and_logits_batch_sequence",), + ) + chunked_logits_spec = create_sharding( + model.mesh, + ("activation_embed_and_logits_batch_sequence", "activation_vocab"), + ) + + _maybe_shard_with_name = functools.partial( + maybe_shard_with_name, + shard_mode=config.shard_mode, + debug_sharding=config.debug_sharding, + extra_stack_level=1, + ) + + def _reshape(inputs, out_shape, out_sharding): + reshape_out_sharding = out_sharding if config.shard_mode == ShardMode.EXPLICIT else None + inputs = jax.lax.reshape(inputs, out_shape, out_sharding=reshape_out_sharding) + return _maybe_shard_with_name(inputs, out_sharding) + + hidden_states = _maybe_shard_with_name(hidden_states, hidden_spec) + labels = _maybe_shard_with_name(labels, label_spec) + segmentation = _maybe_shard_with_name(segmentation, label_spec) + + batch_size, seq_len, emb_dim = hidden_states.shape + vocab_tile_size = (batch_size * seq_len) // config.num_vocab_tiling + + reshaped_hidden_states = _reshape( + hidden_states, (config.num_vocab_tiling, vocab_tile_size, emb_dim), reshaped_hidden_spec + ) + reshaped_labels = _reshape(labels, (config.num_vocab_tiling, vocab_tile_size), reshaped_data_spec) + reshaped_segmentation = _reshape(segmentation, (config.num_vocab_tiling, vocab_tile_size), reshaped_data_spec) + + def _scan_body(accumulators, chunk_data): + loss_accumulator, z_loss_accumulator = accumulators + hidden_chunk, label_chunk, segmentation_chunk = chunk_data + hidden_chunk = _maybe_shard_with_name(hidden_chunk, chunked_hidden_spec) + label_chunk = _maybe_shard_with_name(label_chunk, chunked_data_spec) + segmentation_chunk = _maybe_shard_with_name(segmentation_chunk, chunked_data_spec) + + chunk_logits = model.logits_from_hidden_states(hidden_chunk, deterministic, model_mode) + chunk_logits = _maybe_shard_with_name(chunk_logits, chunked_logits_spec) + one_hot_label_chunk = jax.nn.one_hot(label_chunk, config.vocab_size) + chunk_xent, chunk_z_loss = max_utils.cross_entropy_with_logits( + chunk_logits, one_hot_label_chunk, z_loss=config.z_loss_multiplier + ) + + masked_xent = jnp.sum(chunk_xent * (segmentation_chunk != 0)) + masked_z_loss = jnp.sum(chunk_z_loss * (segmentation_chunk != 0)) + + return (loss_accumulator + masked_xent, z_loss_accumulator + masked_z_loss), None + + initial_acc = (jnp.zeros((), dtype=hidden_states.dtype), jnp.zeros((), dtype=hidden_states.dtype)) + (total_loss, total_z_loss), _ = jax.lax.scan( + _scan_body, initial_acc, (reshaped_hidden_states, reshaped_labels, reshaped_segmentation) + ) + return total_loss, total_z_loss From a36c658e515b7248f246a5174ed447ae94e0acaa Mon Sep 17 00:00:00 2001 From: Lance Wang Date: Wed, 29 Apr 2026 16:07:35 +0000 Subject: [PATCH 4/4] NNX: native DPO (TrainStateNNX.reference_model + dpo_loss_fn_nnx) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Implements NNX-native DPO so that the pure_nnx=True training path no longer raises NotImplementedError on use_dpo runs. The Linen DPO overlay pattern (model.apply(params=..., reference_params=...)) does not translate to NNX modules, which carry their parameters internally. Instead the policy and reference models are held as separate nnx.Module instances on TrainStateNNX, and a new dpo_loss_fn_nnx runs both forwards with stop_gradient on the reference logits. TrainStateNNX: - Add optional `reference_model: nnx.Module` field. apply_gradients continues to update only `self.model`, leaving `self.reference_model` bit-identical across steps. dpo_utils.py: - Add dpo_loss_fn_nnx(policy_model, config, data, dropout_rng, params, reference_model, is_train=True). Signature mirrors the Linen dpo_loss_fn so it slots into gradient_accumulation_loss_and_grad's dispatcher (dropout_rng / params slots are unused for NNX; carried for parity, and reference_model is passed as the single extra_dpo_args entry). With nnx.value_and_grad(..., argnums=0) over the policy, no gradient flows to the reference model's nnx.Param leaves; the explicit jax.lax.stop_gradient on ref_logits is a belt-and-braces guard. - Both dpo_loss_fn (Linen) and dpo_loss_fn_nnx (NNX) now include indexer_loss=0.0 and mtp_loss=0.0 in aux so the gradient_accumulation aux pytree shape matches the non-DPO loss_fn. train.py: - Drop the NotImplementedError in train_step's NNX branch. When use_dpo, dispatch to dpo_loss_fn_nnx with state.reference_model as extra_dpo_args; otherwise use the regular loss_fn. eval_step gains the same dispatch. - diff_wrapper picks _loss_fn / extra_dpo_args from the per-path init block, so both the GA and non-GA NNX paths route DPO identically. - Checkpoint-save _split_dpo_state stripping is now Linen-only; TrainStateNNX saves whole (reference_model included) — the step-0 reload later overwrites reference_model from the step-0 checkpoint. train_utils.py: - NNX init_state_fn materializes a frozen reference_model alongside the policy when config.use_dpo. Both are constructed by _create_model_partial() with config.init_weights_seed, so they start identical (standard DPO practice) until the step-0 reload. - Step-0 checkpoint reload: copy step0_state["model"] into state["reference_model"]. Linen path unchanged. Tests: - New tests/unit/dpo_nnx_test.py (7 tests): TrainStateNNX reference_model init/hasattr semantics; apply_gradients leaves reference bit-identical; aux key set; identical policy/reference yields loss=log(2) and reward_accuracy=0.0 (strict > on equal logratios); dropout_rng/params slots are signature-compat only; nnx.value_and_grad(argnums=0) over the policy yields finite grads on policy params only. - train_nnx_test.py: drop the two stale negative tests (vocab_tiling_raises_not_implemented, train_step_dpo_raises_for_nnx) — both features are now real. Stats: 4 source files + 2 test files, +199/-22 source lines. Linen DPO path behaviorally unchanged (only adds two harmless aux-dict keys); NNX non-DPO path unchanged (all changes gated on config.use_dpo). --- src/maxtext/layers/train_state_nnx.py | 24 +- .../trainers/post_train/dpo/dpo_utils.py | 139 +++++++++++ src/maxtext/trainers/pre_train/train.py | 34 +-- src/maxtext/utils/train_utils.py | 24 +- tests/unit/dpo_nnx_test.py | 215 ++++++++++++++++++ tests/unit/train_nnx_test.py | 17 -- 6 files changed, 412 insertions(+), 41 deletions(-) create mode 100644 tests/unit/dpo_nnx_test.py diff --git a/src/maxtext/layers/train_state_nnx.py b/src/maxtext/layers/train_state_nnx.py index 9ef0e6dffd..3f9ee1ce29 100644 --- a/src/maxtext/layers/train_state_nnx.py +++ b/src/maxtext/layers/train_state_nnx.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -""" The NNX Unified TrainState. """ +"""The NNX Unified TrainState.""" from typing import Any @@ -25,20 +25,34 @@ class TrainStateNNX(nnx.Module): This replaces Linen's TrainState for checkpointing. Linen TrainState pytree: - {“params”: {...}, “opt_state”: {}...} + {"params": {...}, "opt_state": {}...} TrainStateNNX state pytree: - {“model”: {...}, “optimizer”: {“opt_state”: {...}} + {"model": {...}, "optimizer": {"opt_state": {...}}} + + For DPO (Direct Preference Optimization), an optional `reference_model` + carries a frozen copy of the same architecture used to compute reference + log-probabilities. Only `model` is updated by `apply_gradients`; the + reference is held alongside so it is sharded, jit-traced, and checkpointed + with the rest of the train state. """ - def __init__(self, model: nnx.Module, optimizer: nnx.Optimizer | None): + def __init__( + self, + model: nnx.Module, + optimizer: nnx.Optimizer | None, + reference_model: nnx.Module | None = None, + ): self.model = model self.optimizer = optimizer + if reference_model is not None: + self.reference_model = reference_model def apply_gradients(self, grads: Any): """ Mimics the Linen apply_gradients function. Updates the optimizer state, applies updates to parameters, - and increments the step counter. + and increments the step counter. Only updates `self.model`; + `self.reference_model` (if present) is left untouched. """ if self.optimizer is None: raise RuntimeError( diff --git a/src/maxtext/trainers/post_train/dpo/dpo_utils.py b/src/maxtext/trainers/post_train/dpo/dpo_utils.py index eeda1c1a7f..fd5faa5c9c 100644 --- a/src/maxtext/trainers/post_train/dpo/dpo_utils.py +++ b/src/maxtext/trainers/post_train/dpo/dpo_utils.py @@ -19,6 +19,8 @@ import jax import jax.numpy as jnp +from flax import nnx + from maxtext.utils import maxtext_utils @@ -148,6 +150,8 @@ def dpo_loss_fn(model, config, data, dropout_rng, params, reference_params, is_t "total_weights": total_weights, "moe_lb_loss": moe_lb_loss, "reward_accuracy": reward_accuracy, + "indexer_loss": 0.0, # for gradient_accumulation aux pytree compatibility + "mtp_loss": 0.0, # for gradient_accumulation aux pytree compatibility } return loss, aux @@ -155,3 +159,138 @@ def dpo_loss_fn(model, config, data, dropout_rng, params, reference_params, is_t def _merge_dpo_state(state, reference_params): """Merge reference parameters back into DPO state.""" return state.replace(params=dict(state.params, reference_params=reference_params)) + + +# NNX DPO has no split/merge counterpart: the Linen path overlays +# `reference_params` inside `state.params`, so it must be peeled off and +# reattached around `apply_gradients`. The NNX path holds the reference as a +# sibling field `TrainStateNNX.reference_model`; `apply_gradients` already +# only touches `self.model`, so no split/merge is needed. + + +def dpo_loss_fn_nnx(policy_model, config, data, dropout_rng, params, reference_model, is_train=True): + """NNX DPO loss_fn for both train and eval. + + Signature mirrors the Linen `dpo_loss_fn` so it slots into the same + dispatcher in `gradient_accumulation_loss_and_grad`: + `(model, config, data, dropout_rng, params, *extra_dpo_args, is_train=True)` + + Differences from the Linen `dpo_loss_fn`: + * `policy_model` is an `nnx.Module` (carries its own params + RNG state). + * `dropout_rng` and `params` are unused for NNX (kept positional for + signature parity; NNX models manage these internally). + * The 6th arg (the `extra_dpo_args[0]`) is a frozen reference + `nnx.Module`, not a `reference_params` pytree. + * Reference forward is wrapped in `jax.lax.stop_gradient`; combined with + `nnx.value_and_grad(..., argnums=0)` over the policy, no gradient flows + to the reference's `nnx.Param` leaves. + + Args: + policy_model: Policy `nnx.Module` (the model being trained). + config: Config of parameters. + data: Batch of preference data with `chosen` / `rejected` fields. + dropout_rng: Unused for NNX (kept for signature parity with Linen). + params: Unused for NNX (kept for signature parity with Linen). + reference_model: Frozen reference `nnx.Module` for DPO logratio computation. + is_train: True for train_step and False for eval_step. + + Returns: + loss: DPO preference loss + MoE load balance loss (if applicable). + aux: dict with intermediate_outputs, xent_sum (always 0.0), dpo_loss, + total_weights, moe_lb_loss, reward_accuracy. + """ + del dropout_rng, params # unused for NNX + # decimate proportion of data when per_device_batch_size<1 + if is_train: + for k, v in data.items(): + data[k] = v[: config.micro_batch_size_to_train_on, :] + + # for DPO we don't support packed sequences (they shouldn't be present in the first place) + data["chosen_segmentation"] = (data["chosen_segmentation"] == 1).astype(jnp.int32) + data["rejected_segmentation"] = (data["rejected_segmentation"] == 1).astype(jnp.int32) + data["chosen_position"] = data["chosen_position"] * (data["chosen_segmentation"] == 1) + data["rejected_position"] = data["rejected_position"] * (data["rejected_segmentation"] == 1) + + # concatenated policy/reference forward pass + inputs = jnp.concatenate([data["chosen"], data["rejected"]], 0) + inputs_position = jnp.concatenate([data["chosen_position"], data["rejected_position"]], 0) + inputs_segmentation = jnp.concatenate([data["chosen_segmentation"], data["rejected_segmentation"]], 0) + + logits = policy_model( + decoder_input_tokens=inputs, + decoder_positions=inputs_position, + decoder_segment_ids=inputs_segmentation, + enable_dropout=config.enable_dropout if is_train else False, + ) + intermediate_outputs = nnx.state(policy_model, nnx.Intermediate).to_pure_dict() + + ref_logits = reference_model( + decoder_input_tokens=inputs, + decoder_positions=inputs_position, + decoder_segment_ids=inputs_segmentation, + enable_dropout=False, + ) + ref_logits = jax.lax.stop_gradient(ref_logits) + + # extract token ids, segmentation and logits for chosen and rejected sequences + chosen_ids = data["chosen"][..., 1:] + rejected_ids = data["rejected"][..., 1:] + chosen_segmentation = data["chosen_segmentation"][..., 1:] + rejected_segmentation = data["rejected_segmentation"][..., 1:] + n_logits = logits.shape[-3] // 2 # [B, S, E] - [batch, sequence, embedding/vocab] + chosen_logits, rejected_logits = logits[:n_logits, :, :], logits[n_logits:, :, :] + chosen_ref_logits, rejected_ref_logits = ref_logits[:n_logits, :, :], ref_logits[n_logits:, :, :] + + # common subsequence and padding mask + common_prefix_mask = jnp.cumsum(chosen_ids != rejected_ids, axis=-1) == 0 # [B, S] + valid_seq_mask = (chosen_segmentation != 0) & (rejected_segmentation != 0) & ~common_prefix_mask # [B, S] + + # compute logratios from the sequence-reduced observed token log-probability + chosen_logps_seq = jnp.take_along_axis( # [B, S] + jax.nn.log_softmax(chosen_logits[..., :-1, :], axis=-1), chosen_ids[..., None], axis=-1 + )[..., 0] + chosen_logps = jnp.sum(chosen_logps_seq * valid_seq_mask, axis=-1) # [B] + chosen_ref_logps_seq = jnp.take_along_axis( # [B, S] + jax.nn.log_softmax(chosen_ref_logits[..., :-1, :], axis=-1), chosen_ids[..., None], axis=-1 + )[..., 0] + chosen_ref_logps = jnp.sum(chosen_ref_logps_seq * valid_seq_mask, axis=-1) # [B] + chosen_logratios = chosen_logps - chosen_ref_logps # [B] + + rejected_logps_seq = jnp.take_along_axis( # [B, S] + jax.nn.log_softmax(rejected_logits[..., :-1, :], axis=-1), rejected_ids[..., None], axis=-1 + )[..., 0] + rejected_logps = jnp.sum(rejected_logps_seq * valid_seq_mask, axis=-1) # [B] + rejected_ref_logps_seq = jnp.take_along_axis( # [B, S] + jax.nn.log_softmax(rejected_ref_logits[..., :-1, :], axis=-1), rejected_ids[..., None], axis=-1 + )[..., 0] + rejected_ref_logps = jnp.sum(rejected_ref_logps_seq * valid_seq_mask, axis=-1) # [B] + rejected_logratios = rejected_logps - rejected_ref_logps # [B] + + # DPO loss from chosen and rejected logratios + LABEL_SMOOTHING, BETA = config.dpo_label_smoothing, config.dpo_beta + logratios_delta = BETA * (chosen_logratios - rejected_logratios) # [B] + losses = ( # [B] + -jax.nn.log_sigmoid(BETA * logratios_delta) * (1 - LABEL_SMOOTHING) + - jax.nn.log_sigmoid(-BETA * logratios_delta) * LABEL_SMOOTHING + ) + total_loss, total_weights = jnp.mean(losses), losses.shape[0] + loss = total_loss + + moe_lb_loss = 0.0 + if config.num_experts > 1: + moe_lb_losses = maxtext_utils.collect_intermediates_by_suffix(intermediate_outputs, "moe_lb_loss") + if moe_lb_losses: + moe_lb_loss = jnp.mean(jnp.concatenate(moe_lb_losses)) + loss += moe_lb_loss + reward_accuracy = jnp.mean(chosen_logratios > rejected_logratios) + aux = { + "intermediate_outputs": intermediate_outputs, + "xent_sum": 0.0, # DPO has no per-token cross-entropy sum; set to 0 for train_step compatibility + "dpo_loss": total_loss, # pure preference loss before MoE lb, analogous to lm_loss in pre-training + "total_weights": total_weights, + "moe_lb_loss": moe_lb_loss, + "reward_accuracy": reward_accuracy, + "indexer_loss": 0.0, # for gradient_accumulation aux pytree compatibility + "mtp_loss": 0.0, # for gradient_accumulation aux pytree compatibility + } + return loss, aux diff --git a/src/maxtext/trainers/pre_train/train.py b/src/maxtext/trainers/pre_train/train.py index 97e043c7f7..34f67ba266 100644 --- a/src/maxtext/trainers/pre_train/train.py +++ b/src/maxtext/trainers/pre_train/train.py @@ -60,7 +60,7 @@ from maxtext.common.gcloud_stub import cloud_diagnostics as _cloud_diag, is_decoupled from maxtext.common.gcloud_stub import vertex_tensorboard_modules from maxtext.common.metric_logger import MetricLogger, record_activation_metrics -from maxtext.trainers.post_train.dpo.dpo_utils import _merge_dpo_state, _split_dpo_state, dpo_loss_fn +from maxtext.trainers.post_train.dpo.dpo_utils import _merge_dpo_state, _split_dpo_state, dpo_loss_fn, dpo_loss_fn_nnx from maxtext.utils import exceptions from maxtext.utils import gcs_utils from maxtext.utils import max_logging @@ -319,15 +319,15 @@ def train_step(model, config, state_mesh_shardings, params_shardings, state, dat params = state.params ga_fn, ga_model, ga_params, ga_rng, ga_dpo = _loss_fn, model, params, dropout_rng, extra_dpo_args else: - if config.use_dpo: - raise NotImplementedError( - "DPO is not yet supported for NNX modules. DPO requires a reference model " - "stored alongside the policy model (Linen path uses state.params['reference_params']); " - "the NNX TrainState equivalent has not been wired up. As a workaround, set " - "pure_nnx=False for DPO runs." - ) state = nnx.merge(model, state) # reconstruct TrainStateNNX - ga_fn, ga_model, ga_params, ga_rng, ga_dpo = loss_fn, state.model, None, None, [] + if config.use_dpo: + # NNX DPO: reference_model is a sibling field on TrainStateNNX (set up by + # init_initial_state when config.use_dpo=True). dpo_loss_fn_nnx mirrors + # the Linen dpo_loss_fn signature, so it slots into the same dispatcher + # with reference_model passed as the single extra_dpo_args entry. + ga_fn, ga_model, ga_params, ga_rng, ga_dpo = (dpo_loss_fn_nnx, state.model, None, None, [state.reference_model]) + else: + ga_fn, ga_model, ga_params, ga_rng, ga_dpo = loss_fn, state.model, None, None, [] # --- Gradient computation --- if config.gradient_accumulation_steps > 1: @@ -393,9 +393,14 @@ def train_step(model, config, state_mesh_shardings, params_shardings, state, dat ) nnx.update(state.model, curr_params) + # `ga_fn` and `ga_dpo` were set up earlier (loss_fn vs dpo_loss_fn_nnx; + # ga_dpo carries the frozen reference_model when use_dpo, else empty). + _nnx_loss_fn = ga_fn + _nnx_extra_dpo_args = ga_dpo + def diff_wrapper(param, rest, config, data): local_model = nnx.merge(model_graphdef, param, rest, copy=True) - loss, aux = loss_fn(local_model, config, data, None, None, is_train=True) + loss, aux = _nnx_loss_fn(local_model, config, data, None, None, *_nnx_extra_dpo_args, is_train=True) _, _, new_rest = nnx.split(local_model, nnx.Param, ...) return loss, (aux, new_rest) @@ -575,7 +580,10 @@ def eval_step(model, config, state, data, dropout_rng=None): loss, aux = eval_loss_fn(pure_params, *extra_dpo_args, sparsity_state=batch_stats) else: state = nnx.merge(model, state) # reconstruct TrainStateNNX - loss, aux = loss_fn(state.model, config, data, None, None, is_train=False) + if config.use_dpo: + loss, aux = dpo_loss_fn_nnx(state.model, config, data, None, None, state.reference_model, is_train=False) + else: + loss, aux = loss_fn(state.model, config, data, None, None, is_train=False) mtp_acceptance_rate = 0.0 if config.mtp_eval_target_module > 0: @@ -701,7 +709,7 @@ def train_loop(config, recorder, state=None): step_time_delta = datetime.datetime.now() - last_step_completion last_step_completion = datetime.datetime.now() - state_to_save = state if not config.use_dpo else _split_dpo_state(state)[0] + state_to_save = state if not (config.use_dpo and not config.pure_nnx) else _split_dpo_state(state)[0] checkpointing.maybe_save_checkpoint(checkpoint_manager, state_to_save, config, data_iterator, step) if config.dump_hlo and step == (config.dump_step if config.dump_step >= 0 else start_step): @@ -745,7 +753,7 @@ def train_loop(config, recorder, state=None): metric_logger.buffer_and_write_train_metrics(metrics, step, step_time_delta) if config.save_checkpoint_on_completion: - state_to_save = state if not config.use_dpo else _split_dpo_state(state)[0] + state_to_save = state if not (config.use_dpo and not config.pure_nnx) else _split_dpo_state(state)[0] checkpointing.maybe_save_checkpoint(checkpoint_manager, state_to_save, config, data_iterator) if checkpoint_manager is not None: # in case the last checkpoint_period checkpoint is still in progress diff --git a/src/maxtext/utils/train_utils.py b/src/maxtext/utils/train_utils.py index ca90550630..80229b05be 100644 --- a/src/maxtext/utils/train_utils.py +++ b/src/maxtext/utils/train_utils.py @@ -225,10 +225,16 @@ def setup_train_loop(config, recorder, devices=None): if config.pure_nnx: # For NNX, the train state is wrapped in the TrainStateNNX module. + # When DPO is enabled, also materialize a frozen reference model alongside + # the policy. Both are constructed by `_create_model_partial()` (which uses + # `config.init_weights_seed`), so the reference starts identical to the + # policy — standard DPO practice. The reference is later overwritten by + # the step-0 checkpoint in `setup_post_setup_state` below. def create_train_state_fn(): model = _create_model_partial() optimizer = nnx.Optimizer(model, tx, wrt=nnx.Param) - return train_state_nnx.TrainStateNNX(model, optimizer) + reference_model = _create_model_partial() if config.use_dpo else None + return train_state_nnx.TrainStateNNX(model, optimizer, reference_model=reference_model) init_state_fn = create_train_state_fn else: @@ -316,8 +322,6 @@ def create_train_state_fn(): maxtext_utils.print_shardings_params(state_params, state_mesh_shardings_params, mesh, logical_annotations_params) if config.use_dpo: - if config.pure_nnx: - raise NotImplementedError("DPO is not supported yet by NNX models.") abstract_state, _, _ = maxtext_utils.get_abstract_state(config, mesh, init_state_fn, is_training) max_logging.log( "Restoring reference parameters for DPO from" f" '{os.path.join(str(config.checkpoint_dir), str(0))}'" @@ -342,9 +346,17 @@ def create_train_state_fn(): except FileNotFoundError: step0_restored = None if step0_restored is not None: - # TODO: For pure_nnx, the dpo state manipulation is different. - reference_params = step0_restored["items"].params["params"] - state = _merge_dpo_state(state, reference_params) + if config.pure_nnx: + # step0_restored["items"] is the flat nnx.State of the step-0 TrainStateNNX + # (typically from a non-DPO pre-training run, so its top-level fields are + # `model` and `optimizer` — no `reference_model`). Copy its `model` substate + # into our current state's `reference_model` slot. + step0_state = step0_restored["items"] + step0_model_substate = step0_state["model"] if "model" in step0_state else step0_state + state["reference_model"] = step0_model_substate + else: + reference_params = step0_restored["items"].params["params"] + state = _merge_dpo_state(state, reference_params) else: max_logging.log( "Could not restore reference parameters for DPO from" f" '{os.path.join(str(config.checkpoint_dir), str(0))}'" diff --git a/tests/unit/dpo_nnx_test.py b/tests/unit/dpo_nnx_test.py new file mode 100644 index 0000000000..461c3cb2aa --- /dev/null +++ b/tests/unit/dpo_nnx_test.py @@ -0,0 +1,215 @@ +# Copyright 2023–2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""NNX DPO unit tests. + +Covers the NNX-native DPO surface: + * `TrainStateNNX(model, optimizer, reference_model=...)` — reference model + sits alongside policy and is not touched by `apply_gradients`. + * `dpo_loss_fn_nnx(policy, config, data, None, None, reference, is_train)` — + aux structure, identical-model invariant (loss = log(2), reward_accuracy = 0.5). +""" + +import math +import types +import unittest + +import jax +import jax.numpy as jnp +import optax +from flax import nnx + +from maxtext.layers import train_state_nnx +from maxtext.trainers.post_train.dpo import dpo_utils + + +class _MockTransformer(nnx.Module): + """Tiny NNX transformer-shaped module for DPO tests. + + Accepts the same keyword args that `dpo_loss_fn_nnx` passes: + `decoder_input_tokens`, `decoder_positions`, `decoder_segment_ids`, + `enable_dropout`. Other args are tolerated via **kwargs. + """ + + def __init__(self, vocab_size: int, embed_dim: int, rngs: nnx.Rngs): + self.embed = nnx.Embed(vocab_size, embed_dim, rngs=rngs) + self.proj = nnx.Linear(embed_dim, vocab_size, rngs=rngs) + + def __call__( + self, + decoder_input_tokens, + decoder_positions=None, + decoder_segment_ids=None, + enable_dropout=False, + **kwargs, + ): + del decoder_positions, decoder_segment_ids, enable_dropout, kwargs + return self.proj(self.embed(decoder_input_tokens)) + + +def _make_dpo_config(**overrides): + """Build the minimal config surface that `dpo_loss_fn_nnx` reads.""" + base = { + "dpo_label_smoothing": 0.0, + "dpo_beta": 0.1, + "enable_dropout": False, + "num_experts": 1, + "micro_batch_size_to_train_on": 2, + } + base.update(overrides) + return types.SimpleNamespace(**base) + + +def _make_dpo_batch(batch_size=2, seq_len=5): + """Build a tiny DPO-shaped batch. + + `chosen` and `rejected` share the first 2 tokens (common prefix is masked + out in the loss), differ at positions 2 and 3, and are padded at position 4. + """ + chosen = jnp.array([[1, 2, 3, 4, 0]] * batch_size, dtype=jnp.int32) + rejected = jnp.array([[1, 2, 5, 6, 0]] * batch_size, dtype=jnp.int32) + positions = jnp.tile(jnp.arange(seq_len, dtype=jnp.int32), (batch_size, 1)) + segmentation = jnp.array([[1, 1, 1, 1, 0]] * batch_size, dtype=jnp.int32) + return { + "chosen": chosen, + "rejected": rejected, + "chosen_position": positions, + "rejected_position": positions, + "chosen_segmentation": segmentation, + "rejected_segmentation": segmentation, + } + + +class TestTrainStateNNXWithReferenceModel(unittest.TestCase): + """`TrainStateNNX(reference_model=...)` semantics.""" + + def setUp(self): + self.policy = _MockTransformer(vocab_size=8, embed_dim=4, rngs=nnx.Rngs(0)) + self.reference = _MockTransformer(vocab_size=8, embed_dim=4, rngs=nnx.Rngs(1)) + self.tx = optax.adam(1e-3) + + def test_init_with_reference(self): + optimizer = nnx.Optimizer(self.policy, self.tx, wrt=nnx.Param) + state = train_state_nnx.TrainStateNNX(self.policy, optimizer, reference_model=self.reference) + self.assertIs(state.model, self.policy) + self.assertIs(state.reference_model, self.reference) + self.assertEqual(state.optimizer.step.value, 0) + + def test_init_without_reference_omits_attribute(self): + optimizer = nnx.Optimizer(self.policy, self.tx, wrt=nnx.Param) + state = train_state_nnx.TrainStateNNX(self.policy, optimizer) + self.assertFalse(hasattr(state, "reference_model")) + + def test_apply_gradients_does_not_touch_reference(self): + """Gradient update on policy must leave reference model bit-identical.""" + optimizer = nnx.Optimizer(self.policy, self.tx, wrt=nnx.Param) + state = train_state_nnx.TrainStateNNX(self.policy, optimizer, reference_model=self.reference) + + ref_kernel_before = jnp.asarray(state.reference_model.proj.kernel.value).copy() + + def policy_loss(m): + return jnp.mean(m(jnp.array([[1, 2]])) ** 2) + + grads = nnx.grad(policy_loss)(state.model) + state.apply_gradients(grads) + + ref_kernel_after = jnp.asarray(state.reference_model.proj.kernel.value) + self.assertTrue(jnp.array_equal(ref_kernel_before, ref_kernel_after)) + + +class TestDPOLossFnNNX(unittest.TestCase): + """`dpo_loss_fn_nnx` numerical and structural sanity checks.""" + + def setUp(self): + self.policy = _MockTransformer(vocab_size=8, embed_dim=4, rngs=nnx.Rngs(0)) + # Reference initialized with the same seed to make policy and reference + # bit-identical at construction time. + self.reference = _MockTransformer(vocab_size=8, embed_dim=4, rngs=nnx.Rngs(0)) + self.config = _make_dpo_config() + self.data = _make_dpo_batch() + + def test_aux_has_expected_keys(self): + _, aux = dpo_utils.dpo_loss_fn_nnx( + self.policy, self.config, dict(self.data), None, None, self.reference, is_train=True + ) + expected_keys = { + "intermediate_outputs", + "xent_sum", + "dpo_loss", + "total_weights", + "moe_lb_loss", + "reward_accuracy", + "indexer_loss", + "mtp_loss", + } + self.assertEqual(set(aux.keys()), expected_keys) + self.assertEqual(aux["xent_sum"], 0.0) + self.assertEqual(aux["moe_lb_loss"], 0.0) # num_experts=1 + self.assertEqual(aux["total_weights"], self.data["chosen"].shape[0]) + + def test_identical_policy_and_reference_yields_log2_loss(self): + """When policy == reference, all logratios are 0; with label_smoothing=0 + the per-example loss is `-log(sigmoid(0)) = log(2)`. `reward_accuracy` + uses strict `chosen > rejected`, so equal logratios score 0.0 (no example + is strictly preferred). + """ + loss, aux = dpo_utils.dpo_loss_fn_nnx( + self.policy, self.config, dict(self.data), None, None, self.reference, is_train=True + ) + self.assertAlmostEqual(float(loss), math.log(2.0), places=4) + self.assertAlmostEqual(float(aux["dpo_loss"]), math.log(2.0), places=4) + self.assertAlmostEqual(float(aux["reward_accuracy"]), 0.0, places=4) + + def test_dropout_rng_and_params_args_are_unused(self): + """The 4th and 5th positional args are signature-compat slots for the + Linen dispatcher; passing arbitrary values must not affect the result. + """ + loss_a, _ = dpo_utils.dpo_loss_fn_nnx( + self.policy, self.config, dict(self.data), None, None, self.reference, is_train=True + ) + loss_b, _ = dpo_utils.dpo_loss_fn_nnx( + self.policy, + self.config, + dict(self.data), + jax.random.PRNGKey(123), # dropout_rng — unused + {"params": "garbage"}, # params — unused + self.reference, + is_train=True, + ) + self.assertAlmostEqual(float(loss_a), float(loss_b), places=6) + + def test_value_and_grad_argnums0_only_diffs_policy(self): + """`nnx.value_and_grad(..., argnums=0)` over the policy should produce + finite grads on policy params and not require reference grads. + """ + + def _loss(policy_module): + loss, _ = dpo_utils.dpo_loss_fn_nnx( + policy_module, self.config, dict(self.data), None, None, self.reference, is_train=True + ) + return loss + + grad_fn = nnx.value_and_grad(_loss, argnums=0) + loss, grads = grad_fn(self.policy) + self.assertTrue(jnp.isfinite(loss)) + # Grads is an nnx.State of the policy's nnx.Param leaves; check at least one + # leaf is finite and non-trivially shaped. + leaves = jax.tree_util.tree_leaves(grads) + self.assertGreater(len(leaves), 0) + for leaf in leaves: + self.assertTrue(jnp.all(jnp.isfinite(leaf))) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/unit/train_nnx_test.py b/tests/unit/train_nnx_test.py index 3495b4c557..4340d4e22a 100644 --- a/tests/unit/train_nnx_test.py +++ b/tests/unit/train_nnx_test.py @@ -154,13 +154,6 @@ def test_indexer_dense_warmup_skips_xent(self): self.assertEqual(float(aux["xent_sum"]), 0.0) self.assertEqual(float(loss), 0.0) - def test_vocab_tiling_raises_not_implemented(self): - cfg, ts = _build_state() - cfg.num_vocab_tiling = 4 - data = _make_data(batch=cfg.micro_batch_size_to_train_on, vocab=cfg.vocab_size) - with self.assertRaises(NotImplementedError): - pre_train.loss_fn(ts.model, cfg, data, None, None, is_train=True) - class TestTrainStepNNX(unittest.TestCase): """Cover the NNX branch of train_step (the diff_wrapper / nnx.update path).""" @@ -181,16 +174,6 @@ def test_train_step_returns_state_and_metrics(self): self.assertIn("learning/param_norm", metrics["scalar"]) self.assertTrue(jnp.isfinite(metrics["scalar"]["learning/loss"])) - def test_train_step_dpo_raises_for_nnx(self): - cfg, ts = _build_state() - cfg.use_dpo = True - state_graphdef, state_pure = nnx.split(ts) - data = _make_data(batch=cfg.micro_batch_size_to_train_on, vocab=cfg.vocab_size) - with self.assertRaises(NotImplementedError): - pre_train.train_step( - state_graphdef, cfg, state_mesh_shardings=None, params_shardings=None, state=state_pure, data=data - ) - def test_train_step_increments_optimizer_step(self): cfg, ts = _build_state() state_graphdef, state_pure = nnx.split(ts)