Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions src/maxtext/common/common_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,8 +109,6 @@ class DecoderBlockType(enum.Enum):
LLAMA4 = "llama4"
OLMO3 = "olmo3"

LLAMA2LTI = "llama2_learn_to_init"


class AttentionType(enum.Enum):
GLOBAL = "global" # default, with causality
Expand Down
8 changes: 8 additions & 0 deletions src/maxtext/configs/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -1225,6 +1225,14 @@ class Distillation(BaseModel):
default_factory=dict,
description="Experimental weight sharing map inside the student model for learn-to-init phase",
)

attn_module_name: Optional[str] = Field(
None, description="Attention nnx module attribute name to augment with LTI logic"
)

lti_layer_indices: Optional[list[int]] = Field(
None, description="List of layer indices to apply LTI modifications. If None, applied to all layers."
)
# ---------------------------------------

# --- Distillation freezing filter --
Expand Down
3 changes: 0 additions & 3 deletions src/maxtext/layers/decoders.py
Original file line number Diff line number Diff line change
Expand Up @@ -449,8 +449,6 @@ def get_decoder_layers(self):
return [DecoderLayer]
case DecoderBlockType.LLAMA2:
return [llama2.LlamaDecoderLayerToLinen]
case DecoderBlockType.LLAMA2LTI:
return [llama2.LlamaLTIDecoderLayerToLinen]
case DecoderBlockType.MISTRAL:
# TODO(ranran): update to Mistral with sliding window attention
return [mistral.MistralDecoderLayerToLinen]
Expand Down Expand Up @@ -545,7 +543,6 @@ def get_norm_layer(self, num_features: int):
DecoderBlockType.SIMPLE_MLP,
DecoderBlockType.LLAMA4,
DecoderBlockType.OLMO3,
DecoderBlockType.LLAMA2LTI,
):
return functools.partial(rms_norm, num_features=num_features, shard_mode=self.config.shard_mode)
elif self.config.decoder_block == DecoderBlockType.GPT3:
Expand Down
285 changes: 145 additions & 140 deletions src/maxtext/layers/learn_to_init_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from flax import nnx
from maxtext.layers import linears, initializers
from maxtext.common.common_types import Config
from jax.sharding import Mesh, NamedSharding
from jax.sharding import NamedSharding
import jax.numpy as jnp
from typing import Iterable, Optional

Expand All @@ -29,93 +29,96 @@
from maxtext.utils import max_logging, max_utils


