From cd08cff6f062579447c8d33c0bbef5cc7cfbf9f0 Mon Sep 17 00:00:00 2001 From: Lance Wang Date: Fri, 19 Jun 2026 20:14:04 +0000 Subject: [PATCH] Reshard loss under explicit mesh on the NNX path loss_fn's NNX branch constrained xent/z_loss with raw nn.with_logical_constraint, which keeps the size-1 context axis, while the model shards activations via create_sharding (remove_size_one_mesh_axis drops context). Under an explicit mesh the mismatch is a hard assert. Use sharding.maybe_shard_with_logical, as the Linen branch already does: it builds the sharding via create_sharding (dropping size-1 context, so it matches the array) and reshards instead of asserting under explicit mode. --- src/maxtext/trainers/pre_train/train.py | 16 ++++++++++++++-- tests/unit/train_nnx_test.py | 4 ++++ 2 files changed, 18 insertions(+), 2 deletions(-) diff --git a/src/maxtext/trainers/pre_train/train.py b/src/maxtext/trainers/pre_train/train.py index 047ddb97a8..e6f73928a0 100644 --- a/src/maxtext/trainers/pre_train/train.py +++ b/src/maxtext/trainers/pre_train/train.py @@ -221,8 +221,20 @@ def loss_fn(model, config, data, dropout_rng, params, sparsity_state=None, is_tr one_hot_targets = jax.nn.one_hot(data["targets"], config.vocab_size) xent, z_loss = max_utils.cross_entropy_with_logits(logits, one_hot_targets, z_loss=config.z_loss_multiplier) - xent = nn.with_logical_constraint(xent, ("activation_embed_and_logits_batch", "activation_length")) - z_loss = nn.with_logical_constraint(z_loss, ("activation_embed_and_logits_batch", "activation_length")) + xent = sharding.maybe_shard_with_logical( + xent, + ("activation_embed_and_logits_batch", "activation_length"), + model.mesh, + config.shard_mode, + debug_sharding=config.debug_sharding, + ) + z_loss = sharding.maybe_shard_with_logical( + z_loss, + ("activation_embed_and_logits_batch", "activation_length"), + model.mesh, + config.shard_mode, + debug_sharding=config.debug_sharding, + ) # Mask out paddings at the end of each example. xent = xent * (data["targets_segmentation"] != 0) diff --git a/tests/unit/train_nnx_test.py b/tests/unit/train_nnx_test.py index ebeededbd7..6467c6f196 100644 --- a/tests/unit/train_nnx_test.py +++ b/tests/unit/train_nnx_test.py @@ -23,6 +23,7 @@ import unittest from flax import nnx +import jax import jax.numpy as jnp from maxtext.common import train_state_nnx from maxtext.trainers.pre_train import train as pre_train @@ -59,6 +60,7 @@ class _Cfg: record_internal_nn_metrics: bool = False skip_step_on_spikes: bool = False shard_mode: int = 0 # ShardMode.AUTO + debug_sharding: bool = False weight_sparsity_n: int = 0 weight_sparsity_m: int = 0 @@ -73,6 +75,8 @@ class _TinyDecoder(nnx.Module): def __init__(self, vocab_size: int, hidden: int, rngs: nnx.Rngs): self.embed = nnx.Embed(vocab_size, hidden, rngs=rngs) self.proj = nnx.Linear(hidden, vocab_size, rngs=rngs) + # loss_fn shards activations against model.mesh, so the stub needs one. + self.mesh = jax.make_mesh((1, 1, 1, 1), ("data", "fsdp", "expert", "context")) def __call__( self,