diff --git a/src/maxtext/common/common_types.py b/src/maxtext/common/common_types.py index c48a4e2b6c..b5a3efcf75 100644 --- a/src/maxtext/common/common_types.py +++ b/src/maxtext/common/common_types.py @@ -109,8 +109,6 @@ class DecoderBlockType(enum.Enum): LLAMA4 = "llama4" OLMO3 = "olmo3" - LLAMA2LTI = "llama2_learn_to_init" - class AttentionType(enum.Enum): GLOBAL = "global" # default, with causality diff --git a/src/maxtext/configs/types.py b/src/maxtext/configs/types.py index b463516945..ad5a24bbdd 100644 --- a/src/maxtext/configs/types.py +++ b/src/maxtext/configs/types.py @@ -1225,6 +1225,14 @@ class Distillation(BaseModel): default_factory=dict, description="Experimental weight sharing map inside the student model for learn-to-init phase", ) + + attn_module_name: Optional[str] = Field( + None, description="Attention nnx module attribute name to augment with LTI logic" + ) + + lti_layer_indices: Optional[list[int]] = Field( + None, description="List of layer indices to apply LTI modifications. If None, applied to all layers." + ) # --------------------------------------- # --- Distillation freezing filter -- diff --git a/src/maxtext/layers/decoders.py b/src/maxtext/layers/decoders.py index 42644ab262..2b9713eb44 100644 --- a/src/maxtext/layers/decoders.py +++ b/src/maxtext/layers/decoders.py @@ -449,8 +449,6 @@ def get_decoder_layers(self): return [DecoderLayer] case DecoderBlockType.LLAMA2: return [llama2.LlamaDecoderLayerToLinen] - case DecoderBlockType.LLAMA2LTI: - return [llama2.LlamaLTIDecoderLayerToLinen] case DecoderBlockType.MISTRAL: # TODO(ranran): update to Mistral with sliding window attention return [mistral.MistralDecoderLayerToLinen] @@ -545,7 +543,6 @@ def get_norm_layer(self, num_features: int): DecoderBlockType.SIMPLE_MLP, DecoderBlockType.LLAMA4, DecoderBlockType.OLMO3, - DecoderBlockType.LLAMA2LTI, ): return functools.partial(rms_norm, num_features=num_features, shard_mode=self.config.shard_mode) elif self.config.decoder_block == DecoderBlockType.GPT3: diff --git a/src/maxtext/layers/learn_to_init_layer.py b/src/maxtext/layers/learn_to_init_layer.py index a1070a1684..2530c17336 100644 --- a/src/maxtext/layers/learn_to_init_layer.py +++ b/src/maxtext/layers/learn_to_init_layer.py @@ -18,7 +18,7 @@ from flax import nnx from maxtext.layers import linears, initializers from maxtext.common.common_types import Config -from jax.sharding import Mesh, NamedSharding +from jax.sharding import NamedSharding import jax.numpy as jnp from typing import Iterable, Optional @@ -29,93 +29,96 @@ from maxtext.utils import max_logging, max_utils -class LearnToInitDecoderLayer(nnx.Module): - """ - A generic wrapper that initializes a base decoder layer and dynamically swaps - its DenseGeneral modules for learn-to-init distillation. +LTI_MODIFIED_ATTENTION_PARAM_NAMES = ["query", "key", "value", "out"] +LTI_ORIGINAL_ATTENTION_PARAMS_NAME = "kernel" +LTI_LAYER_PATH_PREFIXES = ("layers_", "dense_layers_", "moe_layers_") - This class instantiates a standard base decoder layer (e.g., LlamaDecoderLayer) - and replaces specific attention projection sub-modules ("query", "key", "value", - "out") with customized `LearnToInitDense` modules. - Attributes: - learn_to_init_wrapper: The instantiated base decoder layer containing the mutable NNX graph. - config: The model configuration parameters. - rngs: The random number generator state used for initialization. - self_attention_module_name: The target name of the attention module to customize. +def apply_lti_modification(module: nnx.Module, module_name: str | None = None): + """ + Applies Learn-To-Init structural modifications to an instantiated NNX module. + Checks the config to determine if LTI is enabled. """ - def __init__( - self, - base_layer_cls, - config: Config, - model_mode: str, - mesh: Mesh, - rngs: nnx.Rngs, - quant=None, - **kwargs, - ): - # Instantiate the original layer (e.g., LlamaDecoderLayer) - self.learn_to_init_wrapper = base_layer_cls( - config=config, model_mode=model_mode, mesh=mesh, rngs=rngs, quant=quant, **kwargs - ) - - self.config = config - self.rngs = rngs - - self.self_attention_module_name = "self_attention" - - # replace relevant nnx modules with customized LearnToInit modules - self._customize_attention_modules(self.learn_to_init_wrapper) - - def _customize_attention_modules(self, module: nnx.Module): - """Replaces specific DenseGeneral modules (q, k, v projections) in the attention module.""" - attention_module = getattr(module, self.self_attention_module_name, None) - if attention_module is None: - return - - # Target Q, K, V projections sub module names - target_names = ["query", "key", "value", "out"] - - use_general_linear_map = self.config.lti_use_general_linear_map - teacher_config = self.config.teacher_config - - for name in target_names: - child = getattr(attention_module, name, None) - if isinstance(child, linears.DenseGeneral): - orig_proj_shape = child.kernel.shape - assert len(orig_proj_shape) == 3 - if name in ("query", "key", "value"): - teacher_heads_num = teacher_config.base_num_query_heads if name == "query" else teacher_config.base_num_kv_heads - teacher_shape = (orig_proj_shape[0], teacher_heads_num, teacher_config.head_dim) - elif name == "out": - teacher_shape = (teacher_config.base_num_query_heads, teacher_config.head_dim, orig_proj_shape[2]) + config = getattr(module, "config", None) + if not config or not getattr(config, "learn_to_init_mode", False): + return module + + lti_layer_indices = getattr(config, "lti_layer_indices", None) + if lti_layer_indices is not None and not getattr(config, "scan_layers", True): + if module_name is not None and "layers_" in module_name: + try: + local_idx = int(module_name.split("_")[-1]) + if module_name.startswith("dense_layers_"): + layer_idx = local_idx + elif module_name.startswith("moe_layers_"): + first_dense = getattr(config, "first_num_dense_layers", 0) + layer_idx = first_dense + local_idx else: - max_logging.warning(f"Non handled LTI projection type {name}") - continue - new_module = LearnToInitDense( - in_features_shape=child.in_features_shape, - out_features_shape=child.out_features_shape, - C=jnp.empty(teacher_shape), - axis=child.axis, - weight_dtype=child.weight_dtype, - dtype=child.dtype, - kernel_init=child.kernel_init, - kernel_axes=child.kernel_axes, - quant=child.quant, - use_bias=child.use_bias, - shard_mode=child.shard_mode, - matmul_precision=child.matmul_precision, - is_output_projection=(name == "out"), - use_general_linear_map=use_general_linear_map, - rngs=self.rngs, # Reuse the layer's RNG stream - ) - # Swap the module in the mutable NNX graph - setattr(attention_module, name, new_module) - - def __call__(self, *args, **kwargs): - # Just forward the forward pass arguments to the base layer - return self.learn_to_init_wrapper(*args, **kwargs) + layer_idx = local_idx + + if layer_idx not in lti_layer_indices: + max_logging.info( + f"apply_lti_modification: skipping module={module_name} since its index " + f"{layer_idx} is not in lti_layer_indices." + ) + return module + except ValueError: + pass + + attn_module_name = config.attn_module_name + + if attn_module_name: + max_logging.info(f"apply_lti_modification: customizing module={attn_module_name} in {module_name}") + _customize_attention_modules(config, attn_module_name, module) + return module + + +def _customize_attention_modules(config: Config, attn_module_name: str, module: nnx.Module): + """Replaces specific DenseGeneral modules (q, k, v projections) in the attention module.""" + attention_module = getattr(module, attn_module_name, None) + if attention_module is None: + return + + # Target Q, K, V projections sub module names + target_names = LTI_MODIFIED_ATTENTION_PARAM_NAMES + + use_general_linear_map = config.lti_use_general_linear_map + teacher_config = config.teacher_config + + for name in target_names: + child = getattr(attention_module, name, None) + if isinstance(child, linears.DenseGeneral): + orig_proj_shape = child.kernel.shape + assert len(orig_proj_shape) == 3 + if name in ("query", "key", "value"): + teacher_heads_num = teacher_config.base_num_query_heads if name == "query" else teacher_config.base_num_kv_heads + teacher_shape = (orig_proj_shape[0], teacher_heads_num, teacher_config.head_dim) + elif name == "out": + teacher_shape = (teacher_config.base_num_query_heads, teacher_config.head_dim, orig_proj_shape[2]) + else: + max_logging.warning(f"Non handled LTI projection type {name}") + continue + new_module = LearnToInitDense( + in_features_shape=child.in_features_shape, + out_features_shape=child.out_features_shape, + C=jnp.empty(teacher_shape), + axis=child.axis, + weight_dtype=child.weight_dtype, + dtype=child.dtype, + kernel_init=child.kernel_init, + kernel_axes=child.kernel_axes, + quant=child.quant, + use_bias=child.use_bias, + shard_mode=child.shard_mode, + matmul_precision=child.matmul_precision, + is_output_projection=(name == "out"), + use_general_linear_map=use_general_linear_map, + rngs=attention_module.rngs, # Reuse the original module RNGs + ) + # Swap the module in the mutable NNX graph + setattr(attention_module, name, new_module) + max_logging.info(f"Replaced {attn_module_name}.{name} with LearnToInitDense.{name}") class LearnToInitDense(nnx.Module): @@ -373,69 +376,71 @@ def apply_lti_model_update(student_model, student_config): It effectively collapses the learn-to-init parameterization back into a standard decoder architecture, modifying the `student_model` in-place. - NOTE: works for ToNXX decoder model and layer-scan mode only + NOTE: works for ToNXX decoder model Args: student_model: The trained student model to be updated in-place. student_config: The configuration of the student model containing parameters like `matmul_precision`. """ - # Access the nested ToNNX dictionary directly assert isinstance(student_model.decoder, ToNNX), "LTI now only supports ToNNX as the student_model's decoder type" - lti_wrapped_node = student_model.decoder.layers["learn_to_init_wrapper"] - attn_state_dict = lti_wrapped_node["self_attention"] - - # Iterate through known projections and compute final weights - for proj_name in ["query", "key", "value", "out"]: - if proj_name not in attn_state_dict: - raise ValueError("Unsupported structure of LTI-augmented Attention module.") - - proj_params = attn_state_dict[proj_name] - is_output_proj = proj_name == "out" - - C_param = proj_params.get(LearnToInitDense.TENSOR_C) - - if C_param is None: - raise ValueError("Attention LTI-augmented module has no C parameter.") - - if LearnToInitDense.TENSOR_W in proj_params: - max_logging.log(f"Computing final learn-to-init weight (general map) for: {proj_name}") - W_param = proj_params[LearnToInitDense.TENSOR_W] - final_kernel = calculate_attn_weight( - A=None, - B=None, - C=C_param, - general_map=W_param, - is_output_projection=is_output_proj, - matmul_precision=student_config.matmul_precision, - ) - elif LearnToInitDense.TENSOR_A in proj_params and LearnToInitDense.TENSOR_B in proj_params: - max_logging.log(f"Computing final learn-to-init weight for: {proj_name}") - A_param = proj_params[LearnToInitDense.TENSOR_A] - B_param = proj_params[LearnToInitDense.TENSOR_B] - final_kernel = calculate_attn_weight( - A=A_param, - B=B_param, - C=C_param, - is_output_projection=is_output_proj, - matmul_precision=student_config.matmul_precision, - ) - else: - continue - # 3. Overwrite C with the final computed kernel - C_param.set_value(final_kernel) + if student_config.attn_module_name is None: + return - # 4. Standardize the structure by placing it under the 'kernel' key - proj_params["kernel"] = C_param - - # 5. Clean up the LTI-specific parameters using .pop() - # Using pop(key, None) avoids KeyErrors if a tensor was omitted or already shared/deleted - proj_params.pop(LearnToInitDense.TENSOR_W, None) - proj_params.pop(LearnToInitDense.TENSOR_A, None) - proj_params.pop(LearnToInitDense.TENSOR_B, None) - proj_params.pop(LearnToInitDense.TENSOR_C, None) - - # unpack the learn_to_init_wrapper to match the standard model structure - del student_model.decoder.layers["learn_to_init_wrapper"] - student_model.decoder.layers.update(lti_wrapped_node) + if getattr(student_config, "scan_layers", True): + layer_modules = [student_model.decoder.layers] + else: + layer_modules = [] + # Collect all possible layer names (e.g. layers_0, dense_layers_0, moe_layers_0) + for name, module in vars(student_model.decoder).items(): + if name.startswith(LTI_LAYER_PATH_PREFIXES): + layer_modules.append(module) + + for layer_module in layer_modules: + attn_state_dict = layer_module.get(student_config.attn_module_name) + if attn_state_dict is None: + raise ValueError("LTI: attn_state_dict wasn't found in the model state dict") + + for proj_name in LTI_MODIFIED_ATTENTION_PARAM_NAMES: + proj_params = attn_state_dict.get(proj_name) + if proj_params is None: + raise ValueError("Non LTI supported Attention module state.") + + is_output_proj = proj_name == "out" + C_param = proj_params.get(LearnToInitDense.TENSOR_C) + if C_param is None: + continue # Not an LTI augmented module + if LearnToInitDense.TENSOR_W in proj_params: + max_logging.log(f"Computing final learn-to-init weight (general map) for: {proj_name}") + W_param = proj_params[LearnToInitDense.TENSOR_W] + final_kernel = calculate_attn_weight( + A=None, + B=None, + C=C_param, + general_map=W_param, + is_output_projection=is_output_proj, + matmul_precision=student_config.matmul_precision, + ) + elif LearnToInitDense.TENSOR_A in proj_params and LearnToInitDense.TENSOR_B in proj_params: + max_logging.log(f"Computing final learn-to-init weight for: {proj_name}") + A_param = proj_params[LearnToInitDense.TENSOR_A] + B_param = proj_params[LearnToInitDense.TENSOR_B] + final_kernel = calculate_attn_weight( + A=A_param, + B=B_param, + C=C_param, + is_output_projection=is_output_proj, + matmul_precision=student_config.matmul_precision, + ) + else: + raise ValueError("Non LTI supported Attention module state.") + + C_param.set_value(final_kernel) + # inject as a regular parameter + proj_params[LTI_ORIGINAL_ATTENTION_PARAMS_NAME] = C_param + # Clean up the LTI-specific parameters + proj_params.pop(LearnToInitDense.TENSOR_W, None) + proj_params.pop(LearnToInitDense.TENSOR_A, None) + proj_params.pop(LearnToInitDense.TENSOR_B, None) + proj_params.pop(LearnToInitDense.TENSOR_C, None) diff --git a/src/maxtext/layers/nnx_wrappers.py b/src/maxtext/layers/nnx_wrappers.py index eb81d596d9..7e32053a52 100644 --- a/src/maxtext/layers/nnx_wrappers.py +++ b/src/maxtext/layers/nnx_wrappers.py @@ -417,6 +417,9 @@ class ToLinen(linen.Module): skip_rng: bool = False metadata_fn: tp.Callable[[variablelib.VariableState], tp.Any] | None = to_linen_var + # generic function to augment original nnx module (i.e for learn-to-init distillation) + nnx_module_augment_fn: tp.Callable[[Module, str | None], Module] | None = None + @linen.compact def __call__(self, *args, nnx_method: tp.Callable[..., Any] | str | None = None, **kwargs): def _module_kwargs(): @@ -429,6 +432,10 @@ def _module_kwargs(): # init codepath if self.is_initializing(): module = self.nnx_class(*self.args, **_module_kwargs()) + + if self.nnx_module_augment_fn is not None: + module = self.nnx_module_augment_fn(module, self.name) + # TODO: add lazy_init here in case there's an `ToNNX` submodule under `module`. # update linen variables before call module to save initial state self._update_variables(module) @@ -440,6 +447,11 @@ def _module_kwargs(): # create the nnx module module = self.nnx_class(*self.args, **_module_kwargs()) + # Modify the nnx structure BEFORE loading the state + # This ensures the tree structure matches the state we are about to inject + if self.nnx_module_augment_fn is not None: + module = self.nnx_module_augment_fn(module, self.name) + # update nnx module from linen variables def maybe_unbox(x): if isinstance(x, meta.AxisMetadata): @@ -553,6 +565,7 @@ def to_linen_class( base_nnx_class: type[M], base_metadata_fn: tp.Callable[[variablelib.VariableState], tp.Any] | None = to_linen_var, base_skip_rng: bool = False, + nnx_module_augment_fn: tp.Callable[[Module, str | None], Module] | None = None, **partial_kwargs: tp.Any, ) -> type[ToLinen]: """A dynamically created Linen Module that wraps a specific NNX Module. @@ -595,6 +608,7 @@ def __init__( metadata_fn=None, name=_MISSING, parent=_MISSING, + nnx_module_augment_fn=nnx_module_augment_fn, **other_kwargs, ): linen_kwargs = {} @@ -609,6 +623,7 @@ def __init__( metadata_fn=metadata_fn or base_metadata_fn, skip_rng=skip_rng or base_skip_rng, kwargs=FrozenDict({**partial_kwargs, **(kwargs or {}), **other_kwargs}), + nnx_module_augment_fn=nnx_module_augment_fn, **linen_kwargs, ) diff --git a/src/maxtext/models/llama2.py b/src/maxtext/models/llama2.py index 6a215c5dbe..a75cefc291 100644 --- a/src/maxtext/models/llama2.py +++ b/src/maxtext/models/llama2.py @@ -34,7 +34,7 @@ from maxtext.layers.quantizations import AqtQuantization as Quant from maxtext.utils import max_utils from maxtext.utils.sharding import create_sharding, maybe_shard_with_logical -from maxtext.layers.learn_to_init_layer import LearnToInitDecoderLayer +from maxtext.layers.learn_to_init_layer import apply_lti_modification # ----------------------------------------- # The Decoder Layer specific for Llama2 @@ -228,21 +228,8 @@ def update_cache(cache, val): return layer_output, kv_cache -class LlamaLTIDecoderLayer(LearnToInitDecoderLayer): - """A Type-bounded version of Llama-specific LearnToInitDecoderLayer. - Temporal LTI wrapper before it is generalized for other models. - """ - - def __init__(self, *args, **kwargs): - super().__init__(*args, base_layer_cls=LlamaDecoderLayer, **kwargs) - - -LlamaLTIDecoderLayerToLinen = nnx_wrappers.to_linen_class( - LlamaLTIDecoderLayer, - base_metadata_fn=initializers.variable_to_logically_partitioned, -) - LlamaDecoderLayerToLinen = nnx_wrappers.to_linen_class( LlamaDecoderLayer, base_metadata_fn=initializers.variable_to_logically_partitioned, + nnx_module_augment_fn=apply_lti_modification, ) diff --git a/src/maxtext/trainers/post_train/distillation/lti_utils.py b/src/maxtext/trainers/post_train/distillation/lti_utils.py index d618d1873c..895cd4d852 100644 --- a/src/maxtext/trainers/post_train/distillation/lti_utils.py +++ b/src/maxtext/trainers/post_train/distillation/lti_utils.py @@ -16,6 +16,7 @@ from flax import nnx from maxtext.utils import max_logging +import re def _nav_to_attr(root, parts): @@ -77,41 +78,56 @@ def prepare_student_weights( student_graph = {"/".join(map(str, path)): node for path, node in nnx.graph.iter_graph(student_model)} # --- Weight sharing (alias destination -> source Variable) --- - for source_path, dest_path in student_weights_share_map.items(): - source_node = student_graph.get(source_path) - dest_node = student_graph.get(dest_path) - assert ( - source_node is not None - ), f"Student parameter sharing: Could not find source_node model parameter at path: {source_path}" - assert ( - dest_node is not None - ), f"Student parameter sharing: Could not find dest_node model parameter at path: {dest_path}" - - assert source_node.value.shape == dest_node.value.shape, ( - f"Shape mismatch for sharing parameter between {source_path} and {dest_path}: " - f"{source_node.value.shape} vs {dest_node.value.shape}" - ) - max_logging.info(f"Sharing parameter {source_path} with {dest_path}") - dest_parts = dest_path.split("/") - - dest_parent = _nav_to_attr(student_model, dest_parts[:-1]) - dest_attr = dest_parts[-1] - - if hasattr(dest_parent, dest_attr): - setattr(dest_parent, dest_attr, source_node) - else: - dest_parent[dest_attr] = source_node - - for teacher_path, student_path in teacher_weights_copy_map.items(): - teacher_node = teacher_graph.get(teacher_path) - student_node = student_graph.get(student_path) - assert teacher_node is not None, f"Could not find teacher model parameter at path: {teacher_path}" - assert student_node is not None, f"Could not find student model parameter at path: {student_path}" - assert ( - student_node.value.shape == teacher_node.value.shape - ), f"Shape mismatch for {teacher_path}. Teacher: {teacher_node.value.shape}, Student: {student_node.value.shape}" - student_node.value = teacher_node.value - max_logging.info(f"Inserted teacher weight parameter {teacher_path} to the student at {student_path}") + for source_pattern, dest_pattern in student_weights_share_map.items(): + matched_any = False + for source_path, source_node in student_graph.items(): + match = re.fullmatch(source_pattern, source_path) + if match: + matched_any = True + dest_path = source_path if source_pattern == dest_pattern else match.expand(dest_pattern) + dest_node = student_graph.get(dest_path) + assert ( + dest_node is not None + ), f"Student parameter sharing: Could not find dest_node model parameter at path: {dest_path}" + assert source_node.get_value().shape == dest_node.get_value().shape, ( + f"Shape mismatch for sharing parameter between {source_path} and {dest_path}: " + f"{source_node.get_value().shape} vs {dest_node.get_value().shape}" + ) + max_logging.info(f"Sharing parameter {source_path} with {dest_path}") + dest_parts = dest_path.split("/") + dest_parent = _nav_to_attr(student_model, dest_parts[:-1]) + dest_attr = dest_parts[-1] + if hasattr(dest_parent, dest_attr): + setattr(dest_parent, dest_attr, source_node) + else: + dest_parent[dest_attr] = source_node + + if not matched_any: + raise ValueError(f"Student parameter sharing: No paths matched the source pattern: {source_pattern}") + + # --- teacher to student weights copying --- + for teacher_pattern, student_pattern in teacher_weights_copy_map.items(): + matched_any = False + for teacher_path, teacher_node in teacher_graph.items(): + match = re.fullmatch(teacher_pattern, teacher_path) + if match: + matched_any = True + student_path = teacher_path if teacher_pattern == student_pattern else match.expand(student_pattern) + student_path_parts = student_path.split("/") + student_node = _nav_to_attr(student_model, student_path_parts) # student_graph.get(student_path) + + assert ( + student_node is not None + ), f"Could not find student model parameter at path: {student_path} for teacher parameter {teacher_path}" + assert student_node.get_value().shape == teacher_node.get_value().shape, ( + f"Shape mismatch for {teacher_path}. Teacher: {teacher_node.get_value().shape}," + f" Student: {student_node.get_value().shape}" + ) + student_node.set_value(teacher_node.get_value()) + max_logging.info(f"Inserted teacher weight parameter {teacher_path} to the student at {student_path}") + + if not matched_any: + raise ValueError(f"Teacher weight injection: No paths matched the source pattern: {teacher_pattern}") max_logging.info("Teacher weight injection complete.") diff --git a/src/maxtext/trainers/post_train/distillation/train_distill.py b/src/maxtext/trainers/post_train/distillation/train_distill.py index 5de310da90..d766ab69be 100644 --- a/src/maxtext/trainers/post_train/distillation/train_distill.py +++ b/src/maxtext/trainers/post_train/distillation/train_distill.py @@ -44,6 +44,7 @@ import jax import jax.numpy as jnp import optax +import re from orbax import checkpoint # MaxText Imports @@ -608,11 +609,12 @@ def train_distill( max_logging.log(f"Loading Student from {student_config.load_parameters_path}...") _log_config_details(student_config, "Student") student_model = get_maxtext_model(student_config, mesh) - student_params_to_update = getattr(student_config, "student_params_to_update", []) + student_params_to_update = getattr(student_config, "student_params_to_update", []) or [] + student_param_update_templates = [re.compile(t) for t in student_params_to_update] def student_freeze_param_fn(path) -> bool: path_str = "/".join(str(p) for p in path) - return not any(template in path_str for template in student_params_to_update) + return not any(regex.search(path_str) for regex in student_param_update_templates) # Inject the teacher's frozen weights into the student model if teacher_model: diff --git a/tests/post_training/unit/learn_to_init_test.py b/tests/post_training/unit/learn_to_init_test.py index 42571b8e43..6c757c00c6 100644 --- a/tests/post_training/unit/learn_to_init_test.py +++ b/tests/post_training/unit/learn_to_init_test.py @@ -25,7 +25,7 @@ from maxtext.trainers.post_train.distillation.lti_utils import prepare_student_weights from unittest import mock from maxtext.models.llama2 import LlamaDecoderLayer -from maxtext.layers.learn_to_init_layer import LearnToInitDecoderLayer +from maxtext.layers.learn_to_init_layer import apply_lti_modification # Minimal dummy models for testing @@ -90,11 +90,7 @@ def test_prepare_student_weights_share_and_copy(self): # Since student's layer2 was "shared" from layer1, the copy operation # overwrites student's layer1. self.assertTrue(jnp.array_equal(student.layer1.kernel.value, teacher.layer2.kernel.value)) - - # The actual layer2 of the student remains unchanged because the dictionary - # reference was rerouted. We verify it still has its original initialization. - student_original_layer2 = DummyModel(nnx.Rngs(1)).layer2.kernel.value - self.assertTrue(jnp.array_equal(student.layer2.kernel.value, student_original_layer2)) + self.assertTrue(jnp.array_equal(student.layer2.kernel.value, teacher.layer2.kernel.value)) def test_prepare_student_weights_shape_mismatch(self): """Verifies that an error is raised when trying to copy misaligned shapes.""" @@ -214,19 +210,21 @@ def test_qkv_projection_general_map(self): self.assertEqual(out.shape, (batch_size, seq_len, student_heads, student_head_dim)) -class LearnToInitDecoderLayerTest(unittest.TestCase): +class ApplyLtiModificationTest(unittest.TestCase): + """Test LTI module augmentation functionality.""" - def test_llama_lti_decoder_layer_initialization(self): - """Verifies LearnToInitDecoderLayer initializes and modifies LlamaDecoderLayer correctly.""" + def get_mock_config(self): + """Setup mock teacher & student configs.""" - # 1. Setup mock teacher config mock_teacher_config = mock.MagicMock() mock_teacher_config.base_num_query_heads = 4 mock_teacher_config.base_num_kv_heads = 2 mock_teacher_config.head_dim = 16 - # 2. Setup mock student config + # Setup mock student config mock_config = mock.MagicMock() + mock_config.learn_to_init_mode = True + mock_config.attn_module_name = "self_attention" mock_config.lti_use_general_linear_map = False mock_config.teacher_config = mock_teacher_config @@ -259,8 +257,11 @@ def test_llama_lti_decoder_layer_initialization(self): mock_config.scan_layers = False mock_config.ici_context_autoregressive_parallelism = 1 mock_config.fused_qkv = False + return mock_config - # 3. Dummy Jax sharding mesh and NNX Rngs + def test_apply_lti_modification_initialization(self): + """Verifies apply_lti_modification modifies LlamaDecoderLayer correctly.""" + mock_config = self.get_mock_config() mesh = jax.sharding.Mesh(jax.devices(), ("data",)) rngs = nnx.Rngs(0) @@ -271,21 +272,16 @@ def test_llama_lti_decoder_layer_initialization(self): ): # This effectively initializes LlamaLTIDecoderLayer and implicitly calls _customize_attention_modules - layer = LearnToInitDecoderLayer( - base_layer_cls=LlamaDecoderLayer, + layer = LlamaDecoderLayer( config=mock_config, model_mode="train", mesh=mesh, rngs=rngs, ) + layer = apply_lti_modification(layer) - # 4. Verify initialization result - self.assertIsInstance(layer.learn_to_init_wrapper, LlamaDecoderLayer) - self.assertEqual(layer.self_attention_module_name, "self_attention") - - # 5. Verify the behavior of _customize_attention_modules - # It should correctly replace query, key, value, and out with LearnToInitDense - attention_module = layer.learn_to_init_wrapper.self_attention + self.assertIsInstance(layer, LlamaDecoderLayer) + attention_module = layer.self_attention for proj_name in ["query", "key", "value", "out"]: child = getattr(attention_module, proj_name) @@ -302,6 +298,30 @@ def test_llama_lti_decoder_layer_initialization(self): # (teacher_heads, head_dim, emb_dim) -> (4, 16, 64) self.assertEqual(child.C.value.shape, (4, 16, 64)) + def test_apply_lti_modification_skips_when_disabled(self): + """Verifies apply_lti_modification does not modify when learn_to_init_mode is False.""" + mock_config = self.get_mock_config() + mock_config.learn_to_init_mode = False + mesh = jax.sharding.Mesh(jax.devices(), ("data",)) + rngs = nnx.Rngs(0) + + with ( + mock.patch("maxtext.utils.max_utils.get_batch_seq_len_for_mode", return_value=(2, 32)), + mock.patch("maxtext.layers.quantizations.configure_kv_quant", return_value=None), + ): + layer = LlamaDecoderLayer( + config=mock_config, + model_mode="train", + mesh=mesh, + rngs=rngs, + ) + layer = apply_lti_modification(layer) + + attention_module = layer.self_attention + for proj_name in ["query", "key", "value", "out"]: + child = getattr(attention_module, proj_name) + self.assertNotIsInstance(child, LearnToInitDense, f"{proj_name} was unexpectedly swapped to LearnToInitDense") + if __name__ == "__main__": absltest.main() diff --git a/tests/post_training/unit/train_distill_test.py b/tests/post_training/unit/train_distill_test.py index ca2cbfa91f..d32e6bff6a 100644 --- a/tests/post_training/unit/train_distill_test.py +++ b/tests/post_training/unit/train_distill_test.py @@ -1064,6 +1064,7 @@ def test_main_offline_mode_skips_teacher_loading( mock_student_cfg.distill_weights_copy_map = {} mock_student_cfg.distill_student_weights_share_map = {} mock_student_cfg.get_keys.return_value = {} + mock_student_cfg.student_params_to_update = [] # Add scheduling attributes mock_student_cfg.distill_alpha_end = None @@ -1162,6 +1163,7 @@ def test_main_online_mode_loads_teacher( mock_student_cfg.distill_weights_copy_map = {} mock_student_cfg.distill_student_weights_share_map = {} mock_student_cfg.get_keys.return_value = {} + mock_student_cfg.student_params_to_update = [] # Add scheduling attributes mock_student_cfg.distill_alpha_end = None