class LearnToInitDecoderLayer(nnx.Module):
"""
A generic wrapper that initializes a base decoder layer and dynamically swaps
its DenseGeneral modules for learn-to-init distillation.
LTI_MODIFIED_ATTENTION_PARAM_NAMES = ["query", "key", "value", "out"]
LTI_ORIGINAL_ATTENTION_PARAMS_NAME = "kernel"
LTI_LAYER_PATH_PREFIXES = ("layers_", "dense_layers_", "moe_layers_")

This class instantiates a standard base decoder layer (e.g., LlamaDecoderLayer)
and replaces specific attention projection sub-modules ("query", "key", "value",
"out") with customized `LearnToInitDense` modules.

Attributes:
learn_to_init_wrapper: The instantiated base decoder layer containing the mutable NNX graph.
config: The model configuration parameters.
rngs: The random number generator state used for initialization.
self_attention_module_name: The target name of the attention module to customize.
def apply_lti_modification(module: nnx.Module, module_name: str | None = None):
"""
Applies Learn-To-Init structural modifications to an instantiated NNX module.
Checks the config to determine if LTI is enabled.
"""

def __init__(
self,
base_layer_cls,
config: Config,
model_mode: str,
mesh: Mesh,
rngs: nnx.Rngs,
quant=None,
**kwargs,
):
# Instantiate the original layer (e.g., LlamaDecoderLayer)
self.learn_to_init_wrapper = base_layer_cls(
config=config, model_mode=model_mode, mesh=mesh, rngs=rngs, quant=quant, **kwargs
)

self.config = config
self.rngs = rngs

self.self_attention_module_name = "self_attention"

# replace relevant nnx modules with customized LearnToInit modules
self._customize_attention_modules(self.learn_to_init_wrapper)

def _customize_attention_modules(self, module: nnx.Module):
"""Replaces specific DenseGeneral modules (q, k, v projections) in the attention module."""
attention_module = getattr(module, self.self_attention_module_name, None)
if attention_module is None:
return

# Target Q, K, V projections sub module names
target_names = ["query", "key", "value", "out"]

use_general_linear_map = self.config.lti_use_general_linear_map
teacher_config = self.config.teacher_config

for name in target_names:
child = getattr(attention_module, name, None)
if isinstance(child, linears.DenseGeneral):
orig_proj_shape = child.kernel.shape
assert len(orig_proj_shape) == 3
if name in ("query", "key", "value"):
teacher_heads_num = teacher_config.base_num_query_heads if name == "query" else teacher_config.base_num_kv_heads
teacher_shape = (orig_proj_shape[0], teacher_heads_num, teacher_config.head_dim)
elif name == "out":
teacher_shape = (teacher_config.base_num_query_heads, teacher_config.head_dim, orig_proj_shape[2])
config = getattr(module, "config", None)
if not config or not getattr(config, "learn_to_init_mode", False):
return module

lti_layer_indices = getattr(config, "lti_layer_indices", None)
if lti_layer_indices is not None and not getattr(config, "scan_layers", True):
if module_name is not None and "layers_" in module_name:
try:
local_idx = int(module_name.split("_")[-1])
if module_name.startswith("dense_layers_"):
layer_idx = local_idx
elif module_name.startswith("moe_layers_"):
first_dense = getattr(config, "first_num_dense_layers", 0)
layer_idx = first_dense + local_idx
else:
max_logging.warning(f"Non handled LTI projection type {name}")
continue
new_module = LearnToInitDense(
in_features_shape=child.in_features_shape,
out_features_shape=child.out_features_shape,
C=jnp.empty(teacher_shape),
axis=child.axis,
weight_dtype=child.weight_dtype,
dtype=child.dtype,
kernel_init=child.kernel_init,
kernel_axes=child.kernel_axes,
quant=child.quant,
use_bias=child.use_bias,
shard_mode=child.shard_mode,
matmul_precision=child.matmul_precision,
is_output_projection=(name == "out"),
use_general_linear_map=use_general_linear_map,
rngs=self.rngs, # Reuse the layer's RNG stream
)
# Swap the module in the mutable NNX graph
setattr(attention_module, name, new_module)

def __call__(self, *args, **kwargs):
# Just forward the forward pass arguments to the base layer
return self.learn_to_init_wrapper(*args, **kwargs)
layer_idx = local_idx

if layer_idx not in lti_layer_indices:
max_logging.info(
f"apply_lti_modification: skipping module={module_name} since its index "
f"{layer_idx} is not in lti_layer_indices."
)
return module
except ValueError:
pass

attn_module_name = config.attn_module_name

if attn_module_name:
max_logging.info(f"apply_lti_modification: customizing module={attn_module_name} in {module_name}")
_customize_attention_modules(config, attn_module_name, module)
return module


Comment thread
vlad-karp marked this conversation as resolved.
def _customize_attention_modules(config: Config, attn_module_name: str, module: nnx.Module):
"""Replaces specific DenseGeneral modules (q, k, v projections) in the attention module."""
attention_module = getattr(module, attn_module_name, None)
if attention_module is None:
return

# Target Q, K, V projections sub module names
target_names = LTI_MODIFIED_ATTENTION_PARAM_NAMES

use_general_linear_map = config.lti_use_general_linear_map
teacher_config = config.teacher_config

for name in target_names:
child = getattr(attention_module, name, None)
if isinstance(child, linears.DenseGeneral):
orig_proj_shape = child.kernel.shape
assert len(orig_proj_shape) == 3
if name in ("query", "key", "value"):
teacher_heads_num = teacher_config.base_num_query_heads if name == "query" else teacher_config.base_num_kv_heads
teacher_shape = (orig_proj_shape[0], teacher_heads_num, teacher_config.head_dim)
elif name == "out":
teacher_shape = (teacher_config.base_num_query_heads, teacher_config.head_dim, orig_proj_shape[2])
else:
max_logging.warning(f"Non handled LTI projection type {name}")
continue
new_module = LearnToInitDense(
in_features_shape=child.in_features_shape,
out_features_shape=child.out_features_shape,
C=jnp.empty(teacher_shape),
axis=child.axis,
weight_dtype=child.weight_dtype,
dtype=child.dtype,
kernel_init=child.kernel_init,
kernel_axes=child.kernel_axes,
quant=child.quant,
use_bias=child.use_bias,
shard_mode=child.shard_mode,
matmul_precision=child.matmul_precision,
is_output_projection=(name == "out"),
use_general_linear_map=use_general_linear_map,
rngs=attention_module.rngs, # Reuse the original module RNGs
)
# Swap the module in the mutable NNX graph
setattr(attention_module, name, new_module)
max_logging.info(f"Replaced {attn_module_name}.{name} with LearnToInitDense.{name}")


class LearnToInitDense(nnx.Module):
Expand Down Expand Up @@ -373,69 +376,71 @@ def apply_lti_model_update(student_model, student_config):
It effectively collapses the learn-to-init parameterization back into a standard
decoder architecture, modifying the `student_model` in-place.

NOTE: works for ToNXX decoder model and layer-scan mode only
NOTE: works for ToNXX decoder model

Args:
student_model: The trained student model to be updated in-place.
student_config: The configuration of the student model containing parameters like `matmul_precision`.
"""

