Skip to content

feat: migrate pipeline to nnx#2885

Open
mesakhcienet wants to merge 4 commits intoAI-Hypercomputer:mainfrom
CIeNET-International:test/pipeline-scan-nnx
Open

feat: migrate pipeline to nnx#2885
mesakhcienet wants to merge 4 commits intoAI-Hypercomputer:mainfrom
CIeNET-International:test/pipeline-scan-nnx

Conversation

@mesakhcienet
Copy link
Copy Markdown
Contributor

@mesakhcienet mesakhcienet commented Dec 24, 2025

Description

implement nnx-based pipeline.

This PR extends PR#2831

Main changes:

  1. nnx_decoders.py: implementing the missing pipeline logic in nnx_decoders.py.
  2. pipeline.py : add a new class NNXPipeline, which is a nnx-based pipeline class.

Tests

we run the pipeline process with command below:

MODEL_NAME=llama2-7b
python -m MaxText.train src/maxtext/configs/base.yml \
    run_name=pipeline_test_${MODEL_NAME}_nnx \
    base_output_directory=/dev/shm/pipeline_test_nnx \
    model_name=${MODEL_NAME}\
    dataset_type=synthetic \
    steps=15 \
    debug_sharding=true \
    per_device_batch_size=2 \
    max_target_length=32 \
    ici_pipeline_parallelism=2 \
    num_pipeline_microbatches=4 \
    num_layers_per_pipeline_stage=2 \
    enable_checkpointing=false \
    enable_nnx=true \
    pure_nnx_decoder=true \
    scan_layers_per_stage=false \
    async_checkpointing=false > nnx-porting-log/pipeline/custom_${MODEL_NAME}.log 2>&1

Checklist

Before submitting this PR, please make sure (put X in square brackets):

  • I have performed a self-review of my code. For an optional AI review, add the gemini-review label.
  • I have necessary comments in my code, particularly in hard-to-understand areas.
  • I have run end-to-end tests tests and provided workload links above if applicable.
  • I have made or will make corresponding changes to the doc if needed, including adding new documentation pages to the relevant Table of Contents (toctree directive) as explained in our documentation.

@mesakhcienet mesakhcienet changed the title core: migrate pipeline to nnx feat: migrate pipeline to nnx Dec 24, 2025
@mesakhcienet mesakhcienet force-pushed the test/pipeline-scan-nnx branch 8 times, most recently from 6875da8 to f34b1a3 Compare January 15, 2026 23:43
@codecov
Copy link
Copy Markdown

codecov Bot commented Jan 19, 2026

Codecov Report

❌ Patch coverage is 25.66667% with 223 lines in your changes missing coverage. Please review.

Files with missing lines Patch % Lines
src/maxtext/layers/nnx_decoders.py 21.60% 173 Missing and 23 partials ⚠️
src/maxtext/layers/decoders.py 44.44% 21 Missing and 4 partials ⚠️
src/maxtext/layers/nnx_wrappers.py 60.00% 1 Missing and 1 partial ⚠️

📢 Thoughts on this report? Let us know!

@mesakhcienet mesakhcienet force-pushed the test/pipeline-scan-nnx branch 4 times, most recently from 12a3907 to 2c16599 Compare January 28, 2026 08:04
@mesakhcienet mesakhcienet force-pushed the test/pipeline-scan-nnx branch 2 times, most recently from 64dc147 to 9e4518e Compare February 2, 2026 01:58
@mesakhcienet mesakhcienet force-pushed the test/pipeline-scan-nnx branch from 631a73e to ac97a1d Compare March 2, 2026 08:48
@mesakhcienet mesakhcienet changed the base branch from main to xibin/nnx_all March 2, 2026 08:48
@ecnal-cienet ecnal-cienet force-pushed the xibin/nnx_all branch 12 times, most recently from 1849f0b to 669dc01 Compare March 3, 2026 19:59
@mesakhcienet mesakhcienet force-pushed the test/pipeline-scan-nnx branch 7 times, most recently from 618de58 to e7656b2 Compare March 11, 2026 06:29
Copy link
Copy Markdown
Collaborator

