From 3dbb7bb0c47ad1380114fed673d5ef3a5da682d9 Mon Sep 17 00:00:00 2001 From: vlad-karp Date: Wed, 29 Apr 2026 00:42:37 +0000 Subject: [PATCH 1/5] generalized LTI --- src/maxtext/common/common_types.py | 2 - src/maxtext/configs/types.py | 4 + src/maxtext/layers/decoders.py | 3 - src/maxtext/layers/learn_to_init_layer.py | 130 +++++++++++++++++----- src/maxtext/layers/nnx_wrappers.py | 15 +++ src/maxtext/models/llama2.py | 17 +-- 6 files changed, 125 insertions(+), 46 deletions(-) diff --git a/src/maxtext/common/common_types.py b/src/maxtext/common/common_types.py index c48a4e2b6c..b5a3efcf75 100644 --- a/src/maxtext/common/common_types.py +++ b/src/maxtext/common/common_types.py @@ -109,8 +109,6 @@ class DecoderBlockType(enum.Enum): LLAMA4 = "llama4" OLMO3 = "olmo3" - LLAMA2LTI = "llama2_learn_to_init" - class AttentionType(enum.Enum): GLOBAL = "global" # default, with causality diff --git a/src/maxtext/configs/types.py b/src/maxtext/configs/types.py index b463516945..cdbc276104 100644 --- a/src/maxtext/configs/types.py +++ b/src/maxtext/configs/types.py @@ -1225,6 +1225,10 @@ class Distillation(BaseModel): default_factory=dict, description="Experimental weight sharing map inside the student model for learn-to-init phase", ) + + attn_module_name: Optional[str] = Field( + None, description="Attention nnx module attribute name to augment with LTI logic" + ) # --------------------------------------- # --- Distillation freezing filter -- diff --git a/src/maxtext/layers/decoders.py b/src/maxtext/layers/decoders.py index 42644ab262..2b9713eb44 100644 --- a/src/maxtext/layers/decoders.py +++ b/src/maxtext/layers/decoders.py @@ -449,8 +449,6 @@ def get_decoder_layers(self): return [DecoderLayer] case DecoderBlockType.LLAMA2: return [llama2.LlamaDecoderLayerToLinen] - case DecoderBlockType.LLAMA2LTI: - return [llama2.LlamaLTIDecoderLayerToLinen] case DecoderBlockType.MISTRAL: # TODO(ranran): update to Mistral with sliding window attention return [mistral.MistralDecoderLayerToLinen] @@ -545,7 +543,6 @@ def get_norm_layer(self, num_features: int): DecoderBlockType.SIMPLE_MLP, DecoderBlockType.LLAMA4, DecoderBlockType.OLMO3, - DecoderBlockType.LLAMA2LTI, ): return functools.partial(rms_norm, num_features=num_features, shard_mode=self.config.shard_mode) elif self.config.decoder_block == DecoderBlockType.GPT3: diff --git a/src/maxtext/layers/learn_to_init_layer.py b/src/maxtext/layers/learn_to_init_layer.py index a1070a1684..aee0f98d27 100644 --- a/src/maxtext/layers/learn_to_init_layer.py +++ b/src/maxtext/layers/learn_to_init_layer.py @@ -23,12 +23,32 @@ from typing import Iterable, Optional from maxtext.common.common_types import DType, ShardMode, Array -from maxtext.layers.nnx_wrappers import ToNNX +from maxtext.layers.nnx_wrappers import ToNNX, to_linen_class from maxtext.layers.quantizations import AqtQuantization as Quant from maxtext.layers.initializers import NdInitializer, nd_dense_init from maxtext.utils import max_logging, max_utils +LTI_MODIFIED_ATTENTION_PARAM_NAMES = ["query", "key", "value", "out"] +LTI_ORIGINAL_ATTENTION_PARAMS_NAME = "kernel" + + +def apply_lti_modification(module: nnx.Module): + """ + Applies Learn-To-Init structural modifications to an instantiated NNX module. + Checks the config to determine if LTI is enabled. + """ + + config = getattr(module, "config", None) + if not config or not getattr(config, "learn_to_init_mode", False): + return module + attn_module_name = config.attn_module_name + + if attn_module_name: + _customize_attention_modules(config, attn_module_name, module) + return module + + class LearnToInitDecoderLayer(nnx.Module): """ A generic wrapper that initializes a base decoder layer and dynamically swaps @@ -74,13 +94,11 @@ def _customize_attention_modules(self, module: nnx.Module): if attention_module is None: return - # Target Q, K, V projections sub module names - target_names = ["query", "key", "value", "out"] - use_general_linear_map = self.config.lti_use_general_linear_map teacher_config = self.config.teacher_config - for name in target_names: + # Target Q, K, V projections sub module names + for name in LTI_MODIFIED_ATTENTION_PARAM_NAMES: child = getattr(attention_module, name, None) if isinstance(child, linears.DenseGeneral): orig_proj_shape = child.kernel.shape @@ -118,6 +136,53 @@ def __call__(self, *args, **kwargs): return self.learn_to_init_wrapper(*args, **kwargs) +def _customize_attention_modules(config: Config, attn_module_name: str, module: nnx.Module): + """Replaces specific DenseGeneral modules (q, k, v projections) in the attention module.""" + attention_module = getattr(module, attn_module_name, None) + if attention_module is None: + return + + # Target Q, K, V projections sub module names + target_names = LTI_MODIFIED_ATTENTION_PARAM_NAMES + + use_general_linear_map = config.lti_use_general_linear_map + teacher_config = config.teacher_config + + for name in target_names: + child = getattr(attention_module, name, None) + if isinstance(child, linears.DenseGeneral): + orig_proj_shape = child.kernel.shape + assert len(orig_proj_shape) == 3 + if name in ("query", "key", "value"): + teacher_heads_num = teacher_config.base_num_query_heads if name == "query" else teacher_config.base_num_kv_heads + teacher_shape = (orig_proj_shape[0], teacher_heads_num, teacher_config.head_dim) + elif name == "out": + teacher_shape = (teacher_config.base_num_query_heads, teacher_config.head_dim, orig_proj_shape[2]) + else: + max_logging.warning(f"Non handled LTI projection type {name}") + continue + new_module = LearnToInitDense( + in_features_shape=child.in_features_shape, + out_features_shape=child.out_features_shape, + C=jnp.empty(teacher_shape), + axis=child.axis, + weight_dtype=child.weight_dtype, + dtype=child.dtype, + kernel_init=child.kernel_init, + kernel_axes=child.kernel_axes, + quant=child.quant, + use_bias=child.use_bias, + shard_mode=child.shard_mode, + matmul_precision=child.matmul_precision, + is_output_projection=(name == "out"), + use_general_linear_map=use_general_linear_map, + rngs=attention_module.rngs, # Reuse the original module RNGs + ) + # Swap the module in the mutable NNX graph + setattr(attention_module, name, new_module) + max_logging.info(f"Replaced {attn_module_name}.{name} with LearnToInitDense.{name}") + + class LearnToInitDense(nnx.Module): """ A customized Dense layer used exclusively during the learn-to-init phase of distillation. @@ -380,21 +445,18 @@ def apply_lti_model_update(student_model, student_config): student_config: The configuration of the student model containing parameters like `matmul_precision`. """ - # Access the nested ToNNX dictionary directly assert isinstance(student_model.decoder, ToNNX), "LTI now only supports ToNNX as the student_model's decoder type" - lti_wrapped_node = student_model.decoder.layers["learn_to_init_wrapper"] - attn_state_dict = lti_wrapped_node["self_attention"] - # Iterate through known projections and compute final weights - for proj_name in ["query", "key", "value", "out"]: - if proj_name not in attn_state_dict: - raise ValueError("Unsupported structure of LTI-augmented Attention module.") + attn_state_dict = student_model.decoder.layers.get(student_config.attn_module_name) + if attn_state_dict is None: + raise ValueError("LTI: attn_state_dict wasn't found in the model state dict") + + # Iterate through known projections and compute final weights + for proj_name in LTI_MODIFIED_ATTENTION_PARAM_NAMES: proj_params = attn_state_dict[proj_name] is_output_proj = proj_name == "out" - C_param = proj_params.get(LearnToInitDense.TENSOR_C) - if C_param is None: raise ValueError("Attention LTI-augmented module has no C parameter.") @@ -421,21 +483,37 @@ def apply_lti_model_update(student_model, student_config): matmul_precision=student_config.matmul_precision, ) else: - continue + raise ValueError("Non LTI supported Attention module state.") - # 3. Overwrite C with the final computed kernel - C_param.set_value(final_kernel) - - # 4. Standardize the structure by placing it under the 'kernel' key - proj_params["kernel"] = C_param - - # 5. Clean up the LTI-specific parameters using .pop() - # Using pop(key, None) avoids KeyErrors if a tensor was omitted or already shared/deleted + C_param.value = final_kernel + # inject as a regular parameter + proj_params[LTI_ORIGINAL_ATTENTION_PARAMS_NAME] = C_param + # Clean up the LTI-specific parameters proj_params.pop(LearnToInitDense.TENSOR_W, None) proj_params.pop(LearnToInitDense.TENSOR_A, None) proj_params.pop(LearnToInitDense.TENSOR_B, None) proj_params.pop(LearnToInitDense.TENSOR_C, None) - # unpack the learn_to_init_wrapper to match the standard model structure - del student_model.decoder.layers["learn_to_init_wrapper"] - student_model.decoder.layers.update(lti_wrapped_node) + +def get_decoder_layer_to_linen(config, base_layer_cls): + """ + Returns the appropriate Linen-wrapped decoder layer for any architecture. + Wraps it in LearnToInitDecoderLayer if config.enable_lti is True. + """ + if getattr(config, "learn_to_init_mode", False): + # Dynamically create the Type-bounded LTI wrapper for the provided base layer + class LTIWrapper(LearnToInitDecoderLayer): + + def __init__(self, *args, **kwargs): + super().__init__(*args, base_layer_cls=base_layer_cls, **kwargs) + + return to_linen_class( + LTIWrapper, + base_metadata_fn=initializers.variable_to_logically_partitioned, + ) + + # Return the standard unwrapped Linen class + return to_linen_class( + base_layer_cls, + base_metadata_fn=initializers.variable_to_logically_partitioned, + ) diff --git a/src/maxtext/layers/nnx_wrappers.py b/src/maxtext/layers/nnx_wrappers.py index eb81d596d9..1f9275d569 100644 --- a/src/maxtext/layers/nnx_wrappers.py +++ b/src/maxtext/layers/nnx_wrappers.py @@ -417,6 +417,9 @@ class ToLinen(linen.Module): skip_rng: bool = False metadata_fn: tp.Callable[[variablelib.VariableState], tp.Any] | None = to_linen_var + # generic function to augment original nnx module (i.e for learn-to-init distillation) + nxx_module_augment_fn: tp.Callable[[Module], Module] | None = None + @linen.compact def __call__(self, *args, nnx_method: tp.Callable[..., Any] | str | None = None, **kwargs): def _module_kwargs(): @@ -429,6 +432,10 @@ def _module_kwargs(): # init codepath if self.is_initializing(): module = self.nnx_class(*self.args, **_module_kwargs()) + + if self.nxx_module_augment_fn is not None: + module = self.nxx_module_augment_fn(module) + # TODO: add lazy_init here in case there's an `ToNNX` submodule under `module`. # update linen variables before call module to save initial state self._update_variables(module) @@ -440,6 +447,11 @@ def _module_kwargs(): # create the nnx module module = self.nnx_class(*self.args, **_module_kwargs()) + # Modify the nnx structure BEFORE loading the state + # This ensures the tree structure matches the state we are about to inject + if self.nxx_module_augment_fn is not None: + module = self.nxx_module_augment_fn(module) + # update nnx module from linen variables def maybe_unbox(x): if isinstance(x, meta.AxisMetadata): @@ -553,6 +565,7 @@ def to_linen_class( base_nnx_class: type[M], base_metadata_fn: tp.Callable[[variablelib.VariableState], tp.Any] | None = to_linen_var, base_skip_rng: bool = False, + nxx_module_augment_fn: tp.Callable[[Module], Module] | None = None, **partial_kwargs: tp.Any, ) -> type[ToLinen]: """A dynamically created Linen Module that wraps a specific NNX Module. @@ -595,6 +608,7 @@ def __init__( metadata_fn=None, name=_MISSING, parent=_MISSING, + nxx_module_augment_fn=nxx_module_augment_fn, **other_kwargs, ): linen_kwargs = {} @@ -609,6 +623,7 @@ def __init__( metadata_fn=metadata_fn or base_metadata_fn, skip_rng=skip_rng or base_skip_rng, kwargs=FrozenDict({**partial_kwargs, **(kwargs or {}), **other_kwargs}), + nxx_module_augment_fn=nxx_module_augment_fn, **linen_kwargs, ) diff --git a/src/maxtext/models/llama2.py b/src/maxtext/models/llama2.py index 6a215c5dbe..777efd8c1f 100644 --- a/src/maxtext/models/llama2.py +++ b/src/maxtext/models/llama2.py @@ -34,7 +34,7 @@ from maxtext.layers.quantizations import AqtQuantization as Quant from maxtext.utils import max_utils from maxtext.utils.sharding import create_sharding, maybe_shard_with_logical -from maxtext.layers.learn_to_init_layer import LearnToInitDecoderLayer +from maxtext.layers.learn_to_init_layer import apply_lti_modification # ----------------------------------------- # The Decoder Layer specific for Llama2 @@ -228,21 +228,8 @@ def update_cache(cache, val): return layer_output, kv_cache -class LlamaLTIDecoderLayer(LearnToInitDecoderLayer): - """A Type-bounded version of Llama-specific LearnToInitDecoderLayer. - Temporal LTI wrapper before it is generalized for other models. - """ - - def __init__(self, *args, **kwargs): - super().__init__(*args, base_layer_cls=LlamaDecoderLayer, **kwargs) - - -LlamaLTIDecoderLayerToLinen = nnx_wrappers.to_linen_class( - LlamaLTIDecoderLayer, - base_metadata_fn=initializers.variable_to_logically_partitioned, -) - LlamaDecoderLayerToLinen = nnx_wrappers.to_linen_class( LlamaDecoderLayer, base_metadata_fn=initializers.variable_to_logically_partitioned, + nxx_module_augment_fn=apply_lti_modification, ) From 86ba412beeb4a81cffdfb8f2b5fcb8dd6600e9e6 Mon Sep 17 00:00:00 2001 From: vlad-karp Date: Thu, 30 Apr 2026 17:28:19 +0000 Subject: [PATCH 2/5] regexp based matching for distillation freezing and LTI --- src/maxtext/configs/types.py | 4 + src/maxtext/layers/learn_to_init_layer.py | 135 +++++++++++------- src/maxtext/layers/nnx_wrappers.py | 6 +- .../post_train/distillation/lti_utils.py | 86 ++++++----- .../post_train/distillation/train_distill.py | 4 +- .../post_training/unit/learn_to_init_test.py | 126 +++++++++------- 6 files changed, 222 insertions(+), 139 deletions(-) diff --git a/src/maxtext/configs/types.py b/src/maxtext/configs/types.py index cdbc276104..ad5a24bbdd 100644 --- a/src/maxtext/configs/types.py +++ b/src/maxtext/configs/types.py @@ -1229,6 +1229,10 @@ class Distillation(BaseModel): attn_module_name: Optional[str] = Field( None, description="Attention nnx module attribute name to augment with LTI logic" ) + + lti_layer_indices: Optional[list[int]] = Field( + None, description="List of layer indices to apply LTI modifications. If None, applied to all layers." + ) # --------------------------------------- # --- Distillation freezing filter -- diff --git a/src/maxtext/layers/learn_to_init_layer.py b/src/maxtext/layers/learn_to_init_layer.py index aee0f98d27..c03bcc4910 100644 --- a/src/maxtext/layers/learn_to_init_layer.py +++ b/src/maxtext/layers/learn_to_init_layer.py @@ -33,7 +33,7 @@ LTI_ORIGINAL_ATTENTION_PARAMS_NAME = "kernel" -def apply_lti_modification(module: nnx.Module): +def apply_lti_modification(module: nnx.Module, module_name: str | None = None): """ Applies Learn-To-Init structural modifications to an instantiated NNX module. Checks the config to determine if LTI is enabled. @@ -42,9 +42,33 @@ def apply_lti_modification(module: nnx.Module): config = getattr(module, "config", None) if not config or not getattr(config, "learn_to_init_mode", False): return module + + lti_layer_indices = getattr(config, "lti_layer_indices", None) + if lti_layer_indices is not None and not getattr(config, "scan_layers", True): + if module_name is not None and "layers_" in module_name: + try: + local_idx = int(module_name.split("_")[-1]) + if module_name.startswith("dense_layers_"): + layer_idx = local_idx + elif module_name.startswith("moe_layers_"): + first_dense = getattr(config, "first_num_dense_layers", 0) + layer_idx = first_dense + local_idx + else: + layer_idx = local_idx + + if layer_idx not in lti_layer_indices: + max_logging.info( + f"apply_lti_modification: skipping module={module_name} since its index " + "{layer_idx} is not in lti_layer_indices." + ) + return module + except ValueError: + pass + attn_module_name = config.attn_module_name if attn_module_name: + max_logging.info(f"apply_lti_modification: customizing module={attn_module_name} in {module_name}") _customize_attention_modules(config, attn_module_name, module) return module @@ -438,7 +462,7 @@ def apply_lti_model_update(student_model, student_config): It effectively collapses the learn-to-init parameterization back into a standard decoder architecture, modifying the `student_model` in-place. - NOTE: works for ToNXX decoder model and layer-scan mode only + NOTE: works for ToNXX decoder model Args: student_model: The trained student model to be updated in-place. @@ -447,52 +471,67 @@ def apply_lti_model_update(student_model, student_config): assert isinstance(student_model.decoder, ToNNX), "LTI now only supports ToNNX as the student_model's decoder type" - attn_state_dict = student_model.decoder.layers.get(student_config.attn_module_name) - - if attn_state_dict is None: - raise ValueError("LTI: attn_state_dict wasn't found in the model state dict") - - # Iterate through known projections and compute final weights - for proj_name in LTI_MODIFIED_ATTENTION_PARAM_NAMES: - proj_params = attn_state_dict[proj_name] - is_output_proj = proj_name == "out" - C_param = proj_params.get(LearnToInitDense.TENSOR_C) - if C_param is None: - raise ValueError("Attention LTI-augmented module has no C parameter.") - - if LearnToInitDense.TENSOR_W in proj_params: - max_logging.log(f"Computing final learn-to-init weight (general map) for: {proj_name}") - W_param = proj_params[LearnToInitDense.TENSOR_W] - final_kernel = calculate_attn_weight( - A=None, - B=None, - C=C_param, - general_map=W_param, - is_output_projection=is_output_proj, - matmul_precision=student_config.matmul_precision, - ) - elif LearnToInitDense.TENSOR_A in proj_params and LearnToInitDense.TENSOR_B in proj_params: - max_logging.log(f"Computing final learn-to-init weight for: {proj_name}") - A_param = proj_params[LearnToInitDense.TENSOR_A] - B_param = proj_params[LearnToInitDense.TENSOR_B] - final_kernel = calculate_attn_weight( - A=A_param, - B=B_param, - C=C_param, - is_output_projection=is_output_proj, - matmul_precision=student_config.matmul_precision, - ) - else: - raise ValueError("Non LTI supported Attention module state.") - - C_param.value = final_kernel - # inject as a regular parameter - proj_params[LTI_ORIGINAL_ATTENTION_PARAMS_NAME] = C_param - # Clean up the LTI-specific parameters - proj_params.pop(LearnToInitDense.TENSOR_W, None) - proj_params.pop(LearnToInitDense.TENSOR_A, None) - proj_params.pop(LearnToInitDense.TENSOR_B, None) - proj_params.pop(LearnToInitDense.TENSOR_C, None) + if student_config.attn_module_name is None: + return + + if getattr(student_config, "scan_layers", True): + layer_modules = [student_model.decoder.layers] + else: + layer_modules = [] + # Collect all possible layer names (e.g. layers_0, dense_layers_0, moe_layers_0) + for name, module in vars( + student_model.decoder + ).items(): # or name.startswith("dense_layers_") or name.startswith("moe_layers_"): + if name.startswith("layers_"): + layer_modules.append(module) + + for layer_module in layer_modules: + attn_state_dict = layer_module.get(student_config.attn_module_name) + if attn_state_dict is None: + raise ValueError("LTI: attn_state_dict wasn't found in the model state dict") + + for proj_name in LTI_MODIFIED_ATTENTION_PARAM_NAMES: + proj_params = attn_state_dict.get(proj_name) + if proj_params is None: + raise ValueError("Non LTI supported Attention module state.") + + is_output_proj = proj_name == "out" + C_param = proj_params.get(LearnToInitDense.TENSOR_C) + if C_param is None: + continue # Not an LTI augmented module + if LearnToInitDense.TENSOR_W in proj_params: + max_logging.log(f"Computing final learn-to-init weight (general map) for: {proj_name}") + W_param = proj_params[LearnToInitDense.TENSOR_W] + final_kernel = calculate_attn_weight( + A=None, + B=None, + C=C_param, + general_map=W_param, + is_output_projection=is_output_proj, + matmul_precision=student_config.matmul_precision, + ) + elif LearnToInitDense.TENSOR_A in proj_params and LearnToInitDense.TENSOR_B in proj_params: + max_logging.log(f"Computing final learn-to-init weight for: {proj_name}") + A_param = proj_params[LearnToInitDense.TENSOR_A] + B_param = proj_params[LearnToInitDense.TENSOR_B] + final_kernel = calculate_attn_weight( + A=A_param, + B=B_param, + C=C_param, + is_output_projection=is_output_proj, + matmul_precision=student_config.matmul_precision, + ) + else: + raise ValueError("Non LTI supported Attention module state.") + + C_param.set_value(final_kernel) + # inject as a regular parameter + proj_params[LTI_ORIGINAL_ATTENTION_PARAMS_NAME] = C_param + # Clean up the LTI-specific parameters + proj_params.pop(LearnToInitDense.TENSOR_W, None) + proj_params.pop(LearnToInitDense.TENSOR_A, None) + proj_params.pop(LearnToInitDense.TENSOR_B, None) + proj_params.pop(LearnToInitDense.TENSOR_C, None) def get_decoder_layer_to_linen(config, base_layer_cls): diff --git a/src/maxtext/layers/nnx_wrappers.py b/src/maxtext/layers/nnx_wrappers.py index 1f9275d569..1f986b52c0 100644 --- a/src/maxtext/layers/nnx_wrappers.py +++ b/src/maxtext/layers/nnx_wrappers.py @@ -418,7 +418,7 @@ class ToLinen(linen.Module): metadata_fn: tp.Callable[[variablelib.VariableState], tp.Any] | None = to_linen_var # generic function to augment original nnx module (i.e for learn-to-init distillation) - nxx_module_augment_fn: tp.Callable[[Module], Module] | None = None + nxx_module_augment_fn: tp.Callable[[Module, str | None], Module] | None = None @linen.compact def __call__(self, *args, nnx_method: tp.Callable[..., Any] | str | None = None, **kwargs): @@ -434,7 +434,7 @@ def _module_kwargs(): module = self.nnx_class(*self.args, **_module_kwargs()) if self.nxx_module_augment_fn is not None: - module = self.nxx_module_augment_fn(module) + module = self.nxx_module_augment_fn(module, self.name) # TODO: add lazy_init here in case there's an `ToNNX` submodule under `module`. # update linen variables before call module to save initial state @@ -450,7 +450,7 @@ def _module_kwargs(): # Modify the nnx structure BEFORE loading the state # This ensures the tree structure matches the state we are about to inject if self.nxx_module_augment_fn is not None: - module = self.nxx_module_augment_fn(module) + module = self.nxx_module_augment_fn(module, self.name) # update nnx module from linen variables def maybe_unbox(x): diff --git a/src/maxtext/trainers/post_train/distillation/lti_utils.py b/src/maxtext/trainers/post_train/distillation/lti_utils.py index d618d1873c..59be3c55db 100644 --- a/src/maxtext/trainers/post_train/distillation/lti_utils.py +++ b/src/maxtext/trainers/post_train/distillation/lti_utils.py @@ -16,6 +16,7 @@ from flax import nnx from maxtext.utils import max_logging +import re def _nav_to_attr(root, parts): @@ -77,41 +78,56 @@ def prepare_student_weights( student_graph = {"/".join(map(str, path)): node for path, node in nnx.graph.iter_graph(student_model)} # --- Weight sharing (alias destination -> source Variable) --- - for source_path, dest_path in student_weights_share_map.items(): - source_node = student_graph.get(source_path) - dest_node = student_graph.get(dest_path) - assert ( - source_node is not None - ), f"Student parameter sharing: Could not find source_node model parameter at path: {source_path}" - assert ( - dest_node is not None - ), f"Student parameter sharing: Could not find dest_node model parameter at path: {dest_path}" - - assert source_node.value.shape == dest_node.value.shape, ( - f"Shape mismatch for sharing parameter between {source_path} and {dest_path}: " - f"{source_node.value.shape} vs {dest_node.value.shape}" - ) - max_logging.info(f"Sharing parameter {source_path} with {dest_path}") - dest_parts = dest_path.split("/") - - dest_parent = _nav_to_attr(student_model, dest_parts[:-1]) - dest_attr = dest_parts[-1] - - if hasattr(dest_parent, dest_attr): - setattr(dest_parent, dest_attr, source_node) - else: - dest_parent[dest_attr] = source_node - - for teacher_path, student_path in teacher_weights_copy_map.items(): - teacher_node = teacher_graph.get(teacher_path) - student_node = student_graph.get(student_path) - assert teacher_node is not None, f"Could not find teacher model parameter at path: {teacher_path}" - assert student_node is not None, f"Could not find student model parameter at path: {student_path}" - assert ( - student_node.value.shape == teacher_node.value.shape - ), f"Shape mismatch for {teacher_path}. Teacher: {teacher_node.value.shape}, Student: {student_node.value.shape}" - student_node.value = teacher_node.value - max_logging.info(f"Inserted teacher weight parameter {teacher_path} to the student at {student_path}") + for source_pattern, dest_pattern in student_weights_share_map.items(): + matched_any = False + for source_path, source_node in student_graph.items(): + match = re.fullmatch(source_pattern, source_path) + if match: + matched_any = True + dest_path = source_path if source_pattern == dest_pattern else match.expand(dest_pattern) + dest_node = student_graph.get(dest_path) + assert ( + dest_node is not None + ), f"Student parameter sharing: Could not find dest_node model parameter at path: {dest_path}" + assert source_node.get_value().shape == dest_node.get_value().shape, ( + f"Shape mismatch for sharing parameter between {source_path} and {dest_path}: " + f"{source_node.get_value().shape} vs {dest_node.get_value().shape}" + ) + max_logging.info(f"Sharing parameter {source_path} with {dest_path}") + dest_parts = dest_path.split("/") + dest_parent = _nav_to_attr(student_model, dest_parts[:-1]) + dest_attr = dest_parts[-1] + if hasattr(dest_parent, dest_attr): + setattr(dest_parent, dest_attr, source_node) + else: + dest_parent[dest_attr] = source_node + + if not matched_any: + raise ValueError(f"Student parameter sharing: No paths matched the source pattern: {source_pattern}") + + # --- teacher to student weights copying --- + for teacher_pattern, student_pattern in teacher_weights_copy_map.items(): + matched_any = False + for teacher_path, teacher_node in teacher_graph.items(): + match = re.fullmatch(teacher_pattern, teacher_path) + if match: + matched_any = True + student_path = teacher_path if teacher_pattern == student_pattern else match.expand(student_pattern) + student_path_parts = student_path.split("/") + student_node = _nav_to_attr(student_model, student_path_parts) # student_graph.get(student_path) + + assert ( + student_node is not None + ), f"Could not find student model parameter at path: {student_path} for teacher parameter {teacher_path}" + assert ( + student_node.get_value().shape == teacher_node.get_value().shape + ), f"Shape mismatch for {teacher_path}. Teacher: {teacher_node.get_value().shape}," + " Student: {student_node.get_value().shape}" + student_node.set_value(teacher_node.get_value()) + max_logging.info(f"Inserted teacher weight parameter {teacher_path} to the student at {student_path}") + + if not matched_any: + raise ValueError(f"Teacher weight injection: No paths matched the source pattern: {teacher_pattern}") max_logging.info("Teacher weight injection complete.") diff --git a/src/maxtext/trainers/post_train/distillation/train_distill.py b/src/maxtext/trainers/post_train/distillation/train_distill.py index 5de310da90..183ca26bff 100644 --- a/src/maxtext/trainers/post_train/distillation/train_distill.py +++ b/src/maxtext/trainers/post_train/distillation/train_distill.py @@ -44,6 +44,7 @@ import jax import jax.numpy as jnp import optax +import re from orbax import checkpoint # MaxText Imports @@ -609,10 +610,11 @@ def train_distill( _log_config_details(student_config, "Student") student_model = get_maxtext_model(student_config, mesh) student_params_to_update = getattr(student_config, "student_params_to_update", []) + student_param_update_templates = [re.compile(t) for t in student_params_to_update] def student_freeze_param_fn(path) -> bool: path_str = "/".join(str(p) for p in path) - return not any(template in path_str for template in student_params_to_update) + return not any(regex.search(path_str) for regex in student_param_update_templates) # Inject the teacher's frozen weights into the student model if teacher_model: diff --git a/tests/post_training/unit/learn_to_init_test.py b/tests/post_training/unit/learn_to_init_test.py index 42571b8e43..865fcc411e 100644 --- a/tests/post_training/unit/learn_to_init_test.py +++ b/tests/post_training/unit/learn_to_init_test.py @@ -25,7 +25,7 @@ from maxtext.trainers.post_train.distillation.lti_utils import prepare_student_weights from unittest import mock from maxtext.models.llama2 import LlamaDecoderLayer -from maxtext.layers.learn_to_init_layer import LearnToInitDecoderLayer +from maxtext.layers.learn_to_init_layer import apply_lti_modification # Minimal dummy models for testing @@ -90,11 +90,7 @@ def test_prepare_student_weights_share_and_copy(self): # Since student's layer2 was "shared" from layer1, the copy operation # overwrites student's layer1. self.assertTrue(jnp.array_equal(student.layer1.kernel.value, teacher.layer2.kernel.value)) - - # The actual layer2 of the student remains unchanged because the dictionary - # reference was rerouted. We verify it still has its original initialization. - student_original_layer2 = DummyModel(nnx.Rngs(1)).layer2.kernel.value - self.assertTrue(jnp.array_equal(student.layer2.kernel.value, student_original_layer2)) + self.assertTrue(jnp.array_equal(student.layer2.kernel.value, teacher.layer2.kernel.value)) def test_prepare_student_weights_shape_mismatch(self): """Verifies that an error is raised when trying to copy misaligned shapes.""" @@ -214,19 +210,21 @@ def test_qkv_projection_general_map(self): self.assertEqual(out.shape, (batch_size, seq_len, student_heads, student_head_dim)) -class LearnToInitDecoderLayerTest(unittest.TestCase): +class ApplyLtiModificationTest(unittest.TestCase): + """Test LTI module augmentation functionality.""" - def test_llama_lti_decoder_layer_initialization(self): - """Verifies LearnToInitDecoderLayer initializes and modifies LlamaDecoderLayer correctly.""" + def get_mock_config(self): + """Setup mock teacher & student configs.""" - # 1. Setup mock teacher config mock_teacher_config = mock.MagicMock() mock_teacher_config.base_num_query_heads = 4 mock_teacher_config.base_num_kv_heads = 2 mock_teacher_config.head_dim = 16 - # 2. Setup mock student config + # Setup mock student config mock_config = mock.MagicMock() + mock_config.learn_to_init_mode = True + mock_config.attn_module_name = "self_attention" mock_config.lti_use_general_linear_map = False mock_config.teacher_config = mock_teacher_config @@ -259,48 +257,72 @@ def test_llama_lti_decoder_layer_initialization(self): mock_config.scan_layers = False mock_config.ici_context_autoregressive_parallelism = 1 mock_config.fused_qkv = False + return mock_config + + +def test_apply_lti_modification_initialization(self): + """Verifies apply_lti_modification modifies LlamaDecoderLayer correctly.""" + mock_config = self.get_mock_config() + mesh = jax.sharding.Mesh(jax.devices(), ("data",)) + rngs = nnx.Rngs(0) + + # Patch utility functions to isolate the test from deeper external dependencies + with ( + mock.patch("maxtext.src.maxtext.utils.max_utils.get_batch_seq_len_for_mode", return_value=(2, 32)), + mock.patch("maxtext.src.maxtext.layers.quantizations.configure_kv_quant", return_value=None), + ): + + # This effectively initializes LlamaLTIDecoderLayer and implicitly calls _customize_attention_modules + layer = LlamaDecoderLayer( + config=mock_config, + model_mode="train", + mesh=mesh, + rngs=rngs, + ) + layer = apply_lti_modification(layer) + + self.assertIsInstance(layer, LlamaDecoderLayer) + attention_module = layer.self_attention + + for proj_name in ["query", "key", "value", "out"]: + child = getattr(attention_module, proj_name) + self.assertIsInstance(child, LearnToInitDense, f"{proj_name} was not swapped to LearnToInitDense") + + # Validate that the dummy Teacher Tensor C is dimensioned correctly + if proj_name == "query": + # (emb_dim, teacher_heads, head_dim) -> (64, 4, 16) + self.assertEqual(child.C.value.shape, (64, 4, 16)) + elif proj_name in ("key", "value"): + # (emb_dim, teacher_kv_heads, head_dim) -> (64, 2, 16) + self.assertEqual(child.C.value.shape, (64, 2, 16)) + elif proj_name == "out": + # (teacher_heads, head_dim, emb_dim) -> (4, 16, 64) + self.assertEqual(child.C.value.shape, (4, 16, 64)) + + +def test_apply_lti_modification_skips_when_disabled(self): + """Verifies apply_lti_modification does not modify when learn_to_init_mode is False.""" + mock_config = self.get_mock_config() + mock_config.learn_to_init_mode = False + mesh = jax.sharding.Mesh(jax.devices(), ("data",)) + rngs = nnx.Rngs(0) + + with ( + mock.patch("maxtext.src.maxtext.utils.max_utils.get_batch_seq_len_for_mode", return_value=(2, 32)), + mock.patch("maxtext.src.maxtext.layers.quantizations.configure_kv_quant", return_value=None), + ): + layer = LlamaDecoderLayer( + config=mock_config, + model_mode="train", + mesh=mesh, + rngs=rngs, + ) + layer = apply_lti_modification(layer) - # 3. Dummy Jax sharding mesh and NNX Rngs - mesh = jax.sharding.Mesh(jax.devices(), ("data",)) - rngs = nnx.Rngs(0) - - # Patch utility functions to isolate the test from deeper external dependencies - with ( - mock.patch("maxtext.utils.max_utils.get_batch_seq_len_for_mode", return_value=(2, 32)), - mock.patch("maxtext.layers.quantizations.configure_kv_quant", return_value=None), - ): - - # This effectively initializes LlamaLTIDecoderLayer and implicitly calls _customize_attention_modules - layer = LearnToInitDecoderLayer( - base_layer_cls=LlamaDecoderLayer, - config=mock_config, - model_mode="train", - mesh=mesh, - rngs=rngs, - ) - - # 4. Verify initialization result - self.assertIsInstance(layer.learn_to_init_wrapper, LlamaDecoderLayer) - self.assertEqual(layer.self_attention_module_name, "self_attention") - - # 5. Verify the behavior of _customize_attention_modules - # It should correctly replace query, key, value, and out with LearnToInitDense - attention_module = layer.learn_to_init_wrapper.self_attention - - for proj_name in ["query", "key", "value", "out"]: - child = getattr(attention_module, proj_name) - self.assertIsInstance(child, LearnToInitDense, f"{proj_name} was not swapped to LearnToInitDense") - - # Validate that the dummy Teacher Tensor C is dimensioned correctly - if proj_name == "query": - # (emb_dim, teacher_heads, head_dim) -> (64, 4, 16) - self.assertEqual(child.C.value.shape, (64, 4, 16)) - elif proj_name in ("key", "value"): - # (emb_dim, teacher_kv_heads, head_dim) -> (64, 2, 16) - self.assertEqual(child.C.value.shape, (64, 2, 16)) - elif proj_name == "out": - # (teacher_heads, head_dim, emb_dim) -> (4, 16, 64) - self.assertEqual(child.C.value.shape, (4, 16, 64)) + attention_module = layer.self_attention + for proj_name in ["query", "key", "value", "out"]: + child = getattr(attention_module, proj_name) + self.assertNotIsInstance(child, LearnToInitDense, f"{proj_name} was unexpectedly swapped to LearnToInitDense") if __name__ == "__main__": From dd3fe3ba794ebd72560423771047924772c90962 Mon Sep 17 00:00:00 2001 From: vlad-karp Date: Fri, 1 May 2026 01:31:53 +0000 Subject: [PATCH 3/5] fixed tests --- src/maxtext/trainers/post_train/distillation/train_distill.py | 2 +- tests/post_training/unit/train_distill_test.py | 2 ++ 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/src/maxtext/trainers/post_train/distillation/train_distill.py b/src/maxtext/trainers/post_train/distillation/train_distill.py index 183ca26bff..d766ab69be 100644 --- a/src/maxtext/trainers/post_train/distillation/train_distill.py +++ b/src/maxtext/trainers/post_train/distillation/train_distill.py @@ -609,7 +609,7 @@ def train_distill( max_logging.log(f"Loading Student from {student_config.load_parameters_path}...") _log_config_details(student_config, "Student") student_model = get_maxtext_model(student_config, mesh) - student_params_to_update = getattr(student_config, "student_params_to_update", []) + student_params_to_update = getattr(student_config, "student_params_to_update", []) or [] student_param_update_templates = [re.compile(t) for t in student_params_to_update] def student_freeze_param_fn(path) -> bool: diff --git a/tests/post_training/unit/train_distill_test.py b/tests/post_training/unit/train_distill_test.py index ca2cbfa91f..d32e6bff6a 100644 --- a/tests/post_training/unit/train_distill_test.py +++ b/tests/post_training/unit/train_distill_test.py @@ -1064,6 +1064,7 @@ def test_main_offline_mode_skips_teacher_loading( mock_student_cfg.distill_weights_copy_map = {} mock_student_cfg.distill_student_weights_share_map = {} mock_student_cfg.get_keys.return_value = {} + mock_student_cfg.student_params_to_update = [] # Add scheduling attributes mock_student_cfg.distill_alpha_end = None @@ -1162,6 +1163,7 @@ def test_main_online_mode_loads_teacher( mock_student_cfg.distill_weights_copy_map = {} mock_student_cfg.distill_student_weights_share_map = {} mock_student_cfg.get_keys.return_value = {} + mock_student_cfg.student_params_to_update = [] # Add scheduling attributes mock_student_cfg.distill_alpha_end = None From 3605ce0b61ffe2f3f41fb086042c5144ef128e83 Mon Sep 17 00:00:00 2001 From: vlad-karp Date: Fri, 1 May 2026 01:31:53 +0000 Subject: [PATCH 4/5] fixed tests --- src/maxtext/layers/learn_to_init_layer.py | 115 +----------------- .../post_train/distillation/train_distill.py | 2 +- .../post_training/unit/train_distill_test.py | 2 + 3 files changed, 5 insertions(+), 114 deletions(-) diff --git a/src/maxtext/layers/learn_to_init_layer.py b/src/maxtext/layers/learn_to_init_layer.py index c03bcc4910..f71572bfac 100644 --- a/src/maxtext/layers/learn_to_init_layer.py +++ b/src/maxtext/layers/learn_to_init_layer.py @@ -18,12 +18,12 @@ from flax import nnx from maxtext.layers import linears, initializers from maxtext.common.common_types import Config -from jax.sharding import Mesh, NamedSharding +from jax.sharding import NamedSharding import jax.numpy as jnp from typing import Iterable, Optional from maxtext.common.common_types import DType, ShardMode, Array -from maxtext.layers.nnx_wrappers import ToNNX, to_linen_class +from maxtext.layers.nnx_wrappers import ToNNX from maxtext.layers.quantizations import AqtQuantization as Quant from maxtext.layers.initializers import NdInitializer, nd_dense_init from maxtext.utils import max_logging, max_utils @@ -73,93 +73,6 @@ def apply_lti_modification(module: nnx.Module, module_name: str | None = None): return module -class LearnToInitDecoderLayer(nnx.Module): - """ - A generic wrapper that initializes a base decoder layer and dynamically swaps - its DenseGeneral modules for learn-to-init distillation. - - This class instantiates a standard base decoder layer (e.g., LlamaDecoderLayer) - and replaces specific attention projection sub-modules ("query", "key", "value", - "out") with customized `LearnToInitDense` modules. - - Attributes: - learn_to_init_wrapper: The instantiated base decoder layer containing the mutable NNX graph. - config: The model configuration parameters. - rngs: The random number generator state used for initialization. - self_attention_module_name: The target name of the attention module to customize. - """ - - def __init__( - self, - base_layer_cls, - config: Config, - model_mode: str, - mesh: Mesh, - rngs: nnx.Rngs, - quant=None, - **kwargs, - ): - # Instantiate the original layer (e.g., LlamaDecoderLayer) - self.learn_to_init_wrapper = base_layer_cls( - config=config, model_mode=model_mode, mesh=mesh, rngs=rngs, quant=quant, **kwargs - ) - - self.config = config - self.rngs = rngs - - self.self_attention_module_name = "self_attention" - - # replace relevant nnx modules with customized LearnToInit modules - self._customize_attention_modules(self.learn_to_init_wrapper) - - def _customize_attention_modules(self, module: nnx.Module): - """Replaces specific DenseGeneral modules (q, k, v projections) in the attention module.""" - attention_module = getattr(module, self.self_attention_module_name, None) - if attention_module is None: - return - - use_general_linear_map = self.config.lti_use_general_linear_map - teacher_config = self.config.teacher_config - - # Target Q, K, V projections sub module names - for name in LTI_MODIFIED_ATTENTION_PARAM_NAMES: - child = getattr(attention_module, name, None) - if isinstance(child, linears.DenseGeneral): - orig_proj_shape = child.kernel.shape - assert len(orig_proj_shape) == 3 - if name in ("query", "key", "value"): - teacher_heads_num = teacher_config.base_num_query_heads if name == "query" else teacher_config.base_num_kv_heads - teacher_shape = (orig_proj_shape[0], teacher_heads_num, teacher_config.head_dim) - elif name == "out": - teacher_shape = (teacher_config.base_num_query_heads, teacher_config.head_dim, orig_proj_shape[2]) - else: - max_logging.warning(f"Non handled LTI projection type {name}") - continue - new_module = LearnToInitDense( - in_features_shape=child.in_features_shape, - out_features_shape=child.out_features_shape, - C=jnp.empty(teacher_shape), - axis=child.axis, - weight_dtype=child.weight_dtype, - dtype=child.dtype, - kernel_init=child.kernel_init, - kernel_axes=child.kernel_axes, - quant=child.quant, - use_bias=child.use_bias, - shard_mode=child.shard_mode, - matmul_precision=child.matmul_precision, - is_output_projection=(name == "out"), - use_general_linear_map=use_general_linear_map, - rngs=self.rngs, # Reuse the layer's RNG stream - ) - # Swap the module in the mutable NNX graph - setattr(attention_module, name, new_module) - - def __call__(self, *args, **kwargs): - # Just forward the forward pass arguments to the base layer - return self.learn_to_init_wrapper(*args, **kwargs) - - def _customize_attention_modules(config: Config, attn_module_name: str, module: nnx.Module): """Replaces specific DenseGeneral modules (q, k, v projections) in the attention module.""" attention_module = getattr(module, attn_module_name, None) @@ -532,27 +445,3 @@ def apply_lti_model_update(student_model, student_config): proj_params.pop(LearnToInitDense.TENSOR_A, None) proj_params.pop(LearnToInitDense.TENSOR_B, None) proj_params.pop(LearnToInitDense.TENSOR_C, None) - - -def get_decoder_layer_to_linen(config, base_layer_cls): - """ - Returns the appropriate Linen-wrapped decoder layer for any architecture. - Wraps it in LearnToInitDecoderLayer if config.enable_lti is True. - """ - if getattr(config, "learn_to_init_mode", False): - # Dynamically create the Type-bounded LTI wrapper for the provided base layer - class LTIWrapper(LearnToInitDecoderLayer): - - def __init__(self, *args, **kwargs): - super().__init__(*args, base_layer_cls=base_layer_cls, **kwargs) - - return to_linen_class( - LTIWrapper, - base_metadata_fn=initializers.variable_to_logically_partitioned, - ) - - # Return the standard unwrapped Linen class - return to_linen_class( - base_layer_cls, - base_metadata_fn=initializers.variable_to_logically_partitioned, - ) diff --git a/src/maxtext/trainers/post_train/distillation/train_distill.py b/src/maxtext/trainers/post_train/distillation/train_distill.py index 183ca26bff..d766ab69be 100644 --- a/src/maxtext/trainers/post_train/distillation/train_distill.py +++ b/src/maxtext/trainers/post_train/distillation/train_distill.py @@ -609,7 +609,7 @@ def train_distill( max_logging.log(f"Loading Student from {student_config.load_parameters_path}...") _log_config_details(student_config, "Student") student_model = get_maxtext_model(student_config, mesh) - student_params_to_update = getattr(student_config, "student_params_to_update", []) + student_params_to_update = getattr(student_config, "student_params_to_update", []) or [] student_param_update_templates = [re.compile(t) for t in student_params_to_update] def student_freeze_param_fn(path) -> bool: diff --git a/tests/post_training/unit/train_distill_test.py b/tests/post_training/unit/train_distill_test.py index ca2cbfa91f..d32e6bff6a 100644 --- a/tests/post_training/unit/train_distill_test.py +++ b/tests/post_training/unit/train_distill_test.py @@ -1064,6 +1064,7 @@ def test_main_offline_mode_skips_teacher_loading( mock_student_cfg.distill_weights_copy_map = {} mock_student_cfg.distill_student_weights_share_map = {} mock_student_cfg.get_keys.return_value = {} + mock_student_cfg.student_params_to_update = [] # Add scheduling attributes mock_student_cfg.distill_alpha_end = None @@ -1162,6 +1163,7 @@ def test_main_online_mode_loads_teacher( mock_student_cfg.distill_weights_copy_map = {} mock_student_cfg.distill_student_weights_share_map = {} mock_student_cfg.get_keys.return_value = {} + mock_student_cfg.student_params_to_update = [] # Add scheduling attributes mock_student_cfg.distill_alpha_end = None From 82611ad7040ebb7995504c18581a3e2b01cefbe4 Mon Sep 17 00:00:00 2001 From: vlad-karp Date: Fri, 1 May 2026 19:15:28 +0000 Subject: [PATCH 5/5] addressed comments --- src/maxtext/layers/learn_to_init_layer.py | 9 +- src/maxtext/layers/nnx_wrappers.py | 16 +-- src/maxtext/models/llama2.py | 2 +- .../post_train/distillation/lti_utils.py | 8 +- .../post_training/unit/learn_to_init_test.py | 124 +++++++++--------- 5 files changed, 78 insertions(+), 81 deletions(-) diff --git a/src/maxtext/layers/learn_to_init_layer.py b/src/maxtext/layers/learn_to_init_layer.py index f71572bfac..2530c17336 100644 --- a/src/maxtext/layers/learn_to_init_layer.py +++ b/src/maxtext/layers/learn_to_init_layer.py @@ -31,6 +31,7 @@ LTI_MODIFIED_ATTENTION_PARAM_NAMES = ["query", "key", "value", "out"] LTI_ORIGINAL_ATTENTION_PARAMS_NAME = "kernel" +LTI_LAYER_PATH_PREFIXES = ("layers_", "dense_layers_", "moe_layers_") def apply_lti_modification(module: nnx.Module, module_name: str | None = None): @@ -59,7 +60,7 @@ def apply_lti_modification(module: nnx.Module, module_name: str | None = None): if layer_idx not in lti_layer_indices: max_logging.info( f"apply_lti_modification: skipping module={module_name} since its index " - "{layer_idx} is not in lti_layer_indices." + f"{layer_idx} is not in lti_layer_indices." ) return module except ValueError: @@ -392,10 +393,8 @@ def apply_lti_model_update(student_model, student_config): else: layer_modules = [] # Collect all possible layer names (e.g. layers_0, dense_layers_0, moe_layers_0) - for name, module in vars( - student_model.decoder - ).items(): # or name.startswith("dense_layers_") or name.startswith("moe_layers_"): - if name.startswith("layers_"): + for name, module in vars(student_model.decoder).items(): + if name.startswith(LTI_LAYER_PATH_PREFIXES): layer_modules.append(module) for layer_module in layer_modules: diff --git a/src/maxtext/layers/nnx_wrappers.py b/src/maxtext/layers/nnx_wrappers.py index 1f986b52c0..7e32053a52 100644 --- a/src/maxtext/layers/nnx_wrappers.py +++ b/src/maxtext/layers/nnx_wrappers.py @@ -418,7 +418,7 @@ class ToLinen(linen.Module): metadata_fn: tp.Callable[[variablelib.VariableState], tp.Any] | None = to_linen_var # generic function to augment original nnx module (i.e for learn-to-init distillation) - nxx_module_augment_fn: tp.Callable[[Module, str | None], Module] | None = None + nnx_module_augment_fn: tp.Callable[[Module, str | None], Module] | None = None @linen.compact def __call__(self, *args, nnx_method: tp.Callable[..., Any] | str | None = None, **kwargs): @@ -433,8 +433,8 @@ def _module_kwargs(): if self.is_initializing(): module = self.nnx_class(*self.args, **_module_kwargs()) - if self.nxx_module_augment_fn is not None: - module = self.nxx_module_augment_fn(module, self.name) + if self.nnx_module_augment_fn is not None: + module = self.nnx_module_augment_fn(module, self.name) # TODO: add lazy_init here in case there's an `ToNNX` submodule under `module`. # update linen variables before call module to save initial state @@ -449,8 +449,8 @@ def _module_kwargs(): # Modify the nnx structure BEFORE loading the state # This ensures the tree structure matches the state we are about to inject - if self.nxx_module_augment_fn is not None: - module = self.nxx_module_augment_fn(module, self.name) + if self.nnx_module_augment_fn is not None: + module = self.nnx_module_augment_fn(module, self.name) # update nnx module from linen variables def maybe_unbox(x): @@ -565,7 +565,7 @@ def to_linen_class( base_nnx_class: type[M], base_metadata_fn: tp.Callable[[variablelib.VariableState], tp.Any] | None = to_linen_var, base_skip_rng: bool = False, - nxx_module_augment_fn: tp.Callable[[Module], Module] | None = None, + nnx_module_augment_fn: tp.Callable[[Module, str | None], Module] | None = None, **partial_kwargs: tp.Any, ) -> type[ToLinen]: """A dynamically created Linen Module that wraps a specific NNX Module. @@ -608,7 +608,7 @@ def __init__( metadata_fn=None, name=_MISSING, parent=_MISSING, - nxx_module_augment_fn=nxx_module_augment_fn, + nnx_module_augment_fn=nnx_module_augment_fn, **other_kwargs, ): linen_kwargs = {} @@ -623,7 +623,7 @@ def __init__( metadata_fn=metadata_fn or base_metadata_fn, skip_rng=skip_rng or base_skip_rng, kwargs=FrozenDict({**partial_kwargs, **(kwargs or {}), **other_kwargs}), - nxx_module_augment_fn=nxx_module_augment_fn, + nnx_module_augment_fn=nnx_module_augment_fn, **linen_kwargs, ) diff --git a/src/maxtext/models/llama2.py b/src/maxtext/models/llama2.py index 777efd8c1f..a75cefc291 100644 --- a/src/maxtext/models/llama2.py +++ b/src/maxtext/models/llama2.py @@ -231,5 +231,5 @@ def update_cache(cache, val): LlamaDecoderLayerToLinen = nnx_wrappers.to_linen_class( LlamaDecoderLayer, base_metadata_fn=initializers.variable_to_logically_partitioned, - nxx_module_augment_fn=apply_lti_modification, + nnx_module_augment_fn=apply_lti_modification, ) diff --git a/src/maxtext/trainers/post_train/distillation/lti_utils.py b/src/maxtext/trainers/post_train/distillation/lti_utils.py index 59be3c55db..895cd4d852 100644 --- a/src/maxtext/trainers/post_train/distillation/lti_utils.py +++ b/src/maxtext/trainers/post_train/distillation/lti_utils.py @@ -120,10 +120,10 @@ def prepare_student_weights( assert ( student_node is not None ), f"Could not find student model parameter at path: {student_path} for teacher parameter {teacher_path}" - assert ( - student_node.get_value().shape == teacher_node.get_value().shape - ), f"Shape mismatch for {teacher_path}. Teacher: {teacher_node.get_value().shape}," - " Student: {student_node.get_value().shape}" + assert student_node.get_value().shape == teacher_node.get_value().shape, ( + f"Shape mismatch for {teacher_path}. Teacher: {teacher_node.get_value().shape}," + f" Student: {student_node.get_value().shape}" + ) student_node.set_value(teacher_node.get_value()) max_logging.info(f"Inserted teacher weight parameter {teacher_path} to the student at {student_path}") diff --git a/tests/post_training/unit/learn_to_init_test.py b/tests/post_training/unit/learn_to_init_test.py index 865fcc411e..6c757c00c6 100644 --- a/tests/post_training/unit/learn_to_init_test.py +++ b/tests/post_training/unit/learn_to_init_test.py @@ -259,70 +259,68 @@ def get_mock_config(self): mock_config.fused_qkv = False return mock_config + def test_apply_lti_modification_initialization(self): + """Verifies apply_lti_modification modifies LlamaDecoderLayer correctly.""" + mock_config = self.get_mock_config() + mesh = jax.sharding.Mesh(jax.devices(), ("data",)) + rngs = nnx.Rngs(0) + + # Patch utility functions to isolate the test from deeper external dependencies + with ( + mock.patch("maxtext.utils.max_utils.get_batch_seq_len_for_mode", return_value=(2, 32)), + mock.patch("maxtext.layers.quantizations.configure_kv_quant", return_value=None), + ): + + # This effectively initializes LlamaLTIDecoderLayer and implicitly calls _customize_attention_modules + layer = LlamaDecoderLayer( + config=mock_config, + model_mode="train", + mesh=mesh, + rngs=rngs, + ) + layer = apply_lti_modification(layer) + + self.assertIsInstance(layer, LlamaDecoderLayer) + attention_module = layer.self_attention + + for proj_name in ["query", "key", "value", "out"]: + child = getattr(attention_module, proj_name) + self.assertIsInstance(child, LearnToInitDense, f"{proj_name} was not swapped to LearnToInitDense") + + # Validate that the dummy Teacher Tensor C is dimensioned correctly + if proj_name == "query": + # (emb_dim, teacher_heads, head_dim) -> (64, 4, 16) + self.assertEqual(child.C.value.shape, (64, 4, 16)) + elif proj_name in ("key", "value"): + # (emb_dim, teacher_kv_heads, head_dim) -> (64, 2, 16) + self.assertEqual(child.C.value.shape, (64, 2, 16)) + elif proj_name == "out": + # (teacher_heads, head_dim, emb_dim) -> (4, 16, 64) + self.assertEqual(child.C.value.shape, (4, 16, 64)) + + def test_apply_lti_modification_skips_when_disabled(self): + """Verifies apply_lti_modification does not modify when learn_to_init_mode is False.""" + mock_config = self.get_mock_config() + mock_config.learn_to_init_mode = False + mesh = jax.sharding.Mesh(jax.devices(), ("data",)) + rngs = nnx.Rngs(0) + + with ( + mock.patch("maxtext.utils.max_utils.get_batch_seq_len_for_mode", return_value=(2, 32)), + mock.patch("maxtext.layers.quantizations.configure_kv_quant", return_value=None), + ): + layer = LlamaDecoderLayer( + config=mock_config, + model_mode="train", + mesh=mesh, + rngs=rngs, + ) + layer = apply_lti_modification(layer) -def test_apply_lti_modification_initialization(self): - """Verifies apply_lti_modification modifies LlamaDecoderLayer correctly.""" - mock_config = self.get_mock_config() - mesh = jax.sharding.Mesh(jax.devices(), ("data",)) - rngs = nnx.Rngs(0) - - # Patch utility functions to isolate the test from deeper external dependencies - with ( - mock.patch("maxtext.src.maxtext.utils.max_utils.get_batch_seq_len_for_mode", return_value=(2, 32)), - mock.patch("maxtext.src.maxtext.layers.quantizations.configure_kv_quant", return_value=None), - ): - - # This effectively initializes LlamaLTIDecoderLayer and implicitly calls _customize_attention_modules - layer = LlamaDecoderLayer( - config=mock_config, - model_mode="train", - mesh=mesh, - rngs=rngs, - ) - layer = apply_lti_modification(layer) - - self.assertIsInstance(layer, LlamaDecoderLayer) - attention_module = layer.self_attention - - for proj_name in ["query", "key", "value", "out"]: - child = getattr(attention_module, proj_name) - self.assertIsInstance(child, LearnToInitDense, f"{proj_name} was not swapped to LearnToInitDense") - - # Validate that the dummy Teacher Tensor C is dimensioned correctly - if proj_name == "query": - # (emb_dim, teacher_heads, head_dim) -> (64, 4, 16) - self.assertEqual(child.C.value.shape, (64, 4, 16)) - elif proj_name in ("key", "value"): - # (emb_dim, teacher_kv_heads, head_dim) -> (64, 2, 16) - self.assertEqual(child.C.value.shape, (64, 2, 16)) - elif proj_name == "out": - # (teacher_heads, head_dim, emb_dim) -> (4, 16, 64) - self.assertEqual(child.C.value.shape, (4, 16, 64)) - - -def test_apply_lti_modification_skips_when_disabled(self): - """Verifies apply_lti_modification does not modify when learn_to_init_mode is False.""" - mock_config = self.get_mock_config() - mock_config.learn_to_init_mode = False - mesh = jax.sharding.Mesh(jax.devices(), ("data",)) - rngs = nnx.Rngs(0) - - with ( - mock.patch("maxtext.src.maxtext.utils.max_utils.get_batch_seq_len_for_mode", return_value=(2, 32)), - mock.patch("maxtext.src.maxtext.layers.quantizations.configure_kv_quant", return_value=None), - ): - layer = LlamaDecoderLayer( - config=mock_config, - model_mode="train", - mesh=mesh, - rngs=rngs, - ) - layer = apply_lti_modification(layer) - - attention_module = layer.self_attention - for proj_name in ["query", "key", "value", "out"]: - child = getattr(attention_module, proj_name) - self.assertNotIsInstance(child, LearnToInitDense, f"{proj_name} was unexpectedly swapped to LearnToInitDense") + attention_module = layer.self_attention + for proj_name in ["query", "key", "value", "out"]: + child = getattr(attention_module, proj_name) + self.assertNotIsInstance(child, LearnToInitDense, f"{proj_name} was unexpectedly swapped to LearnToInitDense") if __name__ == "__main__":