Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 14 additions & 2 deletions src/maxtext/trainers/pre_train/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,8 +221,20 @@ def loss_fn(model, config, data, dropout_rng, params, sparsity_state=None, is_tr
one_hot_targets = jax.nn.one_hot(data["targets"], config.vocab_size)
xent, z_loss = max_utils.cross_entropy_with_logits(logits, one_hot_targets, z_loss=config.z_loss_multiplier)

xent = nn.with_logical_constraint(xent, ("activation_embed_and_logits_batch", "activation_length"))
z_loss = nn.with_logical_constraint(z_loss, ("activation_embed_and_logits_batch", "activation_length"))
xent = sharding.maybe_shard_with_logical(
xent,
("activation_embed_and_logits_batch", "activation_length"),
model.mesh,
config.shard_mode,
debug_sharding=config.debug_sharding,
)
z_loss = sharding.maybe_shard_with_logical(
z_loss,
("activation_embed_and_logits_batch", "activation_length"),
model.mesh,
config.shard_mode,
debug_sharding=config.debug_sharding,
)

# Mask out paddings at the end of each example.
xent = xent * (data["targets_segmentation"] != 0)
Expand Down
4 changes: 4 additions & 0 deletions tests/unit/train_nnx_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import unittest

from flax import nnx
import jax
import jax.numpy as jnp
from maxtext.common import train_state_nnx
from maxtext.trainers.pre_train import train as pre_train
Expand Down Expand Up @@ -59,6 +60,7 @@ class _Cfg:
record_internal_nn_metrics: bool = False
skip_step_on_spikes: bool = False
shard_mode: int = 0 # ShardMode.AUTO
debug_sharding: bool = False
weight_sparsity_n: int = 0
weight_sparsity_m: int = 0

Expand All @@ -73,6 +75,8 @@ class _TinyDecoder(nnx.Module):
def __init__(self, vocab_size: int, hidden: int, rngs: nnx.Rngs):
self.embed = nnx.Embed(vocab_size, hidden, rngs=rngs)
self.proj = nnx.Linear(hidden, vocab_size, rngs=rngs)
# loss_fn shards activations against model.mesh, so the stub needs one.
self.mesh = jax.make_mesh((1, 1, 1, 1), ("data", "fsdp", "expert", "context"))

def __call__(
self,
Expand Down
Loading