@bvandermoon bvandermoon left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@gobbleturk what testing do you recommend for migrating pipeline parallelism to NNX? I'll send over an internal doc @hsuan-lun-chiang, @mesakhcienet, and others put together that shows the tests they have already run

@bvandermoon
Copy link
Copy Markdown
Collaborator

@gobbleturk what testing do you recommend for migrating pipeline parallelism to NNX? I'll send over an internal doc @hsuan-lun-chiang, @mesakhcienet, and others put together that shows the tests they have already run

@NuojCheng any thoughts here?

Copy link
Copy Markdown
Collaborator

@NuojCheng NuojCheng left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some additional train compile test for pipeline NNX migration:

@NuojCheng
Copy link
Copy Markdown
Collaborator

There are also some linen usage in pipeline_utils.py, e.g.

I don't see them get updated in this PR but I think they probably should be updated?

Another thing is the usage of function in

# TODO(chengnuojin) Remove this function and its usage after pipeline nnx migration
def remove_logically_partition(weights):
"""Removes LogicallyPartitioned wrapper from weights."""
def _remove_logically_partition_leaf(v):
return getattr(v, "value") if isinstance(v, LogicallyPartitioned) else v
return jax.tree.map(
_remove_logically_partition_leaf,
weights,
is_leaf=lambda v: isinstance(v, LogicallyPartitioned),
)
. I suspect NNX migration can help us get rid of using this function since it is mostly dealing with linen wrapper troubles. Take a look on this if you can. Thank you for the hard work!

@mesakhcienet
Copy link
Copy Markdown
Contributor Author

There are also some linen usage in pipeline_utils.py, e.g.

I don't see them get updated in this PR but I think they probably should be updated?

Another thing is the usage of function in

# TODO(chengnuojin) Remove this function and its usage after pipeline nnx migration
def remove_logically_partition(weights):
"""Removes LogicallyPartitioned wrapper from weights."""
def _remove_logically_partition_leaf(v):
return getattr(v, "value") if isinstance(v, LogicallyPartitioned) else v
return jax.tree.map(
_remove_logically_partition_leaf,
weights,
is_leaf=lambda v: isinstance(v, LogicallyPartitioned),
)

. I suspect NNX migration can help us get rid of using this function since it is mostly dealing with linen wrapper troubles. Take a look on this if you can. Thank you for the hard work!

As far as I know, the current objective is to migrate the Linen pipeline to NNX while preserving the current Linen version. Please advise if any additional progress is required at this time. Thanks!

@NuojCheng
Copy link
Copy Markdown
Collaborator

There are also some linen usage in pipeline_utils.py, e.g.

I don't see them get updated in this PR but I think they probably should be updated?
Another thing is the usage of function in

# TODO(chengnuojin) Remove this function and its usage after pipeline nnx migration
def remove_logically_partition(weights):
"""Removes LogicallyPartitioned wrapper from weights."""
def _remove_logically_partition_leaf(v):
return getattr(v, "value") if isinstance(v, LogicallyPartitioned) else v
return jax.tree.map(
_remove_logically_partition_leaf,
weights,
is_leaf=lambda v: isinstance(v, LogicallyPartitioned),
)

. I suspect NNX migration can help us get rid of using this function since it is mostly dealing with linen wrapper troubles. Take a look on this if you can. Thank you for the hard work!

As far as I know, the current objective is to migrate the Linen pipeline to NNX while preserving the current Linen version. Please advise if any additional progress is required at this time. Thanks!

Shouldn't we have a nnx version of functions in pipeline_utils.py as well?

@bvandermoon
Copy link
Copy Markdown
Collaborator

There are also some linen usage in pipeline_utils.py, e.g.

I don't see them get updated in this PR but I think they probably should be updated?
Another thing is the usage of function in

# TODO(chengnuojin) Remove this function and its usage after pipeline nnx migration
def remove_logically_partition(weights):
"""Removes LogicallyPartitioned wrapper from weights."""
def _remove_logically_partition_leaf(v):
return getattr(v, "value") if isinstance(v, LogicallyPartitioned) else v
return jax.tree.map(
_remove_logically_partition_leaf,
weights,
is_leaf=lambda v: isinstance(v, LogicallyPartitioned),
)