# Access the nested ToNNX dictionary directly
assert isinstance(student_model.decoder, ToNNX), "LTI now only supports ToNNX as the student_model's decoder type"
lti_wrapped_node = student_model.decoder.layers["learn_to_init_wrapper"]
attn_state_dict = lti_wrapped_node["self_attention"]

# Iterate through known projections and compute final weights
for proj_name in ["query", "key", "value", "out"]:
if proj_name not in attn_state_dict:
raise ValueError("Unsupported structure of LTI-augmented Attention module.")

proj_params = attn_state_dict[proj_name]
is_output_proj = proj_name == "out"

C_param = proj_params.get(LearnToInitDense.TENSOR_C)

if C_param is None:
raise ValueError("Attention LTI-augmented module has no C parameter.")

if LearnToInitDense.TENSOR_W in proj_params:
max_logging.log(f"Computing final learn-to-init weight (general map) for: {proj_name}")
W_param = proj_params[LearnToInitDense.TENSOR_W]
final_kernel = calculate_attn_weight(
A=None,
B=None,
C=C_param,
general_map=W_param,
is_output_projection=is_output_proj,
matmul_precision=student_config.matmul_precision,
)
elif LearnToInitDense.TENSOR_A in proj_params and LearnToInitDense.TENSOR_B in proj_params:
max_logging.log(f"Computing final learn-to-init weight for: {proj_name}")
A_param = proj_params[LearnToInitDense.TENSOR_A]
B_param = proj_params[LearnToInitDense.TENSOR_B]
final_kernel = calculate_attn_weight(
A=A_param,
B=B_param,
C=C_param,
is_output_projection=is_output_proj,
matmul_precision=student_config.matmul_precision,
)
else:
continue

# 3. Overwrite C with the final computed kernel
C_param.set_value(final_kernel)
if student_config.attn_module_name is None:
return

# 4. Standardize the structure by placing it under the 'kernel' key
proj_params["kernel"] = C_param

# 5. Clean up the LTI-specific parameters using .pop()
# Using pop(key, None) avoids KeyErrors if a tensor was omitted or already shared/deleted
proj_params.pop(LearnToInitDense.TENSOR_W, None)
proj_params.pop(LearnToInitDense.TENSOR_A, None)
proj_params.pop(LearnToInitDense.TENSOR_B, None)
proj_params.pop(LearnToInitDense.TENSOR_C, None)

# unpack the learn_to_init_wrapper to match the standard model structure
del student_model.decoder.layers["learn_to_init_wrapper"]
student_model.decoder.layers.update(lti_wrapped_node)
if getattr(student_config, "scan_layers", True):
layer_modules = [student_model.decoder.layers]
else:
layer_modules = []
# Collect all possible layer names (e.g. layers_0, dense_layers_0, moe_layers_0)
for name, module in vars(student_model.decoder).items():
if name.startswith(LTI_LAYER_PATH_PREFIXES):
layer_modules.append(module)

for layer_module in layer_modules:
attn_state_dict = layer_module.get(student_config.attn_module_name)
if attn_state_dict is None:
raise ValueError("LTI: attn_state_dict wasn't found in the model state dict")

for proj_name in LTI_MODIFIED_ATTENTION_PARAM_NAMES:
proj_params = attn_state_dict.get(proj_name)
if proj_params is None:
raise ValueError("Non LTI supported Attention module state.")

is_output_proj = proj_name == "out"
C_param = proj_params.get(LearnToInitDense.TENSOR_C)
if C_param is None:
continue # Not an LTI augmented module
if LearnToInitDense.TENSOR_W in proj_params:
max_logging.log(f"Computing final learn-to-init weight (general map) for: {proj_name}")
W_param = proj_params[LearnToInitDense.TENSOR_W]
final_kernel = calculate_attn_weight(
A=None,
B=None,
C=C_param,
general_map=W_param,
is_output_projection=is_output_proj,
matmul_precision=student_config.matmul_precision,
)
elif LearnToInitDense.TENSOR_A in proj_params and LearnToInitDense.TENSOR_B in proj_params:
max_logging.log(f"Computing final learn-to-init weight for: {proj_name}")
A_param = proj_params[LearnToInitDense.TENSOR_A]
B_param = proj_params[LearnToInitDense.TENSOR_B]
final_kernel = calculate_attn_weight(
A=A_param,
B=B_param,
C=C_param,
is_output_projection=is_output_proj,
matmul_precision=student_config.matmul_precision,
)
else:
raise ValueError("Non LTI supported Attention module state.")

C_param.set_value(final_kernel)
# inject as a regular parameter
proj_params[LTI_ORIGINAL_ATTENTION_PARAMS_NAME] = C_param
# Clean up the LTI-specific parameters
proj_params.pop(LearnToInitDense.TENSOR_W, None)
proj_params.pop(LearnToInitDense.TENSOR_A, None)
proj_params.pop(LearnToInitDense.TENSOR_B, None)
proj_params.pop(LearnToInitDense.TENSOR_C, None)
Loading
Loading