diff --git a/src/maxtext/trainers/pre_train/train.py b/src/maxtext/trainers/pre_train/train.py index 047ddb97a8..e6f73928a0 100644 --- a/src/maxtext/trainers/pre_train/train.py +++ b/src/maxtext/trainers/pre_train/train.py @@ -221,8 +221,20 @@ def loss_fn(model, config, data, dropout_rng, params, sparsity_state=None, is_tr one_hot_targets = jax.nn.one_hot(data["targets"], config.vocab_size) xent, z_loss = max_utils.cross_entropy_with_logits(logits, one_hot_targets, z_loss=config.z_loss_multiplier) - xent = nn.with_logical_constraint(xent, ("activation_embed_and_logits_batch", "activation_length")) - z_loss = nn.with_logical_constraint(z_loss, ("activation_embed_and_logits_batch", "activation_length")) + xent = sharding.maybe_shard_with_logical( + xent, + ("activation_embed_and_logits_batch", "activation_length"), + model.mesh, + config.shard_mode, + debug_sharding=config.debug_sharding, + ) + z_loss = sharding.maybe_shard_with_logical( + z_loss, + ("activation_embed_and_logits_batch", "activation_length"), + model.mesh, + config.shard_mode, + debug_sharding=config.debug_sharding, + ) # Mask out paddings at the end of each example. xent = xent * (data["targets_segmentation"] != 0) diff --git a/tests/unit/train_nnx_test.py b/tests/unit/train_nnx_test.py index ebeededbd7..6467c6f196 100644 --- a/tests/unit/train_nnx_test.py +++ b/tests/unit/train_nnx_test.py @@ -23,6 +23,7 @@ import unittest from flax import nnx +import jax import jax.numpy as jnp from maxtext.common import train_state_nnx from maxtext.trainers.pre_train import train as pre_train @@ -59,6 +60,7 @@ class _Cfg: record_internal_nn_metrics: bool = False skip_step_on_spikes: bool = False shard_mode: int = 0 # ShardMode.AUTO + debug_sharding: bool = False weight_sparsity_n: int = 0 weight_sparsity_m: int = 0 @@ -73,6 +75,8 @@ class _TinyDecoder(nnx.Module): def __init__(self, vocab_size: int, hidden: int, rngs: nnx.Rngs): self.embed = nnx.Embed(vocab_size, hidden, rngs=rngs) self.proj = nnx.Linear(hidden, vocab_size, rngs=rngs) + # loss_fn shards activations against model.mesh, so the stub needs one. + self.mesh = jax.make_mesh((1, 1, 1, 1), ("data", "fsdp", "expert", "context")) def __call__( self,