. I suspect NNX migration can help us get rid of using this function since it is mostly dealing with linen wrapper troubles. Take a look on this if you can. Thank you for the hard work!

As far as I know, the current objective is to migrate the Linen pipeline to NNX while preserving the current Linen version. Please advise if any additional progress is required at this time. Thanks!

Are we able to bridge the NNX version back to Linen at a higher layer? If so, then I think we could get rid of the old Linen code that is no longer used

Comment thread src/maxtext/layers/nnx_decoders.py
@mesakhcienet
Copy link
Copy Markdown
Contributor Author

mesakhcienet commented Mar 31, 2026

There are also some linen usage in pipeline_utils.py, e.g.

I don't see them get updated in this PR but I think they probably should be updated?
Another thing is the usage of function in

# TODO(chengnuojin) Remove this function and its usage after pipeline nnx migration
def remove_logically_partition(weights):
"""Removes LogicallyPartitioned wrapper from weights."""
def _remove_logically_partition_leaf(v):
return getattr(v, "value") if isinstance(v, LogicallyPartitioned) else v
return jax.tree.map(
_remove_logically_partition_leaf,
weights,
is_leaf=lambda v: isinstance(v, LogicallyPartitioned),
)

. I suspect NNX migration can help us get rid of using this function since it is mostly dealing with linen wrapper troubles. Take a look on this if you can. Thank you for the hard work!

As far as I know, the current objective is to migrate the Linen pipeline to NNX while preserving the current Linen version. Please advise if any additional progress is required at this time. Thanks!

Are we able to bridge the NNX version back to Linen at a higher layer? If so, then I think we could get rid of the old Linen code that is no longer used

@bvandermoon
Yes we are able to brigde the NNX version. Meanwhile, we are considering two solutions and would like your input on which direction to take:

Option 1: If we use nnx_wrappers.ToLinen to wrap NNX model back to Linen at a higher layer, we can completely remove the old Linen code, but we will need to conduct more tests.

Option 2: Delay the full migration until nnx_decoders.py is stable. In the meantime, we retain the current Linen pipeline and use the enable_nnx and pure_nnx_decoder flags to trigger the NNX version for testing. After the migration is stable, we can remove the flags and linen version of pipeline.

Please let me know which of these two solutions you prefer, thank you.

@NuojCheng
Apologize i missed the first questions you asked.

The NNX pipeline classes (NNXPipeline, NNXCircularPipeline) already handle these internally with JAX-native equivalents:

  • nn.rematjax.checkpoint (pipeline.py L1551, L1870)
  • nn.scanjax.lax.scan + nnx.split/nnx.merge (pipeline.py L1556, L1876)
  • remove_logically_partition()→ inline jax.tree.map unboxing (pipeline.py L1519-1522)

So no NNX versions of those functions are needed — the NNX path bypasses them entirely. maybe you have some suggestions or any part that i am wrong? Please let me know. Thank you.

Comment thread src/maxtext/layers/nnx_decoders.py Outdated
@bvandermoon
Copy link
Copy Markdown
Collaborator

There are also some linen usage in pipeline_utils.py, e.g.

I don't see them get updated in this PR but I think they probably should be updated?
Another thing is the usage of function in

# TODO(chengnuojin) Remove this function and its usage after pipeline nnx migration
def remove_logically_partition(weights):
"""Removes LogicallyPartitioned wrapper from weights."""
def _remove_logically_partition_leaf(v):
return getattr(v, "value") if isinstance(v, LogicallyPartitioned) else v
return jax.tree.map(
_remove_logically_partition_leaf,
weights,
is_leaf=lambda v: isinstance(v, LogicallyPartitioned),
)

. I suspect NNX migration can help us get rid of using this function since it is mostly dealing with linen wrapper troubles. Take a look on this if you can. Thank you for the hard work!

