Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
609 changes: 609 additions & 0 deletions src/maxtext/checkpoint_conversion/compare_linen_nnx_checkpoint.py

Large diffs are not rendered by default.

581 changes: 581 additions & 0 deletions src/maxtext/checkpoint_conversion/linen_nnx_converter.py

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -87,11 +87,12 @@ def convert(paxml_ckpt_path, maxtext_model_name, base_output_directory, run_name
devices_array = maxtext_utils.create_device_mesh(cfg)
mesh = Mesh(devices_array, cfg.mesh_axes)

# This conversion script reads paxml-format weights and emits a Linen-format
# MaxText checkpoint (downstream uses `.params['params']`, `.opt_state.mu['params']`,
# `.opt_state.nu['params']` keystr paths; the keystr_map below targets the Linen
# tree shape). Use the Linen path regardless of pure_nnx.
quant = quantizations.configure_quantization(cfg)
if cfg.pure_nnx:
raise NotImplementedError("Pure NNX support has not been implemented yet.")
else:
model = transformer_as_linen(cfg, mesh, quant=quant, model_mode=MODEL_MODE_TRAIN)
model = transformer_as_linen(cfg, mesh, quant=quant, model_mode=MODEL_MODE_TRAIN)
learning_rate_schedule = maxtext_utils.create_learning_rate_schedule(cfg)
tx = optimizers.get_optimizer(cfg, learning_rate_schedule)

Expand All @@ -102,11 +103,7 @@ def convert(paxml_ckpt_path, maxtext_model_name, base_output_directory, run_name
cfg.checkpoint_period,
)

if cfg.pure_nnx:
# NNX has a different function to init the training state.
raise NotImplementedError("Pure NNX support has not been implemented yet.")
else:
init_state_fn = functools.partial(maxtext_utils.init_initial_state, model, tx, cfg, True, init_rng)
init_state_fn = functools.partial(maxtext_utils.init_initial_state, model, tx, cfg, True, init_rng)
state, _, _, _ = maxtext_utils.setup_training_state(None, cfg, mesh, checkpoint_manager, init_state_fn)
max_logging.log("start")
max_utils.print_mem_stats("After params initialized")
Expand Down
36 changes: 28 additions & 8 deletions src/maxtext/common/checkpointing.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from absl import flags
import datetime
from etils import epath
from flax import nnx
from flax.training import train_state
import jax
from maxtext.utils.globals import DEFAULT_OCDBT_TARGET_DATA_FILE_SIZE
Expand Down Expand Up @@ -536,7 +537,7 @@ def load_state_if_possible(
load_parameters_from_path: str,
load_full_state_from_path: str,
checkpoint_storage_concurrent_gb: int,
abstract_unboxed_pre_state: train_state.TrainState,
abstract_unboxed_pre_state: train_state.TrainState | nnx.State,
enable_single_replica_ckpt_restoring: bool | None = False,
dataset_type: str | None = "tfds",
step: int = -1, # -1 means latest
Expand Down Expand Up @@ -604,9 +605,14 @@ def map_to_pspec(data):
)
ocp.type_handlers.register_type_handler(jax.Array, array_handler, override=True)

restore_args = jax.tree_util.tree_map(map_to_pspec, abstract_unboxed_pre_state)
# Convert nnx.State to pure dict to match how checkpoints are saved for NNX
restore_target = abstract_unboxed_pre_state
if isinstance(abstract_unboxed_pre_state, nnx.State):
restore_target = abstract_unboxed_pre_state.to_pure_dict()

restore_args = jax.tree_util.tree_map(map_to_pspec, restore_target)
checkpoint_args = ocp.args.PyTreeRestore(
item=abstract_unboxed_pre_state,
item=restore_target,
restore_args=restore_args,
partial_restore=True,
)
Expand All @@ -620,9 +626,7 @@ def map_to_pspec(data):
(EmergencyCheckpointManager, EmergencyReplicatorCheckpointManager),
):
return (
checkpoint_manager.restore(
step, args=Composite(state=checkpoint_args)
).state,
checkpoint_manager.restore(step, args=Composite(state=checkpoint_args)).state,
None,
)
# Case 2: Matches if dataset type is "grain" and the data iterator is not a
Expand All @@ -647,9 +651,14 @@ def map_to_pspec(data):
return (checkpoint_manager.restore(step, args=Composite(items=checkpoint_args)), None)

