Fix explicit-mesh loss sharding assert on the NNX path#4208
Open
ecnal-cienet wants to merge 1 commit into
Open
Fix explicit-mesh loss sharding assert on the NNX path#4208ecnal-cienet wants to merge 1 commit into
ecnal-cienet wants to merge 1 commit into
Conversation
Codecov Report✅ All modified and coverable lines are covered by tests. 📢 Thoughts on this report? Let us know! |
21b80d3 to
ce97c7d
Compare
loss_fn's NNX branch constrained xent/z_loss with raw nn.with_logical_constraint, which keeps the size-1 context axis, while the model shards activations via create_sharding (remove_size_one_mesh_axis drops context). Under an explicit mesh the mismatch is a hard assert. Use sharding.maybe_shard_with_logical, as the Linen branch already does: it builds the sharding via create_sharding (dropping size-1 context, so it matches the array) and reshards instead of asserting under explicit mode.
ce97c7d to
cd08cff
Compare
|
🤖 Hi @ecnal-cienet, I've received your request, and I'm working on it now! You can track my progress in the logs for more details. |
|
🤖 I'm sorry @ecnal-cienet, but I was unable to process your request. Please see the logs for more details. |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Problem
Training
deepseek3-671b-batchsplit(which setsshard_mode=explicit) withpure_nnx=Truecrashes at the firsttrain_stepcompile:Root cause
loss_fnhas two branches. The Linen branch already constrainsxent/z_lossthrough MaxText'ssharding.maybe_shard_with_logical, but the NNX branch (taken whenpure_nnx=True) was still using raw Flaxnn.with_logical_constraint.The two resolve the logical axes differently:
nn.with_logical_constraintmapsactivation_embed_and_logits_batchto('data','fsdp','expert','context')— it keeps the size-1contextaxis.create_sharding, which callsremove_size_one_mesh_axisand dropscontext(size 1 here), so logits — and thereforexent— arrive sharded('data','fsdp','expert').Under
autoaxes the constraint is just a hint, so the mismatch was harmless. Underexplicitaxeswith_sharding_constraintbecomes a hard assert that the spec matches exactly, so the extracontextaxis trips it.Fix
Convert the NNX branch to
sharding.maybe_shard_with_logical, mirroring the already-fixed Linen branch. That helper builds the sharding viacreate_sharding(so size-1contextis dropped, matching the array) and, under an explicit mesh, callsjax.sharding.reshardinstead of the asserting constraint. Withcontextof size 1 the reshard is a no-op — same physical layout, just a matching spec.Also updates
tests/unit/train_nnx_test.py: the NNX branch now readsmodel.meshandconfig.debug_sharding, so the tiny test stub gets a single-device mesh and the config stub getsdebug_sharding.Tests
Before Fix (NNX, failed): https://cloudlogging.app.goo.gl/XF3eVFANEPx43XBG7
After Fix (NNX, passed): https://cloudlogging.app.goo.gl/DRsc4TkhtNm5nV929
Unit:
tests/unit/train_nnx_test.pypasses (7 passed).Checklist
Before submitting this PR, please make sure (put X in square brackets):
gemini-reviewlabel.