Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
39 commits
Select commit Hold shift + click to select a range
75d10e1
feat: nnx implementation of nnx classes
mesakhcienet May 11, 2026
d866384
Update pipeline.py
mesakhcienet May 12, 2026
11df872
Update pipeline.py
mesakhcienet May 12, 2026
5a312b5
Update pipeline.py
mesakhcienet May 12, 2026
6110806
Update pipeline.py
mesakhcienet May 12, 2026
171d0d5
Update pipeline.py
mesakhcienet May 12, 2026
c3288b8
Update pipeline.py
mesakhcienet May 12, 2026
a5822d2
Update pipeline.py
mesakhcienet May 12, 2026
6ff1d2d
test-vmap-1
mesakhcienet May 12, 2026
d0ec8ee
test-vmap-1
mesakhcienet May 12, 2026
0f0f119
Update pipeline.py
mesakhcienet May 12, 2026
04ea8ec
Update pipeline.py
mesakhcienet May 12, 2026
c97f31e
update implementation
mesakhcienet May 13, 2026
e24cbaa
Update pipeline.py
mesakhcienet May 13, 2026
1b821ce
feat: add diagnostic logging to NNX pipeline for Linen comparison
mesakhcienet May 13, 2026
a300c16
test: add challenger experiment scripts for memory gap investigation
mesakhcienet May 13, 2026
90cc374
fix: move jax.checkpoint INSIDE scan body to match Linen pattern
mesakhcienet May 13, 2026
5931f21
fix: use scan_pipeline_repeats for outer repeat loop (not iterations)
mesakhcienet May 13, 2026
44acdbf
feat: port 0506 nested scan architecture + jax.lax.scan(unroll=N)
mesakhcienet May 13, 2026
c4f667f
fix: use dual-buffer BSW to fix circular pipeline numerical mismatch
mesakhcienet May 13, 2026
f02cc11
fix: dual-buffer BSW + remove unroll for exact 0506 parity
mesakhcienet May 13, 2026
63b233f
fix: single all-gather per repeat with w_curr in carry
mesakhcienet May 13, 2026
5f24558
revert: restore L1/L2/L3 custom_vjp baseline (5a312b548)
mesakhcienet May 13, 2026
94034eb
revert: restore nnx_wrappers.py and decoders.py to baseline
mesakhcienet May 13, 2026
b9d8bde
feat: add diagnostic logging for metrics-as-scan-ys flow
mesakhcienet May 13, 2026
c3cd002
feat: add prevent_cse=True on L1 jax.remat for XLA fission barriers
mesakhcienet May 13, 2026
f0e54be
feat: insert jax.lax.optimization_barrier inside L1 _forward
mesakhcienet May 13, 2026
a7fdcdc
Update pipeline.py
mesakhcienet May 13, 2026
def6575
Update nnx_wrappers.py
mesakhcienet May 13, 2026
851e70d
Update pipeline.py
mesakhcienet May 13, 2026
b4c575e
Update pipeline.py
mesakhcienet May 13, 2026
ce40cb9
Update pipeline.py
mesakhcienet May 13, 2026
23219ab
Update pipeline.py
mesakhcienet May 13, 2026
07155e4
Update pipeline.py
mesakhcienet May 13, 2026
bf38b06
Update pipeline.py
mesakhcienet May 13, 2026
16c8498
test older version
mesakhcienet May 13, 2026
ba96bc0
update
mesakhcienet May 13, 2026
6f2dd8e
update
mesakhcienet May 13, 2026
2e26fbd
update
mesakhcienet May 13, 2026
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/run_tests_against_package.yml
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ jobs:
run:
runs-on: ${{ inputs.cloud_runner != '' && inputs.cloud_runner || fromJson(format('["self-hosted", "{0}", "{1}"]', inputs.device_type, inputs.device_name)) }}
container:
image: gcr.io/${{ vars.PROJECT_NAME }}/${{ inputs.base_image }}
image: gcr.io/${{ vars.PROJECT_NAME || 'tpu-prod-env-multipod' }}/${{ inputs.base_image }}
env:
XLA_PYTHON_CLIENT_MEM_FRACTION: ${{ inputs.xla_python_client_mem_fraction }}
TF_FORCE_GPU_ALLOW_GROWTH: ${{ inputs.tf_force_gpu_allow_growth }}
Expand Down
139 changes: 99 additions & 40 deletions src/maxtext/layers/decoders.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,8 @@
from maxtext.layers import linears
from maxtext.layers import mhc
from maxtext.layers import normalizations
from maxtext.layers import pipeline
from maxtext.layers import pipeline_0506 as pipeline
from maxtext.layers.nnx_decoders import NNXDecoderLayer, NNXSequentialPipelineStage, NNXScannedPipelineStage
from maxtext.layers import quantizations
from maxtext.layers.attentions import attention_as_linen
from maxtext.layers.embeddings import attend_on_embedding, embed_as_linen, positional_embedding_as_linen
Expand Down Expand Up @@ -263,7 +264,7 @@ def __call__(
page_state=page_state,
)
if self.config.scan_layers:
inputs = inputs[0] # When scan_layers is True the decoder layers return (outputs, None).
inputs = inputs[0] # When scan_layers is True the decoder layers return (outputs, None).
if self.config.scan_layers:
return inputs, None # pytype: disable=bad-return-type
else:
Expand Down Expand Up @@ -308,10 +309,14 @@ def setup(self):
self.decoder_layer = self.get_decoder_layers()
self.norm_layer = self.get_norm_layer(num_features=self.config.emb_dim)
if self.config.using_pipeline_parallelism:
pipeline_stage_module = self.get_pipeline_stage_module(self.decoder_layer)
remat_policy = self.get_remat_policy()
nnx_blocks = self._get_nnx_decoder_block_classes()

