Skip to content

[Distillation] Layer-wise LTI#3769

Open
vlad-karp wants to merge 6 commits intomainfrom
vladk/lti2
Open

[Distillation] Layer-wise LTI#3769
vlad-karp wants to merge 6 commits intomainfrom
vladk/lti2

Conversation

@vlad-karp
Copy link
Copy Markdown
Collaborator

@vlad-karp vlad-karp commented Apr 29, 2026

Layer-wise LTI

Generalization of Learn-To-Init (LTI) Mechanism

The rest of the description includes relevant details and context, examples:

  • current LTI implementation is too model specific and imposes multiple model structure inconveniences when dealing with the extra intermediate LTI wrapper. Also, it works only in the layer-scanned mode.
  • This PR refactors the Learn-To-Init (LTI) approach to make it model-agnostic and more generalized for other models. Enables non layer-scan mode to allow layer-wise LTI training.
  • This allows injecting generalized LTI modifications (apply_lti_modification) dynamically into any instantiated base NNX layer. One can use layer-wise logic for LTI (i.e. augment only specific layers)
  • Removed of Llama2-Specific LTI decoder.
  • Dynamic Module Augmentation : nnx modules are LTI-augmented as they are created in the linen flow
  • The distillation utilities (lti_utils.py and train_distill.py) were upgraded to use regex patterns for weight sharing, copying, and freezing instead of exact path matching

Shortcomings:

  • One still have to add apply_lti_modification function (or another model specific LTI method) when instantiating model's linen version class
  • Currently is not of a much use until the layer-wise student model configuration is available.

Tests

learn_to_init_test.py and train_distill_test.py were refactored to validate the new generic LTI augmentation functionality and the regex-based weight preparation logic.

Checklist

Before submitting this PR, please make sure (put X in square brackets):

  • I have performed a self-review of my code. For an optional AI review, add the gemini-review label.
  • I have necessary comments in my code, particularly in hard-to-understand areas.
  • I have run end-to-end tests tests and provided workload links above if applicable.
  • I have made or will make corresponding changes to the doc if needed, including adding new documentation pages to the relevant Table of Contents (toctree directive) as explained in our documentation.

@vlad-karp vlad-karp marked this pull request as ready for review May 1, 2026 17:20
@vlad-karp vlad-karp changed the title generalized LTI Layer-wise LTI May 1, 2026
@vlad-karp vlad-karp changed the title Layer-wise LTI [Distillation] Layer-wise LTI May 1, 2026
@github-actions
Copy link
Copy Markdown

github-actions Bot commented May 1, 2026

🤖 Hi @gagika, I've received your request, and I'm working on it now! You can track my progress in the logs for more details.

Copy link
Copy Markdown

@github-actions github-actions Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

## 📋 Review Summary

This PR introduces a significant refactor to the Learn-To-Init (LTI) mechanism, making it model-agnostic and more flexible by using dynamic NNX module augmentation. While the overall design improvement is positive and aligns with the goal of generalizing LTI, there are several critical and high-severity issues that need to be addressed, including syntax errors in assertions, missing f-string prefixes, and logical inconsistencies in layer collection.

🔍 General Feedback

  • Regex Support: The move to regex-based weight sharing and copying is a great addition that improves the flexibility of the distillation pipeline.
  • Model Agnostic LTI: Decoupling LTI from specific model architectures (like Llama2) is a good architectural move.
  • Testing: New tests were added for the generic augmentation, but they currently have structural issues (missing indentation) and incorrect mock patch paths that will prevent them from running correctly.
  • Inconsistencies: There is some inconsistency in how different layer prefixes (e.g., dense_layers_, moe_layers_) are handled between initialization and the final weight update.

Comment thread src/maxtext/layers/learn_to_init_layer.py Outdated
Comment thread src/maxtext/layers/learn_to_init_layer.py
Comment thread tests/post_training/unit/learn_to_init_test.py Outdated
Comment thread tests/post_training/unit/learn_to_init_test.py Outdated
Comment thread src/maxtext/trainers/post_train/distillation/lti_utils.py
Comment thread src/maxtext/layers/learn_to_init_layer.py Outdated
Comment thread tests/post_training/unit/learn_to_init_test.py Outdated
Comment thread tests/post_training/unit/learn_to_init_test.py Outdated
Comment thread src/maxtext/layers/nnx_wrappers.py Outdated
Comment thread src/maxtext/trainers/post_train/distillation/lti_utils.py Outdated
Comment thread src/maxtext/layers/learn_to_init_layer.py Outdated
Copy link
Copy Markdown
Collaborator

@JamesDeng42 JamesDeng42 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Generally LGTM

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants