Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions src/maxtext/configs/base.yml
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,9 @@ mtp_loss_scaling_factor: 0.1
# acceptance rate during evaluation. for example, a value of `1` targets the
# first auxiliary prediction head. set to 0 to disable this specific evaluation
mtp_eval_target_module: 0
# the type of decoder layer to use for mtp blocks. if empty, defaults to the last
# decoder layer type of the base model.
mtp_decoder_type: ''

# mixture of experts (moe)
num_experts: 1
Expand Down
4 changes: 4 additions & 0 deletions src/maxtext/configs/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -492,6 +492,10 @@ class MTP(BaseModel):
0,
description="Specifies which MTP layer is used to calculate metrics.",
)
mtp_decoder_type: str = Field(
"",
description="MTP decoder layer type. Defaults to base model's last layer if empty.",
)


class Logits(BaseModel):
Expand Down
5 changes: 3 additions & 2 deletions src/maxtext/layers/decoders.py
Original file line number Diff line number Diff line change
Expand Up @@ -438,13 +438,14 @@ def get_remat_policy(self):
policy = None
return policy

def get_decoder_layers(self):
def get_decoder_layers(self, decoder_block_type=None):
"""Retrieves a list of decoder layer classes based on the `decoder_block` config.

Returns:
A list containing one or more `nn.Module` classes for the decoder.
"""
match self.config.decoder_block:
block_type = decoder_block_type or self.config.decoder_block
match block_type:
case DecoderBlockType.DEFAULT:
return [DecoderLayer]
case DecoderBlockType.LLAMA2:
Expand Down
9 changes: 5 additions & 4 deletions src/maxtext/layers/nnx_decoders.py
Original file line number Diff line number Diff line change
Expand Up @@ -481,9 +481,10 @@ def layer_fn(carry, scanned_vars):

return final_carry, nnx.merge(graphdef, scanned_state)

def get_decoder_layers(self):
def get_decoder_layers(self, decoder_block_type=None):
"""Retrieves decoder layer classes based on config using a dictionary lookup."""
cfg = self.config
block_type = decoder_block_type or cfg.decoder_block

def get_scannable(normal_cls, scannable_cls):
return [scannable_cls] if cfg.scan_layers else [normal_cls]
Expand Down Expand Up @@ -513,10 +514,10 @@ def get_deepseek():
DecoderBlockType.OLMO3: get_scannable(olmo3.Olmo3DecoderLayer, olmo3.Olmo3ScannableBlock),
}

if cfg.decoder_block not in layer_map:
raise ValueError(f"Incorrect decoder_block name {cfg.decoder_block.value=}")
if block_type not in layer_map:
raise ValueError(f"Incorrect decoder_block name {block_type.value=}")

return layer_map[cfg.decoder_block]
return layer_map[block_type]

def minimal_policy(self, with_context=False, with_quantization=False):
"""Helper for creating minimal checkpoint policies."""
Expand Down
25 changes: 20 additions & 5 deletions src/maxtext/models/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@
# pylint: disable=no-name-in-module

from typing import Any
import logging

logger = logging.getLogger(__name__)

import jax
import jax.numpy as jnp
Expand All @@ -27,7 +30,7 @@