As far as I know, the current objective is to migrate the Linen pipeline to NNX while preserving the current Linen version. Please advise if any additional progress is required at this time. Thanks!

Are we able to bridge the NNX version back to Linen at a higher layer? If so, then I think we could get rid of the old Linen code that is no longer used

@bvandermoon Yes we are able to brigde the NNX version. Meanwhile, we are considering two solutions and would like your input on which direction to take:

Option 1: If we use nnx_wrappers.ToLinen to wrap NNX model back to Linen at a higher layer, we can completely remove the old Linen code, but we will need to conduct more tests.

Option 2: Delay the full migration until nnx_decoders.py is stable. In the meantime, we retain the current Linen pipeline and use the enable_nnx and pure_nnx_decoder flags to trigger the NNX version for testing. After the migration is stable, we can remove the flags and linen version of pipeline.

Please let me know which of these two solutions you prefer, thank you.

@NuojCheng Apologize i missed the first questions you asked.

The NNX pipeline classes (NNXPipeline, NNXCircularPipeline) already handle these internally with JAX-native equivalents:

  • nn.rematjax.checkpoint (pipeline.py L1551, L1870)
  • nn.scanjax.lax.scan + nnx.split/nnx.merge (pipeline.py L1556, L1876)
  • remove_logically_partition()→ inline jax.tree.map unboxing (pipeline.py L1519-1522)

So no NNX versions of those functions are needed — the NNX path bypasses them entirely. maybe you have some suggestions or any part that i am wrong? Please let me know. Thank you.

Thank you @mesakhcienet. Let's go with option 1 please. That way we can continue running unit tests along the way, and we don't need to worry about the Linen/NNX versions diverging before the migration is fully done

Comment thread src/maxtext/layers/pipeline.py Outdated
@mesakhcienet
Copy link
Copy Markdown
Contributor Author

@bvandermoon

Update on the pipeline migration approach:

After further investigation, we've adjusted from the original Option 1 (remove Linen entirely) to a hybrid approach. Here's why:

The current branch keeps both Linen (PipelineLinen/CircularPipelineLinen) and NNX (NNXPipeline/NNXCircularPipeline) pipeline classes in pipeline.py, gated by a use_nnx_pipeline config flag (default: False).

Reasoning:

  • The NNX pipeline wrapped via ToLinen still needs validation against the full pipeline test suite (circular pipeline, FSDP+AG, multi-repeat, etc.) to confirm numerical equivalence with the original Linen version.
  • Keeping the Linen version as a reference allows us to A/B test and verify correctness before removing it.
  • The Linen classes are verified line-by-line identical to main branch -- no divergence risk.
  • nnx_decoders.py uses create_nnx_pipeline() directly for the pure NNX path.
  • decoders.py uses create_pipeline() which dispatches on the flag -- False gives native Linen, True gives NNX-wrapped-in-ToLinen.

Plan to converge to Option 1:
Once the NNX pipeline passes all integration tests, we'll remove the Linen classes and flip the default, effectively completing Option 1.

Let me know if you'd prefer we prioritize removing the Linen code sooner. Thank you.

@bvandermoon
Copy link
Copy Markdown
Collaborator

@bvandermoon

Update on the pipeline migration approach:

After further investigation, we've adjusted from the original Option 1 (remove Linen entirely) to a hybrid approach. Here's why:

The current branch keeps both Linen (PipelineLinen/CircularPipelineLinen) and NNX (NNXPipeline/NNXCircularPipeline) pipeline classes in pipeline.py, gated by a use_nnx_pipeline config flag (default: False).

Reasoning:

  • The NNX pipeline wrapped via ToLinen still needs validation against the full pipeline test suite (circular pipeline, FSDP+AG, multi-repeat, etc.) to confirm numerical equivalence with the original Linen version.
  • Keeping the Linen version as a reference allows us to A/B test and verify correctness before removing it.
  • The Linen classes are verified line-by-line identical to main branch -- no divergence risk.
  • nnx_decoders.py uses create_nnx_pipeline() directly for the pure NNX path.
  • decoders.py uses create_pipeline() which dispatches on the flag -- False gives native Linen, True gives NNX-wrapped-in-ToLinen.