def stage_factory(rngs):
return self._build_nnx_pipeline_stage(nnx_blocks, rngs)

self.pipeline_module = pipeline.create_pipeline(
config=self.config, mesh=self.mesh, layers=pipeline_stage_module, remat_policy=remat_policy
config=self.config, layers=stage_factory, mesh=self.mesh, remat_policy=remat_policy
)

def minimal_policy(self, with_context=False, with_quantization=False):
Expand Down Expand Up @@ -499,6 +504,44 @@ def get_decoder_layers(self):
# Default case to handle any unknown decoder block types.
raise ValueError(f"Incorrect decoder_block name {self.config.decoder_block.value=}")

def _get_nnx_decoder_block_classes(self):
"""Returns NNX decoder block classes for pipeline stage creation."""
cfg = self.config

def get_scannable(normal_cls, scannable_cls):
return [scannable_cls] if cfg.scan_layers else [normal_cls]

def get_deepseek():
if cfg.use_batch_split_schedule:
return [deepseek_batchsplit.DeepSeekDenseLayer, deepseek_batchsplit.DeepSeekMoELayer]
return [deepseek.DeepSeekDenseLayer, deepseek.DeepSeekMoELayer]

layer_map = {
DecoderBlockType.DEFAULT: [NNXDecoderLayer],
DecoderBlockType.LLAMA2: [llama2.LlamaDecoderLayer],
DecoderBlockType.MISTRAL: [mistral.MistralDecoderLayer],
DecoderBlockType.MIXTRAL: [mixtral.MixtralDecoderLayer],
DecoderBlockType.GEMMA: [gemma.GemmaDecoderLayer],
DecoderBlockType.GEMMA2: [gemma2.Gemma2DecoderLayer],
DecoderBlockType.GEMMA3: [gemma3.Gemma3DecoderLayer],
DecoderBlockType.GEMMA4: get_scannable(gemma4.Gemma4DecoderLayer, gemma4.Gemma4ScannableBlock),
DecoderBlockType.GPT3: [gpt3.Gpt3DecoderLayer],
DecoderBlockType.GPT_OSS: get_scannable(gpt_oss.GptOssDecoderLayer, gpt_oss.GptOssScannableBlock),
DecoderBlockType.QWEN2: [qwen2.Qwen2DecoderLayer],
DecoderBlockType.QWEN3: [qwen3.Qwen3DecoderLayer],
DecoderBlockType.QWEN3_MOE: [qwen3.Qwen3MoeDecoderLayer],
DecoderBlockType.QWEN3_NEXT: get_scannable(qwen3.Qwen3NextDecoderLayer, qwen3.Qwen3NextScannableBlock),
DecoderBlockType.SIMPLE: [simple_layer.SimpleDecoderLayer],
DecoderBlockType.SIMPLE_MLP: [simple_layer.SimpleMlpDecoderLayer],
DecoderBlockType.DEEPSEEK: get_deepseek(),
DecoderBlockType.LLAMA4: get_scannable(llama4.Llama4DecoderLayer, llama4.Llama4ScannableBlock),
DecoderBlockType.OLMO3: get_scannable(olmo3.Olmo3DecoderLayer, olmo3.Olmo3ScannableBlock),
}