from maxtext.common.common_types import Config, DECODING_ACTIVE_SEQUENCE_INDICATOR, MODEL_MODE_AUTOREGRESSIVE, MODEL_MODE_TRAIN, MultimodalInput
from maxtext.inference import page_manager
from maxtext.layers.nnx_decoders import NNXDecoder
from maxtext.layers.nnx_decoders import NNXDecoder, DecoderBlockType
from maxtext.layers import initializers
from maxtext.layers import nnx_wrappers
from maxtext.layers.decoders import Decoder
Expand Down Expand Up @@ -96,7 +99,14 @@ def setup(self):
# For MTP, we use the DecoderLayer blueprint to ensure architectural consistency.
# By convention, this is the last layer in the list.
layer_types = self.decoder.get_decoder_layers()
mtp_layer_linen = layer_types[-1]
mtp_decoder_type = getattr(self.config, "mtp_decoder_type", None)
if mtp_decoder_type:
mtp_block_enum = DecoderBlockType(mtp_decoder_type)
mtp_layer_linen = self.decoder.get_decoder_layers(decoder_block_type=mtp_block_enum)[-1]
logger.info("MTP using layer type from decoder_block_type: %s", mtp_block_enum)
else:
mtp_layer_linen = layer_types[-1]
logger.info("MTP using default fallback layer type: %s", mtp_layer_linen)
# UNWRAP: The MTP block is pure NNX. If the decoder returned a Linen wrapper,
# extract the native NNX class to preserve parameter tracing/scoping.
mtp_layer_nnx = getattr(mtp_layer_linen, "module_class", mtp_layer_linen)
Expand Down Expand Up @@ -383,9 +393,14 @@ def __init__(
if self.config.mtp_num_layers > 0:
# Get the list of layer blueprints for the current model.
layer_types = self.decoder.get_decoder_layers()
# For MTP, we use the DecoderLayer blueprint to ensure architectural consistency.
# By convention, this is the last layer in the list.
mtp_layer = layer_types[-1]
mtp_decoder_type = getattr(self.config, "mtp_decoder_type", None)
if mtp_decoder_type:
mtp_block_enum = DecoderBlockType(mtp_decoder_type)
mtp_layer = self.decoder.get_decoder_layers(decoder_block_type=mtp_block_enum)[-1]
logger.info("MTP using layer type from decoder_block_type: %s", mtp_block_enum)
else:
mtp_layer = layer_types[-1]
logger.info("MTP using default fallback layer type: %s", mtp_layer)
mtp_block_linen = multi_token_prediction_block_as_linen(
config=self.config,
mesh=self.mesh,
Expand Down
64 changes: 62 additions & 2 deletions tests/unit/multi_token_prediction_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ def test_multi_token_prediction_layer_output(self):
class MTPBlockTestModel(nnx.Module):
"""A lightweight wrapper model for testing the MTPBlock."""

def __init__(self, config: Config, mesh: Mesh, *, rngs: nnx.Rngs):
def __init__(self, config: Config, mesh: Mesh, *, rngs: nnx.Rngs, transformer_layer_module=NNXDecoderLayer):
"""Initializes the MTP block and its dependencies for the test."""
self.config = config
self.mesh = mesh
Expand Down Expand Up @@ -156,7 +156,7 @@ def apply_output_head(self, _shared_embedding, hidden_state, _deterministic, mod
self.mtp_block = multi_token_prediction.MultiTokenPredictionBlock(
config=self.config,
mesh=self.mesh,
transformer_layer_module=NNXDecoderLayer,
transformer_layer_module=transformer_layer_module,
decoder=self.decoder,
rngs=self.rngs,
)
Expand Down Expand Up @@ -376,5 +376,65 @@ def test_mtp_roll_and_mask_shapes(self):
self.assertTrue(jnp.array_equal(rolled_by_0, input_tensor), "A shift of 0 should be a no-op.")


class FlexibleMultiTokenPredictionBlockTest(unittest.TestCase):
"""Unit tests for the MultiTokenPredictionBlock with flexible layer types."""

def test_instantiate_with_dummy_layer(self):
"""Test that we can instantiate MTP block with a different layer type."""
extra_args = get_decoupled_parallelism_overrides()
cfg = pyconfig.initialize(
[None, get_test_config_path()],
run_name="flexible_mtp_test",
skip_jax_distributed_system=True,
base_emb_dim=16,
mtp_num_layers=1,
**extra_args,
)
rng = jax.random.PRNGKey(42)
rngs = nnx.Rngs(params=rng, dropout=rng)
devices_array = maxtext_utils.create_device_mesh(cfg)
mesh = Mesh(devices_array, cfg.mesh_axes)

# Define a dummy layer that is NOT NNXDecoderLayer
class DummyMtpLayer(nnx.Module):

def __init__(self, config: Config, mesh: Mesh, model_mode: str, rngs: nnx.Rngs):
self.config = config
self.mesh = mesh
self.model_mode = model_mode
self.weight = nnx.Param(jnp.ones((1,)))

def __call__(self, inputs, decoder_segment_ids, decoder_positions, deterministic, model_mode):
return inputs

test_model = MTPBlockTestModel(
config=cfg,
mesh=mesh,
rngs=rngs,
transformer_layer_module=DummyMtpLayer, # Passing dummy layer here
)

# Try a dummy forward pass
batch_size = jax.device_count()
seq_len = 8
main_hidden_state = jnp.zeros((batch_size, seq_len, cfg.base_emb_dim))
input_ids = jnp.zeros((batch_size, seq_len), dtype=jnp.int32)
target_ids = jnp.zeros((batch_size, seq_len), dtype=jnp.int32)
target_mask = jnp.ones((batch_size, seq_len))
position_ids = jnp.zeros((batch_size, seq_len), dtype=jnp.int32)
decoder_segment_ids = jnp.ones((batch_size, seq_len), dtype=jnp.int32)

test_model(
main_hidden_state=main_hidden_state,
input_ids=input_ids,
target_ids=target_ids,
target_mask=target_mask,
position_ids=position_ids,
decoder_segment_ids=decoder_segment_ids,
model_mode=MODEL_MODE_TRAIN,
deterministic=True,
)


if __name__ == "__main__":
unittest.main()
Loading