Plan to converge to Option 1: Once the NNX pipeline passes all integration tests, we'll remove the Linen classes and flip the default, effectively completing Option 1.

Let me know if you'd prefer we prioritize removing the Linen code sooner. Thank you.

Thanks @mesakhcienet. I feel we should push strongly to remove the Linen logic if possible. Even though the NNX and Linen versions are equivalent now, that could change if someone updates the Linen version and misses that the NNX one needs to be updated also. Is testing differences between Linen/NNX the main reason to fork these two? Maybe we could test before/after this commit if needed?

@mesakhcienet
Copy link
Copy Markdown
Contributor Author

mesakhcienet commented May 4, 2026

@bvandermoon
Update on the pipeline migration approach:
After further investigation, we've adjusted from the original Option 1 (remove Linen entirely) to a hybrid approach. Here's why:
The current branch keeps both Linen (PipelineLinen/CircularPipelineLinen) and NNX (NNXPipeline/NNXCircularPipeline) pipeline classes in pipeline.py, gated by a use_nnx_pipeline config flag (default: False).
Reasoning:

  • The NNX pipeline wrapped via ToLinen still needs validation against the full pipeline test suite (circular pipeline, FSDP+AG, multi-repeat, etc.) to confirm numerical equivalence with the original Linen version.
  • Keeping the Linen version as a reference allows us to A/B test and verify correctness before removing it.
  • The Linen classes are verified line-by-line identical to main branch -- no divergence risk.
  • nnx_decoders.py uses create_nnx_pipeline() directly for the pure NNX path.
  • decoders.py uses create_pipeline() which dispatches on the flag -- False gives native Linen, True gives NNX-wrapped-in-ToLinen.

Plan to converge to Option 1: Once the NNX pipeline passes all integration tests, we'll remove the Linen classes and flip the default, effectively completing Option 1.
Let me know if you'd prefer we prioritize removing the Linen code sooner. Thank you.

Thanks @mesakhcienet. I feel we should push strongly to remove the Linen logic if possible. Even though the NNX and Linen versions are equivalent now, that could change if someone updates the Linen version and misses that the NNX one needs to be updated also. Is testing differences between Linen/NNX the main reason to fork these two? Maybe we could test before/after this commit if needed?

Thank you for your response. A few clarifications on the sequencing:

  1. The NNX path is opt-in, not the default. This PR adds a new use_nnx_pipeline flag specifically for choosing between the Linen and NNX pipeline implementations, and it defaults to false. Combined with enable_nnx and pure_nnx_decoder (both also false by default), existing users stay on Linen and aren't affected by NNX changes. Linen remains the source of truth until we explicitly flip the defaults, so divergence is much less risky than if NNX were the live path.

  2. Removing Linen fully right now carries unknown risk. I'm not yet confident what breaks downstream if the Linen pipeline classes are deleted outright — there may be call sites or configurations that depend on them in ways the current test coverage doesn't surface. Keeping both lets us de-risk the removal incrementally.

  3. NNX isn't unverified. I've run it against the train-compile configs Nuoj asked about plus the end-to-end run in the PR description. It's just not under unit-test coverage yet.

  4. Existing unit tests are intentionally left on the Linen path so the battle-tested version keeps rigorous coverage during the migration. Plan is to flip them to NNX once the decoder migration is done.

  5. Decoder NNX migration is still in flight (open PRs adding missing models + their unit tests). Adding NNX pipeline unit tests now would likely need rework once that lands.

  6. This PR is already large — I'd rather keep it focused on the migration and handle the test flip in a follow-up.

So my preference is to stick with the hybrid for now and converge to NNX-only once the decoder PRs land and we've validated the removal is safe. Open to revisiting if you'd rather sequence it differently. Thank you!

@bvandermoon
Copy link
Copy Markdown
Collaborator