if load_parameters_from_path != "":
if isinstance(abstract_unboxed_pre_state, nnx.State):
_, params, _ = nnx.split(abstract_unboxed_pre_state.model, nnx.Param, ...)
else:
params = abstract_unboxed_pre_state.params

restored_params = load_params_from_path(
load_parameters_from_path,
abstract_unboxed_pre_state.params,
params,
checkpoint_storage_concurrent_gb,
use_ocdbt=use_ocdbt,
use_zarr3=use_zarr3,
Expand Down Expand Up @@ -741,7 +750,18 @@ def maybe_save_checkpoint(checkpoint_manager, state, config, data_iterator, step
# Determine the effective step for saving a checkpoint.
# If 'step' is not provided, this call is for a potential final checkpoint
# and use the last completed step from the state.
actual_step = (int(state.step) - 1) if step is None else int(step)
if step is not None:
actual_step = int(step)
else:
if config.pure_nnx:
actual_step = int(state.optimizer.step) - 1
else:
# Linen TrainState has .step attribute
actual_step = int(state.step) - 1

if config.pure_nnx:
# Convert nnx.State to dict.
state = state.to_pure_dict()

# Determine if a checkpoint save should be forced, overriding the usual `config.checkpoint_period` logic.
# This occurs if this function was called:
Expand Down
7 changes: 7 additions & 0 deletions src/maxtext/configs/base.yml
Original file line number Diff line number Diff line change
Expand Up @@ -559,6 +559,13 @@ logical_axis_rules: [
['tokens_per_page', []],
['paged_kv_head_dim_size', []],
# ==========================================
# Pipeline Parallelism
# ==========================================
['layers_outside_pipeline', []],
['layers_per_stage', []],
['num_activations', []],
['circular_repeats', []],
# ==========================================
# Deprecated / Scheduled for Removal
# ==========================================
['mlp_no_fsdp', ['tensor', 'tensor_sequence', 'autoregressive']],
Expand Down
3 changes: 1 addition & 2 deletions src/maxtext/configs/pyconfig_deprecated.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,10 +195,9 @@ def validate_expert_shard_attention_option(expert_shard_attention_option: str) -


def validate_vocab_tiling(num_vocab_tiling: int, per_device_batch_size: int, max_target_length: int, enable_nnx: bool):
del enable_nnx # NNX vocab tiling supported via vocab_tiling_nnx_loss in vocabulary_tiling.py
if (per_device_batch_size * max_target_length) % num_vocab_tiling != 0:
raise ValueError("Per device batch size times sequence length should be divisible by the number of vocab tiles.")
if num_vocab_tiling > 1 and enable_nnx: # TODO (chengnuojin) enable vocab tiling on NNX after NNX migration
raise ValueError("We currently don't support vocab tiling on NNX module.")


def validate_rampup_batch_size(batch_size_start, batch_size_end, batch_size_increment, global_rampup_samples):
Expand Down
3 changes: 1 addition & 2 deletions src/maxtext/configs/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -2736,8 +2736,7 @@ def calculate_global_batch_sizes(per_device_batch_size, expansion_factor, num_de
and (self.per_device_batch_size * self.max_target_length) % self.num_vocab_tiling != 0
):
raise ValueError("Per device batch size times sequence length should be divisible by the number of vocab tiles.")
if self.num_vocab_tiling > 1 and self.enable_nnx:
raise ValueError("We currently don't support vocab tiling on NNX module.")
# Vocab tiling on NNX is now supported via vocab_tiling_nnx_loss in vocabulary_tiling.py.
if self.context_parallel_size > 1 and self.context_parallel_strategy.lower() == "ring":
if "gpu" not in self.hardware:
raise ValueError(
Expand Down
37 changes: 16 additions & 21 deletions src/maxtext/experimental/rl/grpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -542,29 +542,28 @@ def setup_train_loop(
- eval_data_iterator: The iterator for the evaluation dataset (or None).
- state: The initialized training state.
"""
# GRPO RL trainer is Linen-shaped end-to-end (state.params accesses below,
# state_mesh_shardings.params, and the inference path through MaxEngine which is
# Linen-only). Run on Linen path regardless of pure_nnx; warn the user since
# NNX-format checkpoints will mismatch at restore time.
if config.pure_nnx or config_inference.pure_nnx:
max_logging.log(
"WARNING: GRPO RL trainer does not yet support pure_nnx natively; "
"running on the Linen path. NNX-format checkpoints will not load correctly here."
)
with maybe_record_goodput(recorder, GoodputEvent.TPU_INIT):
max_logging.log("Training mesh used for the workload")
num_inference_devices = config.inference_devices_per_replica * config.inference_replicas
training_devices = jax.devices()[num_inference_devices:]
if config.pure_nnx:
raise NotImplementedError("Pure NNX support has not been implemented yet.")
else:
model = mt.from_config(config, devices=training_devices)
model = mt.from_config(config, devices=training_devices)
mesh = model.mesh
max_logging.log("Inference mesh used for the workload")
inference_devices = jax.devices()[:num_inference_devices]
if config_inference.pure_nnx:
raise NotImplementedError("Pure NNX support has not been implemented yet.")
else:
inference_model = mt.from_config(config_inference, devices=inference_devices)
inference_model = mt.from_config(config_inference, devices=inference_devices)
inference_mesh = inference_model.mesh
init_rng = jax.random.PRNGKey(config.init_weights_seed)
learning_rate_schedule, tx = train_utils.create_training_optimizer(config, model)
if config.pure_nnx:
# NNX has a different function to init the training state.
raise NotImplementedError("Pure NNX support has not been implemented yet.")
else:
init_state_fn = functools.partial(maxtext_utils.init_initial_state, model, tx, config, True, init_rng)
init_state_fn = functools.partial(maxtext_utils.init_initial_state, model, tx, config, True, init_rng)
checkpoint_manager = train_utils.create_checkpoint_manager(config, mesh, init_state_fn)

with maybe_record_goodput(recorder, GoodputEvent.TRAINING_PREPARATION):
Expand All @@ -573,14 +572,10 @@ def setup_train_loop(
data_iterator, config, mesh, checkpoint_manager, init_state_fn
)

# create inference_state_mesh_shardings from inference_mesh
if config_inference.pure_nnx:
# NNX has a different function to init the training state.
raise NotImplementedError("Pure NNX support has not been implemented yet.")
else:
init_inference_state_fn = functools.partial(
maxtext_utils.init_initial_state, inference_model, tx, config_inference, False, init_rng
)
# create inference_state_mesh_shardings from inference_mesh (Linen path; see warning above)
init_inference_state_fn = functools.partial(
maxtext_utils.init_initial_state, inference_model, tx, config_inference, False, init_rng
)
inference_state_mesh_shardings = maxtext_utils.get_abstract_state(
config_inference, inference_mesh, init_inference_state_fn, is_training=False
)[2]
Expand Down
22 changes: 7 additions & 15 deletions src/maxtext/inference/maxengine/maxengine.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,12 +111,12 @@ def __init__(self, config: Any, devices: Any | None = None):
devices_array = maxtext_utils.create_device_mesh(config=config, devices=devices)
self._mesh = jax.sharding.Mesh(devices_array, config.mesh_axes)

# Model and Optimizer definition
# Model and Optimizer definition.
# MaxEngine uses Linen-shaped state (state.params, state_mesh_shardings.params,
# state.opt_state) and serves Linen-format inference checkpoints. Use Linen path
# regardless of pure_nnx — the flag affects training, not inference serving.
quant = quantizations.configure_quantization(config)
if config.pure_nnx:
raise NotImplementedError("Pure NNX support has not been implemented yet.")
else:
self.model = models.transformer_as_linen(config, mesh=self._mesh, quant=quant, model_mode=MODEL_MODE_PREFILL)
self.model = models.transformer_as_linen(config, mesh=self._mesh, quant=quant, model_mode=MODEL_MODE_PREFILL)
self.replicated_sharding = jax.sharding.NamedSharding(self._mesh, P(None))

self.abstract_params = None
Expand Down Expand Up @@ -232,11 +232,7 @@ def load_params(self, *args, params=None, rng: PRNGKeyType | None = None, **kwar
rng1, rng2, rng3 = jax.random.split(rng, 3)
if params:
print("Resharding given params")
if self.config.pure_nnx:
# NNX has a different function to init the training state.
raise NotImplementedError("Pure NNX support has not been implemented yet.")
else:
init_state_fn = functools.partial(maxtext_utils.init_initial_state, self.model, None, self.config, False, rng)
init_state_fn = functools.partial(maxtext_utils.init_initial_state, self.model, None, self.config, False, rng)
_, self.state_mesh_annotations, state_mesh_shardings = maxtext_utils.get_abstract_state(
self.config, self._mesh, init_state_fn, False
)
Expand All @@ -245,11 +241,7 @@ def load_params(self, *args, params=None, rng: PRNGKeyType | None = None, **kwar
state = maxtext_utils.init_decode_state(None, params)
state = max_utils.unbox_logicallypartioned(state)
else:
if self.config.pure_nnx:
# NNX has a different function to init the training state.
raise NotImplementedError("Pure NNX support has not been implemented yet.")
else:
init_state_fn = functools.partial(maxtext_utils.init_initial_state, self.model, None, self.config, False, rng1)
init_state_fn = functools.partial(maxtext_utils.init_initial_state, self.model, None, self.config, False, rng1)
state, self.state_mesh_annotations = maxtext_utils.setup_decode_state(self.config, self._mesh, None, init_state_fn)
# pylint: disable=isinstance-second-argument-not-valid-type
self.abstract_params = jax.tree_util.tree_map(
Expand Down
4 changes: 2 additions & 2 deletions src/maxtext/layers/attentions.py
Original file line number Diff line number Diff line change
Expand Up @@ -525,14 +525,14 @@ def __init__(
elif self.is_qwen3_next:
self.query_norm = Qwen3NextRMSNorm(
num_features=self.config.head_dim,
eps=self.config.normalization_layer_epsilon,
epsilon=self.config.normalization_layer_epsilon,
dtype=self.config.dtype,
weight_dtype=self.config.weight_dtype,
rngs=self.rngs,
)
self.key_norm = Qwen3NextRMSNorm(
num_features=self.config.head_dim,
eps=self.config.normalization_layer_epsilon,
epsilon=self.config.normalization_layer_epsilon,
dtype=self.config.dtype,
weight_dtype=self.config.weight_dtype,
rngs=self.rngs,
Expand Down
4 changes: 2 additions & 2 deletions src/maxtext/layers/moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -2185,8 +2185,8 @@ def __call__(
w0_kernel = jnp.asarray(self.wi_0[...], self.dtype)
w1_kernel = jnp.asarray(self.wi_1[...], self.dtype)

if self.per_expert_scale is not None:
wo_kernel = wo_kernel * jnp.asarray(self.per_expert_scale[...], self.dtype)[:, None, None]
if self.per_expert_scale is not None:
wo_kernel = wo_kernel * jnp.asarray(self.per_expert_scale[...], self.dtype)[:, None, None]

if self.wi_0_sparsity_module is not None:
_, w0_kernel = self.wi_0_sparsity_module(jnp.zeros_like(w0_kernel), w0_kernel)
Expand Down
42 changes: 37 additions & 5 deletions src/maxtext/layers/nnx_decoders.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
MODEL_MODE_TRAIN,
Config,
DecoderBlockType,
MultimodalInput,
ShardMode,
)
from maxtext.inference import page_manager
Expand Down Expand Up @@ -422,8 +423,16 @@ def pure_layer_fn(state_in, y_in):
out = merged_layer(y_in, **kwargs)
return out, nnx.state(merged_layer)

checkpointed_fn = jax.checkpoint(pure_layer_fn, policy=policy, prevent_cse=prevent_cse)
out, new_state = checkpointed_fn(state, y)
# Linen-based FP8 ops (fp8_nanoo, fp8_gpu) store scale/amax_history in Linen
# mutable scope. jax.checkpoint re-traces the scan body during backward (remat),
# but the Linen scope retains JAX tracers from the first trace, causing
# UnexpectedTracerError. Skip checkpoint for these quantization types.
uses_linen_fp8_mutable_state = self.config.quantization in ("fp8_nanoo", "fp8_gpu")
if uses_linen_fp8_mutable_state:
out, new_state = pure_layer_fn(state, y)
else:
checkpointed_fn = jax.checkpoint(pure_layer_fn, policy=policy, prevent_cse=prevent_cse)
out, new_state = checkpointed_fn(state, y)
nnx.update(layer, new_state)

return out
Expand Down Expand Up @@ -502,13 +511,12 @@ def layer_fn(carry, scanned_vars):
return new_carry, (new_current_state, updated_kv)
return new_carry, new_current_state

layer_fn = jax.checkpoint(layer_fn, policy=policy, prevent_cse=prevent_cse)

if use_kv:
# If kv_caches is provided (e.g., from vLLM), we CANNOT use jax.lax.scan
# because scanning requires stacking the kv_caches list, which creates a copy
# and breaks the in-place memory updates required by vLLM's PagedAttention.
# Therefore, we must unroll the loop statically when kv_caches is provided.
layer_fn = jax.checkpoint(layer_fn, policy=policy, prevent_cse=prevent_cse)

# kv_caches_stacked is actually the original kv_caches list in this new flow
kv_caches_list = kv_caches_stacked
Expand All @@ -530,7 +538,24 @@ def layer_fn(carry, scanned_vars):
# inference with vLLM, parameters do not change and we don't need intermediates.
return current_carry, layers, None
else:
final_carry, scanned_state = jax.lax.scan(layer_fn, x_in, (params, state))
# Linen-based FP8 ops (fp8_nanoo, fp8_gpu) store scale/amax_history in Linen
# mutable scope. jax.lax.scan traces the body function and Linen's setup() creates
# intermediate tracer values (amax_history float32[1024]) that escape the scan scope,
# causing UnexpectedTracerError. Use a Python for loop instead for these types.
uses_linen_fp8_mutable_state = self.config.quantization in ("fp8_nanoo", "fp8_gpu")
if uses_linen_fp8_mutable_state:
carry = x_in
per_layer_states = []
for i in range(length):
current_params = jax.tree.map(lambda x, i=i: x[i], params)
current_state = jax.tree.map(lambda x, i=i: x[i], state)
carry, new_state_i = layer_fn(carry, (current_params, current_state))
per_layer_states.append(new_state_i)
final_carry = carry
scanned_state = jax.tree.map(lambda *xs: jnp.stack(list(xs)), *per_layer_states)
else:
layer_fn = jax.checkpoint(layer_fn, policy=policy, prevent_cse=prevent_cse)
final_carry, scanned_state = jax.lax.scan(layer_fn, x_in, (params, state))
returned_kv_stacked = None

if scan_axis != 0:
Expand Down Expand Up @@ -963,7 +988,14 @@ def __call__(
audio_embeddings: None | jnp.ndarray = None,
audio_masks: None | jnp.ndarray = None,
deepstack_visual_embeds: None | list[jnp.ndarray] = None,
multimodal_input: None | MultimodalInput = None,
):
if multimodal_input is not None:
image_embeddings = multimodal_input.image_embeddings
image_masks = multimodal_input.image_masks
audio_embeddings = multimodal_input.audio_embeddings
audio_masks = multimodal_input.audio_masks
bidirectional_mask = multimodal_input.bidirectional_mask
cfg = self.config
assert decoder_input_tokens.ndim == 2 # [batch, len]

Expand Down
Loading
Loading