if cfg.decoder_block not in layer_map:
raise ValueError(f"Incorrect decoder_block name {cfg.decoder_block.value=}")
return layer_map[cfg.decoder_block]

def set_remat_policy(self, block_layers, policy):
"""Set remat policy"""
RemattedBlockLayers = []
Expand Down Expand Up @@ -527,6 +570,58 @@ def map_fn(path, value):
RemattedBlockLayers.append(layer)
return RemattedBlockLayers

def _build_nnx_pipeline_stage(self, decoder_blocks, rngs):
"""Creates a single NNX pipeline stage module."""
cfg = self.config
base_stage_cls = decoder_blocks[1] if cfg.decoder_block == DecoderBlockType.DEEPSEEK else decoder_blocks[0]

if cfg.num_layers_per_pipeline_stage == 1:
return base_stage_cls(config=cfg, mesh=self.mesh, quant=self.quant, model_mode=self.model_mode, rngs=rngs)
elif cfg.scan_layers_per_stage:
return NNXScannedPipelineStage(
base_stage_cls, cfg.num_layers_per_pipeline_stage, cfg, self.mesh, self.quant, self.model_mode, rngs=rngs
)
return NNXSequentialPipelineStage(
base_stage_cls, cfg.num_layers_per_pipeline_stage, cfg, self.mesh, self.quant, self.model_mode, rngs=rngs
)

def get_pipeline_stage_module(self, decoder_blocks):
"""get pipeline stage module"""

def get_layer_to_pipeline(blocks, cfg):
if cfg.decoder_block == DecoderBlockType.DEEPSEEK:
return blocks[1] # return the sparse block
else:
return blocks[0]

cfg = self.config
base_stage = get_layer_to_pipeline(decoder_blocks, cfg)
if cfg.set_remat_policy_on_layers_per_stage:
policy = self.get_remat_policy()
base_stage = self.set_remat_policy([base_stage], policy)[0]
if cfg.num_layers_per_pipeline_stage == 1:
stage_module = base_stage(config=cfg, mesh=self.mesh, quant=self.quant, model_mode=self.model_mode)
elif cfg.scan_layers_per_stage:
stage_module = self.scan_decoder_layers(
cfg,
base_stage,
cfg.num_layers_per_pipeline_stage,
"layers_per_stage",
self.mesh,
in_axes_tuple=(nn.broadcast,) * 4,
model_mode=self.model_mode,
)
else:
stage_module = SequentialBlockDecoderLayers(
decoder_layer=base_stage,
num_decoder_layers=cfg.num_layers_per_pipeline_stage,
config=cfg,
mesh=self.mesh,
quant=self.quant,
model_mode=self.model_mode,
)
return stage_module

def get_norm_layer(self, num_features: int):
"""get normalization layer (return type inherits from nn.Module)"""
if self.config.decoder_block in (
Expand Down Expand Up @@ -587,42 +682,6 @@ def scan_decoder_layers(self, cfg, decoder_layer, length, metadata_axis_name, me
config=cfg, mesh=mesh, name=metadata_axis_name, quant=self.quant, **kwargs # pytype: disable=wrong-keyword-args
)

def get_pipeline_stage_module(self, decoder_blocks):
"""get pipeline stage module"""

def get_layer_to_pipeline(blocks, cfg):
if cfg.decoder_block == DecoderBlockType.DEEPSEEK:
return blocks[1] # return the sparse block
else:
return blocks[0]

cfg = self.config
base_stage = get_layer_to_pipeline(decoder_blocks, cfg)
if cfg.set_remat_policy_on_layers_per_stage:
policy = self.get_remat_policy()
base_stage = self.set_remat_policy([base_stage], policy)[0]
if cfg.num_layers_per_pipeline_stage == 1:
stage_module = base_stage(config=cfg, mesh=self.mesh, quant=self.quant, model_mode=self.model_mode)
elif cfg.scan_layers_per_stage:
stage_module = self.scan_decoder_layers(
cfg,
base_stage,
cfg.num_layers_per_pipeline_stage,
"layers_per_stage",
self.mesh,
in_axes_tuple=(nn.broadcast,) * 4,
)
else:
stage_module = SequentialBlockDecoderLayers(
decoder_layer=base_stage,
num_decoder_layers=cfg.num_layers_per_pipeline_stage,
config=cfg,
mesh=self.mesh,
quant=self.quant,
model_mode=self.model_mode,
)
return stage_module

@nn.compact
def _apply_embedding(
self,
Expand Down
Loading