@bvandermoon
Update on the pipeline migration approach:
After further investigation, we've adjusted from the original Option 1 (remove Linen entirely) to a hybrid approach. Here's why:
The current branch keeps both Linen (PipelineLinen/CircularPipelineLinen) and NNX (NNXPipeline/NNXCircularPipeline) pipeline classes in pipeline.py, gated by a use_nnx_pipeline config flag (default: False).
Reasoning:

  • The NNX pipeline wrapped via ToLinen still needs validation against the full pipeline test suite (circular pipeline, FSDP+AG, multi-repeat, etc.) to confirm numerical equivalence with the original Linen version.
  • Keeping the Linen version as a reference allows us to A/B test and verify correctness before removing it.
  • The Linen classes are verified line-by-line identical to main branch -- no divergence risk.
  • nnx_decoders.py uses create_nnx_pipeline() directly for the pure NNX path.
  • decoders.py uses create_pipeline() which dispatches on the flag -- False gives native Linen, True gives NNX-wrapped-in-ToLinen.

Plan to converge to Option 1: Once the NNX pipeline passes all integration tests, we'll remove the Linen classes and flip the default, effectively completing Option 1.
Let me know if you'd prefer we prioritize removing the Linen code sooner. Thank you.

Thanks @mesakhcienet. I feel we should push strongly to remove the Linen logic if possible. Even though the NNX and Linen versions are equivalent now, that could change if someone updates the Linen version and misses that the NNX one needs to be updated also. Is testing differences between Linen/NNX the main reason to fork these two? Maybe we could test before/after this commit if needed?

Thank you for your response. A few clarifications on the sequencing:

  1. The NNX path is opt-in, not the default. This PR adds a new use_nnx_pipeline flag specifically for choosing between the Linen and NNX pipeline implementations, and it defaults to false. Combined with enable_nnx and pure_nnx_decoder (both also false by default), existing users stay on Linen and aren't affected by NNX changes. Linen remains the source of truth until we explicitly flip the defaults, so divergence is much less risky than if NNX were the live path.
  2. Removing Linen fully right now carries unknown risk. I'm not yet confident what breaks downstream if the Linen pipeline classes are deleted outright — there may be call sites or configurations that depend on them in ways the current test coverage doesn't surface. Keeping both lets us de-risk the removal incrementally.
  3. NNX isn't unverified. I've run it against the train-compile configs Nuoj asked about plus the end-to-end run in the PR description. It's just not under unit-test coverage yet.
  4. Existing unit tests are intentionally left on the Linen path so the battle-tested version keeps rigorous coverage during the migration. Plan is to flip them to NNX once the decoder migration is done.
  5. Decoder NNX migration is still in flight (open PRs adding missing models + their unit tests). Adding NNX pipeline unit tests now would likely need rework once that lands.
  6. This PR is already large — I'd rather keep it focused on the migration and handle the test flip in a follow-up.

So my preference is to stick with the hybrid for now and converge to NNX-only once the decoder PRs land and we've validated the removal is safe. Open to revisiting if you'd rather sequence it differently. Thank you!

I can be open to this approach since some models are not integrated into the decoder layer. But can you please ensure there is a plan to avoid divergence between the NNX and Linen versions? I am concerned that if the Linen version changes, we could miss some functionality in the NNX version

I am also concerned that if we don't run the unit tests on the NNX version now, it will end up being more painful when we go to make the real cutover. Is there anything we can do to be more confident here?

@mesakhcienet
Copy link
Copy Markdown
Contributor Author

mesakhcienet commented May 6, 2026

@bvandermoon
Update on the pipeline migration approach:
After further investigation, we've adjusted from the original Option 1 (remove Linen entirely) to a hybrid approach. Here's why:
The current branch keeps both Linen (PipelineLinen/CircularPipelineLinen) and NNX (NNXPipeline/NNXCircularPipeline) pipeline classes in pipeline.py, gated by a use_nnx_pipeline config flag (default: False).
Reasoning:

  • The NNX pipeline wrapped via ToLinen still needs validation against the full pipeline test suite (circular pipeline, FSDP+AG, multi-repeat, etc.) to confirm numerical equivalence with the original Linen version.
  • Keeping the Linen version as a reference allows us to A/B test and verify correctness before removing it.
  • The Linen classes are verified line-by-line identical to main branch -- no divergence risk.
  • nnx_decoders.py uses create_nnx_pipeline() directly for the pure NNX path.
  • decoders.py uses create_pipeline() which dispatches on the flag -- False gives native Linen, True gives NNX-wrapped-in-ToLinen.

