diff --git a/src/maxtext/configs/base.yml b/src/maxtext/configs/base.yml index 39ef5b2dee..6eddc283f7 100644 --- a/src/maxtext/configs/base.yml +++ b/src/maxtext/configs/base.yml @@ -190,6 +190,9 @@ mtp_loss_scaling_factor: 0.1 # acceptance rate during evaluation. for example, a value of `1` targets the # first auxiliary prediction head. set to 0 to disable this specific evaluation mtp_eval_target_module: 0 +# the type of decoder layer to use for mtp blocks. if empty, defaults to the last +# decoder layer type of the base model. +mtp_decoder_type: '' # mixture of experts (moe) num_experts: 1 diff --git a/src/maxtext/configs/types.py b/src/maxtext/configs/types.py index af809e2b42..138829a506 100644 --- a/src/maxtext/configs/types.py +++ b/src/maxtext/configs/types.py @@ -492,6 +492,10 @@ class MTP(BaseModel): 0, description="Specifies which MTP layer is used to calculate metrics.", ) + mtp_decoder_type: str = Field( + "", + description="MTP decoder layer type. Defaults to base model's last layer if empty.", + ) class Logits(BaseModel): diff --git a/src/maxtext/layers/decoders.py b/src/maxtext/layers/decoders.py index eb5630968f..13c91c2d25 100644 --- a/src/maxtext/layers/decoders.py +++ b/src/maxtext/layers/decoders.py @@ -438,13 +438,14 @@ def get_remat_policy(self): policy = None return policy - def get_decoder_layers(self): + def get_decoder_layers(self, decoder_block_type=None): """Retrieves a list of decoder layer classes based on the `decoder_block` config. Returns: A list containing one or more `nn.Module` classes for the decoder. """ - match self.config.decoder_block: + block_type = decoder_block_type or self.config.decoder_block + match block_type: case DecoderBlockType.DEFAULT: return [DecoderLayer] case DecoderBlockType.LLAMA2: diff --git a/src/maxtext/layers/nnx_decoders.py b/src/maxtext/layers/nnx_decoders.py index 3c8a601201..2693260267 100644 --- a/src/maxtext/layers/nnx_decoders.py +++ b/src/maxtext/layers/nnx_decoders.py @@ -481,9 +481,10 @@ def layer_fn(carry, scanned_vars): return final_carry, nnx.merge(graphdef, scanned_state) - def get_decoder_layers(self): + def get_decoder_layers(self, decoder_block_type=None): """Retrieves decoder layer classes based on config using a dictionary lookup.""" cfg = self.config + block_type = decoder_block_type or cfg.decoder_block def get_scannable(normal_cls, scannable_cls): return [scannable_cls] if cfg.scan_layers else [normal_cls] @@ -513,10 +514,10 @@ def get_deepseek(): DecoderBlockType.OLMO3: get_scannable(olmo3.Olmo3DecoderLayer, olmo3.Olmo3ScannableBlock), } - if cfg.decoder_block not in layer_map: - raise ValueError(f"Incorrect decoder_block name {cfg.decoder_block.value=}") + if block_type not in layer_map: + raise ValueError(f"Incorrect decoder_block name {block_type.value=}") - return layer_map[cfg.decoder_block] + return layer_map[block_type] def minimal_policy(self, with_context=False, with_quantization=False): """Helper for creating minimal checkpoint policies.""" diff --git a/src/maxtext/models/models.py b/src/maxtext/models/models.py index 23ad166c79..0585dac3ab 100644 --- a/src/maxtext/models/models.py +++ b/src/maxtext/models/models.py @@ -17,6 +17,9 @@ # pylint: disable=no-name-in-module from typing import Any +import logging + +logger = logging.getLogger(__name__) import jax import jax.numpy as jnp @@ -27,7 +30,7 @@ from maxtext.common.common_types import Config, DECODING_ACTIVE_SEQUENCE_INDICATOR, MODEL_MODE_AUTOREGRESSIVE, MODEL_MODE_TRAIN, MultimodalInput from maxtext.inference import page_manager -from maxtext.layers.nnx_decoders import NNXDecoder +from maxtext.layers.nnx_decoders import NNXDecoder, DecoderBlockType from maxtext.layers import initializers from maxtext.layers import nnx_wrappers from maxtext.layers.decoders import Decoder @@ -96,7 +99,14 @@ def setup(self): # For MTP, we use the DecoderLayer blueprint to ensure architectural consistency. # By convention, this is the last layer in the list. layer_types = self.decoder.get_decoder_layers() - mtp_layer_linen = layer_types[-1] + mtp_decoder_type = getattr(self.config, "mtp_decoder_type", None) + if mtp_decoder_type: + mtp_block_enum = DecoderBlockType(mtp_decoder_type) + mtp_layer_linen = self.decoder.get_decoder_layers(decoder_block_type=mtp_block_enum)[-1] + logger.info("MTP using layer type from decoder_block_type: %s", mtp_block_enum) + else: + mtp_layer_linen = layer_types[-1] + logger.info("MTP using default fallback layer type: %s", mtp_layer_linen) # UNWRAP: The MTP block is pure NNX. If the decoder returned a Linen wrapper, # extract the native NNX class to preserve parameter tracing/scoping. mtp_layer_nnx = getattr(mtp_layer_linen, "module_class", mtp_layer_linen) @@ -383,9 +393,14 @@ def __init__( if self.config.mtp_num_layers > 0: # Get the list of layer blueprints for the current model. layer_types = self.decoder.get_decoder_layers() - # For MTP, we use the DecoderLayer blueprint to ensure architectural consistency. - # By convention, this is the last layer in the list. - mtp_layer = layer_types[-1] + mtp_decoder_type = getattr(self.config, "mtp_decoder_type", None) + if mtp_decoder_type: + mtp_block_enum = DecoderBlockType(mtp_decoder_type) + mtp_layer = self.decoder.get_decoder_layers(decoder_block_type=mtp_block_enum)[-1] + logger.info("MTP using layer type from decoder_block_type: %s", mtp_block_enum) + else: + mtp_layer = layer_types[-1] + logger.info("MTP using default fallback layer type: %s", mtp_layer) mtp_block_linen = multi_token_prediction_block_as_linen( config=self.config, mesh=self.mesh, diff --git a/tests/unit/multi_token_prediction_test.py b/tests/unit/multi_token_prediction_test.py index ffe30bea6e..a5c5ecd2f6 100644 --- a/tests/unit/multi_token_prediction_test.py +++ b/tests/unit/multi_token_prediction_test.py @@ -120,7 +120,7 @@ def test_multi_token_prediction_layer_output(self): class MTPBlockTestModel(nnx.Module): """A lightweight wrapper model for testing the MTPBlock.""" - def __init__(self, config: Config, mesh: Mesh, *, rngs: nnx.Rngs): + def __init__(self, config: Config, mesh: Mesh, *, rngs: nnx.Rngs, transformer_layer_module=NNXDecoderLayer): """Initializes the MTP block and its dependencies for the test.""" self.config = config self.mesh = mesh @@ -156,7 +156,7 @@ def apply_output_head(self, _shared_embedding, hidden_state, _deterministic, mod self.mtp_block = multi_token_prediction.MultiTokenPredictionBlock( config=self.config, mesh=self.mesh, - transformer_layer_module=NNXDecoderLayer, + transformer_layer_module=transformer_layer_module, decoder=self.decoder, rngs=self.rngs, ) @@ -376,5 +376,65 @@ def test_mtp_roll_and_mask_shapes(self): self.assertTrue(jnp.array_equal(rolled_by_0, input_tensor), "A shift of 0 should be a no-op.") +class FlexibleMultiTokenPredictionBlockTest(unittest.TestCase): + """Unit tests for the MultiTokenPredictionBlock with flexible layer types.""" + + def test_instantiate_with_dummy_layer(self): + """Test that we can instantiate MTP block with a different layer type.""" + extra_args = get_decoupled_parallelism_overrides() + cfg = pyconfig.initialize( + [None, get_test_config_path()], + run_name="flexible_mtp_test", + skip_jax_distributed_system=True, + base_emb_dim=16, + mtp_num_layers=1, + **extra_args, + ) + rng = jax.random.PRNGKey(42) + rngs = nnx.Rngs(params=rng, dropout=rng) + devices_array = maxtext_utils.create_device_mesh(cfg) + mesh = Mesh(devices_array, cfg.mesh_axes) + + # Define a dummy layer that is NOT NNXDecoderLayer + class DummyMtpLayer(nnx.Module): + + def __init__(self, config: Config, mesh: Mesh, model_mode: str, rngs: nnx.Rngs): + self.config = config + self.mesh = mesh + self.model_mode = model_mode + self.weight = nnx.Param(jnp.ones((1,))) + + def __call__(self, inputs, decoder_segment_ids, decoder_positions, deterministic, model_mode): + return inputs + + test_model = MTPBlockTestModel( + config=cfg, + mesh=mesh, + rngs=rngs, + transformer_layer_module=DummyMtpLayer, # Passing dummy layer here + ) + + # Try a dummy forward pass + batch_size = jax.device_count() + seq_len = 8 + main_hidden_state = jnp.zeros((batch_size, seq_len, cfg.base_emb_dim)) + input_ids = jnp.zeros((batch_size, seq_len), dtype=jnp.int32) + target_ids = jnp.zeros((batch_size, seq_len), dtype=jnp.int32) + target_mask = jnp.ones((batch_size, seq_len)) + position_ids = jnp.zeros((batch_size, seq_len), dtype=jnp.int32) + decoder_segment_ids = jnp.ones((batch_size, seq_len), dtype=jnp.int32) + + test_model( + main_hidden_state=main_hidden_state, + input_ids=input_ids, + target_ids=target_ids, + target_mask=target_mask, + position_ids=position_ids, + decoder_segment_ids=decoder_segment_ids, + model_mode=MODEL_MODE_TRAIN, + deterministic=True, + ) + + if __name__ == "__main__": unittest.main()