Plan to converge to Option 1: Once the NNX pipeline passes all integration tests, we'll remove the Linen classes and flip the default, effectively completing Option 1.
Let me know if you'd prefer we prioritize removing the Linen code sooner. Thank you.

Thanks @mesakhcienet. I feel we should push strongly to remove the Linen logic if possible. Even though the NNX and Linen versions are equivalent now, that could change if someone updates the Linen version and misses that the NNX one needs to be updated also. Is testing differences between Linen/NNX the main reason to fork these two? Maybe we could test before/after this commit if needed?

Thank you for your response. A few clarifications on the sequencing:

  1. The NNX path is opt-in, not the default. This PR adds a new use_nnx_pipeline flag specifically for choosing between the Linen and NNX pipeline implementations, and it defaults to false. Combined with enable_nnx and pure_nnx_decoder (both also false by default), existing users stay on Linen and aren't affected by NNX changes. Linen remains the source of truth until we explicitly flip the defaults, so divergence is much less risky than if NNX were the live path.
  2. Removing Linen fully right now carries unknown risk. I'm not yet confident what breaks downstream if the Linen pipeline classes are deleted outright — there may be call sites or configurations that depend on them in ways the current test coverage doesn't surface. Keeping both lets us de-risk the removal incrementally.
  3. NNX isn't unverified. I've run it against the train-compile configs Nuoj asked about plus the end-to-end run in the PR description. It's just not under unit-test coverage yet.
  4. Existing unit tests are intentionally left on the Linen path so the battle-tested version keeps rigorous coverage during the migration. Plan is to flip them to NNX once the decoder migration is done.
  5. Decoder NNX migration is still in flight (open PRs adding missing models + their unit tests). Adding NNX pipeline unit tests now would likely need rework once that lands.
  6. This PR is already large — I'd rather keep it focused on the migration and handle the test flip in a follow-up.

So my preference is to stick with the hybrid for now and converge to NNX-only once the decoder PRs land and we've validated the removal is safe. Open to revisiting if you'd rather sequence it differently. Thank you!

I can be open to this approach since some models are not integrated into the decoder layer. But can you please ensure there is a plan to avoid divergence between the NNX and Linen versions? I am concerned that if the Linen version changes, we could miss some functionality in the NNX version

I am also concerned that if we don't run the unit tests on the NNX version now, it will end up being more painful when we go to make the real cutover. Is there anything we can do to be more confident here?

@bvandermoon

Thank you sir. Thinking it through, I'd actually like to go with your original suggestion — remove the Linen pipeline entirely in this PR rather than keep both. Maintaining the hybrid with proper safeguards (dual-path CI, equivalence tests, etc.) ends up being a similar amount of work to just doing the cutover with solid NNX tests, and the cutover leaves us in a cleaner place.

Proposed plan for this PR:

  1. Remove the Linen pipeline classes as originally proposed.
  2. Update the existing unit tests (mostly shoud be on pipeline_parallelism_test.py) to run against the NNX path so the pipeline stays under unit-test coverage through the migration.

Does this work for you? I'll need a bit of time to get the NNX tests passing before this is ready for merge. On our side, we'll also need #3114 to land into main branch first before we can merge these changes — happy to adjust the plan if you'd sequence the test work differently.

Comment thread src/maxtext/configs/base.yml Outdated
enable_nnx: False
pure_nnx_decoder: False
pure_nnx: False
use_nnx_pipeline: False # Set to False to use native Linen pipeline (with custom VJP)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder if CI tests can pass if use_nnx_pipeline=true?

Copy link
Copy Markdown
Contributor Author

@mesakhcienet mesakhcienet May 7, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I fixed the pipeline tests and just need to rebase the branch now. I'll ping you when it's ready. Thanks!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants